##// END OF EJS Templates
s/IPython.parallel/ipython_parallel/
Min RK -
Show More
@@ -1,72 +1,72 b''
1 """The IPython ZMQ-based parallel computing interface.
1 """The IPython ZMQ-based parallel computing interface.
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import warnings
19 import warnings
20
20
21 import zmq
21 import zmq
22
22
23 from IPython.config.configurable import MultipleInstanceError
23 from IPython.config.configurable import MultipleInstanceError
24 from IPython.utils.zmqrelated import check_for_zmq
24 from IPython.utils.zmqrelated import check_for_zmq
25
25
26 min_pyzmq = '2.1.11'
26 min_pyzmq = '2.1.11'
27
27
28 check_for_zmq(min_pyzmq, 'IPython.parallel')
28 check_for_zmq(min_pyzmq, 'ipython_parallel')
29
29
30 from IPython.utils.pickleutil import Reference
30 from IPython.utils.pickleutil import Reference
31
31
32 from .client.asyncresult import *
32 from .client.asyncresult import *
33 from .client.client import Client
33 from .client.client import Client
34 from .client.remotefunction import *
34 from .client.remotefunction import *
35 from .client.view import *
35 from .client.view import *
36 from .controller.dependency import *
36 from .controller.dependency import *
37 from .error import *
37 from .error import *
38 from .util import interactive
38 from .util import interactive
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Functions
41 # Functions
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 def bind_kernel(**kwargs):
45 def bind_kernel(**kwargs):
46 """Bind an Engine's Kernel to be used as a full IPython kernel.
46 """Bind an Engine's Kernel to be used as a full IPython kernel.
47
47
48 This allows a running Engine to be used simultaneously as a full IPython kernel
48 This allows a running Engine to be used simultaneously as a full IPython kernel
49 with the QtConsole or other frontends.
49 with the QtConsole or other frontends.
50
50
51 This function returns immediately.
51 This function returns immediately.
52 """
52 """
53 from IPython.kernel.zmq.kernelapp import IPKernelApp
53 from IPython.kernel.zmq.kernelapp import IPKernelApp
54 from IPython.parallel.apps.ipengineapp import IPEngineApp
54 from ipython_parallel.apps.ipengineapp import IPEngineApp
55
55
56 # first check for IPKernelApp, in which case this should be a no-op
56 # first check for IPKernelApp, in which case this should be a no-op
57 # because there is already a bound kernel
57 # because there is already a bound kernel
58 if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp):
58 if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp):
59 return
59 return
60
60
61 if IPEngineApp.initialized():
61 if IPEngineApp.initialized():
62 try:
62 try:
63 app = IPEngineApp.instance()
63 app = IPEngineApp.instance()
64 except MultipleInstanceError:
64 except MultipleInstanceError:
65 pass
65 pass
66 else:
66 else:
67 return app.bind_kernel(**kwargs)
67 return app.bind_kernel(**kwargs)
68
68
69 raise RuntimeError("bind_kernel be called from an IPEngineApp instance")
69 raise RuntimeError("bind_kernel be called from an IPEngineApp instance")
70
70
71
71
72
72
@@ -1,242 +1,242 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 The Base Application class for IPython.parallel apps
3 The Base Application class for ipython_parallel apps
4 """
4 """
5
5
6
6
7 import os
7 import os
8 import logging
8 import logging
9 import re
9 import re
10 import sys
10 import sys
11
11
12 from IPython.config.application import catch_config_error, LevelFormatter
12 from IPython.config.application import catch_config_error, LevelFormatter
13 from IPython.core import release
13 from IPython.core import release
14 from IPython.core.crashhandler import CrashHandler
14 from IPython.core.crashhandler import CrashHandler
15 from IPython.core.application import (
15 from IPython.core.application import (
16 BaseIPythonApplication,
16 BaseIPythonApplication,
17 base_aliases as base_ip_aliases,
17 base_aliases as base_ip_aliases,
18 base_flags as base_ip_flags
18 base_flags as base_ip_flags
19 )
19 )
20 from IPython.utils.path import expand_path
20 from IPython.utils.path import expand_path
21 from IPython.utils.process import check_pid
21 from IPython.utils.process import check_pid
22 from IPython.utils import py3compat
22 from IPython.utils import py3compat
23 from IPython.utils.py3compat import unicode_type
23 from IPython.utils.py3compat import unicode_type
24
24
25 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict
25 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Module errors
28 # Module errors
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 class PIDFileError(Exception):
31 class PIDFileError(Exception):
32 pass
32 pass
33
33
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Crash handler for this application
36 # Crash handler for this application
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38
38
39 class ParallelCrashHandler(CrashHandler):
39 class ParallelCrashHandler(CrashHandler):
40 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
40 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
41
41
42 def __init__(self, app):
42 def __init__(self, app):
43 contact_name = release.authors['Min'][0]
43 contact_name = release.authors['Min'][0]
44 contact_email = release.author_email
44 contact_email = release.author_email
45 bug_tracker = 'https://github.com/ipython/ipython/issues'
45 bug_tracker = 'https://github.com/ipython/ipython/issues'
46 super(ParallelCrashHandler,self).__init__(
46 super(ParallelCrashHandler,self).__init__(
47 app, contact_name, contact_email, bug_tracker
47 app, contact_name, contact_email, bug_tracker
48 )
48 )
49
49
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Main application
52 # Main application
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 base_aliases = {}
54 base_aliases = {}
55 base_aliases.update(base_ip_aliases)
55 base_aliases.update(base_ip_aliases)
56 base_aliases.update({
56 base_aliases.update({
57 'work-dir' : 'BaseParallelApplication.work_dir',
57 'work-dir' : 'BaseParallelApplication.work_dir',
58 'log-to-file' : 'BaseParallelApplication.log_to_file',
58 'log-to-file' : 'BaseParallelApplication.log_to_file',
59 'clean-logs' : 'BaseParallelApplication.clean_logs',
59 'clean-logs' : 'BaseParallelApplication.clean_logs',
60 'log-url' : 'BaseParallelApplication.log_url',
60 'log-url' : 'BaseParallelApplication.log_url',
61 'cluster-id' : 'BaseParallelApplication.cluster_id',
61 'cluster-id' : 'BaseParallelApplication.cluster_id',
62 })
62 })
63
63
64 base_flags = {
64 base_flags = {
65 'log-to-file' : (
65 'log-to-file' : (
66 {'BaseParallelApplication' : {'log_to_file' : True}},
66 {'BaseParallelApplication' : {'log_to_file' : True}},
67 "send log output to a file"
67 "send log output to a file"
68 )
68 )
69 }
69 }
70 base_flags.update(base_ip_flags)
70 base_flags.update(base_ip_flags)
71
71
72 class BaseParallelApplication(BaseIPythonApplication):
72 class BaseParallelApplication(BaseIPythonApplication):
73 """The base Application for IPython.parallel apps
73 """The base Application for ipython_parallel apps
74
74
75 Principle extensions to BaseIPyythonApplication:
75 Principle extensions to BaseIPyythonApplication:
76
76
77 * work_dir
77 * work_dir
78 * remote logging via pyzmq
78 * remote logging via pyzmq
79 * IOLoop instance
79 * IOLoop instance
80 """
80 """
81
81
82 crash_handler_class = ParallelCrashHandler
82 crash_handler_class = ParallelCrashHandler
83
83
84 def _log_level_default(self):
84 def _log_level_default(self):
85 # temporarily override default_log_level to INFO
85 # temporarily override default_log_level to INFO
86 return logging.INFO
86 return logging.INFO
87
87
88 def _log_format_default(self):
88 def _log_format_default(self):
89 """override default log format to include time"""
89 """override default log format to include time"""
90 return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"
90 return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"
91
91
92 work_dir = Unicode(py3compat.getcwd(), config=True,
92 work_dir = Unicode(py3compat.getcwd(), config=True,
93 help='Set the working dir for the process.'
93 help='Set the working dir for the process.'
94 )
94 )
95 def _work_dir_changed(self, name, old, new):
95 def _work_dir_changed(self, name, old, new):
96 self.work_dir = unicode_type(expand_path(new))
96 self.work_dir = unicode_type(expand_path(new))
97
97
98 log_to_file = Bool(config=True,
98 log_to_file = Bool(config=True,
99 help="whether to log to a file")
99 help="whether to log to a file")
100
100
101 clean_logs = Bool(False, config=True,
101 clean_logs = Bool(False, config=True,
102 help="whether to cleanup old logfiles before starting")
102 help="whether to cleanup old logfiles before starting")
103
103
104 log_url = Unicode('', config=True,
104 log_url = Unicode('', config=True,
105 help="The ZMQ URL of the iplogger to aggregate logging.")
105 help="The ZMQ URL of the iplogger to aggregate logging.")
106
106
107 cluster_id = Unicode('', config=True,
107 cluster_id = Unicode('', config=True,
108 help="""String id to add to runtime files, to prevent name collisions when
108 help="""String id to add to runtime files, to prevent name collisions when
109 using multiple clusters with a single profile simultaneously.
109 using multiple clusters with a single profile simultaneously.
110
110
111 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
111 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
112
112
113 Since this is text inserted into filenames, typical recommendations apply:
113 Since this is text inserted into filenames, typical recommendations apply:
114 Simple character strings are ideal, and spaces are not recommended (but should
114 Simple character strings are ideal, and spaces are not recommended (but should
115 generally work).
115 generally work).
116 """
116 """
117 )
117 )
118 def _cluster_id_changed(self, name, old, new):
118 def _cluster_id_changed(self, name, old, new):
119 self.name = self.__class__.name
119 self.name = self.__class__.name
120 if new:
120 if new:
121 self.name += '-%s'%new
121 self.name += '-%s'%new
122
122
123 def _config_files_default(self):
123 def _config_files_default(self):
124 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
124 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
125
125
126 loop = Instance('zmq.eventloop.ioloop.IOLoop')
126 loop = Instance('zmq.eventloop.ioloop.IOLoop')
127 def _loop_default(self):
127 def _loop_default(self):
128 from zmq.eventloop.ioloop import IOLoop
128 from zmq.eventloop.ioloop import IOLoop
129 return IOLoop.instance()
129 return IOLoop.instance()
130
130
131 aliases = Dict(base_aliases)
131 aliases = Dict(base_aliases)
132 flags = Dict(base_flags)
132 flags = Dict(base_flags)
133
133
134 @catch_config_error
134 @catch_config_error
135 def initialize(self, argv=None):
135 def initialize(self, argv=None):
136 """initialize the app"""
136 """initialize the app"""
137 super(BaseParallelApplication, self).initialize(argv)
137 super(BaseParallelApplication, self).initialize(argv)
138 self.to_work_dir()
138 self.to_work_dir()
139 self.reinit_logging()
139 self.reinit_logging()
140
140
141 def to_work_dir(self):
141 def to_work_dir(self):
142 wd = self.work_dir
142 wd = self.work_dir
143 if unicode_type(wd) != py3compat.getcwd():
143 if unicode_type(wd) != py3compat.getcwd():
144 os.chdir(wd)
144 os.chdir(wd)
145 self.log.info("Changing to working dir: %s" % wd)
145 self.log.info("Changing to working dir: %s" % wd)
146 # This is the working dir by now.
146 # This is the working dir by now.
147 sys.path.insert(0, '')
147 sys.path.insert(0, '')
148
148
149 def reinit_logging(self):
149 def reinit_logging(self):
150 # Remove old log files
150 # Remove old log files
151 log_dir = self.profile_dir.log_dir
151 log_dir = self.profile_dir.log_dir
152 if self.clean_logs:
152 if self.clean_logs:
153 for f in os.listdir(log_dir):
153 for f in os.listdir(log_dir):
154 if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
154 if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
155 try:
155 try:
156 os.remove(os.path.join(log_dir, f))
156 os.remove(os.path.join(log_dir, f))
157 except (OSError, IOError):
157 except (OSError, IOError):
158 # probably just conflict from sibling process
158 # probably just conflict from sibling process
159 # already removing it
159 # already removing it
160 pass
160 pass
161 if self.log_to_file:
161 if self.log_to_file:
162 # Start logging to the new log file
162 # Start logging to the new log file
163 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
163 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
164 logfile = os.path.join(log_dir, log_filename)
164 logfile = os.path.join(log_dir, log_filename)
165 open_log_file = open(logfile, 'w')
165 open_log_file = open(logfile, 'w')
166 else:
166 else:
167 open_log_file = None
167 open_log_file = None
168 if open_log_file is not None:
168 if open_log_file is not None:
169 while self.log.handlers:
169 while self.log.handlers:
170 self.log.removeHandler(self.log.handlers[0])
170 self.log.removeHandler(self.log.handlers[0])
171 self._log_handler = logging.StreamHandler(open_log_file)
171 self._log_handler = logging.StreamHandler(open_log_file)
172 self.log.addHandler(self._log_handler)
172 self.log.addHandler(self._log_handler)
173 else:
173 else:
174 self._log_handler = self.log.handlers[0]
174 self._log_handler = self.log.handlers[0]
175 # Add timestamps to log format:
175 # Add timestamps to log format:
176 self._log_formatter = LevelFormatter(self.log_format,
176 self._log_formatter = LevelFormatter(self.log_format,
177 datefmt=self.log_datefmt)
177 datefmt=self.log_datefmt)
178 self._log_handler.setFormatter(self._log_formatter)
178 self._log_handler.setFormatter(self._log_formatter)
179 # do not propagate log messages to root logger
179 # do not propagate log messages to root logger
180 # ipcluster app will sometimes print duplicate messages during shutdown
180 # ipcluster app will sometimes print duplicate messages during shutdown
181 # if this is 1 (default):
181 # if this is 1 (default):
182 self.log.propagate = False
182 self.log.propagate = False
183
183
184 def write_pid_file(self, overwrite=False):
184 def write_pid_file(self, overwrite=False):
185 """Create a .pid file in the pid_dir with my pid.
185 """Create a .pid file in the pid_dir with my pid.
186
186
187 This must be called after pre_construct, which sets `self.pid_dir`.
187 This must be called after pre_construct, which sets `self.pid_dir`.
188 This raises :exc:`PIDFileError` if the pid file exists already.
188 This raises :exc:`PIDFileError` if the pid file exists already.
189 """
189 """
190 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
190 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
191 if os.path.isfile(pid_file):
191 if os.path.isfile(pid_file):
192 pid = self.get_pid_from_file()
192 pid = self.get_pid_from_file()
193 if not overwrite:
193 if not overwrite:
194 raise PIDFileError(
194 raise PIDFileError(
195 'The pid file [%s] already exists. \nThis could mean that this '
195 'The pid file [%s] already exists. \nThis could mean that this '
196 'server is already running with [pid=%s].' % (pid_file, pid)
196 'server is already running with [pid=%s].' % (pid_file, pid)
197 )
197 )
198 with open(pid_file, 'w') as f:
198 with open(pid_file, 'w') as f:
199 self.log.info("Creating pid file: %s" % pid_file)
199 self.log.info("Creating pid file: %s" % pid_file)
200 f.write(repr(os.getpid())+'\n')
200 f.write(repr(os.getpid())+'\n')
201
201
202 def remove_pid_file(self):
202 def remove_pid_file(self):
203 """Remove the pid file.
203 """Remove the pid file.
204
204
205 This should be called at shutdown by registering a callback with
205 This should be called at shutdown by registering a callback with
206 :func:`reactor.addSystemEventTrigger`. This needs to return
206 :func:`reactor.addSystemEventTrigger`. This needs to return
207 ``None``.
207 ``None``.
208 """
208 """
209 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
209 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
210 if os.path.isfile(pid_file):
210 if os.path.isfile(pid_file):
211 try:
211 try:
212 self.log.info("Removing pid file: %s" % pid_file)
212 self.log.info("Removing pid file: %s" % pid_file)
213 os.remove(pid_file)
213 os.remove(pid_file)
214 except:
214 except:
215 self.log.warn("Error removing the pid file: %s" % pid_file)
215 self.log.warn("Error removing the pid file: %s" % pid_file)
216
216
217 def get_pid_from_file(self):
217 def get_pid_from_file(self):
218 """Get the pid from the pid file.
218 """Get the pid from the pid file.
219
219
220 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
220 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
221 """
221 """
222 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
222 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
223 if os.path.isfile(pid_file):
223 if os.path.isfile(pid_file):
224 with open(pid_file, 'r') as f:
224 with open(pid_file, 'r') as f:
225 s = f.read().strip()
225 s = f.read().strip()
226 try:
226 try:
227 pid = int(s)
227 pid = int(s)
228 except:
228 except:
229 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
229 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
230 return pid
230 return pid
231 else:
231 else:
232 raise PIDFileError('pid file not found: %s' % pid_file)
232 raise PIDFileError('pid file not found: %s' % pid_file)
233
233
234 def check_pid(self, pid):
234 def check_pid(self, pid):
235 try:
235 try:
236 return check_pid(pid)
236 return check_pid(pid)
237 except Exception:
237 except Exception:
238 self.log.warn(
238 self.log.warn(
239 "Could not determine whether pid %i is running. "
239 "Could not determine whether pid %i is running. "
240 " Making the likely assumption that it is."%pid
240 " Making the likely assumption that it is."%pid
241 )
241 )
242 return True
242 return True
@@ -1,596 +1,596 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """The ipcluster application."""
3 """The ipcluster application."""
4 from __future__ import print_function
4 from __future__ import print_function
5
5
6 import errno
6 import errno
7 import logging
7 import logging
8 import os
8 import os
9 import re
9 import re
10 import signal
10 import signal
11
11
12 from subprocess import check_call, CalledProcessError, PIPE
12 from subprocess import check_call, CalledProcessError, PIPE
13 import zmq
13 import zmq
14
14
15 from IPython.config.application import catch_config_error
15 from IPython.config.application import catch_config_error
16 from IPython.config.loader import Config
16 from IPython.config.loader import Config
17 from IPython.core.application import BaseIPythonApplication
17 from IPython.core.application import BaseIPythonApplication
18 from IPython.core.profiledir import ProfileDir
18 from IPython.core.profiledir import ProfileDir
19 from IPython.utils.daemonize import daemonize
19 from IPython.utils.daemonize import daemonize
20 from IPython.utils.importstring import import_item
20 from IPython.utils.importstring import import_item
21 from IPython.utils.py3compat import string_types
21 from IPython.utils.py3compat import string_types
22 from IPython.utils.sysinfo import num_cpus
22 from IPython.utils.sysinfo import num_cpus
23 from IPython.utils.traitlets import (Integer, Unicode, Bool, CFloat, Dict, List, Any,
23 from IPython.utils.traitlets import (Integer, Unicode, Bool, CFloat, Dict, List, Any,
24 DottedObjectName)
24 DottedObjectName)
25
25
26 from IPython.parallel.apps.baseapp import (
26 from ipython_parallel.apps.baseapp import (
27 BaseParallelApplication,
27 BaseParallelApplication,
28 PIDFileError,
28 PIDFileError,
29 base_flags, base_aliases
29 base_flags, base_aliases
30 )
30 )
31
31
32
32
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34 # Module level variables
34 # Module level variables
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36
36
37
37
38 _description = """Start an IPython cluster for parallel computing.
38 _description = """Start an IPython cluster for parallel computing.
39
39
40 An IPython cluster consists of 1 controller and 1 or more engines.
40 An IPython cluster consists of 1 controller and 1 or more engines.
41 This command automates the startup of these processes using a wide range of
41 This command automates the startup of these processes using a wide range of
42 startup methods (SSH, local processes, PBS, mpiexec, SGE, LSF, HTCondor,
42 startup methods (SSH, local processes, PBS, mpiexec, SGE, LSF, HTCondor,
43 Windows HPC Server 2008). To start a cluster with 4 engines on your
43 Windows HPC Server 2008). To start a cluster with 4 engines on your
44 local host simply do 'ipcluster start --n=4'. For more complex usage
44 local host simply do 'ipcluster start --n=4'. For more complex usage
45 you will typically do 'ipython profile create mycluster --parallel', then edit
45 you will typically do 'ipython profile create mycluster --parallel', then edit
46 configuration files, followed by 'ipcluster start --profile=mycluster --n=4'.
46 configuration files, followed by 'ipcluster start --profile=mycluster --n=4'.
47 """
47 """
48
48
49 _main_examples = """
49 _main_examples = """
50 ipcluster start --n=4 # start a 4 node cluster on localhost
50 ipcluster start --n=4 # start a 4 node cluster on localhost
51 ipcluster start -h # show the help string for the start subcmd
51 ipcluster start -h # show the help string for the start subcmd
52
52
53 ipcluster stop -h # show the help string for the stop subcmd
53 ipcluster stop -h # show the help string for the stop subcmd
54 ipcluster engines -h # show the help string for the engines subcmd
54 ipcluster engines -h # show the help string for the engines subcmd
55 """
55 """
56
56
57 _start_examples = """
57 _start_examples = """
58 ipython profile create mycluster --parallel # create mycluster profile
58 ipython profile create mycluster --parallel # create mycluster profile
59 ipcluster start --profile=mycluster --n=4 # start mycluster with 4 nodes
59 ipcluster start --profile=mycluster --n=4 # start mycluster with 4 nodes
60 """
60 """
61
61
62 _stop_examples = """
62 _stop_examples = """
63 ipcluster stop --profile=mycluster # stop a running cluster by profile name
63 ipcluster stop --profile=mycluster # stop a running cluster by profile name
64 """
64 """
65
65
66 _engines_examples = """
66 _engines_examples = """
67 ipcluster engines --profile=mycluster --n=4 # start 4 engines only
67 ipcluster engines --profile=mycluster --n=4 # start 4 engines only
68 """
68 """
69
69
70
70
71 # Exit codes for ipcluster
71 # Exit codes for ipcluster
72
72
73 # This will be the exit code if the ipcluster appears to be running because
73 # This will be the exit code if the ipcluster appears to be running because
74 # a .pid file exists
74 # a .pid file exists
75 ALREADY_STARTED = 10
75 ALREADY_STARTED = 10
76
76
77
77
78 # This will be the exit code if ipcluster stop is run, but there is not .pid
78 # This will be the exit code if ipcluster stop is run, but there is not .pid
79 # file to be found.
79 # file to be found.
80 ALREADY_STOPPED = 11
80 ALREADY_STOPPED = 11
81
81
82 # This will be the exit code if ipcluster engines is run, but there is not .pid
82 # This will be the exit code if ipcluster engines is run, but there is not .pid
83 # file to be found.
83 # file to be found.
84 NO_CLUSTER = 12
84 NO_CLUSTER = 12
85
85
86
86
87 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
88 # Utilities
88 # Utilities
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90
90
91 def find_launcher_class(clsname, kind):
91 def find_launcher_class(clsname, kind):
92 """Return a launcher for a given clsname and kind.
92 """Return a launcher for a given clsname and kind.
93
93
94 Parameters
94 Parameters
95 ==========
95 ==========
96 clsname : str
96 clsname : str
97 The full name of the launcher class, either with or without the
97 The full name of the launcher class, either with or without the
98 module path, or an abbreviation (MPI, SSH, SGE, PBS, LSF, HTCondor
98 module path, or an abbreviation (MPI, SSH, SGE, PBS, LSF, HTCondor
99 WindowsHPC).
99 WindowsHPC).
100 kind : str
100 kind : str
101 Either 'EngineSet' or 'Controller'.
101 Either 'EngineSet' or 'Controller'.
102 """
102 """
103 if '.' not in clsname:
103 if '.' not in clsname:
104 # not a module, presume it's the raw name in apps.launcher
104 # not a module, presume it's the raw name in apps.launcher
105 if kind and kind not in clsname:
105 if kind and kind not in clsname:
106 # doesn't match necessary full class name, assume it's
106 # doesn't match necessary full class name, assume it's
107 # just 'PBS' or 'MPI' etc prefix:
107 # just 'PBS' or 'MPI' etc prefix:
108 clsname = clsname + kind + 'Launcher'
108 clsname = clsname + kind + 'Launcher'
109 clsname = 'IPython.parallel.apps.launcher.'+clsname
109 clsname = 'ipython_parallel.apps.launcher.'+clsname
110 klass = import_item(clsname)
110 klass = import_item(clsname)
111 return klass
111 return klass
112
112
113 #-----------------------------------------------------------------------------
113 #-----------------------------------------------------------------------------
114 # Main application
114 # Main application
115 #-----------------------------------------------------------------------------
115 #-----------------------------------------------------------------------------
116
116
117 start_help = """Start an IPython cluster for parallel computing
117 start_help = """Start an IPython cluster for parallel computing
118
118
119 Start an ipython cluster by its profile name or cluster
119 Start an ipython cluster by its profile name or cluster
120 directory. Cluster directories contain configuration, log and
120 directory. Cluster directories contain configuration, log and
121 security related files and are named using the convention
121 security related files and are named using the convention
122 'profile_<name>' and should be creating using the 'start'
122 'profile_<name>' and should be creating using the 'start'
123 subcommand of 'ipcluster'. If your cluster directory is in
123 subcommand of 'ipcluster'. If your cluster directory is in
124 the cwd or the ipython directory, you can simply refer to it
124 the cwd or the ipython directory, you can simply refer to it
125 using its profile name, 'ipcluster start --n=4 --profile=<profile>`,
125 using its profile name, 'ipcluster start --n=4 --profile=<profile>`,
126 otherwise use the 'profile-dir' option.
126 otherwise use the 'profile-dir' option.
127 """
127 """
128 stop_help = """Stop a running IPython cluster
128 stop_help = """Stop a running IPython cluster
129
129
130 Stop a running ipython cluster by its profile name or cluster
130 Stop a running ipython cluster by its profile name or cluster
131 directory. Cluster directories are named using the convention
131 directory. Cluster directories are named using the convention
132 'profile_<name>'. If your cluster directory is in
132 'profile_<name>'. If your cluster directory is in
133 the cwd or the ipython directory, you can simply refer to it
133 the cwd or the ipython directory, you can simply refer to it
134 using its profile name, 'ipcluster stop --profile=<profile>`, otherwise
134 using its profile name, 'ipcluster stop --profile=<profile>`, otherwise
135 use the '--profile-dir' option.
135 use the '--profile-dir' option.
136 """
136 """
137 engines_help = """Start engines connected to an existing IPython cluster
137 engines_help = """Start engines connected to an existing IPython cluster
138
138
139 Start one or more engines to connect to an existing Cluster
139 Start one or more engines to connect to an existing Cluster
140 by profile name or cluster directory.
140 by profile name or cluster directory.
141 Cluster directories contain configuration, log and
141 Cluster directories contain configuration, log and
142 security related files and are named using the convention
142 security related files and are named using the convention
143 'profile_<name>' and should be creating using the 'start'
143 'profile_<name>' and should be creating using the 'start'
144 subcommand of 'ipcluster'. If your cluster directory is in
144 subcommand of 'ipcluster'. If your cluster directory is in
145 the cwd or the ipython directory, you can simply refer to it
145 the cwd or the ipython directory, you can simply refer to it
146 using its profile name, 'ipcluster engines --n=4 --profile=<profile>`,
146 using its profile name, 'ipcluster engines --n=4 --profile=<profile>`,
147 otherwise use the 'profile-dir' option.
147 otherwise use the 'profile-dir' option.
148 """
148 """
149 stop_aliases = dict(
149 stop_aliases = dict(
150 signal='IPClusterStop.signal',
150 signal='IPClusterStop.signal',
151 )
151 )
152 stop_aliases.update(base_aliases)
152 stop_aliases.update(base_aliases)
153
153
154 class IPClusterStop(BaseParallelApplication):
154 class IPClusterStop(BaseParallelApplication):
155 name = u'ipcluster'
155 name = u'ipcluster'
156 description = stop_help
156 description = stop_help
157 examples = _stop_examples
157 examples = _stop_examples
158
158
159 signal = Integer(signal.SIGINT, config=True,
159 signal = Integer(signal.SIGINT, config=True,
160 help="signal to use for stopping processes.")
160 help="signal to use for stopping processes.")
161
161
162 aliases = Dict(stop_aliases)
162 aliases = Dict(stop_aliases)
163
163
164 def start(self):
164 def start(self):
165 """Start the app for the stop subcommand."""
165 """Start the app for the stop subcommand."""
166 try:
166 try:
167 pid = self.get_pid_from_file()
167 pid = self.get_pid_from_file()
168 except PIDFileError:
168 except PIDFileError:
169 self.log.critical(
169 self.log.critical(
170 'Could not read pid file, cluster is probably not running.'
170 'Could not read pid file, cluster is probably not running.'
171 )
171 )
172 # Here I exit with a unusual exit status that other processes
172 # Here I exit with a unusual exit status that other processes
173 # can watch for to learn how I existed.
173 # can watch for to learn how I existed.
174 self.remove_pid_file()
174 self.remove_pid_file()
175 self.exit(ALREADY_STOPPED)
175 self.exit(ALREADY_STOPPED)
176
176
177 if not self.check_pid(pid):
177 if not self.check_pid(pid):
178 self.log.critical(
178 self.log.critical(
179 'Cluster [pid=%r] is not running.' % pid
179 'Cluster [pid=%r] is not running.' % pid
180 )
180 )
181 self.remove_pid_file()
181 self.remove_pid_file()
182 # Here I exit with a unusual exit status that other processes
182 # Here I exit with a unusual exit status that other processes
183 # can watch for to learn how I existed.
183 # can watch for to learn how I existed.
184 self.exit(ALREADY_STOPPED)
184 self.exit(ALREADY_STOPPED)
185
185
186 elif os.name=='posix':
186 elif os.name=='posix':
187 sig = self.signal
187 sig = self.signal
188 self.log.info(
188 self.log.info(
189 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
189 "Stopping cluster [pid=%r] with [signal=%r]" % (pid, sig)
190 )
190 )
191 try:
191 try:
192 os.kill(pid, sig)
192 os.kill(pid, sig)
193 except OSError:
193 except OSError:
194 self.log.error("Stopping cluster failed, assuming already dead.",
194 self.log.error("Stopping cluster failed, assuming already dead.",
195 exc_info=True)
195 exc_info=True)
196 self.remove_pid_file()
196 self.remove_pid_file()
197 elif os.name=='nt':
197 elif os.name=='nt':
198 try:
198 try:
199 # kill the whole tree
199 # kill the whole tree
200 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
200 p = check_call(['taskkill', '-pid', str(pid), '-t', '-f'], stdout=PIPE,stderr=PIPE)
201 except (CalledProcessError, OSError):
201 except (CalledProcessError, OSError):
202 self.log.error("Stopping cluster failed, assuming already dead.",
202 self.log.error("Stopping cluster failed, assuming already dead.",
203 exc_info=True)
203 exc_info=True)
204 self.remove_pid_file()
204 self.remove_pid_file()
205
205
206 engine_aliases = {}
206 engine_aliases = {}
207 engine_aliases.update(base_aliases)
207 engine_aliases.update(base_aliases)
208 engine_aliases.update(dict(
208 engine_aliases.update(dict(
209 n='IPClusterEngines.n',
209 n='IPClusterEngines.n',
210 engines = 'IPClusterEngines.engine_launcher_class',
210 engines = 'IPClusterEngines.engine_launcher_class',
211 daemonize = 'IPClusterEngines.daemonize',
211 daemonize = 'IPClusterEngines.daemonize',
212 ))
212 ))
213 engine_flags = {}
213 engine_flags = {}
214 engine_flags.update(base_flags)
214 engine_flags.update(base_flags)
215
215
216 engine_flags.update(dict(
216 engine_flags.update(dict(
217 daemonize=(
217 daemonize=(
218 {'IPClusterEngines' : {'daemonize' : True}},
218 {'IPClusterEngines' : {'daemonize' : True}},
219 """run the cluster into the background (not available on Windows)""",
219 """run the cluster into the background (not available on Windows)""",
220 )
220 )
221 ))
221 ))
222 class IPClusterEngines(BaseParallelApplication):
222 class IPClusterEngines(BaseParallelApplication):
223
223
224 name = u'ipcluster'
224 name = u'ipcluster'
225 description = engines_help
225 description = engines_help
226 examples = _engines_examples
226 examples = _engines_examples
227 usage = None
227 usage = None
228 default_log_level = logging.INFO
228 default_log_level = logging.INFO
229 classes = List()
229 classes = List()
230 def _classes_default(self):
230 def _classes_default(self):
231 from IPython.parallel.apps import launcher
231 from ipython_parallel.apps import launcher
232 launchers = launcher.all_launchers
232 launchers = launcher.all_launchers
233 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
233 eslaunchers = [ l for l in launchers if 'EngineSet' in l.__name__]
234 return [ProfileDir]+eslaunchers
234 return [ProfileDir]+eslaunchers
235
235
236 n = Integer(num_cpus(), config=True,
236 n = Integer(num_cpus(), config=True,
237 help="""The number of engines to start. The default is to use one for each
237 help="""The number of engines to start. The default is to use one for each
238 CPU on your machine""")
238 CPU on your machine""")
239
239
240 engine_launcher = Any(config=True, help="Deprecated, use engine_launcher_class")
240 engine_launcher = Any(config=True, help="Deprecated, use engine_launcher_class")
241 def _engine_launcher_changed(self, name, old, new):
241 def _engine_launcher_changed(self, name, old, new):
242 if isinstance(new, string_types):
242 if isinstance(new, string_types):
243 self.log.warn("WARNING: %s.engine_launcher is deprecated as of 0.12,"
243 self.log.warn("WARNING: %s.engine_launcher is deprecated as of 0.12,"
244 " use engine_launcher_class" % self.__class__.__name__)
244 " use engine_launcher_class" % self.__class__.__name__)
245 self.engine_launcher_class = new
245 self.engine_launcher_class = new
246 engine_launcher_class = DottedObjectName('LocalEngineSetLauncher',
246 engine_launcher_class = DottedObjectName('LocalEngineSetLauncher',
247 config=True,
247 config=True,
248 help="""The class for launching a set of Engines. Change this value
248 help="""The class for launching a set of Engines. Change this value
249 to use various batch systems to launch your engines, such as PBS,SGE,MPI,etc.
249 to use various batch systems to launch your engines, such as PBS,SGE,MPI,etc.
250 Each launcher class has its own set of configuration options, for making sure
250 Each launcher class has its own set of configuration options, for making sure
251 it will work in your environment.
251 it will work in your environment.
252
252
253 You can also write your own launcher, and specify it's absolute import path,
253 You can also write your own launcher, and specify it's absolute import path,
254 as in 'mymodule.launcher.FTLEnginesLauncher`.
254 as in 'mymodule.launcher.FTLEnginesLauncher`.
255
255
256 IPython's bundled examples include:
256 IPython's bundled examples include:
257
257
258 Local : start engines locally as subprocesses [default]
258 Local : start engines locally as subprocesses [default]
259 MPI : use mpiexec to launch engines in an MPI environment
259 MPI : use mpiexec to launch engines in an MPI environment
260 PBS : use PBS (qsub) to submit engines to a batch queue
260 PBS : use PBS (qsub) to submit engines to a batch queue
261 SGE : use SGE (qsub) to submit engines to a batch queue
261 SGE : use SGE (qsub) to submit engines to a batch queue
262 LSF : use LSF (bsub) to submit engines to a batch queue
262 LSF : use LSF (bsub) to submit engines to a batch queue
263 SSH : use SSH to start the controller
263 SSH : use SSH to start the controller
264 Note that SSH does *not* move the connection files
264 Note that SSH does *not* move the connection files
265 around, so you will likely have to do this manually
265 around, so you will likely have to do this manually
266 unless the machines are on a shared file system.
266 unless the machines are on a shared file system.
267 HTCondor : use HTCondor to submit engines to a batch queue
267 HTCondor : use HTCondor to submit engines to a batch queue
268 WindowsHPC : use Windows HPC
268 WindowsHPC : use Windows HPC
269
269
270 If you are using one of IPython's builtin launchers, you can specify just the
270 If you are using one of IPython's builtin launchers, you can specify just the
271 prefix, e.g:
271 prefix, e.g:
272
272
273 c.IPClusterEngines.engine_launcher_class = 'SSH'
273 c.IPClusterEngines.engine_launcher_class = 'SSH'
274
274
275 or:
275 or:
276
276
277 ipcluster start --engines=MPI
277 ipcluster start --engines=MPI
278
278
279 """
279 """
280 )
280 )
281 daemonize = Bool(False, config=True,
281 daemonize = Bool(False, config=True,
282 help="""Daemonize the ipcluster program. This implies --log-to-file.
282 help="""Daemonize the ipcluster program. This implies --log-to-file.
283 Not available on Windows.
283 Not available on Windows.
284 """)
284 """)
285
285
286 def _daemonize_changed(self, name, old, new):
286 def _daemonize_changed(self, name, old, new):
287 if new:
287 if new:
288 self.log_to_file = True
288 self.log_to_file = True
289
289
290 early_shutdown = Integer(30, config=True, help="The timeout (in seconds)")
290 early_shutdown = Integer(30, config=True, help="The timeout (in seconds)")
291 _stopping = False
291 _stopping = False
292
292
293 aliases = Dict(engine_aliases)
293 aliases = Dict(engine_aliases)
294 flags = Dict(engine_flags)
294 flags = Dict(engine_flags)
295
295
296 @catch_config_error
296 @catch_config_error
297 def initialize(self, argv=None):
297 def initialize(self, argv=None):
298 super(IPClusterEngines, self).initialize(argv)
298 super(IPClusterEngines, self).initialize(argv)
299 self.init_signal()
299 self.init_signal()
300 self.init_launchers()
300 self.init_launchers()
301
301
302 def init_launchers(self):
302 def init_launchers(self):
303 self.engine_launcher = self.build_launcher(self.engine_launcher_class, 'EngineSet')
303 self.engine_launcher = self.build_launcher(self.engine_launcher_class, 'EngineSet')
304
304
305 def init_signal(self):
305 def init_signal(self):
306 # Setup signals
306 # Setup signals
307 signal.signal(signal.SIGINT, self.sigint_handler)
307 signal.signal(signal.SIGINT, self.sigint_handler)
308
308
309 def build_launcher(self, clsname, kind=None):
309 def build_launcher(self, clsname, kind=None):
310 """import and instantiate a Launcher based on importstring"""
310 """import and instantiate a Launcher based on importstring"""
311 try:
311 try:
312 klass = find_launcher_class(clsname, kind)
312 klass = find_launcher_class(clsname, kind)
313 except (ImportError, KeyError):
313 except (ImportError, KeyError):
314 self.log.fatal("Could not import launcher class: %r"%clsname)
314 self.log.fatal("Could not import launcher class: %r"%clsname)
315 self.exit(1)
315 self.exit(1)
316
316
317 launcher = klass(
317 launcher = klass(
318 work_dir=u'.', parent=self, log=self.log,
318 work_dir=u'.', parent=self, log=self.log,
319 profile_dir=self.profile_dir.location, cluster_id=self.cluster_id,
319 profile_dir=self.profile_dir.location, cluster_id=self.cluster_id,
320 )
320 )
321 return launcher
321 return launcher
322
322
323 def engines_started_ok(self):
323 def engines_started_ok(self):
324 self.log.info("Engines appear to have started successfully")
324 self.log.info("Engines appear to have started successfully")
325 self.early_shutdown = 0
325 self.early_shutdown = 0
326
326
327 def start_engines(self):
327 def start_engines(self):
328 # Some EngineSetLaunchers ignore `n` and use their own engine count, such as SSH:
328 # Some EngineSetLaunchers ignore `n` and use their own engine count, such as SSH:
329 n = getattr(self.engine_launcher, 'engine_count', self.n)
329 n = getattr(self.engine_launcher, 'engine_count', self.n)
330 self.log.info("Starting %s Engines with %s", n, self.engine_launcher_class)
330 self.log.info("Starting %s Engines with %s", n, self.engine_launcher_class)
331 try:
331 try:
332 self.engine_launcher.start(self.n)
332 self.engine_launcher.start(self.n)
333 except:
333 except:
334 self.log.exception("Engine start failed")
334 self.log.exception("Engine start failed")
335 raise
335 raise
336 self.engine_launcher.on_stop(self.engines_stopped_early)
336 self.engine_launcher.on_stop(self.engines_stopped_early)
337 if self.early_shutdown:
337 if self.early_shutdown:
338 self.loop.add_timeout(self.loop.time() + self.early_shutdown, self.engines_started_ok)
338 self.loop.add_timeout(self.loop.time() + self.early_shutdown, self.engines_started_ok)
339
339
340 def engines_stopped_early(self, r):
340 def engines_stopped_early(self, r):
341 if self.early_shutdown and not self._stopping:
341 if self.early_shutdown and not self._stopping:
342 self.log.error("""
342 self.log.error("""
343 Engines shutdown early, they probably failed to connect.
343 Engines shutdown early, they probably failed to connect.
344
344
345 Check the engine log files for output.
345 Check the engine log files for output.
346
346
347 If your controller and engines are not on the same machine, you probably
347 If your controller and engines are not on the same machine, you probably
348 have to instruct the controller to listen on an interface other than localhost.
348 have to instruct the controller to listen on an interface other than localhost.
349
349
350 You can set this by adding "--ip='*'" to your ControllerLauncher.controller_args.
350 You can set this by adding "--ip='*'" to your ControllerLauncher.controller_args.
351
351
352 Be sure to read our security docs before instructing your controller to listen on
352 Be sure to read our security docs before instructing your controller to listen on
353 a public interface.
353 a public interface.
354 """)
354 """)
355 self.stop_launchers()
355 self.stop_launchers()
356
356
357 return self.engines_stopped(r)
357 return self.engines_stopped(r)
358
358
359 def engines_stopped(self, r):
359 def engines_stopped(self, r):
360 return self.loop.stop()
360 return self.loop.stop()
361
361
362 def stop_engines(self):
362 def stop_engines(self):
363 if self.engine_launcher.running:
363 if self.engine_launcher.running:
364 self.log.info("Stopping Engines...")
364 self.log.info("Stopping Engines...")
365 d = self.engine_launcher.stop()
365 d = self.engine_launcher.stop()
366 return d
366 return d
367 else:
367 else:
368 return None
368 return None
369
369
370 def stop_launchers(self, r=None):
370 def stop_launchers(self, r=None):
371 if not self._stopping:
371 if not self._stopping:
372 self._stopping = True
372 self._stopping = True
373 self.log.error("IPython cluster: stopping")
373 self.log.error("IPython cluster: stopping")
374 self.stop_engines()
374 self.stop_engines()
375 # Wait a few seconds to let things shut down.
375 # Wait a few seconds to let things shut down.
376 self.loop.add_timeout(self.loop.time() + 3, self.loop.stop)
376 self.loop.add_timeout(self.loop.time() + 3, self.loop.stop)
377
377
378 def sigint_handler(self, signum, frame):
378 def sigint_handler(self, signum, frame):
379 self.log.debug("SIGINT received, stopping launchers...")
379 self.log.debug("SIGINT received, stopping launchers...")
380 self.stop_launchers()
380 self.stop_launchers()
381
381
382 def start_logging(self):
382 def start_logging(self):
383 # Remove old log files of the controller and engine
383 # Remove old log files of the controller and engine
384 if self.clean_logs:
384 if self.clean_logs:
385 log_dir = self.profile_dir.log_dir
385 log_dir = self.profile_dir.log_dir
386 for f in os.listdir(log_dir):
386 for f in os.listdir(log_dir):
387 if re.match(r'ip(engine|controller)-.+\.(log|err|out)',f):
387 if re.match(r'ip(engine|controller)-.+\.(log|err|out)',f):
388 os.remove(os.path.join(log_dir, f))
388 os.remove(os.path.join(log_dir, f))
389
389
390 def start(self):
390 def start(self):
391 """Start the app for the engines subcommand."""
391 """Start the app for the engines subcommand."""
392 self.log.info("IPython cluster: started")
392 self.log.info("IPython cluster: started")
393 # First see if the cluster is already running
393 # First see if the cluster is already running
394
394
395 # Now log and daemonize
395 # Now log and daemonize
396 self.log.info(
396 self.log.info(
397 'Starting engines with [daemon=%r]' % self.daemonize
397 'Starting engines with [daemon=%r]' % self.daemonize
398 )
398 )
399 # TODO: Get daemonize working on Windows or as a Windows Server.
399 # TODO: Get daemonize working on Windows or as a Windows Server.
400 if self.daemonize:
400 if self.daemonize:
401 if os.name=='posix':
401 if os.name=='posix':
402 daemonize()
402 daemonize()
403
403
404 self.loop.add_callback(self.start_engines)
404 self.loop.add_callback(self.start_engines)
405 # Now write the new pid file AFTER our new forked pid is active.
405 # Now write the new pid file AFTER our new forked pid is active.
406 # self.write_pid_file()
406 # self.write_pid_file()
407 try:
407 try:
408 self.loop.start()
408 self.loop.start()
409 except KeyboardInterrupt:
409 except KeyboardInterrupt:
410 pass
410 pass
411 except zmq.ZMQError as e:
411 except zmq.ZMQError as e:
412 if e.errno == errno.EINTR:
412 if e.errno == errno.EINTR:
413 pass
413 pass
414 else:
414 else:
415 raise
415 raise
416
416
417 start_aliases = {}
417 start_aliases = {}
418 start_aliases.update(engine_aliases)
418 start_aliases.update(engine_aliases)
419 start_aliases.update(dict(
419 start_aliases.update(dict(
420 delay='IPClusterStart.delay',
420 delay='IPClusterStart.delay',
421 controller = 'IPClusterStart.controller_launcher_class',
421 controller = 'IPClusterStart.controller_launcher_class',
422 ))
422 ))
423 start_aliases['clean-logs'] = 'IPClusterStart.clean_logs'
423 start_aliases['clean-logs'] = 'IPClusterStart.clean_logs'
424
424
425 class IPClusterStart(IPClusterEngines):
425 class IPClusterStart(IPClusterEngines):
426
426
427 name = u'ipcluster'
427 name = u'ipcluster'
428 description = start_help
428 description = start_help
429 examples = _start_examples
429 examples = _start_examples
430 default_log_level = logging.INFO
430 default_log_level = logging.INFO
431 auto_create = Bool(True, config=True,
431 auto_create = Bool(True, config=True,
432 help="whether to create the profile_dir if it doesn't exist")
432 help="whether to create the profile_dir if it doesn't exist")
433 classes = List()
433 classes = List()
434 def _classes_default(self,):
434 def _classes_default(self,):
435 from IPython.parallel.apps import launcher
435 from ipython_parallel.apps import launcher
436 return [ProfileDir] + [IPClusterEngines] + launcher.all_launchers
436 return [ProfileDir] + [IPClusterEngines] + launcher.all_launchers
437
437
438 clean_logs = Bool(True, config=True,
438 clean_logs = Bool(True, config=True,
439 help="whether to cleanup old logs before starting")
439 help="whether to cleanup old logs before starting")
440
440
441 delay = CFloat(1., config=True,
441 delay = CFloat(1., config=True,
442 help="delay (in s) between starting the controller and the engines")
442 help="delay (in s) between starting the controller and the engines")
443
443
444 controller_launcher = Any(config=True, help="Deprecated, use controller_launcher_class")
444 controller_launcher = Any(config=True, help="Deprecated, use controller_launcher_class")
445 def _controller_launcher_changed(self, name, old, new):
445 def _controller_launcher_changed(self, name, old, new):
446 if isinstance(new, string_types):
446 if isinstance(new, string_types):
447 # old 0.11-style config
447 # old 0.11-style config
448 self.log.warn("WARNING: %s.controller_launcher is deprecated as of 0.12,"
448 self.log.warn("WARNING: %s.controller_launcher is deprecated as of 0.12,"
449 " use controller_launcher_class" % self.__class__.__name__)
449 " use controller_launcher_class" % self.__class__.__name__)
450 self.controller_launcher_class = new
450 self.controller_launcher_class = new
451 controller_launcher_class = DottedObjectName('LocalControllerLauncher',
451 controller_launcher_class = DottedObjectName('LocalControllerLauncher',
452 config=True,
452 config=True,
453 help="""The class for launching a Controller. Change this value if you want
453 help="""The class for launching a Controller. Change this value if you want
454 your controller to also be launched by a batch system, such as PBS,SGE,MPI,etc.
454 your controller to also be launched by a batch system, such as PBS,SGE,MPI,etc.
455
455
456 Each launcher class has its own set of configuration options, for making sure
456 Each launcher class has its own set of configuration options, for making sure
457 it will work in your environment.
457 it will work in your environment.
458
458
459 Note that using a batch launcher for the controller *does not* put it
459 Note that using a batch launcher for the controller *does not* put it
460 in the same batch job as the engines, so they will still start separately.
460 in the same batch job as the engines, so they will still start separately.
461
461
462 IPython's bundled examples include:
462 IPython's bundled examples include:
463
463
464 Local : start engines locally as subprocesses
464 Local : start engines locally as subprocesses
465 MPI : use mpiexec to launch the controller in an MPI universe
465 MPI : use mpiexec to launch the controller in an MPI universe
466 PBS : use PBS (qsub) to submit the controller to a batch queue
466 PBS : use PBS (qsub) to submit the controller to a batch queue
467 SGE : use SGE (qsub) to submit the controller to a batch queue
467 SGE : use SGE (qsub) to submit the controller to a batch queue
468 LSF : use LSF (bsub) to submit the controller to a batch queue
468 LSF : use LSF (bsub) to submit the controller to a batch queue
469 HTCondor : use HTCondor to submit the controller to a batch queue
469 HTCondor : use HTCondor to submit the controller to a batch queue
470 SSH : use SSH to start the controller
470 SSH : use SSH to start the controller
471 WindowsHPC : use Windows HPC
471 WindowsHPC : use Windows HPC
472
472
473 If you are using one of IPython's builtin launchers, you can specify just the
473 If you are using one of IPython's builtin launchers, you can specify just the
474 prefix, e.g:
474 prefix, e.g:
475
475
476 c.IPClusterStart.controller_launcher_class = 'SSH'
476 c.IPClusterStart.controller_launcher_class = 'SSH'
477
477
478 or:
478 or:
479
479
480 ipcluster start --controller=MPI
480 ipcluster start --controller=MPI
481
481
482 """
482 """
483 )
483 )
484 reset = Bool(False, config=True,
484 reset = Bool(False, config=True,
485 help="Whether to reset config files as part of '--create'."
485 help="Whether to reset config files as part of '--create'."
486 )
486 )
487
487
488 # flags = Dict(flags)
488 # flags = Dict(flags)
489 aliases = Dict(start_aliases)
489 aliases = Dict(start_aliases)
490
490
491 def init_launchers(self):
491 def init_launchers(self):
492 self.controller_launcher = self.build_launcher(self.controller_launcher_class, 'Controller')
492 self.controller_launcher = self.build_launcher(self.controller_launcher_class, 'Controller')
493 self.engine_launcher = self.build_launcher(self.engine_launcher_class, 'EngineSet')
493 self.engine_launcher = self.build_launcher(self.engine_launcher_class, 'EngineSet')
494
494
495 def engines_stopped(self, r):
495 def engines_stopped(self, r):
496 """prevent parent.engines_stopped from stopping everything on engine shutdown"""
496 """prevent parent.engines_stopped from stopping everything on engine shutdown"""
497 pass
497 pass
498
498
499 def start_controller(self):
499 def start_controller(self):
500 self.log.info("Starting Controller with %s", self.controller_launcher_class)
500 self.log.info("Starting Controller with %s", self.controller_launcher_class)
501 self.controller_launcher.on_stop(self.stop_launchers)
501 self.controller_launcher.on_stop(self.stop_launchers)
502 try:
502 try:
503 self.controller_launcher.start()
503 self.controller_launcher.start()
504 except:
504 except:
505 self.log.exception("Controller start failed")
505 self.log.exception("Controller start failed")
506 raise
506 raise
507
507
508 def stop_controller(self):
508 def stop_controller(self):
509 # self.log.info("In stop_controller")
509 # self.log.info("In stop_controller")
510 if self.controller_launcher and self.controller_launcher.running:
510 if self.controller_launcher and self.controller_launcher.running:
511 return self.controller_launcher.stop()
511 return self.controller_launcher.stop()
512
512
513 def stop_launchers(self, r=None):
513 def stop_launchers(self, r=None):
514 if not self._stopping:
514 if not self._stopping:
515 self.stop_controller()
515 self.stop_controller()
516 super(IPClusterStart, self).stop_launchers()
516 super(IPClusterStart, self).stop_launchers()
517
517
518 def start(self):
518 def start(self):
519 """Start the app for the start subcommand."""
519 """Start the app for the start subcommand."""
520 # First see if the cluster is already running
520 # First see if the cluster is already running
521 try:
521 try:
522 pid = self.get_pid_from_file()
522 pid = self.get_pid_from_file()
523 except PIDFileError:
523 except PIDFileError:
524 pass
524 pass
525 else:
525 else:
526 if self.check_pid(pid):
526 if self.check_pid(pid):
527 self.log.critical(
527 self.log.critical(
528 'Cluster is already running with [pid=%s]. '
528 'Cluster is already running with [pid=%s]. '
529 'use "ipcluster stop" to stop the cluster.' % pid
529 'use "ipcluster stop" to stop the cluster.' % pid
530 )
530 )
531 # Here I exit with a unusual exit status that other processes
531 # Here I exit with a unusual exit status that other processes
532 # can watch for to learn how I existed.
532 # can watch for to learn how I existed.
533 self.exit(ALREADY_STARTED)
533 self.exit(ALREADY_STARTED)
534 else:
534 else:
535 self.remove_pid_file()
535 self.remove_pid_file()
536
536
537
537
538 # Now log and daemonize
538 # Now log and daemonize
539 self.log.info(
539 self.log.info(
540 'Starting ipcluster with [daemon=%r]' % self.daemonize
540 'Starting ipcluster with [daemon=%r]' % self.daemonize
541 )
541 )
542 # TODO: Get daemonize working on Windows or as a Windows Server.
542 # TODO: Get daemonize working on Windows or as a Windows Server.
543 if self.daemonize:
543 if self.daemonize:
544 if os.name=='posix':
544 if os.name=='posix':
545 daemonize()
545 daemonize()
546
546
547 def start():
547 def start():
548 self.start_controller()
548 self.start_controller()
549 self.loop.add_timeout(self.loop.time() + self.delay, self.start_engines)
549 self.loop.add_timeout(self.loop.time() + self.delay, self.start_engines)
550 self.loop.add_callback(start)
550 self.loop.add_callback(start)
551 # Now write the new pid file AFTER our new forked pid is active.
551 # Now write the new pid file AFTER our new forked pid is active.
552 self.write_pid_file()
552 self.write_pid_file()
553 try:
553 try:
554 self.loop.start()
554 self.loop.start()
555 except KeyboardInterrupt:
555 except KeyboardInterrupt:
556 pass
556 pass
557 except zmq.ZMQError as e:
557 except zmq.ZMQError as e:
558 if e.errno == errno.EINTR:
558 if e.errno == errno.EINTR:
559 pass
559 pass
560 else:
560 else:
561 raise
561 raise
562 finally:
562 finally:
563 self.remove_pid_file()
563 self.remove_pid_file()
564
564
565 base='IPython.parallel.apps.ipclusterapp.IPCluster'
565 base='ipython_parallel.apps.ipclusterapp.IPCluster'
566
566
567 class IPClusterApp(BaseIPythonApplication):
567 class IPClusterApp(BaseIPythonApplication):
568 name = u'ipcluster'
568 name = u'ipcluster'
569 description = _description
569 description = _description
570 examples = _main_examples
570 examples = _main_examples
571
571
572 subcommands = {
572 subcommands = {
573 'start' : (base+'Start', start_help),
573 'start' : (base+'Start', start_help),
574 'stop' : (base+'Stop', stop_help),
574 'stop' : (base+'Stop', stop_help),
575 'engines' : (base+'Engines', engines_help),
575 'engines' : (base+'Engines', engines_help),
576 }
576 }
577
577
578 # no aliases or flags for parent App
578 # no aliases or flags for parent App
579 aliases = Dict()
579 aliases = Dict()
580 flags = Dict()
580 flags = Dict()
581
581
582 def start(self):
582 def start(self):
583 if self.subapp is None:
583 if self.subapp is None:
584 print("No subcommand specified. Must specify one of: %s"%(self.subcommands.keys()))
584 print("No subcommand specified. Must specify one of: %s"%(self.subcommands.keys()))
585 print()
585 print()
586 self.print_description()
586 self.print_description()
587 self.print_subcommands()
587 self.print_subcommands()
588 self.exit(1)
588 self.exit(1)
589 else:
589 else:
590 return self.subapp.start()
590 return self.subapp.start()
591
591
592 launch_new_instance = IPClusterApp.launch_instance
592 launch_new_instance = IPClusterApp.launch_instance
593
593
594 if __name__ == '__main__':
594 if __name__ == '__main__':
595 launch_new_instance()
595 launch_new_instance()
596
596
@@ -1,548 +1,548 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008 The IPython Development Team
14 # Copyright (C) 2008 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 from __future__ import with_statement
24 from __future__ import with_statement
25
25
26 import json
26 import json
27 import os
27 import os
28 import stat
28 import stat
29 import sys
29 import sys
30
30
31 from multiprocessing import Process
31 from multiprocessing import Process
32 from signal import signal, SIGINT, SIGABRT, SIGTERM
32 from signal import signal, SIGINT, SIGABRT, SIGTERM
33
33
34 import zmq
34 import zmq
35 from zmq.devices import ProcessMonitoredQueue
35 from zmq.devices import ProcessMonitoredQueue
36 from zmq.log.handlers import PUBHandler
36 from zmq.log.handlers import PUBHandler
37
37
38 from IPython.core.profiledir import ProfileDir
38 from IPython.core.profiledir import ProfileDir
39
39
40 from IPython.parallel.apps.baseapp import (
40 from ipython_parallel.apps.baseapp import (
41 BaseParallelApplication,
41 BaseParallelApplication,
42 base_aliases,
42 base_aliases,
43 base_flags,
43 base_flags,
44 catch_config_error,
44 catch_config_error,
45 )
45 )
46 from IPython.utils.importstring import import_item
46 from IPython.utils.importstring import import_item
47 from IPython.utils.localinterfaces import localhost, public_ips
47 from IPython.utils.localinterfaces import localhost, public_ips
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
49
49
50 from IPython.kernel.zmq.session import (
50 from IPython.kernel.zmq.session import (
51 Session, session_aliases, session_flags,
51 Session, session_aliases, session_flags,
52 )
52 )
53
53
54 from IPython.parallel.controller.heartmonitor import HeartMonitor
54 from ipython_parallel.controller.heartmonitor import HeartMonitor
55 from IPython.parallel.controller.hub import HubFactory
55 from ipython_parallel.controller.hub import HubFactory
56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
56 from ipython_parallel.controller.scheduler import TaskScheduler,launch_scheduler
57 from IPython.parallel.controller.dictdb import DictDB
57 from ipython_parallel.controller.dictdb import DictDB
58
58
59 from IPython.parallel.util import split_url, disambiguate_url, set_hwm
59 from ipython_parallel.util import split_url, disambiguate_url, set_hwm
60
60
61 # conditional import of SQLiteDB / MongoDB backend class
61 # conditional import of SQLiteDB / MongoDB backend class
62 real_dbs = []
62 real_dbs = []
63
63
64 try:
64 try:
65 from IPython.parallel.controller.sqlitedb import SQLiteDB
65 from ipython_parallel.controller.sqlitedb import SQLiteDB
66 except ImportError:
66 except ImportError:
67 pass
67 pass
68 else:
68 else:
69 real_dbs.append(SQLiteDB)
69 real_dbs.append(SQLiteDB)
70
70
71 try:
71 try:
72 from IPython.parallel.controller.mongodb import MongoDB
72 from ipython_parallel.controller.mongodb import MongoDB
73 except ImportError:
73 except ImportError:
74 pass
74 pass
75 else:
75 else:
76 real_dbs.append(MongoDB)
76 real_dbs.append(MongoDB)
77
77
78
78
79
79
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81 # Module level variables
81 # Module level variables
82 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
83
83
84
84
85 _description = """Start the IPython controller for parallel computing.
85 _description = """Start the IPython controller for parallel computing.
86
86
87 The IPython controller provides a gateway between the IPython engines and
87 The IPython controller provides a gateway between the IPython engines and
88 clients. The controller needs to be started before the engines and can be
88 clients. The controller needs to be started before the engines and can be
89 configured using command line options or using a cluster directory. Cluster
89 configured using command line options or using a cluster directory. Cluster
90 directories contain config, log and security files and are usually located in
90 directories contain config, log and security files and are usually located in
91 your ipython directory and named as "profile_name". See the `profile`
91 your ipython directory and named as "profile_name". See the `profile`
92 and `profile-dir` options for details.
92 and `profile-dir` options for details.
93 """
93 """
94
94
95 _examples = """
95 _examples = """
96 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
96 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
97 ipcontroller --scheme=pure # use the pure zeromq scheduler
97 ipcontroller --scheme=pure # use the pure zeromq scheduler
98 """
98 """
99
99
100
100
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102 # The main application
102 # The main application
103 #-----------------------------------------------------------------------------
103 #-----------------------------------------------------------------------------
104 flags = {}
104 flags = {}
105 flags.update(base_flags)
105 flags.update(base_flags)
106 flags.update({
106 flags.update({
107 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
107 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
108 'Use threads instead of processes for the schedulers'),
108 'Use threads instead of processes for the schedulers'),
109 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
109 'sqlitedb' : ({'HubFactory' : {'db_class' : 'ipython_parallel.controller.sqlitedb.SQLiteDB'}},
110 'use the SQLiteDB backend'),
110 'use the SQLiteDB backend'),
111 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
111 'mongodb' : ({'HubFactory' : {'db_class' : 'ipython_parallel.controller.mongodb.MongoDB'}},
112 'use the MongoDB backend'),
112 'use the MongoDB backend'),
113 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
113 'dictdb' : ({'HubFactory' : {'db_class' : 'ipython_parallel.controller.dictdb.DictDB'}},
114 'use the in-memory DictDB backend'),
114 'use the in-memory DictDB backend'),
115 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
115 'nodb' : ({'HubFactory' : {'db_class' : 'ipython_parallel.controller.dictdb.NoDB'}},
116 """use dummy DB backend, which doesn't store any information.
116 """use dummy DB backend, which doesn't store any information.
117
117
118 This is the default as of IPython 0.13.
118 This is the default as of IPython 0.13.
119
119
120 To enable delayed or repeated retrieval of results from the Hub,
120 To enable delayed or repeated retrieval of results from the Hub,
121 select one of the true db backends.
121 select one of the true db backends.
122 """),
122 """),
123 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
123 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
124 'reuse existing json connection files'),
124 'reuse existing json connection files'),
125 'restore' : ({'IPControllerApp' : {'restore_engines' : True, 'reuse_files' : True}},
125 'restore' : ({'IPControllerApp' : {'restore_engines' : True, 'reuse_files' : True}},
126 'Attempt to restore engines from a JSON file. '
126 'Attempt to restore engines from a JSON file. '
127 'For use when resuming a crashed controller'),
127 'For use when resuming a crashed controller'),
128 })
128 })
129
129
130 flags.update(session_flags)
130 flags.update(session_flags)
131
131
132 aliases = dict(
132 aliases = dict(
133 ssh = 'IPControllerApp.ssh_server',
133 ssh = 'IPControllerApp.ssh_server',
134 enginessh = 'IPControllerApp.engine_ssh_server',
134 enginessh = 'IPControllerApp.engine_ssh_server',
135 location = 'IPControllerApp.location',
135 location = 'IPControllerApp.location',
136
136
137 url = 'HubFactory.url',
137 url = 'HubFactory.url',
138 ip = 'HubFactory.ip',
138 ip = 'HubFactory.ip',
139 transport = 'HubFactory.transport',
139 transport = 'HubFactory.transport',
140 port = 'HubFactory.regport',
140 port = 'HubFactory.regport',
141
141
142 ping = 'HeartMonitor.period',
142 ping = 'HeartMonitor.period',
143
143
144 scheme = 'TaskScheduler.scheme_name',
144 scheme = 'TaskScheduler.scheme_name',
145 hwm = 'TaskScheduler.hwm',
145 hwm = 'TaskScheduler.hwm',
146 )
146 )
147 aliases.update(base_aliases)
147 aliases.update(base_aliases)
148 aliases.update(session_aliases)
148 aliases.update(session_aliases)
149
149
150 class IPControllerApp(BaseParallelApplication):
150 class IPControllerApp(BaseParallelApplication):
151
151
152 name = u'ipcontroller'
152 name = u'ipcontroller'
153 description = _description
153 description = _description
154 examples = _examples
154 examples = _examples
155 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, DictDB] + real_dbs
155 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, DictDB] + real_dbs
156
156
157 # change default to True
157 # change default to True
158 auto_create = Bool(True, config=True,
158 auto_create = Bool(True, config=True,
159 help="""Whether to create profile dir if it doesn't exist.""")
159 help="""Whether to create profile dir if it doesn't exist.""")
160
160
161 reuse_files = Bool(False, config=True,
161 reuse_files = Bool(False, config=True,
162 help="""Whether to reuse existing json connection files.
162 help="""Whether to reuse existing json connection files.
163 If False, connection files will be removed on a clean exit.
163 If False, connection files will be removed on a clean exit.
164 """
164 """
165 )
165 )
166 restore_engines = Bool(False, config=True,
166 restore_engines = Bool(False, config=True,
167 help="""Reload engine state from JSON file
167 help="""Reload engine state from JSON file
168 """
168 """
169 )
169 )
170 ssh_server = Unicode(u'', config=True,
170 ssh_server = Unicode(u'', config=True,
171 help="""ssh url for clients to use when connecting to the Controller
171 help="""ssh url for clients to use when connecting to the Controller
172 processes. It should be of the form: [user@]server[:port]. The
172 processes. It should be of the form: [user@]server[:port]. The
173 Controller's listening addresses must be accessible from the ssh server""",
173 Controller's listening addresses must be accessible from the ssh server""",
174 )
174 )
175 engine_ssh_server = Unicode(u'', config=True,
175 engine_ssh_server = Unicode(u'', config=True,
176 help="""ssh url for engines to use when connecting to the Controller
176 help="""ssh url for engines to use when connecting to the Controller
177 processes. It should be of the form: [user@]server[:port]. The
177 processes. It should be of the form: [user@]server[:port]. The
178 Controller's listening addresses must be accessible from the ssh server""",
178 Controller's listening addresses must be accessible from the ssh server""",
179 )
179 )
180 location = Unicode(u'', config=True,
180 location = Unicode(u'', config=True,
181 help="""The external IP or domain name of the Controller, used for disambiguating
181 help="""The external IP or domain name of the Controller, used for disambiguating
182 engine and client connections.""",
182 engine and client connections.""",
183 )
183 )
184 import_statements = List([], config=True,
184 import_statements = List([], config=True,
185 help="import statements to be run at startup. Necessary in some environments"
185 help="import statements to be run at startup. Necessary in some environments"
186 )
186 )
187
187
188 use_threads = Bool(False, config=True,
188 use_threads = Bool(False, config=True,
189 help='Use threads instead of processes for the schedulers',
189 help='Use threads instead of processes for the schedulers',
190 )
190 )
191
191
192 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
192 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
193 help="JSON filename where engine connection info will be stored.")
193 help="JSON filename where engine connection info will be stored.")
194 client_json_file = Unicode('ipcontroller-client.json', config=True,
194 client_json_file = Unicode('ipcontroller-client.json', config=True,
195 help="JSON filename where client connection info will be stored.")
195 help="JSON filename where client connection info will be stored.")
196
196
197 def _cluster_id_changed(self, name, old, new):
197 def _cluster_id_changed(self, name, old, new):
198 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
198 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
199 self.engine_json_file = "%s-engine.json" % self.name
199 self.engine_json_file = "%s-engine.json" % self.name
200 self.client_json_file = "%s-client.json" % self.name
200 self.client_json_file = "%s-client.json" % self.name
201
201
202
202
203 # internal
203 # internal
204 children = List()
204 children = List()
205 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
205 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
206
206
207 def _use_threads_changed(self, name, old, new):
207 def _use_threads_changed(self, name, old, new):
208 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
208 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
209
209
210 write_connection_files = Bool(True,
210 write_connection_files = Bool(True,
211 help="""Whether to write connection files to disk.
211 help="""Whether to write connection files to disk.
212 True in all cases other than runs with `reuse_files=True` *after the first*
212 True in all cases other than runs with `reuse_files=True` *after the first*
213 """
213 """
214 )
214 )
215
215
216 aliases = Dict(aliases)
216 aliases = Dict(aliases)
217 flags = Dict(flags)
217 flags = Dict(flags)
218
218
219
219
220 def save_connection_dict(self, fname, cdict):
220 def save_connection_dict(self, fname, cdict):
221 """save a connection dict to json file."""
221 """save a connection dict to json file."""
222 c = self.config
222 c = self.config
223 url = cdict['registration']
223 url = cdict['registration']
224 location = cdict['location']
224 location = cdict['location']
225
225
226 if not location:
226 if not location:
227 if public_ips():
227 if public_ips():
228 location = public_ips()[-1]
228 location = public_ips()[-1]
229 else:
229 else:
230 self.log.warn("Could not identify this machine's IP, assuming %s."
230 self.log.warn("Could not identify this machine's IP, assuming %s."
231 " You may need to specify '--location=<external_ip_address>' to help"
231 " You may need to specify '--location=<external_ip_address>' to help"
232 " IPython decide when to connect via loopback." % localhost() )
232 " IPython decide when to connect via loopback." % localhost() )
233 location = localhost()
233 location = localhost()
234 cdict['location'] = location
234 cdict['location'] = location
235 fname = os.path.join(self.profile_dir.security_dir, fname)
235 fname = os.path.join(self.profile_dir.security_dir, fname)
236 self.log.info("writing connection info to %s", fname)
236 self.log.info("writing connection info to %s", fname)
237 with open(fname, 'w') as f:
237 with open(fname, 'w') as f:
238 f.write(json.dumps(cdict, indent=2))
238 f.write(json.dumps(cdict, indent=2))
239 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
239 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
240
240
241 def load_config_from_json(self):
241 def load_config_from_json(self):
242 """load config from existing json connector files."""
242 """load config from existing json connector files."""
243 c = self.config
243 c = self.config
244 self.log.debug("loading config from JSON")
244 self.log.debug("loading config from JSON")
245
245
246 # load engine config
246 # load engine config
247
247
248 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
248 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
249 self.log.info("loading connection info from %s", fname)
249 self.log.info("loading connection info from %s", fname)
250 with open(fname) as f:
250 with open(fname) as f:
251 ecfg = json.loads(f.read())
251 ecfg = json.loads(f.read())
252
252
253 # json gives unicode, Session.key wants bytes
253 # json gives unicode, Session.key wants bytes
254 c.Session.key = ecfg['key'].encode('ascii')
254 c.Session.key = ecfg['key'].encode('ascii')
255
255
256 xport,ip = ecfg['interface'].split('://')
256 xport,ip = ecfg['interface'].split('://')
257
257
258 c.HubFactory.engine_ip = ip
258 c.HubFactory.engine_ip = ip
259 c.HubFactory.engine_transport = xport
259 c.HubFactory.engine_transport = xport
260
260
261 self.location = ecfg['location']
261 self.location = ecfg['location']
262 if not self.engine_ssh_server:
262 if not self.engine_ssh_server:
263 self.engine_ssh_server = ecfg['ssh']
263 self.engine_ssh_server = ecfg['ssh']
264
264
265 # load client config
265 # load client config
266
266
267 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
267 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
268 self.log.info("loading connection info from %s", fname)
268 self.log.info("loading connection info from %s", fname)
269 with open(fname) as f:
269 with open(fname) as f:
270 ccfg = json.loads(f.read())
270 ccfg = json.loads(f.read())
271
271
272 for key in ('key', 'registration', 'pack', 'unpack', 'signature_scheme'):
272 for key in ('key', 'registration', 'pack', 'unpack', 'signature_scheme'):
273 assert ccfg[key] == ecfg[key], "mismatch between engine and client info: %r" % key
273 assert ccfg[key] == ecfg[key], "mismatch between engine and client info: %r" % key
274
274
275 xport,addr = ccfg['interface'].split('://')
275 xport,addr = ccfg['interface'].split('://')
276
276
277 c.HubFactory.client_transport = xport
277 c.HubFactory.client_transport = xport
278 c.HubFactory.client_ip = ip
278 c.HubFactory.client_ip = ip
279 if not self.ssh_server:
279 if not self.ssh_server:
280 self.ssh_server = ccfg['ssh']
280 self.ssh_server = ccfg['ssh']
281
281
282 # load port config:
282 # load port config:
283 c.HubFactory.regport = ecfg['registration']
283 c.HubFactory.regport = ecfg['registration']
284 c.HubFactory.hb = (ecfg['hb_ping'], ecfg['hb_pong'])
284 c.HubFactory.hb = (ecfg['hb_ping'], ecfg['hb_pong'])
285 c.HubFactory.control = (ccfg['control'], ecfg['control'])
285 c.HubFactory.control = (ccfg['control'], ecfg['control'])
286 c.HubFactory.mux = (ccfg['mux'], ecfg['mux'])
286 c.HubFactory.mux = (ccfg['mux'], ecfg['mux'])
287 c.HubFactory.task = (ccfg['task'], ecfg['task'])
287 c.HubFactory.task = (ccfg['task'], ecfg['task'])
288 c.HubFactory.iopub = (ccfg['iopub'], ecfg['iopub'])
288 c.HubFactory.iopub = (ccfg['iopub'], ecfg['iopub'])
289 c.HubFactory.notifier_port = ccfg['notification']
289 c.HubFactory.notifier_port = ccfg['notification']
290
290
291 def cleanup_connection_files(self):
291 def cleanup_connection_files(self):
292 if self.reuse_files:
292 if self.reuse_files:
293 self.log.debug("leaving JSON connection files for reuse")
293 self.log.debug("leaving JSON connection files for reuse")
294 return
294 return
295 self.log.debug("cleaning up JSON connection files")
295 self.log.debug("cleaning up JSON connection files")
296 for f in (self.client_json_file, self.engine_json_file):
296 for f in (self.client_json_file, self.engine_json_file):
297 f = os.path.join(self.profile_dir.security_dir, f)
297 f = os.path.join(self.profile_dir.security_dir, f)
298 try:
298 try:
299 os.remove(f)
299 os.remove(f)
300 except Exception as e:
300 except Exception as e:
301 self.log.error("Failed to cleanup connection file: %s", e)
301 self.log.error("Failed to cleanup connection file: %s", e)
302 else:
302 else:
303 self.log.debug(u"removed %s", f)
303 self.log.debug(u"removed %s", f)
304
304
305 def load_secondary_config(self):
305 def load_secondary_config(self):
306 """secondary config, loading from JSON and setting defaults"""
306 """secondary config, loading from JSON and setting defaults"""
307 if self.reuse_files:
307 if self.reuse_files:
308 try:
308 try:
309 self.load_config_from_json()
309 self.load_config_from_json()
310 except (AssertionError,IOError) as e:
310 except (AssertionError,IOError) as e:
311 self.log.error("Could not load config from JSON: %s" % e)
311 self.log.error("Could not load config from JSON: %s" % e)
312 else:
312 else:
313 # successfully loaded config from JSON, and reuse=True
313 # successfully loaded config from JSON, and reuse=True
314 # no need to wite back the same file
314 # no need to wite back the same file
315 self.write_connection_files = False
315 self.write_connection_files = False
316
316
317 self.log.debug("Config changed")
317 self.log.debug("Config changed")
318 self.log.debug(repr(self.config))
318 self.log.debug(repr(self.config))
319
319
320 def init_hub(self):
320 def init_hub(self):
321 c = self.config
321 c = self.config
322
322
323 self.do_import_statements()
323 self.do_import_statements()
324
324
325 try:
325 try:
326 self.factory = HubFactory(config=c, log=self.log)
326 self.factory = HubFactory(config=c, log=self.log)
327 # self.start_logging()
327 # self.start_logging()
328 self.factory.init_hub()
328 self.factory.init_hub()
329 except TraitError:
329 except TraitError:
330 raise
330 raise
331 except Exception:
331 except Exception:
332 self.log.error("Couldn't construct the Controller", exc_info=True)
332 self.log.error("Couldn't construct the Controller", exc_info=True)
333 self.exit(1)
333 self.exit(1)
334
334
335 if self.write_connection_files:
335 if self.write_connection_files:
336 # save to new json config files
336 # save to new json config files
337 f = self.factory
337 f = self.factory
338 base = {
338 base = {
339 'key' : f.session.key.decode('ascii'),
339 'key' : f.session.key.decode('ascii'),
340 'location' : self.location,
340 'location' : self.location,
341 'pack' : f.session.packer,
341 'pack' : f.session.packer,
342 'unpack' : f.session.unpacker,
342 'unpack' : f.session.unpacker,
343 'signature_scheme' : f.session.signature_scheme,
343 'signature_scheme' : f.session.signature_scheme,
344 }
344 }
345
345
346 cdict = {'ssh' : self.ssh_server}
346 cdict = {'ssh' : self.ssh_server}
347 cdict.update(f.client_info)
347 cdict.update(f.client_info)
348 cdict.update(base)
348 cdict.update(base)
349 self.save_connection_dict(self.client_json_file, cdict)
349 self.save_connection_dict(self.client_json_file, cdict)
350
350
351 edict = {'ssh' : self.engine_ssh_server}
351 edict = {'ssh' : self.engine_ssh_server}
352 edict.update(f.engine_info)
352 edict.update(f.engine_info)
353 edict.update(base)
353 edict.update(base)
354 self.save_connection_dict(self.engine_json_file, edict)
354 self.save_connection_dict(self.engine_json_file, edict)
355
355
356 fname = "engines%s.json" % self.cluster_id
356 fname = "engines%s.json" % self.cluster_id
357 self.factory.hub.engine_state_file = os.path.join(self.profile_dir.log_dir, fname)
357 self.factory.hub.engine_state_file = os.path.join(self.profile_dir.log_dir, fname)
358 if self.restore_engines:
358 if self.restore_engines:
359 self.factory.hub._load_engine_state()
359 self.factory.hub._load_engine_state()
360 # load key into config so other sessions in this process (TaskScheduler)
360 # load key into config so other sessions in this process (TaskScheduler)
361 # have the same value
361 # have the same value
362 self.config.Session.key = self.factory.session.key
362 self.config.Session.key = self.factory.session.key
363
363
364 def init_schedulers(self):
364 def init_schedulers(self):
365 children = self.children
365 children = self.children
366 mq = import_item(str(self.mq_class))
366 mq = import_item(str(self.mq_class))
367
367
368 f = self.factory
368 f = self.factory
369 ident = f.session.bsession
369 ident = f.session.bsession
370 # disambiguate url, in case of *
370 # disambiguate url, in case of *
371 monitor_url = disambiguate_url(f.monitor_url)
371 monitor_url = disambiguate_url(f.monitor_url)
372 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
372 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
373 # IOPub relay (in a Process)
373 # IOPub relay (in a Process)
374 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
374 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
375 q.bind_in(f.client_url('iopub'))
375 q.bind_in(f.client_url('iopub'))
376 q.setsockopt_in(zmq.IDENTITY, ident + b"_iopub")
376 q.setsockopt_in(zmq.IDENTITY, ident + b"_iopub")
377 q.bind_out(f.engine_url('iopub'))
377 q.bind_out(f.engine_url('iopub'))
378 q.setsockopt_out(zmq.SUBSCRIBE, b'')
378 q.setsockopt_out(zmq.SUBSCRIBE, b'')
379 q.connect_mon(monitor_url)
379 q.connect_mon(monitor_url)
380 q.daemon=True
380 q.daemon=True
381 children.append(q)
381 children.append(q)
382
382
383 # Multiplexer Queue (in a Process)
383 # Multiplexer Queue (in a Process)
384 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
384 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
385
385
386 q.bind_in(f.client_url('mux'))
386 q.bind_in(f.client_url('mux'))
387 q.setsockopt_in(zmq.IDENTITY, b'mux_in')
387 q.setsockopt_in(zmq.IDENTITY, b'mux_in')
388 q.bind_out(f.engine_url('mux'))
388 q.bind_out(f.engine_url('mux'))
389 q.setsockopt_out(zmq.IDENTITY, b'mux_out')
389 q.setsockopt_out(zmq.IDENTITY, b'mux_out')
390 q.connect_mon(monitor_url)
390 q.connect_mon(monitor_url)
391 q.daemon=True
391 q.daemon=True
392 children.append(q)
392 children.append(q)
393
393
394 # Control Queue (in a Process)
394 # Control Queue (in a Process)
395 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
395 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
396 q.bind_in(f.client_url('control'))
396 q.bind_in(f.client_url('control'))
397 q.setsockopt_in(zmq.IDENTITY, b'control_in')
397 q.setsockopt_in(zmq.IDENTITY, b'control_in')
398 q.bind_out(f.engine_url('control'))
398 q.bind_out(f.engine_url('control'))
399 q.setsockopt_out(zmq.IDENTITY, b'control_out')
399 q.setsockopt_out(zmq.IDENTITY, b'control_out')
400 q.connect_mon(monitor_url)
400 q.connect_mon(monitor_url)
401 q.daemon=True
401 q.daemon=True
402 children.append(q)
402 children.append(q)
403 if 'TaskScheduler.scheme_name' in self.config:
403 if 'TaskScheduler.scheme_name' in self.config:
404 scheme = self.config.TaskScheduler.scheme_name
404 scheme = self.config.TaskScheduler.scheme_name
405 else:
405 else:
406 scheme = TaskScheduler.scheme_name.get_default_value()
406 scheme = TaskScheduler.scheme_name.get_default_value()
407 # Task Queue (in a Process)
407 # Task Queue (in a Process)
408 if scheme == 'pure':
408 if scheme == 'pure':
409 self.log.warn("task::using pure DEALER Task scheduler")
409 self.log.warn("task::using pure DEALER Task scheduler")
410 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
410 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
411 # q.setsockopt_out(zmq.HWM, hub.hwm)
411 # q.setsockopt_out(zmq.HWM, hub.hwm)
412 q.bind_in(f.client_url('task'))
412 q.bind_in(f.client_url('task'))
413 q.setsockopt_in(zmq.IDENTITY, b'task_in')
413 q.setsockopt_in(zmq.IDENTITY, b'task_in')
414 q.bind_out(f.engine_url('task'))
414 q.bind_out(f.engine_url('task'))
415 q.setsockopt_out(zmq.IDENTITY, b'task_out')
415 q.setsockopt_out(zmq.IDENTITY, b'task_out')
416 q.connect_mon(monitor_url)
416 q.connect_mon(monitor_url)
417 q.daemon=True
417 q.daemon=True
418 children.append(q)
418 children.append(q)
419 elif scheme == 'none':
419 elif scheme == 'none':
420 self.log.warn("task::using no Task scheduler")
420 self.log.warn("task::using no Task scheduler")
421
421
422 else:
422 else:
423 self.log.info("task::using Python %s Task scheduler"%scheme)
423 self.log.info("task::using Python %s Task scheduler"%scheme)
424 sargs = (f.client_url('task'), f.engine_url('task'),
424 sargs = (f.client_url('task'), f.engine_url('task'),
425 monitor_url, disambiguate_url(f.client_url('notification')),
425 monitor_url, disambiguate_url(f.client_url('notification')),
426 disambiguate_url(f.client_url('registration')),
426 disambiguate_url(f.client_url('registration')),
427 )
427 )
428 kwargs = dict(logname='scheduler', loglevel=self.log_level,
428 kwargs = dict(logname='scheduler', loglevel=self.log_level,
429 log_url = self.log_url, config=dict(self.config))
429 log_url = self.log_url, config=dict(self.config))
430 if 'Process' in self.mq_class:
430 if 'Process' in self.mq_class:
431 # run the Python scheduler in a Process
431 # run the Python scheduler in a Process
432 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
432 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
433 q.daemon=True
433 q.daemon=True
434 children.append(q)
434 children.append(q)
435 else:
435 else:
436 # single-threaded Controller
436 # single-threaded Controller
437 kwargs['in_thread'] = True
437 kwargs['in_thread'] = True
438 launch_scheduler(*sargs, **kwargs)
438 launch_scheduler(*sargs, **kwargs)
439
439
440 # set unlimited HWM for all relay devices
440 # set unlimited HWM for all relay devices
441 if hasattr(zmq, 'SNDHWM'):
441 if hasattr(zmq, 'SNDHWM'):
442 q = children[0]
442 q = children[0]
443 q.setsockopt_in(zmq.RCVHWM, 0)
443 q.setsockopt_in(zmq.RCVHWM, 0)
444 q.setsockopt_out(zmq.SNDHWM, 0)
444 q.setsockopt_out(zmq.SNDHWM, 0)
445
445
446 for q in children[1:]:
446 for q in children[1:]:
447 if not hasattr(q, 'setsockopt_in'):
447 if not hasattr(q, 'setsockopt_in'):
448 continue
448 continue
449 q.setsockopt_in(zmq.SNDHWM, 0)
449 q.setsockopt_in(zmq.SNDHWM, 0)
450 q.setsockopt_in(zmq.RCVHWM, 0)
450 q.setsockopt_in(zmq.RCVHWM, 0)
451 q.setsockopt_out(zmq.SNDHWM, 0)
451 q.setsockopt_out(zmq.SNDHWM, 0)
452 q.setsockopt_out(zmq.RCVHWM, 0)
452 q.setsockopt_out(zmq.RCVHWM, 0)
453 q.setsockopt_mon(zmq.SNDHWM, 0)
453 q.setsockopt_mon(zmq.SNDHWM, 0)
454
454
455
455
456 def terminate_children(self):
456 def terminate_children(self):
457 child_procs = []
457 child_procs = []
458 for child in self.children:
458 for child in self.children:
459 if isinstance(child, ProcessMonitoredQueue):
459 if isinstance(child, ProcessMonitoredQueue):
460 child_procs.append(child.launcher)
460 child_procs.append(child.launcher)
461 elif isinstance(child, Process):
461 elif isinstance(child, Process):
462 child_procs.append(child)
462 child_procs.append(child)
463 if child_procs:
463 if child_procs:
464 self.log.critical("terminating children...")
464 self.log.critical("terminating children...")
465 for child in child_procs:
465 for child in child_procs:
466 try:
466 try:
467 child.terminate()
467 child.terminate()
468 except OSError:
468 except OSError:
469 # already dead
469 # already dead
470 pass
470 pass
471
471
472 def handle_signal(self, sig, frame):
472 def handle_signal(self, sig, frame):
473 self.log.critical("Received signal %i, shutting down", sig)
473 self.log.critical("Received signal %i, shutting down", sig)
474 self.terminate_children()
474 self.terminate_children()
475 self.loop.stop()
475 self.loop.stop()
476
476
477 def init_signal(self):
477 def init_signal(self):
478 for sig in (SIGINT, SIGABRT, SIGTERM):
478 for sig in (SIGINT, SIGABRT, SIGTERM):
479 signal(sig, self.handle_signal)
479 signal(sig, self.handle_signal)
480
480
481 def do_import_statements(self):
481 def do_import_statements(self):
482 statements = self.import_statements
482 statements = self.import_statements
483 for s in statements:
483 for s in statements:
484 try:
484 try:
485 self.log.msg("Executing statement: '%s'" % s)
485 self.log.msg("Executing statement: '%s'" % s)
486 exec(s, globals(), locals())
486 exec(s, globals(), locals())
487 except:
487 except:
488 self.log.msg("Error running statement: %s" % s)
488 self.log.msg("Error running statement: %s" % s)
489
489
490 def forward_logging(self):
490 def forward_logging(self):
491 if self.log_url:
491 if self.log_url:
492 self.log.info("Forwarding logging to %s"%self.log_url)
492 self.log.info("Forwarding logging to %s"%self.log_url)
493 context = zmq.Context.instance()
493 context = zmq.Context.instance()
494 lsock = context.socket(zmq.PUB)
494 lsock = context.socket(zmq.PUB)
495 lsock.connect(self.log_url)
495 lsock.connect(self.log_url)
496 handler = PUBHandler(lsock)
496 handler = PUBHandler(lsock)
497 handler.root_topic = 'controller'
497 handler.root_topic = 'controller'
498 handler.setLevel(self.log_level)
498 handler.setLevel(self.log_level)
499 self.log.addHandler(handler)
499 self.log.addHandler(handler)
500
500
501 @catch_config_error
501 @catch_config_error
502 def initialize(self, argv=None):
502 def initialize(self, argv=None):
503 super(IPControllerApp, self).initialize(argv)
503 super(IPControllerApp, self).initialize(argv)
504 self.forward_logging()
504 self.forward_logging()
505 self.load_secondary_config()
505 self.load_secondary_config()
506 self.init_hub()
506 self.init_hub()
507 self.init_schedulers()
507 self.init_schedulers()
508
508
509 def start(self):
509 def start(self):
510 # Start the subprocesses:
510 # Start the subprocesses:
511 self.factory.start()
511 self.factory.start()
512 # children must be started before signals are setup,
512 # children must be started before signals are setup,
513 # otherwise signal-handling will fire multiple times
513 # otherwise signal-handling will fire multiple times
514 for child in self.children:
514 for child in self.children:
515 child.start()
515 child.start()
516 self.init_signal()
516 self.init_signal()
517
517
518 self.write_pid_file(overwrite=True)
518 self.write_pid_file(overwrite=True)
519
519
520 try:
520 try:
521 self.factory.loop.start()
521 self.factory.loop.start()
522 except KeyboardInterrupt:
522 except KeyboardInterrupt:
523 self.log.critical("Interrupted, Exiting...\n")
523 self.log.critical("Interrupted, Exiting...\n")
524 finally:
524 finally:
525 self.cleanup_connection_files()
525 self.cleanup_connection_files()
526
526
527
527
528 def launch_new_instance(*args, **kwargs):
528 def launch_new_instance(*args, **kwargs):
529 """Create and run the IPython controller"""
529 """Create and run the IPython controller"""
530 if sys.platform == 'win32':
530 if sys.platform == 'win32':
531 # make sure we don't get called from a multiprocessing subprocess
531 # make sure we don't get called from a multiprocessing subprocess
532 # this can result in infinite Controllers being started on Windows
532 # this can result in infinite Controllers being started on Windows
533 # which doesn't have a proper fork, so multiprocessing is wonky
533 # which doesn't have a proper fork, so multiprocessing is wonky
534
534
535 # this only comes up when IPython has been installed using vanilla
535 # this only comes up when IPython has been installed using vanilla
536 # setuptools, and *not* distribute.
536 # setuptools, and *not* distribute.
537 import multiprocessing
537 import multiprocessing
538 p = multiprocessing.current_process()
538 p = multiprocessing.current_process()
539 # the main process has name 'MainProcess'
539 # the main process has name 'MainProcess'
540 # subprocesses will have names like 'Process-1'
540 # subprocesses will have names like 'Process-1'
541 if p.name != 'MainProcess':
541 if p.name != 'MainProcess':
542 # we are a subprocess, don't start another Controller!
542 # we are a subprocess, don't start another Controller!
543 return
543 return
544 return IPControllerApp.launch_instance(*args, **kwargs)
544 return IPControllerApp.launch_instance(*args, **kwargs)
545
545
546
546
547 if __name__ == '__main__':
547 if __name__ == '__main__':
548 launch_new_instance()
548 launch_new_instance()
@@ -1,397 +1,397 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython engine application
4 The IPython engine application
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import json
24 import json
25 import os
25 import os
26 import sys
26 import sys
27 import time
27 import time
28
28
29 import zmq
29 import zmq
30 from zmq.eventloop import ioloop
30 from zmq.eventloop import ioloop
31
31
32 from IPython.core.profiledir import ProfileDir
32 from IPython.core.profiledir import ProfileDir
33 from IPython.parallel.apps.baseapp import (
33 from ipython_parallel.apps.baseapp import (
34 BaseParallelApplication,
34 BaseParallelApplication,
35 base_aliases,
35 base_aliases,
36 base_flags,
36 base_flags,
37 catch_config_error,
37 catch_config_error,
38 )
38 )
39 from IPython.kernel.zmq.log import EnginePUBHandler
39 from IPython.kernel.zmq.log import EnginePUBHandler
40 from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
40 from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
41 from IPython.kernel.zmq.kernelapp import IPKernelApp
41 from IPython.kernel.zmq.kernelapp import IPKernelApp
42 from IPython.kernel.zmq.session import (
42 from IPython.kernel.zmq.session import (
43 Session, session_aliases, session_flags
43 Session, session_aliases, session_flags
44 )
44 )
45 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
45 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
46
46
47 from IPython.config.configurable import Configurable
47 from IPython.config.configurable import Configurable
48
48
49 from IPython.parallel.engine.engine import EngineFactory
49 from ipython_parallel.engine.engine import EngineFactory
50 from IPython.parallel.util import disambiguate_ip_address
50 from ipython_parallel.util import disambiguate_ip_address
51
51
52 from IPython.utils.importstring import import_item
52 from IPython.utils.importstring import import_item
53 from IPython.utils.py3compat import cast_bytes
53 from IPython.utils.py3compat import cast_bytes
54 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float, Instance
54 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float, Instance
55
55
56
56
57 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
58 # Module level variables
58 # Module level variables
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60
60
61 _description = """Start an IPython engine for parallel computing.
61 _description = """Start an IPython engine for parallel computing.
62
62
63 IPython engines run in parallel and perform computations on behalf of a client
63 IPython engines run in parallel and perform computations on behalf of a client
64 and controller. A controller needs to be started before the engines. The
64 and controller. A controller needs to be started before the engines. The
65 engine can be configured using command line options or using a cluster
65 engine can be configured using command line options or using a cluster
66 directory. Cluster directories contain config, log and security files and are
66 directory. Cluster directories contain config, log and security files and are
67 usually located in your ipython directory and named as "profile_name".
67 usually located in your ipython directory and named as "profile_name".
68 See the `profile` and `profile-dir` options for details.
68 See the `profile` and `profile-dir` options for details.
69 """
69 """
70
70
71 _examples = """
71 _examples = """
72 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
72 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
73 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
73 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
74 """
74 """
75
75
76 #-----------------------------------------------------------------------------
76 #-----------------------------------------------------------------------------
77 # MPI configuration
77 # MPI configuration
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79
79
80 mpi4py_init = """from mpi4py import MPI as mpi
80 mpi4py_init = """from mpi4py import MPI as mpi
81 mpi.size = mpi.COMM_WORLD.Get_size()
81 mpi.size = mpi.COMM_WORLD.Get_size()
82 mpi.rank = mpi.COMM_WORLD.Get_rank()
82 mpi.rank = mpi.COMM_WORLD.Get_rank()
83 """
83 """
84
84
85
85
86 pytrilinos_init = """from PyTrilinos import Epetra
86 pytrilinos_init = """from PyTrilinos import Epetra
87 class SimpleStruct:
87 class SimpleStruct:
88 pass
88 pass
89 mpi = SimpleStruct()
89 mpi = SimpleStruct()
90 mpi.rank = 0
90 mpi.rank = 0
91 mpi.size = 0
91 mpi.size = 0
92 """
92 """
93
93
94 class MPI(Configurable):
94 class MPI(Configurable):
95 """Configurable for MPI initialization"""
95 """Configurable for MPI initialization"""
96 use = Unicode('', config=True,
96 use = Unicode('', config=True,
97 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
97 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
98 )
98 )
99
99
100 def _use_changed(self, name, old, new):
100 def _use_changed(self, name, old, new):
101 # load default init script if it's not set
101 # load default init script if it's not set
102 if not self.init_script:
102 if not self.init_script:
103 self.init_script = self.default_inits.get(new, '')
103 self.init_script = self.default_inits.get(new, '')
104
104
105 init_script = Unicode('', config=True,
105 init_script = Unicode('', config=True,
106 help="Initialization code for MPI")
106 help="Initialization code for MPI")
107
107
108 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
108 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
109 config=True)
109 config=True)
110
110
111
111
112 #-----------------------------------------------------------------------------
112 #-----------------------------------------------------------------------------
113 # Main application
113 # Main application
114 #-----------------------------------------------------------------------------
114 #-----------------------------------------------------------------------------
115 aliases = dict(
115 aliases = dict(
116 file = 'IPEngineApp.url_file',
116 file = 'IPEngineApp.url_file',
117 c = 'IPEngineApp.startup_command',
117 c = 'IPEngineApp.startup_command',
118 s = 'IPEngineApp.startup_script',
118 s = 'IPEngineApp.startup_script',
119
119
120 url = 'EngineFactory.url',
120 url = 'EngineFactory.url',
121 ssh = 'EngineFactory.sshserver',
121 ssh = 'EngineFactory.sshserver',
122 sshkey = 'EngineFactory.sshkey',
122 sshkey = 'EngineFactory.sshkey',
123 ip = 'EngineFactory.ip',
123 ip = 'EngineFactory.ip',
124 transport = 'EngineFactory.transport',
124 transport = 'EngineFactory.transport',
125 port = 'EngineFactory.regport',
125 port = 'EngineFactory.regport',
126 location = 'EngineFactory.location',
126 location = 'EngineFactory.location',
127
127
128 timeout = 'EngineFactory.timeout',
128 timeout = 'EngineFactory.timeout',
129
129
130 mpi = 'MPI.use',
130 mpi = 'MPI.use',
131
131
132 )
132 )
133 aliases.update(base_aliases)
133 aliases.update(base_aliases)
134 aliases.update(session_aliases)
134 aliases.update(session_aliases)
135 flags = {}
135 flags = {}
136 flags.update(base_flags)
136 flags.update(base_flags)
137 flags.update(session_flags)
137 flags.update(session_flags)
138
138
139 class IPEngineApp(BaseParallelApplication):
139 class IPEngineApp(BaseParallelApplication):
140
140
141 name = 'ipengine'
141 name = 'ipengine'
142 description = _description
142 description = _description
143 examples = _examples
143 examples = _examples
144 classes = List([ZMQInteractiveShell, ProfileDir, Session, EngineFactory, Kernel, MPI])
144 classes = List([ZMQInteractiveShell, ProfileDir, Session, EngineFactory, Kernel, MPI])
145
145
146 startup_script = Unicode(u'', config=True,
146 startup_script = Unicode(u'', config=True,
147 help='specify a script to be run at startup')
147 help='specify a script to be run at startup')
148 startup_command = Unicode('', config=True,
148 startup_command = Unicode('', config=True,
149 help='specify a command to be run at startup')
149 help='specify a command to be run at startup')
150
150
151 url_file = Unicode(u'', config=True,
151 url_file = Unicode(u'', config=True,
152 help="""The full location of the file containing the connection information for
152 help="""The full location of the file containing the connection information for
153 the controller. If this is not given, the file must be in the
153 the controller. If this is not given, the file must be in the
154 security directory of the cluster directory. This location is
154 security directory of the cluster directory. This location is
155 resolved using the `profile` or `profile_dir` options.""",
155 resolved using the `profile` or `profile_dir` options.""",
156 )
156 )
157 wait_for_url_file = Float(5, config=True,
157 wait_for_url_file = Float(5, config=True,
158 help="""The maximum number of seconds to wait for url_file to exist.
158 help="""The maximum number of seconds to wait for url_file to exist.
159 This is useful for batch-systems and shared-filesystems where the
159 This is useful for batch-systems and shared-filesystems where the
160 controller and engine are started at the same time and it
160 controller and engine are started at the same time and it
161 may take a moment for the controller to write the connector files.""")
161 may take a moment for the controller to write the connector files.""")
162
162
163 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
163 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
164
164
165 def _cluster_id_changed(self, name, old, new):
165 def _cluster_id_changed(self, name, old, new):
166 if new:
166 if new:
167 base = 'ipcontroller-%s' % new
167 base = 'ipcontroller-%s' % new
168 else:
168 else:
169 base = 'ipcontroller'
169 base = 'ipcontroller'
170 self.url_file_name = "%s-engine.json" % base
170 self.url_file_name = "%s-engine.json" % base
171
171
172 log_url = Unicode('', config=True,
172 log_url = Unicode('', config=True,
173 help="""The URL for the iploggerapp instance, for forwarding
173 help="""The URL for the iploggerapp instance, for forwarding
174 logging to a central location.""")
174 logging to a central location.""")
175
175
176 # an IPKernelApp instance, used to setup listening for shell frontends
176 # an IPKernelApp instance, used to setup listening for shell frontends
177 kernel_app = Instance(IPKernelApp)
177 kernel_app = Instance(IPKernelApp)
178
178
179 aliases = Dict(aliases)
179 aliases = Dict(aliases)
180 flags = Dict(flags)
180 flags = Dict(flags)
181
181
182 @property
182 @property
183 def kernel(self):
183 def kernel(self):
184 """allow access to the Kernel object, so I look like IPKernelApp"""
184 """allow access to the Kernel object, so I look like IPKernelApp"""
185 return self.engine.kernel
185 return self.engine.kernel
186
186
187 def find_url_file(self):
187 def find_url_file(self):
188 """Set the url file.
188 """Set the url file.
189
189
190 Here we don't try to actually see if it exists for is valid as that
190 Here we don't try to actually see if it exists for is valid as that
191 is hadled by the connection logic.
191 is hadled by the connection logic.
192 """
192 """
193 config = self.config
193 config = self.config
194 # Find the actual controller key file
194 # Find the actual controller key file
195 if not self.url_file:
195 if not self.url_file:
196 self.url_file = os.path.join(
196 self.url_file = os.path.join(
197 self.profile_dir.security_dir,
197 self.profile_dir.security_dir,
198 self.url_file_name
198 self.url_file_name
199 )
199 )
200
200
201 def load_connector_file(self):
201 def load_connector_file(self):
202 """load config from a JSON connector file,
202 """load config from a JSON connector file,
203 at a *lower* priority than command-line/config files.
203 at a *lower* priority than command-line/config files.
204 """
204 """
205
205
206 self.log.info("Loading url_file %r", self.url_file)
206 self.log.info("Loading url_file %r", self.url_file)
207 config = self.config
207 config = self.config
208
208
209 with open(self.url_file) as f:
209 with open(self.url_file) as f:
210 num_tries = 0
210 num_tries = 0
211 max_tries = 5
211 max_tries = 5
212 d = ""
212 d = ""
213 while not d:
213 while not d:
214 try:
214 try:
215 d = json.loads(f.read())
215 d = json.loads(f.read())
216 except ValueError:
216 except ValueError:
217 if num_tries > max_tries:
217 if num_tries > max_tries:
218 raise
218 raise
219 num_tries += 1
219 num_tries += 1
220 time.sleep(0.5)
220 time.sleep(0.5)
221
221
222 # allow hand-override of location for disambiguation
222 # allow hand-override of location for disambiguation
223 # and ssh-server
223 # and ssh-server
224 if 'EngineFactory.location' not in config:
224 if 'EngineFactory.location' not in config:
225 config.EngineFactory.location = d['location']
225 config.EngineFactory.location = d['location']
226 if 'EngineFactory.sshserver' not in config:
226 if 'EngineFactory.sshserver' not in config:
227 config.EngineFactory.sshserver = d.get('ssh')
227 config.EngineFactory.sshserver = d.get('ssh')
228
228
229 location = config.EngineFactory.location
229 location = config.EngineFactory.location
230
230
231 proto, ip = d['interface'].split('://')
231 proto, ip = d['interface'].split('://')
232 ip = disambiguate_ip_address(ip, location)
232 ip = disambiguate_ip_address(ip, location)
233 d['interface'] = '%s://%s' % (proto, ip)
233 d['interface'] = '%s://%s' % (proto, ip)
234
234
235 # DO NOT allow override of basic URLs, serialization, or key
235 # DO NOT allow override of basic URLs, serialization, or key
236 # JSON file takes top priority there
236 # JSON file takes top priority there
237 config.Session.key = cast_bytes(d['key'])
237 config.Session.key = cast_bytes(d['key'])
238 config.Session.signature_scheme = d['signature_scheme']
238 config.Session.signature_scheme = d['signature_scheme']
239
239
240 config.EngineFactory.url = d['interface'] + ':%i' % d['registration']
240 config.EngineFactory.url = d['interface'] + ':%i' % d['registration']
241
241
242 config.Session.packer = d['pack']
242 config.Session.packer = d['pack']
243 config.Session.unpacker = d['unpack']
243 config.Session.unpacker = d['unpack']
244
244
245 self.log.debug("Config changed:")
245 self.log.debug("Config changed:")
246 self.log.debug("%r", config)
246 self.log.debug("%r", config)
247 self.connection_info = d
247 self.connection_info = d
248
248
249 def bind_kernel(self, **kwargs):
249 def bind_kernel(self, **kwargs):
250 """Promote engine to listening kernel, accessible to frontends."""
250 """Promote engine to listening kernel, accessible to frontends."""
251 if self.kernel_app is not None:
251 if self.kernel_app is not None:
252 return
252 return
253
253
254 self.log.info("Opening ports for direct connections as an IPython kernel")
254 self.log.info("Opening ports for direct connections as an IPython kernel")
255
255
256 kernel = self.kernel
256 kernel = self.kernel
257
257
258 kwargs.setdefault('config', self.config)
258 kwargs.setdefault('config', self.config)
259 kwargs.setdefault('log', self.log)
259 kwargs.setdefault('log', self.log)
260 kwargs.setdefault('profile_dir', self.profile_dir)
260 kwargs.setdefault('profile_dir', self.profile_dir)
261 kwargs.setdefault('session', self.engine.session)
261 kwargs.setdefault('session', self.engine.session)
262
262
263 app = self.kernel_app = IPKernelApp(**kwargs)
263 app = self.kernel_app = IPKernelApp(**kwargs)
264
264
265 # allow IPKernelApp.instance():
265 # allow IPKernelApp.instance():
266 IPKernelApp._instance = app
266 IPKernelApp._instance = app
267
267
268 app.init_connection_file()
268 app.init_connection_file()
269 # relevant contents of init_sockets:
269 # relevant contents of init_sockets:
270
270
271 app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
271 app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
272 app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)
272 app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)
273
273
274 app.iopub_port = app._bind_socket(kernel.iopub_socket, app.iopub_port)
274 app.iopub_port = app._bind_socket(kernel.iopub_socket, app.iopub_port)
275 app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)
275 app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)
276
276
277 kernel.stdin_socket = self.engine.context.socket(zmq.ROUTER)
277 kernel.stdin_socket = self.engine.context.socket(zmq.ROUTER)
278 app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
278 app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
279 app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)
279 app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)
280
280
281 # start the heartbeat, and log connection info:
281 # start the heartbeat, and log connection info:
282
282
283 app.init_heartbeat()
283 app.init_heartbeat()
284
284
285 app.log_connection_info()
285 app.log_connection_info()
286 app.write_connection_file()
286 app.write_connection_file()
287
287
288
288
289 def init_engine(self):
289 def init_engine(self):
290 # This is the working dir by now.
290 # This is the working dir by now.
291 sys.path.insert(0, '')
291 sys.path.insert(0, '')
292 config = self.config
292 config = self.config
293 # print config
293 # print config
294 self.find_url_file()
294 self.find_url_file()
295
295
296 # was the url manually specified?
296 # was the url manually specified?
297 keys = set(self.config.EngineFactory.keys())
297 keys = set(self.config.EngineFactory.keys())
298 keys = keys.union(set(self.config.RegistrationFactory.keys()))
298 keys = keys.union(set(self.config.RegistrationFactory.keys()))
299
299
300 if keys.intersection(set(['ip', 'url', 'port'])):
300 if keys.intersection(set(['ip', 'url', 'port'])):
301 # Connection info was specified, don't wait for the file
301 # Connection info was specified, don't wait for the file
302 url_specified = True
302 url_specified = True
303 self.wait_for_url_file = 0
303 self.wait_for_url_file = 0
304 else:
304 else:
305 url_specified = False
305 url_specified = False
306
306
307 if self.wait_for_url_file and not os.path.exists(self.url_file):
307 if self.wait_for_url_file and not os.path.exists(self.url_file):
308 self.log.warn("url_file %r not found", self.url_file)
308 self.log.warn("url_file %r not found", self.url_file)
309 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
309 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
310 tic = time.time()
310 tic = time.time()
311 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
311 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
312 # wait for url_file to exist, or until time limit
312 # wait for url_file to exist, or until time limit
313 time.sleep(0.1)
313 time.sleep(0.1)
314
314
315 if os.path.exists(self.url_file):
315 if os.path.exists(self.url_file):
316 self.load_connector_file()
316 self.load_connector_file()
317 elif not url_specified:
317 elif not url_specified:
318 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
318 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
319 self.exit(1)
319 self.exit(1)
320
320
321 exec_lines = []
321 exec_lines = []
322 for app in ('IPKernelApp', 'InteractiveShellApp'):
322 for app in ('IPKernelApp', 'InteractiveShellApp'):
323 if '%s.exec_lines' % app in config:
323 if '%s.exec_lines' % app in config:
324 exec_lines = config[app].exec_lines
324 exec_lines = config[app].exec_lines
325 break
325 break
326
326
327 exec_files = []
327 exec_files = []
328 for app in ('IPKernelApp', 'InteractiveShellApp'):
328 for app in ('IPKernelApp', 'InteractiveShellApp'):
329 if '%s.exec_files' % app in config:
329 if '%s.exec_files' % app in config:
330 exec_files = config[app].exec_files
330 exec_files = config[app].exec_files
331 break
331 break
332
332
333 config.IPKernelApp.exec_lines = exec_lines
333 config.IPKernelApp.exec_lines = exec_lines
334 config.IPKernelApp.exec_files = exec_files
334 config.IPKernelApp.exec_files = exec_files
335
335
336 if self.startup_script:
336 if self.startup_script:
337 exec_files.append(self.startup_script)
337 exec_files.append(self.startup_script)
338 if self.startup_command:
338 if self.startup_command:
339 exec_lines.append(self.startup_command)
339 exec_lines.append(self.startup_command)
340
340
341 # Create the underlying shell class and Engine
341 # Create the underlying shell class and Engine
342 # shell_class = import_item(self.master_config.Global.shell_class)
342 # shell_class = import_item(self.master_config.Global.shell_class)
343 # print self.config
343 # print self.config
344 try:
344 try:
345 self.engine = EngineFactory(config=config, log=self.log,
345 self.engine = EngineFactory(config=config, log=self.log,
346 connection_info=self.connection_info,
346 connection_info=self.connection_info,
347 )
347 )
348 except:
348 except:
349 self.log.error("Couldn't start the Engine", exc_info=True)
349 self.log.error("Couldn't start the Engine", exc_info=True)
350 self.exit(1)
350 self.exit(1)
351
351
352 def forward_logging(self):
352 def forward_logging(self):
353 if self.log_url:
353 if self.log_url:
354 self.log.info("Forwarding logging to %s", self.log_url)
354 self.log.info("Forwarding logging to %s", self.log_url)
355 context = self.engine.context
355 context = self.engine.context
356 lsock = context.socket(zmq.PUB)
356 lsock = context.socket(zmq.PUB)
357 lsock.connect(self.log_url)
357 lsock.connect(self.log_url)
358 handler = EnginePUBHandler(self.engine, lsock)
358 handler = EnginePUBHandler(self.engine, lsock)
359 handler.setLevel(self.log_level)
359 handler.setLevel(self.log_level)
360 self.log.addHandler(handler)
360 self.log.addHandler(handler)
361
361
362 def init_mpi(self):
362 def init_mpi(self):
363 global mpi
363 global mpi
364 self.mpi = MPI(parent=self)
364 self.mpi = MPI(parent=self)
365
365
366 mpi_import_statement = self.mpi.init_script
366 mpi_import_statement = self.mpi.init_script
367 if mpi_import_statement:
367 if mpi_import_statement:
368 try:
368 try:
369 self.log.info("Initializing MPI:")
369 self.log.info("Initializing MPI:")
370 self.log.info(mpi_import_statement)
370 self.log.info(mpi_import_statement)
371 exec(mpi_import_statement, globals())
371 exec(mpi_import_statement, globals())
372 except:
372 except:
373 mpi = None
373 mpi = None
374 else:
374 else:
375 mpi = None
375 mpi = None
376
376
377 @catch_config_error
377 @catch_config_error
378 def initialize(self, argv=None):
378 def initialize(self, argv=None):
379 super(IPEngineApp, self).initialize(argv)
379 super(IPEngineApp, self).initialize(argv)
380 self.init_mpi()
380 self.init_mpi()
381 self.init_engine()
381 self.init_engine()
382 self.forward_logging()
382 self.forward_logging()
383
383
384 def start(self):
384 def start(self):
385 self.engine.start()
385 self.engine.start()
386 try:
386 try:
387 self.engine.loop.start()
387 self.engine.loop.start()
388 except KeyboardInterrupt:
388 except KeyboardInterrupt:
389 self.log.critical("Engine Interrupted, shutting down...\n")
389 self.log.critical("Engine Interrupted, shutting down...\n")
390
390
391
391
392 launch_new_instance = IPEngineApp.launch_instance
392 launch_new_instance = IPEngineApp.launch_instance
393
393
394
394
395 if __name__ == '__main__':
395 if __name__ == '__main__':
396 launch_new_instance()
396 launch_new_instance()
397
397
@@ -1,95 +1,95 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 A simple IPython logger application
4 A simple IPython logger application
5
5
6 Authors:
6 Authors:
7
7
8 * MinRK
8 * MinRK
9
9
10 """
10 """
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 import os
23 import os
24 import sys
24 import sys
25
25
26 import zmq
26 import zmq
27
27
28 from IPython.core.profiledir import ProfileDir
28 from IPython.core.profiledir import ProfileDir
29 from IPython.utils.traitlets import Bool, Dict, Unicode
29 from IPython.utils.traitlets import Bool, Dict, Unicode
30
30
31 from IPython.parallel.apps.baseapp import (
31 from ipython_parallel.apps.baseapp import (
32 BaseParallelApplication,
32 BaseParallelApplication,
33 base_aliases,
33 base_aliases,
34 catch_config_error,
34 catch_config_error,
35 )
35 )
36 from IPython.parallel.apps.logwatcher import LogWatcher
36 from ipython_parallel.apps.logwatcher import LogWatcher
37
37
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39 # Module level variables
39 # Module level variables
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41
41
42 #: The default config file name for this application
42 #: The default config file name for this application
43 _description = """Start an IPython logger for parallel computing.
43 _description = """Start an IPython logger for parallel computing.
44
44
45 IPython controllers and engines (and your own processes) can broadcast log messages
45 IPython controllers and engines (and your own processes) can broadcast log messages
46 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
46 by registering a `zmq.log.handlers.PUBHandler` with the `logging` module. The
47 logger can be configured using command line options or using a cluster
47 logger can be configured using command line options or using a cluster
48 directory. Cluster directories contain config, log and security files and are
48 directory. Cluster directories contain config, log and security files and are
49 usually located in your ipython directory and named as "profile_name".
49 usually located in your ipython directory and named as "profile_name".
50 See the `profile` and `profile-dir` options for details.
50 See the `profile` and `profile-dir` options for details.
51 """
51 """
52
52
53
53
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55 # Main application
55 # Main application
56 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
57 aliases = {}
57 aliases = {}
58 aliases.update(base_aliases)
58 aliases.update(base_aliases)
59 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
59 aliases.update(dict(url='LogWatcher.url', topics='LogWatcher.topics'))
60
60
61 class IPLoggerApp(BaseParallelApplication):
61 class IPLoggerApp(BaseParallelApplication):
62
62
63 name = u'iplogger'
63 name = u'iplogger'
64 description = _description
64 description = _description
65 classes = [LogWatcher, ProfileDir]
65 classes = [LogWatcher, ProfileDir]
66 aliases = Dict(aliases)
66 aliases = Dict(aliases)
67
67
68 @catch_config_error
68 @catch_config_error
69 def initialize(self, argv=None):
69 def initialize(self, argv=None):
70 super(IPLoggerApp, self).initialize(argv)
70 super(IPLoggerApp, self).initialize(argv)
71 self.init_watcher()
71 self.init_watcher()
72
72
73 def init_watcher(self):
73 def init_watcher(self):
74 try:
74 try:
75 self.watcher = LogWatcher(parent=self, log=self.log)
75 self.watcher = LogWatcher(parent=self, log=self.log)
76 except:
76 except:
77 self.log.error("Couldn't start the LogWatcher", exc_info=True)
77 self.log.error("Couldn't start the LogWatcher", exc_info=True)
78 self.exit(1)
78 self.exit(1)
79 self.log.info("Listening for log messages on %r"%self.watcher.url)
79 self.log.info("Listening for log messages on %r"%self.watcher.url)
80
80
81
81
82 def start(self):
82 def start(self):
83 self.watcher.start()
83 self.watcher.start()
84 try:
84 try:
85 self.watcher.loop.start()
85 self.watcher.loop.start()
86 except KeyboardInterrupt:
86 except KeyboardInterrupt:
87 self.log.critical("Logging Interrupted, shutting down...\n")
87 self.log.critical("Logging Interrupted, shutting down...\n")
88
88
89
89
90 launch_new_instance = IPLoggerApp.launch_instance
90 launch_new_instance = IPLoggerApp.launch_instance
91
91
92
92
93 if __name__ == '__main__':
93 if __name__ == '__main__':
94 launch_new_instance()
94 launch_new_instance()
95
95
@@ -1,1445 +1,1445 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Facilities for launching IPython processes asynchronously."""
2 """Facilities for launching IPython processes asynchronously."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import copy
7 import copy
8 import logging
8 import logging
9 import os
9 import os
10 import pipes
10 import pipes
11 import stat
11 import stat
12 import sys
12 import sys
13 import time
13 import time
14
14
15 # signal imports, handling various platforms, versions
15 # signal imports, handling various platforms, versions
16
16
17 from signal import SIGINT, SIGTERM
17 from signal import SIGINT, SIGTERM
18 try:
18 try:
19 from signal import SIGKILL
19 from signal import SIGKILL
20 except ImportError:
20 except ImportError:
21 # Windows
21 # Windows
22 SIGKILL=SIGTERM
22 SIGKILL=SIGTERM
23
23
24 try:
24 try:
25 # Windows >= 2.7, 3.2
25 # Windows >= 2.7, 3.2
26 from signal import CTRL_C_EVENT as SIGINT
26 from signal import CTRL_C_EVENT as SIGINT
27 except ImportError:
27 except ImportError:
28 pass
28 pass
29
29
30 from subprocess import Popen, PIPE, STDOUT
30 from subprocess import Popen, PIPE, STDOUT
31 try:
31 try:
32 from subprocess import check_output
32 from subprocess import check_output
33 except ImportError:
33 except ImportError:
34 # pre-2.7, define check_output with Popen
34 # pre-2.7, define check_output with Popen
35 def check_output(*args, **kwargs):
35 def check_output(*args, **kwargs):
36 kwargs.update(dict(stdout=PIPE))
36 kwargs.update(dict(stdout=PIPE))
37 p = Popen(*args, **kwargs)
37 p = Popen(*args, **kwargs)
38 out,err = p.communicate()
38 out,err = p.communicate()
39 return out
39 return out
40
40
41 from zmq.eventloop import ioloop
41 from zmq.eventloop import ioloop
42
42
43 from IPython.config.application import Application
43 from IPython.config.application import Application
44 from IPython.config.configurable import LoggingConfigurable
44 from IPython.config.configurable import LoggingConfigurable
45 from IPython.utils.text import EvalFormatter
45 from IPython.utils.text import EvalFormatter
46 from IPython.utils.traitlets import (
46 from IPython.utils.traitlets import (
47 Any, Integer, CFloat, List, Unicode, Dict, Instance, HasTraits, CRegExp
47 Any, Integer, CFloat, List, Unicode, Dict, Instance, HasTraits, CRegExp
48 )
48 )
49 from IPython.utils.encoding import DEFAULT_ENCODING
49 from IPython.utils.encoding import DEFAULT_ENCODING
50 from IPython.utils.path import get_home_dir, ensure_dir_exists
50 from IPython.utils.path import get_home_dir, ensure_dir_exists
51 from IPython.utils.process import find_cmd, FindCmdError
51 from IPython.utils.process import find_cmd, FindCmdError
52 from IPython.utils.py3compat import iteritems, itervalues
52 from IPython.utils.py3compat import iteritems, itervalues
53
53
54 from .win32support import forward_read_events
54 from .win32support import forward_read_events
55
55
56 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
56 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
57
57
58 WINDOWS = os.name == 'nt'
58 WINDOWS = os.name == 'nt'
59
59
60 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
61 # Paths to the kernel apps
61 # Paths to the kernel apps
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63
63
64 ipcluster_cmd_argv = [sys.executable, "-m", "IPython.parallel.cluster"]
64 ipcluster_cmd_argv = [sys.executable, "-m", "ipython_parallel.cluster"]
65
65
66 ipengine_cmd_argv = [sys.executable, "-m", "IPython.parallel.engine"]
66 ipengine_cmd_argv = [sys.executable, "-m", "ipython_parallel.engine"]
67
67
68 ipcontroller_cmd_argv = [sys.executable, "-m", "IPython.parallel.controller"]
68 ipcontroller_cmd_argv = [sys.executable, "-m", "ipython_parallel.controller"]
69
69
70 if WINDOWS and sys.version_info < (3,):
70 if WINDOWS and sys.version_info < (3,):
71 # `python -m package` doesn't work on Windows Python 2
71 # `python -m package` doesn't work on Windows Python 2
72 # due to weird multiprocessing bugs
72 # due to weird multiprocessing bugs
73 # and python -m module puts classes in the `__main__` module,
73 # and python -m module puts classes in the `__main__` module,
74 # so instance checks get confused
74 # so instance checks get confused
75 ipengine_cmd_argv = [sys.executable, "-c", "from IPython.parallel.engine.__main__ import main; main()"]
75 ipengine_cmd_argv = [sys.executable, "-c", "from ipython_parallel.engine.__main__ import main; main()"]
76 ipcontroller_cmd_argv = [sys.executable, "-c", "from IPython.parallel.controller.__main__ import main; main()"]
76 ipcontroller_cmd_argv = [sys.executable, "-c", "from ipython_parallel.controller.__main__ import main; main()"]
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # Base launchers and errors
79 # Base launchers and errors
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82 class LauncherError(Exception):
82 class LauncherError(Exception):
83 pass
83 pass
84
84
85
85
86 class ProcessStateError(LauncherError):
86 class ProcessStateError(LauncherError):
87 pass
87 pass
88
88
89
89
90 class UnknownStatus(LauncherError):
90 class UnknownStatus(LauncherError):
91 pass
91 pass
92
92
93
93
94 class BaseLauncher(LoggingConfigurable):
94 class BaseLauncher(LoggingConfigurable):
95 """An asbtraction for starting, stopping and signaling a process."""
95 """An asbtraction for starting, stopping and signaling a process."""
96
96
97 # In all of the launchers, the work_dir is where child processes will be
97 # In all of the launchers, the work_dir is where child processes will be
98 # run. This will usually be the profile_dir, but may not be. any work_dir
98 # run. This will usually be the profile_dir, but may not be. any work_dir
99 # passed into the __init__ method will override the config value.
99 # passed into the __init__ method will override the config value.
100 # This should not be used to set the work_dir for the actual engine
100 # This should not be used to set the work_dir for the actual engine
101 # and controller. Instead, use their own config files or the
101 # and controller. Instead, use their own config files or the
102 # controller_args, engine_args attributes of the launchers to add
102 # controller_args, engine_args attributes of the launchers to add
103 # the work_dir option.
103 # the work_dir option.
104 work_dir = Unicode(u'.')
104 work_dir = Unicode(u'.')
105 loop = Instance('zmq.eventloop.ioloop.IOLoop')
105 loop = Instance('zmq.eventloop.ioloop.IOLoop')
106
106
107 start_data = Any()
107 start_data = Any()
108 stop_data = Any()
108 stop_data = Any()
109
109
110 def _loop_default(self):
110 def _loop_default(self):
111 return ioloop.IOLoop.instance()
111 return ioloop.IOLoop.instance()
112
112
113 def __init__(self, work_dir=u'.', config=None, **kwargs):
113 def __init__(self, work_dir=u'.', config=None, **kwargs):
114 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
114 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
115 self.state = 'before' # can be before, running, after
115 self.state = 'before' # can be before, running, after
116 self.stop_callbacks = []
116 self.stop_callbacks = []
117
117
118 @property
118 @property
119 def args(self):
119 def args(self):
120 """A list of cmd and args that will be used to start the process.
120 """A list of cmd and args that will be used to start the process.
121
121
122 This is what is passed to :func:`spawnProcess` and the first element
122 This is what is passed to :func:`spawnProcess` and the first element
123 will be the process name.
123 will be the process name.
124 """
124 """
125 return self.find_args()
125 return self.find_args()
126
126
127 def find_args(self):
127 def find_args(self):
128 """The ``.args`` property calls this to find the args list.
128 """The ``.args`` property calls this to find the args list.
129
129
130 Subcommand should implement this to construct the cmd and args.
130 Subcommand should implement this to construct the cmd and args.
131 """
131 """
132 raise NotImplementedError('find_args must be implemented in a subclass')
132 raise NotImplementedError('find_args must be implemented in a subclass')
133
133
134 @property
134 @property
135 def arg_str(self):
135 def arg_str(self):
136 """The string form of the program arguments."""
136 """The string form of the program arguments."""
137 return ' '.join(self.args)
137 return ' '.join(self.args)
138
138
139 @property
139 @property
140 def running(self):
140 def running(self):
141 """Am I running."""
141 """Am I running."""
142 if self.state == 'running':
142 if self.state == 'running':
143 return True
143 return True
144 else:
144 else:
145 return False
145 return False
146
146
147 def start(self):
147 def start(self):
148 """Start the process."""
148 """Start the process."""
149 raise NotImplementedError('start must be implemented in a subclass')
149 raise NotImplementedError('start must be implemented in a subclass')
150
150
151 def stop(self):
151 def stop(self):
152 """Stop the process and notify observers of stopping.
152 """Stop the process and notify observers of stopping.
153
153
154 This method will return None immediately.
154 This method will return None immediately.
155 To observe the actual process stopping, see :meth:`on_stop`.
155 To observe the actual process stopping, see :meth:`on_stop`.
156 """
156 """
157 raise NotImplementedError('stop must be implemented in a subclass')
157 raise NotImplementedError('stop must be implemented in a subclass')
158
158
159 def on_stop(self, f):
159 def on_stop(self, f):
160 """Register a callback to be called with this Launcher's stop_data
160 """Register a callback to be called with this Launcher's stop_data
161 when the process actually finishes.
161 when the process actually finishes.
162 """
162 """
163 if self.state=='after':
163 if self.state=='after':
164 return f(self.stop_data)
164 return f(self.stop_data)
165 else:
165 else:
166 self.stop_callbacks.append(f)
166 self.stop_callbacks.append(f)
167
167
168 def notify_start(self, data):
168 def notify_start(self, data):
169 """Call this to trigger startup actions.
169 """Call this to trigger startup actions.
170
170
171 This logs the process startup and sets the state to 'running'. It is
171 This logs the process startup and sets the state to 'running'. It is
172 a pass-through so it can be used as a callback.
172 a pass-through so it can be used as a callback.
173 """
173 """
174
174
175 self.log.debug('Process %r started: %r', self.args[0], data)
175 self.log.debug('Process %r started: %r', self.args[0], data)
176 self.start_data = data
176 self.start_data = data
177 self.state = 'running'
177 self.state = 'running'
178 return data
178 return data
179
179
180 def notify_stop(self, data):
180 def notify_stop(self, data):
181 """Call this to trigger process stop actions.
181 """Call this to trigger process stop actions.
182
182
183 This logs the process stopping and sets the state to 'after'. Call
183 This logs the process stopping and sets the state to 'after'. Call
184 this to trigger callbacks registered via :meth:`on_stop`."""
184 this to trigger callbacks registered via :meth:`on_stop`."""
185
185
186 self.log.debug('Process %r stopped: %r', self.args[0], data)
186 self.log.debug('Process %r stopped: %r', self.args[0], data)
187 self.stop_data = data
187 self.stop_data = data
188 self.state = 'after'
188 self.state = 'after'
189 for i in range(len(self.stop_callbacks)):
189 for i in range(len(self.stop_callbacks)):
190 d = self.stop_callbacks.pop()
190 d = self.stop_callbacks.pop()
191 d(data)
191 d(data)
192 return data
192 return data
193
193
194 def signal(self, sig):
194 def signal(self, sig):
195 """Signal the process.
195 """Signal the process.
196
196
197 Parameters
197 Parameters
198 ----------
198 ----------
199 sig : str or int
199 sig : str or int
200 'KILL', 'INT', etc., or any signal number
200 'KILL', 'INT', etc., or any signal number
201 """
201 """
202 raise NotImplementedError('signal must be implemented in a subclass')
202 raise NotImplementedError('signal must be implemented in a subclass')
203
203
204 class ClusterAppMixin(HasTraits):
204 class ClusterAppMixin(HasTraits):
205 """MixIn for cluster args as traits"""
205 """MixIn for cluster args as traits"""
206 profile_dir=Unicode('')
206 profile_dir=Unicode('')
207 cluster_id=Unicode('')
207 cluster_id=Unicode('')
208
208
209 @property
209 @property
210 def cluster_args(self):
210 def cluster_args(self):
211 return ['--profile-dir', self.profile_dir, '--cluster-id', self.cluster_id]
211 return ['--profile-dir', self.profile_dir, '--cluster-id', self.cluster_id]
212
212
213 class ControllerMixin(ClusterAppMixin):
213 class ControllerMixin(ClusterAppMixin):
214 controller_cmd = List(ipcontroller_cmd_argv, config=True,
214 controller_cmd = List(ipcontroller_cmd_argv, config=True,
215 help="""Popen command to launch ipcontroller.""")
215 help="""Popen command to launch ipcontroller.""")
216 # Command line arguments to ipcontroller.
216 # Command line arguments to ipcontroller.
217 controller_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
217 controller_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
218 help="""command-line args to pass to ipcontroller""")
218 help="""command-line args to pass to ipcontroller""")
219
219
220 class EngineMixin(ClusterAppMixin):
220 class EngineMixin(ClusterAppMixin):
221 engine_cmd = List(ipengine_cmd_argv, config=True,
221 engine_cmd = List(ipengine_cmd_argv, config=True,
222 help="""command to launch the Engine.""")
222 help="""command to launch the Engine.""")
223 # Command line arguments for ipengine.
223 # Command line arguments for ipengine.
224 engine_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
224 engine_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
225 help="command-line arguments to pass to ipengine"
225 help="command-line arguments to pass to ipengine"
226 )
226 )
227
227
228
228
229 #-----------------------------------------------------------------------------
229 #-----------------------------------------------------------------------------
230 # Local process launchers
230 # Local process launchers
231 #-----------------------------------------------------------------------------
231 #-----------------------------------------------------------------------------
232
232
233
233
234 class LocalProcessLauncher(BaseLauncher):
234 class LocalProcessLauncher(BaseLauncher):
235 """Start and stop an external process in an asynchronous manner.
235 """Start and stop an external process in an asynchronous manner.
236
236
237 This will launch the external process with a working directory of
237 This will launch the external process with a working directory of
238 ``self.work_dir``.
238 ``self.work_dir``.
239 """
239 """
240
240
241 # This is used to to construct self.args, which is passed to
241 # This is used to to construct self.args, which is passed to
242 # spawnProcess.
242 # spawnProcess.
243 cmd_and_args = List([])
243 cmd_and_args = List([])
244 poll_frequency = Integer(100) # in ms
244 poll_frequency = Integer(100) # in ms
245
245
246 def __init__(self, work_dir=u'.', config=None, **kwargs):
246 def __init__(self, work_dir=u'.', config=None, **kwargs):
247 super(LocalProcessLauncher, self).__init__(
247 super(LocalProcessLauncher, self).__init__(
248 work_dir=work_dir, config=config, **kwargs
248 work_dir=work_dir, config=config, **kwargs
249 )
249 )
250 self.process = None
250 self.process = None
251 self.poller = None
251 self.poller = None
252
252
253 def find_args(self):
253 def find_args(self):
254 return self.cmd_and_args
254 return self.cmd_and_args
255
255
256 def start(self):
256 def start(self):
257 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
257 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
258 if self.state == 'before':
258 if self.state == 'before':
259 self.process = Popen(self.args,
259 self.process = Popen(self.args,
260 stdout=PIPE,stderr=PIPE,stdin=PIPE,
260 stdout=PIPE,stderr=PIPE,stdin=PIPE,
261 env=os.environ,
261 env=os.environ,
262 cwd=self.work_dir
262 cwd=self.work_dir
263 )
263 )
264 if WINDOWS:
264 if WINDOWS:
265 self.stdout = forward_read_events(self.process.stdout)
265 self.stdout = forward_read_events(self.process.stdout)
266 self.stderr = forward_read_events(self.process.stderr)
266 self.stderr = forward_read_events(self.process.stderr)
267 else:
267 else:
268 self.stdout = self.process.stdout.fileno()
268 self.stdout = self.process.stdout.fileno()
269 self.stderr = self.process.stderr.fileno()
269 self.stderr = self.process.stderr.fileno()
270 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
270 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
271 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
271 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
272 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
272 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
273 self.poller.start()
273 self.poller.start()
274 self.notify_start(self.process.pid)
274 self.notify_start(self.process.pid)
275 else:
275 else:
276 s = 'The process was already started and has state: %r' % self.state
276 s = 'The process was already started and has state: %r' % self.state
277 raise ProcessStateError(s)
277 raise ProcessStateError(s)
278
278
279 def stop(self):
279 def stop(self):
280 return self.interrupt_then_kill()
280 return self.interrupt_then_kill()
281
281
282 def signal(self, sig):
282 def signal(self, sig):
283 if self.state == 'running':
283 if self.state == 'running':
284 if WINDOWS and sig != SIGINT:
284 if WINDOWS and sig != SIGINT:
285 # use Windows tree-kill for better child cleanup
285 # use Windows tree-kill for better child cleanup
286 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
286 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
287 else:
287 else:
288 self.process.send_signal(sig)
288 self.process.send_signal(sig)
289
289
290 def interrupt_then_kill(self, delay=2.0):
290 def interrupt_then_kill(self, delay=2.0):
291 """Send INT, wait a delay and then send KILL."""
291 """Send INT, wait a delay and then send KILL."""
292 try:
292 try:
293 self.signal(SIGINT)
293 self.signal(SIGINT)
294 except Exception:
294 except Exception:
295 self.log.debug("interrupt failed")
295 self.log.debug("interrupt failed")
296 pass
296 pass
297 self.killer = self.loop.add_timeout(self.loop.time() + delay, lambda : self.signal(SIGKILL))
297 self.killer = self.loop.add_timeout(self.loop.time() + delay, lambda : self.signal(SIGKILL))
298
298
299 # callbacks, etc:
299 # callbacks, etc:
300
300
301 def handle_stdout(self, fd, events):
301 def handle_stdout(self, fd, events):
302 if WINDOWS:
302 if WINDOWS:
303 line = self.stdout.recv()
303 line = self.stdout.recv()
304 else:
304 else:
305 line = self.process.stdout.readline()
305 line = self.process.stdout.readline()
306 # a stopped process will be readable but return empty strings
306 # a stopped process will be readable but return empty strings
307 if line:
307 if line:
308 self.log.debug(line[:-1])
308 self.log.debug(line[:-1])
309 else:
309 else:
310 self.poll()
310 self.poll()
311
311
312 def handle_stderr(self, fd, events):
312 def handle_stderr(self, fd, events):
313 if WINDOWS:
313 if WINDOWS:
314 line = self.stderr.recv()
314 line = self.stderr.recv()
315 else:
315 else:
316 line = self.process.stderr.readline()
316 line = self.process.stderr.readline()
317 # a stopped process will be readable but return empty strings
317 # a stopped process will be readable but return empty strings
318 if line:
318 if line:
319 self.log.debug(line[:-1])
319 self.log.debug(line[:-1])
320 else:
320 else:
321 self.poll()
321 self.poll()
322
322
323 def poll(self):
323 def poll(self):
324 status = self.process.poll()
324 status = self.process.poll()
325 if status is not None:
325 if status is not None:
326 self.poller.stop()
326 self.poller.stop()
327 self.loop.remove_handler(self.stdout)
327 self.loop.remove_handler(self.stdout)
328 self.loop.remove_handler(self.stderr)
328 self.loop.remove_handler(self.stderr)
329 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
329 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
330 return status
330 return status
331
331
332 class LocalControllerLauncher(LocalProcessLauncher, ControllerMixin):
332 class LocalControllerLauncher(LocalProcessLauncher, ControllerMixin):
333 """Launch a controller as a regular external process."""
333 """Launch a controller as a regular external process."""
334
334
335 def find_args(self):
335 def find_args(self):
336 return self.controller_cmd + self.cluster_args + self.controller_args
336 return self.controller_cmd + self.cluster_args + self.controller_args
337
337
338 def start(self):
338 def start(self):
339 """Start the controller by profile_dir."""
339 """Start the controller by profile_dir."""
340 return super(LocalControllerLauncher, self).start()
340 return super(LocalControllerLauncher, self).start()
341
341
342
342
343 class LocalEngineLauncher(LocalProcessLauncher, EngineMixin):
343 class LocalEngineLauncher(LocalProcessLauncher, EngineMixin):
344 """Launch a single engine as a regular externall process."""
344 """Launch a single engine as a regular externall process."""
345
345
346 def find_args(self):
346 def find_args(self):
347 return self.engine_cmd + self.cluster_args + self.engine_args
347 return self.engine_cmd + self.cluster_args + self.engine_args
348
348
349
349
350 class LocalEngineSetLauncher(LocalEngineLauncher):
350 class LocalEngineSetLauncher(LocalEngineLauncher):
351 """Launch a set of engines as regular external processes."""
351 """Launch a set of engines as regular external processes."""
352
352
353 delay = CFloat(0.1, config=True,
353 delay = CFloat(0.1, config=True,
354 help="""delay (in seconds) between starting each engine after the first.
354 help="""delay (in seconds) between starting each engine after the first.
355 This can help force the engines to get their ids in order, or limit
355 This can help force the engines to get their ids in order, or limit
356 process flood when starting many engines."""
356 process flood when starting many engines."""
357 )
357 )
358
358
359 # launcher class
359 # launcher class
360 launcher_class = LocalEngineLauncher
360 launcher_class = LocalEngineLauncher
361
361
362 launchers = Dict()
362 launchers = Dict()
363 stop_data = Dict()
363 stop_data = Dict()
364
364
365 def __init__(self, work_dir=u'.', config=None, **kwargs):
365 def __init__(self, work_dir=u'.', config=None, **kwargs):
366 super(LocalEngineSetLauncher, self).__init__(
366 super(LocalEngineSetLauncher, self).__init__(
367 work_dir=work_dir, config=config, **kwargs
367 work_dir=work_dir, config=config, **kwargs
368 )
368 )
369
369
370 def start(self, n):
370 def start(self, n):
371 """Start n engines by profile or profile_dir."""
371 """Start n engines by profile or profile_dir."""
372 dlist = []
372 dlist = []
373 for i in range(n):
373 for i in range(n):
374 if i > 0:
374 if i > 0:
375 time.sleep(self.delay)
375 time.sleep(self.delay)
376 el = self.launcher_class(work_dir=self.work_dir, parent=self, log=self.log,
376 el = self.launcher_class(work_dir=self.work_dir, parent=self, log=self.log,
377 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
377 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
378 )
378 )
379
379
380 # Copy the engine args over to each engine launcher.
380 # Copy the engine args over to each engine launcher.
381 el.engine_cmd = copy.deepcopy(self.engine_cmd)
381 el.engine_cmd = copy.deepcopy(self.engine_cmd)
382 el.engine_args = copy.deepcopy(self.engine_args)
382 el.engine_args = copy.deepcopy(self.engine_args)
383 el.on_stop(self._notice_engine_stopped)
383 el.on_stop(self._notice_engine_stopped)
384 d = el.start()
384 d = el.start()
385 self.launchers[i] = el
385 self.launchers[i] = el
386 dlist.append(d)
386 dlist.append(d)
387 self.notify_start(dlist)
387 self.notify_start(dlist)
388 return dlist
388 return dlist
389
389
390 def find_args(self):
390 def find_args(self):
391 return ['engine set']
391 return ['engine set']
392
392
393 def signal(self, sig):
393 def signal(self, sig):
394 dlist = []
394 dlist = []
395 for el in itervalues(self.launchers):
395 for el in itervalues(self.launchers):
396 d = el.signal(sig)
396 d = el.signal(sig)
397 dlist.append(d)
397 dlist.append(d)
398 return dlist
398 return dlist
399
399
400 def interrupt_then_kill(self, delay=1.0):
400 def interrupt_then_kill(self, delay=1.0):
401 dlist = []
401 dlist = []
402 for el in itervalues(self.launchers):
402 for el in itervalues(self.launchers):
403 d = el.interrupt_then_kill(delay)
403 d = el.interrupt_then_kill(delay)
404 dlist.append(d)
404 dlist.append(d)
405 return dlist
405 return dlist
406
406
407 def stop(self):
407 def stop(self):
408 return self.interrupt_then_kill()
408 return self.interrupt_then_kill()
409
409
410 def _notice_engine_stopped(self, data):
410 def _notice_engine_stopped(self, data):
411 pid = data['pid']
411 pid = data['pid']
412 for idx,el in iteritems(self.launchers):
412 for idx,el in iteritems(self.launchers):
413 if el.process.pid == pid:
413 if el.process.pid == pid:
414 break
414 break
415 self.launchers.pop(idx)
415 self.launchers.pop(idx)
416 self.stop_data[idx] = data
416 self.stop_data[idx] = data
417 if not self.launchers:
417 if not self.launchers:
418 self.notify_stop(self.stop_data)
418 self.notify_stop(self.stop_data)
419
419
420
420
421 #-----------------------------------------------------------------------------
421 #-----------------------------------------------------------------------------
422 # MPI launchers
422 # MPI launchers
423 #-----------------------------------------------------------------------------
423 #-----------------------------------------------------------------------------
424
424
425
425
426 class MPILauncher(LocalProcessLauncher):
426 class MPILauncher(LocalProcessLauncher):
427 """Launch an external process using mpiexec."""
427 """Launch an external process using mpiexec."""
428
428
429 mpi_cmd = List(['mpiexec'], config=True,
429 mpi_cmd = List(['mpiexec'], config=True,
430 help="The mpiexec command to use in starting the process."
430 help="The mpiexec command to use in starting the process."
431 )
431 )
432 mpi_args = List([], config=True,
432 mpi_args = List([], config=True,
433 help="The command line arguments to pass to mpiexec."
433 help="The command line arguments to pass to mpiexec."
434 )
434 )
435 program = List(['date'],
435 program = List(['date'],
436 help="The program to start via mpiexec.")
436 help="The program to start via mpiexec.")
437 program_args = List([],
437 program_args = List([],
438 help="The command line argument to the program."
438 help="The command line argument to the program."
439 )
439 )
440 n = Integer(1)
440 n = Integer(1)
441
441
442 def __init__(self, *args, **kwargs):
442 def __init__(self, *args, **kwargs):
443 # deprecation for old MPIExec names:
443 # deprecation for old MPIExec names:
444 config = kwargs.get('config', {})
444 config = kwargs.get('config', {})
445 for oldname in ('MPIExecLauncher', 'MPIExecControllerLauncher', 'MPIExecEngineSetLauncher'):
445 for oldname in ('MPIExecLauncher', 'MPIExecControllerLauncher', 'MPIExecEngineSetLauncher'):
446 deprecated = config.get(oldname)
446 deprecated = config.get(oldname)
447 if deprecated:
447 if deprecated:
448 newname = oldname.replace('MPIExec', 'MPI')
448 newname = oldname.replace('MPIExec', 'MPI')
449 config[newname].update(deprecated)
449 config[newname].update(deprecated)
450 self.log.warn("WARNING: %s name has been deprecated, use %s", oldname, newname)
450 self.log.warn("WARNING: %s name has been deprecated, use %s", oldname, newname)
451
451
452 super(MPILauncher, self).__init__(*args, **kwargs)
452 super(MPILauncher, self).__init__(*args, **kwargs)
453
453
454 def find_args(self):
454 def find_args(self):
455 """Build self.args using all the fields."""
455 """Build self.args using all the fields."""
456 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
456 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
457 self.program + self.program_args
457 self.program + self.program_args
458
458
459 def start(self, n):
459 def start(self, n):
460 """Start n instances of the program using mpiexec."""
460 """Start n instances of the program using mpiexec."""
461 self.n = n
461 self.n = n
462 return super(MPILauncher, self).start()
462 return super(MPILauncher, self).start()
463
463
464
464
465 class MPIControllerLauncher(MPILauncher, ControllerMixin):
465 class MPIControllerLauncher(MPILauncher, ControllerMixin):
466 """Launch a controller using mpiexec."""
466 """Launch a controller using mpiexec."""
467
467
468 # alias back to *non-configurable* program[_args] for use in find_args()
468 # alias back to *non-configurable* program[_args] for use in find_args()
469 # this way all Controller/EngineSetLaunchers have the same form, rather
469 # this way all Controller/EngineSetLaunchers have the same form, rather
470 # than *some* having `program_args` and others `controller_args`
470 # than *some* having `program_args` and others `controller_args`
471 @property
471 @property
472 def program(self):
472 def program(self):
473 return self.controller_cmd
473 return self.controller_cmd
474
474
475 @property
475 @property
476 def program_args(self):
476 def program_args(self):
477 return self.cluster_args + self.controller_args
477 return self.cluster_args + self.controller_args
478
478
479 def start(self):
479 def start(self):
480 """Start the controller by profile_dir."""
480 """Start the controller by profile_dir."""
481 return super(MPIControllerLauncher, self).start(1)
481 return super(MPIControllerLauncher, self).start(1)
482
482
483
483
484 class MPIEngineSetLauncher(MPILauncher, EngineMixin):
484 class MPIEngineSetLauncher(MPILauncher, EngineMixin):
485 """Launch engines using mpiexec"""
485 """Launch engines using mpiexec"""
486
486
487 # alias back to *non-configurable* program[_args] for use in find_args()
487 # alias back to *non-configurable* program[_args] for use in find_args()
488 # this way all Controller/EngineSetLaunchers have the same form, rather
488 # this way all Controller/EngineSetLaunchers have the same form, rather
489 # than *some* having `program_args` and others `controller_args`
489 # than *some* having `program_args` and others `controller_args`
490 @property
490 @property
491 def program(self):
491 def program(self):
492 return self.engine_cmd
492 return self.engine_cmd
493
493
494 @property
494 @property
495 def program_args(self):
495 def program_args(self):
496 return self.cluster_args + self.engine_args
496 return self.cluster_args + self.engine_args
497
497
498 def start(self, n):
498 def start(self, n):
499 """Start n engines by profile or profile_dir."""
499 """Start n engines by profile or profile_dir."""
500 self.n = n
500 self.n = n
501 return super(MPIEngineSetLauncher, self).start(n)
501 return super(MPIEngineSetLauncher, self).start(n)
502
502
503 # deprecated MPIExec names
503 # deprecated MPIExec names
504 class DeprecatedMPILauncher(object):
504 class DeprecatedMPILauncher(object):
505 def warn(self):
505 def warn(self):
506 oldname = self.__class__.__name__
506 oldname = self.__class__.__name__
507 newname = oldname.replace('MPIExec', 'MPI')
507 newname = oldname.replace('MPIExec', 'MPI')
508 self.log.warn("WARNING: %s name is deprecated, use %s", oldname, newname)
508 self.log.warn("WARNING: %s name is deprecated, use %s", oldname, newname)
509
509
510 class MPIExecLauncher(MPILauncher, DeprecatedMPILauncher):
510 class MPIExecLauncher(MPILauncher, DeprecatedMPILauncher):
511 """Deprecated, use MPILauncher"""
511 """Deprecated, use MPILauncher"""
512 def __init__(self, *args, **kwargs):
512 def __init__(self, *args, **kwargs):
513 super(MPIExecLauncher, self).__init__(*args, **kwargs)
513 super(MPIExecLauncher, self).__init__(*args, **kwargs)
514 self.warn()
514 self.warn()
515
515
516 class MPIExecControllerLauncher(MPIControllerLauncher, DeprecatedMPILauncher):
516 class MPIExecControllerLauncher(MPIControllerLauncher, DeprecatedMPILauncher):
517 """Deprecated, use MPIControllerLauncher"""
517 """Deprecated, use MPIControllerLauncher"""
518 def __init__(self, *args, **kwargs):
518 def __init__(self, *args, **kwargs):
519 super(MPIExecControllerLauncher, self).__init__(*args, **kwargs)
519 super(MPIExecControllerLauncher, self).__init__(*args, **kwargs)
520 self.warn()
520 self.warn()
521
521
522 class MPIExecEngineSetLauncher(MPIEngineSetLauncher, DeprecatedMPILauncher):
522 class MPIExecEngineSetLauncher(MPIEngineSetLauncher, DeprecatedMPILauncher):
523 """Deprecated, use MPIEngineSetLauncher"""
523 """Deprecated, use MPIEngineSetLauncher"""
524 def __init__(self, *args, **kwargs):
524 def __init__(self, *args, **kwargs):
525 super(MPIExecEngineSetLauncher, self).__init__(*args, **kwargs)
525 super(MPIExecEngineSetLauncher, self).__init__(*args, **kwargs)
526 self.warn()
526 self.warn()
527
527
528
528
529 #-----------------------------------------------------------------------------
529 #-----------------------------------------------------------------------------
530 # SSH launchers
530 # SSH launchers
531 #-----------------------------------------------------------------------------
531 #-----------------------------------------------------------------------------
532
532
533 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
533 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
534
534
535 class SSHLauncher(LocalProcessLauncher):
535 class SSHLauncher(LocalProcessLauncher):
536 """A minimal launcher for ssh.
536 """A minimal launcher for ssh.
537
537
538 To be useful this will probably have to be extended to use the ``sshx``
538 To be useful this will probably have to be extended to use the ``sshx``
539 idea for environment variables. There could be other things this needs
539 idea for environment variables. There could be other things this needs
540 as well.
540 as well.
541 """
541 """
542
542
543 ssh_cmd = List(['ssh'], config=True,
543 ssh_cmd = List(['ssh'], config=True,
544 help="command for starting ssh")
544 help="command for starting ssh")
545 ssh_args = List(['-tt'], config=True,
545 ssh_args = List(['-tt'], config=True,
546 help="args to pass to ssh")
546 help="args to pass to ssh")
547 scp_cmd = List(['scp'], config=True,
547 scp_cmd = List(['scp'], config=True,
548 help="command for sending files")
548 help="command for sending files")
549 program = List(['date'],
549 program = List(['date'],
550 help="Program to launch via ssh")
550 help="Program to launch via ssh")
551 program_args = List([],
551 program_args = List([],
552 help="args to pass to remote program")
552 help="args to pass to remote program")
553 hostname = Unicode('', config=True,
553 hostname = Unicode('', config=True,
554 help="hostname on which to launch the program")
554 help="hostname on which to launch the program")
555 user = Unicode('', config=True,
555 user = Unicode('', config=True,
556 help="username for ssh")
556 help="username for ssh")
557 location = Unicode('', config=True,
557 location = Unicode('', config=True,
558 help="user@hostname location for ssh in one setting")
558 help="user@hostname location for ssh in one setting")
559 to_fetch = List([], config=True,
559 to_fetch = List([], config=True,
560 help="List of (remote, local) files to fetch after starting")
560 help="List of (remote, local) files to fetch after starting")
561 to_send = List([], config=True,
561 to_send = List([], config=True,
562 help="List of (local, remote) files to send before starting")
562 help="List of (local, remote) files to send before starting")
563
563
564 def _hostname_changed(self, name, old, new):
564 def _hostname_changed(self, name, old, new):
565 if self.user:
565 if self.user:
566 self.location = u'%s@%s' % (self.user, new)
566 self.location = u'%s@%s' % (self.user, new)
567 else:
567 else:
568 self.location = new
568 self.location = new
569
569
570 def _user_changed(self, name, old, new):
570 def _user_changed(self, name, old, new):
571 self.location = u'%s@%s' % (new, self.hostname)
571 self.location = u'%s@%s' % (new, self.hostname)
572
572
573 def find_args(self):
573 def find_args(self):
574 return self.ssh_cmd + self.ssh_args + [self.location] + \
574 return self.ssh_cmd + self.ssh_args + [self.location] + \
575 list(map(pipes.quote, self.program + self.program_args))
575 list(map(pipes.quote, self.program + self.program_args))
576
576
577 def _send_file(self, local, remote):
577 def _send_file(self, local, remote):
578 """send a single file"""
578 """send a single file"""
579 full_remote = "%s:%s" % (self.location, remote)
579 full_remote = "%s:%s" % (self.location, remote)
580 for i in range(10):
580 for i in range(10):
581 if not os.path.exists(local):
581 if not os.path.exists(local):
582 self.log.debug("waiting for %s" % local)
582 self.log.debug("waiting for %s" % local)
583 time.sleep(1)
583 time.sleep(1)
584 else:
584 else:
585 break
585 break
586 remote_dir = os.path.dirname(remote)
586 remote_dir = os.path.dirname(remote)
587 self.log.info("ensuring remote %s:%s/ exists", self.location, remote_dir)
587 self.log.info("ensuring remote %s:%s/ exists", self.location, remote_dir)
588 check_output(self.ssh_cmd + self.ssh_args + \
588 check_output(self.ssh_cmd + self.ssh_args + \
589 [self.location, 'mkdir', '-p', '--', remote_dir]
589 [self.location, 'mkdir', '-p', '--', remote_dir]
590 )
590 )
591 self.log.info("sending %s to %s", local, full_remote)
591 self.log.info("sending %s to %s", local, full_remote)
592 check_output(self.scp_cmd + [local, full_remote])
592 check_output(self.scp_cmd + [local, full_remote])
593
593
594 def send_files(self):
594 def send_files(self):
595 """send our files (called before start)"""
595 """send our files (called before start)"""
596 if not self.to_send:
596 if not self.to_send:
597 return
597 return
598 for local_file, remote_file in self.to_send:
598 for local_file, remote_file in self.to_send:
599 self._send_file(local_file, remote_file)
599 self._send_file(local_file, remote_file)
600
600
601 def _fetch_file(self, remote, local):
601 def _fetch_file(self, remote, local):
602 """fetch a single file"""
602 """fetch a single file"""
603 full_remote = "%s:%s" % (self.location, remote)
603 full_remote = "%s:%s" % (self.location, remote)
604 self.log.info("fetching %s from %s", local, full_remote)
604 self.log.info("fetching %s from %s", local, full_remote)
605 for i in range(10):
605 for i in range(10):
606 # wait up to 10s for remote file to exist
606 # wait up to 10s for remote file to exist
607 check = check_output(self.ssh_cmd + self.ssh_args + \
607 check = check_output(self.ssh_cmd + self.ssh_args + \
608 [self.location, 'test -e', remote, "&& echo 'yes' || echo 'no'"])
608 [self.location, 'test -e', remote, "&& echo 'yes' || echo 'no'"])
609 check = check.decode(DEFAULT_ENCODING, 'replace').strip()
609 check = check.decode(DEFAULT_ENCODING, 'replace').strip()
610 if check == u'no':
610 if check == u'no':
611 time.sleep(1)
611 time.sleep(1)
612 elif check == u'yes':
612 elif check == u'yes':
613 break
613 break
614 local_dir = os.path.dirname(local)
614 local_dir = os.path.dirname(local)
615 ensure_dir_exists(local_dir, 775)
615 ensure_dir_exists(local_dir, 775)
616 check_output(self.scp_cmd + [full_remote, local])
616 check_output(self.scp_cmd + [full_remote, local])
617
617
618 def fetch_files(self):
618 def fetch_files(self):
619 """fetch remote files (called after start)"""
619 """fetch remote files (called after start)"""
620 if not self.to_fetch:
620 if not self.to_fetch:
621 return
621 return
622 for remote_file, local_file in self.to_fetch:
622 for remote_file, local_file in self.to_fetch:
623 self._fetch_file(remote_file, local_file)
623 self._fetch_file(remote_file, local_file)
624
624
625 def start(self, hostname=None, user=None):
625 def start(self, hostname=None, user=None):
626 if hostname is not None:
626 if hostname is not None:
627 self.hostname = hostname
627 self.hostname = hostname
628 if user is not None:
628 if user is not None:
629 self.user = user
629 self.user = user
630
630
631 self.send_files()
631 self.send_files()
632 super(SSHLauncher, self).start()
632 super(SSHLauncher, self).start()
633 self.fetch_files()
633 self.fetch_files()
634
634
635 def signal(self, sig):
635 def signal(self, sig):
636 if self.state == 'running':
636 if self.state == 'running':
637 # send escaped ssh connection-closer
637 # send escaped ssh connection-closer
638 self.process.stdin.write('~.')
638 self.process.stdin.write('~.')
639 self.process.stdin.flush()
639 self.process.stdin.flush()
640
640
641 class SSHClusterLauncher(SSHLauncher, ClusterAppMixin):
641 class SSHClusterLauncher(SSHLauncher, ClusterAppMixin):
642
642
643 remote_profile_dir = Unicode('', config=True,
643 remote_profile_dir = Unicode('', config=True,
644 help="""The remote profile_dir to use.
644 help="""The remote profile_dir to use.
645
645
646 If not specified, use calling profile, stripping out possible leading homedir.
646 If not specified, use calling profile, stripping out possible leading homedir.
647 """)
647 """)
648
648
649 def _profile_dir_changed(self, name, old, new):
649 def _profile_dir_changed(self, name, old, new):
650 if not self.remote_profile_dir:
650 if not self.remote_profile_dir:
651 # trigger remote_profile_dir_default logic again,
651 # trigger remote_profile_dir_default logic again,
652 # in case it was already triggered before profile_dir was set
652 # in case it was already triggered before profile_dir was set
653 self.remote_profile_dir = self._strip_home(new)
653 self.remote_profile_dir = self._strip_home(new)
654
654
655 @staticmethod
655 @staticmethod
656 def _strip_home(path):
656 def _strip_home(path):
657 """turns /home/you/.ipython/profile_foo into .ipython/profile_foo"""
657 """turns /home/you/.ipython/profile_foo into .ipython/profile_foo"""
658 home = get_home_dir()
658 home = get_home_dir()
659 if not home.endswith('/'):
659 if not home.endswith('/'):
660 home = home+'/'
660 home = home+'/'
661
661
662 if path.startswith(home):
662 if path.startswith(home):
663 return path[len(home):]
663 return path[len(home):]
664 else:
664 else:
665 return path
665 return path
666
666
667 def _remote_profile_dir_default(self):
667 def _remote_profile_dir_default(self):
668 return self._strip_home(self.profile_dir)
668 return self._strip_home(self.profile_dir)
669
669
670 def _cluster_id_changed(self, name, old, new):
670 def _cluster_id_changed(self, name, old, new):
671 if new:
671 if new:
672 raise ValueError("cluster id not supported by SSH launchers")
672 raise ValueError("cluster id not supported by SSH launchers")
673
673
674 @property
674 @property
675 def cluster_args(self):
675 def cluster_args(self):
676 return ['--profile-dir', self.remote_profile_dir]
676 return ['--profile-dir', self.remote_profile_dir]
677
677
678 class SSHControllerLauncher(SSHClusterLauncher, ControllerMixin):
678 class SSHControllerLauncher(SSHClusterLauncher, ControllerMixin):
679
679
680 # alias back to *non-configurable* program[_args] for use in find_args()
680 # alias back to *non-configurable* program[_args] for use in find_args()
681 # this way all Controller/EngineSetLaunchers have the same form, rather
681 # this way all Controller/EngineSetLaunchers have the same form, rather
682 # than *some* having `program_args` and others `controller_args`
682 # than *some* having `program_args` and others `controller_args`
683
683
684 def _controller_cmd_default(self):
684 def _controller_cmd_default(self):
685 return ['ipcontroller']
685 return ['ipcontroller']
686
686
687 @property
687 @property
688 def program(self):
688 def program(self):
689 return self.controller_cmd
689 return self.controller_cmd
690
690
691 @property
691 @property
692 def program_args(self):
692 def program_args(self):
693 return self.cluster_args + self.controller_args
693 return self.cluster_args + self.controller_args
694
694
695 def _to_fetch_default(self):
695 def _to_fetch_default(self):
696 return [
696 return [
697 (os.path.join(self.remote_profile_dir, 'security', cf),
697 (os.path.join(self.remote_profile_dir, 'security', cf),
698 os.path.join(self.profile_dir, 'security', cf),)
698 os.path.join(self.profile_dir, 'security', cf),)
699 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
699 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
700 ]
700 ]
701
701
702 class SSHEngineLauncher(SSHClusterLauncher, EngineMixin):
702 class SSHEngineLauncher(SSHClusterLauncher, EngineMixin):
703
703
704 # alias back to *non-configurable* program[_args] for use in find_args()
704 # alias back to *non-configurable* program[_args] for use in find_args()
705 # this way all Controller/EngineSetLaunchers have the same form, rather
705 # this way all Controller/EngineSetLaunchers have the same form, rather
706 # than *some* having `program_args` and others `controller_args`
706 # than *some* having `program_args` and others `controller_args`
707
707
708 def _engine_cmd_default(self):
708 def _engine_cmd_default(self):
709 return ['ipengine']
709 return ['ipengine']
710
710
711 @property
711 @property
712 def program(self):
712 def program(self):
713 return self.engine_cmd
713 return self.engine_cmd
714
714
715 @property
715 @property
716 def program_args(self):
716 def program_args(self):
717 return self.cluster_args + self.engine_args
717 return self.cluster_args + self.engine_args
718
718
719 def _to_send_default(self):
719 def _to_send_default(self):
720 return [
720 return [
721 (os.path.join(self.profile_dir, 'security', cf),
721 (os.path.join(self.profile_dir, 'security', cf),
722 os.path.join(self.remote_profile_dir, 'security', cf))
722 os.path.join(self.remote_profile_dir, 'security', cf))
723 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
723 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
724 ]
724 ]
725
725
726
726
727 class SSHEngineSetLauncher(LocalEngineSetLauncher):
727 class SSHEngineSetLauncher(LocalEngineSetLauncher):
728 launcher_class = SSHEngineLauncher
728 launcher_class = SSHEngineLauncher
729 engines = Dict(config=True,
729 engines = Dict(config=True,
730 help="""dict of engines to launch. This is a dict by hostname of ints,
730 help="""dict of engines to launch. This is a dict by hostname of ints,
731 corresponding to the number of engines to start on that host.""")
731 corresponding to the number of engines to start on that host.""")
732
732
733 def _engine_cmd_default(self):
733 def _engine_cmd_default(self):
734 return ['ipengine']
734 return ['ipengine']
735
735
736 @property
736 @property
737 def engine_count(self):
737 def engine_count(self):
738 """determine engine count from `engines` dict"""
738 """determine engine count from `engines` dict"""
739 count = 0
739 count = 0
740 for n in itervalues(self.engines):
740 for n in itervalues(self.engines):
741 if isinstance(n, (tuple,list)):
741 if isinstance(n, (tuple,list)):
742 n,args = n
742 n,args = n
743 count += n
743 count += n
744 return count
744 return count
745
745
746 def start(self, n):
746 def start(self, n):
747 """Start engines by profile or profile_dir.
747 """Start engines by profile or profile_dir.
748 `n` is ignored, and the `engines` config property is used instead.
748 `n` is ignored, and the `engines` config property is used instead.
749 """
749 """
750
750
751 dlist = []
751 dlist = []
752 for host, n in iteritems(self.engines):
752 for host, n in iteritems(self.engines):
753 if isinstance(n, (tuple, list)):
753 if isinstance(n, (tuple, list)):
754 n, args = n
754 n, args = n
755 else:
755 else:
756 args = copy.deepcopy(self.engine_args)
756 args = copy.deepcopy(self.engine_args)
757
757
758 if '@' in host:
758 if '@' in host:
759 user,host = host.split('@',1)
759 user,host = host.split('@',1)
760 else:
760 else:
761 user=None
761 user=None
762 for i in range(n):
762 for i in range(n):
763 if i > 0:
763 if i > 0:
764 time.sleep(self.delay)
764 time.sleep(self.delay)
765 el = self.launcher_class(work_dir=self.work_dir, parent=self, log=self.log,
765 el = self.launcher_class(work_dir=self.work_dir, parent=self, log=self.log,
766 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
766 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
767 )
767 )
768 if i > 0:
768 if i > 0:
769 # only send files for the first engine on each host
769 # only send files for the first engine on each host
770 el.to_send = []
770 el.to_send = []
771
771
772 # Copy the engine args over to each engine launcher.
772 # Copy the engine args over to each engine launcher.
773 el.engine_cmd = self.engine_cmd
773 el.engine_cmd = self.engine_cmd
774 el.engine_args = args
774 el.engine_args = args
775 el.on_stop(self._notice_engine_stopped)
775 el.on_stop(self._notice_engine_stopped)
776 d = el.start(user=user, hostname=host)
776 d = el.start(user=user, hostname=host)
777 self.launchers[ "%s/%i" % (host,i) ] = el
777 self.launchers[ "%s/%i" % (host,i) ] = el
778 dlist.append(d)
778 dlist.append(d)
779 self.notify_start(dlist)
779 self.notify_start(dlist)
780 return dlist
780 return dlist
781
781
782
782
783 class SSHProxyEngineSetLauncher(SSHClusterLauncher):
783 class SSHProxyEngineSetLauncher(SSHClusterLauncher):
784 """Launcher for calling
784 """Launcher for calling
785 `ipcluster engines` on a remote machine.
785 `ipcluster engines` on a remote machine.
786
786
787 Requires that remote profile is already configured.
787 Requires that remote profile is already configured.
788 """
788 """
789
789
790 n = Integer()
790 n = Integer()
791 ipcluster_cmd = List(['ipcluster'], config=True)
791 ipcluster_cmd = List(['ipcluster'], config=True)
792
792
793 @property
793 @property
794 def program(self):
794 def program(self):
795 return self.ipcluster_cmd + ['engines']
795 return self.ipcluster_cmd + ['engines']
796
796
797 @property
797 @property
798 def program_args(self):
798 def program_args(self):
799 return ['-n', str(self.n), '--profile-dir', self.remote_profile_dir]
799 return ['-n', str(self.n), '--profile-dir', self.remote_profile_dir]
800
800
801 def _to_send_default(self):
801 def _to_send_default(self):
802 return [
802 return [
803 (os.path.join(self.profile_dir, 'security', cf),
803 (os.path.join(self.profile_dir, 'security', cf),
804 os.path.join(self.remote_profile_dir, 'security', cf))
804 os.path.join(self.remote_profile_dir, 'security', cf))
805 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
805 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
806 ]
806 ]
807
807
808 def start(self, n):
808 def start(self, n):
809 self.n = n
809 self.n = n
810 super(SSHProxyEngineSetLauncher, self).start()
810 super(SSHProxyEngineSetLauncher, self).start()
811
811
812
812
813 #-----------------------------------------------------------------------------
813 #-----------------------------------------------------------------------------
814 # Windows HPC Server 2008 scheduler launchers
814 # Windows HPC Server 2008 scheduler launchers
815 #-----------------------------------------------------------------------------
815 #-----------------------------------------------------------------------------
816
816
817
817
818 # This is only used on Windows.
818 # This is only used on Windows.
819 def find_job_cmd():
819 def find_job_cmd():
820 if WINDOWS:
820 if WINDOWS:
821 try:
821 try:
822 return find_cmd('job')
822 return find_cmd('job')
823 except (FindCmdError, ImportError):
823 except (FindCmdError, ImportError):
824 # ImportError will be raised if win32api is not installed
824 # ImportError will be raised if win32api is not installed
825 return 'job'
825 return 'job'
826 else:
826 else:
827 return 'job'
827 return 'job'
828
828
829
829
830 class WindowsHPCLauncher(BaseLauncher):
830 class WindowsHPCLauncher(BaseLauncher):
831
831
832 job_id_regexp = CRegExp(r'\d+', config=True,
832 job_id_regexp = CRegExp(r'\d+', config=True,
833 help="""A regular expression used to get the job id from the output of the
833 help="""A regular expression used to get the job id from the output of the
834 submit_command. """
834 submit_command. """
835 )
835 )
836 job_file_name = Unicode(u'ipython_job.xml', config=True,
836 job_file_name = Unicode(u'ipython_job.xml', config=True,
837 help="The filename of the instantiated job script.")
837 help="The filename of the instantiated job script.")
838 # The full path to the instantiated job script. This gets made dynamically
838 # The full path to the instantiated job script. This gets made dynamically
839 # by combining the work_dir with the job_file_name.
839 # by combining the work_dir with the job_file_name.
840 job_file = Unicode(u'')
840 job_file = Unicode(u'')
841 scheduler = Unicode('', config=True,
841 scheduler = Unicode('', config=True,
842 help="The hostname of the scheduler to submit the job to.")
842 help="The hostname of the scheduler to submit the job to.")
843 job_cmd = Unicode(find_job_cmd(), config=True,
843 job_cmd = Unicode(find_job_cmd(), config=True,
844 help="The command for submitting jobs.")
844 help="The command for submitting jobs.")
845
845
846 def __init__(self, work_dir=u'.', config=None, **kwargs):
846 def __init__(self, work_dir=u'.', config=None, **kwargs):
847 super(WindowsHPCLauncher, self).__init__(
847 super(WindowsHPCLauncher, self).__init__(
848 work_dir=work_dir, config=config, **kwargs
848 work_dir=work_dir, config=config, **kwargs
849 )
849 )
850
850
851 @property
851 @property
852 def job_file(self):
852 def job_file(self):
853 return os.path.join(self.work_dir, self.job_file_name)
853 return os.path.join(self.work_dir, self.job_file_name)
854
854
855 def write_job_file(self, n):
855 def write_job_file(self, n):
856 raise NotImplementedError("Implement write_job_file in a subclass.")
856 raise NotImplementedError("Implement write_job_file in a subclass.")
857
857
858 def find_args(self):
858 def find_args(self):
859 return [u'job.exe']
859 return [u'job.exe']
860
860
861 def parse_job_id(self, output):
861 def parse_job_id(self, output):
862 """Take the output of the submit command and return the job id."""
862 """Take the output of the submit command and return the job id."""
863 m = self.job_id_regexp.search(output)
863 m = self.job_id_regexp.search(output)
864 if m is not None:
864 if m is not None:
865 job_id = m.group()
865 job_id = m.group()
866 else:
866 else:
867 raise LauncherError("Job id couldn't be determined: %s" % output)
867 raise LauncherError("Job id couldn't be determined: %s" % output)
868 self.job_id = job_id
868 self.job_id = job_id
869 self.log.info('Job started with id: %r', job_id)
869 self.log.info('Job started with id: %r', job_id)
870 return job_id
870 return job_id
871
871
872 def start(self, n):
872 def start(self, n):
873 """Start n copies of the process using the Win HPC job scheduler."""
873 """Start n copies of the process using the Win HPC job scheduler."""
874 self.write_job_file(n)
874 self.write_job_file(n)
875 args = [
875 args = [
876 'submit',
876 'submit',
877 '/jobfile:%s' % self.job_file,
877 '/jobfile:%s' % self.job_file,
878 '/scheduler:%s' % self.scheduler
878 '/scheduler:%s' % self.scheduler
879 ]
879 ]
880 self.log.debug("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
880 self.log.debug("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
881
881
882 output = check_output([self.job_cmd]+args,
882 output = check_output([self.job_cmd]+args,
883 env=os.environ,
883 env=os.environ,
884 cwd=self.work_dir,
884 cwd=self.work_dir,
885 stderr=STDOUT
885 stderr=STDOUT
886 )
886 )
887 output = output.decode(DEFAULT_ENCODING, 'replace')
887 output = output.decode(DEFAULT_ENCODING, 'replace')
888 job_id = self.parse_job_id(output)
888 job_id = self.parse_job_id(output)
889 self.notify_start(job_id)
889 self.notify_start(job_id)
890 return job_id
890 return job_id
891
891
892 def stop(self):
892 def stop(self):
893 args = [
893 args = [
894 'cancel',
894 'cancel',
895 self.job_id,
895 self.job_id,
896 '/scheduler:%s' % self.scheduler
896 '/scheduler:%s' % self.scheduler
897 ]
897 ]
898 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
898 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
899 try:
899 try:
900 output = check_output([self.job_cmd]+args,
900 output = check_output([self.job_cmd]+args,
901 env=os.environ,
901 env=os.environ,
902 cwd=self.work_dir,
902 cwd=self.work_dir,
903 stderr=STDOUT
903 stderr=STDOUT
904 )
904 )
905 output = output.decode(DEFAULT_ENCODING, 'replace')
905 output = output.decode(DEFAULT_ENCODING, 'replace')
906 except:
906 except:
907 output = u'The job already appears to be stopped: %r' % self.job_id
907 output = u'The job already appears to be stopped: %r' % self.job_id
908 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
908 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
909 return output
909 return output
910
910
911
911
912 class WindowsHPCControllerLauncher(WindowsHPCLauncher, ClusterAppMixin):
912 class WindowsHPCControllerLauncher(WindowsHPCLauncher, ClusterAppMixin):
913
913
914 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
914 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
915 help="WinHPC xml job file.")
915 help="WinHPC xml job file.")
916 controller_args = List([], config=False,
916 controller_args = List([], config=False,
917 help="extra args to pass to ipcontroller")
917 help="extra args to pass to ipcontroller")
918
918
919 def write_job_file(self, n):
919 def write_job_file(self, n):
920 job = IPControllerJob(parent=self)
920 job = IPControllerJob(parent=self)
921
921
922 t = IPControllerTask(parent=self)
922 t = IPControllerTask(parent=self)
923 # The tasks work directory is *not* the actual work directory of
923 # The tasks work directory is *not* the actual work directory of
924 # the controller. It is used as the base path for the stdout/stderr
924 # the controller. It is used as the base path for the stdout/stderr
925 # files that the scheduler redirects to.
925 # files that the scheduler redirects to.
926 t.work_directory = self.profile_dir
926 t.work_directory = self.profile_dir
927 # Add the profile_dir and from self.start().
927 # Add the profile_dir and from self.start().
928 t.controller_args.extend(self.cluster_args)
928 t.controller_args.extend(self.cluster_args)
929 t.controller_args.extend(self.controller_args)
929 t.controller_args.extend(self.controller_args)
930 job.add_task(t)
930 job.add_task(t)
931
931
932 self.log.debug("Writing job description file: %s", self.job_file)
932 self.log.debug("Writing job description file: %s", self.job_file)
933 job.write(self.job_file)
933 job.write(self.job_file)
934
934
935 @property
935 @property
936 def job_file(self):
936 def job_file(self):
937 return os.path.join(self.profile_dir, self.job_file_name)
937 return os.path.join(self.profile_dir, self.job_file_name)
938
938
939 def start(self):
939 def start(self):
940 """Start the controller by profile_dir."""
940 """Start the controller by profile_dir."""
941 return super(WindowsHPCControllerLauncher, self).start(1)
941 return super(WindowsHPCControllerLauncher, self).start(1)
942
942
943
943
944 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher, ClusterAppMixin):
944 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher, ClusterAppMixin):
945
945
946 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
946 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
947 help="jobfile for ipengines job")
947 help="jobfile for ipengines job")
948 engine_args = List([], config=False,
948 engine_args = List([], config=False,
949 help="extra args to pas to ipengine")
949 help="extra args to pas to ipengine")
950
950
951 def write_job_file(self, n):
951 def write_job_file(self, n):
952 job = IPEngineSetJob(parent=self)
952 job = IPEngineSetJob(parent=self)
953
953
954 for i in range(n):
954 for i in range(n):
955 t = IPEngineTask(parent=self)
955 t = IPEngineTask(parent=self)
956 # The tasks work directory is *not* the actual work directory of
956 # The tasks work directory is *not* the actual work directory of
957 # the engine. It is used as the base path for the stdout/stderr
957 # the engine. It is used as the base path for the stdout/stderr
958 # files that the scheduler redirects to.
958 # files that the scheduler redirects to.
959 t.work_directory = self.profile_dir
959 t.work_directory = self.profile_dir
960 # Add the profile_dir and from self.start().
960 # Add the profile_dir and from self.start().
961 t.engine_args.extend(self.cluster_args)
961 t.engine_args.extend(self.cluster_args)
962 t.engine_args.extend(self.engine_args)
962 t.engine_args.extend(self.engine_args)
963 job.add_task(t)
963 job.add_task(t)
964
964
965 self.log.debug("Writing job description file: %s", self.job_file)
965 self.log.debug("Writing job description file: %s", self.job_file)
966 job.write(self.job_file)
966 job.write(self.job_file)
967
967
968 @property
968 @property
969 def job_file(self):
969 def job_file(self):
970 return os.path.join(self.profile_dir, self.job_file_name)
970 return os.path.join(self.profile_dir, self.job_file_name)
971
971
972 def start(self, n):
972 def start(self, n):
973 """Start the controller by profile_dir."""
973 """Start the controller by profile_dir."""
974 return super(WindowsHPCEngineSetLauncher, self).start(n)
974 return super(WindowsHPCEngineSetLauncher, self).start(n)
975
975
976
976
977 #-----------------------------------------------------------------------------
977 #-----------------------------------------------------------------------------
978 # Batch (PBS) system launchers
978 # Batch (PBS) system launchers
979 #-----------------------------------------------------------------------------
979 #-----------------------------------------------------------------------------
980
980
981 class BatchClusterAppMixin(ClusterAppMixin):
981 class BatchClusterAppMixin(ClusterAppMixin):
982 """ClusterApp mixin that updates the self.context dict, rather than cl-args."""
982 """ClusterApp mixin that updates the self.context dict, rather than cl-args."""
983 def _profile_dir_changed(self, name, old, new):
983 def _profile_dir_changed(self, name, old, new):
984 self.context[name] = new
984 self.context[name] = new
985 _cluster_id_changed = _profile_dir_changed
985 _cluster_id_changed = _profile_dir_changed
986
986
987 def _profile_dir_default(self):
987 def _profile_dir_default(self):
988 self.context['profile_dir'] = ''
988 self.context['profile_dir'] = ''
989 return ''
989 return ''
990 def _cluster_id_default(self):
990 def _cluster_id_default(self):
991 self.context['cluster_id'] = ''
991 self.context['cluster_id'] = ''
992 return ''
992 return ''
993
993
994
994
995 class BatchSystemLauncher(BaseLauncher):
995 class BatchSystemLauncher(BaseLauncher):
996 """Launch an external process using a batch system.
996 """Launch an external process using a batch system.
997
997
998 This class is designed to work with UNIX batch systems like PBS, LSF,
998 This class is designed to work with UNIX batch systems like PBS, LSF,
999 GridEngine, etc. The overall model is that there are different commands
999 GridEngine, etc. The overall model is that there are different commands
1000 like qsub, qdel, etc. that handle the starting and stopping of the process.
1000 like qsub, qdel, etc. that handle the starting and stopping of the process.
1001
1001
1002 This class also has the notion of a batch script. The ``batch_template``
1002 This class also has the notion of a batch script. The ``batch_template``
1003 attribute can be set to a string that is a template for the batch script.
1003 attribute can be set to a string that is a template for the batch script.
1004 This template is instantiated using string formatting. Thus the template can
1004 This template is instantiated using string formatting. Thus the template can
1005 use {n} fot the number of instances. Subclasses can add additional variables
1005 use {n} fot the number of instances. Subclasses can add additional variables
1006 to the template dict.
1006 to the template dict.
1007 """
1007 """
1008
1008
1009 # Subclasses must fill these in. See PBSEngineSet
1009 # Subclasses must fill these in. See PBSEngineSet
1010 submit_command = List([''], config=True,
1010 submit_command = List([''], config=True,
1011 help="The name of the command line program used to submit jobs.")
1011 help="The name of the command line program used to submit jobs.")
1012 delete_command = List([''], config=True,
1012 delete_command = List([''], config=True,
1013 help="The name of the command line program used to delete jobs.")
1013 help="The name of the command line program used to delete jobs.")
1014 job_id_regexp = CRegExp('', config=True,
1014 job_id_regexp = CRegExp('', config=True,
1015 help="""A regular expression used to get the job id from the output of the
1015 help="""A regular expression used to get the job id from the output of the
1016 submit_command.""")
1016 submit_command.""")
1017 job_id_regexp_group = Integer(0, config=True,
1017 job_id_regexp_group = Integer(0, config=True,
1018 help="""The group we wish to match in job_id_regexp (0 to match all)""")
1018 help="""The group we wish to match in job_id_regexp (0 to match all)""")
1019 batch_template = Unicode('', config=True,
1019 batch_template = Unicode('', config=True,
1020 help="The string that is the batch script template itself.")
1020 help="The string that is the batch script template itself.")
1021 batch_template_file = Unicode(u'', config=True,
1021 batch_template_file = Unicode(u'', config=True,
1022 help="The file that contains the batch template.")
1022 help="The file that contains the batch template.")
1023 batch_file_name = Unicode(u'batch_script', config=True,
1023 batch_file_name = Unicode(u'batch_script', config=True,
1024 help="The filename of the instantiated batch script.")
1024 help="The filename of the instantiated batch script.")
1025 queue = Unicode(u'', config=True,
1025 queue = Unicode(u'', config=True,
1026 help="The PBS Queue.")
1026 help="The PBS Queue.")
1027
1027
1028 def _queue_changed(self, name, old, new):
1028 def _queue_changed(self, name, old, new):
1029 self.context[name] = new
1029 self.context[name] = new
1030
1030
1031 n = Integer(1)
1031 n = Integer(1)
1032 _n_changed = _queue_changed
1032 _n_changed = _queue_changed
1033
1033
1034 # not configurable, override in subclasses
1034 # not configurable, override in subclasses
1035 # PBS Job Array regex
1035 # PBS Job Array regex
1036 job_array_regexp = CRegExp('')
1036 job_array_regexp = CRegExp('')
1037 job_array_template = Unicode('')
1037 job_array_template = Unicode('')
1038 # PBS Queue regex
1038 # PBS Queue regex
1039 queue_regexp = CRegExp('')
1039 queue_regexp = CRegExp('')
1040 queue_template = Unicode('')
1040 queue_template = Unicode('')
1041 # The default batch template, override in subclasses
1041 # The default batch template, override in subclasses
1042 default_template = Unicode('')
1042 default_template = Unicode('')
1043 # The full path to the instantiated batch script.
1043 # The full path to the instantiated batch script.
1044 batch_file = Unicode(u'')
1044 batch_file = Unicode(u'')
1045 # the format dict used with batch_template:
1045 # the format dict used with batch_template:
1046 context = Dict()
1046 context = Dict()
1047
1047
1048 def _context_default(self):
1048 def _context_default(self):
1049 """load the default context with the default values for the basic keys
1049 """load the default context with the default values for the basic keys
1050
1050
1051 because the _trait_changed methods only load the context if they
1051 because the _trait_changed methods only load the context if they
1052 are set to something other than the default value.
1052 are set to something other than the default value.
1053 """
1053 """
1054 return dict(n=1, queue=u'', profile_dir=u'', cluster_id=u'')
1054 return dict(n=1, queue=u'', profile_dir=u'', cluster_id=u'')
1055
1055
1056 # the Formatter instance for rendering the templates:
1056 # the Formatter instance for rendering the templates:
1057 formatter = Instance(EvalFormatter, (), {})
1057 formatter = Instance(EvalFormatter, (), {})
1058
1058
1059 def find_args(self):
1059 def find_args(self):
1060 return self.submit_command + [self.batch_file]
1060 return self.submit_command + [self.batch_file]
1061
1061
1062 def __init__(self, work_dir=u'.', config=None, **kwargs):
1062 def __init__(self, work_dir=u'.', config=None, **kwargs):
1063 super(BatchSystemLauncher, self).__init__(
1063 super(BatchSystemLauncher, self).__init__(
1064 work_dir=work_dir, config=config, **kwargs
1064 work_dir=work_dir, config=config, **kwargs
1065 )
1065 )
1066 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
1066 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
1067
1067
1068 def parse_job_id(self, output):
1068 def parse_job_id(self, output):
1069 """Take the output of the submit command and return the job id."""
1069 """Take the output of the submit command and return the job id."""
1070 m = self.job_id_regexp.search(output)
1070 m = self.job_id_regexp.search(output)
1071 if m is not None:
1071 if m is not None:
1072 job_id = m.group(self.job_id_regexp_group)
1072 job_id = m.group(self.job_id_regexp_group)
1073 else:
1073 else:
1074 raise LauncherError("Job id couldn't be determined: %s" % output)
1074 raise LauncherError("Job id couldn't be determined: %s" % output)
1075 self.job_id = job_id
1075 self.job_id = job_id
1076 self.log.info('Job submitted with job id: %r', job_id)
1076 self.log.info('Job submitted with job id: %r', job_id)
1077 return job_id
1077 return job_id
1078
1078
1079 def write_batch_script(self, n):
1079 def write_batch_script(self, n):
1080 """Instantiate and write the batch script to the work_dir."""
1080 """Instantiate and write the batch script to the work_dir."""
1081 self.n = n
1081 self.n = n
1082 # first priority is batch_template if set
1082 # first priority is batch_template if set
1083 if self.batch_template_file and not self.batch_template:
1083 if self.batch_template_file and not self.batch_template:
1084 # second priority is batch_template_file
1084 # second priority is batch_template_file
1085 with open(self.batch_template_file) as f:
1085 with open(self.batch_template_file) as f:
1086 self.batch_template = f.read()
1086 self.batch_template = f.read()
1087 if not self.batch_template:
1087 if not self.batch_template:
1088 # third (last) priority is default_template
1088 # third (last) priority is default_template
1089 self.batch_template = self.default_template
1089 self.batch_template = self.default_template
1090 # add jobarray or queue lines to user-specified template
1090 # add jobarray or queue lines to user-specified template
1091 # note that this is *only* when user did not specify a template.
1091 # note that this is *only* when user did not specify a template.
1092 self._insert_queue_in_script()
1092 self._insert_queue_in_script()
1093 self._insert_job_array_in_script()
1093 self._insert_job_array_in_script()
1094 script_as_string = self.formatter.format(self.batch_template, **self.context)
1094 script_as_string = self.formatter.format(self.batch_template, **self.context)
1095 self.log.debug('Writing batch script: %s', self.batch_file)
1095 self.log.debug('Writing batch script: %s', self.batch_file)
1096 with open(self.batch_file, 'w') as f:
1096 with open(self.batch_file, 'w') as f:
1097 f.write(script_as_string)
1097 f.write(script_as_string)
1098 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1098 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1099
1099
1100 def _insert_queue_in_script(self):
1100 def _insert_queue_in_script(self):
1101 """Inserts a queue if required into the batch script.
1101 """Inserts a queue if required into the batch script.
1102 """
1102 """
1103 if self.queue and not self.queue_regexp.search(self.batch_template):
1103 if self.queue and not self.queue_regexp.search(self.batch_template):
1104 self.log.debug("adding PBS queue settings to batch script")
1104 self.log.debug("adding PBS queue settings to batch script")
1105 firstline, rest = self.batch_template.split('\n',1)
1105 firstline, rest = self.batch_template.split('\n',1)
1106 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
1106 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
1107
1107
1108 def _insert_job_array_in_script(self):
1108 def _insert_job_array_in_script(self):
1109 """Inserts a job array if required into the batch script.
1109 """Inserts a job array if required into the batch script.
1110 """
1110 """
1111 if not self.job_array_regexp.search(self.batch_template):
1111 if not self.job_array_regexp.search(self.batch_template):
1112 self.log.debug("adding job array settings to batch script")
1112 self.log.debug("adding job array settings to batch script")
1113 firstline, rest = self.batch_template.split('\n',1)
1113 firstline, rest = self.batch_template.split('\n',1)
1114 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
1114 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
1115
1115
1116 def start(self, n):
1116 def start(self, n):
1117 """Start n copies of the process using a batch system."""
1117 """Start n copies of the process using a batch system."""
1118 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
1118 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
1119 # Here we save profile_dir in the context so they
1119 # Here we save profile_dir in the context so they
1120 # can be used in the batch script template as {profile_dir}
1120 # can be used in the batch script template as {profile_dir}
1121 self.write_batch_script(n)
1121 self.write_batch_script(n)
1122 output = check_output(self.args, env=os.environ)
1122 output = check_output(self.args, env=os.environ)
1123 output = output.decode(DEFAULT_ENCODING, 'replace')
1123 output = output.decode(DEFAULT_ENCODING, 'replace')
1124
1124
1125 job_id = self.parse_job_id(output)
1125 job_id = self.parse_job_id(output)
1126 self.notify_start(job_id)
1126 self.notify_start(job_id)
1127 return job_id
1127 return job_id
1128
1128
1129 def stop(self):
1129 def stop(self):
1130 try:
1130 try:
1131 p = Popen(self.delete_command+[self.job_id], env=os.environ,
1131 p = Popen(self.delete_command+[self.job_id], env=os.environ,
1132 stdout=PIPE, stderr=PIPE)
1132 stdout=PIPE, stderr=PIPE)
1133 out, err = p.communicate()
1133 out, err = p.communicate()
1134 output = out + err
1134 output = out + err
1135 except:
1135 except:
1136 self.log.exception("Problem stopping cluster with command: %s" %
1136 self.log.exception("Problem stopping cluster with command: %s" %
1137 (self.delete_command + [self.job_id]))
1137 (self.delete_command + [self.job_id]))
1138 output = ""
1138 output = ""
1139 output = output.decode(DEFAULT_ENCODING, 'replace')
1139 output = output.decode(DEFAULT_ENCODING, 'replace')
1140 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
1140 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
1141 return output
1141 return output
1142
1142
1143
1143
1144 class PBSLauncher(BatchSystemLauncher):
1144 class PBSLauncher(BatchSystemLauncher):
1145 """A BatchSystemLauncher subclass for PBS."""
1145 """A BatchSystemLauncher subclass for PBS."""
1146
1146
1147 submit_command = List(['qsub'], config=True,
1147 submit_command = List(['qsub'], config=True,
1148 help="The PBS submit command ['qsub']")
1148 help="The PBS submit command ['qsub']")
1149 delete_command = List(['qdel'], config=True,
1149 delete_command = List(['qdel'], config=True,
1150 help="The PBS delete command ['qsub']")
1150 help="The PBS delete command ['qsub']")
1151 job_id_regexp = CRegExp(r'\d+', config=True,
1151 job_id_regexp = CRegExp(r'\d+', config=True,
1152 help="Regular expresion for identifying the job ID [r'\d+']")
1152 help="Regular expresion for identifying the job ID [r'\d+']")
1153
1153
1154 batch_file = Unicode(u'')
1154 batch_file = Unicode(u'')
1155 job_array_regexp = CRegExp('#PBS\W+-t\W+[\w\d\-\$]+')
1155 job_array_regexp = CRegExp('#PBS\W+-t\W+[\w\d\-\$]+')
1156 job_array_template = Unicode('#PBS -t 1-{n}')
1156 job_array_template = Unicode('#PBS -t 1-{n}')
1157 queue_regexp = CRegExp('#PBS\W+-q\W+\$?\w+')
1157 queue_regexp = CRegExp('#PBS\W+-q\W+\$?\w+')
1158 queue_template = Unicode('#PBS -q {queue}')
1158 queue_template = Unicode('#PBS -q {queue}')
1159
1159
1160
1160
1161 class PBSControllerLauncher(PBSLauncher, BatchClusterAppMixin):
1161 class PBSControllerLauncher(PBSLauncher, BatchClusterAppMixin):
1162 """Launch a controller using PBS."""
1162 """Launch a controller using PBS."""
1163
1163
1164 batch_file_name = Unicode(u'pbs_controller', config=True,
1164 batch_file_name = Unicode(u'pbs_controller', config=True,
1165 help="batch file name for the controller job.")
1165 help="batch file name for the controller job.")
1166 default_template= Unicode("""#!/bin/sh
1166 default_template= Unicode("""#!/bin/sh
1167 #PBS -V
1167 #PBS -V
1168 #PBS -N ipcontroller
1168 #PBS -N ipcontroller
1169 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1169 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1170 """%(' '.join(map(pipes.quote, ipcontroller_cmd_argv))))
1170 """%(' '.join(map(pipes.quote, ipcontroller_cmd_argv))))
1171
1171
1172 def start(self):
1172 def start(self):
1173 """Start the controller by profile or profile_dir."""
1173 """Start the controller by profile or profile_dir."""
1174 return super(PBSControllerLauncher, self).start(1)
1174 return super(PBSControllerLauncher, self).start(1)
1175
1175
1176
1176
1177 class PBSEngineSetLauncher(PBSLauncher, BatchClusterAppMixin):
1177 class PBSEngineSetLauncher(PBSLauncher, BatchClusterAppMixin):
1178 """Launch Engines using PBS"""
1178 """Launch Engines using PBS"""
1179 batch_file_name = Unicode(u'pbs_engines', config=True,
1179 batch_file_name = Unicode(u'pbs_engines', config=True,
1180 help="batch file name for the engine(s) job.")
1180 help="batch file name for the engine(s) job.")
1181 default_template= Unicode(u"""#!/bin/sh
1181 default_template= Unicode(u"""#!/bin/sh
1182 #PBS -V
1182 #PBS -V
1183 #PBS -N ipengine
1183 #PBS -N ipengine
1184 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1184 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1185 """%(' '.join(map(pipes.quote,ipengine_cmd_argv))))
1185 """%(' '.join(map(pipes.quote,ipengine_cmd_argv))))
1186
1186
1187
1187
1188 #SGE is very similar to PBS
1188 #SGE is very similar to PBS
1189
1189
1190 class SGELauncher(PBSLauncher):
1190 class SGELauncher(PBSLauncher):
1191 """Sun GridEngine is a PBS clone with slightly different syntax"""
1191 """Sun GridEngine is a PBS clone with slightly different syntax"""
1192 job_array_regexp = CRegExp('#\$\W+\-t')
1192 job_array_regexp = CRegExp('#\$\W+\-t')
1193 job_array_template = Unicode('#$ -t 1-{n}')
1193 job_array_template = Unicode('#$ -t 1-{n}')
1194 queue_regexp = CRegExp('#\$\W+-q\W+\$?\w+')
1194 queue_regexp = CRegExp('#\$\W+-q\W+\$?\w+')
1195 queue_template = Unicode('#$ -q {queue}')
1195 queue_template = Unicode('#$ -q {queue}')
1196
1196
1197
1197
1198 class SGEControllerLauncher(SGELauncher, BatchClusterAppMixin):
1198 class SGEControllerLauncher(SGELauncher, BatchClusterAppMixin):
1199 """Launch a controller using SGE."""
1199 """Launch a controller using SGE."""
1200
1200
1201 batch_file_name = Unicode(u'sge_controller', config=True,
1201 batch_file_name = Unicode(u'sge_controller', config=True,
1202 help="batch file name for the ipontroller job.")
1202 help="batch file name for the ipontroller job.")
1203 default_template= Unicode(u"""#$ -V
1203 default_template= Unicode(u"""#$ -V
1204 #$ -S /bin/sh
1204 #$ -S /bin/sh
1205 #$ -N ipcontroller
1205 #$ -N ipcontroller
1206 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1206 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1207 """%(' '.join(map(pipes.quote, ipcontroller_cmd_argv))))
1207 """%(' '.join(map(pipes.quote, ipcontroller_cmd_argv))))
1208
1208
1209 def start(self):
1209 def start(self):
1210 """Start the controller by profile or profile_dir."""
1210 """Start the controller by profile or profile_dir."""
1211 return super(SGEControllerLauncher, self).start(1)
1211 return super(SGEControllerLauncher, self).start(1)
1212
1212
1213
1213
1214 class SGEEngineSetLauncher(SGELauncher, BatchClusterAppMixin):
1214 class SGEEngineSetLauncher(SGELauncher, BatchClusterAppMixin):
1215 """Launch Engines with SGE"""
1215 """Launch Engines with SGE"""
1216 batch_file_name = Unicode(u'sge_engines', config=True,
1216 batch_file_name = Unicode(u'sge_engines', config=True,
1217 help="batch file name for the engine(s) job.")
1217 help="batch file name for the engine(s) job.")
1218 default_template = Unicode("""#$ -V
1218 default_template = Unicode("""#$ -V
1219 #$ -S /bin/sh
1219 #$ -S /bin/sh
1220 #$ -N ipengine
1220 #$ -N ipengine
1221 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1221 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1222 """%(' '.join(map(pipes.quote, ipengine_cmd_argv))))
1222 """%(' '.join(map(pipes.quote, ipengine_cmd_argv))))
1223
1223
1224
1224
1225 # LSF launchers
1225 # LSF launchers
1226
1226
1227 class LSFLauncher(BatchSystemLauncher):
1227 class LSFLauncher(BatchSystemLauncher):
1228 """A BatchSystemLauncher subclass for LSF."""
1228 """A BatchSystemLauncher subclass for LSF."""
1229
1229
1230 submit_command = List(['bsub'], config=True,
1230 submit_command = List(['bsub'], config=True,
1231 help="The PBS submit command ['bsub']")
1231 help="The PBS submit command ['bsub']")
1232 delete_command = List(['bkill'], config=True,
1232 delete_command = List(['bkill'], config=True,
1233 help="The PBS delete command ['bkill']")
1233 help="The PBS delete command ['bkill']")
1234 job_id_regexp = CRegExp(r'\d+', config=True,
1234 job_id_regexp = CRegExp(r'\d+', config=True,
1235 help="Regular expresion for identifying the job ID [r'\d+']")
1235 help="Regular expresion for identifying the job ID [r'\d+']")
1236
1236
1237 batch_file = Unicode(u'')
1237 batch_file = Unicode(u'')
1238 job_array_regexp = CRegExp('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1238 job_array_regexp = CRegExp('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1239 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1239 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1240 queue_regexp = CRegExp('#BSUB[ \t]+-q[ \t]+\w+')
1240 queue_regexp = CRegExp('#BSUB[ \t]+-q[ \t]+\w+')
1241 queue_template = Unicode('#BSUB -q {queue}')
1241 queue_template = Unicode('#BSUB -q {queue}')
1242
1242
1243 def start(self, n):
1243 def start(self, n):
1244 """Start n copies of the process using LSF batch system.
1244 """Start n copies of the process using LSF batch system.
1245 This cant inherit from the base class because bsub expects
1245 This cant inherit from the base class because bsub expects
1246 to be piped a shell script in order to honor the #BSUB directives :
1246 to be piped a shell script in order to honor the #BSUB directives :
1247 bsub < script
1247 bsub < script
1248 """
1248 """
1249 # Here we save profile_dir in the context so they
1249 # Here we save profile_dir in the context so they
1250 # can be used in the batch script template as {profile_dir}
1250 # can be used in the batch script template as {profile_dir}
1251 self.write_batch_script(n)
1251 self.write_batch_script(n)
1252 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1252 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1253 self.log.debug("Starting %s: %s", self.__class__.__name__, piped_cmd)
1253 self.log.debug("Starting %s: %s", self.__class__.__name__, piped_cmd)
1254 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1254 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1255 output,err = p.communicate()
1255 output,err = p.communicate()
1256 output = output.decode(DEFAULT_ENCODING, 'replace')
1256 output = output.decode(DEFAULT_ENCODING, 'replace')
1257 job_id = self.parse_job_id(output)
1257 job_id = self.parse_job_id(output)
1258 self.notify_start(job_id)
1258 self.notify_start(job_id)
1259 return job_id
1259 return job_id
1260
1260
1261
1261
1262 class LSFControllerLauncher(LSFLauncher, BatchClusterAppMixin):
1262 class LSFControllerLauncher(LSFLauncher, BatchClusterAppMixin):
1263 """Launch a controller using LSF."""
1263 """Launch a controller using LSF."""
1264
1264
1265 batch_file_name = Unicode(u'lsf_controller', config=True,
1265 batch_file_name = Unicode(u'lsf_controller', config=True,
1266 help="batch file name for the controller job.")
1266 help="batch file name for the controller job.")
1267 default_template= Unicode("""#!/bin/sh
1267 default_template= Unicode("""#!/bin/sh
1268 #BSUB -J ipcontroller
1268 #BSUB -J ipcontroller
1269 #BSUB -oo ipcontroller.o.%%J
1269 #BSUB -oo ipcontroller.o.%%J
1270 #BSUB -eo ipcontroller.e.%%J
1270 #BSUB -eo ipcontroller.e.%%J
1271 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1271 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1272 """%(' '.join(map(pipes.quote,ipcontroller_cmd_argv))))
1272 """%(' '.join(map(pipes.quote,ipcontroller_cmd_argv))))
1273
1273
1274 def start(self):
1274 def start(self):
1275 """Start the controller by profile or profile_dir."""
1275 """Start the controller by profile or profile_dir."""
1276 return super(LSFControllerLauncher, self).start(1)
1276 return super(LSFControllerLauncher, self).start(1)
1277
1277
1278
1278
1279 class LSFEngineSetLauncher(LSFLauncher, BatchClusterAppMixin):
1279 class LSFEngineSetLauncher(LSFLauncher, BatchClusterAppMixin):
1280 """Launch Engines using LSF"""
1280 """Launch Engines using LSF"""
1281 batch_file_name = Unicode(u'lsf_engines', config=True,
1281 batch_file_name = Unicode(u'lsf_engines', config=True,
1282 help="batch file name for the engine(s) job.")
1282 help="batch file name for the engine(s) job.")
1283 default_template= Unicode(u"""#!/bin/sh
1283 default_template= Unicode(u"""#!/bin/sh
1284 #BSUB -oo ipengine.o.%%J
1284 #BSUB -oo ipengine.o.%%J
1285 #BSUB -eo ipengine.e.%%J
1285 #BSUB -eo ipengine.e.%%J
1286 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1286 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1287 """%(' '.join(map(pipes.quote, ipengine_cmd_argv))))
1287 """%(' '.join(map(pipes.quote, ipengine_cmd_argv))))
1288
1288
1289
1289
1290
1290
1291 class HTCondorLauncher(BatchSystemLauncher):
1291 class HTCondorLauncher(BatchSystemLauncher):
1292 """A BatchSystemLauncher subclass for HTCondor.
1292 """A BatchSystemLauncher subclass for HTCondor.
1293
1293
1294 HTCondor requires that we launch the ipengine/ipcontroller scripts rather
1294 HTCondor requires that we launch the ipengine/ipcontroller scripts rather
1295 that the python instance but otherwise is very similar to PBS. This is because
1295 that the python instance but otherwise is very similar to PBS. This is because
1296 HTCondor destroys sys.executable when launching remote processes - a launched
1296 HTCondor destroys sys.executable when launching remote processes - a launched
1297 python process depends on sys.executable to effectively evaluate its
1297 python process depends on sys.executable to effectively evaluate its
1298 module search paths. Without it, regardless of which python interpreter you launch
1298 module search paths. Without it, regardless of which python interpreter you launch
1299 you will get the to built in module search paths.
1299 you will get the to built in module search paths.
1300
1300
1301 We use the ip{cluster, engine, controller} scripts as our executable to circumvent
1301 We use the ip{cluster, engine, controller} scripts as our executable to circumvent
1302 this - the mechanism of shebanged scripts means that the python binary will be
1302 this - the mechanism of shebanged scripts means that the python binary will be
1303 launched with argv[0] set to the *location of the ip{cluster, engine, controller}
1303 launched with argv[0] set to the *location of the ip{cluster, engine, controller}
1304 scripts on the remote node*. This means you need to take care that:
1304 scripts on the remote node*. This means you need to take care that:
1305
1305
1306 a. Your remote nodes have their paths configured correctly, with the ipengine and ipcontroller
1306 a. Your remote nodes have their paths configured correctly, with the ipengine and ipcontroller
1307 of the python environment you wish to execute code in having top precedence.
1307 of the python environment you wish to execute code in having top precedence.
1308 b. This functionality is untested on Windows.
1308 b. This functionality is untested on Windows.
1309
1309
1310 If you need different behavior, consider making you own template.
1310 If you need different behavior, consider making you own template.
1311 """
1311 """
1312
1312
1313 submit_command = List(['condor_submit'], config=True,
1313 submit_command = List(['condor_submit'], config=True,
1314 help="The HTCondor submit command ['condor_submit']")
1314 help="The HTCondor submit command ['condor_submit']")
1315 delete_command = List(['condor_rm'], config=True,
1315 delete_command = List(['condor_rm'], config=True,
1316 help="The HTCondor delete command ['condor_rm']")
1316 help="The HTCondor delete command ['condor_rm']")
1317 job_id_regexp = CRegExp(r'(\d+)\.$', config=True,
1317 job_id_regexp = CRegExp(r'(\d+)\.$', config=True,
1318 help="Regular expression for identifying the job ID [r'(\d+)\.$']")
1318 help="Regular expression for identifying the job ID [r'(\d+)\.$']")
1319 job_id_regexp_group = Integer(1, config=True,
1319 job_id_regexp_group = Integer(1, config=True,
1320 help="""The group we wish to match in job_id_regexp [1]""")
1320 help="""The group we wish to match in job_id_regexp [1]""")
1321
1321
1322 job_array_regexp = CRegExp('queue\W+\$')
1322 job_array_regexp = CRegExp('queue\W+\$')
1323 job_array_template = Unicode('queue {n}')
1323 job_array_template = Unicode('queue {n}')
1324
1324
1325
1325
1326 def _insert_job_array_in_script(self):
1326 def _insert_job_array_in_script(self):
1327 """Inserts a job array if required into the batch script.
1327 """Inserts a job array if required into the batch script.
1328 """
1328 """
1329 if not self.job_array_regexp.search(self.batch_template):
1329 if not self.job_array_regexp.search(self.batch_template):
1330 self.log.debug("adding job array settings to batch script")
1330 self.log.debug("adding job array settings to batch script")
1331 #HTCondor requires that the job array goes at the bottom of the script
1331 #HTCondor requires that the job array goes at the bottom of the script
1332 self.batch_template = '\n'.join([self.batch_template,
1332 self.batch_template = '\n'.join([self.batch_template,
1333 self.job_array_template])
1333 self.job_array_template])
1334
1334
1335 def _insert_queue_in_script(self):
1335 def _insert_queue_in_script(self):
1336 """AFAIK, HTCondor doesn't have a concept of multiple queues that can be
1336 """AFAIK, HTCondor doesn't have a concept of multiple queues that can be
1337 specified in the script.
1337 specified in the script.
1338 """
1338 """
1339 pass
1339 pass
1340
1340
1341
1341
1342 class HTCondorControllerLauncher(HTCondorLauncher, BatchClusterAppMixin):
1342 class HTCondorControllerLauncher(HTCondorLauncher, BatchClusterAppMixin):
1343 """Launch a controller using HTCondor."""
1343 """Launch a controller using HTCondor."""
1344
1344
1345 batch_file_name = Unicode(u'htcondor_controller', config=True,
1345 batch_file_name = Unicode(u'htcondor_controller', config=True,
1346 help="batch file name for the controller job.")
1346 help="batch file name for the controller job.")
1347 default_template = Unicode(r"""
1347 default_template = Unicode(r"""
1348 universe = vanilla
1348 universe = vanilla
1349 executable = ipcontroller
1349 executable = ipcontroller
1350 # by default we expect a shared file system
1350 # by default we expect a shared file system
1351 transfer_executable = False
1351 transfer_executable = False
1352 arguments = --log-to-file '--profile-dir={profile_dir}' --cluster-id='{cluster_id}'
1352 arguments = --log-to-file '--profile-dir={profile_dir}' --cluster-id='{cluster_id}'
1353 """)
1353 """)
1354
1354
1355 def start(self):
1355 def start(self):
1356 """Start the controller by profile or profile_dir."""
1356 """Start the controller by profile or profile_dir."""
1357 return super(HTCondorControllerLauncher, self).start(1)
1357 return super(HTCondorControllerLauncher, self).start(1)
1358
1358
1359
1359
1360 class HTCondorEngineSetLauncher(HTCondorLauncher, BatchClusterAppMixin):
1360 class HTCondorEngineSetLauncher(HTCondorLauncher, BatchClusterAppMixin):
1361 """Launch Engines using HTCondor"""
1361 """Launch Engines using HTCondor"""
1362 batch_file_name = Unicode(u'htcondor_engines', config=True,
1362 batch_file_name = Unicode(u'htcondor_engines', config=True,
1363 help="batch file name for the engine(s) job.")
1363 help="batch file name for the engine(s) job.")
1364 default_template = Unicode("""
1364 default_template = Unicode("""
1365 universe = vanilla
1365 universe = vanilla
1366 executable = ipengine
1366 executable = ipengine
1367 # by default we expect a shared file system
1367 # by default we expect a shared file system
1368 transfer_executable = False
1368 transfer_executable = False
1369 arguments = "--log-to-file '--profile-dir={profile_dir}' '--cluster-id={cluster_id}'"
1369 arguments = "--log-to-file '--profile-dir={profile_dir}' '--cluster-id={cluster_id}'"
1370 """)
1370 """)
1371
1371
1372
1372
1373 #-----------------------------------------------------------------------------
1373 #-----------------------------------------------------------------------------
1374 # A launcher for ipcluster itself!
1374 # A launcher for ipcluster itself!
1375 #-----------------------------------------------------------------------------
1375 #-----------------------------------------------------------------------------
1376
1376
1377
1377
1378 class IPClusterLauncher(LocalProcessLauncher):
1378 class IPClusterLauncher(LocalProcessLauncher):
1379 """Launch the ipcluster program in an external process."""
1379 """Launch the ipcluster program in an external process."""
1380
1380
1381 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1381 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1382 help="Popen command for ipcluster")
1382 help="Popen command for ipcluster")
1383 ipcluster_args = List(
1383 ipcluster_args = List(
1384 ['--clean-logs=True', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1384 ['--clean-logs=True', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1385 help="Command line arguments to pass to ipcluster.")
1385 help="Command line arguments to pass to ipcluster.")
1386 ipcluster_subcommand = Unicode('start')
1386 ipcluster_subcommand = Unicode('start')
1387 profile = Unicode('default')
1387 profile = Unicode('default')
1388 n = Integer(2)
1388 n = Integer(2)
1389
1389
1390 def find_args(self):
1390 def find_args(self):
1391 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1391 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1392 ['--n=%i'%self.n, '--profile=%s'%self.profile] + \
1392 ['--n=%i'%self.n, '--profile=%s'%self.profile] + \
1393 self.ipcluster_args
1393 self.ipcluster_args
1394
1394
1395 def start(self):
1395 def start(self):
1396 return super(IPClusterLauncher, self).start()
1396 return super(IPClusterLauncher, self).start()
1397
1397
1398 #-----------------------------------------------------------------------------
1398 #-----------------------------------------------------------------------------
1399 # Collections of launchers
1399 # Collections of launchers
1400 #-----------------------------------------------------------------------------
1400 #-----------------------------------------------------------------------------
1401
1401
1402 local_launchers = [
1402 local_launchers = [
1403 LocalControllerLauncher,
1403 LocalControllerLauncher,
1404 LocalEngineLauncher,
1404 LocalEngineLauncher,
1405 LocalEngineSetLauncher,
1405 LocalEngineSetLauncher,
1406 ]
1406 ]
1407 mpi_launchers = [
1407 mpi_launchers = [
1408 MPILauncher,
1408 MPILauncher,
1409 MPIControllerLauncher,
1409 MPIControllerLauncher,
1410 MPIEngineSetLauncher,
1410 MPIEngineSetLauncher,
1411 ]
1411 ]
1412 ssh_launchers = [
1412 ssh_launchers = [
1413 SSHLauncher,
1413 SSHLauncher,
1414 SSHControllerLauncher,
1414 SSHControllerLauncher,
1415 SSHEngineLauncher,
1415 SSHEngineLauncher,
1416 SSHEngineSetLauncher,
1416 SSHEngineSetLauncher,
1417 SSHProxyEngineSetLauncher,
1417 SSHProxyEngineSetLauncher,
1418 ]
1418 ]
1419 winhpc_launchers = [
1419 winhpc_launchers = [
1420 WindowsHPCLauncher,
1420 WindowsHPCLauncher,
1421 WindowsHPCControllerLauncher,
1421 WindowsHPCControllerLauncher,
1422 WindowsHPCEngineSetLauncher,
1422 WindowsHPCEngineSetLauncher,
1423 ]
1423 ]
1424 pbs_launchers = [
1424 pbs_launchers = [
1425 PBSLauncher,
1425 PBSLauncher,
1426 PBSControllerLauncher,
1426 PBSControllerLauncher,
1427 PBSEngineSetLauncher,
1427 PBSEngineSetLauncher,
1428 ]
1428 ]
1429 sge_launchers = [
1429 sge_launchers = [
1430 SGELauncher,
1430 SGELauncher,
1431 SGEControllerLauncher,
1431 SGEControllerLauncher,
1432 SGEEngineSetLauncher,
1432 SGEEngineSetLauncher,
1433 ]
1433 ]
1434 lsf_launchers = [
1434 lsf_launchers = [
1435 LSFLauncher,
1435 LSFLauncher,
1436 LSFControllerLauncher,
1436 LSFControllerLauncher,
1437 LSFEngineSetLauncher,
1437 LSFEngineSetLauncher,
1438 ]
1438 ]
1439 htcondor_launchers = [
1439 htcondor_launchers = [
1440 HTCondorLauncher,
1440 HTCondorLauncher,
1441 HTCondorControllerLauncher,
1441 HTCondorControllerLauncher,
1442 HTCondorEngineSetLauncher,
1442 HTCondorEngineSetLauncher,
1443 ]
1443 ]
1444 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1444 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1445 + pbs_launchers + sge_launchers + lsf_launchers + htcondor_launchers
1445 + pbs_launchers + sge_launchers + lsf_launchers + htcondor_launchers
@@ -1,703 +1,703 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import sys
8 import sys
9 import time
9 import time
10 from datetime import datetime
10 from datetime import datetime
11
11
12 from zmq import MessageTracker
12 from zmq import MessageTracker
13
13
14 from IPython.core.display import clear_output, display, display_pretty
14 from IPython.core.display import clear_output, display, display_pretty
15 from decorator import decorator
15 from decorator import decorator
16 from IPython.parallel import error
16 from ipython_parallel import error
17 from IPython.utils.py3compat import string_types
17 from IPython.utils.py3compat import string_types
18
18
19
19
20 def _raw_text(s):
20 def _raw_text(s):
21 display_pretty(s, raw=True)
21 display_pretty(s, raw=True)
22
22
23
23
24 # global empty tracker that's always done:
24 # global empty tracker that's always done:
25 finished_tracker = MessageTracker()
25 finished_tracker = MessageTracker()
26
26
27 @decorator
27 @decorator
28 def check_ready(f, self, *args, **kwargs):
28 def check_ready(f, self, *args, **kwargs):
29 """Call spin() to sync state prior to calling the method."""
29 """Call spin() to sync state prior to calling the method."""
30 self.wait(0)
30 self.wait(0)
31 if not self._ready:
31 if not self._ready:
32 raise error.TimeoutError("result not ready")
32 raise error.TimeoutError("result not ready")
33 return f(self, *args, **kwargs)
33 return f(self, *args, **kwargs)
34
34
35 class AsyncResult(object):
35 class AsyncResult(object):
36 """Class for representing results of non-blocking calls.
36 """Class for representing results of non-blocking calls.
37
37
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
38 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
39 """
39 """
40
40
41 msg_ids = None
41 msg_ids = None
42 _targets = None
42 _targets = None
43 _tracker = None
43 _tracker = None
44 _single_result = False
44 _single_result = False
45 owner = False,
45 owner = False,
46
46
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
47 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None,
48 owner=False,
48 owner=False,
49 ):
49 ):
50 if isinstance(msg_ids, string_types):
50 if isinstance(msg_ids, string_types):
51 # always a list
51 # always a list
52 msg_ids = [msg_ids]
52 msg_ids = [msg_ids]
53 self._single_result = True
53 self._single_result = True
54 else:
54 else:
55 self._single_result = False
55 self._single_result = False
56 if tracker is None:
56 if tracker is None:
57 # default to always done
57 # default to always done
58 tracker = finished_tracker
58 tracker = finished_tracker
59 self._client = client
59 self._client = client
60 self.msg_ids = msg_ids
60 self.msg_ids = msg_ids
61 self._fname=fname
61 self._fname=fname
62 self._targets = targets
62 self._targets = targets
63 self._tracker = tracker
63 self._tracker = tracker
64 self.owner = owner
64 self.owner = owner
65
65
66 self._ready = False
66 self._ready = False
67 self._outputs_ready = False
67 self._outputs_ready = False
68 self._success = None
68 self._success = None
69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
69 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
70
70
71 def __repr__(self):
71 def __repr__(self):
72 if self._ready:
72 if self._ready:
73 return "<%s: finished>"%(self.__class__.__name__)
73 return "<%s: finished>"%(self.__class__.__name__)
74 else:
74 else:
75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
75 return "<%s: %s>"%(self.__class__.__name__,self._fname)
76
76
77
77
78 def _reconstruct_result(self, res):
78 def _reconstruct_result(self, res):
79 """Reconstruct our result from actual result list (always a list)
79 """Reconstruct our result from actual result list (always a list)
80
80
81 Override me in subclasses for turning a list of results
81 Override me in subclasses for turning a list of results
82 into the expected form.
82 into the expected form.
83 """
83 """
84 if self._single_result:
84 if self._single_result:
85 return res[0]
85 return res[0]
86 else:
86 else:
87 return res
87 return res
88
88
89 def get(self, timeout=-1):
89 def get(self, timeout=-1):
90 """Return the result when it arrives.
90 """Return the result when it arrives.
91
91
92 If `timeout` is not ``None`` and the result does not arrive within
92 If `timeout` is not ``None`` and the result does not arrive within
93 `timeout` seconds then ``TimeoutError`` is raised. If the
93 `timeout` seconds then ``TimeoutError`` is raised. If the
94 remote call raised an exception then that exception will be reraised
94 remote call raised an exception then that exception will be reraised
95 by get() inside a `RemoteError`.
95 by get() inside a `RemoteError`.
96 """
96 """
97 if not self.ready():
97 if not self.ready():
98 self.wait(timeout)
98 self.wait(timeout)
99
99
100 if self._ready:
100 if self._ready:
101 if self._success:
101 if self._success:
102 return self._result
102 return self._result
103 else:
103 else:
104 raise self._exception
104 raise self._exception
105 else:
105 else:
106 raise error.TimeoutError("Result not ready.")
106 raise error.TimeoutError("Result not ready.")
107
107
108 def _check_ready(self):
108 def _check_ready(self):
109 if not self.ready():
109 if not self.ready():
110 raise error.TimeoutError("Result not ready.")
110 raise error.TimeoutError("Result not ready.")
111
111
112 def ready(self):
112 def ready(self):
113 """Return whether the call has completed."""
113 """Return whether the call has completed."""
114 if not self._ready:
114 if not self._ready:
115 self.wait(0)
115 self.wait(0)
116 elif not self._outputs_ready:
116 elif not self._outputs_ready:
117 self._wait_for_outputs(0)
117 self._wait_for_outputs(0)
118
118
119 return self._ready
119 return self._ready
120
120
121 def wait(self, timeout=-1):
121 def wait(self, timeout=-1):
122 """Wait until the result is available or until `timeout` seconds pass.
122 """Wait until the result is available or until `timeout` seconds pass.
123
123
124 This method always returns None.
124 This method always returns None.
125 """
125 """
126 if self._ready:
126 if self._ready:
127 self._wait_for_outputs(timeout)
127 self._wait_for_outputs(timeout)
128 return
128 return
129 self._ready = self._client.wait(self.msg_ids, timeout)
129 self._ready = self._client.wait(self.msg_ids, timeout)
130 if self._ready:
130 if self._ready:
131 try:
131 try:
132 results = list(map(self._client.results.get, self.msg_ids))
132 results = list(map(self._client.results.get, self.msg_ids))
133 self._result = results
133 self._result = results
134 if self._single_result:
134 if self._single_result:
135 r = results[0]
135 r = results[0]
136 if isinstance(r, Exception):
136 if isinstance(r, Exception):
137 raise r
137 raise r
138 else:
138 else:
139 results = error.collect_exceptions(results, self._fname)
139 results = error.collect_exceptions(results, self._fname)
140 self._result = self._reconstruct_result(results)
140 self._result = self._reconstruct_result(results)
141 except Exception as e:
141 except Exception as e:
142 self._exception = e
142 self._exception = e
143 self._success = False
143 self._success = False
144 else:
144 else:
145 self._success = True
145 self._success = True
146 finally:
146 finally:
147 if timeout is None or timeout < 0:
147 if timeout is None or timeout < 0:
148 # cutoff infinite wait at 10s
148 # cutoff infinite wait at 10s
149 timeout = 10
149 timeout = 10
150 self._wait_for_outputs(timeout)
150 self._wait_for_outputs(timeout)
151
151
152 if self.owner:
152 if self.owner:
153
153
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
154 self._metadata = [self._client.metadata.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
155 [self._client.results.pop(mid) for mid in self.msg_ids]
156
156
157
157
158
158
159 def successful(self):
159 def successful(self):
160 """Return whether the call completed without raising an exception.
160 """Return whether the call completed without raising an exception.
161
161
162 Will raise ``AssertionError`` if the result is not ready.
162 Will raise ``AssertionError`` if the result is not ready.
163 """
163 """
164 assert self.ready()
164 assert self.ready()
165 return self._success
165 return self._success
166
166
167 #----------------------------------------------------------------
167 #----------------------------------------------------------------
168 # Extra methods not in mp.pool.AsyncResult
168 # Extra methods not in mp.pool.AsyncResult
169 #----------------------------------------------------------------
169 #----------------------------------------------------------------
170
170
171 def get_dict(self, timeout=-1):
171 def get_dict(self, timeout=-1):
172 """Get the results as a dict, keyed by engine_id.
172 """Get the results as a dict, keyed by engine_id.
173
173
174 timeout behavior is described in `get()`.
174 timeout behavior is described in `get()`.
175 """
175 """
176
176
177 results = self.get(timeout)
177 results = self.get(timeout)
178 if self._single_result:
178 if self._single_result:
179 results = [results]
179 results = [results]
180 engine_ids = [ md['engine_id'] for md in self._metadata ]
180 engine_ids = [ md['engine_id'] for md in self._metadata ]
181
181
182
182
183 rdict = {}
183 rdict = {}
184 for engine_id, result in zip(engine_ids, results):
184 for engine_id, result in zip(engine_ids, results):
185 if engine_id in rdict:
185 if engine_id in rdict:
186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
186 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
187 engine_ids.count(engine_id), engine_id)
187 engine_ids.count(engine_id), engine_id)
188 )
188 )
189 else:
189 else:
190 rdict[engine_id] = result
190 rdict[engine_id] = result
191
191
192 return rdict
192 return rdict
193
193
194 @property
194 @property
195 def result(self):
195 def result(self):
196 """result property wrapper for `get(timeout=-1)`."""
196 """result property wrapper for `get(timeout=-1)`."""
197 return self.get()
197 return self.get()
198
198
199 # abbreviated alias:
199 # abbreviated alias:
200 r = result
200 r = result
201
201
202 @property
202 @property
203 def metadata(self):
203 def metadata(self):
204 """property for accessing execution metadata."""
204 """property for accessing execution metadata."""
205 if self._single_result:
205 if self._single_result:
206 return self._metadata[0]
206 return self._metadata[0]
207 else:
207 else:
208 return self._metadata
208 return self._metadata
209
209
210 @property
210 @property
211 def result_dict(self):
211 def result_dict(self):
212 """result property as a dict."""
212 """result property as a dict."""
213 return self.get_dict()
213 return self.get_dict()
214
214
215 def __dict__(self):
215 def __dict__(self):
216 return self.get_dict(0)
216 return self.get_dict(0)
217
217
218 def abort(self):
218 def abort(self):
219 """abort my tasks."""
219 """abort my tasks."""
220 assert not self.ready(), "Can't abort, I am already done!"
220 assert not self.ready(), "Can't abort, I am already done!"
221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
221 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
222
222
223 @property
223 @property
224 def sent(self):
224 def sent(self):
225 """check whether my messages have been sent."""
225 """check whether my messages have been sent."""
226 return self._tracker.done
226 return self._tracker.done
227
227
228 def wait_for_send(self, timeout=-1):
228 def wait_for_send(self, timeout=-1):
229 """wait for pyzmq send to complete.
229 """wait for pyzmq send to complete.
230
230
231 This is necessary when sending arrays that you intend to edit in-place.
231 This is necessary when sending arrays that you intend to edit in-place.
232 `timeout` is in seconds, and will raise TimeoutError if it is reached
232 `timeout` is in seconds, and will raise TimeoutError if it is reached
233 before the send completes.
233 before the send completes.
234 """
234 """
235 return self._tracker.wait(timeout)
235 return self._tracker.wait(timeout)
236
236
237 #-------------------------------------
237 #-------------------------------------
238 # dict-access
238 # dict-access
239 #-------------------------------------
239 #-------------------------------------
240
240
241 def __getitem__(self, key):
241 def __getitem__(self, key):
242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
242 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
243 """
243 """
244 if isinstance(key, int):
244 if isinstance(key, int):
245 self._check_ready()
245 self._check_ready()
246 return error.collect_exceptions([self._result[key]], self._fname)[0]
246 return error.collect_exceptions([self._result[key]], self._fname)[0]
247 elif isinstance(key, slice):
247 elif isinstance(key, slice):
248 self._check_ready()
248 self._check_ready()
249 return error.collect_exceptions(self._result[key], self._fname)
249 return error.collect_exceptions(self._result[key], self._fname)
250 elif isinstance(key, string_types):
250 elif isinstance(key, string_types):
251 # metadata proxy *does not* require that results are done
251 # metadata proxy *does not* require that results are done
252 self.wait(0)
252 self.wait(0)
253 values = [ md[key] for md in self._metadata ]
253 values = [ md[key] for md in self._metadata ]
254 if self._single_result:
254 if self._single_result:
255 return values[0]
255 return values[0]
256 else:
256 else:
257 return values
257 return values
258 else:
258 else:
259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
259 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
260
260
261 def __getattr__(self, key):
261 def __getattr__(self, key):
262 """getattr maps to getitem for convenient attr access to metadata."""
262 """getattr maps to getitem for convenient attr access to metadata."""
263 try:
263 try:
264 return self.__getitem__(key)
264 return self.__getitem__(key)
265 except (error.TimeoutError, KeyError):
265 except (error.TimeoutError, KeyError):
266 raise AttributeError("%r object has no attribute %r"%(
266 raise AttributeError("%r object has no attribute %r"%(
267 self.__class__.__name__, key))
267 self.__class__.__name__, key))
268
268
269 # asynchronous iterator:
269 # asynchronous iterator:
270 def __iter__(self):
270 def __iter__(self):
271 if self._single_result:
271 if self._single_result:
272 raise TypeError("AsyncResults with a single result are not iterable.")
272 raise TypeError("AsyncResults with a single result are not iterable.")
273 try:
273 try:
274 rlist = self.get(0)
274 rlist = self.get(0)
275 except error.TimeoutError:
275 except error.TimeoutError:
276 # wait for each result individually
276 # wait for each result individually
277 for msg_id in self.msg_ids:
277 for msg_id in self.msg_ids:
278 ar = AsyncResult(self._client, msg_id, self._fname)
278 ar = AsyncResult(self._client, msg_id, self._fname)
279 yield ar.get()
279 yield ar.get()
280 else:
280 else:
281 # already done
281 # already done
282 for r in rlist:
282 for r in rlist:
283 yield r
283 yield r
284
284
285 def __len__(self):
285 def __len__(self):
286 return len(self.msg_ids)
286 return len(self.msg_ids)
287
287
288 #-------------------------------------
288 #-------------------------------------
289 # Sugar methods and attributes
289 # Sugar methods and attributes
290 #-------------------------------------
290 #-------------------------------------
291
291
292 def timedelta(self, start, end, start_key=min, end_key=max):
292 def timedelta(self, start, end, start_key=min, end_key=max):
293 """compute the difference between two sets of timestamps
293 """compute the difference between two sets of timestamps
294
294
295 The default behavior is to use the earliest of the first
295 The default behavior is to use the earliest of the first
296 and the latest of the second list, but this can be changed
296 and the latest of the second list, but this can be changed
297 by passing a different
297 by passing a different
298
298
299 Parameters
299 Parameters
300 ----------
300 ----------
301
301
302 start : one or more datetime objects (e.g. ar.submitted)
302 start : one or more datetime objects (e.g. ar.submitted)
303 end : one or more datetime objects (e.g. ar.received)
303 end : one or more datetime objects (e.g. ar.received)
304 start_key : callable
304 start_key : callable
305 Function to call on `start` to extract the relevant
305 Function to call on `start` to extract the relevant
306 entry [defalt: min]
306 entry [defalt: min]
307 end_key : callable
307 end_key : callable
308 Function to call on `end` to extract the relevant
308 Function to call on `end` to extract the relevant
309 entry [default: max]
309 entry [default: max]
310
310
311 Returns
311 Returns
312 -------
312 -------
313
313
314 dt : float
314 dt : float
315 The time elapsed (in seconds) between the two selected timestamps.
315 The time elapsed (in seconds) between the two selected timestamps.
316 """
316 """
317 if not isinstance(start, datetime):
317 if not isinstance(start, datetime):
318 # handle single_result AsyncResults, where ar.stamp is single object,
318 # handle single_result AsyncResults, where ar.stamp is single object,
319 # not a list
319 # not a list
320 start = start_key(start)
320 start = start_key(start)
321 if not isinstance(end, datetime):
321 if not isinstance(end, datetime):
322 # handle single_result AsyncResults, where ar.stamp is single object,
322 # handle single_result AsyncResults, where ar.stamp is single object,
323 # not a list
323 # not a list
324 end = end_key(end)
324 end = end_key(end)
325 return (end - start).total_seconds()
325 return (end - start).total_seconds()
326
326
327 @property
327 @property
328 def progress(self):
328 def progress(self):
329 """the number of tasks which have been completed at this point.
329 """the number of tasks which have been completed at this point.
330
330
331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
331 Fractional progress would be given by 1.0 * ar.progress / len(ar)
332 """
332 """
333 self.wait(0)
333 self.wait(0)
334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
334 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
335
335
336 @property
336 @property
337 def elapsed(self):
337 def elapsed(self):
338 """elapsed time since initial submission"""
338 """elapsed time since initial submission"""
339 if self.ready():
339 if self.ready():
340 return self.wall_time
340 return self.wall_time
341
341
342 now = submitted = datetime.now()
342 now = submitted = datetime.now()
343 for msg_id in self.msg_ids:
343 for msg_id in self.msg_ids:
344 if msg_id in self._client.metadata:
344 if msg_id in self._client.metadata:
345 stamp = self._client.metadata[msg_id]['submitted']
345 stamp = self._client.metadata[msg_id]['submitted']
346 if stamp and stamp < submitted:
346 if stamp and stamp < submitted:
347 submitted = stamp
347 submitted = stamp
348 return (now-submitted).total_seconds()
348 return (now-submitted).total_seconds()
349
349
350 @property
350 @property
351 @check_ready
351 @check_ready
352 def serial_time(self):
352 def serial_time(self):
353 """serial computation time of a parallel calculation
353 """serial computation time of a parallel calculation
354
354
355 Computed as the sum of (completed-started) of each task
355 Computed as the sum of (completed-started) of each task
356 """
356 """
357 t = 0
357 t = 0
358 for md in self._metadata:
358 for md in self._metadata:
359 t += (md['completed'] - md['started']).total_seconds()
359 t += (md['completed'] - md['started']).total_seconds()
360 return t
360 return t
361
361
362 @property
362 @property
363 @check_ready
363 @check_ready
364 def wall_time(self):
364 def wall_time(self):
365 """actual computation time of a parallel calculation
365 """actual computation time of a parallel calculation
366
366
367 Computed as the time between the latest `received` stamp
367 Computed as the time between the latest `received` stamp
368 and the earliest `submitted`.
368 and the earliest `submitted`.
369
369
370 Only reliable if Client was spinning/waiting when the task finished, because
370 Only reliable if Client was spinning/waiting when the task finished, because
371 the `received` timestamp is created when a result is pulled off of the zmq queue,
371 the `received` timestamp is created when a result is pulled off of the zmq queue,
372 which happens as a result of `client.spin()`.
372 which happens as a result of `client.spin()`.
373
373
374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
374 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
375
375
376 """
376 """
377 return self.timedelta(self.submitted, self.received)
377 return self.timedelta(self.submitted, self.received)
378
378
379 def wait_interactive(self, interval=1., timeout=-1):
379 def wait_interactive(self, interval=1., timeout=-1):
380 """interactive wait, printing progress at regular intervals"""
380 """interactive wait, printing progress at regular intervals"""
381 if timeout is None:
381 if timeout is None:
382 timeout = -1
382 timeout = -1
383 N = len(self)
383 N = len(self)
384 tic = time.time()
384 tic = time.time()
385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
385 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
386 self.wait(interval)
386 self.wait(interval)
387 clear_output(wait=True)
387 clear_output(wait=True)
388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
388 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
389 sys.stdout.flush()
389 sys.stdout.flush()
390 print()
390 print()
391 print("done")
391 print("done")
392
392
393 def _republish_displaypub(self, content, eid):
393 def _republish_displaypub(self, content, eid):
394 """republish individual displaypub content dicts"""
394 """republish individual displaypub content dicts"""
395 try:
395 try:
396 ip = get_ipython()
396 ip = get_ipython()
397 except NameError:
397 except NameError:
398 # displaypub is meaningless outside IPython
398 # displaypub is meaningless outside IPython
399 return
399 return
400 md = content['metadata'] or {}
400 md = content['metadata'] or {}
401 md['engine'] = eid
401 md['engine'] = eid
402 ip.display_pub.publish(data=content['data'], metadata=md)
402 ip.display_pub.publish(data=content['data'], metadata=md)
403
403
404 def _display_stream(self, text, prefix='', file=None):
404 def _display_stream(self, text, prefix='', file=None):
405 if not text:
405 if not text:
406 # nothing to display
406 # nothing to display
407 return
407 return
408 if file is None:
408 if file is None:
409 file = sys.stdout
409 file = sys.stdout
410 end = '' if text.endswith('\n') else '\n'
410 end = '' if text.endswith('\n') else '\n'
411
411
412 multiline = text.count('\n') > int(text.endswith('\n'))
412 multiline = text.count('\n') > int(text.endswith('\n'))
413 if prefix and multiline and not text.startswith('\n'):
413 if prefix and multiline and not text.startswith('\n'):
414 prefix = prefix + '\n'
414 prefix = prefix + '\n'
415 print("%s%s" % (prefix, text), file=file, end=end)
415 print("%s%s" % (prefix, text), file=file, end=end)
416
416
417
417
418 def _display_single_result(self):
418 def _display_single_result(self):
419 self._display_stream(self.stdout)
419 self._display_stream(self.stdout)
420 self._display_stream(self.stderr, file=sys.stderr)
420 self._display_stream(self.stderr, file=sys.stderr)
421
421
422 try:
422 try:
423 get_ipython()
423 get_ipython()
424 except NameError:
424 except NameError:
425 # displaypub is meaningless outside IPython
425 # displaypub is meaningless outside IPython
426 return
426 return
427
427
428 for output in self.outputs:
428 for output in self.outputs:
429 self._republish_displaypub(output, self.engine_id)
429 self._republish_displaypub(output, self.engine_id)
430
430
431 if self.execute_result is not None:
431 if self.execute_result is not None:
432 display(self.get())
432 display(self.get())
433
433
434 def _wait_for_outputs(self, timeout=-1):
434 def _wait_for_outputs(self, timeout=-1):
435 """wait for the 'status=idle' message that indicates we have all outputs
435 """wait for the 'status=idle' message that indicates we have all outputs
436 """
436 """
437 if self._outputs_ready or not self._success:
437 if self._outputs_ready or not self._success:
438 # don't wait on errors
438 # don't wait on errors
439 return
439 return
440
440
441 # cast None to -1 for infinite timeout
441 # cast None to -1 for infinite timeout
442 if timeout is None:
442 if timeout is None:
443 timeout = -1
443 timeout = -1
444
444
445 tic = time.time()
445 tic = time.time()
446 while True:
446 while True:
447 self._client._flush_iopub(self._client._iopub_socket)
447 self._client._flush_iopub(self._client._iopub_socket)
448 self._outputs_ready = all(md['outputs_ready']
448 self._outputs_ready = all(md['outputs_ready']
449 for md in self._metadata)
449 for md in self._metadata)
450 if self._outputs_ready or \
450 if self._outputs_ready or \
451 (timeout >= 0 and time.time() > tic + timeout):
451 (timeout >= 0 and time.time() > tic + timeout):
452 break
452 break
453 time.sleep(0.01)
453 time.sleep(0.01)
454
454
455 @check_ready
455 @check_ready
456 def display_outputs(self, groupby="type"):
456 def display_outputs(self, groupby="type"):
457 """republish the outputs of the computation
457 """republish the outputs of the computation
458
458
459 Parameters
459 Parameters
460 ----------
460 ----------
461
461
462 groupby : str [default: type]
462 groupby : str [default: type]
463 if 'type':
463 if 'type':
464 Group outputs by type (show all stdout, then all stderr, etc.):
464 Group outputs by type (show all stdout, then all stderr, etc.):
465
465
466 [stdout:1] foo
466 [stdout:1] foo
467 [stdout:2] foo
467 [stdout:2] foo
468 [stderr:1] bar
468 [stderr:1] bar
469 [stderr:2] bar
469 [stderr:2] bar
470 if 'engine':
470 if 'engine':
471 Display outputs for each engine before moving on to the next:
471 Display outputs for each engine before moving on to the next:
472
472
473 [stdout:1] foo
473 [stdout:1] foo
474 [stderr:1] bar
474 [stderr:1] bar
475 [stdout:2] foo
475 [stdout:2] foo
476 [stderr:2] bar
476 [stderr:2] bar
477
477
478 if 'order':
478 if 'order':
479 Like 'type', but further collate individual displaypub
479 Like 'type', but further collate individual displaypub
480 outputs. This is meant for cases of each command producing
480 outputs. This is meant for cases of each command producing
481 several plots, and you would like to see all of the first
481 several plots, and you would like to see all of the first
482 plots together, then all of the second plots, and so on.
482 plots together, then all of the second plots, and so on.
483 """
483 """
484 if self._single_result:
484 if self._single_result:
485 self._display_single_result()
485 self._display_single_result()
486 return
486 return
487
487
488 stdouts = self.stdout
488 stdouts = self.stdout
489 stderrs = self.stderr
489 stderrs = self.stderr
490 execute_results = self.execute_result
490 execute_results = self.execute_result
491 output_lists = self.outputs
491 output_lists = self.outputs
492 results = self.get()
492 results = self.get()
493
493
494 targets = self.engine_id
494 targets = self.engine_id
495
495
496 if groupby == "engine":
496 if groupby == "engine":
497 for eid,stdout,stderr,outputs,r,execute_result in zip(
497 for eid,stdout,stderr,outputs,r,execute_result in zip(
498 targets, stdouts, stderrs, output_lists, results, execute_results
498 targets, stdouts, stderrs, output_lists, results, execute_results
499 ):
499 ):
500 self._display_stream(stdout, '[stdout:%i] ' % eid)
500 self._display_stream(stdout, '[stdout:%i] ' % eid)
501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
501 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
502
502
503 try:
503 try:
504 get_ipython()
504 get_ipython()
505 except NameError:
505 except NameError:
506 # displaypub is meaningless outside IPython
506 # displaypub is meaningless outside IPython
507 return
507 return
508
508
509 if outputs or execute_result is not None:
509 if outputs or execute_result is not None:
510 _raw_text('[output:%i]' % eid)
510 _raw_text('[output:%i]' % eid)
511
511
512 for output in outputs:
512 for output in outputs:
513 self._republish_displaypub(output, eid)
513 self._republish_displaypub(output, eid)
514
514
515 if execute_result is not None:
515 if execute_result is not None:
516 display(r)
516 display(r)
517
517
518 elif groupby in ('type', 'order'):
518 elif groupby in ('type', 'order'):
519 # republish stdout:
519 # republish stdout:
520 for eid,stdout in zip(targets, stdouts):
520 for eid,stdout in zip(targets, stdouts):
521 self._display_stream(stdout, '[stdout:%i] ' % eid)
521 self._display_stream(stdout, '[stdout:%i] ' % eid)
522
522
523 # republish stderr:
523 # republish stderr:
524 for eid,stderr in zip(targets, stderrs):
524 for eid,stderr in zip(targets, stderrs):
525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
525 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
526
526
527 try:
527 try:
528 get_ipython()
528 get_ipython()
529 except NameError:
529 except NameError:
530 # displaypub is meaningless outside IPython
530 # displaypub is meaningless outside IPython
531 return
531 return
532
532
533 if groupby == 'order':
533 if groupby == 'order':
534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
534 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
535 N = max(len(outputs) for outputs in output_lists)
535 N = max(len(outputs) for outputs in output_lists)
536 for i in range(N):
536 for i in range(N):
537 for eid in targets:
537 for eid in targets:
538 outputs = output_dict[eid]
538 outputs = output_dict[eid]
539 if len(outputs) >= N:
539 if len(outputs) >= N:
540 _raw_text('[output:%i]' % eid)
540 _raw_text('[output:%i]' % eid)
541 self._republish_displaypub(outputs[i], eid)
541 self._republish_displaypub(outputs[i], eid)
542 else:
542 else:
543 # republish displaypub output
543 # republish displaypub output
544 for eid,outputs in zip(targets, output_lists):
544 for eid,outputs in zip(targets, output_lists):
545 if outputs:
545 if outputs:
546 _raw_text('[output:%i]' % eid)
546 _raw_text('[output:%i]' % eid)
547 for output in outputs:
547 for output in outputs:
548 self._republish_displaypub(output, eid)
548 self._republish_displaypub(output, eid)
549
549
550 # finally, add execute_result:
550 # finally, add execute_result:
551 for eid,r,execute_result in zip(targets, results, execute_results):
551 for eid,r,execute_result in zip(targets, results, execute_results):
552 if execute_result is not None:
552 if execute_result is not None:
553 display(r)
553 display(r)
554
554
555 else:
555 else:
556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
556 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
557
557
558
558
559
559
560
560
561 class AsyncMapResult(AsyncResult):
561 class AsyncMapResult(AsyncResult):
562 """Class for representing results of non-blocking gathers.
562 """Class for representing results of non-blocking gathers.
563
563
564 This will properly reconstruct the gather.
564 This will properly reconstruct the gather.
565
565
566 This class is iterable at any time, and will wait on results as they come.
566 This class is iterable at any time, and will wait on results as they come.
567
567
568 If ordered=False, then the first results to arrive will come first, otherwise
568 If ordered=False, then the first results to arrive will come first, otherwise
569 results will be yielded in the order they were submitted.
569 results will be yielded in the order they were submitted.
570
570
571 """
571 """
572
572
573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
573 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
574 AsyncResult.__init__(self, client, msg_ids, fname=fname)
575 self._mapObject = mapObject
575 self._mapObject = mapObject
576 self._single_result = False
576 self._single_result = False
577 self.ordered = ordered
577 self.ordered = ordered
578
578
579 def _reconstruct_result(self, res):
579 def _reconstruct_result(self, res):
580 """Perform the gather on the actual results."""
580 """Perform the gather on the actual results."""
581 return self._mapObject.joinPartitions(res)
581 return self._mapObject.joinPartitions(res)
582
582
583 # asynchronous iterator:
583 # asynchronous iterator:
584 def __iter__(self):
584 def __iter__(self):
585 it = self._ordered_iter if self.ordered else self._unordered_iter
585 it = self._ordered_iter if self.ordered else self._unordered_iter
586 for r in it():
586 for r in it():
587 yield r
587 yield r
588
588
589 # asynchronous ordered iterator:
589 # asynchronous ordered iterator:
590 def _ordered_iter(self):
590 def _ordered_iter(self):
591 """iterator for results *as they arrive*, preserving submission order."""
591 """iterator for results *as they arrive*, preserving submission order."""
592 try:
592 try:
593 rlist = self.get(0)
593 rlist = self.get(0)
594 except error.TimeoutError:
594 except error.TimeoutError:
595 # wait for each result individually
595 # wait for each result individually
596 for msg_id in self.msg_ids:
596 for msg_id in self.msg_ids:
597 ar = AsyncResult(self._client, msg_id, self._fname)
597 ar = AsyncResult(self._client, msg_id, self._fname)
598 rlist = ar.get()
598 rlist = ar.get()
599 try:
599 try:
600 for r in rlist:
600 for r in rlist:
601 yield r
601 yield r
602 except TypeError:
602 except TypeError:
603 # flattened, not a list
603 # flattened, not a list
604 # this could get broken by flattened data that returns iterables
604 # this could get broken by flattened data that returns iterables
605 # but most calls to map do not expose the `flatten` argument
605 # but most calls to map do not expose the `flatten` argument
606 yield rlist
606 yield rlist
607 else:
607 else:
608 # already done
608 # already done
609 for r in rlist:
609 for r in rlist:
610 yield r
610 yield r
611
611
612 # asynchronous unordered iterator:
612 # asynchronous unordered iterator:
613 def _unordered_iter(self):
613 def _unordered_iter(self):
614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
614 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
615 try:
615 try:
616 rlist = self.get(0)
616 rlist = self.get(0)
617 except error.TimeoutError:
617 except error.TimeoutError:
618 pending = set(self.msg_ids)
618 pending = set(self.msg_ids)
619 while pending:
619 while pending:
620 try:
620 try:
621 self._client.wait(pending, 1e-3)
621 self._client.wait(pending, 1e-3)
622 except error.TimeoutError:
622 except error.TimeoutError:
623 # ignore timeout error, because that only means
623 # ignore timeout error, because that only means
624 # *some* jobs are outstanding
624 # *some* jobs are outstanding
625 pass
625 pass
626 # update ready set with those no longer outstanding:
626 # update ready set with those no longer outstanding:
627 ready = pending.difference(self._client.outstanding)
627 ready = pending.difference(self._client.outstanding)
628 # update pending to exclude those that are finished
628 # update pending to exclude those that are finished
629 pending = pending.difference(ready)
629 pending = pending.difference(ready)
630 while ready:
630 while ready:
631 msg_id = ready.pop()
631 msg_id = ready.pop()
632 ar = AsyncResult(self._client, msg_id, self._fname)
632 ar = AsyncResult(self._client, msg_id, self._fname)
633 rlist = ar.get()
633 rlist = ar.get()
634 try:
634 try:
635 for r in rlist:
635 for r in rlist:
636 yield r
636 yield r
637 except TypeError:
637 except TypeError:
638 # flattened, not a list
638 # flattened, not a list
639 # this could get broken by flattened data that returns iterables
639 # this could get broken by flattened data that returns iterables
640 # but most calls to map do not expose the `flatten` argument
640 # but most calls to map do not expose the `flatten` argument
641 yield rlist
641 yield rlist
642 else:
642 else:
643 # already done
643 # already done
644 for r in rlist:
644 for r in rlist:
645 yield r
645 yield r
646
646
647
647
648 class AsyncHubResult(AsyncResult):
648 class AsyncHubResult(AsyncResult):
649 """Class to wrap pending results that must be requested from the Hub.
649 """Class to wrap pending results that must be requested from the Hub.
650
650
651 Note that waiting/polling on these objects requires polling the Hubover the network,
651 Note that waiting/polling on these objects requires polling the Hubover the network,
652 so use `AsyncHubResult.wait()` sparingly.
652 so use `AsyncHubResult.wait()` sparingly.
653 """
653 """
654
654
655 def _wait_for_outputs(self, timeout=-1):
655 def _wait_for_outputs(self, timeout=-1):
656 """no-op, because HubResults are never incomplete"""
656 """no-op, because HubResults are never incomplete"""
657 self._outputs_ready = True
657 self._outputs_ready = True
658
658
659 def wait(self, timeout=-1):
659 def wait(self, timeout=-1):
660 """wait for result to complete."""
660 """wait for result to complete."""
661 start = time.time()
661 start = time.time()
662 if self._ready:
662 if self._ready:
663 return
663 return
664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
664 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
665 local_ready = self._client.wait(local_ids, timeout)
665 local_ready = self._client.wait(local_ids, timeout)
666 if local_ready:
666 if local_ready:
667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
667 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
668 if not remote_ids:
668 if not remote_ids:
669 self._ready = True
669 self._ready = True
670 else:
670 else:
671 rdict = self._client.result_status(remote_ids, status_only=False)
671 rdict = self._client.result_status(remote_ids, status_only=False)
672 pending = rdict['pending']
672 pending = rdict['pending']
673 while pending and (timeout < 0 or time.time() < start+timeout):
673 while pending and (timeout < 0 or time.time() < start+timeout):
674 rdict = self._client.result_status(remote_ids, status_only=False)
674 rdict = self._client.result_status(remote_ids, status_only=False)
675 pending = rdict['pending']
675 pending = rdict['pending']
676 if pending:
676 if pending:
677 time.sleep(0.1)
677 time.sleep(0.1)
678 if not pending:
678 if not pending:
679 self._ready = True
679 self._ready = True
680 if self._ready:
680 if self._ready:
681 try:
681 try:
682 results = list(map(self._client.results.get, self.msg_ids))
682 results = list(map(self._client.results.get, self.msg_ids))
683 self._result = results
683 self._result = results
684 if self._single_result:
684 if self._single_result:
685 r = results[0]
685 r = results[0]
686 if isinstance(r, Exception):
686 if isinstance(r, Exception):
687 raise r
687 raise r
688 else:
688 else:
689 results = error.collect_exceptions(results, self._fname)
689 results = error.collect_exceptions(results, self._fname)
690 self._result = self._reconstruct_result(results)
690 self._result = self._reconstruct_result(results)
691 except Exception as e:
691 except Exception as e:
692 self._exception = e
692 self._exception = e
693 self._success = False
693 self._success = False
694 else:
694 else:
695 self._success = True
695 self._success = True
696 finally:
696 finally:
697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
697 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
698 if self.owner:
698 if self.owner:
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
699 [self._client.metadata.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
700 [self._client.results.pop(mid) for mid in self.msg_ids]
701
701
702
702
703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
703 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1893 +1,1893 b''
1 """A semi-synchronous Client for IPython parallel"""
1 """A semi-synchronous Client for IPython parallel"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import os
8 import os
9 import json
9 import json
10 import sys
10 import sys
11 from threading import Thread, Event
11 from threading import Thread, Event
12 import time
12 import time
13 import warnings
13 import warnings
14 from datetime import datetime
14 from datetime import datetime
15 from getpass import getpass
15 from getpass import getpass
16 from pprint import pprint
16 from pprint import pprint
17
17
18 pjoin = os.path.join
18 pjoin = os.path.join
19
19
20 import zmq
20 import zmq
21
21
22 from IPython.config.configurable import MultipleInstanceError
22 from IPython.config.configurable import MultipleInstanceError
23 from IPython.core.application import BaseIPythonApplication
23 from IPython.core.application import BaseIPythonApplication
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25
25
26 from IPython.utils.capture import RichOutput
26 from IPython.utils.capture import RichOutput
27 from IPython.utils.coloransi import TermColors
27 from IPython.utils.coloransi import TermColors
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 from IPython.utils.localinterfaces import localhost, is_local_ip
29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 from IPython.utils.path import get_ipython_dir, compress_user
30 from IPython.utils.path import get_ipython_dir, compress_user
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 Dict, List, Bool, Set, Any)
33 Dict, List, Bool, Set, Any)
34 from decorator import decorator
34 from decorator import decorator
35
35
36 from IPython.parallel import Reference
36 from ipython_parallel import Reference
37 from IPython.parallel import error
37 from ipython_parallel import error
38 from IPython.parallel import util
38 from ipython_parallel import util
39
39
40 from IPython.kernel.zmq.session import Session, Message
40 from IPython.kernel.zmq.session import Session, Message
41 from IPython.kernel.zmq import serialize
41 from IPython.kernel.zmq import serialize
42
42
43 from .asyncresult import AsyncResult, AsyncHubResult
43 from .asyncresult import AsyncResult, AsyncHubResult
44 from .view import DirectView, LoadBalancedView
44 from .view import DirectView, LoadBalancedView
45
45
46 #--------------------------------------------------------------------------
46 #--------------------------------------------------------------------------
47 # Decorators for Client methods
47 # Decorators for Client methods
48 #--------------------------------------------------------------------------
48 #--------------------------------------------------------------------------
49
49
50
50
51 @decorator
51 @decorator
52 def spin_first(f, self, *args, **kwargs):
52 def spin_first(f, self, *args, **kwargs):
53 """Call spin() to sync state prior to calling the method."""
53 """Call spin() to sync state prior to calling the method."""
54 self.spin()
54 self.spin()
55 return f(self, *args, **kwargs)
55 return f(self, *args, **kwargs)
56
56
57
57
58 #--------------------------------------------------------------------------
58 #--------------------------------------------------------------------------
59 # Classes
59 # Classes
60 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
61
61
62 _no_connection_file_msg = """
62 _no_connection_file_msg = """
63 Failed to connect because no Controller could be found.
63 Failed to connect because no Controller could be found.
64 Please double-check your profile and ensure that a cluster is running.
64 Please double-check your profile and ensure that a cluster is running.
65 """
65 """
66
66
67 class ExecuteReply(RichOutput):
67 class ExecuteReply(RichOutput):
68 """wrapper for finished Execute results"""
68 """wrapper for finished Execute results"""
69 def __init__(self, msg_id, content, metadata):
69 def __init__(self, msg_id, content, metadata):
70 self.msg_id = msg_id
70 self.msg_id = msg_id
71 self._content = content
71 self._content = content
72 self.execution_count = content['execution_count']
72 self.execution_count = content['execution_count']
73 self.metadata = metadata
73 self.metadata = metadata
74
74
75 # RichOutput overrides
75 # RichOutput overrides
76
76
77 @property
77 @property
78 def source(self):
78 def source(self):
79 execute_result = self.metadata['execute_result']
79 execute_result = self.metadata['execute_result']
80 if execute_result:
80 if execute_result:
81 return execute_result.get('source', '')
81 return execute_result.get('source', '')
82
82
83 @property
83 @property
84 def data(self):
84 def data(self):
85 execute_result = self.metadata['execute_result']
85 execute_result = self.metadata['execute_result']
86 if execute_result:
86 if execute_result:
87 return execute_result.get('data', {})
87 return execute_result.get('data', {})
88
88
89 @property
89 @property
90 def _metadata(self):
90 def _metadata(self):
91 execute_result = self.metadata['execute_result']
91 execute_result = self.metadata['execute_result']
92 if execute_result:
92 if execute_result:
93 return execute_result.get('metadata', {})
93 return execute_result.get('metadata', {})
94
94
95 def display(self):
95 def display(self):
96 from IPython.display import publish_display_data
96 from IPython.display import publish_display_data
97 publish_display_data(self.data, self.metadata)
97 publish_display_data(self.data, self.metadata)
98
98
99 def _repr_mime_(self, mime):
99 def _repr_mime_(self, mime):
100 if mime not in self.data:
100 if mime not in self.data:
101 return
101 return
102 data = self.data[mime]
102 data = self.data[mime]
103 if mime in self._metadata:
103 if mime in self._metadata:
104 return data, self._metadata[mime]
104 return data, self._metadata[mime]
105 else:
105 else:
106 return data
106 return data
107
107
108 def __getitem__(self, key):
108 def __getitem__(self, key):
109 return self.metadata[key]
109 return self.metadata[key]
110
110
111 def __getattr__(self, key):
111 def __getattr__(self, key):
112 if key not in self.metadata:
112 if key not in self.metadata:
113 raise AttributeError(key)
113 raise AttributeError(key)
114 return self.metadata[key]
114 return self.metadata[key]
115
115
116 def __repr__(self):
116 def __repr__(self):
117 execute_result = self.metadata['execute_result'] or {'data':{}}
117 execute_result = self.metadata['execute_result'] or {'data':{}}
118 text_out = execute_result['data'].get('text/plain', '')
118 text_out = execute_result['data'].get('text/plain', '')
119 if len(text_out) > 32:
119 if len(text_out) > 32:
120 text_out = text_out[:29] + '...'
120 text_out = text_out[:29] + '...'
121
121
122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
122 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
123
123
124 def _repr_pretty_(self, p, cycle):
124 def _repr_pretty_(self, p, cycle):
125 execute_result = self.metadata['execute_result'] or {'data':{}}
125 execute_result = self.metadata['execute_result'] or {'data':{}}
126 text_out = execute_result['data'].get('text/plain', '')
126 text_out = execute_result['data'].get('text/plain', '')
127
127
128 if not text_out:
128 if not text_out:
129 return
129 return
130
130
131 try:
131 try:
132 ip = get_ipython()
132 ip = get_ipython()
133 except NameError:
133 except NameError:
134 colors = "NoColor"
134 colors = "NoColor"
135 else:
135 else:
136 colors = ip.colors
136 colors = ip.colors
137
137
138 if colors == "NoColor":
138 if colors == "NoColor":
139 out = normal = ""
139 out = normal = ""
140 else:
140 else:
141 out = TermColors.Red
141 out = TermColors.Red
142 normal = TermColors.Normal
142 normal = TermColors.Normal
143
143
144 if '\n' in text_out and not text_out.startswith('\n'):
144 if '\n' in text_out and not text_out.startswith('\n'):
145 # add newline for multiline reprs
145 # add newline for multiline reprs
146 text_out = '\n' + text_out
146 text_out = '\n' + text_out
147
147
148 p.text(
148 p.text(
149 out + u'Out[%i:%i]: ' % (
149 out + u'Out[%i:%i]: ' % (
150 self.metadata['engine_id'], self.execution_count
150 self.metadata['engine_id'], self.execution_count
151 ) + normal + text_out
151 ) + normal + text_out
152 )
152 )
153
153
154
154
155 class Metadata(dict):
155 class Metadata(dict):
156 """Subclass of dict for initializing metadata values.
156 """Subclass of dict for initializing metadata values.
157
157
158 Attribute access works on keys.
158 Attribute access works on keys.
159
159
160 These objects have a strict set of keys - errors will raise if you try
160 These objects have a strict set of keys - errors will raise if you try
161 to add new keys.
161 to add new keys.
162 """
162 """
163 def __init__(self, *args, **kwargs):
163 def __init__(self, *args, **kwargs):
164 dict.__init__(self)
164 dict.__init__(self)
165 md = {'msg_id' : None,
165 md = {'msg_id' : None,
166 'submitted' : None,
166 'submitted' : None,
167 'started' : None,
167 'started' : None,
168 'completed' : None,
168 'completed' : None,
169 'received' : None,
169 'received' : None,
170 'engine_uuid' : None,
170 'engine_uuid' : None,
171 'engine_id' : None,
171 'engine_id' : None,
172 'follow' : None,
172 'follow' : None,
173 'after' : None,
173 'after' : None,
174 'status' : None,
174 'status' : None,
175
175
176 'execute_input' : None,
176 'execute_input' : None,
177 'execute_result' : None,
177 'execute_result' : None,
178 'error' : None,
178 'error' : None,
179 'stdout' : '',
179 'stdout' : '',
180 'stderr' : '',
180 'stderr' : '',
181 'outputs' : [],
181 'outputs' : [],
182 'data': {},
182 'data': {},
183 'outputs_ready' : False,
183 'outputs_ready' : False,
184 }
184 }
185 self.update(md)
185 self.update(md)
186 self.update(dict(*args, **kwargs))
186 self.update(dict(*args, **kwargs))
187
187
188 def __getattr__(self, key):
188 def __getattr__(self, key):
189 """getattr aliased to getitem"""
189 """getattr aliased to getitem"""
190 if key in self:
190 if key in self:
191 return self[key]
191 return self[key]
192 else:
192 else:
193 raise AttributeError(key)
193 raise AttributeError(key)
194
194
195 def __setattr__(self, key, value):
195 def __setattr__(self, key, value):
196 """setattr aliased to setitem, with strict"""
196 """setattr aliased to setitem, with strict"""
197 if key in self:
197 if key in self:
198 self[key] = value
198 self[key] = value
199 else:
199 else:
200 raise AttributeError(key)
200 raise AttributeError(key)
201
201
202 def __setitem__(self, key, value):
202 def __setitem__(self, key, value):
203 """strict static key enforcement"""
203 """strict static key enforcement"""
204 if key in self:
204 if key in self:
205 dict.__setitem__(self, key, value)
205 dict.__setitem__(self, key, value)
206 else:
206 else:
207 raise KeyError(key)
207 raise KeyError(key)
208
208
209
209
210 class Client(HasTraits):
210 class Client(HasTraits):
211 """A semi-synchronous client to the IPython ZMQ cluster
211 """A semi-synchronous client to the IPython ZMQ cluster
212
212
213 Parameters
213 Parameters
214 ----------
214 ----------
215
215
216 url_file : str/unicode; path to ipcontroller-client.json
216 url_file : str/unicode; path to ipcontroller-client.json
217 This JSON file should contain all the information needed to connect to a cluster,
217 This JSON file should contain all the information needed to connect to a cluster,
218 and is likely the only argument needed.
218 and is likely the only argument needed.
219 Connection information for the Hub's registration. If a json connector
219 Connection information for the Hub's registration. If a json connector
220 file is given, then likely no further configuration is necessary.
220 file is given, then likely no further configuration is necessary.
221 [Default: use profile]
221 [Default: use profile]
222 profile : bytes
222 profile : bytes
223 The name of the Cluster profile to be used to find connector information.
223 The name of the Cluster profile to be used to find connector information.
224 If run from an IPython application, the default profile will be the same
224 If run from an IPython application, the default profile will be the same
225 as the running application, otherwise it will be 'default'.
225 as the running application, otherwise it will be 'default'.
226 cluster_id : str
226 cluster_id : str
227 String id to added to runtime files, to prevent name collisions when using
227 String id to added to runtime files, to prevent name collisions when using
228 multiple clusters with a single profile simultaneously.
228 multiple clusters with a single profile simultaneously.
229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
229 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
230 Since this is text inserted into filenames, typical recommendations apply:
230 Since this is text inserted into filenames, typical recommendations apply:
231 Simple character strings are ideal, and spaces are not recommended (but
231 Simple character strings are ideal, and spaces are not recommended (but
232 should generally work)
232 should generally work)
233 context : zmq.Context
233 context : zmq.Context
234 Pass an existing zmq.Context instance, otherwise the client will create its own.
234 Pass an existing zmq.Context instance, otherwise the client will create its own.
235 debug : bool
235 debug : bool
236 flag for lots of message printing for debug purposes
236 flag for lots of message printing for debug purposes
237 timeout : int/float
237 timeout : int/float
238 time (in seconds) to wait for connection replies from the Hub
238 time (in seconds) to wait for connection replies from the Hub
239 [Default: 10]
239 [Default: 10]
240
240
241 #-------------- session related args ----------------
241 #-------------- session related args ----------------
242
242
243 config : Config object
243 config : Config object
244 If specified, this will be relayed to the Session for configuration
244 If specified, this will be relayed to the Session for configuration
245 username : str
245 username : str
246 set username for the session object
246 set username for the session object
247
247
248 #-------------- ssh related args ----------------
248 #-------------- ssh related args ----------------
249 # These are args for configuring the ssh tunnel to be used
249 # These are args for configuring the ssh tunnel to be used
250 # credentials are used to forward connections over ssh to the Controller
250 # credentials are used to forward connections over ssh to the Controller
251 # Note that the ip given in `addr` needs to be relative to sshserver
251 # Note that the ip given in `addr` needs to be relative to sshserver
252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
252 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
253 # and set sshserver as the same machine the Controller is on. However,
253 # and set sshserver as the same machine the Controller is on. However,
254 # the only requirement is that sshserver is able to see the Controller
254 # the only requirement is that sshserver is able to see the Controller
255 # (i.e. is within the same trusted network).
255 # (i.e. is within the same trusted network).
256
256
257 sshserver : str
257 sshserver : str
258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
258 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
259 If keyfile or password is specified, and this is not, it will default to
259 If keyfile or password is specified, and this is not, it will default to
260 the ip given in addr.
260 the ip given in addr.
261 sshkey : str; path to ssh private key file
261 sshkey : str; path to ssh private key file
262 This specifies a key to be used in ssh login, default None.
262 This specifies a key to be used in ssh login, default None.
263 Regular default ssh keys will be used without specifying this argument.
263 Regular default ssh keys will be used without specifying this argument.
264 password : str
264 password : str
265 Your ssh password to sshserver. Note that if this is left None,
265 Your ssh password to sshserver. Note that if this is left None,
266 you will be prompted for it if passwordless key based login is unavailable.
266 you will be prompted for it if passwordless key based login is unavailable.
267 paramiko : bool
267 paramiko : bool
268 flag for whether to use paramiko instead of shell ssh for tunneling.
268 flag for whether to use paramiko instead of shell ssh for tunneling.
269 [default: True on win32, False else]
269 [default: True on win32, False else]
270
270
271
271
272 Attributes
272 Attributes
273 ----------
273 ----------
274
274
275 ids : list of int engine IDs
275 ids : list of int engine IDs
276 requesting the ids attribute always synchronizes
276 requesting the ids attribute always synchronizes
277 the registration state. To request ids without synchronization,
277 the registration state. To request ids without synchronization,
278 use semi-private _ids attributes.
278 use semi-private _ids attributes.
279
279
280 history : list of msg_ids
280 history : list of msg_ids
281 a list of msg_ids, keeping track of all the execution
281 a list of msg_ids, keeping track of all the execution
282 messages you have submitted in order.
282 messages you have submitted in order.
283
283
284 outstanding : set of msg_ids
284 outstanding : set of msg_ids
285 a set of msg_ids that have been submitted, but whose
285 a set of msg_ids that have been submitted, but whose
286 results have not yet been received.
286 results have not yet been received.
287
287
288 results : dict
288 results : dict
289 a dict of all our results, keyed by msg_id
289 a dict of all our results, keyed by msg_id
290
290
291 block : bool
291 block : bool
292 determines default behavior when block not specified
292 determines default behavior when block not specified
293 in execution methods
293 in execution methods
294
294
295 Methods
295 Methods
296 -------
296 -------
297
297
298 spin
298 spin
299 flushes incoming results and registration state changes
299 flushes incoming results and registration state changes
300 control methods spin, and requesting `ids` also ensures up to date
300 control methods spin, and requesting `ids` also ensures up to date
301
301
302 wait
302 wait
303 wait on one or more msg_ids
303 wait on one or more msg_ids
304
304
305 execution methods
305 execution methods
306 apply
306 apply
307 legacy: execute, run
307 legacy: execute, run
308
308
309 data movement
309 data movement
310 push, pull, scatter, gather
310 push, pull, scatter, gather
311
311
312 query methods
312 query methods
313 queue_status, get_result, purge, result_status
313 queue_status, get_result, purge, result_status
314
314
315 control methods
315 control methods
316 abort, shutdown
316 abort, shutdown
317
317
318 """
318 """
319
319
320
320
321 block = Bool(False)
321 block = Bool(False)
322 outstanding = Set()
322 outstanding = Set()
323 results = Instance('collections.defaultdict', (dict,))
323 results = Instance('collections.defaultdict', (dict,))
324 metadata = Instance('collections.defaultdict', (Metadata,))
324 metadata = Instance('collections.defaultdict', (Metadata,))
325 history = List()
325 history = List()
326 debug = Bool(False)
326 debug = Bool(False)
327 _spin_thread = Any()
327 _spin_thread = Any()
328 _stop_spinning = Any()
328 _stop_spinning = Any()
329
329
330 profile=Unicode()
330 profile=Unicode()
331 def _profile_default(self):
331 def _profile_default(self):
332 if BaseIPythonApplication.initialized():
332 if BaseIPythonApplication.initialized():
333 # an IPython app *might* be running, try to get its profile
333 # an IPython app *might* be running, try to get its profile
334 try:
334 try:
335 return BaseIPythonApplication.instance().profile
335 return BaseIPythonApplication.instance().profile
336 except (AttributeError, MultipleInstanceError):
336 except (AttributeError, MultipleInstanceError):
337 # could be a *different* subclass of config.Application,
337 # could be a *different* subclass of config.Application,
338 # which would raise one of these two errors.
338 # which would raise one of these two errors.
339 return u'default'
339 return u'default'
340 else:
340 else:
341 return u'default'
341 return u'default'
342
342
343
343
344 _outstanding_dict = Instance('collections.defaultdict', (set,))
344 _outstanding_dict = Instance('collections.defaultdict', (set,))
345 _ids = List()
345 _ids = List()
346 _connected=Bool(False)
346 _connected=Bool(False)
347 _ssh=Bool(False)
347 _ssh=Bool(False)
348 _context = Instance('zmq.Context')
348 _context = Instance('zmq.Context')
349 _config = Dict()
349 _config = Dict()
350 _engines=Instance(util.ReverseDict, (), {})
350 _engines=Instance(util.ReverseDict, (), {})
351 # _hub_socket=Instance('zmq.Socket')
351 # _hub_socket=Instance('zmq.Socket')
352 _query_socket=Instance('zmq.Socket')
352 _query_socket=Instance('zmq.Socket')
353 _control_socket=Instance('zmq.Socket')
353 _control_socket=Instance('zmq.Socket')
354 _iopub_socket=Instance('zmq.Socket')
354 _iopub_socket=Instance('zmq.Socket')
355 _notification_socket=Instance('zmq.Socket')
355 _notification_socket=Instance('zmq.Socket')
356 _mux_socket=Instance('zmq.Socket')
356 _mux_socket=Instance('zmq.Socket')
357 _task_socket=Instance('zmq.Socket')
357 _task_socket=Instance('zmq.Socket')
358 _task_scheme=Unicode()
358 _task_scheme=Unicode()
359 _closed = False
359 _closed = False
360 _ignored_control_replies=Integer(0)
360 _ignored_control_replies=Integer(0)
361 _ignored_hub_replies=Integer(0)
361 _ignored_hub_replies=Integer(0)
362
362
363 def __new__(self, *args, **kw):
363 def __new__(self, *args, **kw):
364 # don't raise on positional args
364 # don't raise on positional args
365 return HasTraits.__new__(self, **kw)
365 return HasTraits.__new__(self, **kw)
366
366
367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
367 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
368 context=None, debug=False,
368 context=None, debug=False,
369 sshserver=None, sshkey=None, password=None, paramiko=None,
369 sshserver=None, sshkey=None, password=None, paramiko=None,
370 timeout=10, cluster_id=None, **extra_args
370 timeout=10, cluster_id=None, **extra_args
371 ):
371 ):
372 if profile:
372 if profile:
373 super(Client, self).__init__(debug=debug, profile=profile)
373 super(Client, self).__init__(debug=debug, profile=profile)
374 else:
374 else:
375 super(Client, self).__init__(debug=debug)
375 super(Client, self).__init__(debug=debug)
376 if context is None:
376 if context is None:
377 context = zmq.Context.instance()
377 context = zmq.Context.instance()
378 self._context = context
378 self._context = context
379 self._stop_spinning = Event()
379 self._stop_spinning = Event()
380
380
381 if 'url_or_file' in extra_args:
381 if 'url_or_file' in extra_args:
382 url_file = extra_args['url_or_file']
382 url_file = extra_args['url_or_file']
383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
383 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
384
384
385 if url_file and util.is_url(url_file):
385 if url_file and util.is_url(url_file):
386 raise ValueError("single urls cannot be specified, url-files must be used.")
386 raise ValueError("single urls cannot be specified, url-files must be used.")
387
387
388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
388 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
389
389
390 no_file_msg = '\n'.join([
390 no_file_msg = '\n'.join([
391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
391 "You have attempted to connect to an IPython Cluster but no Controller could be found.",
392 "Please double-check your configuration and ensure that a cluster is running.",
392 "Please double-check your configuration and ensure that a cluster is running.",
393 ])
393 ])
394
394
395 if self._cd is not None:
395 if self._cd is not None:
396 if url_file is None:
396 if url_file is None:
397 if not cluster_id:
397 if not cluster_id:
398 client_json = 'ipcontroller-client.json'
398 client_json = 'ipcontroller-client.json'
399 else:
399 else:
400 client_json = 'ipcontroller-%s-client.json' % cluster_id
400 client_json = 'ipcontroller-%s-client.json' % cluster_id
401 url_file = pjoin(self._cd.security_dir, client_json)
401 url_file = pjoin(self._cd.security_dir, client_json)
402 if not os.path.exists(url_file):
402 if not os.path.exists(url_file):
403 msg = '\n'.join([
403 msg = '\n'.join([
404 "Connection file %r not found." % compress_user(url_file),
404 "Connection file %r not found." % compress_user(url_file),
405 no_file_msg,
405 no_file_msg,
406 ])
406 ])
407 raise IOError(msg)
407 raise IOError(msg)
408 if url_file is None:
408 if url_file is None:
409 raise IOError(no_file_msg)
409 raise IOError(no_file_msg)
410
410
411 if not os.path.exists(url_file):
411 if not os.path.exists(url_file):
412 # Connection file explicitly specified, but not found
412 # Connection file explicitly specified, but not found
413 raise IOError("Connection file %r not found. Is a controller running?" % \
413 raise IOError("Connection file %r not found. Is a controller running?" % \
414 compress_user(url_file)
414 compress_user(url_file)
415 )
415 )
416
416
417 with open(url_file) as f:
417 with open(url_file) as f:
418 cfg = json.load(f)
418 cfg = json.load(f)
419
419
420 self._task_scheme = cfg['task_scheme']
420 self._task_scheme = cfg['task_scheme']
421
421
422 # sync defaults from args, json:
422 # sync defaults from args, json:
423 if sshserver:
423 if sshserver:
424 cfg['ssh'] = sshserver
424 cfg['ssh'] = sshserver
425
425
426 location = cfg.setdefault('location', None)
426 location = cfg.setdefault('location', None)
427
427
428 proto,addr = cfg['interface'].split('://')
428 proto,addr = cfg['interface'].split('://')
429 addr = util.disambiguate_ip_address(addr, location)
429 addr = util.disambiguate_ip_address(addr, location)
430 cfg['interface'] = "%s://%s" % (proto, addr)
430 cfg['interface'] = "%s://%s" % (proto, addr)
431
431
432 # turn interface,port into full urls:
432 # turn interface,port into full urls:
433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
433 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
434 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
435
435
436 url = cfg['registration']
436 url = cfg['registration']
437
437
438 if location is not None and addr == localhost():
438 if location is not None and addr == localhost():
439 # location specified, and connection is expected to be local
439 # location specified, and connection is expected to be local
440 if not is_local_ip(location) and not sshserver:
440 if not is_local_ip(location) and not sshserver:
441 # load ssh from JSON *only* if the controller is not on
441 # load ssh from JSON *only* if the controller is not on
442 # this machine
442 # this machine
443 sshserver=cfg['ssh']
443 sshserver=cfg['ssh']
444 if not is_local_ip(location) and not sshserver:
444 if not is_local_ip(location) and not sshserver:
445 # warn if no ssh specified, but SSH is probably needed
445 # warn if no ssh specified, but SSH is probably needed
446 # This is only a warning, because the most likely cause
446 # This is only a warning, because the most likely cause
447 # is a local Controller on a laptop whose IP is dynamic
447 # is a local Controller on a laptop whose IP is dynamic
448 warnings.warn("""
448 warnings.warn("""
449 Controller appears to be listening on localhost, but not on this machine.
449 Controller appears to be listening on localhost, but not on this machine.
450 If this is true, you should specify Client(...,sshserver='you@%s')
450 If this is true, you should specify Client(...,sshserver='you@%s')
451 or instruct your controller to listen on an external IP."""%location,
451 or instruct your controller to listen on an external IP."""%location,
452 RuntimeWarning)
452 RuntimeWarning)
453 elif not sshserver:
453 elif not sshserver:
454 # otherwise sync with cfg
454 # otherwise sync with cfg
455 sshserver = cfg['ssh']
455 sshserver = cfg['ssh']
456
456
457 self._config = cfg
457 self._config = cfg
458
458
459 self._ssh = bool(sshserver or sshkey or password)
459 self._ssh = bool(sshserver or sshkey or password)
460 if self._ssh and sshserver is None:
460 if self._ssh and sshserver is None:
461 # default to ssh via localhost
461 # default to ssh via localhost
462 sshserver = addr
462 sshserver = addr
463 if self._ssh and password is None:
463 if self._ssh and password is None:
464 from zmq.ssh import tunnel
464 from zmq.ssh import tunnel
465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
465 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
466 password=False
466 password=False
467 else:
467 else:
468 password = getpass("SSH Password for %s: "%sshserver)
468 password = getpass("SSH Password for %s: "%sshserver)
469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
469 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
470
470
471 # configure and construct the session
471 # configure and construct the session
472 try:
472 try:
473 extra_args['packer'] = cfg['pack']
473 extra_args['packer'] = cfg['pack']
474 extra_args['unpacker'] = cfg['unpack']
474 extra_args['unpacker'] = cfg['unpack']
475 extra_args['key'] = cast_bytes(cfg['key'])
475 extra_args['key'] = cast_bytes(cfg['key'])
476 extra_args['signature_scheme'] = cfg['signature_scheme']
476 extra_args['signature_scheme'] = cfg['signature_scheme']
477 except KeyError as exc:
477 except KeyError as exc:
478 msg = '\n'.join([
478 msg = '\n'.join([
479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
479 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
480 "If you are reusing connection files, remove them and start ipcontroller again."
480 "If you are reusing connection files, remove them and start ipcontroller again."
481 ])
481 ])
482 raise ValueError(msg.format(exc.message))
482 raise ValueError(msg.format(exc.message))
483
483
484 self.session = Session(**extra_args)
484 self.session = Session(**extra_args)
485
485
486 self._query_socket = self._context.socket(zmq.DEALER)
486 self._query_socket = self._context.socket(zmq.DEALER)
487
487
488 if self._ssh:
488 if self._ssh:
489 from zmq.ssh import tunnel
489 from zmq.ssh import tunnel
490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
490 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
491 else:
491 else:
492 self._query_socket.connect(cfg['registration'])
492 self._query_socket.connect(cfg['registration'])
493
493
494 self.session.debug = self.debug
494 self.session.debug = self.debug
495
495
496 self._notification_handlers = {'registration_notification' : self._register_engine,
496 self._notification_handlers = {'registration_notification' : self._register_engine,
497 'unregistration_notification' : self._unregister_engine,
497 'unregistration_notification' : self._unregister_engine,
498 'shutdown_notification' : lambda msg: self.close(),
498 'shutdown_notification' : lambda msg: self.close(),
499 }
499 }
500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
500 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
501 'apply_reply' : self._handle_apply_reply}
501 'apply_reply' : self._handle_apply_reply}
502
502
503 try:
503 try:
504 self._connect(sshserver, ssh_kwargs, timeout)
504 self._connect(sshserver, ssh_kwargs, timeout)
505 except:
505 except:
506 self.close(linger=0)
506 self.close(linger=0)
507 raise
507 raise
508
508
509 # last step: setup magics, if we are in IPython:
509 # last step: setup magics, if we are in IPython:
510
510
511 try:
511 try:
512 ip = get_ipython()
512 ip = get_ipython()
513 except NameError:
513 except NameError:
514 return
514 return
515 else:
515 else:
516 if 'px' not in ip.magics_manager.magics:
516 if 'px' not in ip.magics_manager.magics:
517 # in IPython but we are the first Client.
517 # in IPython but we are the first Client.
518 # activate a default view for parallel magics.
518 # activate a default view for parallel magics.
519 self.activate()
519 self.activate()
520
520
521 def __del__(self):
521 def __del__(self):
522 """cleanup sockets, but _not_ context."""
522 """cleanup sockets, but _not_ context."""
523 self.close()
523 self.close()
524
524
525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
525 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
526 if ipython_dir is None:
526 if ipython_dir is None:
527 ipython_dir = get_ipython_dir()
527 ipython_dir = get_ipython_dir()
528 if profile_dir is not None:
528 if profile_dir is not None:
529 try:
529 try:
530 self._cd = ProfileDir.find_profile_dir(profile_dir)
530 self._cd = ProfileDir.find_profile_dir(profile_dir)
531 return
531 return
532 except ProfileDirError:
532 except ProfileDirError:
533 pass
533 pass
534 elif profile is not None:
534 elif profile is not None:
535 try:
535 try:
536 self._cd = ProfileDir.find_profile_dir_by_name(
536 self._cd = ProfileDir.find_profile_dir_by_name(
537 ipython_dir, profile)
537 ipython_dir, profile)
538 return
538 return
539 except ProfileDirError:
539 except ProfileDirError:
540 pass
540 pass
541 self._cd = None
541 self._cd = None
542
542
543 def _update_engines(self, engines):
543 def _update_engines(self, engines):
544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
544 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
545 for k,v in iteritems(engines):
545 for k,v in iteritems(engines):
546 eid = int(k)
546 eid = int(k)
547 if eid not in self._engines:
547 if eid not in self._engines:
548 self._ids.append(eid)
548 self._ids.append(eid)
549 self._engines[eid] = v
549 self._engines[eid] = v
550 self._ids = sorted(self._ids)
550 self._ids = sorted(self._ids)
551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
551 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
552 self._task_scheme == 'pure' and self._task_socket:
552 self._task_scheme == 'pure' and self._task_socket:
553 self._stop_scheduling_tasks()
553 self._stop_scheduling_tasks()
554
554
555 def _stop_scheduling_tasks(self):
555 def _stop_scheduling_tasks(self):
556 """Stop scheduling tasks because an engine has been unregistered
556 """Stop scheduling tasks because an engine has been unregistered
557 from a pure ZMQ scheduler.
557 from a pure ZMQ scheduler.
558 """
558 """
559 self._task_socket.close()
559 self._task_socket.close()
560 self._task_socket = None
560 self._task_socket = None
561 msg = "An engine has been unregistered, and we are using pure " +\
561 msg = "An engine has been unregistered, and we are using pure " +\
562 "ZMQ task scheduling. Task farming will be disabled."
562 "ZMQ task scheduling. Task farming will be disabled."
563 if self.outstanding:
563 if self.outstanding:
564 msg += " If you were running tasks when this happened, " +\
564 msg += " If you were running tasks when this happened, " +\
565 "some `outstanding` msg_ids may never resolve."
565 "some `outstanding` msg_ids may never resolve."
566 warnings.warn(msg, RuntimeWarning)
566 warnings.warn(msg, RuntimeWarning)
567
567
568 def _build_targets(self, targets):
568 def _build_targets(self, targets):
569 """Turn valid target IDs or 'all' into two lists:
569 """Turn valid target IDs or 'all' into two lists:
570 (int_ids, uuids).
570 (int_ids, uuids).
571 """
571 """
572 if not self._ids:
572 if not self._ids:
573 # flush notification socket if no engines yet, just in case
573 # flush notification socket if no engines yet, just in case
574 if not self.ids:
574 if not self.ids:
575 raise error.NoEnginesRegistered("Can't build targets without any engines")
575 raise error.NoEnginesRegistered("Can't build targets without any engines")
576
576
577 if targets is None:
577 if targets is None:
578 targets = self._ids
578 targets = self._ids
579 elif isinstance(targets, string_types):
579 elif isinstance(targets, string_types):
580 if targets.lower() == 'all':
580 if targets.lower() == 'all':
581 targets = self._ids
581 targets = self._ids
582 else:
582 else:
583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
583 raise TypeError("%r not valid str target, must be 'all'"%(targets))
584 elif isinstance(targets, int):
584 elif isinstance(targets, int):
585 if targets < 0:
585 if targets < 0:
586 targets = self.ids[targets]
586 targets = self.ids[targets]
587 if targets not in self._ids:
587 if targets not in self._ids:
588 raise IndexError("No such engine: %i"%targets)
588 raise IndexError("No such engine: %i"%targets)
589 targets = [targets]
589 targets = [targets]
590
590
591 if isinstance(targets, slice):
591 if isinstance(targets, slice):
592 indices = list(range(len(self._ids))[targets])
592 indices = list(range(len(self._ids))[targets])
593 ids = self.ids
593 ids = self.ids
594 targets = [ ids[i] for i in indices ]
594 targets = [ ids[i] for i in indices ]
595
595
596 if not isinstance(targets, (tuple, list, xrange)):
596 if not isinstance(targets, (tuple, list, xrange)):
597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
597 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
598
598
599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
599 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
600
600
601 def _connect(self, sshserver, ssh_kwargs, timeout):
601 def _connect(self, sshserver, ssh_kwargs, timeout):
602 """setup all our socket connections to the cluster. This is called from
602 """setup all our socket connections to the cluster. This is called from
603 __init__."""
603 __init__."""
604
604
605 # Maybe allow reconnecting?
605 # Maybe allow reconnecting?
606 if self._connected:
606 if self._connected:
607 return
607 return
608 self._connected=True
608 self._connected=True
609
609
610 def connect_socket(s, url):
610 def connect_socket(s, url):
611 if self._ssh:
611 if self._ssh:
612 from zmq.ssh import tunnel
612 from zmq.ssh import tunnel
613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
613 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
614 else:
614 else:
615 return s.connect(url)
615 return s.connect(url)
616
616
617 self.session.send(self._query_socket, 'connection_request')
617 self.session.send(self._query_socket, 'connection_request')
618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
618 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
619 poller = zmq.Poller()
619 poller = zmq.Poller()
620 poller.register(self._query_socket, zmq.POLLIN)
620 poller.register(self._query_socket, zmq.POLLIN)
621 # poll expects milliseconds, timeout is seconds
621 # poll expects milliseconds, timeout is seconds
622 evts = poller.poll(timeout*1000)
622 evts = poller.poll(timeout*1000)
623 if not evts:
623 if not evts:
624 raise error.TimeoutError("Hub connection request timed out")
624 raise error.TimeoutError("Hub connection request timed out")
625 idents,msg = self.session.recv(self._query_socket,mode=0)
625 idents,msg = self.session.recv(self._query_socket,mode=0)
626 if self.debug:
626 if self.debug:
627 pprint(msg)
627 pprint(msg)
628 content = msg['content']
628 content = msg['content']
629 # self._config['registration'] = dict(content)
629 # self._config['registration'] = dict(content)
630 cfg = self._config
630 cfg = self._config
631 if content['status'] == 'ok':
631 if content['status'] == 'ok':
632 self._mux_socket = self._context.socket(zmq.DEALER)
632 self._mux_socket = self._context.socket(zmq.DEALER)
633 connect_socket(self._mux_socket, cfg['mux'])
633 connect_socket(self._mux_socket, cfg['mux'])
634
634
635 self._task_socket = self._context.socket(zmq.DEALER)
635 self._task_socket = self._context.socket(zmq.DEALER)
636 connect_socket(self._task_socket, cfg['task'])
636 connect_socket(self._task_socket, cfg['task'])
637
637
638 self._notification_socket = self._context.socket(zmq.SUB)
638 self._notification_socket = self._context.socket(zmq.SUB)
639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
639 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
640 connect_socket(self._notification_socket, cfg['notification'])
640 connect_socket(self._notification_socket, cfg['notification'])
641
641
642 self._control_socket = self._context.socket(zmq.DEALER)
642 self._control_socket = self._context.socket(zmq.DEALER)
643 connect_socket(self._control_socket, cfg['control'])
643 connect_socket(self._control_socket, cfg['control'])
644
644
645 self._iopub_socket = self._context.socket(zmq.SUB)
645 self._iopub_socket = self._context.socket(zmq.SUB)
646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
646 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
647 connect_socket(self._iopub_socket, cfg['iopub'])
647 connect_socket(self._iopub_socket, cfg['iopub'])
648
648
649 self._update_engines(dict(content['engines']))
649 self._update_engines(dict(content['engines']))
650 else:
650 else:
651 self._connected = False
651 self._connected = False
652 raise Exception("Failed to connect!")
652 raise Exception("Failed to connect!")
653
653
654 #--------------------------------------------------------------------------
654 #--------------------------------------------------------------------------
655 # handlers and callbacks for incoming messages
655 # handlers and callbacks for incoming messages
656 #--------------------------------------------------------------------------
656 #--------------------------------------------------------------------------
657
657
658 def _unwrap_exception(self, content):
658 def _unwrap_exception(self, content):
659 """unwrap exception, and remap engine_id to int."""
659 """unwrap exception, and remap engine_id to int."""
660 e = error.unwrap_exception(content)
660 e = error.unwrap_exception(content)
661 # print e.traceback
661 # print e.traceback
662 if e.engine_info:
662 if e.engine_info:
663 e_uuid = e.engine_info['engine_uuid']
663 e_uuid = e.engine_info['engine_uuid']
664 eid = self._engines[e_uuid]
664 eid = self._engines[e_uuid]
665 e.engine_info['engine_id'] = eid
665 e.engine_info['engine_id'] = eid
666 return e
666 return e
667
667
668 def _extract_metadata(self, msg):
668 def _extract_metadata(self, msg):
669 header = msg['header']
669 header = msg['header']
670 parent = msg['parent_header']
670 parent = msg['parent_header']
671 msg_meta = msg['metadata']
671 msg_meta = msg['metadata']
672 content = msg['content']
672 content = msg['content']
673 md = {'msg_id' : parent['msg_id'],
673 md = {'msg_id' : parent['msg_id'],
674 'received' : datetime.now(),
674 'received' : datetime.now(),
675 'engine_uuid' : msg_meta.get('engine', None),
675 'engine_uuid' : msg_meta.get('engine', None),
676 'follow' : msg_meta.get('follow', []),
676 'follow' : msg_meta.get('follow', []),
677 'after' : msg_meta.get('after', []),
677 'after' : msg_meta.get('after', []),
678 'status' : content['status'],
678 'status' : content['status'],
679 }
679 }
680
680
681 if md['engine_uuid'] is not None:
681 if md['engine_uuid'] is not None:
682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
682 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
683
683
684 if 'date' in parent:
684 if 'date' in parent:
685 md['submitted'] = parent['date']
685 md['submitted'] = parent['date']
686 if 'started' in msg_meta:
686 if 'started' in msg_meta:
687 md['started'] = parse_date(msg_meta['started'])
687 md['started'] = parse_date(msg_meta['started'])
688 if 'date' in header:
688 if 'date' in header:
689 md['completed'] = header['date']
689 md['completed'] = header['date']
690 return md
690 return md
691
691
692 def _register_engine(self, msg):
692 def _register_engine(self, msg):
693 """Register a new engine, and update our connection info."""
693 """Register a new engine, and update our connection info."""
694 content = msg['content']
694 content = msg['content']
695 eid = content['id']
695 eid = content['id']
696 d = {eid : content['uuid']}
696 d = {eid : content['uuid']}
697 self._update_engines(d)
697 self._update_engines(d)
698
698
699 def _unregister_engine(self, msg):
699 def _unregister_engine(self, msg):
700 """Unregister an engine that has died."""
700 """Unregister an engine that has died."""
701 content = msg['content']
701 content = msg['content']
702 eid = int(content['id'])
702 eid = int(content['id'])
703 if eid in self._ids:
703 if eid in self._ids:
704 self._ids.remove(eid)
704 self._ids.remove(eid)
705 uuid = self._engines.pop(eid)
705 uuid = self._engines.pop(eid)
706
706
707 self._handle_stranded_msgs(eid, uuid)
707 self._handle_stranded_msgs(eid, uuid)
708
708
709 if self._task_socket and self._task_scheme == 'pure':
709 if self._task_socket and self._task_scheme == 'pure':
710 self._stop_scheduling_tasks()
710 self._stop_scheduling_tasks()
711
711
712 def _handle_stranded_msgs(self, eid, uuid):
712 def _handle_stranded_msgs(self, eid, uuid):
713 """Handle messages known to be on an engine when the engine unregisters.
713 """Handle messages known to be on an engine when the engine unregisters.
714
714
715 It is possible that this will fire prematurely - that is, an engine will
715 It is possible that this will fire prematurely - that is, an engine will
716 go down after completing a result, and the client will be notified
716 go down after completing a result, and the client will be notified
717 of the unregistration and later receive the successful result.
717 of the unregistration and later receive the successful result.
718 """
718 """
719
719
720 outstanding = self._outstanding_dict[uuid]
720 outstanding = self._outstanding_dict[uuid]
721
721
722 for msg_id in list(outstanding):
722 for msg_id in list(outstanding):
723 if msg_id in self.results:
723 if msg_id in self.results:
724 # we already
724 # we already
725 continue
725 continue
726 try:
726 try:
727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
727 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
728 except:
728 except:
729 content = error.wrap_exception()
729 content = error.wrap_exception()
730 # build a fake message:
730 # build a fake message:
731 msg = self.session.msg('apply_reply', content=content)
731 msg = self.session.msg('apply_reply', content=content)
732 msg['parent_header']['msg_id'] = msg_id
732 msg['parent_header']['msg_id'] = msg_id
733 msg['metadata']['engine'] = uuid
733 msg['metadata']['engine'] = uuid
734 self._handle_apply_reply(msg)
734 self._handle_apply_reply(msg)
735
735
736 def _handle_execute_reply(self, msg):
736 def _handle_execute_reply(self, msg):
737 """Save the reply to an execute_request into our results.
737 """Save the reply to an execute_request into our results.
738
738
739 execute messages are never actually used. apply is used instead.
739 execute messages are never actually used. apply is used instead.
740 """
740 """
741
741
742 parent = msg['parent_header']
742 parent = msg['parent_header']
743 msg_id = parent['msg_id']
743 msg_id = parent['msg_id']
744 if msg_id not in self.outstanding:
744 if msg_id not in self.outstanding:
745 if msg_id in self.history:
745 if msg_id in self.history:
746 print("got stale result: %s"%msg_id)
746 print("got stale result: %s"%msg_id)
747 else:
747 else:
748 print("got unknown result: %s"%msg_id)
748 print("got unknown result: %s"%msg_id)
749 else:
749 else:
750 self.outstanding.remove(msg_id)
750 self.outstanding.remove(msg_id)
751
751
752 content = msg['content']
752 content = msg['content']
753 header = msg['header']
753 header = msg['header']
754
754
755 # construct metadata:
755 # construct metadata:
756 md = self.metadata[msg_id]
756 md = self.metadata[msg_id]
757 md.update(self._extract_metadata(msg))
757 md.update(self._extract_metadata(msg))
758 # is this redundant?
758 # is this redundant?
759 self.metadata[msg_id] = md
759 self.metadata[msg_id] = md
760
760
761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
761 e_outstanding = self._outstanding_dict[md['engine_uuid']]
762 if msg_id in e_outstanding:
762 if msg_id in e_outstanding:
763 e_outstanding.remove(msg_id)
763 e_outstanding.remove(msg_id)
764
764
765 # construct result:
765 # construct result:
766 if content['status'] == 'ok':
766 if content['status'] == 'ok':
767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
767 self.results[msg_id] = ExecuteReply(msg_id, content, md)
768 elif content['status'] == 'aborted':
768 elif content['status'] == 'aborted':
769 self.results[msg_id] = error.TaskAborted(msg_id)
769 self.results[msg_id] = error.TaskAborted(msg_id)
770 elif content['status'] == 'resubmitted':
770 elif content['status'] == 'resubmitted':
771 # TODO: handle resubmission
771 # TODO: handle resubmission
772 pass
772 pass
773 else:
773 else:
774 self.results[msg_id] = self._unwrap_exception(content)
774 self.results[msg_id] = self._unwrap_exception(content)
775
775
776 def _handle_apply_reply(self, msg):
776 def _handle_apply_reply(self, msg):
777 """Save the reply to an apply_request into our results."""
777 """Save the reply to an apply_request into our results."""
778 parent = msg['parent_header']
778 parent = msg['parent_header']
779 msg_id = parent['msg_id']
779 msg_id = parent['msg_id']
780 if msg_id not in self.outstanding:
780 if msg_id not in self.outstanding:
781 if msg_id in self.history:
781 if msg_id in self.history:
782 print("got stale result: %s"%msg_id)
782 print("got stale result: %s"%msg_id)
783 print(self.results[msg_id])
783 print(self.results[msg_id])
784 print(msg)
784 print(msg)
785 else:
785 else:
786 print("got unknown result: %s"%msg_id)
786 print("got unknown result: %s"%msg_id)
787 else:
787 else:
788 self.outstanding.remove(msg_id)
788 self.outstanding.remove(msg_id)
789 content = msg['content']
789 content = msg['content']
790 header = msg['header']
790 header = msg['header']
791
791
792 # construct metadata:
792 # construct metadata:
793 md = self.metadata[msg_id]
793 md = self.metadata[msg_id]
794 md.update(self._extract_metadata(msg))
794 md.update(self._extract_metadata(msg))
795 # is this redundant?
795 # is this redundant?
796 self.metadata[msg_id] = md
796 self.metadata[msg_id] = md
797
797
798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
798 e_outstanding = self._outstanding_dict[md['engine_uuid']]
799 if msg_id in e_outstanding:
799 if msg_id in e_outstanding:
800 e_outstanding.remove(msg_id)
800 e_outstanding.remove(msg_id)
801
801
802 # construct result:
802 # construct result:
803 if content['status'] == 'ok':
803 if content['status'] == 'ok':
804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
804 self.results[msg_id] = serialize.deserialize_object(msg['buffers'])[0]
805 elif content['status'] == 'aborted':
805 elif content['status'] == 'aborted':
806 self.results[msg_id] = error.TaskAborted(msg_id)
806 self.results[msg_id] = error.TaskAborted(msg_id)
807 elif content['status'] == 'resubmitted':
807 elif content['status'] == 'resubmitted':
808 # TODO: handle resubmission
808 # TODO: handle resubmission
809 pass
809 pass
810 else:
810 else:
811 self.results[msg_id] = self._unwrap_exception(content)
811 self.results[msg_id] = self._unwrap_exception(content)
812
812
813 def _flush_notifications(self):
813 def _flush_notifications(self):
814 """Flush notifications of engine registrations waiting
814 """Flush notifications of engine registrations waiting
815 in ZMQ queue."""
815 in ZMQ queue."""
816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
816 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 while msg is not None:
817 while msg is not None:
818 if self.debug:
818 if self.debug:
819 pprint(msg)
819 pprint(msg)
820 msg_type = msg['header']['msg_type']
820 msg_type = msg['header']['msg_type']
821 handler = self._notification_handlers.get(msg_type, None)
821 handler = self._notification_handlers.get(msg_type, None)
822 if handler is None:
822 if handler is None:
823 raise Exception("Unhandled message type: %s" % msg_type)
823 raise Exception("Unhandled message type: %s" % msg_type)
824 else:
824 else:
825 handler(msg)
825 handler(msg)
826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
826 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
827
827
828 def _flush_results(self, sock):
828 def _flush_results(self, sock):
829 """Flush task or queue results waiting in ZMQ queue."""
829 """Flush task or queue results waiting in ZMQ queue."""
830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 while msg is not None:
831 while msg is not None:
832 if self.debug:
832 if self.debug:
833 pprint(msg)
833 pprint(msg)
834 msg_type = msg['header']['msg_type']
834 msg_type = msg['header']['msg_type']
835 handler = self._queue_handlers.get(msg_type, None)
835 handler = self._queue_handlers.get(msg_type, None)
836 if handler is None:
836 if handler is None:
837 raise Exception("Unhandled message type: %s" % msg_type)
837 raise Exception("Unhandled message type: %s" % msg_type)
838 else:
838 else:
839 handler(msg)
839 handler(msg)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841
841
842 def _flush_control(self, sock):
842 def _flush_control(self, sock):
843 """Flush replies from the control channel waiting
843 """Flush replies from the control channel waiting
844 in the ZMQ queue.
844 in the ZMQ queue.
845
845
846 Currently: ignore them."""
846 Currently: ignore them."""
847 if self._ignored_control_replies <= 0:
847 if self._ignored_control_replies <= 0:
848 return
848 return
849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
849 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
850 while msg is not None:
850 while msg is not None:
851 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
852 if self.debug:
852 if self.debug:
853 pprint(msg)
853 pprint(msg)
854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
854 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
855
855
856 def _flush_ignored_control(self):
856 def _flush_ignored_control(self):
857 """flush ignored control replies"""
857 """flush ignored control replies"""
858 while self._ignored_control_replies > 0:
858 while self._ignored_control_replies > 0:
859 self.session.recv(self._control_socket)
859 self.session.recv(self._control_socket)
860 self._ignored_control_replies -= 1
860 self._ignored_control_replies -= 1
861
861
862 def _flush_ignored_hub_replies(self):
862 def _flush_ignored_hub_replies(self):
863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
863 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
864 while msg is not None:
864 while msg is not None:
865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
865 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
866
866
867 def _flush_iopub(self, sock):
867 def _flush_iopub(self, sock):
868 """Flush replies from the iopub channel waiting
868 """Flush replies from the iopub channel waiting
869 in the ZMQ queue.
869 in the ZMQ queue.
870 """
870 """
871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
872 while msg is not None:
872 while msg is not None:
873 if self.debug:
873 if self.debug:
874 pprint(msg)
874 pprint(msg)
875 parent = msg['parent_header']
875 parent = msg['parent_header']
876 if not parent or parent['session'] != self.session.session:
876 if not parent or parent['session'] != self.session.session:
877 # ignore IOPub messages not from here
877 # ignore IOPub messages not from here
878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
878 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
879 continue
879 continue
880 msg_id = parent['msg_id']
880 msg_id = parent['msg_id']
881 content = msg['content']
881 content = msg['content']
882 header = msg['header']
882 header = msg['header']
883 msg_type = msg['header']['msg_type']
883 msg_type = msg['header']['msg_type']
884
884
885 if msg_type == 'status' and msg_id not in self.metadata:
885 if msg_type == 'status' and msg_id not in self.metadata:
886 # ignore status messages if they aren't mine
886 # ignore status messages if they aren't mine
887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
887 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
888 continue
888 continue
889
889
890 # init metadata:
890 # init metadata:
891 md = self.metadata[msg_id]
891 md = self.metadata[msg_id]
892
892
893 if msg_type == 'stream':
893 if msg_type == 'stream':
894 name = content['name']
894 name = content['name']
895 s = md[name] or ''
895 s = md[name] or ''
896 md[name] = s + content['text']
896 md[name] = s + content['text']
897 elif msg_type == 'error':
897 elif msg_type == 'error':
898 md.update({'error' : self._unwrap_exception(content)})
898 md.update({'error' : self._unwrap_exception(content)})
899 elif msg_type == 'execute_input':
899 elif msg_type == 'execute_input':
900 md.update({'execute_input' : content['code']})
900 md.update({'execute_input' : content['code']})
901 elif msg_type == 'display_data':
901 elif msg_type == 'display_data':
902 md['outputs'].append(content)
902 md['outputs'].append(content)
903 elif msg_type == 'execute_result':
903 elif msg_type == 'execute_result':
904 md['execute_result'] = content
904 md['execute_result'] = content
905 elif msg_type == 'data_message':
905 elif msg_type == 'data_message':
906 data, remainder = serialize.deserialize_object(msg['buffers'])
906 data, remainder = serialize.deserialize_object(msg['buffers'])
907 md['data'].update(data)
907 md['data'].update(data)
908 elif msg_type == 'status':
908 elif msg_type == 'status':
909 # idle message comes after all outputs
909 # idle message comes after all outputs
910 if content['execution_state'] == 'idle':
910 if content['execution_state'] == 'idle':
911 md['outputs_ready'] = True
911 md['outputs_ready'] = True
912 else:
912 else:
913 # unhandled msg_type (status, etc.)
913 # unhandled msg_type (status, etc.)
914 pass
914 pass
915
915
916 # reduntant?
916 # reduntant?
917 self.metadata[msg_id] = md
917 self.metadata[msg_id] = md
918
918
919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
919 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
920
920
921 #--------------------------------------------------------------------------
921 #--------------------------------------------------------------------------
922 # len, getitem
922 # len, getitem
923 #--------------------------------------------------------------------------
923 #--------------------------------------------------------------------------
924
924
925 def __len__(self):
925 def __len__(self):
926 """len(client) returns # of engines."""
926 """len(client) returns # of engines."""
927 return len(self.ids)
927 return len(self.ids)
928
928
929 def __getitem__(self, key):
929 def __getitem__(self, key):
930 """index access returns DirectView multiplexer objects
930 """index access returns DirectView multiplexer objects
931
931
932 Must be int, slice, or list/tuple/xrange of ints"""
932 Must be int, slice, or list/tuple/xrange of ints"""
933 if not isinstance(key, (int, slice, tuple, list, xrange)):
933 if not isinstance(key, (int, slice, tuple, list, xrange)):
934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
934 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
935 else:
935 else:
936 return self.direct_view(key)
936 return self.direct_view(key)
937
937
938 def __iter__(self):
938 def __iter__(self):
939 """Since we define getitem, Client is iterable
939 """Since we define getitem, Client is iterable
940
940
941 but unless we also define __iter__, it won't work correctly unless engine IDs
941 but unless we also define __iter__, it won't work correctly unless engine IDs
942 start at zero and are continuous.
942 start at zero and are continuous.
943 """
943 """
944 for eid in self.ids:
944 for eid in self.ids:
945 yield self.direct_view(eid)
945 yield self.direct_view(eid)
946
946
947 #--------------------------------------------------------------------------
947 #--------------------------------------------------------------------------
948 # Begin public methods
948 # Begin public methods
949 #--------------------------------------------------------------------------
949 #--------------------------------------------------------------------------
950
950
951 @property
951 @property
952 def ids(self):
952 def ids(self):
953 """Always up-to-date ids property."""
953 """Always up-to-date ids property."""
954 self._flush_notifications()
954 self._flush_notifications()
955 # always copy:
955 # always copy:
956 return list(self._ids)
956 return list(self._ids)
957
957
958 def activate(self, targets='all', suffix=''):
958 def activate(self, targets='all', suffix=''):
959 """Create a DirectView and register it with IPython magics
959 """Create a DirectView and register it with IPython magics
960
960
961 Defines the magics `%px, %autopx, %pxresult, %%px`
961 Defines the magics `%px, %autopx, %pxresult, %%px`
962
962
963 Parameters
963 Parameters
964 ----------
964 ----------
965
965
966 targets: int, list of ints, or 'all'
966 targets: int, list of ints, or 'all'
967 The engines on which the view's magics will run
967 The engines on which the view's magics will run
968 suffix: str [default: '']
968 suffix: str [default: '']
969 The suffix, if any, for the magics. This allows you to have
969 The suffix, if any, for the magics. This allows you to have
970 multiple views associated with parallel magics at the same time.
970 multiple views associated with parallel magics at the same time.
971
971
972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
972 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
973 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
974 on engine 0.
974 on engine 0.
975 """
975 """
976 view = self.direct_view(targets)
976 view = self.direct_view(targets)
977 view.block = True
977 view.block = True
978 view.activate(suffix)
978 view.activate(suffix)
979 return view
979 return view
980
980
981 def close(self, linger=None):
981 def close(self, linger=None):
982 """Close my zmq Sockets
982 """Close my zmq Sockets
983
983
984 If `linger`, set the zmq LINGER socket option,
984 If `linger`, set the zmq LINGER socket option,
985 which allows discarding of messages.
985 which allows discarding of messages.
986 """
986 """
987 if self._closed:
987 if self._closed:
988 return
988 return
989 self.stop_spin_thread()
989 self.stop_spin_thread()
990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
990 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
991 for name in snames:
991 for name in snames:
992 socket = getattr(self, name)
992 socket = getattr(self, name)
993 if socket is not None and not socket.closed:
993 if socket is not None and not socket.closed:
994 if linger is not None:
994 if linger is not None:
995 socket.close(linger=linger)
995 socket.close(linger=linger)
996 else:
996 else:
997 socket.close()
997 socket.close()
998 self._closed = True
998 self._closed = True
999
999
1000 def _spin_every(self, interval=1):
1000 def _spin_every(self, interval=1):
1001 """target func for use in spin_thread"""
1001 """target func for use in spin_thread"""
1002 while True:
1002 while True:
1003 if self._stop_spinning.is_set():
1003 if self._stop_spinning.is_set():
1004 return
1004 return
1005 time.sleep(interval)
1005 time.sleep(interval)
1006 self.spin()
1006 self.spin()
1007
1007
1008 def spin_thread(self, interval=1):
1008 def spin_thread(self, interval=1):
1009 """call Client.spin() in a background thread on some regular interval
1009 """call Client.spin() in a background thread on some regular interval
1010
1010
1011 This helps ensure that messages don't pile up too much in the zmq queue
1011 This helps ensure that messages don't pile up too much in the zmq queue
1012 while you are working on other things, or just leaving an idle terminal.
1012 while you are working on other things, or just leaving an idle terminal.
1013
1013
1014 It also helps limit potential padding of the `received` timestamp
1014 It also helps limit potential padding of the `received` timestamp
1015 on AsyncResult objects, used for timings.
1015 on AsyncResult objects, used for timings.
1016
1016
1017 Parameters
1017 Parameters
1018 ----------
1018 ----------
1019
1019
1020 interval : float, optional
1020 interval : float, optional
1021 The interval on which to spin the client in the background thread
1021 The interval on which to spin the client in the background thread
1022 (simply passed to time.sleep).
1022 (simply passed to time.sleep).
1023
1023
1024 Notes
1024 Notes
1025 -----
1025 -----
1026
1026
1027 For precision timing, you may want to use this method to put a bound
1027 For precision timing, you may want to use this method to put a bound
1028 on the jitter (in seconds) in `received` timestamps used
1028 on the jitter (in seconds) in `received` timestamps used
1029 in AsyncResult.wall_time.
1029 in AsyncResult.wall_time.
1030
1030
1031 """
1031 """
1032 if self._spin_thread is not None:
1032 if self._spin_thread is not None:
1033 self.stop_spin_thread()
1033 self.stop_spin_thread()
1034 self._stop_spinning.clear()
1034 self._stop_spinning.clear()
1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1035 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1036 self._spin_thread.daemon = True
1036 self._spin_thread.daemon = True
1037 self._spin_thread.start()
1037 self._spin_thread.start()
1038
1038
1039 def stop_spin_thread(self):
1039 def stop_spin_thread(self):
1040 """stop background spin_thread, if any"""
1040 """stop background spin_thread, if any"""
1041 if self._spin_thread is not None:
1041 if self._spin_thread is not None:
1042 self._stop_spinning.set()
1042 self._stop_spinning.set()
1043 self._spin_thread.join()
1043 self._spin_thread.join()
1044 self._spin_thread = None
1044 self._spin_thread = None
1045
1045
1046 def spin(self):
1046 def spin(self):
1047 """Flush any registration notifications and execution results
1047 """Flush any registration notifications and execution results
1048 waiting in the ZMQ queue.
1048 waiting in the ZMQ queue.
1049 """
1049 """
1050 if self._notification_socket:
1050 if self._notification_socket:
1051 self._flush_notifications()
1051 self._flush_notifications()
1052 if self._iopub_socket:
1052 if self._iopub_socket:
1053 self._flush_iopub(self._iopub_socket)
1053 self._flush_iopub(self._iopub_socket)
1054 if self._mux_socket:
1054 if self._mux_socket:
1055 self._flush_results(self._mux_socket)
1055 self._flush_results(self._mux_socket)
1056 if self._task_socket:
1056 if self._task_socket:
1057 self._flush_results(self._task_socket)
1057 self._flush_results(self._task_socket)
1058 if self._control_socket:
1058 if self._control_socket:
1059 self._flush_control(self._control_socket)
1059 self._flush_control(self._control_socket)
1060 if self._query_socket:
1060 if self._query_socket:
1061 self._flush_ignored_hub_replies()
1061 self._flush_ignored_hub_replies()
1062
1062
1063 def wait(self, jobs=None, timeout=-1):
1063 def wait(self, jobs=None, timeout=-1):
1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1064 """waits on one or more `jobs`, for up to `timeout` seconds.
1065
1065
1066 Parameters
1066 Parameters
1067 ----------
1067 ----------
1068
1068
1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1069 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1070 ints are indices to self.history
1070 ints are indices to self.history
1071 strs are msg_ids
1071 strs are msg_ids
1072 default: wait on all outstanding messages
1072 default: wait on all outstanding messages
1073 timeout : float
1073 timeout : float
1074 a time in seconds, after which to give up.
1074 a time in seconds, after which to give up.
1075 default is -1, which means no timeout
1075 default is -1, which means no timeout
1076
1076
1077 Returns
1077 Returns
1078 -------
1078 -------
1079
1079
1080 True : when all msg_ids are done
1080 True : when all msg_ids are done
1081 False : timeout reached, some msg_ids still outstanding
1081 False : timeout reached, some msg_ids still outstanding
1082 """
1082 """
1083 tic = time.time()
1083 tic = time.time()
1084 if jobs is None:
1084 if jobs is None:
1085 theids = self.outstanding
1085 theids = self.outstanding
1086 else:
1086 else:
1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1087 if isinstance(jobs, string_types + (int, AsyncResult)):
1088 jobs = [jobs]
1088 jobs = [jobs]
1089 theids = set()
1089 theids = set()
1090 for job in jobs:
1090 for job in jobs:
1091 if isinstance(job, int):
1091 if isinstance(job, int):
1092 # index access
1092 # index access
1093 job = self.history[job]
1093 job = self.history[job]
1094 elif isinstance(job, AsyncResult):
1094 elif isinstance(job, AsyncResult):
1095 theids.update(job.msg_ids)
1095 theids.update(job.msg_ids)
1096 continue
1096 continue
1097 theids.add(job)
1097 theids.add(job)
1098 if not theids.intersection(self.outstanding):
1098 if not theids.intersection(self.outstanding):
1099 return True
1099 return True
1100 self.spin()
1100 self.spin()
1101 while theids.intersection(self.outstanding):
1101 while theids.intersection(self.outstanding):
1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1102 if timeout >= 0 and ( time.time()-tic ) > timeout:
1103 break
1103 break
1104 time.sleep(1e-3)
1104 time.sleep(1e-3)
1105 self.spin()
1105 self.spin()
1106 return len(theids.intersection(self.outstanding)) == 0
1106 return len(theids.intersection(self.outstanding)) == 0
1107
1107
1108 #--------------------------------------------------------------------------
1108 #--------------------------------------------------------------------------
1109 # Control methods
1109 # Control methods
1110 #--------------------------------------------------------------------------
1110 #--------------------------------------------------------------------------
1111
1111
1112 @spin_first
1112 @spin_first
1113 def clear(self, targets=None, block=None):
1113 def clear(self, targets=None, block=None):
1114 """Clear the namespace in target(s)."""
1114 """Clear the namespace in target(s)."""
1115 block = self.block if block is None else block
1115 block = self.block if block is None else block
1116 targets = self._build_targets(targets)[0]
1116 targets = self._build_targets(targets)[0]
1117 for t in targets:
1117 for t in targets:
1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1118 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1119 error = False
1119 error = False
1120 if block:
1120 if block:
1121 self._flush_ignored_control()
1121 self._flush_ignored_control()
1122 for i in range(len(targets)):
1122 for i in range(len(targets)):
1123 idents,msg = self.session.recv(self._control_socket,0)
1123 idents,msg = self.session.recv(self._control_socket,0)
1124 if self.debug:
1124 if self.debug:
1125 pprint(msg)
1125 pprint(msg)
1126 if msg['content']['status'] != 'ok':
1126 if msg['content']['status'] != 'ok':
1127 error = self._unwrap_exception(msg['content'])
1127 error = self._unwrap_exception(msg['content'])
1128 else:
1128 else:
1129 self._ignored_control_replies += len(targets)
1129 self._ignored_control_replies += len(targets)
1130 if error:
1130 if error:
1131 raise error
1131 raise error
1132
1132
1133
1133
1134 @spin_first
1134 @spin_first
1135 def abort(self, jobs=None, targets=None, block=None):
1135 def abort(self, jobs=None, targets=None, block=None):
1136 """Abort specific jobs from the execution queues of target(s).
1136 """Abort specific jobs from the execution queues of target(s).
1137
1137
1138 This is a mechanism to prevent jobs that have already been submitted
1138 This is a mechanism to prevent jobs that have already been submitted
1139 from executing.
1139 from executing.
1140
1140
1141 Parameters
1141 Parameters
1142 ----------
1142 ----------
1143
1143
1144 jobs : msg_id, list of msg_ids, or AsyncResult
1144 jobs : msg_id, list of msg_ids, or AsyncResult
1145 The jobs to be aborted
1145 The jobs to be aborted
1146
1146
1147 If unspecified/None: abort all outstanding jobs.
1147 If unspecified/None: abort all outstanding jobs.
1148
1148
1149 """
1149 """
1150 block = self.block if block is None else block
1150 block = self.block if block is None else block
1151 jobs = jobs if jobs is not None else list(self.outstanding)
1151 jobs = jobs if jobs is not None else list(self.outstanding)
1152 targets = self._build_targets(targets)[0]
1152 targets = self._build_targets(targets)[0]
1153
1153
1154 msg_ids = []
1154 msg_ids = []
1155 if isinstance(jobs, string_types + (AsyncResult,)):
1155 if isinstance(jobs, string_types + (AsyncResult,)):
1156 jobs = [jobs]
1156 jobs = [jobs]
1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1157 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1158 if bad_ids:
1158 if bad_ids:
1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1159 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1160 for j in jobs:
1160 for j in jobs:
1161 if isinstance(j, AsyncResult):
1161 if isinstance(j, AsyncResult):
1162 msg_ids.extend(j.msg_ids)
1162 msg_ids.extend(j.msg_ids)
1163 else:
1163 else:
1164 msg_ids.append(j)
1164 msg_ids.append(j)
1165 content = dict(msg_ids=msg_ids)
1165 content = dict(msg_ids=msg_ids)
1166 for t in targets:
1166 for t in targets:
1167 self.session.send(self._control_socket, 'abort_request',
1167 self.session.send(self._control_socket, 'abort_request',
1168 content=content, ident=t)
1168 content=content, ident=t)
1169 error = False
1169 error = False
1170 if block:
1170 if block:
1171 self._flush_ignored_control()
1171 self._flush_ignored_control()
1172 for i in range(len(targets)):
1172 for i in range(len(targets)):
1173 idents,msg = self.session.recv(self._control_socket,0)
1173 idents,msg = self.session.recv(self._control_socket,0)
1174 if self.debug:
1174 if self.debug:
1175 pprint(msg)
1175 pprint(msg)
1176 if msg['content']['status'] != 'ok':
1176 if msg['content']['status'] != 'ok':
1177 error = self._unwrap_exception(msg['content'])
1177 error = self._unwrap_exception(msg['content'])
1178 else:
1178 else:
1179 self._ignored_control_replies += len(targets)
1179 self._ignored_control_replies += len(targets)
1180 if error:
1180 if error:
1181 raise error
1181 raise error
1182
1182
1183 @spin_first
1183 @spin_first
1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1184 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1185 """Terminates one or more engine processes, optionally including the hub.
1185 """Terminates one or more engine processes, optionally including the hub.
1186
1186
1187 Parameters
1187 Parameters
1188 ----------
1188 ----------
1189
1189
1190 targets: list of ints or 'all' [default: all]
1190 targets: list of ints or 'all' [default: all]
1191 Which engines to shutdown.
1191 Which engines to shutdown.
1192 hub: bool [default: False]
1192 hub: bool [default: False]
1193 Whether to include the Hub. hub=True implies targets='all'.
1193 Whether to include the Hub. hub=True implies targets='all'.
1194 block: bool [default: self.block]
1194 block: bool [default: self.block]
1195 Whether to wait for clean shutdown replies or not.
1195 Whether to wait for clean shutdown replies or not.
1196 restart: bool [default: False]
1196 restart: bool [default: False]
1197 NOT IMPLEMENTED
1197 NOT IMPLEMENTED
1198 whether to restart engines after shutting them down.
1198 whether to restart engines after shutting them down.
1199 """
1199 """
1200 from IPython.parallel.error import NoEnginesRegistered
1200 from ipython_parallel.error import NoEnginesRegistered
1201 if restart:
1201 if restart:
1202 raise NotImplementedError("Engine restart is not yet implemented")
1202 raise NotImplementedError("Engine restart is not yet implemented")
1203
1203
1204 block = self.block if block is None else block
1204 block = self.block if block is None else block
1205 if hub:
1205 if hub:
1206 targets = 'all'
1206 targets = 'all'
1207 try:
1207 try:
1208 targets = self._build_targets(targets)[0]
1208 targets = self._build_targets(targets)[0]
1209 except NoEnginesRegistered:
1209 except NoEnginesRegistered:
1210 targets = []
1210 targets = []
1211 for t in targets:
1211 for t in targets:
1212 self.session.send(self._control_socket, 'shutdown_request',
1212 self.session.send(self._control_socket, 'shutdown_request',
1213 content={'restart':restart},ident=t)
1213 content={'restart':restart},ident=t)
1214 error = False
1214 error = False
1215 if block or hub:
1215 if block or hub:
1216 self._flush_ignored_control()
1216 self._flush_ignored_control()
1217 for i in range(len(targets)):
1217 for i in range(len(targets)):
1218 idents,msg = self.session.recv(self._control_socket, 0)
1218 idents,msg = self.session.recv(self._control_socket, 0)
1219 if self.debug:
1219 if self.debug:
1220 pprint(msg)
1220 pprint(msg)
1221 if msg['content']['status'] != 'ok':
1221 if msg['content']['status'] != 'ok':
1222 error = self._unwrap_exception(msg['content'])
1222 error = self._unwrap_exception(msg['content'])
1223 else:
1223 else:
1224 self._ignored_control_replies += len(targets)
1224 self._ignored_control_replies += len(targets)
1225
1225
1226 if hub:
1226 if hub:
1227 time.sleep(0.25)
1227 time.sleep(0.25)
1228 self.session.send(self._query_socket, 'shutdown_request')
1228 self.session.send(self._query_socket, 'shutdown_request')
1229 idents,msg = self.session.recv(self._query_socket, 0)
1229 idents,msg = self.session.recv(self._query_socket, 0)
1230 if self.debug:
1230 if self.debug:
1231 pprint(msg)
1231 pprint(msg)
1232 if msg['content']['status'] != 'ok':
1232 if msg['content']['status'] != 'ok':
1233 error = self._unwrap_exception(msg['content'])
1233 error = self._unwrap_exception(msg['content'])
1234
1234
1235 if error:
1235 if error:
1236 raise error
1236 raise error
1237
1237
1238 #--------------------------------------------------------------------------
1238 #--------------------------------------------------------------------------
1239 # Execution related methods
1239 # Execution related methods
1240 #--------------------------------------------------------------------------
1240 #--------------------------------------------------------------------------
1241
1241
1242 def _maybe_raise(self, result):
1242 def _maybe_raise(self, result):
1243 """wrapper for maybe raising an exception if apply failed."""
1243 """wrapper for maybe raising an exception if apply failed."""
1244 if isinstance(result, error.RemoteError):
1244 if isinstance(result, error.RemoteError):
1245 raise result
1245 raise result
1246
1246
1247 return result
1247 return result
1248
1248
1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1249 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1250 ident=None):
1250 ident=None):
1251 """construct and send an apply message via a socket.
1251 """construct and send an apply message via a socket.
1252
1252
1253 This is the principal method with which all engine execution is performed by views.
1253 This is the principal method with which all engine execution is performed by views.
1254 """
1254 """
1255
1255
1256 if self._closed:
1256 if self._closed:
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1257 raise RuntimeError("Client cannot be used after its sockets have been closed")
1258
1258
1259 # defaults:
1259 # defaults:
1260 args = args if args is not None else []
1260 args = args if args is not None else []
1261 kwargs = kwargs if kwargs is not None else {}
1261 kwargs = kwargs if kwargs is not None else {}
1262 metadata = metadata if metadata is not None else {}
1262 metadata = metadata if metadata is not None else {}
1263
1263
1264 # validate arguments
1264 # validate arguments
1265 if not callable(f) and not isinstance(f, Reference):
1265 if not callable(f) and not isinstance(f, Reference):
1266 raise TypeError("f must be callable, not %s"%type(f))
1266 raise TypeError("f must be callable, not %s"%type(f))
1267 if not isinstance(args, (tuple, list)):
1267 if not isinstance(args, (tuple, list)):
1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1268 raise TypeError("args must be tuple or list, not %s"%type(args))
1269 if not isinstance(kwargs, dict):
1269 if not isinstance(kwargs, dict):
1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1270 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1271 if not isinstance(metadata, dict):
1271 if not isinstance(metadata, dict):
1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1272 raise TypeError("metadata must be dict, not %s"%type(metadata))
1273
1273
1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1274 bufs = serialize.pack_apply_message(f, args, kwargs,
1275 buffer_threshold=self.session.buffer_threshold,
1275 buffer_threshold=self.session.buffer_threshold,
1276 item_threshold=self.session.item_threshold,
1276 item_threshold=self.session.item_threshold,
1277 )
1277 )
1278
1278
1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1279 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1280 metadata=metadata, track=track)
1280 metadata=metadata, track=track)
1281
1281
1282 msg_id = msg['header']['msg_id']
1282 msg_id = msg['header']['msg_id']
1283 self.outstanding.add(msg_id)
1283 self.outstanding.add(msg_id)
1284 if ident:
1284 if ident:
1285 # possibly routed to a specific engine
1285 # possibly routed to a specific engine
1286 if isinstance(ident, list):
1286 if isinstance(ident, list):
1287 ident = ident[-1]
1287 ident = ident[-1]
1288 if ident in self._engines.values():
1288 if ident in self._engines.values():
1289 # save for later, in case of engine death
1289 # save for later, in case of engine death
1290 self._outstanding_dict[ident].add(msg_id)
1290 self._outstanding_dict[ident].add(msg_id)
1291 self.history.append(msg_id)
1291 self.history.append(msg_id)
1292 self.metadata[msg_id]['submitted'] = datetime.now()
1292 self.metadata[msg_id]['submitted'] = datetime.now()
1293
1293
1294 return msg
1294 return msg
1295
1295
1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1296 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1297 """construct and send an execute request via a socket.
1297 """construct and send an execute request via a socket.
1298
1298
1299 """
1299 """
1300
1300
1301 if self._closed:
1301 if self._closed:
1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1302 raise RuntimeError("Client cannot be used after its sockets have been closed")
1303
1303
1304 # defaults:
1304 # defaults:
1305 metadata = metadata if metadata is not None else {}
1305 metadata = metadata if metadata is not None else {}
1306
1306
1307 # validate arguments
1307 # validate arguments
1308 if not isinstance(code, string_types):
1308 if not isinstance(code, string_types):
1309 raise TypeError("code must be text, not %s" % type(code))
1309 raise TypeError("code must be text, not %s" % type(code))
1310 if not isinstance(metadata, dict):
1310 if not isinstance(metadata, dict):
1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1311 raise TypeError("metadata must be dict, not %s" % type(metadata))
1312
1312
1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1313 content = dict(code=code, silent=bool(silent), user_expressions={})
1314
1314
1315
1315
1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1316 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1317 metadata=metadata)
1317 metadata=metadata)
1318
1318
1319 msg_id = msg['header']['msg_id']
1319 msg_id = msg['header']['msg_id']
1320 self.outstanding.add(msg_id)
1320 self.outstanding.add(msg_id)
1321 if ident:
1321 if ident:
1322 # possibly routed to a specific engine
1322 # possibly routed to a specific engine
1323 if isinstance(ident, list):
1323 if isinstance(ident, list):
1324 ident = ident[-1]
1324 ident = ident[-1]
1325 if ident in self._engines.values():
1325 if ident in self._engines.values():
1326 # save for later, in case of engine death
1326 # save for later, in case of engine death
1327 self._outstanding_dict[ident].add(msg_id)
1327 self._outstanding_dict[ident].add(msg_id)
1328 self.history.append(msg_id)
1328 self.history.append(msg_id)
1329 self.metadata[msg_id]['submitted'] = datetime.now()
1329 self.metadata[msg_id]['submitted'] = datetime.now()
1330
1330
1331 return msg
1331 return msg
1332
1332
1333 #--------------------------------------------------------------------------
1333 #--------------------------------------------------------------------------
1334 # construct a View object
1334 # construct a View object
1335 #--------------------------------------------------------------------------
1335 #--------------------------------------------------------------------------
1336
1336
1337 def load_balanced_view(self, targets=None):
1337 def load_balanced_view(self, targets=None):
1338 """construct a DirectView object.
1338 """construct a DirectView object.
1339
1339
1340 If no arguments are specified, create a LoadBalancedView
1340 If no arguments are specified, create a LoadBalancedView
1341 using all engines.
1341 using all engines.
1342
1342
1343 Parameters
1343 Parameters
1344 ----------
1344 ----------
1345
1345
1346 targets: list,slice,int,etc. [default: use all engines]
1346 targets: list,slice,int,etc. [default: use all engines]
1347 The subset of engines across which to load-balance
1347 The subset of engines across which to load-balance
1348 """
1348 """
1349 if targets == 'all':
1349 if targets == 'all':
1350 targets = None
1350 targets = None
1351 if targets is not None:
1351 if targets is not None:
1352 targets = self._build_targets(targets)[1]
1352 targets = self._build_targets(targets)[1]
1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1353 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1354
1354
1355 def direct_view(self, targets='all'):
1355 def direct_view(self, targets='all'):
1356 """construct a DirectView object.
1356 """construct a DirectView object.
1357
1357
1358 If no targets are specified, create a DirectView using all engines.
1358 If no targets are specified, create a DirectView using all engines.
1359
1359
1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1360 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1361 evaluate the target engines at each execution, whereas rc[:] will connect to
1362 all *current* engines, and that list will not change.
1362 all *current* engines, and that list will not change.
1363
1363
1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1364 That is, 'all' will always use all engines, whereas rc[:] will not use
1365 engines added after the DirectView is constructed.
1365 engines added after the DirectView is constructed.
1366
1366
1367 Parameters
1367 Parameters
1368 ----------
1368 ----------
1369
1369
1370 targets: list,slice,int,etc. [default: use all engines]
1370 targets: list,slice,int,etc. [default: use all engines]
1371 The engines to use for the View
1371 The engines to use for the View
1372 """
1372 """
1373 single = isinstance(targets, int)
1373 single = isinstance(targets, int)
1374 # allow 'all' to be lazily evaluated at each execution
1374 # allow 'all' to be lazily evaluated at each execution
1375 if targets != 'all':
1375 if targets != 'all':
1376 targets = self._build_targets(targets)[1]
1376 targets = self._build_targets(targets)[1]
1377 if single:
1377 if single:
1378 targets = targets[0]
1378 targets = targets[0]
1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1379 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1380
1380
1381 #--------------------------------------------------------------------------
1381 #--------------------------------------------------------------------------
1382 # Query methods
1382 # Query methods
1383 #--------------------------------------------------------------------------
1383 #--------------------------------------------------------------------------
1384
1384
1385 @spin_first
1385 @spin_first
1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1386 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1387 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1388
1388
1389 If the client already has the results, no request to the Hub will be made.
1389 If the client already has the results, no request to the Hub will be made.
1390
1390
1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1391 This is a convenient way to construct AsyncResult objects, which are wrappers
1392 that include metadata about execution, and allow for awaiting results that
1392 that include metadata about execution, and allow for awaiting results that
1393 were not submitted by this Client.
1393 were not submitted by this Client.
1394
1394
1395 It can also be a convenient way to retrieve the metadata associated with
1395 It can also be a convenient way to retrieve the metadata associated with
1396 blocking execution, since it always retrieves
1396 blocking execution, since it always retrieves
1397
1397
1398 Examples
1398 Examples
1399 --------
1399 --------
1400 ::
1400 ::
1401
1401
1402 In [10]: r = client.apply()
1402 In [10]: r = client.apply()
1403
1403
1404 Parameters
1404 Parameters
1405 ----------
1405 ----------
1406
1406
1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1407 indices_or_msg_ids : integer history index, str msg_id, or list of either
1408 The indices or msg_ids of indices to be retrieved
1408 The indices or msg_ids of indices to be retrieved
1409
1409
1410 block : bool
1410 block : bool
1411 Whether to wait for the result to be done
1411 Whether to wait for the result to be done
1412 owner : bool [default: True]
1412 owner : bool [default: True]
1413 Whether this AsyncResult should own the result.
1413 Whether this AsyncResult should own the result.
1414 If so, calling `ar.get()` will remove data from the
1414 If so, calling `ar.get()` will remove data from the
1415 client's result and metadata cache.
1415 client's result and metadata cache.
1416 There should only be one owner of any given msg_id.
1416 There should only be one owner of any given msg_id.
1417
1417
1418 Returns
1418 Returns
1419 -------
1419 -------
1420
1420
1421 AsyncResult
1421 AsyncResult
1422 A single AsyncResult object will always be returned.
1422 A single AsyncResult object will always be returned.
1423
1423
1424 AsyncHubResult
1424 AsyncHubResult
1425 A subclass of AsyncResult that retrieves results from the Hub
1425 A subclass of AsyncResult that retrieves results from the Hub
1426
1426
1427 """
1427 """
1428 block = self.block if block is None else block
1428 block = self.block if block is None else block
1429 if indices_or_msg_ids is None:
1429 if indices_or_msg_ids is None:
1430 indices_or_msg_ids = -1
1430 indices_or_msg_ids = -1
1431
1431
1432 single_result = False
1432 single_result = False
1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1433 if not isinstance(indices_or_msg_ids, (list,tuple)):
1434 indices_or_msg_ids = [indices_or_msg_ids]
1434 indices_or_msg_ids = [indices_or_msg_ids]
1435 single_result = True
1435 single_result = True
1436
1436
1437 theids = []
1437 theids = []
1438 for id in indices_or_msg_ids:
1438 for id in indices_or_msg_ids:
1439 if isinstance(id, int):
1439 if isinstance(id, int):
1440 id = self.history[id]
1440 id = self.history[id]
1441 if not isinstance(id, string_types):
1441 if not isinstance(id, string_types):
1442 raise TypeError("indices must be str or int, not %r"%id)
1442 raise TypeError("indices must be str or int, not %r"%id)
1443 theids.append(id)
1443 theids.append(id)
1444
1444
1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1445 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1446 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1447
1447
1448 # given single msg_id initially, get_result shot get the result itself,
1448 # given single msg_id initially, get_result shot get the result itself,
1449 # not a length-one list
1449 # not a length-one list
1450 if single_result:
1450 if single_result:
1451 theids = theids[0]
1451 theids = theids[0]
1452
1452
1453 if remote_ids:
1453 if remote_ids:
1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1454 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1455 else:
1455 else:
1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1456 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1457
1457
1458 if block:
1458 if block:
1459 ar.wait()
1459 ar.wait()
1460
1460
1461 return ar
1461 return ar
1462
1462
1463 @spin_first
1463 @spin_first
1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1464 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1465 """Resubmit one or more tasks.
1465 """Resubmit one or more tasks.
1466
1466
1467 in-flight tasks may not be resubmitted.
1467 in-flight tasks may not be resubmitted.
1468
1468
1469 Parameters
1469 Parameters
1470 ----------
1470 ----------
1471
1471
1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1472 indices_or_msg_ids : integer history index, str msg_id, or list of either
1473 The indices or msg_ids of indices to be retrieved
1473 The indices or msg_ids of indices to be retrieved
1474
1474
1475 block : bool
1475 block : bool
1476 Whether to wait for the result to be done
1476 Whether to wait for the result to be done
1477
1477
1478 Returns
1478 Returns
1479 -------
1479 -------
1480
1480
1481 AsyncHubResult
1481 AsyncHubResult
1482 A subclass of AsyncResult that retrieves results from the Hub
1482 A subclass of AsyncResult that retrieves results from the Hub
1483
1483
1484 """
1484 """
1485 block = self.block if block is None else block
1485 block = self.block if block is None else block
1486 if indices_or_msg_ids is None:
1486 if indices_or_msg_ids is None:
1487 indices_or_msg_ids = -1
1487 indices_or_msg_ids = -1
1488
1488
1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1489 if not isinstance(indices_or_msg_ids, (list,tuple)):
1490 indices_or_msg_ids = [indices_or_msg_ids]
1490 indices_or_msg_ids = [indices_or_msg_ids]
1491
1491
1492 theids = []
1492 theids = []
1493 for id in indices_or_msg_ids:
1493 for id in indices_or_msg_ids:
1494 if isinstance(id, int):
1494 if isinstance(id, int):
1495 id = self.history[id]
1495 id = self.history[id]
1496 if not isinstance(id, string_types):
1496 if not isinstance(id, string_types):
1497 raise TypeError("indices must be str or int, not %r"%id)
1497 raise TypeError("indices must be str or int, not %r"%id)
1498 theids.append(id)
1498 theids.append(id)
1499
1499
1500 content = dict(msg_ids = theids)
1500 content = dict(msg_ids = theids)
1501
1501
1502 self.session.send(self._query_socket, 'resubmit_request', content)
1502 self.session.send(self._query_socket, 'resubmit_request', content)
1503
1503
1504 zmq.select([self._query_socket], [], [])
1504 zmq.select([self._query_socket], [], [])
1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1505 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1506 if self.debug:
1506 if self.debug:
1507 pprint(msg)
1507 pprint(msg)
1508 content = msg['content']
1508 content = msg['content']
1509 if content['status'] != 'ok':
1509 if content['status'] != 'ok':
1510 raise self._unwrap_exception(content)
1510 raise self._unwrap_exception(content)
1511 mapping = content['resubmitted']
1511 mapping = content['resubmitted']
1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1512 new_ids = [ mapping[msg_id] for msg_id in theids ]
1513
1513
1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1514 ar = AsyncHubResult(self, msg_ids=new_ids)
1515
1515
1516 if block:
1516 if block:
1517 ar.wait()
1517 ar.wait()
1518
1518
1519 return ar
1519 return ar
1520
1520
1521 @spin_first
1521 @spin_first
1522 def result_status(self, msg_ids, status_only=True):
1522 def result_status(self, msg_ids, status_only=True):
1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1523 """Check on the status of the result(s) of the apply request with `msg_ids`.
1524
1524
1525 If status_only is False, then the actual results will be retrieved, else
1525 If status_only is False, then the actual results will be retrieved, else
1526 only the status of the results will be checked.
1526 only the status of the results will be checked.
1527
1527
1528 Parameters
1528 Parameters
1529 ----------
1529 ----------
1530
1530
1531 msg_ids : list of msg_ids
1531 msg_ids : list of msg_ids
1532 if int:
1532 if int:
1533 Passed as index to self.history for convenience.
1533 Passed as index to self.history for convenience.
1534 status_only : bool (default: True)
1534 status_only : bool (default: True)
1535 if False:
1535 if False:
1536 Retrieve the actual results of completed tasks.
1536 Retrieve the actual results of completed tasks.
1537
1537
1538 Returns
1538 Returns
1539 -------
1539 -------
1540
1540
1541 results : dict
1541 results : dict
1542 There will always be the keys 'pending' and 'completed', which will
1542 There will always be the keys 'pending' and 'completed', which will
1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1543 be lists of msg_ids that are incomplete or complete. If `status_only`
1544 is False, then completed results will be keyed by their `msg_id`.
1544 is False, then completed results will be keyed by their `msg_id`.
1545 """
1545 """
1546 if not isinstance(msg_ids, (list,tuple)):
1546 if not isinstance(msg_ids, (list,tuple)):
1547 msg_ids = [msg_ids]
1547 msg_ids = [msg_ids]
1548
1548
1549 theids = []
1549 theids = []
1550 for msg_id in msg_ids:
1550 for msg_id in msg_ids:
1551 if isinstance(msg_id, int):
1551 if isinstance(msg_id, int):
1552 msg_id = self.history[msg_id]
1552 msg_id = self.history[msg_id]
1553 if not isinstance(msg_id, string_types):
1553 if not isinstance(msg_id, string_types):
1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1554 raise TypeError("msg_ids must be str, not %r"%msg_id)
1555 theids.append(msg_id)
1555 theids.append(msg_id)
1556
1556
1557 completed = []
1557 completed = []
1558 local_results = {}
1558 local_results = {}
1559
1559
1560 # comment this block out to temporarily disable local shortcut:
1560 # comment this block out to temporarily disable local shortcut:
1561 for msg_id in theids:
1561 for msg_id in theids:
1562 if msg_id in self.results:
1562 if msg_id in self.results:
1563 completed.append(msg_id)
1563 completed.append(msg_id)
1564 local_results[msg_id] = self.results[msg_id]
1564 local_results[msg_id] = self.results[msg_id]
1565 theids.remove(msg_id)
1565 theids.remove(msg_id)
1566
1566
1567 if theids: # some not locally cached
1567 if theids: # some not locally cached
1568 content = dict(msg_ids=theids, status_only=status_only)
1568 content = dict(msg_ids=theids, status_only=status_only)
1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1569 msg = self.session.send(self._query_socket, "result_request", content=content)
1570 zmq.select([self._query_socket], [], [])
1570 zmq.select([self._query_socket], [], [])
1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1571 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1572 if self.debug:
1572 if self.debug:
1573 pprint(msg)
1573 pprint(msg)
1574 content = msg['content']
1574 content = msg['content']
1575 if content['status'] != 'ok':
1575 if content['status'] != 'ok':
1576 raise self._unwrap_exception(content)
1576 raise self._unwrap_exception(content)
1577 buffers = msg['buffers']
1577 buffers = msg['buffers']
1578 else:
1578 else:
1579 content = dict(completed=[],pending=[])
1579 content = dict(completed=[],pending=[])
1580
1580
1581 content['completed'].extend(completed)
1581 content['completed'].extend(completed)
1582
1582
1583 if status_only:
1583 if status_only:
1584 return content
1584 return content
1585
1585
1586 failures = []
1586 failures = []
1587 # load cached results into result:
1587 # load cached results into result:
1588 content.update(local_results)
1588 content.update(local_results)
1589
1589
1590 # update cache with results:
1590 # update cache with results:
1591 for msg_id in sorted(theids):
1591 for msg_id in sorted(theids):
1592 if msg_id in content['completed']:
1592 if msg_id in content['completed']:
1593 rec = content[msg_id]
1593 rec = content[msg_id]
1594 parent = extract_dates(rec['header'])
1594 parent = extract_dates(rec['header'])
1595 header = extract_dates(rec['result_header'])
1595 header = extract_dates(rec['result_header'])
1596 rcontent = rec['result_content']
1596 rcontent = rec['result_content']
1597 iodict = rec['io']
1597 iodict = rec['io']
1598 if isinstance(rcontent, str):
1598 if isinstance(rcontent, str):
1599 rcontent = self.session.unpack(rcontent)
1599 rcontent = self.session.unpack(rcontent)
1600
1600
1601 md = self.metadata[msg_id]
1601 md = self.metadata[msg_id]
1602 md_msg = dict(
1602 md_msg = dict(
1603 content=rcontent,
1603 content=rcontent,
1604 parent_header=parent,
1604 parent_header=parent,
1605 header=header,
1605 header=header,
1606 metadata=rec['result_metadata'],
1606 metadata=rec['result_metadata'],
1607 )
1607 )
1608 md.update(self._extract_metadata(md_msg))
1608 md.update(self._extract_metadata(md_msg))
1609 if rec.get('received'):
1609 if rec.get('received'):
1610 md['received'] = parse_date(rec['received'])
1610 md['received'] = parse_date(rec['received'])
1611 md.update(iodict)
1611 md.update(iodict)
1612
1612
1613 if rcontent['status'] == 'ok':
1613 if rcontent['status'] == 'ok':
1614 if header['msg_type'] == 'apply_reply':
1614 if header['msg_type'] == 'apply_reply':
1615 res,buffers = serialize.deserialize_object(buffers)
1615 res,buffers = serialize.deserialize_object(buffers)
1616 elif header['msg_type'] == 'execute_reply':
1616 elif header['msg_type'] == 'execute_reply':
1617 res = ExecuteReply(msg_id, rcontent, md)
1617 res = ExecuteReply(msg_id, rcontent, md)
1618 else:
1618 else:
1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1619 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1620 else:
1620 else:
1621 res = self._unwrap_exception(rcontent)
1621 res = self._unwrap_exception(rcontent)
1622 failures.append(res)
1622 failures.append(res)
1623
1623
1624 self.results[msg_id] = res
1624 self.results[msg_id] = res
1625 content[msg_id] = res
1625 content[msg_id] = res
1626
1626
1627 if len(theids) == 1 and failures:
1627 if len(theids) == 1 and failures:
1628 raise failures[0]
1628 raise failures[0]
1629
1629
1630 error.collect_exceptions(failures, "result_status")
1630 error.collect_exceptions(failures, "result_status")
1631 return content
1631 return content
1632
1632
1633 @spin_first
1633 @spin_first
1634 def queue_status(self, targets='all', verbose=False):
1634 def queue_status(self, targets='all', verbose=False):
1635 """Fetch the status of engine queues.
1635 """Fetch the status of engine queues.
1636
1636
1637 Parameters
1637 Parameters
1638 ----------
1638 ----------
1639
1639
1640 targets : int/str/list of ints/strs
1640 targets : int/str/list of ints/strs
1641 the engines whose states are to be queried.
1641 the engines whose states are to be queried.
1642 default : all
1642 default : all
1643 verbose : bool
1643 verbose : bool
1644 Whether to return lengths only, or lists of ids for each element
1644 Whether to return lengths only, or lists of ids for each element
1645 """
1645 """
1646 if targets == 'all':
1646 if targets == 'all':
1647 # allow 'all' to be evaluated on the engine
1647 # allow 'all' to be evaluated on the engine
1648 engine_ids = None
1648 engine_ids = None
1649 else:
1649 else:
1650 engine_ids = self._build_targets(targets)[1]
1650 engine_ids = self._build_targets(targets)[1]
1651 content = dict(targets=engine_ids, verbose=verbose)
1651 content = dict(targets=engine_ids, verbose=verbose)
1652 self.session.send(self._query_socket, "queue_request", content=content)
1652 self.session.send(self._query_socket, "queue_request", content=content)
1653 idents,msg = self.session.recv(self._query_socket, 0)
1653 idents,msg = self.session.recv(self._query_socket, 0)
1654 if self.debug:
1654 if self.debug:
1655 pprint(msg)
1655 pprint(msg)
1656 content = msg['content']
1656 content = msg['content']
1657 status = content.pop('status')
1657 status = content.pop('status')
1658 if status != 'ok':
1658 if status != 'ok':
1659 raise self._unwrap_exception(content)
1659 raise self._unwrap_exception(content)
1660 content = rekey(content)
1660 content = rekey(content)
1661 if isinstance(targets, int):
1661 if isinstance(targets, int):
1662 return content[targets]
1662 return content[targets]
1663 else:
1663 else:
1664 return content
1664 return content
1665
1665
1666 def _build_msgids_from_target(self, targets=None):
1666 def _build_msgids_from_target(self, targets=None):
1667 """Build a list of msg_ids from the list of engine targets"""
1667 """Build a list of msg_ids from the list of engine targets"""
1668 if not targets: # needed as _build_targets otherwise uses all engines
1668 if not targets: # needed as _build_targets otherwise uses all engines
1669 return []
1669 return []
1670 target_ids = self._build_targets(targets)[0]
1670 target_ids = self._build_targets(targets)[0]
1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1671 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1672
1672
1673 def _build_msgids_from_jobs(self, jobs=None):
1673 def _build_msgids_from_jobs(self, jobs=None):
1674 """Build a list of msg_ids from "jobs" """
1674 """Build a list of msg_ids from "jobs" """
1675 if not jobs:
1675 if not jobs:
1676 return []
1676 return []
1677 msg_ids = []
1677 msg_ids = []
1678 if isinstance(jobs, string_types + (AsyncResult,)):
1678 if isinstance(jobs, string_types + (AsyncResult,)):
1679 jobs = [jobs]
1679 jobs = [jobs]
1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1680 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1681 if bad_ids:
1681 if bad_ids:
1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1682 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1683 for j in jobs:
1683 for j in jobs:
1684 if isinstance(j, AsyncResult):
1684 if isinstance(j, AsyncResult):
1685 msg_ids.extend(j.msg_ids)
1685 msg_ids.extend(j.msg_ids)
1686 else:
1686 else:
1687 msg_ids.append(j)
1687 msg_ids.append(j)
1688 return msg_ids
1688 return msg_ids
1689
1689
1690 def purge_local_results(self, jobs=[], targets=[]):
1690 def purge_local_results(self, jobs=[], targets=[]):
1691 """Clears the client caches of results and their metadata.
1691 """Clears the client caches of results and their metadata.
1692
1692
1693 Individual results can be purged by msg_id, or the entire
1693 Individual results can be purged by msg_id, or the entire
1694 history of specific targets can be purged.
1694 history of specific targets can be purged.
1695
1695
1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1696 Use `purge_local_results('all')` to scrub everything from the Clients's
1697 results and metadata caches.
1697 results and metadata caches.
1698
1698
1699 After this call all `AsyncResults` are invalid and should be discarded.
1699 After this call all `AsyncResults` are invalid and should be discarded.
1700
1700
1701 If you must "reget" the results, you can still do so by using
1701 If you must "reget" the results, you can still do so by using
1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1702 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1703 redownload the results from the hub if they are still available
1703 redownload the results from the hub if they are still available
1704 (i.e `client.purge_hub_results(...)` has not been called.
1704 (i.e `client.purge_hub_results(...)` has not been called.
1705
1705
1706 Parameters
1706 Parameters
1707 ----------
1707 ----------
1708
1708
1709 jobs : str or list of str or AsyncResult objects
1709 jobs : str or list of str or AsyncResult objects
1710 the msg_ids whose results should be purged.
1710 the msg_ids whose results should be purged.
1711 targets : int/list of ints
1711 targets : int/list of ints
1712 The engines, by integer ID, whose entire result histories are to be purged.
1712 The engines, by integer ID, whose entire result histories are to be purged.
1713
1713
1714 Raises
1714 Raises
1715 ------
1715 ------
1716
1716
1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1717 RuntimeError : if any of the tasks to be purged are still outstanding.
1718
1718
1719 """
1719 """
1720 if not targets and not jobs:
1720 if not targets and not jobs:
1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1721 raise ValueError("Must specify at least one of `targets` and `jobs`")
1722
1722
1723 if jobs == 'all':
1723 if jobs == 'all':
1724 if self.outstanding:
1724 if self.outstanding:
1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1725 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1726 self.results.clear()
1726 self.results.clear()
1727 self.metadata.clear()
1727 self.metadata.clear()
1728 else:
1728 else:
1729 msg_ids = set()
1729 msg_ids = set()
1730 msg_ids.update(self._build_msgids_from_target(targets))
1730 msg_ids.update(self._build_msgids_from_target(targets))
1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1731 msg_ids.update(self._build_msgids_from_jobs(jobs))
1732 still_outstanding = self.outstanding.intersection(msg_ids)
1732 still_outstanding = self.outstanding.intersection(msg_ids)
1733 if still_outstanding:
1733 if still_outstanding:
1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1734 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1735 for mid in msg_ids:
1735 for mid in msg_ids:
1736 self.results.pop(mid, None)
1736 self.results.pop(mid, None)
1737 self.metadata.pop(mid, None)
1737 self.metadata.pop(mid, None)
1738
1738
1739
1739
1740 @spin_first
1740 @spin_first
1741 def purge_hub_results(self, jobs=[], targets=[]):
1741 def purge_hub_results(self, jobs=[], targets=[]):
1742 """Tell the Hub to forget results.
1742 """Tell the Hub to forget results.
1743
1743
1744 Individual results can be purged by msg_id, or the entire
1744 Individual results can be purged by msg_id, or the entire
1745 history of specific targets can be purged.
1745 history of specific targets can be purged.
1746
1746
1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1747 Use `purge_results('all')` to scrub everything from the Hub's db.
1748
1748
1749 Parameters
1749 Parameters
1750 ----------
1750 ----------
1751
1751
1752 jobs : str or list of str or AsyncResult objects
1752 jobs : str or list of str or AsyncResult objects
1753 the msg_ids whose results should be forgotten.
1753 the msg_ids whose results should be forgotten.
1754 targets : int/str/list of ints/strs
1754 targets : int/str/list of ints/strs
1755 The targets, by int_id, whose entire history is to be purged.
1755 The targets, by int_id, whose entire history is to be purged.
1756
1756
1757 default : None
1757 default : None
1758 """
1758 """
1759 if not targets and not jobs:
1759 if not targets and not jobs:
1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1760 raise ValueError("Must specify at least one of `targets` and `jobs`")
1761 if targets:
1761 if targets:
1762 targets = self._build_targets(targets)[1]
1762 targets = self._build_targets(targets)[1]
1763
1763
1764 # construct msg_ids from jobs
1764 # construct msg_ids from jobs
1765 if jobs == 'all':
1765 if jobs == 'all':
1766 msg_ids = jobs
1766 msg_ids = jobs
1767 else:
1767 else:
1768 msg_ids = self._build_msgids_from_jobs(jobs)
1768 msg_ids = self._build_msgids_from_jobs(jobs)
1769
1769
1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1770 content = dict(engine_ids=targets, msg_ids=msg_ids)
1771 self.session.send(self._query_socket, "purge_request", content=content)
1771 self.session.send(self._query_socket, "purge_request", content=content)
1772 idents, msg = self.session.recv(self._query_socket, 0)
1772 idents, msg = self.session.recv(self._query_socket, 0)
1773 if self.debug:
1773 if self.debug:
1774 pprint(msg)
1774 pprint(msg)
1775 content = msg['content']
1775 content = msg['content']
1776 if content['status'] != 'ok':
1776 if content['status'] != 'ok':
1777 raise self._unwrap_exception(content)
1777 raise self._unwrap_exception(content)
1778
1778
1779 def purge_results(self, jobs=[], targets=[]):
1779 def purge_results(self, jobs=[], targets=[]):
1780 """Clears the cached results from both the hub and the local client
1780 """Clears the cached results from both the hub and the local client
1781
1781
1782 Individual results can be purged by msg_id, or the entire
1782 Individual results can be purged by msg_id, or the entire
1783 history of specific targets can be purged.
1783 history of specific targets can be purged.
1784
1784
1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1785 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1786 the Client's db.
1786 the Client's db.
1787
1787
1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1788 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1789 the same arguments.
1789 the same arguments.
1790
1790
1791 Parameters
1791 Parameters
1792 ----------
1792 ----------
1793
1793
1794 jobs : str or list of str or AsyncResult objects
1794 jobs : str or list of str or AsyncResult objects
1795 the msg_ids whose results should be forgotten.
1795 the msg_ids whose results should be forgotten.
1796 targets : int/str/list of ints/strs
1796 targets : int/str/list of ints/strs
1797 The targets, by int_id, whose entire history is to be purged.
1797 The targets, by int_id, whose entire history is to be purged.
1798
1798
1799 default : None
1799 default : None
1800 """
1800 """
1801 self.purge_local_results(jobs=jobs, targets=targets)
1801 self.purge_local_results(jobs=jobs, targets=targets)
1802 self.purge_hub_results(jobs=jobs, targets=targets)
1802 self.purge_hub_results(jobs=jobs, targets=targets)
1803
1803
1804 def purge_everything(self):
1804 def purge_everything(self):
1805 """Clears all content from previous Tasks from both the hub and the local client
1805 """Clears all content from previous Tasks from both the hub and the local client
1806
1806
1807 In addition to calling `purge_results("all")` it also deletes the history and
1807 In addition to calling `purge_results("all")` it also deletes the history and
1808 other bookkeeping lists.
1808 other bookkeeping lists.
1809 """
1809 """
1810 self.purge_results("all")
1810 self.purge_results("all")
1811 self.history = []
1811 self.history = []
1812 self.session.digest_history.clear()
1812 self.session.digest_history.clear()
1813
1813
1814 @spin_first
1814 @spin_first
1815 def hub_history(self):
1815 def hub_history(self):
1816 """Get the Hub's history
1816 """Get the Hub's history
1817
1817
1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1818 Just like the Client, the Hub has a history, which is a list of msg_ids.
1819 This will contain the history of all clients, and, depending on configuration,
1819 This will contain the history of all clients, and, depending on configuration,
1820 may contain history across multiple cluster sessions.
1820 may contain history across multiple cluster sessions.
1821
1821
1822 Any msg_id returned here is a valid argument to `get_result`.
1822 Any msg_id returned here is a valid argument to `get_result`.
1823
1823
1824 Returns
1824 Returns
1825 -------
1825 -------
1826
1826
1827 msg_ids : list of strs
1827 msg_ids : list of strs
1828 list of all msg_ids, ordered by task submission time.
1828 list of all msg_ids, ordered by task submission time.
1829 """
1829 """
1830
1830
1831 self.session.send(self._query_socket, "history_request", content={})
1831 self.session.send(self._query_socket, "history_request", content={})
1832 idents, msg = self.session.recv(self._query_socket, 0)
1832 idents, msg = self.session.recv(self._query_socket, 0)
1833
1833
1834 if self.debug:
1834 if self.debug:
1835 pprint(msg)
1835 pprint(msg)
1836 content = msg['content']
1836 content = msg['content']
1837 if content['status'] != 'ok':
1837 if content['status'] != 'ok':
1838 raise self._unwrap_exception(content)
1838 raise self._unwrap_exception(content)
1839 else:
1839 else:
1840 return content['history']
1840 return content['history']
1841
1841
1842 @spin_first
1842 @spin_first
1843 def db_query(self, query, keys=None):
1843 def db_query(self, query, keys=None):
1844 """Query the Hub's TaskRecord database
1844 """Query the Hub's TaskRecord database
1845
1845
1846 This will return a list of task record dicts that match `query`
1846 This will return a list of task record dicts that match `query`
1847
1847
1848 Parameters
1848 Parameters
1849 ----------
1849 ----------
1850
1850
1851 query : mongodb query dict
1851 query : mongodb query dict
1852 The search dict. See mongodb query docs for details.
1852 The search dict. See mongodb query docs for details.
1853 keys : list of strs [optional]
1853 keys : list of strs [optional]
1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1854 The subset of keys to be returned. The default is to fetch everything but buffers.
1855 'msg_id' will *always* be included.
1855 'msg_id' will *always* be included.
1856 """
1856 """
1857 if isinstance(keys, string_types):
1857 if isinstance(keys, string_types):
1858 keys = [keys]
1858 keys = [keys]
1859 content = dict(query=query, keys=keys)
1859 content = dict(query=query, keys=keys)
1860 self.session.send(self._query_socket, "db_request", content=content)
1860 self.session.send(self._query_socket, "db_request", content=content)
1861 idents, msg = self.session.recv(self._query_socket, 0)
1861 idents, msg = self.session.recv(self._query_socket, 0)
1862 if self.debug:
1862 if self.debug:
1863 pprint(msg)
1863 pprint(msg)
1864 content = msg['content']
1864 content = msg['content']
1865 if content['status'] != 'ok':
1865 if content['status'] != 'ok':
1866 raise self._unwrap_exception(content)
1866 raise self._unwrap_exception(content)
1867
1867
1868 records = content['records']
1868 records = content['records']
1869
1869
1870 buffer_lens = content['buffer_lens']
1870 buffer_lens = content['buffer_lens']
1871 result_buffer_lens = content['result_buffer_lens']
1871 result_buffer_lens = content['result_buffer_lens']
1872 buffers = msg['buffers']
1872 buffers = msg['buffers']
1873 has_bufs = buffer_lens is not None
1873 has_bufs = buffer_lens is not None
1874 has_rbufs = result_buffer_lens is not None
1874 has_rbufs = result_buffer_lens is not None
1875 for i,rec in enumerate(records):
1875 for i,rec in enumerate(records):
1876 # unpack datetime objects
1876 # unpack datetime objects
1877 for hkey in ('header', 'result_header'):
1877 for hkey in ('header', 'result_header'):
1878 if hkey in rec:
1878 if hkey in rec:
1879 rec[hkey] = extract_dates(rec[hkey])
1879 rec[hkey] = extract_dates(rec[hkey])
1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1880 for dtkey in ('submitted', 'started', 'completed', 'received'):
1881 if dtkey in rec:
1881 if dtkey in rec:
1882 rec[dtkey] = parse_date(rec[dtkey])
1882 rec[dtkey] = parse_date(rec[dtkey])
1883 # relink buffers
1883 # relink buffers
1884 if has_bufs:
1884 if has_bufs:
1885 blen = buffer_lens[i]
1885 blen = buffer_lens[i]
1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1886 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1887 if has_rbufs:
1887 if has_rbufs:
1888 blen = result_buffer_lens[i]
1888 blen = result_buffer_lens[i]
1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1889 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1890
1890
1891 return records
1891 return records
1892
1892
1893 __all__ = [ 'Client' ]
1893 __all__ = [ 'Client' ]
@@ -1,1125 +1,1125 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import imp
8 import imp
9 import sys
9 import sys
10 import warnings
10 import warnings
11 from contextlib import contextmanager
11 from contextlib import contextmanager
12 from types import ModuleType
12 from types import ModuleType
13
13
14 import zmq
14 import zmq
15
15
16 from IPython.testing.skipdoctest import skip_doctest
16 from IPython.testing.skipdoctest import skip_doctest
17 from IPython.utils import pickleutil
17 from IPython.utils import pickleutil
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 )
20 )
21 from decorator import decorator
21 from decorator import decorator
22
22
23 from IPython.parallel import util
23 from ipython_parallel import util
24 from IPython.parallel.controller.dependency import Dependency, dependent
24 from ipython_parallel.controller.dependency import Dependency, dependent
25 from IPython.utils.py3compat import string_types, iteritems, PY3
25 from IPython.utils.py3compat import string_types, iteritems, PY3
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncResult, AsyncMapResult
28 from .asyncresult import AsyncResult, AsyncMapResult
29 from .remotefunction import ParallelFunction, parallel, remote, getname
29 from .remotefunction import ParallelFunction, parallel, remote, getname
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Decorators
32 # Decorators
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 @decorator
35 @decorator
36 def save_ids(f, self, *args, **kwargs):
36 def save_ids(f, self, *args, **kwargs):
37 """Keep our history and outstanding attributes up to date after a method call."""
37 """Keep our history and outstanding attributes up to date after a method call."""
38 n_previous = len(self.client.history)
38 n_previous = len(self.client.history)
39 try:
39 try:
40 ret = f(self, *args, **kwargs)
40 ret = f(self, *args, **kwargs)
41 finally:
41 finally:
42 nmsgs = len(self.client.history) - n_previous
42 nmsgs = len(self.client.history) - n_previous
43 msg_ids = self.client.history[-nmsgs:]
43 msg_ids = self.client.history[-nmsgs:]
44 self.history.extend(msg_ids)
44 self.history.extend(msg_ids)
45 self.outstanding.update(msg_ids)
45 self.outstanding.update(msg_ids)
46 return ret
46 return ret
47
47
48 @decorator
48 @decorator
49 def sync_results(f, self, *args, **kwargs):
49 def sync_results(f, self, *args, **kwargs):
50 """sync relevant results from self.client to our results attribute."""
50 """sync relevant results from self.client to our results attribute."""
51 if self._in_sync_results:
51 if self._in_sync_results:
52 return f(self, *args, **kwargs)
52 return f(self, *args, **kwargs)
53 self._in_sync_results = True
53 self._in_sync_results = True
54 try:
54 try:
55 ret = f(self, *args, **kwargs)
55 ret = f(self, *args, **kwargs)
56 finally:
56 finally:
57 self._in_sync_results = False
57 self._in_sync_results = False
58 self._sync_results()
58 self._sync_results()
59 return ret
59 return ret
60
60
61 @decorator
61 @decorator
62 def spin_after(f, self, *args, **kwargs):
62 def spin_after(f, self, *args, **kwargs):
63 """call spin after the method."""
63 """call spin after the method."""
64 ret = f(self, *args, **kwargs)
64 ret = f(self, *args, **kwargs)
65 self.spin()
65 self.spin()
66 return ret
66 return ret
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72 @skip_doctest
72 @skip_doctest
73 class View(HasTraits):
73 class View(HasTraits):
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75
75
76 Don't use this class, use subclasses.
76 Don't use this class, use subclasses.
77
77
78 Methods
78 Methods
79 -------
79 -------
80
80
81 spin
81 spin
82 flushes incoming results and registration state changes
82 flushes incoming results and registration state changes
83 control methods spin, and requesting `ids` also ensures up to date
83 control methods spin, and requesting `ids` also ensures up to date
84
84
85 wait
85 wait
86 wait on one or more msg_ids
86 wait on one or more msg_ids
87
87
88 execution methods
88 execution methods
89 apply
89 apply
90 legacy: execute, run
90 legacy: execute, run
91
91
92 data movement
92 data movement
93 push, pull, scatter, gather
93 push, pull, scatter, gather
94
94
95 query methods
95 query methods
96 get_result, queue_status, purge_results, result_status
96 get_result, queue_status, purge_results, result_status
97
97
98 control methods
98 control methods
99 abort, shutdown
99 abort, shutdown
100
100
101 """
101 """
102 # flags
102 # flags
103 block=Bool(False)
103 block=Bool(False)
104 track=Bool(True)
104 track=Bool(True)
105 targets = Any()
105 targets = Any()
106
106
107 history=List()
107 history=List()
108 outstanding = Set()
108 outstanding = Set()
109 results = Dict()
109 results = Dict()
110 client = Instance('IPython.parallel.Client')
110 client = Instance('ipython_parallel.Client')
111
111
112 _socket = Instance('zmq.Socket')
112 _socket = Instance('zmq.Socket')
113 _flag_names = List(['targets', 'block', 'track'])
113 _flag_names = List(['targets', 'block', 'track'])
114 _in_sync_results = Bool(False)
114 _in_sync_results = Bool(False)
115 _targets = Any()
115 _targets = Any()
116 _idents = Any()
116 _idents = Any()
117
117
118 def __init__(self, client=None, socket=None, **flags):
118 def __init__(self, client=None, socket=None, **flags):
119 super(View, self).__init__(client=client, _socket=socket)
119 super(View, self).__init__(client=client, _socket=socket)
120 self.results = client.results
120 self.results = client.results
121 self.block = client.block
121 self.block = client.block
122
122
123 self.set_flags(**flags)
123 self.set_flags(**flags)
124
124
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126
126
127 def __repr__(self):
127 def __repr__(self):
128 strtargets = str(self.targets)
128 strtargets = str(self.targets)
129 if len(strtargets) > 16:
129 if len(strtargets) > 16:
130 strtargets = strtargets[:12]+'...]'
130 strtargets = strtargets[:12]+'...]'
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
132
132
133 def __len__(self):
133 def __len__(self):
134 if isinstance(self.targets, list):
134 if isinstance(self.targets, list):
135 return len(self.targets)
135 return len(self.targets)
136 elif isinstance(self.targets, int):
136 elif isinstance(self.targets, int):
137 return 1
137 return 1
138 else:
138 else:
139 return len(self.client)
139 return len(self.client)
140
140
141 def set_flags(self, **kwargs):
141 def set_flags(self, **kwargs):
142 """set my attribute flags by keyword.
142 """set my attribute flags by keyword.
143
143
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 These attributes can be set all at once by name with this method.
145 These attributes can be set all at once by name with this method.
146
146
147 Parameters
147 Parameters
148 ----------
148 ----------
149
149
150 block : bool
150 block : bool
151 whether to wait for results
151 whether to wait for results
152 track : bool
152 track : bool
153 whether to create a MessageTracker to allow the user to
153 whether to create a MessageTracker to allow the user to
154 safely edit after arrays and buffers during non-copying
154 safely edit after arrays and buffers during non-copying
155 sends.
155 sends.
156 """
156 """
157 for name, value in iteritems(kwargs):
157 for name, value in iteritems(kwargs):
158 if name not in self._flag_names:
158 if name not in self._flag_names:
159 raise KeyError("Invalid name: %r"%name)
159 raise KeyError("Invalid name: %r"%name)
160 else:
160 else:
161 setattr(self, name, value)
161 setattr(self, name, value)
162
162
163 @contextmanager
163 @contextmanager
164 def temp_flags(self, **kwargs):
164 def temp_flags(self, **kwargs):
165 """temporarily set flags, for use in `with` statements.
165 """temporarily set flags, for use in `with` statements.
166
166
167 See set_flags for permanent setting of flags
167 See set_flags for permanent setting of flags
168
168
169 Examples
169 Examples
170 --------
170 --------
171
171
172 >>> view.track=False
172 >>> view.track=False
173 ...
173 ...
174 >>> with view.temp_flags(track=True):
174 >>> with view.temp_flags(track=True):
175 ... ar = view.apply(dostuff, my_big_array)
175 ... ar = view.apply(dostuff, my_big_array)
176 ... ar.tracker.wait() # wait for send to finish
176 ... ar.tracker.wait() # wait for send to finish
177 >>> view.track
177 >>> view.track
178 False
178 False
179
179
180 """
180 """
181 # preflight: save flags, and set temporaries
181 # preflight: save flags, and set temporaries
182 saved_flags = {}
182 saved_flags = {}
183 for f in self._flag_names:
183 for f in self._flag_names:
184 saved_flags[f] = getattr(self, f)
184 saved_flags[f] = getattr(self, f)
185 self.set_flags(**kwargs)
185 self.set_flags(**kwargs)
186 # yield to the with-statement block
186 # yield to the with-statement block
187 try:
187 try:
188 yield
188 yield
189 finally:
189 finally:
190 # postflight: restore saved flags
190 # postflight: restore saved flags
191 self.set_flags(**saved_flags)
191 self.set_flags(**saved_flags)
192
192
193
193
194 #----------------------------------------------------------------
194 #----------------------------------------------------------------
195 # apply
195 # apply
196 #----------------------------------------------------------------
196 #----------------------------------------------------------------
197
197
198 def _sync_results(self):
198 def _sync_results(self):
199 """to be called by @sync_results decorator
199 """to be called by @sync_results decorator
200
200
201 after submitting any tasks.
201 after submitting any tasks.
202 """
202 """
203 delta = self.outstanding.difference(self.client.outstanding)
203 delta = self.outstanding.difference(self.client.outstanding)
204 completed = self.outstanding.intersection(delta)
204 completed = self.outstanding.intersection(delta)
205 self.outstanding = self.outstanding.difference(completed)
205 self.outstanding = self.outstanding.difference(completed)
206
206
207 @sync_results
207 @sync_results
208 @save_ids
208 @save_ids
209 def _really_apply(self, f, args, kwargs, block=None, **options):
209 def _really_apply(self, f, args, kwargs, block=None, **options):
210 """wrapper for client.send_apply_request"""
210 """wrapper for client.send_apply_request"""
211 raise NotImplementedError("Implement in subclasses")
211 raise NotImplementedError("Implement in subclasses")
212
212
213 def apply(self, f, *args, **kwargs):
213 def apply(self, f, *args, **kwargs):
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215
215
216 This method sets all apply flags via this View's attributes.
216 This method sets all apply flags via this View's attributes.
217
217
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
218 Returns :class:`~ipython_parallel.client.asyncresult.AsyncResult`
219 instance if ``self.block`` is False, otherwise the return value of
219 instance if ``self.block`` is False, otherwise the return value of
220 ``f(*args, **kwargs)``.
220 ``f(*args, **kwargs)``.
221 """
221 """
222 return self._really_apply(f, args, kwargs)
222 return self._really_apply(f, args, kwargs)
223
223
224 def apply_async(self, f, *args, **kwargs):
224 def apply_async(self, f, *args, **kwargs):
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226
226
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
227 Returns :class:`~ipython_parallel.client.asyncresult.AsyncResult` instance.
228 """
228 """
229 return self._really_apply(f, args, kwargs, block=False)
229 return self._really_apply(f, args, kwargs, block=False)
230
230
231 @spin_after
231 @spin_after
232 def apply_sync(self, f, *args, **kwargs):
232 def apply_sync(self, f, *args, **kwargs):
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 returning the result.
234 returning the result.
235 """
235 """
236 return self._really_apply(f, args, kwargs, block=True)
236 return self._really_apply(f, args, kwargs, block=True)
237
237
238 #----------------------------------------------------------------
238 #----------------------------------------------------------------
239 # wrappers for client and control methods
239 # wrappers for client and control methods
240 #----------------------------------------------------------------
240 #----------------------------------------------------------------
241 @sync_results
241 @sync_results
242 def spin(self):
242 def spin(self):
243 """spin the client, and sync"""
243 """spin the client, and sync"""
244 self.client.spin()
244 self.client.spin()
245
245
246 @sync_results
246 @sync_results
247 def wait(self, jobs=None, timeout=-1):
247 def wait(self, jobs=None, timeout=-1):
248 """waits on one or more `jobs`, for up to `timeout` seconds.
248 """waits on one or more `jobs`, for up to `timeout` seconds.
249
249
250 Parameters
250 Parameters
251 ----------
251 ----------
252
252
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 ints are indices to self.history
254 ints are indices to self.history
255 strs are msg_ids
255 strs are msg_ids
256 default: wait on all outstanding messages
256 default: wait on all outstanding messages
257 timeout : float
257 timeout : float
258 a time in seconds, after which to give up.
258 a time in seconds, after which to give up.
259 default is -1, which means no timeout
259 default is -1, which means no timeout
260
260
261 Returns
261 Returns
262 -------
262 -------
263
263
264 True : when all msg_ids are done
264 True : when all msg_ids are done
265 False : timeout reached, some msg_ids still outstanding
265 False : timeout reached, some msg_ids still outstanding
266 """
266 """
267 if jobs is None:
267 if jobs is None:
268 jobs = self.history
268 jobs = self.history
269 return self.client.wait(jobs, timeout)
269 return self.client.wait(jobs, timeout)
270
270
271 def abort(self, jobs=None, targets=None, block=None):
271 def abort(self, jobs=None, targets=None, block=None):
272 """Abort jobs on my engines.
272 """Abort jobs on my engines.
273
273
274 Parameters
274 Parameters
275 ----------
275 ----------
276
276
277 jobs : None, str, list of strs, optional
277 jobs : None, str, list of strs, optional
278 if None: abort all jobs.
278 if None: abort all jobs.
279 else: abort specific msg_id(s).
279 else: abort specific msg_id(s).
280 """
280 """
281 block = block if block is not None else self.block
281 block = block if block is not None else self.block
282 targets = targets if targets is not None else self.targets
282 targets = targets if targets is not None else self.targets
283 jobs = jobs if jobs is not None else list(self.outstanding)
283 jobs = jobs if jobs is not None else list(self.outstanding)
284
284
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
286
286
287 def queue_status(self, targets=None, verbose=False):
287 def queue_status(self, targets=None, verbose=False):
288 """Fetch the Queue status of my engines"""
288 """Fetch the Queue status of my engines"""
289 targets = targets if targets is not None else self.targets
289 targets = targets if targets is not None else self.targets
290 return self.client.queue_status(targets=targets, verbose=verbose)
290 return self.client.queue_status(targets=targets, verbose=verbose)
291
291
292 def purge_results(self, jobs=[], targets=[]):
292 def purge_results(self, jobs=[], targets=[]):
293 """Instruct the controller to forget specific results."""
293 """Instruct the controller to forget specific results."""
294 if targets is None or targets == 'all':
294 if targets is None or targets == 'all':
295 targets = self.targets
295 targets = self.targets
296 return self.client.purge_results(jobs=jobs, targets=targets)
296 return self.client.purge_results(jobs=jobs, targets=targets)
297
297
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 """Terminates one or more engine processes, optionally including the hub.
299 """Terminates one or more engine processes, optionally including the hub.
300 """
300 """
301 block = self.block if block is None else block
301 block = self.block if block is None else block
302 if targets is None or targets == 'all':
302 if targets is None or targets == 'all':
303 targets = self.targets
303 targets = self.targets
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305
305
306 @spin_after
306 @spin_after
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 """return one or more results, specified by history index or msg_id.
308 """return one or more results, specified by history index or msg_id.
309
309
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
310 See :meth:`ipython_parallel.client.client.Client.get_result` for details.
311 """
311 """
312
312
313 if indices_or_msg_ids is None:
313 if indices_or_msg_ids is None:
314 indices_or_msg_ids = -1
314 indices_or_msg_ids = -1
315 if isinstance(indices_or_msg_ids, int):
315 if isinstance(indices_or_msg_ids, int):
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 indices_or_msg_ids = list(indices_or_msg_ids)
318 indices_or_msg_ids = list(indices_or_msg_ids)
319 for i,index in enumerate(indices_or_msg_ids):
319 for i,index in enumerate(indices_or_msg_ids):
320 if isinstance(index, int):
320 if isinstance(index, int):
321 indices_or_msg_ids[i] = self.history[index]
321 indices_or_msg_ids[i] = self.history[index]
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323
323
324 #-------------------------------------------------------------------
324 #-------------------------------------------------------------------
325 # Map
325 # Map
326 #-------------------------------------------------------------------
326 #-------------------------------------------------------------------
327
327
328 @sync_results
328 @sync_results
329 def map(self, f, *sequences, **kwargs):
329 def map(self, f, *sequences, **kwargs):
330 """override in subclasses"""
330 """override in subclasses"""
331 raise NotImplementedError
331 raise NotImplementedError
332
332
333 def map_async(self, f, *sequences, **kwargs):
333 def map_async(self, f, *sequences, **kwargs):
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
335
335
336 This is equivalent to ``map(...block=False)``.
336 This is equivalent to ``map(...block=False)``.
337
337
338 See `self.map` for details.
338 See `self.map` for details.
339 """
339 """
340 if 'block' in kwargs:
340 if 'block' in kwargs:
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 kwargs['block'] = False
342 kwargs['block'] = False
343 return self.map(f,*sequences,**kwargs)
343 return self.map(f,*sequences,**kwargs)
344
344
345 def map_sync(self, f, *sequences, **kwargs):
345 def map_sync(self, f, *sequences, **kwargs):
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
347
347
348 This is equivalent to ``map(...block=True)``.
348 This is equivalent to ``map(...block=True)``.
349
349
350 See `self.map` for details.
350 See `self.map` for details.
351 """
351 """
352 if 'block' in kwargs:
352 if 'block' in kwargs:
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 kwargs['block'] = True
354 kwargs['block'] = True
355 return self.map(f,*sequences,**kwargs)
355 return self.map(f,*sequences,**kwargs)
356
356
357 def imap(self, f, *sequences, **kwargs):
357 def imap(self, f, *sequences, **kwargs):
358 """Parallel version of :func:`itertools.imap`.
358 """Parallel version of :func:`itertools.imap`.
359
359
360 See `self.map` for details.
360 See `self.map` for details.
361
361
362 """
362 """
363
363
364 return iter(self.map_async(f,*sequences, **kwargs))
364 return iter(self.map_async(f,*sequences, **kwargs))
365
365
366 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
367 # Decorators
367 # Decorators
368 #-------------------------------------------------------------------
368 #-------------------------------------------------------------------
369
369
370 def remote(self, block=None, **flags):
370 def remote(self, block=None, **flags):
371 """Decorator for making a RemoteFunction"""
371 """Decorator for making a RemoteFunction"""
372 block = self.block if block is None else block
372 block = self.block if block is None else block
373 return remote(self, block=block, **flags)
373 return remote(self, block=block, **flags)
374
374
375 def parallel(self, dist='b', block=None, **flags):
375 def parallel(self, dist='b', block=None, **flags):
376 """Decorator for making a ParallelFunction"""
376 """Decorator for making a ParallelFunction"""
377 block = self.block if block is None else block
377 block = self.block if block is None else block
378 return parallel(self, dist=dist, block=block, **flags)
378 return parallel(self, dist=dist, block=block, **flags)
379
379
380 @skip_doctest
380 @skip_doctest
381 class DirectView(View):
381 class DirectView(View):
382 """Direct Multiplexer View of one or more engines.
382 """Direct Multiplexer View of one or more engines.
383
383
384 These are created via indexed access to a client:
384 These are created via indexed access to a client:
385
385
386 >>> dv_1 = client[1]
386 >>> dv_1 = client[1]
387 >>> dv_all = client[:]
387 >>> dv_all = client[:]
388 >>> dv_even = client[::2]
388 >>> dv_even = client[::2]
389 >>> dv_some = client[1:3]
389 >>> dv_some = client[1:3]
390
390
391 This object provides dictionary access to engine namespaces:
391 This object provides dictionary access to engine namespaces:
392
392
393 # push a=5:
393 # push a=5:
394 >>> dv['a'] = 5
394 >>> dv['a'] = 5
395 # pull 'foo':
395 # pull 'foo':
396 >>> dv['foo']
396 >>> dv['foo']
397
397
398 """
398 """
399
399
400 def __init__(self, client=None, socket=None, targets=None):
400 def __init__(self, client=None, socket=None, targets=None):
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402
402
403 @property
403 @property
404 def importer(self):
404 def importer(self):
405 """sync_imports(local=True) as a property.
405 """sync_imports(local=True) as a property.
406
406
407 See sync_imports for details.
407 See sync_imports for details.
408
408
409 """
409 """
410 return self.sync_imports(True)
410 return self.sync_imports(True)
411
411
412 @contextmanager
412 @contextmanager
413 def sync_imports(self, local=True, quiet=False):
413 def sync_imports(self, local=True, quiet=False):
414 """Context Manager for performing simultaneous local and remote imports.
414 """Context Manager for performing simultaneous local and remote imports.
415
415
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417
417
418 If `local=True`, then the package will also be imported locally.
418 If `local=True`, then the package will also be imported locally.
419
419
420 If `quiet=True`, no output will be produced when attempting remote
420 If `quiet=True`, no output will be produced when attempting remote
421 imports.
421 imports.
422
422
423 Note that remote-only (`local=False`) imports have not been implemented.
423 Note that remote-only (`local=False`) imports have not been implemented.
424
424
425 >>> with view.sync_imports():
425 >>> with view.sync_imports():
426 ... from numpy import recarray
426 ... from numpy import recarray
427 importing recarray from numpy on engine(s)
427 importing recarray from numpy on engine(s)
428
428
429 """
429 """
430 from IPython.utils.py3compat import builtin_mod
430 from IPython.utils.py3compat import builtin_mod
431 local_import = builtin_mod.__import__
431 local_import = builtin_mod.__import__
432 modules = set()
432 modules = set()
433 results = []
433 results = []
434 @util.interactive
434 @util.interactive
435 def remote_import(name, fromlist, level):
435 def remote_import(name, fromlist, level):
436 """the function to be passed to apply, that actually performs the import
436 """the function to be passed to apply, that actually performs the import
437 on the engine, and loads up the user namespace.
437 on the engine, and loads up the user namespace.
438 """
438 """
439 import sys
439 import sys
440 user_ns = globals()
440 user_ns = globals()
441 mod = __import__(name, fromlist=fromlist, level=level)
441 mod = __import__(name, fromlist=fromlist, level=level)
442 if fromlist:
442 if fromlist:
443 for key in fromlist:
443 for key in fromlist:
444 user_ns[key] = getattr(mod, key)
444 user_ns[key] = getattr(mod, key)
445 else:
445 else:
446 user_ns[name] = sys.modules[name]
446 user_ns[name] = sys.modules[name]
447
447
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 """the drop-in replacement for __import__, that optionally imports
449 """the drop-in replacement for __import__, that optionally imports
450 locally as well.
450 locally as well.
451 """
451 """
452 # don't override nested imports
452 # don't override nested imports
453 save_import = builtin_mod.__import__
453 save_import = builtin_mod.__import__
454 builtin_mod.__import__ = local_import
454 builtin_mod.__import__ = local_import
455
455
456 if imp.lock_held():
456 if imp.lock_held():
457 # this is a side-effect import, don't do it remotely, or even
457 # this is a side-effect import, don't do it remotely, or even
458 # ignore the local effects
458 # ignore the local effects
459 return local_import(name, globals, locals, fromlist, level)
459 return local_import(name, globals, locals, fromlist, level)
460
460
461 imp.acquire_lock()
461 imp.acquire_lock()
462 if local:
462 if local:
463 mod = local_import(name, globals, locals, fromlist, level)
463 mod = local_import(name, globals, locals, fromlist, level)
464 else:
464 else:
465 raise NotImplementedError("remote-only imports not yet implemented")
465 raise NotImplementedError("remote-only imports not yet implemented")
466 imp.release_lock()
466 imp.release_lock()
467
467
468 key = name+':'+','.join(fromlist or [])
468 key = name+':'+','.join(fromlist or [])
469 if level <= 0 and key not in modules:
469 if level <= 0 and key not in modules:
470 modules.add(key)
470 modules.add(key)
471 if not quiet:
471 if not quiet:
472 if fromlist:
472 if fromlist:
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 else:
474 else:
475 print("importing %s on engine(s)"%name)
475 print("importing %s on engine(s)"%name)
476 results.append(self.apply_async(remote_import, name, fromlist, level))
476 results.append(self.apply_async(remote_import, name, fromlist, level))
477 # restore override
477 # restore override
478 builtin_mod.__import__ = save_import
478 builtin_mod.__import__ = save_import
479
479
480 return mod
480 return mod
481
481
482 # override __import__
482 # override __import__
483 builtin_mod.__import__ = view_import
483 builtin_mod.__import__ = view_import
484 try:
484 try:
485 # enter the block
485 # enter the block
486 yield
486 yield
487 except ImportError:
487 except ImportError:
488 if local:
488 if local:
489 raise
489 raise
490 else:
490 else:
491 # ignore import errors if not doing local imports
491 # ignore import errors if not doing local imports
492 pass
492 pass
493 finally:
493 finally:
494 # always restore __import__
494 # always restore __import__
495 builtin_mod.__import__ = local_import
495 builtin_mod.__import__ = local_import
496
496
497 for r in results:
497 for r in results:
498 # raise possible remote ImportErrors here
498 # raise possible remote ImportErrors here
499 r.get()
499 r.get()
500
500
501 def use_dill(self):
501 def use_dill(self):
502 """Expand serialization support with dill
502 """Expand serialization support with dill
503
503
504 adds support for closures, etc.
504 adds support for closures, etc.
505
505
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
507 """
507 """
508 pickleutil.use_dill()
508 pickleutil.use_dill()
509 return self.apply(pickleutil.use_dill)
509 return self.apply(pickleutil.use_dill)
510
510
511 def use_cloudpickle(self):
511 def use_cloudpickle(self):
512 """Expand serialization support with cloudpickle.
512 """Expand serialization support with cloudpickle.
513 """
513 """
514 pickleutil.use_cloudpickle()
514 pickleutil.use_cloudpickle()
515 return self.apply(pickleutil.use_cloudpickle)
515 return self.apply(pickleutil.use_cloudpickle)
516
516
517
517
518 @sync_results
518 @sync_results
519 @save_ids
519 @save_ids
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 """calls f(*args, **kwargs) on remote engines, returning the result.
521 """calls f(*args, **kwargs) on remote engines, returning the result.
522
522
523 This method sets all of `apply`'s flags via this View's attributes.
523 This method sets all of `apply`'s flags via this View's attributes.
524
524
525 Parameters
525 Parameters
526 ----------
526 ----------
527
527
528 f : callable
528 f : callable
529
529
530 args : list [default: empty]
530 args : list [default: empty]
531
531
532 kwargs : dict [default: empty]
532 kwargs : dict [default: empty]
533
533
534 targets : target list [default: self.targets]
534 targets : target list [default: self.targets]
535 where to run
535 where to run
536 block : bool [default: self.block]
536 block : bool [default: self.block]
537 whether to block
537 whether to block
538 track : bool [default: self.track]
538 track : bool [default: self.track]
539 whether to ask zmq to track the message, for safe non-copying sends
539 whether to ask zmq to track the message, for safe non-copying sends
540
540
541 Returns
541 Returns
542 -------
542 -------
543
543
544 if self.block is False:
544 if self.block is False:
545 returns AsyncResult
545 returns AsyncResult
546 else:
546 else:
547 returns actual result of f(*args, **kwargs) on the engine(s)
547 returns actual result of f(*args, **kwargs) on the engine(s)
548 This will be a list of self.targets is also a list (even length 1), or
548 This will be a list of self.targets is also a list (even length 1), or
549 the single result if self.targets is an integer engine id
549 the single result if self.targets is an integer engine id
550 """
550 """
551 args = [] if args is None else args
551 args = [] if args is None else args
552 kwargs = {} if kwargs is None else kwargs
552 kwargs = {} if kwargs is None else kwargs
553 block = self.block if block is None else block
553 block = self.block if block is None else block
554 track = self.track if track is None else track
554 track = self.track if track is None else track
555 targets = self.targets if targets is None else targets
555 targets = self.targets if targets is None else targets
556
556
557 _idents, _targets = self.client._build_targets(targets)
557 _idents, _targets = self.client._build_targets(targets)
558 msg_ids = []
558 msg_ids = []
559 trackers = []
559 trackers = []
560 for ident in _idents:
560 for ident in _idents:
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 ident=ident)
562 ident=ident)
563 if track:
563 if track:
564 trackers.append(msg['tracker'])
564 trackers.append(msg['tracker'])
565 msg_ids.append(msg['header']['msg_id'])
565 msg_ids.append(msg['header']['msg_id'])
566 if isinstance(targets, int):
566 if isinstance(targets, int):
567 msg_ids = msg_ids[0]
567 msg_ids = msg_ids[0]
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
570 tracker=tracker, owner=True,
571 )
571 )
572 if block:
572 if block:
573 try:
573 try:
574 return ar.get()
574 return ar.get()
575 except KeyboardInterrupt:
575 except KeyboardInterrupt:
576 pass
576 pass
577 return ar
577 return ar
578
578
579
579
580 @sync_results
580 @sync_results
581 def map(self, f, *sequences, **kwargs):
581 def map(self, f, *sequences, **kwargs):
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583
583
584 Parallel version of builtin `map`, using this View's `targets`.
584 Parallel version of builtin `map`, using this View's `targets`.
585
585
586 There will be one task per target, so work will be chunked
586 There will be one task per target, so work will be chunked
587 if the sequences are longer than `targets`.
587 if the sequences are longer than `targets`.
588
588
589 Results can be iterated as they are ready, but will become available in chunks.
589 Results can be iterated as they are ready, but will become available in chunks.
590
590
591 Parameters
591 Parameters
592 ----------
592 ----------
593
593
594 f : callable
594 f : callable
595 function to be mapped
595 function to be mapped
596 *sequences: one or more sequences of matching length
596 *sequences: one or more sequences of matching length
597 the sequences to be distributed and passed to `f`
597 the sequences to be distributed and passed to `f`
598 block : bool
598 block : bool
599 whether to wait for the result or not [default self.block]
599 whether to wait for the result or not [default self.block]
600
600
601 Returns
601 Returns
602 -------
602 -------
603
603
604
604
605 If block=False
605 If block=False
606 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
606 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
607 An object like AsyncResult, but which reassembles the sequence of results
607 An object like AsyncResult, but which reassembles the sequence of results
608 into a single list. AsyncMapResults can be iterated through before all
608 into a single list. AsyncMapResults can be iterated through before all
609 results are complete.
609 results are complete.
610 else
610 else
611 A list, the result of ``map(f,*sequences)``
611 A list, the result of ``map(f,*sequences)``
612 """
612 """
613
613
614 block = kwargs.pop('block', self.block)
614 block = kwargs.pop('block', self.block)
615 for k in kwargs.keys():
615 for k in kwargs.keys():
616 if k not in ['block', 'track']:
616 if k not in ['block', 'track']:
617 raise TypeError("invalid keyword arg, %r"%k)
617 raise TypeError("invalid keyword arg, %r"%k)
618
618
619 assert len(sequences) > 0, "must have some sequences to map onto!"
619 assert len(sequences) > 0, "must have some sequences to map onto!"
620 pf = ParallelFunction(self, f, block=block, **kwargs)
620 pf = ParallelFunction(self, f, block=block, **kwargs)
621 return pf.map(*sequences)
621 return pf.map(*sequences)
622
622
623 @sync_results
623 @sync_results
624 @save_ids
624 @save_ids
625 def execute(self, code, silent=True, targets=None, block=None):
625 def execute(self, code, silent=True, targets=None, block=None):
626 """Executes `code` on `targets` in blocking or nonblocking manner.
626 """Executes `code` on `targets` in blocking or nonblocking manner.
627
627
628 ``execute`` is always `bound` (affects engine namespace)
628 ``execute`` is always `bound` (affects engine namespace)
629
629
630 Parameters
630 Parameters
631 ----------
631 ----------
632
632
633 code : str
633 code : str
634 the code string to be executed
634 the code string to be executed
635 block : bool
635 block : bool
636 whether or not to wait until done to return
636 whether or not to wait until done to return
637 default: self.block
637 default: self.block
638 """
638 """
639 block = self.block if block is None else block
639 block = self.block if block is None else block
640 targets = self.targets if targets is None else targets
640 targets = self.targets if targets is None else targets
641
641
642 _idents, _targets = self.client._build_targets(targets)
642 _idents, _targets = self.client._build_targets(targets)
643 msg_ids = []
643 msg_ids = []
644 trackers = []
644 trackers = []
645 for ident in _idents:
645 for ident in _idents:
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 msg_ids.append(msg['header']['msg_id'])
647 msg_ids.append(msg['header']['msg_id'])
648 if isinstance(targets, int):
648 if isinstance(targets, int):
649 msg_ids = msg_ids[0]
649 msg_ids = msg_ids[0]
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 if block:
651 if block:
652 try:
652 try:
653 ar.get()
653 ar.get()
654 except KeyboardInterrupt:
654 except KeyboardInterrupt:
655 pass
655 pass
656 return ar
656 return ar
657
657
658 def run(self, filename, targets=None, block=None):
658 def run(self, filename, targets=None, block=None):
659 """Execute contents of `filename` on my engine(s).
659 """Execute contents of `filename` on my engine(s).
660
660
661 This simply reads the contents of the file and calls `execute`.
661 This simply reads the contents of the file and calls `execute`.
662
662
663 Parameters
663 Parameters
664 ----------
664 ----------
665
665
666 filename : str
666 filename : str
667 The path to the file
667 The path to the file
668 targets : int/str/list of ints/strs
668 targets : int/str/list of ints/strs
669 the engines on which to execute
669 the engines on which to execute
670 default : all
670 default : all
671 block : bool
671 block : bool
672 whether or not to wait until done
672 whether or not to wait until done
673 default: self.block
673 default: self.block
674
674
675 """
675 """
676 with open(filename, 'r') as f:
676 with open(filename, 'r') as f:
677 # add newline in case of trailing indented whitespace
677 # add newline in case of trailing indented whitespace
678 # which will cause SyntaxError
678 # which will cause SyntaxError
679 code = f.read()+'\n'
679 code = f.read()+'\n'
680 return self.execute(code, block=block, targets=targets)
680 return self.execute(code, block=block, targets=targets)
681
681
682 def update(self, ns):
682 def update(self, ns):
683 """update remote namespace with dict `ns`
683 """update remote namespace with dict `ns`
684
684
685 See `push` for details.
685 See `push` for details.
686 """
686 """
687 return self.push(ns, block=self.block, track=self.track)
687 return self.push(ns, block=self.block, track=self.track)
688
688
689 def push(self, ns, targets=None, block=None, track=None):
689 def push(self, ns, targets=None, block=None, track=None):
690 """update remote namespace with dict `ns`
690 """update remote namespace with dict `ns`
691
691
692 Parameters
692 Parameters
693 ----------
693 ----------
694
694
695 ns : dict
695 ns : dict
696 dict of keys with which to update engine namespace(s)
696 dict of keys with which to update engine namespace(s)
697 block : bool [default : self.block]
697 block : bool [default : self.block]
698 whether to wait to be notified of engine receipt
698 whether to wait to be notified of engine receipt
699
699
700 """
700 """
701
701
702 block = block if block is not None else self.block
702 block = block if block is not None else self.block
703 track = track if track is not None else self.track
703 track = track if track is not None else self.track
704 targets = targets if targets is not None else self.targets
704 targets = targets if targets is not None else self.targets
705 # applier = self.apply_sync if block else self.apply_async
705 # applier = self.apply_sync if block else self.apply_async
706 if not isinstance(ns, dict):
706 if not isinstance(ns, dict):
707 raise TypeError("Must be a dict, not %s"%type(ns))
707 raise TypeError("Must be a dict, not %s"%type(ns))
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709
709
710 def get(self, key_s):
710 def get(self, key_s):
711 """get object(s) by `key_s` from remote namespace
711 """get object(s) by `key_s` from remote namespace
712
712
713 see `pull` for details.
713 see `pull` for details.
714 """
714 """
715 # block = block if block is not None else self.block
715 # block = block if block is not None else self.block
716 return self.pull(key_s, block=True)
716 return self.pull(key_s, block=True)
717
717
718 def pull(self, names, targets=None, block=None):
718 def pull(self, names, targets=None, block=None):
719 """get object(s) by `name` from remote namespace
719 """get object(s) by `name` from remote namespace
720
720
721 will return one object if it is a key.
721 will return one object if it is a key.
722 can also take a list of keys, in which case it will return a list of objects.
722 can also take a list of keys, in which case it will return a list of objects.
723 """
723 """
724 block = block if block is not None else self.block
724 block = block if block is not None else self.block
725 targets = targets if targets is not None else self.targets
725 targets = targets if targets is not None else self.targets
726 applier = self.apply_sync if block else self.apply_async
726 applier = self.apply_sync if block else self.apply_async
727 if isinstance(names, string_types):
727 if isinstance(names, string_types):
728 pass
728 pass
729 elif isinstance(names, (list,tuple,set)):
729 elif isinstance(names, (list,tuple,set)):
730 for key in names:
730 for key in names:
731 if not isinstance(key, string_types):
731 if not isinstance(key, string_types):
732 raise TypeError("keys must be str, not type %r"%type(key))
732 raise TypeError("keys must be str, not type %r"%type(key))
733 else:
733 else:
734 raise TypeError("names must be strs, not %r"%names)
734 raise TypeError("names must be strs, not %r"%names)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736
736
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 """
738 """
739 Partition a Python sequence and send the partitions to a set of engines.
739 Partition a Python sequence and send the partitions to a set of engines.
740 """
740 """
741 block = block if block is not None else self.block
741 block = block if block is not None else self.block
742 track = track if track is not None else self.track
742 track = track if track is not None else self.track
743 targets = targets if targets is not None else self.targets
743 targets = targets if targets is not None else self.targets
744
744
745 # construct integer ID list:
745 # construct integer ID list:
746 targets = self.client._build_targets(targets)[1]
746 targets = self.client._build_targets(targets)[1]
747
747
748 mapObject = Map.dists[dist]()
748 mapObject = Map.dists[dist]()
749 nparts = len(targets)
749 nparts = len(targets)
750 msg_ids = []
750 msg_ids = []
751 trackers = []
751 trackers = []
752 for index, engineid in enumerate(targets):
752 for index, engineid in enumerate(targets):
753 partition = mapObject.getPartition(seq, index, nparts)
753 partition = mapObject.getPartition(seq, index, nparts)
754 if flatten and len(partition) == 1:
754 if flatten and len(partition) == 1:
755 ns = {key: partition[0]}
755 ns = {key: partition[0]}
756 else:
756 else:
757 ns = {key: partition}
757 ns = {key: partition}
758 r = self.push(ns, block=False, track=track, targets=engineid)
758 r = self.push(ns, block=False, track=track, targets=engineid)
759 msg_ids.extend(r.msg_ids)
759 msg_ids.extend(r.msg_ids)
760 if track:
760 if track:
761 trackers.append(r._tracker)
761 trackers.append(r._tracker)
762
762
763 if track:
763 if track:
764 tracker = zmq.MessageTracker(*trackers)
764 tracker = zmq.MessageTracker(*trackers)
765 else:
765 else:
766 tracker = None
766 tracker = None
767
767
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
769 tracker=tracker, owner=True,
770 )
770 )
771 if block:
771 if block:
772 r.wait()
772 r.wait()
773 else:
773 else:
774 return r
774 return r
775
775
776 @sync_results
776 @sync_results
777 @save_ids
777 @save_ids
778 def gather(self, key, dist='b', targets=None, block=None):
778 def gather(self, key, dist='b', targets=None, block=None):
779 """
779 """
780 Gather a partitioned sequence on a set of engines as a single local seq.
780 Gather a partitioned sequence on a set of engines as a single local seq.
781 """
781 """
782 block = block if block is not None else self.block
782 block = block if block is not None else self.block
783 targets = targets if targets is not None else self.targets
783 targets = targets if targets is not None else self.targets
784 mapObject = Map.dists[dist]()
784 mapObject = Map.dists[dist]()
785 msg_ids = []
785 msg_ids = []
786
786
787 # construct integer ID list:
787 # construct integer ID list:
788 targets = self.client._build_targets(targets)[1]
788 targets = self.client._build_targets(targets)[1]
789
789
790 for index, engineid in enumerate(targets):
790 for index, engineid in enumerate(targets):
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792
792
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794
794
795 if block:
795 if block:
796 try:
796 try:
797 return r.get()
797 return r.get()
798 except KeyboardInterrupt:
798 except KeyboardInterrupt:
799 pass
799 pass
800 return r
800 return r
801
801
802 def __getitem__(self, key):
802 def __getitem__(self, key):
803 return self.get(key)
803 return self.get(key)
804
804
805 def __setitem__(self,key, value):
805 def __setitem__(self,key, value):
806 self.update({key:value})
806 self.update({key:value})
807
807
808 def clear(self, targets=None, block=None):
808 def clear(self, targets=None, block=None):
809 """Clear the remote namespaces on my engines."""
809 """Clear the remote namespaces on my engines."""
810 block = block if block is not None else self.block
810 block = block if block is not None else self.block
811 targets = targets if targets is not None else self.targets
811 targets = targets if targets is not None else self.targets
812 return self.client.clear(targets=targets, block=block)
812 return self.client.clear(targets=targets, block=block)
813
813
814 #----------------------------------------
814 #----------------------------------------
815 # activate for %px, %autopx, etc. magics
815 # activate for %px, %autopx, etc. magics
816 #----------------------------------------
816 #----------------------------------------
817
817
818 def activate(self, suffix=''):
818 def activate(self, suffix=''):
819 """Activate IPython magics associated with this View
819 """Activate IPython magics associated with this View
820
820
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822
822
823 Parameters
823 Parameters
824 ----------
824 ----------
825
825
826 suffix: str [default: '']
826 suffix: str [default: '']
827 The suffix, if any, for the magics. This allows you to have
827 The suffix, if any, for the magics. This allows you to have
828 multiple views associated with parallel magics at the same time.
828 multiple views associated with parallel magics at the same time.
829
829
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 on the even engines.
832 on the even engines.
833 """
833 """
834
834
835 from IPython.parallel.client.magics import ParallelMagics
835 from ipython_parallel.client.magics import ParallelMagics
836
836
837 try:
837 try:
838 # This is injected into __builtins__.
838 # This is injected into __builtins__.
839 ip = get_ipython()
839 ip = get_ipython()
840 except NameError:
840 except NameError:
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 return
842 return
843
843
844 M = ParallelMagics(ip, self, suffix)
844 M = ParallelMagics(ip, self, suffix)
845 ip.magics_manager.register(M)
845 ip.magics_manager.register(M)
846
846
847
847
848 @skip_doctest
848 @skip_doctest
849 class LoadBalancedView(View):
849 class LoadBalancedView(View):
850 """An load-balancing View that only executes via the Task scheduler.
850 """An load-balancing View that only executes via the Task scheduler.
851
851
852 Load-balanced views can be created with the client's `view` method:
852 Load-balanced views can be created with the client's `view` method:
853
853
854 >>> v = client.load_balanced_view()
854 >>> v = client.load_balanced_view()
855
855
856 or targets can be specified, to restrict the potential destinations:
856 or targets can be specified, to restrict the potential destinations:
857
857
858 >>> v = client.load_balanced_view([1,3])
858 >>> v = client.load_balanced_view([1,3])
859
859
860 which would restrict loadbalancing to between engines 1 and 3.
860 which would restrict loadbalancing to between engines 1 and 3.
861
861
862 """
862 """
863
863
864 follow=Any()
864 follow=Any()
865 after=Any()
865 after=Any()
866 timeout=CFloat()
866 timeout=CFloat()
867 retries = Integer(0)
867 retries = Integer(0)
868
868
869 _task_scheme = Any()
869 _task_scheme = Any()
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871
871
872 def __init__(self, client=None, socket=None, **flags):
872 def __init__(self, client=None, socket=None, **flags):
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 self._task_scheme=client._task_scheme
874 self._task_scheme=client._task_scheme
875
875
876 def _validate_dependency(self, dep):
876 def _validate_dependency(self, dep):
877 """validate a dependency.
877 """validate a dependency.
878
878
879 For use in `set_flags`.
879 For use in `set_flags`.
880 """
880 """
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 return True
882 return True
883 elif isinstance(dep, (list,set, tuple)):
883 elif isinstance(dep, (list,set, tuple)):
884 for d in dep:
884 for d in dep:
885 if not isinstance(d, string_types + (AsyncResult,)):
885 if not isinstance(d, string_types + (AsyncResult,)):
886 return False
886 return False
887 elif isinstance(dep, dict):
887 elif isinstance(dep, dict):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 return False
889 return False
890 if not isinstance(dep['msg_ids'], list):
890 if not isinstance(dep['msg_ids'], list):
891 return False
891 return False
892 for d in dep['msg_ids']:
892 for d in dep['msg_ids']:
893 if not isinstance(d, string_types):
893 if not isinstance(d, string_types):
894 return False
894 return False
895 else:
895 else:
896 return False
896 return False
897
897
898 return True
898 return True
899
899
900 def _render_dependency(self, dep):
900 def _render_dependency(self, dep):
901 """helper for building jsonable dependencies from various input forms."""
901 """helper for building jsonable dependencies from various input forms."""
902 if isinstance(dep, Dependency):
902 if isinstance(dep, Dependency):
903 return dep.as_dict()
903 return dep.as_dict()
904 elif isinstance(dep, AsyncResult):
904 elif isinstance(dep, AsyncResult):
905 return dep.msg_ids
905 return dep.msg_ids
906 elif dep is None:
906 elif dep is None:
907 return []
907 return []
908 else:
908 else:
909 # pass to Dependency constructor
909 # pass to Dependency constructor
910 return list(Dependency(dep))
910 return list(Dependency(dep))
911
911
912 def set_flags(self, **kwargs):
912 def set_flags(self, **kwargs):
913 """set my attribute flags by keyword.
913 """set my attribute flags by keyword.
914
914
915 A View is a wrapper for the Client's apply method, but with attributes
915 A View is a wrapper for the Client's apply method, but with attributes
916 that specify keyword arguments, those attributes can be set by keyword
916 that specify keyword arguments, those attributes can be set by keyword
917 argument with this method.
917 argument with this method.
918
918
919 Parameters
919 Parameters
920 ----------
920 ----------
921
921
922 block : bool
922 block : bool
923 whether to wait for results
923 whether to wait for results
924 track : bool
924 track : bool
925 whether to create a MessageTracker to allow the user to
925 whether to create a MessageTracker to allow the user to
926 safely edit after arrays and buffers during non-copying
926 safely edit after arrays and buffers during non-copying
927 sends.
927 sends.
928
928
929 after : Dependency or collection of msg_ids
929 after : Dependency or collection of msg_ids
930 Only for load-balanced execution (targets=None)
930 Only for load-balanced execution (targets=None)
931 Specify a list of msg_ids as a time-based dependency.
931 Specify a list of msg_ids as a time-based dependency.
932 This job will only be run *after* the dependencies
932 This job will only be run *after* the dependencies
933 have been met.
933 have been met.
934
934
935 follow : Dependency or collection of msg_ids
935 follow : Dependency or collection of msg_ids
936 Only for load-balanced execution (targets=None)
936 Only for load-balanced execution (targets=None)
937 Specify a list of msg_ids as a location-based dependency.
937 Specify a list of msg_ids as a location-based dependency.
938 This job will only be run on an engine where this dependency
938 This job will only be run on an engine where this dependency
939 is met.
939 is met.
940
940
941 timeout : float/int or None
941 timeout : float/int or None
942 Only for load-balanced execution (targets=None)
942 Only for load-balanced execution (targets=None)
943 Specify an amount of time (in seconds) for the scheduler to
943 Specify an amount of time (in seconds) for the scheduler to
944 wait for dependencies to be met before failing with a
944 wait for dependencies to be met before failing with a
945 DependencyTimeout.
945 DependencyTimeout.
946
946
947 retries : int
947 retries : int
948 Number of times a task will be retried on failure.
948 Number of times a task will be retried on failure.
949 """
949 """
950
950
951 super(LoadBalancedView, self).set_flags(**kwargs)
951 super(LoadBalancedView, self).set_flags(**kwargs)
952 for name in ('follow', 'after'):
952 for name in ('follow', 'after'):
953 if name in kwargs:
953 if name in kwargs:
954 value = kwargs[name]
954 value = kwargs[name]
955 if self._validate_dependency(value):
955 if self._validate_dependency(value):
956 setattr(self, name, value)
956 setattr(self, name, value)
957 else:
957 else:
958 raise ValueError("Invalid dependency: %r"%value)
958 raise ValueError("Invalid dependency: %r"%value)
959 if 'timeout' in kwargs:
959 if 'timeout' in kwargs:
960 t = kwargs['timeout']
960 t = kwargs['timeout']
961 if not isinstance(t, (int, float, type(None))):
961 if not isinstance(t, (int, float, type(None))):
962 if (not PY3) and (not isinstance(t, long)):
962 if (not PY3) and (not isinstance(t, long)):
963 raise TypeError("Invalid type for timeout: %r"%type(t))
963 raise TypeError("Invalid type for timeout: %r"%type(t))
964 if t is not None:
964 if t is not None:
965 if t < 0:
965 if t < 0:
966 raise ValueError("Invalid timeout: %s"%t)
966 raise ValueError("Invalid timeout: %s"%t)
967 self.timeout = t
967 self.timeout = t
968
968
969 @sync_results
969 @sync_results
970 @save_ids
970 @save_ids
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 after=None, follow=None, timeout=None,
972 after=None, follow=None, timeout=None,
973 targets=None, retries=None):
973 targets=None, retries=None):
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
975
975
976 This method temporarily sets all of `apply`'s flags for a single call.
976 This method temporarily sets all of `apply`'s flags for a single call.
977
977
978 Parameters
978 Parameters
979 ----------
979 ----------
980
980
981 f : callable
981 f : callable
982
982
983 args : list [default: empty]
983 args : list [default: empty]
984
984
985 kwargs : dict [default: empty]
985 kwargs : dict [default: empty]
986
986
987 block : bool [default: self.block]
987 block : bool [default: self.block]
988 whether to block
988 whether to block
989 track : bool [default: self.track]
989 track : bool [default: self.track]
990 whether to ask zmq to track the message, for safe non-copying sends
990 whether to ask zmq to track the message, for safe non-copying sends
991
991
992 !!!!!! TODO: THE REST HERE !!!!
992 !!!!!! TODO: THE REST HERE !!!!
993
993
994 Returns
994 Returns
995 -------
995 -------
996
996
997 if self.block is False:
997 if self.block is False:
998 returns AsyncResult
998 returns AsyncResult
999 else:
999 else:
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1001 This will be a list of self.targets is also a list (even length 1), or
1001 This will be a list of self.targets is also a list (even length 1), or
1002 the single result if self.targets is an integer engine id
1002 the single result if self.targets is an integer engine id
1003 """
1003 """
1004
1004
1005 # validate whether we can run
1005 # validate whether we can run
1006 if self._socket.closed:
1006 if self._socket.closed:
1007 msg = "Task farming is disabled"
1007 msg = "Task farming is disabled"
1008 if self._task_scheme == 'pure':
1008 if self._task_scheme == 'pure':
1009 msg += " because the pure ZMQ scheduler cannot handle"
1009 msg += " because the pure ZMQ scheduler cannot handle"
1010 msg += " disappearing engines."
1010 msg += " disappearing engines."
1011 raise RuntimeError(msg)
1011 raise RuntimeError(msg)
1012
1012
1013 if self._task_scheme == 'pure':
1013 if self._task_scheme == 'pure':
1014 # pure zmq scheme doesn't support extra features
1014 # pure zmq scheme doesn't support extra features
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 "follow, after, retries, targets, timeout"
1016 "follow, after, retries, targets, timeout"
1017 if (follow or after or retries or targets or timeout):
1017 if (follow or after or retries or targets or timeout):
1018 # hard fail on Scheduler flags
1018 # hard fail on Scheduler flags
1019 raise RuntimeError(msg)
1019 raise RuntimeError(msg)
1020 if isinstance(f, dependent):
1020 if isinstance(f, dependent):
1021 # soft warn on functional dependencies
1021 # soft warn on functional dependencies
1022 warnings.warn(msg, RuntimeWarning)
1022 warnings.warn(msg, RuntimeWarning)
1023
1023
1024 # build args
1024 # build args
1025 args = [] if args is None else args
1025 args = [] if args is None else args
1026 kwargs = {} if kwargs is None else kwargs
1026 kwargs = {} if kwargs is None else kwargs
1027 block = self.block if block is None else block
1027 block = self.block if block is None else block
1028 track = self.track if track is None else track
1028 track = self.track if track is None else track
1029 after = self.after if after is None else after
1029 after = self.after if after is None else after
1030 retries = self.retries if retries is None else retries
1030 retries = self.retries if retries is None else retries
1031 follow = self.follow if follow is None else follow
1031 follow = self.follow if follow is None else follow
1032 timeout = self.timeout if timeout is None else timeout
1032 timeout = self.timeout if timeout is None else timeout
1033 targets = self.targets if targets is None else targets
1033 targets = self.targets if targets is None else targets
1034
1034
1035 if not isinstance(retries, int):
1035 if not isinstance(retries, int):
1036 raise TypeError('retries must be int, not %r'%type(retries))
1036 raise TypeError('retries must be int, not %r'%type(retries))
1037
1037
1038 if targets is None:
1038 if targets is None:
1039 idents = []
1039 idents = []
1040 else:
1040 else:
1041 idents = self.client._build_targets(targets)[0]
1041 idents = self.client._build_targets(targets)[0]
1042 # ensure *not* bytes
1042 # ensure *not* bytes
1043 idents = [ ident.decode() for ident in idents ]
1043 idents = [ ident.decode() for ident in idents ]
1044
1044
1045 after = self._render_dependency(after)
1045 after = self._render_dependency(after)
1046 follow = self._render_dependency(follow)
1046 follow = self._render_dependency(follow)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048
1048
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 metadata=metadata)
1050 metadata=metadata)
1051 tracker = None if track is False else msg['tracker']
1051 tracker = None if track is False else msg['tracker']
1052
1052
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1054 targets=None, tracker=tracker, owner=True,
1055 )
1055 )
1056 if block:
1056 if block:
1057 try:
1057 try:
1058 return ar.get()
1058 return ar.get()
1059 except KeyboardInterrupt:
1059 except KeyboardInterrupt:
1060 pass
1060 pass
1061 return ar
1061 return ar
1062
1062
1063 @sync_results
1063 @sync_results
1064 @save_ids
1064 @save_ids
1065 def map(self, f, *sequences, **kwargs):
1065 def map(self, f, *sequences, **kwargs):
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067
1067
1068 Parallel version of builtin `map`, load-balanced by this View.
1068 Parallel version of builtin `map`, load-balanced by this View.
1069
1069
1070 `block`, and `chunksize` can be specified by keyword only.
1070 `block`, and `chunksize` can be specified by keyword only.
1071
1071
1072 Each `chunksize` elements will be a separate task, and will be
1072 Each `chunksize` elements will be a separate task, and will be
1073 load-balanced. This lets individual elements be available for iteration
1073 load-balanced. This lets individual elements be available for iteration
1074 as soon as they arrive.
1074 as soon as they arrive.
1075
1075
1076 Parameters
1076 Parameters
1077 ----------
1077 ----------
1078
1078
1079 f : callable
1079 f : callable
1080 function to be mapped
1080 function to be mapped
1081 *sequences: one or more sequences of matching length
1081 *sequences: one or more sequences of matching length
1082 the sequences to be distributed and passed to `f`
1082 the sequences to be distributed and passed to `f`
1083 block : bool [default self.block]
1083 block : bool [default self.block]
1084 whether to wait for the result or not
1084 whether to wait for the result or not
1085 track : bool
1085 track : bool
1086 whether to create a MessageTracker to allow the user to
1086 whether to create a MessageTracker to allow the user to
1087 safely edit after arrays and buffers during non-copying
1087 safely edit after arrays and buffers during non-copying
1088 sends.
1088 sends.
1089 chunksize : int [default 1]
1089 chunksize : int [default 1]
1090 how many elements should be in each task.
1090 how many elements should be in each task.
1091 ordered : bool [default True]
1091 ordered : bool [default True]
1092 Whether the results should be gathered as they arrive, or enforce
1092 Whether the results should be gathered as they arrive, or enforce
1093 the order of submission.
1093 the order of submission.
1094
1094
1095 Only applies when iterating through AsyncMapResult as results arrive.
1095 Only applies when iterating through AsyncMapResult as results arrive.
1096 Has no effect when block=True.
1096 Has no effect when block=True.
1097
1097
1098 Returns
1098 Returns
1099 -------
1099 -------
1100
1100
1101 if block=False
1101 if block=False
1102 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1102 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1103 An object like AsyncResult, but which reassembles the sequence of results
1103 An object like AsyncResult, but which reassembles the sequence of results
1104 into a single list. AsyncMapResults can be iterated through before all
1104 into a single list. AsyncMapResults can be iterated through before all
1105 results are complete.
1105 results are complete.
1106 else
1106 else
1107 A list, the result of ``map(f,*sequences)``
1107 A list, the result of ``map(f,*sequences)``
1108 """
1108 """
1109
1109
1110 # default
1110 # default
1111 block = kwargs.get('block', self.block)
1111 block = kwargs.get('block', self.block)
1112 chunksize = kwargs.get('chunksize', 1)
1112 chunksize = kwargs.get('chunksize', 1)
1113 ordered = kwargs.get('ordered', True)
1113 ordered = kwargs.get('ordered', True)
1114
1114
1115 keyset = set(kwargs.keys())
1115 keyset = set(kwargs.keys())
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 if extra_keys:
1117 if extra_keys:
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119
1119
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1121
1121
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 return pf.map(*sequences)
1123 return pf.map(*sequences)
1124
1124
1125 __all__ = ['LoadBalancedView', 'DirectView']
1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,3 +1,3 b''
1 if __name__ == '__main__':
1 if __name__ == '__main__':
2 from IPython.parallel.apps import ipclusterapp as app
2 from ipython_parallel.apps import ipclusterapp as app
3 app.launch_new_instance()
3 app.launch_new_instance()
@@ -1,6 +1,6 b''
1 def main():
1 def main():
2 from IPython.parallel.apps import ipcontrollerapp as app
2 from ipython_parallel.apps import ipcontrollerapp as app
3 app.launch_new_instance()
3 app.launch_new_instance()
4
4
5 if __name__ == '__main__':
5 if __name__ == '__main__':
6 main()
6 main()
@@ -1,229 +1,229 b''
1 """Dependency utilities
1 """Dependency utilities
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2013 The IPython Development Team
8 # Copyright (C) 2013 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 from types import ModuleType
14 from types import ModuleType
15
15
16 from IPython.parallel.client.asyncresult import AsyncResult
16 from ipython_parallel.client.asyncresult import AsyncResult
17 from IPython.parallel.error import UnmetDependency
17 from ipython_parallel.error import UnmetDependency
18 from IPython.parallel.util import interactive
18 from ipython_parallel.util import interactive
19 from IPython.utils import py3compat
19 from IPython.utils import py3compat
20 from IPython.utils.py3compat import string_types
20 from IPython.utils.py3compat import string_types
21 from IPython.utils.pickleutil import can, uncan
21 from IPython.utils.pickleutil import can, uncan
22
22
23 class depend(object):
23 class depend(object):
24 """Dependency decorator, for use with tasks.
24 """Dependency decorator, for use with tasks.
25
25
26 `@depend` lets you define a function for engine dependencies
26 `@depend` lets you define a function for engine dependencies
27 just like you use `apply` for tasks.
27 just like you use `apply` for tasks.
28
28
29
29
30 Examples
30 Examples
31 --------
31 --------
32 ::
32 ::
33
33
34 @depend(df, a,b, c=5)
34 @depend(df, a,b, c=5)
35 def f(m,n,p)
35 def f(m,n,p)
36
36
37 view.apply(f, 1,2,3)
37 view.apply(f, 1,2,3)
38
38
39 will call df(a,b,c=5) on the engine, and if it returns False or
39 will call df(a,b,c=5) on the engine, and if it returns False or
40 raises an UnmetDependency error, then the task will not be run
40 raises an UnmetDependency error, then the task will not be run
41 and another engine will be tried.
41 and another engine will be tried.
42 """
42 """
43 def __init__(self, _wrapped_f, *args, **kwargs):
43 def __init__(self, _wrapped_f, *args, **kwargs):
44 self.f = _wrapped_f
44 self.f = _wrapped_f
45 self.args = args
45 self.args = args
46 self.kwargs = kwargs
46 self.kwargs = kwargs
47
47
48 def __call__(self, f):
48 def __call__(self, f):
49 return dependent(f, self.f, *self.args, **self.kwargs)
49 return dependent(f, self.f, *self.args, **self.kwargs)
50
50
51 class dependent(object):
51 class dependent(object):
52 """A function that depends on another function.
52 """A function that depends on another function.
53 This is an object to prevent the closure used
53 This is an object to prevent the closure used
54 in traditional decorators, which are not picklable.
54 in traditional decorators, which are not picklable.
55 """
55 """
56
56
57 def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs):
57 def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs):
58 self.f = _wrapped_f
58 self.f = _wrapped_f
59 name = getattr(_wrapped_f, '__name__', 'f')
59 name = getattr(_wrapped_f, '__name__', 'f')
60 if py3compat.PY3:
60 if py3compat.PY3:
61 self.__name__ = name
61 self.__name__ = name
62 else:
62 else:
63 self.func_name = name
63 self.func_name = name
64 self.df = _wrapped_df
64 self.df = _wrapped_df
65 self.dargs = dargs
65 self.dargs = dargs
66 self.dkwargs = dkwargs
66 self.dkwargs = dkwargs
67
67
68 def check_dependency(self):
68 def check_dependency(self):
69 if self.df(*self.dargs, **self.dkwargs) is False:
69 if self.df(*self.dargs, **self.dkwargs) is False:
70 raise UnmetDependency()
70 raise UnmetDependency()
71
71
72 def __call__(self, *args, **kwargs):
72 def __call__(self, *args, **kwargs):
73 return self.f(*args, **kwargs)
73 return self.f(*args, **kwargs)
74
74
75 if not py3compat.PY3:
75 if not py3compat.PY3:
76 @property
76 @property
77 def __name__(self):
77 def __name__(self):
78 return self.func_name
78 return self.func_name
79
79
80 @interactive
80 @interactive
81 def _require(*modules, **mapping):
81 def _require(*modules, **mapping):
82 """Helper for @require decorator."""
82 """Helper for @require decorator."""
83 from IPython.parallel.error import UnmetDependency
83 from ipython_parallel.error import UnmetDependency
84 from IPython.utils.pickleutil import uncan
84 from IPython.utils.pickleutil import uncan
85 user_ns = globals()
85 user_ns = globals()
86 for name in modules:
86 for name in modules:
87 try:
87 try:
88 exec('import %s' % name, user_ns)
88 exec('import %s' % name, user_ns)
89 except ImportError:
89 except ImportError:
90 raise UnmetDependency(name)
90 raise UnmetDependency(name)
91
91
92 for name, cobj in mapping.items():
92 for name, cobj in mapping.items():
93 user_ns[name] = uncan(cobj, user_ns)
93 user_ns[name] = uncan(cobj, user_ns)
94 return True
94 return True
95
95
96 def require(*objects, **mapping):
96 def require(*objects, **mapping):
97 """Simple decorator for requiring local objects and modules to be available
97 """Simple decorator for requiring local objects and modules to be available
98 when the decorated function is called on the engine.
98 when the decorated function is called on the engine.
99
99
100 Modules specified by name or passed directly will be imported
100 Modules specified by name or passed directly will be imported
101 prior to calling the decorated function.
101 prior to calling the decorated function.
102
102
103 Objects other than modules will be pushed as a part of the task.
103 Objects other than modules will be pushed as a part of the task.
104 Functions can be passed positionally,
104 Functions can be passed positionally,
105 and will be pushed to the engine with their __name__.
105 and will be pushed to the engine with their __name__.
106 Other objects can be passed by keyword arg.
106 Other objects can be passed by keyword arg.
107
107
108 Examples::
108 Examples::
109
109
110 In [1]: @require('numpy')
110 In [1]: @require('numpy')
111 ...: def norm(a):
111 ...: def norm(a):
112 ...: return numpy.linalg.norm(a,2)
112 ...: return numpy.linalg.norm(a,2)
113
113
114 In [2]: foo = lambda x: x*x
114 In [2]: foo = lambda x: x*x
115 In [3]: @require(foo)
115 In [3]: @require(foo)
116 ...: def bar(a):
116 ...: def bar(a):
117 ...: return foo(1-a)
117 ...: return foo(1-a)
118 """
118 """
119 names = []
119 names = []
120 for obj in objects:
120 for obj in objects:
121 if isinstance(obj, ModuleType):
121 if isinstance(obj, ModuleType):
122 obj = obj.__name__
122 obj = obj.__name__
123
123
124 if isinstance(obj, string_types):
124 if isinstance(obj, string_types):
125 names.append(obj)
125 names.append(obj)
126 elif hasattr(obj, '__name__'):
126 elif hasattr(obj, '__name__'):
127 mapping[obj.__name__] = obj
127 mapping[obj.__name__] = obj
128 else:
128 else:
129 raise TypeError("Objects other than modules and functions "
129 raise TypeError("Objects other than modules and functions "
130 "must be passed by kwarg, but got: %s" % type(obj)
130 "must be passed by kwarg, but got: %s" % type(obj)
131 )
131 )
132
132
133 for name, obj in mapping.items():
133 for name, obj in mapping.items():
134 mapping[name] = can(obj)
134 mapping[name] = can(obj)
135 return depend(_require, *names, **mapping)
135 return depend(_require, *names, **mapping)
136
136
137 class Dependency(set):
137 class Dependency(set):
138 """An object for representing a set of msg_id dependencies.
138 """An object for representing a set of msg_id dependencies.
139
139
140 Subclassed from set().
140 Subclassed from set().
141
141
142 Parameters
142 Parameters
143 ----------
143 ----------
144 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
144 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
145 The msg_ids to depend on
145 The msg_ids to depend on
146 all : bool [default True]
146 all : bool [default True]
147 Whether the dependency should be considered met when *all* depending tasks have completed
147 Whether the dependency should be considered met when *all* depending tasks have completed
148 or only when *any* have been completed.
148 or only when *any* have been completed.
149 success : bool [default True]
149 success : bool [default True]
150 Whether to consider successes as fulfilling dependencies.
150 Whether to consider successes as fulfilling dependencies.
151 failure : bool [default False]
151 failure : bool [default False]
152 Whether to consider failures as fulfilling dependencies.
152 Whether to consider failures as fulfilling dependencies.
153
153
154 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
154 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
155 as soon as the first depended-upon task fails.
155 as soon as the first depended-upon task fails.
156 """
156 """
157
157
158 all=True
158 all=True
159 success=True
159 success=True
160 failure=True
160 failure=True
161
161
162 def __init__(self, dependencies=[], all=True, success=True, failure=False):
162 def __init__(self, dependencies=[], all=True, success=True, failure=False):
163 if isinstance(dependencies, dict):
163 if isinstance(dependencies, dict):
164 # load from dict
164 # load from dict
165 all = dependencies.get('all', True)
165 all = dependencies.get('all', True)
166 success = dependencies.get('success', success)
166 success = dependencies.get('success', success)
167 failure = dependencies.get('failure', failure)
167 failure = dependencies.get('failure', failure)
168 dependencies = dependencies.get('dependencies', [])
168 dependencies = dependencies.get('dependencies', [])
169 ids = []
169 ids = []
170
170
171 # extract ids from various sources:
171 # extract ids from various sources:
172 if isinstance(dependencies, string_types + (AsyncResult,)):
172 if isinstance(dependencies, string_types + (AsyncResult,)):
173 dependencies = [dependencies]
173 dependencies = [dependencies]
174 for d in dependencies:
174 for d in dependencies:
175 if isinstance(d, string_types):
175 if isinstance(d, string_types):
176 ids.append(d)
176 ids.append(d)
177 elif isinstance(d, AsyncResult):
177 elif isinstance(d, AsyncResult):
178 ids.extend(d.msg_ids)
178 ids.extend(d.msg_ids)
179 else:
179 else:
180 raise TypeError("invalid dependency type: %r"%type(d))
180 raise TypeError("invalid dependency type: %r"%type(d))
181
181
182 set.__init__(self, ids)
182 set.__init__(self, ids)
183 self.all = all
183 self.all = all
184 if not (success or failure):
184 if not (success or failure):
185 raise ValueError("Must depend on at least one of successes or failures!")
185 raise ValueError("Must depend on at least one of successes or failures!")
186 self.success=success
186 self.success=success
187 self.failure = failure
187 self.failure = failure
188
188
189 def check(self, completed, failed=None):
189 def check(self, completed, failed=None):
190 """check whether our dependencies have been met."""
190 """check whether our dependencies have been met."""
191 if len(self) == 0:
191 if len(self) == 0:
192 return True
192 return True
193 against = set()
193 against = set()
194 if self.success:
194 if self.success:
195 against = completed
195 against = completed
196 if failed is not None and self.failure:
196 if failed is not None and self.failure:
197 against = against.union(failed)
197 against = against.union(failed)
198 if self.all:
198 if self.all:
199 return self.issubset(against)
199 return self.issubset(against)
200 else:
200 else:
201 return not self.isdisjoint(against)
201 return not self.isdisjoint(against)
202
202
203 def unreachable(self, completed, failed=None):
203 def unreachable(self, completed, failed=None):
204 """return whether this dependency has become impossible."""
204 """return whether this dependency has become impossible."""
205 if len(self) == 0:
205 if len(self) == 0:
206 return False
206 return False
207 against = set()
207 against = set()
208 if not self.success:
208 if not self.success:
209 against = completed
209 against = completed
210 if failed is not None and not self.failure:
210 if failed is not None and not self.failure:
211 against = against.union(failed)
211 against = against.union(failed)
212 if self.all:
212 if self.all:
213 return not self.isdisjoint(against)
213 return not self.isdisjoint(against)
214 else:
214 else:
215 return self.issubset(against)
215 return self.issubset(against)
216
216
217
217
218 def as_dict(self):
218 def as_dict(self):
219 """Represent this dependency as a dict. For json compatibility."""
219 """Represent this dependency as a dict. For json compatibility."""
220 return dict(
220 return dict(
221 dependencies=list(self),
221 dependencies=list(self),
222 all=self.all,
222 all=self.all,
223 success=self.success,
223 success=self.success,
224 failure=self.failure
224 failure=self.failure
225 )
225 )
226
226
227
227
228 __all__ = ['depend', 'require', 'dependent', 'Dependency']
228 __all__ = ['depend', 'require', 'dependent', 'Dependency']
229
229
@@ -1,193 +1,193 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 A multi-heart Heartbeat system using PUB and ROUTER sockets. pings are sent out on the PUB,
3 A multi-heart Heartbeat system using PUB and ROUTER sockets. pings are sent out on the PUB,
4 and hearts are tracked based on their DEALER identities.
4 and hearts are tracked based on their DEALER identities.
5 """
5 """
6
6
7 # Copyright (c) IPython Development Team.
7 # Copyright (c) IPython Development Team.
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9
9
10 from __future__ import print_function
10 from __future__ import print_function
11 import time
11 import time
12 import uuid
12 import uuid
13
13
14 import zmq
14 import zmq
15 from zmq.devices import ThreadDevice, ThreadMonitoredQueue
15 from zmq.devices import ThreadDevice, ThreadMonitoredQueue
16 from zmq.eventloop import ioloop, zmqstream
16 from zmq.eventloop import ioloop, zmqstream
17
17
18 from IPython.config.configurable import LoggingConfigurable
18 from IPython.config.configurable import LoggingConfigurable
19 from IPython.utils.py3compat import str_to_bytes
19 from IPython.utils.py3compat import str_to_bytes
20 from IPython.utils.traitlets import Set, Instance, CFloat, Integer, Dict, Bool
20 from IPython.utils.traitlets import Set, Instance, CFloat, Integer, Dict, Bool
21
21
22 from IPython.parallel.util import log_errors
22 from ipython_parallel.util import log_errors
23
23
24 class Heart(object):
24 class Heart(object):
25 """A basic heart object for responding to a HeartMonitor.
25 """A basic heart object for responding to a HeartMonitor.
26 This is a simple wrapper with defaults for the most common
26 This is a simple wrapper with defaults for the most common
27 Device model for responding to heartbeats.
27 Device model for responding to heartbeats.
28
28
29 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
29 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
30 SUB/DEALER for in/out.
30 SUB/DEALER for in/out.
31
31
32 You can specify the DEALER's IDENTITY via the optional heart_id argument."""
32 You can specify the DEALER's IDENTITY via the optional heart_id argument."""
33 device=None
33 device=None
34 id=None
34 id=None
35 def __init__(self, in_addr, out_addr, mon_addr=None, in_type=zmq.SUB, out_type=zmq.DEALER, mon_type=zmq.PUB, heart_id=None):
35 def __init__(self, in_addr, out_addr, mon_addr=None, in_type=zmq.SUB, out_type=zmq.DEALER, mon_type=zmq.PUB, heart_id=None):
36 if mon_addr is None:
36 if mon_addr is None:
37 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
37 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
38 else:
38 else:
39 self.device = ThreadMonitoredQueue(in_type, out_type, mon_type, in_prefix=b"", out_prefix=b"")
39 self.device = ThreadMonitoredQueue(in_type, out_type, mon_type, in_prefix=b"", out_prefix=b"")
40 # do not allow the device to share global Context.instance,
40 # do not allow the device to share global Context.instance,
41 # which is the default behavior in pyzmq > 2.1.10
41 # which is the default behavior in pyzmq > 2.1.10
42 self.device.context_factory = zmq.Context
42 self.device.context_factory = zmq.Context
43
43
44 self.device.daemon=True
44 self.device.daemon=True
45 self.device.connect_in(in_addr)
45 self.device.connect_in(in_addr)
46 self.device.connect_out(out_addr)
46 self.device.connect_out(out_addr)
47 if mon_addr is not None:
47 if mon_addr is not None:
48 self.device.connect_mon(mon_addr)
48 self.device.connect_mon(mon_addr)
49 if in_type == zmq.SUB:
49 if in_type == zmq.SUB:
50 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
50 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
51 if heart_id is None:
51 if heart_id is None:
52 heart_id = uuid.uuid4().bytes
52 heart_id = uuid.uuid4().bytes
53 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
53 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
54 self.id = heart_id
54 self.id = heart_id
55
55
56 def start(self):
56 def start(self):
57 return self.device.start()
57 return self.device.start()
58
58
59
59
60 class HeartMonitor(LoggingConfigurable):
60 class HeartMonitor(LoggingConfigurable):
61 """A basic HeartMonitor class
61 """A basic HeartMonitor class
62 pingstream: a PUB stream
62 pingstream: a PUB stream
63 pongstream: an ROUTER stream
63 pongstream: an ROUTER stream
64 period: the period of the heartbeat in milliseconds"""
64 period: the period of the heartbeat in milliseconds"""
65
65
66 debug = Bool(False, config=True,
66 debug = Bool(False, config=True,
67 help="""Whether to include every heartbeat in debugging output.
67 help="""Whether to include every heartbeat in debugging output.
68
68
69 Has to be set explicitly, because there will be *a lot* of output.
69 Has to be set explicitly, because there will be *a lot* of output.
70 """
70 """
71 )
71 )
72 period = Integer(3000, config=True,
72 period = Integer(3000, config=True,
73 help='The frequency at which the Hub pings the engines for heartbeats '
73 help='The frequency at which the Hub pings the engines for heartbeats '
74 '(in ms)',
74 '(in ms)',
75 )
75 )
76 max_heartmonitor_misses = Integer(10, config=True,
76 max_heartmonitor_misses = Integer(10, config=True,
77 help='Allowed consecutive missed pings from controller Hub to engine before unregistering.',
77 help='Allowed consecutive missed pings from controller Hub to engine before unregistering.',
78 )
78 )
79
79
80 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
80 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
81 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
81 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
82 loop = Instance('zmq.eventloop.ioloop.IOLoop')
82 loop = Instance('zmq.eventloop.ioloop.IOLoop')
83 def _loop_default(self):
83 def _loop_default(self):
84 return ioloop.IOLoop.instance()
84 return ioloop.IOLoop.instance()
85
85
86 # not settable:
86 # not settable:
87 hearts=Set()
87 hearts=Set()
88 responses=Set()
88 responses=Set()
89 on_probation=Dict()
89 on_probation=Dict()
90 last_ping=CFloat(0)
90 last_ping=CFloat(0)
91 _new_handlers = Set()
91 _new_handlers = Set()
92 _failure_handlers = Set()
92 _failure_handlers = Set()
93 lifetime = CFloat(0)
93 lifetime = CFloat(0)
94 tic = CFloat(0)
94 tic = CFloat(0)
95
95
96 def __init__(self, **kwargs):
96 def __init__(self, **kwargs):
97 super(HeartMonitor, self).__init__(**kwargs)
97 super(HeartMonitor, self).__init__(**kwargs)
98
98
99 self.pongstream.on_recv(self.handle_pong)
99 self.pongstream.on_recv(self.handle_pong)
100
100
101 def start(self):
101 def start(self):
102 self.tic = time.time()
102 self.tic = time.time()
103 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
103 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
104 self.caller.start()
104 self.caller.start()
105
105
106 def add_new_heart_handler(self, handler):
106 def add_new_heart_handler(self, handler):
107 """add a new handler for new hearts"""
107 """add a new handler for new hearts"""
108 self.log.debug("heartbeat::new_heart_handler: %s", handler)
108 self.log.debug("heartbeat::new_heart_handler: %s", handler)
109 self._new_handlers.add(handler)
109 self._new_handlers.add(handler)
110
110
111 def add_heart_failure_handler(self, handler):
111 def add_heart_failure_handler(self, handler):
112 """add a new handler for heart failure"""
112 """add a new handler for heart failure"""
113 self.log.debug("heartbeat::new heart failure handler: %s", handler)
113 self.log.debug("heartbeat::new heart failure handler: %s", handler)
114 self._failure_handlers.add(handler)
114 self._failure_handlers.add(handler)
115
115
116 def beat(self):
116 def beat(self):
117 self.pongstream.flush()
117 self.pongstream.flush()
118 self.last_ping = self.lifetime
118 self.last_ping = self.lifetime
119
119
120 toc = time.time()
120 toc = time.time()
121 self.lifetime += toc-self.tic
121 self.lifetime += toc-self.tic
122 self.tic = toc
122 self.tic = toc
123 if self.debug:
123 if self.debug:
124 self.log.debug("heartbeat::sending %s", self.lifetime)
124 self.log.debug("heartbeat::sending %s", self.lifetime)
125 goodhearts = self.hearts.intersection(self.responses)
125 goodhearts = self.hearts.intersection(self.responses)
126 missed_beats = self.hearts.difference(goodhearts)
126 missed_beats = self.hearts.difference(goodhearts)
127 newhearts = self.responses.difference(goodhearts)
127 newhearts = self.responses.difference(goodhearts)
128 for heart in newhearts:
128 for heart in newhearts:
129 self.handle_new_heart(heart)
129 self.handle_new_heart(heart)
130 heartfailures, on_probation = self._check_missed(missed_beats, self.on_probation,
130 heartfailures, on_probation = self._check_missed(missed_beats, self.on_probation,
131 self.hearts)
131 self.hearts)
132 for failure in heartfailures:
132 for failure in heartfailures:
133 self.handle_heart_failure(failure)
133 self.handle_heart_failure(failure)
134 self.on_probation = on_probation
134 self.on_probation = on_probation
135 self.responses = set()
135 self.responses = set()
136 #print self.on_probation, self.hearts
136 #print self.on_probation, self.hearts
137 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
137 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
138 self.pingstream.send(str_to_bytes(str(self.lifetime)))
138 self.pingstream.send(str_to_bytes(str(self.lifetime)))
139 # flush stream to force immediate socket send
139 # flush stream to force immediate socket send
140 self.pingstream.flush()
140 self.pingstream.flush()
141
141
142 def _check_missed(self, missed_beats, on_probation, hearts):
142 def _check_missed(self, missed_beats, on_probation, hearts):
143 """Update heartbeats on probation, identifying any that have too many misses.
143 """Update heartbeats on probation, identifying any that have too many misses.
144 """
144 """
145 failures = []
145 failures = []
146 new_probation = {}
146 new_probation = {}
147 for cur_heart in (b for b in missed_beats if b in hearts):
147 for cur_heart in (b for b in missed_beats if b in hearts):
148 miss_count = on_probation.get(cur_heart, 0) + 1
148 miss_count = on_probation.get(cur_heart, 0) + 1
149 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count))
149 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count))
150 if miss_count > self.max_heartmonitor_misses:
150 if miss_count > self.max_heartmonitor_misses:
151 failures.append(cur_heart)
151 failures.append(cur_heart)
152 else:
152 else:
153 new_probation[cur_heart] = miss_count
153 new_probation[cur_heart] = miss_count
154 return failures, new_probation
154 return failures, new_probation
155
155
156 def handle_new_heart(self, heart):
156 def handle_new_heart(self, heart):
157 if self._new_handlers:
157 if self._new_handlers:
158 for handler in self._new_handlers:
158 for handler in self._new_handlers:
159 handler(heart)
159 handler(heart)
160 else:
160 else:
161 self.log.info("heartbeat::yay, got new heart %s!", heart)
161 self.log.info("heartbeat::yay, got new heart %s!", heart)
162 self.hearts.add(heart)
162 self.hearts.add(heart)
163
163
164 def handle_heart_failure(self, heart):
164 def handle_heart_failure(self, heart):
165 if self._failure_handlers:
165 if self._failure_handlers:
166 for handler in self._failure_handlers:
166 for handler in self._failure_handlers:
167 try:
167 try:
168 handler(heart)
168 handler(heart)
169 except Exception as e:
169 except Exception as e:
170 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
170 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
171 pass
171 pass
172 else:
172 else:
173 self.log.info("heartbeat::Heart %s failed :(", heart)
173 self.log.info("heartbeat::Heart %s failed :(", heart)
174 self.hearts.remove(heart)
174 self.hearts.remove(heart)
175
175
176
176
177 @log_errors
177 @log_errors
178 def handle_pong(self, msg):
178 def handle_pong(self, msg):
179 "a heart just beat"
179 "a heart just beat"
180 current = str_to_bytes(str(self.lifetime))
180 current = str_to_bytes(str(self.lifetime))
181 last = str_to_bytes(str(self.last_ping))
181 last = str_to_bytes(str(self.last_ping))
182 if msg[1] == current:
182 if msg[1] == current:
183 delta = time.time()-self.tic
183 delta = time.time()-self.tic
184 if self.debug:
184 if self.debug:
185 self.log.debug("heartbeat::heart %r took %.2f ms to respond", msg[0], 1000*delta)
185 self.log.debug("heartbeat::heart %r took %.2f ms to respond", msg[0], 1000*delta)
186 self.responses.add(msg[0])
186 self.responses.add(msg[0])
187 elif msg[1] == last:
187 elif msg[1] == last:
188 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
188 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
189 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
189 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
190 self.responses.add(msg[0])
190 self.responses.add(msg[0])
191 else:
191 else:
192 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
192 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
193
193
@@ -1,1438 +1,1438 b''
1 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
2
2
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5 """
5 """
6
6
7 # Copyright (c) IPython Development Team.
7 # Copyright (c) IPython Development Team.
8 # Distributed under the terms of the Modified BSD License.
8 # Distributed under the terms of the Modified BSD License.
9
9
10 from __future__ import print_function
10 from __future__ import print_function
11
11
12 import json
12 import json
13 import os
13 import os
14 import sys
14 import sys
15 import time
15 import time
16 from datetime import datetime
16 from datetime import datetime
17
17
18 import zmq
18 import zmq
19 from zmq.eventloop.zmqstream import ZMQStream
19 from zmq.eventloop.zmqstream import ZMQStream
20
20
21 # internal:
21 # internal:
22 from IPython.utils.importstring import import_item
22 from IPython.utils.importstring import import_item
23 from IPython.utils.jsonutil import extract_dates
23 from IPython.utils.jsonutil import extract_dates
24 from IPython.utils.localinterfaces import localhost
24 from IPython.utils.localinterfaces import localhost
25 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
25 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
26 from IPython.utils.traitlets import (
26 from IPython.utils.traitlets import (
27 HasTraits, Any, Instance, Integer, Unicode, Dict, Set, Tuple, DottedObjectName
27 HasTraits, Any, Instance, Integer, Unicode, Dict, Set, Tuple, DottedObjectName
28 )
28 )
29
29
30 from IPython.parallel import error, util
30 from ipython_parallel import error, util
31 from IPython.parallel.factory import RegistrationFactory
31 from ipython_parallel.factory import RegistrationFactory
32
32
33 from IPython.kernel.zmq.session import SessionFactory
33 from IPython.kernel.zmq.session import SessionFactory
34
34
35 from .heartmonitor import HeartMonitor
35 from .heartmonitor import HeartMonitor
36
36
37
37
38 def _passer(*args, **kwargs):
38 def _passer(*args, **kwargs):
39 return
39 return
40
40
41 def _printer(*args, **kwargs):
41 def _printer(*args, **kwargs):
42 print (args)
42 print (args)
43 print (kwargs)
43 print (kwargs)
44
44
45 def empty_record():
45 def empty_record():
46 """Return an empty dict with all record keys."""
46 """Return an empty dict with all record keys."""
47 return {
47 return {
48 'msg_id' : None,
48 'msg_id' : None,
49 'header' : None,
49 'header' : None,
50 'metadata' : None,
50 'metadata' : None,
51 'content': None,
51 'content': None,
52 'buffers': None,
52 'buffers': None,
53 'submitted': None,
53 'submitted': None,
54 'client_uuid' : None,
54 'client_uuid' : None,
55 'engine_uuid' : None,
55 'engine_uuid' : None,
56 'started': None,
56 'started': None,
57 'completed': None,
57 'completed': None,
58 'resubmitted': None,
58 'resubmitted': None,
59 'received': None,
59 'received': None,
60 'result_header' : None,
60 'result_header' : None,
61 'result_metadata' : None,
61 'result_metadata' : None,
62 'result_content' : None,
62 'result_content' : None,
63 'result_buffers' : None,
63 'result_buffers' : None,
64 'queue' : None,
64 'queue' : None,
65 'execute_input' : None,
65 'execute_input' : None,
66 'execute_result': None,
66 'execute_result': None,
67 'error': None,
67 'error': None,
68 'stdout': '',
68 'stdout': '',
69 'stderr': '',
69 'stderr': '',
70 }
70 }
71
71
72 def init_record(msg):
72 def init_record(msg):
73 """Initialize a TaskRecord based on a request."""
73 """Initialize a TaskRecord based on a request."""
74 header = msg['header']
74 header = msg['header']
75 return {
75 return {
76 'msg_id' : header['msg_id'],
76 'msg_id' : header['msg_id'],
77 'header' : header,
77 'header' : header,
78 'content': msg['content'],
78 'content': msg['content'],
79 'metadata': msg['metadata'],
79 'metadata': msg['metadata'],
80 'buffers': msg['buffers'],
80 'buffers': msg['buffers'],
81 'submitted': header['date'],
81 'submitted': header['date'],
82 'client_uuid' : None,
82 'client_uuid' : None,
83 'engine_uuid' : None,
83 'engine_uuid' : None,
84 'started': None,
84 'started': None,
85 'completed': None,
85 'completed': None,
86 'resubmitted': None,
86 'resubmitted': None,
87 'received': None,
87 'received': None,
88 'result_header' : None,
88 'result_header' : None,
89 'result_metadata': None,
89 'result_metadata': None,
90 'result_content' : None,
90 'result_content' : None,
91 'result_buffers' : None,
91 'result_buffers' : None,
92 'queue' : None,
92 'queue' : None,
93 'execute_input' : None,
93 'execute_input' : None,
94 'execute_result': None,
94 'execute_result': None,
95 'error': None,
95 'error': None,
96 'stdout': '',
96 'stdout': '',
97 'stderr': '',
97 'stderr': '',
98 }
98 }
99
99
100
100
101 class EngineConnector(HasTraits):
101 class EngineConnector(HasTraits):
102 """A simple object for accessing the various zmq connections of an object.
102 """A simple object for accessing the various zmq connections of an object.
103 Attributes are:
103 Attributes are:
104 id (int): engine ID
104 id (int): engine ID
105 uuid (unicode): engine UUID
105 uuid (unicode): engine UUID
106 pending: set of msg_ids
106 pending: set of msg_ids
107 stallback: tornado timeout for stalled registration
107 stallback: tornado timeout for stalled registration
108 """
108 """
109
109
110 id = Integer(0)
110 id = Integer(0)
111 uuid = Unicode()
111 uuid = Unicode()
112 pending = Set()
112 pending = Set()
113 stallback = Any()
113 stallback = Any()
114
114
115
115
116 _db_shortcuts = {
116 _db_shortcuts = {
117 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
117 'sqlitedb' : 'ipython_parallel.controller.sqlitedb.SQLiteDB',
118 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
118 'mongodb' : 'ipython_parallel.controller.mongodb.MongoDB',
119 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
119 'dictdb' : 'ipython_parallel.controller.dictdb.DictDB',
120 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
120 'nodb' : 'ipython_parallel.controller.dictdb.NoDB',
121 }
121 }
122
122
123 class HubFactory(RegistrationFactory):
123 class HubFactory(RegistrationFactory):
124 """The Configurable for setting up a Hub."""
124 """The Configurable for setting up a Hub."""
125
125
126 # port-pairs for monitoredqueues:
126 # port-pairs for monitoredqueues:
127 hb = Tuple(Integer,Integer,config=True,
127 hb = Tuple(Integer,Integer,config=True,
128 help="""PUB/ROUTER Port pair for Engine heartbeats""")
128 help="""PUB/ROUTER Port pair for Engine heartbeats""")
129 def _hb_default(self):
129 def _hb_default(self):
130 return tuple(util.select_random_ports(2))
130 return tuple(util.select_random_ports(2))
131
131
132 mux = Tuple(Integer,Integer,config=True,
132 mux = Tuple(Integer,Integer,config=True,
133 help="""Client/Engine Port pair for MUX queue""")
133 help="""Client/Engine Port pair for MUX queue""")
134
134
135 def _mux_default(self):
135 def _mux_default(self):
136 return tuple(util.select_random_ports(2))
136 return tuple(util.select_random_ports(2))
137
137
138 task = Tuple(Integer,Integer,config=True,
138 task = Tuple(Integer,Integer,config=True,
139 help="""Client/Engine Port pair for Task queue""")
139 help="""Client/Engine Port pair for Task queue""")
140 def _task_default(self):
140 def _task_default(self):
141 return tuple(util.select_random_ports(2))
141 return tuple(util.select_random_ports(2))
142
142
143 control = Tuple(Integer,Integer,config=True,
143 control = Tuple(Integer,Integer,config=True,
144 help="""Client/Engine Port pair for Control queue""")
144 help="""Client/Engine Port pair for Control queue""")
145
145
146 def _control_default(self):
146 def _control_default(self):
147 return tuple(util.select_random_ports(2))
147 return tuple(util.select_random_ports(2))
148
148
149 iopub = Tuple(Integer,Integer,config=True,
149 iopub = Tuple(Integer,Integer,config=True,
150 help="""Client/Engine Port pair for IOPub relay""")
150 help="""Client/Engine Port pair for IOPub relay""")
151
151
152 def _iopub_default(self):
152 def _iopub_default(self):
153 return tuple(util.select_random_ports(2))
153 return tuple(util.select_random_ports(2))
154
154
155 # single ports:
155 # single ports:
156 mon_port = Integer(config=True,
156 mon_port = Integer(config=True,
157 help="""Monitor (SUB) port for queue traffic""")
157 help="""Monitor (SUB) port for queue traffic""")
158
158
159 def _mon_port_default(self):
159 def _mon_port_default(self):
160 return util.select_random_ports(1)[0]
160 return util.select_random_ports(1)[0]
161
161
162 notifier_port = Integer(config=True,
162 notifier_port = Integer(config=True,
163 help="""PUB port for sending engine status notifications""")
163 help="""PUB port for sending engine status notifications""")
164
164
165 def _notifier_port_default(self):
165 def _notifier_port_default(self):
166 return util.select_random_ports(1)[0]
166 return util.select_random_ports(1)[0]
167
167
168 engine_ip = Unicode(config=True,
168 engine_ip = Unicode(config=True,
169 help="IP on which to listen for engine connections. [default: loopback]")
169 help="IP on which to listen for engine connections. [default: loopback]")
170 def _engine_ip_default(self):
170 def _engine_ip_default(self):
171 return localhost()
171 return localhost()
172 engine_transport = Unicode('tcp', config=True,
172 engine_transport = Unicode('tcp', config=True,
173 help="0MQ transport for engine connections. [default: tcp]")
173 help="0MQ transport for engine connections. [default: tcp]")
174
174
175 client_ip = Unicode(config=True,
175 client_ip = Unicode(config=True,
176 help="IP on which to listen for client connections. [default: loopback]")
176 help="IP on which to listen for client connections. [default: loopback]")
177 client_transport = Unicode('tcp', config=True,
177 client_transport = Unicode('tcp', config=True,
178 help="0MQ transport for client connections. [default : tcp]")
178 help="0MQ transport for client connections. [default : tcp]")
179
179
180 monitor_ip = Unicode(config=True,
180 monitor_ip = Unicode(config=True,
181 help="IP on which to listen for monitor messages. [default: loopback]")
181 help="IP on which to listen for monitor messages. [default: loopback]")
182 monitor_transport = Unicode('tcp', config=True,
182 monitor_transport = Unicode('tcp', config=True,
183 help="0MQ transport for monitor messages. [default : tcp]")
183 help="0MQ transport for monitor messages. [default : tcp]")
184
184
185 _client_ip_default = _monitor_ip_default = _engine_ip_default
185 _client_ip_default = _monitor_ip_default = _engine_ip_default
186
186
187
187
188 monitor_url = Unicode('')
188 monitor_url = Unicode('')
189
189
190 db_class = DottedObjectName('NoDB',
190 db_class = DottedObjectName('NoDB',
191 config=True, help="""The class to use for the DB backend
191 config=True, help="""The class to use for the DB backend
192
192
193 Options include:
193 Options include:
194
194
195 SQLiteDB: SQLite
195 SQLiteDB: SQLite
196 MongoDB : use MongoDB
196 MongoDB : use MongoDB
197 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
197 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
198 NoDB : disable database altogether (default)
198 NoDB : disable database altogether (default)
199
199
200 """)
200 """)
201
201
202 registration_timeout = Integer(0, config=True,
202 registration_timeout = Integer(0, config=True,
203 help="Engine registration timeout in seconds [default: max(30,"
203 help="Engine registration timeout in seconds [default: max(30,"
204 "10*heartmonitor.period)]" )
204 "10*heartmonitor.period)]" )
205
205
206 def _registration_timeout_default(self):
206 def _registration_timeout_default(self):
207 if self.heartmonitor is None:
207 if self.heartmonitor is None:
208 # early initialization, this value will be ignored
208 # early initialization, this value will be ignored
209 return 0
209 return 0
210 # heartmonitor period is in milliseconds, so 10x in seconds is .01
210 # heartmonitor period is in milliseconds, so 10x in seconds is .01
211 return max(30, int(.01 * self.heartmonitor.period))
211 return max(30, int(.01 * self.heartmonitor.period))
212
212
213 # not configurable
213 # not configurable
214 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
214 db = Instance('ipython_parallel.controller.dictdb.BaseDB')
215 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
215 heartmonitor = Instance('ipython_parallel.controller.heartmonitor.HeartMonitor')
216
216
217 def _ip_changed(self, name, old, new):
217 def _ip_changed(self, name, old, new):
218 self.engine_ip = new
218 self.engine_ip = new
219 self.client_ip = new
219 self.client_ip = new
220 self.monitor_ip = new
220 self.monitor_ip = new
221 self._update_monitor_url()
221 self._update_monitor_url()
222
222
223 def _update_monitor_url(self):
223 def _update_monitor_url(self):
224 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
224 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
225
225
226 def _transport_changed(self, name, old, new):
226 def _transport_changed(self, name, old, new):
227 self.engine_transport = new
227 self.engine_transport = new
228 self.client_transport = new
228 self.client_transport = new
229 self.monitor_transport = new
229 self.monitor_transport = new
230 self._update_monitor_url()
230 self._update_monitor_url()
231
231
232 def __init__(self, **kwargs):
232 def __init__(self, **kwargs):
233 super(HubFactory, self).__init__(**kwargs)
233 super(HubFactory, self).__init__(**kwargs)
234 self._update_monitor_url()
234 self._update_monitor_url()
235
235
236
236
237 def construct(self):
237 def construct(self):
238 self.init_hub()
238 self.init_hub()
239
239
240 def start(self):
240 def start(self):
241 self.heartmonitor.start()
241 self.heartmonitor.start()
242 self.log.info("Heartmonitor started")
242 self.log.info("Heartmonitor started")
243
243
244 def client_url(self, channel):
244 def client_url(self, channel):
245 """return full zmq url for a named client channel"""
245 """return full zmq url for a named client channel"""
246 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
246 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
247
247
248 def engine_url(self, channel):
248 def engine_url(self, channel):
249 """return full zmq url for a named engine channel"""
249 """return full zmq url for a named engine channel"""
250 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
250 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
251
251
252 def init_hub(self):
252 def init_hub(self):
253 """construct Hub object"""
253 """construct Hub object"""
254
254
255 ctx = self.context
255 ctx = self.context
256 loop = self.loop
256 loop = self.loop
257 if 'TaskScheduler.scheme_name' in self.config:
257 if 'TaskScheduler.scheme_name' in self.config:
258 scheme = self.config.TaskScheduler.scheme_name
258 scheme = self.config.TaskScheduler.scheme_name
259 else:
259 else:
260 from .scheduler import TaskScheduler
260 from .scheduler import TaskScheduler
261 scheme = TaskScheduler.scheme_name.get_default_value()
261 scheme = TaskScheduler.scheme_name.get_default_value()
262
262
263 # build connection dicts
263 # build connection dicts
264 engine = self.engine_info = {
264 engine = self.engine_info = {
265 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
265 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
266 'registration' : self.regport,
266 'registration' : self.regport,
267 'control' : self.control[1],
267 'control' : self.control[1],
268 'mux' : self.mux[1],
268 'mux' : self.mux[1],
269 'hb_ping' : self.hb[0],
269 'hb_ping' : self.hb[0],
270 'hb_pong' : self.hb[1],
270 'hb_pong' : self.hb[1],
271 'task' : self.task[1],
271 'task' : self.task[1],
272 'iopub' : self.iopub[1],
272 'iopub' : self.iopub[1],
273 }
273 }
274
274
275 client = self.client_info = {
275 client = self.client_info = {
276 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
276 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
277 'registration' : self.regport,
277 'registration' : self.regport,
278 'control' : self.control[0],
278 'control' : self.control[0],
279 'mux' : self.mux[0],
279 'mux' : self.mux[0],
280 'task' : self.task[0],
280 'task' : self.task[0],
281 'task_scheme' : scheme,
281 'task_scheme' : scheme,
282 'iopub' : self.iopub[0],
282 'iopub' : self.iopub[0],
283 'notification' : self.notifier_port,
283 'notification' : self.notifier_port,
284 }
284 }
285
285
286 self.log.debug("Hub engine addrs: %s", self.engine_info)
286 self.log.debug("Hub engine addrs: %s", self.engine_info)
287 self.log.debug("Hub client addrs: %s", self.client_info)
287 self.log.debug("Hub client addrs: %s", self.client_info)
288
288
289 # Registrar socket
289 # Registrar socket
290 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
290 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
291 util.set_hwm(q, 0)
291 util.set_hwm(q, 0)
292 q.bind(self.client_url('registration'))
292 q.bind(self.client_url('registration'))
293 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
293 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
294 if self.client_ip != self.engine_ip:
294 if self.client_ip != self.engine_ip:
295 q.bind(self.engine_url('registration'))
295 q.bind(self.engine_url('registration'))
296 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
296 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
297
297
298 ### Engine connections ###
298 ### Engine connections ###
299
299
300 # heartbeat
300 # heartbeat
301 hpub = ctx.socket(zmq.PUB)
301 hpub = ctx.socket(zmq.PUB)
302 hpub.bind(self.engine_url('hb_ping'))
302 hpub.bind(self.engine_url('hb_ping'))
303 hrep = ctx.socket(zmq.ROUTER)
303 hrep = ctx.socket(zmq.ROUTER)
304 util.set_hwm(hrep, 0)
304 util.set_hwm(hrep, 0)
305 hrep.bind(self.engine_url('hb_pong'))
305 hrep.bind(self.engine_url('hb_pong'))
306 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
306 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
307 pingstream=ZMQStream(hpub,loop),
307 pingstream=ZMQStream(hpub,loop),
308 pongstream=ZMQStream(hrep,loop)
308 pongstream=ZMQStream(hrep,loop)
309 )
309 )
310
310
311 ### Client connections ###
311 ### Client connections ###
312
312
313 # Notifier socket
313 # Notifier socket
314 n = ZMQStream(ctx.socket(zmq.PUB), loop)
314 n = ZMQStream(ctx.socket(zmq.PUB), loop)
315 n.bind(self.client_url('notification'))
315 n.bind(self.client_url('notification'))
316
316
317 ### build and launch the queues ###
317 ### build and launch the queues ###
318
318
319 # monitor socket
319 # monitor socket
320 sub = ctx.socket(zmq.SUB)
320 sub = ctx.socket(zmq.SUB)
321 sub.setsockopt(zmq.SUBSCRIBE, b"")
321 sub.setsockopt(zmq.SUBSCRIBE, b"")
322 sub.bind(self.monitor_url)
322 sub.bind(self.monitor_url)
323 sub.bind('inproc://monitor')
323 sub.bind('inproc://monitor')
324 sub = ZMQStream(sub, loop)
324 sub = ZMQStream(sub, loop)
325
325
326 # connect the db
326 # connect the db
327 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
327 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
328 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
328 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
329 self.db = import_item(str(db_class))(session=self.session.session,
329 self.db = import_item(str(db_class))(session=self.session.session,
330 parent=self, log=self.log)
330 parent=self, log=self.log)
331 time.sleep(.25)
331 time.sleep(.25)
332
332
333 # resubmit stream
333 # resubmit stream
334 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
334 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
335 url = util.disambiguate_url(self.client_url('task'))
335 url = util.disambiguate_url(self.client_url('task'))
336 r.connect(url)
336 r.connect(url)
337
337
338 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
338 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
339 query=q, notifier=n, resubmit=r, db=self.db,
339 query=q, notifier=n, resubmit=r, db=self.db,
340 engine_info=self.engine_info, client_info=self.client_info,
340 engine_info=self.engine_info, client_info=self.client_info,
341 log=self.log, registration_timeout=self.registration_timeout)
341 log=self.log, registration_timeout=self.registration_timeout)
342
342
343
343
344 class Hub(SessionFactory):
344 class Hub(SessionFactory):
345 """The IPython Controller Hub with 0MQ connections
345 """The IPython Controller Hub with 0MQ connections
346
346
347 Parameters
347 Parameters
348 ==========
348 ==========
349 loop: zmq IOLoop instance
349 loop: zmq IOLoop instance
350 session: Session object
350 session: Session object
351 <removed> context: zmq context for creating new connections (?)
351 <removed> context: zmq context for creating new connections (?)
352 queue: ZMQStream for monitoring the command queue (SUB)
352 queue: ZMQStream for monitoring the command queue (SUB)
353 query: ZMQStream for engine registration and client queries requests (ROUTER)
353 query: ZMQStream for engine registration and client queries requests (ROUTER)
354 heartbeat: HeartMonitor object checking the pulse of the engines
354 heartbeat: HeartMonitor object checking the pulse of the engines
355 notifier: ZMQStream for broadcasting engine registration changes (PUB)
355 notifier: ZMQStream for broadcasting engine registration changes (PUB)
356 db: connection to db for out of memory logging of commands
356 db: connection to db for out of memory logging of commands
357 NotImplemented
357 NotImplemented
358 engine_info: dict of zmq connection information for engines to connect
358 engine_info: dict of zmq connection information for engines to connect
359 to the queues.
359 to the queues.
360 client_info: dict of zmq connection information for engines to connect
360 client_info: dict of zmq connection information for engines to connect
361 to the queues.
361 to the queues.
362 """
362 """
363
363
364 engine_state_file = Unicode()
364 engine_state_file = Unicode()
365
365
366 # internal data structures:
366 # internal data structures:
367 ids=Set() # engine IDs
367 ids=Set() # engine IDs
368 keytable=Dict()
368 keytable=Dict()
369 by_ident=Dict()
369 by_ident=Dict()
370 engines=Dict()
370 engines=Dict()
371 clients=Dict()
371 clients=Dict()
372 hearts=Dict()
372 hearts=Dict()
373 pending=Set()
373 pending=Set()
374 queues=Dict() # pending msg_ids keyed by engine_id
374 queues=Dict() # pending msg_ids keyed by engine_id
375 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
375 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
376 completed=Dict() # completed msg_ids keyed by engine_id
376 completed=Dict() # completed msg_ids keyed by engine_id
377 all_completed=Set() # completed msg_ids keyed by engine_id
377 all_completed=Set() # completed msg_ids keyed by engine_id
378 dead_engines=Set() # completed msg_ids keyed by engine_id
378 dead_engines=Set() # completed msg_ids keyed by engine_id
379 unassigned=Set() # set of task msg_ds not yet assigned a destination
379 unassigned=Set() # set of task msg_ds not yet assigned a destination
380 incoming_registrations=Dict()
380 incoming_registrations=Dict()
381 registration_timeout=Integer()
381 registration_timeout=Integer()
382 _idcounter=Integer(0)
382 _idcounter=Integer(0)
383
383
384 # objects from constructor:
384 # objects from constructor:
385 query=Instance(ZMQStream)
385 query=Instance(ZMQStream)
386 monitor=Instance(ZMQStream)
386 monitor=Instance(ZMQStream)
387 notifier=Instance(ZMQStream)
387 notifier=Instance(ZMQStream)
388 resubmit=Instance(ZMQStream)
388 resubmit=Instance(ZMQStream)
389 heartmonitor=Instance(HeartMonitor)
389 heartmonitor=Instance(HeartMonitor)
390 db=Instance(object)
390 db=Instance(object)
391 client_info=Dict()
391 client_info=Dict()
392 engine_info=Dict()
392 engine_info=Dict()
393
393
394
394
395 def __init__(self, **kwargs):
395 def __init__(self, **kwargs):
396 """
396 """
397 # universal:
397 # universal:
398 loop: IOLoop for creating future connections
398 loop: IOLoop for creating future connections
399 session: streamsession for sending serialized data
399 session: streamsession for sending serialized data
400 # engine:
400 # engine:
401 queue: ZMQStream for monitoring queue messages
401 queue: ZMQStream for monitoring queue messages
402 query: ZMQStream for engine+client registration and client requests
402 query: ZMQStream for engine+client registration and client requests
403 heartbeat: HeartMonitor object for tracking engines
403 heartbeat: HeartMonitor object for tracking engines
404 # extra:
404 # extra:
405 db: ZMQStream for db connection (NotImplemented)
405 db: ZMQStream for db connection (NotImplemented)
406 engine_info: zmq address/protocol dict for engine connections
406 engine_info: zmq address/protocol dict for engine connections
407 client_info: zmq address/protocol dict for client connections
407 client_info: zmq address/protocol dict for client connections
408 """
408 """
409
409
410 super(Hub, self).__init__(**kwargs)
410 super(Hub, self).__init__(**kwargs)
411
411
412 # register our callbacks
412 # register our callbacks
413 self.query.on_recv(self.dispatch_query)
413 self.query.on_recv(self.dispatch_query)
414 self.monitor.on_recv(self.dispatch_monitor_traffic)
414 self.monitor.on_recv(self.dispatch_monitor_traffic)
415
415
416 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
416 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
417 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
417 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
418
418
419 self.monitor_handlers = {b'in' : self.save_queue_request,
419 self.monitor_handlers = {b'in' : self.save_queue_request,
420 b'out': self.save_queue_result,
420 b'out': self.save_queue_result,
421 b'intask': self.save_task_request,
421 b'intask': self.save_task_request,
422 b'outtask': self.save_task_result,
422 b'outtask': self.save_task_result,
423 b'tracktask': self.save_task_destination,
423 b'tracktask': self.save_task_destination,
424 b'incontrol': _passer,
424 b'incontrol': _passer,
425 b'outcontrol': _passer,
425 b'outcontrol': _passer,
426 b'iopub': self.save_iopub_message,
426 b'iopub': self.save_iopub_message,
427 }
427 }
428
428
429 self.query_handlers = {'queue_request': self.queue_status,
429 self.query_handlers = {'queue_request': self.queue_status,
430 'result_request': self.get_results,
430 'result_request': self.get_results,
431 'history_request': self.get_history,
431 'history_request': self.get_history,
432 'db_request': self.db_query,
432 'db_request': self.db_query,
433 'purge_request': self.purge_results,
433 'purge_request': self.purge_results,
434 'load_request': self.check_load,
434 'load_request': self.check_load,
435 'resubmit_request': self.resubmit_task,
435 'resubmit_request': self.resubmit_task,
436 'shutdown_request': self.shutdown_request,
436 'shutdown_request': self.shutdown_request,
437 'registration_request' : self.register_engine,
437 'registration_request' : self.register_engine,
438 'unregistration_request' : self.unregister_engine,
438 'unregistration_request' : self.unregister_engine,
439 'connection_request': self.connection_request,
439 'connection_request': self.connection_request,
440 }
440 }
441
441
442 # ignore resubmit replies
442 # ignore resubmit replies
443 self.resubmit.on_recv(lambda msg: None, copy=False)
443 self.resubmit.on_recv(lambda msg: None, copy=False)
444
444
445 self.log.info("hub::created hub")
445 self.log.info("hub::created hub")
446
446
447 @property
447 @property
448 def _next_id(self):
448 def _next_id(self):
449 """gemerate a new ID.
449 """gemerate a new ID.
450
450
451 No longer reuse old ids, just count from 0."""
451 No longer reuse old ids, just count from 0."""
452 newid = self._idcounter
452 newid = self._idcounter
453 self._idcounter += 1
453 self._idcounter += 1
454 return newid
454 return newid
455 # newid = 0
455 # newid = 0
456 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
456 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
457 # # print newid, self.ids, self.incoming_registrations
457 # # print newid, self.ids, self.incoming_registrations
458 # while newid in self.ids or newid in incoming:
458 # while newid in self.ids or newid in incoming:
459 # newid += 1
459 # newid += 1
460 # return newid
460 # return newid
461
461
462 #-----------------------------------------------------------------------------
462 #-----------------------------------------------------------------------------
463 # message validation
463 # message validation
464 #-----------------------------------------------------------------------------
464 #-----------------------------------------------------------------------------
465
465
466 def _validate_targets(self, targets):
466 def _validate_targets(self, targets):
467 """turn any valid targets argument into a list of integer ids"""
467 """turn any valid targets argument into a list of integer ids"""
468 if targets is None:
468 if targets is None:
469 # default to all
469 # default to all
470 return self.ids
470 return self.ids
471
471
472 if isinstance(targets, (int,str,unicode_type)):
472 if isinstance(targets, (int,str,unicode_type)):
473 # only one target specified
473 # only one target specified
474 targets = [targets]
474 targets = [targets]
475 _targets = []
475 _targets = []
476 for t in targets:
476 for t in targets:
477 # map raw identities to ids
477 # map raw identities to ids
478 if isinstance(t, (str,unicode_type)):
478 if isinstance(t, (str,unicode_type)):
479 t = self.by_ident.get(cast_bytes(t), t)
479 t = self.by_ident.get(cast_bytes(t), t)
480 _targets.append(t)
480 _targets.append(t)
481 targets = _targets
481 targets = _targets
482 bad_targets = [ t for t in targets if t not in self.ids ]
482 bad_targets = [ t for t in targets if t not in self.ids ]
483 if bad_targets:
483 if bad_targets:
484 raise IndexError("No Such Engine: %r" % bad_targets)
484 raise IndexError("No Such Engine: %r" % bad_targets)
485 if not targets:
485 if not targets:
486 raise IndexError("No Engines Registered")
486 raise IndexError("No Engines Registered")
487 return targets
487 return targets
488
488
489 #-----------------------------------------------------------------------------
489 #-----------------------------------------------------------------------------
490 # dispatch methods (1 per stream)
490 # dispatch methods (1 per stream)
491 #-----------------------------------------------------------------------------
491 #-----------------------------------------------------------------------------
492
492
493
493
494 @util.log_errors
494 @util.log_errors
495 def dispatch_monitor_traffic(self, msg):
495 def dispatch_monitor_traffic(self, msg):
496 """all ME and Task queue messages come through here, as well as
496 """all ME and Task queue messages come through here, as well as
497 IOPub traffic."""
497 IOPub traffic."""
498 self.log.debug("monitor traffic: %r", msg[0])
498 self.log.debug("monitor traffic: %r", msg[0])
499 switch = msg[0]
499 switch = msg[0]
500 try:
500 try:
501 idents, msg = self.session.feed_identities(msg[1:])
501 idents, msg = self.session.feed_identities(msg[1:])
502 except ValueError:
502 except ValueError:
503 idents=[]
503 idents=[]
504 if not idents:
504 if not idents:
505 self.log.error("Monitor message without topic: %r", msg)
505 self.log.error("Monitor message without topic: %r", msg)
506 return
506 return
507 handler = self.monitor_handlers.get(switch, None)
507 handler = self.monitor_handlers.get(switch, None)
508 if handler is not None:
508 if handler is not None:
509 handler(idents, msg)
509 handler(idents, msg)
510 else:
510 else:
511 self.log.error("Unrecognized monitor topic: %r", switch)
511 self.log.error("Unrecognized monitor topic: %r", switch)
512
512
513
513
514 @util.log_errors
514 @util.log_errors
515 def dispatch_query(self, msg):
515 def dispatch_query(self, msg):
516 """Route registration requests and queries from clients."""
516 """Route registration requests and queries from clients."""
517 try:
517 try:
518 idents, msg = self.session.feed_identities(msg)
518 idents, msg = self.session.feed_identities(msg)
519 except ValueError:
519 except ValueError:
520 idents = []
520 idents = []
521 if not idents:
521 if not idents:
522 self.log.error("Bad Query Message: %r", msg)
522 self.log.error("Bad Query Message: %r", msg)
523 return
523 return
524 client_id = idents[0]
524 client_id = idents[0]
525 try:
525 try:
526 msg = self.session.deserialize(msg, content=True)
526 msg = self.session.deserialize(msg, content=True)
527 except Exception:
527 except Exception:
528 content = error.wrap_exception()
528 content = error.wrap_exception()
529 self.log.error("Bad Query Message: %r", msg, exc_info=True)
529 self.log.error("Bad Query Message: %r", msg, exc_info=True)
530 self.session.send(self.query, "hub_error", ident=client_id,
530 self.session.send(self.query, "hub_error", ident=client_id,
531 content=content)
531 content=content)
532 return
532 return
533 # print client_id, header, parent, content
533 # print client_id, header, parent, content
534 #switch on message type:
534 #switch on message type:
535 msg_type = msg['header']['msg_type']
535 msg_type = msg['header']['msg_type']
536 self.log.info("client::client %r requested %r", client_id, msg_type)
536 self.log.info("client::client %r requested %r", client_id, msg_type)
537 handler = self.query_handlers.get(msg_type, None)
537 handler = self.query_handlers.get(msg_type, None)
538 try:
538 try:
539 assert handler is not None, "Bad Message Type: %r" % msg_type
539 assert handler is not None, "Bad Message Type: %r" % msg_type
540 except:
540 except:
541 content = error.wrap_exception()
541 content = error.wrap_exception()
542 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
542 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
543 self.session.send(self.query, "hub_error", ident=client_id,
543 self.session.send(self.query, "hub_error", ident=client_id,
544 content=content)
544 content=content)
545 return
545 return
546
546
547 else:
547 else:
548 handler(idents, msg)
548 handler(idents, msg)
549
549
550 def dispatch_db(self, msg):
550 def dispatch_db(self, msg):
551 """"""
551 """"""
552 raise NotImplementedError
552 raise NotImplementedError
553
553
554 #---------------------------------------------------------------------------
554 #---------------------------------------------------------------------------
555 # handler methods (1 per event)
555 # handler methods (1 per event)
556 #---------------------------------------------------------------------------
556 #---------------------------------------------------------------------------
557
557
558 #----------------------- Heartbeat --------------------------------------
558 #----------------------- Heartbeat --------------------------------------
559
559
560 def handle_new_heart(self, heart):
560 def handle_new_heart(self, heart):
561 """handler to attach to heartbeater.
561 """handler to attach to heartbeater.
562 Called when a new heart starts to beat.
562 Called when a new heart starts to beat.
563 Triggers completion of registration."""
563 Triggers completion of registration."""
564 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
564 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
565 if heart not in self.incoming_registrations:
565 if heart not in self.incoming_registrations:
566 self.log.info("heartbeat::ignoring new heart: %r", heart)
566 self.log.info("heartbeat::ignoring new heart: %r", heart)
567 else:
567 else:
568 self.finish_registration(heart)
568 self.finish_registration(heart)
569
569
570
570
571 def handle_heart_failure(self, heart):
571 def handle_heart_failure(self, heart):
572 """handler to attach to heartbeater.
572 """handler to attach to heartbeater.
573 called when a previously registered heart fails to respond to beat request.
573 called when a previously registered heart fails to respond to beat request.
574 triggers unregistration"""
574 triggers unregistration"""
575 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
575 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
576 eid = self.hearts.get(heart, None)
576 eid = self.hearts.get(heart, None)
577 uuid = self.engines[eid].uuid
577 uuid = self.engines[eid].uuid
578 if eid is None or self.keytable[eid] in self.dead_engines:
578 if eid is None or self.keytable[eid] in self.dead_engines:
579 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
579 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
580 else:
580 else:
581 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
581 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
582
582
583 #----------------------- MUX Queue Traffic ------------------------------
583 #----------------------- MUX Queue Traffic ------------------------------
584
584
585 def save_queue_request(self, idents, msg):
585 def save_queue_request(self, idents, msg):
586 if len(idents) < 2:
586 if len(idents) < 2:
587 self.log.error("invalid identity prefix: %r", idents)
587 self.log.error("invalid identity prefix: %r", idents)
588 return
588 return
589 queue_id, client_id = idents[:2]
589 queue_id, client_id = idents[:2]
590 try:
590 try:
591 msg = self.session.deserialize(msg)
591 msg = self.session.deserialize(msg)
592 except Exception:
592 except Exception:
593 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
593 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
594 return
594 return
595
595
596 eid = self.by_ident.get(queue_id, None)
596 eid = self.by_ident.get(queue_id, None)
597 if eid is None:
597 if eid is None:
598 self.log.error("queue::target %r not registered", queue_id)
598 self.log.error("queue::target %r not registered", queue_id)
599 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
599 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
600 return
600 return
601 record = init_record(msg)
601 record = init_record(msg)
602 msg_id = record['msg_id']
602 msg_id = record['msg_id']
603 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
603 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
604 # Unicode in records
604 # Unicode in records
605 record['engine_uuid'] = queue_id.decode('ascii')
605 record['engine_uuid'] = queue_id.decode('ascii')
606 record['client_uuid'] = msg['header']['session']
606 record['client_uuid'] = msg['header']['session']
607 record['queue'] = 'mux'
607 record['queue'] = 'mux'
608
608
609 try:
609 try:
610 # it's posible iopub arrived first:
610 # it's posible iopub arrived first:
611 existing = self.db.get_record(msg_id)
611 existing = self.db.get_record(msg_id)
612 for key,evalue in iteritems(existing):
612 for key,evalue in iteritems(existing):
613 rvalue = record.get(key, None)
613 rvalue = record.get(key, None)
614 if evalue and rvalue and evalue != rvalue:
614 if evalue and rvalue and evalue != rvalue:
615 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
615 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
616 elif evalue and not rvalue:
616 elif evalue and not rvalue:
617 record[key] = evalue
617 record[key] = evalue
618 try:
618 try:
619 self.db.update_record(msg_id, record)
619 self.db.update_record(msg_id, record)
620 except Exception:
620 except Exception:
621 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
621 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
622 except KeyError:
622 except KeyError:
623 try:
623 try:
624 self.db.add_record(msg_id, record)
624 self.db.add_record(msg_id, record)
625 except Exception:
625 except Exception:
626 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
626 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
627
627
628
628
629 self.pending.add(msg_id)
629 self.pending.add(msg_id)
630 self.queues[eid].append(msg_id)
630 self.queues[eid].append(msg_id)
631
631
632 def save_queue_result(self, idents, msg):
632 def save_queue_result(self, idents, msg):
633 if len(idents) < 2:
633 if len(idents) < 2:
634 self.log.error("invalid identity prefix: %r", idents)
634 self.log.error("invalid identity prefix: %r", idents)
635 return
635 return
636
636
637 client_id, queue_id = idents[:2]
637 client_id, queue_id = idents[:2]
638 try:
638 try:
639 msg = self.session.deserialize(msg)
639 msg = self.session.deserialize(msg)
640 except Exception:
640 except Exception:
641 self.log.error("queue::engine %r sent invalid message to %r: %r",
641 self.log.error("queue::engine %r sent invalid message to %r: %r",
642 queue_id, client_id, msg, exc_info=True)
642 queue_id, client_id, msg, exc_info=True)
643 return
643 return
644
644
645 eid = self.by_ident.get(queue_id, None)
645 eid = self.by_ident.get(queue_id, None)
646 if eid is None:
646 if eid is None:
647 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
647 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
648 return
648 return
649
649
650 parent = msg['parent_header']
650 parent = msg['parent_header']
651 if not parent:
651 if not parent:
652 return
652 return
653 msg_id = parent['msg_id']
653 msg_id = parent['msg_id']
654 if msg_id in self.pending:
654 if msg_id in self.pending:
655 self.pending.remove(msg_id)
655 self.pending.remove(msg_id)
656 self.all_completed.add(msg_id)
656 self.all_completed.add(msg_id)
657 self.queues[eid].remove(msg_id)
657 self.queues[eid].remove(msg_id)
658 self.completed[eid].append(msg_id)
658 self.completed[eid].append(msg_id)
659 self.log.info("queue::request %r completed on %s", msg_id, eid)
659 self.log.info("queue::request %r completed on %s", msg_id, eid)
660 elif msg_id not in self.all_completed:
660 elif msg_id not in self.all_completed:
661 # it could be a result from a dead engine that died before delivering the
661 # it could be a result from a dead engine that died before delivering the
662 # result
662 # result
663 self.log.warn("queue:: unknown msg finished %r", msg_id)
663 self.log.warn("queue:: unknown msg finished %r", msg_id)
664 return
664 return
665 # update record anyway, because the unregistration could have been premature
665 # update record anyway, because the unregistration could have been premature
666 rheader = msg['header']
666 rheader = msg['header']
667 md = msg['metadata']
667 md = msg['metadata']
668 completed = rheader['date']
668 completed = rheader['date']
669 started = extract_dates(md.get('started', None))
669 started = extract_dates(md.get('started', None))
670 result = {
670 result = {
671 'result_header' : rheader,
671 'result_header' : rheader,
672 'result_metadata': md,
672 'result_metadata': md,
673 'result_content': msg['content'],
673 'result_content': msg['content'],
674 'received': datetime.now(),
674 'received': datetime.now(),
675 'started' : started,
675 'started' : started,
676 'completed' : completed
676 'completed' : completed
677 }
677 }
678
678
679 result['result_buffers'] = msg['buffers']
679 result['result_buffers'] = msg['buffers']
680 try:
680 try:
681 self.db.update_record(msg_id, result)
681 self.db.update_record(msg_id, result)
682 except Exception:
682 except Exception:
683 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
683 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
684
684
685
685
686 #--------------------- Task Queue Traffic ------------------------------
686 #--------------------- Task Queue Traffic ------------------------------
687
687
688 def save_task_request(self, idents, msg):
688 def save_task_request(self, idents, msg):
689 """Save the submission of a task."""
689 """Save the submission of a task."""
690 client_id = idents[0]
690 client_id = idents[0]
691
691
692 try:
692 try:
693 msg = self.session.deserialize(msg)
693 msg = self.session.deserialize(msg)
694 except Exception:
694 except Exception:
695 self.log.error("task::client %r sent invalid task message: %r",
695 self.log.error("task::client %r sent invalid task message: %r",
696 client_id, msg, exc_info=True)
696 client_id, msg, exc_info=True)
697 return
697 return
698 record = init_record(msg)
698 record = init_record(msg)
699
699
700 record['client_uuid'] = msg['header']['session']
700 record['client_uuid'] = msg['header']['session']
701 record['queue'] = 'task'
701 record['queue'] = 'task'
702 header = msg['header']
702 header = msg['header']
703 msg_id = header['msg_id']
703 msg_id = header['msg_id']
704 self.pending.add(msg_id)
704 self.pending.add(msg_id)
705 self.unassigned.add(msg_id)
705 self.unassigned.add(msg_id)
706 try:
706 try:
707 # it's posible iopub arrived first:
707 # it's posible iopub arrived first:
708 existing = self.db.get_record(msg_id)
708 existing = self.db.get_record(msg_id)
709 if existing['resubmitted']:
709 if existing['resubmitted']:
710 for key in ('submitted', 'client_uuid', 'buffers'):
710 for key in ('submitted', 'client_uuid', 'buffers'):
711 # don't clobber these keys on resubmit
711 # don't clobber these keys on resubmit
712 # submitted and client_uuid should be different
712 # submitted and client_uuid should be different
713 # and buffers might be big, and shouldn't have changed
713 # and buffers might be big, and shouldn't have changed
714 record.pop(key)
714 record.pop(key)
715 # still check content,header which should not change
715 # still check content,header which should not change
716 # but are not expensive to compare as buffers
716 # but are not expensive to compare as buffers
717
717
718 for key,evalue in iteritems(existing):
718 for key,evalue in iteritems(existing):
719 if key.endswith('buffers'):
719 if key.endswith('buffers'):
720 # don't compare buffers
720 # don't compare buffers
721 continue
721 continue
722 rvalue = record.get(key, None)
722 rvalue = record.get(key, None)
723 if evalue and rvalue and evalue != rvalue:
723 if evalue and rvalue and evalue != rvalue:
724 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
724 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
725 elif evalue and not rvalue:
725 elif evalue and not rvalue:
726 record[key] = evalue
726 record[key] = evalue
727 try:
727 try:
728 self.db.update_record(msg_id, record)
728 self.db.update_record(msg_id, record)
729 except Exception:
729 except Exception:
730 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
730 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
731 except KeyError:
731 except KeyError:
732 try:
732 try:
733 self.db.add_record(msg_id, record)
733 self.db.add_record(msg_id, record)
734 except Exception:
734 except Exception:
735 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
735 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
736 except Exception:
736 except Exception:
737 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
737 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
738
738
739 def save_task_result(self, idents, msg):
739 def save_task_result(self, idents, msg):
740 """save the result of a completed task."""
740 """save the result of a completed task."""
741 client_id = idents[0]
741 client_id = idents[0]
742 try:
742 try:
743 msg = self.session.deserialize(msg)
743 msg = self.session.deserialize(msg)
744 except Exception:
744 except Exception:
745 self.log.error("task::invalid task result message send to %r: %r",
745 self.log.error("task::invalid task result message send to %r: %r",
746 client_id, msg, exc_info=True)
746 client_id, msg, exc_info=True)
747 return
747 return
748
748
749 parent = msg['parent_header']
749 parent = msg['parent_header']
750 if not parent:
750 if not parent:
751 # print msg
751 # print msg
752 self.log.warn("Task %r had no parent!", msg)
752 self.log.warn("Task %r had no parent!", msg)
753 return
753 return
754 msg_id = parent['msg_id']
754 msg_id = parent['msg_id']
755 if msg_id in self.unassigned:
755 if msg_id in self.unassigned:
756 self.unassigned.remove(msg_id)
756 self.unassigned.remove(msg_id)
757
757
758 header = msg['header']
758 header = msg['header']
759 md = msg['metadata']
759 md = msg['metadata']
760 engine_uuid = md.get('engine', u'')
760 engine_uuid = md.get('engine', u'')
761 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
761 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
762
762
763 status = md.get('status', None)
763 status = md.get('status', None)
764
764
765 if msg_id in self.pending:
765 if msg_id in self.pending:
766 self.log.info("task::task %r finished on %s", msg_id, eid)
766 self.log.info("task::task %r finished on %s", msg_id, eid)
767 self.pending.remove(msg_id)
767 self.pending.remove(msg_id)
768 self.all_completed.add(msg_id)
768 self.all_completed.add(msg_id)
769 if eid is not None:
769 if eid is not None:
770 if status != 'aborted':
770 if status != 'aborted':
771 self.completed[eid].append(msg_id)
771 self.completed[eid].append(msg_id)
772 if msg_id in self.tasks[eid]:
772 if msg_id in self.tasks[eid]:
773 self.tasks[eid].remove(msg_id)
773 self.tasks[eid].remove(msg_id)
774 completed = header['date']
774 completed = header['date']
775 started = extract_dates(md.get('started', None))
775 started = extract_dates(md.get('started', None))
776 result = {
776 result = {
777 'result_header' : header,
777 'result_header' : header,
778 'result_metadata': msg['metadata'],
778 'result_metadata': msg['metadata'],
779 'result_content': msg['content'],
779 'result_content': msg['content'],
780 'started' : started,
780 'started' : started,
781 'completed' : completed,
781 'completed' : completed,
782 'received' : datetime.now(),
782 'received' : datetime.now(),
783 'engine_uuid': engine_uuid,
783 'engine_uuid': engine_uuid,
784 }
784 }
785
785
786 result['result_buffers'] = msg['buffers']
786 result['result_buffers'] = msg['buffers']
787 try:
787 try:
788 self.db.update_record(msg_id, result)
788 self.db.update_record(msg_id, result)
789 except Exception:
789 except Exception:
790 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
790 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
791
791
792 else:
792 else:
793 self.log.debug("task::unknown task %r finished", msg_id)
793 self.log.debug("task::unknown task %r finished", msg_id)
794
794
795 def save_task_destination(self, idents, msg):
795 def save_task_destination(self, idents, msg):
796 try:
796 try:
797 msg = self.session.deserialize(msg, content=True)
797 msg = self.session.deserialize(msg, content=True)
798 except Exception:
798 except Exception:
799 self.log.error("task::invalid task tracking message", exc_info=True)
799 self.log.error("task::invalid task tracking message", exc_info=True)
800 return
800 return
801 content = msg['content']
801 content = msg['content']
802 # print (content)
802 # print (content)
803 msg_id = content['msg_id']
803 msg_id = content['msg_id']
804 engine_uuid = content['engine_id']
804 engine_uuid = content['engine_id']
805 eid = self.by_ident[cast_bytes(engine_uuid)]
805 eid = self.by_ident[cast_bytes(engine_uuid)]
806
806
807 self.log.info("task::task %r arrived on %r", msg_id, eid)
807 self.log.info("task::task %r arrived on %r", msg_id, eid)
808 if msg_id in self.unassigned:
808 if msg_id in self.unassigned:
809 self.unassigned.remove(msg_id)
809 self.unassigned.remove(msg_id)
810 # else:
810 # else:
811 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
811 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
812
812
813 self.tasks[eid].append(msg_id)
813 self.tasks[eid].append(msg_id)
814 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
814 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
815 try:
815 try:
816 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
816 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
817 except Exception:
817 except Exception:
818 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
818 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
819
819
820
820
821 def mia_task_request(self, idents, msg):
821 def mia_task_request(self, idents, msg):
822 raise NotImplementedError
822 raise NotImplementedError
823 client_id = idents[0]
823 client_id = idents[0]
824 # content = dict(mia=self.mia,status='ok')
824 # content = dict(mia=self.mia,status='ok')
825 # self.session.send('mia_reply', content=content, idents=client_id)
825 # self.session.send('mia_reply', content=content, idents=client_id)
826
826
827
827
828 #--------------------- IOPub Traffic ------------------------------
828 #--------------------- IOPub Traffic ------------------------------
829
829
830 def save_iopub_message(self, topics, msg):
830 def save_iopub_message(self, topics, msg):
831 """save an iopub message into the db"""
831 """save an iopub message into the db"""
832 # print (topics)
832 # print (topics)
833 try:
833 try:
834 msg = self.session.deserialize(msg, content=True)
834 msg = self.session.deserialize(msg, content=True)
835 except Exception:
835 except Exception:
836 self.log.error("iopub::invalid IOPub message", exc_info=True)
836 self.log.error("iopub::invalid IOPub message", exc_info=True)
837 return
837 return
838
838
839 parent = msg['parent_header']
839 parent = msg['parent_header']
840 if not parent:
840 if not parent:
841 self.log.debug("iopub::IOPub message lacks parent: %r", msg)
841 self.log.debug("iopub::IOPub message lacks parent: %r", msg)
842 return
842 return
843 msg_id = parent['msg_id']
843 msg_id = parent['msg_id']
844 msg_type = msg['header']['msg_type']
844 msg_type = msg['header']['msg_type']
845 content = msg['content']
845 content = msg['content']
846
846
847 # ensure msg_id is in db
847 # ensure msg_id is in db
848 try:
848 try:
849 rec = self.db.get_record(msg_id)
849 rec = self.db.get_record(msg_id)
850 except KeyError:
850 except KeyError:
851 rec = None
851 rec = None
852
852
853 # stream
853 # stream
854 d = {}
854 d = {}
855 if msg_type == 'stream':
855 if msg_type == 'stream':
856 name = content['name']
856 name = content['name']
857 s = '' if rec is None else rec[name]
857 s = '' if rec is None else rec[name]
858 d[name] = s + content['text']
858 d[name] = s + content['text']
859
859
860 elif msg_type == 'error':
860 elif msg_type == 'error':
861 d['error'] = content
861 d['error'] = content
862 elif msg_type == 'execute_input':
862 elif msg_type == 'execute_input':
863 d['execute_input'] = content['code']
863 d['execute_input'] = content['code']
864 elif msg_type in ('display_data', 'execute_result'):
864 elif msg_type in ('display_data', 'execute_result'):
865 d[msg_type] = content
865 d[msg_type] = content
866 elif msg_type == 'status':
866 elif msg_type == 'status':
867 pass
867 pass
868 elif msg_type == 'data_pub':
868 elif msg_type == 'data_pub':
869 self.log.info("ignored data_pub message for %s" % msg_id)
869 self.log.info("ignored data_pub message for %s" % msg_id)
870 else:
870 else:
871 self.log.warn("unhandled iopub msg_type: %r", msg_type)
871 self.log.warn("unhandled iopub msg_type: %r", msg_type)
872
872
873 if not d:
873 if not d:
874 return
874 return
875
875
876 if rec is None:
876 if rec is None:
877 # new record
877 # new record
878 rec = empty_record()
878 rec = empty_record()
879 rec['msg_id'] = msg_id
879 rec['msg_id'] = msg_id
880 rec.update(d)
880 rec.update(d)
881 d = rec
881 d = rec
882 update_record = self.db.add_record
882 update_record = self.db.add_record
883 else:
883 else:
884 update_record = self.db.update_record
884 update_record = self.db.update_record
885
885
886 try:
886 try:
887 update_record(msg_id, d)
887 update_record(msg_id, d)
888 except Exception:
888 except Exception:
889 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
889 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
890
890
891
891
892
892
893 #-------------------------------------------------------------------------
893 #-------------------------------------------------------------------------
894 # Registration requests
894 # Registration requests
895 #-------------------------------------------------------------------------
895 #-------------------------------------------------------------------------
896
896
897 def connection_request(self, client_id, msg):
897 def connection_request(self, client_id, msg):
898 """Reply with connection addresses for clients."""
898 """Reply with connection addresses for clients."""
899 self.log.info("client::client %r connected", client_id)
899 self.log.info("client::client %r connected", client_id)
900 content = dict(status='ok')
900 content = dict(status='ok')
901 jsonable = {}
901 jsonable = {}
902 for k,v in iteritems(self.keytable):
902 for k,v in iteritems(self.keytable):
903 if v not in self.dead_engines:
903 if v not in self.dead_engines:
904 jsonable[str(k)] = v
904 jsonable[str(k)] = v
905 content['engines'] = jsonable
905 content['engines'] = jsonable
906 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
906 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
907
907
908 def register_engine(self, reg, msg):
908 def register_engine(self, reg, msg):
909 """Register a new engine."""
909 """Register a new engine."""
910 content = msg['content']
910 content = msg['content']
911 try:
911 try:
912 uuid = content['uuid']
912 uuid = content['uuid']
913 except KeyError:
913 except KeyError:
914 self.log.error("registration::queue not specified", exc_info=True)
914 self.log.error("registration::queue not specified", exc_info=True)
915 return
915 return
916
916
917 eid = self._next_id
917 eid = self._next_id
918
918
919 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
919 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
920
920
921 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
921 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
922 # check if requesting available IDs:
922 # check if requesting available IDs:
923 if cast_bytes(uuid) in self.by_ident:
923 if cast_bytes(uuid) in self.by_ident:
924 try:
924 try:
925 raise KeyError("uuid %r in use" % uuid)
925 raise KeyError("uuid %r in use" % uuid)
926 except:
926 except:
927 content = error.wrap_exception()
927 content = error.wrap_exception()
928 self.log.error("uuid %r in use", uuid, exc_info=True)
928 self.log.error("uuid %r in use", uuid, exc_info=True)
929 else:
929 else:
930 for h, ec in iteritems(self.incoming_registrations):
930 for h, ec in iteritems(self.incoming_registrations):
931 if uuid == h:
931 if uuid == h:
932 try:
932 try:
933 raise KeyError("heart_id %r in use" % uuid)
933 raise KeyError("heart_id %r in use" % uuid)
934 except:
934 except:
935 self.log.error("heart_id %r in use", uuid, exc_info=True)
935 self.log.error("heart_id %r in use", uuid, exc_info=True)
936 content = error.wrap_exception()
936 content = error.wrap_exception()
937 break
937 break
938 elif uuid == ec.uuid:
938 elif uuid == ec.uuid:
939 try:
939 try:
940 raise KeyError("uuid %r in use" % uuid)
940 raise KeyError("uuid %r in use" % uuid)
941 except:
941 except:
942 self.log.error("uuid %r in use", uuid, exc_info=True)
942 self.log.error("uuid %r in use", uuid, exc_info=True)
943 content = error.wrap_exception()
943 content = error.wrap_exception()
944 break
944 break
945
945
946 msg = self.session.send(self.query, "registration_reply",
946 msg = self.session.send(self.query, "registration_reply",
947 content=content,
947 content=content,
948 ident=reg)
948 ident=reg)
949
949
950 heart = cast_bytes(uuid)
950 heart = cast_bytes(uuid)
951
951
952 if content['status'] == 'ok':
952 if content['status'] == 'ok':
953 if heart in self.heartmonitor.hearts:
953 if heart in self.heartmonitor.hearts:
954 # already beating
954 # already beating
955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
956 self.finish_registration(heart)
956 self.finish_registration(heart)
957 else:
957 else:
958 purge = lambda : self._purge_stalled_registration(heart)
958 purge = lambda : self._purge_stalled_registration(heart)
959 t = self.loop.add_timeout(
959 t = self.loop.add_timeout(
960 self.loop.time() + self.registration_timeout,
960 self.loop.time() + self.registration_timeout,
961 purge,
961 purge,
962 )
962 )
963 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=t)
963 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=t)
964 else:
964 else:
965 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
965 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
966
966
967 return eid
967 return eid
968
968
969 def unregister_engine(self, ident, msg):
969 def unregister_engine(self, ident, msg):
970 """Unregister an engine that explicitly requested to leave."""
970 """Unregister an engine that explicitly requested to leave."""
971 try:
971 try:
972 eid = msg['content']['id']
972 eid = msg['content']['id']
973 except:
973 except:
974 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
974 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
975 return
975 return
976 self.log.info("registration::unregister_engine(%r)", eid)
976 self.log.info("registration::unregister_engine(%r)", eid)
977
977
978 uuid = self.keytable[eid]
978 uuid = self.keytable[eid]
979 content=dict(id=eid, uuid=uuid)
979 content=dict(id=eid, uuid=uuid)
980 self.dead_engines.add(uuid)
980 self.dead_engines.add(uuid)
981
981
982 self.loop.add_timeout(
982 self.loop.add_timeout(
983 self.loop.time() + self.registration_timeout,
983 self.loop.time() + self.registration_timeout,
984 lambda : self._handle_stranded_msgs(eid, uuid),
984 lambda : self._handle_stranded_msgs(eid, uuid),
985 )
985 )
986 ############## TODO: HANDLE IT ################
986 ############## TODO: HANDLE IT ################
987
987
988 self._save_engine_state()
988 self._save_engine_state()
989
989
990 if self.notifier:
990 if self.notifier:
991 self.session.send(self.notifier, "unregistration_notification", content=content)
991 self.session.send(self.notifier, "unregistration_notification", content=content)
992
992
993 def _handle_stranded_msgs(self, eid, uuid):
993 def _handle_stranded_msgs(self, eid, uuid):
994 """Handle messages known to be on an engine when the engine unregisters.
994 """Handle messages known to be on an engine when the engine unregisters.
995
995
996 It is possible that this will fire prematurely - that is, an engine will
996 It is possible that this will fire prematurely - that is, an engine will
997 go down after completing a result, and the client will be notified
997 go down after completing a result, and the client will be notified
998 that the result failed and later receive the actual result.
998 that the result failed and later receive the actual result.
999 """
999 """
1000
1000
1001 outstanding = self.queues[eid]
1001 outstanding = self.queues[eid]
1002
1002
1003 for msg_id in outstanding:
1003 for msg_id in outstanding:
1004 self.pending.remove(msg_id)
1004 self.pending.remove(msg_id)
1005 self.all_completed.add(msg_id)
1005 self.all_completed.add(msg_id)
1006 try:
1006 try:
1007 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1007 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1008 except:
1008 except:
1009 content = error.wrap_exception()
1009 content = error.wrap_exception()
1010 # build a fake header:
1010 # build a fake header:
1011 header = {}
1011 header = {}
1012 header['engine'] = uuid
1012 header['engine'] = uuid
1013 header['date'] = datetime.now()
1013 header['date'] = datetime.now()
1014 rec = dict(result_content=content, result_header=header, result_buffers=[])
1014 rec = dict(result_content=content, result_header=header, result_buffers=[])
1015 rec['completed'] = header['date']
1015 rec['completed'] = header['date']
1016 rec['engine_uuid'] = uuid
1016 rec['engine_uuid'] = uuid
1017 try:
1017 try:
1018 self.db.update_record(msg_id, rec)
1018 self.db.update_record(msg_id, rec)
1019 except Exception:
1019 except Exception:
1020 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1020 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1021
1021
1022
1022
1023 def finish_registration(self, heart):
1023 def finish_registration(self, heart):
1024 """Second half of engine registration, called after our HeartMonitor
1024 """Second half of engine registration, called after our HeartMonitor
1025 has received a beat from the Engine's Heart."""
1025 has received a beat from the Engine's Heart."""
1026 try:
1026 try:
1027 ec = self.incoming_registrations.pop(heart)
1027 ec = self.incoming_registrations.pop(heart)
1028 except KeyError:
1028 except KeyError:
1029 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1029 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1030 return
1030 return
1031 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1031 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1032 if ec.stallback is not None:
1032 if ec.stallback is not None:
1033 self.loop.remove_timeout(ec.stallback)
1033 self.loop.remove_timeout(ec.stallback)
1034 eid = ec.id
1034 eid = ec.id
1035 self.ids.add(eid)
1035 self.ids.add(eid)
1036 self.keytable[eid] = ec.uuid
1036 self.keytable[eid] = ec.uuid
1037 self.engines[eid] = ec
1037 self.engines[eid] = ec
1038 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1038 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1039 self.queues[eid] = list()
1039 self.queues[eid] = list()
1040 self.tasks[eid] = list()
1040 self.tasks[eid] = list()
1041 self.completed[eid] = list()
1041 self.completed[eid] = list()
1042 self.hearts[heart] = eid
1042 self.hearts[heart] = eid
1043 content = dict(id=eid, uuid=self.engines[eid].uuid)
1043 content = dict(id=eid, uuid=self.engines[eid].uuid)
1044 if self.notifier:
1044 if self.notifier:
1045 self.session.send(self.notifier, "registration_notification", content=content)
1045 self.session.send(self.notifier, "registration_notification", content=content)
1046 self.log.info("engine::Engine Connected: %i", eid)
1046 self.log.info("engine::Engine Connected: %i", eid)
1047
1047
1048 self._save_engine_state()
1048 self._save_engine_state()
1049
1049
1050 def _purge_stalled_registration(self, heart):
1050 def _purge_stalled_registration(self, heart):
1051 if heart in self.incoming_registrations:
1051 if heart in self.incoming_registrations:
1052 ec = self.incoming_registrations.pop(heart)
1052 ec = self.incoming_registrations.pop(heart)
1053 self.log.info("registration::purging stalled registration: %i", ec.id)
1053 self.log.info("registration::purging stalled registration: %i", ec.id)
1054 else:
1054 else:
1055 pass
1055 pass
1056
1056
1057 #-------------------------------------------------------------------------
1057 #-------------------------------------------------------------------------
1058 # Engine State
1058 # Engine State
1059 #-------------------------------------------------------------------------
1059 #-------------------------------------------------------------------------
1060
1060
1061
1061
1062 def _cleanup_engine_state_file(self):
1062 def _cleanup_engine_state_file(self):
1063 """cleanup engine state mapping"""
1063 """cleanup engine state mapping"""
1064
1064
1065 if os.path.exists(self.engine_state_file):
1065 if os.path.exists(self.engine_state_file):
1066 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1066 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1067 try:
1067 try:
1068 os.remove(self.engine_state_file)
1068 os.remove(self.engine_state_file)
1069 except IOError:
1069 except IOError:
1070 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1070 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1071
1071
1072
1072
1073 def _save_engine_state(self):
1073 def _save_engine_state(self):
1074 """save engine mapping to JSON file"""
1074 """save engine mapping to JSON file"""
1075 if not self.engine_state_file:
1075 if not self.engine_state_file:
1076 return
1076 return
1077 self.log.debug("save engine state to %s" % self.engine_state_file)
1077 self.log.debug("save engine state to %s" % self.engine_state_file)
1078 state = {}
1078 state = {}
1079 engines = {}
1079 engines = {}
1080 for eid, ec in iteritems(self.engines):
1080 for eid, ec in iteritems(self.engines):
1081 if ec.uuid not in self.dead_engines:
1081 if ec.uuid not in self.dead_engines:
1082 engines[eid] = ec.uuid
1082 engines[eid] = ec.uuid
1083
1083
1084 state['engines'] = engines
1084 state['engines'] = engines
1085
1085
1086 state['next_id'] = self._idcounter
1086 state['next_id'] = self._idcounter
1087
1087
1088 with open(self.engine_state_file, 'w') as f:
1088 with open(self.engine_state_file, 'w') as f:
1089 json.dump(state, f)
1089 json.dump(state, f)
1090
1090
1091
1091
1092 def _load_engine_state(self):
1092 def _load_engine_state(self):
1093 """load engine mapping from JSON file"""
1093 """load engine mapping from JSON file"""
1094 if not os.path.exists(self.engine_state_file):
1094 if not os.path.exists(self.engine_state_file):
1095 return
1095 return
1096
1096
1097 self.log.info("loading engine state from %s" % self.engine_state_file)
1097 self.log.info("loading engine state from %s" % self.engine_state_file)
1098
1098
1099 with open(self.engine_state_file) as f:
1099 with open(self.engine_state_file) as f:
1100 state = json.load(f)
1100 state = json.load(f)
1101
1101
1102 save_notifier = self.notifier
1102 save_notifier = self.notifier
1103 self.notifier = None
1103 self.notifier = None
1104 for eid, uuid in iteritems(state['engines']):
1104 for eid, uuid in iteritems(state['engines']):
1105 heart = uuid.encode('ascii')
1105 heart = uuid.encode('ascii')
1106 # start with this heart as current and beating:
1106 # start with this heart as current and beating:
1107 self.heartmonitor.responses.add(heart)
1107 self.heartmonitor.responses.add(heart)
1108 self.heartmonitor.hearts.add(heart)
1108 self.heartmonitor.hearts.add(heart)
1109
1109
1110 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1110 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1111 self.finish_registration(heart)
1111 self.finish_registration(heart)
1112
1112
1113 self.notifier = save_notifier
1113 self.notifier = save_notifier
1114
1114
1115 self._idcounter = state['next_id']
1115 self._idcounter = state['next_id']
1116
1116
1117 #-------------------------------------------------------------------------
1117 #-------------------------------------------------------------------------
1118 # Client Requests
1118 # Client Requests
1119 #-------------------------------------------------------------------------
1119 #-------------------------------------------------------------------------
1120
1120
1121 def shutdown_request(self, client_id, msg):
1121 def shutdown_request(self, client_id, msg):
1122 """handle shutdown request."""
1122 """handle shutdown request."""
1123 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1123 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1124 # also notify other clients of shutdown
1124 # also notify other clients of shutdown
1125 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1125 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1126 self.loop.add_timeout(self.loop.time() + 1, self._shutdown)
1126 self.loop.add_timeout(self.loop.time() + 1, self._shutdown)
1127
1127
1128 def _shutdown(self):
1128 def _shutdown(self):
1129 self.log.info("hub::hub shutting down.")
1129 self.log.info("hub::hub shutting down.")
1130 time.sleep(0.1)
1130 time.sleep(0.1)
1131 sys.exit(0)
1131 sys.exit(0)
1132
1132
1133
1133
1134 def check_load(self, client_id, msg):
1134 def check_load(self, client_id, msg):
1135 content = msg['content']
1135 content = msg['content']
1136 try:
1136 try:
1137 targets = content['targets']
1137 targets = content['targets']
1138 targets = self._validate_targets(targets)
1138 targets = self._validate_targets(targets)
1139 except:
1139 except:
1140 content = error.wrap_exception()
1140 content = error.wrap_exception()
1141 self.session.send(self.query, "hub_error",
1141 self.session.send(self.query, "hub_error",
1142 content=content, ident=client_id)
1142 content=content, ident=client_id)
1143 return
1143 return
1144
1144
1145 content = dict(status='ok')
1145 content = dict(status='ok')
1146 # loads = {}
1146 # loads = {}
1147 for t in targets:
1147 for t in targets:
1148 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1148 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1149 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1149 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1150
1150
1151
1151
1152 def queue_status(self, client_id, msg):
1152 def queue_status(self, client_id, msg):
1153 """Return the Queue status of one or more targets.
1153 """Return the Queue status of one or more targets.
1154
1154
1155 If verbose, return the msg_ids, else return len of each type.
1155 If verbose, return the msg_ids, else return len of each type.
1156
1156
1157 Keys:
1157 Keys:
1158
1158
1159 * queue (pending MUX jobs)
1159 * queue (pending MUX jobs)
1160 * tasks (pending Task jobs)
1160 * tasks (pending Task jobs)
1161 * completed (finished jobs from both queues)
1161 * completed (finished jobs from both queues)
1162 """
1162 """
1163 content = msg['content']
1163 content = msg['content']
1164 targets = content['targets']
1164 targets = content['targets']
1165 try:
1165 try:
1166 targets = self._validate_targets(targets)
1166 targets = self._validate_targets(targets)
1167 except:
1167 except:
1168 content = error.wrap_exception()
1168 content = error.wrap_exception()
1169 self.session.send(self.query, "hub_error",
1169 self.session.send(self.query, "hub_error",
1170 content=content, ident=client_id)
1170 content=content, ident=client_id)
1171 return
1171 return
1172 verbose = content.get('verbose', False)
1172 verbose = content.get('verbose', False)
1173 content = dict(status='ok')
1173 content = dict(status='ok')
1174 for t in targets:
1174 for t in targets:
1175 queue = self.queues[t]
1175 queue = self.queues[t]
1176 completed = self.completed[t]
1176 completed = self.completed[t]
1177 tasks = self.tasks[t]
1177 tasks = self.tasks[t]
1178 if not verbose:
1178 if not verbose:
1179 queue = len(queue)
1179 queue = len(queue)
1180 completed = len(completed)
1180 completed = len(completed)
1181 tasks = len(tasks)
1181 tasks = len(tasks)
1182 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1182 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1183 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1183 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1184 # print (content)
1184 # print (content)
1185 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1185 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1186
1186
1187 def purge_results(self, client_id, msg):
1187 def purge_results(self, client_id, msg):
1188 """Purge results from memory. This method is more valuable before we move
1188 """Purge results from memory. This method is more valuable before we move
1189 to a DB based message storage mechanism."""
1189 to a DB based message storage mechanism."""
1190 content = msg['content']
1190 content = msg['content']
1191 self.log.info("Dropping records with %s", content)
1191 self.log.info("Dropping records with %s", content)
1192 msg_ids = content.get('msg_ids', [])
1192 msg_ids = content.get('msg_ids', [])
1193 reply = dict(status='ok')
1193 reply = dict(status='ok')
1194 if msg_ids == 'all':
1194 if msg_ids == 'all':
1195 try:
1195 try:
1196 self.db.drop_matching_records(dict(completed={'$ne':None}))
1196 self.db.drop_matching_records(dict(completed={'$ne':None}))
1197 except Exception:
1197 except Exception:
1198 reply = error.wrap_exception()
1198 reply = error.wrap_exception()
1199 self.log.exception("Error dropping records")
1199 self.log.exception("Error dropping records")
1200 else:
1200 else:
1201 pending = [m for m in msg_ids if (m in self.pending)]
1201 pending = [m for m in msg_ids if (m in self.pending)]
1202 if pending:
1202 if pending:
1203 try:
1203 try:
1204 raise IndexError("msg pending: %r" % pending[0])
1204 raise IndexError("msg pending: %r" % pending[0])
1205 except:
1205 except:
1206 reply = error.wrap_exception()
1206 reply = error.wrap_exception()
1207 self.log.exception("Error dropping records")
1207 self.log.exception("Error dropping records")
1208 else:
1208 else:
1209 try:
1209 try:
1210 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1210 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1211 except Exception:
1211 except Exception:
1212 reply = error.wrap_exception()
1212 reply = error.wrap_exception()
1213 self.log.exception("Error dropping records")
1213 self.log.exception("Error dropping records")
1214
1214
1215 if reply['status'] == 'ok':
1215 if reply['status'] == 'ok':
1216 eids = content.get('engine_ids', [])
1216 eids = content.get('engine_ids', [])
1217 for eid in eids:
1217 for eid in eids:
1218 if eid not in self.engines:
1218 if eid not in self.engines:
1219 try:
1219 try:
1220 raise IndexError("No such engine: %i" % eid)
1220 raise IndexError("No such engine: %i" % eid)
1221 except:
1221 except:
1222 reply = error.wrap_exception()
1222 reply = error.wrap_exception()
1223 self.log.exception("Error dropping records")
1223 self.log.exception("Error dropping records")
1224 break
1224 break
1225 uid = self.engines[eid].uuid
1225 uid = self.engines[eid].uuid
1226 try:
1226 try:
1227 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1227 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1228 except Exception:
1228 except Exception:
1229 reply = error.wrap_exception()
1229 reply = error.wrap_exception()
1230 self.log.exception("Error dropping records")
1230 self.log.exception("Error dropping records")
1231 break
1231 break
1232
1232
1233 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1233 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1234
1234
1235 def resubmit_task(self, client_id, msg):
1235 def resubmit_task(self, client_id, msg):
1236 """Resubmit one or more tasks."""
1236 """Resubmit one or more tasks."""
1237 def finish(reply):
1237 def finish(reply):
1238 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1238 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1239
1239
1240 content = msg['content']
1240 content = msg['content']
1241 msg_ids = content['msg_ids']
1241 msg_ids = content['msg_ids']
1242 reply = dict(status='ok')
1242 reply = dict(status='ok')
1243 try:
1243 try:
1244 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1244 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1245 'header', 'content', 'buffers'])
1245 'header', 'content', 'buffers'])
1246 except Exception:
1246 except Exception:
1247 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1247 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1248 return finish(error.wrap_exception())
1248 return finish(error.wrap_exception())
1249
1249
1250 # validate msg_ids
1250 # validate msg_ids
1251 found_ids = [ rec['msg_id'] for rec in records ]
1251 found_ids = [ rec['msg_id'] for rec in records ]
1252 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1252 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1253 if len(records) > len(msg_ids):
1253 if len(records) > len(msg_ids):
1254 try:
1254 try:
1255 raise RuntimeError("DB appears to be in an inconsistent state."
1255 raise RuntimeError("DB appears to be in an inconsistent state."
1256 "More matching records were found than should exist")
1256 "More matching records were found than should exist")
1257 except Exception:
1257 except Exception:
1258 self.log.exception("Failed to resubmit task")
1258 self.log.exception("Failed to resubmit task")
1259 return finish(error.wrap_exception())
1259 return finish(error.wrap_exception())
1260 elif len(records) < len(msg_ids):
1260 elif len(records) < len(msg_ids):
1261 missing = [ m for m in msg_ids if m not in found_ids ]
1261 missing = [ m for m in msg_ids if m not in found_ids ]
1262 try:
1262 try:
1263 raise KeyError("No such msg(s): %r" % missing)
1263 raise KeyError("No such msg(s): %r" % missing)
1264 except KeyError:
1264 except KeyError:
1265 self.log.exception("Failed to resubmit task")
1265 self.log.exception("Failed to resubmit task")
1266 return finish(error.wrap_exception())
1266 return finish(error.wrap_exception())
1267 elif pending_ids:
1267 elif pending_ids:
1268 pass
1268 pass
1269 # no need to raise on resubmit of pending task, now that we
1269 # no need to raise on resubmit of pending task, now that we
1270 # resubmit under new ID, but do we want to raise anyway?
1270 # resubmit under new ID, but do we want to raise anyway?
1271 # msg_id = invalid_ids[0]
1271 # msg_id = invalid_ids[0]
1272 # try:
1272 # try:
1273 # raise ValueError("Task(s) %r appears to be inflight" % )
1273 # raise ValueError("Task(s) %r appears to be inflight" % )
1274 # except Exception:
1274 # except Exception:
1275 # return finish(error.wrap_exception())
1275 # return finish(error.wrap_exception())
1276
1276
1277 # mapping of original IDs to resubmitted IDs
1277 # mapping of original IDs to resubmitted IDs
1278 resubmitted = {}
1278 resubmitted = {}
1279
1279
1280 # send the messages
1280 # send the messages
1281 for rec in records:
1281 for rec in records:
1282 header = rec['header']
1282 header = rec['header']
1283 msg = self.session.msg(header['msg_type'], parent=header)
1283 msg = self.session.msg(header['msg_type'], parent=header)
1284 msg_id = msg['msg_id']
1284 msg_id = msg['msg_id']
1285 msg['content'] = rec['content']
1285 msg['content'] = rec['content']
1286
1286
1287 # use the old header, but update msg_id and timestamp
1287 # use the old header, but update msg_id and timestamp
1288 fresh = msg['header']
1288 fresh = msg['header']
1289 header['msg_id'] = fresh['msg_id']
1289 header['msg_id'] = fresh['msg_id']
1290 header['date'] = fresh['date']
1290 header['date'] = fresh['date']
1291 msg['header'] = header
1291 msg['header'] = header
1292
1292
1293 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1293 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1294
1294
1295 resubmitted[rec['msg_id']] = msg_id
1295 resubmitted[rec['msg_id']] = msg_id
1296 self.pending.add(msg_id)
1296 self.pending.add(msg_id)
1297 msg['buffers'] = rec['buffers']
1297 msg['buffers'] = rec['buffers']
1298 try:
1298 try:
1299 self.db.add_record(msg_id, init_record(msg))
1299 self.db.add_record(msg_id, init_record(msg))
1300 except Exception:
1300 except Exception:
1301 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1301 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1302 return finish(error.wrap_exception())
1302 return finish(error.wrap_exception())
1303
1303
1304 finish(dict(status='ok', resubmitted=resubmitted))
1304 finish(dict(status='ok', resubmitted=resubmitted))
1305
1305
1306 # store the new IDs in the Task DB
1306 # store the new IDs in the Task DB
1307 for msg_id, resubmit_id in iteritems(resubmitted):
1307 for msg_id, resubmit_id in iteritems(resubmitted):
1308 try:
1308 try:
1309 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1309 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1310 except Exception:
1310 except Exception:
1311 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1311 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1312
1312
1313
1313
1314 def _extract_record(self, rec):
1314 def _extract_record(self, rec):
1315 """decompose a TaskRecord dict into subsection of reply for get_result"""
1315 """decompose a TaskRecord dict into subsection of reply for get_result"""
1316 io_dict = {}
1316 io_dict = {}
1317 for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'):
1317 for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'):
1318 io_dict[key] = rec[key]
1318 io_dict[key] = rec[key]
1319 content = {
1319 content = {
1320 'header': rec['header'],
1320 'header': rec['header'],
1321 'metadata': rec['metadata'],
1321 'metadata': rec['metadata'],
1322 'result_metadata': rec['result_metadata'],
1322 'result_metadata': rec['result_metadata'],
1323 'result_header' : rec['result_header'],
1323 'result_header' : rec['result_header'],
1324 'result_content': rec['result_content'],
1324 'result_content': rec['result_content'],
1325 'received' : rec['received'],
1325 'received' : rec['received'],
1326 'io' : io_dict,
1326 'io' : io_dict,
1327 }
1327 }
1328 if rec['result_buffers']:
1328 if rec['result_buffers']:
1329 buffers = list(map(bytes, rec['result_buffers']))
1329 buffers = list(map(bytes, rec['result_buffers']))
1330 else:
1330 else:
1331 buffers = []
1331 buffers = []
1332
1332
1333 return content, buffers
1333 return content, buffers
1334
1334
1335 def get_results(self, client_id, msg):
1335 def get_results(self, client_id, msg):
1336 """Get the result of 1 or more messages."""
1336 """Get the result of 1 or more messages."""
1337 content = msg['content']
1337 content = msg['content']
1338 msg_ids = sorted(set(content['msg_ids']))
1338 msg_ids = sorted(set(content['msg_ids']))
1339 statusonly = content.get('status_only', False)
1339 statusonly = content.get('status_only', False)
1340 pending = []
1340 pending = []
1341 completed = []
1341 completed = []
1342 content = dict(status='ok')
1342 content = dict(status='ok')
1343 content['pending'] = pending
1343 content['pending'] = pending
1344 content['completed'] = completed
1344 content['completed'] = completed
1345 buffers = []
1345 buffers = []
1346 if not statusonly:
1346 if not statusonly:
1347 try:
1347 try:
1348 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1348 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1349 # turn match list into dict, for faster lookup
1349 # turn match list into dict, for faster lookup
1350 records = {}
1350 records = {}
1351 for rec in matches:
1351 for rec in matches:
1352 records[rec['msg_id']] = rec
1352 records[rec['msg_id']] = rec
1353 except Exception:
1353 except Exception:
1354 content = error.wrap_exception()
1354 content = error.wrap_exception()
1355 self.log.exception("Failed to get results")
1355 self.log.exception("Failed to get results")
1356 self.session.send(self.query, "result_reply", content=content,
1356 self.session.send(self.query, "result_reply", content=content,
1357 parent=msg, ident=client_id)
1357 parent=msg, ident=client_id)
1358 return
1358 return
1359 else:
1359 else:
1360 records = {}
1360 records = {}
1361 for msg_id in msg_ids:
1361 for msg_id in msg_ids:
1362 if msg_id in self.pending:
1362 if msg_id in self.pending:
1363 pending.append(msg_id)
1363 pending.append(msg_id)
1364 elif msg_id in self.all_completed:
1364 elif msg_id in self.all_completed:
1365 completed.append(msg_id)
1365 completed.append(msg_id)
1366 if not statusonly:
1366 if not statusonly:
1367 c,bufs = self._extract_record(records[msg_id])
1367 c,bufs = self._extract_record(records[msg_id])
1368 content[msg_id] = c
1368 content[msg_id] = c
1369 buffers.extend(bufs)
1369 buffers.extend(bufs)
1370 elif msg_id in records:
1370 elif msg_id in records:
1371 if rec['completed']:
1371 if rec['completed']:
1372 completed.append(msg_id)
1372 completed.append(msg_id)
1373 c,bufs = self._extract_record(records[msg_id])
1373 c,bufs = self._extract_record(records[msg_id])
1374 content[msg_id] = c
1374 content[msg_id] = c
1375 buffers.extend(bufs)
1375 buffers.extend(bufs)
1376 else:
1376 else:
1377 pending.append(msg_id)
1377 pending.append(msg_id)
1378 else:
1378 else:
1379 try:
1379 try:
1380 raise KeyError('No such message: '+msg_id)
1380 raise KeyError('No such message: '+msg_id)
1381 except:
1381 except:
1382 content = error.wrap_exception()
1382 content = error.wrap_exception()
1383 break
1383 break
1384 self.session.send(self.query, "result_reply", content=content,
1384 self.session.send(self.query, "result_reply", content=content,
1385 parent=msg, ident=client_id,
1385 parent=msg, ident=client_id,
1386 buffers=buffers)
1386 buffers=buffers)
1387
1387
1388 def get_history(self, client_id, msg):
1388 def get_history(self, client_id, msg):
1389 """Get a list of all msg_ids in our DB records"""
1389 """Get a list of all msg_ids in our DB records"""
1390 try:
1390 try:
1391 msg_ids = self.db.get_history()
1391 msg_ids = self.db.get_history()
1392 except Exception as e:
1392 except Exception as e:
1393 content = error.wrap_exception()
1393 content = error.wrap_exception()
1394 self.log.exception("Failed to get history")
1394 self.log.exception("Failed to get history")
1395 else:
1395 else:
1396 content = dict(status='ok', history=msg_ids)
1396 content = dict(status='ok', history=msg_ids)
1397
1397
1398 self.session.send(self.query, "history_reply", content=content,
1398 self.session.send(self.query, "history_reply", content=content,
1399 parent=msg, ident=client_id)
1399 parent=msg, ident=client_id)
1400
1400
1401 def db_query(self, client_id, msg):
1401 def db_query(self, client_id, msg):
1402 """Perform a raw query on the task record database."""
1402 """Perform a raw query on the task record database."""
1403 content = msg['content']
1403 content = msg['content']
1404 query = extract_dates(content.get('query', {}))
1404 query = extract_dates(content.get('query', {}))
1405 keys = content.get('keys', None)
1405 keys = content.get('keys', None)
1406 buffers = []
1406 buffers = []
1407 empty = list()
1407 empty = list()
1408 try:
1408 try:
1409 records = self.db.find_records(query, keys)
1409 records = self.db.find_records(query, keys)
1410 except Exception as e:
1410 except Exception as e:
1411 content = error.wrap_exception()
1411 content = error.wrap_exception()
1412 self.log.exception("DB query failed")
1412 self.log.exception("DB query failed")
1413 else:
1413 else:
1414 # extract buffers from reply content:
1414 # extract buffers from reply content:
1415 if keys is not None:
1415 if keys is not None:
1416 buffer_lens = [] if 'buffers' in keys else None
1416 buffer_lens = [] if 'buffers' in keys else None
1417 result_buffer_lens = [] if 'result_buffers' in keys else None
1417 result_buffer_lens = [] if 'result_buffers' in keys else None
1418 else:
1418 else:
1419 buffer_lens = None
1419 buffer_lens = None
1420 result_buffer_lens = None
1420 result_buffer_lens = None
1421
1421
1422 for rec in records:
1422 for rec in records:
1423 # buffers may be None, so double check
1423 # buffers may be None, so double check
1424 b = rec.pop('buffers', empty) or empty
1424 b = rec.pop('buffers', empty) or empty
1425 if buffer_lens is not None:
1425 if buffer_lens is not None:
1426 buffer_lens.append(len(b))
1426 buffer_lens.append(len(b))
1427 buffers.extend(b)
1427 buffers.extend(b)
1428 rb = rec.pop('result_buffers', empty) or empty
1428 rb = rec.pop('result_buffers', empty) or empty
1429 if result_buffer_lens is not None:
1429 if result_buffer_lens is not None:
1430 result_buffer_lens.append(len(rb))
1430 result_buffer_lens.append(len(rb))
1431 buffers.extend(rb)
1431 buffers.extend(rb)
1432 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1432 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1433 result_buffer_lens=result_buffer_lens)
1433 result_buffer_lens=result_buffer_lens)
1434 # self.log.debug (content)
1434 # self.log.debug (content)
1435 self.session.send(self.query, "db_reply", content=content,
1435 self.session.send(self.query, "db_reply", content=content,
1436 parent=msg, ident=client_id,
1436 parent=msg, ident=client_id,
1437 buffers=buffers)
1437 buffers=buffers)
1438
1438
@@ -1,849 +1,849 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 # Copyright (c) IPython Development Team.
8 # Copyright (c) IPython Development Team.
9 # Distributed under the terms of the Modified BSD License.
9 # Distributed under the terms of the Modified BSD License.
10
10
11 import logging
11 import logging
12 import sys
12 import sys
13 import time
13 import time
14
14
15 from collections import deque
15 from collections import deque
16 from datetime import datetime
16 from datetime import datetime
17 from random import randint, random
17 from random import randint, random
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import numpy
21 import numpy
22 except ImportError:
22 except ImportError:
23 numpy = None
23 numpy = None
24
24
25 import zmq
25 import zmq
26 from zmq.eventloop import ioloop, zmqstream
26 from zmq.eventloop import ioloop, zmqstream
27
27
28 # local imports
28 # local imports
29 from decorator import decorator
29 from decorator import decorator
30 from IPython.config.application import Application
30 from IPython.config.application import Application
31 from IPython.config.loader import Config
31 from IPython.config.loader import Config
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
33 from IPython.utils.py3compat import cast_bytes
33 from IPython.utils.py3compat import cast_bytes
34
34
35 from IPython.parallel import error, util
35 from ipython_parallel import error, util
36 from IPython.parallel.factory import SessionFactory
36 from ipython_parallel.factory import SessionFactory
37 from IPython.parallel.util import connect_logger, local_logger
37 from ipython_parallel.util import connect_logger, local_logger
38
38
39 from .dependency import Dependency
39 from .dependency import Dependency
40
40
41 @decorator
41 @decorator
42 def logged(f,self,*args,**kwargs):
42 def logged(f,self,*args,**kwargs):
43 # print ("#--------------------")
43 # print ("#--------------------")
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
45 # print ("#--")
45 # print ("#--")
46 return f(self,*args, **kwargs)
46 return f(self,*args, **kwargs)
47
47
48 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
49 # Chooser functions
49 # Chooser functions
50 #----------------------------------------------------------------------
50 #----------------------------------------------------------------------
51
51
52 def plainrandom(loads):
52 def plainrandom(loads):
53 """Plain random pick."""
53 """Plain random pick."""
54 n = len(loads)
54 n = len(loads)
55 return randint(0,n-1)
55 return randint(0,n-1)
56
56
57 def lru(loads):
57 def lru(loads):
58 """Always pick the front of the line.
58 """Always pick the front of the line.
59
59
60 The content of `loads` is ignored.
60 The content of `loads` is ignored.
61
61
62 Assumes LRU ordering of loads, with oldest first.
62 Assumes LRU ordering of loads, with oldest first.
63 """
63 """
64 return 0
64 return 0
65
65
66 def twobin(loads):
66 def twobin(loads):
67 """Pick two at random, use the LRU of the two.
67 """Pick two at random, use the LRU of the two.
68
68
69 The content of loads is ignored.
69 The content of loads is ignored.
70
70
71 Assumes LRU ordering of loads, with oldest first.
71 Assumes LRU ordering of loads, with oldest first.
72 """
72 """
73 n = len(loads)
73 n = len(loads)
74 a = randint(0,n-1)
74 a = randint(0,n-1)
75 b = randint(0,n-1)
75 b = randint(0,n-1)
76 return min(a,b)
76 return min(a,b)
77
77
78 def weighted(loads):
78 def weighted(loads):
79 """Pick two at random using inverse load as weight.
79 """Pick two at random using inverse load as weight.
80
80
81 Return the less loaded of the two.
81 Return the less loaded of the two.
82 """
82 """
83 # weight 0 a million times more than 1:
83 # weight 0 a million times more than 1:
84 weights = 1./(1e-6+numpy.array(loads))
84 weights = 1./(1e-6+numpy.array(loads))
85 sums = weights.cumsum()
85 sums = weights.cumsum()
86 t = sums[-1]
86 t = sums[-1]
87 x = random()*t
87 x = random()*t
88 y = random()*t
88 y = random()*t
89 idx = 0
89 idx = 0
90 idy = 0
90 idy = 0
91 while sums[idx] < x:
91 while sums[idx] < x:
92 idx += 1
92 idx += 1
93 while sums[idy] < y:
93 while sums[idy] < y:
94 idy += 1
94 idy += 1
95 if weights[idy] > weights[idx]:
95 if weights[idy] > weights[idx]:
96 return idy
96 return idy
97 else:
97 else:
98 return idx
98 return idx
99
99
100 def leastload(loads):
100 def leastload(loads):
101 """Always choose the lowest load.
101 """Always choose the lowest load.
102
102
103 If the lowest load occurs more than once, the first
103 If the lowest load occurs more than once, the first
104 occurance will be used. If loads has LRU ordering, this means
104 occurance will be used. If loads has LRU ordering, this means
105 the LRU of those with the lowest load is chosen.
105 the LRU of those with the lowest load is chosen.
106 """
106 """
107 return loads.index(min(loads))
107 return loads.index(min(loads))
108
108
109 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
110 # Classes
110 # Classes
111 #---------------------------------------------------------------------
111 #---------------------------------------------------------------------
112
112
113
113
114 # store empty default dependency:
114 # store empty default dependency:
115 MET = Dependency([])
115 MET = Dependency([])
116
116
117
117
118 class Job(object):
118 class Job(object):
119 """Simple container for a job"""
119 """Simple container for a job"""
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
121 targets, after, follow, timeout):
121 targets, after, follow, timeout):
122 self.msg_id = msg_id
122 self.msg_id = msg_id
123 self.raw_msg = raw_msg
123 self.raw_msg = raw_msg
124 self.idents = idents
124 self.idents = idents
125 self.msg = msg
125 self.msg = msg
126 self.header = header
126 self.header = header
127 self.metadata = metadata
127 self.metadata = metadata
128 self.targets = targets
128 self.targets = targets
129 self.after = after
129 self.after = after
130 self.follow = follow
130 self.follow = follow
131 self.timeout = timeout
131 self.timeout = timeout
132
132
133 self.removed = False # used for lazy-delete from sorted queue
133 self.removed = False # used for lazy-delete from sorted queue
134 self.timestamp = time.time()
134 self.timestamp = time.time()
135 self.timeout_id = 0
135 self.timeout_id = 0
136 self.blacklist = set()
136 self.blacklist = set()
137
137
138 def __lt__(self, other):
138 def __lt__(self, other):
139 return self.timestamp < other.timestamp
139 return self.timestamp < other.timestamp
140
140
141 def __cmp__(self, other):
141 def __cmp__(self, other):
142 return cmp(self.timestamp, other.timestamp)
142 return cmp(self.timestamp, other.timestamp)
143
143
144 @property
144 @property
145 def dependents(self):
145 def dependents(self):
146 return self.follow.union(self.after)
146 return self.follow.union(self.after)
147
147
148
148
149 class TaskScheduler(SessionFactory):
149 class TaskScheduler(SessionFactory):
150 """Python TaskScheduler object.
150 """Python TaskScheduler object.
151
151
152 This is the simplest object that supports msg_id based
152 This is the simplest object that supports msg_id based
153 DAG dependencies. *Only* task msg_ids are checked, not
153 DAG dependencies. *Only* task msg_ids are checked, not
154 msg_ids of jobs submitted via the MUX queue.
154 msg_ids of jobs submitted via the MUX queue.
155
155
156 """
156 """
157
157
158 hwm = Integer(1, config=True,
158 hwm = Integer(1, config=True,
159 help="""specify the High Water Mark (HWM) for the downstream
159 help="""specify the High Water Mark (HWM) for the downstream
160 socket in the Task scheduler. This is the maximum number
160 socket in the Task scheduler. This is the maximum number
161 of allowed outstanding tasks on each engine.
161 of allowed outstanding tasks on each engine.
162
162
163 The default (1) means that only one task can be outstanding on each
163 The default (1) means that only one task can be outstanding on each
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
165 engines continue to be assigned tasks while they are working,
165 engines continue to be assigned tasks while they are working,
166 effectively hiding network latency behind computation, but can result
166 effectively hiding network latency behind computation, but can result
167 in an imbalance of work when submitting many heterogenous tasks all at
167 in an imbalance of work when submitting many heterogenous tasks all at
168 once. Any positive value greater than one is a compromise between the
168 once. Any positive value greater than one is a compromise between the
169 two.
169 two.
170
170
171 """
171 """
172 )
172 )
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
174 'leastload', config=True,
174 'leastload', config=True,
175 help="""select the task scheduler scheme [default: Python LRU]
175 help="""select the task scheduler scheme [default: Python LRU]
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
177 )
177 )
178 def _scheme_name_changed(self, old, new):
178 def _scheme_name_changed(self, old, new):
179 self.log.debug("Using scheme %r"%new)
179 self.log.debug("Using scheme %r"%new)
180 self.scheme = globals()[new]
180 self.scheme = globals()[new]
181
181
182 # input arguments:
182 # input arguments:
183 scheme = Instance(FunctionType) # function for determining the destination
183 scheme = Instance(FunctionType) # function for determining the destination
184 def _scheme_default(self):
184 def _scheme_default(self):
185 return leastload
185 return leastload
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
191
191
192 # internals:
192 # internals:
193 queue = Instance(deque) # sorted list of Jobs
193 queue = Instance(deque) # sorted list of Jobs
194 def _queue_default(self):
194 def _queue_default(self):
195 return deque()
195 return deque()
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
200 pending = Dict() # dict by engine_uuid of submitted tasks
200 pending = Dict() # dict by engine_uuid of submitted tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
204 clients = Dict() # dict by msg_id for who submitted the task
204 clients = Dict() # dict by msg_id for who submitted the task
205 targets = List() # list of target IDENTs
205 targets = List() # list of target IDENTs
206 loads = List() # list of engine loads
206 loads = List() # list of engine loads
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
208 all_completed = Set() # set of all completed tasks
208 all_completed = Set() # set of all completed tasks
209 all_failed = Set() # set of all failed tasks
209 all_failed = Set() # set of all failed tasks
210 all_done = Set() # set of all finished tasks=union(completed,failed)
210 all_done = Set() # set of all finished tasks=union(completed,failed)
211 all_ids = Set() # set of all submitted task IDs
211 all_ids = Set() # set of all submitted task IDs
212
212
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
214 # but ensure Bytes
214 # but ensure Bytes
215 def _ident_default(self):
215 def _ident_default(self):
216 return self.session.bsession
216 return self.session.bsession
217
217
218 def start(self):
218 def start(self):
219 self.query_stream.on_recv(self.dispatch_query_reply)
219 self.query_stream.on_recv(self.dispatch_query_reply)
220 self.session.send(self.query_stream, "connection_request", {})
220 self.session.send(self.query_stream, "connection_request", {})
221
221
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
224
224
225 self._notification_handlers = dict(
225 self._notification_handlers = dict(
226 registration_notification = self._register_engine,
226 registration_notification = self._register_engine,
227 unregistration_notification = self._unregister_engine
227 unregistration_notification = self._unregister_engine
228 )
228 )
229 self.notifier_stream.on_recv(self.dispatch_notification)
229 self.notifier_stream.on_recv(self.dispatch_notification)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
231
231
232 def resume_receiving(self):
232 def resume_receiving(self):
233 """Resume accepting jobs."""
233 """Resume accepting jobs."""
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
235
236 def stop_receiving(self):
236 def stop_receiving(self):
237 """Stop accepting jobs while there are no engines.
237 """Stop accepting jobs while there are no engines.
238 Leave them in the ZMQ queue."""
238 Leave them in the ZMQ queue."""
239 self.client_stream.on_recv(None)
239 self.client_stream.on_recv(None)
240
240
241 #-----------------------------------------------------------------------
241 #-----------------------------------------------------------------------
242 # [Un]Registration Handling
242 # [Un]Registration Handling
243 #-----------------------------------------------------------------------
243 #-----------------------------------------------------------------------
244
244
245
245
246 def dispatch_query_reply(self, msg):
246 def dispatch_query_reply(self, msg):
247 """handle reply to our initial connection request"""
247 """handle reply to our initial connection request"""
248 try:
248 try:
249 idents,msg = self.session.feed_identities(msg)
249 idents,msg = self.session.feed_identities(msg)
250 except ValueError:
250 except ValueError:
251 self.log.warn("task::Invalid Message: %r",msg)
251 self.log.warn("task::Invalid Message: %r",msg)
252 return
252 return
253 try:
253 try:
254 msg = self.session.deserialize(msg)
254 msg = self.session.deserialize(msg)
255 except ValueError:
255 except ValueError:
256 self.log.warn("task::Unauthorized message from: %r"%idents)
256 self.log.warn("task::Unauthorized message from: %r"%idents)
257 return
257 return
258
258
259 content = msg['content']
259 content = msg['content']
260 for uuid in content.get('engines', {}).values():
260 for uuid in content.get('engines', {}).values():
261 self._register_engine(cast_bytes(uuid))
261 self._register_engine(cast_bytes(uuid))
262
262
263
263
264 @util.log_errors
264 @util.log_errors
265 def dispatch_notification(self, msg):
265 def dispatch_notification(self, msg):
266 """dispatch register/unregister events."""
266 """dispatch register/unregister events."""
267 try:
267 try:
268 idents,msg = self.session.feed_identities(msg)
268 idents,msg = self.session.feed_identities(msg)
269 except ValueError:
269 except ValueError:
270 self.log.warn("task::Invalid Message: %r",msg)
270 self.log.warn("task::Invalid Message: %r",msg)
271 return
271 return
272 try:
272 try:
273 msg = self.session.deserialize(msg)
273 msg = self.session.deserialize(msg)
274 except ValueError:
274 except ValueError:
275 self.log.warn("task::Unauthorized message from: %r"%idents)
275 self.log.warn("task::Unauthorized message from: %r"%idents)
276 return
276 return
277
277
278 msg_type = msg['header']['msg_type']
278 msg_type = msg['header']['msg_type']
279
279
280 handler = self._notification_handlers.get(msg_type, None)
280 handler = self._notification_handlers.get(msg_type, None)
281 if handler is None:
281 if handler is None:
282 self.log.error("Unhandled message type: %r"%msg_type)
282 self.log.error("Unhandled message type: %r"%msg_type)
283 else:
283 else:
284 try:
284 try:
285 handler(cast_bytes(msg['content']['uuid']))
285 handler(cast_bytes(msg['content']['uuid']))
286 except Exception:
286 except Exception:
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
288
288
289 def _register_engine(self, uid):
289 def _register_engine(self, uid):
290 """New engine with ident `uid` became available."""
290 """New engine with ident `uid` became available."""
291 # head of the line:
291 # head of the line:
292 self.targets.insert(0,uid)
292 self.targets.insert(0,uid)
293 self.loads.insert(0,0)
293 self.loads.insert(0,0)
294
294
295 # initialize sets
295 # initialize sets
296 self.completed[uid] = set()
296 self.completed[uid] = set()
297 self.failed[uid] = set()
297 self.failed[uid] = set()
298 self.pending[uid] = {}
298 self.pending[uid] = {}
299
299
300 # rescan the graph:
300 # rescan the graph:
301 self.update_graph(None)
301 self.update_graph(None)
302
302
303 def _unregister_engine(self, uid):
303 def _unregister_engine(self, uid):
304 """Existing engine with ident `uid` became unavailable."""
304 """Existing engine with ident `uid` became unavailable."""
305 if len(self.targets) == 1:
305 if len(self.targets) == 1:
306 # this was our only engine
306 # this was our only engine
307 pass
307 pass
308
308
309 # handle any potentially finished tasks:
309 # handle any potentially finished tasks:
310 self.engine_stream.flush()
310 self.engine_stream.flush()
311
311
312 # don't pop destinations, because they might be used later
312 # don't pop destinations, because they might be used later
313 # map(self.destinations.pop, self.completed.pop(uid))
313 # map(self.destinations.pop, self.completed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
315
315
316 # prevent this engine from receiving work
316 # prevent this engine from receiving work
317 idx = self.targets.index(uid)
317 idx = self.targets.index(uid)
318 self.targets.pop(idx)
318 self.targets.pop(idx)
319 self.loads.pop(idx)
319 self.loads.pop(idx)
320
320
321 # wait 5 seconds before cleaning up pending jobs, since the results might
321 # wait 5 seconds before cleaning up pending jobs, since the results might
322 # still be incoming
322 # still be incoming
323 if self.pending[uid]:
323 if self.pending[uid]:
324 self.loop.add_timeout(self.loop.time() + 5,
324 self.loop.add_timeout(self.loop.time() + 5,
325 lambda : self.handle_stranded_tasks(uid),
325 lambda : self.handle_stranded_tasks(uid),
326 )
326 )
327 else:
327 else:
328 self.completed.pop(uid)
328 self.completed.pop(uid)
329 self.failed.pop(uid)
329 self.failed.pop(uid)
330
330
331
331
332 def handle_stranded_tasks(self, engine):
332 def handle_stranded_tasks(self, engine):
333 """Deal with jobs resident in an engine that died."""
333 """Deal with jobs resident in an engine that died."""
334 lost = self.pending[engine]
334 lost = self.pending[engine]
335 for msg_id in lost.keys():
335 for msg_id in lost.keys():
336 if msg_id not in self.pending[engine]:
336 if msg_id not in self.pending[engine]:
337 # prevent double-handling of messages
337 # prevent double-handling of messages
338 continue
338 continue
339
339
340 raw_msg = lost[msg_id].raw_msg
340 raw_msg = lost[msg_id].raw_msg
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
341 idents,msg = self.session.feed_identities(raw_msg, copy=False)
342 parent = self.session.unpack(msg[1].bytes)
342 parent = self.session.unpack(msg[1].bytes)
343 idents = [engine, idents[0]]
343 idents = [engine, idents[0]]
344
344
345 # build fake error reply
345 # build fake error reply
346 try:
346 try:
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
347 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
348 except:
348 except:
349 content = error.wrap_exception()
349 content = error.wrap_exception()
350 # build fake metadata
350 # build fake metadata
351 md = dict(
351 md = dict(
352 status=u'error',
352 status=u'error',
353 engine=engine.decode('ascii'),
353 engine=engine.decode('ascii'),
354 date=datetime.now(),
354 date=datetime.now(),
355 )
355 )
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
356 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
357 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
358 # and dispatch it
358 # and dispatch it
359 self.dispatch_result(raw_reply)
359 self.dispatch_result(raw_reply)
360
360
361 # finally scrub completed/failed lists
361 # finally scrub completed/failed lists
362 self.completed.pop(engine)
362 self.completed.pop(engine)
363 self.failed.pop(engine)
363 self.failed.pop(engine)
364
364
365
365
366 #-----------------------------------------------------------------------
366 #-----------------------------------------------------------------------
367 # Job Submission
367 # Job Submission
368 #-----------------------------------------------------------------------
368 #-----------------------------------------------------------------------
369
369
370
370
371 @util.log_errors
371 @util.log_errors
372 def dispatch_submission(self, raw_msg):
372 def dispatch_submission(self, raw_msg):
373 """Dispatch job submission to appropriate handlers."""
373 """Dispatch job submission to appropriate handlers."""
374 # ensure targets up to date:
374 # ensure targets up to date:
375 self.notifier_stream.flush()
375 self.notifier_stream.flush()
376 try:
376 try:
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
377 idents, msg = self.session.feed_identities(raw_msg, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
378 msg = self.session.deserialize(msg, content=False, copy=False)
379 except Exception:
379 except Exception:
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
380 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
381 return
381 return
382
382
383
383
384 # send to monitor
384 # send to monitor
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
385 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
386
386
387 header = msg['header']
387 header = msg['header']
388 md = msg['metadata']
388 md = msg['metadata']
389 msg_id = header['msg_id']
389 msg_id = header['msg_id']
390 self.all_ids.add(msg_id)
390 self.all_ids.add(msg_id)
391
391
392 # get targets as a set of bytes objects
392 # get targets as a set of bytes objects
393 # from a list of unicode objects
393 # from a list of unicode objects
394 targets = md.get('targets', [])
394 targets = md.get('targets', [])
395 targets = set(map(cast_bytes, targets))
395 targets = set(map(cast_bytes, targets))
396
396
397 retries = md.get('retries', 0)
397 retries = md.get('retries', 0)
398 self.retries[msg_id] = retries
398 self.retries[msg_id] = retries
399
399
400 # time dependencies
400 # time dependencies
401 after = md.get('after', None)
401 after = md.get('after', None)
402 if after:
402 if after:
403 after = Dependency(after)
403 after = Dependency(after)
404 if after.all:
404 if after.all:
405 if after.success:
405 if after.success:
406 after = Dependency(after.difference(self.all_completed),
406 after = Dependency(after.difference(self.all_completed),
407 success=after.success,
407 success=after.success,
408 failure=after.failure,
408 failure=after.failure,
409 all=after.all,
409 all=after.all,
410 )
410 )
411 if after.failure:
411 if after.failure:
412 after = Dependency(after.difference(self.all_failed),
412 after = Dependency(after.difference(self.all_failed),
413 success=after.success,
413 success=after.success,
414 failure=after.failure,
414 failure=after.failure,
415 all=after.all,
415 all=after.all,
416 )
416 )
417 if after.check(self.all_completed, self.all_failed):
417 if after.check(self.all_completed, self.all_failed):
418 # recast as empty set, if `after` already met,
418 # recast as empty set, if `after` already met,
419 # to prevent unnecessary set comparisons
419 # to prevent unnecessary set comparisons
420 after = MET
420 after = MET
421 else:
421 else:
422 after = MET
422 after = MET
423
423
424 # location dependencies
424 # location dependencies
425 follow = Dependency(md.get('follow', []))
425 follow = Dependency(md.get('follow', []))
426
426
427 timeout = md.get('timeout', None)
427 timeout = md.get('timeout', None)
428 if timeout:
428 if timeout:
429 timeout = float(timeout)
429 timeout = float(timeout)
430
430
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
431 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
432 header=header, targets=targets, after=after, follow=follow,
432 header=header, targets=targets, after=after, follow=follow,
433 timeout=timeout, metadata=md,
433 timeout=timeout, metadata=md,
434 )
434 )
435 # validate and reduce dependencies:
435 # validate and reduce dependencies:
436 for dep in after,follow:
436 for dep in after,follow:
437 if not dep: # empty dependency
437 if not dep: # empty dependency
438 continue
438 continue
439 # check valid:
439 # check valid:
440 if msg_id in dep or dep.difference(self.all_ids):
440 if msg_id in dep or dep.difference(self.all_ids):
441 self.queue_map[msg_id] = job
441 self.queue_map[msg_id] = job
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
442 return self.fail_unreachable(msg_id, error.InvalidDependency)
443 # check if unreachable:
443 # check if unreachable:
444 if dep.unreachable(self.all_completed, self.all_failed):
444 if dep.unreachable(self.all_completed, self.all_failed):
445 self.queue_map[msg_id] = job
445 self.queue_map[msg_id] = job
446 return self.fail_unreachable(msg_id)
446 return self.fail_unreachable(msg_id)
447
447
448 if after.check(self.all_completed, self.all_failed):
448 if after.check(self.all_completed, self.all_failed):
449 # time deps already met, try to run
449 # time deps already met, try to run
450 if not self.maybe_run(job):
450 if not self.maybe_run(job):
451 # can't run yet
451 # can't run yet
452 if msg_id not in self.all_failed:
452 if msg_id not in self.all_failed:
453 # could have failed as unreachable
453 # could have failed as unreachable
454 self.save_unmet(job)
454 self.save_unmet(job)
455 else:
455 else:
456 self.save_unmet(job)
456 self.save_unmet(job)
457
457
458 def job_timeout(self, job, timeout_id):
458 def job_timeout(self, job, timeout_id):
459 """callback for a job's timeout.
459 """callback for a job's timeout.
460
460
461 The job may or may not have been run at this point.
461 The job may or may not have been run at this point.
462 """
462 """
463 if job.timeout_id != timeout_id:
463 if job.timeout_id != timeout_id:
464 # not the most recent call
464 # not the most recent call
465 return
465 return
466 now = time.time()
466 now = time.time()
467 if job.timeout >= (now + 1):
467 if job.timeout >= (now + 1):
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
468 self.log.warn("task %s timeout fired prematurely: %s > %s",
469 job.msg_id, job.timeout, now
469 job.msg_id, job.timeout, now
470 )
470 )
471 if job.msg_id in self.queue_map:
471 if job.msg_id in self.queue_map:
472 # still waiting, but ran out of time
472 # still waiting, but ran out of time
473 self.log.info("task %r timed out", job.msg_id)
473 self.log.info("task %r timed out", job.msg_id)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
474 self.fail_unreachable(job.msg_id, error.TaskTimeout)
475
475
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
476 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
477 """a task has become unreachable, send a reply with an ImpossibleDependency
477 """a task has become unreachable, send a reply with an ImpossibleDependency
478 error."""
478 error."""
479 if msg_id not in self.queue_map:
479 if msg_id not in self.queue_map:
480 self.log.error("task %r already failed!", msg_id)
480 self.log.error("task %r already failed!", msg_id)
481 return
481 return
482 job = self.queue_map.pop(msg_id)
482 job = self.queue_map.pop(msg_id)
483 # lazy-delete from the queue
483 # lazy-delete from the queue
484 job.removed = True
484 job.removed = True
485 for mid in job.dependents:
485 for mid in job.dependents:
486 if mid in self.graph:
486 if mid in self.graph:
487 self.graph[mid].remove(msg_id)
487 self.graph[mid].remove(msg_id)
488
488
489 try:
489 try:
490 raise why()
490 raise why()
491 except:
491 except:
492 content = error.wrap_exception()
492 content = error.wrap_exception()
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
493 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
494
494
495 self.all_done.add(msg_id)
495 self.all_done.add(msg_id)
496 self.all_failed.add(msg_id)
496 self.all_failed.add(msg_id)
497
497
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
498 msg = self.session.send(self.client_stream, 'apply_reply', content,
499 parent=job.header, ident=job.idents)
499 parent=job.header, ident=job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
500 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
501
501
502 self.update_graph(msg_id, success=False)
502 self.update_graph(msg_id, success=False)
503
503
504 def available_engines(self):
504 def available_engines(self):
505 """return a list of available engine indices based on HWM"""
505 """return a list of available engine indices based on HWM"""
506 if not self.hwm:
506 if not self.hwm:
507 return list(range(len(self.targets)))
507 return list(range(len(self.targets)))
508 available = []
508 available = []
509 for idx in range(len(self.targets)):
509 for idx in range(len(self.targets)):
510 if self.loads[idx] < self.hwm:
510 if self.loads[idx] < self.hwm:
511 available.append(idx)
511 available.append(idx)
512 return available
512 return available
513
513
514 def maybe_run(self, job):
514 def maybe_run(self, job):
515 """check location dependencies, and run if they are met."""
515 """check location dependencies, and run if they are met."""
516 msg_id = job.msg_id
516 msg_id = job.msg_id
517 self.log.debug("Attempting to assign task %s", msg_id)
517 self.log.debug("Attempting to assign task %s", msg_id)
518 available = self.available_engines()
518 available = self.available_engines()
519 if not available:
519 if not available:
520 # no engines, definitely can't run
520 # no engines, definitely can't run
521 return False
521 return False
522
522
523 if job.follow or job.targets or job.blacklist or self.hwm:
523 if job.follow or job.targets or job.blacklist or self.hwm:
524 # we need a can_run filter
524 # we need a can_run filter
525 def can_run(idx):
525 def can_run(idx):
526 # check hwm
526 # check hwm
527 if self.hwm and self.loads[idx] == self.hwm:
527 if self.hwm and self.loads[idx] == self.hwm:
528 return False
528 return False
529 target = self.targets[idx]
529 target = self.targets[idx]
530 # check blacklist
530 # check blacklist
531 if target in job.blacklist:
531 if target in job.blacklist:
532 return False
532 return False
533 # check targets
533 # check targets
534 if job.targets and target not in job.targets:
534 if job.targets and target not in job.targets:
535 return False
535 return False
536 # check follow
536 # check follow
537 return job.follow.check(self.completed[target], self.failed[target])
537 return job.follow.check(self.completed[target], self.failed[target])
538
538
539 indices = list(filter(can_run, available))
539 indices = list(filter(can_run, available))
540
540
541 if not indices:
541 if not indices:
542 # couldn't run
542 # couldn't run
543 if job.follow.all:
543 if job.follow.all:
544 # check follow for impossibility
544 # check follow for impossibility
545 dests = set()
545 dests = set()
546 relevant = set()
546 relevant = set()
547 if job.follow.success:
547 if job.follow.success:
548 relevant = self.all_completed
548 relevant = self.all_completed
549 if job.follow.failure:
549 if job.follow.failure:
550 relevant = relevant.union(self.all_failed)
550 relevant = relevant.union(self.all_failed)
551 for m in job.follow.intersection(relevant):
551 for m in job.follow.intersection(relevant):
552 dests.add(self.destinations[m])
552 dests.add(self.destinations[m])
553 if len(dests) > 1:
553 if len(dests) > 1:
554 self.queue_map[msg_id] = job
554 self.queue_map[msg_id] = job
555 self.fail_unreachable(msg_id)
555 self.fail_unreachable(msg_id)
556 return False
556 return False
557 if job.targets:
557 if job.targets:
558 # check blacklist+targets for impossibility
558 # check blacklist+targets for impossibility
559 job.targets.difference_update(job.blacklist)
559 job.targets.difference_update(job.blacklist)
560 if not job.targets or not job.targets.intersection(self.targets):
560 if not job.targets or not job.targets.intersection(self.targets):
561 self.queue_map[msg_id] = job
561 self.queue_map[msg_id] = job
562 self.fail_unreachable(msg_id)
562 self.fail_unreachable(msg_id)
563 return False
563 return False
564 return False
564 return False
565 else:
565 else:
566 indices = None
566 indices = None
567
567
568 self.submit_task(job, indices)
568 self.submit_task(job, indices)
569 return True
569 return True
570
570
571 def save_unmet(self, job):
571 def save_unmet(self, job):
572 """Save a message for later submission when its dependencies are met."""
572 """Save a message for later submission when its dependencies are met."""
573 msg_id = job.msg_id
573 msg_id = job.msg_id
574 self.log.debug("Adding task %s to the queue", msg_id)
574 self.log.debug("Adding task %s to the queue", msg_id)
575 self.queue_map[msg_id] = job
575 self.queue_map[msg_id] = job
576 self.queue.append(job)
576 self.queue.append(job)
577 # track the ids in follow or after, but not those already finished
577 # track the ids in follow or after, but not those already finished
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
578 for dep_id in job.after.union(job.follow).difference(self.all_done):
579 if dep_id not in self.graph:
579 if dep_id not in self.graph:
580 self.graph[dep_id] = set()
580 self.graph[dep_id] = set()
581 self.graph[dep_id].add(msg_id)
581 self.graph[dep_id].add(msg_id)
582
582
583 # schedule timeout callback
583 # schedule timeout callback
584 if job.timeout:
584 if job.timeout:
585 timeout_id = job.timeout_id = job.timeout_id + 1
585 timeout_id = job.timeout_id = job.timeout_id + 1
586 self.loop.add_timeout(time.time() + job.timeout,
586 self.loop.add_timeout(time.time() + job.timeout,
587 lambda : self.job_timeout(job, timeout_id)
587 lambda : self.job_timeout(job, timeout_id)
588 )
588 )
589
589
590
590
591 def submit_task(self, job, indices=None):
591 def submit_task(self, job, indices=None):
592 """Submit a task to any of a subset of our targets."""
592 """Submit a task to any of a subset of our targets."""
593 if indices:
593 if indices:
594 loads = [self.loads[i] for i in indices]
594 loads = [self.loads[i] for i in indices]
595 else:
595 else:
596 loads = self.loads
596 loads = self.loads
597 idx = self.scheme(loads)
597 idx = self.scheme(loads)
598 if indices:
598 if indices:
599 idx = indices[idx]
599 idx = indices[idx]
600 target = self.targets[idx]
600 target = self.targets[idx]
601 # print (target, map(str, msg[:3]))
601 # print (target, map(str, msg[:3]))
602 # send job to the engine
602 # send job to the engine
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
603 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
604 self.engine_stream.send_multipart(job.raw_msg, copy=False)
605 # update load
605 # update load
606 self.add_job(idx)
606 self.add_job(idx)
607 self.pending[target][job.msg_id] = job
607 self.pending[target][job.msg_id] = job
608 # notify Hub
608 # notify Hub
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
609 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
610 self.session.send(self.mon_stream, 'task_destination', content=content,
610 self.session.send(self.mon_stream, 'task_destination', content=content,
611 ident=[b'tracktask',self.ident])
611 ident=[b'tracktask',self.ident])
612
612
613
613
614 #-----------------------------------------------------------------------
614 #-----------------------------------------------------------------------
615 # Result Handling
615 # Result Handling
616 #-----------------------------------------------------------------------
616 #-----------------------------------------------------------------------
617
617
618
618
619 @util.log_errors
619 @util.log_errors
620 def dispatch_result(self, raw_msg):
620 def dispatch_result(self, raw_msg):
621 """dispatch method for result replies"""
621 """dispatch method for result replies"""
622 try:
622 try:
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
623 idents,msg = self.session.feed_identities(raw_msg, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
624 msg = self.session.deserialize(msg, content=False, copy=False)
625 engine = idents[0]
625 engine = idents[0]
626 try:
626 try:
627 idx = self.targets.index(engine)
627 idx = self.targets.index(engine)
628 except ValueError:
628 except ValueError:
629 pass # skip load-update for dead engines
629 pass # skip load-update for dead engines
630 else:
630 else:
631 self.finish_job(idx)
631 self.finish_job(idx)
632 except Exception:
632 except Exception:
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
633 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
634 return
634 return
635
635
636 md = msg['metadata']
636 md = msg['metadata']
637 parent = msg['parent_header']
637 parent = msg['parent_header']
638 if md.get('dependencies_met', True):
638 if md.get('dependencies_met', True):
639 success = (md['status'] == 'ok')
639 success = (md['status'] == 'ok')
640 msg_id = parent['msg_id']
640 msg_id = parent['msg_id']
641 retries = self.retries[msg_id]
641 retries = self.retries[msg_id]
642 if not success and retries > 0:
642 if not success and retries > 0:
643 # failed
643 # failed
644 self.retries[msg_id] = retries - 1
644 self.retries[msg_id] = retries - 1
645 self.handle_unmet_dependency(idents, parent)
645 self.handle_unmet_dependency(idents, parent)
646 else:
646 else:
647 del self.retries[msg_id]
647 del self.retries[msg_id]
648 # relay to client and update graph
648 # relay to client and update graph
649 self.handle_result(idents, parent, raw_msg, success)
649 self.handle_result(idents, parent, raw_msg, success)
650 # send to Hub monitor
650 # send to Hub monitor
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
651 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
652 else:
652 else:
653 self.handle_unmet_dependency(idents, parent)
653 self.handle_unmet_dependency(idents, parent)
654
654
655 def handle_result(self, idents, parent, raw_msg, success=True):
655 def handle_result(self, idents, parent, raw_msg, success=True):
656 """handle a real task result, either success or failure"""
656 """handle a real task result, either success or failure"""
657 # first, relay result to client
657 # first, relay result to client
658 engine = idents[0]
658 engine = idents[0]
659 client = idents[1]
659 client = idents[1]
660 # swap_ids for ROUTER-ROUTER mirror
660 # swap_ids for ROUTER-ROUTER mirror
661 raw_msg[:2] = [client,engine]
661 raw_msg[:2] = [client,engine]
662 # print (map(str, raw_msg[:4]))
662 # print (map(str, raw_msg[:4]))
663 self.client_stream.send_multipart(raw_msg, copy=False)
663 self.client_stream.send_multipart(raw_msg, copy=False)
664 # now, update our data structures
664 # now, update our data structures
665 msg_id = parent['msg_id']
665 msg_id = parent['msg_id']
666 self.pending[engine].pop(msg_id)
666 self.pending[engine].pop(msg_id)
667 if success:
667 if success:
668 self.completed[engine].add(msg_id)
668 self.completed[engine].add(msg_id)
669 self.all_completed.add(msg_id)
669 self.all_completed.add(msg_id)
670 else:
670 else:
671 self.failed[engine].add(msg_id)
671 self.failed[engine].add(msg_id)
672 self.all_failed.add(msg_id)
672 self.all_failed.add(msg_id)
673 self.all_done.add(msg_id)
673 self.all_done.add(msg_id)
674 self.destinations[msg_id] = engine
674 self.destinations[msg_id] = engine
675
675
676 self.update_graph(msg_id, success)
676 self.update_graph(msg_id, success)
677
677
678 def handle_unmet_dependency(self, idents, parent):
678 def handle_unmet_dependency(self, idents, parent):
679 """handle an unmet dependency"""
679 """handle an unmet dependency"""
680 engine = idents[0]
680 engine = idents[0]
681 msg_id = parent['msg_id']
681 msg_id = parent['msg_id']
682
682
683 job = self.pending[engine].pop(msg_id)
683 job = self.pending[engine].pop(msg_id)
684 job.blacklist.add(engine)
684 job.blacklist.add(engine)
685
685
686 if job.blacklist == job.targets:
686 if job.blacklist == job.targets:
687 self.queue_map[msg_id] = job
687 self.queue_map[msg_id] = job
688 self.fail_unreachable(msg_id)
688 self.fail_unreachable(msg_id)
689 elif not self.maybe_run(job):
689 elif not self.maybe_run(job):
690 # resubmit failed
690 # resubmit failed
691 if msg_id not in self.all_failed:
691 if msg_id not in self.all_failed:
692 # put it back in our dependency tree
692 # put it back in our dependency tree
693 self.save_unmet(job)
693 self.save_unmet(job)
694
694
695 if self.hwm:
695 if self.hwm:
696 try:
696 try:
697 idx = self.targets.index(engine)
697 idx = self.targets.index(engine)
698 except ValueError:
698 except ValueError:
699 pass # skip load-update for dead engines
699 pass # skip load-update for dead engines
700 else:
700 else:
701 if self.loads[idx] == self.hwm-1:
701 if self.loads[idx] == self.hwm-1:
702 self.update_graph(None)
702 self.update_graph(None)
703
703
704 def update_graph(self, dep_id=None, success=True):
704 def update_graph(self, dep_id=None, success=True):
705 """dep_id just finished. Update our dependency
705 """dep_id just finished. Update our dependency
706 graph and submit any jobs that just became runnable.
706 graph and submit any jobs that just became runnable.
707
707
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
708 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
709 """
709 """
710 # print ("\n\n***********")
710 # print ("\n\n***********")
711 # pprint (dep_id)
711 # pprint (dep_id)
712 # pprint (self.graph)
712 # pprint (self.graph)
713 # pprint (self.queue_map)
713 # pprint (self.queue_map)
714 # pprint (self.all_completed)
714 # pprint (self.all_completed)
715 # pprint (self.all_failed)
715 # pprint (self.all_failed)
716 # print ("\n\n***********\n\n")
716 # print ("\n\n***********\n\n")
717 # update any jobs that depended on the dependency
717 # update any jobs that depended on the dependency
718 msg_ids = self.graph.pop(dep_id, [])
718 msg_ids = self.graph.pop(dep_id, [])
719
719
720 # recheck *all* jobs if
720 # recheck *all* jobs if
721 # a) we have HWM and an engine just become no longer full
721 # a) we have HWM and an engine just become no longer full
722 # or b) dep_id was given as None
722 # or b) dep_id was given as None
723
723
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
724 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
725 jobs = self.queue
725 jobs = self.queue
726 using_queue = True
726 using_queue = True
727 else:
727 else:
728 using_queue = False
728 using_queue = False
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
729 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
730
730
731 to_restore = []
731 to_restore = []
732 while jobs:
732 while jobs:
733 job = jobs.popleft()
733 job = jobs.popleft()
734 if job.removed:
734 if job.removed:
735 continue
735 continue
736 msg_id = job.msg_id
736 msg_id = job.msg_id
737
737
738 put_it_back = True
738 put_it_back = True
739
739
740 if job.after.unreachable(self.all_completed, self.all_failed)\
740 if job.after.unreachable(self.all_completed, self.all_failed)\
741 or job.follow.unreachable(self.all_completed, self.all_failed):
741 or job.follow.unreachable(self.all_completed, self.all_failed):
742 self.fail_unreachable(msg_id)
742 self.fail_unreachable(msg_id)
743 put_it_back = False
743 put_it_back = False
744
744
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
745 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
746 if self.maybe_run(job):
746 if self.maybe_run(job):
747 put_it_back = False
747 put_it_back = False
748 self.queue_map.pop(msg_id)
748 self.queue_map.pop(msg_id)
749 for mid in job.dependents:
749 for mid in job.dependents:
750 if mid in self.graph:
750 if mid in self.graph:
751 self.graph[mid].remove(msg_id)
751 self.graph[mid].remove(msg_id)
752
752
753 # abort the loop if we just filled up all of our engines.
753 # abort the loop if we just filled up all of our engines.
754 # avoids an O(N) operation in situation of full queue,
754 # avoids an O(N) operation in situation of full queue,
755 # where graph update is triggered as soon as an engine becomes
755 # where graph update is triggered as soon as an engine becomes
756 # non-full, and all tasks after the first are checked,
756 # non-full, and all tasks after the first are checked,
757 # even though they can't run.
757 # even though they can't run.
758 if not self.available_engines():
758 if not self.available_engines():
759 break
759 break
760
760
761 if using_queue and put_it_back:
761 if using_queue and put_it_back:
762 # popped a job from the queue but it neither ran nor failed,
762 # popped a job from the queue but it neither ran nor failed,
763 # so we need to put it back when we are done
763 # so we need to put it back when we are done
764 # make sure to_restore preserves the same ordering
764 # make sure to_restore preserves the same ordering
765 to_restore.append(job)
765 to_restore.append(job)
766
766
767 # put back any tasks we popped but didn't run
767 # put back any tasks we popped but didn't run
768 if using_queue:
768 if using_queue:
769 self.queue.extendleft(to_restore)
769 self.queue.extendleft(to_restore)
770
770
771 #----------------------------------------------------------------------
771 #----------------------------------------------------------------------
772 # methods to be overridden by subclasses
772 # methods to be overridden by subclasses
773 #----------------------------------------------------------------------
773 #----------------------------------------------------------------------
774
774
775 def add_job(self, idx):
775 def add_job(self, idx):
776 """Called after self.targets[idx] just got the job with header.
776 """Called after self.targets[idx] just got the job with header.
777 Override with subclasses. The default ordering is simple LRU.
777 Override with subclasses. The default ordering is simple LRU.
778 The default loads are the number of outstanding jobs."""
778 The default loads are the number of outstanding jobs."""
779 self.loads[idx] += 1
779 self.loads[idx] += 1
780 for lis in (self.targets, self.loads):
780 for lis in (self.targets, self.loads):
781 lis.append(lis.pop(idx))
781 lis.append(lis.pop(idx))
782
782
783
783
784 def finish_job(self, idx):
784 def finish_job(self, idx):
785 """Called after self.targets[idx] just finished a job.
785 """Called after self.targets[idx] just finished a job.
786 Override with subclasses."""
786 Override with subclasses."""
787 self.loads[idx] -= 1
787 self.loads[idx] -= 1
788
788
789
789
790
790
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
791 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
792 logname='root', log_url=None, loglevel=logging.DEBUG,
793 identity=b'task', in_thread=False):
793 identity=b'task', in_thread=False):
794
794
795 ZMQStream = zmqstream.ZMQStream
795 ZMQStream = zmqstream.ZMQStream
796
796
797 if config:
797 if config:
798 # unwrap dict back into Config
798 # unwrap dict back into Config
799 config = Config(config)
799 config = Config(config)
800
800
801 if in_thread:
801 if in_thread:
802 # use instance() to get the same Context/Loop as our parent
802 # use instance() to get the same Context/Loop as our parent
803 ctx = zmq.Context.instance()
803 ctx = zmq.Context.instance()
804 loop = ioloop.IOLoop.instance()
804 loop = ioloop.IOLoop.instance()
805 else:
805 else:
806 # in a process, don't use instance()
806 # in a process, don't use instance()
807 # for safety with multiprocessing
807 # for safety with multiprocessing
808 ctx = zmq.Context()
808 ctx = zmq.Context()
809 loop = ioloop.IOLoop()
809 loop = ioloop.IOLoop()
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
810 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
811 util.set_hwm(ins, 0)
811 util.set_hwm(ins, 0)
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
812 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
813 ins.bind(in_addr)
813 ins.bind(in_addr)
814
814
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
815 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
816 util.set_hwm(outs, 0)
816 util.set_hwm(outs, 0)
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
817 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
818 outs.bind(out_addr)
818 outs.bind(out_addr)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
819 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
820 util.set_hwm(mons, 0)
820 util.set_hwm(mons, 0)
821 mons.connect(mon_addr)
821 mons.connect(mon_addr)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
822 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
823 nots.setsockopt(zmq.SUBSCRIBE, b'')
824 nots.connect(not_addr)
824 nots.connect(not_addr)
825
825
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
826 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
827 querys.connect(reg_addr)
827 querys.connect(reg_addr)
828
828
829 # setup logging.
829 # setup logging.
830 if in_thread:
830 if in_thread:
831 log = Application.instance().log
831 log = Application.instance().log
832 else:
832 else:
833 if log_url:
833 if log_url:
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
834 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
835 else:
835 else:
836 log = local_logger(logname, loglevel)
836 log = local_logger(logname, loglevel)
837
837
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
838 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
839 mon_stream=mons, notifier_stream=nots,
839 mon_stream=mons, notifier_stream=nots,
840 query_stream=querys,
840 query_stream=querys,
841 loop=loop, log=log,
841 loop=loop, log=log,
842 config=config)
842 config=config)
843 scheduler.start()
843 scheduler.start()
844 if not in_thread:
844 if not in_thread:
845 try:
845 try:
846 loop.start()
846 loop.start()
847 except KeyboardInterrupt:
847 except KeyboardInterrupt:
848 scheduler.log.critical("Interrupted, exiting...")
848 scheduler.log.critical("Interrupted, exiting...")
849
849
@@ -1,6 +1,6 b''
1 def main():
1 def main():
2 from IPython.parallel.apps import ipengineapp as app
2 from ipython_parallel.apps import ipengineapp as app
3 app.launch_new_instance()
3 app.launch_new_instance()
4
4
5 if __name__ == '__main__':
5 if __name__ == '__main__':
6 main()
6 main()
@@ -1,301 +1,301 b''
1 """A simple engine that talks to a controller over 0MQ.
1 """A simple engine that talks to a controller over 0MQ.
2 it handles registration, etc. and launches a kernel
2 it handles registration, etc. and launches a kernel
3 connected to the Controller's Schedulers.
3 connected to the Controller's Schedulers.
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9 from __future__ import print_function
9 from __future__ import print_function
10
10
11 import sys
11 import sys
12 import time
12 import time
13 from getpass import getpass
13 from getpass import getpass
14
14
15 import zmq
15 import zmq
16 from zmq.eventloop import ioloop, zmqstream
16 from zmq.eventloop import ioloop, zmqstream
17
17
18 from IPython.utils.localinterfaces import localhost
18 from IPython.utils.localinterfaces import localhost
19 from IPython.utils.traitlets import (
19 from IPython.utils.traitlets import (
20 Instance, Dict, Integer, Type, Float, Unicode, CBytes, Bool
20 Instance, Dict, Integer, Type, Float, Unicode, CBytes, Bool
21 )
21 )
22 from IPython.utils.py3compat import cast_bytes
22 from IPython.utils.py3compat import cast_bytes
23
23
24 from IPython.parallel.controller.heartmonitor import Heart
24 from ipython_parallel.controller.heartmonitor import Heart
25 from IPython.parallel.factory import RegistrationFactory
25 from ipython_parallel.factory import RegistrationFactory
26 from IPython.parallel.util import disambiguate_url
26 from ipython_parallel.util import disambiguate_url
27
27
28 from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
28 from IPython.kernel.zmq.ipkernel import IPythonKernel as Kernel
29 from IPython.kernel.zmq.kernelapp import IPKernelApp
29 from IPython.kernel.zmq.kernelapp import IPKernelApp
30
30
31 class EngineFactory(RegistrationFactory):
31 class EngineFactory(RegistrationFactory):
32 """IPython engine"""
32 """IPython engine"""
33
33
34 # configurables:
34 # configurables:
35 out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
35 out_stream_factory=Type('IPython.kernel.zmq.iostream.OutStream', config=True,
36 help="""The OutStream for handling stdout/err.
36 help="""The OutStream for handling stdout/err.
37 Typically 'IPython.kernel.zmq.iostream.OutStream'""")
37 Typically 'IPython.kernel.zmq.iostream.OutStream'""")
38 display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
38 display_hook_factory=Type('IPython.kernel.zmq.displayhook.ZMQDisplayHook', config=True,
39 help="""The class for handling displayhook.
39 help="""The class for handling displayhook.
40 Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
40 Typically 'IPython.kernel.zmq.displayhook.ZMQDisplayHook'""")
41 location=Unicode(config=True,
41 location=Unicode(config=True,
42 help="""The location (an IP address) of the controller. This is
42 help="""The location (an IP address) of the controller. This is
43 used for disambiguating URLs, to determine whether
43 used for disambiguating URLs, to determine whether
44 loopback should be used to connect or the public address.""")
44 loopback should be used to connect or the public address.""")
45 timeout=Float(5.0, config=True,
45 timeout=Float(5.0, config=True,
46 help="""The time (in seconds) to wait for the Controller to respond
46 help="""The time (in seconds) to wait for the Controller to respond
47 to registration requests before giving up.""")
47 to registration requests before giving up.""")
48 max_heartbeat_misses=Integer(50, config=True,
48 max_heartbeat_misses=Integer(50, config=True,
49 help="""The maximum number of times a check for the heartbeat ping of a
49 help="""The maximum number of times a check for the heartbeat ping of a
50 controller can be missed before shutting down the engine.
50 controller can be missed before shutting down the engine.
51
51
52 If set to 0, the check is disabled.""")
52 If set to 0, the check is disabled.""")
53 sshserver=Unicode(config=True,
53 sshserver=Unicode(config=True,
54 help="""The SSH server to use for tunneling connections to the Controller.""")
54 help="""The SSH server to use for tunneling connections to the Controller.""")
55 sshkey=Unicode(config=True,
55 sshkey=Unicode(config=True,
56 help="""The SSH private key file to use when tunneling connections to the Controller.""")
56 help="""The SSH private key file to use when tunneling connections to the Controller.""")
57 paramiko=Bool(sys.platform == 'win32', config=True,
57 paramiko=Bool(sys.platform == 'win32', config=True,
58 help="""Whether to use paramiko instead of openssh for tunnels.""")
58 help="""Whether to use paramiko instead of openssh for tunnels.""")
59
59
60 @property
60 @property
61 def tunnel_mod(self):
61 def tunnel_mod(self):
62 from zmq.ssh import tunnel
62 from zmq.ssh import tunnel
63 return tunnel
63 return tunnel
64
64
65
65
66 # not configurable:
66 # not configurable:
67 connection_info = Dict()
67 connection_info = Dict()
68 user_ns = Dict()
68 user_ns = Dict()
69 id = Integer(allow_none=True)
69 id = Integer(allow_none=True)
70 registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
70 registrar = Instance('zmq.eventloop.zmqstream.ZMQStream')
71 kernel = Instance(Kernel)
71 kernel = Instance(Kernel)
72 hb_check_period=Integer()
72 hb_check_period=Integer()
73
73
74 # States for the heartbeat monitoring
74 # States for the heartbeat monitoring
75 # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
75 # Initial values for monitored and pinged must satisfy "monitored > pinged == False" so that
76 # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
76 # during the first check no "missed" ping is reported. Must be floats for Python 3 compatibility.
77 _hb_last_pinged = 0.0
77 _hb_last_pinged = 0.0
78 _hb_last_monitored = 0.0
78 _hb_last_monitored = 0.0
79 _hb_missed_beats = 0
79 _hb_missed_beats = 0
80 # The zmq Stream which receives the pings from the Heart
80 # The zmq Stream which receives the pings from the Heart
81 _hb_listener = None
81 _hb_listener = None
82
82
83 bident = CBytes()
83 bident = CBytes()
84 ident = Unicode()
84 ident = Unicode()
85 def _ident_changed(self, name, old, new):
85 def _ident_changed(self, name, old, new):
86 self.bident = cast_bytes(new)
86 self.bident = cast_bytes(new)
87 using_ssh=Bool(False)
87 using_ssh=Bool(False)
88
88
89
89
90 def __init__(self, **kwargs):
90 def __init__(self, **kwargs):
91 super(EngineFactory, self).__init__(**kwargs)
91 super(EngineFactory, self).__init__(**kwargs)
92 self.ident = self.session.session
92 self.ident = self.session.session
93
93
94 def init_connector(self):
94 def init_connector(self):
95 """construct connection function, which handles tunnels."""
95 """construct connection function, which handles tunnels."""
96 self.using_ssh = bool(self.sshkey or self.sshserver)
96 self.using_ssh = bool(self.sshkey or self.sshserver)
97
97
98 if self.sshkey and not self.sshserver:
98 if self.sshkey and not self.sshserver:
99 # We are using ssh directly to the controller, tunneling localhost to localhost
99 # We are using ssh directly to the controller, tunneling localhost to localhost
100 self.sshserver = self.url.split('://')[1].split(':')[0]
100 self.sshserver = self.url.split('://')[1].split(':')[0]
101
101
102 if self.using_ssh:
102 if self.using_ssh:
103 if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
103 if self.tunnel_mod.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
104 password=False
104 password=False
105 else:
105 else:
106 password = getpass("SSH Password for %s: "%self.sshserver)
106 password = getpass("SSH Password for %s: "%self.sshserver)
107 else:
107 else:
108 password = False
108 password = False
109
109
110 def connect(s, url):
110 def connect(s, url):
111 url = disambiguate_url(url, self.location)
111 url = disambiguate_url(url, self.location)
112 if self.using_ssh:
112 if self.using_ssh:
113 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
113 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
114 return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
114 return self.tunnel_mod.tunnel_connection(s, url, self.sshserver,
115 keyfile=self.sshkey, paramiko=self.paramiko,
115 keyfile=self.sshkey, paramiko=self.paramiko,
116 password=password,
116 password=password,
117 )
117 )
118 else:
118 else:
119 return s.connect(url)
119 return s.connect(url)
120
120
121 def maybe_tunnel(url):
121 def maybe_tunnel(url):
122 """like connect, but don't complete the connection (for use by heartbeat)"""
122 """like connect, but don't complete the connection (for use by heartbeat)"""
123 url = disambiguate_url(url, self.location)
123 url = disambiguate_url(url, self.location)
124 if self.using_ssh:
124 if self.using_ssh:
125 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
125 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
126 url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
126 url, tunnelobj = self.tunnel_mod.open_tunnel(url, self.sshserver,
127 keyfile=self.sshkey, paramiko=self.paramiko,
127 keyfile=self.sshkey, paramiko=self.paramiko,
128 password=password,
128 password=password,
129 )
129 )
130 return str(url)
130 return str(url)
131 return connect, maybe_tunnel
131 return connect, maybe_tunnel
132
132
133 def register(self):
133 def register(self):
134 """send the registration_request"""
134 """send the registration_request"""
135
135
136 self.log.info("Registering with controller at %s"%self.url)
136 self.log.info("Registering with controller at %s"%self.url)
137 ctx = self.context
137 ctx = self.context
138 connect,maybe_tunnel = self.init_connector()
138 connect,maybe_tunnel = self.init_connector()
139 reg = ctx.socket(zmq.DEALER)
139 reg = ctx.socket(zmq.DEALER)
140 reg.setsockopt(zmq.IDENTITY, self.bident)
140 reg.setsockopt(zmq.IDENTITY, self.bident)
141 connect(reg, self.url)
141 connect(reg, self.url)
142 self.registrar = zmqstream.ZMQStream(reg, self.loop)
142 self.registrar = zmqstream.ZMQStream(reg, self.loop)
143
143
144
144
145 content = dict(uuid=self.ident)
145 content = dict(uuid=self.ident)
146 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
146 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
147 # print (self.session.key)
147 # print (self.session.key)
148 self.session.send(self.registrar, "registration_request", content=content)
148 self.session.send(self.registrar, "registration_request", content=content)
149
149
150 def _report_ping(self, msg):
150 def _report_ping(self, msg):
151 """Callback for when the heartmonitor.Heart receives a ping"""
151 """Callback for when the heartmonitor.Heart receives a ping"""
152 #self.log.debug("Received a ping: %s", msg)
152 #self.log.debug("Received a ping: %s", msg)
153 self._hb_last_pinged = time.time()
153 self._hb_last_pinged = time.time()
154
154
155 def complete_registration(self, msg, connect, maybe_tunnel):
155 def complete_registration(self, msg, connect, maybe_tunnel):
156 # print msg
156 # print msg
157 self.loop.remove_timeout(self._abort_timeout)
157 self.loop.remove_timeout(self._abort_timeout)
158 ctx = self.context
158 ctx = self.context
159 loop = self.loop
159 loop = self.loop
160 identity = self.bident
160 identity = self.bident
161 idents,msg = self.session.feed_identities(msg)
161 idents,msg = self.session.feed_identities(msg)
162 msg = self.session.deserialize(msg)
162 msg = self.session.deserialize(msg)
163 content = msg['content']
163 content = msg['content']
164 info = self.connection_info
164 info = self.connection_info
165
165
166 def url(key):
166 def url(key):
167 """get zmq url for given channel"""
167 """get zmq url for given channel"""
168 return str(info["interface"] + ":%i" % info[key])
168 return str(info["interface"] + ":%i" % info[key])
169
169
170 if content['status'] == 'ok':
170 if content['status'] == 'ok':
171 self.id = int(content['id'])
171 self.id = int(content['id'])
172
172
173 # launch heartbeat
173 # launch heartbeat
174 # possibly forward hb ports with tunnels
174 # possibly forward hb ports with tunnels
175 hb_ping = maybe_tunnel(url('hb_ping'))
175 hb_ping = maybe_tunnel(url('hb_ping'))
176 hb_pong = maybe_tunnel(url('hb_pong'))
176 hb_pong = maybe_tunnel(url('hb_pong'))
177
177
178 hb_monitor = None
178 hb_monitor = None
179 if self.max_heartbeat_misses > 0:
179 if self.max_heartbeat_misses > 0:
180 # Add a monitor socket which will record the last time a ping was seen
180 # Add a monitor socket which will record the last time a ping was seen
181 mon = self.context.socket(zmq.SUB)
181 mon = self.context.socket(zmq.SUB)
182 mport = mon.bind_to_random_port('tcp://%s' % localhost())
182 mport = mon.bind_to_random_port('tcp://%s' % localhost())
183 mon.setsockopt(zmq.SUBSCRIBE, b"")
183 mon.setsockopt(zmq.SUBSCRIBE, b"")
184 self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
184 self._hb_listener = zmqstream.ZMQStream(mon, self.loop)
185 self._hb_listener.on_recv(self._report_ping)
185 self._hb_listener.on_recv(self._report_ping)
186
186
187
187
188 hb_monitor = "tcp://%s:%i" % (localhost(), mport)
188 hb_monitor = "tcp://%s:%i" % (localhost(), mport)
189
189
190 heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
190 heart = Heart(hb_ping, hb_pong, hb_monitor , heart_id=identity)
191 heart.start()
191 heart.start()
192
192
193 # create Shell Connections (MUX, Task, etc.):
193 # create Shell Connections (MUX, Task, etc.):
194 shell_addrs = url('mux'), url('task')
194 shell_addrs = url('mux'), url('task')
195
195
196 # Use only one shell stream for mux and tasks
196 # Use only one shell stream for mux and tasks
197 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
197 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
198 stream.setsockopt(zmq.IDENTITY, identity)
198 stream.setsockopt(zmq.IDENTITY, identity)
199 shell_streams = [stream]
199 shell_streams = [stream]
200 for addr in shell_addrs:
200 for addr in shell_addrs:
201 connect(stream, addr)
201 connect(stream, addr)
202
202
203 # control stream:
203 # control stream:
204 control_addr = url('control')
204 control_addr = url('control')
205 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
205 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
206 control_stream.setsockopt(zmq.IDENTITY, identity)
206 control_stream.setsockopt(zmq.IDENTITY, identity)
207 connect(control_stream, control_addr)
207 connect(control_stream, control_addr)
208
208
209 # create iopub stream:
209 # create iopub stream:
210 iopub_addr = url('iopub')
210 iopub_addr = url('iopub')
211 iopub_socket = ctx.socket(zmq.PUB)
211 iopub_socket = ctx.socket(zmq.PUB)
212 iopub_socket.setsockopt(zmq.IDENTITY, identity)
212 iopub_socket.setsockopt(zmq.IDENTITY, identity)
213 connect(iopub_socket, iopub_addr)
213 connect(iopub_socket, iopub_addr)
214
214
215 # disable history:
215 # disable history:
216 self.config.HistoryManager.hist_file = ':memory:'
216 self.config.HistoryManager.hist_file = ':memory:'
217
217
218 # Redirect input streams and set a display hook.
218 # Redirect input streams and set a display hook.
219 if self.out_stream_factory:
219 if self.out_stream_factory:
220 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
220 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
221 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
221 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
222 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
222 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
223 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
223 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
224 if self.display_hook_factory:
224 if self.display_hook_factory:
225 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
225 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
226 sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
226 sys.displayhook.topic = cast_bytes('engine.%i.execute_result' % self.id)
227
227
228 self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
228 self.kernel = Kernel(parent=self, int_id=self.id, ident=self.ident, session=self.session,
229 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
229 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
230 loop=loop, user_ns=self.user_ns, log=self.log)
230 loop=loop, user_ns=self.user_ns, log=self.log)
231
231
232 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
232 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
233
233
234
234
235 # periodically check the heartbeat pings of the controller
235 # periodically check the heartbeat pings of the controller
236 # Should be started here and not in "start()" so that the right period can be taken
236 # Should be started here and not in "start()" so that the right period can be taken
237 # from the hubs HeartBeatMonitor.period
237 # from the hubs HeartBeatMonitor.period
238 if self.max_heartbeat_misses > 0:
238 if self.max_heartbeat_misses > 0:
239 # Use a slightly bigger check period than the hub signal period to not warn unnecessary
239 # Use a slightly bigger check period than the hub signal period to not warn unnecessary
240 self.hb_check_period = int(content['hb_period'])+10
240 self.hb_check_period = int(content['hb_period'])+10
241 self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
241 self.log.info("Starting to monitor the heartbeat signal from the hub every %i ms." , self.hb_check_period)
242 self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
242 self._hb_reporter = ioloop.PeriodicCallback(self._hb_monitor, self.hb_check_period, self.loop)
243 self._hb_reporter.start()
243 self._hb_reporter.start()
244 else:
244 else:
245 self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
245 self.log.info("Monitoring of the heartbeat signal from the hub is not enabled.")
246
246
247
247
248 # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
248 # FIXME: This is a hack until IPKernelApp and IPEngineApp can be fully merged
249 app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
249 app = IPKernelApp(parent=self, shell=self.kernel.shell, kernel=self.kernel, log=self.log)
250 app.init_profile_dir()
250 app.init_profile_dir()
251 app.init_code()
251 app.init_code()
252
252
253 self.kernel.start()
253 self.kernel.start()
254 else:
254 else:
255 self.log.fatal("Registration Failed: %s"%msg)
255 self.log.fatal("Registration Failed: %s"%msg)
256 raise Exception("Registration Failed: %s"%msg)
256 raise Exception("Registration Failed: %s"%msg)
257
257
258 self.log.info("Completed registration with id %i"%self.id)
258 self.log.info("Completed registration with id %i"%self.id)
259
259
260
260
261 def abort(self):
261 def abort(self):
262 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
262 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
263 if self.url.startswith('127.'):
263 if self.url.startswith('127.'):
264 self.log.fatal("""
264 self.log.fatal("""
265 If the controller and engines are not on the same machine,
265 If the controller and engines are not on the same machine,
266 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
266 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
267 c.HubFactory.ip='*' # for all interfaces, internal and external
267 c.HubFactory.ip='*' # for all interfaces, internal and external
268 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
268 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
269 or tunnel connections via ssh.
269 or tunnel connections via ssh.
270 """)
270 """)
271 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
271 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
272 time.sleep(1)
272 time.sleep(1)
273 sys.exit(255)
273 sys.exit(255)
274
274
275 def _hb_monitor(self):
275 def _hb_monitor(self):
276 """Callback to monitor the heartbeat from the controller"""
276 """Callback to monitor the heartbeat from the controller"""
277 self._hb_listener.flush()
277 self._hb_listener.flush()
278 if self._hb_last_monitored > self._hb_last_pinged:
278 if self._hb_last_monitored > self._hb_last_pinged:
279 self._hb_missed_beats += 1
279 self._hb_missed_beats += 1
280 self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
280 self.log.warn("No heartbeat in the last %s ms (%s time(s) in a row).", self.hb_check_period, self._hb_missed_beats)
281 else:
281 else:
282 #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
282 #self.log.debug("Heartbeat received (after missing %s beats).", self._hb_missed_beats)
283 self._hb_missed_beats = 0
283 self._hb_missed_beats = 0
284
284
285 if self._hb_missed_beats >= self.max_heartbeat_misses:
285 if self._hb_missed_beats >= self.max_heartbeat_misses:
286 self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
286 self.log.fatal("Maximum number of heartbeats misses reached (%s times %s ms), shutting down.",
287 self.max_heartbeat_misses, self.hb_check_period)
287 self.max_heartbeat_misses, self.hb_check_period)
288 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
288 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
289 self.loop.stop()
289 self.loop.stop()
290
290
291 self._hb_last_monitored = time.time()
291 self._hb_last_monitored = time.time()
292
292
293
293
294 def start(self):
294 def start(self):
295 loop = self.loop
295 loop = self.loop
296 def _start():
296 def _start():
297 self.register()
297 self.register()
298 self._abort_timeout = loop.add_timeout(loop.time() + self.timeout, self.abort)
298 self._abort_timeout = loop.add_timeout(loop.time() + self.timeout, self.abort)
299 self.loop.add_callback(_start)
299 self.loop.add_callback(_start)
300
300
301
301
@@ -1,252 +1,252 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes and functions for kernel related errors and exceptions.
3 """Classes and functions for kernel related errors and exceptions.
4
4
5 Inheritance diagram:
5 Inheritance diagram:
6
6
7 .. inheritance-diagram:: IPython.parallel.error
7 .. inheritance-diagram:: ipython_parallel.error
8 :parts: 3
8 :parts: 3
9
9
10 Authors:
10 Authors:
11
11
12 * Brian Granger
12 * Brian Granger
13 * Min RK
13 * Min RK
14 """
14 """
15 from __future__ import print_function
15 from __future__ import print_function
16
16
17 import sys
17 import sys
18 import traceback
18 import traceback
19
19
20 from IPython.utils.py3compat import unicode_type
20 from IPython.utils.py3compat import unicode_type
21
21
22 __docformat__ = "restructuredtext en"
22 __docformat__ = "restructuredtext en"
23
23
24 # Tell nose to skip this module
24 # Tell nose to skip this module
25 __test__ = {}
25 __test__ = {}
26
26
27 #-------------------------------------------------------------------------------
27 #-------------------------------------------------------------------------------
28 # Copyright (C) 2008-2011 The IPython Development Team
28 # Copyright (C) 2008-2011 The IPython Development Team
29 #
29 #
30 # Distributed under the terms of the BSD License. The full license is in
30 # Distributed under the terms of the BSD License. The full license is in
31 # the file COPYING, distributed as part of this software.
31 # the file COPYING, distributed as part of this software.
32 #-------------------------------------------------------------------------------
32 #-------------------------------------------------------------------------------
33
33
34 #-------------------------------------------------------------------------------
34 #-------------------------------------------------------------------------------
35 # Error classes
35 # Error classes
36 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
37 class IPythonError(Exception):
37 class IPythonError(Exception):
38 """Base exception that all of our exceptions inherit from.
38 """Base exception that all of our exceptions inherit from.
39
39
40 This can be raised by code that doesn't have any more specific
40 This can be raised by code that doesn't have any more specific
41 information."""
41 information."""
42
42
43 pass
43 pass
44
44
45 class KernelError(IPythonError):
45 class KernelError(IPythonError):
46 pass
46 pass
47
47
48 class EngineError(KernelError):
48 class EngineError(KernelError):
49 pass
49 pass
50
50
51 class NoEnginesRegistered(KernelError):
51 class NoEnginesRegistered(KernelError):
52 pass
52 pass
53
53
54 class TaskAborted(KernelError):
54 class TaskAborted(KernelError):
55 pass
55 pass
56
56
57 class TaskTimeout(KernelError):
57 class TaskTimeout(KernelError):
58 pass
58 pass
59
59
60 class TimeoutError(KernelError):
60 class TimeoutError(KernelError):
61 pass
61 pass
62
62
63 class UnmetDependency(KernelError):
63 class UnmetDependency(KernelError):
64 pass
64 pass
65
65
66 class ImpossibleDependency(UnmetDependency):
66 class ImpossibleDependency(UnmetDependency):
67 pass
67 pass
68
68
69 class DependencyTimeout(ImpossibleDependency):
69 class DependencyTimeout(ImpossibleDependency):
70 pass
70 pass
71
71
72 class InvalidDependency(ImpossibleDependency):
72 class InvalidDependency(ImpossibleDependency):
73 pass
73 pass
74
74
75 class RemoteError(KernelError):
75 class RemoteError(KernelError):
76 """Error raised elsewhere"""
76 """Error raised elsewhere"""
77 ename=None
77 ename=None
78 evalue=None
78 evalue=None
79 traceback=None
79 traceback=None
80 engine_info=None
80 engine_info=None
81
81
82 def __init__(self, ename, evalue, traceback, engine_info=None):
82 def __init__(self, ename, evalue, traceback, engine_info=None):
83 self.ename=ename
83 self.ename=ename
84 self.evalue=evalue
84 self.evalue=evalue
85 self.traceback=traceback
85 self.traceback=traceback
86 self.engine_info=engine_info or {}
86 self.engine_info=engine_info or {}
87 self.args=(ename, evalue)
87 self.args=(ename, evalue)
88
88
89 def __repr__(self):
89 def __repr__(self):
90 engineid = self.engine_info.get('engine_id', ' ')
90 engineid = self.engine_info.get('engine_id', ' ')
91 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
91 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
92
92
93 def __str__(self):
93 def __str__(self):
94 return "%s(%s)" % (self.ename, self.evalue)
94 return "%s(%s)" % (self.ename, self.evalue)
95
95
96 def render_traceback(self):
96 def render_traceback(self):
97 """render traceback to a list of lines"""
97 """render traceback to a list of lines"""
98 return (self.traceback or "No traceback available").splitlines()
98 return (self.traceback or "No traceback available").splitlines()
99
99
100 def _render_traceback_(self):
100 def _render_traceback_(self):
101 """Special method for custom tracebacks within IPython.
101 """Special method for custom tracebacks within IPython.
102
102
103 This will be called by IPython instead of displaying the local traceback.
103 This will be called by IPython instead of displaying the local traceback.
104
104
105 It should return a traceback rendered as a list of lines.
105 It should return a traceback rendered as a list of lines.
106 """
106 """
107 return self.render_traceback()
107 return self.render_traceback()
108
108
109 def print_traceback(self, excid=None):
109 def print_traceback(self, excid=None):
110 """print my traceback"""
110 """print my traceback"""
111 print('\n'.join(self.render_traceback()))
111 print('\n'.join(self.render_traceback()))
112
112
113
113
114
114
115
115
116 class TaskRejectError(KernelError):
116 class TaskRejectError(KernelError):
117 """Exception to raise when a task should be rejected by an engine.
117 """Exception to raise when a task should be rejected by an engine.
118
118
119 This exception can be used to allow a task running on an engine to test
119 This exception can be used to allow a task running on an engine to test
120 if the engine (or the user's namespace on the engine) has the needed
120 if the engine (or the user's namespace on the engine) has the needed
121 task dependencies. If not, the task should raise this exception. For
121 task dependencies. If not, the task should raise this exception. For
122 the task to be retried on another engine, the task should be created
122 the task to be retried on another engine, the task should be created
123 with the `retries` argument > 1.
123 with the `retries` argument > 1.
124
124
125 The advantage of this approach over our older properties system is that
125 The advantage of this approach over our older properties system is that
126 tasks have full access to the user's namespace on the engines and the
126 tasks have full access to the user's namespace on the engines and the
127 properties don't have to be managed or tested by the controller.
127 properties don't have to be managed or tested by the controller.
128 """
128 """
129
129
130
130
131 class CompositeError(RemoteError):
131 class CompositeError(RemoteError):
132 """Error for representing possibly multiple errors on engines"""
132 """Error for representing possibly multiple errors on engines"""
133 tb_limit = 4 # limit on how many tracebacks to draw
133 tb_limit = 4 # limit on how many tracebacks to draw
134
134
135 def __init__(self, message, elist):
135 def __init__(self, message, elist):
136 Exception.__init__(self, *(message, elist))
136 Exception.__init__(self, *(message, elist))
137 # Don't use pack_exception because it will conflict with the .message
137 # Don't use pack_exception because it will conflict with the .message
138 # attribute that is being deprecated in 2.6 and beyond.
138 # attribute that is being deprecated in 2.6 and beyond.
139 self.msg = message
139 self.msg = message
140 self.elist = elist
140 self.elist = elist
141 self.args = [ e[0] for e in elist ]
141 self.args = [ e[0] for e in elist ]
142
142
143 def _get_engine_str(self, ei):
143 def _get_engine_str(self, ei):
144 if not ei:
144 if not ei:
145 return '[Engine Exception]'
145 return '[Engine Exception]'
146 else:
146 else:
147 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
147 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
148
148
149 def _get_traceback(self, ev):
149 def _get_traceback(self, ev):
150 try:
150 try:
151 tb = ev._ipython_traceback_text
151 tb = ev._ipython_traceback_text
152 except AttributeError:
152 except AttributeError:
153 return 'No traceback available'
153 return 'No traceback available'
154 else:
154 else:
155 return tb
155 return tb
156
156
157 def __str__(self):
157 def __str__(self):
158 s = str(self.msg)
158 s = str(self.msg)
159 for en, ev, etb, ei in self.elist[:self.tb_limit]:
159 for en, ev, etb, ei in self.elist[:self.tb_limit]:
160 engine_str = self._get_engine_str(ei)
160 engine_str = self._get_engine_str(ei)
161 s = s + '\n' + engine_str + en + ': ' + str(ev)
161 s = s + '\n' + engine_str + en + ': ' + str(ev)
162 if len(self.elist) > self.tb_limit:
162 if len(self.elist) > self.tb_limit:
163 s = s + '\n.... %i more exceptions ...' % (len(self.elist) - self.tb_limit)
163 s = s + '\n.... %i more exceptions ...' % (len(self.elist) - self.tb_limit)
164 return s
164 return s
165
165
166 def __repr__(self):
166 def __repr__(self):
167 return "CompositeError(%i)" % len(self.elist)
167 return "CompositeError(%i)" % len(self.elist)
168
168
169 def render_traceback(self, excid=None):
169 def render_traceback(self, excid=None):
170 """render one or all of my tracebacks to a list of lines"""
170 """render one or all of my tracebacks to a list of lines"""
171 lines = []
171 lines = []
172 if excid is None:
172 if excid is None:
173 for (en,ev,etb,ei) in self.elist[:self.tb_limit]:
173 for (en,ev,etb,ei) in self.elist[:self.tb_limit]:
174 lines.append(self._get_engine_str(ei))
174 lines.append(self._get_engine_str(ei))
175 lines.extend((etb or 'No traceback available').splitlines())
175 lines.extend((etb or 'No traceback available').splitlines())
176 lines.append('')
176 lines.append('')
177 if len(self.elist) > self.tb_limit:
177 if len(self.elist) > self.tb_limit:
178 lines.append(
178 lines.append(
179 '... %i more exceptions ...' % (len(self.elist) - self.tb_limit)
179 '... %i more exceptions ...' % (len(self.elist) - self.tb_limit)
180 )
180 )
181 else:
181 else:
182 try:
182 try:
183 en,ev,etb,ei = self.elist[excid]
183 en,ev,etb,ei = self.elist[excid]
184 except:
184 except:
185 raise IndexError("an exception with index %i does not exist"%excid)
185 raise IndexError("an exception with index %i does not exist"%excid)
186 else:
186 else:
187 lines.append(self._get_engine_str(ei))
187 lines.append(self._get_engine_str(ei))
188 lines.extend((etb or 'No traceback available').splitlines())
188 lines.extend((etb or 'No traceback available').splitlines())
189
189
190 return lines
190 return lines
191
191
192 def print_traceback(self, excid=None):
192 def print_traceback(self, excid=None):
193 print('\n'.join(self.render_traceback(excid)))
193 print('\n'.join(self.render_traceback(excid)))
194
194
195 def raise_exception(self, excid=0):
195 def raise_exception(self, excid=0):
196 try:
196 try:
197 en,ev,etb,ei = self.elist[excid]
197 en,ev,etb,ei = self.elist[excid]
198 except:
198 except:
199 raise IndexError("an exception with index %i does not exist"%excid)
199 raise IndexError("an exception with index %i does not exist"%excid)
200 else:
200 else:
201 raise RemoteError(en, ev, etb, ei)
201 raise RemoteError(en, ev, etb, ei)
202
202
203
203
204 def collect_exceptions(rdict_or_list, method='unspecified'):
204 def collect_exceptions(rdict_or_list, method='unspecified'):
205 """check a result dict for errors, and raise CompositeError if any exist.
205 """check a result dict for errors, and raise CompositeError if any exist.
206 Passthrough otherwise."""
206 Passthrough otherwise."""
207 elist = []
207 elist = []
208 if isinstance(rdict_or_list, dict):
208 if isinstance(rdict_or_list, dict):
209 rlist = rdict_or_list.values()
209 rlist = rdict_or_list.values()
210 else:
210 else:
211 rlist = rdict_or_list
211 rlist = rdict_or_list
212 for r in rlist:
212 for r in rlist:
213 if isinstance(r, RemoteError):
213 if isinstance(r, RemoteError):
214 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
214 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
215 # Sometimes we could have CompositeError in our list. Just take
215 # Sometimes we could have CompositeError in our list. Just take
216 # the errors out of them and put them in our new list. This
216 # the errors out of them and put them in our new list. This
217 # has the effect of flattening lists of CompositeErrors into one
217 # has the effect of flattening lists of CompositeErrors into one
218 # CompositeError
218 # CompositeError
219 if en=='CompositeError':
219 if en=='CompositeError':
220 for e in ev.elist:
220 for e in ev.elist:
221 elist.append(e)
221 elist.append(e)
222 else:
222 else:
223 elist.append((en, ev, etb, ei))
223 elist.append((en, ev, etb, ei))
224 if len(elist)==0:
224 if len(elist)==0:
225 return rdict_or_list
225 return rdict_or_list
226 else:
226 else:
227 msg = "one or more exceptions from call to method: %s" % (method)
227 msg = "one or more exceptions from call to method: %s" % (method)
228 # This silliness is needed so the debugger has access to the exception
228 # This silliness is needed so the debugger has access to the exception
229 # instance (e in this case)
229 # instance (e in this case)
230 try:
230 try:
231 raise CompositeError(msg, elist)
231 raise CompositeError(msg, elist)
232 except CompositeError as e:
232 except CompositeError as e:
233 raise e
233 raise e
234
234
235 def wrap_exception(engine_info={}):
235 def wrap_exception(engine_info={}):
236 etype, evalue, tb = sys.exc_info()
236 etype, evalue, tb = sys.exc_info()
237 stb = traceback.format_exception(etype, evalue, tb)
237 stb = traceback.format_exception(etype, evalue, tb)
238 exc_content = {
238 exc_content = {
239 'status' : 'error',
239 'status' : 'error',
240 'traceback' : stb,
240 'traceback' : stb,
241 'ename' : unicode_type(etype.__name__),
241 'ename' : unicode_type(etype.__name__),
242 'evalue' : unicode_type(evalue),
242 'evalue' : unicode_type(evalue),
243 'engine_info' : engine_info
243 'engine_info' : engine_info
244 }
244 }
245 return exc_content
245 return exc_content
246
246
247 def unwrap_exception(content):
247 def unwrap_exception(content):
248 err = RemoteError(content['ename'], content['evalue'],
248 err = RemoteError(content['ename'], content['evalue'],
249 ''.join(content['traceback']),
249 ''.join(content['traceback']),
250 content.get('engine_info', {}))
250 content.get('engine_info', {}))
251 return err
251 return err
252
252
@@ -1,73 +1,73 b''
1 """Base config factories.
1 """Base config factories.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 from IPython.utils.localinterfaces import localhost
19 from IPython.utils.localinterfaces import localhost
20 from IPython.utils.traitlets import Integer, Unicode
20 from IPython.utils.traitlets import Integer, Unicode
21
21
22 from IPython.parallel.util import select_random_ports
22 from ipython_parallel.util import select_random_ports
23 from IPython.kernel.zmq.session import SessionFactory
23 from IPython.kernel.zmq.session import SessionFactory
24
24
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26 # Classes
26 # Classes
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28
28
29
29
30 class RegistrationFactory(SessionFactory):
30 class RegistrationFactory(SessionFactory):
31 """The Base Configurable for objects that involve registration."""
31 """The Base Configurable for objects that involve registration."""
32
32
33 url = Unicode('', config=True,
33 url = Unicode('', config=True,
34 help="""The 0MQ url used for registration. This sets transport, ip, and port
34 help="""The 0MQ url used for registration. This sets transport, ip, and port
35 in one variable. For example: url='tcp://127.0.0.1:12345' or
35 in one variable. For example: url='tcp://127.0.0.1:12345' or
36 url='epgm://*:90210'"""
36 url='epgm://*:90210'"""
37 ) # url takes precedence over ip,regport,transport
37 ) # url takes precedence over ip,regport,transport
38 transport = Unicode('tcp', config=True,
38 transport = Unicode('tcp', config=True,
39 help="""The 0MQ transport for communications. This will likely be
39 help="""The 0MQ transport for communications. This will likely be
40 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
40 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
41 ip = Unicode(config=True,
41 ip = Unicode(config=True,
42 help="""The IP address for registration. This is generally either
42 help="""The IP address for registration. This is generally either
43 '127.0.0.1' for loopback only or '*' for all interfaces.
43 '127.0.0.1' for loopback only or '*' for all interfaces.
44 """)
44 """)
45 def _ip_default(self):
45 def _ip_default(self):
46 return localhost()
46 return localhost()
47 regport = Integer(config=True,
47 regport = Integer(config=True,
48 help="""The port on which the Hub listens for registration.""")
48 help="""The port on which the Hub listens for registration.""")
49 def _regport_default(self):
49 def _regport_default(self):
50 return select_random_ports(1)[0]
50 return select_random_ports(1)[0]
51
51
52 def __init__(self, **kwargs):
52 def __init__(self, **kwargs):
53 super(RegistrationFactory, self).__init__(**kwargs)
53 super(RegistrationFactory, self).__init__(**kwargs)
54 self._propagate_url()
54 self._propagate_url()
55 self._rebuild_url()
55 self._rebuild_url()
56 self.on_trait_change(self._propagate_url, 'url')
56 self.on_trait_change(self._propagate_url, 'url')
57 self.on_trait_change(self._rebuild_url, 'ip')
57 self.on_trait_change(self._rebuild_url, 'ip')
58 self.on_trait_change(self._rebuild_url, 'transport')
58 self.on_trait_change(self._rebuild_url, 'transport')
59 self.on_trait_change(self._rebuild_url, 'regport')
59 self.on_trait_change(self._rebuild_url, 'regport')
60
60
61 def _rebuild_url(self):
61 def _rebuild_url(self):
62 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
62 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
63
63
64 def _propagate_url(self):
64 def _propagate_url(self):
65 """Ensure self.url contains full transport://interface:port"""
65 """Ensure self.url contains full transport://interface:port"""
66 if self.url:
66 if self.url:
67 iface = self.url.split('://',1)
67 iface = self.url.split('://',1)
68 if len(iface) == 2:
68 if len(iface) == 2:
69 self.transport,iface = iface
69 self.transport,iface = iface
70 iface = iface.split(':')
70 iface = iface.split(':')
71 self.ip = iface[0]
71 self.ip = iface[0]
72 if iface[1]:
72 if iface[1]:
73 self.regport = int(iface[1])
73 self.regport = int(iface[1])
@@ -1,3 +1,3 b''
1 if __name__ == '__main__':
1 if __name__ == '__main__':
2 from IPython.parallel.apps import iploggerapp as app
2 from ipython_parallel.apps import iploggerapp as app
3 app.launch_new_instance()
3 app.launch_new_instance()
@@ -1,145 +1,145 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2 from __future__ import print_function
2 from __future__ import print_function
3
3
4 #-------------------------------------------------------------------------------
4 #-------------------------------------------------------------------------------
5 # Copyright (C) 2011 The IPython Development Team
5 # Copyright (C) 2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-------------------------------------------------------------------------------
9 #-------------------------------------------------------------------------------
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import tempfile
16 import tempfile
17 import time
17 import time
18 from subprocess import Popen, PIPE, STDOUT
18 from subprocess import Popen, PIPE, STDOUT
19
19
20 import nose
20 import nose
21
21
22 from IPython.utils.path import get_ipython_dir
22 from IPython.utils.path import get_ipython_dir
23 from IPython.parallel import Client, error
23 from ipython_parallel import Client, error
24 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
24 from ipython_parallel.apps.launcher import (LocalProcessLauncher,
25 ipengine_cmd_argv,
25 ipengine_cmd_argv,
26 ipcontroller_cmd_argv,
26 ipcontroller_cmd_argv,
27 SIGKILL,
27 SIGKILL,
28 ProcessStateError,
28 ProcessStateError,
29 )
29 )
30
30
31 # globals
31 # globals
32 launchers = []
32 launchers = []
33 blackhole = open(os.devnull, 'w')
33 blackhole = open(os.devnull, 'w')
34
34
35 # Launcher class
35 # Launcher class
36 class TestProcessLauncher(LocalProcessLauncher):
36 class TestProcessLauncher(LocalProcessLauncher):
37 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
37 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
38 def start(self):
38 def start(self):
39 if self.state == 'before':
39 if self.state == 'before':
40 # Store stdout & stderr to show with failing tests.
40 # Store stdout & stderr to show with failing tests.
41 # This is defined in IPython.testing.iptest
41 # This is defined in IPython.testing.iptest
42 self.process = Popen(self.args,
42 self.process = Popen(self.args,
43 stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT,
43 stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT,
44 env=os.environ,
44 env=os.environ,
45 cwd=self.work_dir
45 cwd=self.work_dir
46 )
46 )
47 self.notify_start(self.process.pid)
47 self.notify_start(self.process.pid)
48 self.poll = self.process.poll
48 self.poll = self.process.poll
49 else:
49 else:
50 s = 'The process was already started and has state: %r' % self.state
50 s = 'The process was already started and has state: %r' % self.state
51 raise ProcessStateError(s)
51 raise ProcessStateError(s)
52
52
53 # nose setup/teardown
53 # nose setup/teardown
54
54
55 def setup():
55 def setup():
56
56
57 # show tracebacks for RemoteErrors
57 # show tracebacks for RemoteErrors
58 class RemoteErrorWithTB(error.RemoteError):
58 class RemoteErrorWithTB(error.RemoteError):
59 def __str__(self):
59 def __str__(self):
60 s = super(RemoteErrorWithTB, self).__str__()
60 s = super(RemoteErrorWithTB, self).__str__()
61 return '\n'.join([s, self.traceback or ''])
61 return '\n'.join([s, self.traceback or ''])
62
62
63 error.RemoteError = RemoteErrorWithTB
63 error.RemoteError = RemoteErrorWithTB
64
64
65 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
65 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
66 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
66 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
67 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
67 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
68 for json in (engine_json, client_json):
68 for json in (engine_json, client_json):
69 if os.path.exists(json):
69 if os.path.exists(json):
70 os.remove(json)
70 os.remove(json)
71
71
72 cp = TestProcessLauncher()
72 cp = TestProcessLauncher()
73 cp.cmd_and_args = ipcontroller_cmd_argv + \
73 cp.cmd_and_args = ipcontroller_cmd_argv + \
74 ['--profile=iptest', '--log-level=20', '--ping=250', '--dictdb']
74 ['--profile=iptest', '--log-level=20', '--ping=250', '--dictdb']
75 cp.start()
75 cp.start()
76 launchers.append(cp)
76 launchers.append(cp)
77 tic = time.time()
77 tic = time.time()
78 while not os.path.exists(engine_json) or not os.path.exists(client_json):
78 while not os.path.exists(engine_json) or not os.path.exists(client_json):
79 if cp.poll() is not None:
79 if cp.poll() is not None:
80 raise RuntimeError("The test controller exited with status %s" % cp.poll())
80 raise RuntimeError("The test controller exited with status %s" % cp.poll())
81 elif time.time()-tic > 15:
81 elif time.time()-tic > 15:
82 raise RuntimeError("Timeout waiting for the test controller to start.")
82 raise RuntimeError("Timeout waiting for the test controller to start.")
83 time.sleep(0.1)
83 time.sleep(0.1)
84 add_engines(1)
84 add_engines(1)
85
85
86 def add_engines(n=1, profile='iptest', total=False):
86 def add_engines(n=1, profile='iptest', total=False):
87 """add a number of engines to a given profile.
87 """add a number of engines to a given profile.
88
88
89 If total is True, then already running engines are counted, and only
89 If total is True, then already running engines are counted, and only
90 the additional engines necessary (if any) are started.
90 the additional engines necessary (if any) are started.
91 """
91 """
92 rc = Client(profile=profile)
92 rc = Client(profile=profile)
93 base = len(rc)
93 base = len(rc)
94
94
95 if total:
95 if total:
96 n = max(n - base, 0)
96 n = max(n - base, 0)
97
97
98 eps = []
98 eps = []
99 for i in range(n):
99 for i in range(n):
100 ep = TestProcessLauncher()
100 ep = TestProcessLauncher()
101 ep.cmd_and_args = ipengine_cmd_argv + [
101 ep.cmd_and_args = ipengine_cmd_argv + [
102 '--profile=%s' % profile,
102 '--profile=%s' % profile,
103 '--log-level=50',
103 '--log-level=50',
104 '--InteractiveShell.colors=nocolor'
104 '--InteractiveShell.colors=nocolor'
105 ]
105 ]
106 ep.start()
106 ep.start()
107 launchers.append(ep)
107 launchers.append(ep)
108 eps.append(ep)
108 eps.append(ep)
109 tic = time.time()
109 tic = time.time()
110 while len(rc) < base+n:
110 while len(rc) < base+n:
111 if any([ ep.poll() is not None for ep in eps ]):
111 if any([ ep.poll() is not None for ep in eps ]):
112 raise RuntimeError("A test engine failed to start.")
112 raise RuntimeError("A test engine failed to start.")
113 elif time.time()-tic > 15:
113 elif time.time()-tic > 15:
114 raise RuntimeError("Timeout waiting for engines to connect.")
114 raise RuntimeError("Timeout waiting for engines to connect.")
115 time.sleep(.1)
115 time.sleep(.1)
116 rc.spin()
116 rc.spin()
117 rc.close()
117 rc.close()
118 return eps
118 return eps
119
119
120 def teardown():
120 def teardown():
121 try:
121 try:
122 time.sleep(1)
122 time.sleep(1)
123 except KeyboardInterrupt:
123 except KeyboardInterrupt:
124 return
124 return
125 while launchers:
125 while launchers:
126 p = launchers.pop()
126 p = launchers.pop()
127 if p.poll() is None:
127 if p.poll() is None:
128 try:
128 try:
129 p.stop()
129 p.stop()
130 except Exception as e:
130 except Exception as e:
131 print(e)
131 print(e)
132 pass
132 pass
133 if p.poll() is None:
133 if p.poll() is None:
134 try:
134 try:
135 time.sleep(.25)
135 time.sleep(.25)
136 except KeyboardInterrupt:
136 except KeyboardInterrupt:
137 return
137 return
138 if p.poll() is None:
138 if p.poll() is None:
139 try:
139 try:
140 print('cleaning up test process...')
140 print('cleaning up test process...')
141 p.signal(SIGKILL)
141 p.signal(SIGKILL)
142 except:
142 except:
143 print("couldn't shutdown process: ", p)
143 print("couldn't shutdown process: ", p)
144 blackhole.close()
144 blackhole.close()
145
145
@@ -1,192 +1,192 b''
1 """base class for parallel client tests
1 """base class for parallel client tests
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14 from __future__ import print_function
14 from __future__ import print_function
15
15
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import time
18 import time
19
19
20 from nose import SkipTest
20 from nose import SkipTest
21
21
22 import zmq
22 import zmq
23 from zmq.tests import BaseZMQTestCase
23 from zmq.tests import BaseZMQTestCase
24
24
25 from decorator import decorator
25 from decorator import decorator
26
26
27 from IPython.parallel import error
27 from ipython_parallel import error
28 from IPython.parallel import Client
28 from ipython_parallel import Client
29
29
30 from IPython.parallel.tests import launchers, add_engines
30 from ipython_parallel.tests import launchers, add_engines
31
31
32 # simple tasks for use in apply tests
32 # simple tasks for use in apply tests
33
33
34 def segfault():
34 def segfault():
35 """this will segfault"""
35 """this will segfault"""
36 import ctypes
36 import ctypes
37 ctypes.memset(-1,0,1)
37 ctypes.memset(-1,0,1)
38
38
39 def crash():
39 def crash():
40 """from stdlib crashers in the test suite"""
40 """from stdlib crashers in the test suite"""
41 import types
41 import types
42 if sys.platform.startswith('win'):
42 if sys.platform.startswith('win'):
43 import ctypes
43 import ctypes
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 if sys.version_info[0] >= 3:
46 if sys.version_info[0] >= 3:
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 args.insert(1, 0)
48 args.insert(1, 0)
49
49
50 co = types.CodeType(*args)
50 co = types.CodeType(*args)
51 exec(co)
51 exec(co)
52
52
53 def wait(n):
53 def wait(n):
54 """sleep for a time"""
54 """sleep for a time"""
55 import time
55 import time
56 time.sleep(n)
56 time.sleep(n)
57 return n
57 return n
58
58
59 def raiser(eclass):
59 def raiser(eclass):
60 """raise an exception"""
60 """raise an exception"""
61 raise eclass()
61 raise eclass()
62
62
63 def generate_output():
63 def generate_output():
64 """function for testing output
64 """function for testing output
65
65
66 publishes two outputs of each type, and returns
66 publishes two outputs of each type, and returns
67 a rich displayable object.
67 a rich displayable object.
68 """
68 """
69
69
70 import sys
70 import sys
71 from IPython.core.display import display, HTML, Math
71 from IPython.core.display import display, HTML, Math
72
72
73 print("stdout")
73 print("stdout")
74 print("stderr", file=sys.stderr)
74 print("stderr", file=sys.stderr)
75
75
76 display(HTML("<b>HTML</b>"))
76 display(HTML("<b>HTML</b>"))
77
77
78 print("stdout2")
78 print("stdout2")
79 print("stderr2", file=sys.stderr)
79 print("stderr2", file=sys.stderr)
80
80
81 display(Math(r"\alpha=\beta"))
81 display(Math(r"\alpha=\beta"))
82
82
83 return Math("42")
83 return Math("42")
84
84
85 # test decorator for skipping tests when libraries are unavailable
85 # test decorator for skipping tests when libraries are unavailable
86 def skip_without(*names):
86 def skip_without(*names):
87 """skip a test if some names are not importable"""
87 """skip a test if some names are not importable"""
88 @decorator
88 @decorator
89 def skip_without_names(f, *args, **kwargs):
89 def skip_without_names(f, *args, **kwargs):
90 """decorator to skip tests in the absence of numpy."""
90 """decorator to skip tests in the absence of numpy."""
91 for name in names:
91 for name in names:
92 try:
92 try:
93 __import__(name)
93 __import__(name)
94 except ImportError:
94 except ImportError:
95 raise SkipTest
95 raise SkipTest
96 return f(*args, **kwargs)
96 return f(*args, **kwargs)
97 return skip_without_names
97 return skip_without_names
98
98
99 #-------------------------------------------------------------------------------
99 #-------------------------------------------------------------------------------
100 # Classes
100 # Classes
101 #-------------------------------------------------------------------------------
101 #-------------------------------------------------------------------------------
102
102
103
103
104 class ClusterTestCase(BaseZMQTestCase):
104 class ClusterTestCase(BaseZMQTestCase):
105 timeout = 10
105 timeout = 10
106
106
107 def add_engines(self, n=1, block=True):
107 def add_engines(self, n=1, block=True):
108 """add multiple engines to our cluster"""
108 """add multiple engines to our cluster"""
109 self.engines.extend(add_engines(n))
109 self.engines.extend(add_engines(n))
110 if block:
110 if block:
111 self.wait_on_engines()
111 self.wait_on_engines()
112
112
113 def minimum_engines(self, n=1, block=True):
113 def minimum_engines(self, n=1, block=True):
114 """add engines until there are at least n connected"""
114 """add engines until there are at least n connected"""
115 self.engines.extend(add_engines(n, total=True))
115 self.engines.extend(add_engines(n, total=True))
116 if block:
116 if block:
117 self.wait_on_engines()
117 self.wait_on_engines()
118
118
119
119
120 def wait_on_engines(self, timeout=5):
120 def wait_on_engines(self, timeout=5):
121 """wait for our engines to connect."""
121 """wait for our engines to connect."""
122 n = len(self.engines)+self.base_engine_count
122 n = len(self.engines)+self.base_engine_count
123 tic = time.time()
123 tic = time.time()
124 while time.time()-tic < timeout and len(self.client.ids) < n:
124 while time.time()-tic < timeout and len(self.client.ids) < n:
125 time.sleep(0.1)
125 time.sleep(0.1)
126
126
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
127 assert not len(self.client.ids) < n, "waiting for engines timed out"
128
128
129 def client_wait(self, client, jobs=None, timeout=-1):
129 def client_wait(self, client, jobs=None, timeout=-1):
130 """my wait wrapper, sets a default finite timeout to avoid hangs"""
130 """my wait wrapper, sets a default finite timeout to avoid hangs"""
131 if timeout < 0:
131 if timeout < 0:
132 timeout = self.timeout
132 timeout = self.timeout
133 return Client.wait(client, jobs, timeout)
133 return Client.wait(client, jobs, timeout)
134
134
135 def connect_client(self):
135 def connect_client(self):
136 """connect a client with my Context, and track its sockets for cleanup"""
136 """connect a client with my Context, and track its sockets for cleanup"""
137 c = Client(profile='iptest', context=self.context)
137 c = Client(profile='iptest', context=self.context)
138 c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
138 c.wait = lambda *a, **kw: self.client_wait(c, *a, **kw)
139
139
140 for name in filter(lambda n:n.endswith('socket'), dir(c)):
140 for name in filter(lambda n:n.endswith('socket'), dir(c)):
141 s = getattr(c, name)
141 s = getattr(c, name)
142 s.setsockopt(zmq.LINGER, 0)
142 s.setsockopt(zmq.LINGER, 0)
143 self.sockets.append(s)
143 self.sockets.append(s)
144 return c
144 return c
145
145
146 def assertRaisesRemote(self, etype, f, *args, **kwargs):
146 def assertRaisesRemote(self, etype, f, *args, **kwargs):
147 try:
147 try:
148 try:
148 try:
149 f(*args, **kwargs)
149 f(*args, **kwargs)
150 except error.CompositeError as e:
150 except error.CompositeError as e:
151 e.raise_exception()
151 e.raise_exception()
152 except error.RemoteError as e:
152 except error.RemoteError as e:
153 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
153 self.assertEqual(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
154 else:
154 else:
155 self.fail("should have raised a RemoteError")
155 self.fail("should have raised a RemoteError")
156
156
157 def _wait_for(self, f, timeout=10):
157 def _wait_for(self, f, timeout=10):
158 """wait for a condition"""
158 """wait for a condition"""
159 tic = time.time()
159 tic = time.time()
160 while time.time() <= tic + timeout:
160 while time.time() <= tic + timeout:
161 if f():
161 if f():
162 return
162 return
163 time.sleep(0.1)
163 time.sleep(0.1)
164 self.client.spin()
164 self.client.spin()
165 if not f():
165 if not f():
166 print("Warning: Awaited condition never arrived")
166 print("Warning: Awaited condition never arrived")
167
167
168 def setUp(self):
168 def setUp(self):
169 BaseZMQTestCase.setUp(self)
169 BaseZMQTestCase.setUp(self)
170 self.client = self.connect_client()
170 self.client = self.connect_client()
171 # start every test with clean engine namespaces:
171 # start every test with clean engine namespaces:
172 self.client.clear(block=True)
172 self.client.clear(block=True)
173 self.base_engine_count=len(self.client.ids)
173 self.base_engine_count=len(self.client.ids)
174 self.engines=[]
174 self.engines=[]
175
175
176 def tearDown(self):
176 def tearDown(self):
177 # self.client.clear(block=True)
177 # self.client.clear(block=True)
178 # close fds:
178 # close fds:
179 for e in filter(lambda e: e.poll() is not None, launchers):
179 for e in filter(lambda e: e.poll() is not None, launchers):
180 launchers.remove(e)
180 launchers.remove(e)
181
181
182 # allow flushing of incoming messages to prevent crash on socket close
182 # allow flushing of incoming messages to prevent crash on socket close
183 self.client.wait(timeout=2)
183 self.client.wait(timeout=2)
184 # time.sleep(2)
184 # time.sleep(2)
185 self.client.spin()
185 self.client.spin()
186 self.client.close()
186 self.client.close()
187 BaseZMQTestCase.tearDown(self)
187 BaseZMQTestCase.tearDown(self)
188 # this will be redundant when pyzmq merges PR #88
188 # this will be redundant when pyzmq merges PR #88
189 # self.context.term()
189 # self.context.term()
190 # print tempfile.TemporaryFile().fileno(),
190 # print tempfile.TemporaryFile().fileno(),
191 # sys.stdout.flush()
191 # sys.stdout.flush()
192
192
@@ -1,342 +1,342 b''
1 """Tests for asyncresult.py"""
1 """Tests for asyncresult.py"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import time
6 import time
7
7
8 import nose.tools as nt
8 import nose.tools as nt
9
9
10 from IPython.utils.io import capture_output
10 from IPython.utils.io import capture_output
11
11
12 from IPython.parallel.error import TimeoutError
12 from ipython_parallel.error import TimeoutError
13 from IPython.parallel import error, Client
13 from ipython_parallel import error, Client
14 from IPython.parallel.tests import add_engines
14 from ipython_parallel.tests import add_engines
15 from .clienttest import ClusterTestCase
15 from .clienttest import ClusterTestCase
16 from IPython.utils.py3compat import iteritems
16 from IPython.utils.py3compat import iteritems
17
17
18 def setup():
18 def setup():
19 add_engines(2, total=True)
19 add_engines(2, total=True)
20
20
21 def wait(n):
21 def wait(n):
22 import time
22 import time
23 time.sleep(n)
23 time.sleep(n)
24 return n
24 return n
25
25
26 def echo(x):
26 def echo(x):
27 return x
27 return x
28
28
29 class AsyncResultTest(ClusterTestCase):
29 class AsyncResultTest(ClusterTestCase):
30
30
31 def test_single_result_view(self):
31 def test_single_result_view(self):
32 """various one-target views get the right value for single_result"""
32 """various one-target views get the right value for single_result"""
33 eid = self.client.ids[-1]
33 eid = self.client.ids[-1]
34 ar = self.client[eid].apply_async(lambda : 42)
34 ar = self.client[eid].apply_async(lambda : 42)
35 self.assertEqual(ar.get(), 42)
35 self.assertEqual(ar.get(), 42)
36 ar = self.client[[eid]].apply_async(lambda : 42)
36 ar = self.client[[eid]].apply_async(lambda : 42)
37 self.assertEqual(ar.get(), [42])
37 self.assertEqual(ar.get(), [42])
38 ar = self.client[-1:].apply_async(lambda : 42)
38 ar = self.client[-1:].apply_async(lambda : 42)
39 self.assertEqual(ar.get(), [42])
39 self.assertEqual(ar.get(), [42])
40
40
41 def test_get_after_done(self):
41 def test_get_after_done(self):
42 ar = self.client[-1].apply_async(lambda : 42)
42 ar = self.client[-1].apply_async(lambda : 42)
43 ar.wait()
43 ar.wait()
44 self.assertTrue(ar.ready())
44 self.assertTrue(ar.ready())
45 self.assertEqual(ar.get(), 42)
45 self.assertEqual(ar.get(), 42)
46 self.assertEqual(ar.get(), 42)
46 self.assertEqual(ar.get(), 42)
47
47
48 def test_get_before_done(self):
48 def test_get_before_done(self):
49 ar = self.client[-1].apply_async(wait, 0.1)
49 ar = self.client[-1].apply_async(wait, 0.1)
50 self.assertRaises(TimeoutError, ar.get, 0)
50 self.assertRaises(TimeoutError, ar.get, 0)
51 ar.wait(0)
51 ar.wait(0)
52 self.assertFalse(ar.ready())
52 self.assertFalse(ar.ready())
53 self.assertEqual(ar.get(), 0.1)
53 self.assertEqual(ar.get(), 0.1)
54
54
55 def test_get_after_error(self):
55 def test_get_after_error(self):
56 ar = self.client[-1].apply_async(lambda : 1/0)
56 ar = self.client[-1].apply_async(lambda : 1/0)
57 ar.wait(10)
57 ar.wait(10)
58 self.assertRaisesRemote(ZeroDivisionError, ar.get)
58 self.assertRaisesRemote(ZeroDivisionError, ar.get)
59 self.assertRaisesRemote(ZeroDivisionError, ar.get)
59 self.assertRaisesRemote(ZeroDivisionError, ar.get)
60 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
60 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
61
61
62 def test_get_dict(self):
62 def test_get_dict(self):
63 n = len(self.client)
63 n = len(self.client)
64 ar = self.client[:].apply_async(lambda : 5)
64 ar = self.client[:].apply_async(lambda : 5)
65 self.assertEqual(ar.get(), [5]*n)
65 self.assertEqual(ar.get(), [5]*n)
66 d = ar.get_dict()
66 d = ar.get_dict()
67 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
67 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
68 for eid,r in iteritems(d):
68 for eid,r in iteritems(d):
69 self.assertEqual(r, 5)
69 self.assertEqual(r, 5)
70
70
71 def test_get_dict_single(self):
71 def test_get_dict_single(self):
72 view = self.client[-1]
72 view = self.client[-1]
73 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
73 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
74 ar = view.apply_async(echo, v)
74 ar = view.apply_async(echo, v)
75 self.assertEqual(ar.get(), v)
75 self.assertEqual(ar.get(), v)
76 d = ar.get_dict()
76 d = ar.get_dict()
77 self.assertEqual(d, {view.targets : v})
77 self.assertEqual(d, {view.targets : v})
78
78
79 def test_get_dict_bad(self):
79 def test_get_dict_bad(self):
80 ar = self.client[:].apply_async(lambda : 5)
80 ar = self.client[:].apply_async(lambda : 5)
81 ar2 = self.client[:].apply_async(lambda : 5)
81 ar2 = self.client[:].apply_async(lambda : 5)
82 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
82 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
83 self.assertRaises(ValueError, ar.get_dict)
83 self.assertRaises(ValueError, ar.get_dict)
84
84
85 def test_list_amr(self):
85 def test_list_amr(self):
86 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
86 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
87 rlist = list(ar)
87 rlist = list(ar)
88
88
89 def test_getattr(self):
89 def test_getattr(self):
90 ar = self.client[:].apply_async(wait, 0.5)
90 ar = self.client[:].apply_async(wait, 0.5)
91 self.assertEqual(ar.engine_id, [None] * len(ar))
91 self.assertEqual(ar.engine_id, [None] * len(ar))
92 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar._foo)
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
94 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertRaises(AttributeError, lambda : ar.foo)
95 self.assertFalse(hasattr(ar, '__length_hint__'))
95 self.assertFalse(hasattr(ar, '__length_hint__'))
96 self.assertFalse(hasattr(ar, 'foo'))
96 self.assertFalse(hasattr(ar, 'foo'))
97 self.assertTrue(hasattr(ar, 'engine_id'))
97 self.assertTrue(hasattr(ar, 'engine_id'))
98 ar.get(5)
98 ar.get(5)
99 self.assertRaises(AttributeError, lambda : ar._foo)
99 self.assertRaises(AttributeError, lambda : ar._foo)
100 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
100 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
101 self.assertRaises(AttributeError, lambda : ar.foo)
101 self.assertRaises(AttributeError, lambda : ar.foo)
102 self.assertTrue(isinstance(ar.engine_id, list))
102 self.assertTrue(isinstance(ar.engine_id, list))
103 self.assertEqual(ar.engine_id, ar['engine_id'])
103 self.assertEqual(ar.engine_id, ar['engine_id'])
104 self.assertFalse(hasattr(ar, '__length_hint__'))
104 self.assertFalse(hasattr(ar, '__length_hint__'))
105 self.assertFalse(hasattr(ar, 'foo'))
105 self.assertFalse(hasattr(ar, 'foo'))
106 self.assertTrue(hasattr(ar, 'engine_id'))
106 self.assertTrue(hasattr(ar, 'engine_id'))
107
107
108 def test_getitem(self):
108 def test_getitem(self):
109 ar = self.client[:].apply_async(wait, 0.5)
109 ar = self.client[:].apply_async(wait, 0.5)
110 self.assertEqual(ar['engine_id'], [None] * len(ar))
110 self.assertEqual(ar['engine_id'], [None] * len(ar))
111 self.assertRaises(KeyError, lambda : ar['foo'])
111 self.assertRaises(KeyError, lambda : ar['foo'])
112 ar.get(5)
112 ar.get(5)
113 self.assertRaises(KeyError, lambda : ar['foo'])
113 self.assertRaises(KeyError, lambda : ar['foo'])
114 self.assertTrue(isinstance(ar['engine_id'], list))
114 self.assertTrue(isinstance(ar['engine_id'], list))
115 self.assertEqual(ar.engine_id, ar['engine_id'])
115 self.assertEqual(ar.engine_id, ar['engine_id'])
116
116
117 def test_single_result(self):
117 def test_single_result(self):
118 ar = self.client[-1].apply_async(wait, 0.5)
118 ar = self.client[-1].apply_async(wait, 0.5)
119 self.assertRaises(KeyError, lambda : ar['foo'])
119 self.assertRaises(KeyError, lambda : ar['foo'])
120 self.assertEqual(ar['engine_id'], None)
120 self.assertEqual(ar['engine_id'], None)
121 self.assertTrue(ar.get(5) == 0.5)
121 self.assertTrue(ar.get(5) == 0.5)
122 self.assertTrue(isinstance(ar['engine_id'], int))
122 self.assertTrue(isinstance(ar['engine_id'], int))
123 self.assertTrue(isinstance(ar.engine_id, int))
123 self.assertTrue(isinstance(ar.engine_id, int))
124 self.assertEqual(ar.engine_id, ar['engine_id'])
124 self.assertEqual(ar.engine_id, ar['engine_id'])
125
125
126 def test_abort(self):
126 def test_abort(self):
127 e = self.client[-1]
127 e = self.client[-1]
128 ar = e.execute('import time; time.sleep(1)', block=False)
128 ar = e.execute('import time; time.sleep(1)', block=False)
129 ar2 = e.apply_async(lambda : 2)
129 ar2 = e.apply_async(lambda : 2)
130 ar2.abort()
130 ar2.abort()
131 self.assertRaises(error.TaskAborted, ar2.get)
131 self.assertRaises(error.TaskAborted, ar2.get)
132 ar.get()
132 ar.get()
133
133
134 def test_len(self):
134 def test_len(self):
135 v = self.client.load_balanced_view()
135 v = self.client.load_balanced_view()
136 ar = v.map_async(lambda x: x, list(range(10)))
136 ar = v.map_async(lambda x: x, list(range(10)))
137 self.assertEqual(len(ar), 10)
137 self.assertEqual(len(ar), 10)
138 ar = v.apply_async(lambda x: x, list(range(10)))
138 ar = v.apply_async(lambda x: x, list(range(10)))
139 self.assertEqual(len(ar), 1)
139 self.assertEqual(len(ar), 1)
140 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
140 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
141 self.assertEqual(len(ar), len(self.client.ids))
141 self.assertEqual(len(ar), len(self.client.ids))
142
142
143 def test_wall_time_single(self):
143 def test_wall_time_single(self):
144 v = self.client.load_balanced_view()
144 v = self.client.load_balanced_view()
145 ar = v.apply_async(time.sleep, 0.25)
145 ar = v.apply_async(time.sleep, 0.25)
146 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
146 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
147 ar.get(2)
147 ar.get(2)
148 self.assertTrue(ar.wall_time < 1.)
148 self.assertTrue(ar.wall_time < 1.)
149 self.assertTrue(ar.wall_time > 0.2)
149 self.assertTrue(ar.wall_time > 0.2)
150
150
151 def test_wall_time_multi(self):
151 def test_wall_time_multi(self):
152 self.minimum_engines(4)
152 self.minimum_engines(4)
153 v = self.client[:]
153 v = self.client[:]
154 ar = v.apply_async(time.sleep, 0.25)
154 ar = v.apply_async(time.sleep, 0.25)
155 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
155 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
156 ar.get(2)
156 ar.get(2)
157 self.assertTrue(ar.wall_time < 1.)
157 self.assertTrue(ar.wall_time < 1.)
158 self.assertTrue(ar.wall_time > 0.2)
158 self.assertTrue(ar.wall_time > 0.2)
159
159
160 def test_serial_time_single(self):
160 def test_serial_time_single(self):
161 v = self.client.load_balanced_view()
161 v = self.client.load_balanced_view()
162 ar = v.apply_async(time.sleep, 0.25)
162 ar = v.apply_async(time.sleep, 0.25)
163 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
163 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
164 ar.get(2)
164 ar.get(2)
165 self.assertTrue(ar.serial_time < 1.)
165 self.assertTrue(ar.serial_time < 1.)
166 self.assertTrue(ar.serial_time > 0.2)
166 self.assertTrue(ar.serial_time > 0.2)
167
167
168 def test_serial_time_multi(self):
168 def test_serial_time_multi(self):
169 self.minimum_engines(4)
169 self.minimum_engines(4)
170 v = self.client[:]
170 v = self.client[:]
171 ar = v.apply_async(time.sleep, 0.25)
171 ar = v.apply_async(time.sleep, 0.25)
172 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
172 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
173 ar.get(2)
173 ar.get(2)
174 self.assertTrue(ar.serial_time < 2.)
174 self.assertTrue(ar.serial_time < 2.)
175 self.assertTrue(ar.serial_time > 0.8)
175 self.assertTrue(ar.serial_time > 0.8)
176
176
177 def test_elapsed_single(self):
177 def test_elapsed_single(self):
178 v = self.client.load_balanced_view()
178 v = self.client.load_balanced_view()
179 ar = v.apply_async(time.sleep, 0.25)
179 ar = v.apply_async(time.sleep, 0.25)
180 while not ar.ready():
180 while not ar.ready():
181 time.sleep(0.01)
181 time.sleep(0.01)
182 self.assertTrue(ar.elapsed < 1)
182 self.assertTrue(ar.elapsed < 1)
183 self.assertTrue(ar.elapsed < 1)
183 self.assertTrue(ar.elapsed < 1)
184 ar.get(2)
184 ar.get(2)
185
185
186 def test_elapsed_multi(self):
186 def test_elapsed_multi(self):
187 v = self.client[:]
187 v = self.client[:]
188 ar = v.apply_async(time.sleep, 0.25)
188 ar = v.apply_async(time.sleep, 0.25)
189 while not ar.ready():
189 while not ar.ready():
190 time.sleep(0.01)
190 time.sleep(0.01)
191 self.assertLess(ar.elapsed, 1)
191 self.assertLess(ar.elapsed, 1)
192 self.assertLess(ar.elapsed, 1)
192 self.assertLess(ar.elapsed, 1)
193 ar.get(2)
193 ar.get(2)
194
194
195 def test_hubresult_timestamps(self):
195 def test_hubresult_timestamps(self):
196 self.minimum_engines(4)
196 self.minimum_engines(4)
197 v = self.client[:]
197 v = self.client[:]
198 ar = v.apply_async(time.sleep, 0.25)
198 ar = v.apply_async(time.sleep, 0.25)
199 ar.get(2)
199 ar.get(2)
200 rc2 = Client(profile='iptest')
200 rc2 = Client(profile='iptest')
201 # must have try/finally to close second Client, otherwise
201 # must have try/finally to close second Client, otherwise
202 # will have dangling sockets causing problems
202 # will have dangling sockets causing problems
203 try:
203 try:
204 time.sleep(0.25)
204 time.sleep(0.25)
205 hr = rc2.get_result(ar.msg_ids)
205 hr = rc2.get_result(ar.msg_ids)
206 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
206 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
207 hr.get(1)
207 hr.get(1)
208 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
208 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
209 self.assertEqual(hr.serial_time, ar.serial_time)
209 self.assertEqual(hr.serial_time, ar.serial_time)
210 finally:
210 finally:
211 rc2.close()
211 rc2.close()
212
212
213 def test_display_empty_streams_single(self):
213 def test_display_empty_streams_single(self):
214 """empty stdout/err are not displayed (single result)"""
214 """empty stdout/err are not displayed (single result)"""
215 self.minimum_engines(1)
215 self.minimum_engines(1)
216
216
217 v = self.client[-1]
217 v = self.client[-1]
218 ar = v.execute("print (5555)")
218 ar = v.execute("print (5555)")
219 ar.get(5)
219 ar.get(5)
220 with capture_output() as io:
220 with capture_output() as io:
221 ar.display_outputs()
221 ar.display_outputs()
222 self.assertEqual(io.stderr, '')
222 self.assertEqual(io.stderr, '')
223 self.assertEqual('5555\n', io.stdout)
223 self.assertEqual('5555\n', io.stdout)
224
224
225 ar = v.execute("a=5")
225 ar = v.execute("a=5")
226 ar.get(5)
226 ar.get(5)
227 with capture_output() as io:
227 with capture_output() as io:
228 ar.display_outputs()
228 ar.display_outputs()
229 self.assertEqual(io.stderr, '')
229 self.assertEqual(io.stderr, '')
230 self.assertEqual(io.stdout, '')
230 self.assertEqual(io.stdout, '')
231
231
232 def test_display_empty_streams_type(self):
232 def test_display_empty_streams_type(self):
233 """empty stdout/err are not displayed (groupby type)"""
233 """empty stdout/err are not displayed (groupby type)"""
234 self.minimum_engines(1)
234 self.minimum_engines(1)
235
235
236 v = self.client[:]
236 v = self.client[:]
237 ar = v.execute("print (5555)")
237 ar = v.execute("print (5555)")
238 ar.get(5)
238 ar.get(5)
239 with capture_output() as io:
239 with capture_output() as io:
240 ar.display_outputs()
240 ar.display_outputs()
241 self.assertEqual(io.stderr, '')
241 self.assertEqual(io.stderr, '')
242 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
242 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
243 self.assertFalse('\n\n' in io.stdout, io.stdout)
243 self.assertFalse('\n\n' in io.stdout, io.stdout)
244 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
244 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
245
245
246 ar = v.execute("a=5")
246 ar = v.execute("a=5")
247 ar.get(5)
247 ar.get(5)
248 with capture_output() as io:
248 with capture_output() as io:
249 ar.display_outputs()
249 ar.display_outputs()
250 self.assertEqual(io.stderr, '')
250 self.assertEqual(io.stderr, '')
251 self.assertEqual(io.stdout, '')
251 self.assertEqual(io.stdout, '')
252
252
253 def test_display_empty_streams_engine(self):
253 def test_display_empty_streams_engine(self):
254 """empty stdout/err are not displayed (groupby engine)"""
254 """empty stdout/err are not displayed (groupby engine)"""
255 self.minimum_engines(1)
255 self.minimum_engines(1)
256
256
257 v = self.client[:]
257 v = self.client[:]
258 ar = v.execute("print (5555)")
258 ar = v.execute("print (5555)")
259 ar.get(5)
259 ar.get(5)
260 with capture_output() as io:
260 with capture_output() as io:
261 ar.display_outputs('engine')
261 ar.display_outputs('engine')
262 self.assertEqual(io.stderr, '')
262 self.assertEqual(io.stderr, '')
263 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
263 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
264 self.assertFalse('\n\n' in io.stdout, io.stdout)
264 self.assertFalse('\n\n' in io.stdout, io.stdout)
265 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
265 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
266
266
267 ar = v.execute("a=5")
267 ar = v.execute("a=5")
268 ar.get(5)
268 ar.get(5)
269 with capture_output() as io:
269 with capture_output() as io:
270 ar.display_outputs('engine')
270 ar.display_outputs('engine')
271 self.assertEqual(io.stderr, '')
271 self.assertEqual(io.stderr, '')
272 self.assertEqual(io.stdout, '')
272 self.assertEqual(io.stdout, '')
273
273
274 def test_await_data(self):
274 def test_await_data(self):
275 """asking for ar.data flushes outputs"""
275 """asking for ar.data flushes outputs"""
276 self.minimum_engines(1)
276 self.minimum_engines(1)
277
277
278 v = self.client[-1]
278 v = self.client[-1]
279 ar = v.execute('\n'.join([
279 ar = v.execute('\n'.join([
280 "import time",
280 "import time",
281 "from IPython.kernel.zmq.datapub import publish_data",
281 "from IPython.kernel.zmq.datapub import publish_data",
282 "for i in range(5):",
282 "for i in range(5):",
283 " publish_data(dict(i=i))",
283 " publish_data(dict(i=i))",
284 " time.sleep(0.1)",
284 " time.sleep(0.1)",
285 ]), block=False)
285 ]), block=False)
286 found = set()
286 found = set()
287 tic = time.time()
287 tic = time.time()
288 # timeout after 10s
288 # timeout after 10s
289 while time.time() <= tic + 10:
289 while time.time() <= tic + 10:
290 if ar.data:
290 if ar.data:
291 i = ar.data['i']
291 i = ar.data['i']
292 found.add(i)
292 found.add(i)
293 if i == 4:
293 if i == 4:
294 break
294 break
295 time.sleep(0.05)
295 time.sleep(0.05)
296
296
297 ar.get(5)
297 ar.get(5)
298 nt.assert_in(4, found)
298 nt.assert_in(4, found)
299 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
299 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
300
300
301 def test_not_single_result(self):
301 def test_not_single_result(self):
302 save_build = self.client._build_targets
302 save_build = self.client._build_targets
303 def single_engine(*a, **kw):
303 def single_engine(*a, **kw):
304 idents, targets = save_build(*a, **kw)
304 idents, targets = save_build(*a, **kw)
305 return idents[:1], targets[:1]
305 return idents[:1], targets[:1]
306 ids = single_engine('all')[1]
306 ids = single_engine('all')[1]
307 self.client._build_targets = single_engine
307 self.client._build_targets = single_engine
308 for targets in ('all', None, ids):
308 for targets in ('all', None, ids):
309 dv = self.client.direct_view(targets=targets)
309 dv = self.client.direct_view(targets=targets)
310 ar = dv.apply_async(lambda : 5)
310 ar = dv.apply_async(lambda : 5)
311 self.assertEqual(ar.get(10), [5])
311 self.assertEqual(ar.get(10), [5])
312 self.client._build_targets = save_build
312 self.client._build_targets = save_build
313
313
314 def test_owner_pop(self):
314 def test_owner_pop(self):
315 self.minimum_engines(1)
315 self.minimum_engines(1)
316
316
317 view = self.client[-1]
317 view = self.client[-1]
318 ar = view.apply_async(lambda : 1)
318 ar = view.apply_async(lambda : 1)
319 ar.get()
319 ar.get()
320 msg_id = ar.msg_ids[0]
320 msg_id = ar.msg_ids[0]
321 self.assertNotIn(msg_id, self.client.results)
321 self.assertNotIn(msg_id, self.client.results)
322 self.assertNotIn(msg_id, self.client.metadata)
322 self.assertNotIn(msg_id, self.client.metadata)
323
323
324 def test_non_owner(self):
324 def test_non_owner(self):
325 self.minimum_engines(1)
325 self.minimum_engines(1)
326
326
327 view = self.client[-1]
327 view = self.client[-1]
328 ar = view.apply_async(lambda : 1)
328 ar = view.apply_async(lambda : 1)
329 ar.owner = False
329 ar.owner = False
330 ar.get()
330 ar.get()
331 msg_id = ar.msg_ids[0]
331 msg_id = ar.msg_ids[0]
332 self.assertIn(msg_id, self.client.results)
332 self.assertIn(msg_id, self.client.results)
333 self.assertIn(msg_id, self.client.metadata)
333 self.assertIn(msg_id, self.client.metadata)
334 ar2 = self.client.get_result(msg_id, owner=True)
334 ar2 = self.client.get_result(msg_id, owner=True)
335 self.assertIs(type(ar2), type(ar))
335 self.assertIs(type(ar2), type(ar))
336 self.assertTrue(ar2.owner)
336 self.assertTrue(ar2.owner)
337 self.assertEqual(ar.get(), ar2.get())
337 self.assertEqual(ar.get(), ar2.get())
338 ar2.get()
338 ar2.get()
339 self.assertNotIn(msg_id, self.client.results)
339 self.assertNotIn(msg_id, self.client.results)
340 self.assertNotIn(msg_id, self.client.metadata)
340 self.assertNotIn(msg_id, self.client.metadata)
341
341
342
342
@@ -1,550 +1,550 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import division
6 from __future__ import division
7
7
8 import time
8 import time
9 from datetime import datetime
9 from datetime import datetime
10
10
11 import zmq
11 import zmq
12
12
13 from IPython import parallel
13 from IPython import parallel
14 from IPython.parallel.client import client as clientmod
14 from ipython_parallel.client import client as clientmod
15 from IPython.parallel import error
15 from ipython_parallel import error
16 from IPython.parallel import AsyncResult, AsyncHubResult
16 from ipython_parallel import AsyncResult, AsyncHubResult
17 from IPython.parallel import LoadBalancedView, DirectView
17 from ipython_parallel import LoadBalancedView, DirectView
18
18
19 from .clienttest import ClusterTestCase, segfault, wait, add_engines
19 from .clienttest import ClusterTestCase, segfault, wait, add_engines
20
20
21 def setup():
21 def setup():
22 add_engines(4, total=True)
22 add_engines(4, total=True)
23
23
24 class TestClient(ClusterTestCase):
24 class TestClient(ClusterTestCase):
25
25
26 def test_ids(self):
26 def test_ids(self):
27 n = len(self.client.ids)
27 n = len(self.client.ids)
28 self.add_engines(2)
28 self.add_engines(2)
29 self.assertEqual(len(self.client.ids), n+2)
29 self.assertEqual(len(self.client.ids), n+2)
30
30
31 def test_iter(self):
31 def test_iter(self):
32 self.minimum_engines(4)
32 self.minimum_engines(4)
33 engine_ids = [ view.targets for view in self.client ]
33 engine_ids = [ view.targets for view in self.client ]
34 self.assertEqual(engine_ids, self.client.ids)
34 self.assertEqual(engine_ids, self.client.ids)
35
35
36 def test_view_indexing(self):
36 def test_view_indexing(self):
37 """test index access for views"""
37 """test index access for views"""
38 self.minimum_engines(4)
38 self.minimum_engines(4)
39 targets = self.client._build_targets('all')[-1]
39 targets = self.client._build_targets('all')[-1]
40 v = self.client[:]
40 v = self.client[:]
41 self.assertEqual(v.targets, targets)
41 self.assertEqual(v.targets, targets)
42 t = self.client.ids[2]
42 t = self.client.ids[2]
43 v = self.client[t]
43 v = self.client[t]
44 self.assertTrue(isinstance(v, DirectView))
44 self.assertTrue(isinstance(v, DirectView))
45 self.assertEqual(v.targets, t)
45 self.assertEqual(v.targets, t)
46 t = self.client.ids[2:4]
46 t = self.client.ids[2:4]
47 v = self.client[t]
47 v = self.client[t]
48 self.assertTrue(isinstance(v, DirectView))
48 self.assertTrue(isinstance(v, DirectView))
49 self.assertEqual(v.targets, t)
49 self.assertEqual(v.targets, t)
50 v = self.client[::2]
50 v = self.client[::2]
51 self.assertTrue(isinstance(v, DirectView))
51 self.assertTrue(isinstance(v, DirectView))
52 self.assertEqual(v.targets, targets[::2])
52 self.assertEqual(v.targets, targets[::2])
53 v = self.client[1::3]
53 v = self.client[1::3]
54 self.assertTrue(isinstance(v, DirectView))
54 self.assertTrue(isinstance(v, DirectView))
55 self.assertEqual(v.targets, targets[1::3])
55 self.assertEqual(v.targets, targets[1::3])
56 v = self.client[:-3]
56 v = self.client[:-3]
57 self.assertTrue(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, targets[:-3])
58 self.assertEqual(v.targets, targets[:-3])
59 v = self.client[-1]
59 v = self.client[-1]
60 self.assertTrue(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[-1])
61 self.assertEqual(v.targets, targets[-1])
62 self.assertRaises(TypeError, lambda : self.client[None])
62 self.assertRaises(TypeError, lambda : self.client[None])
63
63
64 def test_lbview_targets(self):
64 def test_lbview_targets(self):
65 """test load_balanced_view targets"""
65 """test load_balanced_view targets"""
66 v = self.client.load_balanced_view()
66 v = self.client.load_balanced_view()
67 self.assertEqual(v.targets, None)
67 self.assertEqual(v.targets, None)
68 v = self.client.load_balanced_view(-1)
68 v = self.client.load_balanced_view(-1)
69 self.assertEqual(v.targets, [self.client.ids[-1]])
69 self.assertEqual(v.targets, [self.client.ids[-1]])
70 v = self.client.load_balanced_view('all')
70 v = self.client.load_balanced_view('all')
71 self.assertEqual(v.targets, None)
71 self.assertEqual(v.targets, None)
72
72
73 def test_dview_targets(self):
73 def test_dview_targets(self):
74 """test direct_view targets"""
74 """test direct_view targets"""
75 v = self.client.direct_view()
75 v = self.client.direct_view()
76 self.assertEqual(v.targets, 'all')
76 self.assertEqual(v.targets, 'all')
77 v = self.client.direct_view('all')
77 v = self.client.direct_view('all')
78 self.assertEqual(v.targets, 'all')
78 self.assertEqual(v.targets, 'all')
79 v = self.client.direct_view(-1)
79 v = self.client.direct_view(-1)
80 self.assertEqual(v.targets, self.client.ids[-1])
80 self.assertEqual(v.targets, self.client.ids[-1])
81
81
82 def test_lazy_all_targets(self):
82 def test_lazy_all_targets(self):
83 """test lazy evaluation of rc.direct_view('all')"""
83 """test lazy evaluation of rc.direct_view('all')"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86
86
87 def double(x):
87 def double(x):
88 return x*2
88 return x*2
89 seq = list(range(100))
89 seq = list(range(100))
90 ref = [ double(x) for x in seq ]
90 ref = [ double(x) for x in seq ]
91
91
92 # add some engines, which should be used
92 # add some engines, which should be used
93 self.add_engines(1)
93 self.add_engines(1)
94 n1 = len(self.client.ids)
94 n1 = len(self.client.ids)
95
95
96 # simple apply
96 # simple apply
97 r = v.apply_sync(lambda : 1)
97 r = v.apply_sync(lambda : 1)
98 self.assertEqual(r, [1] * n1)
98 self.assertEqual(r, [1] * n1)
99
99
100 # map goes through remotefunction
100 # map goes through remotefunction
101 r = v.map_sync(double, seq)
101 r = v.map_sync(double, seq)
102 self.assertEqual(r, ref)
102 self.assertEqual(r, ref)
103
103
104 # add a couple more engines, and try again
104 # add a couple more engines, and try again
105 self.add_engines(2)
105 self.add_engines(2)
106 n2 = len(self.client.ids)
106 n2 = len(self.client.ids)
107 self.assertNotEqual(n2, n1)
107 self.assertNotEqual(n2, n1)
108
108
109 # apply
109 # apply
110 r = v.apply_sync(lambda : 1)
110 r = v.apply_sync(lambda : 1)
111 self.assertEqual(r, [1] * n2)
111 self.assertEqual(r, [1] * n2)
112
112
113 # map
113 # map
114 r = v.map_sync(double, seq)
114 r = v.map_sync(double, seq)
115 self.assertEqual(r, ref)
115 self.assertEqual(r, ref)
116
116
117 def test_targets(self):
117 def test_targets(self):
118 """test various valid targets arguments"""
118 """test various valid targets arguments"""
119 build = self.client._build_targets
119 build = self.client._build_targets
120 ids = self.client.ids
120 ids = self.client.ids
121 idents,targets = build(None)
121 idents,targets = build(None)
122 self.assertEqual(ids, targets)
122 self.assertEqual(ids, targets)
123
123
124 def test_clear(self):
124 def test_clear(self):
125 """test clear behavior"""
125 """test clear behavior"""
126 self.minimum_engines(2)
126 self.minimum_engines(2)
127 v = self.client[:]
127 v = self.client[:]
128 v.block=True
128 v.block=True
129 v.push(dict(a=5))
129 v.push(dict(a=5))
130 v.pull('a')
130 v.pull('a')
131 id0 = self.client.ids[-1]
131 id0 = self.client.ids[-1]
132 self.client.clear(targets=id0, block=True)
132 self.client.clear(targets=id0, block=True)
133 a = self.client[:-1].get('a')
133 a = self.client[:-1].get('a')
134 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
134 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
135 self.client.clear(block=True)
135 self.client.clear(block=True)
136 for i in self.client.ids:
136 for i in self.client.ids:
137 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
137 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
138
138
139 def test_get_result(self):
139 def test_get_result(self):
140 """test getting results from the Hub."""
140 """test getting results from the Hub."""
141 c = clientmod.Client(profile='iptest')
141 c = clientmod.Client(profile='iptest')
142 t = c.ids[-1]
142 t = c.ids[-1]
143 ar = c[t].apply_async(wait, 1)
143 ar = c[t].apply_async(wait, 1)
144 # give the monitor time to notice the message
144 # give the monitor time to notice the message
145 time.sleep(.25)
145 time.sleep(.25)
146 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
146 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
147 self.assertIsInstance(ahr, AsyncHubResult)
147 self.assertIsInstance(ahr, AsyncHubResult)
148 self.assertEqual(ahr.get(), ar.get())
148 self.assertEqual(ahr.get(), ar.get())
149 ar2 = self.client.get_result(ar.msg_ids[0])
149 ar2 = self.client.get_result(ar.msg_ids[0])
150 self.assertNotIsInstance(ar2, AsyncHubResult)
150 self.assertNotIsInstance(ar2, AsyncHubResult)
151 self.assertEqual(ahr.get(), ar2.get())
151 self.assertEqual(ahr.get(), ar2.get())
152 c.close()
152 c.close()
153
153
154 def test_get_execute_result(self):
154 def test_get_execute_result(self):
155 """test getting execute results from the Hub."""
155 """test getting execute results from the Hub."""
156 c = clientmod.Client(profile='iptest')
156 c = clientmod.Client(profile='iptest')
157 t = c.ids[-1]
157 t = c.ids[-1]
158 cell = '\n'.join([
158 cell = '\n'.join([
159 'import time',
159 'import time',
160 'time.sleep(0.25)',
160 'time.sleep(0.25)',
161 '5'
161 '5'
162 ])
162 ])
163 ar = c[t].execute("import time; time.sleep(1)", silent=False)
163 ar = c[t].execute("import time; time.sleep(1)", silent=False)
164 # give the monitor time to notice the message
164 # give the monitor time to notice the message
165 time.sleep(.25)
165 time.sleep(.25)
166 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
166 ahr = self.client.get_result(ar.msg_ids[0], owner=False)
167 self.assertIsInstance(ahr, AsyncHubResult)
167 self.assertIsInstance(ahr, AsyncHubResult)
168 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
168 self.assertEqual(ahr.get().execute_result, ar.get().execute_result)
169 ar2 = self.client.get_result(ar.msg_ids[0])
169 ar2 = self.client.get_result(ar.msg_ids[0])
170 self.assertNotIsInstance(ar2, AsyncHubResult)
170 self.assertNotIsInstance(ar2, AsyncHubResult)
171 self.assertEqual(ahr.get(), ar2.get())
171 self.assertEqual(ahr.get(), ar2.get())
172 c.close()
172 c.close()
173
173
174 def test_ids_list(self):
174 def test_ids_list(self):
175 """test client.ids"""
175 """test client.ids"""
176 ids = self.client.ids
176 ids = self.client.ids
177 self.assertEqual(ids, self.client._ids)
177 self.assertEqual(ids, self.client._ids)
178 self.assertFalse(ids is self.client._ids)
178 self.assertFalse(ids is self.client._ids)
179 ids.remove(ids[-1])
179 ids.remove(ids[-1])
180 self.assertNotEqual(ids, self.client._ids)
180 self.assertNotEqual(ids, self.client._ids)
181
181
182 def test_queue_status(self):
182 def test_queue_status(self):
183 ids = self.client.ids
183 ids = self.client.ids
184 id0 = ids[0]
184 id0 = ids[0]
185 qs = self.client.queue_status(targets=id0)
185 qs = self.client.queue_status(targets=id0)
186 self.assertTrue(isinstance(qs, dict))
186 self.assertTrue(isinstance(qs, dict))
187 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
187 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
188 allqs = self.client.queue_status()
188 allqs = self.client.queue_status()
189 self.assertTrue(isinstance(allqs, dict))
189 self.assertTrue(isinstance(allqs, dict))
190 intkeys = list(allqs.keys())
190 intkeys = list(allqs.keys())
191 intkeys.remove('unassigned')
191 intkeys.remove('unassigned')
192 print("intkeys", intkeys)
192 print("intkeys", intkeys)
193 intkeys = sorted(intkeys)
193 intkeys = sorted(intkeys)
194 ids = self.client.ids
194 ids = self.client.ids
195 print("client.ids", ids)
195 print("client.ids", ids)
196 ids = sorted(self.client.ids)
196 ids = sorted(self.client.ids)
197 self.assertEqual(intkeys, ids)
197 self.assertEqual(intkeys, ids)
198 unassigned = allqs.pop('unassigned')
198 unassigned = allqs.pop('unassigned')
199 for eid,qs in allqs.items():
199 for eid,qs in allqs.items():
200 self.assertTrue(isinstance(qs, dict))
200 self.assertTrue(isinstance(qs, dict))
201 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
201 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
202
202
203 def test_shutdown(self):
203 def test_shutdown(self):
204 ids = self.client.ids
204 ids = self.client.ids
205 id0 = ids[0]
205 id0 = ids[0]
206 self.client.shutdown(id0, block=True)
206 self.client.shutdown(id0, block=True)
207 while id0 in self.client.ids:
207 while id0 in self.client.ids:
208 time.sleep(0.1)
208 time.sleep(0.1)
209 self.client.spin()
209 self.client.spin()
210
210
211 self.assertRaises(IndexError, lambda : self.client[id0])
211 self.assertRaises(IndexError, lambda : self.client[id0])
212
212
213 def test_result_status(self):
213 def test_result_status(self):
214 pass
214 pass
215 # to be written
215 # to be written
216
216
217 def test_db_query_dt(self):
217 def test_db_query_dt(self):
218 """test db query by date"""
218 """test db query by date"""
219 hist = self.client.hub_history()
219 hist = self.client.hub_history()
220 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
220 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
221 tic = middle['submitted']
221 tic = middle['submitted']
222 before = self.client.db_query({'submitted' : {'$lt' : tic}})
222 before = self.client.db_query({'submitted' : {'$lt' : tic}})
223 after = self.client.db_query({'submitted' : {'$gte' : tic}})
223 after = self.client.db_query({'submitted' : {'$gte' : tic}})
224 self.assertEqual(len(before)+len(after),len(hist))
224 self.assertEqual(len(before)+len(after),len(hist))
225 for b in before:
225 for b in before:
226 self.assertTrue(b['submitted'] < tic)
226 self.assertTrue(b['submitted'] < tic)
227 for a in after:
227 for a in after:
228 self.assertTrue(a['submitted'] >= tic)
228 self.assertTrue(a['submitted'] >= tic)
229 same = self.client.db_query({'submitted' : tic})
229 same = self.client.db_query({'submitted' : tic})
230 for s in same:
230 for s in same:
231 self.assertTrue(s['submitted'] == tic)
231 self.assertTrue(s['submitted'] == tic)
232
232
233 def test_db_query_keys(self):
233 def test_db_query_keys(self):
234 """test extracting subset of record keys"""
234 """test extracting subset of record keys"""
235 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
235 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
236 for rec in found:
236 for rec in found:
237 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
237 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
238
238
239 def test_db_query_default_keys(self):
239 def test_db_query_default_keys(self):
240 """default db_query excludes buffers"""
240 """default db_query excludes buffers"""
241 found = self.client.db_query({'msg_id': {'$ne' : ''}})
241 found = self.client.db_query({'msg_id': {'$ne' : ''}})
242 for rec in found:
242 for rec in found:
243 keys = set(rec.keys())
243 keys = set(rec.keys())
244 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
244 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
245 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
245 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
246
246
247 def test_db_query_msg_id(self):
247 def test_db_query_msg_id(self):
248 """ensure msg_id is always in db queries"""
248 """ensure msg_id is always in db queries"""
249 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
249 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
250 for rec in found:
250 for rec in found:
251 self.assertTrue('msg_id' in rec.keys())
251 self.assertTrue('msg_id' in rec.keys())
252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
252 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
253 for rec in found:
253 for rec in found:
254 self.assertTrue('msg_id' in rec.keys())
254 self.assertTrue('msg_id' in rec.keys())
255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
255 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
256 for rec in found:
256 for rec in found:
257 self.assertTrue('msg_id' in rec.keys())
257 self.assertTrue('msg_id' in rec.keys())
258
258
259 def test_db_query_get_result(self):
259 def test_db_query_get_result(self):
260 """pop in db_query shouldn't pop from result itself"""
260 """pop in db_query shouldn't pop from result itself"""
261 self.client[:].apply_sync(lambda : 1)
261 self.client[:].apply_sync(lambda : 1)
262 found = self.client.db_query({'msg_id': {'$ne' : ''}})
262 found = self.client.db_query({'msg_id': {'$ne' : ''}})
263 rc2 = clientmod.Client(profile='iptest')
263 rc2 = clientmod.Client(profile='iptest')
264 # If this bug is not fixed, this call will hang:
264 # If this bug is not fixed, this call will hang:
265 ar = rc2.get_result(self.client.history[-1])
265 ar = rc2.get_result(self.client.history[-1])
266 ar.wait(2)
266 ar.wait(2)
267 self.assertTrue(ar.ready())
267 self.assertTrue(ar.ready())
268 ar.get()
268 ar.get()
269 rc2.close()
269 rc2.close()
270
270
271 def test_db_query_in(self):
271 def test_db_query_in(self):
272 """test db query with '$in','$nin' operators"""
272 """test db query with '$in','$nin' operators"""
273 hist = self.client.hub_history()
273 hist = self.client.hub_history()
274 even = hist[::2]
274 even = hist[::2]
275 odd = hist[1::2]
275 odd = hist[1::2]
276 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
276 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
277 found = [ r['msg_id'] for r in recs ]
277 found = [ r['msg_id'] for r in recs ]
278 self.assertEqual(set(even), set(found))
278 self.assertEqual(set(even), set(found))
279 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
279 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
280 found = [ r['msg_id'] for r in recs ]
280 found = [ r['msg_id'] for r in recs ]
281 self.assertEqual(set(odd), set(found))
281 self.assertEqual(set(odd), set(found))
282
282
283 def test_hub_history(self):
283 def test_hub_history(self):
284 hist = self.client.hub_history()
284 hist = self.client.hub_history()
285 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
285 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
286 recdict = {}
286 recdict = {}
287 for rec in recs:
287 for rec in recs:
288 recdict[rec['msg_id']] = rec
288 recdict[rec['msg_id']] = rec
289
289
290 latest = datetime(1984,1,1)
290 latest = datetime(1984,1,1)
291 for msg_id in hist:
291 for msg_id in hist:
292 rec = recdict[msg_id]
292 rec = recdict[msg_id]
293 newt = rec['submitted']
293 newt = rec['submitted']
294 self.assertTrue(newt >= latest)
294 self.assertTrue(newt >= latest)
295 latest = newt
295 latest = newt
296 ar = self.client[-1].apply_async(lambda : 1)
296 ar = self.client[-1].apply_async(lambda : 1)
297 ar.get()
297 ar.get()
298 time.sleep(0.25)
298 time.sleep(0.25)
299 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
299 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
300
300
301 def _wait_for_idle(self):
301 def _wait_for_idle(self):
302 """wait for the cluster to become idle, according to the everyone."""
302 """wait for the cluster to become idle, according to the everyone."""
303 rc = self.client
303 rc = self.client
304
304
305 # step 0. wait for local results
305 # step 0. wait for local results
306 # this should be sufficient 99% of the time.
306 # this should be sufficient 99% of the time.
307 rc.wait(timeout=5)
307 rc.wait(timeout=5)
308
308
309 # step 1. wait for all requests to be noticed
309 # step 1. wait for all requests to be noticed
310 # timeout 5s, polling every 100ms
310 # timeout 5s, polling every 100ms
311 msg_ids = set(rc.history)
311 msg_ids = set(rc.history)
312 hub_hist = rc.hub_history()
312 hub_hist = rc.hub_history()
313 for i in range(50):
313 for i in range(50):
314 if msg_ids.difference(hub_hist):
314 if msg_ids.difference(hub_hist):
315 time.sleep(0.1)
315 time.sleep(0.1)
316 hub_hist = rc.hub_history()
316 hub_hist = rc.hub_history()
317 else:
317 else:
318 break
318 break
319
319
320 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
320 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
321
321
322 # step 2. wait for all requests to be done
322 # step 2. wait for all requests to be done
323 # timeout 5s, polling every 100ms
323 # timeout 5s, polling every 100ms
324 qs = rc.queue_status()
324 qs = rc.queue_status()
325 for i in range(50):
325 for i in range(50):
326 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
326 if qs['unassigned'] or any(qs[eid]['tasks'] + qs[eid]['queue'] for eid in qs if eid != 'unassigned'):
327 time.sleep(0.1)
327 time.sleep(0.1)
328 qs = rc.queue_status()
328 qs = rc.queue_status()
329 else:
329 else:
330 break
330 break
331
331
332 # ensure Hub up to date:
332 # ensure Hub up to date:
333 self.assertEqual(qs['unassigned'], 0)
333 self.assertEqual(qs['unassigned'], 0)
334 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
334 for eid in [ eid for eid in qs if eid != 'unassigned' ]:
335 self.assertEqual(qs[eid]['tasks'], 0)
335 self.assertEqual(qs[eid]['tasks'], 0)
336 self.assertEqual(qs[eid]['queue'], 0)
336 self.assertEqual(qs[eid]['queue'], 0)
337
337
338
338
339 def test_resubmit(self):
339 def test_resubmit(self):
340 def f():
340 def f():
341 import random
341 import random
342 return random.random()
342 return random.random()
343 v = self.client.load_balanced_view()
343 v = self.client.load_balanced_view()
344 ar = v.apply_async(f)
344 ar = v.apply_async(f)
345 r1 = ar.get(1)
345 r1 = ar.get(1)
346 # give the Hub a chance to notice:
346 # give the Hub a chance to notice:
347 self._wait_for_idle()
347 self._wait_for_idle()
348 ahr = self.client.resubmit(ar.msg_ids)
348 ahr = self.client.resubmit(ar.msg_ids)
349 r2 = ahr.get(1)
349 r2 = ahr.get(1)
350 self.assertFalse(r1 == r2)
350 self.assertFalse(r1 == r2)
351
351
352 def test_resubmit_chain(self):
352 def test_resubmit_chain(self):
353 """resubmit resubmitted tasks"""
353 """resubmit resubmitted tasks"""
354 v = self.client.load_balanced_view()
354 v = self.client.load_balanced_view()
355 ar = v.apply_async(lambda x: x, 'x'*1024)
355 ar = v.apply_async(lambda x: x, 'x'*1024)
356 ar.get()
356 ar.get()
357 self._wait_for_idle()
357 self._wait_for_idle()
358 ars = [ar]
358 ars = [ar]
359
359
360 for i in range(10):
360 for i in range(10):
361 ar = ars[-1]
361 ar = ars[-1]
362 ar2 = self.client.resubmit(ar.msg_ids)
362 ar2 = self.client.resubmit(ar.msg_ids)
363
363
364 [ ar.get() for ar in ars ]
364 [ ar.get() for ar in ars ]
365
365
366 def test_resubmit_header(self):
366 def test_resubmit_header(self):
367 """resubmit shouldn't clobber the whole header"""
367 """resubmit shouldn't clobber the whole header"""
368 def f():
368 def f():
369 import random
369 import random
370 return random.random()
370 return random.random()
371 v = self.client.load_balanced_view()
371 v = self.client.load_balanced_view()
372 v.retries = 1
372 v.retries = 1
373 ar = v.apply_async(f)
373 ar = v.apply_async(f)
374 r1 = ar.get(1)
374 r1 = ar.get(1)
375 # give the Hub a chance to notice:
375 # give the Hub a chance to notice:
376 self._wait_for_idle()
376 self._wait_for_idle()
377 ahr = self.client.resubmit(ar.msg_ids)
377 ahr = self.client.resubmit(ar.msg_ids)
378 ahr.get(1)
378 ahr.get(1)
379 time.sleep(0.5)
379 time.sleep(0.5)
380 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
380 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
381 h1,h2 = [ r['header'] for r in records ]
381 h1,h2 = [ r['header'] for r in records ]
382 for key in set(h1.keys()).union(set(h2.keys())):
382 for key in set(h1.keys()).union(set(h2.keys())):
383 if key in ('msg_id', 'date'):
383 if key in ('msg_id', 'date'):
384 self.assertNotEqual(h1[key], h2[key])
384 self.assertNotEqual(h1[key], h2[key])
385 else:
385 else:
386 self.assertEqual(h1[key], h2[key])
386 self.assertEqual(h1[key], h2[key])
387
387
388 def test_resubmit_aborted(self):
388 def test_resubmit_aborted(self):
389 def f():
389 def f():
390 import random
390 import random
391 return random.random()
391 return random.random()
392 v = self.client.load_balanced_view()
392 v = self.client.load_balanced_view()
393 # restrict to one engine, so we can put a sleep
393 # restrict to one engine, so we can put a sleep
394 # ahead of the task, so it will get aborted
394 # ahead of the task, so it will get aborted
395 eid = self.client.ids[-1]
395 eid = self.client.ids[-1]
396 v.targets = [eid]
396 v.targets = [eid]
397 sleep = v.apply_async(time.sleep, 0.5)
397 sleep = v.apply_async(time.sleep, 0.5)
398 ar = v.apply_async(f)
398 ar = v.apply_async(f)
399 ar.abort()
399 ar.abort()
400 self.assertRaises(error.TaskAborted, ar.get)
400 self.assertRaises(error.TaskAborted, ar.get)
401 # Give the Hub a chance to get up to date:
401 # Give the Hub a chance to get up to date:
402 self._wait_for_idle()
402 self._wait_for_idle()
403 ahr = self.client.resubmit(ar.msg_ids)
403 ahr = self.client.resubmit(ar.msg_ids)
404 r2 = ahr.get(1)
404 r2 = ahr.get(1)
405
405
406 def test_resubmit_inflight(self):
406 def test_resubmit_inflight(self):
407 """resubmit of inflight task"""
407 """resubmit of inflight task"""
408 v = self.client.load_balanced_view()
408 v = self.client.load_balanced_view()
409 ar = v.apply_async(time.sleep,1)
409 ar = v.apply_async(time.sleep,1)
410 # give the message a chance to arrive
410 # give the message a chance to arrive
411 time.sleep(0.2)
411 time.sleep(0.2)
412 ahr = self.client.resubmit(ar.msg_ids)
412 ahr = self.client.resubmit(ar.msg_ids)
413 ar.get(2)
413 ar.get(2)
414 ahr.get(2)
414 ahr.get(2)
415
415
416 def test_resubmit_badkey(self):
416 def test_resubmit_badkey(self):
417 """ensure KeyError on resubmit of nonexistant task"""
417 """ensure KeyError on resubmit of nonexistant task"""
418 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
418 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
419
419
420 def test_purge_hub_results(self):
420 def test_purge_hub_results(self):
421 # ensure there are some tasks
421 # ensure there are some tasks
422 for i in range(5):
422 for i in range(5):
423 self.client[:].apply_sync(lambda : 1)
423 self.client[:].apply_sync(lambda : 1)
424 # Wait for the Hub to realise the result is done:
424 # Wait for the Hub to realise the result is done:
425 # This prevents a race condition, where we
425 # This prevents a race condition, where we
426 # might purge a result the Hub still thinks is pending.
426 # might purge a result the Hub still thinks is pending.
427 self._wait_for_idle()
427 self._wait_for_idle()
428 rc2 = clientmod.Client(profile='iptest')
428 rc2 = clientmod.Client(profile='iptest')
429 hist = self.client.hub_history()
429 hist = self.client.hub_history()
430 ahr = rc2.get_result([hist[-1]])
430 ahr = rc2.get_result([hist[-1]])
431 ahr.wait(10)
431 ahr.wait(10)
432 self.client.purge_hub_results(hist[-1])
432 self.client.purge_hub_results(hist[-1])
433 newhist = self.client.hub_history()
433 newhist = self.client.hub_history()
434 self.assertEqual(len(newhist)+1,len(hist))
434 self.assertEqual(len(newhist)+1,len(hist))
435 rc2.spin()
435 rc2.spin()
436 rc2.close()
436 rc2.close()
437
437
438 def test_purge_local_results(self):
438 def test_purge_local_results(self):
439 # ensure there are some tasks
439 # ensure there are some tasks
440 res = []
440 res = []
441 for i in range(5):
441 for i in range(5):
442 res.append(self.client[:].apply_async(lambda : 1))
442 res.append(self.client[:].apply_async(lambda : 1))
443 self._wait_for_idle()
443 self._wait_for_idle()
444 self.client.wait(10) # wait for the results to come back
444 self.client.wait(10) # wait for the results to come back
445 before = len(self.client.results)
445 before = len(self.client.results)
446 self.assertEqual(len(self.client.metadata),before)
446 self.assertEqual(len(self.client.metadata),before)
447 self.client.purge_local_results(res[-1])
447 self.client.purge_local_results(res[-1])
448 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
448 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
449 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
449 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
450
450
451 def test_purge_local_results_outstanding(self):
451 def test_purge_local_results_outstanding(self):
452 v = self.client[-1]
452 v = self.client[-1]
453 ar = v.apply_async(lambda : 1)
453 ar = v.apply_async(lambda : 1)
454 msg_id = ar.msg_ids[0]
454 msg_id = ar.msg_ids[0]
455 ar.owner = False
455 ar.owner = False
456 ar.get()
456 ar.get()
457 self._wait_for_idle()
457 self._wait_for_idle()
458 ar2 = v.apply_async(time.sleep, 1)
458 ar2 = v.apply_async(time.sleep, 1)
459 self.assertIn(msg_id, self.client.results)
459 self.assertIn(msg_id, self.client.results)
460 self.assertIn(msg_id, self.client.metadata)
460 self.assertIn(msg_id, self.client.metadata)
461 self.client.purge_local_results(ar)
461 self.client.purge_local_results(ar)
462 self.assertNotIn(msg_id, self.client.results)
462 self.assertNotIn(msg_id, self.client.results)
463 self.assertNotIn(msg_id, self.client.metadata)
463 self.assertNotIn(msg_id, self.client.metadata)
464 with self.assertRaises(RuntimeError):
464 with self.assertRaises(RuntimeError):
465 self.client.purge_local_results(ar2)
465 self.client.purge_local_results(ar2)
466 ar2.get()
466 ar2.get()
467 self.client.purge_local_results(ar2)
467 self.client.purge_local_results(ar2)
468
468
469 def test_purge_all_local_results_outstanding(self):
469 def test_purge_all_local_results_outstanding(self):
470 v = self.client[-1]
470 v = self.client[-1]
471 ar = v.apply_async(time.sleep, 1)
471 ar = v.apply_async(time.sleep, 1)
472 with self.assertRaises(RuntimeError):
472 with self.assertRaises(RuntimeError):
473 self.client.purge_local_results('all')
473 self.client.purge_local_results('all')
474 ar.get()
474 ar.get()
475 self.client.purge_local_results('all')
475 self.client.purge_local_results('all')
476
476
477 def test_purge_all_hub_results(self):
477 def test_purge_all_hub_results(self):
478 self.client.purge_hub_results('all')
478 self.client.purge_hub_results('all')
479 hist = self.client.hub_history()
479 hist = self.client.hub_history()
480 self.assertEqual(len(hist), 0)
480 self.assertEqual(len(hist), 0)
481
481
482 def test_purge_all_local_results(self):
482 def test_purge_all_local_results(self):
483 self.client.purge_local_results('all')
483 self.client.purge_local_results('all')
484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
484 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
485 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
486
486
487 def test_purge_all_results(self):
487 def test_purge_all_results(self):
488 # ensure there are some tasks
488 # ensure there are some tasks
489 for i in range(5):
489 for i in range(5):
490 self.client[:].apply_sync(lambda : 1)
490 self.client[:].apply_sync(lambda : 1)
491 self.client.wait(10)
491 self.client.wait(10)
492 self._wait_for_idle()
492 self._wait_for_idle()
493 self.client.purge_results('all')
493 self.client.purge_results('all')
494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
494 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
495 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
496 hist = self.client.hub_history()
496 hist = self.client.hub_history()
497 self.assertEqual(len(hist), 0, msg="hub history not empty")
497 self.assertEqual(len(hist), 0, msg="hub history not empty")
498
498
499 def test_purge_everything(self):
499 def test_purge_everything(self):
500 # ensure there are some tasks
500 # ensure there are some tasks
501 for i in range(5):
501 for i in range(5):
502 self.client[:].apply_sync(lambda : 1)
502 self.client[:].apply_sync(lambda : 1)
503 self.client.wait(10)
503 self.client.wait(10)
504 self._wait_for_idle()
504 self._wait_for_idle()
505 self.client.purge_everything()
505 self.client.purge_everything()
506 # The client results
506 # The client results
507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
507 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
508 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
509 # The client "bookkeeping"
509 # The client "bookkeeping"
510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
510 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
511 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
512 # the hub results
512 # the hub results
513 hist = self.client.hub_history()
513 hist = self.client.hub_history()
514 self.assertEqual(len(hist), 0, msg="hub history not empty")
514 self.assertEqual(len(hist), 0, msg="hub history not empty")
515
515
516
516
517 def test_spin_thread(self):
517 def test_spin_thread(self):
518 self.client.spin_thread(0.01)
518 self.client.spin_thread(0.01)
519 ar = self.client[-1].apply_async(lambda : 1)
519 ar = self.client[-1].apply_async(lambda : 1)
520 md = self.client.metadata[ar.msg_ids[0]]
520 md = self.client.metadata[ar.msg_ids[0]]
521 # 3s timeout, 100ms poll
521 # 3s timeout, 100ms poll
522 for i in range(30):
522 for i in range(30):
523 time.sleep(0.1)
523 time.sleep(0.1)
524 if md['received'] is not None:
524 if md['received'] is not None:
525 break
525 break
526 self.assertIsInstance(md['received'], datetime)
526 self.assertIsInstance(md['received'], datetime)
527
527
528 def test_stop_spin_thread(self):
528 def test_stop_spin_thread(self):
529 self.client.spin_thread(0.01)
529 self.client.spin_thread(0.01)
530 self.client.stop_spin_thread()
530 self.client.stop_spin_thread()
531 ar = self.client[-1].apply_async(lambda : 1)
531 ar = self.client[-1].apply_async(lambda : 1)
532 md = self.client.metadata[ar.msg_ids[0]]
532 md = self.client.metadata[ar.msg_ids[0]]
533 # 500ms timeout, 100ms poll
533 # 500ms timeout, 100ms poll
534 for i in range(5):
534 for i in range(5):
535 time.sleep(0.1)
535 time.sleep(0.1)
536 self.assertIsNone(md['received'], None)
536 self.assertIsNone(md['received'], None)
537
537
538 def test_activate(self):
538 def test_activate(self):
539 ip = get_ipython()
539 ip = get_ipython()
540 magics = ip.magics_manager.magics
540 magics = ip.magics_manager.magics
541 self.assertTrue('px' in magics['line'])
541 self.assertTrue('px' in magics['line'])
542 self.assertTrue('px' in magics['cell'])
542 self.assertTrue('px' in magics['cell'])
543 v0 = self.client.activate(-1, '0')
543 v0 = self.client.activate(-1, '0')
544 self.assertTrue('px0' in magics['line'])
544 self.assertTrue('px0' in magics['line'])
545 self.assertTrue('px0' in magics['cell'])
545 self.assertTrue('px0' in magics['cell'])
546 self.assertEqual(v0.targets, self.client.ids[-1])
546 self.assertEqual(v0.targets, self.client.ids[-1])
547 v0 = self.client.activate('all', 'all')
547 v0 = self.client.activate('all', 'all')
548 self.assertTrue('pxall' in magics['line'])
548 self.assertTrue('pxall' in magics['line'])
549 self.assertTrue('pxall' in magics['cell'])
549 self.assertTrue('pxall' in magics['cell'])
550 self.assertEqual(v0.targets, 'all')
550 self.assertEqual(v0.targets, 'all')
@@ -1,314 +1,314 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import logging
22 import os
22 import os
23 import tempfile
23 import tempfile
24 import time
24 import time
25
25
26 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
27 from unittest import TestCase
27 from unittest import TestCase
28
28
29 from IPython.parallel import error
29 from ipython_parallel import error
30 from IPython.parallel.controller.dictdb import DictDB
30 from ipython_parallel.controller.dictdb import DictDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from ipython_parallel.controller.sqlitedb import SQLiteDB
32 from IPython.parallel.controller.hub import init_record, empty_record
32 from ipython_parallel.controller.hub import init_record, empty_record
33
33
34 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
35 from IPython.kernel.zmq.session import Session
35 from IPython.kernel.zmq.session import Session
36
36
37
37
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39 # TestCases
39 # TestCases
40 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
41
41
42
42
43 def setup():
43 def setup():
44 global temp_db
44 global temp_db
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46
46
47
47
48 class TaskDBTest:
48 class TaskDBTest:
49 def setUp(self):
49 def setUp(self):
50 self.session = Session()
50 self.session = Session()
51 self.db = self.create_db()
51 self.db = self.create_db()
52 self.load_records(16)
52 self.load_records(16)
53
53
54 def create_db(self):
54 def create_db(self):
55 raise NotImplementedError
55 raise NotImplementedError
56
56
57 def load_records(self, n=1, buffer_size=100):
57 def load_records(self, n=1, buffer_size=100):
58 """load n records for testing"""
58 """load n records for testing"""
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 time.sleep(0.1)
60 time.sleep(0.1)
61 msg_ids = []
61 msg_ids = []
62 for i in range(n):
62 for i in range(n):
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
64 msg['buffers'] = [os.urandom(buffer_size)]
64 msg['buffers'] = [os.urandom(buffer_size)]
65 rec = init_record(msg)
65 rec = init_record(msg)
66 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
67 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
68 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
69 return msg_ids
69 return msg_ids
70
70
71 def test_add_record(self):
71 def test_add_record(self):
72 before = self.db.get_history()
72 before = self.db.get_history()
73 self.load_records(5)
73 self.load_records(5)
74 after = self.db.get_history()
74 after = self.db.get_history()
75 self.assertEqual(len(after), len(before)+5)
75 self.assertEqual(len(after), len(before)+5)
76 self.assertEqual(after[:-5],before)
76 self.assertEqual(after[:-5],before)
77
77
78 def test_drop_record(self):
78 def test_drop_record(self):
79 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
80 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83
83
84 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
85 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
86 micro = dt.microsecond
86 micro = dt.microsecond
87 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
88 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
89
89
90 def test_update_record(self):
90 def test_update_record(self):
91 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
92 #
92 #
93 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
94 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
95 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
96 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
97 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
98 self.assertEqual(rec2['stdout'], 'hello there')
98 self.assertEqual(rec2['stdout'], 'hello there')
99 self.assertEqual(rec2['completed'], now)
99 self.assertEqual(rec2['completed'], now)
100 rec1.update(data)
100 rec1.update(data)
101 self.assertEqual(rec1, rec2)
101 self.assertEqual(rec1, rec2)
102
102
103 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
104 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
105 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
106 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108
108
109 def test_find_records_dt(self):
109 def test_find_records_dt(self):
110 """test finding records by date"""
110 """test finding records by date"""
111 hist = self.db.get_history()
111 hist = self.db.get_history()
112 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
113 tic = middle['submitted']
113 tic = middle['submitted']
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 self.assertEqual(len(before)+len(after),len(hist))
116 self.assertEqual(len(before)+len(after),len(hist))
117 for b in before:
117 for b in before:
118 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
119 for a in after:
119 for a in after:
120 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
121 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
122 for s in same:
122 for s in same:
123 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
124
124
125 def test_find_records_keys(self):
125 def test_find_records_keys(self):
126 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 for rec in found:
128 for rec in found:
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130
130
131 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
132 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 for rec in found:
134 for rec in found:
135 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 for rec in found:
137 for rec in found:
138 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 for rec in found:
140 for rec in found:
141 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
142
142
143 def test_find_records_in(self):
143 def test_find_records_in(self):
144 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
145 hist = self.db.get_history()
145 hist = self.db.get_history()
146 even = hist[::2]
146 even = hist[::2]
147 odd = hist[1::2]
147 odd = hist[1::2]
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
150 self.assertEqual(set(even), set(found))
150 self.assertEqual(set(even), set(found))
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
153 self.assertEqual(set(odd), set(found))
153 self.assertEqual(set(odd), set(found))
154
154
155 def test_get_history(self):
155 def test_get_history(self):
156 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
157 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
158 for msg_id in msg_ids:
158 for msg_id in msg_ids:
159 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
160 newt = rec['submitted']
160 newt = rec['submitted']
161 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
162 latest = newt
162 latest = newt
163 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
164 self.assertEqual(self.db.get_history()[-1],msg_id)
164 self.assertEqual(self.db.get_history()[-1],msg_id)
165
165
166 def test_datetime(self):
166 def test_datetime(self):
167 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
168 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
169 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
173 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
174
174
175 def test_drop_matching(self):
175 def test_drop_matching(self):
176 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
177 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
178 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
179 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
180 self.assertEqual(len(recs), 0)
180 self.assertEqual(len(recs), 0)
181
181
182 def test_null(self):
182 def test_null(self):
183 """test None comparison queries"""
183 """test None comparison queries"""
184 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
185
185
186 query = {'msg_id' : None}
186 query = {'msg_id' : None}
187 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
188 self.assertEqual(len(recs), 0)
188 self.assertEqual(len(recs), 0)
189
189
190 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
191 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
192 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
193
193
194 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
195 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
196 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
197 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
198 rec.pop('buffers')
198 rec.pop('buffers')
199 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
200 rec['header']['msg_id'] = 'fubar'
200 rec['header']['msg_id'] = 'fubar'
201 rec2 = self.db.get_record(msg_id)
201 rec2 = self.db.get_record(msg_id)
202 self.assertTrue('buffers' in rec2)
202 self.assertTrue('buffers' in rec2)
203 self.assertFalse('garbage' in rec2)
203 self.assertFalse('garbage' in rec2)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205
205
206 def test_pop_safe_find(self):
206 def test_pop_safe_find(self):
207 """editing query results shouldn't affect record [find]"""
207 """editing query results shouldn't affect record [find]"""
208 msg_id = self.db.get_history()[-1]
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 rec.pop('buffers')
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
211 rec['garbage'] = 'hello'
212 rec['header']['msg_id'] = 'fubar'
212 rec['header']['msg_id'] = 'fubar'
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 self.assertTrue('buffers' in rec2)
214 self.assertTrue('buffers' in rec2)
215 self.assertFalse('garbage' in rec2)
215 self.assertFalse('garbage' in rec2)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217
217
218 def test_pop_safe_find_keys(self):
218 def test_pop_safe_find_keys(self):
219 """editing query results shouldn't affect record [find+keys]"""
219 """editing query results shouldn't affect record [find+keys]"""
220 msg_id = self.db.get_history()[-1]
220 msg_id = self.db.get_history()[-1]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 rec.pop('buffers')
222 rec.pop('buffers')
223 rec['garbage'] = 'hello'
223 rec['garbage'] = 'hello'
224 rec['header']['msg_id'] = 'fubar'
224 rec['header']['msg_id'] = 'fubar'
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 self.assertTrue('buffers' in rec2)
226 self.assertTrue('buffers' in rec2)
227 self.assertFalse('garbage' in rec2)
227 self.assertFalse('garbage' in rec2)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229
229
230
230
231 class TestDictBackend(TaskDBTest, TestCase):
231 class TestDictBackend(TaskDBTest, TestCase):
232
232
233 def create_db(self):
233 def create_db(self):
234 return DictDB()
234 return DictDB()
235
235
236 def test_cull_count(self):
236 def test_cull_count(self):
237 self.db = self.create_db() # skip the load-records init from setUp
237 self.db = self.create_db() # skip the load-records init from setUp
238 self.db.record_limit = 20
238 self.db.record_limit = 20
239 self.db.cull_fraction = 0.2
239 self.db.cull_fraction = 0.2
240 self.load_records(20)
240 self.load_records(20)
241 self.assertEqual(len(self.db.get_history()), 20)
241 self.assertEqual(len(self.db.get_history()), 20)
242 self.load_records(1)
242 self.load_records(1)
243 # 0.2 * 20 = 4, 21 - 4 = 17
243 # 0.2 * 20 = 4, 21 - 4 = 17
244 self.assertEqual(len(self.db.get_history()), 17)
244 self.assertEqual(len(self.db.get_history()), 17)
245 self.load_records(3)
245 self.load_records(3)
246 self.assertEqual(len(self.db.get_history()), 20)
246 self.assertEqual(len(self.db.get_history()), 20)
247 self.load_records(1)
247 self.load_records(1)
248 self.assertEqual(len(self.db.get_history()), 17)
248 self.assertEqual(len(self.db.get_history()), 17)
249
249
250 for i in range(25):
250 for i in range(25):
251 self.load_records(1)
251 self.load_records(1)
252 self.assertTrue(len(self.db.get_history()) >= 17)
252 self.assertTrue(len(self.db.get_history()) >= 17)
253 self.assertTrue(len(self.db.get_history()) <= 20)
253 self.assertTrue(len(self.db.get_history()) <= 20)
254
254
255 def test_cull_size(self):
255 def test_cull_size(self):
256 self.db = self.create_db() # skip the load-records init from setUp
256 self.db = self.create_db() # skip the load-records init from setUp
257 self.db.size_limit = 1000
257 self.db.size_limit = 1000
258 self.db.cull_fraction = 0.2
258 self.db.cull_fraction = 0.2
259 self.load_records(100, buffer_size=10)
259 self.load_records(100, buffer_size=10)
260 self.assertEqual(len(self.db.get_history()), 100)
260 self.assertEqual(len(self.db.get_history()), 100)
261 self.load_records(1, buffer_size=0)
261 self.load_records(1, buffer_size=0)
262 self.assertEqual(len(self.db.get_history()), 101)
262 self.assertEqual(len(self.db.get_history()), 101)
263 self.load_records(1, buffer_size=1)
263 self.load_records(1, buffer_size=1)
264 # 0.2 * 100 = 20, 101 - 20 = 81
264 # 0.2 * 100 = 20, 101 - 20 = 81
265 self.assertEqual(len(self.db.get_history()), 81)
265 self.assertEqual(len(self.db.get_history()), 81)
266
266
267 def test_cull_size_drop(self):
267 def test_cull_size_drop(self):
268 """dropping records updates tracked buffer size"""
268 """dropping records updates tracked buffer size"""
269 self.db = self.create_db() # skip the load-records init from setUp
269 self.db = self.create_db() # skip the load-records init from setUp
270 self.db.size_limit = 1000
270 self.db.size_limit = 1000
271 self.db.cull_fraction = 0.2
271 self.db.cull_fraction = 0.2
272 self.load_records(100, buffer_size=10)
272 self.load_records(100, buffer_size=10)
273 self.assertEqual(len(self.db.get_history()), 100)
273 self.assertEqual(len(self.db.get_history()), 100)
274 self.db.drop_record(self.db.get_history()[-1])
274 self.db.drop_record(self.db.get_history()[-1])
275 self.assertEqual(len(self.db.get_history()), 99)
275 self.assertEqual(len(self.db.get_history()), 99)
276 self.load_records(1, buffer_size=5)
276 self.load_records(1, buffer_size=5)
277 self.assertEqual(len(self.db.get_history()), 100)
277 self.assertEqual(len(self.db.get_history()), 100)
278 self.load_records(1, buffer_size=5)
278 self.load_records(1, buffer_size=5)
279 self.assertEqual(len(self.db.get_history()), 101)
279 self.assertEqual(len(self.db.get_history()), 101)
280 self.load_records(1, buffer_size=1)
280 self.load_records(1, buffer_size=1)
281 self.assertEqual(len(self.db.get_history()), 81)
281 self.assertEqual(len(self.db.get_history()), 81)
282
282
283 def test_cull_size_update(self):
283 def test_cull_size_update(self):
284 """updating records updates tracked buffer size"""
284 """updating records updates tracked buffer size"""
285 self.db = self.create_db() # skip the load-records init from setUp
285 self.db = self.create_db() # skip the load-records init from setUp
286 self.db.size_limit = 1000
286 self.db.size_limit = 1000
287 self.db.cull_fraction = 0.2
287 self.db.cull_fraction = 0.2
288 self.load_records(100, buffer_size=10)
288 self.load_records(100, buffer_size=10)
289 self.assertEqual(len(self.db.get_history()), 100)
289 self.assertEqual(len(self.db.get_history()), 100)
290 msg_id = self.db.get_history()[-1]
290 msg_id = self.db.get_history()[-1]
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
292 self.assertEqual(len(self.db.get_history()), 100)
292 self.assertEqual(len(self.db.get_history()), 100)
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
294 self.assertEqual(len(self.db.get_history()), 79)
294 self.assertEqual(len(self.db.get_history()), 79)
295
295
296 class TestSQLiteBackend(TaskDBTest, TestCase):
296 class TestSQLiteBackend(TaskDBTest, TestCase):
297
297
298 @dec.skip_without('sqlite3')
298 @dec.skip_without('sqlite3')
299 def create_db(self):
299 def create_db(self):
300 location, fname = os.path.split(temp_db)
300 location, fname = os.path.split(temp_db)
301 log = logging.getLogger('test')
301 log = logging.getLogger('test')
302 log.setLevel(logging.CRITICAL)
302 log.setLevel(logging.CRITICAL)
303 return SQLiteDB(location=location, fname=fname, log=log)
303 return SQLiteDB(location=location, fname=fname, log=log)
304
304
305 def tearDown(self):
305 def tearDown(self):
306 self.db._db.close()
306 self.db._db.close()
307
307
308
308
309 def teardown():
309 def teardown():
310 """cleanup task db file after all tests have run"""
310 """cleanup task db file after all tests have run"""
311 try:
311 try:
312 os.remove(temp_db)
312 os.remove(temp_db)
313 except:
313 except:
314 pass
314 pass
@@ -1,136 +1,136 b''
1 """Tests for dependency.py
1 """Tests for dependency.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 __docformat__ = "restructuredtext en"
8 __docformat__ = "restructuredtext en"
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16
16
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-------------------------------------------------------------------------------
19 #-------------------------------------------------------------------------------
20
20
21 # import
21 # import
22 import os
22 import os
23
23
24 from IPython.utils.pickleutil import can, uncan
24 from IPython.utils.pickleutil import can, uncan
25
25
26 import IPython.parallel as pmod
26 import ipython_parallel as pmod
27 from IPython.parallel.util import interactive
27 from ipython_parallel.util import interactive
28
28
29 from IPython.parallel.tests import add_engines
29 from ipython_parallel.tests import add_engines
30 from .clienttest import ClusterTestCase
30 from .clienttest import ClusterTestCase
31
31
32 def setup():
32 def setup():
33 add_engines(1, total=True)
33 add_engines(1, total=True)
34
34
35 @pmod.require('time')
35 @pmod.require('time')
36 def wait(n):
36 def wait(n):
37 time.sleep(n)
37 time.sleep(n)
38 return n
38 return n
39
39
40 @pmod.interactive
40 @pmod.interactive
41 def func(x):
41 def func(x):
42 return x*x
42 return x*x
43
43
44 mixed = list(map(str, range(10)))
44 mixed = list(map(str, range(10)))
45 completed = list(map(str, range(0,10,2)))
45 completed = list(map(str, range(0,10,2)))
46 failed = list(map(str, range(1,10,2)))
46 failed = list(map(str, range(1,10,2)))
47
47
48 class DependencyTest(ClusterTestCase):
48 class DependencyTest(ClusterTestCase):
49
49
50 def setUp(self):
50 def setUp(self):
51 ClusterTestCase.setUp(self)
51 ClusterTestCase.setUp(self)
52 self.user_ns = {'__builtins__' : __builtins__}
52 self.user_ns = {'__builtins__' : __builtins__}
53 self.view = self.client.load_balanced_view()
53 self.view = self.client.load_balanced_view()
54 self.dview = self.client[-1]
54 self.dview = self.client[-1]
55 self.succeeded = set(map(str, range(0,25,2)))
55 self.succeeded = set(map(str, range(0,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
57
57
58 def assertMet(self, dep):
58 def assertMet(self, dep):
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
60
60
61 def assertUnmet(self, dep):
61 def assertUnmet(self, dep):
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
63
63
64 def assertUnreachable(self, dep):
64 def assertUnreachable(self, dep):
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
66
66
67 def assertReachable(self, dep):
67 def assertReachable(self, dep):
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
69
69
70 def cancan(self, f):
70 def cancan(self, f):
71 """decorator to pass through canning into self.user_ns"""
71 """decorator to pass through canning into self.user_ns"""
72 return uncan(can(f), self.user_ns)
72 return uncan(can(f), self.user_ns)
73
73
74 def test_require_imports(self):
74 def test_require_imports(self):
75 """test that @require imports names"""
75 """test that @require imports names"""
76 @self.cancan
76 @self.cancan
77 @pmod.require('base64')
77 @pmod.require('base64')
78 @interactive
78 @interactive
79 def encode(arg):
79 def encode(arg):
80 return base64.b64encode(arg)
80 return base64.b64encode(arg)
81 # must pass through canning to properly connect namespaces
81 # must pass through canning to properly connect namespaces
82 self.assertEqual(encode(b'foo'), b'Zm9v')
82 self.assertEqual(encode(b'foo'), b'Zm9v')
83
83
84 def test_success_only(self):
84 def test_success_only(self):
85 dep = pmod.Dependency(mixed, success=True, failure=False)
85 dep = pmod.Dependency(mixed, success=True, failure=False)
86 self.assertUnmet(dep)
86 self.assertUnmet(dep)
87 self.assertUnreachable(dep)
87 self.assertUnreachable(dep)
88 dep.all=False
88 dep.all=False
89 self.assertMet(dep)
89 self.assertMet(dep)
90 self.assertReachable(dep)
90 self.assertReachable(dep)
91 dep = pmod.Dependency(completed, success=True, failure=False)
91 dep = pmod.Dependency(completed, success=True, failure=False)
92 self.assertMet(dep)
92 self.assertMet(dep)
93 self.assertReachable(dep)
93 self.assertReachable(dep)
94 dep.all=False
94 dep.all=False
95 self.assertMet(dep)
95 self.assertMet(dep)
96 self.assertReachable(dep)
96 self.assertReachable(dep)
97
97
98 def test_failure_only(self):
98 def test_failure_only(self):
99 dep = pmod.Dependency(mixed, success=False, failure=True)
99 dep = pmod.Dependency(mixed, success=False, failure=True)
100 self.assertUnmet(dep)
100 self.assertUnmet(dep)
101 self.assertUnreachable(dep)
101 self.assertUnreachable(dep)
102 dep.all=False
102 dep.all=False
103 self.assertMet(dep)
103 self.assertMet(dep)
104 self.assertReachable(dep)
104 self.assertReachable(dep)
105 dep = pmod.Dependency(completed, success=False, failure=True)
105 dep = pmod.Dependency(completed, success=False, failure=True)
106 self.assertUnmet(dep)
106 self.assertUnmet(dep)
107 self.assertUnreachable(dep)
107 self.assertUnreachable(dep)
108 dep.all=False
108 dep.all=False
109 self.assertUnmet(dep)
109 self.assertUnmet(dep)
110 self.assertUnreachable(dep)
110 self.assertUnreachable(dep)
111
111
112 def test_require_function(self):
112 def test_require_function(self):
113
113
114 @pmod.interactive
114 @pmod.interactive
115 def bar(a):
115 def bar(a):
116 return func(a)
116 return func(a)
117
117
118 @pmod.require(func)
118 @pmod.require(func)
119 @pmod.interactive
119 @pmod.interactive
120 def bar2(a):
120 def bar2(a):
121 return func(a)
121 return func(a)
122
122
123 self.client[:].clear()
123 self.client[:].clear()
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
125 ar = self.view.apply_async(bar2, 5)
125 ar = self.view.apply_async(bar2, 5)
126 self.assertEqual(ar.get(5), func(5))
126 self.assertEqual(ar.get(5), func(5))
127
127
128 def test_require_object(self):
128 def test_require_object(self):
129
129
130 @pmod.require(foo=func)
130 @pmod.require(foo=func)
131 @pmod.interactive
131 @pmod.interactive
132 def bar(a):
132 def bar(a):
133 return foo(a)
133 return foo(a)
134
134
135 ar = self.view.apply_async(bar, 5)
135 ar = self.view.apply_async(bar, 5)
136 self.assertEqual(ar.get(5), func(5))
136 self.assertEqual(ar.get(5), func(5))
@@ -1,194 +1,194 b''
1 """Tests for launchers
1 """Tests for launchers
2
2
3 Doesn't actually start any subprocesses, but goes through the motions of constructing
3 Doesn't actually start any subprocesses, but goes through the motions of constructing
4 objects, which should test basic config.
4 objects, which should test basic config.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2013 The IPython Development Team
12 # Copyright (C) 2013 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 import logging
22 import logging
23 import os
23 import os
24 import shutil
24 import shutil
25 import sys
25 import sys
26 import tempfile
26 import tempfile
27
27
28 from unittest import TestCase
28 from unittest import TestCase
29
29
30 from nose import SkipTest
30 from nose import SkipTest
31
31
32 from IPython.config import Config
32 from IPython.config import Config
33
33
34 from IPython.parallel.apps import launcher
34 from ipython_parallel.apps import launcher
35
35
36 from IPython.testing import decorators as dec
36 from IPython.testing import decorators as dec
37 from IPython.utils.py3compat import string_types
37 from IPython.utils.py3compat import string_types
38
38
39
39
40 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
41 # TestCase Mixins
41 # TestCase Mixins
42 #-------------------------------------------------------------------------------
42 #-------------------------------------------------------------------------------
43
43
44 class LauncherTest:
44 class LauncherTest:
45 """Mixin for generic launcher tests"""
45 """Mixin for generic launcher tests"""
46 def setUp(self):
46 def setUp(self):
47 self.profile_dir = tempfile.mkdtemp(prefix="profile_")
47 self.profile_dir = tempfile.mkdtemp(prefix="profile_")
48
48
49 def tearDown(self):
49 def tearDown(self):
50 shutil.rmtree(self.profile_dir)
50 shutil.rmtree(self.profile_dir)
51
51
52 @property
52 @property
53 def config(self):
53 def config(self):
54 return Config()
54 return Config()
55
55
56 def build_launcher(self, **kwargs):
56 def build_launcher(self, **kwargs):
57 kw = dict(
57 kw = dict(
58 work_dir=self.profile_dir,
58 work_dir=self.profile_dir,
59 profile_dir=self.profile_dir,
59 profile_dir=self.profile_dir,
60 config=self.config,
60 config=self.config,
61 cluster_id='',
61 cluster_id='',
62 log=logging.getLogger(),
62 log=logging.getLogger(),
63 )
63 )
64 kw.update(kwargs)
64 kw.update(kwargs)
65 return self.launcher_class(**kw)
65 return self.launcher_class(**kw)
66
66
67 def test_profile_dir_arg(self):
67 def test_profile_dir_arg(self):
68 launcher = self.build_launcher()
68 launcher = self.build_launcher()
69 self.assertTrue("--profile-dir" in launcher.cluster_args)
69 self.assertTrue("--profile-dir" in launcher.cluster_args)
70 self.assertTrue(self.profile_dir in launcher.cluster_args)
70 self.assertTrue(self.profile_dir in launcher.cluster_args)
71
71
72 def test_cluster_id_arg(self):
72 def test_cluster_id_arg(self):
73 launcher = self.build_launcher()
73 launcher = self.build_launcher()
74 self.assertTrue("--cluster-id" in launcher.cluster_args)
74 self.assertTrue("--cluster-id" in launcher.cluster_args)
75 idx = launcher.cluster_args.index("--cluster-id")
75 idx = launcher.cluster_args.index("--cluster-id")
76 self.assertEqual(launcher.cluster_args[idx+1], '')
76 self.assertEqual(launcher.cluster_args[idx+1], '')
77 launcher.cluster_id = 'foo'
77 launcher.cluster_id = 'foo'
78 self.assertEqual(launcher.cluster_args[idx+1], 'foo')
78 self.assertEqual(launcher.cluster_args[idx+1], 'foo')
79
79
80 def test_args(self):
80 def test_args(self):
81 launcher = self.build_launcher()
81 launcher = self.build_launcher()
82 for arg in launcher.args:
82 for arg in launcher.args:
83 self.assertTrue(isinstance(arg, string_types), str(arg))
83 self.assertTrue(isinstance(arg, string_types), str(arg))
84
84
85 class BatchTest:
85 class BatchTest:
86 """Tests for batch-system launchers (LSF, SGE, PBS)"""
86 """Tests for batch-system launchers (LSF, SGE, PBS)"""
87 def test_batch_template(self):
87 def test_batch_template(self):
88 launcher = self.build_launcher()
88 launcher = self.build_launcher()
89 batch_file = os.path.join(self.profile_dir, launcher.batch_file_name)
89 batch_file = os.path.join(self.profile_dir, launcher.batch_file_name)
90 self.assertEqual(launcher.batch_file, batch_file)
90 self.assertEqual(launcher.batch_file, batch_file)
91 launcher.write_batch_script(1)
91 launcher.write_batch_script(1)
92 self.assertTrue(os.path.isfile(batch_file))
92 self.assertTrue(os.path.isfile(batch_file))
93
93
94 class SSHTest:
94 class SSHTest:
95 """Tests for SSH launchers"""
95 """Tests for SSH launchers"""
96 def test_cluster_id_arg(self):
96 def test_cluster_id_arg(self):
97 raise SkipTest("SSH Launchers don't support cluster ID")
97 raise SkipTest("SSH Launchers don't support cluster ID")
98
98
99 def test_remote_profile_dir(self):
99 def test_remote_profile_dir(self):
100 cfg = Config()
100 cfg = Config()
101 launcher_cfg = getattr(cfg, self.launcher_class.__name__)
101 launcher_cfg = getattr(cfg, self.launcher_class.__name__)
102 launcher_cfg.remote_profile_dir = "foo"
102 launcher_cfg.remote_profile_dir = "foo"
103 launcher = self.build_launcher(config=cfg)
103 launcher = self.build_launcher(config=cfg)
104 self.assertEqual(launcher.remote_profile_dir, "foo")
104 self.assertEqual(launcher.remote_profile_dir, "foo")
105
105
106 def test_remote_profile_dir_default(self):
106 def test_remote_profile_dir_default(self):
107 launcher = self.build_launcher()
107 launcher = self.build_launcher()
108 self.assertEqual(launcher.remote_profile_dir, self.profile_dir)
108 self.assertEqual(launcher.remote_profile_dir, self.profile_dir)
109
109
110 #-------------------------------------------------------------------------------
110 #-------------------------------------------------------------------------------
111 # Controller Launcher Tests
111 # Controller Launcher Tests
112 #-------------------------------------------------------------------------------
112 #-------------------------------------------------------------------------------
113
113
114 class ControllerLauncherTest(LauncherTest):
114 class ControllerLauncherTest(LauncherTest):
115 """Tests for Controller Launchers"""
115 """Tests for Controller Launchers"""
116 pass
116 pass
117
117
118 class TestLocalControllerLauncher(ControllerLauncherTest, TestCase):
118 class TestLocalControllerLauncher(ControllerLauncherTest, TestCase):
119 launcher_class = launcher.LocalControllerLauncher
119 launcher_class = launcher.LocalControllerLauncher
120
120
121 class TestMPIControllerLauncher(ControllerLauncherTest, TestCase):
121 class TestMPIControllerLauncher(ControllerLauncherTest, TestCase):
122 launcher_class = launcher.MPIControllerLauncher
122 launcher_class = launcher.MPIControllerLauncher
123
123
124 class TestPBSControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
124 class TestPBSControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
125 launcher_class = launcher.PBSControllerLauncher
125 launcher_class = launcher.PBSControllerLauncher
126
126
127 class TestSGEControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
127 class TestSGEControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
128 launcher_class = launcher.SGEControllerLauncher
128 launcher_class = launcher.SGEControllerLauncher
129
129
130 class TestLSFControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
130 class TestLSFControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
131 launcher_class = launcher.LSFControllerLauncher
131 launcher_class = launcher.LSFControllerLauncher
132
132
133 class TestHTCondorControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
133 class TestHTCondorControllerLauncher(BatchTest, ControllerLauncherTest, TestCase):
134 launcher_class = launcher.HTCondorControllerLauncher
134 launcher_class = launcher.HTCondorControllerLauncher
135
135
136 class TestSSHControllerLauncher(SSHTest, ControllerLauncherTest, TestCase):
136 class TestSSHControllerLauncher(SSHTest, ControllerLauncherTest, TestCase):
137 launcher_class = launcher.SSHControllerLauncher
137 launcher_class = launcher.SSHControllerLauncher
138
138
139 #-------------------------------------------------------------------------------
139 #-------------------------------------------------------------------------------
140 # Engine Set Launcher Tests
140 # Engine Set Launcher Tests
141 #-------------------------------------------------------------------------------
141 #-------------------------------------------------------------------------------
142
142
143 class EngineSetLauncherTest(LauncherTest):
143 class EngineSetLauncherTest(LauncherTest):
144 """Tests for EngineSet launchers"""
144 """Tests for EngineSet launchers"""
145 pass
145 pass
146
146
147 class TestLocalEngineSetLauncher(EngineSetLauncherTest, TestCase):
147 class TestLocalEngineSetLauncher(EngineSetLauncherTest, TestCase):
148 launcher_class = launcher.LocalEngineSetLauncher
148 launcher_class = launcher.LocalEngineSetLauncher
149
149
150 class TestMPIEngineSetLauncher(EngineSetLauncherTest, TestCase):
150 class TestMPIEngineSetLauncher(EngineSetLauncherTest, TestCase):
151 launcher_class = launcher.MPIEngineSetLauncher
151 launcher_class = launcher.MPIEngineSetLauncher
152
152
153 class TestPBSEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
153 class TestPBSEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
154 launcher_class = launcher.PBSEngineSetLauncher
154 launcher_class = launcher.PBSEngineSetLauncher
155
155
156 class TestSGEEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
156 class TestSGEEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
157 launcher_class = launcher.SGEEngineSetLauncher
157 launcher_class = launcher.SGEEngineSetLauncher
158
158
159 class TestLSFEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
159 class TestLSFEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
160 launcher_class = launcher.LSFEngineSetLauncher
160 launcher_class = launcher.LSFEngineSetLauncher
161
161
162 class TestHTCondorEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
162 class TestHTCondorEngineSetLauncher(BatchTest, EngineSetLauncherTest, TestCase):
163 launcher_class = launcher.HTCondorEngineSetLauncher
163 launcher_class = launcher.HTCondorEngineSetLauncher
164
164
165 class TestSSHEngineSetLauncher(EngineSetLauncherTest, TestCase):
165 class TestSSHEngineSetLauncher(EngineSetLauncherTest, TestCase):
166 launcher_class = launcher.SSHEngineSetLauncher
166 launcher_class = launcher.SSHEngineSetLauncher
167
167
168 def test_cluster_id_arg(self):
168 def test_cluster_id_arg(self):
169 raise SkipTest("SSH Launchers don't support cluster ID")
169 raise SkipTest("SSH Launchers don't support cluster ID")
170
170
171 class TestSSHProxyEngineSetLauncher(SSHTest, LauncherTest, TestCase):
171 class TestSSHProxyEngineSetLauncher(SSHTest, LauncherTest, TestCase):
172 launcher_class = launcher.SSHProxyEngineSetLauncher
172 launcher_class = launcher.SSHProxyEngineSetLauncher
173
173
174 class TestSSHEngineLauncher(SSHTest, LauncherTest, TestCase):
174 class TestSSHEngineLauncher(SSHTest, LauncherTest, TestCase):
175 launcher_class = launcher.SSHEngineLauncher
175 launcher_class = launcher.SSHEngineLauncher
176
176
177 #-------------------------------------------------------------------------------
177 #-------------------------------------------------------------------------------
178 # Windows Launcher Tests
178 # Windows Launcher Tests
179 #-------------------------------------------------------------------------------
179 #-------------------------------------------------------------------------------
180
180
181 class WinHPCTest:
181 class WinHPCTest:
182 """Tests for WinHPC Launchers"""
182 """Tests for WinHPC Launchers"""
183 def test_batch_template(self):
183 def test_batch_template(self):
184 launcher = self.build_launcher()
184 launcher = self.build_launcher()
185 job_file = os.path.join(self.profile_dir, launcher.job_file_name)
185 job_file = os.path.join(self.profile_dir, launcher.job_file_name)
186 self.assertEqual(launcher.job_file, job_file)
186 self.assertEqual(launcher.job_file, job_file)
187 launcher.write_job_file(1)
187 launcher.write_job_file(1)
188 self.assertTrue(os.path.isfile(job_file))
188 self.assertTrue(os.path.isfile(job_file))
189
189
190 class TestWinHPCControllerLauncher(WinHPCTest, ControllerLauncherTest, TestCase):
190 class TestWinHPCControllerLauncher(WinHPCTest, ControllerLauncherTest, TestCase):
191 launcher_class = launcher.WindowsHPCControllerLauncher
191 launcher_class = launcher.WindowsHPCControllerLauncher
192
192
193 class TestWinHPCEngineSetLauncher(WinHPCTest, EngineSetLauncherTest, TestCase):
193 class TestWinHPCEngineSetLauncher(WinHPCTest, EngineSetLauncherTest, TestCase):
194 launcher_class = launcher.WindowsHPCEngineSetLauncher
194 launcher_class = launcher.WindowsHPCEngineSetLauncher
@@ -1,221 +1,221 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test LoadBalancedView objects
2 """test LoadBalancedView objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from nose import SkipTest
23 from nose import SkipTest
24 from nose.plugins.attrib import attr
24 from nose.plugins.attrib import attr
25
25
26 from IPython import parallel as pmod
26 from IPython import parallel as pmod
27 from IPython.parallel import error
27 from ipython_parallel import error
28
28
29 from IPython.parallel.tests import add_engines
29 from ipython_parallel.tests import add_engines
30
30
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32
32
33 def setup():
33 def setup():
34 add_engines(3, total=True)
34 add_engines(3, total=True)
35
35
36 class TestLoadBalancedView(ClusterTestCase):
36 class TestLoadBalancedView(ClusterTestCase):
37
37
38 def setUp(self):
38 def setUp(self):
39 ClusterTestCase.setUp(self)
39 ClusterTestCase.setUp(self)
40 self.view = self.client.load_balanced_view()
40 self.view = self.client.load_balanced_view()
41
41
42 @attr('crash')
42 @attr('crash')
43 def test_z_crash_task(self):
43 def test_z_crash_task(self):
44 """test graceful handling of engine death (balanced)"""
44 """test graceful handling of engine death (balanced)"""
45 # self.add_engines(1)
45 # self.add_engines(1)
46 ar = self.view.apply_async(crash)
46 ar = self.view.apply_async(crash)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 eid = ar.engine_id
48 eid = ar.engine_id
49 tic = time.time()
49 tic = time.time()
50 while eid in self.client.ids and time.time()-tic < 5:
50 while eid in self.client.ids and time.time()-tic < 5:
51 time.sleep(.01)
51 time.sleep(.01)
52 self.client.spin()
52 self.client.spin()
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54
54
55 def test_map(self):
55 def test_map(self):
56 def f(x):
56 def f(x):
57 return x**2
57 return x**2
58 data = list(range(16))
58 data = list(range(16))
59 r = self.view.map_sync(f, data)
59 r = self.view.map_sync(f, data)
60 self.assertEqual(r, list(map(f, data)))
60 self.assertEqual(r, list(map(f, data)))
61
61
62 def test_map_generator(self):
62 def test_map_generator(self):
63 def f(x):
63 def f(x):
64 return x**2
64 return x**2
65
65
66 data = list(range(16))
66 data = list(range(16))
67 r = self.view.map_sync(f, iter(data))
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, list(map(f, iter(data))))
68 self.assertEqual(r, list(map(f, iter(data))))
69
69
70 def test_map_short_first(self):
70 def test_map_short_first(self):
71 def f(x,y):
71 def f(x,y):
72 if y is None:
72 if y is None:
73 return y
73 return y
74 if x is None:
74 if x is None:
75 return x
75 return x
76 return x*y
76 return x*y
77 data = list(range(10))
77 data = list(range(10))
78 data2 = list(range(4))
78 data2 = list(range(4))
79
79
80 r = self.view.map_sync(f, data, data2)
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, list(map(f, data, data2)))
81 self.assertEqual(r, list(map(f, data, data2)))
82
82
83 def test_map_short_last(self):
83 def test_map_short_last(self):
84 def f(x,y):
84 def f(x,y):
85 if y is None:
85 if y is None:
86 return y
86 return y
87 if x is None:
87 if x is None:
88 return x
88 return x
89 return x*y
89 return x*y
90 data = list(range(4))
90 data = list(range(4))
91 data2 = list(range(10))
91 data2 = list(range(10))
92
92
93 r = self.view.map_sync(f, data, data2)
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, list(map(f, data, data2)))
94 self.assertEqual(r, list(map(f, data, data2)))
95
95
96 def test_map_unordered(self):
96 def test_map_unordered(self):
97 def f(x):
97 def f(x):
98 return x**2
98 return x**2
99 def slow_f(x):
99 def slow_f(x):
100 import time
100 import time
101 time.sleep(0.05*x)
101 time.sleep(0.05*x)
102 return x**2
102 return x**2
103 data = list(range(16,0,-1))
103 data = list(range(16,0,-1))
104 reference = list(map(f, data))
104 reference = list(map(f, data))
105
105
106 amr = self.view.map_async(slow_f, data, ordered=False)
106 amr = self.view.map_async(slow_f, data, ordered=False)
107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
108 # check individual elements, retrieved as they come
108 # check individual elements, retrieved as they come
109 # list comprehension uses __iter__
109 # list comprehension uses __iter__
110 astheycame = [ r for r in amr ]
110 astheycame = [ r for r in amr ]
111 # Ensure that at least one result came out of order:
111 # Ensure that at least one result came out of order:
112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
114
114
115 def test_map_ordered(self):
115 def test_map_ordered(self):
116 def f(x):
116 def f(x):
117 return x**2
117 return x**2
118 def slow_f(x):
118 def slow_f(x):
119 import time
119 import time
120 time.sleep(0.05*x)
120 time.sleep(0.05*x)
121 return x**2
121 return x**2
122 data = list(range(16,0,-1))
122 data = list(range(16,0,-1))
123 reference = list(map(f, data))
123 reference = list(map(f, data))
124
124
125 amr = self.view.map_async(slow_f, data)
125 amr = self.view.map_async(slow_f, data)
126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
127 # check individual elements, retrieved as they come
127 # check individual elements, retrieved as they come
128 # list(amr) uses __iter__
128 # list(amr) uses __iter__
129 astheycame = list(amr)
129 astheycame = list(amr)
130 # Ensure that results came in order
130 # Ensure that results came in order
131 self.assertEqual(astheycame, reference)
131 self.assertEqual(astheycame, reference)
132 self.assertEqual(amr.result, reference)
132 self.assertEqual(amr.result, reference)
133
133
134 def test_map_iterable(self):
134 def test_map_iterable(self):
135 """test map on iterables (balanced)"""
135 """test map on iterables (balanced)"""
136 view = self.view
136 view = self.view
137 # 101 is prime, so it won't be evenly distributed
137 # 101 is prime, so it won't be evenly distributed
138 arr = range(101)
138 arr = range(101)
139 # so that it will be an iterator, even in Python 3
139 # so that it will be an iterator, even in Python 3
140 it = iter(arr)
140 it = iter(arr)
141 r = view.map_sync(lambda x:x, arr)
141 r = view.map_sync(lambda x:x, arr)
142 self.assertEqual(r, list(arr))
142 self.assertEqual(r, list(arr))
143
143
144
144
145 def test_abort(self):
145 def test_abort(self):
146 view = self.view
146 view = self.view
147 ar = self.client[:].apply_async(time.sleep, .5)
147 ar = self.client[:].apply_async(time.sleep, .5)
148 ar = self.client[:].apply_async(time.sleep, .5)
148 ar = self.client[:].apply_async(time.sleep, .5)
149 time.sleep(0.2)
149 time.sleep(0.2)
150 ar2 = view.apply_async(lambda : 2)
150 ar2 = view.apply_async(lambda : 2)
151 ar3 = view.apply_async(lambda : 3)
151 ar3 = view.apply_async(lambda : 3)
152 view.abort(ar2)
152 view.abort(ar2)
153 view.abort(ar3.msg_ids)
153 view.abort(ar3.msg_ids)
154 self.assertRaises(error.TaskAborted, ar2.get)
154 self.assertRaises(error.TaskAborted, ar2.get)
155 self.assertRaises(error.TaskAborted, ar3.get)
155 self.assertRaises(error.TaskAborted, ar3.get)
156
156
157 def test_retries(self):
157 def test_retries(self):
158 self.minimum_engines(3)
158 self.minimum_engines(3)
159 view = self.view
159 view = self.view
160 def fail():
160 def fail():
161 assert False
161 assert False
162 for r in range(len(self.client)-1):
162 for r in range(len(self.client)-1):
163 with view.temp_flags(retries=r):
163 with view.temp_flags(retries=r):
164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
165
165
166 with view.temp_flags(retries=len(self.client), timeout=0.1):
166 with view.temp_flags(retries=len(self.client), timeout=0.1):
167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
168
168
169 def test_short_timeout(self):
169 def test_short_timeout(self):
170 self.minimum_engines(2)
170 self.minimum_engines(2)
171 view = self.view
171 view = self.view
172 def fail():
172 def fail():
173 import time
173 import time
174 time.sleep(0.25)
174 time.sleep(0.25)
175 assert False
175 assert False
176 with view.temp_flags(retries=1, timeout=0.01):
176 with view.temp_flags(retries=1, timeout=0.01):
177 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
177 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
178
178
179 def test_invalid_dependency(self):
179 def test_invalid_dependency(self):
180 view = self.view
180 view = self.view
181 with view.temp_flags(after='12345'):
181 with view.temp_flags(after='12345'):
182 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
182 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
183
183
184 def test_impossible_dependency(self):
184 def test_impossible_dependency(self):
185 self.minimum_engines(2)
185 self.minimum_engines(2)
186 view = self.client.load_balanced_view()
186 view = self.client.load_balanced_view()
187 ar1 = view.apply_async(lambda : 1)
187 ar1 = view.apply_async(lambda : 1)
188 ar1.get()
188 ar1.get()
189 e1 = ar1.engine_id
189 e1 = ar1.engine_id
190 e2 = e1
190 e2 = e1
191 while e2 == e1:
191 while e2 == e1:
192 ar2 = view.apply_async(lambda : 1)
192 ar2 = view.apply_async(lambda : 1)
193 ar2.get()
193 ar2.get()
194 e2 = ar2.engine_id
194 e2 = ar2.engine_id
195
195
196 with view.temp_flags(follow=[ar1, ar2]):
196 with view.temp_flags(follow=[ar1, ar2]):
197 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
197 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
198
198
199
199
200 def test_follow(self):
200 def test_follow(self):
201 ar = self.view.apply_async(lambda : 1)
201 ar = self.view.apply_async(lambda : 1)
202 ar.get()
202 ar.get()
203 ars = []
203 ars = []
204 first_id = ar.engine_id
204 first_id = ar.engine_id
205
205
206 self.view.follow = ar
206 self.view.follow = ar
207 for i in range(5):
207 for i in range(5):
208 ars.append(self.view.apply_async(lambda : 1))
208 ars.append(self.view.apply_async(lambda : 1))
209 self.view.wait(ars)
209 self.view.wait(ars)
210 for ar in ars:
210 for ar in ars:
211 self.assertEqual(ar.engine_id, first_id)
211 self.assertEqual(ar.engine_id, first_id)
212
212
213 def test_after(self):
213 def test_after(self):
214 view = self.view
214 view = self.view
215 ar = view.apply_async(time.sleep, 0.5)
215 ar = view.apply_async(time.sleep, 0.5)
216 with view.temp_flags(after=ar):
216 with view.temp_flags(after=ar):
217 ar2 = view.apply_async(lambda : 1)
217 ar2 = view.apply_async(lambda : 1)
218
218
219 ar.wait()
219 ar.wait()
220 ar2.wait()
220 ar2.wait()
221 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
221 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,374 +1,374 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Test Parallel magics
2 """Test Parallel magics
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import re
19 import re
20 import time
20 import time
21
21
22
22
23 from IPython.testing import decorators as dec
23 from IPython.testing import decorators as dec
24 from IPython.utils.io import capture_output
24 from IPython.utils.io import capture_output
25
25
26 from IPython import parallel as pmod
26 from IPython import parallel as pmod
27 from IPython.parallel import AsyncResult
27 from ipython_parallel import AsyncResult
28
28
29 from IPython.parallel.tests import add_engines
29 from ipython_parallel.tests import add_engines
30
30
31 from .clienttest import ClusterTestCase, generate_output
31 from .clienttest import ClusterTestCase, generate_output
32
32
33 def setup():
33 def setup():
34 add_engines(3, total=True)
34 add_engines(3, total=True)
35
35
36 class TestParallelMagics(ClusterTestCase):
36 class TestParallelMagics(ClusterTestCase):
37
37
38 def test_px_blocking(self):
38 def test_px_blocking(self):
39 ip = get_ipython()
39 ip = get_ipython()
40 v = self.client[-1:]
40 v = self.client[-1:]
41 v.activate()
41 v.activate()
42 v.block=True
42 v.block=True
43
43
44 ip.magic('px a=5')
44 ip.magic('px a=5')
45 self.assertEqual(v['a'], [5])
45 self.assertEqual(v['a'], [5])
46 ip.magic('px a=10')
46 ip.magic('px a=10')
47 self.assertEqual(v['a'], [10])
47 self.assertEqual(v['a'], [10])
48 # just 'print a' works ~99% of the time, but this ensures that
48 # just 'print a' works ~99% of the time, but this ensures that
49 # the stdout message has arrived when the result is finished:
49 # the stdout message has arrived when the result is finished:
50 with capture_output() as io:
50 with capture_output() as io:
51 ip.magic(
51 ip.magic(
52 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
52 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
53 )
53 )
54 self.assertIn('[stdout:', io.stdout)
54 self.assertIn('[stdout:', io.stdout)
55 self.assertNotIn('\n\n', io.stdout)
55 self.assertNotIn('\n\n', io.stdout)
56 assert io.stdout.rstrip().endswith('10')
56 assert io.stdout.rstrip().endswith('10')
57 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
57 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
58
58
59 def _check_generated_stderr(self, stderr, n):
59 def _check_generated_stderr(self, stderr, n):
60 expected = [
60 expected = [
61 r'\[stderr:\d+\]',
61 r'\[stderr:\d+\]',
62 '^stderr$',
62 '^stderr$',
63 '^stderr2$',
63 '^stderr2$',
64 ] * n
64 ] * n
65
65
66 self.assertNotIn('\n\n', stderr)
66 self.assertNotIn('\n\n', stderr)
67 lines = stderr.splitlines()
67 lines = stderr.splitlines()
68 self.assertEqual(len(lines), len(expected), stderr)
68 self.assertEqual(len(lines), len(expected), stderr)
69 for line,expect in zip(lines, expected):
69 for line,expect in zip(lines, expected):
70 if isinstance(expect, str):
70 if isinstance(expect, str):
71 expect = [expect]
71 expect = [expect]
72 for ex in expect:
72 for ex in expect:
73 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
73 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
74
74
75 def test_cellpx_block_args(self):
75 def test_cellpx_block_args(self):
76 """%%px --[no]block flags work"""
76 """%%px --[no]block flags work"""
77 ip = get_ipython()
77 ip = get_ipython()
78 v = self.client[-1:]
78 v = self.client[-1:]
79 v.activate()
79 v.activate()
80 v.block=False
80 v.block=False
81
81
82 for block in (True, False):
82 for block in (True, False):
83 v.block = block
83 v.block = block
84 ip.magic("pxconfig --verbose")
84 ip.magic("pxconfig --verbose")
85 with capture_output(display=False) as io:
85 with capture_output(display=False) as io:
86 ip.run_cell_magic("px", "", "1")
86 ip.run_cell_magic("px", "", "1")
87 if block:
87 if block:
88 assert io.stdout.startswith("Parallel"), io.stdout
88 assert io.stdout.startswith("Parallel"), io.stdout
89 else:
89 else:
90 assert io.stdout.startswith("Async"), io.stdout
90 assert io.stdout.startswith("Async"), io.stdout
91
91
92 with capture_output(display=False) as io:
92 with capture_output(display=False) as io:
93 ip.run_cell_magic("px", "--block", "1")
93 ip.run_cell_magic("px", "--block", "1")
94 assert io.stdout.startswith("Parallel"), io.stdout
94 assert io.stdout.startswith("Parallel"), io.stdout
95
95
96 with capture_output(display=False) as io:
96 with capture_output(display=False) as io:
97 ip.run_cell_magic("px", "--noblock", "1")
97 ip.run_cell_magic("px", "--noblock", "1")
98 assert io.stdout.startswith("Async"), io.stdout
98 assert io.stdout.startswith("Async"), io.stdout
99
99
100 def test_cellpx_groupby_engine(self):
100 def test_cellpx_groupby_engine(self):
101 """%%px --group-outputs=engine"""
101 """%%px --group-outputs=engine"""
102 ip = get_ipython()
102 ip = get_ipython()
103 v = self.client[:]
103 v = self.client[:]
104 v.block = True
104 v.block = True
105 v.activate()
105 v.activate()
106
106
107 v['generate_output'] = generate_output
107 v['generate_output'] = generate_output
108
108
109 with capture_output(display=False) as io:
109 with capture_output(display=False) as io:
110 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
110 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
111
111
112 self.assertNotIn('\n\n', io.stdout)
112 self.assertNotIn('\n\n', io.stdout)
113 lines = io.stdout.splitlines()
113 lines = io.stdout.splitlines()
114 expected = [
114 expected = [
115 r'\[stdout:\d+\]',
115 r'\[stdout:\d+\]',
116 'stdout',
116 'stdout',
117 'stdout2',
117 'stdout2',
118 r'\[output:\d+\]',
118 r'\[output:\d+\]',
119 r'IPython\.core\.display\.HTML',
119 r'IPython\.core\.display\.HTML',
120 r'IPython\.core\.display\.Math',
120 r'IPython\.core\.display\.Math',
121 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
121 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
122 ] * len(v)
122 ] * len(v)
123
123
124 self.assertEqual(len(lines), len(expected), io.stdout)
124 self.assertEqual(len(lines), len(expected), io.stdout)
125 for line,expect in zip(lines, expected):
125 for line,expect in zip(lines, expected):
126 if isinstance(expect, str):
126 if isinstance(expect, str):
127 expect = [expect]
127 expect = [expect]
128 for ex in expect:
128 for ex in expect:
129 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
129 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
130
130
131 self._check_generated_stderr(io.stderr, len(v))
131 self._check_generated_stderr(io.stderr, len(v))
132
132
133
133
134 def test_cellpx_groupby_order(self):
134 def test_cellpx_groupby_order(self):
135 """%%px --group-outputs=order"""
135 """%%px --group-outputs=order"""
136 ip = get_ipython()
136 ip = get_ipython()
137 v = self.client[:]
137 v = self.client[:]
138 v.block = True
138 v.block = True
139 v.activate()
139 v.activate()
140
140
141 v['generate_output'] = generate_output
141 v['generate_output'] = generate_output
142
142
143 with capture_output(display=False) as io:
143 with capture_output(display=False) as io:
144 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
144 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
145
145
146 self.assertNotIn('\n\n', io.stdout)
146 self.assertNotIn('\n\n', io.stdout)
147 lines = io.stdout.splitlines()
147 lines = io.stdout.splitlines()
148 expected = []
148 expected = []
149 expected.extend([
149 expected.extend([
150 r'\[stdout:\d+\]',
150 r'\[stdout:\d+\]',
151 'stdout',
151 'stdout',
152 'stdout2',
152 'stdout2',
153 ] * len(v))
153 ] * len(v))
154 expected.extend([
154 expected.extend([
155 r'\[output:\d+\]',
155 r'\[output:\d+\]',
156 'IPython.core.display.HTML',
156 'IPython.core.display.HTML',
157 ] * len(v))
157 ] * len(v))
158 expected.extend([
158 expected.extend([
159 r'\[output:\d+\]',
159 r'\[output:\d+\]',
160 'IPython.core.display.Math',
160 'IPython.core.display.Math',
161 ] * len(v))
161 ] * len(v))
162 expected.extend([
162 expected.extend([
163 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
163 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
164 ] * len(v))
164 ] * len(v))
165
165
166 self.assertEqual(len(lines), len(expected), io.stdout)
166 self.assertEqual(len(lines), len(expected), io.stdout)
167 for line,expect in zip(lines, expected):
167 for line,expect in zip(lines, expected):
168 if isinstance(expect, str):
168 if isinstance(expect, str):
169 expect = [expect]
169 expect = [expect]
170 for ex in expect:
170 for ex in expect:
171 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
171 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
172
172
173 self._check_generated_stderr(io.stderr, len(v))
173 self._check_generated_stderr(io.stderr, len(v))
174
174
175 def test_cellpx_groupby_type(self):
175 def test_cellpx_groupby_type(self):
176 """%%px --group-outputs=type"""
176 """%%px --group-outputs=type"""
177 ip = get_ipython()
177 ip = get_ipython()
178 v = self.client[:]
178 v = self.client[:]
179 v.block = True
179 v.block = True
180 v.activate()
180 v.activate()
181
181
182 v['generate_output'] = generate_output
182 v['generate_output'] = generate_output
183
183
184 with capture_output(display=False) as io:
184 with capture_output(display=False) as io:
185 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
185 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
186
186
187 self.assertNotIn('\n\n', io.stdout)
187 self.assertNotIn('\n\n', io.stdout)
188 lines = io.stdout.splitlines()
188 lines = io.stdout.splitlines()
189
189
190 expected = []
190 expected = []
191 expected.extend([
191 expected.extend([
192 r'\[stdout:\d+\]',
192 r'\[stdout:\d+\]',
193 'stdout',
193 'stdout',
194 'stdout2',
194 'stdout2',
195 ] * len(v))
195 ] * len(v))
196 expected.extend([
196 expected.extend([
197 r'\[output:\d+\]',
197 r'\[output:\d+\]',
198 r'IPython\.core\.display\.HTML',
198 r'IPython\.core\.display\.HTML',
199 r'IPython\.core\.display\.Math',
199 r'IPython\.core\.display\.Math',
200 ] * len(v))
200 ] * len(v))
201 expected.extend([
201 expected.extend([
202 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
202 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
203 ] * len(v))
203 ] * len(v))
204
204
205 self.assertEqual(len(lines), len(expected), io.stdout)
205 self.assertEqual(len(lines), len(expected), io.stdout)
206 for line,expect in zip(lines, expected):
206 for line,expect in zip(lines, expected):
207 if isinstance(expect, str):
207 if isinstance(expect, str):
208 expect = [expect]
208 expect = [expect]
209 for ex in expect:
209 for ex in expect:
210 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
210 assert re.search(ex, line) is not None, "Expected %r in %r" % (ex, line)
211
211
212 self._check_generated_stderr(io.stderr, len(v))
212 self._check_generated_stderr(io.stderr, len(v))
213
213
214
214
215 def test_px_nonblocking(self):
215 def test_px_nonblocking(self):
216 ip = get_ipython()
216 ip = get_ipython()
217 v = self.client[-1:]
217 v = self.client[-1:]
218 v.activate()
218 v.activate()
219 v.block=False
219 v.block=False
220
220
221 ip.magic('px a=5')
221 ip.magic('px a=5')
222 self.assertEqual(v['a'], [5])
222 self.assertEqual(v['a'], [5])
223 ip.magic('px a=10')
223 ip.magic('px a=10')
224 self.assertEqual(v['a'], [10])
224 self.assertEqual(v['a'], [10])
225 ip.magic('pxconfig --verbose')
225 ip.magic('pxconfig --verbose')
226 with capture_output() as io:
226 with capture_output() as io:
227 ar = ip.magic('px print (a)')
227 ar = ip.magic('px print (a)')
228 self.assertIsInstance(ar, AsyncResult)
228 self.assertIsInstance(ar, AsyncResult)
229 self.assertIn('Async', io.stdout)
229 self.assertIn('Async', io.stdout)
230 self.assertNotIn('[stdout:', io.stdout)
230 self.assertNotIn('[stdout:', io.stdout)
231 self.assertNotIn('\n\n', io.stdout)
231 self.assertNotIn('\n\n', io.stdout)
232
232
233 ar = ip.magic('px 1/0')
233 ar = ip.magic('px 1/0')
234 self.assertRaisesRemote(ZeroDivisionError, ar.get)
234 self.assertRaisesRemote(ZeroDivisionError, ar.get)
235
235
236 def test_autopx_blocking(self):
236 def test_autopx_blocking(self):
237 ip = get_ipython()
237 ip = get_ipython()
238 v = self.client[-1]
238 v = self.client[-1]
239 v.activate()
239 v.activate()
240 v.block=True
240 v.block=True
241
241
242 with capture_output(display=False) as io:
242 with capture_output(display=False) as io:
243 ip.magic('autopx')
243 ip.magic('autopx')
244 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
244 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
245 ip.run_cell('b*=2')
245 ip.run_cell('b*=2')
246 ip.run_cell('print (b)')
246 ip.run_cell('print (b)')
247 ip.run_cell('b')
247 ip.run_cell('b')
248 ip.run_cell("b/c")
248 ip.run_cell("b/c")
249 ip.magic('autopx')
249 ip.magic('autopx')
250
250
251 output = io.stdout
251 output = io.stdout
252
252
253 assert output.startswith('%autopx enabled'), output
253 assert output.startswith('%autopx enabled'), output
254 assert output.rstrip().endswith('%autopx disabled'), output
254 assert output.rstrip().endswith('%autopx disabled'), output
255 self.assertIn('ZeroDivisionError', output)
255 self.assertIn('ZeroDivisionError', output)
256 self.assertIn('\nOut[', output)
256 self.assertIn('\nOut[', output)
257 self.assertIn(': 24690', output)
257 self.assertIn(': 24690', output)
258 ar = v.get_result(-1)
258 ar = v.get_result(-1)
259 self.assertEqual(v['a'], 5)
259 self.assertEqual(v['a'], 5)
260 self.assertEqual(v['b'], 24690)
260 self.assertEqual(v['b'], 24690)
261 self.assertRaisesRemote(ZeroDivisionError, ar.get)
261 self.assertRaisesRemote(ZeroDivisionError, ar.get)
262
262
263 def test_autopx_nonblocking(self):
263 def test_autopx_nonblocking(self):
264 ip = get_ipython()
264 ip = get_ipython()
265 v = self.client[-1]
265 v = self.client[-1]
266 v.activate()
266 v.activate()
267 v.block=False
267 v.block=False
268
268
269 with capture_output() as io:
269 with capture_output() as io:
270 ip.magic('autopx')
270 ip.magic('autopx')
271 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
271 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
272 ip.run_cell('print (b)')
272 ip.run_cell('print (b)')
273 ip.run_cell('import time; time.sleep(0.1)')
273 ip.run_cell('import time; time.sleep(0.1)')
274 ip.run_cell("b/c")
274 ip.run_cell("b/c")
275 ip.run_cell('b*=2')
275 ip.run_cell('b*=2')
276 ip.magic('autopx')
276 ip.magic('autopx')
277
277
278 output = io.stdout.rstrip()
278 output = io.stdout.rstrip()
279
279
280 assert output.startswith('%autopx enabled'), output
280 assert output.startswith('%autopx enabled'), output
281 assert output.endswith('%autopx disabled'), output
281 assert output.endswith('%autopx disabled'), output
282 self.assertNotIn('ZeroDivisionError', output)
282 self.assertNotIn('ZeroDivisionError', output)
283 ar = v.get_result(-2)
283 ar = v.get_result(-2)
284 self.assertRaisesRemote(ZeroDivisionError, ar.get)
284 self.assertRaisesRemote(ZeroDivisionError, ar.get)
285 # prevent TaskAborted on pulls, due to ZeroDivisionError
285 # prevent TaskAborted on pulls, due to ZeroDivisionError
286 time.sleep(0.5)
286 time.sleep(0.5)
287 self.assertEqual(v['a'], 5)
287 self.assertEqual(v['a'], 5)
288 # b*=2 will not fire, due to abort
288 # b*=2 will not fire, due to abort
289 self.assertEqual(v['b'], 10)
289 self.assertEqual(v['b'], 10)
290
290
291 def test_result(self):
291 def test_result(self):
292 ip = get_ipython()
292 ip = get_ipython()
293 v = self.client[-1]
293 v = self.client[-1]
294 v.activate()
294 v.activate()
295 data = dict(a=111,b=222)
295 data = dict(a=111,b=222)
296 v.push(data, block=True)
296 v.push(data, block=True)
297
297
298 for name in ('a', 'b'):
298 for name in ('a', 'b'):
299 ip.magic('px ' + name)
299 ip.magic('px ' + name)
300 with capture_output(display=False) as io:
300 with capture_output(display=False) as io:
301 ip.magic('pxresult')
301 ip.magic('pxresult')
302 self.assertIn(str(data[name]), io.stdout)
302 self.assertIn(str(data[name]), io.stdout)
303
303
304 @dec.skipif_not_matplotlib
304 @dec.skipif_not_matplotlib
305 def test_px_pylab(self):
305 def test_px_pylab(self):
306 """%pylab works on engines"""
306 """%pylab works on engines"""
307 ip = get_ipython()
307 ip = get_ipython()
308 v = self.client[-1]
308 v = self.client[-1]
309 v.block = True
309 v.block = True
310 v.activate()
310 v.activate()
311
311
312 with capture_output() as io:
312 with capture_output() as io:
313 ip.magic("px %pylab inline")
313 ip.magic("px %pylab inline")
314
314
315 self.assertIn("Populating the interactive namespace from numpy and matplotlib", io.stdout)
315 self.assertIn("Populating the interactive namespace from numpy and matplotlib", io.stdout)
316
316
317 with capture_output(display=False) as io:
317 with capture_output(display=False) as io:
318 ip.magic("px plot(rand(100))")
318 ip.magic("px plot(rand(100))")
319 self.assertIn('Out[', io.stdout)
319 self.assertIn('Out[', io.stdout)
320 self.assertIn('matplotlib.lines', io.stdout)
320 self.assertIn('matplotlib.lines', io.stdout)
321
321
322 def test_pxconfig(self):
322 def test_pxconfig(self):
323 ip = get_ipython()
323 ip = get_ipython()
324 rc = self.client
324 rc = self.client
325 v = rc.activate(-1, '_tst')
325 v = rc.activate(-1, '_tst')
326 self.assertEqual(v.targets, rc.ids[-1])
326 self.assertEqual(v.targets, rc.ids[-1])
327 ip.magic("%pxconfig_tst -t :")
327 ip.magic("%pxconfig_tst -t :")
328 self.assertEqual(v.targets, rc.ids)
328 self.assertEqual(v.targets, rc.ids)
329 ip.magic("%pxconfig_tst -t ::2")
329 ip.magic("%pxconfig_tst -t ::2")
330 self.assertEqual(v.targets, rc.ids[::2])
330 self.assertEqual(v.targets, rc.ids[::2])
331 ip.magic("%pxconfig_tst -t 1::2")
331 ip.magic("%pxconfig_tst -t 1::2")
332 self.assertEqual(v.targets, rc.ids[1::2])
332 self.assertEqual(v.targets, rc.ids[1::2])
333 ip.magic("%pxconfig_tst -t 1")
333 ip.magic("%pxconfig_tst -t 1")
334 self.assertEqual(v.targets, 1)
334 self.assertEqual(v.targets, 1)
335 ip.magic("%pxconfig_tst --block")
335 ip.magic("%pxconfig_tst --block")
336 self.assertEqual(v.block, True)
336 self.assertEqual(v.block, True)
337 ip.magic("%pxconfig_tst --noblock")
337 ip.magic("%pxconfig_tst --noblock")
338 self.assertEqual(v.block, False)
338 self.assertEqual(v.block, False)
339
339
340 def test_cellpx_targets(self):
340 def test_cellpx_targets(self):
341 """%%px --targets doesn't change defaults"""
341 """%%px --targets doesn't change defaults"""
342 ip = get_ipython()
342 ip = get_ipython()
343 rc = self.client
343 rc = self.client
344 view = rc.activate(rc.ids)
344 view = rc.activate(rc.ids)
345 self.assertEqual(view.targets, rc.ids)
345 self.assertEqual(view.targets, rc.ids)
346 ip.magic('pxconfig --verbose')
346 ip.magic('pxconfig --verbose')
347 for cell in ("pass", "1/0"):
347 for cell in ("pass", "1/0"):
348 with capture_output(display=False) as io:
348 with capture_output(display=False) as io:
349 try:
349 try:
350 ip.run_cell_magic("px", "--targets all", cell)
350 ip.run_cell_magic("px", "--targets all", cell)
351 except pmod.RemoteError:
351 except pmod.RemoteError:
352 pass
352 pass
353 self.assertIn('engine(s): all', io.stdout)
353 self.assertIn('engine(s): all', io.stdout)
354 self.assertEqual(view.targets, rc.ids)
354 self.assertEqual(view.targets, rc.ids)
355
355
356
356
357 def test_cellpx_block(self):
357 def test_cellpx_block(self):
358 """%%px --block doesn't change default"""
358 """%%px --block doesn't change default"""
359 ip = get_ipython()
359 ip = get_ipython()
360 rc = self.client
360 rc = self.client
361 view = rc.activate(rc.ids)
361 view = rc.activate(rc.ids)
362 view.block = False
362 view.block = False
363 self.assertEqual(view.targets, rc.ids)
363 self.assertEqual(view.targets, rc.ids)
364 ip.magic('pxconfig --verbose')
364 ip.magic('pxconfig --verbose')
365 for cell in ("pass", "1/0"):
365 for cell in ("pass", "1/0"):
366 with capture_output(display=False) as io:
366 with capture_output(display=False) as io:
367 try:
367 try:
368 ip.run_cell_magic("px", "--block", cell)
368 ip.run_cell_magic("px", "--block", cell)
369 except pmod.RemoteError:
369 except pmod.RemoteError:
370 pass
370 pass
371 self.assertNotIn('Async', io.stdout)
371 self.assertNotIn('Async', io.stdout)
372 self.assertEqual(view.block, False)
372 self.assertEqual(view.block, False)
373
373
374
374
@@ -1,56 +1,56 b''
1 """Tests for mongodb backend
1 """Tests for mongodb backend
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import os
19 import os
20
20
21 from unittest import TestCase
21 from unittest import TestCase
22
22
23 from nose import SkipTest
23 from nose import SkipTest
24
24
25 from pymongo import Connection
25 from pymongo import Connection
26 from IPython.parallel.controller.mongodb import MongoDB
26 from ipython_parallel.controller.mongodb import MongoDB
27
27
28 from . import test_db
28 from . import test_db
29
29
30 conn_kwargs = {}
30 conn_kwargs = {}
31 if 'DB_IP' in os.environ:
31 if 'DB_IP' in os.environ:
32 conn_kwargs['host'] = os.environ['DB_IP']
32 conn_kwargs['host'] = os.environ['DB_IP']
33 if 'DBA_MONGODB_ADMIN_URI' in os.environ:
33 if 'DBA_MONGODB_ADMIN_URI' in os.environ:
34 # On ShiningPanda, we need a username and password to connect. They are
34 # On ShiningPanda, we need a username and password to connect. They are
35 # passed in a mongodb:// URI.
35 # passed in a mongodb:// URI.
36 conn_kwargs['host'] = os.environ['DBA_MONGODB_ADMIN_URI']
36 conn_kwargs['host'] = os.environ['DBA_MONGODB_ADMIN_URI']
37 if 'DB_PORT' in os.environ:
37 if 'DB_PORT' in os.environ:
38 conn_kwargs['port'] = int(os.environ['DB_PORT'])
38 conn_kwargs['port'] = int(os.environ['DB_PORT'])
39
39
40 try:
40 try:
41 c = Connection(**conn_kwargs)
41 c = Connection(**conn_kwargs)
42 except Exception:
42 except Exception:
43 c=None
43 c=None
44
44
45 class TestMongoBackend(test_db.TaskDBTest, TestCase):
45 class TestMongoBackend(test_db.TaskDBTest, TestCase):
46 """MongoDB backend tests"""
46 """MongoDB backend tests"""
47
47
48 def create_db(self):
48 def create_db(self):
49 try:
49 try:
50 return MongoDB(database='iptestdb', _connection=c)
50 return MongoDB(database='iptestdb', _connection=c)
51 except Exception:
51 except Exception:
52 raise SkipTest("Couldn't connect to mongodb")
52 raise SkipTest("Couldn't connect to mongodb")
53
53
54 def teardown(self):
54 def teardown(self):
55 if c is not None:
55 if c is not None:
56 c.drop_database('iptestdb')
56 c.drop_database('iptestdb')
@@ -1,843 +1,843 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects"""
2 """test View objects"""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import base64
7 import base64
8 import sys
8 import sys
9 import platform
9 import platform
10 import time
10 import time
11 from collections import namedtuple
11 from collections import namedtuple
12 from tempfile import NamedTemporaryFile
12 from tempfile import NamedTemporaryFile
13
13
14 import zmq
14 import zmq
15 from nose.plugins.attrib import attr
15 from nose.plugins.attrib import attr
16
16
17 from IPython.testing import decorators as dec
17 from IPython.testing import decorators as dec
18 from IPython.utils.io import capture_output
18 from IPython.utils.io import capture_output
19 from IPython.utils.py3compat import unicode_type
19 from IPython.utils.py3compat import unicode_type
20
20
21 from IPython import parallel as pmod
21 from IPython import parallel as pmod
22 from IPython.parallel import error
22 from ipython_parallel import error
23 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
23 from ipython_parallel import AsyncResult, AsyncHubResult, AsyncMapResult
24 from IPython.parallel.util import interactive
24 from ipython_parallel.util import interactive
25
25
26 from IPython.parallel.tests import add_engines
26 from ipython_parallel.tests import add_engines
27
27
28 from .clienttest import ClusterTestCase, crash, wait, skip_without
28 from .clienttest import ClusterTestCase, crash, wait, skip_without
29
29
30 def setup():
30 def setup():
31 add_engines(3, total=True)
31 add_engines(3, total=True)
32
32
33 point = namedtuple("point", "x y")
33 point = namedtuple("point", "x y")
34
34
35 class TestView(ClusterTestCase):
35 class TestView(ClusterTestCase):
36
36
37 def setUp(self):
37 def setUp(self):
38 # On Win XP, wait for resource cleanup, else parallel test group fails
38 # On Win XP, wait for resource cleanup, else parallel test group fails
39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
39 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
40 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
41 time.sleep(2)
41 time.sleep(2)
42 super(TestView, self).setUp()
42 super(TestView, self).setUp()
43
43
44 @attr('crash')
44 @attr('crash')
45 def test_z_crash_mux(self):
45 def test_z_crash_mux(self):
46 """test graceful handling of engine death (direct)"""
46 """test graceful handling of engine death (direct)"""
47 # self.add_engines(1)
47 # self.add_engines(1)
48 eid = self.client.ids[-1]
48 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
49 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
51 eid = ar.engine_id
52 tic = time.time()
52 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
53 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
54 time.sleep(.01)
55 self.client.spin()
55 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
57
58 def test_push_pull(self):
58 def test_push_pull(self):
59 """test pushing and pulling"""
59 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
60 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
61 t = self.client.ids[-1]
62 v = self.client[t]
62 v = self.client[t]
63 push = v.push
63 push = v.push
64 pull = v.pull
64 pull = v.pull
65 v.block=True
65 v.block=True
66 nengines = len(self.client)
66 nengines = len(self.client)
67 push({'data':data})
67 push({'data':data})
68 d = pull('data')
68 d = pull('data')
69 self.assertEqual(d, data)
69 self.assertEqual(d, data)
70 self.client[:].push({'data':data})
70 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
71 d = self.client[:].pull('data', block=True)
72 self.assertEqual(d, nengines*[data])
72 self.assertEqual(d, nengines*[data])
73 ar = push({'data':data}, block=False)
73 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
74 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
75 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
76 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
77 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
78 r = ar.get()
79 self.assertEqual(r, nengines*[data])
79 self.assertEqual(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
80 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
81 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEqual(r, nengines*[[10,20]])
82 self.assertEqual(r, nengines*[[10,20]])
83
83
84 def test_push_pull_function(self):
84 def test_push_pull_function(self):
85 "test pushing and pulling functions"
85 "test pushing and pulling functions"
86 def testf(x):
86 def testf(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 t = self.client.ids[-1]
89 t = self.client.ids[-1]
90 v = self.client[t]
90 v = self.client[t]
91 v.block=True
91 v.block=True
92 push = v.push
92 push = v.push
93 pull = v.pull
93 pull = v.pull
94 execute = v.execute
94 execute = v.execute
95 push({'testf':testf})
95 push({'testf':testf})
96 r = pull('testf')
96 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
97 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
98 execute('r = testf(10)')
99 r = pull('r')
99 r = pull('r')
100 self.assertEqual(r, testf(10))
100 self.assertEqual(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
102 ar.get()
103 ar = self.client[:].pull('testf', block=False)
103 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
104 rlist = ar.get()
105 for r in rlist:
105 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
106 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
107 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
108 r = pull(('testf','g'))
109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
109 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
110
110
111 def test_push_function_globals(self):
111 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
112 """test that pushed functions have access to globals"""
113 @interactive
113 @interactive
114 def geta():
114 def geta():
115 return a
115 return a
116 # self.add_engines(1)
116 # self.add_engines(1)
117 v = self.client[-1]
117 v = self.client[-1]
118 v.block=True
118 v.block=True
119 v['f'] = geta
119 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
121 v.execute('a=5')
122 v.execute('b=f()')
122 v.execute('b=f()')
123 self.assertEqual(v['b'], 5)
123 self.assertEqual(v['b'], 5)
124
124
125 def test_push_function_defaults(self):
125 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
126 """test that pushed functions preserve default args"""
127 def echo(a=10):
127 def echo(a=10):
128 return a
128 return a
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = echo
131 v['f'] = echo
132 v.execute('b=f()')
132 v.execute('b=f()')
133 self.assertEqual(v['b'], 10)
133 self.assertEqual(v['b'], 10)
134
134
135 def test_get_result(self):
135 def test_get_result(self):
136 """test getting results from the Hub."""
136 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
137 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
138 # self.add_engines(1)
139 t = c.ids[-1]
139 t = c.ids[-1]
140 v = c[t]
140 v = c[t]
141 v2 = self.client[t]
141 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
142 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
143 # give the monitor time to notice the message
144 time.sleep(.25)
144 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids[0], owner=False)
145 ahr = v2.get_result(ar.msg_ids[0], owner=False)
146 self.assertIsInstance(ahr, AsyncHubResult)
146 self.assertIsInstance(ahr, AsyncHubResult)
147 self.assertEqual(ahr.get(), ar.get())
147 self.assertEqual(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids[0])
148 ar2 = v2.get_result(ar.msg_ids[0])
149 self.assertNotIsInstance(ar2, AsyncHubResult)
149 self.assertNotIsInstance(ar2, AsyncHubResult)
150 self.assertEqual(ahr.get(), ar2.get())
150 self.assertEqual(ahr.get(), ar2.get())
151 c.spin()
151 c.spin()
152 c.close()
152 c.close()
153
153
154 def test_run_newline(self):
154 def test_run_newline(self):
155 """test that run appends newline to files"""
155 """test that run appends newline to files"""
156 with NamedTemporaryFile('w', delete=False) as f:
156 with NamedTemporaryFile('w', delete=False) as f:
157 f.write("""def g():
157 f.write("""def g():
158 return 5
158 return 5
159 """)
159 """)
160 v = self.client[-1]
160 v = self.client[-1]
161 v.run(f.name, block=True)
161 v.run(f.name, block=True)
162 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
162 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
163
164 def test_apply_tracked(self):
164 def test_apply_tracked(self):
165 """test tracking for apply"""
165 """test tracking for apply"""
166 # self.add_engines(1)
166 # self.add_engines(1)
167 t = self.client.ids[-1]
167 t = self.client.ids[-1]
168 v = self.client[t]
168 v = self.client[t]
169 v.block=False
169 v.block=False
170 def echo(n=1024*1024, **kwargs):
170 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
171 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
172 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
173 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
175 self.assertTrue(ar.sent)
176 ar = echo(track=True)
176 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEqual(ar.sent, ar._tracker.done)
178 self.assertEqual(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
179 ar._tracker.wait()
180 self.assertTrue(ar.sent)
180 self.assertTrue(ar.sent)
181
181
182 def test_push_tracked(self):
182 def test_push_tracked(self):
183 t = self.client.ids[-1]
183 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
184 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
185 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
186 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
188 self.assertTrue(ar.sent)
189
189
190 ar = v.push(ns, block=False, track=True)
190 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
192 ar._tracker.wait()
193 self.assertEqual(ar.sent, ar._tracker.done)
193 self.assertEqual(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
194 self.assertTrue(ar.sent)
195 ar.get()
195 ar.get()
196
196
197 def test_scatter_tracked(self):
197 def test_scatter_tracked(self):
198 t = self.client.ids
198 t = self.client.ids
199 x='x'*1024*1024
199 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
202 self.assertTrue(ar.sent)
203
203
204 ar = self.client[t].scatter('x', x, block=False, track=True)
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertEqual(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
207 ar._tracker.wait()
208 self.assertTrue(ar.sent)
208 self.assertTrue(ar.sent)
209 ar.get()
209 ar.get()
210
210
211 def test_remote_reference(self):
211 def test_remote_reference(self):
212 v = self.client[-1]
212 v = self.client[-1]
213 v['a'] = 123
213 v['a'] = 123
214 ra = pmod.Reference('a')
214 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
215 b = v.apply_sync(lambda x: x, ra)
216 self.assertEqual(b, 123)
216 self.assertEqual(b, 123)
217
217
218
218
219 def test_scatter_gather(self):
219 def test_scatter_gather(self):
220 view = self.client[:]
220 view = self.client[:]
221 seq1 = list(range(16))
221 seq1 = list(range(16))
222 view.scatter('a', seq1)
222 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
223 seq2 = view.gather('a', block=True)
224 self.assertEqual(seq2, seq1)
224 self.assertEqual(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
226
227 @skip_without('numpy')
227 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
228 def test_scatter_gather_numpy(self):
229 import numpy
229 import numpy
230 from numpy.testing.utils import assert_array_equal
230 from numpy.testing.utils import assert_array_equal
231 view = self.client[:]
231 view = self.client[:]
232 a = numpy.arange(64)
232 a = numpy.arange(64)
233 view.scatter('a', a, block=True)
233 view.scatter('a', a, block=True)
234 b = view.gather('a', block=True)
234 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
235 assert_array_equal(b, a)
236
236
237 def test_scatter_gather_lazy(self):
237 def test_scatter_gather_lazy(self):
238 """scatter/gather with targets='all'"""
238 """scatter/gather with targets='all'"""
239 view = self.client.direct_view(targets='all')
239 view = self.client.direct_view(targets='all')
240 x = list(range(64))
240 x = list(range(64))
241 view.scatter('x', x)
241 view.scatter('x', x)
242 gathered = view.gather('x', block=True)
242 gathered = view.gather('x', block=True)
243 self.assertEqual(gathered, x)
243 self.assertEqual(gathered, x)
244
244
245
245
246 @dec.known_failure_py3
246 @dec.known_failure_py3
247 @skip_without('numpy')
247 @skip_without('numpy')
248 def test_push_numpy_nocopy(self):
248 def test_push_numpy_nocopy(self):
249 import numpy
249 import numpy
250 view = self.client[:]
250 view = self.client[:]
251 a = numpy.arange(64)
251 a = numpy.arange(64)
252 view['A'] = a
252 view['A'] = a
253 @interactive
253 @interactive
254 def check_writeable(x):
254 def check_writeable(x):
255 return x.flags.writeable
255 return x.flags.writeable
256
256
257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
259
259
260 view.push(dict(B=a))
260 view.push(dict(B=a))
261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
263
263
264 @skip_without('numpy')
264 @skip_without('numpy')
265 def test_apply_numpy(self):
265 def test_apply_numpy(self):
266 """view.apply(f, ndarray)"""
266 """view.apply(f, ndarray)"""
267 import numpy
267 import numpy
268 from numpy.testing.utils import assert_array_equal
268 from numpy.testing.utils import assert_array_equal
269
269
270 A = numpy.random.random((100,100))
270 A = numpy.random.random((100,100))
271 view = self.client[-1]
271 view = self.client[-1]
272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
273 B = A.astype(dt)
273 B = A.astype(dt)
274 C = view.apply_sync(lambda x:x, B)
274 C = view.apply_sync(lambda x:x, B)
275 assert_array_equal(B,C)
275 assert_array_equal(B,C)
276
276
277 @skip_without('numpy')
277 @skip_without('numpy')
278 def test_apply_numpy_object_dtype(self):
278 def test_apply_numpy_object_dtype(self):
279 """view.apply(f, ndarray) with dtype=object"""
279 """view.apply(f, ndarray) with dtype=object"""
280 import numpy
280 import numpy
281 from numpy.testing.utils import assert_array_equal
281 from numpy.testing.utils import assert_array_equal
282 view = self.client[-1]
282 view = self.client[-1]
283
283
284 A = numpy.array([dict(a=5)])
284 A = numpy.array([dict(a=5)])
285 B = view.apply_sync(lambda x:x, A)
285 B = view.apply_sync(lambda x:x, A)
286 assert_array_equal(A,B)
286 assert_array_equal(A,B)
287
287
288 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
288 A = numpy.array([(0, dict(b=10))], dtype=[('i', int), ('o', object)])
289 B = view.apply_sync(lambda x:x, A)
289 B = view.apply_sync(lambda x:x, A)
290 assert_array_equal(A,B)
290 assert_array_equal(A,B)
291
291
292 @skip_without('numpy')
292 @skip_without('numpy')
293 def test_push_pull_recarray(self):
293 def test_push_pull_recarray(self):
294 """push/pull recarrays"""
294 """push/pull recarrays"""
295 import numpy
295 import numpy
296 from numpy.testing.utils import assert_array_equal
296 from numpy.testing.utils import assert_array_equal
297
297
298 view = self.client[-1]
298 view = self.client[-1]
299
299
300 R = numpy.array([
300 R = numpy.array([
301 (1, 'hi', 0.),
301 (1, 'hi', 0.),
302 (2**30, 'there', 2.5),
302 (2**30, 'there', 2.5),
303 (-99999, 'world', -12345.6789),
303 (-99999, 'world', -12345.6789),
304 ], [('n', int), ('s', '|S10'), ('f', float)])
304 ], [('n', int), ('s', '|S10'), ('f', float)])
305
305
306 view['RR'] = R
306 view['RR'] = R
307 R2 = view['RR']
307 R2 = view['RR']
308
308
309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
309 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
310 self.assertEqual(r_dtype, R.dtype)
310 self.assertEqual(r_dtype, R.dtype)
311 self.assertEqual(r_shape, R.shape)
311 self.assertEqual(r_shape, R.shape)
312 self.assertEqual(R2.dtype, R.dtype)
312 self.assertEqual(R2.dtype, R.dtype)
313 self.assertEqual(R2.shape, R.shape)
313 self.assertEqual(R2.shape, R.shape)
314 assert_array_equal(R2, R)
314 assert_array_equal(R2, R)
315
315
316 @skip_without('pandas')
316 @skip_without('pandas')
317 def test_push_pull_timeseries(self):
317 def test_push_pull_timeseries(self):
318 """push/pull pandas.TimeSeries"""
318 """push/pull pandas.TimeSeries"""
319 import pandas
319 import pandas
320
320
321 ts = pandas.TimeSeries(list(range(10)))
321 ts = pandas.TimeSeries(list(range(10)))
322
322
323 view = self.client[-1]
323 view = self.client[-1]
324
324
325 view.push(dict(ts=ts), block=True)
325 view.push(dict(ts=ts), block=True)
326 rts = view['ts']
326 rts = view['ts']
327
327
328 self.assertEqual(type(rts), type(ts))
328 self.assertEqual(type(rts), type(ts))
329 self.assertTrue((ts == rts).all())
329 self.assertTrue((ts == rts).all())
330
330
331 def test_map(self):
331 def test_map(self):
332 view = self.client[:]
332 view = self.client[:]
333 def f(x):
333 def f(x):
334 return x**2
334 return x**2
335 data = list(range(16))
335 data = list(range(16))
336 r = view.map_sync(f, data)
336 r = view.map_sync(f, data)
337 self.assertEqual(r, list(map(f, data)))
337 self.assertEqual(r, list(map(f, data)))
338
338
339 def test_map_empty_sequence(self):
339 def test_map_empty_sequence(self):
340 view = self.client[:]
340 view = self.client[:]
341 r = view.map_sync(lambda x: x, [])
341 r = view.map_sync(lambda x: x, [])
342 self.assertEqual(r, [])
342 self.assertEqual(r, [])
343
343
344 def test_map_iterable(self):
344 def test_map_iterable(self):
345 """test map on iterables (direct)"""
345 """test map on iterables (direct)"""
346 view = self.client[:]
346 view = self.client[:]
347 # 101 is prime, so it won't be evenly distributed
347 # 101 is prime, so it won't be evenly distributed
348 arr = range(101)
348 arr = range(101)
349 # ensure it will be an iterator, even in Python 3
349 # ensure it will be an iterator, even in Python 3
350 it = iter(arr)
350 it = iter(arr)
351 r = view.map_sync(lambda x: x, it)
351 r = view.map_sync(lambda x: x, it)
352 self.assertEqual(r, list(arr))
352 self.assertEqual(r, list(arr))
353
353
354 @skip_without('numpy')
354 @skip_without('numpy')
355 def test_map_numpy(self):
355 def test_map_numpy(self):
356 """test map on numpy arrays (direct)"""
356 """test map on numpy arrays (direct)"""
357 import numpy
357 import numpy
358 from numpy.testing.utils import assert_array_equal
358 from numpy.testing.utils import assert_array_equal
359
359
360 view = self.client[:]
360 view = self.client[:]
361 # 101 is prime, so it won't be evenly distributed
361 # 101 is prime, so it won't be evenly distributed
362 arr = numpy.arange(101)
362 arr = numpy.arange(101)
363 r = view.map_sync(lambda x: x, arr)
363 r = view.map_sync(lambda x: x, arr)
364 assert_array_equal(r, arr)
364 assert_array_equal(r, arr)
365
365
366 def test_scatter_gather_nonblocking(self):
366 def test_scatter_gather_nonblocking(self):
367 data = list(range(16))
367 data = list(range(16))
368 view = self.client[:]
368 view = self.client[:]
369 view.scatter('a', data, block=False)
369 view.scatter('a', data, block=False)
370 ar = view.gather('a', block=False)
370 ar = view.gather('a', block=False)
371 self.assertEqual(ar.get(), data)
371 self.assertEqual(ar.get(), data)
372
372
373 @skip_without('numpy')
373 @skip_without('numpy')
374 def test_scatter_gather_numpy_nonblocking(self):
374 def test_scatter_gather_numpy_nonblocking(self):
375 import numpy
375 import numpy
376 from numpy.testing.utils import assert_array_equal
376 from numpy.testing.utils import assert_array_equal
377 a = numpy.arange(64)
377 a = numpy.arange(64)
378 view = self.client[:]
378 view = self.client[:]
379 ar = view.scatter('a', a, block=False)
379 ar = view.scatter('a', a, block=False)
380 self.assertTrue(isinstance(ar, AsyncResult))
380 self.assertTrue(isinstance(ar, AsyncResult))
381 amr = view.gather('a', block=False)
381 amr = view.gather('a', block=False)
382 self.assertTrue(isinstance(amr, AsyncMapResult))
382 self.assertTrue(isinstance(amr, AsyncMapResult))
383 assert_array_equal(amr.get(), a)
383 assert_array_equal(amr.get(), a)
384
384
385 def test_execute(self):
385 def test_execute(self):
386 view = self.client[:]
386 view = self.client[:]
387 # self.client.debug=True
387 # self.client.debug=True
388 execute = view.execute
388 execute = view.execute
389 ar = execute('c=30', block=False)
389 ar = execute('c=30', block=False)
390 self.assertTrue(isinstance(ar, AsyncResult))
390 self.assertTrue(isinstance(ar, AsyncResult))
391 ar = execute('d=[0,1,2]', block=False)
391 ar = execute('d=[0,1,2]', block=False)
392 self.client.wait(ar, 1)
392 self.client.wait(ar, 1)
393 self.assertEqual(len(ar.get()), len(self.client))
393 self.assertEqual(len(ar.get()), len(self.client))
394 for c in view['c']:
394 for c in view['c']:
395 self.assertEqual(c, 30)
395 self.assertEqual(c, 30)
396
396
397 def test_abort(self):
397 def test_abort(self):
398 view = self.client[-1]
398 view = self.client[-1]
399 ar = view.execute('import time; time.sleep(1)', block=False)
399 ar = view.execute('import time; time.sleep(1)', block=False)
400 ar2 = view.apply_async(lambda : 2)
400 ar2 = view.apply_async(lambda : 2)
401 ar3 = view.apply_async(lambda : 3)
401 ar3 = view.apply_async(lambda : 3)
402 view.abort(ar2)
402 view.abort(ar2)
403 view.abort(ar3.msg_ids)
403 view.abort(ar3.msg_ids)
404 self.assertRaises(error.TaskAborted, ar2.get)
404 self.assertRaises(error.TaskAborted, ar2.get)
405 self.assertRaises(error.TaskAborted, ar3.get)
405 self.assertRaises(error.TaskAborted, ar3.get)
406
406
407 def test_abort_all(self):
407 def test_abort_all(self):
408 """view.abort() aborts all outstanding tasks"""
408 """view.abort() aborts all outstanding tasks"""
409 view = self.client[-1]
409 view = self.client[-1]
410 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
410 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
411 view.abort()
411 view.abort()
412 view.wait(timeout=5)
412 view.wait(timeout=5)
413 for ar in ars[5:]:
413 for ar in ars[5:]:
414 self.assertRaises(error.TaskAborted, ar.get)
414 self.assertRaises(error.TaskAborted, ar.get)
415
415
416 def test_temp_flags(self):
416 def test_temp_flags(self):
417 view = self.client[-1]
417 view = self.client[-1]
418 view.block=True
418 view.block=True
419 with view.temp_flags(block=False):
419 with view.temp_flags(block=False):
420 self.assertFalse(view.block)
420 self.assertFalse(view.block)
421 self.assertTrue(view.block)
421 self.assertTrue(view.block)
422
422
423 @dec.known_failure_py3
423 @dec.known_failure_py3
424 def test_importer(self):
424 def test_importer(self):
425 view = self.client[-1]
425 view = self.client[-1]
426 view.clear(block=True)
426 view.clear(block=True)
427 with view.importer:
427 with view.importer:
428 import re
428 import re
429
429
430 @interactive
430 @interactive
431 def findall(pat, s):
431 def findall(pat, s):
432 # this globals() step isn't necessary in real code
432 # this globals() step isn't necessary in real code
433 # only to prevent a closure in the test
433 # only to prevent a closure in the test
434 re = globals()['re']
434 re = globals()['re']
435 return re.findall(pat, s)
435 return re.findall(pat, s)
436
436
437 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
437 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
438
438
439 def test_unicode_execute(self):
439 def test_unicode_execute(self):
440 """test executing unicode strings"""
440 """test executing unicode strings"""
441 v = self.client[-1]
441 v = self.client[-1]
442 v.block=True
442 v.block=True
443 if sys.version_info[0] >= 3:
443 if sys.version_info[0] >= 3:
444 code="a='é'"
444 code="a='é'"
445 else:
445 else:
446 code=u"a=u'é'"
446 code=u"a=u'é'"
447 v.execute(code)
447 v.execute(code)
448 self.assertEqual(v['a'], u'é')
448 self.assertEqual(v['a'], u'é')
449
449
450 def test_unicode_apply_result(self):
450 def test_unicode_apply_result(self):
451 """test unicode apply results"""
451 """test unicode apply results"""
452 v = self.client[-1]
452 v = self.client[-1]
453 r = v.apply_sync(lambda : u'é')
453 r = v.apply_sync(lambda : u'é')
454 self.assertEqual(r, u'é')
454 self.assertEqual(r, u'é')
455
455
456 def test_unicode_apply_arg(self):
456 def test_unicode_apply_arg(self):
457 """test passing unicode arguments to apply"""
457 """test passing unicode arguments to apply"""
458 v = self.client[-1]
458 v = self.client[-1]
459
459
460 @interactive
460 @interactive
461 def check_unicode(a, check):
461 def check_unicode(a, check):
462 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
462 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
463 assert isinstance(check, bytes), "%r is not bytes"%check
463 assert isinstance(check, bytes), "%r is not bytes"%check
464 assert a.encode('utf8') == check, "%s != %s"%(a,check)
464 assert a.encode('utf8') == check, "%s != %s"%(a,check)
465
465
466 for s in [ u'é', u'ßø®∫',u'asdf' ]:
466 for s in [ u'é', u'ßø®∫',u'asdf' ]:
467 try:
467 try:
468 v.apply_sync(check_unicode, s, s.encode('utf8'))
468 v.apply_sync(check_unicode, s, s.encode('utf8'))
469 except error.RemoteError as e:
469 except error.RemoteError as e:
470 if e.ename == 'AssertionError':
470 if e.ename == 'AssertionError':
471 self.fail(e.evalue)
471 self.fail(e.evalue)
472 else:
472 else:
473 raise e
473 raise e
474
474
475 def test_map_reference(self):
475 def test_map_reference(self):
476 """view.map(<Reference>, *seqs) should work"""
476 """view.map(<Reference>, *seqs) should work"""
477 v = self.client[:]
477 v = self.client[:]
478 v.scatter('n', self.client.ids, flatten=True)
478 v.scatter('n', self.client.ids, flatten=True)
479 v.execute("f = lambda x,y: x*y")
479 v.execute("f = lambda x,y: x*y")
480 rf = pmod.Reference('f')
480 rf = pmod.Reference('f')
481 nlist = list(range(10))
481 nlist = list(range(10))
482 mlist = nlist[::-1]
482 mlist = nlist[::-1]
483 expected = [ m*n for m,n in zip(mlist, nlist) ]
483 expected = [ m*n for m,n in zip(mlist, nlist) ]
484 result = v.map_sync(rf, mlist, nlist)
484 result = v.map_sync(rf, mlist, nlist)
485 self.assertEqual(result, expected)
485 self.assertEqual(result, expected)
486
486
487 def test_apply_reference(self):
487 def test_apply_reference(self):
488 """view.apply(<Reference>, *args) should work"""
488 """view.apply(<Reference>, *args) should work"""
489 v = self.client[:]
489 v = self.client[:]
490 v.scatter('n', self.client.ids, flatten=True)
490 v.scatter('n', self.client.ids, flatten=True)
491 v.execute("f = lambda x: n*x")
491 v.execute("f = lambda x: n*x")
492 rf = pmod.Reference('f')
492 rf = pmod.Reference('f')
493 result = v.apply_sync(rf, 5)
493 result = v.apply_sync(rf, 5)
494 expected = [ 5*id for id in self.client.ids ]
494 expected = [ 5*id for id in self.client.ids ]
495 self.assertEqual(result, expected)
495 self.assertEqual(result, expected)
496
496
497 def test_eval_reference(self):
497 def test_eval_reference(self):
498 v = self.client[self.client.ids[0]]
498 v = self.client[self.client.ids[0]]
499 v['g'] = list(range(5))
499 v['g'] = list(range(5))
500 rg = pmod.Reference('g[0]')
500 rg = pmod.Reference('g[0]')
501 echo = lambda x:x
501 echo = lambda x:x
502 self.assertEqual(v.apply_sync(echo, rg), 0)
502 self.assertEqual(v.apply_sync(echo, rg), 0)
503
503
504 def test_reference_nameerror(self):
504 def test_reference_nameerror(self):
505 v = self.client[self.client.ids[0]]
505 v = self.client[self.client.ids[0]]
506 r = pmod.Reference('elvis_has_left')
506 r = pmod.Reference('elvis_has_left')
507 echo = lambda x:x
507 echo = lambda x:x
508 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
508 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
509
509
510 def test_single_engine_map(self):
510 def test_single_engine_map(self):
511 e0 = self.client[self.client.ids[0]]
511 e0 = self.client[self.client.ids[0]]
512 r = list(range(5))
512 r = list(range(5))
513 check = [ -1*i for i in r ]
513 check = [ -1*i for i in r ]
514 result = e0.map_sync(lambda x: -1*x, r)
514 result = e0.map_sync(lambda x: -1*x, r)
515 self.assertEqual(result, check)
515 self.assertEqual(result, check)
516
516
517 def test_len(self):
517 def test_len(self):
518 """len(view) makes sense"""
518 """len(view) makes sense"""
519 e0 = self.client[self.client.ids[0]]
519 e0 = self.client[self.client.ids[0]]
520 self.assertEqual(len(e0), 1)
520 self.assertEqual(len(e0), 1)
521 v = self.client[:]
521 v = self.client[:]
522 self.assertEqual(len(v), len(self.client.ids))
522 self.assertEqual(len(v), len(self.client.ids))
523 v = self.client.direct_view('all')
523 v = self.client.direct_view('all')
524 self.assertEqual(len(v), len(self.client.ids))
524 self.assertEqual(len(v), len(self.client.ids))
525 v = self.client[:2]
525 v = self.client[:2]
526 self.assertEqual(len(v), 2)
526 self.assertEqual(len(v), 2)
527 v = self.client[:1]
527 v = self.client[:1]
528 self.assertEqual(len(v), 1)
528 self.assertEqual(len(v), 1)
529 v = self.client.load_balanced_view()
529 v = self.client.load_balanced_view()
530 self.assertEqual(len(v), len(self.client.ids))
530 self.assertEqual(len(v), len(self.client.ids))
531
531
532
532
533 # begin execute tests
533 # begin execute tests
534
534
535 def test_execute_reply(self):
535 def test_execute_reply(self):
536 e0 = self.client[self.client.ids[0]]
536 e0 = self.client[self.client.ids[0]]
537 e0.block = True
537 e0.block = True
538 ar = e0.execute("5", silent=False)
538 ar = e0.execute("5", silent=False)
539 er = ar.get()
539 er = ar.get()
540 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
540 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
541 self.assertEqual(er.execute_result['data']['text/plain'], '5')
541 self.assertEqual(er.execute_result['data']['text/plain'], '5')
542
542
543 def test_execute_reply_rich(self):
543 def test_execute_reply_rich(self):
544 e0 = self.client[self.client.ids[0]]
544 e0 = self.client[self.client.ids[0]]
545 e0.block = True
545 e0.block = True
546 e0.execute("from IPython.display import Image, HTML")
546 e0.execute("from IPython.display import Image, HTML")
547 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
547 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
548 er = ar.get()
548 er = ar.get()
549 b64data = base64.encodestring(b'garbage').decode('ascii')
549 b64data = base64.encodestring(b'garbage').decode('ascii')
550 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
550 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
551 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
551 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
552 er = ar.get()
552 er = ar.get()
553 self.assertEqual(er._repr_html_(), "<b>bold</b>")
553 self.assertEqual(er._repr_html_(), "<b>bold</b>")
554
554
555 def test_execute_reply_stdout(self):
555 def test_execute_reply_stdout(self):
556 e0 = self.client[self.client.ids[0]]
556 e0 = self.client[self.client.ids[0]]
557 e0.block = True
557 e0.block = True
558 ar = e0.execute("print (5)", silent=False)
558 ar = e0.execute("print (5)", silent=False)
559 er = ar.get()
559 er = ar.get()
560 self.assertEqual(er.stdout.strip(), '5')
560 self.assertEqual(er.stdout.strip(), '5')
561
561
562 def test_execute_result(self):
562 def test_execute_result(self):
563 """execute triggers execute_result with silent=False"""
563 """execute triggers execute_result with silent=False"""
564 view = self.client[:]
564 view = self.client[:]
565 ar = view.execute("5", silent=False, block=True)
565 ar = view.execute("5", silent=False, block=True)
566
566
567 expected = [{'text/plain' : '5'}] * len(view)
567 expected = [{'text/plain' : '5'}] * len(view)
568 mimes = [ out['data'] for out in ar.execute_result ]
568 mimes = [ out['data'] for out in ar.execute_result ]
569 self.assertEqual(mimes, expected)
569 self.assertEqual(mimes, expected)
570
570
571 def test_execute_silent(self):
571 def test_execute_silent(self):
572 """execute does not trigger execute_result with silent=True"""
572 """execute does not trigger execute_result with silent=True"""
573 view = self.client[:]
573 view = self.client[:]
574 ar = view.execute("5", block=True)
574 ar = view.execute("5", block=True)
575 expected = [None] * len(view)
575 expected = [None] * len(view)
576 self.assertEqual(ar.execute_result, expected)
576 self.assertEqual(ar.execute_result, expected)
577
577
578 def test_execute_magic(self):
578 def test_execute_magic(self):
579 """execute accepts IPython commands"""
579 """execute accepts IPython commands"""
580 view = self.client[:]
580 view = self.client[:]
581 view.execute("a = 5")
581 view.execute("a = 5")
582 ar = view.execute("%whos", block=True)
582 ar = view.execute("%whos", block=True)
583 # this will raise, if that failed
583 # this will raise, if that failed
584 ar.get(5)
584 ar.get(5)
585 for stdout in ar.stdout:
585 for stdout in ar.stdout:
586 lines = stdout.splitlines()
586 lines = stdout.splitlines()
587 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
587 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
588 found = False
588 found = False
589 for line in lines[2:]:
589 for line in lines[2:]:
590 split = line.split()
590 split = line.split()
591 if split == ['a', 'int', '5']:
591 if split == ['a', 'int', '5']:
592 found = True
592 found = True
593 break
593 break
594 self.assertTrue(found, "whos output wrong: %s" % stdout)
594 self.assertTrue(found, "whos output wrong: %s" % stdout)
595
595
596 def test_execute_displaypub(self):
596 def test_execute_displaypub(self):
597 """execute tracks display_pub output"""
597 """execute tracks display_pub output"""
598 view = self.client[:]
598 view = self.client[:]
599 view.execute("from IPython.core.display import *")
599 view.execute("from IPython.core.display import *")
600 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
600 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
601
601
602 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
602 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
603 for outputs in ar.outputs:
603 for outputs in ar.outputs:
604 mimes = [ out['data'] for out in outputs ]
604 mimes = [ out['data'] for out in outputs ]
605 self.assertEqual(mimes, expected)
605 self.assertEqual(mimes, expected)
606
606
607 def test_apply_displaypub(self):
607 def test_apply_displaypub(self):
608 """apply tracks display_pub output"""
608 """apply tracks display_pub output"""
609 view = self.client[:]
609 view = self.client[:]
610 view.execute("from IPython.core.display import *")
610 view.execute("from IPython.core.display import *")
611
611
612 @interactive
612 @interactive
613 def publish():
613 def publish():
614 [ display(i) for i in range(5) ]
614 [ display(i) for i in range(5) ]
615
615
616 ar = view.apply_async(publish)
616 ar = view.apply_async(publish)
617 ar.get(5)
617 ar.get(5)
618 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
618 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
619 for outputs in ar.outputs:
619 for outputs in ar.outputs:
620 mimes = [ out['data'] for out in outputs ]
620 mimes = [ out['data'] for out in outputs ]
621 self.assertEqual(mimes, expected)
621 self.assertEqual(mimes, expected)
622
622
623 def test_execute_raises(self):
623 def test_execute_raises(self):
624 """exceptions in execute requests raise appropriately"""
624 """exceptions in execute requests raise appropriately"""
625 view = self.client[-1]
625 view = self.client[-1]
626 ar = view.execute("1/0")
626 ar = view.execute("1/0")
627 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
627 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
628
628
629 def test_remoteerror_render_exception(self):
629 def test_remoteerror_render_exception(self):
630 """RemoteErrors get nice tracebacks"""
630 """RemoteErrors get nice tracebacks"""
631 view = self.client[-1]
631 view = self.client[-1]
632 ar = view.execute("1/0")
632 ar = view.execute("1/0")
633 ip = get_ipython()
633 ip = get_ipython()
634 ip.user_ns['ar'] = ar
634 ip.user_ns['ar'] = ar
635 with capture_output() as io:
635 with capture_output() as io:
636 ip.run_cell("ar.get(2)")
636 ip.run_cell("ar.get(2)")
637
637
638 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
638 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
639
639
640 def test_compositeerror_render_exception(self):
640 def test_compositeerror_render_exception(self):
641 """CompositeErrors get nice tracebacks"""
641 """CompositeErrors get nice tracebacks"""
642 view = self.client[:]
642 view = self.client[:]
643 ar = view.execute("1/0")
643 ar = view.execute("1/0")
644 ip = get_ipython()
644 ip = get_ipython()
645 ip.user_ns['ar'] = ar
645 ip.user_ns['ar'] = ar
646
646
647 with capture_output() as io:
647 with capture_output() as io:
648 ip.run_cell("ar.get(2)")
648 ip.run_cell("ar.get(2)")
649
649
650 count = min(error.CompositeError.tb_limit, len(view))
650 count = min(error.CompositeError.tb_limit, len(view))
651
651
652 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
652 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
653 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
653 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
654 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
654 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
655
655
656 def test_compositeerror_truncate(self):
656 def test_compositeerror_truncate(self):
657 """Truncate CompositeErrors with many exceptions"""
657 """Truncate CompositeErrors with many exceptions"""
658 view = self.client[:]
658 view = self.client[:]
659 msg_ids = []
659 msg_ids = []
660 for i in range(10):
660 for i in range(10):
661 ar = view.execute("1/0")
661 ar = view.execute("1/0")
662 msg_ids.extend(ar.msg_ids)
662 msg_ids.extend(ar.msg_ids)
663
663
664 ar = self.client.get_result(msg_ids)
664 ar = self.client.get_result(msg_ids)
665 try:
665 try:
666 ar.get()
666 ar.get()
667 except error.CompositeError as _e:
667 except error.CompositeError as _e:
668 e = _e
668 e = _e
669 else:
669 else:
670 self.fail("Should have raised CompositeError")
670 self.fail("Should have raised CompositeError")
671
671
672 lines = e.render_traceback()
672 lines = e.render_traceback()
673 with capture_output() as io:
673 with capture_output() as io:
674 e.print_traceback()
674 e.print_traceback()
675
675
676 self.assertTrue("more exceptions" in lines[-1])
676 self.assertTrue("more exceptions" in lines[-1])
677 count = e.tb_limit
677 count = e.tb_limit
678
678
679 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
679 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
680 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
680 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
681 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
681 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
682
682
683 @dec.skipif_not_matplotlib
683 @dec.skipif_not_matplotlib
684 def test_magic_pylab(self):
684 def test_magic_pylab(self):
685 """%pylab works on engines"""
685 """%pylab works on engines"""
686 view = self.client[-1]
686 view = self.client[-1]
687 ar = view.execute("%pylab inline")
687 ar = view.execute("%pylab inline")
688 # at least check if this raised:
688 # at least check if this raised:
689 reply = ar.get(5)
689 reply = ar.get(5)
690 # include imports, in case user config
690 # include imports, in case user config
691 ar = view.execute("plot(rand(100))", silent=False)
691 ar = view.execute("plot(rand(100))", silent=False)
692 reply = ar.get(5)
692 reply = ar.get(5)
693 self.assertEqual(len(reply.outputs), 1)
693 self.assertEqual(len(reply.outputs), 1)
694 output = reply.outputs[0]
694 output = reply.outputs[0]
695 self.assertTrue("data" in output)
695 self.assertTrue("data" in output)
696 data = output['data']
696 data = output['data']
697 self.assertTrue("image/png" in data)
697 self.assertTrue("image/png" in data)
698
698
699 def test_func_default_func(self):
699 def test_func_default_func(self):
700 """interactively defined function as apply func default"""
700 """interactively defined function as apply func default"""
701 def foo():
701 def foo():
702 return 'foo'
702 return 'foo'
703
703
704 def bar(f=foo):
704 def bar(f=foo):
705 return f()
705 return f()
706
706
707 view = self.client[-1]
707 view = self.client[-1]
708 ar = view.apply_async(bar)
708 ar = view.apply_async(bar)
709 r = ar.get(10)
709 r = ar.get(10)
710 self.assertEqual(r, 'foo')
710 self.assertEqual(r, 'foo')
711 def test_data_pub_single(self):
711 def test_data_pub_single(self):
712 view = self.client[-1]
712 view = self.client[-1]
713 ar = view.execute('\n'.join([
713 ar = view.execute('\n'.join([
714 'from IPython.kernel.zmq.datapub import publish_data',
714 'from IPython.kernel.zmq.datapub import publish_data',
715 'for i in range(5):',
715 'for i in range(5):',
716 ' publish_data(dict(i=i))'
716 ' publish_data(dict(i=i))'
717 ]), block=False)
717 ]), block=False)
718 self.assertTrue(isinstance(ar.data, dict))
718 self.assertTrue(isinstance(ar.data, dict))
719 ar.get(5)
719 ar.get(5)
720 self.assertEqual(ar.data, dict(i=4))
720 self.assertEqual(ar.data, dict(i=4))
721
721
722 def test_data_pub(self):
722 def test_data_pub(self):
723 view = self.client[:]
723 view = self.client[:]
724 ar = view.execute('\n'.join([
724 ar = view.execute('\n'.join([
725 'from IPython.kernel.zmq.datapub import publish_data',
725 'from IPython.kernel.zmq.datapub import publish_data',
726 'for i in range(5):',
726 'for i in range(5):',
727 ' publish_data(dict(i=i))'
727 ' publish_data(dict(i=i))'
728 ]), block=False)
728 ]), block=False)
729 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
729 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
730 ar.get(5)
730 ar.get(5)
731 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
731 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
732
732
733 def test_can_list_arg(self):
733 def test_can_list_arg(self):
734 """args in lists are canned"""
734 """args in lists are canned"""
735 view = self.client[-1]
735 view = self.client[-1]
736 view['a'] = 128
736 view['a'] = 128
737 rA = pmod.Reference('a')
737 rA = pmod.Reference('a')
738 ar = view.apply_async(lambda x: x, [rA])
738 ar = view.apply_async(lambda x: x, [rA])
739 r = ar.get(5)
739 r = ar.get(5)
740 self.assertEqual(r, [128])
740 self.assertEqual(r, [128])
741
741
742 def test_can_dict_arg(self):
742 def test_can_dict_arg(self):
743 """args in dicts are canned"""
743 """args in dicts are canned"""
744 view = self.client[-1]
744 view = self.client[-1]
745 view['a'] = 128
745 view['a'] = 128
746 rA = pmod.Reference('a')
746 rA = pmod.Reference('a')
747 ar = view.apply_async(lambda x: x, dict(foo=rA))
747 ar = view.apply_async(lambda x: x, dict(foo=rA))
748 r = ar.get(5)
748 r = ar.get(5)
749 self.assertEqual(r, dict(foo=128))
749 self.assertEqual(r, dict(foo=128))
750
750
751 def test_can_list_kwarg(self):
751 def test_can_list_kwarg(self):
752 """kwargs in lists are canned"""
752 """kwargs in lists are canned"""
753 view = self.client[-1]
753 view = self.client[-1]
754 view['a'] = 128
754 view['a'] = 128
755 rA = pmod.Reference('a')
755 rA = pmod.Reference('a')
756 ar = view.apply_async(lambda x=5: x, x=[rA])
756 ar = view.apply_async(lambda x=5: x, x=[rA])
757 r = ar.get(5)
757 r = ar.get(5)
758 self.assertEqual(r, [128])
758 self.assertEqual(r, [128])
759
759
760 def test_can_dict_kwarg(self):
760 def test_can_dict_kwarg(self):
761 """kwargs in dicts are canned"""
761 """kwargs in dicts are canned"""
762 view = self.client[-1]
762 view = self.client[-1]
763 view['a'] = 128
763 view['a'] = 128
764 rA = pmod.Reference('a')
764 rA = pmod.Reference('a')
765 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
765 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
766 r = ar.get(5)
766 r = ar.get(5)
767 self.assertEqual(r, dict(foo=128))
767 self.assertEqual(r, dict(foo=128))
768
768
769 def test_map_ref(self):
769 def test_map_ref(self):
770 """view.map works with references"""
770 """view.map works with references"""
771 view = self.client[:]
771 view = self.client[:]
772 ranks = sorted(self.client.ids)
772 ranks = sorted(self.client.ids)
773 view.scatter('rank', ranks, flatten=True)
773 view.scatter('rank', ranks, flatten=True)
774 rrank = pmod.Reference('rank')
774 rrank = pmod.Reference('rank')
775
775
776 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
776 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
777 drank = amr.get(5)
777 drank = amr.get(5)
778 self.assertEqual(drank, [ r*2 for r in ranks ])
778 self.assertEqual(drank, [ r*2 for r in ranks ])
779
779
780 def test_nested_getitem_setitem(self):
780 def test_nested_getitem_setitem(self):
781 """get and set with view['a.b']"""
781 """get and set with view['a.b']"""
782 view = self.client[-1]
782 view = self.client[-1]
783 view.execute('\n'.join([
783 view.execute('\n'.join([
784 'class A(object): pass',
784 'class A(object): pass',
785 'a = A()',
785 'a = A()',
786 'a.b = 128',
786 'a.b = 128',
787 ]), block=True)
787 ]), block=True)
788 ra = pmod.Reference('a')
788 ra = pmod.Reference('a')
789
789
790 r = view.apply_sync(lambda x: x.b, ra)
790 r = view.apply_sync(lambda x: x.b, ra)
791 self.assertEqual(r, 128)
791 self.assertEqual(r, 128)
792 self.assertEqual(view['a.b'], 128)
792 self.assertEqual(view['a.b'], 128)
793
793
794 view['a.b'] = 0
794 view['a.b'] = 0
795
795
796 r = view.apply_sync(lambda x: x.b, ra)
796 r = view.apply_sync(lambda x: x.b, ra)
797 self.assertEqual(r, 0)
797 self.assertEqual(r, 0)
798 self.assertEqual(view['a.b'], 0)
798 self.assertEqual(view['a.b'], 0)
799
799
800 def test_return_namedtuple(self):
800 def test_return_namedtuple(self):
801 def namedtuplify(x, y):
801 def namedtuplify(x, y):
802 from IPython.parallel.tests.test_view import point
802 from ipython_parallel.tests.test_view import point
803 return point(x, y)
803 return point(x, y)
804
804
805 view = self.client[-1]
805 view = self.client[-1]
806 p = view.apply_sync(namedtuplify, 1, 2)
806 p = view.apply_sync(namedtuplify, 1, 2)
807 self.assertEqual(p.x, 1)
807 self.assertEqual(p.x, 1)
808 self.assertEqual(p.y, 2)
808 self.assertEqual(p.y, 2)
809
809
810 def test_apply_namedtuple(self):
810 def test_apply_namedtuple(self):
811 def echoxy(p):
811 def echoxy(p):
812 return p.y, p.x
812 return p.y, p.x
813
813
814 view = self.client[-1]
814 view = self.client[-1]
815 tup = view.apply_sync(echoxy, point(1, 2))
815 tup = view.apply_sync(echoxy, point(1, 2))
816 self.assertEqual(tup, (2,1))
816 self.assertEqual(tup, (2,1))
817
817
818 def test_sync_imports(self):
818 def test_sync_imports(self):
819 view = self.client[-1]
819 view = self.client[-1]
820 with capture_output() as io:
820 with capture_output() as io:
821 with view.sync_imports():
821 with view.sync_imports():
822 import IPython
822 import IPython
823 self.assertIn("IPython", io.stdout)
823 self.assertIn("IPython", io.stdout)
824
824
825 @interactive
825 @interactive
826 def find_ipython():
826 def find_ipython():
827 return 'IPython' in globals()
827 return 'IPython' in globals()
828
828
829 assert view.apply_sync(find_ipython)
829 assert view.apply_sync(find_ipython)
830
830
831 def test_sync_imports_quiet(self):
831 def test_sync_imports_quiet(self):
832 view = self.client[-1]
832 view = self.client[-1]
833 with capture_output() as io:
833 with capture_output() as io:
834 with view.sync_imports(quiet=True):
834 with view.sync_imports(quiet=True):
835 import IPython
835 import IPython
836 self.assertEqual(io.stdout, '')
836 self.assertEqual(io.stdout, '')
837
837
838 @interactive
838 @interactive
839 def find_ipython():
839 def find_ipython():
840 return 'IPython' in globals()
840 return 'IPython' in globals()
841
841
842 assert view.apply_sync(find_ipython)
842 assert view.apply_sync(find_ipython)
843
843
General Comments 0
You need to be logged in to leave comments. Login now