##// END OF EJS Templates
Fix parallel test suite
Thomas Kluyver -
Show More
@@ -1,275 +1,275 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 The Base Application class for IPython.parallel apps
3 The Base Application class for IPython.parallel apps
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Min RK
8 * Min RK
9
9
10 """
10 """
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2011 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 import os
23 import os
24 import logging
24 import logging
25 import re
25 import re
26 import sys
26 import sys
27
27
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29
29
30 from IPython.config.application import catch_config_error, LevelFormatter
30 from IPython.config.application import catch_config_error, LevelFormatter
31 from IPython.core import release
31 from IPython.core import release
32 from IPython.core.crashhandler import CrashHandler
32 from IPython.core.crashhandler import CrashHandler
33 from IPython.core.application import (
33 from IPython.core.application import (
34 BaseIPythonApplication,
34 BaseIPythonApplication,
35 base_aliases as base_ip_aliases,
35 base_aliases as base_ip_aliases,
36 base_flags as base_ip_flags
36 base_flags as base_ip_flags
37 )
37 )
38 from IPython.utils.path import expand_path
38 from IPython.utils.path import expand_path
39 from IPython.utils.py3compat import unicode_type
39 from IPython.utils.py3compat import unicode_type
40
40
41 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict
41 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Module errors
44 # Module errors
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 class PIDFileError(Exception):
47 class PIDFileError(Exception):
48 pass
48 pass
49
49
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Crash handler for this application
52 # Crash handler for this application
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54
54
55 class ParallelCrashHandler(CrashHandler):
55 class ParallelCrashHandler(CrashHandler):
56 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
56 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
57
57
58 def __init__(self, app):
58 def __init__(self, app):
59 contact_name = release.authors['Min'][0]
59 contact_name = release.authors['Min'][0]
60 contact_email = release.author_email
60 contact_email = release.author_email
61 bug_tracker = 'https://github.com/ipython/ipython/issues'
61 bug_tracker = 'https://github.com/ipython/ipython/issues'
62 super(ParallelCrashHandler,self).__init__(
62 super(ParallelCrashHandler,self).__init__(
63 app, contact_name, contact_email, bug_tracker
63 app, contact_name, contact_email, bug_tracker
64 )
64 )
65
65
66
66
67 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
68 # Main application
68 # Main application
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70 base_aliases = {}
70 base_aliases = {}
71 base_aliases.update(base_ip_aliases)
71 base_aliases.update(base_ip_aliases)
72 base_aliases.update({
72 base_aliases.update({
73 'work-dir' : 'BaseParallelApplication.work_dir',
73 'work-dir' : 'BaseParallelApplication.work_dir',
74 'log-to-file' : 'BaseParallelApplication.log_to_file',
74 'log-to-file' : 'BaseParallelApplication.log_to_file',
75 'clean-logs' : 'BaseParallelApplication.clean_logs',
75 'clean-logs' : 'BaseParallelApplication.clean_logs',
76 'log-url' : 'BaseParallelApplication.log_url',
76 'log-url' : 'BaseParallelApplication.log_url',
77 'cluster-id' : 'BaseParallelApplication.cluster_id',
77 'cluster-id' : 'BaseParallelApplication.cluster_id',
78 })
78 })
79
79
80 base_flags = {
80 base_flags = {
81 'log-to-file' : (
81 'log-to-file' : (
82 {'BaseParallelApplication' : {'log_to_file' : True}},
82 {'BaseParallelApplication' : {'log_to_file' : True}},
83 "send log output to a file"
83 "send log output to a file"
84 )
84 )
85 }
85 }
86 base_flags.update(base_ip_flags)
86 base_flags.update(base_ip_flags)
87
87
88 class BaseParallelApplication(BaseIPythonApplication):
88 class BaseParallelApplication(BaseIPythonApplication):
89 """The base Application for IPython.parallel apps
89 """The base Application for IPython.parallel apps
90
90
91 Principle extensions to BaseIPyythonApplication:
91 Principle extensions to BaseIPyythonApplication:
92
92
93 * work_dir
93 * work_dir
94 * remote logging via pyzmq
94 * remote logging via pyzmq
95 * IOLoop instance
95 * IOLoop instance
96 """
96 """
97
97
98 crash_handler_class = ParallelCrashHandler
98 crash_handler_class = ParallelCrashHandler
99
99
100 def _log_level_default(self):
100 def _log_level_default(self):
101 # temporarily override default_log_level to INFO
101 # temporarily override default_log_level to INFO
102 return logging.INFO
102 return logging.INFO
103
103
104 def _log_format_default(self):
104 def _log_format_default(self):
105 """override default log format to include time"""
105 """override default log format to include time"""
106 return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"
106 return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"
107
107
108 work_dir = Unicode(os.getcwdu(), config=True,
108 work_dir = Unicode(os.getcwdu(), config=True,
109 help='Set the working dir for the process.'
109 help='Set the working dir for the process.'
110 )
110 )
111 def _work_dir_changed(self, name, old, new):
111 def _work_dir_changed(self, name, old, new):
112 self.work_dir = unicode_type(expand_path(new))
112 self.work_dir = unicode_type(expand_path(new))
113
113
114 log_to_file = Bool(config=True,
114 log_to_file = Bool(config=True,
115 help="whether to log to a file")
115 help="whether to log to a file")
116
116
117 clean_logs = Bool(False, config=True,
117 clean_logs = Bool(False, config=True,
118 help="whether to cleanup old logfiles before starting")
118 help="whether to cleanup old logfiles before starting")
119
119
120 log_url = Unicode('', config=True,
120 log_url = Unicode('', config=True,
121 help="The ZMQ URL of the iplogger to aggregate logging.")
121 help="The ZMQ URL of the iplogger to aggregate logging.")
122
122
123 cluster_id = Unicode('', config=True,
123 cluster_id = Unicode('', config=True,
124 help="""String id to add to runtime files, to prevent name collisions when
124 help="""String id to add to runtime files, to prevent name collisions when
125 using multiple clusters with a single profile simultaneously.
125 using multiple clusters with a single profile simultaneously.
126
126
127 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
127 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
128
128
129 Since this is text inserted into filenames, typical recommendations apply:
129 Since this is text inserted into filenames, typical recommendations apply:
130 Simple character strings are ideal, and spaces are not recommended (but should
130 Simple character strings are ideal, and spaces are not recommended (but should
131 generally work).
131 generally work).
132 """
132 """
133 )
133 )
134 def _cluster_id_changed(self, name, old, new):
134 def _cluster_id_changed(self, name, old, new):
135 self.name = self.__class__.name
135 self.name = self.__class__.name
136 if new:
136 if new:
137 self.name += '-%s'%new
137 self.name += '-%s'%new
138
138
139 def _config_files_default(self):
139 def _config_files_default(self):
140 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
140 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
141
141
142 loop = Instance('zmq.eventloop.ioloop.IOLoop')
142 loop = Instance('zmq.eventloop.ioloop.IOLoop')
143 def _loop_default(self):
143 def _loop_default(self):
144 from zmq.eventloop.ioloop import IOLoop
144 from zmq.eventloop.ioloop import IOLoop
145 return IOLoop.instance()
145 return IOLoop.instance()
146
146
147 aliases = Dict(base_aliases)
147 aliases = Dict(base_aliases)
148 flags = Dict(base_flags)
148 flags = Dict(base_flags)
149
149
150 @catch_config_error
150 @catch_config_error
151 def initialize(self, argv=None):
151 def initialize(self, argv=None):
152 """initialize the app"""
152 """initialize the app"""
153 super(BaseParallelApplication, self).initialize(argv)
153 super(BaseParallelApplication, self).initialize(argv)
154 self.to_work_dir()
154 self.to_work_dir()
155 self.reinit_logging()
155 self.reinit_logging()
156
156
157 def to_work_dir(self):
157 def to_work_dir(self):
158 wd = self.work_dir
158 wd = self.work_dir
159 if unicode_type(wd) != os.getcwdu():
159 if unicode_type(wd) != os.getcwdu():
160 os.chdir(wd)
160 os.chdir(wd)
161 self.log.info("Changing to working dir: %s" % wd)
161 self.log.info("Changing to working dir: %s" % wd)
162 # This is the working dir by now.
162 # This is the working dir by now.
163 sys.path.insert(0, '')
163 sys.path.insert(0, '')
164
164
165 def reinit_logging(self):
165 def reinit_logging(self):
166 # Remove old log files
166 # Remove old log files
167 log_dir = self.profile_dir.log_dir
167 log_dir = self.profile_dir.log_dir
168 if self.clean_logs:
168 if self.clean_logs:
169 for f in os.listdir(log_dir):
169 for f in os.listdir(log_dir):
170 if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
170 if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
171 try:
171 try:
172 os.remove(os.path.join(log_dir, f))
172 os.remove(os.path.join(log_dir, f))
173 except (OSError, IOError):
173 except (OSError, IOError):
174 # probably just conflict from sibling process
174 # probably just conflict from sibling process
175 # already removing it
175 # already removing it
176 pass
176 pass
177 if self.log_to_file:
177 if self.log_to_file:
178 # Start logging to the new log file
178 # Start logging to the new log file
179 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
179 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
180 logfile = os.path.join(log_dir, log_filename)
180 logfile = os.path.join(log_dir, log_filename)
181 open_log_file = open(logfile, 'w')
181 open_log_file = open(logfile, 'w')
182 else:
182 else:
183 open_log_file = None
183 open_log_file = None
184 if open_log_file is not None:
184 if open_log_file is not None:
185 while self.log.handlers:
185 while self.log.handlers:
186 self.log.removeHandler(self.log.handlers[0])
186 self.log.removeHandler(self.log.handlers[0])
187 self._log_handler = logging.StreamHandler(open_log_file)
187 self._log_handler = logging.StreamHandler(open_log_file)
188 self.log.addHandler(self._log_handler)
188 self.log.addHandler(self._log_handler)
189 else:
189 else:
190 self._log_handler = self.log.handlers[0]
190 self._log_handler = self.log.handlers[0]
191 # Add timestamps to log format:
191 # Add timestamps to log format:
192 self._log_formatter = LevelFormatter(self.log_format,
192 self._log_formatter = LevelFormatter(self.log_format,
193 datefmt=self.log_datefmt)
193 datefmt=self.log_datefmt)
194 self._log_handler.setFormatter(self._log_formatter)
194 self._log_handler.setFormatter(self._log_formatter)
195 # do not propagate log messages to root logger
195 # do not propagate log messages to root logger
196 # ipcluster app will sometimes print duplicate messages during shutdown
196 # ipcluster app will sometimes print duplicate messages during shutdown
197 # if this is 1 (default):
197 # if this is 1 (default):
198 self.log.propagate = False
198 self.log.propagate = False
199
199
200 def write_pid_file(self, overwrite=False):
200 def write_pid_file(self, overwrite=False):
201 """Create a .pid file in the pid_dir with my pid.
201 """Create a .pid file in the pid_dir with my pid.
202
202
203 This must be called after pre_construct, which sets `self.pid_dir`.
203 This must be called after pre_construct, which sets `self.pid_dir`.
204 This raises :exc:`PIDFileError` if the pid file exists already.
204 This raises :exc:`PIDFileError` if the pid file exists already.
205 """
205 """
206 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
206 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
207 if os.path.isfile(pid_file):
207 if os.path.isfile(pid_file):
208 pid = self.get_pid_from_file()
208 pid = self.get_pid_from_file()
209 if not overwrite:
209 if not overwrite:
210 raise PIDFileError(
210 raise PIDFileError(
211 'The pid file [%s] already exists. \nThis could mean that this '
211 'The pid file [%s] already exists. \nThis could mean that this '
212 'server is already running with [pid=%s].' % (pid_file, pid)
212 'server is already running with [pid=%s].' % (pid_file, pid)
213 )
213 )
214 with open(pid_file, 'w') as f:
214 with open(pid_file, 'w') as f:
215 self.log.info("Creating pid file: %s" % pid_file)
215 self.log.info("Creating pid file: %s" % pid_file)
216 f.write(repr(os.getpid())+'\n')
216 f.write(repr(os.getpid())+'\n')
217
217
218 def remove_pid_file(self):
218 def remove_pid_file(self):
219 """Remove the pid file.
219 """Remove the pid file.
220
220
221 This should be called at shutdown by registering a callback with
221 This should be called at shutdown by registering a callback with
222 :func:`reactor.addSystemEventTrigger`. This needs to return
222 :func:`reactor.addSystemEventTrigger`. This needs to return
223 ``None``.
223 ``None``.
224 """
224 """
225 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
225 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
226 if os.path.isfile(pid_file):
226 if os.path.isfile(pid_file):
227 try:
227 try:
228 self.log.info("Removing pid file: %s" % pid_file)
228 self.log.info("Removing pid file: %s" % pid_file)
229 os.remove(pid_file)
229 os.remove(pid_file)
230 except:
230 except:
231 self.log.warn("Error removing the pid file: %s" % pid_file)
231 self.log.warn("Error removing the pid file: %s" % pid_file)
232
232
233 def get_pid_from_file(self):
233 def get_pid_from_file(self):
234 """Get the pid from the pid file.
234 """Get the pid from the pid file.
235
235
236 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
236 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
237 """
237 """
238 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
238 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
239 if os.path.isfile(pid_file):
239 if os.path.isfile(pid_file):
240 with open(pid_file, 'r') as f:
240 with open(pid_file, 'r') as f:
241 s = f.read().strip()
241 s = f.read().strip()
242 try:
242 try:
243 pid = int(s)
243 pid = int(s)
244 except:
244 except:
245 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
245 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
246 return pid
246 return pid
247 else:
247 else:
248 raise PIDFileError('pid file not found: %s' % pid_file)
248 raise PIDFileError('pid file not found: %s' % pid_file)
249
249
250 def check_pid(self, pid):
250 def check_pid(self, pid):
251 if os.name == 'nt':
251 if os.name == 'nt':
252 try:
252 try:
253 import ctypes
253 import ctypes
254 # returns 0 if no such process (of ours) exists
254 # returns 0 if no such process (of ours) exists
255 # positive int otherwise
255 # positive int otherwise
256 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
256 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
257 except Exception:
257 except Exception:
258 self.log.warn(
258 self.log.warn(
259 "Could not determine whether pid %i is running via `OpenProcess`. "
259 "Could not determine whether pid %i is running via `OpenProcess`. "
260 " Making the likely assumption that it is."%pid
260 " Making the likely assumption that it is."%pid
261 )
261 )
262 return True
262 return True
263 return bool(p)
263 return bool(p)
264 else:
264 else:
265 try:
265 try:
266 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
266 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
267 output,_ = p.communicate()
267 output,_ = p.communicate()
268 except OSError:
268 except OSError:
269 self.log.warn(
269 self.log.warn(
270 "Could not determine whether pid %i is running via `ps x`. "
270 "Could not determine whether pid %i is running via `ps x`. "
271 " Making the likely assumption that it is."%pid
271 " Making the likely assumption that it is."%pid
272 )
272 )
273 return True
273 return True
274 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
274 pids = list(map(int, re.findall(r'^\W*\d+', output, re.MULTILINE)))
275 return pid in pids
275 return pid in pids
@@ -1,707 +1,707 b''
1 """AsyncResult objects for the client
1 """AsyncResult objects for the client
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 from __future__ import print_function
18 from __future__ import print_function
19
19
20 import sys
20 import sys
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23
23
24 from zmq import MessageTracker
24 from zmq import MessageTracker
25
25
26 from IPython.core.display import clear_output, display, display_pretty
26 from IPython.core.display import clear_output, display, display_pretty
27 from IPython.external.decorator import decorator
27 from IPython.external.decorator import decorator
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.utils.py3compat import string_types
29 from IPython.utils.py3compat import string_types
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Functions
32 # Functions
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 def _raw_text(s):
35 def _raw_text(s):
36 display_pretty(s, raw=True)
36 display_pretty(s, raw=True)
37
37
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39 # Classes
39 # Classes
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41
41
42 # global empty tracker that's always done:
42 # global empty tracker that's always done:
43 finished_tracker = MessageTracker()
43 finished_tracker = MessageTracker()
44
44
45 @decorator
45 @decorator
46 def check_ready(f, self, *args, **kwargs):
46 def check_ready(f, self, *args, **kwargs):
47 """Call spin() to sync state prior to calling the method."""
47 """Call spin() to sync state prior to calling the method."""
48 self.wait(0)
48 self.wait(0)
49 if not self._ready:
49 if not self._ready:
50 raise error.TimeoutError("result not ready")
50 raise error.TimeoutError("result not ready")
51 return f(self, *args, **kwargs)
51 return f(self, *args, **kwargs)
52
52
53 class AsyncResult(object):
53 class AsyncResult(object):
54 """Class for representing results of non-blocking calls.
54 """Class for representing results of non-blocking calls.
55
55
56 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
56 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
57 """
57 """
58
58
59 msg_ids = None
59 msg_ids = None
60 _targets = None
60 _targets = None
61 _tracker = None
61 _tracker = None
62 _single_result = False
62 _single_result = False
63
63
64 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
64 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
65 if isinstance(msg_ids, string_types):
65 if isinstance(msg_ids, string_types):
66 # always a list
66 # always a list
67 msg_ids = [msg_ids]
67 msg_ids = [msg_ids]
68 self._single_result = True
68 self._single_result = True
69 else:
69 else:
70 self._single_result = False
70 self._single_result = False
71 if tracker is None:
71 if tracker is None:
72 # default to always done
72 # default to always done
73 tracker = finished_tracker
73 tracker = finished_tracker
74 self._client = client
74 self._client = client
75 self.msg_ids = msg_ids
75 self.msg_ids = msg_ids
76 self._fname=fname
76 self._fname=fname
77 self._targets = targets
77 self._targets = targets
78 self._tracker = tracker
78 self._tracker = tracker
79
79
80 self._ready = False
80 self._ready = False
81 self._outputs_ready = False
81 self._outputs_ready = False
82 self._success = None
82 self._success = None
83 self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ]
83 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
84
84
85 def __repr__(self):
85 def __repr__(self):
86 if self._ready:
86 if self._ready:
87 return "<%s: finished>"%(self.__class__.__name__)
87 return "<%s: finished>"%(self.__class__.__name__)
88 else:
88 else:
89 return "<%s: %s>"%(self.__class__.__name__,self._fname)
89 return "<%s: %s>"%(self.__class__.__name__,self._fname)
90
90
91
91
92 def _reconstruct_result(self, res):
92 def _reconstruct_result(self, res):
93 """Reconstruct our result from actual result list (always a list)
93 """Reconstruct our result from actual result list (always a list)
94
94
95 Override me in subclasses for turning a list of results
95 Override me in subclasses for turning a list of results
96 into the expected form.
96 into the expected form.
97 """
97 """
98 if self._single_result:
98 if self._single_result:
99 return res[0]
99 return res[0]
100 else:
100 else:
101 return res
101 return res
102
102
103 def get(self, timeout=-1):
103 def get(self, timeout=-1):
104 """Return the result when it arrives.
104 """Return the result when it arrives.
105
105
106 If `timeout` is not ``None`` and the result does not arrive within
106 If `timeout` is not ``None`` and the result does not arrive within
107 `timeout` seconds then ``TimeoutError`` is raised. If the
107 `timeout` seconds then ``TimeoutError`` is raised. If the
108 remote call raised an exception then that exception will be reraised
108 remote call raised an exception then that exception will be reraised
109 by get() inside a `RemoteError`.
109 by get() inside a `RemoteError`.
110 """
110 """
111 if not self.ready():
111 if not self.ready():
112 self.wait(timeout)
112 self.wait(timeout)
113
113
114 if self._ready:
114 if self._ready:
115 if self._success:
115 if self._success:
116 return self._result
116 return self._result
117 else:
117 else:
118 raise self._exception
118 raise self._exception
119 else:
119 else:
120 raise error.TimeoutError("Result not ready.")
120 raise error.TimeoutError("Result not ready.")
121
121
122 def _check_ready(self):
122 def _check_ready(self):
123 if not self.ready():
123 if not self.ready():
124 raise error.TimeoutError("Result not ready.")
124 raise error.TimeoutError("Result not ready.")
125
125
126 def ready(self):
126 def ready(self):
127 """Return whether the call has completed."""
127 """Return whether the call has completed."""
128 if not self._ready:
128 if not self._ready:
129 self.wait(0)
129 self.wait(0)
130 elif not self._outputs_ready:
130 elif not self._outputs_ready:
131 self._wait_for_outputs(0)
131 self._wait_for_outputs(0)
132
132
133 return self._ready
133 return self._ready
134
134
135 def wait(self, timeout=-1):
135 def wait(self, timeout=-1):
136 """Wait until the result is available or until `timeout` seconds pass.
136 """Wait until the result is available or until `timeout` seconds pass.
137
137
138 This method always returns None.
138 This method always returns None.
139 """
139 """
140 if self._ready:
140 if self._ready:
141 self._wait_for_outputs(timeout)
141 self._wait_for_outputs(timeout)
142 return
142 return
143 self._ready = self._client.wait(self.msg_ids, timeout)
143 self._ready = self._client.wait(self.msg_ids, timeout)
144 if self._ready:
144 if self._ready:
145 try:
145 try:
146 results = map(self._client.results.get, self.msg_ids)
146 results = list(map(self._client.results.get, self.msg_ids))
147 self._result = results
147 self._result = results
148 if self._single_result:
148 if self._single_result:
149 r = results[0]
149 r = results[0]
150 if isinstance(r, Exception):
150 if isinstance(r, Exception):
151 raise r
151 raise r
152 else:
152 else:
153 results = error.collect_exceptions(results, self._fname)
153 results = error.collect_exceptions(results, self._fname)
154 self._result = self._reconstruct_result(results)
154 self._result = self._reconstruct_result(results)
155 except Exception as e:
155 except Exception as e:
156 self._exception = e
156 self._exception = e
157 self._success = False
157 self._success = False
158 else:
158 else:
159 self._success = True
159 self._success = True
160 finally:
160 finally:
161 if timeout is None or timeout < 0:
161 if timeout is None or timeout < 0:
162 # cutoff infinite wait at 10s
162 # cutoff infinite wait at 10s
163 timeout = 10
163 timeout = 10
164 self._wait_for_outputs(timeout)
164 self._wait_for_outputs(timeout)
165
165
166
166
167 def successful(self):
167 def successful(self):
168 """Return whether the call completed without raising an exception.
168 """Return whether the call completed without raising an exception.
169
169
170 Will raise ``AssertionError`` if the result is not ready.
170 Will raise ``AssertionError`` if the result is not ready.
171 """
171 """
172 assert self.ready()
172 assert self.ready()
173 return self._success
173 return self._success
174
174
175 #----------------------------------------------------------------
175 #----------------------------------------------------------------
176 # Extra methods not in mp.pool.AsyncResult
176 # Extra methods not in mp.pool.AsyncResult
177 #----------------------------------------------------------------
177 #----------------------------------------------------------------
178
178
179 def get_dict(self, timeout=-1):
179 def get_dict(self, timeout=-1):
180 """Get the results as a dict, keyed by engine_id.
180 """Get the results as a dict, keyed by engine_id.
181
181
182 timeout behavior is described in `get()`.
182 timeout behavior is described in `get()`.
183 """
183 """
184
184
185 results = self.get(timeout)
185 results = self.get(timeout)
186 if self._single_result:
186 if self._single_result:
187 results = [results]
187 results = [results]
188 engine_ids = [ md['engine_id'] for md in self._metadata ]
188 engine_ids = [ md['engine_id'] for md in self._metadata ]
189
189
190
190
191 rdict = {}
191 rdict = {}
192 for engine_id, result in zip(engine_ids, results):
192 for engine_id, result in zip(engine_ids, results):
193 if engine_id in rdict:
193 if engine_id in rdict:
194 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
194 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
195 engine_ids.count(engine_id), engine_id)
195 engine_ids.count(engine_id), engine_id)
196 )
196 )
197 else:
197 else:
198 rdict[engine_id] = result
198 rdict[engine_id] = result
199
199
200 return rdict
200 return rdict
201
201
202 @property
202 @property
203 def result(self):
203 def result(self):
204 """result property wrapper for `get(timeout=-1)`."""
204 """result property wrapper for `get(timeout=-1)`."""
205 return self.get()
205 return self.get()
206
206
207 # abbreviated alias:
207 # abbreviated alias:
208 r = result
208 r = result
209
209
210 @property
210 @property
211 def metadata(self):
211 def metadata(self):
212 """property for accessing execution metadata."""
212 """property for accessing execution metadata."""
213 if self._single_result:
213 if self._single_result:
214 return self._metadata[0]
214 return self._metadata[0]
215 else:
215 else:
216 return self._metadata
216 return self._metadata
217
217
218 @property
218 @property
219 def result_dict(self):
219 def result_dict(self):
220 """result property as a dict."""
220 """result property as a dict."""
221 return self.get_dict()
221 return self.get_dict()
222
222
223 def __dict__(self):
223 def __dict__(self):
224 return self.get_dict(0)
224 return self.get_dict(0)
225
225
226 def abort(self):
226 def abort(self):
227 """abort my tasks."""
227 """abort my tasks."""
228 assert not self.ready(), "Can't abort, I am already done!"
228 assert not self.ready(), "Can't abort, I am already done!"
229 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
229 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
230
230
231 @property
231 @property
232 def sent(self):
232 def sent(self):
233 """check whether my messages have been sent."""
233 """check whether my messages have been sent."""
234 return self._tracker.done
234 return self._tracker.done
235
235
236 def wait_for_send(self, timeout=-1):
236 def wait_for_send(self, timeout=-1):
237 """wait for pyzmq send to complete.
237 """wait for pyzmq send to complete.
238
238
239 This is necessary when sending arrays that you intend to edit in-place.
239 This is necessary when sending arrays that you intend to edit in-place.
240 `timeout` is in seconds, and will raise TimeoutError if it is reached
240 `timeout` is in seconds, and will raise TimeoutError if it is reached
241 before the send completes.
241 before the send completes.
242 """
242 """
243 return self._tracker.wait(timeout)
243 return self._tracker.wait(timeout)
244
244
245 #-------------------------------------
245 #-------------------------------------
246 # dict-access
246 # dict-access
247 #-------------------------------------
247 #-------------------------------------
248
248
249 def __getitem__(self, key):
249 def __getitem__(self, key):
250 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
250 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
251 """
251 """
252 if isinstance(key, int):
252 if isinstance(key, int):
253 self._check_ready()
253 self._check_ready()
254 return error.collect_exceptions([self._result[key]], self._fname)[0]
254 return error.collect_exceptions([self._result[key]], self._fname)[0]
255 elif isinstance(key, slice):
255 elif isinstance(key, slice):
256 self._check_ready()
256 self._check_ready()
257 return error.collect_exceptions(self._result[key], self._fname)
257 return error.collect_exceptions(self._result[key], self._fname)
258 elif isinstance(key, string_types):
258 elif isinstance(key, string_types):
259 # metadata proxy *does not* require that results are done
259 # metadata proxy *does not* require that results are done
260 self.wait(0)
260 self.wait(0)
261 values = [ md[key] for md in self._metadata ]
261 values = [ md[key] for md in self._metadata ]
262 if self._single_result:
262 if self._single_result:
263 return values[0]
263 return values[0]
264 else:
264 else:
265 return values
265 return values
266 else:
266 else:
267 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
267 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
268
268
269 def __getattr__(self, key):
269 def __getattr__(self, key):
270 """getattr maps to getitem for convenient attr access to metadata."""
270 """getattr maps to getitem for convenient attr access to metadata."""
271 try:
271 try:
272 return self.__getitem__(key)
272 return self.__getitem__(key)
273 except (error.TimeoutError, KeyError):
273 except (error.TimeoutError, KeyError):
274 raise AttributeError("%r object has no attribute %r"%(
274 raise AttributeError("%r object has no attribute %r"%(
275 self.__class__.__name__, key))
275 self.__class__.__name__, key))
276
276
277 # asynchronous iterator:
277 # asynchronous iterator:
278 def __iter__(self):
278 def __iter__(self):
279 if self._single_result:
279 if self._single_result:
280 raise TypeError("AsyncResults with a single result are not iterable.")
280 raise TypeError("AsyncResults with a single result are not iterable.")
281 try:
281 try:
282 rlist = self.get(0)
282 rlist = self.get(0)
283 except error.TimeoutError:
283 except error.TimeoutError:
284 # wait for each result individually
284 # wait for each result individually
285 for msg_id in self.msg_ids:
285 for msg_id in self.msg_ids:
286 ar = AsyncResult(self._client, msg_id, self._fname)
286 ar = AsyncResult(self._client, msg_id, self._fname)
287 yield ar.get()
287 yield ar.get()
288 else:
288 else:
289 # already done
289 # already done
290 for r in rlist:
290 for r in rlist:
291 yield r
291 yield r
292
292
293 def __len__(self):
293 def __len__(self):
294 return len(self.msg_ids)
294 return len(self.msg_ids)
295
295
296 #-------------------------------------
296 #-------------------------------------
297 # Sugar methods and attributes
297 # Sugar methods and attributes
298 #-------------------------------------
298 #-------------------------------------
299
299
300 def timedelta(self, start, end, start_key=min, end_key=max):
300 def timedelta(self, start, end, start_key=min, end_key=max):
301 """compute the difference between two sets of timestamps
301 """compute the difference between two sets of timestamps
302
302
303 The default behavior is to use the earliest of the first
303 The default behavior is to use the earliest of the first
304 and the latest of the second list, but this can be changed
304 and the latest of the second list, but this can be changed
305 by passing a different
305 by passing a different
306
306
307 Parameters
307 Parameters
308 ----------
308 ----------
309
309
310 start : one or more datetime objects (e.g. ar.submitted)
310 start : one or more datetime objects (e.g. ar.submitted)
311 end : one or more datetime objects (e.g. ar.received)
311 end : one or more datetime objects (e.g. ar.received)
312 start_key : callable
312 start_key : callable
313 Function to call on `start` to extract the relevant
313 Function to call on `start` to extract the relevant
314 entry [defalt: min]
314 entry [defalt: min]
315 end_key : callable
315 end_key : callable
316 Function to call on `end` to extract the relevant
316 Function to call on `end` to extract the relevant
317 entry [default: max]
317 entry [default: max]
318
318
319 Returns
319 Returns
320 -------
320 -------
321
321
322 dt : float
322 dt : float
323 The time elapsed (in seconds) between the two selected timestamps.
323 The time elapsed (in seconds) between the two selected timestamps.
324 """
324 """
325 if not isinstance(start, datetime):
325 if not isinstance(start, datetime):
326 # handle single_result AsyncResults, where ar.stamp is single object,
326 # handle single_result AsyncResults, where ar.stamp is single object,
327 # not a list
327 # not a list
328 start = start_key(start)
328 start = start_key(start)
329 if not isinstance(end, datetime):
329 if not isinstance(end, datetime):
330 # handle single_result AsyncResults, where ar.stamp is single object,
330 # handle single_result AsyncResults, where ar.stamp is single object,
331 # not a list
331 # not a list
332 end = end_key(end)
332 end = end_key(end)
333 return (end - start).total_seconds()
333 return (end - start).total_seconds()
334
334
335 @property
335 @property
336 def progress(self):
336 def progress(self):
337 """the number of tasks which have been completed at this point.
337 """the number of tasks which have been completed at this point.
338
338
339 Fractional progress would be given by 1.0 * ar.progress / len(ar)
339 Fractional progress would be given by 1.0 * ar.progress / len(ar)
340 """
340 """
341 self.wait(0)
341 self.wait(0)
342 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
342 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
343
343
344 @property
344 @property
345 def elapsed(self):
345 def elapsed(self):
346 """elapsed time since initial submission"""
346 """elapsed time since initial submission"""
347 if self.ready():
347 if self.ready():
348 return self.wall_time
348 return self.wall_time
349
349
350 now = submitted = datetime.now()
350 now = submitted = datetime.now()
351 for msg_id in self.msg_ids:
351 for msg_id in self.msg_ids:
352 if msg_id in self._client.metadata:
352 if msg_id in self._client.metadata:
353 stamp = self._client.metadata[msg_id]['submitted']
353 stamp = self._client.metadata[msg_id]['submitted']
354 if stamp and stamp < submitted:
354 if stamp and stamp < submitted:
355 submitted = stamp
355 submitted = stamp
356 return (now-submitted).total_seconds()
356 return (now-submitted).total_seconds()
357
357
358 @property
358 @property
359 @check_ready
359 @check_ready
360 def serial_time(self):
360 def serial_time(self):
361 """serial computation time of a parallel calculation
361 """serial computation time of a parallel calculation
362
362
363 Computed as the sum of (completed-started) of each task
363 Computed as the sum of (completed-started) of each task
364 """
364 """
365 t = 0
365 t = 0
366 for md in self._metadata:
366 for md in self._metadata:
367 t += (md['completed'] - md['started']).total_seconds()
367 t += (md['completed'] - md['started']).total_seconds()
368 return t
368 return t
369
369
370 @property
370 @property
371 @check_ready
371 @check_ready
372 def wall_time(self):
372 def wall_time(self):
373 """actual computation time of a parallel calculation
373 """actual computation time of a parallel calculation
374
374
375 Computed as the time between the latest `received` stamp
375 Computed as the time between the latest `received` stamp
376 and the earliest `submitted`.
376 and the earliest `submitted`.
377
377
378 Only reliable if Client was spinning/waiting when the task finished, because
378 Only reliable if Client was spinning/waiting when the task finished, because
379 the `received` timestamp is created when a result is pulled off of the zmq queue,
379 the `received` timestamp is created when a result is pulled off of the zmq queue,
380 which happens as a result of `client.spin()`.
380 which happens as a result of `client.spin()`.
381
381
382 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
382 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
383
383
384 """
384 """
385 return self.timedelta(self.submitted, self.received)
385 return self.timedelta(self.submitted, self.received)
386
386
387 def wait_interactive(self, interval=1., timeout=-1):
387 def wait_interactive(self, interval=1., timeout=-1):
388 """interactive wait, printing progress at regular intervals"""
388 """interactive wait, printing progress at regular intervals"""
389 if timeout is None:
389 if timeout is None:
390 timeout = -1
390 timeout = -1
391 N = len(self)
391 N = len(self)
392 tic = time.time()
392 tic = time.time()
393 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
393 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
394 self.wait(interval)
394 self.wait(interval)
395 clear_output(wait=True)
395 clear_output(wait=True)
396 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
396 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
397 sys.stdout.flush()
397 sys.stdout.flush()
398 print()
398 print()
399 print("done")
399 print("done")
400
400
401 def _republish_displaypub(self, content, eid):
401 def _republish_displaypub(self, content, eid):
402 """republish individual displaypub content dicts"""
402 """republish individual displaypub content dicts"""
403 try:
403 try:
404 ip = get_ipython()
404 ip = get_ipython()
405 except NameError:
405 except NameError:
406 # displaypub is meaningless outside IPython
406 # displaypub is meaningless outside IPython
407 return
407 return
408 md = content['metadata'] or {}
408 md = content['metadata'] or {}
409 md['engine'] = eid
409 md['engine'] = eid
410 ip.display_pub.publish(content['source'], content['data'], md)
410 ip.display_pub.publish(content['source'], content['data'], md)
411
411
412 def _display_stream(self, text, prefix='', file=None):
412 def _display_stream(self, text, prefix='', file=None):
413 if not text:
413 if not text:
414 # nothing to display
414 # nothing to display
415 return
415 return
416 if file is None:
416 if file is None:
417 file = sys.stdout
417 file = sys.stdout
418 end = '' if text.endswith('\n') else '\n'
418 end = '' if text.endswith('\n') else '\n'
419
419
420 multiline = text.count('\n') > int(text.endswith('\n'))
420 multiline = text.count('\n') > int(text.endswith('\n'))
421 if prefix and multiline and not text.startswith('\n'):
421 if prefix and multiline and not text.startswith('\n'):
422 prefix = prefix + '\n'
422 prefix = prefix + '\n'
423 print("%s%s" % (prefix, text), file=file, end=end)
423 print("%s%s" % (prefix, text), file=file, end=end)
424
424
425
425
426 def _display_single_result(self):
426 def _display_single_result(self):
427 self._display_stream(self.stdout)
427 self._display_stream(self.stdout)
428 self._display_stream(self.stderr, file=sys.stderr)
428 self._display_stream(self.stderr, file=sys.stderr)
429
429
430 try:
430 try:
431 get_ipython()
431 get_ipython()
432 except NameError:
432 except NameError:
433 # displaypub is meaningless outside IPython
433 # displaypub is meaningless outside IPython
434 return
434 return
435
435
436 for output in self.outputs:
436 for output in self.outputs:
437 self._republish_displaypub(output, self.engine_id)
437 self._republish_displaypub(output, self.engine_id)
438
438
439 if self.pyout is not None:
439 if self.pyout is not None:
440 display(self.get())
440 display(self.get())
441
441
442 def _wait_for_outputs(self, timeout=-1):
442 def _wait_for_outputs(self, timeout=-1):
443 """wait for the 'status=idle' message that indicates we have all outputs
443 """wait for the 'status=idle' message that indicates we have all outputs
444 """
444 """
445 if self._outputs_ready or not self._success:
445 if self._outputs_ready or not self._success:
446 # don't wait on errors
446 # don't wait on errors
447 return
447 return
448
448
449 # cast None to -1 for infinite timeout
449 # cast None to -1 for infinite timeout
450 if timeout is None:
450 if timeout is None:
451 timeout = -1
451 timeout = -1
452
452
453 tic = time.time()
453 tic = time.time()
454 while True:
454 while True:
455 self._client._flush_iopub(self._client._iopub_socket)
455 self._client._flush_iopub(self._client._iopub_socket)
456 self._outputs_ready = all(md['outputs_ready']
456 self._outputs_ready = all(md['outputs_ready']
457 for md in self._metadata)
457 for md in self._metadata)
458 if self._outputs_ready or \
458 if self._outputs_ready or \
459 (timeout >= 0 and time.time() > tic + timeout):
459 (timeout >= 0 and time.time() > tic + timeout):
460 break
460 break
461 time.sleep(0.01)
461 time.sleep(0.01)
462
462
463 @check_ready
463 @check_ready
464 def display_outputs(self, groupby="type"):
464 def display_outputs(self, groupby="type"):
465 """republish the outputs of the computation
465 """republish the outputs of the computation
466
466
467 Parameters
467 Parameters
468 ----------
468 ----------
469
469
470 groupby : str [default: type]
470 groupby : str [default: type]
471 if 'type':
471 if 'type':
472 Group outputs by type (show all stdout, then all stderr, etc.):
472 Group outputs by type (show all stdout, then all stderr, etc.):
473
473
474 [stdout:1] foo
474 [stdout:1] foo
475 [stdout:2] foo
475 [stdout:2] foo
476 [stderr:1] bar
476 [stderr:1] bar
477 [stderr:2] bar
477 [stderr:2] bar
478 if 'engine':
478 if 'engine':
479 Display outputs for each engine before moving on to the next:
479 Display outputs for each engine before moving on to the next:
480
480
481 [stdout:1] foo
481 [stdout:1] foo
482 [stderr:1] bar
482 [stderr:1] bar
483 [stdout:2] foo
483 [stdout:2] foo
484 [stderr:2] bar
484 [stderr:2] bar
485
485
486 if 'order':
486 if 'order':
487 Like 'type', but further collate individual displaypub
487 Like 'type', but further collate individual displaypub
488 outputs. This is meant for cases of each command producing
488 outputs. This is meant for cases of each command producing
489 several plots, and you would like to see all of the first
489 several plots, and you would like to see all of the first
490 plots together, then all of the second plots, and so on.
490 plots together, then all of the second plots, and so on.
491 """
491 """
492 if self._single_result:
492 if self._single_result:
493 self._display_single_result()
493 self._display_single_result()
494 return
494 return
495
495
496 stdouts = self.stdout
496 stdouts = self.stdout
497 stderrs = self.stderr
497 stderrs = self.stderr
498 pyouts = self.pyout
498 pyouts = self.pyout
499 output_lists = self.outputs
499 output_lists = self.outputs
500 results = self.get()
500 results = self.get()
501
501
502 targets = self.engine_id
502 targets = self.engine_id
503
503
504 if groupby == "engine":
504 if groupby == "engine":
505 for eid,stdout,stderr,outputs,r,pyout in zip(
505 for eid,stdout,stderr,outputs,r,pyout in zip(
506 targets, stdouts, stderrs, output_lists, results, pyouts
506 targets, stdouts, stderrs, output_lists, results, pyouts
507 ):
507 ):
508 self._display_stream(stdout, '[stdout:%i] ' % eid)
508 self._display_stream(stdout, '[stdout:%i] ' % eid)
509 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
509 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
510
510
511 try:
511 try:
512 get_ipython()
512 get_ipython()
513 except NameError:
513 except NameError:
514 # displaypub is meaningless outside IPython
514 # displaypub is meaningless outside IPython
515 return
515 return
516
516
517 if outputs or pyout is not None:
517 if outputs or pyout is not None:
518 _raw_text('[output:%i]' % eid)
518 _raw_text('[output:%i]' % eid)
519
519
520 for output in outputs:
520 for output in outputs:
521 self._republish_displaypub(output, eid)
521 self._republish_displaypub(output, eid)
522
522
523 if pyout is not None:
523 if pyout is not None:
524 display(r)
524 display(r)
525
525
526 elif groupby in ('type', 'order'):
526 elif groupby in ('type', 'order'):
527 # republish stdout:
527 # republish stdout:
528 for eid,stdout in zip(targets, stdouts):
528 for eid,stdout in zip(targets, stdouts):
529 self._display_stream(stdout, '[stdout:%i] ' % eid)
529 self._display_stream(stdout, '[stdout:%i] ' % eid)
530
530
531 # republish stderr:
531 # republish stderr:
532 for eid,stderr in zip(targets, stderrs):
532 for eid,stderr in zip(targets, stderrs):
533 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
533 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
534
534
535 try:
535 try:
536 get_ipython()
536 get_ipython()
537 except NameError:
537 except NameError:
538 # displaypub is meaningless outside IPython
538 # displaypub is meaningless outside IPython
539 return
539 return
540
540
541 if groupby == 'order':
541 if groupby == 'order':
542 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
542 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
543 N = max(len(outputs) for outputs in output_lists)
543 N = max(len(outputs) for outputs in output_lists)
544 for i in range(N):
544 for i in range(N):
545 for eid in targets:
545 for eid in targets:
546 outputs = output_dict[eid]
546 outputs = output_dict[eid]
547 if len(outputs) >= N:
547 if len(outputs) >= N:
548 _raw_text('[output:%i]' % eid)
548 _raw_text('[output:%i]' % eid)
549 self._republish_displaypub(outputs[i], eid)
549 self._republish_displaypub(outputs[i], eid)
550 else:
550 else:
551 # republish displaypub output
551 # republish displaypub output
552 for eid,outputs in zip(targets, output_lists):
552 for eid,outputs in zip(targets, output_lists):
553 if outputs:
553 if outputs:
554 _raw_text('[output:%i]' % eid)
554 _raw_text('[output:%i]' % eid)
555 for output in outputs:
555 for output in outputs:
556 self._republish_displaypub(output, eid)
556 self._republish_displaypub(output, eid)
557
557
558 # finally, add pyout:
558 # finally, add pyout:
559 for eid,r,pyout in zip(targets, results, pyouts):
559 for eid,r,pyout in zip(targets, results, pyouts):
560 if pyout is not None:
560 if pyout is not None:
561 display(r)
561 display(r)
562
562
563 else:
563 else:
564 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
564 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
565
565
566
566
567
567
568
568
569 class AsyncMapResult(AsyncResult):
569 class AsyncMapResult(AsyncResult):
570 """Class for representing results of non-blocking gathers.
570 """Class for representing results of non-blocking gathers.
571
571
572 This will properly reconstruct the gather.
572 This will properly reconstruct the gather.
573
573
574 This class is iterable at any time, and will wait on results as they come.
574 This class is iterable at any time, and will wait on results as they come.
575
575
576 If ordered=False, then the first results to arrive will come first, otherwise
576 If ordered=False, then the first results to arrive will come first, otherwise
577 results will be yielded in the order they were submitted.
577 results will be yielded in the order they were submitted.
578
578
579 """
579 """
580
580
581 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
581 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
582 AsyncResult.__init__(self, client, msg_ids, fname=fname)
582 AsyncResult.__init__(self, client, msg_ids, fname=fname)
583 self._mapObject = mapObject
583 self._mapObject = mapObject
584 self._single_result = False
584 self._single_result = False
585 self.ordered = ordered
585 self.ordered = ordered
586
586
587 def _reconstruct_result(self, res):
587 def _reconstruct_result(self, res):
588 """Perform the gather on the actual results."""
588 """Perform the gather on the actual results."""
589 return self._mapObject.joinPartitions(res)
589 return self._mapObject.joinPartitions(res)
590
590
591 # asynchronous iterator:
591 # asynchronous iterator:
592 def __iter__(self):
592 def __iter__(self):
593 it = self._ordered_iter if self.ordered else self._unordered_iter
593 it = self._ordered_iter if self.ordered else self._unordered_iter
594 for r in it():
594 for r in it():
595 yield r
595 yield r
596
596
597 # asynchronous ordered iterator:
597 # asynchronous ordered iterator:
598 def _ordered_iter(self):
598 def _ordered_iter(self):
599 """iterator for results *as they arrive*, preserving submission order."""
599 """iterator for results *as they arrive*, preserving submission order."""
600 try:
600 try:
601 rlist = self.get(0)
601 rlist = self.get(0)
602 except error.TimeoutError:
602 except error.TimeoutError:
603 # wait for each result individually
603 # wait for each result individually
604 for msg_id in self.msg_ids:
604 for msg_id in self.msg_ids:
605 ar = AsyncResult(self._client, msg_id, self._fname)
605 ar = AsyncResult(self._client, msg_id, self._fname)
606 rlist = ar.get()
606 rlist = ar.get()
607 try:
607 try:
608 for r in rlist:
608 for r in rlist:
609 yield r
609 yield r
610 except TypeError:
610 except TypeError:
611 # flattened, not a list
611 # flattened, not a list
612 # this could get broken by flattened data that returns iterables
612 # this could get broken by flattened data that returns iterables
613 # but most calls to map do not expose the `flatten` argument
613 # but most calls to map do not expose the `flatten` argument
614 yield rlist
614 yield rlist
615 else:
615 else:
616 # already done
616 # already done
617 for r in rlist:
617 for r in rlist:
618 yield r
618 yield r
619
619
620 # asynchronous unordered iterator:
620 # asynchronous unordered iterator:
621 def _unordered_iter(self):
621 def _unordered_iter(self):
622 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
622 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
623 try:
623 try:
624 rlist = self.get(0)
624 rlist = self.get(0)
625 except error.TimeoutError:
625 except error.TimeoutError:
626 pending = set(self.msg_ids)
626 pending = set(self.msg_ids)
627 while pending:
627 while pending:
628 try:
628 try:
629 self._client.wait(pending, 1e-3)
629 self._client.wait(pending, 1e-3)
630 except error.TimeoutError:
630 except error.TimeoutError:
631 # ignore timeout error, because that only means
631 # ignore timeout error, because that only means
632 # *some* jobs are outstanding
632 # *some* jobs are outstanding
633 pass
633 pass
634 # update ready set with those no longer outstanding:
634 # update ready set with those no longer outstanding:
635 ready = pending.difference(self._client.outstanding)
635 ready = pending.difference(self._client.outstanding)
636 # update pending to exclude those that are finished
636 # update pending to exclude those that are finished
637 pending = pending.difference(ready)
637 pending = pending.difference(ready)
638 while ready:
638 while ready:
639 msg_id = ready.pop()
639 msg_id = ready.pop()
640 ar = AsyncResult(self._client, msg_id, self._fname)
640 ar = AsyncResult(self._client, msg_id, self._fname)
641 rlist = ar.get()
641 rlist = ar.get()
642 try:
642 try:
643 for r in rlist:
643 for r in rlist:
644 yield r
644 yield r
645 except TypeError:
645 except TypeError:
646 # flattened, not a list
646 # flattened, not a list
647 # this could get broken by flattened data that returns iterables
647 # this could get broken by flattened data that returns iterables
648 # but most calls to map do not expose the `flatten` argument
648 # but most calls to map do not expose the `flatten` argument
649 yield rlist
649 yield rlist
650 else:
650 else:
651 # already done
651 # already done
652 for r in rlist:
652 for r in rlist:
653 yield r
653 yield r
654
654
655
655
656 class AsyncHubResult(AsyncResult):
656 class AsyncHubResult(AsyncResult):
657 """Class to wrap pending results that must be requested from the Hub.
657 """Class to wrap pending results that must be requested from the Hub.
658
658
659 Note that waiting/polling on these objects requires polling the Hubover the network,
659 Note that waiting/polling on these objects requires polling the Hubover the network,
660 so use `AsyncHubResult.wait()` sparingly.
660 so use `AsyncHubResult.wait()` sparingly.
661 """
661 """
662
662
663 def _wait_for_outputs(self, timeout=-1):
663 def _wait_for_outputs(self, timeout=-1):
664 """no-op, because HubResults are never incomplete"""
664 """no-op, because HubResults are never incomplete"""
665 self._outputs_ready = True
665 self._outputs_ready = True
666
666
667 def wait(self, timeout=-1):
667 def wait(self, timeout=-1):
668 """wait for result to complete."""
668 """wait for result to complete."""
669 start = time.time()
669 start = time.time()
670 if self._ready:
670 if self._ready:
671 return
671 return
672 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
672 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
673 local_ready = self._client.wait(local_ids, timeout)
673 local_ready = self._client.wait(local_ids, timeout)
674 if local_ready:
674 if local_ready:
675 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
675 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
676 if not remote_ids:
676 if not remote_ids:
677 self._ready = True
677 self._ready = True
678 else:
678 else:
679 rdict = self._client.result_status(remote_ids, status_only=False)
679 rdict = self._client.result_status(remote_ids, status_only=False)
680 pending = rdict['pending']
680 pending = rdict['pending']
681 while pending and (timeout < 0 or time.time() < start+timeout):
681 while pending and (timeout < 0 or time.time() < start+timeout):
682 rdict = self._client.result_status(remote_ids, status_only=False)
682 rdict = self._client.result_status(remote_ids, status_only=False)
683 pending = rdict['pending']
683 pending = rdict['pending']
684 if pending:
684 if pending:
685 time.sleep(0.1)
685 time.sleep(0.1)
686 if not pending:
686 if not pending:
687 self._ready = True
687 self._ready = True
688 if self._ready:
688 if self._ready:
689 try:
689 try:
690 results = map(self._client.results.get, self.msg_ids)
690 results = list(map(self._client.results.get, self.msg_ids))
691 self._result = results
691 self._result = results
692 if self._single_result:
692 if self._single_result:
693 r = results[0]
693 r = results[0]
694 if isinstance(r, Exception):
694 if isinstance(r, Exception):
695 raise r
695 raise r
696 else:
696 else:
697 results = error.collect_exceptions(results, self._fname)
697 results = error.collect_exceptions(results, self._fname)
698 self._result = self._reconstruct_result(results)
698 self._result = self._reconstruct_result(results)
699 except Exception as e:
699 except Exception as e:
700 self._exception = e
700 self._exception = e
701 self._success = False
701 self._success = False
702 else:
702 else:
703 self._success = True
703 self._success = True
704 finally:
704 finally:
705 self._metadata = map(self._client.metadata.get, self.msg_ids)
705 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
706
706
707 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
707 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1854 +1,1855 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 from __future__ import print_function
7 from __future__ import print_function
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import os
19 import os
20 import json
20 import json
21 import sys
21 import sys
22 from threading import Thread, Event
22 from threading import Thread, Event
23 import time
23 import time
24 import warnings
24 import warnings
25 from datetime import datetime
25 from datetime import datetime
26 from getpass import getpass
26 from getpass import getpass
27 from pprint import pprint
27 from pprint import pprint
28
28
29 pjoin = os.path.join
29 pjoin = os.path.join
30
30
31 import zmq
31 import zmq
32 # from zmq.eventloop import ioloop, zmqstream
32 # from zmq.eventloop import ioloop, zmqstream
33
33
34 from IPython.config.configurable import MultipleInstanceError
34 from IPython.config.configurable import MultipleInstanceError
35 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.application import BaseIPythonApplication
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37
37
38 from IPython.utils.capture import RichOutput
38 from IPython.utils.capture import RichOutput
39 from IPython.utils.coloransi import TermColors
39 from IPython.utils.coloransi import TermColors
40 from IPython.utils.jsonutil import rekey
40 from IPython.utils.jsonutil import rekey
41 from IPython.utils.localinterfaces import localhost, is_local_ip
41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 from IPython.utils.path import get_ipython_dir
42 from IPython.utils.path import get_ipython_dir
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 Dict, List, Bool, Set, Any)
45 Dict, List, Bool, Set, Any)
46 from IPython.external.decorator import decorator
46 from IPython.external.decorator import decorator
47 from IPython.external.ssh import tunnel
47 from IPython.external.ssh import tunnel
48
48
49 from IPython.parallel import Reference
49 from IPython.parallel import Reference
50 from IPython.parallel import error
50 from IPython.parallel import error
51 from IPython.parallel import util
51 from IPython.parallel import util
52
52
53 from IPython.kernel.zmq.session import Session, Message
53 from IPython.kernel.zmq.session import Session, Message
54 from IPython.kernel.zmq import serialize
54 from IPython.kernel.zmq import serialize
55
55
56 from .asyncresult import AsyncResult, AsyncHubResult
56 from .asyncresult import AsyncResult, AsyncHubResult
57 from .view import DirectView, LoadBalancedView
57 from .view import DirectView, LoadBalancedView
58
58
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60 # Decorators for Client methods
60 # Decorators for Client methods
61 #--------------------------------------------------------------------------
61 #--------------------------------------------------------------------------
62
62
63 @decorator
63 @decorator
64 def spin_first(f, self, *args, **kwargs):
64 def spin_first(f, self, *args, **kwargs):
65 """Call spin() to sync state prior to calling the method."""
65 """Call spin() to sync state prior to calling the method."""
66 self.spin()
66 self.spin()
67 return f(self, *args, **kwargs)
67 return f(self, *args, **kwargs)
68
68
69
69
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71 # Classes
71 # Classes
72 #--------------------------------------------------------------------------
72 #--------------------------------------------------------------------------
73
73
74
74
75 class ExecuteReply(RichOutput):
75 class ExecuteReply(RichOutput):
76 """wrapper for finished Execute results"""
76 """wrapper for finished Execute results"""
77 def __init__(self, msg_id, content, metadata):
77 def __init__(self, msg_id, content, metadata):
78 self.msg_id = msg_id
78 self.msg_id = msg_id
79 self._content = content
79 self._content = content
80 self.execution_count = content['execution_count']
80 self.execution_count = content['execution_count']
81 self.metadata = metadata
81 self.metadata = metadata
82
82
83 # RichOutput overrides
83 # RichOutput overrides
84
84
85 @property
85 @property
86 def source(self):
86 def source(self):
87 pyout = self.metadata['pyout']
87 pyout = self.metadata['pyout']
88 if pyout:
88 if pyout:
89 return pyout.get('source', '')
89 return pyout.get('source', '')
90
90
91 @property
91 @property
92 def data(self):
92 def data(self):
93 pyout = self.metadata['pyout']
93 pyout = self.metadata['pyout']
94 if pyout:
94 if pyout:
95 return pyout.get('data', {})
95 return pyout.get('data', {})
96
96
97 @property
97 @property
98 def _metadata(self):
98 def _metadata(self):
99 pyout = self.metadata['pyout']
99 pyout = self.metadata['pyout']
100 if pyout:
100 if pyout:
101 return pyout.get('metadata', {})
101 return pyout.get('metadata', {})
102
102
103 def display(self):
103 def display(self):
104 from IPython.display import publish_display_data
104 from IPython.display import publish_display_data
105 publish_display_data(self.source, self.data, self.metadata)
105 publish_display_data(self.source, self.data, self.metadata)
106
106
107 def _repr_mime_(self, mime):
107 def _repr_mime_(self, mime):
108 if mime not in self.data:
108 if mime not in self.data:
109 return
109 return
110 data = self.data[mime]
110 data = self.data[mime]
111 if mime in self._metadata:
111 if mime in self._metadata:
112 return data, self._metadata[mime]
112 return data, self._metadata[mime]
113 else:
113 else:
114 return data
114 return data
115
115
116 def __getitem__(self, key):
116 def __getitem__(self, key):
117 return self.metadata[key]
117 return self.metadata[key]
118
118
119 def __getattr__(self, key):
119 def __getattr__(self, key):
120 if key not in self.metadata:
120 if key not in self.metadata:
121 raise AttributeError(key)
121 raise AttributeError(key)
122 return self.metadata[key]
122 return self.metadata[key]
123
123
124 def __repr__(self):
124 def __repr__(self):
125 pyout = self.metadata['pyout'] or {'data':{}}
125 pyout = self.metadata['pyout'] or {'data':{}}
126 text_out = pyout['data'].get('text/plain', '')
126 text_out = pyout['data'].get('text/plain', '')
127 if len(text_out) > 32:
127 if len(text_out) > 32:
128 text_out = text_out[:29] + '...'
128 text_out = text_out[:29] + '...'
129
129
130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
131
131
132 def _repr_pretty_(self, p, cycle):
132 def _repr_pretty_(self, p, cycle):
133 pyout = self.metadata['pyout'] or {'data':{}}
133 pyout = self.metadata['pyout'] or {'data':{}}
134 text_out = pyout['data'].get('text/plain', '')
134 text_out = pyout['data'].get('text/plain', '')
135
135
136 if not text_out:
136 if not text_out:
137 return
137 return
138
138
139 try:
139 try:
140 ip = get_ipython()
140 ip = get_ipython()
141 except NameError:
141 except NameError:
142 colors = "NoColor"
142 colors = "NoColor"
143 else:
143 else:
144 colors = ip.colors
144 colors = ip.colors
145
145
146 if colors == "NoColor":
146 if colors == "NoColor":
147 out = normal = ""
147 out = normal = ""
148 else:
148 else:
149 out = TermColors.Red
149 out = TermColors.Red
150 normal = TermColors.Normal
150 normal = TermColors.Normal
151
151
152 if '\n' in text_out and not text_out.startswith('\n'):
152 if '\n' in text_out and not text_out.startswith('\n'):
153 # add newline for multiline reprs
153 # add newline for multiline reprs
154 text_out = '\n' + text_out
154 text_out = '\n' + text_out
155
155
156 p.text(
156 p.text(
157 out + u'Out[%i:%i]: ' % (
157 out + u'Out[%i:%i]: ' % (
158 self.metadata['engine_id'], self.execution_count
158 self.metadata['engine_id'], self.execution_count
159 ) + normal + text_out
159 ) + normal + text_out
160 )
160 )
161
161
162
162
163 class Metadata(dict):
163 class Metadata(dict):
164 """Subclass of dict for initializing metadata values.
164 """Subclass of dict for initializing metadata values.
165
165
166 Attribute access works on keys.
166 Attribute access works on keys.
167
167
168 These objects have a strict set of keys - errors will raise if you try
168 These objects have a strict set of keys - errors will raise if you try
169 to add new keys.
169 to add new keys.
170 """
170 """
171 def __init__(self, *args, **kwargs):
171 def __init__(self, *args, **kwargs):
172 dict.__init__(self)
172 dict.__init__(self)
173 md = {'msg_id' : None,
173 md = {'msg_id' : None,
174 'submitted' : None,
174 'submitted' : None,
175 'started' : None,
175 'started' : None,
176 'completed' : None,
176 'completed' : None,
177 'received' : None,
177 'received' : None,
178 'engine_uuid' : None,
178 'engine_uuid' : None,
179 'engine_id' : None,
179 'engine_id' : None,
180 'follow' : None,
180 'follow' : None,
181 'after' : None,
181 'after' : None,
182 'status' : None,
182 'status' : None,
183
183
184 'pyin' : None,
184 'pyin' : None,
185 'pyout' : None,
185 'pyout' : None,
186 'pyerr' : None,
186 'pyerr' : None,
187 'stdout' : '',
187 'stdout' : '',
188 'stderr' : '',
188 'stderr' : '',
189 'outputs' : [],
189 'outputs' : [],
190 'data': {},
190 'data': {},
191 'outputs_ready' : False,
191 'outputs_ready' : False,
192 }
192 }
193 self.update(md)
193 self.update(md)
194 self.update(dict(*args, **kwargs))
194 self.update(dict(*args, **kwargs))
195
195
196 def __getattr__(self, key):
196 def __getattr__(self, key):
197 """getattr aliased to getitem"""
197 """getattr aliased to getitem"""
198 if key in self:
198 if key in self:
199 return self[key]
199 return self[key]
200 else:
200 else:
201 raise AttributeError(key)
201 raise AttributeError(key)
202
202
203 def __setattr__(self, key, value):
203 def __setattr__(self, key, value):
204 """setattr aliased to setitem, with strict"""
204 """setattr aliased to setitem, with strict"""
205 if key in self:
205 if key in self:
206 self[key] = value
206 self[key] = value
207 else:
207 else:
208 raise AttributeError(key)
208 raise AttributeError(key)
209
209
210 def __setitem__(self, key, value):
210 def __setitem__(self, key, value):
211 """strict static key enforcement"""
211 """strict static key enforcement"""
212 if key in self:
212 if key in self:
213 dict.__setitem__(self, key, value)
213 dict.__setitem__(self, key, value)
214 else:
214 else:
215 raise KeyError(key)
215 raise KeyError(key)
216
216
217
217
218 class Client(HasTraits):
218 class Client(HasTraits):
219 """A semi-synchronous client to the IPython ZMQ cluster
219 """A semi-synchronous client to the IPython ZMQ cluster
220
220
221 Parameters
221 Parameters
222 ----------
222 ----------
223
223
224 url_file : str/unicode; path to ipcontroller-client.json
224 url_file : str/unicode; path to ipcontroller-client.json
225 This JSON file should contain all the information needed to connect to a cluster,
225 This JSON file should contain all the information needed to connect to a cluster,
226 and is likely the only argument needed.
226 and is likely the only argument needed.
227 Connection information for the Hub's registration. If a json connector
227 Connection information for the Hub's registration. If a json connector
228 file is given, then likely no further configuration is necessary.
228 file is given, then likely no further configuration is necessary.
229 [Default: use profile]
229 [Default: use profile]
230 profile : bytes
230 profile : bytes
231 The name of the Cluster profile to be used to find connector information.
231 The name of the Cluster profile to be used to find connector information.
232 If run from an IPython application, the default profile will be the same
232 If run from an IPython application, the default profile will be the same
233 as the running application, otherwise it will be 'default'.
233 as the running application, otherwise it will be 'default'.
234 cluster_id : str
234 cluster_id : str
235 String id to added to runtime files, to prevent name collisions when using
235 String id to added to runtime files, to prevent name collisions when using
236 multiple clusters with a single profile simultaneously.
236 multiple clusters with a single profile simultaneously.
237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
238 Since this is text inserted into filenames, typical recommendations apply:
238 Since this is text inserted into filenames, typical recommendations apply:
239 Simple character strings are ideal, and spaces are not recommended (but
239 Simple character strings are ideal, and spaces are not recommended (but
240 should generally work)
240 should generally work)
241 context : zmq.Context
241 context : zmq.Context
242 Pass an existing zmq.Context instance, otherwise the client will create its own.
242 Pass an existing zmq.Context instance, otherwise the client will create its own.
243 debug : bool
243 debug : bool
244 flag for lots of message printing for debug purposes
244 flag for lots of message printing for debug purposes
245 timeout : int/float
245 timeout : int/float
246 time (in seconds) to wait for connection replies from the Hub
246 time (in seconds) to wait for connection replies from the Hub
247 [Default: 10]
247 [Default: 10]
248
248
249 #-------------- session related args ----------------
249 #-------------- session related args ----------------
250
250
251 config : Config object
251 config : Config object
252 If specified, this will be relayed to the Session for configuration
252 If specified, this will be relayed to the Session for configuration
253 username : str
253 username : str
254 set username for the session object
254 set username for the session object
255
255
256 #-------------- ssh related args ----------------
256 #-------------- ssh related args ----------------
257 # These are args for configuring the ssh tunnel to be used
257 # These are args for configuring the ssh tunnel to be used
258 # credentials are used to forward connections over ssh to the Controller
258 # credentials are used to forward connections over ssh to the Controller
259 # Note that the ip given in `addr` needs to be relative to sshserver
259 # Note that the ip given in `addr` needs to be relative to sshserver
260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
261 # and set sshserver as the same machine the Controller is on. However,
261 # and set sshserver as the same machine the Controller is on. However,
262 # the only requirement is that sshserver is able to see the Controller
262 # the only requirement is that sshserver is able to see the Controller
263 # (i.e. is within the same trusted network).
263 # (i.e. is within the same trusted network).
264
264
265 sshserver : str
265 sshserver : str
266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
267 If keyfile or password is specified, and this is not, it will default to
267 If keyfile or password is specified, and this is not, it will default to
268 the ip given in addr.
268 the ip given in addr.
269 sshkey : str; path to ssh private key file
269 sshkey : str; path to ssh private key file
270 This specifies a key to be used in ssh login, default None.
270 This specifies a key to be used in ssh login, default None.
271 Regular default ssh keys will be used without specifying this argument.
271 Regular default ssh keys will be used without specifying this argument.
272 password : str
272 password : str
273 Your ssh password to sshserver. Note that if this is left None,
273 Your ssh password to sshserver. Note that if this is left None,
274 you will be prompted for it if passwordless key based login is unavailable.
274 you will be prompted for it if passwordless key based login is unavailable.
275 paramiko : bool
275 paramiko : bool
276 flag for whether to use paramiko instead of shell ssh for tunneling.
276 flag for whether to use paramiko instead of shell ssh for tunneling.
277 [default: True on win32, False else]
277 [default: True on win32, False else]
278
278
279
279
280 Attributes
280 Attributes
281 ----------
281 ----------
282
282
283 ids : list of int engine IDs
283 ids : list of int engine IDs
284 requesting the ids attribute always synchronizes
284 requesting the ids attribute always synchronizes
285 the registration state. To request ids without synchronization,
285 the registration state. To request ids without synchronization,
286 use semi-private _ids attributes.
286 use semi-private _ids attributes.
287
287
288 history : list of msg_ids
288 history : list of msg_ids
289 a list of msg_ids, keeping track of all the execution
289 a list of msg_ids, keeping track of all the execution
290 messages you have submitted in order.
290 messages you have submitted in order.
291
291
292 outstanding : set of msg_ids
292 outstanding : set of msg_ids
293 a set of msg_ids that have been submitted, but whose
293 a set of msg_ids that have been submitted, but whose
294 results have not yet been received.
294 results have not yet been received.
295
295
296 results : dict
296 results : dict
297 a dict of all our results, keyed by msg_id
297 a dict of all our results, keyed by msg_id
298
298
299 block : bool
299 block : bool
300 determines default behavior when block not specified
300 determines default behavior when block not specified
301 in execution methods
301 in execution methods
302
302
303 Methods
303 Methods
304 -------
304 -------
305
305
306 spin
306 spin
307 flushes incoming results and registration state changes
307 flushes incoming results and registration state changes
308 control methods spin, and requesting `ids` also ensures up to date
308 control methods spin, and requesting `ids` also ensures up to date
309
309
310 wait
310 wait
311 wait on one or more msg_ids
311 wait on one or more msg_ids
312
312
313 execution methods
313 execution methods
314 apply
314 apply
315 legacy: execute, run
315 legacy: execute, run
316
316
317 data movement
317 data movement
318 push, pull, scatter, gather
318 push, pull, scatter, gather
319
319
320 query methods
320 query methods
321 queue_status, get_result, purge, result_status
321 queue_status, get_result, purge, result_status
322
322
323 control methods
323 control methods
324 abort, shutdown
324 abort, shutdown
325
325
326 """
326 """
327
327
328
328
329 block = Bool(False)
329 block = Bool(False)
330 outstanding = Set()
330 outstanding = Set()
331 results = Instance('collections.defaultdict', (dict,))
331 results = Instance('collections.defaultdict', (dict,))
332 metadata = Instance('collections.defaultdict', (Metadata,))
332 metadata = Instance('collections.defaultdict', (Metadata,))
333 history = List()
333 history = List()
334 debug = Bool(False)
334 debug = Bool(False)
335 _spin_thread = Any()
335 _spin_thread = Any()
336 _stop_spinning = Any()
336 _stop_spinning = Any()
337
337
338 profile=Unicode()
338 profile=Unicode()
339 def _profile_default(self):
339 def _profile_default(self):
340 if BaseIPythonApplication.initialized():
340 if BaseIPythonApplication.initialized():
341 # an IPython app *might* be running, try to get its profile
341 # an IPython app *might* be running, try to get its profile
342 try:
342 try:
343 return BaseIPythonApplication.instance().profile
343 return BaseIPythonApplication.instance().profile
344 except (AttributeError, MultipleInstanceError):
344 except (AttributeError, MultipleInstanceError):
345 # could be a *different* subclass of config.Application,
345 # could be a *different* subclass of config.Application,
346 # which would raise one of these two errors.
346 # which would raise one of these two errors.
347 return u'default'
347 return u'default'
348 else:
348 else:
349 return u'default'
349 return u'default'
350
350
351
351
352 _outstanding_dict = Instance('collections.defaultdict', (set,))
352 _outstanding_dict = Instance('collections.defaultdict', (set,))
353 _ids = List()
353 _ids = List()
354 _connected=Bool(False)
354 _connected=Bool(False)
355 _ssh=Bool(False)
355 _ssh=Bool(False)
356 _context = Instance('zmq.Context')
356 _context = Instance('zmq.Context')
357 _config = Dict()
357 _config = Dict()
358 _engines=Instance(util.ReverseDict, (), {})
358 _engines=Instance(util.ReverseDict, (), {})
359 # _hub_socket=Instance('zmq.Socket')
359 # _hub_socket=Instance('zmq.Socket')
360 _query_socket=Instance('zmq.Socket')
360 _query_socket=Instance('zmq.Socket')
361 _control_socket=Instance('zmq.Socket')
361 _control_socket=Instance('zmq.Socket')
362 _iopub_socket=Instance('zmq.Socket')
362 _iopub_socket=Instance('zmq.Socket')
363 _notification_socket=Instance('zmq.Socket')
363 _notification_socket=Instance('zmq.Socket')
364 _mux_socket=Instance('zmq.Socket')
364 _mux_socket=Instance('zmq.Socket')
365 _task_socket=Instance('zmq.Socket')
365 _task_socket=Instance('zmq.Socket')
366 _task_scheme=Unicode()
366 _task_scheme=Unicode()
367 _closed = False
367 _closed = False
368 _ignored_control_replies=Integer(0)
368 _ignored_control_replies=Integer(0)
369 _ignored_hub_replies=Integer(0)
369 _ignored_hub_replies=Integer(0)
370
370
371 def __new__(self, *args, **kw):
371 def __new__(self, *args, **kw):
372 # don't raise on positional args
372 # don't raise on positional args
373 return HasTraits.__new__(self, **kw)
373 return HasTraits.__new__(self, **kw)
374
374
375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
376 context=None, debug=False,
376 context=None, debug=False,
377 sshserver=None, sshkey=None, password=None, paramiko=None,
377 sshserver=None, sshkey=None, password=None, paramiko=None,
378 timeout=10, cluster_id=None, **extra_args
378 timeout=10, cluster_id=None, **extra_args
379 ):
379 ):
380 if profile:
380 if profile:
381 super(Client, self).__init__(debug=debug, profile=profile)
381 super(Client, self).__init__(debug=debug, profile=profile)
382 else:
382 else:
383 super(Client, self).__init__(debug=debug)
383 super(Client, self).__init__(debug=debug)
384 if context is None:
384 if context is None:
385 context = zmq.Context.instance()
385 context = zmq.Context.instance()
386 self._context = context
386 self._context = context
387 self._stop_spinning = Event()
387 self._stop_spinning = Event()
388
388
389 if 'url_or_file' in extra_args:
389 if 'url_or_file' in extra_args:
390 url_file = extra_args['url_or_file']
390 url_file = extra_args['url_or_file']
391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
392
392
393 if url_file and util.is_url(url_file):
393 if url_file and util.is_url(url_file):
394 raise ValueError("single urls cannot be specified, url-files must be used.")
394 raise ValueError("single urls cannot be specified, url-files must be used.")
395
395
396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
397
397
398 if self._cd is not None:
398 if self._cd is not None:
399 if url_file is None:
399 if url_file is None:
400 if not cluster_id:
400 if not cluster_id:
401 client_json = 'ipcontroller-client.json'
401 client_json = 'ipcontroller-client.json'
402 else:
402 else:
403 client_json = 'ipcontroller-%s-client.json' % cluster_id
403 client_json = 'ipcontroller-%s-client.json' % cluster_id
404 url_file = pjoin(self._cd.security_dir, client_json)
404 url_file = pjoin(self._cd.security_dir, client_json)
405 if url_file is None:
405 if url_file is None:
406 raise ValueError(
406 raise ValueError(
407 "I can't find enough information to connect to a hub!"
407 "I can't find enough information to connect to a hub!"
408 " Please specify at least one of url_file or profile."
408 " Please specify at least one of url_file or profile."
409 )
409 )
410
410
411 with open(url_file) as f:
411 with open(url_file) as f:
412 cfg = json.load(f)
412 cfg = json.load(f)
413
413
414 self._task_scheme = cfg['task_scheme']
414 self._task_scheme = cfg['task_scheme']
415
415
416 # sync defaults from args, json:
416 # sync defaults from args, json:
417 if sshserver:
417 if sshserver:
418 cfg['ssh'] = sshserver
418 cfg['ssh'] = sshserver
419
419
420 location = cfg.setdefault('location', None)
420 location = cfg.setdefault('location', None)
421
421
422 proto,addr = cfg['interface'].split('://')
422 proto,addr = cfg['interface'].split('://')
423 addr = util.disambiguate_ip_address(addr, location)
423 addr = util.disambiguate_ip_address(addr, location)
424 cfg['interface'] = "%s://%s" % (proto, addr)
424 cfg['interface'] = "%s://%s" % (proto, addr)
425
425
426 # turn interface,port into full urls:
426 # turn interface,port into full urls:
427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
429
429
430 url = cfg['registration']
430 url = cfg['registration']
431
431
432 if location is not None and addr == localhost():
432 if location is not None and addr == localhost():
433 # location specified, and connection is expected to be local
433 # location specified, and connection is expected to be local
434 if not is_local_ip(location) and not sshserver:
434 if not is_local_ip(location) and not sshserver:
435 # load ssh from JSON *only* if the controller is not on
435 # load ssh from JSON *only* if the controller is not on
436 # this machine
436 # this machine
437 sshserver=cfg['ssh']
437 sshserver=cfg['ssh']
438 if not is_local_ip(location) and not sshserver:
438 if not is_local_ip(location) and not sshserver:
439 # warn if no ssh specified, but SSH is probably needed
439 # warn if no ssh specified, but SSH is probably needed
440 # This is only a warning, because the most likely cause
440 # This is only a warning, because the most likely cause
441 # is a local Controller on a laptop whose IP is dynamic
441 # is a local Controller on a laptop whose IP is dynamic
442 warnings.warn("""
442 warnings.warn("""
443 Controller appears to be listening on localhost, but not on this machine.
443 Controller appears to be listening on localhost, but not on this machine.
444 If this is true, you should specify Client(...,sshserver='you@%s')
444 If this is true, you should specify Client(...,sshserver='you@%s')
445 or instruct your controller to listen on an external IP."""%location,
445 or instruct your controller to listen on an external IP."""%location,
446 RuntimeWarning)
446 RuntimeWarning)
447 elif not sshserver:
447 elif not sshserver:
448 # otherwise sync with cfg
448 # otherwise sync with cfg
449 sshserver = cfg['ssh']
449 sshserver = cfg['ssh']
450
450
451 self._config = cfg
451 self._config = cfg
452
452
453 self._ssh = bool(sshserver or sshkey or password)
453 self._ssh = bool(sshserver or sshkey or password)
454 if self._ssh and sshserver is None:
454 if self._ssh and sshserver is None:
455 # default to ssh via localhost
455 # default to ssh via localhost
456 sshserver = addr
456 sshserver = addr
457 if self._ssh and password is None:
457 if self._ssh and password is None:
458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
459 password=False
459 password=False
460 else:
460 else:
461 password = getpass("SSH Password for %s: "%sshserver)
461 password = getpass("SSH Password for %s: "%sshserver)
462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
463
463
464 # configure and construct the session
464 # configure and construct the session
465 try:
465 try:
466 extra_args['packer'] = cfg['pack']
466 extra_args['packer'] = cfg['pack']
467 extra_args['unpacker'] = cfg['unpack']
467 extra_args['unpacker'] = cfg['unpack']
468 extra_args['key'] = cast_bytes(cfg['key'])
468 extra_args['key'] = cast_bytes(cfg['key'])
469 extra_args['signature_scheme'] = cfg['signature_scheme']
469 extra_args['signature_scheme'] = cfg['signature_scheme']
470 except KeyError as exc:
470 except KeyError as exc:
471 msg = '\n'.join([
471 msg = '\n'.join([
472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
473 "If you are reusing connection files, remove them and start ipcontroller again."
473 "If you are reusing connection files, remove them and start ipcontroller again."
474 ])
474 ])
475 raise ValueError(msg.format(exc.message))
475 raise ValueError(msg.format(exc.message))
476
476
477 self.session = Session(**extra_args)
477 self.session = Session(**extra_args)
478
478
479 self._query_socket = self._context.socket(zmq.DEALER)
479 self._query_socket = self._context.socket(zmq.DEALER)
480
480
481 if self._ssh:
481 if self._ssh:
482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
483 else:
483 else:
484 self._query_socket.connect(cfg['registration'])
484 self._query_socket.connect(cfg['registration'])
485
485
486 self.session.debug = self.debug
486 self.session.debug = self.debug
487
487
488 self._notification_handlers = {'registration_notification' : self._register_engine,
488 self._notification_handlers = {'registration_notification' : self._register_engine,
489 'unregistration_notification' : self._unregister_engine,
489 'unregistration_notification' : self._unregister_engine,
490 'shutdown_notification' : lambda msg: self.close(),
490 'shutdown_notification' : lambda msg: self.close(),
491 }
491 }
492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
493 'apply_reply' : self._handle_apply_reply}
493 'apply_reply' : self._handle_apply_reply}
494
494
495 try:
495 try:
496 self._connect(sshserver, ssh_kwargs, timeout)
496 self._connect(sshserver, ssh_kwargs, timeout)
497 except:
497 except:
498 self.close(linger=0)
498 self.close(linger=0)
499 raise
499 raise
500
500
501 # last step: setup magics, if we are in IPython:
501 # last step: setup magics, if we are in IPython:
502
502
503 try:
503 try:
504 ip = get_ipython()
504 ip = get_ipython()
505 except NameError:
505 except NameError:
506 return
506 return
507 else:
507 else:
508 if 'px' not in ip.magics_manager.magics:
508 if 'px' not in ip.magics_manager.magics:
509 # in IPython but we are the first Client.
509 # in IPython but we are the first Client.
510 # activate a default view for parallel magics.
510 # activate a default view for parallel magics.
511 self.activate()
511 self.activate()
512
512
513 def __del__(self):
513 def __del__(self):
514 """cleanup sockets, but _not_ context."""
514 """cleanup sockets, but _not_ context."""
515 self.close()
515 self.close()
516
516
517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
518 if ipython_dir is None:
518 if ipython_dir is None:
519 ipython_dir = get_ipython_dir()
519 ipython_dir = get_ipython_dir()
520 if profile_dir is not None:
520 if profile_dir is not None:
521 try:
521 try:
522 self._cd = ProfileDir.find_profile_dir(profile_dir)
522 self._cd = ProfileDir.find_profile_dir(profile_dir)
523 return
523 return
524 except ProfileDirError:
524 except ProfileDirError:
525 pass
525 pass
526 elif profile is not None:
526 elif profile is not None:
527 try:
527 try:
528 self._cd = ProfileDir.find_profile_dir_by_name(
528 self._cd = ProfileDir.find_profile_dir_by_name(
529 ipython_dir, profile)
529 ipython_dir, profile)
530 return
530 return
531 except ProfileDirError:
531 except ProfileDirError:
532 pass
532 pass
533 self._cd = None
533 self._cd = None
534
534
535 def _update_engines(self, engines):
535 def _update_engines(self, engines):
536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
537 for k,v in iteritems(engines):
537 for k,v in iteritems(engines):
538 eid = int(k)
538 eid = int(k)
539 if eid not in self._engines:
539 if eid not in self._engines:
540 self._ids.append(eid)
540 self._ids.append(eid)
541 self._engines[eid] = v
541 self._engines[eid] = v
542 self._ids = sorted(self._ids)
542 self._ids = sorted(self._ids)
543 if sorted(self._engines.keys()) != range(len(self._engines)) and \
543 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
544 self._task_scheme == 'pure' and self._task_socket:
544 self._task_scheme == 'pure' and self._task_socket:
545 self._stop_scheduling_tasks()
545 self._stop_scheduling_tasks()
546
546
547 def _stop_scheduling_tasks(self):
547 def _stop_scheduling_tasks(self):
548 """Stop scheduling tasks because an engine has been unregistered
548 """Stop scheduling tasks because an engine has been unregistered
549 from a pure ZMQ scheduler.
549 from a pure ZMQ scheduler.
550 """
550 """
551 self._task_socket.close()
551 self._task_socket.close()
552 self._task_socket = None
552 self._task_socket = None
553 msg = "An engine has been unregistered, and we are using pure " +\
553 msg = "An engine has been unregistered, and we are using pure " +\
554 "ZMQ task scheduling. Task farming will be disabled."
554 "ZMQ task scheduling. Task farming will be disabled."
555 if self.outstanding:
555 if self.outstanding:
556 msg += " If you were running tasks when this happened, " +\
556 msg += " If you were running tasks when this happened, " +\
557 "some `outstanding` msg_ids may never resolve."
557 "some `outstanding` msg_ids may never resolve."
558 warnings.warn(msg, RuntimeWarning)
558 warnings.warn(msg, RuntimeWarning)
559
559
560 def _build_targets(self, targets):
560 def _build_targets(self, targets):
561 """Turn valid target IDs or 'all' into two lists:
561 """Turn valid target IDs or 'all' into two lists:
562 (int_ids, uuids).
562 (int_ids, uuids).
563 """
563 """
564 if not self._ids:
564 if not self._ids:
565 # flush notification socket if no engines yet, just in case
565 # flush notification socket if no engines yet, just in case
566 if not self.ids:
566 if not self.ids:
567 raise error.NoEnginesRegistered("Can't build targets without any engines")
567 raise error.NoEnginesRegistered("Can't build targets without any engines")
568
568
569 if targets is None:
569 if targets is None:
570 targets = self._ids
570 targets = self._ids
571 elif isinstance(targets, string_types):
571 elif isinstance(targets, string_types):
572 if targets.lower() == 'all':
572 if targets.lower() == 'all':
573 targets = self._ids
573 targets = self._ids
574 else:
574 else:
575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
576 elif isinstance(targets, int):
576 elif isinstance(targets, int):
577 if targets < 0:
577 if targets < 0:
578 targets = self.ids[targets]
578 targets = self.ids[targets]
579 if targets not in self._ids:
579 if targets not in self._ids:
580 raise IndexError("No such engine: %i"%targets)
580 raise IndexError("No such engine: %i"%targets)
581 targets = [targets]
581 targets = [targets]
582
582
583 if isinstance(targets, slice):
583 if isinstance(targets, slice):
584 indices = range(len(self._ids))[targets]
584 indices = list(range(len(self._ids))[targets])
585 ids = self.ids
585 ids = self.ids
586 targets = [ ids[i] for i in indices ]
586 targets = [ ids[i] for i in indices ]
587
587
588 if not isinstance(targets, (tuple, list, xrange)):
588 if not isinstance(targets, (tuple, list, xrange)):
589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
590
590
591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
592
592
593 def _connect(self, sshserver, ssh_kwargs, timeout):
593 def _connect(self, sshserver, ssh_kwargs, timeout):
594 """setup all our socket connections to the cluster. This is called from
594 """setup all our socket connections to the cluster. This is called from
595 __init__."""
595 __init__."""
596
596
597 # Maybe allow reconnecting?
597 # Maybe allow reconnecting?
598 if self._connected:
598 if self._connected:
599 return
599 return
600 self._connected=True
600 self._connected=True
601
601
602 def connect_socket(s, url):
602 def connect_socket(s, url):
603 if self._ssh:
603 if self._ssh:
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
605 else:
605 else:
606 return s.connect(url)
606 return s.connect(url)
607
607
608 self.session.send(self._query_socket, 'connection_request')
608 self.session.send(self._query_socket, 'connection_request')
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
610 poller = zmq.Poller()
610 poller = zmq.Poller()
611 poller.register(self._query_socket, zmq.POLLIN)
611 poller.register(self._query_socket, zmq.POLLIN)
612 # poll expects milliseconds, timeout is seconds
612 # poll expects milliseconds, timeout is seconds
613 evts = poller.poll(timeout*1000)
613 evts = poller.poll(timeout*1000)
614 if not evts:
614 if not evts:
615 raise error.TimeoutError("Hub connection request timed out")
615 raise error.TimeoutError("Hub connection request timed out")
616 idents,msg = self.session.recv(self._query_socket,mode=0)
616 idents,msg = self.session.recv(self._query_socket,mode=0)
617 if self.debug:
617 if self.debug:
618 pprint(msg)
618 pprint(msg)
619 content = msg['content']
619 content = msg['content']
620 # self._config['registration'] = dict(content)
620 # self._config['registration'] = dict(content)
621 cfg = self._config
621 cfg = self._config
622 if content['status'] == 'ok':
622 if content['status'] == 'ok':
623 self._mux_socket = self._context.socket(zmq.DEALER)
623 self._mux_socket = self._context.socket(zmq.DEALER)
624 connect_socket(self._mux_socket, cfg['mux'])
624 connect_socket(self._mux_socket, cfg['mux'])
625
625
626 self._task_socket = self._context.socket(zmq.DEALER)
626 self._task_socket = self._context.socket(zmq.DEALER)
627 connect_socket(self._task_socket, cfg['task'])
627 connect_socket(self._task_socket, cfg['task'])
628
628
629 self._notification_socket = self._context.socket(zmq.SUB)
629 self._notification_socket = self._context.socket(zmq.SUB)
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
631 connect_socket(self._notification_socket, cfg['notification'])
631 connect_socket(self._notification_socket, cfg['notification'])
632
632
633 self._control_socket = self._context.socket(zmq.DEALER)
633 self._control_socket = self._context.socket(zmq.DEALER)
634 connect_socket(self._control_socket, cfg['control'])
634 connect_socket(self._control_socket, cfg['control'])
635
635
636 self._iopub_socket = self._context.socket(zmq.SUB)
636 self._iopub_socket = self._context.socket(zmq.SUB)
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
638 connect_socket(self._iopub_socket, cfg['iopub'])
638 connect_socket(self._iopub_socket, cfg['iopub'])
639
639
640 self._update_engines(dict(content['engines']))
640 self._update_engines(dict(content['engines']))
641 else:
641 else:
642 self._connected = False
642 self._connected = False
643 raise Exception("Failed to connect!")
643 raise Exception("Failed to connect!")
644
644
645 #--------------------------------------------------------------------------
645 #--------------------------------------------------------------------------
646 # handlers and callbacks for incoming messages
646 # handlers and callbacks for incoming messages
647 #--------------------------------------------------------------------------
647 #--------------------------------------------------------------------------
648
648
649 def _unwrap_exception(self, content):
649 def _unwrap_exception(self, content):
650 """unwrap exception, and remap engine_id to int."""
650 """unwrap exception, and remap engine_id to int."""
651 e = error.unwrap_exception(content)
651 e = error.unwrap_exception(content)
652 # print e.traceback
652 # print e.traceback
653 if e.engine_info:
653 if e.engine_info:
654 e_uuid = e.engine_info['engine_uuid']
654 e_uuid = e.engine_info['engine_uuid']
655 eid = self._engines[e_uuid]
655 eid = self._engines[e_uuid]
656 e.engine_info['engine_id'] = eid
656 e.engine_info['engine_id'] = eid
657 return e
657 return e
658
658
659 def _extract_metadata(self, msg):
659 def _extract_metadata(self, msg):
660 header = msg['header']
660 header = msg['header']
661 parent = msg['parent_header']
661 parent = msg['parent_header']
662 msg_meta = msg['metadata']
662 msg_meta = msg['metadata']
663 content = msg['content']
663 content = msg['content']
664 md = {'msg_id' : parent['msg_id'],
664 md = {'msg_id' : parent['msg_id'],
665 'received' : datetime.now(),
665 'received' : datetime.now(),
666 'engine_uuid' : msg_meta.get('engine', None),
666 'engine_uuid' : msg_meta.get('engine', None),
667 'follow' : msg_meta.get('follow', []),
667 'follow' : msg_meta.get('follow', []),
668 'after' : msg_meta.get('after', []),
668 'after' : msg_meta.get('after', []),
669 'status' : content['status'],
669 'status' : content['status'],
670 }
670 }
671
671
672 if md['engine_uuid'] is not None:
672 if md['engine_uuid'] is not None:
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
674
674
675 if 'date' in parent:
675 if 'date' in parent:
676 md['submitted'] = parent['date']
676 md['submitted'] = parent['date']
677 if 'started' in msg_meta:
677 if 'started' in msg_meta:
678 md['started'] = msg_meta['started']
678 md['started'] = msg_meta['started']
679 if 'date' in header:
679 if 'date' in header:
680 md['completed'] = header['date']
680 md['completed'] = header['date']
681 return md
681 return md
682
682
683 def _register_engine(self, msg):
683 def _register_engine(self, msg):
684 """Register a new engine, and update our connection info."""
684 """Register a new engine, and update our connection info."""
685 content = msg['content']
685 content = msg['content']
686 eid = content['id']
686 eid = content['id']
687 d = {eid : content['uuid']}
687 d = {eid : content['uuid']}
688 self._update_engines(d)
688 self._update_engines(d)
689
689
690 def _unregister_engine(self, msg):
690 def _unregister_engine(self, msg):
691 """Unregister an engine that has died."""
691 """Unregister an engine that has died."""
692 content = msg['content']
692 content = msg['content']
693 eid = int(content['id'])
693 eid = int(content['id'])
694 if eid in self._ids:
694 if eid in self._ids:
695 self._ids.remove(eid)
695 self._ids.remove(eid)
696 uuid = self._engines.pop(eid)
696 uuid = self._engines.pop(eid)
697
697
698 self._handle_stranded_msgs(eid, uuid)
698 self._handle_stranded_msgs(eid, uuid)
699
699
700 if self._task_socket and self._task_scheme == 'pure':
700 if self._task_socket and self._task_scheme == 'pure':
701 self._stop_scheduling_tasks()
701 self._stop_scheduling_tasks()
702
702
703 def _handle_stranded_msgs(self, eid, uuid):
703 def _handle_stranded_msgs(self, eid, uuid):
704 """Handle messages known to be on an engine when the engine unregisters.
704 """Handle messages known to be on an engine when the engine unregisters.
705
705
706 It is possible that this will fire prematurely - that is, an engine will
706 It is possible that this will fire prematurely - that is, an engine will
707 go down after completing a result, and the client will be notified
707 go down after completing a result, and the client will be notified
708 of the unregistration and later receive the successful result.
708 of the unregistration and later receive the successful result.
709 """
709 """
710
710
711 outstanding = self._outstanding_dict[uuid]
711 outstanding = self._outstanding_dict[uuid]
712
712
713 for msg_id in list(outstanding):
713 for msg_id in list(outstanding):
714 if msg_id in self.results:
714 if msg_id in self.results:
715 # we already
715 # we already
716 continue
716 continue
717 try:
717 try:
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
719 except:
719 except:
720 content = error.wrap_exception()
720 content = error.wrap_exception()
721 # build a fake message:
721 # build a fake message:
722 msg = self.session.msg('apply_reply', content=content)
722 msg = self.session.msg('apply_reply', content=content)
723 msg['parent_header']['msg_id'] = msg_id
723 msg['parent_header']['msg_id'] = msg_id
724 msg['metadata']['engine'] = uuid
724 msg['metadata']['engine'] = uuid
725 self._handle_apply_reply(msg)
725 self._handle_apply_reply(msg)
726
726
727 def _handle_execute_reply(self, msg):
727 def _handle_execute_reply(self, msg):
728 """Save the reply to an execute_request into our results.
728 """Save the reply to an execute_request into our results.
729
729
730 execute messages are never actually used. apply is used instead.
730 execute messages are never actually used. apply is used instead.
731 """
731 """
732
732
733 parent = msg['parent_header']
733 parent = msg['parent_header']
734 msg_id = parent['msg_id']
734 msg_id = parent['msg_id']
735 if msg_id not in self.outstanding:
735 if msg_id not in self.outstanding:
736 if msg_id in self.history:
736 if msg_id in self.history:
737 print(("got stale result: %s"%msg_id))
737 print(("got stale result: %s"%msg_id))
738 else:
738 else:
739 print(("got unknown result: %s"%msg_id))
739 print(("got unknown result: %s"%msg_id))
740 else:
740 else:
741 self.outstanding.remove(msg_id)
741 self.outstanding.remove(msg_id)
742
742
743 content = msg['content']
743 content = msg['content']
744 header = msg['header']
744 header = msg['header']
745
745
746 # construct metadata:
746 # construct metadata:
747 md = self.metadata[msg_id]
747 md = self.metadata[msg_id]
748 md.update(self._extract_metadata(msg))
748 md.update(self._extract_metadata(msg))
749 # is this redundant?
749 # is this redundant?
750 self.metadata[msg_id] = md
750 self.metadata[msg_id] = md
751
751
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
753 if msg_id in e_outstanding:
753 if msg_id in e_outstanding:
754 e_outstanding.remove(msg_id)
754 e_outstanding.remove(msg_id)
755
755
756 # construct result:
756 # construct result:
757 if content['status'] == 'ok':
757 if content['status'] == 'ok':
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
759 elif content['status'] == 'aborted':
759 elif content['status'] == 'aborted':
760 self.results[msg_id] = error.TaskAborted(msg_id)
760 self.results[msg_id] = error.TaskAborted(msg_id)
761 elif content['status'] == 'resubmitted':
761 elif content['status'] == 'resubmitted':
762 # TODO: handle resubmission
762 # TODO: handle resubmission
763 pass
763 pass
764 else:
764 else:
765 self.results[msg_id] = self._unwrap_exception(content)
765 self.results[msg_id] = self._unwrap_exception(content)
766
766
767 def _handle_apply_reply(self, msg):
767 def _handle_apply_reply(self, msg):
768 """Save the reply to an apply_request into our results."""
768 """Save the reply to an apply_request into our results."""
769 parent = msg['parent_header']
769 parent = msg['parent_header']
770 msg_id = parent['msg_id']
770 msg_id = parent['msg_id']
771 if msg_id not in self.outstanding:
771 if msg_id not in self.outstanding:
772 if msg_id in self.history:
772 if msg_id in self.history:
773 print(("got stale result: %s"%msg_id))
773 print(("got stale result: %s"%msg_id))
774 print(self.results[msg_id])
774 print(self.results[msg_id])
775 print(msg)
775 print(msg)
776 else:
776 else:
777 print(("got unknown result: %s"%msg_id))
777 print(("got unknown result: %s"%msg_id))
778 else:
778 else:
779 self.outstanding.remove(msg_id)
779 self.outstanding.remove(msg_id)
780 content = msg['content']
780 content = msg['content']
781 header = msg['header']
781 header = msg['header']
782
782
783 # construct metadata:
783 # construct metadata:
784 md = self.metadata[msg_id]
784 md = self.metadata[msg_id]
785 md.update(self._extract_metadata(msg))
785 md.update(self._extract_metadata(msg))
786 # is this redundant?
786 # is this redundant?
787 self.metadata[msg_id] = md
787 self.metadata[msg_id] = md
788
788
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
790 if msg_id in e_outstanding:
790 if msg_id in e_outstanding:
791 e_outstanding.remove(msg_id)
791 e_outstanding.remove(msg_id)
792
792
793 # construct result:
793 # construct result:
794 if content['status'] == 'ok':
794 if content['status'] == 'ok':
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
796 elif content['status'] == 'aborted':
796 elif content['status'] == 'aborted':
797 self.results[msg_id] = error.TaskAborted(msg_id)
797 self.results[msg_id] = error.TaskAborted(msg_id)
798 elif content['status'] == 'resubmitted':
798 elif content['status'] == 'resubmitted':
799 # TODO: handle resubmission
799 # TODO: handle resubmission
800 pass
800 pass
801 else:
801 else:
802 self.results[msg_id] = self._unwrap_exception(content)
802 self.results[msg_id] = self._unwrap_exception(content)
803
803
804 def _flush_notifications(self):
804 def _flush_notifications(self):
805 """Flush notifications of engine registrations waiting
805 """Flush notifications of engine registrations waiting
806 in ZMQ queue."""
806 in ZMQ queue."""
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 while msg is not None:
808 while msg is not None:
809 if self.debug:
809 if self.debug:
810 pprint(msg)
810 pprint(msg)
811 msg_type = msg['header']['msg_type']
811 msg_type = msg['header']['msg_type']
812 handler = self._notification_handlers.get(msg_type, None)
812 handler = self._notification_handlers.get(msg_type, None)
813 if handler is None:
813 if handler is None:
814 raise Exception("Unhandled message type: %s" % msg_type)
814 raise Exception("Unhandled message type: %s" % msg_type)
815 else:
815 else:
816 handler(msg)
816 handler(msg)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
818
818
819 def _flush_results(self, sock):
819 def _flush_results(self, sock):
820 """Flush task or queue results waiting in ZMQ queue."""
820 """Flush task or queue results waiting in ZMQ queue."""
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 while msg is not None:
822 while msg is not None:
823 if self.debug:
823 if self.debug:
824 pprint(msg)
824 pprint(msg)
825 msg_type = msg['header']['msg_type']
825 msg_type = msg['header']['msg_type']
826 handler = self._queue_handlers.get(msg_type, None)
826 handler = self._queue_handlers.get(msg_type, None)
827 if handler is None:
827 if handler is None:
828 raise Exception("Unhandled message type: %s" % msg_type)
828 raise Exception("Unhandled message type: %s" % msg_type)
829 else:
829 else:
830 handler(msg)
830 handler(msg)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832
832
833 def _flush_control(self, sock):
833 def _flush_control(self, sock):
834 """Flush replies from the control channel waiting
834 """Flush replies from the control channel waiting
835 in the ZMQ queue.
835 in the ZMQ queue.
836
836
837 Currently: ignore them."""
837 Currently: ignore them."""
838 if self._ignored_control_replies <= 0:
838 if self._ignored_control_replies <= 0:
839 return
839 return
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 while msg is not None:
841 while msg is not None:
842 self._ignored_control_replies -= 1
842 self._ignored_control_replies -= 1
843 if self.debug:
843 if self.debug:
844 pprint(msg)
844 pprint(msg)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846
846
847 def _flush_ignored_control(self):
847 def _flush_ignored_control(self):
848 """flush ignored control replies"""
848 """flush ignored control replies"""
849 while self._ignored_control_replies > 0:
849 while self._ignored_control_replies > 0:
850 self.session.recv(self._control_socket)
850 self.session.recv(self._control_socket)
851 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
852
852
853 def _flush_ignored_hub_replies(self):
853 def _flush_ignored_hub_replies(self):
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
855 while msg is not None:
855 while msg is not None:
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
857
857
858 def _flush_iopub(self, sock):
858 def _flush_iopub(self, sock):
859 """Flush replies from the iopub channel waiting
859 """Flush replies from the iopub channel waiting
860 in the ZMQ queue.
860 in the ZMQ queue.
861 """
861 """
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
863 while msg is not None:
863 while msg is not None:
864 if self.debug:
864 if self.debug:
865 pprint(msg)
865 pprint(msg)
866 parent = msg['parent_header']
866 parent = msg['parent_header']
867 # ignore IOPub messages with no parent.
867 # ignore IOPub messages with no parent.
868 # Caused by print statements or warnings from before the first execution.
868 # Caused by print statements or warnings from before the first execution.
869 if not parent:
869 if not parent:
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 continue
871 continue
872 msg_id = parent['msg_id']
872 msg_id = parent['msg_id']
873 content = msg['content']
873 content = msg['content']
874 header = msg['header']
874 header = msg['header']
875 msg_type = msg['header']['msg_type']
875 msg_type = msg['header']['msg_type']
876
876
877 # init metadata:
877 # init metadata:
878 md = self.metadata[msg_id]
878 md = self.metadata[msg_id]
879
879
880 if msg_type == 'stream':
880 if msg_type == 'stream':
881 name = content['name']
881 name = content['name']
882 s = md[name] or ''
882 s = md[name] or ''
883 md[name] = s + content['data']
883 md[name] = s + content['data']
884 elif msg_type == 'pyerr':
884 elif msg_type == 'pyerr':
885 md.update({'pyerr' : self._unwrap_exception(content)})
885 md.update({'pyerr' : self._unwrap_exception(content)})
886 elif msg_type == 'pyin':
886 elif msg_type == 'pyin':
887 md.update({'pyin' : content['code']})
887 md.update({'pyin' : content['code']})
888 elif msg_type == 'display_data':
888 elif msg_type == 'display_data':
889 md['outputs'].append(content)
889 md['outputs'].append(content)
890 elif msg_type == 'pyout':
890 elif msg_type == 'pyout':
891 md['pyout'] = content
891 md['pyout'] = content
892 elif msg_type == 'data_message':
892 elif msg_type == 'data_message':
893 data, remainder = serialize.unserialize_object(msg['buffers'])
893 data, remainder = serialize.unserialize_object(msg['buffers'])
894 md['data'].update(data)
894 md['data'].update(data)
895 elif msg_type == 'status':
895 elif msg_type == 'status':
896 # idle message comes after all outputs
896 # idle message comes after all outputs
897 if content['execution_state'] == 'idle':
897 if content['execution_state'] == 'idle':
898 md['outputs_ready'] = True
898 md['outputs_ready'] = True
899 else:
899 else:
900 # unhandled msg_type (status, etc.)
900 # unhandled msg_type (status, etc.)
901 pass
901 pass
902
902
903 # reduntant?
903 # reduntant?
904 self.metadata[msg_id] = md
904 self.metadata[msg_id] = md
905
905
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
907
907
908 #--------------------------------------------------------------------------
908 #--------------------------------------------------------------------------
909 # len, getitem
909 # len, getitem
910 #--------------------------------------------------------------------------
910 #--------------------------------------------------------------------------
911
911
912 def __len__(self):
912 def __len__(self):
913 """len(client) returns # of engines."""
913 """len(client) returns # of engines."""
914 return len(self.ids)
914 return len(self.ids)
915
915
916 def __getitem__(self, key):
916 def __getitem__(self, key):
917 """index access returns DirectView multiplexer objects
917 """index access returns DirectView multiplexer objects
918
918
919 Must be int, slice, or list/tuple/xrange of ints"""
919 Must be int, slice, or list/tuple/xrange of ints"""
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
922 else:
922 else:
923 return self.direct_view(key)
923 return self.direct_view(key)
924
924
925 #--------------------------------------------------------------------------
925 #--------------------------------------------------------------------------
926 # Begin public methods
926 # Begin public methods
927 #--------------------------------------------------------------------------
927 #--------------------------------------------------------------------------
928
928
929 @property
929 @property
930 def ids(self):
930 def ids(self):
931 """Always up-to-date ids property."""
931 """Always up-to-date ids property."""
932 self._flush_notifications()
932 self._flush_notifications()
933 # always copy:
933 # always copy:
934 return list(self._ids)
934 return list(self._ids)
935
935
936 def activate(self, targets='all', suffix=''):
936 def activate(self, targets='all', suffix=''):
937 """Create a DirectView and register it with IPython magics
937 """Create a DirectView and register it with IPython magics
938
938
939 Defines the magics `%px, %autopx, %pxresult, %%px`
939 Defines the magics `%px, %autopx, %pxresult, %%px`
940
940
941 Parameters
941 Parameters
942 ----------
942 ----------
943
943
944 targets: int, list of ints, or 'all'
944 targets: int, list of ints, or 'all'
945 The engines on which the view's magics will run
945 The engines on which the view's magics will run
946 suffix: str [default: '']
946 suffix: str [default: '']
947 The suffix, if any, for the magics. This allows you to have
947 The suffix, if any, for the magics. This allows you to have
948 multiple views associated with parallel magics at the same time.
948 multiple views associated with parallel magics at the same time.
949
949
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
952 on engine 0.
952 on engine 0.
953 """
953 """
954 view = self.direct_view(targets)
954 view = self.direct_view(targets)
955 view.block = True
955 view.block = True
956 view.activate(suffix)
956 view.activate(suffix)
957 return view
957 return view
958
958
959 def close(self, linger=None):
959 def close(self, linger=None):
960 """Close my zmq Sockets
960 """Close my zmq Sockets
961
961
962 If `linger`, set the zmq LINGER socket option,
962 If `linger`, set the zmq LINGER socket option,
963 which allows discarding of messages.
963 which allows discarding of messages.
964 """
964 """
965 if self._closed:
965 if self._closed:
966 return
966 return
967 self.stop_spin_thread()
967 self.stop_spin_thread()
968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
969 for name in snames:
969 for name in snames:
970 socket = getattr(self, name)
970 socket = getattr(self, name)
971 if socket is not None and not socket.closed:
971 if socket is not None and not socket.closed:
972 if linger is not None:
972 if linger is not None:
973 socket.close(linger=linger)
973 socket.close(linger=linger)
974 else:
974 else:
975 socket.close()
975 socket.close()
976 self._closed = True
976 self._closed = True
977
977
978 def _spin_every(self, interval=1):
978 def _spin_every(self, interval=1):
979 """target func for use in spin_thread"""
979 """target func for use in spin_thread"""
980 while True:
980 while True:
981 if self._stop_spinning.is_set():
981 if self._stop_spinning.is_set():
982 return
982 return
983 time.sleep(interval)
983 time.sleep(interval)
984 self.spin()
984 self.spin()
985
985
986 def spin_thread(self, interval=1):
986 def spin_thread(self, interval=1):
987 """call Client.spin() in a background thread on some regular interval
987 """call Client.spin() in a background thread on some regular interval
988
988
989 This helps ensure that messages don't pile up too much in the zmq queue
989 This helps ensure that messages don't pile up too much in the zmq queue
990 while you are working on other things, or just leaving an idle terminal.
990 while you are working on other things, or just leaving an idle terminal.
991
991
992 It also helps limit potential padding of the `received` timestamp
992 It also helps limit potential padding of the `received` timestamp
993 on AsyncResult objects, used for timings.
993 on AsyncResult objects, used for timings.
994
994
995 Parameters
995 Parameters
996 ----------
996 ----------
997
997
998 interval : float, optional
998 interval : float, optional
999 The interval on which to spin the client in the background thread
999 The interval on which to spin the client in the background thread
1000 (simply passed to time.sleep).
1000 (simply passed to time.sleep).
1001
1001
1002 Notes
1002 Notes
1003 -----
1003 -----
1004
1004
1005 For precision timing, you may want to use this method to put a bound
1005 For precision timing, you may want to use this method to put a bound
1006 on the jitter (in seconds) in `received` timestamps used
1006 on the jitter (in seconds) in `received` timestamps used
1007 in AsyncResult.wall_time.
1007 in AsyncResult.wall_time.
1008
1008
1009 """
1009 """
1010 if self._spin_thread is not None:
1010 if self._spin_thread is not None:
1011 self.stop_spin_thread()
1011 self.stop_spin_thread()
1012 self._stop_spinning.clear()
1012 self._stop_spinning.clear()
1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1014 self._spin_thread.daemon = True
1014 self._spin_thread.daemon = True
1015 self._spin_thread.start()
1015 self._spin_thread.start()
1016
1016
1017 def stop_spin_thread(self):
1017 def stop_spin_thread(self):
1018 """stop background spin_thread, if any"""
1018 """stop background spin_thread, if any"""
1019 if self._spin_thread is not None:
1019 if self._spin_thread is not None:
1020 self._stop_spinning.set()
1020 self._stop_spinning.set()
1021 self._spin_thread.join()
1021 self._spin_thread.join()
1022 self._spin_thread = None
1022 self._spin_thread = None
1023
1023
1024 def spin(self):
1024 def spin(self):
1025 """Flush any registration notifications and execution results
1025 """Flush any registration notifications and execution results
1026 waiting in the ZMQ queue.
1026 waiting in the ZMQ queue.
1027 """
1027 """
1028 if self._notification_socket:
1028 if self._notification_socket:
1029 self._flush_notifications()
1029 self._flush_notifications()
1030 if self._iopub_socket:
1030 if self._iopub_socket:
1031 self._flush_iopub(self._iopub_socket)
1031 self._flush_iopub(self._iopub_socket)
1032 if self._mux_socket:
1032 if self._mux_socket:
1033 self._flush_results(self._mux_socket)
1033 self._flush_results(self._mux_socket)
1034 if self._task_socket:
1034 if self._task_socket:
1035 self._flush_results(self._task_socket)
1035 self._flush_results(self._task_socket)
1036 if self._control_socket:
1036 if self._control_socket:
1037 self._flush_control(self._control_socket)
1037 self._flush_control(self._control_socket)
1038 if self._query_socket:
1038 if self._query_socket:
1039 self._flush_ignored_hub_replies()
1039 self._flush_ignored_hub_replies()
1040
1040
1041 def wait(self, jobs=None, timeout=-1):
1041 def wait(self, jobs=None, timeout=-1):
1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1043
1043
1044 Parameters
1044 Parameters
1045 ----------
1045 ----------
1046
1046
1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1048 ints are indices to self.history
1048 ints are indices to self.history
1049 strs are msg_ids
1049 strs are msg_ids
1050 default: wait on all outstanding messages
1050 default: wait on all outstanding messages
1051 timeout : float
1051 timeout : float
1052 a time in seconds, after which to give up.
1052 a time in seconds, after which to give up.
1053 default is -1, which means no timeout
1053 default is -1, which means no timeout
1054
1054
1055 Returns
1055 Returns
1056 -------
1056 -------
1057
1057
1058 True : when all msg_ids are done
1058 True : when all msg_ids are done
1059 False : timeout reached, some msg_ids still outstanding
1059 False : timeout reached, some msg_ids still outstanding
1060 """
1060 """
1061 tic = time.time()
1061 tic = time.time()
1062 if jobs is None:
1062 if jobs is None:
1063 theids = self.outstanding
1063 theids = self.outstanding
1064 else:
1064 else:
1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1066 jobs = [jobs]
1066 jobs = [jobs]
1067 theids = set()
1067 theids = set()
1068 for job in jobs:
1068 for job in jobs:
1069 if isinstance(job, int):
1069 if isinstance(job, int):
1070 # index access
1070 # index access
1071 job = self.history[job]
1071 job = self.history[job]
1072 elif isinstance(job, AsyncResult):
1072 elif isinstance(job, AsyncResult):
1073 map(theids.add, job.msg_ids)
1073 theids.update(job.msg_ids)
1074 continue
1074 continue
1075 theids.add(job)
1075 theids.add(job)
1076 if not theids.intersection(self.outstanding):
1076 if not theids.intersection(self.outstanding):
1077 return True
1077 return True
1078 self.spin()
1078 self.spin()
1079 while theids.intersection(self.outstanding):
1079 while theids.intersection(self.outstanding):
1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1081 break
1081 break
1082 time.sleep(1e-3)
1082 time.sleep(1e-3)
1083 self.spin()
1083 self.spin()
1084 return len(theids.intersection(self.outstanding)) == 0
1084 return len(theids.intersection(self.outstanding)) == 0
1085
1085
1086 #--------------------------------------------------------------------------
1086 #--------------------------------------------------------------------------
1087 # Control methods
1087 # Control methods
1088 #--------------------------------------------------------------------------
1088 #--------------------------------------------------------------------------
1089
1089
1090 @spin_first
1090 @spin_first
1091 def clear(self, targets=None, block=None):
1091 def clear(self, targets=None, block=None):
1092 """Clear the namespace in target(s)."""
1092 """Clear the namespace in target(s)."""
1093 block = self.block if block is None else block
1093 block = self.block if block is None else block
1094 targets = self._build_targets(targets)[0]
1094 targets = self._build_targets(targets)[0]
1095 for t in targets:
1095 for t in targets:
1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1097 error = False
1097 error = False
1098 if block:
1098 if block:
1099 self._flush_ignored_control()
1099 self._flush_ignored_control()
1100 for i in range(len(targets)):
1100 for i in range(len(targets)):
1101 idents,msg = self.session.recv(self._control_socket,0)
1101 idents,msg = self.session.recv(self._control_socket,0)
1102 if self.debug:
1102 if self.debug:
1103 pprint(msg)
1103 pprint(msg)
1104 if msg['content']['status'] != 'ok':
1104 if msg['content']['status'] != 'ok':
1105 error = self._unwrap_exception(msg['content'])
1105 error = self._unwrap_exception(msg['content'])
1106 else:
1106 else:
1107 self._ignored_control_replies += len(targets)
1107 self._ignored_control_replies += len(targets)
1108 if error:
1108 if error:
1109 raise error
1109 raise error
1110
1110
1111
1111
1112 @spin_first
1112 @spin_first
1113 def abort(self, jobs=None, targets=None, block=None):
1113 def abort(self, jobs=None, targets=None, block=None):
1114 """Abort specific jobs from the execution queues of target(s).
1114 """Abort specific jobs from the execution queues of target(s).
1115
1115
1116 This is a mechanism to prevent jobs that have already been submitted
1116 This is a mechanism to prevent jobs that have already been submitted
1117 from executing.
1117 from executing.
1118
1118
1119 Parameters
1119 Parameters
1120 ----------
1120 ----------
1121
1121
1122 jobs : msg_id, list of msg_ids, or AsyncResult
1122 jobs : msg_id, list of msg_ids, or AsyncResult
1123 The jobs to be aborted
1123 The jobs to be aborted
1124
1124
1125 If unspecified/None: abort all outstanding jobs.
1125 If unspecified/None: abort all outstanding jobs.
1126
1126
1127 """
1127 """
1128 block = self.block if block is None else block
1128 block = self.block if block is None else block
1129 jobs = jobs if jobs is not None else list(self.outstanding)
1129 jobs = jobs if jobs is not None else list(self.outstanding)
1130 targets = self._build_targets(targets)[0]
1130 targets = self._build_targets(targets)[0]
1131
1131
1132 msg_ids = []
1132 msg_ids = []
1133 if isinstance(jobs, string_types + (AsyncResult,)):
1133 if isinstance(jobs, string_types + (AsyncResult,)):
1134 jobs = [jobs]
1134 jobs = [jobs]
1135 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1135 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1136 if bad_ids:
1136 if bad_ids:
1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1138 for j in jobs:
1138 for j in jobs:
1139 if isinstance(j, AsyncResult):
1139 if isinstance(j, AsyncResult):
1140 msg_ids.extend(j.msg_ids)
1140 msg_ids.extend(j.msg_ids)
1141 else:
1141 else:
1142 msg_ids.append(j)
1142 msg_ids.append(j)
1143 content = dict(msg_ids=msg_ids)
1143 content = dict(msg_ids=msg_ids)
1144 for t in targets:
1144 for t in targets:
1145 self.session.send(self._control_socket, 'abort_request',
1145 self.session.send(self._control_socket, 'abort_request',
1146 content=content, ident=t)
1146 content=content, ident=t)
1147 error = False
1147 error = False
1148 if block:
1148 if block:
1149 self._flush_ignored_control()
1149 self._flush_ignored_control()
1150 for i in range(len(targets)):
1150 for i in range(len(targets)):
1151 idents,msg = self.session.recv(self._control_socket,0)
1151 idents,msg = self.session.recv(self._control_socket,0)
1152 if self.debug:
1152 if self.debug:
1153 pprint(msg)
1153 pprint(msg)
1154 if msg['content']['status'] != 'ok':
1154 if msg['content']['status'] != 'ok':
1155 error = self._unwrap_exception(msg['content'])
1155 error = self._unwrap_exception(msg['content'])
1156 else:
1156 else:
1157 self._ignored_control_replies += len(targets)
1157 self._ignored_control_replies += len(targets)
1158 if error:
1158 if error:
1159 raise error
1159 raise error
1160
1160
1161 @spin_first
1161 @spin_first
1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1163 """Terminates one or more engine processes, optionally including the hub.
1163 """Terminates one or more engine processes, optionally including the hub.
1164
1164
1165 Parameters
1165 Parameters
1166 ----------
1166 ----------
1167
1167
1168 targets: list of ints or 'all' [default: all]
1168 targets: list of ints or 'all' [default: all]
1169 Which engines to shutdown.
1169 Which engines to shutdown.
1170 hub: bool [default: False]
1170 hub: bool [default: False]
1171 Whether to include the Hub. hub=True implies targets='all'.
1171 Whether to include the Hub. hub=True implies targets='all'.
1172 block: bool [default: self.block]
1172 block: bool [default: self.block]
1173 Whether to wait for clean shutdown replies or not.
1173 Whether to wait for clean shutdown replies or not.
1174 restart: bool [default: False]
1174 restart: bool [default: False]
1175 NOT IMPLEMENTED
1175 NOT IMPLEMENTED
1176 whether to restart engines after shutting them down.
1176 whether to restart engines after shutting them down.
1177 """
1177 """
1178 from IPython.parallel.error import NoEnginesRegistered
1178 from IPython.parallel.error import NoEnginesRegistered
1179 if restart:
1179 if restart:
1180 raise NotImplementedError("Engine restart is not yet implemented")
1180 raise NotImplementedError("Engine restart is not yet implemented")
1181
1181
1182 block = self.block if block is None else block
1182 block = self.block if block is None else block
1183 if hub:
1183 if hub:
1184 targets = 'all'
1184 targets = 'all'
1185 try:
1185 try:
1186 targets = self._build_targets(targets)[0]
1186 targets = self._build_targets(targets)[0]
1187 except NoEnginesRegistered:
1187 except NoEnginesRegistered:
1188 targets = []
1188 targets = []
1189 for t in targets:
1189 for t in targets:
1190 self.session.send(self._control_socket, 'shutdown_request',
1190 self.session.send(self._control_socket, 'shutdown_request',
1191 content={'restart':restart},ident=t)
1191 content={'restart':restart},ident=t)
1192 error = False
1192 error = False
1193 if block or hub:
1193 if block or hub:
1194 self._flush_ignored_control()
1194 self._flush_ignored_control()
1195 for i in range(len(targets)):
1195 for i in range(len(targets)):
1196 idents,msg = self.session.recv(self._control_socket, 0)
1196 idents,msg = self.session.recv(self._control_socket, 0)
1197 if self.debug:
1197 if self.debug:
1198 pprint(msg)
1198 pprint(msg)
1199 if msg['content']['status'] != 'ok':
1199 if msg['content']['status'] != 'ok':
1200 error = self._unwrap_exception(msg['content'])
1200 error = self._unwrap_exception(msg['content'])
1201 else:
1201 else:
1202 self._ignored_control_replies += len(targets)
1202 self._ignored_control_replies += len(targets)
1203
1203
1204 if hub:
1204 if hub:
1205 time.sleep(0.25)
1205 time.sleep(0.25)
1206 self.session.send(self._query_socket, 'shutdown_request')
1206 self.session.send(self._query_socket, 'shutdown_request')
1207 idents,msg = self.session.recv(self._query_socket, 0)
1207 idents,msg = self.session.recv(self._query_socket, 0)
1208 if self.debug:
1208 if self.debug:
1209 pprint(msg)
1209 pprint(msg)
1210 if msg['content']['status'] != 'ok':
1210 if msg['content']['status'] != 'ok':
1211 error = self._unwrap_exception(msg['content'])
1211 error = self._unwrap_exception(msg['content'])
1212
1212
1213 if error:
1213 if error:
1214 raise error
1214 raise error
1215
1215
1216 #--------------------------------------------------------------------------
1216 #--------------------------------------------------------------------------
1217 # Execution related methods
1217 # Execution related methods
1218 #--------------------------------------------------------------------------
1218 #--------------------------------------------------------------------------
1219
1219
1220 def _maybe_raise(self, result):
1220 def _maybe_raise(self, result):
1221 """wrapper for maybe raising an exception if apply failed."""
1221 """wrapper for maybe raising an exception if apply failed."""
1222 if isinstance(result, error.RemoteError):
1222 if isinstance(result, error.RemoteError):
1223 raise result
1223 raise result
1224
1224
1225 return result
1225 return result
1226
1226
1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1228 ident=None):
1228 ident=None):
1229 """construct and send an apply message via a socket.
1229 """construct and send an apply message via a socket.
1230
1230
1231 This is the principal method with which all engine execution is performed by views.
1231 This is the principal method with which all engine execution is performed by views.
1232 """
1232 """
1233
1233
1234 if self._closed:
1234 if self._closed:
1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1236
1236
1237 # defaults:
1237 # defaults:
1238 args = args if args is not None else []
1238 args = args if args is not None else []
1239 kwargs = kwargs if kwargs is not None else {}
1239 kwargs = kwargs if kwargs is not None else {}
1240 metadata = metadata if metadata is not None else {}
1240 metadata = metadata if metadata is not None else {}
1241
1241
1242 # validate arguments
1242 # validate arguments
1243 if not callable(f) and not isinstance(f, Reference):
1243 if not callable(f) and not isinstance(f, Reference):
1244 raise TypeError("f must be callable, not %s"%type(f))
1244 raise TypeError("f must be callable, not %s"%type(f))
1245 if not isinstance(args, (tuple, list)):
1245 if not isinstance(args, (tuple, list)):
1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1247 if not isinstance(kwargs, dict):
1247 if not isinstance(kwargs, dict):
1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1249 if not isinstance(metadata, dict):
1249 if not isinstance(metadata, dict):
1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1251
1251
1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1253 buffer_threshold=self.session.buffer_threshold,
1253 buffer_threshold=self.session.buffer_threshold,
1254 item_threshold=self.session.item_threshold,
1254 item_threshold=self.session.item_threshold,
1255 )
1255 )
1256
1256
1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1258 metadata=metadata, track=track)
1258 metadata=metadata, track=track)
1259
1259
1260 msg_id = msg['header']['msg_id']
1260 msg_id = msg['header']['msg_id']
1261 self.outstanding.add(msg_id)
1261 self.outstanding.add(msg_id)
1262 if ident:
1262 if ident:
1263 # possibly routed to a specific engine
1263 # possibly routed to a specific engine
1264 if isinstance(ident, list):
1264 if isinstance(ident, list):
1265 ident = ident[-1]
1265 ident = ident[-1]
1266 if ident in self._engines.values():
1266 if ident in self._engines.values():
1267 # save for later, in case of engine death
1267 # save for later, in case of engine death
1268 self._outstanding_dict[ident].add(msg_id)
1268 self._outstanding_dict[ident].add(msg_id)
1269 self.history.append(msg_id)
1269 self.history.append(msg_id)
1270 self.metadata[msg_id]['submitted'] = datetime.now()
1270 self.metadata[msg_id]['submitted'] = datetime.now()
1271
1271
1272 return msg
1272 return msg
1273
1273
1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1275 """construct and send an execute request via a socket.
1275 """construct and send an execute request via a socket.
1276
1276
1277 """
1277 """
1278
1278
1279 if self._closed:
1279 if self._closed:
1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1281
1281
1282 # defaults:
1282 # defaults:
1283 metadata = metadata if metadata is not None else {}
1283 metadata = metadata if metadata is not None else {}
1284
1284
1285 # validate arguments
1285 # validate arguments
1286 if not isinstance(code, string_types):
1286 if not isinstance(code, string_types):
1287 raise TypeError("code must be text, not %s" % type(code))
1287 raise TypeError("code must be text, not %s" % type(code))
1288 if not isinstance(metadata, dict):
1288 if not isinstance(metadata, dict):
1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1290
1290
1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1292
1292
1293
1293
1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1295 metadata=metadata)
1295 metadata=metadata)
1296
1296
1297 msg_id = msg['header']['msg_id']
1297 msg_id = msg['header']['msg_id']
1298 self.outstanding.add(msg_id)
1298 self.outstanding.add(msg_id)
1299 if ident:
1299 if ident:
1300 # possibly routed to a specific engine
1300 # possibly routed to a specific engine
1301 if isinstance(ident, list):
1301 if isinstance(ident, list):
1302 ident = ident[-1]
1302 ident = ident[-1]
1303 if ident in self._engines.values():
1303 if ident in self._engines.values():
1304 # save for later, in case of engine death
1304 # save for later, in case of engine death
1305 self._outstanding_dict[ident].add(msg_id)
1305 self._outstanding_dict[ident].add(msg_id)
1306 self.history.append(msg_id)
1306 self.history.append(msg_id)
1307 self.metadata[msg_id]['submitted'] = datetime.now()
1307 self.metadata[msg_id]['submitted'] = datetime.now()
1308
1308
1309 return msg
1309 return msg
1310
1310
1311 #--------------------------------------------------------------------------
1311 #--------------------------------------------------------------------------
1312 # construct a View object
1312 # construct a View object
1313 #--------------------------------------------------------------------------
1313 #--------------------------------------------------------------------------
1314
1314
1315 def load_balanced_view(self, targets=None):
1315 def load_balanced_view(self, targets=None):
1316 """construct a DirectView object.
1316 """construct a DirectView object.
1317
1317
1318 If no arguments are specified, create a LoadBalancedView
1318 If no arguments are specified, create a LoadBalancedView
1319 using all engines.
1319 using all engines.
1320
1320
1321 Parameters
1321 Parameters
1322 ----------
1322 ----------
1323
1323
1324 targets: list,slice,int,etc. [default: use all engines]
1324 targets: list,slice,int,etc. [default: use all engines]
1325 The subset of engines across which to load-balance
1325 The subset of engines across which to load-balance
1326 """
1326 """
1327 if targets == 'all':
1327 if targets == 'all':
1328 targets = None
1328 targets = None
1329 if targets is not None:
1329 if targets is not None:
1330 targets = self._build_targets(targets)[1]
1330 targets = self._build_targets(targets)[1]
1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1332
1332
1333 def direct_view(self, targets='all'):
1333 def direct_view(self, targets='all'):
1334 """construct a DirectView object.
1334 """construct a DirectView object.
1335
1335
1336 If no targets are specified, create a DirectView using all engines.
1336 If no targets are specified, create a DirectView using all engines.
1337
1337
1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1340 all *current* engines, and that list will not change.
1340 all *current* engines, and that list will not change.
1341
1341
1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1343 engines added after the DirectView is constructed.
1343 engines added after the DirectView is constructed.
1344
1344
1345 Parameters
1345 Parameters
1346 ----------
1346 ----------
1347
1347
1348 targets: list,slice,int,etc. [default: use all engines]
1348 targets: list,slice,int,etc. [default: use all engines]
1349 The engines to use for the View
1349 The engines to use for the View
1350 """
1350 """
1351 single = isinstance(targets, int)
1351 single = isinstance(targets, int)
1352 # allow 'all' to be lazily evaluated at each execution
1352 # allow 'all' to be lazily evaluated at each execution
1353 if targets != 'all':
1353 if targets != 'all':
1354 targets = self._build_targets(targets)[1]
1354 targets = self._build_targets(targets)[1]
1355 if single:
1355 if single:
1356 targets = targets[0]
1356 targets = targets[0]
1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1358
1358
1359 #--------------------------------------------------------------------------
1359 #--------------------------------------------------------------------------
1360 # Query methods
1360 # Query methods
1361 #--------------------------------------------------------------------------
1361 #--------------------------------------------------------------------------
1362
1362
1363 @spin_first
1363 @spin_first
1364 def get_result(self, indices_or_msg_ids=None, block=None):
1364 def get_result(self, indices_or_msg_ids=None, block=None):
1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1366
1366
1367 If the client already has the results, no request to the Hub will be made.
1367 If the client already has the results, no request to the Hub will be made.
1368
1368
1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1370 that include metadata about execution, and allow for awaiting results that
1370 that include metadata about execution, and allow for awaiting results that
1371 were not submitted by this Client.
1371 were not submitted by this Client.
1372
1372
1373 It can also be a convenient way to retrieve the metadata associated with
1373 It can also be a convenient way to retrieve the metadata associated with
1374 blocking execution, since it always retrieves
1374 blocking execution, since it always retrieves
1375
1375
1376 Examples
1376 Examples
1377 --------
1377 --------
1378 ::
1378 ::
1379
1379
1380 In [10]: r = client.apply()
1380 In [10]: r = client.apply()
1381
1381
1382 Parameters
1382 Parameters
1383 ----------
1383 ----------
1384
1384
1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1386 The indices or msg_ids of indices to be retrieved
1386 The indices or msg_ids of indices to be retrieved
1387
1387
1388 block : bool
1388 block : bool
1389 Whether to wait for the result to be done
1389 Whether to wait for the result to be done
1390
1390
1391 Returns
1391 Returns
1392 -------
1392 -------
1393
1393
1394 AsyncResult
1394 AsyncResult
1395 A single AsyncResult object will always be returned.
1395 A single AsyncResult object will always be returned.
1396
1396
1397 AsyncHubResult
1397 AsyncHubResult
1398 A subclass of AsyncResult that retrieves results from the Hub
1398 A subclass of AsyncResult that retrieves results from the Hub
1399
1399
1400 """
1400 """
1401 block = self.block if block is None else block
1401 block = self.block if block is None else block
1402 if indices_or_msg_ids is None:
1402 if indices_or_msg_ids is None:
1403 indices_or_msg_ids = -1
1403 indices_or_msg_ids = -1
1404
1404
1405 single_result = False
1405 single_result = False
1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1407 indices_or_msg_ids = [indices_or_msg_ids]
1407 indices_or_msg_ids = [indices_or_msg_ids]
1408 single_result = True
1408 single_result = True
1409
1409
1410 theids = []
1410 theids = []
1411 for id in indices_or_msg_ids:
1411 for id in indices_or_msg_ids:
1412 if isinstance(id, int):
1412 if isinstance(id, int):
1413 id = self.history[id]
1413 id = self.history[id]
1414 if not isinstance(id, string_types):
1414 if not isinstance(id, string_types):
1415 raise TypeError("indices must be str or int, not %r"%id)
1415 raise TypeError("indices must be str or int, not %r"%id)
1416 theids.append(id)
1416 theids.append(id)
1417
1417
1418 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1418 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1419 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1419 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1420
1420
1421 # given single msg_id initially, get_result shot get the result itself,
1421 # given single msg_id initially, get_result shot get the result itself,
1422 # not a length-one list
1422 # not a length-one list
1423 if single_result:
1423 if single_result:
1424 theids = theids[0]
1424 theids = theids[0]
1425
1425
1426 if remote_ids:
1426 if remote_ids:
1427 ar = AsyncHubResult(self, msg_ids=theids)
1427 ar = AsyncHubResult(self, msg_ids=theids)
1428 else:
1428 else:
1429 ar = AsyncResult(self, msg_ids=theids)
1429 ar = AsyncResult(self, msg_ids=theids)
1430
1430
1431 if block:
1431 if block:
1432 ar.wait()
1432 ar.wait()
1433
1433
1434 return ar
1434 return ar
1435
1435
1436 @spin_first
1436 @spin_first
1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1438 """Resubmit one or more tasks.
1438 """Resubmit one or more tasks.
1439
1439
1440 in-flight tasks may not be resubmitted.
1440 in-flight tasks may not be resubmitted.
1441
1441
1442 Parameters
1442 Parameters
1443 ----------
1443 ----------
1444
1444
1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1446 The indices or msg_ids of indices to be retrieved
1446 The indices or msg_ids of indices to be retrieved
1447
1447
1448 block : bool
1448 block : bool
1449 Whether to wait for the result to be done
1449 Whether to wait for the result to be done
1450
1450
1451 Returns
1451 Returns
1452 -------
1452 -------
1453
1453
1454 AsyncHubResult
1454 AsyncHubResult
1455 A subclass of AsyncResult that retrieves results from the Hub
1455 A subclass of AsyncResult that retrieves results from the Hub
1456
1456
1457 """
1457 """
1458 block = self.block if block is None else block
1458 block = self.block if block is None else block
1459 if indices_or_msg_ids is None:
1459 if indices_or_msg_ids is None:
1460 indices_or_msg_ids = -1
1460 indices_or_msg_ids = -1
1461
1461
1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1463 indices_or_msg_ids = [indices_or_msg_ids]
1463 indices_or_msg_ids = [indices_or_msg_ids]
1464
1464
1465 theids = []
1465 theids = []
1466 for id in indices_or_msg_ids:
1466 for id in indices_or_msg_ids:
1467 if isinstance(id, int):
1467 if isinstance(id, int):
1468 id = self.history[id]
1468 id = self.history[id]
1469 if not isinstance(id, string_types):
1469 if not isinstance(id, string_types):
1470 raise TypeError("indices must be str or int, not %r"%id)
1470 raise TypeError("indices must be str or int, not %r"%id)
1471 theids.append(id)
1471 theids.append(id)
1472
1472
1473 content = dict(msg_ids = theids)
1473 content = dict(msg_ids = theids)
1474
1474
1475 self.session.send(self._query_socket, 'resubmit_request', content)
1475 self.session.send(self._query_socket, 'resubmit_request', content)
1476
1476
1477 zmq.select([self._query_socket], [], [])
1477 zmq.select([self._query_socket], [], [])
1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1479 if self.debug:
1479 if self.debug:
1480 pprint(msg)
1480 pprint(msg)
1481 content = msg['content']
1481 content = msg['content']
1482 if content['status'] != 'ok':
1482 if content['status'] != 'ok':
1483 raise self._unwrap_exception(content)
1483 raise self._unwrap_exception(content)
1484 mapping = content['resubmitted']
1484 mapping = content['resubmitted']
1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1486
1486
1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1488
1488
1489 if block:
1489 if block:
1490 ar.wait()
1490 ar.wait()
1491
1491
1492 return ar
1492 return ar
1493
1493
1494 @spin_first
1494 @spin_first
1495 def result_status(self, msg_ids, status_only=True):
1495 def result_status(self, msg_ids, status_only=True):
1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1497
1497
1498 If status_only is False, then the actual results will be retrieved, else
1498 If status_only is False, then the actual results will be retrieved, else
1499 only the status of the results will be checked.
1499 only the status of the results will be checked.
1500
1500
1501 Parameters
1501 Parameters
1502 ----------
1502 ----------
1503
1503
1504 msg_ids : list of msg_ids
1504 msg_ids : list of msg_ids
1505 if int:
1505 if int:
1506 Passed as index to self.history for convenience.
1506 Passed as index to self.history for convenience.
1507 status_only : bool (default: True)
1507 status_only : bool (default: True)
1508 if False:
1508 if False:
1509 Retrieve the actual results of completed tasks.
1509 Retrieve the actual results of completed tasks.
1510
1510
1511 Returns
1511 Returns
1512 -------
1512 -------
1513
1513
1514 results : dict
1514 results : dict
1515 There will always be the keys 'pending' and 'completed', which will
1515 There will always be the keys 'pending' and 'completed', which will
1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1517 is False, then completed results will be keyed by their `msg_id`.
1517 is False, then completed results will be keyed by their `msg_id`.
1518 """
1518 """
1519 if not isinstance(msg_ids, (list,tuple)):
1519 if not isinstance(msg_ids, (list,tuple)):
1520 msg_ids = [msg_ids]
1520 msg_ids = [msg_ids]
1521
1521
1522 theids = []
1522 theids = []
1523 for msg_id in msg_ids:
1523 for msg_id in msg_ids:
1524 if isinstance(msg_id, int):
1524 if isinstance(msg_id, int):
1525 msg_id = self.history[msg_id]
1525 msg_id = self.history[msg_id]
1526 if not isinstance(msg_id, string_types):
1526 if not isinstance(msg_id, string_types):
1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1528 theids.append(msg_id)
1528 theids.append(msg_id)
1529
1529
1530 completed = []
1530 completed = []
1531 local_results = {}
1531 local_results = {}
1532
1532
1533 # comment this block out to temporarily disable local shortcut:
1533 # comment this block out to temporarily disable local shortcut:
1534 for msg_id in theids:
1534 for msg_id in theids:
1535 if msg_id in self.results:
1535 if msg_id in self.results:
1536 completed.append(msg_id)
1536 completed.append(msg_id)
1537 local_results[msg_id] = self.results[msg_id]
1537 local_results[msg_id] = self.results[msg_id]
1538 theids.remove(msg_id)
1538 theids.remove(msg_id)
1539
1539
1540 if theids: # some not locally cached
1540 if theids: # some not locally cached
1541 content = dict(msg_ids=theids, status_only=status_only)
1541 content = dict(msg_ids=theids, status_only=status_only)
1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1543 zmq.select([self._query_socket], [], [])
1543 zmq.select([self._query_socket], [], [])
1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1545 if self.debug:
1545 if self.debug:
1546 pprint(msg)
1546 pprint(msg)
1547 content = msg['content']
1547 content = msg['content']
1548 if content['status'] != 'ok':
1548 if content['status'] != 'ok':
1549 raise self._unwrap_exception(content)
1549 raise self._unwrap_exception(content)
1550 buffers = msg['buffers']
1550 buffers = msg['buffers']
1551 else:
1551 else:
1552 content = dict(completed=[],pending=[])
1552 content = dict(completed=[],pending=[])
1553
1553
1554 content['completed'].extend(completed)
1554 content['completed'].extend(completed)
1555
1555
1556 if status_only:
1556 if status_only:
1557 return content
1557 return content
1558
1558
1559 failures = []
1559 failures = []
1560 # load cached results into result:
1560 # load cached results into result:
1561 content.update(local_results)
1561 content.update(local_results)
1562
1562
1563 # update cache with results:
1563 # update cache with results:
1564 for msg_id in sorted(theids):
1564 for msg_id in sorted(theids):
1565 if msg_id in content['completed']:
1565 if msg_id in content['completed']:
1566 rec = content[msg_id]
1566 rec = content[msg_id]
1567 parent = rec['header']
1567 parent = rec['header']
1568 header = rec['result_header']
1568 header = rec['result_header']
1569 rcontent = rec['result_content']
1569 rcontent = rec['result_content']
1570 iodict = rec['io']
1570 iodict = rec['io']
1571 if isinstance(rcontent, str):
1571 if isinstance(rcontent, str):
1572 rcontent = self.session.unpack(rcontent)
1572 rcontent = self.session.unpack(rcontent)
1573
1573
1574 md = self.metadata[msg_id]
1574 md = self.metadata[msg_id]
1575 md_msg = dict(
1575 md_msg = dict(
1576 content=rcontent,
1576 content=rcontent,
1577 parent_header=parent,
1577 parent_header=parent,
1578 header=header,
1578 header=header,
1579 metadata=rec['result_metadata'],
1579 metadata=rec['result_metadata'],
1580 )
1580 )
1581 md.update(self._extract_metadata(md_msg))
1581 md.update(self._extract_metadata(md_msg))
1582 if rec.get('received'):
1582 if rec.get('received'):
1583 md['received'] = rec['received']
1583 md['received'] = rec['received']
1584 md.update(iodict)
1584 md.update(iodict)
1585
1585
1586 if rcontent['status'] == 'ok':
1586 if rcontent['status'] == 'ok':
1587 if header['msg_type'] == 'apply_reply':
1587 if header['msg_type'] == 'apply_reply':
1588 res,buffers = serialize.unserialize_object(buffers)
1588 res,buffers = serialize.unserialize_object(buffers)
1589 elif header['msg_type'] == 'execute_reply':
1589 elif header['msg_type'] == 'execute_reply':
1590 res = ExecuteReply(msg_id, rcontent, md)
1590 res = ExecuteReply(msg_id, rcontent, md)
1591 else:
1591 else:
1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1593 else:
1593 else:
1594 res = self._unwrap_exception(rcontent)
1594 res = self._unwrap_exception(rcontent)
1595 failures.append(res)
1595 failures.append(res)
1596
1596
1597 self.results[msg_id] = res
1597 self.results[msg_id] = res
1598 content[msg_id] = res
1598 content[msg_id] = res
1599
1599
1600 if len(theids) == 1 and failures:
1600 if len(theids) == 1 and failures:
1601 raise failures[0]
1601 raise failures[0]
1602
1602
1603 error.collect_exceptions(failures, "result_status")
1603 error.collect_exceptions(failures, "result_status")
1604 return content
1604 return content
1605
1605
1606 @spin_first
1606 @spin_first
1607 def queue_status(self, targets='all', verbose=False):
1607 def queue_status(self, targets='all', verbose=False):
1608 """Fetch the status of engine queues.
1608 """Fetch the status of engine queues.
1609
1609
1610 Parameters
1610 Parameters
1611 ----------
1611 ----------
1612
1612
1613 targets : int/str/list of ints/strs
1613 targets : int/str/list of ints/strs
1614 the engines whose states are to be queried.
1614 the engines whose states are to be queried.
1615 default : all
1615 default : all
1616 verbose : bool
1616 verbose : bool
1617 Whether to return lengths only, or lists of ids for each element
1617 Whether to return lengths only, or lists of ids for each element
1618 """
1618 """
1619 if targets == 'all':
1619 if targets == 'all':
1620 # allow 'all' to be evaluated on the engine
1620 # allow 'all' to be evaluated on the engine
1621 engine_ids = None
1621 engine_ids = None
1622 else:
1622 else:
1623 engine_ids = self._build_targets(targets)[1]
1623 engine_ids = self._build_targets(targets)[1]
1624 content = dict(targets=engine_ids, verbose=verbose)
1624 content = dict(targets=engine_ids, verbose=verbose)
1625 self.session.send(self._query_socket, "queue_request", content=content)
1625 self.session.send(self._query_socket, "queue_request", content=content)
1626 idents,msg = self.session.recv(self._query_socket, 0)
1626 idents,msg = self.session.recv(self._query_socket, 0)
1627 if self.debug:
1627 if self.debug:
1628 pprint(msg)
1628 pprint(msg)
1629 content = msg['content']
1629 content = msg['content']
1630 status = content.pop('status')
1630 status = content.pop('status')
1631 if status != 'ok':
1631 if status != 'ok':
1632 raise self._unwrap_exception(content)
1632 raise self._unwrap_exception(content)
1633 content = rekey(content)
1633 content = rekey(content)
1634 if isinstance(targets, int):
1634 if isinstance(targets, int):
1635 return content[targets]
1635 return content[targets]
1636 else:
1636 else:
1637 return content
1637 return content
1638
1638
1639 def _build_msgids_from_target(self, targets=None):
1639 def _build_msgids_from_target(self, targets=None):
1640 """Build a list of msg_ids from the list of engine targets"""
1640 """Build a list of msg_ids from the list of engine targets"""
1641 if not targets: # needed as _build_targets otherwise uses all engines
1641 if not targets: # needed as _build_targets otherwise uses all engines
1642 return []
1642 return []
1643 target_ids = self._build_targets(targets)[0]
1643 target_ids = self._build_targets(targets)[0]
1644 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1644 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1645
1645
1646 def _build_msgids_from_jobs(self, jobs=None):
1646 def _build_msgids_from_jobs(self, jobs=None):
1647 """Build a list of msg_ids from "jobs" """
1647 """Build a list of msg_ids from "jobs" """
1648 if not jobs:
1648 if not jobs:
1649 return []
1649 return []
1650 msg_ids = []
1650 msg_ids = []
1651 if isinstance(jobs, string_types + (AsyncResult,)):
1651 if isinstance(jobs, string_types + (AsyncResult,)):
1652 jobs = [jobs]
1652 jobs = [jobs]
1653 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1653 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1654 if bad_ids:
1654 if bad_ids:
1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1656 for j in jobs:
1656 for j in jobs:
1657 if isinstance(j, AsyncResult):
1657 if isinstance(j, AsyncResult):
1658 msg_ids.extend(j.msg_ids)
1658 msg_ids.extend(j.msg_ids)
1659 else:
1659 else:
1660 msg_ids.append(j)
1660 msg_ids.append(j)
1661 return msg_ids
1661 return msg_ids
1662
1662
1663 def purge_local_results(self, jobs=[], targets=[]):
1663 def purge_local_results(self, jobs=[], targets=[]):
1664 """Clears the client caches of results and frees such memory.
1664 """Clears the client caches of results and frees such memory.
1665
1665
1666 Individual results can be purged by msg_id, or the entire
1666 Individual results can be purged by msg_id, or the entire
1667 history of specific targets can be purged.
1667 history of specific targets can be purged.
1668
1668
1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1670
1670
1671 The client must have no outstanding tasks before purging the caches.
1671 The client must have no outstanding tasks before purging the caches.
1672 Raises `AssertionError` if there are still outstanding tasks.
1672 Raises `AssertionError` if there are still outstanding tasks.
1673
1673
1674 After this call all `AsyncResults` are invalid and should be discarded.
1674 After this call all `AsyncResults` are invalid and should be discarded.
1675
1675
1676 If you must "reget" the results, you can still do so by using
1676 If you must "reget" the results, you can still do so by using
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 redownload the results from the hub if they are still available
1678 redownload the results from the hub if they are still available
1679 (i.e `client.purge_hub_results(...)` has not been called.
1679 (i.e `client.purge_hub_results(...)` has not been called.
1680
1680
1681 Parameters
1681 Parameters
1682 ----------
1682 ----------
1683
1683
1684 jobs : str or list of str or AsyncResult objects
1684 jobs : str or list of str or AsyncResult objects
1685 the msg_ids whose results should be purged.
1685 the msg_ids whose results should be purged.
1686 targets : int/str/list of ints/strs
1686 targets : int/str/list of ints/strs
1687 The targets, by int_id, whose entire results are to be purged.
1687 The targets, by int_id, whose entire results are to be purged.
1688
1688
1689 default : None
1689 default : None
1690 """
1690 """
1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1692
1692
1693 if not targets and not jobs:
1693 if not targets and not jobs:
1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1695
1695
1696 if jobs == 'all':
1696 if jobs == 'all':
1697 self.results.clear()
1697 self.results.clear()
1698 self.metadata.clear()
1698 self.metadata.clear()
1699 return
1699 return
1700 else:
1700 else:
1701 msg_ids = []
1701 msg_ids = []
1702 msg_ids.extend(self._build_msgids_from_target(targets))
1702 msg_ids.extend(self._build_msgids_from_target(targets))
1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1704 map(self.results.pop, msg_ids)
1704 for mid in msg_ids:
1705 map(self.metadata.pop, msg_ids)
1705 self.results.pop(mid)
1706 self.metadata.pop(mid)
1706
1707
1707
1708
1708 @spin_first
1709 @spin_first
1709 def purge_hub_results(self, jobs=[], targets=[]):
1710 def purge_hub_results(self, jobs=[], targets=[]):
1710 """Tell the Hub to forget results.
1711 """Tell the Hub to forget results.
1711
1712
1712 Individual results can be purged by msg_id, or the entire
1713 Individual results can be purged by msg_id, or the entire
1713 history of specific targets can be purged.
1714 history of specific targets can be purged.
1714
1715
1715 Use `purge_results('all')` to scrub everything from the Hub's db.
1716 Use `purge_results('all')` to scrub everything from the Hub's db.
1716
1717
1717 Parameters
1718 Parameters
1718 ----------
1719 ----------
1719
1720
1720 jobs : str or list of str or AsyncResult objects
1721 jobs : str or list of str or AsyncResult objects
1721 the msg_ids whose results should be forgotten.
1722 the msg_ids whose results should be forgotten.
1722 targets : int/str/list of ints/strs
1723 targets : int/str/list of ints/strs
1723 The targets, by int_id, whose entire history is to be purged.
1724 The targets, by int_id, whose entire history is to be purged.
1724
1725
1725 default : None
1726 default : None
1726 """
1727 """
1727 if not targets and not jobs:
1728 if not targets and not jobs:
1728 raise ValueError("Must specify at least one of `targets` and `jobs`")
1729 raise ValueError("Must specify at least one of `targets` and `jobs`")
1729 if targets:
1730 if targets:
1730 targets = self._build_targets(targets)[1]
1731 targets = self._build_targets(targets)[1]
1731
1732
1732 # construct msg_ids from jobs
1733 # construct msg_ids from jobs
1733 if jobs == 'all':
1734 if jobs == 'all':
1734 msg_ids = jobs
1735 msg_ids = jobs
1735 else:
1736 else:
1736 msg_ids = self._build_msgids_from_jobs(jobs)
1737 msg_ids = self._build_msgids_from_jobs(jobs)
1737
1738
1738 content = dict(engine_ids=targets, msg_ids=msg_ids)
1739 content = dict(engine_ids=targets, msg_ids=msg_ids)
1739 self.session.send(self._query_socket, "purge_request", content=content)
1740 self.session.send(self._query_socket, "purge_request", content=content)
1740 idents, msg = self.session.recv(self._query_socket, 0)
1741 idents, msg = self.session.recv(self._query_socket, 0)
1741 if self.debug:
1742 if self.debug:
1742 pprint(msg)
1743 pprint(msg)
1743 content = msg['content']
1744 content = msg['content']
1744 if content['status'] != 'ok':
1745 if content['status'] != 'ok':
1745 raise self._unwrap_exception(content)
1746 raise self._unwrap_exception(content)
1746
1747
1747 def purge_results(self, jobs=[], targets=[]):
1748 def purge_results(self, jobs=[], targets=[]):
1748 """Clears the cached results from both the hub and the local client
1749 """Clears the cached results from both the hub and the local client
1749
1750
1750 Individual results can be purged by msg_id, or the entire
1751 Individual results can be purged by msg_id, or the entire
1751 history of specific targets can be purged.
1752 history of specific targets can be purged.
1752
1753
1753 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1754 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1754 the Client's db.
1755 the Client's db.
1755
1756
1756 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1757 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1757 the same arguments.
1758 the same arguments.
1758
1759
1759 Parameters
1760 Parameters
1760 ----------
1761 ----------
1761
1762
1762 jobs : str or list of str or AsyncResult objects
1763 jobs : str or list of str or AsyncResult objects
1763 the msg_ids whose results should be forgotten.
1764 the msg_ids whose results should be forgotten.
1764 targets : int/str/list of ints/strs
1765 targets : int/str/list of ints/strs
1765 The targets, by int_id, whose entire history is to be purged.
1766 The targets, by int_id, whose entire history is to be purged.
1766
1767
1767 default : None
1768 default : None
1768 """
1769 """
1769 self.purge_local_results(jobs=jobs, targets=targets)
1770 self.purge_local_results(jobs=jobs, targets=targets)
1770 self.purge_hub_results(jobs=jobs, targets=targets)
1771 self.purge_hub_results(jobs=jobs, targets=targets)
1771
1772
1772 def purge_everything(self):
1773 def purge_everything(self):
1773 """Clears all content from previous Tasks from both the hub and the local client
1774 """Clears all content from previous Tasks from both the hub and the local client
1774
1775
1775 In addition to calling `purge_results("all")` it also deletes the history and
1776 In addition to calling `purge_results("all")` it also deletes the history and
1776 other bookkeeping lists.
1777 other bookkeeping lists.
1777 """
1778 """
1778 self.purge_results("all")
1779 self.purge_results("all")
1779 self.history = []
1780 self.history = []
1780 self.session.digest_history.clear()
1781 self.session.digest_history.clear()
1781
1782
1782 @spin_first
1783 @spin_first
1783 def hub_history(self):
1784 def hub_history(self):
1784 """Get the Hub's history
1785 """Get the Hub's history
1785
1786
1786 Just like the Client, the Hub has a history, which is a list of msg_ids.
1787 Just like the Client, the Hub has a history, which is a list of msg_ids.
1787 This will contain the history of all clients, and, depending on configuration,
1788 This will contain the history of all clients, and, depending on configuration,
1788 may contain history across multiple cluster sessions.
1789 may contain history across multiple cluster sessions.
1789
1790
1790 Any msg_id returned here is a valid argument to `get_result`.
1791 Any msg_id returned here is a valid argument to `get_result`.
1791
1792
1792 Returns
1793 Returns
1793 -------
1794 -------
1794
1795
1795 msg_ids : list of strs
1796 msg_ids : list of strs
1796 list of all msg_ids, ordered by task submission time.
1797 list of all msg_ids, ordered by task submission time.
1797 """
1798 """
1798
1799
1799 self.session.send(self._query_socket, "history_request", content={})
1800 self.session.send(self._query_socket, "history_request", content={})
1800 idents, msg = self.session.recv(self._query_socket, 0)
1801 idents, msg = self.session.recv(self._query_socket, 0)
1801
1802
1802 if self.debug:
1803 if self.debug:
1803 pprint(msg)
1804 pprint(msg)
1804 content = msg['content']
1805 content = msg['content']
1805 if content['status'] != 'ok':
1806 if content['status'] != 'ok':
1806 raise self._unwrap_exception(content)
1807 raise self._unwrap_exception(content)
1807 else:
1808 else:
1808 return content['history']
1809 return content['history']
1809
1810
1810 @spin_first
1811 @spin_first
1811 def db_query(self, query, keys=None):
1812 def db_query(self, query, keys=None):
1812 """Query the Hub's TaskRecord database
1813 """Query the Hub's TaskRecord database
1813
1814
1814 This will return a list of task record dicts that match `query`
1815 This will return a list of task record dicts that match `query`
1815
1816
1816 Parameters
1817 Parameters
1817 ----------
1818 ----------
1818
1819
1819 query : mongodb query dict
1820 query : mongodb query dict
1820 The search dict. See mongodb query docs for details.
1821 The search dict. See mongodb query docs for details.
1821 keys : list of strs [optional]
1822 keys : list of strs [optional]
1822 The subset of keys to be returned. The default is to fetch everything but buffers.
1823 The subset of keys to be returned. The default is to fetch everything but buffers.
1823 'msg_id' will *always* be included.
1824 'msg_id' will *always* be included.
1824 """
1825 """
1825 if isinstance(keys, string_types):
1826 if isinstance(keys, string_types):
1826 keys = [keys]
1827 keys = [keys]
1827 content = dict(query=query, keys=keys)
1828 content = dict(query=query, keys=keys)
1828 self.session.send(self._query_socket, "db_request", content=content)
1829 self.session.send(self._query_socket, "db_request", content=content)
1829 idents, msg = self.session.recv(self._query_socket, 0)
1830 idents, msg = self.session.recv(self._query_socket, 0)
1830 if self.debug:
1831 if self.debug:
1831 pprint(msg)
1832 pprint(msg)
1832 content = msg['content']
1833 content = msg['content']
1833 if content['status'] != 'ok':
1834 if content['status'] != 'ok':
1834 raise self._unwrap_exception(content)
1835 raise self._unwrap_exception(content)
1835
1836
1836 records = content['records']
1837 records = content['records']
1837
1838
1838 buffer_lens = content['buffer_lens']
1839 buffer_lens = content['buffer_lens']
1839 result_buffer_lens = content['result_buffer_lens']
1840 result_buffer_lens = content['result_buffer_lens']
1840 buffers = msg['buffers']
1841 buffers = msg['buffers']
1841 has_bufs = buffer_lens is not None
1842 has_bufs = buffer_lens is not None
1842 has_rbufs = result_buffer_lens is not None
1843 has_rbufs = result_buffer_lens is not None
1843 for i,rec in enumerate(records):
1844 for i,rec in enumerate(records):
1844 # relink buffers
1845 # relink buffers
1845 if has_bufs:
1846 if has_bufs:
1846 blen = buffer_lens[i]
1847 blen = buffer_lens[i]
1847 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1848 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1848 if has_rbufs:
1849 if has_rbufs:
1849 blen = result_buffer_lens[i]
1850 blen = result_buffer_lens[i]
1850 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1851 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1851
1852
1852 return records
1853 return records
1853
1854
1854 __all__ = [ 'Client' ]
1855 __all__ = [ 'Client' ]
@@ -1,1118 +1,1119 b''
1 """Views of remote engines.
1 """Views of remote engines.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 from __future__ import print_function
7 from __future__ import print_function
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import imp
19 import imp
20 import sys
20 import sys
21 import warnings
21 import warnings
22 from contextlib import contextmanager
22 from contextlib import contextmanager
23 from types import ModuleType
23 from types import ModuleType
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.testing.skipdoctest import skip_doctest
27 from IPython.testing.skipdoctest import skip_doctest
28 from IPython.utils.traitlets import (
28 from IPython.utils.traitlets import (
29 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
30 )
30 )
31 from IPython.external.decorator import decorator
31 from IPython.external.decorator import decorator
32
32
33 from IPython.parallel import util
33 from IPython.parallel import util
34 from IPython.parallel.controller.dependency import Dependency, dependent
34 from IPython.parallel.controller.dependency import Dependency, dependent
35 from IPython.utils.py3compat import string_types, iteritems
35 from IPython.utils.py3compat import string_types, iteritems, PY3
36
36
37 from . import map as Map
37 from . import map as Map
38 from .asyncresult import AsyncResult, AsyncMapResult
38 from .asyncresult import AsyncResult, AsyncMapResult
39 from .remotefunction import ParallelFunction, parallel, remote, getname
39 from .remotefunction import ParallelFunction, parallel, remote, getname
40
40
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42 # Decorators
42 # Decorators
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44
44
45 @decorator
45 @decorator
46 def save_ids(f, self, *args, **kwargs):
46 def save_ids(f, self, *args, **kwargs):
47 """Keep our history and outstanding attributes up to date after a method call."""
47 """Keep our history and outstanding attributes up to date after a method call."""
48 n_previous = len(self.client.history)
48 n_previous = len(self.client.history)
49 try:
49 try:
50 ret = f(self, *args, **kwargs)
50 ret = f(self, *args, **kwargs)
51 finally:
51 finally:
52 nmsgs = len(self.client.history) - n_previous
52 nmsgs = len(self.client.history) - n_previous
53 msg_ids = self.client.history[-nmsgs:]
53 msg_ids = self.client.history[-nmsgs:]
54 self.history.extend(msg_ids)
54 self.history.extend(msg_ids)
55 map(self.outstanding.add, msg_ids)
55 self.outstanding.update(msg_ids)
56 return ret
56 return ret
57
57
58 @decorator
58 @decorator
59 def sync_results(f, self, *args, **kwargs):
59 def sync_results(f, self, *args, **kwargs):
60 """sync relevant results from self.client to our results attribute."""
60 """sync relevant results from self.client to our results attribute."""
61 if self._in_sync_results:
61 if self._in_sync_results:
62 return f(self, *args, **kwargs)
62 return f(self, *args, **kwargs)
63 self._in_sync_results = True
63 self._in_sync_results = True
64 try:
64 try:
65 ret = f(self, *args, **kwargs)
65 ret = f(self, *args, **kwargs)
66 finally:
66 finally:
67 self._in_sync_results = False
67 self._in_sync_results = False
68 self._sync_results()
68 self._sync_results()
69 return ret
69 return ret
70
70
71 @decorator
71 @decorator
72 def spin_after(f, self, *args, **kwargs):
72 def spin_after(f, self, *args, **kwargs):
73 """call spin after the method."""
73 """call spin after the method."""
74 ret = f(self, *args, **kwargs)
74 ret = f(self, *args, **kwargs)
75 self.spin()
75 self.spin()
76 return ret
76 return ret
77
77
78 #-----------------------------------------------------------------------------
78 #-----------------------------------------------------------------------------
79 # Classes
79 # Classes
80 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
81
81
82 @skip_doctest
82 @skip_doctest
83 class View(HasTraits):
83 class View(HasTraits):
84 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
84 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
85
85
86 Don't use this class, use subclasses.
86 Don't use this class, use subclasses.
87
87
88 Methods
88 Methods
89 -------
89 -------
90
90
91 spin
91 spin
92 flushes incoming results and registration state changes
92 flushes incoming results and registration state changes
93 control methods spin, and requesting `ids` also ensures up to date
93 control methods spin, and requesting `ids` also ensures up to date
94
94
95 wait
95 wait
96 wait on one or more msg_ids
96 wait on one or more msg_ids
97
97
98 execution methods
98 execution methods
99 apply
99 apply
100 legacy: execute, run
100 legacy: execute, run
101
101
102 data movement
102 data movement
103 push, pull, scatter, gather
103 push, pull, scatter, gather
104
104
105 query methods
105 query methods
106 get_result, queue_status, purge_results, result_status
106 get_result, queue_status, purge_results, result_status
107
107
108 control methods
108 control methods
109 abort, shutdown
109 abort, shutdown
110
110
111 """
111 """
112 # flags
112 # flags
113 block=Bool(False)
113 block=Bool(False)
114 track=Bool(True)
114 track=Bool(True)
115 targets = Any()
115 targets = Any()
116
116
117 history=List()
117 history=List()
118 outstanding = Set()
118 outstanding = Set()
119 results = Dict()
119 results = Dict()
120 client = Instance('IPython.parallel.Client')
120 client = Instance('IPython.parallel.Client')
121
121
122 _socket = Instance('zmq.Socket')
122 _socket = Instance('zmq.Socket')
123 _flag_names = List(['targets', 'block', 'track'])
123 _flag_names = List(['targets', 'block', 'track'])
124 _in_sync_results = Bool(False)
124 _in_sync_results = Bool(False)
125 _targets = Any()
125 _targets = Any()
126 _idents = Any()
126 _idents = Any()
127
127
128 def __init__(self, client=None, socket=None, **flags):
128 def __init__(self, client=None, socket=None, **flags):
129 super(View, self).__init__(client=client, _socket=socket)
129 super(View, self).__init__(client=client, _socket=socket)
130 self.results = client.results
130 self.results = client.results
131 self.block = client.block
131 self.block = client.block
132
132
133 self.set_flags(**flags)
133 self.set_flags(**flags)
134
134
135 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
135 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
136
136
137 def __repr__(self):
137 def __repr__(self):
138 strtargets = str(self.targets)
138 strtargets = str(self.targets)
139 if len(strtargets) > 16:
139 if len(strtargets) > 16:
140 strtargets = strtargets[:12]+'...]'
140 strtargets = strtargets[:12]+'...]'
141 return "<%s %s>"%(self.__class__.__name__, strtargets)
141 return "<%s %s>"%(self.__class__.__name__, strtargets)
142
142
143 def __len__(self):
143 def __len__(self):
144 if isinstance(self.targets, list):
144 if isinstance(self.targets, list):
145 return len(self.targets)
145 return len(self.targets)
146 elif isinstance(self.targets, int):
146 elif isinstance(self.targets, int):
147 return 1
147 return 1
148 else:
148 else:
149 return len(self.client)
149 return len(self.client)
150
150
151 def set_flags(self, **kwargs):
151 def set_flags(self, **kwargs):
152 """set my attribute flags by keyword.
152 """set my attribute flags by keyword.
153
153
154 Views determine behavior with a few attributes (`block`, `track`, etc.).
154 Views determine behavior with a few attributes (`block`, `track`, etc.).
155 These attributes can be set all at once by name with this method.
155 These attributes can be set all at once by name with this method.
156
156
157 Parameters
157 Parameters
158 ----------
158 ----------
159
159
160 block : bool
160 block : bool
161 whether to wait for results
161 whether to wait for results
162 track : bool
162 track : bool
163 whether to create a MessageTracker to allow the user to
163 whether to create a MessageTracker to allow the user to
164 safely edit after arrays and buffers during non-copying
164 safely edit after arrays and buffers during non-copying
165 sends.
165 sends.
166 """
166 """
167 for name, value in iteritems(kwargs):
167 for name, value in iteritems(kwargs):
168 if name not in self._flag_names:
168 if name not in self._flag_names:
169 raise KeyError("Invalid name: %r"%name)
169 raise KeyError("Invalid name: %r"%name)
170 else:
170 else:
171 setattr(self, name, value)
171 setattr(self, name, value)
172
172
173 @contextmanager
173 @contextmanager
174 def temp_flags(self, **kwargs):
174 def temp_flags(self, **kwargs):
175 """temporarily set flags, for use in `with` statements.
175 """temporarily set flags, for use in `with` statements.
176
176
177 See set_flags for permanent setting of flags
177 See set_flags for permanent setting of flags
178
178
179 Examples
179 Examples
180 --------
180 --------
181
181
182 >>> view.track=False
182 >>> view.track=False
183 ...
183 ...
184 >>> with view.temp_flags(track=True):
184 >>> with view.temp_flags(track=True):
185 ... ar = view.apply(dostuff, my_big_array)
185 ... ar = view.apply(dostuff, my_big_array)
186 ... ar.tracker.wait() # wait for send to finish
186 ... ar.tracker.wait() # wait for send to finish
187 >>> view.track
187 >>> view.track
188 False
188 False
189
189
190 """
190 """
191 # preflight: save flags, and set temporaries
191 # preflight: save flags, and set temporaries
192 saved_flags = {}
192 saved_flags = {}
193 for f in self._flag_names:
193 for f in self._flag_names:
194 saved_flags[f] = getattr(self, f)
194 saved_flags[f] = getattr(self, f)
195 self.set_flags(**kwargs)
195 self.set_flags(**kwargs)
196 # yield to the with-statement block
196 # yield to the with-statement block
197 try:
197 try:
198 yield
198 yield
199 finally:
199 finally:
200 # postflight: restore saved flags
200 # postflight: restore saved flags
201 self.set_flags(**saved_flags)
201 self.set_flags(**saved_flags)
202
202
203
203
204 #----------------------------------------------------------------
204 #----------------------------------------------------------------
205 # apply
205 # apply
206 #----------------------------------------------------------------
206 #----------------------------------------------------------------
207
207
208 def _sync_results(self):
208 def _sync_results(self):
209 """to be called by @sync_results decorator
209 """to be called by @sync_results decorator
210
210
211 after submitting any tasks.
211 after submitting any tasks.
212 """
212 """
213 delta = self.outstanding.difference(self.client.outstanding)
213 delta = self.outstanding.difference(self.client.outstanding)
214 completed = self.outstanding.intersection(delta)
214 completed = self.outstanding.intersection(delta)
215 self.outstanding = self.outstanding.difference(completed)
215 self.outstanding = self.outstanding.difference(completed)
216
216
217 @sync_results
217 @sync_results
218 @save_ids
218 @save_ids
219 def _really_apply(self, f, args, kwargs, block=None, **options):
219 def _really_apply(self, f, args, kwargs, block=None, **options):
220 """wrapper for client.send_apply_request"""
220 """wrapper for client.send_apply_request"""
221 raise NotImplementedError("Implement in subclasses")
221 raise NotImplementedError("Implement in subclasses")
222
222
223 def apply(self, f, *args, **kwargs):
223 def apply(self, f, *args, **kwargs):
224 """calls f(*args, **kwargs) on remote engines, returning the result.
224 """calls f(*args, **kwargs) on remote engines, returning the result.
225
225
226 This method sets all apply flags via this View's attributes.
226 This method sets all apply flags via this View's attributes.
227
227
228 if self.block is False:
228 if self.block is False:
229 returns AsyncResult
229 returns AsyncResult
230 else:
230 else:
231 returns actual result of f(*args, **kwargs)
231 returns actual result of f(*args, **kwargs)
232 """
232 """
233 return self._really_apply(f, args, kwargs)
233 return self._really_apply(f, args, kwargs)
234
234
235 def apply_async(self, f, *args, **kwargs):
235 def apply_async(self, f, *args, **kwargs):
236 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
236 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
237
237
238 returns AsyncResult
238 returns AsyncResult
239 """
239 """
240 return self._really_apply(f, args, kwargs, block=False)
240 return self._really_apply(f, args, kwargs, block=False)
241
241
242 @spin_after
242 @spin_after
243 def apply_sync(self, f, *args, **kwargs):
243 def apply_sync(self, f, *args, **kwargs):
244 """calls f(*args, **kwargs) on remote engines in a blocking manner,
244 """calls f(*args, **kwargs) on remote engines in a blocking manner,
245 returning the result.
245 returning the result.
246
246
247 returns: actual result of f(*args, **kwargs)
247 returns: actual result of f(*args, **kwargs)
248 """
248 """
249 return self._really_apply(f, args, kwargs, block=True)
249 return self._really_apply(f, args, kwargs, block=True)
250
250
251 #----------------------------------------------------------------
251 #----------------------------------------------------------------
252 # wrappers for client and control methods
252 # wrappers for client and control methods
253 #----------------------------------------------------------------
253 #----------------------------------------------------------------
254 @sync_results
254 @sync_results
255 def spin(self):
255 def spin(self):
256 """spin the client, and sync"""
256 """spin the client, and sync"""
257 self.client.spin()
257 self.client.spin()
258
258
259 @sync_results
259 @sync_results
260 def wait(self, jobs=None, timeout=-1):
260 def wait(self, jobs=None, timeout=-1):
261 """waits on one or more `jobs`, for up to `timeout` seconds.
261 """waits on one or more `jobs`, for up to `timeout` seconds.
262
262
263 Parameters
263 Parameters
264 ----------
264 ----------
265
265
266 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
266 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
267 ints are indices to self.history
267 ints are indices to self.history
268 strs are msg_ids
268 strs are msg_ids
269 default: wait on all outstanding messages
269 default: wait on all outstanding messages
270 timeout : float
270 timeout : float
271 a time in seconds, after which to give up.
271 a time in seconds, after which to give up.
272 default is -1, which means no timeout
272 default is -1, which means no timeout
273
273
274 Returns
274 Returns
275 -------
275 -------
276
276
277 True : when all msg_ids are done
277 True : when all msg_ids are done
278 False : timeout reached, some msg_ids still outstanding
278 False : timeout reached, some msg_ids still outstanding
279 """
279 """
280 if jobs is None:
280 if jobs is None:
281 jobs = self.history
281 jobs = self.history
282 return self.client.wait(jobs, timeout)
282 return self.client.wait(jobs, timeout)
283
283
284 def abort(self, jobs=None, targets=None, block=None):
284 def abort(self, jobs=None, targets=None, block=None):
285 """Abort jobs on my engines.
285 """Abort jobs on my engines.
286
286
287 Parameters
287 Parameters
288 ----------
288 ----------
289
289
290 jobs : None, str, list of strs, optional
290 jobs : None, str, list of strs, optional
291 if None: abort all jobs.
291 if None: abort all jobs.
292 else: abort specific msg_id(s).
292 else: abort specific msg_id(s).
293 """
293 """
294 block = block if block is not None else self.block
294 block = block if block is not None else self.block
295 targets = targets if targets is not None else self.targets
295 targets = targets if targets is not None else self.targets
296 jobs = jobs if jobs is not None else list(self.outstanding)
296 jobs = jobs if jobs is not None else list(self.outstanding)
297
297
298 return self.client.abort(jobs=jobs, targets=targets, block=block)
298 return self.client.abort(jobs=jobs, targets=targets, block=block)
299
299
300 def queue_status(self, targets=None, verbose=False):
300 def queue_status(self, targets=None, verbose=False):
301 """Fetch the Queue status of my engines"""
301 """Fetch the Queue status of my engines"""
302 targets = targets if targets is not None else self.targets
302 targets = targets if targets is not None else self.targets
303 return self.client.queue_status(targets=targets, verbose=verbose)
303 return self.client.queue_status(targets=targets, verbose=verbose)
304
304
305 def purge_results(self, jobs=[], targets=[]):
305 def purge_results(self, jobs=[], targets=[]):
306 """Instruct the controller to forget specific results."""
306 """Instruct the controller to forget specific results."""
307 if targets is None or targets == 'all':
307 if targets is None or targets == 'all':
308 targets = self.targets
308 targets = self.targets
309 return self.client.purge_results(jobs=jobs, targets=targets)
309 return self.client.purge_results(jobs=jobs, targets=targets)
310
310
311 def shutdown(self, targets=None, restart=False, hub=False, block=None):
311 def shutdown(self, targets=None, restart=False, hub=False, block=None):
312 """Terminates one or more engine processes, optionally including the hub.
312 """Terminates one or more engine processes, optionally including the hub.
313 """
313 """
314 block = self.block if block is None else block
314 block = self.block if block is None else block
315 if targets is None or targets == 'all':
315 if targets is None or targets == 'all':
316 targets = self.targets
316 targets = self.targets
317 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
317 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
318
318
319 @spin_after
319 @spin_after
320 def get_result(self, indices_or_msg_ids=None):
320 def get_result(self, indices_or_msg_ids=None):
321 """return one or more results, specified by history index or msg_id.
321 """return one or more results, specified by history index or msg_id.
322
322
323 See client.get_result for details.
323 See client.get_result for details.
324
324
325 """
325 """
326
326
327 if indices_or_msg_ids is None:
327 if indices_or_msg_ids is None:
328 indices_or_msg_ids = -1
328 indices_or_msg_ids = -1
329 if isinstance(indices_or_msg_ids, int):
329 if isinstance(indices_or_msg_ids, int):
330 indices_or_msg_ids = self.history[indices_or_msg_ids]
330 indices_or_msg_ids = self.history[indices_or_msg_ids]
331 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
331 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
332 indices_or_msg_ids = list(indices_or_msg_ids)
332 indices_or_msg_ids = list(indices_or_msg_ids)
333 for i,index in enumerate(indices_or_msg_ids):
333 for i,index in enumerate(indices_or_msg_ids):
334 if isinstance(index, int):
334 if isinstance(index, int):
335 indices_or_msg_ids[i] = self.history[index]
335 indices_or_msg_ids[i] = self.history[index]
336 return self.client.get_result(indices_or_msg_ids)
336 return self.client.get_result(indices_or_msg_ids)
337
337
338 #-------------------------------------------------------------------
338 #-------------------------------------------------------------------
339 # Map
339 # Map
340 #-------------------------------------------------------------------
340 #-------------------------------------------------------------------
341
341
342 @sync_results
342 @sync_results
343 def map(self, f, *sequences, **kwargs):
343 def map(self, f, *sequences, **kwargs):
344 """override in subclasses"""
344 """override in subclasses"""
345 raise NotImplementedError
345 raise NotImplementedError
346
346
347 def map_async(self, f, *sequences, **kwargs):
347 def map_async(self, f, *sequences, **kwargs):
348 """Parallel version of builtin `map`, using this view's engines.
348 """Parallel version of builtin `map`, using this view's engines.
349
349
350 This is equivalent to map(...block=False)
350 This is equivalent to map(...block=False)
351
351
352 See `self.map` for details.
352 See `self.map` for details.
353 """
353 """
354 if 'block' in kwargs:
354 if 'block' in kwargs:
355 raise TypeError("map_async doesn't take a `block` keyword argument.")
355 raise TypeError("map_async doesn't take a `block` keyword argument.")
356 kwargs['block'] = False
356 kwargs['block'] = False
357 return self.map(f,*sequences,**kwargs)
357 return self.map(f,*sequences,**kwargs)
358
358
359 def map_sync(self, f, *sequences, **kwargs):
359 def map_sync(self, f, *sequences, **kwargs):
360 """Parallel version of builtin `map`, using this view's engines.
360 """Parallel version of builtin `map`, using this view's engines.
361
361
362 This is equivalent to map(...block=True)
362 This is equivalent to map(...block=True)
363
363
364 See `self.map` for details.
364 See `self.map` for details.
365 """
365 """
366 if 'block' in kwargs:
366 if 'block' in kwargs:
367 raise TypeError("map_sync doesn't take a `block` keyword argument.")
367 raise TypeError("map_sync doesn't take a `block` keyword argument.")
368 kwargs['block'] = True
368 kwargs['block'] = True
369 return self.map(f,*sequences,**kwargs)
369 return self.map(f,*sequences,**kwargs)
370
370
371 def imap(self, f, *sequences, **kwargs):
371 def imap(self, f, *sequences, **kwargs):
372 """Parallel version of `itertools.imap`.
372 """Parallel version of `itertools.imap`.
373
373
374 See `self.map` for details.
374 See `self.map` for details.
375
375
376 """
376 """
377
377
378 return iter(self.map_async(f,*sequences, **kwargs))
378 return iter(self.map_async(f,*sequences, **kwargs))
379
379
380 #-------------------------------------------------------------------
380 #-------------------------------------------------------------------
381 # Decorators
381 # Decorators
382 #-------------------------------------------------------------------
382 #-------------------------------------------------------------------
383
383
384 def remote(self, block=None, **flags):
384 def remote(self, block=None, **flags):
385 """Decorator for making a RemoteFunction"""
385 """Decorator for making a RemoteFunction"""
386 block = self.block if block is None else block
386 block = self.block if block is None else block
387 return remote(self, block=block, **flags)
387 return remote(self, block=block, **flags)
388
388
389 def parallel(self, dist='b', block=None, **flags):
389 def parallel(self, dist='b', block=None, **flags):
390 """Decorator for making a ParallelFunction"""
390 """Decorator for making a ParallelFunction"""
391 block = self.block if block is None else block
391 block = self.block if block is None else block
392 return parallel(self, dist=dist, block=block, **flags)
392 return parallel(self, dist=dist, block=block, **flags)
393
393
394 @skip_doctest
394 @skip_doctest
395 class DirectView(View):
395 class DirectView(View):
396 """Direct Multiplexer View of one or more engines.
396 """Direct Multiplexer View of one or more engines.
397
397
398 These are created via indexed access to a client:
398 These are created via indexed access to a client:
399
399
400 >>> dv_1 = client[1]
400 >>> dv_1 = client[1]
401 >>> dv_all = client[:]
401 >>> dv_all = client[:]
402 >>> dv_even = client[::2]
402 >>> dv_even = client[::2]
403 >>> dv_some = client[1:3]
403 >>> dv_some = client[1:3]
404
404
405 This object provides dictionary access to engine namespaces:
405 This object provides dictionary access to engine namespaces:
406
406
407 # push a=5:
407 # push a=5:
408 >>> dv['a'] = 5
408 >>> dv['a'] = 5
409 # pull 'foo':
409 # pull 'foo':
410 >>> db['foo']
410 >>> db['foo']
411
411
412 """
412 """
413
413
414 def __init__(self, client=None, socket=None, targets=None):
414 def __init__(self, client=None, socket=None, targets=None):
415 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
415 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
416
416
417 @property
417 @property
418 def importer(self):
418 def importer(self):
419 """sync_imports(local=True) as a property.
419 """sync_imports(local=True) as a property.
420
420
421 See sync_imports for details.
421 See sync_imports for details.
422
422
423 """
423 """
424 return self.sync_imports(True)
424 return self.sync_imports(True)
425
425
426 @contextmanager
426 @contextmanager
427 def sync_imports(self, local=True, quiet=False):
427 def sync_imports(self, local=True, quiet=False):
428 """Context Manager for performing simultaneous local and remote imports.
428 """Context Manager for performing simultaneous local and remote imports.
429
429
430 'import x as y' will *not* work. The 'as y' part will simply be ignored.
430 'import x as y' will *not* work. The 'as y' part will simply be ignored.
431
431
432 If `local=True`, then the package will also be imported locally.
432 If `local=True`, then the package will also be imported locally.
433
433
434 If `quiet=True`, no output will be produced when attempting remote
434 If `quiet=True`, no output will be produced when attempting remote
435 imports.
435 imports.
436
436
437 Note that remote-only (`local=False`) imports have not been implemented.
437 Note that remote-only (`local=False`) imports have not been implemented.
438
438
439 >>> with view.sync_imports():
439 >>> with view.sync_imports():
440 ... from numpy import recarray
440 ... from numpy import recarray
441 importing recarray from numpy on engine(s)
441 importing recarray from numpy on engine(s)
442
442
443 """
443 """
444 from IPython.utils.py3compat import builtin_mod
444 from IPython.utils.py3compat import builtin_mod
445 local_import = builtin_mod.__import__
445 local_import = builtin_mod.__import__
446 modules = set()
446 modules = set()
447 results = []
447 results = []
448 @util.interactive
448 @util.interactive
449 def remote_import(name, fromlist, level):
449 def remote_import(name, fromlist, level):
450 """the function to be passed to apply, that actually performs the import
450 """the function to be passed to apply, that actually performs the import
451 on the engine, and loads up the user namespace.
451 on the engine, and loads up the user namespace.
452 """
452 """
453 import sys
453 import sys
454 user_ns = globals()
454 user_ns = globals()
455 mod = __import__(name, fromlist=fromlist, level=level)
455 mod = __import__(name, fromlist=fromlist, level=level)
456 if fromlist:
456 if fromlist:
457 for key in fromlist:
457 for key in fromlist:
458 user_ns[key] = getattr(mod, key)
458 user_ns[key] = getattr(mod, key)
459 else:
459 else:
460 user_ns[name] = sys.modules[name]
460 user_ns[name] = sys.modules[name]
461
461
462 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
462 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
463 """the drop-in replacement for __import__, that optionally imports
463 """the drop-in replacement for __import__, that optionally imports
464 locally as well.
464 locally as well.
465 """
465 """
466 # don't override nested imports
466 # don't override nested imports
467 save_import = builtin_mod.__import__
467 save_import = builtin_mod.__import__
468 builtin_mod.__import__ = local_import
468 builtin_mod.__import__ = local_import
469
469
470 if imp.lock_held():
470 if imp.lock_held():
471 # this is a side-effect import, don't do it remotely, or even
471 # this is a side-effect import, don't do it remotely, or even
472 # ignore the local effects
472 # ignore the local effects
473 return local_import(name, globals, locals, fromlist, level)
473 return local_import(name, globals, locals, fromlist, level)
474
474
475 imp.acquire_lock()
475 imp.acquire_lock()
476 if local:
476 if local:
477 mod = local_import(name, globals, locals, fromlist, level)
477 mod = local_import(name, globals, locals, fromlist, level)
478 else:
478 else:
479 raise NotImplementedError("remote-only imports not yet implemented")
479 raise NotImplementedError("remote-only imports not yet implemented")
480 imp.release_lock()
480 imp.release_lock()
481
481
482 key = name+':'+','.join(fromlist or [])
482 key = name+':'+','.join(fromlist or [])
483 if level <= 0 and key not in modules:
483 if level <= 0 and key not in modules:
484 modules.add(key)
484 modules.add(key)
485 if not quiet:
485 if not quiet:
486 if fromlist:
486 if fromlist:
487 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
487 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
488 else:
488 else:
489 print("importing %s on engine(s)"%name)
489 print("importing %s on engine(s)"%name)
490 results.append(self.apply_async(remote_import, name, fromlist, level))
490 results.append(self.apply_async(remote_import, name, fromlist, level))
491 # restore override
491 # restore override
492 builtin_mod.__import__ = save_import
492 builtin_mod.__import__ = save_import
493
493
494 return mod
494 return mod
495
495
496 # override __import__
496 # override __import__
497 builtin_mod.__import__ = view_import
497 builtin_mod.__import__ = view_import
498 try:
498 try:
499 # enter the block
499 # enter the block
500 yield
500 yield
501 except ImportError:
501 except ImportError:
502 if local:
502 if local:
503 raise
503 raise
504 else:
504 else:
505 # ignore import errors if not doing local imports
505 # ignore import errors if not doing local imports
506 pass
506 pass
507 finally:
507 finally:
508 # always restore __import__
508 # always restore __import__
509 builtin_mod.__import__ = local_import
509 builtin_mod.__import__ = local_import
510
510
511 for r in results:
511 for r in results:
512 # raise possible remote ImportErrors here
512 # raise possible remote ImportErrors here
513 r.get()
513 r.get()
514
514
515
515
516 @sync_results
516 @sync_results
517 @save_ids
517 @save_ids
518 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
518 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
519 """calls f(*args, **kwargs) on remote engines, returning the result.
519 """calls f(*args, **kwargs) on remote engines, returning the result.
520
520
521 This method sets all of `apply`'s flags via this View's attributes.
521 This method sets all of `apply`'s flags via this View's attributes.
522
522
523 Parameters
523 Parameters
524 ----------
524 ----------
525
525
526 f : callable
526 f : callable
527
527
528 args : list [default: empty]
528 args : list [default: empty]
529
529
530 kwargs : dict [default: empty]
530 kwargs : dict [default: empty]
531
531
532 targets : target list [default: self.targets]
532 targets : target list [default: self.targets]
533 where to run
533 where to run
534 block : bool [default: self.block]
534 block : bool [default: self.block]
535 whether to block
535 whether to block
536 track : bool [default: self.track]
536 track : bool [default: self.track]
537 whether to ask zmq to track the message, for safe non-copying sends
537 whether to ask zmq to track the message, for safe non-copying sends
538
538
539 Returns
539 Returns
540 -------
540 -------
541
541
542 if self.block is False:
542 if self.block is False:
543 returns AsyncResult
543 returns AsyncResult
544 else:
544 else:
545 returns actual result of f(*args, **kwargs) on the engine(s)
545 returns actual result of f(*args, **kwargs) on the engine(s)
546 This will be a list of self.targets is also a list (even length 1), or
546 This will be a list of self.targets is also a list (even length 1), or
547 the single result if self.targets is an integer engine id
547 the single result if self.targets is an integer engine id
548 """
548 """
549 args = [] if args is None else args
549 args = [] if args is None else args
550 kwargs = {} if kwargs is None else kwargs
550 kwargs = {} if kwargs is None else kwargs
551 block = self.block if block is None else block
551 block = self.block if block is None else block
552 track = self.track if track is None else track
552 track = self.track if track is None else track
553 targets = self.targets if targets is None else targets
553 targets = self.targets if targets is None else targets
554
554
555 _idents, _targets = self.client._build_targets(targets)
555 _idents, _targets = self.client._build_targets(targets)
556 msg_ids = []
556 msg_ids = []
557 trackers = []
557 trackers = []
558 for ident in _idents:
558 for ident in _idents:
559 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
559 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
560 ident=ident)
560 ident=ident)
561 if track:
561 if track:
562 trackers.append(msg['tracker'])
562 trackers.append(msg['tracker'])
563 msg_ids.append(msg['header']['msg_id'])
563 msg_ids.append(msg['header']['msg_id'])
564 if isinstance(targets, int):
564 if isinstance(targets, int):
565 msg_ids = msg_ids[0]
565 msg_ids = msg_ids[0]
566 tracker = None if track is False else zmq.MessageTracker(*trackers)
566 tracker = None if track is False else zmq.MessageTracker(*trackers)
567 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
567 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
568 if block:
568 if block:
569 try:
569 try:
570 return ar.get()
570 return ar.get()
571 except KeyboardInterrupt:
571 except KeyboardInterrupt:
572 pass
572 pass
573 return ar
573 return ar
574
574
575
575
576 @sync_results
576 @sync_results
577 def map(self, f, *sequences, **kwargs):
577 def map(self, f, *sequences, **kwargs):
578 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
578 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
579
579
580 Parallel version of builtin `map`, using this View's `targets`.
580 Parallel version of builtin `map`, using this View's `targets`.
581
581
582 There will be one task per target, so work will be chunked
582 There will be one task per target, so work will be chunked
583 if the sequences are longer than `targets`.
583 if the sequences are longer than `targets`.
584
584
585 Results can be iterated as they are ready, but will become available in chunks.
585 Results can be iterated as they are ready, but will become available in chunks.
586
586
587 Parameters
587 Parameters
588 ----------
588 ----------
589
589
590 f : callable
590 f : callable
591 function to be mapped
591 function to be mapped
592 *sequences: one or more sequences of matching length
592 *sequences: one or more sequences of matching length
593 the sequences to be distributed and passed to `f`
593 the sequences to be distributed and passed to `f`
594 block : bool
594 block : bool
595 whether to wait for the result or not [default self.block]
595 whether to wait for the result or not [default self.block]
596
596
597 Returns
597 Returns
598 -------
598 -------
599
599
600 if block=False:
600 if block=False:
601 AsyncMapResult
601 AsyncMapResult
602 An object like AsyncResult, but which reassembles the sequence of results
602 An object like AsyncResult, but which reassembles the sequence of results
603 into a single list. AsyncMapResults can be iterated through before all
603 into a single list. AsyncMapResults can be iterated through before all
604 results are complete.
604 results are complete.
605 else:
605 else:
606 list
606 list
607 the result of map(f,*sequences)
607 the result of map(f,*sequences)
608 """
608 """
609
609
610 block = kwargs.pop('block', self.block)
610 block = kwargs.pop('block', self.block)
611 for k in kwargs.keys():
611 for k in kwargs.keys():
612 if k not in ['block', 'track']:
612 if k not in ['block', 'track']:
613 raise TypeError("invalid keyword arg, %r"%k)
613 raise TypeError("invalid keyword arg, %r"%k)
614
614
615 assert len(sequences) > 0, "must have some sequences to map onto!"
615 assert len(sequences) > 0, "must have some sequences to map onto!"
616 pf = ParallelFunction(self, f, block=block, **kwargs)
616 pf = ParallelFunction(self, f, block=block, **kwargs)
617 return pf.map(*sequences)
617 return pf.map(*sequences)
618
618
619 @sync_results
619 @sync_results
620 @save_ids
620 @save_ids
621 def execute(self, code, silent=True, targets=None, block=None):
621 def execute(self, code, silent=True, targets=None, block=None):
622 """Executes `code` on `targets` in blocking or nonblocking manner.
622 """Executes `code` on `targets` in blocking or nonblocking manner.
623
623
624 ``execute`` is always `bound` (affects engine namespace)
624 ``execute`` is always `bound` (affects engine namespace)
625
625
626 Parameters
626 Parameters
627 ----------
627 ----------
628
628
629 code : str
629 code : str
630 the code string to be executed
630 the code string to be executed
631 block : bool
631 block : bool
632 whether or not to wait until done to return
632 whether or not to wait until done to return
633 default: self.block
633 default: self.block
634 """
634 """
635 block = self.block if block is None else block
635 block = self.block if block is None else block
636 targets = self.targets if targets is None else targets
636 targets = self.targets if targets is None else targets
637
637
638 _idents, _targets = self.client._build_targets(targets)
638 _idents, _targets = self.client._build_targets(targets)
639 msg_ids = []
639 msg_ids = []
640 trackers = []
640 trackers = []
641 for ident in _idents:
641 for ident in _idents:
642 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
642 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
643 msg_ids.append(msg['header']['msg_id'])
643 msg_ids.append(msg['header']['msg_id'])
644 if isinstance(targets, int):
644 if isinstance(targets, int):
645 msg_ids = msg_ids[0]
645 msg_ids = msg_ids[0]
646 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
646 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
647 if block:
647 if block:
648 try:
648 try:
649 ar.get()
649 ar.get()
650 except KeyboardInterrupt:
650 except KeyboardInterrupt:
651 pass
651 pass
652 return ar
652 return ar
653
653
654 def run(self, filename, targets=None, block=None):
654 def run(self, filename, targets=None, block=None):
655 """Execute contents of `filename` on my engine(s).
655 """Execute contents of `filename` on my engine(s).
656
656
657 This simply reads the contents of the file and calls `execute`.
657 This simply reads the contents of the file and calls `execute`.
658
658
659 Parameters
659 Parameters
660 ----------
660 ----------
661
661
662 filename : str
662 filename : str
663 The path to the file
663 The path to the file
664 targets : int/str/list of ints/strs
664 targets : int/str/list of ints/strs
665 the engines on which to execute
665 the engines on which to execute
666 default : all
666 default : all
667 block : bool
667 block : bool
668 whether or not to wait until done
668 whether or not to wait until done
669 default: self.block
669 default: self.block
670
670
671 """
671 """
672 with open(filename, 'r') as f:
672 with open(filename, 'r') as f:
673 # add newline in case of trailing indented whitespace
673 # add newline in case of trailing indented whitespace
674 # which will cause SyntaxError
674 # which will cause SyntaxError
675 code = f.read()+'\n'
675 code = f.read()+'\n'
676 return self.execute(code, block=block, targets=targets)
676 return self.execute(code, block=block, targets=targets)
677
677
678 def update(self, ns):
678 def update(self, ns):
679 """update remote namespace with dict `ns`
679 """update remote namespace with dict `ns`
680
680
681 See `push` for details.
681 See `push` for details.
682 """
682 """
683 return self.push(ns, block=self.block, track=self.track)
683 return self.push(ns, block=self.block, track=self.track)
684
684
685 def push(self, ns, targets=None, block=None, track=None):
685 def push(self, ns, targets=None, block=None, track=None):
686 """update remote namespace with dict `ns`
686 """update remote namespace with dict `ns`
687
687
688 Parameters
688 Parameters
689 ----------
689 ----------
690
690
691 ns : dict
691 ns : dict
692 dict of keys with which to update engine namespace(s)
692 dict of keys with which to update engine namespace(s)
693 block : bool [default : self.block]
693 block : bool [default : self.block]
694 whether to wait to be notified of engine receipt
694 whether to wait to be notified of engine receipt
695
695
696 """
696 """
697
697
698 block = block if block is not None else self.block
698 block = block if block is not None else self.block
699 track = track if track is not None else self.track
699 track = track if track is not None else self.track
700 targets = targets if targets is not None else self.targets
700 targets = targets if targets is not None else self.targets
701 # applier = self.apply_sync if block else self.apply_async
701 # applier = self.apply_sync if block else self.apply_async
702 if not isinstance(ns, dict):
702 if not isinstance(ns, dict):
703 raise TypeError("Must be a dict, not %s"%type(ns))
703 raise TypeError("Must be a dict, not %s"%type(ns))
704 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
704 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
705
705
706 def get(self, key_s):
706 def get(self, key_s):
707 """get object(s) by `key_s` from remote namespace
707 """get object(s) by `key_s` from remote namespace
708
708
709 see `pull` for details.
709 see `pull` for details.
710 """
710 """
711 # block = block if block is not None else self.block
711 # block = block if block is not None else self.block
712 return self.pull(key_s, block=True)
712 return self.pull(key_s, block=True)
713
713
714 def pull(self, names, targets=None, block=None):
714 def pull(self, names, targets=None, block=None):
715 """get object(s) by `name` from remote namespace
715 """get object(s) by `name` from remote namespace
716
716
717 will return one object if it is a key.
717 will return one object if it is a key.
718 can also take a list of keys, in which case it will return a list of objects.
718 can also take a list of keys, in which case it will return a list of objects.
719 """
719 """
720 block = block if block is not None else self.block
720 block = block if block is not None else self.block
721 targets = targets if targets is not None else self.targets
721 targets = targets if targets is not None else self.targets
722 applier = self.apply_sync if block else self.apply_async
722 applier = self.apply_sync if block else self.apply_async
723 if isinstance(names, string_types):
723 if isinstance(names, string_types):
724 pass
724 pass
725 elif isinstance(names, (list,tuple,set)):
725 elif isinstance(names, (list,tuple,set)):
726 for key in names:
726 for key in names:
727 if not isinstance(key, string_types):
727 if not isinstance(key, string_types):
728 raise TypeError("keys must be str, not type %r"%type(key))
728 raise TypeError("keys must be str, not type %r"%type(key))
729 else:
729 else:
730 raise TypeError("names must be strs, not %r"%names)
730 raise TypeError("names must be strs, not %r"%names)
731 return self._really_apply(util._pull, (names,), block=block, targets=targets)
731 return self._really_apply(util._pull, (names,), block=block, targets=targets)
732
732
733 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
733 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
734 """
734 """
735 Partition a Python sequence and send the partitions to a set of engines.
735 Partition a Python sequence and send the partitions to a set of engines.
736 """
736 """
737 block = block if block is not None else self.block
737 block = block if block is not None else self.block
738 track = track if track is not None else self.track
738 track = track if track is not None else self.track
739 targets = targets if targets is not None else self.targets
739 targets = targets if targets is not None else self.targets
740
740
741 # construct integer ID list:
741 # construct integer ID list:
742 targets = self.client._build_targets(targets)[1]
742 targets = self.client._build_targets(targets)[1]
743
743
744 mapObject = Map.dists[dist]()
744 mapObject = Map.dists[dist]()
745 nparts = len(targets)
745 nparts = len(targets)
746 msg_ids = []
746 msg_ids = []
747 trackers = []
747 trackers = []
748 for index, engineid in enumerate(targets):
748 for index, engineid in enumerate(targets):
749 partition = mapObject.getPartition(seq, index, nparts)
749 partition = mapObject.getPartition(seq, index, nparts)
750 if flatten and len(partition) == 1:
750 if flatten and len(partition) == 1:
751 ns = {key: partition[0]}
751 ns = {key: partition[0]}
752 else:
752 else:
753 ns = {key: partition}
753 ns = {key: partition}
754 r = self.push(ns, block=False, track=track, targets=engineid)
754 r = self.push(ns, block=False, track=track, targets=engineid)
755 msg_ids.extend(r.msg_ids)
755 msg_ids.extend(r.msg_ids)
756 if track:
756 if track:
757 trackers.append(r._tracker)
757 trackers.append(r._tracker)
758
758
759 if track:
759 if track:
760 tracker = zmq.MessageTracker(*trackers)
760 tracker = zmq.MessageTracker(*trackers)
761 else:
761 else:
762 tracker = None
762 tracker = None
763
763
764 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
764 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
765 if block:
765 if block:
766 r.wait()
766 r.wait()
767 else:
767 else:
768 return r
768 return r
769
769
770 @sync_results
770 @sync_results
771 @save_ids
771 @save_ids
772 def gather(self, key, dist='b', targets=None, block=None):
772 def gather(self, key, dist='b', targets=None, block=None):
773 """
773 """
774 Gather a partitioned sequence on a set of engines as a single local seq.
774 Gather a partitioned sequence on a set of engines as a single local seq.
775 """
775 """
776 block = block if block is not None else self.block
776 block = block if block is not None else self.block
777 targets = targets if targets is not None else self.targets
777 targets = targets if targets is not None else self.targets
778 mapObject = Map.dists[dist]()
778 mapObject = Map.dists[dist]()
779 msg_ids = []
779 msg_ids = []
780
780
781 # construct integer ID list:
781 # construct integer ID list:
782 targets = self.client._build_targets(targets)[1]
782 targets = self.client._build_targets(targets)[1]
783
783
784 for index, engineid in enumerate(targets):
784 for index, engineid in enumerate(targets):
785 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
785 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
786
786
787 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
787 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
788
788
789 if block:
789 if block:
790 try:
790 try:
791 return r.get()
791 return r.get()
792 except KeyboardInterrupt:
792 except KeyboardInterrupt:
793 pass
793 pass
794 return r
794 return r
795
795
796 def __getitem__(self, key):
796 def __getitem__(self, key):
797 return self.get(key)
797 return self.get(key)
798
798
799 def __setitem__(self,key, value):
799 def __setitem__(self,key, value):
800 self.update({key:value})
800 self.update({key:value})
801
801
802 def clear(self, targets=None, block=None):
802 def clear(self, targets=None, block=None):
803 """Clear the remote namespaces on my engines."""
803 """Clear the remote namespaces on my engines."""
804 block = block if block is not None else self.block
804 block = block if block is not None else self.block
805 targets = targets if targets is not None else self.targets
805 targets = targets if targets is not None else self.targets
806 return self.client.clear(targets=targets, block=block)
806 return self.client.clear(targets=targets, block=block)
807
807
808 #----------------------------------------
808 #----------------------------------------
809 # activate for %px, %autopx, etc. magics
809 # activate for %px, %autopx, etc. magics
810 #----------------------------------------
810 #----------------------------------------
811
811
812 def activate(self, suffix=''):
812 def activate(self, suffix=''):
813 """Activate IPython magics associated with this View
813 """Activate IPython magics associated with this View
814
814
815 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
815 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
816
816
817 Parameters
817 Parameters
818 ----------
818 ----------
819
819
820 suffix: str [default: '']
820 suffix: str [default: '']
821 The suffix, if any, for the magics. This allows you to have
821 The suffix, if any, for the magics. This allows you to have
822 multiple views associated with parallel magics at the same time.
822 multiple views associated with parallel magics at the same time.
823
823
824 e.g. ``rc[::2].activate(suffix='_even')`` will give you
824 e.g. ``rc[::2].activate(suffix='_even')`` will give you
825 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
825 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
826 on the even engines.
826 on the even engines.
827 """
827 """
828
828
829 from IPython.parallel.client.magics import ParallelMagics
829 from IPython.parallel.client.magics import ParallelMagics
830
830
831 try:
831 try:
832 # This is injected into __builtins__.
832 # This is injected into __builtins__.
833 ip = get_ipython()
833 ip = get_ipython()
834 except NameError:
834 except NameError:
835 print("The IPython parallel magics (%px, etc.) only work within IPython.")
835 print("The IPython parallel magics (%px, etc.) only work within IPython.")
836 return
836 return
837
837
838 M = ParallelMagics(ip, self, suffix)
838 M = ParallelMagics(ip, self, suffix)
839 ip.magics_manager.register(M)
839 ip.magics_manager.register(M)
840
840
841
841
842 @skip_doctest
842 @skip_doctest
843 class LoadBalancedView(View):
843 class LoadBalancedView(View):
844 """An load-balancing View that only executes via the Task scheduler.
844 """An load-balancing View that only executes via the Task scheduler.
845
845
846 Load-balanced views can be created with the client's `view` method:
846 Load-balanced views can be created with the client's `view` method:
847
847
848 >>> v = client.load_balanced_view()
848 >>> v = client.load_balanced_view()
849
849
850 or targets can be specified, to restrict the potential destinations:
850 or targets can be specified, to restrict the potential destinations:
851
851
852 >>> v = client.client.load_balanced_view([1,3])
852 >>> v = client.client.load_balanced_view([1,3])
853
853
854 which would restrict loadbalancing to between engines 1 and 3.
854 which would restrict loadbalancing to between engines 1 and 3.
855
855
856 """
856 """
857
857
858 follow=Any()
858 follow=Any()
859 after=Any()
859 after=Any()
860 timeout=CFloat()
860 timeout=CFloat()
861 retries = Integer(0)
861 retries = Integer(0)
862
862
863 _task_scheme = Any()
863 _task_scheme = Any()
864 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
864 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
865
865
866 def __init__(self, client=None, socket=None, **flags):
866 def __init__(self, client=None, socket=None, **flags):
867 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
867 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
868 self._task_scheme=client._task_scheme
868 self._task_scheme=client._task_scheme
869
869
870 def _validate_dependency(self, dep):
870 def _validate_dependency(self, dep):
871 """validate a dependency.
871 """validate a dependency.
872
872
873 For use in `set_flags`.
873 For use in `set_flags`.
874 """
874 """
875 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
875 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
876 return True
876 return True
877 elif isinstance(dep, (list,set, tuple)):
877 elif isinstance(dep, (list,set, tuple)):
878 for d in dep:
878 for d in dep:
879 if not isinstance(d, string_types + (AsyncResult,)):
879 if not isinstance(d, string_types + (AsyncResult,)):
880 return False
880 return False
881 elif isinstance(dep, dict):
881 elif isinstance(dep, dict):
882 if set(dep.keys()) != set(Dependency().as_dict().keys()):
882 if set(dep.keys()) != set(Dependency().as_dict().keys()):
883 return False
883 return False
884 if not isinstance(dep['msg_ids'], list):
884 if not isinstance(dep['msg_ids'], list):
885 return False
885 return False
886 for d in dep['msg_ids']:
886 for d in dep['msg_ids']:
887 if not isinstance(d, string_types):
887 if not isinstance(d, string_types):
888 return False
888 return False
889 else:
889 else:
890 return False
890 return False
891
891
892 return True
892 return True
893
893
894 def _render_dependency(self, dep):
894 def _render_dependency(self, dep):
895 """helper for building jsonable dependencies from various input forms."""
895 """helper for building jsonable dependencies from various input forms."""
896 if isinstance(dep, Dependency):
896 if isinstance(dep, Dependency):
897 return dep.as_dict()
897 return dep.as_dict()
898 elif isinstance(dep, AsyncResult):
898 elif isinstance(dep, AsyncResult):
899 return dep.msg_ids
899 return dep.msg_ids
900 elif dep is None:
900 elif dep is None:
901 return []
901 return []
902 else:
902 else:
903 # pass to Dependency constructor
903 # pass to Dependency constructor
904 return list(Dependency(dep))
904 return list(Dependency(dep))
905
905
906 def set_flags(self, **kwargs):
906 def set_flags(self, **kwargs):
907 """set my attribute flags by keyword.
907 """set my attribute flags by keyword.
908
908
909 A View is a wrapper for the Client's apply method, but with attributes
909 A View is a wrapper for the Client's apply method, but with attributes
910 that specify keyword arguments, those attributes can be set by keyword
910 that specify keyword arguments, those attributes can be set by keyword
911 argument with this method.
911 argument with this method.
912
912
913 Parameters
913 Parameters
914 ----------
914 ----------
915
915
916 block : bool
916 block : bool
917 whether to wait for results
917 whether to wait for results
918 track : bool
918 track : bool
919 whether to create a MessageTracker to allow the user to
919 whether to create a MessageTracker to allow the user to
920 safely edit after arrays and buffers during non-copying
920 safely edit after arrays and buffers during non-copying
921 sends.
921 sends.
922
922
923 after : Dependency or collection of msg_ids
923 after : Dependency or collection of msg_ids
924 Only for load-balanced execution (targets=None)
924 Only for load-balanced execution (targets=None)
925 Specify a list of msg_ids as a time-based dependency.
925 Specify a list of msg_ids as a time-based dependency.
926 This job will only be run *after* the dependencies
926 This job will only be run *after* the dependencies
927 have been met.
927 have been met.
928
928
929 follow : Dependency or collection of msg_ids
929 follow : Dependency or collection of msg_ids
930 Only for load-balanced execution (targets=None)
930 Only for load-balanced execution (targets=None)
931 Specify a list of msg_ids as a location-based dependency.
931 Specify a list of msg_ids as a location-based dependency.
932 This job will only be run on an engine where this dependency
932 This job will only be run on an engine where this dependency
933 is met.
933 is met.
934
934
935 timeout : float/int or None
935 timeout : float/int or None
936 Only for load-balanced execution (targets=None)
936 Only for load-balanced execution (targets=None)
937 Specify an amount of time (in seconds) for the scheduler to
937 Specify an amount of time (in seconds) for the scheduler to
938 wait for dependencies to be met before failing with a
938 wait for dependencies to be met before failing with a
939 DependencyTimeout.
939 DependencyTimeout.
940
940
941 retries : int
941 retries : int
942 Number of times a task will be retried on failure.
942 Number of times a task will be retried on failure.
943 """
943 """
944
944
945 super(LoadBalancedView, self).set_flags(**kwargs)
945 super(LoadBalancedView, self).set_flags(**kwargs)
946 for name in ('follow', 'after'):
946 for name in ('follow', 'after'):
947 if name in kwargs:
947 if name in kwargs:
948 value = kwargs[name]
948 value = kwargs[name]
949 if self._validate_dependency(value):
949 if self._validate_dependency(value):
950 setattr(self, name, value)
950 setattr(self, name, value)
951 else:
951 else:
952 raise ValueError("Invalid dependency: %r"%value)
952 raise ValueError("Invalid dependency: %r"%value)
953 if 'timeout' in kwargs:
953 if 'timeout' in kwargs:
954 t = kwargs['timeout']
954 t = kwargs['timeout']
955 if not isinstance(t, (int, long, float, type(None))):
955 if not isinstance(t, (int, float, type(None))):
956 if (not PY3) and (not isinstance(t, long)):
956 raise TypeError("Invalid type for timeout: %r"%type(t))
957 raise TypeError("Invalid type for timeout: %r"%type(t))
957 if t is not None:
958 if t is not None:
958 if t < 0:
959 if t < 0:
959 raise ValueError("Invalid timeout: %s"%t)
960 raise ValueError("Invalid timeout: %s"%t)
960 self.timeout = t
961 self.timeout = t
961
962
962 @sync_results
963 @sync_results
963 @save_ids
964 @save_ids
964 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
965 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
965 after=None, follow=None, timeout=None,
966 after=None, follow=None, timeout=None,
966 targets=None, retries=None):
967 targets=None, retries=None):
967 """calls f(*args, **kwargs) on a remote engine, returning the result.
968 """calls f(*args, **kwargs) on a remote engine, returning the result.
968
969
969 This method temporarily sets all of `apply`'s flags for a single call.
970 This method temporarily sets all of `apply`'s flags for a single call.
970
971
971 Parameters
972 Parameters
972 ----------
973 ----------
973
974
974 f : callable
975 f : callable
975
976
976 args : list [default: empty]
977 args : list [default: empty]
977
978
978 kwargs : dict [default: empty]
979 kwargs : dict [default: empty]
979
980
980 block : bool [default: self.block]
981 block : bool [default: self.block]
981 whether to block
982 whether to block
982 track : bool [default: self.track]
983 track : bool [default: self.track]
983 whether to ask zmq to track the message, for safe non-copying sends
984 whether to ask zmq to track the message, for safe non-copying sends
984
985
985 !!!!!! TODO: THE REST HERE !!!!
986 !!!!!! TODO: THE REST HERE !!!!
986
987
987 Returns
988 Returns
988 -------
989 -------
989
990
990 if self.block is False:
991 if self.block is False:
991 returns AsyncResult
992 returns AsyncResult
992 else:
993 else:
993 returns actual result of f(*args, **kwargs) on the engine(s)
994 returns actual result of f(*args, **kwargs) on the engine(s)
994 This will be a list of self.targets is also a list (even length 1), or
995 This will be a list of self.targets is also a list (even length 1), or
995 the single result if self.targets is an integer engine id
996 the single result if self.targets is an integer engine id
996 """
997 """
997
998
998 # validate whether we can run
999 # validate whether we can run
999 if self._socket.closed:
1000 if self._socket.closed:
1000 msg = "Task farming is disabled"
1001 msg = "Task farming is disabled"
1001 if self._task_scheme == 'pure':
1002 if self._task_scheme == 'pure':
1002 msg += " because the pure ZMQ scheduler cannot handle"
1003 msg += " because the pure ZMQ scheduler cannot handle"
1003 msg += " disappearing engines."
1004 msg += " disappearing engines."
1004 raise RuntimeError(msg)
1005 raise RuntimeError(msg)
1005
1006
1006 if self._task_scheme == 'pure':
1007 if self._task_scheme == 'pure':
1007 # pure zmq scheme doesn't support extra features
1008 # pure zmq scheme doesn't support extra features
1008 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1009 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1009 "follow, after, retries, targets, timeout"
1010 "follow, after, retries, targets, timeout"
1010 if (follow or after or retries or targets or timeout):
1011 if (follow or after or retries or targets or timeout):
1011 # hard fail on Scheduler flags
1012 # hard fail on Scheduler flags
1012 raise RuntimeError(msg)
1013 raise RuntimeError(msg)
1013 if isinstance(f, dependent):
1014 if isinstance(f, dependent):
1014 # soft warn on functional dependencies
1015 # soft warn on functional dependencies
1015 warnings.warn(msg, RuntimeWarning)
1016 warnings.warn(msg, RuntimeWarning)
1016
1017
1017 # build args
1018 # build args
1018 args = [] if args is None else args
1019 args = [] if args is None else args
1019 kwargs = {} if kwargs is None else kwargs
1020 kwargs = {} if kwargs is None else kwargs
1020 block = self.block if block is None else block
1021 block = self.block if block is None else block
1021 track = self.track if track is None else track
1022 track = self.track if track is None else track
1022 after = self.after if after is None else after
1023 after = self.after if after is None else after
1023 retries = self.retries if retries is None else retries
1024 retries = self.retries if retries is None else retries
1024 follow = self.follow if follow is None else follow
1025 follow = self.follow if follow is None else follow
1025 timeout = self.timeout if timeout is None else timeout
1026 timeout = self.timeout if timeout is None else timeout
1026 targets = self.targets if targets is None else targets
1027 targets = self.targets if targets is None else targets
1027
1028
1028 if not isinstance(retries, int):
1029 if not isinstance(retries, int):
1029 raise TypeError('retries must be int, not %r'%type(retries))
1030 raise TypeError('retries must be int, not %r'%type(retries))
1030
1031
1031 if targets is None:
1032 if targets is None:
1032 idents = []
1033 idents = []
1033 else:
1034 else:
1034 idents = self.client._build_targets(targets)[0]
1035 idents = self.client._build_targets(targets)[0]
1035 # ensure *not* bytes
1036 # ensure *not* bytes
1036 idents = [ ident.decode() for ident in idents ]
1037 idents = [ ident.decode() for ident in idents ]
1037
1038
1038 after = self._render_dependency(after)
1039 after = self._render_dependency(after)
1039 follow = self._render_dependency(follow)
1040 follow = self._render_dependency(follow)
1040 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1041 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1041
1042
1042 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1043 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1043 metadata=metadata)
1044 metadata=metadata)
1044 tracker = None if track is False else msg['tracker']
1045 tracker = None if track is False else msg['tracker']
1045
1046
1046 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1047 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1047
1048
1048 if block:
1049 if block:
1049 try:
1050 try:
1050 return ar.get()
1051 return ar.get()
1051 except KeyboardInterrupt:
1052 except KeyboardInterrupt:
1052 pass
1053 pass
1053 return ar
1054 return ar
1054
1055
1055 @sync_results
1056 @sync_results
1056 @save_ids
1057 @save_ids
1057 def map(self, f, *sequences, **kwargs):
1058 def map(self, f, *sequences, **kwargs):
1058 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1059 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1059
1060
1060 Parallel version of builtin `map`, load-balanced by this View.
1061 Parallel version of builtin `map`, load-balanced by this View.
1061
1062
1062 `block`, and `chunksize` can be specified by keyword only.
1063 `block`, and `chunksize` can be specified by keyword only.
1063
1064
1064 Each `chunksize` elements will be a separate task, and will be
1065 Each `chunksize` elements will be a separate task, and will be
1065 load-balanced. This lets individual elements be available for iteration
1066 load-balanced. This lets individual elements be available for iteration
1066 as soon as they arrive.
1067 as soon as they arrive.
1067
1068
1068 Parameters
1069 Parameters
1069 ----------
1070 ----------
1070
1071
1071 f : callable
1072 f : callable
1072 function to be mapped
1073 function to be mapped
1073 *sequences: one or more sequences of matching length
1074 *sequences: one or more sequences of matching length
1074 the sequences to be distributed and passed to `f`
1075 the sequences to be distributed and passed to `f`
1075 block : bool [default self.block]
1076 block : bool [default self.block]
1076 whether to wait for the result or not
1077 whether to wait for the result or not
1077 track : bool
1078 track : bool
1078 whether to create a MessageTracker to allow the user to
1079 whether to create a MessageTracker to allow the user to
1079 safely edit after arrays and buffers during non-copying
1080 safely edit after arrays and buffers during non-copying
1080 sends.
1081 sends.
1081 chunksize : int [default 1]
1082 chunksize : int [default 1]
1082 how many elements should be in each task.
1083 how many elements should be in each task.
1083 ordered : bool [default True]
1084 ordered : bool [default True]
1084 Whether the results should be gathered as they arrive, or enforce
1085 Whether the results should be gathered as they arrive, or enforce
1085 the order of submission.
1086 the order of submission.
1086
1087
1087 Only applies when iterating through AsyncMapResult as results arrive.
1088 Only applies when iterating through AsyncMapResult as results arrive.
1088 Has no effect when block=True.
1089 Has no effect when block=True.
1089
1090
1090 Returns
1091 Returns
1091 -------
1092 -------
1092
1093
1093 if block=False:
1094 if block=False:
1094 AsyncMapResult
1095 AsyncMapResult
1095 An object like AsyncResult, but which reassembles the sequence of results
1096 An object like AsyncResult, but which reassembles the sequence of results
1096 into a single list. AsyncMapResults can be iterated through before all
1097 into a single list. AsyncMapResults can be iterated through before all
1097 results are complete.
1098 results are complete.
1098 else:
1099 else:
1099 the result of map(f,*sequences)
1100 the result of map(f,*sequences)
1100
1101
1101 """
1102 """
1102
1103
1103 # default
1104 # default
1104 block = kwargs.get('block', self.block)
1105 block = kwargs.get('block', self.block)
1105 chunksize = kwargs.get('chunksize', 1)
1106 chunksize = kwargs.get('chunksize', 1)
1106 ordered = kwargs.get('ordered', True)
1107 ordered = kwargs.get('ordered', True)
1107
1108
1108 keyset = set(kwargs.keys())
1109 keyset = set(kwargs.keys())
1109 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1110 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1110 if extra_keys:
1111 if extra_keys:
1111 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1112 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1112
1113
1113 assert len(sequences) > 0, "must have some sequences to map onto!"
1114 assert len(sequences) > 0, "must have some sequences to map onto!"
1114
1115
1115 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1116 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1116 return pf.map(*sequences)
1117 return pf.map(*sequences)
1117
1118
1118 __all__ = ['LoadBalancedView', 'DirectView']
1119 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,226 +1,230 b''
1 """Dependency utilities
1 """Dependency utilities
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2013 The IPython Development Team
8 # Copyright (C) 2013 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 from types import ModuleType
14 from types import ModuleType
15
15
16 from IPython.parallel.client.asyncresult import AsyncResult
16 from IPython.parallel.client.asyncresult import AsyncResult
17 from IPython.parallel.error import UnmetDependency
17 from IPython.parallel.error import UnmetDependency
18 from IPython.parallel.util import interactive
18 from IPython.parallel.util import interactive
19 from IPython.utils import py3compat
19 from IPython.utils import py3compat
20 from IPython.utils.py3compat import string_types
20 from IPython.utils.py3compat import string_types
21 from IPython.utils.pickleutil import can, uncan
21 from IPython.utils.pickleutil import can, uncan
22
22
23 class depend(object):
23 class depend(object):
24 """Dependency decorator, for use with tasks.
24 """Dependency decorator, for use with tasks.
25
25
26 `@depend` lets you define a function for engine dependencies
26 `@depend` lets you define a function for engine dependencies
27 just like you use `apply` for tasks.
27 just like you use `apply` for tasks.
28
28
29
29
30 Examples
30 Examples
31 --------
31 --------
32 ::
32 ::
33
33
34 @depend(df, a,b, c=5)
34 @depend(df, a,b, c=5)
35 def f(m,n,p)
35 def f(m,n,p)
36
36
37 view.apply(f, 1,2,3)
37 view.apply(f, 1,2,3)
38
38
39 will call df(a,b,c=5) on the engine, and if it returns False or
39 will call df(a,b,c=5) on the engine, and if it returns False or
40 raises an UnmetDependency error, then the task will not be run
40 raises an UnmetDependency error, then the task will not be run
41 and another engine will be tried.
41 and another engine will be tried.
42 """
42 """
43 def __init__(self, f, *args, **kwargs):
43 def __init__(self, f, *args, **kwargs):
44 self.f = f
44 self.f = f
45 self.args = args
45 self.args = args
46 self.kwargs = kwargs
46 self.kwargs = kwargs
47
47
48 def __call__(self, f):
48 def __call__(self, f):
49 return dependent(f, self.f, *self.args, **self.kwargs)
49 return dependent(f, self.f, *self.args, **self.kwargs)
50
50
51 class dependent(object):
51 class dependent(object):
52 """A function that depends on another function.
52 """A function that depends on another function.
53 This is an object to prevent the closure used
53 This is an object to prevent the closure used
54 in traditional decorators, which are not picklable.
54 in traditional decorators, which are not picklable.
55 """
55 """
56
56
57 def __init__(self, f, df, *dargs, **dkwargs):
57 def __init__(self, f, df, *dargs, **dkwargs):
58 self.f = f
58 self.f = f
59 self.__name__ = getattr(f, '__name__', 'f')
59 name = getattr(f, '__name__', 'f')
60 if py3compat.PY3:
61 self.__name__ = name
62 else:
63 self.func_name = name
60 self.df = df
64 self.df = df
61 self.dargs = dargs
65 self.dargs = dargs
62 self.dkwargs = dkwargs
66 self.dkwargs = dkwargs
63
67
64 def check_dependency(self):
68 def check_dependency(self):
65 if self.df(*self.dargs, **self.dkwargs) is False:
69 if self.df(*self.dargs, **self.dkwargs) is False:
66 raise UnmetDependency()
70 raise UnmetDependency()
67
71
68 def __call__(self, *args, **kwargs):
72 def __call__(self, *args, **kwargs):
69 return self.f(*args, **kwargs)
73 return self.f(*args, **kwargs)
70
74
71 if not py3compat.PY3:
75 if not py3compat.PY3:
72 @property
76 @property
73 def __name__(self):
77 def __name__(self):
74 return self.__name__
78 return self.__name__
75
79
76 @interactive
80 @interactive
77 def _require(*modules, **mapping):
81 def _require(*modules, **mapping):
78 """Helper for @require decorator."""
82 """Helper for @require decorator."""
79 from IPython.parallel.error import UnmetDependency
83 from IPython.parallel.error import UnmetDependency
80 from IPython.utils.pickleutil import uncan
84 from IPython.utils.pickleutil import uncan
81 user_ns = globals()
85 user_ns = globals()
82 for name in modules:
86 for name in modules:
83 try:
87 try:
84 exec('import %s' % name, user_ns)
88 exec('import %s' % name, user_ns)
85 except ImportError:
89 except ImportError:
86 raise UnmetDependency(name)
90 raise UnmetDependency(name)
87
91
88 for name, cobj in mapping.items():
92 for name, cobj in mapping.items():
89 user_ns[name] = uncan(cobj, user_ns)
93 user_ns[name] = uncan(cobj, user_ns)
90 return True
94 return True
91
95
92 def require(*objects, **mapping):
96 def require(*objects, **mapping):
93 """Simple decorator for requiring local objects and modules to be available
97 """Simple decorator for requiring local objects and modules to be available
94 when the decorated function is called on the engine.
98 when the decorated function is called on the engine.
95
99
96 Modules specified by name or passed directly will be imported
100 Modules specified by name or passed directly will be imported
97 prior to calling the decorated function.
101 prior to calling the decorated function.
98
102
99 Objects other than modules will be pushed as a part of the task.
103 Objects other than modules will be pushed as a part of the task.
100 Functions can be passed positionally,
104 Functions can be passed positionally,
101 and will be pushed to the engine with their __name__.
105 and will be pushed to the engine with their __name__.
102 Other objects can be passed by keyword arg.
106 Other objects can be passed by keyword arg.
103
107
104 Examples
108 Examples
105 --------
109 --------
106
110
107 In [1]: @require('numpy')
111 In [1]: @require('numpy')
108 ...: def norm(a):
112 ...: def norm(a):
109 ...: return numpy.linalg.norm(a,2)
113 ...: return numpy.linalg.norm(a,2)
110
114
111 In [2]: foo = lambda x: x*x
115 In [2]: foo = lambda x: x*x
112 In [3]: @require(foo)
116 In [3]: @require(foo)
113 ...: def bar(a):
117 ...: def bar(a):
114 ...: return foo(1-a)
118 ...: return foo(1-a)
115 """
119 """
116 names = []
120 names = []
117 for obj in objects:
121 for obj in objects:
118 if isinstance(obj, ModuleType):
122 if isinstance(obj, ModuleType):
119 obj = obj.__name__
123 obj = obj.__name__
120
124
121 if isinstance(obj, string_types):
125 if isinstance(obj, string_types):
122 names.append(obj)
126 names.append(obj)
123 elif hasattr(obj, '__name__'):
127 elif hasattr(obj, '__name__'):
124 mapping[obj.__name__] = obj
128 mapping[obj.__name__] = obj
125 else:
129 else:
126 raise TypeError("Objects other than modules and functions "
130 raise TypeError("Objects other than modules and functions "
127 "must be passed by kwarg, but got: %s" % type(obj)
131 "must be passed by kwarg, but got: %s" % type(obj)
128 )
132 )
129
133
130 for name, obj in mapping.items():
134 for name, obj in mapping.items():
131 mapping[name] = can(obj)
135 mapping[name] = can(obj)
132 return depend(_require, *names, **mapping)
136 return depend(_require, *names, **mapping)
133
137
134 class Dependency(set):
138 class Dependency(set):
135 """An object for representing a set of msg_id dependencies.
139 """An object for representing a set of msg_id dependencies.
136
140
137 Subclassed from set().
141 Subclassed from set().
138
142
139 Parameters
143 Parameters
140 ----------
144 ----------
141 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
145 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
142 The msg_ids to depend on
146 The msg_ids to depend on
143 all : bool [default True]
147 all : bool [default True]
144 Whether the dependency should be considered met when *all* depending tasks have completed
148 Whether the dependency should be considered met when *all* depending tasks have completed
145 or only when *any* have been completed.
149 or only when *any* have been completed.
146 success : bool [default True]
150 success : bool [default True]
147 Whether to consider successes as fulfilling dependencies.
151 Whether to consider successes as fulfilling dependencies.
148 failure : bool [default False]
152 failure : bool [default False]
149 Whether to consider failures as fulfilling dependencies.
153 Whether to consider failures as fulfilling dependencies.
150
154
151 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
155 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
152 as soon as the first depended-upon task fails.
156 as soon as the first depended-upon task fails.
153 """
157 """
154
158
155 all=True
159 all=True
156 success=True
160 success=True
157 failure=True
161 failure=True
158
162
159 def __init__(self, dependencies=[], all=True, success=True, failure=False):
163 def __init__(self, dependencies=[], all=True, success=True, failure=False):
160 if isinstance(dependencies, dict):
164 if isinstance(dependencies, dict):
161 # load from dict
165 # load from dict
162 all = dependencies.get('all', True)
166 all = dependencies.get('all', True)
163 success = dependencies.get('success', success)
167 success = dependencies.get('success', success)
164 failure = dependencies.get('failure', failure)
168 failure = dependencies.get('failure', failure)
165 dependencies = dependencies.get('dependencies', [])
169 dependencies = dependencies.get('dependencies', [])
166 ids = []
170 ids = []
167
171
168 # extract ids from various sources:
172 # extract ids from various sources:
169 if isinstance(dependencies, string_types + (AsyncResult,)):
173 if isinstance(dependencies, string_types + (AsyncResult,)):
170 dependencies = [dependencies]
174 dependencies = [dependencies]
171 for d in dependencies:
175 for d in dependencies:
172 if isinstance(d, string_types):
176 if isinstance(d, string_types):
173 ids.append(d)
177 ids.append(d)
174 elif isinstance(d, AsyncResult):
178 elif isinstance(d, AsyncResult):
175 ids.extend(d.msg_ids)
179 ids.extend(d.msg_ids)
176 else:
180 else:
177 raise TypeError("invalid dependency type: %r"%type(d))
181 raise TypeError("invalid dependency type: %r"%type(d))
178
182
179 set.__init__(self, ids)
183 set.__init__(self, ids)
180 self.all = all
184 self.all = all
181 if not (success or failure):
185 if not (success or failure):
182 raise ValueError("Must depend on at least one of successes or failures!")
186 raise ValueError("Must depend on at least one of successes or failures!")
183 self.success=success
187 self.success=success
184 self.failure = failure
188 self.failure = failure
185
189
186 def check(self, completed, failed=None):
190 def check(self, completed, failed=None):
187 """check whether our dependencies have been met."""
191 """check whether our dependencies have been met."""
188 if len(self) == 0:
192 if len(self) == 0:
189 return True
193 return True
190 against = set()
194 against = set()
191 if self.success:
195 if self.success:
192 against = completed
196 against = completed
193 if failed is not None and self.failure:
197 if failed is not None and self.failure:
194 against = against.union(failed)
198 against = against.union(failed)
195 if self.all:
199 if self.all:
196 return self.issubset(against)
200 return self.issubset(against)
197 else:
201 else:
198 return not self.isdisjoint(against)
202 return not self.isdisjoint(against)
199
203
200 def unreachable(self, completed, failed=None):
204 def unreachable(self, completed, failed=None):
201 """return whether this dependency has become impossible."""
205 """return whether this dependency has become impossible."""
202 if len(self) == 0:
206 if len(self) == 0:
203 return False
207 return False
204 against = set()
208 against = set()
205 if not self.success:
209 if not self.success:
206 against = completed
210 against = completed
207 if failed is not None and not self.failure:
211 if failed is not None and not self.failure:
208 against = against.union(failed)
212 against = against.union(failed)
209 if self.all:
213 if self.all:
210 return not self.isdisjoint(against)
214 return not self.isdisjoint(against)
211 else:
215 else:
212 return self.issubset(against)
216 return self.issubset(against)
213
217
214
218
215 def as_dict(self):
219 def as_dict(self):
216 """Represent this dependency as a dict. For json compatibility."""
220 """Represent this dependency as a dict. For json compatibility."""
217 return dict(
221 return dict(
218 dependencies=list(self),
222 dependencies=list(self),
219 all=self.all,
223 all=self.all,
220 success=self.success,
224 success=self.success,
221 failure=self.failure
225 failure=self.failure
222 )
226 )
223
227
224
228
225 __all__ = ['depend', 'require', 'dependent', 'Dependency']
229 __all__ = ['depend', 'require', 'dependent', 'Dependency']
226
230
@@ -1,190 +1,192 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 A multi-heart Heartbeat system using PUB and ROUTER sockets. pings are sent out on the PUB,
3 A multi-heart Heartbeat system using PUB and ROUTER sockets. pings are sent out on the PUB,
4 and hearts are tracked based on their DEALER identities.
4 and hearts are tracked based on their DEALER identities.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 from __future__ import print_function
17 from __future__ import print_function
18 import time
18 import time
19 import uuid
19 import uuid
20
20
21 import zmq
21 import zmq
22 from zmq.devices import ThreadDevice, ThreadMonitoredQueue
22 from zmq.devices import ThreadDevice, ThreadMonitoredQueue
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 from IPython.config.configurable import LoggingConfigurable
25 from IPython.config.configurable import LoggingConfigurable
26 from IPython.utils.py3compat import str_to_bytes
26 from IPython.utils.py3compat import str_to_bytes
27 from IPython.utils.traitlets import Set, Instance, CFloat, Integer, Dict
27 from IPython.utils.traitlets import Set, Instance, CFloat, Integer, Dict
28
28
29 from IPython.parallel.util import log_errors
29 from IPython.parallel.util import log_errors
30
30
31 class Heart(object):
31 class Heart(object):
32 """A basic heart object for responding to a HeartMonitor.
32 """A basic heart object for responding to a HeartMonitor.
33 This is a simple wrapper with defaults for the most common
33 This is a simple wrapper with defaults for the most common
34 Device model for responding to heartbeats.
34 Device model for responding to heartbeats.
35
35
36 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
37 SUB/DEALER for in/out.
37 SUB/DEALER for in/out.
38
38
39 You can specify the DEALER's IDENTITY via the optional heart_id argument."""
39 You can specify the DEALER's IDENTITY via the optional heart_id argument."""
40 device=None
40 device=None
41 id=None
41 id=None
42 def __init__(self, in_addr, out_addr, mon_addr=None, in_type=zmq.SUB, out_type=zmq.DEALER, mon_type=zmq.PUB, heart_id=None):
42 def __init__(self, in_addr, out_addr, mon_addr=None, in_type=zmq.SUB, out_type=zmq.DEALER, mon_type=zmq.PUB, heart_id=None):
43 if mon_addr is None:
43 if mon_addr is None:
44 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
44 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
45 else:
45 else:
46 self.device = ThreadMonitoredQueue(in_type, out_type, mon_type, in_prefix=b"", out_prefix=b"")
46 self.device = ThreadMonitoredQueue(in_type, out_type, mon_type, in_prefix=b"", out_prefix=b"")
47 # do not allow the device to share global Context.instance,
47 # do not allow the device to share global Context.instance,
48 # which is the default behavior in pyzmq > 2.1.10
48 # which is the default behavior in pyzmq > 2.1.10
49 self.device.context_factory = zmq.Context
49 self.device.context_factory = zmq.Context
50
50
51 self.device.daemon=True
51 self.device.daemon=True
52 self.device.connect_in(in_addr)
52 self.device.connect_in(in_addr)
53 self.device.connect_out(out_addr)
53 self.device.connect_out(out_addr)
54 if mon_addr is not None:
54 if mon_addr is not None:
55 self.device.connect_mon(mon_addr)
55 self.device.connect_mon(mon_addr)
56 if in_type == zmq.SUB:
56 if in_type == zmq.SUB:
57 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
57 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
58 if heart_id is None:
58 if heart_id is None:
59 heart_id = uuid.uuid4().bytes
59 heart_id = uuid.uuid4().bytes
60 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
60 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
61 self.id = heart_id
61 self.id = heart_id
62
62
63 def start(self):
63 def start(self):
64 return self.device.start()
64 return self.device.start()
65
65
66
66
67 class HeartMonitor(LoggingConfigurable):
67 class HeartMonitor(LoggingConfigurable):
68 """A basic HeartMonitor class
68 """A basic HeartMonitor class
69 pingstream: a PUB stream
69 pingstream: a PUB stream
70 pongstream: an ROUTER stream
70 pongstream: an ROUTER stream
71 period: the period of the heartbeat in milliseconds"""
71 period: the period of the heartbeat in milliseconds"""
72
72
73 period = Integer(3000, config=True,
73 period = Integer(3000, config=True,
74 help='The frequency at which the Hub pings the engines for heartbeats '
74 help='The frequency at which the Hub pings the engines for heartbeats '
75 '(in ms)',
75 '(in ms)',
76 )
76 )
77 max_heartmonitor_misses = Integer(10, config=True,
77 max_heartmonitor_misses = Integer(10, config=True,
78 help='Allowed consecutive missed pings from controller Hub to engine before unregistering.',
78 help='Allowed consecutive missed pings from controller Hub to engine before unregistering.',
79 )
79 )
80
80
81 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
81 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
82 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
82 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
83 loop = Instance('zmq.eventloop.ioloop.IOLoop')
83 loop = Instance('zmq.eventloop.ioloop.IOLoop')
84 def _loop_default(self):
84 def _loop_default(self):
85 return ioloop.IOLoop.instance()
85 return ioloop.IOLoop.instance()
86
86
87 # not settable:
87 # not settable:
88 hearts=Set()
88 hearts=Set()
89 responses=Set()
89 responses=Set()
90 on_probation=Dict()
90 on_probation=Dict()
91 last_ping=CFloat(0)
91 last_ping=CFloat(0)
92 _new_handlers = Set()
92 _new_handlers = Set()
93 _failure_handlers = Set()
93 _failure_handlers = Set()
94 lifetime = CFloat(0)
94 lifetime = CFloat(0)
95 tic = CFloat(0)
95 tic = CFloat(0)
96
96
97 def __init__(self, **kwargs):
97 def __init__(self, **kwargs):
98 super(HeartMonitor, self).__init__(**kwargs)
98 super(HeartMonitor, self).__init__(**kwargs)
99
99
100 self.pongstream.on_recv(self.handle_pong)
100 self.pongstream.on_recv(self.handle_pong)
101
101
102 def start(self):
102 def start(self):
103 self.tic = time.time()
103 self.tic = time.time()
104 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
104 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
105 self.caller.start()
105 self.caller.start()
106
106
107 def add_new_heart_handler(self, handler):
107 def add_new_heart_handler(self, handler):
108 """add a new handler for new hearts"""
108 """add a new handler for new hearts"""
109 self.log.debug("heartbeat::new_heart_handler: %s", handler)
109 self.log.debug("heartbeat::new_heart_handler: %s", handler)
110 self._new_handlers.add(handler)
110 self._new_handlers.add(handler)
111
111
112 def add_heart_failure_handler(self, handler):
112 def add_heart_failure_handler(self, handler):
113 """add a new handler for heart failure"""
113 """add a new handler for heart failure"""
114 self.log.debug("heartbeat::new heart failure handler: %s", handler)
114 self.log.debug("heartbeat::new heart failure handler: %s", handler)
115 self._failure_handlers.add(handler)
115 self._failure_handlers.add(handler)
116
116
117 def beat(self):
117 def beat(self):
118 self.pongstream.flush()
118 self.pongstream.flush()
119 self.last_ping = self.lifetime
119 self.last_ping = self.lifetime
120
120
121 toc = time.time()
121 toc = time.time()
122 self.lifetime += toc-self.tic
122 self.lifetime += toc-self.tic
123 self.tic = toc
123 self.tic = toc
124 self.log.debug("heartbeat::sending %s", self.lifetime)
124 self.log.debug("heartbeat::sending %s", self.lifetime)
125 goodhearts = self.hearts.intersection(self.responses)
125 goodhearts = self.hearts.intersection(self.responses)
126 missed_beats = self.hearts.difference(goodhearts)
126 missed_beats = self.hearts.difference(goodhearts)
127 newhearts = self.responses.difference(goodhearts)
127 newhearts = self.responses.difference(goodhearts)
128 map(self.handle_new_heart, newhearts)
128 for heart in newhearts:
129 self.handle_new_heart(heart)
129 heartfailures, on_probation = self._check_missed(missed_beats, self.on_probation,
130 heartfailures, on_probation = self._check_missed(missed_beats, self.on_probation,
130 self.hearts)
131 self.hearts)
131 map(self.handle_heart_failure, heartfailures)
132 for failure in heartfailures:
133 self.handle_heart_failure(failure)
132 self.on_probation = on_probation
134 self.on_probation = on_probation
133 self.responses = set()
135 self.responses = set()
134 #print self.on_probation, self.hearts
136 #print self.on_probation, self.hearts
135 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
137 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
136 self.pingstream.send(str_to_bytes(str(self.lifetime)))
138 self.pingstream.send(str_to_bytes(str(self.lifetime)))
137 # flush stream to force immediate socket send
139 # flush stream to force immediate socket send
138 self.pingstream.flush()
140 self.pingstream.flush()
139
141
140 def _check_missed(self, missed_beats, on_probation, hearts):
142 def _check_missed(self, missed_beats, on_probation, hearts):
141 """Update heartbeats on probation, identifying any that have too many misses.
143 """Update heartbeats on probation, identifying any that have too many misses.
142 """
144 """
143 failures = []
145 failures = []
144 new_probation = {}
146 new_probation = {}
145 for cur_heart in (b for b in missed_beats if b in hearts):
147 for cur_heart in (b for b in missed_beats if b in hearts):
146 miss_count = on_probation.get(cur_heart, 0) + 1
148 miss_count = on_probation.get(cur_heart, 0) + 1
147 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count))
149 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count))
148 if miss_count > self.max_heartmonitor_misses:
150 if miss_count > self.max_heartmonitor_misses:
149 failures.append(cur_heart)
151 failures.append(cur_heart)
150 else:
152 else:
151 new_probation[cur_heart] = miss_count
153 new_probation[cur_heart] = miss_count
152 return failures, new_probation
154 return failures, new_probation
153
155
154 def handle_new_heart(self, heart):
156 def handle_new_heart(self, heart):
155 if self._new_handlers:
157 if self._new_handlers:
156 for handler in self._new_handlers:
158 for handler in self._new_handlers:
157 handler(heart)
159 handler(heart)
158 else:
160 else:
159 self.log.info("heartbeat::yay, got new heart %s!", heart)
161 self.log.info("heartbeat::yay, got new heart %s!", heart)
160 self.hearts.add(heart)
162 self.hearts.add(heart)
161
163
162 def handle_heart_failure(self, heart):
164 def handle_heart_failure(self, heart):
163 if self._failure_handlers:
165 if self._failure_handlers:
164 for handler in self._failure_handlers:
166 for handler in self._failure_handlers:
165 try:
167 try:
166 handler(heart)
168 handler(heart)
167 except Exception as e:
169 except Exception as e:
168 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
170 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
169 pass
171 pass
170 else:
172 else:
171 self.log.info("heartbeat::Heart %s failed :(", heart)
173 self.log.info("heartbeat::Heart %s failed :(", heart)
172 self.hearts.remove(heart)
174 self.hearts.remove(heart)
173
175
174
176
175 @log_errors
177 @log_errors
176 def handle_pong(self, msg):
178 def handle_pong(self, msg):
177 "a heart just beat"
179 "a heart just beat"
178 current = str_to_bytes(str(self.lifetime))
180 current = str_to_bytes(str(self.lifetime))
179 last = str_to_bytes(str(self.last_ping))
181 last = str_to_bytes(str(self.last_ping))
180 if msg[1] == current:
182 if msg[1] == current:
181 delta = time.time()-self.tic
183 delta = time.time()-self.tic
182 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
184 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
183 self.responses.add(msg[0])
185 self.responses.add(msg[0])
184 elif msg[1] == last:
186 elif msg[1] == last:
185 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
187 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
186 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
188 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
187 self.responses.add(msg[0])
189 self.responses.add(msg[0])
188 else:
190 else:
189 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
191 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
190
192
@@ -1,1421 +1,1421 b''
1 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
2 This is the master object that handles connections from engines and clients,
2 This is the master object that handles connections from engines and clients,
3 and monitors traffic through the various queues.
3 and monitors traffic through the various queues.
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 """
8 """
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 from __future__ import print_function
19 from __future__ import print_function
20
20
21 import json
21 import json
22 import os
22 import os
23 import sys
23 import sys
24 import time
24 import time
25 from datetime import datetime
25 from datetime import datetime
26
26
27 import zmq
27 import zmq
28 from zmq.eventloop import ioloop
28 from zmq.eventloop import ioloop
29 from zmq.eventloop.zmqstream import ZMQStream
29 from zmq.eventloop.zmqstream import ZMQStream
30
30
31 # internal:
31 # internal:
32 from IPython.utils.importstring import import_item
32 from IPython.utils.importstring import import_item
33 from IPython.utils.localinterfaces import localhost
33 from IPython.utils.localinterfaces import localhost
34 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
34 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
35 from IPython.utils.traitlets import (
35 from IPython.utils.traitlets import (
36 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
36 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 )
37 )
38
38
39 from IPython.parallel import error, util
39 from IPython.parallel import error, util
40 from IPython.parallel.factory import RegistrationFactory
40 from IPython.parallel.factory import RegistrationFactory
41
41
42 from IPython.kernel.zmq.session import SessionFactory
42 from IPython.kernel.zmq.session import SessionFactory
43
43
44 from .heartmonitor import HeartMonitor
44 from .heartmonitor import HeartMonitor
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Code
47 # Code
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50 def _passer(*args, **kwargs):
50 def _passer(*args, **kwargs):
51 return
51 return
52
52
53 def _printer(*args, **kwargs):
53 def _printer(*args, **kwargs):
54 print (args)
54 print (args)
55 print (kwargs)
55 print (kwargs)
56
56
57 def empty_record():
57 def empty_record():
58 """Return an empty dict with all record keys."""
58 """Return an empty dict with all record keys."""
59 return {
59 return {
60 'msg_id' : None,
60 'msg_id' : None,
61 'header' : None,
61 'header' : None,
62 'metadata' : None,
62 'metadata' : None,
63 'content': None,
63 'content': None,
64 'buffers': None,
64 'buffers': None,
65 'submitted': None,
65 'submitted': None,
66 'client_uuid' : None,
66 'client_uuid' : None,
67 'engine_uuid' : None,
67 'engine_uuid' : None,
68 'started': None,
68 'started': None,
69 'completed': None,
69 'completed': None,
70 'resubmitted': None,
70 'resubmitted': None,
71 'received': None,
71 'received': None,
72 'result_header' : None,
72 'result_header' : None,
73 'result_metadata' : None,
73 'result_metadata' : None,
74 'result_content' : None,
74 'result_content' : None,
75 'result_buffers' : None,
75 'result_buffers' : None,
76 'queue' : None,
76 'queue' : None,
77 'pyin' : None,
77 'pyin' : None,
78 'pyout': None,
78 'pyout': None,
79 'pyerr': None,
79 'pyerr': None,
80 'stdout': '',
80 'stdout': '',
81 'stderr': '',
81 'stderr': '',
82 }
82 }
83
83
84 def init_record(msg):
84 def init_record(msg):
85 """Initialize a TaskRecord based on a request."""
85 """Initialize a TaskRecord based on a request."""
86 header = msg['header']
86 header = msg['header']
87 return {
87 return {
88 'msg_id' : header['msg_id'],
88 'msg_id' : header['msg_id'],
89 'header' : header,
89 'header' : header,
90 'content': msg['content'],
90 'content': msg['content'],
91 'metadata': msg['metadata'],
91 'metadata': msg['metadata'],
92 'buffers': msg['buffers'],
92 'buffers': msg['buffers'],
93 'submitted': header['date'],
93 'submitted': header['date'],
94 'client_uuid' : None,
94 'client_uuid' : None,
95 'engine_uuid' : None,
95 'engine_uuid' : None,
96 'started': None,
96 'started': None,
97 'completed': None,
97 'completed': None,
98 'resubmitted': None,
98 'resubmitted': None,
99 'received': None,
99 'received': None,
100 'result_header' : None,
100 'result_header' : None,
101 'result_metadata': None,
101 'result_metadata': None,
102 'result_content' : None,
102 'result_content' : None,
103 'result_buffers' : None,
103 'result_buffers' : None,
104 'queue' : None,
104 'queue' : None,
105 'pyin' : None,
105 'pyin' : None,
106 'pyout': None,
106 'pyout': None,
107 'pyerr': None,
107 'pyerr': None,
108 'stdout': '',
108 'stdout': '',
109 'stderr': '',
109 'stderr': '',
110 }
110 }
111
111
112
112
113 class EngineConnector(HasTraits):
113 class EngineConnector(HasTraits):
114 """A simple object for accessing the various zmq connections of an object.
114 """A simple object for accessing the various zmq connections of an object.
115 Attributes are:
115 Attributes are:
116 id (int): engine ID
116 id (int): engine ID
117 uuid (unicode): engine UUID
117 uuid (unicode): engine UUID
118 pending: set of msg_ids
118 pending: set of msg_ids
119 stallback: DelayedCallback for stalled registration
119 stallback: DelayedCallback for stalled registration
120 """
120 """
121
121
122 id = Integer(0)
122 id = Integer(0)
123 uuid = Unicode()
123 uuid = Unicode()
124 pending = Set()
124 pending = Set()
125 stallback = Instance(ioloop.DelayedCallback)
125 stallback = Instance(ioloop.DelayedCallback)
126
126
127
127
128 _db_shortcuts = {
128 _db_shortcuts = {
129 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
129 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
130 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
131 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
132 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 }
133 }
134
134
135 class HubFactory(RegistrationFactory):
135 class HubFactory(RegistrationFactory):
136 """The Configurable for setting up a Hub."""
136 """The Configurable for setting up a Hub."""
137
137
138 # port-pairs for monitoredqueues:
138 # port-pairs for monitoredqueues:
139 hb = Tuple(Integer,Integer,config=True,
139 hb = Tuple(Integer,Integer,config=True,
140 help="""PUB/ROUTER Port pair for Engine heartbeats""")
140 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 def _hb_default(self):
141 def _hb_default(self):
142 return tuple(util.select_random_ports(2))
142 return tuple(util.select_random_ports(2))
143
143
144 mux = Tuple(Integer,Integer,config=True,
144 mux = Tuple(Integer,Integer,config=True,
145 help="""Client/Engine Port pair for MUX queue""")
145 help="""Client/Engine Port pair for MUX queue""")
146
146
147 def _mux_default(self):
147 def _mux_default(self):
148 return tuple(util.select_random_ports(2))
148 return tuple(util.select_random_ports(2))
149
149
150 task = Tuple(Integer,Integer,config=True,
150 task = Tuple(Integer,Integer,config=True,
151 help="""Client/Engine Port pair for Task queue""")
151 help="""Client/Engine Port pair for Task queue""")
152 def _task_default(self):
152 def _task_default(self):
153 return tuple(util.select_random_ports(2))
153 return tuple(util.select_random_ports(2))
154
154
155 control = Tuple(Integer,Integer,config=True,
155 control = Tuple(Integer,Integer,config=True,
156 help="""Client/Engine Port pair for Control queue""")
156 help="""Client/Engine Port pair for Control queue""")
157
157
158 def _control_default(self):
158 def _control_default(self):
159 return tuple(util.select_random_ports(2))
159 return tuple(util.select_random_ports(2))
160
160
161 iopub = Tuple(Integer,Integer,config=True,
161 iopub = Tuple(Integer,Integer,config=True,
162 help="""Client/Engine Port pair for IOPub relay""")
162 help="""Client/Engine Port pair for IOPub relay""")
163
163
164 def _iopub_default(self):
164 def _iopub_default(self):
165 return tuple(util.select_random_ports(2))
165 return tuple(util.select_random_ports(2))
166
166
167 # single ports:
167 # single ports:
168 mon_port = Integer(config=True,
168 mon_port = Integer(config=True,
169 help="""Monitor (SUB) port for queue traffic""")
169 help="""Monitor (SUB) port for queue traffic""")
170
170
171 def _mon_port_default(self):
171 def _mon_port_default(self):
172 return util.select_random_ports(1)[0]
172 return util.select_random_ports(1)[0]
173
173
174 notifier_port = Integer(config=True,
174 notifier_port = Integer(config=True,
175 help="""PUB port for sending engine status notifications""")
175 help="""PUB port for sending engine status notifications""")
176
176
177 def _notifier_port_default(self):
177 def _notifier_port_default(self):
178 return util.select_random_ports(1)[0]
178 return util.select_random_ports(1)[0]
179
179
180 engine_ip = Unicode(config=True,
180 engine_ip = Unicode(config=True,
181 help="IP on which to listen for engine connections. [default: loopback]")
181 help="IP on which to listen for engine connections. [default: loopback]")
182 def _engine_ip_default(self):
182 def _engine_ip_default(self):
183 return localhost()
183 return localhost()
184 engine_transport = Unicode('tcp', config=True,
184 engine_transport = Unicode('tcp', config=True,
185 help="0MQ transport for engine connections. [default: tcp]")
185 help="0MQ transport for engine connections. [default: tcp]")
186
186
187 client_ip = Unicode(config=True,
187 client_ip = Unicode(config=True,
188 help="IP on which to listen for client connections. [default: loopback]")
188 help="IP on which to listen for client connections. [default: loopback]")
189 client_transport = Unicode('tcp', config=True,
189 client_transport = Unicode('tcp', config=True,
190 help="0MQ transport for client connections. [default : tcp]")
190 help="0MQ transport for client connections. [default : tcp]")
191
191
192 monitor_ip = Unicode(config=True,
192 monitor_ip = Unicode(config=True,
193 help="IP on which to listen for monitor messages. [default: loopback]")
193 help="IP on which to listen for monitor messages. [default: loopback]")
194 monitor_transport = Unicode('tcp', config=True,
194 monitor_transport = Unicode('tcp', config=True,
195 help="0MQ transport for monitor messages. [default : tcp]")
195 help="0MQ transport for monitor messages. [default : tcp]")
196
196
197 _client_ip_default = _monitor_ip_default = _engine_ip_default
197 _client_ip_default = _monitor_ip_default = _engine_ip_default
198
198
199
199
200 monitor_url = Unicode('')
200 monitor_url = Unicode('')
201
201
202 db_class = DottedObjectName('NoDB',
202 db_class = DottedObjectName('NoDB',
203 config=True, help="""The class to use for the DB backend
203 config=True, help="""The class to use for the DB backend
204
204
205 Options include:
205 Options include:
206
206
207 SQLiteDB: SQLite
207 SQLiteDB: SQLite
208 MongoDB : use MongoDB
208 MongoDB : use MongoDB
209 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
209 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 NoDB : disable database altogether (default)
210 NoDB : disable database altogether (default)
211
211
212 """)
212 """)
213
213
214 # not configurable
214 # not configurable
215 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
215 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
216 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217
217
218 def _ip_changed(self, name, old, new):
218 def _ip_changed(self, name, old, new):
219 self.engine_ip = new
219 self.engine_ip = new
220 self.client_ip = new
220 self.client_ip = new
221 self.monitor_ip = new
221 self.monitor_ip = new
222 self._update_monitor_url()
222 self._update_monitor_url()
223
223
224 def _update_monitor_url(self):
224 def _update_monitor_url(self):
225 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
225 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226
226
227 def _transport_changed(self, name, old, new):
227 def _transport_changed(self, name, old, new):
228 self.engine_transport = new
228 self.engine_transport = new
229 self.client_transport = new
229 self.client_transport = new
230 self.monitor_transport = new
230 self.monitor_transport = new
231 self._update_monitor_url()
231 self._update_monitor_url()
232
232
233 def __init__(self, **kwargs):
233 def __init__(self, **kwargs):
234 super(HubFactory, self).__init__(**kwargs)
234 super(HubFactory, self).__init__(**kwargs)
235 self._update_monitor_url()
235 self._update_monitor_url()
236
236
237
237
238 def construct(self):
238 def construct(self):
239 self.init_hub()
239 self.init_hub()
240
240
241 def start(self):
241 def start(self):
242 self.heartmonitor.start()
242 self.heartmonitor.start()
243 self.log.info("Heartmonitor started")
243 self.log.info("Heartmonitor started")
244
244
245 def client_url(self, channel):
245 def client_url(self, channel):
246 """return full zmq url for a named client channel"""
246 """return full zmq url for a named client channel"""
247 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
247 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248
248
249 def engine_url(self, channel):
249 def engine_url(self, channel):
250 """return full zmq url for a named engine channel"""
250 """return full zmq url for a named engine channel"""
251 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
251 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252
252
253 def init_hub(self):
253 def init_hub(self):
254 """construct Hub object"""
254 """construct Hub object"""
255
255
256 ctx = self.context
256 ctx = self.context
257 loop = self.loop
257 loop = self.loop
258 if 'TaskScheduler.scheme_name' in self.config:
258 if 'TaskScheduler.scheme_name' in self.config:
259 scheme = self.config.TaskScheduler.scheme_name
259 scheme = self.config.TaskScheduler.scheme_name
260 else:
260 else:
261 from .scheduler import TaskScheduler
261 from .scheduler import TaskScheduler
262 scheme = TaskScheduler.scheme_name.get_default_value()
262 scheme = TaskScheduler.scheme_name.get_default_value()
263
263
264 # build connection dicts
264 # build connection dicts
265 engine = self.engine_info = {
265 engine = self.engine_info = {
266 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
266 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
267 'registration' : self.regport,
267 'registration' : self.regport,
268 'control' : self.control[1],
268 'control' : self.control[1],
269 'mux' : self.mux[1],
269 'mux' : self.mux[1],
270 'hb_ping' : self.hb[0],
270 'hb_ping' : self.hb[0],
271 'hb_pong' : self.hb[1],
271 'hb_pong' : self.hb[1],
272 'task' : self.task[1],
272 'task' : self.task[1],
273 'iopub' : self.iopub[1],
273 'iopub' : self.iopub[1],
274 }
274 }
275
275
276 client = self.client_info = {
276 client = self.client_info = {
277 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
277 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
278 'registration' : self.regport,
278 'registration' : self.regport,
279 'control' : self.control[0],
279 'control' : self.control[0],
280 'mux' : self.mux[0],
280 'mux' : self.mux[0],
281 'task' : self.task[0],
281 'task' : self.task[0],
282 'task_scheme' : scheme,
282 'task_scheme' : scheme,
283 'iopub' : self.iopub[0],
283 'iopub' : self.iopub[0],
284 'notification' : self.notifier_port,
284 'notification' : self.notifier_port,
285 }
285 }
286
286
287 self.log.debug("Hub engine addrs: %s", self.engine_info)
287 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub client addrs: %s", self.client_info)
288 self.log.debug("Hub client addrs: %s", self.client_info)
289
289
290 # Registrar socket
290 # Registrar socket
291 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
291 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
292 util.set_hwm(q, 0)
292 util.set_hwm(q, 0)
293 q.bind(self.client_url('registration'))
293 q.bind(self.client_url('registration'))
294 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
294 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
295 if self.client_ip != self.engine_ip:
295 if self.client_ip != self.engine_ip:
296 q.bind(self.engine_url('registration'))
296 q.bind(self.engine_url('registration'))
297 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
297 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
298
298
299 ### Engine connections ###
299 ### Engine connections ###
300
300
301 # heartbeat
301 # heartbeat
302 hpub = ctx.socket(zmq.PUB)
302 hpub = ctx.socket(zmq.PUB)
303 hpub.bind(self.engine_url('hb_ping'))
303 hpub.bind(self.engine_url('hb_ping'))
304 hrep = ctx.socket(zmq.ROUTER)
304 hrep = ctx.socket(zmq.ROUTER)
305 util.set_hwm(hrep, 0)
305 util.set_hwm(hrep, 0)
306 hrep.bind(self.engine_url('hb_pong'))
306 hrep.bind(self.engine_url('hb_pong'))
307 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
307 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
308 pingstream=ZMQStream(hpub,loop),
308 pingstream=ZMQStream(hpub,loop),
309 pongstream=ZMQStream(hrep,loop)
309 pongstream=ZMQStream(hrep,loop)
310 )
310 )
311
311
312 ### Client connections ###
312 ### Client connections ###
313
313
314 # Notifier socket
314 # Notifier socket
315 n = ZMQStream(ctx.socket(zmq.PUB), loop)
315 n = ZMQStream(ctx.socket(zmq.PUB), loop)
316 n.bind(self.client_url('notification'))
316 n.bind(self.client_url('notification'))
317
317
318 ### build and launch the queues ###
318 ### build and launch the queues ###
319
319
320 # monitor socket
320 # monitor socket
321 sub = ctx.socket(zmq.SUB)
321 sub = ctx.socket(zmq.SUB)
322 sub.setsockopt(zmq.SUBSCRIBE, b"")
322 sub.setsockopt(zmq.SUBSCRIBE, b"")
323 sub.bind(self.monitor_url)
323 sub.bind(self.monitor_url)
324 sub.bind('inproc://monitor')
324 sub.bind('inproc://monitor')
325 sub = ZMQStream(sub, loop)
325 sub = ZMQStream(sub, loop)
326
326
327 # connect the db
327 # connect the db
328 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
328 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
329 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
329 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
330 self.db = import_item(str(db_class))(session=self.session.session,
330 self.db = import_item(str(db_class))(session=self.session.session,
331 parent=self, log=self.log)
331 parent=self, log=self.log)
332 time.sleep(.25)
332 time.sleep(.25)
333
333
334 # resubmit stream
334 # resubmit stream
335 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
335 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
336 url = util.disambiguate_url(self.client_url('task'))
336 url = util.disambiguate_url(self.client_url('task'))
337 r.connect(url)
337 r.connect(url)
338
338
339 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
339 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
340 query=q, notifier=n, resubmit=r, db=self.db,
340 query=q, notifier=n, resubmit=r, db=self.db,
341 engine_info=self.engine_info, client_info=self.client_info,
341 engine_info=self.engine_info, client_info=self.client_info,
342 log=self.log)
342 log=self.log)
343
343
344
344
345 class Hub(SessionFactory):
345 class Hub(SessionFactory):
346 """The IPython Controller Hub with 0MQ connections
346 """The IPython Controller Hub with 0MQ connections
347
347
348 Parameters
348 Parameters
349 ==========
349 ==========
350 loop: zmq IOLoop instance
350 loop: zmq IOLoop instance
351 session: Session object
351 session: Session object
352 <removed> context: zmq context for creating new connections (?)
352 <removed> context: zmq context for creating new connections (?)
353 queue: ZMQStream for monitoring the command queue (SUB)
353 queue: ZMQStream for monitoring the command queue (SUB)
354 query: ZMQStream for engine registration and client queries requests (ROUTER)
354 query: ZMQStream for engine registration and client queries requests (ROUTER)
355 heartbeat: HeartMonitor object checking the pulse of the engines
355 heartbeat: HeartMonitor object checking the pulse of the engines
356 notifier: ZMQStream for broadcasting engine registration changes (PUB)
356 notifier: ZMQStream for broadcasting engine registration changes (PUB)
357 db: connection to db for out of memory logging of commands
357 db: connection to db for out of memory logging of commands
358 NotImplemented
358 NotImplemented
359 engine_info: dict of zmq connection information for engines to connect
359 engine_info: dict of zmq connection information for engines to connect
360 to the queues.
360 to the queues.
361 client_info: dict of zmq connection information for engines to connect
361 client_info: dict of zmq connection information for engines to connect
362 to the queues.
362 to the queues.
363 """
363 """
364
364
365 engine_state_file = Unicode()
365 engine_state_file = Unicode()
366
366
367 # internal data structures:
367 # internal data structures:
368 ids=Set() # engine IDs
368 ids=Set() # engine IDs
369 keytable=Dict()
369 keytable=Dict()
370 by_ident=Dict()
370 by_ident=Dict()
371 engines=Dict()
371 engines=Dict()
372 clients=Dict()
372 clients=Dict()
373 hearts=Dict()
373 hearts=Dict()
374 pending=Set()
374 pending=Set()
375 queues=Dict() # pending msg_ids keyed by engine_id
375 queues=Dict() # pending msg_ids keyed by engine_id
376 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
376 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
377 completed=Dict() # completed msg_ids keyed by engine_id
377 completed=Dict() # completed msg_ids keyed by engine_id
378 all_completed=Set() # completed msg_ids keyed by engine_id
378 all_completed=Set() # completed msg_ids keyed by engine_id
379 dead_engines=Set() # completed msg_ids keyed by engine_id
379 dead_engines=Set() # completed msg_ids keyed by engine_id
380 unassigned=Set() # set of task msg_ds not yet assigned a destination
380 unassigned=Set() # set of task msg_ds not yet assigned a destination
381 incoming_registrations=Dict()
381 incoming_registrations=Dict()
382 registration_timeout=Integer()
382 registration_timeout=Integer()
383 _idcounter=Integer(0)
383 _idcounter=Integer(0)
384
384
385 # objects from constructor:
385 # objects from constructor:
386 query=Instance(ZMQStream)
386 query=Instance(ZMQStream)
387 monitor=Instance(ZMQStream)
387 monitor=Instance(ZMQStream)
388 notifier=Instance(ZMQStream)
388 notifier=Instance(ZMQStream)
389 resubmit=Instance(ZMQStream)
389 resubmit=Instance(ZMQStream)
390 heartmonitor=Instance(HeartMonitor)
390 heartmonitor=Instance(HeartMonitor)
391 db=Instance(object)
391 db=Instance(object)
392 client_info=Dict()
392 client_info=Dict()
393 engine_info=Dict()
393 engine_info=Dict()
394
394
395
395
396 def __init__(self, **kwargs):
396 def __init__(self, **kwargs):
397 """
397 """
398 # universal:
398 # universal:
399 loop: IOLoop for creating future connections
399 loop: IOLoop for creating future connections
400 session: streamsession for sending serialized data
400 session: streamsession for sending serialized data
401 # engine:
401 # engine:
402 queue: ZMQStream for monitoring queue messages
402 queue: ZMQStream for monitoring queue messages
403 query: ZMQStream for engine+client registration and client requests
403 query: ZMQStream for engine+client registration and client requests
404 heartbeat: HeartMonitor object for tracking engines
404 heartbeat: HeartMonitor object for tracking engines
405 # extra:
405 # extra:
406 db: ZMQStream for db connection (NotImplemented)
406 db: ZMQStream for db connection (NotImplemented)
407 engine_info: zmq address/protocol dict for engine connections
407 engine_info: zmq address/protocol dict for engine connections
408 client_info: zmq address/protocol dict for client connections
408 client_info: zmq address/protocol dict for client connections
409 """
409 """
410
410
411 super(Hub, self).__init__(**kwargs)
411 super(Hub, self).__init__(**kwargs)
412 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
412 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
413
413
414 # register our callbacks
414 # register our callbacks
415 self.query.on_recv(self.dispatch_query)
415 self.query.on_recv(self.dispatch_query)
416 self.monitor.on_recv(self.dispatch_monitor_traffic)
416 self.monitor.on_recv(self.dispatch_monitor_traffic)
417
417
418 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
418 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
419 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
419 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
420
420
421 self.monitor_handlers = {b'in' : self.save_queue_request,
421 self.monitor_handlers = {b'in' : self.save_queue_request,
422 b'out': self.save_queue_result,
422 b'out': self.save_queue_result,
423 b'intask': self.save_task_request,
423 b'intask': self.save_task_request,
424 b'outtask': self.save_task_result,
424 b'outtask': self.save_task_result,
425 b'tracktask': self.save_task_destination,
425 b'tracktask': self.save_task_destination,
426 b'incontrol': _passer,
426 b'incontrol': _passer,
427 b'outcontrol': _passer,
427 b'outcontrol': _passer,
428 b'iopub': self.save_iopub_message,
428 b'iopub': self.save_iopub_message,
429 }
429 }
430
430
431 self.query_handlers = {'queue_request': self.queue_status,
431 self.query_handlers = {'queue_request': self.queue_status,
432 'result_request': self.get_results,
432 'result_request': self.get_results,
433 'history_request': self.get_history,
433 'history_request': self.get_history,
434 'db_request': self.db_query,
434 'db_request': self.db_query,
435 'purge_request': self.purge_results,
435 'purge_request': self.purge_results,
436 'load_request': self.check_load,
436 'load_request': self.check_load,
437 'resubmit_request': self.resubmit_task,
437 'resubmit_request': self.resubmit_task,
438 'shutdown_request': self.shutdown_request,
438 'shutdown_request': self.shutdown_request,
439 'registration_request' : self.register_engine,
439 'registration_request' : self.register_engine,
440 'unregistration_request' : self.unregister_engine,
440 'unregistration_request' : self.unregister_engine,
441 'connection_request': self.connection_request,
441 'connection_request': self.connection_request,
442 }
442 }
443
443
444 # ignore resubmit replies
444 # ignore resubmit replies
445 self.resubmit.on_recv(lambda msg: None, copy=False)
445 self.resubmit.on_recv(lambda msg: None, copy=False)
446
446
447 self.log.info("hub::created hub")
447 self.log.info("hub::created hub")
448
448
449 @property
449 @property
450 def _next_id(self):
450 def _next_id(self):
451 """gemerate a new ID.
451 """gemerate a new ID.
452
452
453 No longer reuse old ids, just count from 0."""
453 No longer reuse old ids, just count from 0."""
454 newid = self._idcounter
454 newid = self._idcounter
455 self._idcounter += 1
455 self._idcounter += 1
456 return newid
456 return newid
457 # newid = 0
457 # newid = 0
458 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
458 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
459 # # print newid, self.ids, self.incoming_registrations
459 # # print newid, self.ids, self.incoming_registrations
460 # while newid in self.ids or newid in incoming:
460 # while newid in self.ids or newid in incoming:
461 # newid += 1
461 # newid += 1
462 # return newid
462 # return newid
463
463
464 #-----------------------------------------------------------------------------
464 #-----------------------------------------------------------------------------
465 # message validation
465 # message validation
466 #-----------------------------------------------------------------------------
466 #-----------------------------------------------------------------------------
467
467
468 def _validate_targets(self, targets):
468 def _validate_targets(self, targets):
469 """turn any valid targets argument into a list of integer ids"""
469 """turn any valid targets argument into a list of integer ids"""
470 if targets is None:
470 if targets is None:
471 # default to all
471 # default to all
472 return self.ids
472 return self.ids
473
473
474 if isinstance(targets, (int,str,unicode_type)):
474 if isinstance(targets, (int,str,unicode_type)):
475 # only one target specified
475 # only one target specified
476 targets = [targets]
476 targets = [targets]
477 _targets = []
477 _targets = []
478 for t in targets:
478 for t in targets:
479 # map raw identities to ids
479 # map raw identities to ids
480 if isinstance(t, (str,unicode_type)):
480 if isinstance(t, (str,unicode_type)):
481 t = self.by_ident.get(cast_bytes(t), t)
481 t = self.by_ident.get(cast_bytes(t), t)
482 _targets.append(t)
482 _targets.append(t)
483 targets = _targets
483 targets = _targets
484 bad_targets = [ t for t in targets if t not in self.ids ]
484 bad_targets = [ t for t in targets if t not in self.ids ]
485 if bad_targets:
485 if bad_targets:
486 raise IndexError("No Such Engine: %r" % bad_targets)
486 raise IndexError("No Such Engine: %r" % bad_targets)
487 if not targets:
487 if not targets:
488 raise IndexError("No Engines Registered")
488 raise IndexError("No Engines Registered")
489 return targets
489 return targets
490
490
491 #-----------------------------------------------------------------------------
491 #-----------------------------------------------------------------------------
492 # dispatch methods (1 per stream)
492 # dispatch methods (1 per stream)
493 #-----------------------------------------------------------------------------
493 #-----------------------------------------------------------------------------
494
494
495
495
496 @util.log_errors
496 @util.log_errors
497 def dispatch_monitor_traffic(self, msg):
497 def dispatch_monitor_traffic(self, msg):
498 """all ME and Task queue messages come through here, as well as
498 """all ME and Task queue messages come through here, as well as
499 IOPub traffic."""
499 IOPub traffic."""
500 self.log.debug("monitor traffic: %r", msg[0])
500 self.log.debug("monitor traffic: %r", msg[0])
501 switch = msg[0]
501 switch = msg[0]
502 try:
502 try:
503 idents, msg = self.session.feed_identities(msg[1:])
503 idents, msg = self.session.feed_identities(msg[1:])
504 except ValueError:
504 except ValueError:
505 idents=[]
505 idents=[]
506 if not idents:
506 if not idents:
507 self.log.error("Monitor message without topic: %r", msg)
507 self.log.error("Monitor message without topic: %r", msg)
508 return
508 return
509 handler = self.monitor_handlers.get(switch, None)
509 handler = self.monitor_handlers.get(switch, None)
510 if handler is not None:
510 if handler is not None:
511 handler(idents, msg)
511 handler(idents, msg)
512 else:
512 else:
513 self.log.error("Unrecognized monitor topic: %r", switch)
513 self.log.error("Unrecognized monitor topic: %r", switch)
514
514
515
515
516 @util.log_errors
516 @util.log_errors
517 def dispatch_query(self, msg):
517 def dispatch_query(self, msg):
518 """Route registration requests and queries from clients."""
518 """Route registration requests and queries from clients."""
519 try:
519 try:
520 idents, msg = self.session.feed_identities(msg)
520 idents, msg = self.session.feed_identities(msg)
521 except ValueError:
521 except ValueError:
522 idents = []
522 idents = []
523 if not idents:
523 if not idents:
524 self.log.error("Bad Query Message: %r", msg)
524 self.log.error("Bad Query Message: %r", msg)
525 return
525 return
526 client_id = idents[0]
526 client_id = idents[0]
527 try:
527 try:
528 msg = self.session.unserialize(msg, content=True)
528 msg = self.session.unserialize(msg, content=True)
529 except Exception:
529 except Exception:
530 content = error.wrap_exception()
530 content = error.wrap_exception()
531 self.log.error("Bad Query Message: %r", msg, exc_info=True)
531 self.log.error("Bad Query Message: %r", msg, exc_info=True)
532 self.session.send(self.query, "hub_error", ident=client_id,
532 self.session.send(self.query, "hub_error", ident=client_id,
533 content=content)
533 content=content)
534 return
534 return
535 # print client_id, header, parent, content
535 # print client_id, header, parent, content
536 #switch on message type:
536 #switch on message type:
537 msg_type = msg['header']['msg_type']
537 msg_type = msg['header']['msg_type']
538 self.log.info("client::client %r requested %r", client_id, msg_type)
538 self.log.info("client::client %r requested %r", client_id, msg_type)
539 handler = self.query_handlers.get(msg_type, None)
539 handler = self.query_handlers.get(msg_type, None)
540 try:
540 try:
541 assert handler is not None, "Bad Message Type: %r" % msg_type
541 assert handler is not None, "Bad Message Type: %r" % msg_type
542 except:
542 except:
543 content = error.wrap_exception()
543 content = error.wrap_exception()
544 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
544 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
545 self.session.send(self.query, "hub_error", ident=client_id,
545 self.session.send(self.query, "hub_error", ident=client_id,
546 content=content)
546 content=content)
547 return
547 return
548
548
549 else:
549 else:
550 handler(idents, msg)
550 handler(idents, msg)
551
551
552 def dispatch_db(self, msg):
552 def dispatch_db(self, msg):
553 """"""
553 """"""
554 raise NotImplementedError
554 raise NotImplementedError
555
555
556 #---------------------------------------------------------------------------
556 #---------------------------------------------------------------------------
557 # handler methods (1 per event)
557 # handler methods (1 per event)
558 #---------------------------------------------------------------------------
558 #---------------------------------------------------------------------------
559
559
560 #----------------------- Heartbeat --------------------------------------
560 #----------------------- Heartbeat --------------------------------------
561
561
562 def handle_new_heart(self, heart):
562 def handle_new_heart(self, heart):
563 """handler to attach to heartbeater.
563 """handler to attach to heartbeater.
564 Called when a new heart starts to beat.
564 Called when a new heart starts to beat.
565 Triggers completion of registration."""
565 Triggers completion of registration."""
566 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
566 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
567 if heart not in self.incoming_registrations:
567 if heart not in self.incoming_registrations:
568 self.log.info("heartbeat::ignoring new heart: %r", heart)
568 self.log.info("heartbeat::ignoring new heart: %r", heart)
569 else:
569 else:
570 self.finish_registration(heart)
570 self.finish_registration(heart)
571
571
572
572
573 def handle_heart_failure(self, heart):
573 def handle_heart_failure(self, heart):
574 """handler to attach to heartbeater.
574 """handler to attach to heartbeater.
575 called when a previously registered heart fails to respond to beat request.
575 called when a previously registered heart fails to respond to beat request.
576 triggers unregistration"""
576 triggers unregistration"""
577 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
577 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
578 eid = self.hearts.get(heart, None)
578 eid = self.hearts.get(heart, None)
579 uuid = self.engines[eid].uuid
579 uuid = self.engines[eid].uuid
580 if eid is None or self.keytable[eid] in self.dead_engines:
580 if eid is None or self.keytable[eid] in self.dead_engines:
581 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
581 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
582 else:
582 else:
583 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
583 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
584
584
585 #----------------------- MUX Queue Traffic ------------------------------
585 #----------------------- MUX Queue Traffic ------------------------------
586
586
587 def save_queue_request(self, idents, msg):
587 def save_queue_request(self, idents, msg):
588 if len(idents) < 2:
588 if len(idents) < 2:
589 self.log.error("invalid identity prefix: %r", idents)
589 self.log.error("invalid identity prefix: %r", idents)
590 return
590 return
591 queue_id, client_id = idents[:2]
591 queue_id, client_id = idents[:2]
592 try:
592 try:
593 msg = self.session.unserialize(msg)
593 msg = self.session.unserialize(msg)
594 except Exception:
594 except Exception:
595 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
595 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
596 return
596 return
597
597
598 eid = self.by_ident.get(queue_id, None)
598 eid = self.by_ident.get(queue_id, None)
599 if eid is None:
599 if eid is None:
600 self.log.error("queue::target %r not registered", queue_id)
600 self.log.error("queue::target %r not registered", queue_id)
601 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
601 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
602 return
602 return
603 record = init_record(msg)
603 record = init_record(msg)
604 msg_id = record['msg_id']
604 msg_id = record['msg_id']
605 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
605 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
606 # Unicode in records
606 # Unicode in records
607 record['engine_uuid'] = queue_id.decode('ascii')
607 record['engine_uuid'] = queue_id.decode('ascii')
608 record['client_uuid'] = msg['header']['session']
608 record['client_uuid'] = msg['header']['session']
609 record['queue'] = 'mux'
609 record['queue'] = 'mux'
610
610
611 try:
611 try:
612 # it's posible iopub arrived first:
612 # it's posible iopub arrived first:
613 existing = self.db.get_record(msg_id)
613 existing = self.db.get_record(msg_id)
614 for key,evalue in iteritems(existing):
614 for key,evalue in iteritems(existing):
615 rvalue = record.get(key, None)
615 rvalue = record.get(key, None)
616 if evalue and rvalue and evalue != rvalue:
616 if evalue and rvalue and evalue != rvalue:
617 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
617 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
618 elif evalue and not rvalue:
618 elif evalue and not rvalue:
619 record[key] = evalue
619 record[key] = evalue
620 try:
620 try:
621 self.db.update_record(msg_id, record)
621 self.db.update_record(msg_id, record)
622 except Exception:
622 except Exception:
623 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
623 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
624 except KeyError:
624 except KeyError:
625 try:
625 try:
626 self.db.add_record(msg_id, record)
626 self.db.add_record(msg_id, record)
627 except Exception:
627 except Exception:
628 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
628 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
629
629
630
630
631 self.pending.add(msg_id)
631 self.pending.add(msg_id)
632 self.queues[eid].append(msg_id)
632 self.queues[eid].append(msg_id)
633
633
634 def save_queue_result(self, idents, msg):
634 def save_queue_result(self, idents, msg):
635 if len(idents) < 2:
635 if len(idents) < 2:
636 self.log.error("invalid identity prefix: %r", idents)
636 self.log.error("invalid identity prefix: %r", idents)
637 return
637 return
638
638
639 client_id, queue_id = idents[:2]
639 client_id, queue_id = idents[:2]
640 try:
640 try:
641 msg = self.session.unserialize(msg)
641 msg = self.session.unserialize(msg)
642 except Exception:
642 except Exception:
643 self.log.error("queue::engine %r sent invalid message to %r: %r",
643 self.log.error("queue::engine %r sent invalid message to %r: %r",
644 queue_id, client_id, msg, exc_info=True)
644 queue_id, client_id, msg, exc_info=True)
645 return
645 return
646
646
647 eid = self.by_ident.get(queue_id, None)
647 eid = self.by_ident.get(queue_id, None)
648 if eid is None:
648 if eid is None:
649 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
649 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
650 return
650 return
651
651
652 parent = msg['parent_header']
652 parent = msg['parent_header']
653 if not parent:
653 if not parent:
654 return
654 return
655 msg_id = parent['msg_id']
655 msg_id = parent['msg_id']
656 if msg_id in self.pending:
656 if msg_id in self.pending:
657 self.pending.remove(msg_id)
657 self.pending.remove(msg_id)
658 self.all_completed.add(msg_id)
658 self.all_completed.add(msg_id)
659 self.queues[eid].remove(msg_id)
659 self.queues[eid].remove(msg_id)
660 self.completed[eid].append(msg_id)
660 self.completed[eid].append(msg_id)
661 self.log.info("queue::request %r completed on %s", msg_id, eid)
661 self.log.info("queue::request %r completed on %s", msg_id, eid)
662 elif msg_id not in self.all_completed:
662 elif msg_id not in self.all_completed:
663 # it could be a result from a dead engine that died before delivering the
663 # it could be a result from a dead engine that died before delivering the
664 # result
664 # result
665 self.log.warn("queue:: unknown msg finished %r", msg_id)
665 self.log.warn("queue:: unknown msg finished %r", msg_id)
666 return
666 return
667 # update record anyway, because the unregistration could have been premature
667 # update record anyway, because the unregistration could have been premature
668 rheader = msg['header']
668 rheader = msg['header']
669 md = msg['metadata']
669 md = msg['metadata']
670 completed = rheader['date']
670 completed = rheader['date']
671 started = md.get('started', None)
671 started = md.get('started', None)
672 result = {
672 result = {
673 'result_header' : rheader,
673 'result_header' : rheader,
674 'result_metadata': md,
674 'result_metadata': md,
675 'result_content': msg['content'],
675 'result_content': msg['content'],
676 'received': datetime.now(),
676 'received': datetime.now(),
677 'started' : started,
677 'started' : started,
678 'completed' : completed
678 'completed' : completed
679 }
679 }
680
680
681 result['result_buffers'] = msg['buffers']
681 result['result_buffers'] = msg['buffers']
682 try:
682 try:
683 self.db.update_record(msg_id, result)
683 self.db.update_record(msg_id, result)
684 except Exception:
684 except Exception:
685 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
685 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
686
686
687
687
688 #--------------------- Task Queue Traffic ------------------------------
688 #--------------------- Task Queue Traffic ------------------------------
689
689
690 def save_task_request(self, idents, msg):
690 def save_task_request(self, idents, msg):
691 """Save the submission of a task."""
691 """Save the submission of a task."""
692 client_id = idents[0]
692 client_id = idents[0]
693
693
694 try:
694 try:
695 msg = self.session.unserialize(msg)
695 msg = self.session.unserialize(msg)
696 except Exception:
696 except Exception:
697 self.log.error("task::client %r sent invalid task message: %r",
697 self.log.error("task::client %r sent invalid task message: %r",
698 client_id, msg, exc_info=True)
698 client_id, msg, exc_info=True)
699 return
699 return
700 record = init_record(msg)
700 record = init_record(msg)
701
701
702 record['client_uuid'] = msg['header']['session']
702 record['client_uuid'] = msg['header']['session']
703 record['queue'] = 'task'
703 record['queue'] = 'task'
704 header = msg['header']
704 header = msg['header']
705 msg_id = header['msg_id']
705 msg_id = header['msg_id']
706 self.pending.add(msg_id)
706 self.pending.add(msg_id)
707 self.unassigned.add(msg_id)
707 self.unassigned.add(msg_id)
708 try:
708 try:
709 # it's posible iopub arrived first:
709 # it's posible iopub arrived first:
710 existing = self.db.get_record(msg_id)
710 existing = self.db.get_record(msg_id)
711 if existing['resubmitted']:
711 if existing['resubmitted']:
712 for key in ('submitted', 'client_uuid', 'buffers'):
712 for key in ('submitted', 'client_uuid', 'buffers'):
713 # don't clobber these keys on resubmit
713 # don't clobber these keys on resubmit
714 # submitted and client_uuid should be different
714 # submitted and client_uuid should be different
715 # and buffers might be big, and shouldn't have changed
715 # and buffers might be big, and shouldn't have changed
716 record.pop(key)
716 record.pop(key)
717 # still check content,header which should not change
717 # still check content,header which should not change
718 # but are not expensive to compare as buffers
718 # but are not expensive to compare as buffers
719
719
720 for key,evalue in iteritems(existing):
720 for key,evalue in iteritems(existing):
721 if key.endswith('buffers'):
721 if key.endswith('buffers'):
722 # don't compare buffers
722 # don't compare buffers
723 continue
723 continue
724 rvalue = record.get(key, None)
724 rvalue = record.get(key, None)
725 if evalue and rvalue and evalue != rvalue:
725 if evalue and rvalue and evalue != rvalue:
726 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
726 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
727 elif evalue and not rvalue:
727 elif evalue and not rvalue:
728 record[key] = evalue
728 record[key] = evalue
729 try:
729 try:
730 self.db.update_record(msg_id, record)
730 self.db.update_record(msg_id, record)
731 except Exception:
731 except Exception:
732 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
732 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
733 except KeyError:
733 except KeyError:
734 try:
734 try:
735 self.db.add_record(msg_id, record)
735 self.db.add_record(msg_id, record)
736 except Exception:
736 except Exception:
737 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
737 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
738 except Exception:
738 except Exception:
739 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
739 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
740
740
741 def save_task_result(self, idents, msg):
741 def save_task_result(self, idents, msg):
742 """save the result of a completed task."""
742 """save the result of a completed task."""
743 client_id = idents[0]
743 client_id = idents[0]
744 try:
744 try:
745 msg = self.session.unserialize(msg)
745 msg = self.session.unserialize(msg)
746 except Exception:
746 except Exception:
747 self.log.error("task::invalid task result message send to %r: %r",
747 self.log.error("task::invalid task result message send to %r: %r",
748 client_id, msg, exc_info=True)
748 client_id, msg, exc_info=True)
749 return
749 return
750
750
751 parent = msg['parent_header']
751 parent = msg['parent_header']
752 if not parent:
752 if not parent:
753 # print msg
753 # print msg
754 self.log.warn("Task %r had no parent!", msg)
754 self.log.warn("Task %r had no parent!", msg)
755 return
755 return
756 msg_id = parent['msg_id']
756 msg_id = parent['msg_id']
757 if msg_id in self.unassigned:
757 if msg_id in self.unassigned:
758 self.unassigned.remove(msg_id)
758 self.unassigned.remove(msg_id)
759
759
760 header = msg['header']
760 header = msg['header']
761 md = msg['metadata']
761 md = msg['metadata']
762 engine_uuid = md.get('engine', u'')
762 engine_uuid = md.get('engine', u'')
763 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
763 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
764
764
765 status = md.get('status', None)
765 status = md.get('status', None)
766
766
767 if msg_id in self.pending:
767 if msg_id in self.pending:
768 self.log.info("task::task %r finished on %s", msg_id, eid)
768 self.log.info("task::task %r finished on %s", msg_id, eid)
769 self.pending.remove(msg_id)
769 self.pending.remove(msg_id)
770 self.all_completed.add(msg_id)
770 self.all_completed.add(msg_id)
771 if eid is not None:
771 if eid is not None:
772 if status != 'aborted':
772 if status != 'aborted':
773 self.completed[eid].append(msg_id)
773 self.completed[eid].append(msg_id)
774 if msg_id in self.tasks[eid]:
774 if msg_id in self.tasks[eid]:
775 self.tasks[eid].remove(msg_id)
775 self.tasks[eid].remove(msg_id)
776 completed = header['date']
776 completed = header['date']
777 started = md.get('started', None)
777 started = md.get('started', None)
778 result = {
778 result = {
779 'result_header' : header,
779 'result_header' : header,
780 'result_metadata': msg['metadata'],
780 'result_metadata': msg['metadata'],
781 'result_content': msg['content'],
781 'result_content': msg['content'],
782 'started' : started,
782 'started' : started,
783 'completed' : completed,
783 'completed' : completed,
784 'received' : datetime.now(),
784 'received' : datetime.now(),
785 'engine_uuid': engine_uuid,
785 'engine_uuid': engine_uuid,
786 }
786 }
787
787
788 result['result_buffers'] = msg['buffers']
788 result['result_buffers'] = msg['buffers']
789 try:
789 try:
790 self.db.update_record(msg_id, result)
790 self.db.update_record(msg_id, result)
791 except Exception:
791 except Exception:
792 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
792 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
793
793
794 else:
794 else:
795 self.log.debug("task::unknown task %r finished", msg_id)
795 self.log.debug("task::unknown task %r finished", msg_id)
796
796
797 def save_task_destination(self, idents, msg):
797 def save_task_destination(self, idents, msg):
798 try:
798 try:
799 msg = self.session.unserialize(msg, content=True)
799 msg = self.session.unserialize(msg, content=True)
800 except Exception:
800 except Exception:
801 self.log.error("task::invalid task tracking message", exc_info=True)
801 self.log.error("task::invalid task tracking message", exc_info=True)
802 return
802 return
803 content = msg['content']
803 content = msg['content']
804 # print (content)
804 # print (content)
805 msg_id = content['msg_id']
805 msg_id = content['msg_id']
806 engine_uuid = content['engine_id']
806 engine_uuid = content['engine_id']
807 eid = self.by_ident[cast_bytes(engine_uuid)]
807 eid = self.by_ident[cast_bytes(engine_uuid)]
808
808
809 self.log.info("task::task %r arrived on %r", msg_id, eid)
809 self.log.info("task::task %r arrived on %r", msg_id, eid)
810 if msg_id in self.unassigned:
810 if msg_id in self.unassigned:
811 self.unassigned.remove(msg_id)
811 self.unassigned.remove(msg_id)
812 # else:
812 # else:
813 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
813 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
814
814
815 self.tasks[eid].append(msg_id)
815 self.tasks[eid].append(msg_id)
816 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
816 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
817 try:
817 try:
818 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
818 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
819 except Exception:
819 except Exception:
820 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
820 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
821
821
822
822
823 def mia_task_request(self, idents, msg):
823 def mia_task_request(self, idents, msg):
824 raise NotImplementedError
824 raise NotImplementedError
825 client_id = idents[0]
825 client_id = idents[0]
826 # content = dict(mia=self.mia,status='ok')
826 # content = dict(mia=self.mia,status='ok')
827 # self.session.send('mia_reply', content=content, idents=client_id)
827 # self.session.send('mia_reply', content=content, idents=client_id)
828
828
829
829
830 #--------------------- IOPub Traffic ------------------------------
830 #--------------------- IOPub Traffic ------------------------------
831
831
832 def save_iopub_message(self, topics, msg):
832 def save_iopub_message(self, topics, msg):
833 """save an iopub message into the db"""
833 """save an iopub message into the db"""
834 # print (topics)
834 # print (topics)
835 try:
835 try:
836 msg = self.session.unserialize(msg, content=True)
836 msg = self.session.unserialize(msg, content=True)
837 except Exception:
837 except Exception:
838 self.log.error("iopub::invalid IOPub message", exc_info=True)
838 self.log.error("iopub::invalid IOPub message", exc_info=True)
839 return
839 return
840
840
841 parent = msg['parent_header']
841 parent = msg['parent_header']
842 if not parent:
842 if not parent:
843 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
843 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
844 return
844 return
845 msg_id = parent['msg_id']
845 msg_id = parent['msg_id']
846 msg_type = msg['header']['msg_type']
846 msg_type = msg['header']['msg_type']
847 content = msg['content']
847 content = msg['content']
848
848
849 # ensure msg_id is in db
849 # ensure msg_id is in db
850 try:
850 try:
851 rec = self.db.get_record(msg_id)
851 rec = self.db.get_record(msg_id)
852 except KeyError:
852 except KeyError:
853 rec = empty_record()
853 rec = empty_record()
854 rec['msg_id'] = msg_id
854 rec['msg_id'] = msg_id
855 self.db.add_record(msg_id, rec)
855 self.db.add_record(msg_id, rec)
856 # stream
856 # stream
857 d = {}
857 d = {}
858 if msg_type == 'stream':
858 if msg_type == 'stream':
859 name = content['name']
859 name = content['name']
860 s = rec[name] or ''
860 s = rec[name] or ''
861 d[name] = s + content['data']
861 d[name] = s + content['data']
862
862
863 elif msg_type == 'pyerr':
863 elif msg_type == 'pyerr':
864 d['pyerr'] = content
864 d['pyerr'] = content
865 elif msg_type == 'pyin':
865 elif msg_type == 'pyin':
866 d['pyin'] = content['code']
866 d['pyin'] = content['code']
867 elif msg_type in ('display_data', 'pyout'):
867 elif msg_type in ('display_data', 'pyout'):
868 d[msg_type] = content
868 d[msg_type] = content
869 elif msg_type == 'status':
869 elif msg_type == 'status':
870 pass
870 pass
871 elif msg_type == 'data_pub':
871 elif msg_type == 'data_pub':
872 self.log.info("ignored data_pub message for %s" % msg_id)
872 self.log.info("ignored data_pub message for %s" % msg_id)
873 else:
873 else:
874 self.log.warn("unhandled iopub msg_type: %r", msg_type)
874 self.log.warn("unhandled iopub msg_type: %r", msg_type)
875
875
876 if not d:
876 if not d:
877 return
877 return
878
878
879 try:
879 try:
880 self.db.update_record(msg_id, d)
880 self.db.update_record(msg_id, d)
881 except Exception:
881 except Exception:
882 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
882 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
883
883
884
884
885
885
886 #-------------------------------------------------------------------------
886 #-------------------------------------------------------------------------
887 # Registration requests
887 # Registration requests
888 #-------------------------------------------------------------------------
888 #-------------------------------------------------------------------------
889
889
890 def connection_request(self, client_id, msg):
890 def connection_request(self, client_id, msg):
891 """Reply with connection addresses for clients."""
891 """Reply with connection addresses for clients."""
892 self.log.info("client::client %r connected", client_id)
892 self.log.info("client::client %r connected", client_id)
893 content = dict(status='ok')
893 content = dict(status='ok')
894 jsonable = {}
894 jsonable = {}
895 for k,v in iteritems(self.keytable):
895 for k,v in iteritems(self.keytable):
896 if v not in self.dead_engines:
896 if v not in self.dead_engines:
897 jsonable[str(k)] = v
897 jsonable[str(k)] = v
898 content['engines'] = jsonable
898 content['engines'] = jsonable
899 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
899 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
900
900
901 def register_engine(self, reg, msg):
901 def register_engine(self, reg, msg):
902 """Register a new engine."""
902 """Register a new engine."""
903 content = msg['content']
903 content = msg['content']
904 try:
904 try:
905 uuid = content['uuid']
905 uuid = content['uuid']
906 except KeyError:
906 except KeyError:
907 self.log.error("registration::queue not specified", exc_info=True)
907 self.log.error("registration::queue not specified", exc_info=True)
908 return
908 return
909
909
910 eid = self._next_id
910 eid = self._next_id
911
911
912 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
912 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
913
913
914 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
914 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
915 # check if requesting available IDs:
915 # check if requesting available IDs:
916 if cast_bytes(uuid) in self.by_ident:
916 if cast_bytes(uuid) in self.by_ident:
917 try:
917 try:
918 raise KeyError("uuid %r in use" % uuid)
918 raise KeyError("uuid %r in use" % uuid)
919 except:
919 except:
920 content = error.wrap_exception()
920 content = error.wrap_exception()
921 self.log.error("uuid %r in use", uuid, exc_info=True)
921 self.log.error("uuid %r in use", uuid, exc_info=True)
922 else:
922 else:
923 for h, ec in iteritems(self.incoming_registrations):
923 for h, ec in iteritems(self.incoming_registrations):
924 if uuid == h:
924 if uuid == h:
925 try:
925 try:
926 raise KeyError("heart_id %r in use" % uuid)
926 raise KeyError("heart_id %r in use" % uuid)
927 except:
927 except:
928 self.log.error("heart_id %r in use", uuid, exc_info=True)
928 self.log.error("heart_id %r in use", uuid, exc_info=True)
929 content = error.wrap_exception()
929 content = error.wrap_exception()
930 break
930 break
931 elif uuid == ec.uuid:
931 elif uuid == ec.uuid:
932 try:
932 try:
933 raise KeyError("uuid %r in use" % uuid)
933 raise KeyError("uuid %r in use" % uuid)
934 except:
934 except:
935 self.log.error("uuid %r in use", uuid, exc_info=True)
935 self.log.error("uuid %r in use", uuid, exc_info=True)
936 content = error.wrap_exception()
936 content = error.wrap_exception()
937 break
937 break
938
938
939 msg = self.session.send(self.query, "registration_reply",
939 msg = self.session.send(self.query, "registration_reply",
940 content=content,
940 content=content,
941 ident=reg)
941 ident=reg)
942
942
943 heart = cast_bytes(uuid)
943 heart = cast_bytes(uuid)
944
944
945 if content['status'] == 'ok':
945 if content['status'] == 'ok':
946 if heart in self.heartmonitor.hearts:
946 if heart in self.heartmonitor.hearts:
947 # already beating
947 # already beating
948 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
948 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
949 self.finish_registration(heart)
949 self.finish_registration(heart)
950 else:
950 else:
951 purge = lambda : self._purge_stalled_registration(heart)
951 purge = lambda : self._purge_stalled_registration(heart)
952 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
952 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
953 dc.start()
953 dc.start()
954 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
954 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
955 else:
955 else:
956 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
956 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
957
957
958 return eid
958 return eid
959
959
960 def unregister_engine(self, ident, msg):
960 def unregister_engine(self, ident, msg):
961 """Unregister an engine that explicitly requested to leave."""
961 """Unregister an engine that explicitly requested to leave."""
962 try:
962 try:
963 eid = msg['content']['id']
963 eid = msg['content']['id']
964 except:
964 except:
965 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
965 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
966 return
966 return
967 self.log.info("registration::unregister_engine(%r)", eid)
967 self.log.info("registration::unregister_engine(%r)", eid)
968 # print (eid)
968 # print (eid)
969 uuid = self.keytable[eid]
969 uuid = self.keytable[eid]
970 content=dict(id=eid, uuid=uuid)
970 content=dict(id=eid, uuid=uuid)
971 self.dead_engines.add(uuid)
971 self.dead_engines.add(uuid)
972 # self.ids.remove(eid)
972 # self.ids.remove(eid)
973 # uuid = self.keytable.pop(eid)
973 # uuid = self.keytable.pop(eid)
974 #
974 #
975 # ec = self.engines.pop(eid)
975 # ec = self.engines.pop(eid)
976 # self.hearts.pop(ec.heartbeat)
976 # self.hearts.pop(ec.heartbeat)
977 # self.by_ident.pop(ec.queue)
977 # self.by_ident.pop(ec.queue)
978 # self.completed.pop(eid)
978 # self.completed.pop(eid)
979 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
979 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
980 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
980 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
981 dc.start()
981 dc.start()
982 ############## TODO: HANDLE IT ################
982 ############## TODO: HANDLE IT ################
983
983
984 self._save_engine_state()
984 self._save_engine_state()
985
985
986 if self.notifier:
986 if self.notifier:
987 self.session.send(self.notifier, "unregistration_notification", content=content)
987 self.session.send(self.notifier, "unregistration_notification", content=content)
988
988
989 def _handle_stranded_msgs(self, eid, uuid):
989 def _handle_stranded_msgs(self, eid, uuid):
990 """Handle messages known to be on an engine when the engine unregisters.
990 """Handle messages known to be on an engine when the engine unregisters.
991
991
992 It is possible that this will fire prematurely - that is, an engine will
992 It is possible that this will fire prematurely - that is, an engine will
993 go down after completing a result, and the client will be notified
993 go down after completing a result, and the client will be notified
994 that the result failed and later receive the actual result.
994 that the result failed and later receive the actual result.
995 """
995 """
996
996
997 outstanding = self.queues[eid]
997 outstanding = self.queues[eid]
998
998
999 for msg_id in outstanding:
999 for msg_id in outstanding:
1000 self.pending.remove(msg_id)
1000 self.pending.remove(msg_id)
1001 self.all_completed.add(msg_id)
1001 self.all_completed.add(msg_id)
1002 try:
1002 try:
1003 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1003 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1004 except:
1004 except:
1005 content = error.wrap_exception()
1005 content = error.wrap_exception()
1006 # build a fake header:
1006 # build a fake header:
1007 header = {}
1007 header = {}
1008 header['engine'] = uuid
1008 header['engine'] = uuid
1009 header['date'] = datetime.now()
1009 header['date'] = datetime.now()
1010 rec = dict(result_content=content, result_header=header, result_buffers=[])
1010 rec = dict(result_content=content, result_header=header, result_buffers=[])
1011 rec['completed'] = header['date']
1011 rec['completed'] = header['date']
1012 rec['engine_uuid'] = uuid
1012 rec['engine_uuid'] = uuid
1013 try:
1013 try:
1014 self.db.update_record(msg_id, rec)
1014 self.db.update_record(msg_id, rec)
1015 except Exception:
1015 except Exception:
1016 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1016 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1017
1017
1018
1018
1019 def finish_registration(self, heart):
1019 def finish_registration(self, heart):
1020 """Second half of engine registration, called after our HeartMonitor
1020 """Second half of engine registration, called after our HeartMonitor
1021 has received a beat from the Engine's Heart."""
1021 has received a beat from the Engine's Heart."""
1022 try:
1022 try:
1023 ec = self.incoming_registrations.pop(heart)
1023 ec = self.incoming_registrations.pop(heart)
1024 except KeyError:
1024 except KeyError:
1025 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1025 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1026 return
1026 return
1027 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1027 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1028 if ec.stallback is not None:
1028 if ec.stallback is not None:
1029 ec.stallback.stop()
1029 ec.stallback.stop()
1030 eid = ec.id
1030 eid = ec.id
1031 self.ids.add(eid)
1031 self.ids.add(eid)
1032 self.keytable[eid] = ec.uuid
1032 self.keytable[eid] = ec.uuid
1033 self.engines[eid] = ec
1033 self.engines[eid] = ec
1034 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1034 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1035 self.queues[eid] = list()
1035 self.queues[eid] = list()
1036 self.tasks[eid] = list()
1036 self.tasks[eid] = list()
1037 self.completed[eid] = list()
1037 self.completed[eid] = list()
1038 self.hearts[heart] = eid
1038 self.hearts[heart] = eid
1039 content = dict(id=eid, uuid=self.engines[eid].uuid)
1039 content = dict(id=eid, uuid=self.engines[eid].uuid)
1040 if self.notifier:
1040 if self.notifier:
1041 self.session.send(self.notifier, "registration_notification", content=content)
1041 self.session.send(self.notifier, "registration_notification", content=content)
1042 self.log.info("engine::Engine Connected: %i", eid)
1042 self.log.info("engine::Engine Connected: %i", eid)
1043
1043
1044 self._save_engine_state()
1044 self._save_engine_state()
1045
1045
1046 def _purge_stalled_registration(self, heart):
1046 def _purge_stalled_registration(self, heart):
1047 if heart in self.incoming_registrations:
1047 if heart in self.incoming_registrations:
1048 ec = self.incoming_registrations.pop(heart)
1048 ec = self.incoming_registrations.pop(heart)
1049 self.log.info("registration::purging stalled registration: %i", ec.id)
1049 self.log.info("registration::purging stalled registration: %i", ec.id)
1050 else:
1050 else:
1051 pass
1051 pass
1052
1052
1053 #-------------------------------------------------------------------------
1053 #-------------------------------------------------------------------------
1054 # Engine State
1054 # Engine State
1055 #-------------------------------------------------------------------------
1055 #-------------------------------------------------------------------------
1056
1056
1057
1057
1058 def _cleanup_engine_state_file(self):
1058 def _cleanup_engine_state_file(self):
1059 """cleanup engine state mapping"""
1059 """cleanup engine state mapping"""
1060
1060
1061 if os.path.exists(self.engine_state_file):
1061 if os.path.exists(self.engine_state_file):
1062 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1062 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1063 try:
1063 try:
1064 os.remove(self.engine_state_file)
1064 os.remove(self.engine_state_file)
1065 except IOError:
1065 except IOError:
1066 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1066 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1067
1067
1068
1068
1069 def _save_engine_state(self):
1069 def _save_engine_state(self):
1070 """save engine mapping to JSON file"""
1070 """save engine mapping to JSON file"""
1071 if not self.engine_state_file:
1071 if not self.engine_state_file:
1072 return
1072 return
1073 self.log.debug("save engine state to %s" % self.engine_state_file)
1073 self.log.debug("save engine state to %s" % self.engine_state_file)
1074 state = {}
1074 state = {}
1075 engines = {}
1075 engines = {}
1076 for eid, ec in iteritems(self.engines):
1076 for eid, ec in iteritems(self.engines):
1077 if ec.uuid not in self.dead_engines:
1077 if ec.uuid not in self.dead_engines:
1078 engines[eid] = ec.uuid
1078 engines[eid] = ec.uuid
1079
1079
1080 state['engines'] = engines
1080 state['engines'] = engines
1081
1081
1082 state['next_id'] = self._idcounter
1082 state['next_id'] = self._idcounter
1083
1083
1084 with open(self.engine_state_file, 'w') as f:
1084 with open(self.engine_state_file, 'w') as f:
1085 json.dump(state, f)
1085 json.dump(state, f)
1086
1086
1087
1087
1088 def _load_engine_state(self):
1088 def _load_engine_state(self):
1089 """load engine mapping from JSON file"""
1089 """load engine mapping from JSON file"""
1090 if not os.path.exists(self.engine_state_file):
1090 if not os.path.exists(self.engine_state_file):
1091 return
1091 return
1092
1092
1093 self.log.info("loading engine state from %s" % self.engine_state_file)
1093 self.log.info("loading engine state from %s" % self.engine_state_file)
1094
1094
1095 with open(self.engine_state_file) as f:
1095 with open(self.engine_state_file) as f:
1096 state = json.load(f)
1096 state = json.load(f)
1097
1097
1098 save_notifier = self.notifier
1098 save_notifier = self.notifier
1099 self.notifier = None
1099 self.notifier = None
1100 for eid, uuid in iteritems(state['engines']):
1100 for eid, uuid in iteritems(state['engines']):
1101 heart = uuid.encode('ascii')
1101 heart = uuid.encode('ascii')
1102 # start with this heart as current and beating:
1102 # start with this heart as current and beating:
1103 self.heartmonitor.responses.add(heart)
1103 self.heartmonitor.responses.add(heart)
1104 self.heartmonitor.hearts.add(heart)
1104 self.heartmonitor.hearts.add(heart)
1105
1105
1106 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1106 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1107 self.finish_registration(heart)
1107 self.finish_registration(heart)
1108
1108
1109 self.notifier = save_notifier
1109 self.notifier = save_notifier
1110
1110
1111 self._idcounter = state['next_id']
1111 self._idcounter = state['next_id']
1112
1112
1113 #-------------------------------------------------------------------------
1113 #-------------------------------------------------------------------------
1114 # Client Requests
1114 # Client Requests
1115 #-------------------------------------------------------------------------
1115 #-------------------------------------------------------------------------
1116
1116
1117 def shutdown_request(self, client_id, msg):
1117 def shutdown_request(self, client_id, msg):
1118 """handle shutdown request."""
1118 """handle shutdown request."""
1119 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1119 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1120 # also notify other clients of shutdown
1120 # also notify other clients of shutdown
1121 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1121 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1122 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1122 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1123 dc.start()
1123 dc.start()
1124
1124
1125 def _shutdown(self):
1125 def _shutdown(self):
1126 self.log.info("hub::hub shutting down.")
1126 self.log.info("hub::hub shutting down.")
1127 time.sleep(0.1)
1127 time.sleep(0.1)
1128 sys.exit(0)
1128 sys.exit(0)
1129
1129
1130
1130
1131 def check_load(self, client_id, msg):
1131 def check_load(self, client_id, msg):
1132 content = msg['content']
1132 content = msg['content']
1133 try:
1133 try:
1134 targets = content['targets']
1134 targets = content['targets']
1135 targets = self._validate_targets(targets)
1135 targets = self._validate_targets(targets)
1136 except:
1136 except:
1137 content = error.wrap_exception()
1137 content = error.wrap_exception()
1138 self.session.send(self.query, "hub_error",
1138 self.session.send(self.query, "hub_error",
1139 content=content, ident=client_id)
1139 content=content, ident=client_id)
1140 return
1140 return
1141
1141
1142 content = dict(status='ok')
1142 content = dict(status='ok')
1143 # loads = {}
1143 # loads = {}
1144 for t in targets:
1144 for t in targets:
1145 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1145 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1146 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1146 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1147
1147
1148
1148
1149 def queue_status(self, client_id, msg):
1149 def queue_status(self, client_id, msg):
1150 """Return the Queue status of one or more targets.
1150 """Return the Queue status of one or more targets.
1151 if verbose: return the msg_ids
1151 if verbose: return the msg_ids
1152 else: return len of each type.
1152 else: return len of each type.
1153 keys: queue (pending MUX jobs)
1153 keys: queue (pending MUX jobs)
1154 tasks (pending Task jobs)
1154 tasks (pending Task jobs)
1155 completed (finished jobs from both queues)"""
1155 completed (finished jobs from both queues)"""
1156 content = msg['content']
1156 content = msg['content']
1157 targets = content['targets']
1157 targets = content['targets']
1158 try:
1158 try:
1159 targets = self._validate_targets(targets)
1159 targets = self._validate_targets(targets)
1160 except:
1160 except:
1161 content = error.wrap_exception()
1161 content = error.wrap_exception()
1162 self.session.send(self.query, "hub_error",
1162 self.session.send(self.query, "hub_error",
1163 content=content, ident=client_id)
1163 content=content, ident=client_id)
1164 return
1164 return
1165 verbose = content.get('verbose', False)
1165 verbose = content.get('verbose', False)
1166 content = dict(status='ok')
1166 content = dict(status='ok')
1167 for t in targets:
1167 for t in targets:
1168 queue = self.queues[t]
1168 queue = self.queues[t]
1169 completed = self.completed[t]
1169 completed = self.completed[t]
1170 tasks = self.tasks[t]
1170 tasks = self.tasks[t]
1171 if not verbose:
1171 if not verbose:
1172 queue = len(queue)
1172 queue = len(queue)
1173 completed = len(completed)
1173 completed = len(completed)
1174 tasks = len(tasks)
1174 tasks = len(tasks)
1175 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1175 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1176 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1176 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1177 # print (content)
1177 # print (content)
1178 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1178 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1179
1179
1180 def purge_results(self, client_id, msg):
1180 def purge_results(self, client_id, msg):
1181 """Purge results from memory. This method is more valuable before we move
1181 """Purge results from memory. This method is more valuable before we move
1182 to a DB based message storage mechanism."""
1182 to a DB based message storage mechanism."""
1183 content = msg['content']
1183 content = msg['content']
1184 self.log.info("Dropping records with %s", content)
1184 self.log.info("Dropping records with %s", content)
1185 msg_ids = content.get('msg_ids', [])
1185 msg_ids = content.get('msg_ids', [])
1186 reply = dict(status='ok')
1186 reply = dict(status='ok')
1187 if msg_ids == 'all':
1187 if msg_ids == 'all':
1188 try:
1188 try:
1189 self.db.drop_matching_records(dict(completed={'$ne':None}))
1189 self.db.drop_matching_records(dict(completed={'$ne':None}))
1190 except Exception:
1190 except Exception:
1191 reply = error.wrap_exception()
1191 reply = error.wrap_exception()
1192 else:
1192 else:
1193 pending = filter(lambda m: m in self.pending, msg_ids)
1193 pending = [m for m in msg_ids if (m in self.pending)]
1194 if pending:
1194 if pending:
1195 try:
1195 try:
1196 raise IndexError("msg pending: %r" % pending[0])
1196 raise IndexError("msg pending: %r" % pending[0])
1197 except:
1197 except:
1198 reply = error.wrap_exception()
1198 reply = error.wrap_exception()
1199 else:
1199 else:
1200 try:
1200 try:
1201 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1201 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1202 except Exception:
1202 except Exception:
1203 reply = error.wrap_exception()
1203 reply = error.wrap_exception()
1204
1204
1205 if reply['status'] == 'ok':
1205 if reply['status'] == 'ok':
1206 eids = content.get('engine_ids', [])
1206 eids = content.get('engine_ids', [])
1207 for eid in eids:
1207 for eid in eids:
1208 if eid not in self.engines:
1208 if eid not in self.engines:
1209 try:
1209 try:
1210 raise IndexError("No such engine: %i" % eid)
1210 raise IndexError("No such engine: %i" % eid)
1211 except:
1211 except:
1212 reply = error.wrap_exception()
1212 reply = error.wrap_exception()
1213 break
1213 break
1214 uid = self.engines[eid].uuid
1214 uid = self.engines[eid].uuid
1215 try:
1215 try:
1216 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1216 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1217 except Exception:
1217 except Exception:
1218 reply = error.wrap_exception()
1218 reply = error.wrap_exception()
1219 break
1219 break
1220
1220
1221 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1221 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1222
1222
1223 def resubmit_task(self, client_id, msg):
1223 def resubmit_task(self, client_id, msg):
1224 """Resubmit one or more tasks."""
1224 """Resubmit one or more tasks."""
1225 def finish(reply):
1225 def finish(reply):
1226 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1226 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1227
1227
1228 content = msg['content']
1228 content = msg['content']
1229 msg_ids = content['msg_ids']
1229 msg_ids = content['msg_ids']
1230 reply = dict(status='ok')
1230 reply = dict(status='ok')
1231 try:
1231 try:
1232 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1232 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1233 'header', 'content', 'buffers'])
1233 'header', 'content', 'buffers'])
1234 except Exception:
1234 except Exception:
1235 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1235 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1236 return finish(error.wrap_exception())
1236 return finish(error.wrap_exception())
1237
1237
1238 # validate msg_ids
1238 # validate msg_ids
1239 found_ids = [ rec['msg_id'] for rec in records ]
1239 found_ids = [ rec['msg_id'] for rec in records ]
1240 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1240 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1241 if len(records) > len(msg_ids):
1241 if len(records) > len(msg_ids):
1242 try:
1242 try:
1243 raise RuntimeError("DB appears to be in an inconsistent state."
1243 raise RuntimeError("DB appears to be in an inconsistent state."
1244 "More matching records were found than should exist")
1244 "More matching records were found than should exist")
1245 except Exception:
1245 except Exception:
1246 return finish(error.wrap_exception())
1246 return finish(error.wrap_exception())
1247 elif len(records) < len(msg_ids):
1247 elif len(records) < len(msg_ids):
1248 missing = [ m for m in msg_ids if m not in found_ids ]
1248 missing = [ m for m in msg_ids if m not in found_ids ]
1249 try:
1249 try:
1250 raise KeyError("No such msg(s): %r" % missing)
1250 raise KeyError("No such msg(s): %r" % missing)
1251 except KeyError:
1251 except KeyError:
1252 return finish(error.wrap_exception())
1252 return finish(error.wrap_exception())
1253 elif pending_ids:
1253 elif pending_ids:
1254 pass
1254 pass
1255 # no need to raise on resubmit of pending task, now that we
1255 # no need to raise on resubmit of pending task, now that we
1256 # resubmit under new ID, but do we want to raise anyway?
1256 # resubmit under new ID, but do we want to raise anyway?
1257 # msg_id = invalid_ids[0]
1257 # msg_id = invalid_ids[0]
1258 # try:
1258 # try:
1259 # raise ValueError("Task(s) %r appears to be inflight" % )
1259 # raise ValueError("Task(s) %r appears to be inflight" % )
1260 # except Exception:
1260 # except Exception:
1261 # return finish(error.wrap_exception())
1261 # return finish(error.wrap_exception())
1262
1262
1263 # mapping of original IDs to resubmitted IDs
1263 # mapping of original IDs to resubmitted IDs
1264 resubmitted = {}
1264 resubmitted = {}
1265
1265
1266 # send the messages
1266 # send the messages
1267 for rec in records:
1267 for rec in records:
1268 header = rec['header']
1268 header = rec['header']
1269 msg = self.session.msg(header['msg_type'], parent=header)
1269 msg = self.session.msg(header['msg_type'], parent=header)
1270 msg_id = msg['msg_id']
1270 msg_id = msg['msg_id']
1271 msg['content'] = rec['content']
1271 msg['content'] = rec['content']
1272
1272
1273 # use the old header, but update msg_id and timestamp
1273 # use the old header, but update msg_id and timestamp
1274 fresh = msg['header']
1274 fresh = msg['header']
1275 header['msg_id'] = fresh['msg_id']
1275 header['msg_id'] = fresh['msg_id']
1276 header['date'] = fresh['date']
1276 header['date'] = fresh['date']
1277 msg['header'] = header
1277 msg['header'] = header
1278
1278
1279 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1279 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1280
1280
1281 resubmitted[rec['msg_id']] = msg_id
1281 resubmitted[rec['msg_id']] = msg_id
1282 self.pending.add(msg_id)
1282 self.pending.add(msg_id)
1283 msg['buffers'] = rec['buffers']
1283 msg['buffers'] = rec['buffers']
1284 try:
1284 try:
1285 self.db.add_record(msg_id, init_record(msg))
1285 self.db.add_record(msg_id, init_record(msg))
1286 except Exception:
1286 except Exception:
1287 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1287 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1288 return finish(error.wrap_exception())
1288 return finish(error.wrap_exception())
1289
1289
1290 finish(dict(status='ok', resubmitted=resubmitted))
1290 finish(dict(status='ok', resubmitted=resubmitted))
1291
1291
1292 # store the new IDs in the Task DB
1292 # store the new IDs in the Task DB
1293 for msg_id, resubmit_id in iteritems(resubmitted):
1293 for msg_id, resubmit_id in iteritems(resubmitted):
1294 try:
1294 try:
1295 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1295 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1296 except Exception:
1296 except Exception:
1297 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1297 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1298
1298
1299
1299
1300 def _extract_record(self, rec):
1300 def _extract_record(self, rec):
1301 """decompose a TaskRecord dict into subsection of reply for get_result"""
1301 """decompose a TaskRecord dict into subsection of reply for get_result"""
1302 io_dict = {}
1302 io_dict = {}
1303 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1303 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1304 io_dict[key] = rec[key]
1304 io_dict[key] = rec[key]
1305 content = {
1305 content = {
1306 'header': rec['header'],
1306 'header': rec['header'],
1307 'metadata': rec['metadata'],
1307 'metadata': rec['metadata'],
1308 'result_metadata': rec['result_metadata'],
1308 'result_metadata': rec['result_metadata'],
1309 'result_header' : rec['result_header'],
1309 'result_header' : rec['result_header'],
1310 'result_content': rec['result_content'],
1310 'result_content': rec['result_content'],
1311 'received' : rec['received'],
1311 'received' : rec['received'],
1312 'io' : io_dict,
1312 'io' : io_dict,
1313 }
1313 }
1314 if rec['result_buffers']:
1314 if rec['result_buffers']:
1315 buffers = map(bytes, rec['result_buffers'])
1315 buffers = list(map(bytes, rec['result_buffers']))
1316 else:
1316 else:
1317 buffers = []
1317 buffers = []
1318
1318
1319 return content, buffers
1319 return content, buffers
1320
1320
1321 def get_results(self, client_id, msg):
1321 def get_results(self, client_id, msg):
1322 """Get the result of 1 or more messages."""
1322 """Get the result of 1 or more messages."""
1323 content = msg['content']
1323 content = msg['content']
1324 msg_ids = sorted(set(content['msg_ids']))
1324 msg_ids = sorted(set(content['msg_ids']))
1325 statusonly = content.get('status_only', False)
1325 statusonly = content.get('status_only', False)
1326 pending = []
1326 pending = []
1327 completed = []
1327 completed = []
1328 content = dict(status='ok')
1328 content = dict(status='ok')
1329 content['pending'] = pending
1329 content['pending'] = pending
1330 content['completed'] = completed
1330 content['completed'] = completed
1331 buffers = []
1331 buffers = []
1332 if not statusonly:
1332 if not statusonly:
1333 try:
1333 try:
1334 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1334 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1335 # turn match list into dict, for faster lookup
1335 # turn match list into dict, for faster lookup
1336 records = {}
1336 records = {}
1337 for rec in matches:
1337 for rec in matches:
1338 records[rec['msg_id']] = rec
1338 records[rec['msg_id']] = rec
1339 except Exception:
1339 except Exception:
1340 content = error.wrap_exception()
1340 content = error.wrap_exception()
1341 self.session.send(self.query, "result_reply", content=content,
1341 self.session.send(self.query, "result_reply", content=content,
1342 parent=msg, ident=client_id)
1342 parent=msg, ident=client_id)
1343 return
1343 return
1344 else:
1344 else:
1345 records = {}
1345 records = {}
1346 for msg_id in msg_ids:
1346 for msg_id in msg_ids:
1347 if msg_id in self.pending:
1347 if msg_id in self.pending:
1348 pending.append(msg_id)
1348 pending.append(msg_id)
1349 elif msg_id in self.all_completed:
1349 elif msg_id in self.all_completed:
1350 completed.append(msg_id)
1350 completed.append(msg_id)
1351 if not statusonly:
1351 if not statusonly:
1352 c,bufs = self._extract_record(records[msg_id])
1352 c,bufs = self._extract_record(records[msg_id])
1353 content[msg_id] = c
1353 content[msg_id] = c
1354 buffers.extend(bufs)
1354 buffers.extend(bufs)
1355 elif msg_id in records:
1355 elif msg_id in records:
1356 if rec['completed']:
1356 if rec['completed']:
1357 completed.append(msg_id)
1357 completed.append(msg_id)
1358 c,bufs = self._extract_record(records[msg_id])
1358 c,bufs = self._extract_record(records[msg_id])
1359 content[msg_id] = c
1359 content[msg_id] = c
1360 buffers.extend(bufs)
1360 buffers.extend(bufs)
1361 else:
1361 else:
1362 pending.append(msg_id)
1362 pending.append(msg_id)
1363 else:
1363 else:
1364 try:
1364 try:
1365 raise KeyError('No such message: '+msg_id)
1365 raise KeyError('No such message: '+msg_id)
1366 except:
1366 except:
1367 content = error.wrap_exception()
1367 content = error.wrap_exception()
1368 break
1368 break
1369 self.session.send(self.query, "result_reply", content=content,
1369 self.session.send(self.query, "result_reply", content=content,
1370 parent=msg, ident=client_id,
1370 parent=msg, ident=client_id,
1371 buffers=buffers)
1371 buffers=buffers)
1372
1372
1373 def get_history(self, client_id, msg):
1373 def get_history(self, client_id, msg):
1374 """Get a list of all msg_ids in our DB records"""
1374 """Get a list of all msg_ids in our DB records"""
1375 try:
1375 try:
1376 msg_ids = self.db.get_history()
1376 msg_ids = self.db.get_history()
1377 except Exception as e:
1377 except Exception as e:
1378 content = error.wrap_exception()
1378 content = error.wrap_exception()
1379 else:
1379 else:
1380 content = dict(status='ok', history=msg_ids)
1380 content = dict(status='ok', history=msg_ids)
1381
1381
1382 self.session.send(self.query, "history_reply", content=content,
1382 self.session.send(self.query, "history_reply", content=content,
1383 parent=msg, ident=client_id)
1383 parent=msg, ident=client_id)
1384
1384
1385 def db_query(self, client_id, msg):
1385 def db_query(self, client_id, msg):
1386 """Perform a raw query on the task record database."""
1386 """Perform a raw query on the task record database."""
1387 content = msg['content']
1387 content = msg['content']
1388 query = content.get('query', {})
1388 query = content.get('query', {})
1389 keys = content.get('keys', None)
1389 keys = content.get('keys', None)
1390 buffers = []
1390 buffers = []
1391 empty = list()
1391 empty = list()
1392 try:
1392 try:
1393 records = self.db.find_records(query, keys)
1393 records = self.db.find_records(query, keys)
1394 except Exception as e:
1394 except Exception as e:
1395 content = error.wrap_exception()
1395 content = error.wrap_exception()
1396 else:
1396 else:
1397 # extract buffers from reply content:
1397 # extract buffers from reply content:
1398 if keys is not None:
1398 if keys is not None:
1399 buffer_lens = [] if 'buffers' in keys else None
1399 buffer_lens = [] if 'buffers' in keys else None
1400 result_buffer_lens = [] if 'result_buffers' in keys else None
1400 result_buffer_lens = [] if 'result_buffers' in keys else None
1401 else:
1401 else:
1402 buffer_lens = None
1402 buffer_lens = None
1403 result_buffer_lens = None
1403 result_buffer_lens = None
1404
1404
1405 for rec in records:
1405 for rec in records:
1406 # buffers may be None, so double check
1406 # buffers may be None, so double check
1407 b = rec.pop('buffers', empty) or empty
1407 b = rec.pop('buffers', empty) or empty
1408 if buffer_lens is not None:
1408 if buffer_lens is not None:
1409 buffer_lens.append(len(b))
1409 buffer_lens.append(len(b))
1410 buffers.extend(b)
1410 buffers.extend(b)
1411 rb = rec.pop('result_buffers', empty) or empty
1411 rb = rec.pop('result_buffers', empty) or empty
1412 if result_buffer_lens is not None:
1412 if result_buffer_lens is not None:
1413 result_buffer_lens.append(len(rb))
1413 result_buffer_lens.append(len(rb))
1414 buffers.extend(rb)
1414 buffers.extend(rb)
1415 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1415 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1416 result_buffer_lens=result_buffer_lens)
1416 result_buffer_lens=result_buffer_lens)
1417 # self.log.debug (content)
1417 # self.log.debug (content)
1418 self.session.send(self.query, "db_reply", content=content,
1418 self.session.send(self.query, "db_reply", content=content,
1419 parent=msg, ident=client_id,
1419 parent=msg, ident=client_id,
1420 buffers=buffers)
1420 buffers=buffers)
1421
1421
@@ -1,122 +1,122 b''
1 """A TaskRecord backend using mongodb
1 """A TaskRecord backend using mongodb
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 from pymongo import Connection
14 from pymongo import Connection
15
15
16 # bson.Binary import moved
16 # bson.Binary import moved
17 try:
17 try:
18 from bson.binary import Binary
18 from bson.binary import Binary
19 except ImportError:
19 except ImportError:
20 from bson import Binary
20 from bson import Binary
21
21
22 from IPython.utils.traitlets import Dict, List, Unicode, Instance
22 from IPython.utils.traitlets import Dict, List, Unicode, Instance
23
23
24 from .dictdb import BaseDB
24 from .dictdb import BaseDB
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # MongoDB class
27 # MongoDB class
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30 class MongoDB(BaseDB):
30 class MongoDB(BaseDB):
31 """MongoDB TaskRecord backend."""
31 """MongoDB TaskRecord backend."""
32
32
33 connection_args = List(config=True,
33 connection_args = List(config=True,
34 help="""Positional arguments to be passed to pymongo.Connection. Only
34 help="""Positional arguments to be passed to pymongo.Connection. Only
35 necessary if the default mongodb configuration does not point to your
35 necessary if the default mongodb configuration does not point to your
36 mongod instance.""")
36 mongod instance.""")
37 connection_kwargs = Dict(config=True,
37 connection_kwargs = Dict(config=True,
38 help="""Keyword arguments to be passed to pymongo.Connection. Only
38 help="""Keyword arguments to be passed to pymongo.Connection. Only
39 necessary if the default mongodb configuration does not point to your
39 necessary if the default mongodb configuration does not point to your
40 mongod instance."""
40 mongod instance."""
41 )
41 )
42 database = Unicode("ipython-tasks", config=True,
42 database = Unicode("ipython-tasks", config=True,
43 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
43 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
44 a new database will be created with the Hub's IDENT. Specifying the database will result
44 a new database will be created with the Hub's IDENT. Specifying the database will result
45 in tasks from previous sessions being available via Clients' db_query and
45 in tasks from previous sessions being available via Clients' db_query and
46 get_result methods.""")
46 get_result methods.""")
47
47
48 _connection = Instance(Connection) # pymongo connection
48 _connection = Instance(Connection) # pymongo connection
49
49
50 def __init__(self, **kwargs):
50 def __init__(self, **kwargs):
51 super(MongoDB, self).__init__(**kwargs)
51 super(MongoDB, self).__init__(**kwargs)
52 if self._connection is None:
52 if self._connection is None:
53 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
53 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
54 if not self.database:
54 if not self.database:
55 self.database = self.session
55 self.database = self.session
56 self._db = self._connection[self.database]
56 self._db = self._connection[self.database]
57 self._records = self._db['task_records']
57 self._records = self._db['task_records']
58 self._records.ensure_index('msg_id', unique=True)
58 self._records.ensure_index('msg_id', unique=True)
59 self._records.ensure_index('submitted') # for sorting history
59 self._records.ensure_index('submitted') # for sorting history
60 # for rec in self._records.find
60 # for rec in self._records.find
61
61
62 def _binary_buffers(self, rec):
62 def _binary_buffers(self, rec):
63 for key in ('buffers', 'result_buffers'):
63 for key in ('buffers', 'result_buffers'):
64 if rec.get(key, None):
64 if rec.get(key, None):
65 rec[key] = map(Binary, rec[key])
65 rec[key] = list(map(Binary, rec[key]))
66 return rec
66 return rec
67
67
68 def add_record(self, msg_id, rec):
68 def add_record(self, msg_id, rec):
69 """Add a new Task Record, by msg_id."""
69 """Add a new Task Record, by msg_id."""
70 # print rec
70 # print rec
71 rec = self._binary_buffers(rec)
71 rec = self._binary_buffers(rec)
72 self._records.insert(rec)
72 self._records.insert(rec)
73
73
74 def get_record(self, msg_id):
74 def get_record(self, msg_id):
75 """Get a specific Task Record, by msg_id."""
75 """Get a specific Task Record, by msg_id."""
76 r = self._records.find_one({'msg_id': msg_id})
76 r = self._records.find_one({'msg_id': msg_id})
77 if not r:
77 if not r:
78 # r will be '' if nothing is found
78 # r will be '' if nothing is found
79 raise KeyError(msg_id)
79 raise KeyError(msg_id)
80 return r
80 return r
81
81
82 def update_record(self, msg_id, rec):
82 def update_record(self, msg_id, rec):
83 """Update the data in an existing record."""
83 """Update the data in an existing record."""
84 rec = self._binary_buffers(rec)
84 rec = self._binary_buffers(rec)
85
85
86 self._records.update({'msg_id':msg_id}, {'$set': rec})
86 self._records.update({'msg_id':msg_id}, {'$set': rec})
87
87
88 def drop_matching_records(self, check):
88 def drop_matching_records(self, check):
89 """Remove a record from the DB."""
89 """Remove a record from the DB."""
90 self._records.remove(check)
90 self._records.remove(check)
91
91
92 def drop_record(self, msg_id):
92 def drop_record(self, msg_id):
93 """Remove a record from the DB."""
93 """Remove a record from the DB."""
94 self._records.remove({'msg_id':msg_id})
94 self._records.remove({'msg_id':msg_id})
95
95
96 def find_records(self, check, keys=None):
96 def find_records(self, check, keys=None):
97 """Find records matching a query dict, optionally extracting subset of keys.
97 """Find records matching a query dict, optionally extracting subset of keys.
98
98
99 Returns list of matching records.
99 Returns list of matching records.
100
100
101 Parameters
101 Parameters
102 ----------
102 ----------
103
103
104 check: dict
104 check: dict
105 mongodb-style query argument
105 mongodb-style query argument
106 keys: list of strs [optional]
106 keys: list of strs [optional]
107 if specified, the subset of keys to extract. msg_id will *always* be
107 if specified, the subset of keys to extract. msg_id will *always* be
108 included.
108 included.
109 """
109 """
110 if keys and 'msg_id' not in keys:
110 if keys and 'msg_id' not in keys:
111 keys.append('msg_id')
111 keys.append('msg_id')
112 matches = list(self._records.find(check,keys))
112 matches = list(self._records.find(check,keys))
113 for rec in matches:
113 for rec in matches:
114 rec.pop('_id')
114 rec.pop('_id')
115 return matches
115 return matches
116
116
117 def get_history(self):
117 def get_history(self):
118 """get all msg_ids, ordered by time submitted."""
118 """get all msg_ids, ordered by time submitted."""
119 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
119 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
120 return [ rec['msg_id'] for rec in cursor ]
120 return [ rec['msg_id'] for rec in cursor ]
121
121
122
122
@@ -1,860 +1,859 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 import logging
22 import logging
23 import sys
23 import sys
24 import time
24 import time
25
25
26 from collections import deque
26 from collections import deque
27 from datetime import datetime
27 from datetime import datetime
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 from IPython.utils.py3compat import cast_bytes
44 from IPython.utils.py3compat import cast_bytes
45
45
46 from IPython.parallel import error, util
46 from IPython.parallel import error, util
47 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.factory import SessionFactory
48 from IPython.parallel.util import connect_logger, local_logger
48 from IPython.parallel.util import connect_logger, local_logger
49
49
50 from .dependency import Dependency
50 from .dependency import Dependency
51
51
52 @decorator
52 @decorator
53 def logged(f,self,*args,**kwargs):
53 def logged(f,self,*args,**kwargs):
54 # print ("#--------------------")
54 # print ("#--------------------")
55 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
55 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
56 # print ("#--")
56 # print ("#--")
57 return f(self,*args, **kwargs)
57 return f(self,*args, **kwargs)
58
58
59 #----------------------------------------------------------------------
59 #----------------------------------------------------------------------
60 # Chooser functions
60 # Chooser functions
61 #----------------------------------------------------------------------
61 #----------------------------------------------------------------------
62
62
63 def plainrandom(loads):
63 def plainrandom(loads):
64 """Plain random pick."""
64 """Plain random pick."""
65 n = len(loads)
65 n = len(loads)
66 return randint(0,n-1)
66 return randint(0,n-1)
67
67
68 def lru(loads):
68 def lru(loads):
69 """Always pick the front of the line.
69 """Always pick the front of the line.
70
70
71 The content of `loads` is ignored.
71 The content of `loads` is ignored.
72
72
73 Assumes LRU ordering of loads, with oldest first.
73 Assumes LRU ordering of loads, with oldest first.
74 """
74 """
75 return 0
75 return 0
76
76
77 def twobin(loads):
77 def twobin(loads):
78 """Pick two at random, use the LRU of the two.
78 """Pick two at random, use the LRU of the two.
79
79
80 The content of loads is ignored.
80 The content of loads is ignored.
81
81
82 Assumes LRU ordering of loads, with oldest first.
82 Assumes LRU ordering of loads, with oldest first.
83 """
83 """
84 n = len(loads)
84 n = len(loads)
85 a = randint(0,n-1)
85 a = randint(0,n-1)
86 b = randint(0,n-1)
86 b = randint(0,n-1)
87 return min(a,b)
87 return min(a,b)
88
88
89 def weighted(loads):
89 def weighted(loads):
90 """Pick two at random using inverse load as weight.
90 """Pick two at random using inverse load as weight.
91
91
92 Return the less loaded of the two.
92 Return the less loaded of the two.
93 """
93 """
94 # weight 0 a million times more than 1:
94 # weight 0 a million times more than 1:
95 weights = 1./(1e-6+numpy.array(loads))
95 weights = 1./(1e-6+numpy.array(loads))
96 sums = weights.cumsum()
96 sums = weights.cumsum()
97 t = sums[-1]
97 t = sums[-1]
98 x = random()*t
98 x = random()*t
99 y = random()*t
99 y = random()*t
100 idx = 0
100 idx = 0
101 idy = 0
101 idy = 0
102 while sums[idx] < x:
102 while sums[idx] < x:
103 idx += 1
103 idx += 1
104 while sums[idy] < y:
104 while sums[idy] < y:
105 idy += 1
105 idy += 1
106 if weights[idy] > weights[idx]:
106 if weights[idy] > weights[idx]:
107 return idy
107 return idy
108 else:
108 else:
109 return idx
109 return idx
110
110
111 def leastload(loads):
111 def leastload(loads):
112 """Always choose the lowest load.
112 """Always choose the lowest load.
113
113
114 If the lowest load occurs more than once, the first
114 If the lowest load occurs more than once, the first
115 occurance will be used. If loads has LRU ordering, this means
115 occurance will be used. If loads has LRU ordering, this means
116 the LRU of those with the lowest load is chosen.
116 the LRU of those with the lowest load is chosen.
117 """
117 """
118 return loads.index(min(loads))
118 return loads.index(min(loads))
119
119
120 #---------------------------------------------------------------------
120 #---------------------------------------------------------------------
121 # Classes
121 # Classes
122 #---------------------------------------------------------------------
122 #---------------------------------------------------------------------
123
123
124
124
125 # store empty default dependency:
125 # store empty default dependency:
126 MET = Dependency([])
126 MET = Dependency([])
127
127
128
128
129 class Job(object):
129 class Job(object):
130 """Simple container for a job"""
130 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 targets, after, follow, timeout):
132 targets, after, follow, timeout):
133 self.msg_id = msg_id
133 self.msg_id = msg_id
134 self.raw_msg = raw_msg
134 self.raw_msg = raw_msg
135 self.idents = idents
135 self.idents = idents
136 self.msg = msg
136 self.msg = msg
137 self.header = header
137 self.header = header
138 self.metadata = metadata
138 self.metadata = metadata
139 self.targets = targets
139 self.targets = targets
140 self.after = after
140 self.after = after
141 self.follow = follow
141 self.follow = follow
142 self.timeout = timeout
142 self.timeout = timeout
143
143
144 self.removed = False # used for lazy-delete from sorted queue
144 self.removed = False # used for lazy-delete from sorted queue
145 self.timestamp = time.time()
145 self.timestamp = time.time()
146 self.timeout_id = 0
146 self.timeout_id = 0
147 self.blacklist = set()
147 self.blacklist = set()
148
148
149 def __lt__(self, other):
149 def __lt__(self, other):
150 return self.timestamp < other.timestamp
150 return self.timestamp < other.timestamp
151
151
152 def __cmp__(self, other):
152 def __cmp__(self, other):
153 return cmp(self.timestamp, other.timestamp)
153 return cmp(self.timestamp, other.timestamp)
154
154
155 @property
155 @property
156 def dependents(self):
156 def dependents(self):
157 return self.follow.union(self.after)
157 return self.follow.union(self.after)
158
158
159
159
160 class TaskScheduler(SessionFactory):
160 class TaskScheduler(SessionFactory):
161 """Python TaskScheduler object.
161 """Python TaskScheduler object.
162
162
163 This is the simplest object that supports msg_id based
163 This is the simplest object that supports msg_id based
164 DAG dependencies. *Only* task msg_ids are checked, not
164 DAG dependencies. *Only* task msg_ids are checked, not
165 msg_ids of jobs submitted via the MUX queue.
165 msg_ids of jobs submitted via the MUX queue.
166
166
167 """
167 """
168
168
169 hwm = Integer(1, config=True,
169 hwm = Integer(1, config=True,
170 help="""specify the High Water Mark (HWM) for the downstream
170 help="""specify the High Water Mark (HWM) for the downstream
171 socket in the Task scheduler. This is the maximum number
171 socket in the Task scheduler. This is the maximum number
172 of allowed outstanding tasks on each engine.
172 of allowed outstanding tasks on each engine.
173
173
174 The default (1) means that only one task can be outstanding on each
174 The default (1) means that only one task can be outstanding on each
175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
176 engines continue to be assigned tasks while they are working,
176 engines continue to be assigned tasks while they are working,
177 effectively hiding network latency behind computation, but can result
177 effectively hiding network latency behind computation, but can result
178 in an imbalance of work when submitting many heterogenous tasks all at
178 in an imbalance of work when submitting many heterogenous tasks all at
179 once. Any positive value greater than one is a compromise between the
179 once. Any positive value greater than one is a compromise between the
180 two.
180 two.
181
181
182 """
182 """
183 )
183 )
184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
185 'leastload', config=True, allow_none=False,
185 'leastload', config=True, allow_none=False,
186 help="""select the task scheduler scheme [default: Python LRU]
186 help="""select the task scheduler scheme [default: Python LRU]
187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
188 )
188 )
189 def _scheme_name_changed(self, old, new):
189 def _scheme_name_changed(self, old, new):
190 self.log.debug("Using scheme %r"%new)
190 self.log.debug("Using scheme %r"%new)
191 self.scheme = globals()[new]
191 self.scheme = globals()[new]
192
192
193 # input arguments:
193 # input arguments:
194 scheme = Instance(FunctionType) # function for determining the destination
194 scheme = Instance(FunctionType) # function for determining the destination
195 def _scheme_default(self):
195 def _scheme_default(self):
196 return leastload
196 return leastload
197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
202
202
203 # internals:
203 # internals:
204 queue = Instance(deque) # sorted list of Jobs
204 queue = Instance(deque) # sorted list of Jobs
205 def _queue_default(self):
205 def _queue_default(self):
206 return deque()
206 return deque()
207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
211 pending = Dict() # dict by engine_uuid of submitted tasks
211 pending = Dict() # dict by engine_uuid of submitted tasks
212 completed = Dict() # dict by engine_uuid of completed tasks
212 completed = Dict() # dict by engine_uuid of completed tasks
213 failed = Dict() # dict by engine_uuid of failed tasks
213 failed = Dict() # dict by engine_uuid of failed tasks
214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
215 clients = Dict() # dict by msg_id for who submitted the task
215 clients = Dict() # dict by msg_id for who submitted the task
216 targets = List() # list of target IDENTs
216 targets = List() # list of target IDENTs
217 loads = List() # list of engine loads
217 loads = List() # list of engine loads
218 # full = Set() # set of IDENTs that have HWM outstanding tasks
218 # full = Set() # set of IDENTs that have HWM outstanding tasks
219 all_completed = Set() # set of all completed tasks
219 all_completed = Set() # set of all completed tasks
220 all_failed = Set() # set of all failed tasks
220 all_failed = Set() # set of all failed tasks
221 all_done = Set() # set of all finished tasks=union(completed,failed)
221 all_done = Set() # set of all finished tasks=union(completed,failed)
222 all_ids = Set() # set of all submitted task IDs
222 all_ids = Set() # set of all submitted task IDs
223
223
224 ident = CBytes() # ZMQ identity. This should just be self.session.session
224 ident = CBytes() # ZMQ identity. This should just be self.session.session
225 # but ensure Bytes
225 # but ensure Bytes
226 def _ident_default(self):
226 def _ident_default(self):
227 return self.session.bsession
227 return self.session.bsession
228
228
229 def start(self):
229 def start(self):
230 self.query_stream.on_recv(self.dispatch_query_reply)
230 self.query_stream.on_recv(self.dispatch_query_reply)
231 self.session.send(self.query_stream, "connection_request", {})
231 self.session.send(self.query_stream, "connection_request", {})
232
232
233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
235
236 self._notification_handlers = dict(
236 self._notification_handlers = dict(
237 registration_notification = self._register_engine,
237 registration_notification = self._register_engine,
238 unregistration_notification = self._unregister_engine
238 unregistration_notification = self._unregister_engine
239 )
239 )
240 self.notifier_stream.on_recv(self.dispatch_notification)
240 self.notifier_stream.on_recv(self.dispatch_notification)
241 self.log.info("Scheduler started [%s]" % self.scheme_name)
241 self.log.info("Scheduler started [%s]" % self.scheme_name)
242
242
243 def resume_receiving(self):
243 def resume_receiving(self):
244 """Resume accepting jobs."""
244 """Resume accepting jobs."""
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246
246
247 def stop_receiving(self):
247 def stop_receiving(self):
248 """Stop accepting jobs while there are no engines.
248 """Stop accepting jobs while there are no engines.
249 Leave them in the ZMQ queue."""
249 Leave them in the ZMQ queue."""
250 self.client_stream.on_recv(None)
250 self.client_stream.on_recv(None)
251
251
252 #-----------------------------------------------------------------------
252 #-----------------------------------------------------------------------
253 # [Un]Registration Handling
253 # [Un]Registration Handling
254 #-----------------------------------------------------------------------
254 #-----------------------------------------------------------------------
255
255
256
256
257 def dispatch_query_reply(self, msg):
257 def dispatch_query_reply(self, msg):
258 """handle reply to our initial connection request"""
258 """handle reply to our initial connection request"""
259 try:
259 try:
260 idents,msg = self.session.feed_identities(msg)
260 idents,msg = self.session.feed_identities(msg)
261 except ValueError:
261 except ValueError:
262 self.log.warn("task::Invalid Message: %r",msg)
262 self.log.warn("task::Invalid Message: %r",msg)
263 return
263 return
264 try:
264 try:
265 msg = self.session.unserialize(msg)
265 msg = self.session.unserialize(msg)
266 except ValueError:
266 except ValueError:
267 self.log.warn("task::Unauthorized message from: %r"%idents)
267 self.log.warn("task::Unauthorized message from: %r"%idents)
268 return
268 return
269
269
270 content = msg['content']
270 content = msg['content']
271 for uuid in content.get('engines', {}).values():
271 for uuid in content.get('engines', {}).values():
272 self._register_engine(cast_bytes(uuid))
272 self._register_engine(cast_bytes(uuid))
273
273
274
274
275 @util.log_errors
275 @util.log_errors
276 def dispatch_notification(self, msg):
276 def dispatch_notification(self, msg):
277 """dispatch register/unregister events."""
277 """dispatch register/unregister events."""
278 try:
278 try:
279 idents,msg = self.session.feed_identities(msg)
279 idents,msg = self.session.feed_identities(msg)
280 except ValueError:
280 except ValueError:
281 self.log.warn("task::Invalid Message: %r",msg)
281 self.log.warn("task::Invalid Message: %r",msg)
282 return
282 return
283 try:
283 try:
284 msg = self.session.unserialize(msg)
284 msg = self.session.unserialize(msg)
285 except ValueError:
285 except ValueError:
286 self.log.warn("task::Unauthorized message from: %r"%idents)
286 self.log.warn("task::Unauthorized message from: %r"%idents)
287 return
287 return
288
288
289 msg_type = msg['header']['msg_type']
289 msg_type = msg['header']['msg_type']
290
290
291 handler = self._notification_handlers.get(msg_type, None)
291 handler = self._notification_handlers.get(msg_type, None)
292 if handler is None:
292 if handler is None:
293 self.log.error("Unhandled message type: %r"%msg_type)
293 self.log.error("Unhandled message type: %r"%msg_type)
294 else:
294 else:
295 try:
295 try:
296 handler(cast_bytes(msg['content']['uuid']))
296 handler(cast_bytes(msg['content']['uuid']))
297 except Exception:
297 except Exception:
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299
299
300 def _register_engine(self, uid):
300 def _register_engine(self, uid):
301 """New engine with ident `uid` became available."""
301 """New engine with ident `uid` became available."""
302 # head of the line:
302 # head of the line:
303 self.targets.insert(0,uid)
303 self.targets.insert(0,uid)
304 self.loads.insert(0,0)
304 self.loads.insert(0,0)
305
305
306 # initialize sets
306 # initialize sets
307 self.completed[uid] = set()
307 self.completed[uid] = set()
308 self.failed[uid] = set()
308 self.failed[uid] = set()
309 self.pending[uid] = {}
309 self.pending[uid] = {}
310
310
311 # rescan the graph:
311 # rescan the graph:
312 self.update_graph(None)
312 self.update_graph(None)
313
313
314 def _unregister_engine(self, uid):
314 def _unregister_engine(self, uid):
315 """Existing engine with ident `uid` became unavailable."""
315 """Existing engine with ident `uid` became unavailable."""
316 if len(self.targets) == 1:
316 if len(self.targets) == 1:
317 # this was our only engine
317 # this was our only engine
318 pass
318 pass
319
319
320 # handle any potentially finished tasks:
320 # handle any potentially finished tasks:
321 self.engine_stream.flush()
321 self.engine_stream.flush()
322
322
323 # don't pop destinations, because they might be used later
323 # don't pop destinations, because they might be used later
324 # map(self.destinations.pop, self.completed.pop(uid))
324 # map(self.destinations.pop, self.completed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
326
326
327 # prevent this engine from receiving work
327 # prevent this engine from receiving work
328 idx = self.targets.index(uid)
328 idx = self.targets.index(uid)
329 self.targets.pop(idx)
329 self.targets.pop(idx)
330 self.loads.pop(idx)
330 self.loads.pop(idx)
331
331
332 # wait 5 seconds before cleaning up pending jobs, since the results might
332 # wait 5 seconds before cleaning up pending jobs, since the results might
333 # still be incoming
333 # still be incoming
334 if self.pending[uid]:
334 if self.pending[uid]:
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 dc.start()
336 dc.start()
337 else:
337 else:
338 self.completed.pop(uid)
338 self.completed.pop(uid)
339 self.failed.pop(uid)
339 self.failed.pop(uid)
340
340
341
341
342 def handle_stranded_tasks(self, engine):
342 def handle_stranded_tasks(self, engine):
343 """Deal with jobs resident in an engine that died."""
343 """Deal with jobs resident in an engine that died."""
344 lost = self.pending[engine]
344 lost = self.pending[engine]
345 for msg_id in lost.keys():
345 for msg_id in lost.keys():
346 if msg_id not in self.pending[engine]:
346 if msg_id not in self.pending[engine]:
347 # prevent double-handling of messages
347 # prevent double-handling of messages
348 continue
348 continue
349
349
350 raw_msg = lost[msg_id].raw_msg
350 raw_msg = lost[msg_id].raw_msg
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 parent = self.session.unpack(msg[1].bytes)
352 parent = self.session.unpack(msg[1].bytes)
353 idents = [engine, idents[0]]
353 idents = [engine, idents[0]]
354
354
355 # build fake error reply
355 # build fake error reply
356 try:
356 try:
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 except:
358 except:
359 content = error.wrap_exception()
359 content = error.wrap_exception()
360 # build fake metadata
360 # build fake metadata
361 md = dict(
361 md = dict(
362 status=u'error',
362 status=u'error',
363 engine=engine.decode('ascii'),
363 engine=engine.decode('ascii'),
364 date=datetime.now(),
364 date=datetime.now(),
365 )
365 )
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
367 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
368 # and dispatch it
368 # and dispatch it
369 self.dispatch_result(raw_reply)
369 self.dispatch_result(raw_reply)
370
370
371 # finally scrub completed/failed lists
371 # finally scrub completed/failed lists
372 self.completed.pop(engine)
372 self.completed.pop(engine)
373 self.failed.pop(engine)
373 self.failed.pop(engine)
374
374
375
375
376 #-----------------------------------------------------------------------
376 #-----------------------------------------------------------------------
377 # Job Submission
377 # Job Submission
378 #-----------------------------------------------------------------------
378 #-----------------------------------------------------------------------
379
379
380
380
381 @util.log_errors
381 @util.log_errors
382 def dispatch_submission(self, raw_msg):
382 def dispatch_submission(self, raw_msg):
383 """Dispatch job submission to appropriate handlers."""
383 """Dispatch job submission to appropriate handlers."""
384 # ensure targets up to date:
384 # ensure targets up to date:
385 self.notifier_stream.flush()
385 self.notifier_stream.flush()
386 try:
386 try:
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
389 except Exception:
389 except Exception:
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 return
391 return
392
392
393
393
394 # send to monitor
394 # send to monitor
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396
396
397 header = msg['header']
397 header = msg['header']
398 md = msg['metadata']
398 md = msg['metadata']
399 msg_id = header['msg_id']
399 msg_id = header['msg_id']
400 self.all_ids.add(msg_id)
400 self.all_ids.add(msg_id)
401
401
402 # get targets as a set of bytes objects
402 # get targets as a set of bytes objects
403 # from a list of unicode objects
403 # from a list of unicode objects
404 targets = md.get('targets', [])
404 targets = md.get('targets', [])
405 targets = map(cast_bytes, targets)
405 targets = set(map(cast_bytes, targets))
406 targets = set(targets)
407
406
408 retries = md.get('retries', 0)
407 retries = md.get('retries', 0)
409 self.retries[msg_id] = retries
408 self.retries[msg_id] = retries
410
409
411 # time dependencies
410 # time dependencies
412 after = md.get('after', None)
411 after = md.get('after', None)
413 if after:
412 if after:
414 after = Dependency(after)
413 after = Dependency(after)
415 if after.all:
414 if after.all:
416 if after.success:
415 if after.success:
417 after = Dependency(after.difference(self.all_completed),
416 after = Dependency(after.difference(self.all_completed),
418 success=after.success,
417 success=after.success,
419 failure=after.failure,
418 failure=after.failure,
420 all=after.all,
419 all=after.all,
421 )
420 )
422 if after.failure:
421 if after.failure:
423 after = Dependency(after.difference(self.all_failed),
422 after = Dependency(after.difference(self.all_failed),
424 success=after.success,
423 success=after.success,
425 failure=after.failure,
424 failure=after.failure,
426 all=after.all,
425 all=after.all,
427 )
426 )
428 if after.check(self.all_completed, self.all_failed):
427 if after.check(self.all_completed, self.all_failed):
429 # recast as empty set, if `after` already met,
428 # recast as empty set, if `after` already met,
430 # to prevent unnecessary set comparisons
429 # to prevent unnecessary set comparisons
431 after = MET
430 after = MET
432 else:
431 else:
433 after = MET
432 after = MET
434
433
435 # location dependencies
434 # location dependencies
436 follow = Dependency(md.get('follow', []))
435 follow = Dependency(md.get('follow', []))
437
436
438 timeout = md.get('timeout', None)
437 timeout = md.get('timeout', None)
439 if timeout:
438 if timeout:
440 timeout = float(timeout)
439 timeout = float(timeout)
441
440
442 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
441 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
443 header=header, targets=targets, after=after, follow=follow,
442 header=header, targets=targets, after=after, follow=follow,
444 timeout=timeout, metadata=md,
443 timeout=timeout, metadata=md,
445 )
444 )
446 # validate and reduce dependencies:
445 # validate and reduce dependencies:
447 for dep in after,follow:
446 for dep in after,follow:
448 if not dep: # empty dependency
447 if not dep: # empty dependency
449 continue
448 continue
450 # check valid:
449 # check valid:
451 if msg_id in dep or dep.difference(self.all_ids):
450 if msg_id in dep or dep.difference(self.all_ids):
452 self.queue_map[msg_id] = job
451 self.queue_map[msg_id] = job
453 return self.fail_unreachable(msg_id, error.InvalidDependency)
452 return self.fail_unreachable(msg_id, error.InvalidDependency)
454 # check if unreachable:
453 # check if unreachable:
455 if dep.unreachable(self.all_completed, self.all_failed):
454 if dep.unreachable(self.all_completed, self.all_failed):
456 self.queue_map[msg_id] = job
455 self.queue_map[msg_id] = job
457 return self.fail_unreachable(msg_id)
456 return self.fail_unreachable(msg_id)
458
457
459 if after.check(self.all_completed, self.all_failed):
458 if after.check(self.all_completed, self.all_failed):
460 # time deps already met, try to run
459 # time deps already met, try to run
461 if not self.maybe_run(job):
460 if not self.maybe_run(job):
462 # can't run yet
461 # can't run yet
463 if msg_id not in self.all_failed:
462 if msg_id not in self.all_failed:
464 # could have failed as unreachable
463 # could have failed as unreachable
465 self.save_unmet(job)
464 self.save_unmet(job)
466 else:
465 else:
467 self.save_unmet(job)
466 self.save_unmet(job)
468
467
469 def job_timeout(self, job, timeout_id):
468 def job_timeout(self, job, timeout_id):
470 """callback for a job's timeout.
469 """callback for a job's timeout.
471
470
472 The job may or may not have been run at this point.
471 The job may or may not have been run at this point.
473 """
472 """
474 if job.timeout_id != timeout_id:
473 if job.timeout_id != timeout_id:
475 # not the most recent call
474 # not the most recent call
476 return
475 return
477 now = time.time()
476 now = time.time()
478 if job.timeout >= (now + 1):
477 if job.timeout >= (now + 1):
479 self.log.warn("task %s timeout fired prematurely: %s > %s",
478 self.log.warn("task %s timeout fired prematurely: %s > %s",
480 job.msg_id, job.timeout, now
479 job.msg_id, job.timeout, now
481 )
480 )
482 if job.msg_id in self.queue_map:
481 if job.msg_id in self.queue_map:
483 # still waiting, but ran out of time
482 # still waiting, but ran out of time
484 self.log.info("task %r timed out", job.msg_id)
483 self.log.info("task %r timed out", job.msg_id)
485 self.fail_unreachable(job.msg_id, error.TaskTimeout)
484 self.fail_unreachable(job.msg_id, error.TaskTimeout)
486
485
487 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
486 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
488 """a task has become unreachable, send a reply with an ImpossibleDependency
487 """a task has become unreachable, send a reply with an ImpossibleDependency
489 error."""
488 error."""
490 if msg_id not in self.queue_map:
489 if msg_id not in self.queue_map:
491 self.log.error("task %r already failed!", msg_id)
490 self.log.error("task %r already failed!", msg_id)
492 return
491 return
493 job = self.queue_map.pop(msg_id)
492 job = self.queue_map.pop(msg_id)
494 # lazy-delete from the queue
493 # lazy-delete from the queue
495 job.removed = True
494 job.removed = True
496 for mid in job.dependents:
495 for mid in job.dependents:
497 if mid in self.graph:
496 if mid in self.graph:
498 self.graph[mid].remove(msg_id)
497 self.graph[mid].remove(msg_id)
499
498
500 try:
499 try:
501 raise why()
500 raise why()
502 except:
501 except:
503 content = error.wrap_exception()
502 content = error.wrap_exception()
504 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
503 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
505
504
506 self.all_done.add(msg_id)
505 self.all_done.add(msg_id)
507 self.all_failed.add(msg_id)
506 self.all_failed.add(msg_id)
508
507
509 msg = self.session.send(self.client_stream, 'apply_reply', content,
508 msg = self.session.send(self.client_stream, 'apply_reply', content,
510 parent=job.header, ident=job.idents)
509 parent=job.header, ident=job.idents)
511 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
510 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
512
511
513 self.update_graph(msg_id, success=False)
512 self.update_graph(msg_id, success=False)
514
513
515 def available_engines(self):
514 def available_engines(self):
516 """return a list of available engine indices based on HWM"""
515 """return a list of available engine indices based on HWM"""
517 if not self.hwm:
516 if not self.hwm:
518 return range(len(self.targets))
517 return list(range(len(self.targets)))
519 available = []
518 available = []
520 for idx in range(len(self.targets)):
519 for idx in range(len(self.targets)):
521 if self.loads[idx] < self.hwm:
520 if self.loads[idx] < self.hwm:
522 available.append(idx)
521 available.append(idx)
523 return available
522 return available
524
523
525 def maybe_run(self, job):
524 def maybe_run(self, job):
526 """check location dependencies, and run if they are met."""
525 """check location dependencies, and run if they are met."""
527 msg_id = job.msg_id
526 msg_id = job.msg_id
528 self.log.debug("Attempting to assign task %s", msg_id)
527 self.log.debug("Attempting to assign task %s", msg_id)
529 available = self.available_engines()
528 available = self.available_engines()
530 if not available:
529 if not available:
531 # no engines, definitely can't run
530 # no engines, definitely can't run
532 return False
531 return False
533
532
534 if job.follow or job.targets or job.blacklist or self.hwm:
533 if job.follow or job.targets or job.blacklist or self.hwm:
535 # we need a can_run filter
534 # we need a can_run filter
536 def can_run(idx):
535 def can_run(idx):
537 # check hwm
536 # check hwm
538 if self.hwm and self.loads[idx] == self.hwm:
537 if self.hwm and self.loads[idx] == self.hwm:
539 return False
538 return False
540 target = self.targets[idx]
539 target = self.targets[idx]
541 # check blacklist
540 # check blacklist
542 if target in job.blacklist:
541 if target in job.blacklist:
543 return False
542 return False
544 # check targets
543 # check targets
545 if job.targets and target not in job.targets:
544 if job.targets and target not in job.targets:
546 return False
545 return False
547 # check follow
546 # check follow
548 return job.follow.check(self.completed[target], self.failed[target])
547 return job.follow.check(self.completed[target], self.failed[target])
549
548
550 indices = filter(can_run, available)
549 indices = list(filter(can_run, available))
551
550
552 if not indices:
551 if not indices:
553 # couldn't run
552 # couldn't run
554 if job.follow.all:
553 if job.follow.all:
555 # check follow for impossibility
554 # check follow for impossibility
556 dests = set()
555 dests = set()
557 relevant = set()
556 relevant = set()
558 if job.follow.success:
557 if job.follow.success:
559 relevant = self.all_completed
558 relevant = self.all_completed
560 if job.follow.failure:
559 if job.follow.failure:
561 relevant = relevant.union(self.all_failed)
560 relevant = relevant.union(self.all_failed)
562 for m in job.follow.intersection(relevant):
561 for m in job.follow.intersection(relevant):
563 dests.add(self.destinations[m])
562 dests.add(self.destinations[m])
564 if len(dests) > 1:
563 if len(dests) > 1:
565 self.queue_map[msg_id] = job
564 self.queue_map[msg_id] = job
566 self.fail_unreachable(msg_id)
565 self.fail_unreachable(msg_id)
567 return False
566 return False
568 if job.targets:
567 if job.targets:
569 # check blacklist+targets for impossibility
568 # check blacklist+targets for impossibility
570 job.targets.difference_update(job.blacklist)
569 job.targets.difference_update(job.blacklist)
571 if not job.targets or not job.targets.intersection(self.targets):
570 if not job.targets or not job.targets.intersection(self.targets):
572 self.queue_map[msg_id] = job
571 self.queue_map[msg_id] = job
573 self.fail_unreachable(msg_id)
572 self.fail_unreachable(msg_id)
574 return False
573 return False
575 return False
574 return False
576 else:
575 else:
577 indices = None
576 indices = None
578
577
579 self.submit_task(job, indices)
578 self.submit_task(job, indices)
580 return True
579 return True
581
580
582 def save_unmet(self, job):
581 def save_unmet(self, job):
583 """Save a message for later submission when its dependencies are met."""
582 """Save a message for later submission when its dependencies are met."""
584 msg_id = job.msg_id
583 msg_id = job.msg_id
585 self.log.debug("Adding task %s to the queue", msg_id)
584 self.log.debug("Adding task %s to the queue", msg_id)
586 self.queue_map[msg_id] = job
585 self.queue_map[msg_id] = job
587 self.queue.append(job)
586 self.queue.append(job)
588 # track the ids in follow or after, but not those already finished
587 # track the ids in follow or after, but not those already finished
589 for dep_id in job.after.union(job.follow).difference(self.all_done):
588 for dep_id in job.after.union(job.follow).difference(self.all_done):
590 if dep_id not in self.graph:
589 if dep_id not in self.graph:
591 self.graph[dep_id] = set()
590 self.graph[dep_id] = set()
592 self.graph[dep_id].add(msg_id)
591 self.graph[dep_id].add(msg_id)
593
592
594 # schedule timeout callback
593 # schedule timeout callback
595 if job.timeout:
594 if job.timeout:
596 timeout_id = job.timeout_id = job.timeout_id + 1
595 timeout_id = job.timeout_id = job.timeout_id + 1
597 self.loop.add_timeout(time.time() + job.timeout,
596 self.loop.add_timeout(time.time() + job.timeout,
598 lambda : self.job_timeout(job, timeout_id)
597 lambda : self.job_timeout(job, timeout_id)
599 )
598 )
600
599
601
600
602 def submit_task(self, job, indices=None):
601 def submit_task(self, job, indices=None):
603 """Submit a task to any of a subset of our targets."""
602 """Submit a task to any of a subset of our targets."""
604 if indices:
603 if indices:
605 loads = [self.loads[i] for i in indices]
604 loads = [self.loads[i] for i in indices]
606 else:
605 else:
607 loads = self.loads
606 loads = self.loads
608 idx = self.scheme(loads)
607 idx = self.scheme(loads)
609 if indices:
608 if indices:
610 idx = indices[idx]
609 idx = indices[idx]
611 target = self.targets[idx]
610 target = self.targets[idx]
612 # print (target, map(str, msg[:3]))
611 # print (target, map(str, msg[:3]))
613 # send job to the engine
612 # send job to the engine
614 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
613 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
615 self.engine_stream.send_multipart(job.raw_msg, copy=False)
614 self.engine_stream.send_multipart(job.raw_msg, copy=False)
616 # update load
615 # update load
617 self.add_job(idx)
616 self.add_job(idx)
618 self.pending[target][job.msg_id] = job
617 self.pending[target][job.msg_id] = job
619 # notify Hub
618 # notify Hub
620 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
619 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
621 self.session.send(self.mon_stream, 'task_destination', content=content,
620 self.session.send(self.mon_stream, 'task_destination', content=content,
622 ident=[b'tracktask',self.ident])
621 ident=[b'tracktask',self.ident])
623
622
624
623
625 #-----------------------------------------------------------------------
624 #-----------------------------------------------------------------------
626 # Result Handling
625 # Result Handling
627 #-----------------------------------------------------------------------
626 #-----------------------------------------------------------------------
628
627
629
628
630 @util.log_errors
629 @util.log_errors
631 def dispatch_result(self, raw_msg):
630 def dispatch_result(self, raw_msg):
632 """dispatch method for result replies"""
631 """dispatch method for result replies"""
633 try:
632 try:
634 idents,msg = self.session.feed_identities(raw_msg, copy=False)
633 idents,msg = self.session.feed_identities(raw_msg, copy=False)
635 msg = self.session.unserialize(msg, content=False, copy=False)
634 msg = self.session.unserialize(msg, content=False, copy=False)
636 engine = idents[0]
635 engine = idents[0]
637 try:
636 try:
638 idx = self.targets.index(engine)
637 idx = self.targets.index(engine)
639 except ValueError:
638 except ValueError:
640 pass # skip load-update for dead engines
639 pass # skip load-update for dead engines
641 else:
640 else:
642 self.finish_job(idx)
641 self.finish_job(idx)
643 except Exception:
642 except Exception:
644 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
643 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
645 return
644 return
646
645
647 md = msg['metadata']
646 md = msg['metadata']
648 parent = msg['parent_header']
647 parent = msg['parent_header']
649 if md.get('dependencies_met', True):
648 if md.get('dependencies_met', True):
650 success = (md['status'] == 'ok')
649 success = (md['status'] == 'ok')
651 msg_id = parent['msg_id']
650 msg_id = parent['msg_id']
652 retries = self.retries[msg_id]
651 retries = self.retries[msg_id]
653 if not success and retries > 0:
652 if not success and retries > 0:
654 # failed
653 # failed
655 self.retries[msg_id] = retries - 1
654 self.retries[msg_id] = retries - 1
656 self.handle_unmet_dependency(idents, parent)
655 self.handle_unmet_dependency(idents, parent)
657 else:
656 else:
658 del self.retries[msg_id]
657 del self.retries[msg_id]
659 # relay to client and update graph
658 # relay to client and update graph
660 self.handle_result(idents, parent, raw_msg, success)
659 self.handle_result(idents, parent, raw_msg, success)
661 # send to Hub monitor
660 # send to Hub monitor
662 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
661 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
663 else:
662 else:
664 self.handle_unmet_dependency(idents, parent)
663 self.handle_unmet_dependency(idents, parent)
665
664
666 def handle_result(self, idents, parent, raw_msg, success=True):
665 def handle_result(self, idents, parent, raw_msg, success=True):
667 """handle a real task result, either success or failure"""
666 """handle a real task result, either success or failure"""
668 # first, relay result to client
667 # first, relay result to client
669 engine = idents[0]
668 engine = idents[0]
670 client = idents[1]
669 client = idents[1]
671 # swap_ids for ROUTER-ROUTER mirror
670 # swap_ids for ROUTER-ROUTER mirror
672 raw_msg[:2] = [client,engine]
671 raw_msg[:2] = [client,engine]
673 # print (map(str, raw_msg[:4]))
672 # print (map(str, raw_msg[:4]))
674 self.client_stream.send_multipart(raw_msg, copy=False)
673 self.client_stream.send_multipart(raw_msg, copy=False)
675 # now, update our data structures
674 # now, update our data structures
676 msg_id = parent['msg_id']
675 msg_id = parent['msg_id']
677 self.pending[engine].pop(msg_id)
676 self.pending[engine].pop(msg_id)
678 if success:
677 if success:
679 self.completed[engine].add(msg_id)
678 self.completed[engine].add(msg_id)
680 self.all_completed.add(msg_id)
679 self.all_completed.add(msg_id)
681 else:
680 else:
682 self.failed[engine].add(msg_id)
681 self.failed[engine].add(msg_id)
683 self.all_failed.add(msg_id)
682 self.all_failed.add(msg_id)
684 self.all_done.add(msg_id)
683 self.all_done.add(msg_id)
685 self.destinations[msg_id] = engine
684 self.destinations[msg_id] = engine
686
685
687 self.update_graph(msg_id, success)
686 self.update_graph(msg_id, success)
688
687
689 def handle_unmet_dependency(self, idents, parent):
688 def handle_unmet_dependency(self, idents, parent):
690 """handle an unmet dependency"""
689 """handle an unmet dependency"""
691 engine = idents[0]
690 engine = idents[0]
692 msg_id = parent['msg_id']
691 msg_id = parent['msg_id']
693
692
694 job = self.pending[engine].pop(msg_id)
693 job = self.pending[engine].pop(msg_id)
695 job.blacklist.add(engine)
694 job.blacklist.add(engine)
696
695
697 if job.blacklist == job.targets:
696 if job.blacklist == job.targets:
698 self.queue_map[msg_id] = job
697 self.queue_map[msg_id] = job
699 self.fail_unreachable(msg_id)
698 self.fail_unreachable(msg_id)
700 elif not self.maybe_run(job):
699 elif not self.maybe_run(job):
701 # resubmit failed
700 # resubmit failed
702 if msg_id not in self.all_failed:
701 if msg_id not in self.all_failed:
703 # put it back in our dependency tree
702 # put it back in our dependency tree
704 self.save_unmet(job)
703 self.save_unmet(job)
705
704
706 if self.hwm:
705 if self.hwm:
707 try:
706 try:
708 idx = self.targets.index(engine)
707 idx = self.targets.index(engine)
709 except ValueError:
708 except ValueError:
710 pass # skip load-update for dead engines
709 pass # skip load-update for dead engines
711 else:
710 else:
712 if self.loads[idx] == self.hwm-1:
711 if self.loads[idx] == self.hwm-1:
713 self.update_graph(None)
712 self.update_graph(None)
714
713
715 def update_graph(self, dep_id=None, success=True):
714 def update_graph(self, dep_id=None, success=True):
716 """dep_id just finished. Update our dependency
715 """dep_id just finished. Update our dependency
717 graph and submit any jobs that just became runnable.
716 graph and submit any jobs that just became runnable.
718
717
719 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
718 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
720 """
719 """
721 # print ("\n\n***********")
720 # print ("\n\n***********")
722 # pprint (dep_id)
721 # pprint (dep_id)
723 # pprint (self.graph)
722 # pprint (self.graph)
724 # pprint (self.queue_map)
723 # pprint (self.queue_map)
725 # pprint (self.all_completed)
724 # pprint (self.all_completed)
726 # pprint (self.all_failed)
725 # pprint (self.all_failed)
727 # print ("\n\n***********\n\n")
726 # print ("\n\n***********\n\n")
728 # update any jobs that depended on the dependency
727 # update any jobs that depended on the dependency
729 msg_ids = self.graph.pop(dep_id, [])
728 msg_ids = self.graph.pop(dep_id, [])
730
729
731 # recheck *all* jobs if
730 # recheck *all* jobs if
732 # a) we have HWM and an engine just become no longer full
731 # a) we have HWM and an engine just become no longer full
733 # or b) dep_id was given as None
732 # or b) dep_id was given as None
734
733
735 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
734 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
736 jobs = self.queue
735 jobs = self.queue
737 using_queue = True
736 using_queue = True
738 else:
737 else:
739 using_queue = False
738 using_queue = False
740 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
739 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
741
740
742 to_restore = []
741 to_restore = []
743 while jobs:
742 while jobs:
744 job = jobs.popleft()
743 job = jobs.popleft()
745 if job.removed:
744 if job.removed:
746 continue
745 continue
747 msg_id = job.msg_id
746 msg_id = job.msg_id
748
747
749 put_it_back = True
748 put_it_back = True
750
749
751 if job.after.unreachable(self.all_completed, self.all_failed)\
750 if job.after.unreachable(self.all_completed, self.all_failed)\
752 or job.follow.unreachable(self.all_completed, self.all_failed):
751 or job.follow.unreachable(self.all_completed, self.all_failed):
753 self.fail_unreachable(msg_id)
752 self.fail_unreachable(msg_id)
754 put_it_back = False
753 put_it_back = False
755
754
756 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
755 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
757 if self.maybe_run(job):
756 if self.maybe_run(job):
758 put_it_back = False
757 put_it_back = False
759 self.queue_map.pop(msg_id)
758 self.queue_map.pop(msg_id)
760 for mid in job.dependents:
759 for mid in job.dependents:
761 if mid in self.graph:
760 if mid in self.graph:
762 self.graph[mid].remove(msg_id)
761 self.graph[mid].remove(msg_id)
763
762
764 # abort the loop if we just filled up all of our engines.
763 # abort the loop if we just filled up all of our engines.
765 # avoids an O(N) operation in situation of full queue,
764 # avoids an O(N) operation in situation of full queue,
766 # where graph update is triggered as soon as an engine becomes
765 # where graph update is triggered as soon as an engine becomes
767 # non-full, and all tasks after the first are checked,
766 # non-full, and all tasks after the first are checked,
768 # even though they can't run.
767 # even though they can't run.
769 if not self.available_engines():
768 if not self.available_engines():
770 break
769 break
771
770
772 if using_queue and put_it_back:
771 if using_queue and put_it_back:
773 # popped a job from the queue but it neither ran nor failed,
772 # popped a job from the queue but it neither ran nor failed,
774 # so we need to put it back when we are done
773 # so we need to put it back when we are done
775 # make sure to_restore preserves the same ordering
774 # make sure to_restore preserves the same ordering
776 to_restore.append(job)
775 to_restore.append(job)
777
776
778 # put back any tasks we popped but didn't run
777 # put back any tasks we popped but didn't run
779 if using_queue:
778 if using_queue:
780 self.queue.extendleft(to_restore)
779 self.queue.extendleft(to_restore)
781
780
782 #----------------------------------------------------------------------
781 #----------------------------------------------------------------------
783 # methods to be overridden by subclasses
782 # methods to be overridden by subclasses
784 #----------------------------------------------------------------------
783 #----------------------------------------------------------------------
785
784
786 def add_job(self, idx):
785 def add_job(self, idx):
787 """Called after self.targets[idx] just got the job with header.
786 """Called after self.targets[idx] just got the job with header.
788 Override with subclasses. The default ordering is simple LRU.
787 Override with subclasses. The default ordering is simple LRU.
789 The default loads are the number of outstanding jobs."""
788 The default loads are the number of outstanding jobs."""
790 self.loads[idx] += 1
789 self.loads[idx] += 1
791 for lis in (self.targets, self.loads):
790 for lis in (self.targets, self.loads):
792 lis.append(lis.pop(idx))
791 lis.append(lis.pop(idx))
793
792
794
793
795 def finish_job(self, idx):
794 def finish_job(self, idx):
796 """Called after self.targets[idx] just finished a job.
795 """Called after self.targets[idx] just finished a job.
797 Override with subclasses."""
796 Override with subclasses."""
798 self.loads[idx] -= 1
797 self.loads[idx] -= 1
799
798
800
799
801
800
802 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
801 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
803 logname='root', log_url=None, loglevel=logging.DEBUG,
802 logname='root', log_url=None, loglevel=logging.DEBUG,
804 identity=b'task', in_thread=False):
803 identity=b'task', in_thread=False):
805
804
806 ZMQStream = zmqstream.ZMQStream
805 ZMQStream = zmqstream.ZMQStream
807
806
808 if config:
807 if config:
809 # unwrap dict back into Config
808 # unwrap dict back into Config
810 config = Config(config)
809 config = Config(config)
811
810
812 if in_thread:
811 if in_thread:
813 # use instance() to get the same Context/Loop as our parent
812 # use instance() to get the same Context/Loop as our parent
814 ctx = zmq.Context.instance()
813 ctx = zmq.Context.instance()
815 loop = ioloop.IOLoop.instance()
814 loop = ioloop.IOLoop.instance()
816 else:
815 else:
817 # in a process, don't use instance()
816 # in a process, don't use instance()
818 # for safety with multiprocessing
817 # for safety with multiprocessing
819 ctx = zmq.Context()
818 ctx = zmq.Context()
820 loop = ioloop.IOLoop()
819 loop = ioloop.IOLoop()
821 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
820 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
822 util.set_hwm(ins, 0)
821 util.set_hwm(ins, 0)
823 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
822 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
824 ins.bind(in_addr)
823 ins.bind(in_addr)
825
824
826 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
825 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
827 util.set_hwm(outs, 0)
826 util.set_hwm(outs, 0)
828 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
827 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
829 outs.bind(out_addr)
828 outs.bind(out_addr)
830 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
829 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
831 util.set_hwm(mons, 0)
830 util.set_hwm(mons, 0)
832 mons.connect(mon_addr)
831 mons.connect(mon_addr)
833 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
832 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
834 nots.setsockopt(zmq.SUBSCRIBE, b'')
833 nots.setsockopt(zmq.SUBSCRIBE, b'')
835 nots.connect(not_addr)
834 nots.connect(not_addr)
836
835
837 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
836 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
838 querys.connect(reg_addr)
837 querys.connect(reg_addr)
839
838
840 # setup logging.
839 # setup logging.
841 if in_thread:
840 if in_thread:
842 log = Application.instance().log
841 log = Application.instance().log
843 else:
842 else:
844 if log_url:
843 if log_url:
845 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
844 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
846 else:
845 else:
847 log = local_logger(logname, loglevel)
846 log = local_logger(logname, loglevel)
848
847
849 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
848 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
850 mon_stream=mons, notifier_stream=nots,
849 mon_stream=mons, notifier_stream=nots,
851 query_stream=querys,
850 query_stream=querys,
852 loop=loop, log=log,
851 loop=loop, log=log,
853 config=config)
852 config=config)
854 scheduler.start()
853 scheduler.start()
855 if not in_thread:
854 if not in_thread:
856 try:
855 try:
857 loop.start()
856 loop.start()
858 except KeyboardInterrupt:
857 except KeyboardInterrupt:
859 scheduler.log.critical("Interrupted, exiting...")
858 scheduler.log.critical("Interrupted, exiting...")
860
859
@@ -1,422 +1,422 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import json
14 import json
15 import os
15 import os
16 try:
16 try:
17 import cPickle as pickle
17 import cPickle as pickle
18 except ImportError:
18 except ImportError:
19 import pickle
19 import pickle
20 from datetime import datetime
20 from datetime import datetime
21
21
22 try:
22 try:
23 import sqlite3
23 import sqlite3
24 except ImportError:
24 except ImportError:
25 sqlite3 = None
25 sqlite3 = None
26
26
27 from zmq.eventloop import ioloop
27 from zmq.eventloop import ioloop
28
28
29 from IPython.utils.traitlets import Unicode, Instance, List, Dict
29 from IPython.utils.traitlets import Unicode, Instance, List, Dict
30 from .dictdb import BaseDB
30 from .dictdb import BaseDB
31 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
31 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
32 from IPython.utils.py3compat import iteritems
32 from IPython.utils.py3compat import iteritems
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # SQLite operators, adapters, and converters
35 # SQLite operators, adapters, and converters
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 try:
38 try:
39 buffer
39 buffer
40 except NameError:
40 except NameError:
41 # py3k
41 # py3k
42 buffer = memoryview
42 buffer = memoryview
43
43
44 operators = {
44 operators = {
45 '$lt' : "<",
45 '$lt' : "<",
46 '$gt' : ">",
46 '$gt' : ">",
47 # null is handled weird with ==,!=
47 # null is handled weird with ==,!=
48 '$eq' : "=",
48 '$eq' : "=",
49 '$ne' : "!=",
49 '$ne' : "!=",
50 '$lte': "<=",
50 '$lte': "<=",
51 '$gte': ">=",
51 '$gte': ">=",
52 '$in' : ('=', ' OR '),
52 '$in' : ('=', ' OR '),
53 '$nin': ('!=', ' AND '),
53 '$nin': ('!=', ' AND '),
54 # '$all': None,
54 # '$all': None,
55 # '$mod': None,
55 # '$mod': None,
56 # '$exists' : None
56 # '$exists' : None
57 }
57 }
58 null_operators = {
58 null_operators = {
59 '=' : "IS NULL",
59 '=' : "IS NULL",
60 '!=' : "IS NOT NULL",
60 '!=' : "IS NOT NULL",
61 }
61 }
62
62
63 def _adapt_dict(d):
63 def _adapt_dict(d):
64 return json.dumps(d, default=date_default)
64 return json.dumps(d, default=date_default)
65
65
66 def _convert_dict(ds):
66 def _convert_dict(ds):
67 if ds is None:
67 if ds is None:
68 return ds
68 return ds
69 else:
69 else:
70 if isinstance(ds, bytes):
70 if isinstance(ds, bytes):
71 # If I understand the sqlite doc correctly, this will always be utf8
71 # If I understand the sqlite doc correctly, this will always be utf8
72 ds = ds.decode('utf8')
72 ds = ds.decode('utf8')
73 return extract_dates(json.loads(ds))
73 return extract_dates(json.loads(ds))
74
74
75 def _adapt_bufs(bufs):
75 def _adapt_bufs(bufs):
76 # this is *horrible*
76 # this is *horrible*
77 # copy buffers into single list and pickle it:
77 # copy buffers into single list and pickle it:
78 if bufs and isinstance(bufs[0], (bytes, buffer)):
78 if bufs and isinstance(bufs[0], (bytes, buffer)):
79 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
79 return sqlite3.Binary(pickle.dumps(list(map(bytes, bufs)),-1))
80 elif bufs:
80 elif bufs:
81 return bufs
81 return bufs
82 else:
82 else:
83 return None
83 return None
84
84
85 def _convert_bufs(bs):
85 def _convert_bufs(bs):
86 if bs is None:
86 if bs is None:
87 return []
87 return []
88 else:
88 else:
89 return pickle.loads(bytes(bs))
89 return pickle.loads(bytes(bs))
90
90
91 #-----------------------------------------------------------------------------
91 #-----------------------------------------------------------------------------
92 # SQLiteDB class
92 # SQLiteDB class
93 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
94
94
95 class SQLiteDB(BaseDB):
95 class SQLiteDB(BaseDB):
96 """SQLite3 TaskRecord backend."""
96 """SQLite3 TaskRecord backend."""
97
97
98 filename = Unicode('tasks.db', config=True,
98 filename = Unicode('tasks.db', config=True,
99 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
99 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
100 location = Unicode('', config=True,
100 location = Unicode('', config=True,
101 help="""The directory containing the sqlite task database. The default
101 help="""The directory containing the sqlite task database. The default
102 is to use the cluster_dir location.""")
102 is to use the cluster_dir location.""")
103 table = Unicode("ipython-tasks", config=True,
103 table = Unicode("ipython-tasks", config=True,
104 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
104 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
105 a new table will be created with the Hub's IDENT. Specifying the table will result
105 a new table will be created with the Hub's IDENT. Specifying the table will result
106 in tasks from previous sessions being available via Clients' db_query and
106 in tasks from previous sessions being available via Clients' db_query and
107 get_result methods.""")
107 get_result methods.""")
108
108
109 if sqlite3 is not None:
109 if sqlite3 is not None:
110 _db = Instance('sqlite3.Connection')
110 _db = Instance('sqlite3.Connection')
111 else:
111 else:
112 _db = None
112 _db = None
113 # the ordered list of column names
113 # the ordered list of column names
114 _keys = List(['msg_id' ,
114 _keys = List(['msg_id' ,
115 'header' ,
115 'header' ,
116 'metadata',
116 'metadata',
117 'content',
117 'content',
118 'buffers',
118 'buffers',
119 'submitted',
119 'submitted',
120 'client_uuid' ,
120 'client_uuid' ,
121 'engine_uuid' ,
121 'engine_uuid' ,
122 'started',
122 'started',
123 'completed',
123 'completed',
124 'resubmitted',
124 'resubmitted',
125 'received',
125 'received',
126 'result_header' ,
126 'result_header' ,
127 'result_metadata',
127 'result_metadata',
128 'result_content' ,
128 'result_content' ,
129 'result_buffers' ,
129 'result_buffers' ,
130 'queue' ,
130 'queue' ,
131 'pyin' ,
131 'pyin' ,
132 'pyout',
132 'pyout',
133 'pyerr',
133 'pyerr',
134 'stdout',
134 'stdout',
135 'stderr',
135 'stderr',
136 ])
136 ])
137 # sqlite datatypes for checking that db is current format
137 # sqlite datatypes for checking that db is current format
138 _types = Dict({'msg_id' : 'text' ,
138 _types = Dict({'msg_id' : 'text' ,
139 'header' : 'dict text',
139 'header' : 'dict text',
140 'metadata' : 'dict text',
140 'metadata' : 'dict text',
141 'content' : 'dict text',
141 'content' : 'dict text',
142 'buffers' : 'bufs blob',
142 'buffers' : 'bufs blob',
143 'submitted' : 'timestamp',
143 'submitted' : 'timestamp',
144 'client_uuid' : 'text',
144 'client_uuid' : 'text',
145 'engine_uuid' : 'text',
145 'engine_uuid' : 'text',
146 'started' : 'timestamp',
146 'started' : 'timestamp',
147 'completed' : 'timestamp',
147 'completed' : 'timestamp',
148 'resubmitted' : 'text',
148 'resubmitted' : 'text',
149 'received' : 'timestamp',
149 'received' : 'timestamp',
150 'result_header' : 'dict text',
150 'result_header' : 'dict text',
151 'result_metadata' : 'dict text',
151 'result_metadata' : 'dict text',
152 'result_content' : 'dict text',
152 'result_content' : 'dict text',
153 'result_buffers' : 'bufs blob',
153 'result_buffers' : 'bufs blob',
154 'queue' : 'text',
154 'queue' : 'text',
155 'pyin' : 'text',
155 'pyin' : 'text',
156 'pyout' : 'text',
156 'pyout' : 'text',
157 'pyerr' : 'text',
157 'pyerr' : 'text',
158 'stdout' : 'text',
158 'stdout' : 'text',
159 'stderr' : 'text',
159 'stderr' : 'text',
160 })
160 })
161
161
162 def __init__(self, **kwargs):
162 def __init__(self, **kwargs):
163 super(SQLiteDB, self).__init__(**kwargs)
163 super(SQLiteDB, self).__init__(**kwargs)
164 if sqlite3 is None:
164 if sqlite3 is None:
165 raise ImportError("SQLiteDB requires sqlite3")
165 raise ImportError("SQLiteDB requires sqlite3")
166 if not self.table:
166 if not self.table:
167 # use session, and prefix _, since starting with # is illegal
167 # use session, and prefix _, since starting with # is illegal
168 self.table = '_'+self.session.replace('-','_')
168 self.table = '_'+self.session.replace('-','_')
169 if not self.location:
169 if not self.location:
170 # get current profile
170 # get current profile
171 from IPython.core.application import BaseIPythonApplication
171 from IPython.core.application import BaseIPythonApplication
172 if BaseIPythonApplication.initialized():
172 if BaseIPythonApplication.initialized():
173 app = BaseIPythonApplication.instance()
173 app = BaseIPythonApplication.instance()
174 if app.profile_dir is not None:
174 if app.profile_dir is not None:
175 self.location = app.profile_dir.location
175 self.location = app.profile_dir.location
176 else:
176 else:
177 self.location = u'.'
177 self.location = u'.'
178 else:
178 else:
179 self.location = u'.'
179 self.location = u'.'
180 self._init_db()
180 self._init_db()
181
181
182 # register db commit as 2s periodic callback
182 # register db commit as 2s periodic callback
183 # to prevent clogging pipes
183 # to prevent clogging pipes
184 # assumes we are being run in a zmq ioloop app
184 # assumes we are being run in a zmq ioloop app
185 loop = ioloop.IOLoop.instance()
185 loop = ioloop.IOLoop.instance()
186 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
186 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
187 pc.start()
187 pc.start()
188
188
189 def _defaults(self, keys=None):
189 def _defaults(self, keys=None):
190 """create an empty record"""
190 """create an empty record"""
191 d = {}
191 d = {}
192 keys = self._keys if keys is None else keys
192 keys = self._keys if keys is None else keys
193 for key in keys:
193 for key in keys:
194 d[key] = None
194 d[key] = None
195 return d
195 return d
196
196
197 def _check_table(self):
197 def _check_table(self):
198 """Ensure that an incorrect table doesn't exist
198 """Ensure that an incorrect table doesn't exist
199
199
200 If a bad (old) table does exist, return False
200 If a bad (old) table does exist, return False
201 """
201 """
202 cursor = self._db.execute("PRAGMA table_info('%s')"%self.table)
202 cursor = self._db.execute("PRAGMA table_info('%s')"%self.table)
203 lines = cursor.fetchall()
203 lines = cursor.fetchall()
204 if not lines:
204 if not lines:
205 # table does not exist
205 # table does not exist
206 return True
206 return True
207 types = {}
207 types = {}
208 keys = []
208 keys = []
209 for line in lines:
209 for line in lines:
210 keys.append(line[1])
210 keys.append(line[1])
211 types[line[1]] = line[2]
211 types[line[1]] = line[2]
212 if self._keys != keys:
212 if self._keys != keys:
213 # key mismatch
213 # key mismatch
214 self.log.warn('keys mismatch')
214 self.log.warn('keys mismatch')
215 return False
215 return False
216 for key in self._keys:
216 for key in self._keys:
217 if types[key] != self._types[key]:
217 if types[key] != self._types[key]:
218 self.log.warn(
218 self.log.warn(
219 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
219 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
220 )
220 )
221 return False
221 return False
222 return True
222 return True
223
223
224 def _init_db(self):
224 def _init_db(self):
225 """Connect to the database and get new session number."""
225 """Connect to the database and get new session number."""
226 # register adapters
226 # register adapters
227 sqlite3.register_adapter(dict, _adapt_dict)
227 sqlite3.register_adapter(dict, _adapt_dict)
228 sqlite3.register_converter('dict', _convert_dict)
228 sqlite3.register_converter('dict', _convert_dict)
229 sqlite3.register_adapter(list, _adapt_bufs)
229 sqlite3.register_adapter(list, _adapt_bufs)
230 sqlite3.register_converter('bufs', _convert_bufs)
230 sqlite3.register_converter('bufs', _convert_bufs)
231 # connect to the db
231 # connect to the db
232 dbfile = os.path.join(self.location, self.filename)
232 dbfile = os.path.join(self.location, self.filename)
233 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
233 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
234 # isolation_level = None)#,
234 # isolation_level = None)#,
235 cached_statements=64)
235 cached_statements=64)
236 # print dir(self._db)
236 # print dir(self._db)
237 first_table = previous_table = self.table
237 first_table = previous_table = self.table
238 i=0
238 i=0
239 while not self._check_table():
239 while not self._check_table():
240 i+=1
240 i+=1
241 self.table = first_table+'_%i'%i
241 self.table = first_table+'_%i'%i
242 self.log.warn(
242 self.log.warn(
243 "Table %s exists and doesn't match db format, trying %s"%
243 "Table %s exists and doesn't match db format, trying %s"%
244 (previous_table, self.table)
244 (previous_table, self.table)
245 )
245 )
246 previous_table = self.table
246 previous_table = self.table
247
247
248 self._db.execute("""CREATE TABLE IF NOT EXISTS '%s'
248 self._db.execute("""CREATE TABLE IF NOT EXISTS '%s'
249 (msg_id text PRIMARY KEY,
249 (msg_id text PRIMARY KEY,
250 header dict text,
250 header dict text,
251 metadata dict text,
251 metadata dict text,
252 content dict text,
252 content dict text,
253 buffers bufs blob,
253 buffers bufs blob,
254 submitted timestamp,
254 submitted timestamp,
255 client_uuid text,
255 client_uuid text,
256 engine_uuid text,
256 engine_uuid text,
257 started timestamp,
257 started timestamp,
258 completed timestamp,
258 completed timestamp,
259 resubmitted text,
259 resubmitted text,
260 received timestamp,
260 received timestamp,
261 result_header dict text,
261 result_header dict text,
262 result_metadata dict text,
262 result_metadata dict text,
263 result_content dict text,
263 result_content dict text,
264 result_buffers bufs blob,
264 result_buffers bufs blob,
265 queue text,
265 queue text,
266 pyin text,
266 pyin text,
267 pyout text,
267 pyout text,
268 pyerr text,
268 pyerr text,
269 stdout text,
269 stdout text,
270 stderr text)
270 stderr text)
271 """%self.table)
271 """%self.table)
272 self._db.commit()
272 self._db.commit()
273
273
274 def _dict_to_list(self, d):
274 def _dict_to_list(self, d):
275 """turn a mongodb-style record dict into a list."""
275 """turn a mongodb-style record dict into a list."""
276
276
277 return [ d[key] for key in self._keys ]
277 return [ d[key] for key in self._keys ]
278
278
279 def _list_to_dict(self, line, keys=None):
279 def _list_to_dict(self, line, keys=None):
280 """Inverse of dict_to_list"""
280 """Inverse of dict_to_list"""
281 keys = self._keys if keys is None else keys
281 keys = self._keys if keys is None else keys
282 d = self._defaults(keys)
282 d = self._defaults(keys)
283 for key,value in zip(keys, line):
283 for key,value in zip(keys, line):
284 d[key] = value
284 d[key] = value
285
285
286 return d
286 return d
287
287
288 def _render_expression(self, check):
288 def _render_expression(self, check):
289 """Turn a mongodb-style search dict into an SQL query."""
289 """Turn a mongodb-style search dict into an SQL query."""
290 expressions = []
290 expressions = []
291 args = []
291 args = []
292
292
293 skeys = set(check.keys())
293 skeys = set(check.keys())
294 skeys.difference_update(set(self._keys))
294 skeys.difference_update(set(self._keys))
295 skeys.difference_update(set(['buffers', 'result_buffers']))
295 skeys.difference_update(set(['buffers', 'result_buffers']))
296 if skeys:
296 if skeys:
297 raise KeyError("Illegal testing key(s): %s"%skeys)
297 raise KeyError("Illegal testing key(s): %s"%skeys)
298
298
299 for name,sub_check in iteritems(check):
299 for name,sub_check in iteritems(check):
300 if isinstance(sub_check, dict):
300 if isinstance(sub_check, dict):
301 for test,value in iteritems(sub_check):
301 for test,value in iteritems(sub_check):
302 try:
302 try:
303 op = operators[test]
303 op = operators[test]
304 except KeyError:
304 except KeyError:
305 raise KeyError("Unsupported operator: %r"%test)
305 raise KeyError("Unsupported operator: %r"%test)
306 if isinstance(op, tuple):
306 if isinstance(op, tuple):
307 op, join = op
307 op, join = op
308
308
309 if value is None and op in null_operators:
309 if value is None and op in null_operators:
310 expr = "%s %s" % (name, null_operators[op])
310 expr = "%s %s" % (name, null_operators[op])
311 else:
311 else:
312 expr = "%s %s ?"%(name, op)
312 expr = "%s %s ?"%(name, op)
313 if isinstance(value, (tuple,list)):
313 if isinstance(value, (tuple,list)):
314 if op in null_operators and any([v is None for v in value]):
314 if op in null_operators and any([v is None for v in value]):
315 # equality tests don't work with NULL
315 # equality tests don't work with NULL
316 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
316 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
317 expr = '( %s )'%( join.join([expr]*len(value)) )
317 expr = '( %s )'%( join.join([expr]*len(value)) )
318 args.extend(value)
318 args.extend(value)
319 else:
319 else:
320 args.append(value)
320 args.append(value)
321 expressions.append(expr)
321 expressions.append(expr)
322 else:
322 else:
323 # it's an equality check
323 # it's an equality check
324 if sub_check is None:
324 if sub_check is None:
325 expressions.append("%s IS NULL" % name)
325 expressions.append("%s IS NULL" % name)
326 else:
326 else:
327 expressions.append("%s = ?"%name)
327 expressions.append("%s = ?"%name)
328 args.append(sub_check)
328 args.append(sub_check)
329
329
330 expr = " AND ".join(expressions)
330 expr = " AND ".join(expressions)
331 return expr, args
331 return expr, args
332
332
333 def add_record(self, msg_id, rec):
333 def add_record(self, msg_id, rec):
334 """Add a new Task Record, by msg_id."""
334 """Add a new Task Record, by msg_id."""
335 d = self._defaults()
335 d = self._defaults()
336 d.update(rec)
336 d.update(rec)
337 d['msg_id'] = msg_id
337 d['msg_id'] = msg_id
338 line = self._dict_to_list(d)
338 line = self._dict_to_list(d)
339 tups = '(%s)'%(','.join(['?']*len(line)))
339 tups = '(%s)'%(','.join(['?']*len(line)))
340 self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line)
340 self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line)
341 # self._db.commit()
341 # self._db.commit()
342
342
343 def get_record(self, msg_id):
343 def get_record(self, msg_id):
344 """Get a specific Task Record, by msg_id."""
344 """Get a specific Task Record, by msg_id."""
345 cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
345 cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
346 line = cursor.fetchone()
346 line = cursor.fetchone()
347 if line is None:
347 if line is None:
348 raise KeyError("No such msg: %r"%msg_id)
348 raise KeyError("No such msg: %r"%msg_id)
349 return self._list_to_dict(line)
349 return self._list_to_dict(line)
350
350
351 def update_record(self, msg_id, rec):
351 def update_record(self, msg_id, rec):
352 """Update the data in an existing record."""
352 """Update the data in an existing record."""
353 query = "UPDATE '%s' SET "%self.table
353 query = "UPDATE '%s' SET "%self.table
354 sets = []
354 sets = []
355 keys = sorted(rec.keys())
355 keys = sorted(rec.keys())
356 values = []
356 values = []
357 for key in keys:
357 for key in keys:
358 sets.append('%s = ?'%key)
358 sets.append('%s = ?'%key)
359 values.append(rec[key])
359 values.append(rec[key])
360 query += ', '.join(sets)
360 query += ', '.join(sets)
361 query += ' WHERE msg_id == ?'
361 query += ' WHERE msg_id == ?'
362 values.append(msg_id)
362 values.append(msg_id)
363 self._db.execute(query, values)
363 self._db.execute(query, values)
364 # self._db.commit()
364 # self._db.commit()
365
365
366 def drop_record(self, msg_id):
366 def drop_record(self, msg_id):
367 """Remove a record from the DB."""
367 """Remove a record from the DB."""
368 self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
368 self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
369 # self._db.commit()
369 # self._db.commit()
370
370
371 def drop_matching_records(self, check):
371 def drop_matching_records(self, check):
372 """Remove a record from the DB."""
372 """Remove a record from the DB."""
373 expr,args = self._render_expression(check)
373 expr,args = self._render_expression(check)
374 query = "DELETE FROM '%s' WHERE %s"%(self.table, expr)
374 query = "DELETE FROM '%s' WHERE %s"%(self.table, expr)
375 self._db.execute(query,args)
375 self._db.execute(query,args)
376 # self._db.commit()
376 # self._db.commit()
377
377
378 def find_records(self, check, keys=None):
378 def find_records(self, check, keys=None):
379 """Find records matching a query dict, optionally extracting subset of keys.
379 """Find records matching a query dict, optionally extracting subset of keys.
380
380
381 Returns list of matching records.
381 Returns list of matching records.
382
382
383 Parameters
383 Parameters
384 ----------
384 ----------
385
385
386 check: dict
386 check: dict
387 mongodb-style query argument
387 mongodb-style query argument
388 keys: list of strs [optional]
388 keys: list of strs [optional]
389 if specified, the subset of keys to extract. msg_id will *always* be
389 if specified, the subset of keys to extract. msg_id will *always* be
390 included.
390 included.
391 """
391 """
392 if keys:
392 if keys:
393 bad_keys = [ key for key in keys if key not in self._keys ]
393 bad_keys = [ key for key in keys if key not in self._keys ]
394 if bad_keys:
394 if bad_keys:
395 raise KeyError("Bad record key(s): %s"%bad_keys)
395 raise KeyError("Bad record key(s): %s"%bad_keys)
396
396
397 if keys:
397 if keys:
398 # ensure msg_id is present and first:
398 # ensure msg_id is present and first:
399 if 'msg_id' in keys:
399 if 'msg_id' in keys:
400 keys.remove('msg_id')
400 keys.remove('msg_id')
401 keys.insert(0, 'msg_id')
401 keys.insert(0, 'msg_id')
402 req = ', '.join(keys)
402 req = ', '.join(keys)
403 else:
403 else:
404 req = '*'
404 req = '*'
405 expr,args = self._render_expression(check)
405 expr,args = self._render_expression(check)
406 query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr)
406 query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr)
407 cursor = self._db.execute(query, args)
407 cursor = self._db.execute(query, args)
408 matches = cursor.fetchall()
408 matches = cursor.fetchall()
409 records = []
409 records = []
410 for line in matches:
410 for line in matches:
411 rec = self._list_to_dict(line, keys)
411 rec = self._list_to_dict(line, keys)
412 records.append(rec)
412 records.append(rec)
413 return records
413 return records
414
414
415 def get_history(self):
415 def get_history(self):
416 """get all msg_ids, ordered by time submitted."""
416 """get all msg_ids, ordered by time submitted."""
417 query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table
417 query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table
418 cursor = self._db.execute(query)
418 cursor = self._db.execute(query)
419 # will be a list of length 1 tuples
419 # will be a list of length 1 tuples
420 return [ tup[0] for tup in cursor.fetchall()]
420 return [ tup[0] for tup in cursor.fetchall()]
421
421
422 __all__ = ['SQLiteDB'] No newline at end of file
422 __all__ = ['SQLiteDB']
@@ -1,326 +1,326 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import time
19 import time
20
20
21 import nose.tools as nt
21 import nose.tools as nt
22
22
23 from IPython.utils.io import capture_output
23 from IPython.utils.io import capture_output
24
24
25 from IPython.parallel.error import TimeoutError
25 from IPython.parallel.error import TimeoutError
26 from IPython.parallel import error, Client
26 from IPython.parallel import error, Client
27 from IPython.parallel.tests import add_engines
27 from IPython.parallel.tests import add_engines
28 from .clienttest import ClusterTestCase
28 from .clienttest import ClusterTestCase
29 from IPython.utils.py3compat import iteritems
29 from IPython.utils.py3compat import iteritems
30
30
31 def setup():
31 def setup():
32 add_engines(2, total=True)
32 add_engines(2, total=True)
33
33
34 def wait(n):
34 def wait(n):
35 import time
35 import time
36 time.sleep(n)
36 time.sleep(n)
37 return n
37 return n
38
38
39 def echo(x):
39 def echo(x):
40 return x
40 return x
41
41
42 class AsyncResultTest(ClusterTestCase):
42 class AsyncResultTest(ClusterTestCase):
43
43
44 def test_single_result_view(self):
44 def test_single_result_view(self):
45 """various one-target views get the right value for single_result"""
45 """various one-target views get the right value for single_result"""
46 eid = self.client.ids[-1]
46 eid = self.client.ids[-1]
47 ar = self.client[eid].apply_async(lambda : 42)
47 ar = self.client[eid].apply_async(lambda : 42)
48 self.assertEqual(ar.get(), 42)
48 self.assertEqual(ar.get(), 42)
49 ar = self.client[[eid]].apply_async(lambda : 42)
49 ar = self.client[[eid]].apply_async(lambda : 42)
50 self.assertEqual(ar.get(), [42])
50 self.assertEqual(ar.get(), [42])
51 ar = self.client[-1:].apply_async(lambda : 42)
51 ar = self.client[-1:].apply_async(lambda : 42)
52 self.assertEqual(ar.get(), [42])
52 self.assertEqual(ar.get(), [42])
53
53
54 def test_get_after_done(self):
54 def test_get_after_done(self):
55 ar = self.client[-1].apply_async(lambda : 42)
55 ar = self.client[-1].apply_async(lambda : 42)
56 ar.wait()
56 ar.wait()
57 self.assertTrue(ar.ready())
57 self.assertTrue(ar.ready())
58 self.assertEqual(ar.get(), 42)
58 self.assertEqual(ar.get(), 42)
59 self.assertEqual(ar.get(), 42)
59 self.assertEqual(ar.get(), 42)
60
60
61 def test_get_before_done(self):
61 def test_get_before_done(self):
62 ar = self.client[-1].apply_async(wait, 0.1)
62 ar = self.client[-1].apply_async(wait, 0.1)
63 self.assertRaises(TimeoutError, ar.get, 0)
63 self.assertRaises(TimeoutError, ar.get, 0)
64 ar.wait(0)
64 ar.wait(0)
65 self.assertFalse(ar.ready())
65 self.assertFalse(ar.ready())
66 self.assertEqual(ar.get(), 0.1)
66 self.assertEqual(ar.get(), 0.1)
67
67
68 def test_get_after_error(self):
68 def test_get_after_error(self):
69 ar = self.client[-1].apply_async(lambda : 1/0)
69 ar = self.client[-1].apply_async(lambda : 1/0)
70 ar.wait(10)
70 ar.wait(10)
71 self.assertRaisesRemote(ZeroDivisionError, ar.get)
71 self.assertRaisesRemote(ZeroDivisionError, ar.get)
72 self.assertRaisesRemote(ZeroDivisionError, ar.get)
72 self.assertRaisesRemote(ZeroDivisionError, ar.get)
73 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
73 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
74
74
75 def test_get_dict(self):
75 def test_get_dict(self):
76 n = len(self.client)
76 n = len(self.client)
77 ar = self.client[:].apply_async(lambda : 5)
77 ar = self.client[:].apply_async(lambda : 5)
78 self.assertEqual(ar.get(), [5]*n)
78 self.assertEqual(ar.get(), [5]*n)
79 d = ar.get_dict()
79 d = ar.get_dict()
80 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
80 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
81 for eid,r in iteritems(d):
81 for eid,r in iteritems(d):
82 self.assertEqual(r, 5)
82 self.assertEqual(r, 5)
83
83
84 def test_get_dict_single(self):
84 def test_get_dict_single(self):
85 view = self.client[-1]
85 view = self.client[-1]
86 for v in (range(5), 5, ('abc', 'def'), 'string'):
86 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
87 ar = view.apply_async(echo, v)
87 ar = view.apply_async(echo, v)
88 self.assertEqual(ar.get(), v)
88 self.assertEqual(ar.get(), v)
89 d = ar.get_dict()
89 d = ar.get_dict()
90 self.assertEqual(d, {view.targets : v})
90 self.assertEqual(d, {view.targets : v})
91
91
92 def test_get_dict_bad(self):
92 def test_get_dict_bad(self):
93 ar = self.client[:].apply_async(lambda : 5)
93 ar = self.client[:].apply_async(lambda : 5)
94 ar2 = self.client[:].apply_async(lambda : 5)
94 ar2 = self.client[:].apply_async(lambda : 5)
95 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
95 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
96 self.assertRaises(ValueError, ar.get_dict)
96 self.assertRaises(ValueError, ar.get_dict)
97
97
98 def test_list_amr(self):
98 def test_list_amr(self):
99 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
99 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
100 rlist = list(ar)
100 rlist = list(ar)
101
101
102 def test_getattr(self):
102 def test_getattr(self):
103 ar = self.client[:].apply_async(wait, 0.5)
103 ar = self.client[:].apply_async(wait, 0.5)
104 self.assertEqual(ar.engine_id, [None] * len(ar))
104 self.assertEqual(ar.engine_id, [None] * len(ar))
105 self.assertRaises(AttributeError, lambda : ar._foo)
105 self.assertRaises(AttributeError, lambda : ar._foo)
106 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
106 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
107 self.assertRaises(AttributeError, lambda : ar.foo)
107 self.assertRaises(AttributeError, lambda : ar.foo)
108 self.assertFalse(hasattr(ar, '__length_hint__'))
108 self.assertFalse(hasattr(ar, '__length_hint__'))
109 self.assertFalse(hasattr(ar, 'foo'))
109 self.assertFalse(hasattr(ar, 'foo'))
110 self.assertTrue(hasattr(ar, 'engine_id'))
110 self.assertTrue(hasattr(ar, 'engine_id'))
111 ar.get(5)
111 ar.get(5)
112 self.assertRaises(AttributeError, lambda : ar._foo)
112 self.assertRaises(AttributeError, lambda : ar._foo)
113 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
113 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
114 self.assertRaises(AttributeError, lambda : ar.foo)
114 self.assertRaises(AttributeError, lambda : ar.foo)
115 self.assertTrue(isinstance(ar.engine_id, list))
115 self.assertTrue(isinstance(ar.engine_id, list))
116 self.assertEqual(ar.engine_id, ar['engine_id'])
116 self.assertEqual(ar.engine_id, ar['engine_id'])
117 self.assertFalse(hasattr(ar, '__length_hint__'))
117 self.assertFalse(hasattr(ar, '__length_hint__'))
118 self.assertFalse(hasattr(ar, 'foo'))
118 self.assertFalse(hasattr(ar, 'foo'))
119 self.assertTrue(hasattr(ar, 'engine_id'))
119 self.assertTrue(hasattr(ar, 'engine_id'))
120
120
121 def test_getitem(self):
121 def test_getitem(self):
122 ar = self.client[:].apply_async(wait, 0.5)
122 ar = self.client[:].apply_async(wait, 0.5)
123 self.assertEqual(ar['engine_id'], [None] * len(ar))
123 self.assertEqual(ar['engine_id'], [None] * len(ar))
124 self.assertRaises(KeyError, lambda : ar['foo'])
124 self.assertRaises(KeyError, lambda : ar['foo'])
125 ar.get(5)
125 ar.get(5)
126 self.assertRaises(KeyError, lambda : ar['foo'])
126 self.assertRaises(KeyError, lambda : ar['foo'])
127 self.assertTrue(isinstance(ar['engine_id'], list))
127 self.assertTrue(isinstance(ar['engine_id'], list))
128 self.assertEqual(ar.engine_id, ar['engine_id'])
128 self.assertEqual(ar.engine_id, ar['engine_id'])
129
129
130 def test_single_result(self):
130 def test_single_result(self):
131 ar = self.client[-1].apply_async(wait, 0.5)
131 ar = self.client[-1].apply_async(wait, 0.5)
132 self.assertRaises(KeyError, lambda : ar['foo'])
132 self.assertRaises(KeyError, lambda : ar['foo'])
133 self.assertEqual(ar['engine_id'], None)
133 self.assertEqual(ar['engine_id'], None)
134 self.assertTrue(ar.get(5) == 0.5)
134 self.assertTrue(ar.get(5) == 0.5)
135 self.assertTrue(isinstance(ar['engine_id'], int))
135 self.assertTrue(isinstance(ar['engine_id'], int))
136 self.assertTrue(isinstance(ar.engine_id, int))
136 self.assertTrue(isinstance(ar.engine_id, int))
137 self.assertEqual(ar.engine_id, ar['engine_id'])
137 self.assertEqual(ar.engine_id, ar['engine_id'])
138
138
139 def test_abort(self):
139 def test_abort(self):
140 e = self.client[-1]
140 e = self.client[-1]
141 ar = e.execute('import time; time.sleep(1)', block=False)
141 ar = e.execute('import time; time.sleep(1)', block=False)
142 ar2 = e.apply_async(lambda : 2)
142 ar2 = e.apply_async(lambda : 2)
143 ar2.abort()
143 ar2.abort()
144 self.assertRaises(error.TaskAborted, ar2.get)
144 self.assertRaises(error.TaskAborted, ar2.get)
145 ar.get()
145 ar.get()
146
146
147 def test_len(self):
147 def test_len(self):
148 v = self.client.load_balanced_view()
148 v = self.client.load_balanced_view()
149 ar = v.map_async(lambda x: x, range(10))
149 ar = v.map_async(lambda x: x, list(range(10)))
150 self.assertEqual(len(ar), 10)
150 self.assertEqual(len(ar), 10)
151 ar = v.apply_async(lambda x: x, range(10))
151 ar = v.apply_async(lambda x: x, list(range(10)))
152 self.assertEqual(len(ar), 1)
152 self.assertEqual(len(ar), 1)
153 ar = self.client[:].apply_async(lambda x: x, range(10))
153 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
154 self.assertEqual(len(ar), len(self.client.ids))
154 self.assertEqual(len(ar), len(self.client.ids))
155
155
156 def test_wall_time_single(self):
156 def test_wall_time_single(self):
157 v = self.client.load_balanced_view()
157 v = self.client.load_balanced_view()
158 ar = v.apply_async(time.sleep, 0.25)
158 ar = v.apply_async(time.sleep, 0.25)
159 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
159 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
160 ar.get(2)
160 ar.get(2)
161 self.assertTrue(ar.wall_time < 1.)
161 self.assertTrue(ar.wall_time < 1.)
162 self.assertTrue(ar.wall_time > 0.2)
162 self.assertTrue(ar.wall_time > 0.2)
163
163
164 def test_wall_time_multi(self):
164 def test_wall_time_multi(self):
165 self.minimum_engines(4)
165 self.minimum_engines(4)
166 v = self.client[:]
166 v = self.client[:]
167 ar = v.apply_async(time.sleep, 0.25)
167 ar = v.apply_async(time.sleep, 0.25)
168 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
168 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
169 ar.get(2)
169 ar.get(2)
170 self.assertTrue(ar.wall_time < 1.)
170 self.assertTrue(ar.wall_time < 1.)
171 self.assertTrue(ar.wall_time > 0.2)
171 self.assertTrue(ar.wall_time > 0.2)
172
172
173 def test_serial_time_single(self):
173 def test_serial_time_single(self):
174 v = self.client.load_balanced_view()
174 v = self.client.load_balanced_view()
175 ar = v.apply_async(time.sleep, 0.25)
175 ar = v.apply_async(time.sleep, 0.25)
176 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
176 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
177 ar.get(2)
177 ar.get(2)
178 self.assertTrue(ar.serial_time < 1.)
178 self.assertTrue(ar.serial_time < 1.)
179 self.assertTrue(ar.serial_time > 0.2)
179 self.assertTrue(ar.serial_time > 0.2)
180
180
181 def test_serial_time_multi(self):
181 def test_serial_time_multi(self):
182 self.minimum_engines(4)
182 self.minimum_engines(4)
183 v = self.client[:]
183 v = self.client[:]
184 ar = v.apply_async(time.sleep, 0.25)
184 ar = v.apply_async(time.sleep, 0.25)
185 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
185 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
186 ar.get(2)
186 ar.get(2)
187 self.assertTrue(ar.serial_time < 2.)
187 self.assertTrue(ar.serial_time < 2.)
188 self.assertTrue(ar.serial_time > 0.8)
188 self.assertTrue(ar.serial_time > 0.8)
189
189
190 def test_elapsed_single(self):
190 def test_elapsed_single(self):
191 v = self.client.load_balanced_view()
191 v = self.client.load_balanced_view()
192 ar = v.apply_async(time.sleep, 0.25)
192 ar = v.apply_async(time.sleep, 0.25)
193 while not ar.ready():
193 while not ar.ready():
194 time.sleep(0.01)
194 time.sleep(0.01)
195 self.assertTrue(ar.elapsed < 1)
195 self.assertTrue(ar.elapsed < 1)
196 self.assertTrue(ar.elapsed < 1)
196 self.assertTrue(ar.elapsed < 1)
197 ar.get(2)
197 ar.get(2)
198
198
199 def test_elapsed_multi(self):
199 def test_elapsed_multi(self):
200 v = self.client[:]
200 v = self.client[:]
201 ar = v.apply_async(time.sleep, 0.25)
201 ar = v.apply_async(time.sleep, 0.25)
202 while not ar.ready():
202 while not ar.ready():
203 time.sleep(0.01)
203 time.sleep(0.01)
204 self.assertTrue(ar.elapsed < 1)
204 self.assertTrue(ar.elapsed < 1)
205 self.assertTrue(ar.elapsed < 1)
205 self.assertTrue(ar.elapsed < 1)
206 ar.get(2)
206 ar.get(2)
207
207
208 def test_hubresult_timestamps(self):
208 def test_hubresult_timestamps(self):
209 self.minimum_engines(4)
209 self.minimum_engines(4)
210 v = self.client[:]
210 v = self.client[:]
211 ar = v.apply_async(time.sleep, 0.25)
211 ar = v.apply_async(time.sleep, 0.25)
212 ar.get(2)
212 ar.get(2)
213 rc2 = Client(profile='iptest')
213 rc2 = Client(profile='iptest')
214 # must have try/finally to close second Client, otherwise
214 # must have try/finally to close second Client, otherwise
215 # will have dangling sockets causing problems
215 # will have dangling sockets causing problems
216 try:
216 try:
217 time.sleep(0.25)
217 time.sleep(0.25)
218 hr = rc2.get_result(ar.msg_ids)
218 hr = rc2.get_result(ar.msg_ids)
219 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
219 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
220 hr.get(1)
220 hr.get(1)
221 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
221 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
222 self.assertEqual(hr.serial_time, ar.serial_time)
222 self.assertEqual(hr.serial_time, ar.serial_time)
223 finally:
223 finally:
224 rc2.close()
224 rc2.close()
225
225
226 def test_display_empty_streams_single(self):
226 def test_display_empty_streams_single(self):
227 """empty stdout/err are not displayed (single result)"""
227 """empty stdout/err are not displayed (single result)"""
228 self.minimum_engines(1)
228 self.minimum_engines(1)
229
229
230 v = self.client[-1]
230 v = self.client[-1]
231 ar = v.execute("print (5555)")
231 ar = v.execute("print (5555)")
232 ar.get(5)
232 ar.get(5)
233 with capture_output() as io:
233 with capture_output() as io:
234 ar.display_outputs()
234 ar.display_outputs()
235 self.assertEqual(io.stderr, '')
235 self.assertEqual(io.stderr, '')
236 self.assertEqual('5555\n', io.stdout)
236 self.assertEqual('5555\n', io.stdout)
237
237
238 ar = v.execute("a=5")
238 ar = v.execute("a=5")
239 ar.get(5)
239 ar.get(5)
240 with capture_output() as io:
240 with capture_output() as io:
241 ar.display_outputs()
241 ar.display_outputs()
242 self.assertEqual(io.stderr, '')
242 self.assertEqual(io.stderr, '')
243 self.assertEqual(io.stdout, '')
243 self.assertEqual(io.stdout, '')
244
244
245 def test_display_empty_streams_type(self):
245 def test_display_empty_streams_type(self):
246 """empty stdout/err are not displayed (groupby type)"""
246 """empty stdout/err are not displayed (groupby type)"""
247 self.minimum_engines(1)
247 self.minimum_engines(1)
248
248
249 v = self.client[:]
249 v = self.client[:]
250 ar = v.execute("print (5555)")
250 ar = v.execute("print (5555)")
251 ar.get(5)
251 ar.get(5)
252 with capture_output() as io:
252 with capture_output() as io:
253 ar.display_outputs()
253 ar.display_outputs()
254 self.assertEqual(io.stderr, '')
254 self.assertEqual(io.stderr, '')
255 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
255 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
256 self.assertFalse('\n\n' in io.stdout, io.stdout)
256 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
257 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
258
258
259 ar = v.execute("a=5")
259 ar = v.execute("a=5")
260 ar.get(5)
260 ar.get(5)
261 with capture_output() as io:
261 with capture_output() as io:
262 ar.display_outputs()
262 ar.display_outputs()
263 self.assertEqual(io.stderr, '')
263 self.assertEqual(io.stderr, '')
264 self.assertEqual(io.stdout, '')
264 self.assertEqual(io.stdout, '')
265
265
266 def test_display_empty_streams_engine(self):
266 def test_display_empty_streams_engine(self):
267 """empty stdout/err are not displayed (groupby engine)"""
267 """empty stdout/err are not displayed (groupby engine)"""
268 self.minimum_engines(1)
268 self.minimum_engines(1)
269
269
270 v = self.client[:]
270 v = self.client[:]
271 ar = v.execute("print (5555)")
271 ar = v.execute("print (5555)")
272 ar.get(5)
272 ar.get(5)
273 with capture_output() as io:
273 with capture_output() as io:
274 ar.display_outputs('engine')
274 ar.display_outputs('engine')
275 self.assertEqual(io.stderr, '')
275 self.assertEqual(io.stderr, '')
276 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
276 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
277 self.assertFalse('\n\n' in io.stdout, io.stdout)
277 self.assertFalse('\n\n' in io.stdout, io.stdout)
278 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
278 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
279
279
280 ar = v.execute("a=5")
280 ar = v.execute("a=5")
281 ar.get(5)
281 ar.get(5)
282 with capture_output() as io:
282 with capture_output() as io:
283 ar.display_outputs('engine')
283 ar.display_outputs('engine')
284 self.assertEqual(io.stderr, '')
284 self.assertEqual(io.stderr, '')
285 self.assertEqual(io.stdout, '')
285 self.assertEqual(io.stdout, '')
286
286
287 def test_await_data(self):
287 def test_await_data(self):
288 """asking for ar.data flushes outputs"""
288 """asking for ar.data flushes outputs"""
289 self.minimum_engines(1)
289 self.minimum_engines(1)
290
290
291 v = self.client[-1]
291 v = self.client[-1]
292 ar = v.execute('\n'.join([
292 ar = v.execute('\n'.join([
293 "import time",
293 "import time",
294 "from IPython.kernel.zmq.datapub import publish_data",
294 "from IPython.kernel.zmq.datapub import publish_data",
295 "for i in range(5):",
295 "for i in range(5):",
296 " publish_data(dict(i=i))",
296 " publish_data(dict(i=i))",
297 " time.sleep(0.1)",
297 " time.sleep(0.1)",
298 ]), block=False)
298 ]), block=False)
299 found = set()
299 found = set()
300 tic = time.time()
300 tic = time.time()
301 # timeout after 10s
301 # timeout after 10s
302 while time.time() <= tic + 10:
302 while time.time() <= tic + 10:
303 if ar.data:
303 if ar.data:
304 found.add(ar.data['i'])
304 found.add(ar.data['i'])
305 if ar.data['i'] == 4:
305 if ar.data['i'] == 4:
306 break
306 break
307 time.sleep(0.05)
307 time.sleep(0.05)
308
308
309 ar.get(5)
309 ar.get(5)
310 nt.assert_in(4, found)
310 nt.assert_in(4, found)
311 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
311 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
312
312
313 def test_not_single_result(self):
313 def test_not_single_result(self):
314 save_build = self.client._build_targets
314 save_build = self.client._build_targets
315 def single_engine(*a, **kw):
315 def single_engine(*a, **kw):
316 idents, targets = save_build(*a, **kw)
316 idents, targets = save_build(*a, **kw)
317 return idents[:1], targets[:1]
317 return idents[:1], targets[:1]
318 ids = single_engine('all')[1]
318 ids = single_engine('all')[1]
319 self.client._build_targets = single_engine
319 self.client._build_targets = single_engine
320 for targets in ('all', None, ids):
320 for targets in ('all', None, ids):
321 dv = self.client.direct_view(targets=targets)
321 dv = self.client.direct_view(targets=targets)
322 ar = dv.apply_async(lambda : 5)
322 ar = dv.apply_async(lambda : 5)
323 self.assertEqual(ar.get(10), [5])
323 self.assertEqual(ar.get(10), [5])
324 self.client._build_targets = save_build
324 self.client._build_targets = save_build
325
325
326
326
@@ -1,517 +1,517 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython import parallel
27 from IPython import parallel
28 from IPython.parallel.client import client as clientmod
28 from IPython.parallel.client import client as clientmod
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import AsyncResult, AsyncHubResult
31 from IPython.parallel import LoadBalancedView, DirectView
31 from IPython.parallel import LoadBalancedView, DirectView
32
32
33 from .clienttest import ClusterTestCase, segfault, wait, add_engines
33 from .clienttest import ClusterTestCase, segfault, wait, add_engines
34
34
35 def setup():
35 def setup():
36 add_engines(4, total=True)
36 add_engines(4, total=True)
37
37
38 class TestClient(ClusterTestCase):
38 class TestClient(ClusterTestCase):
39
39
40 def test_ids(self):
40 def test_ids(self):
41 n = len(self.client.ids)
41 n = len(self.client.ids)
42 self.add_engines(2)
42 self.add_engines(2)
43 self.assertEqual(len(self.client.ids), n+2)
43 self.assertEqual(len(self.client.ids), n+2)
44
44
45 def test_view_indexing(self):
45 def test_view_indexing(self):
46 """test index access for views"""
46 """test index access for views"""
47 self.minimum_engines(4)
47 self.minimum_engines(4)
48 targets = self.client._build_targets('all')[-1]
48 targets = self.client._build_targets('all')[-1]
49 v = self.client[:]
49 v = self.client[:]
50 self.assertEqual(v.targets, targets)
50 self.assertEqual(v.targets, targets)
51 t = self.client.ids[2]
51 t = self.client.ids[2]
52 v = self.client[t]
52 v = self.client[t]
53 self.assertTrue(isinstance(v, DirectView))
53 self.assertTrue(isinstance(v, DirectView))
54 self.assertEqual(v.targets, t)
54 self.assertEqual(v.targets, t)
55 t = self.client.ids[2:4]
55 t = self.client.ids[2:4]
56 v = self.client[t]
56 v = self.client[t]
57 self.assertTrue(isinstance(v, DirectView))
57 self.assertTrue(isinstance(v, DirectView))
58 self.assertEqual(v.targets, t)
58 self.assertEqual(v.targets, t)
59 v = self.client[::2]
59 v = self.client[::2]
60 self.assertTrue(isinstance(v, DirectView))
60 self.assertTrue(isinstance(v, DirectView))
61 self.assertEqual(v.targets, targets[::2])
61 self.assertEqual(v.targets, targets[::2])
62 v = self.client[1::3]
62 v = self.client[1::3]
63 self.assertTrue(isinstance(v, DirectView))
63 self.assertTrue(isinstance(v, DirectView))
64 self.assertEqual(v.targets, targets[1::3])
64 self.assertEqual(v.targets, targets[1::3])
65 v = self.client[:-3]
65 v = self.client[:-3]
66 self.assertTrue(isinstance(v, DirectView))
66 self.assertTrue(isinstance(v, DirectView))
67 self.assertEqual(v.targets, targets[:-3])
67 self.assertEqual(v.targets, targets[:-3])
68 v = self.client[-1]
68 v = self.client[-1]
69 self.assertTrue(isinstance(v, DirectView))
69 self.assertTrue(isinstance(v, DirectView))
70 self.assertEqual(v.targets, targets[-1])
70 self.assertEqual(v.targets, targets[-1])
71 self.assertRaises(TypeError, lambda : self.client[None])
71 self.assertRaises(TypeError, lambda : self.client[None])
72
72
73 def test_lbview_targets(self):
73 def test_lbview_targets(self):
74 """test load_balanced_view targets"""
74 """test load_balanced_view targets"""
75 v = self.client.load_balanced_view()
75 v = self.client.load_balanced_view()
76 self.assertEqual(v.targets, None)
76 self.assertEqual(v.targets, None)
77 v = self.client.load_balanced_view(-1)
77 v = self.client.load_balanced_view(-1)
78 self.assertEqual(v.targets, [self.client.ids[-1]])
78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 v = self.client.load_balanced_view('all')
79 v = self.client.load_balanced_view('all')
80 self.assertEqual(v.targets, None)
80 self.assertEqual(v.targets, None)
81
81
82 def test_dview_targets(self):
82 def test_dview_targets(self):
83 """test direct_view targets"""
83 """test direct_view targets"""
84 v = self.client.direct_view()
84 v = self.client.direct_view()
85 self.assertEqual(v.targets, 'all')
85 self.assertEqual(v.targets, 'all')
86 v = self.client.direct_view('all')
86 v = self.client.direct_view('all')
87 self.assertEqual(v.targets, 'all')
87 self.assertEqual(v.targets, 'all')
88 v = self.client.direct_view(-1)
88 v = self.client.direct_view(-1)
89 self.assertEqual(v.targets, self.client.ids[-1])
89 self.assertEqual(v.targets, self.client.ids[-1])
90
90
91 def test_lazy_all_targets(self):
91 def test_lazy_all_targets(self):
92 """test lazy evaluation of rc.direct_view('all')"""
92 """test lazy evaluation of rc.direct_view('all')"""
93 v = self.client.direct_view()
93 v = self.client.direct_view()
94 self.assertEqual(v.targets, 'all')
94 self.assertEqual(v.targets, 'all')
95
95
96 def double(x):
96 def double(x):
97 return x*2
97 return x*2
98 seq = range(100)
98 seq = list(range(100))
99 ref = [ double(x) for x in seq ]
99 ref = [ double(x) for x in seq ]
100
100
101 # add some engines, which should be used
101 # add some engines, which should be used
102 self.add_engines(1)
102 self.add_engines(1)
103 n1 = len(self.client.ids)
103 n1 = len(self.client.ids)
104
104
105 # simple apply
105 # simple apply
106 r = v.apply_sync(lambda : 1)
106 r = v.apply_sync(lambda : 1)
107 self.assertEqual(r, [1] * n1)
107 self.assertEqual(r, [1] * n1)
108
108
109 # map goes through remotefunction
109 # map goes through remotefunction
110 r = v.map_sync(double, seq)
110 r = v.map_sync(double, seq)
111 self.assertEqual(r, ref)
111 self.assertEqual(r, ref)
112
112
113 # add a couple more engines, and try again
113 # add a couple more engines, and try again
114 self.add_engines(2)
114 self.add_engines(2)
115 n2 = len(self.client.ids)
115 n2 = len(self.client.ids)
116 self.assertNotEqual(n2, n1)
116 self.assertNotEqual(n2, n1)
117
117
118 # apply
118 # apply
119 r = v.apply_sync(lambda : 1)
119 r = v.apply_sync(lambda : 1)
120 self.assertEqual(r, [1] * n2)
120 self.assertEqual(r, [1] * n2)
121
121
122 # map
122 # map
123 r = v.map_sync(double, seq)
123 r = v.map_sync(double, seq)
124 self.assertEqual(r, ref)
124 self.assertEqual(r, ref)
125
125
126 def test_targets(self):
126 def test_targets(self):
127 """test various valid targets arguments"""
127 """test various valid targets arguments"""
128 build = self.client._build_targets
128 build = self.client._build_targets
129 ids = self.client.ids
129 ids = self.client.ids
130 idents,targets = build(None)
130 idents,targets = build(None)
131 self.assertEqual(ids, targets)
131 self.assertEqual(ids, targets)
132
132
133 def test_clear(self):
133 def test_clear(self):
134 """test clear behavior"""
134 """test clear behavior"""
135 self.minimum_engines(2)
135 self.minimum_engines(2)
136 v = self.client[:]
136 v = self.client[:]
137 v.block=True
137 v.block=True
138 v.push(dict(a=5))
138 v.push(dict(a=5))
139 v.pull('a')
139 v.pull('a')
140 id0 = self.client.ids[-1]
140 id0 = self.client.ids[-1]
141 self.client.clear(targets=id0, block=True)
141 self.client.clear(targets=id0, block=True)
142 a = self.client[:-1].get('a')
142 a = self.client[:-1].get('a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 self.client.clear(block=True)
144 self.client.clear(block=True)
145 for i in self.client.ids:
145 for i in self.client.ids:
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147
147
148 def test_get_result(self):
148 def test_get_result(self):
149 """test getting results from the Hub."""
149 """test getting results from the Hub."""
150 c = clientmod.Client(profile='iptest')
150 c = clientmod.Client(profile='iptest')
151 t = c.ids[-1]
151 t = c.ids[-1]
152 ar = c[t].apply_async(wait, 1)
152 ar = c[t].apply_async(wait, 1)
153 # give the monitor time to notice the message
153 # give the monitor time to notice the message
154 time.sleep(.25)
154 time.sleep(.25)
155 ahr = self.client.get_result(ar.msg_ids[0])
155 ahr = self.client.get_result(ar.msg_ids[0])
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 self.assertEqual(ahr.get(), ar.get())
157 self.assertEqual(ahr.get(), ar.get())
158 ar2 = self.client.get_result(ar.msg_ids[0])
158 ar2 = self.client.get_result(ar.msg_ids[0])
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 c.close()
160 c.close()
161
161
162 def test_get_execute_result(self):
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
165 t = c.ids[-1]
166 cell = '\n'.join([
166 cell = '\n'.join([
167 'import time',
167 'import time',
168 'time.sleep(0.25)',
168 'time.sleep(0.25)',
169 '5'
169 '5'
170 ])
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
172 # give the monitor time to notice the message
173 time.sleep(.25)
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids[0])
174 ahr = self.client.get_result(ar.msg_ids[0])
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids[0])
177 ar2 = self.client.get_result(ar.msg_ids[0])
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
179 c.close()
180
180
181 def test_ids_list(self):
181 def test_ids_list(self):
182 """test client.ids"""
182 """test client.ids"""
183 ids = self.client.ids
183 ids = self.client.ids
184 self.assertEqual(ids, self.client._ids)
184 self.assertEqual(ids, self.client._ids)
185 self.assertFalse(ids is self.client._ids)
185 self.assertFalse(ids is self.client._ids)
186 ids.remove(ids[-1])
186 ids.remove(ids[-1])
187 self.assertNotEqual(ids, self.client._ids)
187 self.assertNotEqual(ids, self.client._ids)
188
188
189 def test_queue_status(self):
189 def test_queue_status(self):
190 ids = self.client.ids
190 ids = self.client.ids
191 id0 = ids[0]
191 id0 = ids[0]
192 qs = self.client.queue_status(targets=id0)
192 qs = self.client.queue_status(targets=id0)
193 self.assertTrue(isinstance(qs, dict))
193 self.assertTrue(isinstance(qs, dict))
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 allqs = self.client.queue_status()
195 allqs = self.client.queue_status()
196 self.assertTrue(isinstance(allqs, dict))
196 self.assertTrue(isinstance(allqs, dict))
197 intkeys = list(allqs.keys())
197 intkeys = list(allqs.keys())
198 intkeys.remove('unassigned')
198 intkeys.remove('unassigned')
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 unassigned = allqs.pop('unassigned')
200 unassigned = allqs.pop('unassigned')
201 for eid,qs in allqs.items():
201 for eid,qs in allqs.items():
202 self.assertTrue(isinstance(qs, dict))
202 self.assertTrue(isinstance(qs, dict))
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204
204
205 def test_shutdown(self):
205 def test_shutdown(self):
206 ids = self.client.ids
206 ids = self.client.ids
207 id0 = ids[0]
207 id0 = ids[0]
208 self.client.shutdown(id0, block=True)
208 self.client.shutdown(id0, block=True)
209 while id0 in self.client.ids:
209 while id0 in self.client.ids:
210 time.sleep(0.1)
210 time.sleep(0.1)
211 self.client.spin()
211 self.client.spin()
212
212
213 self.assertRaises(IndexError, lambda : self.client[id0])
213 self.assertRaises(IndexError, lambda : self.client[id0])
214
214
215 def test_result_status(self):
215 def test_result_status(self):
216 pass
216 pass
217 # to be written
217 # to be written
218
218
219 def test_db_query_dt(self):
219 def test_db_query_dt(self):
220 """test db query by date"""
220 """test db query by date"""
221 hist = self.client.hub_history()
221 hist = self.client.hub_history()
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 tic = middle['submitted']
223 tic = middle['submitted']
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 self.assertEqual(len(before)+len(after),len(hist))
226 self.assertEqual(len(before)+len(after),len(hist))
227 for b in before:
227 for b in before:
228 self.assertTrue(b['submitted'] < tic)
228 self.assertTrue(b['submitted'] < tic)
229 for a in after:
229 for a in after:
230 self.assertTrue(a['submitted'] >= tic)
230 self.assertTrue(a['submitted'] >= tic)
231 same = self.client.db_query({'submitted' : tic})
231 same = self.client.db_query({'submitted' : tic})
232 for s in same:
232 for s in same:
233 self.assertTrue(s['submitted'] == tic)
233 self.assertTrue(s['submitted'] == tic)
234
234
235 def test_db_query_keys(self):
235 def test_db_query_keys(self):
236 """test extracting subset of record keys"""
236 """test extracting subset of record keys"""
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 for rec in found:
238 for rec in found:
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240
240
241 def test_db_query_default_keys(self):
241 def test_db_query_default_keys(self):
242 """default db_query excludes buffers"""
242 """default db_query excludes buffers"""
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 for rec in found:
244 for rec in found:
245 keys = set(rec.keys())
245 keys = set(rec.keys())
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248
248
249 def test_db_query_msg_id(self):
249 def test_db_query_msg_id(self):
250 """ensure msg_id is always in db queries"""
250 """ensure msg_id is always in db queries"""
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 for rec in found:
252 for rec in found:
253 self.assertTrue('msg_id' in rec.keys())
253 self.assertTrue('msg_id' in rec.keys())
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 for rec in found:
255 for rec in found:
256 self.assertTrue('msg_id' in rec.keys())
256 self.assertTrue('msg_id' in rec.keys())
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 for rec in found:
258 for rec in found:
259 self.assertTrue('msg_id' in rec.keys())
259 self.assertTrue('msg_id' in rec.keys())
260
260
261 def test_db_query_get_result(self):
261 def test_db_query_get_result(self):
262 """pop in db_query shouldn't pop from result itself"""
262 """pop in db_query shouldn't pop from result itself"""
263 self.client[:].apply_sync(lambda : 1)
263 self.client[:].apply_sync(lambda : 1)
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 rc2 = clientmod.Client(profile='iptest')
265 rc2 = clientmod.Client(profile='iptest')
266 # If this bug is not fixed, this call will hang:
266 # If this bug is not fixed, this call will hang:
267 ar = rc2.get_result(self.client.history[-1])
267 ar = rc2.get_result(self.client.history[-1])
268 ar.wait(2)
268 ar.wait(2)
269 self.assertTrue(ar.ready())
269 self.assertTrue(ar.ready())
270 ar.get()
270 ar.get()
271 rc2.close()
271 rc2.close()
272
272
273 def test_db_query_in(self):
273 def test_db_query_in(self):
274 """test db query with '$in','$nin' operators"""
274 """test db query with '$in','$nin' operators"""
275 hist = self.client.hub_history()
275 hist = self.client.hub_history()
276 even = hist[::2]
276 even = hist[::2]
277 odd = hist[1::2]
277 odd = hist[1::2]
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 found = [ r['msg_id'] for r in recs ]
279 found = [ r['msg_id'] for r in recs ]
280 self.assertEqual(set(even), set(found))
280 self.assertEqual(set(even), set(found))
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 found = [ r['msg_id'] for r in recs ]
282 found = [ r['msg_id'] for r in recs ]
283 self.assertEqual(set(odd), set(found))
283 self.assertEqual(set(odd), set(found))
284
284
285 def test_hub_history(self):
285 def test_hub_history(self):
286 hist = self.client.hub_history()
286 hist = self.client.hub_history()
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 recdict = {}
288 recdict = {}
289 for rec in recs:
289 for rec in recs:
290 recdict[rec['msg_id']] = rec
290 recdict[rec['msg_id']] = rec
291
291
292 latest = datetime(1984,1,1)
292 latest = datetime(1984,1,1)
293 for msg_id in hist:
293 for msg_id in hist:
294 rec = recdict[msg_id]
294 rec = recdict[msg_id]
295 newt = rec['submitted']
295 newt = rec['submitted']
296 self.assertTrue(newt >= latest)
296 self.assertTrue(newt >= latest)
297 latest = newt
297 latest = newt
298 ar = self.client[-1].apply_async(lambda : 1)
298 ar = self.client[-1].apply_async(lambda : 1)
299 ar.get()
299 ar.get()
300 time.sleep(0.25)
300 time.sleep(0.25)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302
302
303 def _wait_for_idle(self):
303 def _wait_for_idle(self):
304 """wait for an engine to become idle, according to the Hub"""
304 """wait for an engine to become idle, according to the Hub"""
305 rc = self.client
305 rc = self.client
306
306
307 # step 1. wait for all requests to be noticed
307 # step 1. wait for all requests to be noticed
308 # timeout 5s, polling every 100ms
308 # timeout 5s, polling every 100ms
309 msg_ids = set(rc.history)
309 msg_ids = set(rc.history)
310 hub_hist = rc.hub_history()
310 hub_hist = rc.hub_history()
311 for i in range(50):
311 for i in range(50):
312 if msg_ids.difference(hub_hist):
312 if msg_ids.difference(hub_hist):
313 time.sleep(0.1)
313 time.sleep(0.1)
314 hub_hist = rc.hub_history()
314 hub_hist = rc.hub_history()
315 else:
315 else:
316 break
316 break
317
317
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319
319
320 # step 2. wait for all requests to be done
320 # step 2. wait for all requests to be done
321 # timeout 5s, polling every 100ms
321 # timeout 5s, polling every 100ms
322 qs = rc.queue_status()
322 qs = rc.queue_status()
323 for i in range(50):
323 for i in range(50):
324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 time.sleep(0.1)
325 time.sleep(0.1)
326 qs = rc.queue_status()
326 qs = rc.queue_status()
327 else:
327 else:
328 break
328 break
329
329
330 # ensure Hub up to date:
330 # ensure Hub up to date:
331 self.assertEqual(qs['unassigned'], 0)
331 self.assertEqual(qs['unassigned'], 0)
332 for eid in rc.ids:
332 for eid in rc.ids:
333 self.assertEqual(qs[eid]['tasks'], 0)
333 self.assertEqual(qs[eid]['tasks'], 0)
334
334
335
335
336 def test_resubmit(self):
336 def test_resubmit(self):
337 def f():
337 def f():
338 import random
338 import random
339 return random.random()
339 return random.random()
340 v = self.client.load_balanced_view()
340 v = self.client.load_balanced_view()
341 ar = v.apply_async(f)
341 ar = v.apply_async(f)
342 r1 = ar.get(1)
342 r1 = ar.get(1)
343 # give the Hub a chance to notice:
343 # give the Hub a chance to notice:
344 self._wait_for_idle()
344 self._wait_for_idle()
345 ahr = self.client.resubmit(ar.msg_ids)
345 ahr = self.client.resubmit(ar.msg_ids)
346 r2 = ahr.get(1)
346 r2 = ahr.get(1)
347 self.assertFalse(r1 == r2)
347 self.assertFalse(r1 == r2)
348
348
349 def test_resubmit_chain(self):
349 def test_resubmit_chain(self):
350 """resubmit resubmitted tasks"""
350 """resubmit resubmitted tasks"""
351 v = self.client.load_balanced_view()
351 v = self.client.load_balanced_view()
352 ar = v.apply_async(lambda x: x, 'x'*1024)
352 ar = v.apply_async(lambda x: x, 'x'*1024)
353 ar.get()
353 ar.get()
354 self._wait_for_idle()
354 self._wait_for_idle()
355 ars = [ar]
355 ars = [ar]
356
356
357 for i in range(10):
357 for i in range(10):
358 ar = ars[-1]
358 ar = ars[-1]
359 ar2 = self.client.resubmit(ar.msg_ids)
359 ar2 = self.client.resubmit(ar.msg_ids)
360
360
361 [ ar.get() for ar in ars ]
361 [ ar.get() for ar in ars ]
362
362
363 def test_resubmit_header(self):
363 def test_resubmit_header(self):
364 """resubmit shouldn't clobber the whole header"""
364 """resubmit shouldn't clobber the whole header"""
365 def f():
365 def f():
366 import random
366 import random
367 return random.random()
367 return random.random()
368 v = self.client.load_balanced_view()
368 v = self.client.load_balanced_view()
369 v.retries = 1
369 v.retries = 1
370 ar = v.apply_async(f)
370 ar = v.apply_async(f)
371 r1 = ar.get(1)
371 r1 = ar.get(1)
372 # give the Hub a chance to notice:
372 # give the Hub a chance to notice:
373 self._wait_for_idle()
373 self._wait_for_idle()
374 ahr = self.client.resubmit(ar.msg_ids)
374 ahr = self.client.resubmit(ar.msg_ids)
375 ahr.get(1)
375 ahr.get(1)
376 time.sleep(0.5)
376 time.sleep(0.5)
377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 h1,h2 = [ r['header'] for r in records ]
378 h1,h2 = [ r['header'] for r in records ]
379 for key in set(h1.keys()).union(set(h2.keys())):
379 for key in set(h1.keys()).union(set(h2.keys())):
380 if key in ('msg_id', 'date'):
380 if key in ('msg_id', 'date'):
381 self.assertNotEqual(h1[key], h2[key])
381 self.assertNotEqual(h1[key], h2[key])
382 else:
382 else:
383 self.assertEqual(h1[key], h2[key])
383 self.assertEqual(h1[key], h2[key])
384
384
385 def test_resubmit_aborted(self):
385 def test_resubmit_aborted(self):
386 def f():
386 def f():
387 import random
387 import random
388 return random.random()
388 return random.random()
389 v = self.client.load_balanced_view()
389 v = self.client.load_balanced_view()
390 # restrict to one engine, so we can put a sleep
390 # restrict to one engine, so we can put a sleep
391 # ahead of the task, so it will get aborted
391 # ahead of the task, so it will get aborted
392 eid = self.client.ids[-1]
392 eid = self.client.ids[-1]
393 v.targets = [eid]
393 v.targets = [eid]
394 sleep = v.apply_async(time.sleep, 0.5)
394 sleep = v.apply_async(time.sleep, 0.5)
395 ar = v.apply_async(f)
395 ar = v.apply_async(f)
396 ar.abort()
396 ar.abort()
397 self.assertRaises(error.TaskAborted, ar.get)
397 self.assertRaises(error.TaskAborted, ar.get)
398 # Give the Hub a chance to get up to date:
398 # Give the Hub a chance to get up to date:
399 self._wait_for_idle()
399 self._wait_for_idle()
400 ahr = self.client.resubmit(ar.msg_ids)
400 ahr = self.client.resubmit(ar.msg_ids)
401 r2 = ahr.get(1)
401 r2 = ahr.get(1)
402
402
403 def test_resubmit_inflight(self):
403 def test_resubmit_inflight(self):
404 """resubmit of inflight task"""
404 """resubmit of inflight task"""
405 v = self.client.load_balanced_view()
405 v = self.client.load_balanced_view()
406 ar = v.apply_async(time.sleep,1)
406 ar = v.apply_async(time.sleep,1)
407 # give the message a chance to arrive
407 # give the message a chance to arrive
408 time.sleep(0.2)
408 time.sleep(0.2)
409 ahr = self.client.resubmit(ar.msg_ids)
409 ahr = self.client.resubmit(ar.msg_ids)
410 ar.get(2)
410 ar.get(2)
411 ahr.get(2)
411 ahr.get(2)
412
412
413 def test_resubmit_badkey(self):
413 def test_resubmit_badkey(self):
414 """ensure KeyError on resubmit of nonexistant task"""
414 """ensure KeyError on resubmit of nonexistant task"""
415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416
416
417 def test_purge_hub_results(self):
417 def test_purge_hub_results(self):
418 # ensure there are some tasks
418 # ensure there are some tasks
419 for i in range(5):
419 for i in range(5):
420 self.client[:].apply_sync(lambda : 1)
420 self.client[:].apply_sync(lambda : 1)
421 # Wait for the Hub to realise the result is done:
421 # Wait for the Hub to realise the result is done:
422 # This prevents a race condition, where we
422 # This prevents a race condition, where we
423 # might purge a result the Hub still thinks is pending.
423 # might purge a result the Hub still thinks is pending.
424 self._wait_for_idle()
424 self._wait_for_idle()
425 rc2 = clientmod.Client(profile='iptest')
425 rc2 = clientmod.Client(profile='iptest')
426 hist = self.client.hub_history()
426 hist = self.client.hub_history()
427 ahr = rc2.get_result([hist[-1]])
427 ahr = rc2.get_result([hist[-1]])
428 ahr.wait(10)
428 ahr.wait(10)
429 self.client.purge_hub_results(hist[-1])
429 self.client.purge_hub_results(hist[-1])
430 newhist = self.client.hub_history()
430 newhist = self.client.hub_history()
431 self.assertEqual(len(newhist)+1,len(hist))
431 self.assertEqual(len(newhist)+1,len(hist))
432 rc2.spin()
432 rc2.spin()
433 rc2.close()
433 rc2.close()
434
434
435 def test_purge_local_results(self):
435 def test_purge_local_results(self):
436 # ensure there are some tasks
436 # ensure there are some tasks
437 res = []
437 res = []
438 for i in range(5):
438 for i in range(5):
439 res.append(self.client[:].apply_async(lambda : 1))
439 res.append(self.client[:].apply_async(lambda : 1))
440 self._wait_for_idle()
440 self._wait_for_idle()
441 self.client.wait(10) # wait for the results to come back
441 self.client.wait(10) # wait for the results to come back
442 before = len(self.client.results)
442 before = len(self.client.results)
443 self.assertEqual(len(self.client.metadata),before)
443 self.assertEqual(len(self.client.metadata),before)
444 self.client.purge_local_results(res[-1])
444 self.client.purge_local_results(res[-1])
445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447
447
448 def test_purge_all_hub_results(self):
448 def test_purge_all_hub_results(self):
449 self.client.purge_hub_results('all')
449 self.client.purge_hub_results('all')
450 hist = self.client.hub_history()
450 hist = self.client.hub_history()
451 self.assertEqual(len(hist), 0)
451 self.assertEqual(len(hist), 0)
452
452
453 def test_purge_all_local_results(self):
453 def test_purge_all_local_results(self):
454 self.client.purge_local_results('all')
454 self.client.purge_local_results('all')
455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
457
457
458 def test_purge_all_results(self):
458 def test_purge_all_results(self):
459 # ensure there are some tasks
459 # ensure there are some tasks
460 for i in range(5):
460 for i in range(5):
461 self.client[:].apply_sync(lambda : 1)
461 self.client[:].apply_sync(lambda : 1)
462 self.client.wait(10)
462 self.client.wait(10)
463 self._wait_for_idle()
463 self._wait_for_idle()
464 self.client.purge_results('all')
464 self.client.purge_results('all')
465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
467 hist = self.client.hub_history()
467 hist = self.client.hub_history()
468 self.assertEqual(len(hist), 0, msg="hub history not empty")
468 self.assertEqual(len(hist), 0, msg="hub history not empty")
469
469
470 def test_purge_everything(self):
470 def test_purge_everything(self):
471 # ensure there are some tasks
471 # ensure there are some tasks
472 for i in range(5):
472 for i in range(5):
473 self.client[:].apply_sync(lambda : 1)
473 self.client[:].apply_sync(lambda : 1)
474 self.client.wait(10)
474 self.client.wait(10)
475 self._wait_for_idle()
475 self._wait_for_idle()
476 self.client.purge_everything()
476 self.client.purge_everything()
477 # The client results
477 # The client results
478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 # The client "bookkeeping"
480 # The client "bookkeeping"
481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
483 # the hub results
483 # the hub results
484 hist = self.client.hub_history()
484 hist = self.client.hub_history()
485 self.assertEqual(len(hist), 0, msg="hub history not empty")
485 self.assertEqual(len(hist), 0, msg="hub history not empty")
486
486
487
487
488 def test_spin_thread(self):
488 def test_spin_thread(self):
489 self.client.spin_thread(0.01)
489 self.client.spin_thread(0.01)
490 ar = self.client[-1].apply_async(lambda : 1)
490 ar = self.client[-1].apply_async(lambda : 1)
491 time.sleep(0.1)
491 time.sleep(0.1)
492 self.assertTrue(ar.wall_time < 0.1,
492 self.assertTrue(ar.wall_time < 0.1,
493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
494 )
494 )
495
495
496 def test_stop_spin_thread(self):
496 def test_stop_spin_thread(self):
497 self.client.spin_thread(0.01)
497 self.client.spin_thread(0.01)
498 self.client.stop_spin_thread()
498 self.client.stop_spin_thread()
499 ar = self.client[-1].apply_async(lambda : 1)
499 ar = self.client[-1].apply_async(lambda : 1)
500 time.sleep(0.15)
500 time.sleep(0.15)
501 self.assertTrue(ar.wall_time > 0.1,
501 self.assertTrue(ar.wall_time > 0.1,
502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
503 )
503 )
504
504
505 def test_activate(self):
505 def test_activate(self):
506 ip = get_ipython()
506 ip = get_ipython()
507 magics = ip.magics_manager.magics
507 magics = ip.magics_manager.magics
508 self.assertTrue('px' in magics['line'])
508 self.assertTrue('px' in magics['line'])
509 self.assertTrue('px' in magics['cell'])
509 self.assertTrue('px' in magics['cell'])
510 v0 = self.client.activate(-1, '0')
510 v0 = self.client.activate(-1, '0')
511 self.assertTrue('px0' in magics['line'])
511 self.assertTrue('px0' in magics['line'])
512 self.assertTrue('px0' in magics['cell'])
512 self.assertTrue('px0' in magics['cell'])
513 self.assertEqual(v0.targets, self.client.ids[-1])
513 self.assertEqual(v0.targets, self.client.ids[-1])
514 v0 = self.client.activate('all', 'all')
514 v0 = self.client.activate('all', 'all')
515 self.assertTrue('pxall' in magics['line'])
515 self.assertTrue('pxall' in magics['line'])
516 self.assertTrue('pxall' in magics['cell'])
516 self.assertTrue('pxall' in magics['cell'])
517 self.assertEqual(v0.targets, 'all')
517 self.assertEqual(v0.targets, 'all')
@@ -1,136 +1,136 b''
1 """Tests for dependency.py
1 """Tests for dependency.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 __docformat__ = "restructuredtext en"
8 __docformat__ = "restructuredtext en"
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16
16
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-------------------------------------------------------------------------------
19 #-------------------------------------------------------------------------------
20
20
21 # import
21 # import
22 import os
22 import os
23
23
24 from IPython.utils.pickleutil import can, uncan
24 from IPython.utils.pickleutil import can, uncan
25
25
26 import IPython.parallel as pmod
26 import IPython.parallel as pmod
27 from IPython.parallel.util import interactive
27 from IPython.parallel.util import interactive
28
28
29 from IPython.parallel.tests import add_engines
29 from IPython.parallel.tests import add_engines
30 from .clienttest import ClusterTestCase
30 from .clienttest import ClusterTestCase
31
31
32 def setup():
32 def setup():
33 add_engines(1, total=True)
33 add_engines(1, total=True)
34
34
35 @pmod.require('time')
35 @pmod.require('time')
36 def wait(n):
36 def wait(n):
37 time.sleep(n)
37 time.sleep(n)
38 return n
38 return n
39
39
40 @pmod.interactive
40 @pmod.interactive
41 def func(x):
41 def func(x):
42 return x*x
42 return x*x
43
43
44 mixed = map(str, range(10))
44 mixed = list(map(str, range(10)))
45 completed = map(str, range(0,10,2))
45 completed = list(map(str, range(0,10,2)))
46 failed = map(str, range(1,10,2))
46 failed = list(map(str, range(1,10,2)))
47
47
48 class DependencyTest(ClusterTestCase):
48 class DependencyTest(ClusterTestCase):
49
49
50 def setUp(self):
50 def setUp(self):
51 ClusterTestCase.setUp(self)
51 ClusterTestCase.setUp(self)
52 self.user_ns = {'__builtins__' : __builtins__}
52 self.user_ns = {'__builtins__' : __builtins__}
53 self.view = self.client.load_balanced_view()
53 self.view = self.client.load_balanced_view()
54 self.dview = self.client[-1]
54 self.dview = self.client[-1]
55 self.succeeded = set(map(str, range(0,25,2)))
55 self.succeeded = set(map(str, range(0,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
57
57
58 def assertMet(self, dep):
58 def assertMet(self, dep):
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
60
60
61 def assertUnmet(self, dep):
61 def assertUnmet(self, dep):
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
63
63
64 def assertUnreachable(self, dep):
64 def assertUnreachable(self, dep):
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
66
66
67 def assertReachable(self, dep):
67 def assertReachable(self, dep):
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
69
69
70 def cancan(self, f):
70 def cancan(self, f):
71 """decorator to pass through canning into self.user_ns"""
71 """decorator to pass through canning into self.user_ns"""
72 return uncan(can(f), self.user_ns)
72 return uncan(can(f), self.user_ns)
73
73
74 def test_require_imports(self):
74 def test_require_imports(self):
75 """test that @require imports names"""
75 """test that @require imports names"""
76 @self.cancan
76 @self.cancan
77 @pmod.require('urllib')
77 @pmod.require('base64')
78 @interactive
78 @interactive
79 def encode(dikt):
79 def encode(arg):
80 return urllib.urlencode(dikt)
80 return base64.b64encode(arg)
81 # must pass through canning to properly connect namespaces
81 # must pass through canning to properly connect namespaces
82 self.assertEqual(encode(dict(a=5)), 'a=5')
82 self.assertEqual(encode(b'foo'), b'Zm9v')
83
83
84 def test_success_only(self):
84 def test_success_only(self):
85 dep = pmod.Dependency(mixed, success=True, failure=False)
85 dep = pmod.Dependency(mixed, success=True, failure=False)
86 self.assertUnmet(dep)
86 self.assertUnmet(dep)
87 self.assertUnreachable(dep)
87 self.assertUnreachable(dep)
88 dep.all=False
88 dep.all=False
89 self.assertMet(dep)
89 self.assertMet(dep)
90 self.assertReachable(dep)
90 self.assertReachable(dep)
91 dep = pmod.Dependency(completed, success=True, failure=False)
91 dep = pmod.Dependency(completed, success=True, failure=False)
92 self.assertMet(dep)
92 self.assertMet(dep)
93 self.assertReachable(dep)
93 self.assertReachable(dep)
94 dep.all=False
94 dep.all=False
95 self.assertMet(dep)
95 self.assertMet(dep)
96 self.assertReachable(dep)
96 self.assertReachable(dep)
97
97
98 def test_failure_only(self):
98 def test_failure_only(self):
99 dep = pmod.Dependency(mixed, success=False, failure=True)
99 dep = pmod.Dependency(mixed, success=False, failure=True)
100 self.assertUnmet(dep)
100 self.assertUnmet(dep)
101 self.assertUnreachable(dep)
101 self.assertUnreachable(dep)
102 dep.all=False
102 dep.all=False
103 self.assertMet(dep)
103 self.assertMet(dep)
104 self.assertReachable(dep)
104 self.assertReachable(dep)
105 dep = pmod.Dependency(completed, success=False, failure=True)
105 dep = pmod.Dependency(completed, success=False, failure=True)
106 self.assertUnmet(dep)
106 self.assertUnmet(dep)
107 self.assertUnreachable(dep)
107 self.assertUnreachable(dep)
108 dep.all=False
108 dep.all=False
109 self.assertUnmet(dep)
109 self.assertUnmet(dep)
110 self.assertUnreachable(dep)
110 self.assertUnreachable(dep)
111
111
112 def test_require_function(self):
112 def test_require_function(self):
113
113
114 @pmod.interactive
114 @pmod.interactive
115 def bar(a):
115 def bar(a):
116 return func(a)
116 return func(a)
117
117
118 @pmod.require(func)
118 @pmod.require(func)
119 @pmod.interactive
119 @pmod.interactive
120 def bar2(a):
120 def bar2(a):
121 return func(a)
121 return func(a)
122
122
123 self.client[:].clear()
123 self.client[:].clear()
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
125 ar = self.view.apply_async(bar2, 5)
125 ar = self.view.apply_async(bar2, 5)
126 self.assertEqual(ar.get(5), func(5))
126 self.assertEqual(ar.get(5), func(5))
127
127
128 def test_require_object(self):
128 def test_require_object(self):
129
129
130 @pmod.require(foo=func)
130 @pmod.require(foo=func)
131 @pmod.interactive
131 @pmod.interactive
132 def bar(a):
132 def bar(a):
133 return foo(a)
133 return foo(a)
134
134
135 ar = self.view.apply_async(bar, 5)
135 ar = self.view.apply_async(bar, 5)
136 self.assertEqual(ar.get(5), func(5))
136 self.assertEqual(ar.get(5), func(5))
@@ -1,221 +1,221 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test LoadBalancedView objects
2 """test LoadBalancedView objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from nose import SkipTest
23 from nose import SkipTest
24 from nose.plugins.attrib import attr
24 from nose.plugins.attrib import attr
25
25
26 from IPython import parallel as pmod
26 from IPython import parallel as pmod
27 from IPython.parallel import error
27 from IPython.parallel import error
28
28
29 from IPython.parallel.tests import add_engines
29 from IPython.parallel.tests import add_engines
30
30
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32
32
33 def setup():
33 def setup():
34 add_engines(3, total=True)
34 add_engines(3, total=True)
35
35
36 class TestLoadBalancedView(ClusterTestCase):
36 class TestLoadBalancedView(ClusterTestCase):
37
37
38 def setUp(self):
38 def setUp(self):
39 ClusterTestCase.setUp(self)
39 ClusterTestCase.setUp(self)
40 self.view = self.client.load_balanced_view()
40 self.view = self.client.load_balanced_view()
41
41
42 @attr('crash')
42 @attr('crash')
43 def test_z_crash_task(self):
43 def test_z_crash_task(self):
44 """test graceful handling of engine death (balanced)"""
44 """test graceful handling of engine death (balanced)"""
45 # self.add_engines(1)
45 # self.add_engines(1)
46 ar = self.view.apply_async(crash)
46 ar = self.view.apply_async(crash)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 eid = ar.engine_id
48 eid = ar.engine_id
49 tic = time.time()
49 tic = time.time()
50 while eid in self.client.ids and time.time()-tic < 5:
50 while eid in self.client.ids and time.time()-tic < 5:
51 time.sleep(.01)
51 time.sleep(.01)
52 self.client.spin()
52 self.client.spin()
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54
54
55 def test_map(self):
55 def test_map(self):
56 def f(x):
56 def f(x):
57 return x**2
57 return x**2
58 data = range(16)
58 data = list(range(16))
59 r = self.view.map_sync(f, data)
59 r = self.view.map_sync(f, data)
60 self.assertEqual(r, map(f, data))
60 self.assertEqual(r, list(map(f, data)))
61
61
62 def test_map_generator(self):
62 def test_map_generator(self):
63 def f(x):
63 def f(x):
64 return x**2
64 return x**2
65
65
66 data = range(16)
66 data = list(range(16))
67 r = self.view.map_sync(f, iter(data))
67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
68 self.assertEqual(r, list(map(f, iter(data))))
69
69
70 def test_map_short_first(self):
70 def test_map_short_first(self):
71 def f(x,y):
71 def f(x,y):
72 if y is None:
72 if y is None:
73 return y
73 return y
74 if x is None:
74 if x is None:
75 return x
75 return x
76 return x*y
76 return x*y
77 data = range(10)
77 data = list(range(10))
78 data2 = range(4)
78 data2 = list(range(4))
79
79
80 r = self.view.map_sync(f, data, data2)
80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
81 self.assertEqual(r, list(map(f, data, data2)))
82
82
83 def test_map_short_last(self):
83 def test_map_short_last(self):
84 def f(x,y):
84 def f(x,y):
85 if y is None:
85 if y is None:
86 return y
86 return y
87 if x is None:
87 if x is None:
88 return x
88 return x
89 return x*y
89 return x*y
90 data = range(4)
90 data = list(range(4))
91 data2 = range(10)
91 data2 = list(range(10))
92
92
93 r = self.view.map_sync(f, data, data2)
93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
94 self.assertEqual(r, list(map(f, data, data2)))
95
95
96 def test_map_unordered(self):
96 def test_map_unordered(self):
97 def f(x):
97 def f(x):
98 return x**2
98 return x**2
99 def slow_f(x):
99 def slow_f(x):
100 import time
100 import time
101 time.sleep(0.05*x)
101 time.sleep(0.05*x)
102 return x**2
102 return x**2
103 data = range(16,0,-1)
103 data = list(range(16,0,-1))
104 reference = map(f, data)
104 reference = list(map(f, data))
105
105
106 amr = self.view.map_async(slow_f, data, ordered=False)
106 amr = self.view.map_async(slow_f, data, ordered=False)
107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
108 # check individual elements, retrieved as they come
108 # check individual elements, retrieved as they come
109 # list comprehension uses __iter__
109 # list comprehension uses __iter__
110 astheycame = [ r for r in amr ]
110 astheycame = [ r for r in amr ]
111 # Ensure that at least one result came out of order:
111 # Ensure that at least one result came out of order:
112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
114
114
115 def test_map_ordered(self):
115 def test_map_ordered(self):
116 def f(x):
116 def f(x):
117 return x**2
117 return x**2
118 def slow_f(x):
118 def slow_f(x):
119 import time
119 import time
120 time.sleep(0.05*x)
120 time.sleep(0.05*x)
121 return x**2
121 return x**2
122 data = range(16,0,-1)
122 data = list(range(16,0,-1))
123 reference = map(f, data)
123 reference = list(map(f, data))
124
124
125 amr = self.view.map_async(slow_f, data)
125 amr = self.view.map_async(slow_f, data)
126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
127 # check individual elements, retrieved as they come
127 # check individual elements, retrieved as they come
128 # list(amr) uses __iter__
128 # list(amr) uses __iter__
129 astheycame = list(amr)
129 astheycame = list(amr)
130 # Ensure that results came in order
130 # Ensure that results came in order
131 self.assertEqual(astheycame, reference)
131 self.assertEqual(astheycame, reference)
132 self.assertEqual(amr.result, reference)
132 self.assertEqual(amr.result, reference)
133
133
134 def test_map_iterable(self):
134 def test_map_iterable(self):
135 """test map on iterables (balanced)"""
135 """test map on iterables (balanced)"""
136 view = self.view
136 view = self.view
137 # 101 is prime, so it won't be evenly distributed
137 # 101 is prime, so it won't be evenly distributed
138 arr = range(101)
138 arr = range(101)
139 # so that it will be an iterator, even in Python 3
139 # so that it will be an iterator, even in Python 3
140 it = iter(arr)
140 it = iter(arr)
141 r = view.map_sync(lambda x:x, arr)
141 r = view.map_sync(lambda x:x, arr)
142 self.assertEqual(r, list(arr))
142 self.assertEqual(r, list(arr))
143
143
144
144
145 def test_abort(self):
145 def test_abort(self):
146 view = self.view
146 view = self.view
147 ar = self.client[:].apply_async(time.sleep, .5)
147 ar = self.client[:].apply_async(time.sleep, .5)
148 ar = self.client[:].apply_async(time.sleep, .5)
148 ar = self.client[:].apply_async(time.sleep, .5)
149 time.sleep(0.2)
149 time.sleep(0.2)
150 ar2 = view.apply_async(lambda : 2)
150 ar2 = view.apply_async(lambda : 2)
151 ar3 = view.apply_async(lambda : 3)
151 ar3 = view.apply_async(lambda : 3)
152 view.abort(ar2)
152 view.abort(ar2)
153 view.abort(ar3.msg_ids)
153 view.abort(ar3.msg_ids)
154 self.assertRaises(error.TaskAborted, ar2.get)
154 self.assertRaises(error.TaskAborted, ar2.get)
155 self.assertRaises(error.TaskAborted, ar3.get)
155 self.assertRaises(error.TaskAborted, ar3.get)
156
156
157 def test_retries(self):
157 def test_retries(self):
158 self.minimum_engines(3)
158 self.minimum_engines(3)
159 view = self.view
159 view = self.view
160 def fail():
160 def fail():
161 assert False
161 assert False
162 for r in range(len(self.client)-1):
162 for r in range(len(self.client)-1):
163 with view.temp_flags(retries=r):
163 with view.temp_flags(retries=r):
164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
165
165
166 with view.temp_flags(retries=len(self.client), timeout=0.1):
166 with view.temp_flags(retries=len(self.client), timeout=0.1):
167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
168
168
169 def test_short_timeout(self):
169 def test_short_timeout(self):
170 self.minimum_engines(2)
170 self.minimum_engines(2)
171 view = self.view
171 view = self.view
172 def fail():
172 def fail():
173 import time
173 import time
174 time.sleep(0.25)
174 time.sleep(0.25)
175 assert False
175 assert False
176 with view.temp_flags(retries=1, timeout=0.01):
176 with view.temp_flags(retries=1, timeout=0.01):
177 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
177 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
178
178
179 def test_invalid_dependency(self):
179 def test_invalid_dependency(self):
180 view = self.view
180 view = self.view
181 with view.temp_flags(after='12345'):
181 with view.temp_flags(after='12345'):
182 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
182 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
183
183
184 def test_impossible_dependency(self):
184 def test_impossible_dependency(self):
185 self.minimum_engines(2)
185 self.minimum_engines(2)
186 view = self.client.load_balanced_view()
186 view = self.client.load_balanced_view()
187 ar1 = view.apply_async(lambda : 1)
187 ar1 = view.apply_async(lambda : 1)
188 ar1.get()
188 ar1.get()
189 e1 = ar1.engine_id
189 e1 = ar1.engine_id
190 e2 = e1
190 e2 = e1
191 while e2 == e1:
191 while e2 == e1:
192 ar2 = view.apply_async(lambda : 1)
192 ar2 = view.apply_async(lambda : 1)
193 ar2.get()
193 ar2.get()
194 e2 = ar2.engine_id
194 e2 = ar2.engine_id
195
195
196 with view.temp_flags(follow=[ar1, ar2]):
196 with view.temp_flags(follow=[ar1, ar2]):
197 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
197 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
198
198
199
199
200 def test_follow(self):
200 def test_follow(self):
201 ar = self.view.apply_async(lambda : 1)
201 ar = self.view.apply_async(lambda : 1)
202 ar.get()
202 ar.get()
203 ars = []
203 ars = []
204 first_id = ar.engine_id
204 first_id = ar.engine_id
205
205
206 self.view.follow = ar
206 self.view.follow = ar
207 for i in range(5):
207 for i in range(5):
208 ars.append(self.view.apply_async(lambda : 1))
208 ars.append(self.view.apply_async(lambda : 1))
209 self.view.wait(ars)
209 self.view.wait(ars)
210 for ar in ars:
210 for ar in ars:
211 self.assertEqual(ar.engine_id, first_id)
211 self.assertEqual(ar.engine_id, first_id)
212
212
213 def test_after(self):
213 def test_after(self):
214 view = self.view
214 view = self.view
215 ar = view.apply_async(time.sleep, 0.5)
215 ar = view.apply_async(time.sleep, 0.5)
216 with view.temp_flags(after=ar):
216 with view.temp_flags(after=ar):
217 ar2 = view.apply_async(lambda : 1)
217 ar2 = view.apply_async(lambda : 1)
218
218
219 ar.wait()
219 ar.wait()
220 ar2.wait()
220 ar2.wait()
221 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
221 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,835 +1,835 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import base64
19 import base64
20 import sys
20 import sys
21 import platform
21 import platform
22 import time
22 import time
23 from collections import namedtuple
23 from collections import namedtuple
24 from tempfile import mktemp
24 from tempfile import mktemp
25
25
26 import zmq
26 import zmq
27 from nose.plugins.attrib import attr
27 from nose.plugins.attrib import attr
28
28
29 from IPython.testing import decorators as dec
29 from IPython.testing import decorators as dec
30 from IPython.utils.io import capture_output
30 from IPython.utils.io import capture_output
31 from IPython.utils.py3compat import unicode_type
31 from IPython.utils.py3compat import unicode_type
32
32
33 from IPython import parallel as pmod
33 from IPython import parallel as pmod
34 from IPython.parallel import error
34 from IPython.parallel import error
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 from IPython.parallel.util import interactive
36 from IPython.parallel.util import interactive
37
37
38 from IPython.parallel.tests import add_engines
38 from IPython.parallel.tests import add_engines
39
39
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41
41
42 def setup():
42 def setup():
43 add_engines(3, total=True)
43 add_engines(3, total=True)
44
44
45 point = namedtuple("point", "x y")
45 point = namedtuple("point", "x y")
46
46
47 class TestView(ClusterTestCase):
47 class TestView(ClusterTestCase):
48
48
49 def setUp(self):
49 def setUp(self):
50 # On Win XP, wait for resource cleanup, else parallel test group fails
50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 time.sleep(2)
53 time.sleep(2)
54 super(TestView, self).setUp()
54 super(TestView, self).setUp()
55
55
56 @attr('crash')
56 @attr('crash')
57 def test_z_crash_mux(self):
57 def test_z_crash_mux(self):
58 """test graceful handling of engine death (direct)"""
58 """test graceful handling of engine death (direct)"""
59 # self.add_engines(1)
59 # self.add_engines(1)
60 eid = self.client.ids[-1]
60 eid = self.client.ids[-1]
61 ar = self.client[eid].apply_async(crash)
61 ar = self.client[eid].apply_async(crash)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 eid = ar.engine_id
63 eid = ar.engine_id
64 tic = time.time()
64 tic = time.time()
65 while eid in self.client.ids and time.time()-tic < 5:
65 while eid in self.client.ids and time.time()-tic < 5:
66 time.sleep(.01)
66 time.sleep(.01)
67 self.client.spin()
67 self.client.spin()
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69
69
70 def test_push_pull(self):
70 def test_push_pull(self):
71 """test pushing and pulling"""
71 """test pushing and pulling"""
72 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 t = self.client.ids[-1]
73 t = self.client.ids[-1]
74 v = self.client[t]
74 v = self.client[t]
75 push = v.push
75 push = v.push
76 pull = v.pull
76 pull = v.pull
77 v.block=True
77 v.block=True
78 nengines = len(self.client)
78 nengines = len(self.client)
79 push({'data':data})
79 push({'data':data})
80 d = pull('data')
80 d = pull('data')
81 self.assertEqual(d, data)
81 self.assertEqual(d, data)
82 self.client[:].push({'data':data})
82 self.client[:].push({'data':data})
83 d = self.client[:].pull('data', block=True)
83 d = self.client[:].pull('data', block=True)
84 self.assertEqual(d, nengines*[data])
84 self.assertEqual(d, nengines*[data])
85 ar = push({'data':data}, block=False)
85 ar = push({'data':data}, block=False)
86 self.assertTrue(isinstance(ar, AsyncResult))
86 self.assertTrue(isinstance(ar, AsyncResult))
87 r = ar.get()
87 r = ar.get()
88 ar = self.client[:].pull('data', block=False)
88 ar = self.client[:].pull('data', block=False)
89 self.assertTrue(isinstance(ar, AsyncResult))
89 self.assertTrue(isinstance(ar, AsyncResult))
90 r = ar.get()
90 r = ar.get()
91 self.assertEqual(r, nengines*[data])
91 self.assertEqual(r, nengines*[data])
92 self.client[:].push(dict(a=10,b=20))
92 self.client[:].push(dict(a=10,b=20))
93 r = self.client[:].pull(('a','b'), block=True)
93 r = self.client[:].pull(('a','b'), block=True)
94 self.assertEqual(r, nengines*[[10,20]])
94 self.assertEqual(r, nengines*[[10,20]])
95
95
96 def test_push_pull_function(self):
96 def test_push_pull_function(self):
97 "test pushing and pulling functions"
97 "test pushing and pulling functions"
98 def testf(x):
98 def testf(x):
99 return 2.0*x
99 return 2.0*x
100
100
101 t = self.client.ids[-1]
101 t = self.client.ids[-1]
102 v = self.client[t]
102 v = self.client[t]
103 v.block=True
103 v.block=True
104 push = v.push
104 push = v.push
105 pull = v.pull
105 pull = v.pull
106 execute = v.execute
106 execute = v.execute
107 push({'testf':testf})
107 push({'testf':testf})
108 r = pull('testf')
108 r = pull('testf')
109 self.assertEqual(r(1.0), testf(1.0))
109 self.assertEqual(r(1.0), testf(1.0))
110 execute('r = testf(10)')
110 execute('r = testf(10)')
111 r = pull('r')
111 r = pull('r')
112 self.assertEqual(r, testf(10))
112 self.assertEqual(r, testf(10))
113 ar = self.client[:].push({'testf':testf}, block=False)
113 ar = self.client[:].push({'testf':testf}, block=False)
114 ar.get()
114 ar.get()
115 ar = self.client[:].pull('testf', block=False)
115 ar = self.client[:].pull('testf', block=False)
116 rlist = ar.get()
116 rlist = ar.get()
117 for r in rlist:
117 for r in rlist:
118 self.assertEqual(r(1.0), testf(1.0))
118 self.assertEqual(r(1.0), testf(1.0))
119 execute("def g(x): return x*x")
119 execute("def g(x): return x*x")
120 r = pull(('testf','g'))
120 r = pull(('testf','g'))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122
122
123 def test_push_function_globals(self):
123 def test_push_function_globals(self):
124 """test that pushed functions have access to globals"""
124 """test that pushed functions have access to globals"""
125 @interactive
125 @interactive
126 def geta():
126 def geta():
127 return a
127 return a
128 # self.add_engines(1)
128 # self.add_engines(1)
129 v = self.client[-1]
129 v = self.client[-1]
130 v.block=True
130 v.block=True
131 v['f'] = geta
131 v['f'] = geta
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 v.execute('a=5')
133 v.execute('a=5')
134 v.execute('b=f()')
134 v.execute('b=f()')
135 self.assertEqual(v['b'], 5)
135 self.assertEqual(v['b'], 5)
136
136
137 def test_push_function_defaults(self):
137 def test_push_function_defaults(self):
138 """test that pushed functions preserve default args"""
138 """test that pushed functions preserve default args"""
139 def echo(a=10):
139 def echo(a=10):
140 return a
140 return a
141 v = self.client[-1]
141 v = self.client[-1]
142 v.block=True
142 v.block=True
143 v['f'] = echo
143 v['f'] = echo
144 v.execute('b=f()')
144 v.execute('b=f()')
145 self.assertEqual(v['b'], 10)
145 self.assertEqual(v['b'], 10)
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = pmod.Client(profile='iptest')
149 c = pmod.Client(profile='iptest')
150 # self.add_engines(1)
150 # self.add_engines(1)
151 t = c.ids[-1]
151 t = c.ids[-1]
152 v = c[t]
152 v = c[t]
153 v2 = self.client[t]
153 v2 = self.client[t]
154 ar = v.apply_async(wait, 1)
154 ar = v.apply_async(wait, 1)
155 # give the monitor time to notice the message
155 # give the monitor time to notice the message
156 time.sleep(.25)
156 time.sleep(.25)
157 ahr = v2.get_result(ar.msg_ids[0])
157 ahr = v2.get_result(ar.msg_ids[0])
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 self.assertEqual(ahr.get(), ar.get())
159 self.assertEqual(ahr.get(), ar.get())
160 ar2 = v2.get_result(ar.msg_ids[0])
160 ar2 = v2.get_result(ar.msg_ids[0])
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 c.spin()
162 c.spin()
163 c.close()
163 c.close()
164
164
165 def test_run_newline(self):
165 def test_run_newline(self):
166 """test that run appends newline to files"""
166 """test that run appends newline to files"""
167 tmpfile = mktemp()
167 tmpfile = mktemp()
168 with open(tmpfile, 'w') as f:
168 with open(tmpfile, 'w') as f:
169 f.write("""def g():
169 f.write("""def g():
170 return 5
170 return 5
171 """)
171 """)
172 v = self.client[-1]
172 v = self.client[-1]
173 v.run(tmpfile, block=True)
173 v.run(tmpfile, block=True)
174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175
175
176 def test_apply_tracked(self):
176 def test_apply_tracked(self):
177 """test tracking for apply"""
177 """test tracking for apply"""
178 # self.add_engines(1)
178 # self.add_engines(1)
179 t = self.client.ids[-1]
179 t = self.client.ids[-1]
180 v = self.client[t]
180 v = self.client[t]
181 v.block=False
181 v.block=False
182 def echo(n=1024*1024, **kwargs):
182 def echo(n=1024*1024, **kwargs):
183 with v.temp_flags(**kwargs):
183 with v.temp_flags(**kwargs):
184 return v.apply(lambda x: x, 'x'*n)
184 return v.apply(lambda x: x, 'x'*n)
185 ar = echo(1, track=False)
185 ar = echo(1, track=False)
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 self.assertTrue(ar.sent)
187 self.assertTrue(ar.sent)
188 ar = echo(track=True)
188 ar = echo(track=True)
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 self.assertEqual(ar.sent, ar._tracker.done)
190 self.assertEqual(ar.sent, ar._tracker.done)
191 ar._tracker.wait()
191 ar._tracker.wait()
192 self.assertTrue(ar.sent)
192 self.assertTrue(ar.sent)
193
193
194 def test_push_tracked(self):
194 def test_push_tracked(self):
195 t = self.client.ids[-1]
195 t = self.client.ids[-1]
196 ns = dict(x='x'*1024*1024)
196 ns = dict(x='x'*1024*1024)
197 v = self.client[t]
197 v = self.client[t]
198 ar = v.push(ns, block=False, track=False)
198 ar = v.push(ns, block=False, track=False)
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(ar.sent)
200 self.assertTrue(ar.sent)
201
201
202 ar = v.push(ns, block=False, track=True)
202 ar = v.push(ns, block=False, track=True)
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 ar._tracker.wait()
204 ar._tracker.wait()
205 self.assertEqual(ar.sent, ar._tracker.done)
205 self.assertEqual(ar.sent, ar._tracker.done)
206 self.assertTrue(ar.sent)
206 self.assertTrue(ar.sent)
207 ar.get()
207 ar.get()
208
208
209 def test_scatter_tracked(self):
209 def test_scatter_tracked(self):
210 t = self.client.ids
210 t = self.client.ids
211 x='x'*1024*1024
211 x='x'*1024*1024
212 ar = self.client[t].scatter('x', x, block=False, track=False)
212 ar = self.client[t].scatter('x', x, block=False, track=False)
213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 self.assertTrue(ar.sent)
214 self.assertTrue(ar.sent)
215
215
216 ar = self.client[t].scatter('x', x, block=False, track=True)
216 ar = self.client[t].scatter('x', x, block=False, track=True)
217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 self.assertEqual(ar.sent, ar._tracker.done)
218 self.assertEqual(ar.sent, ar._tracker.done)
219 ar._tracker.wait()
219 ar._tracker.wait()
220 self.assertTrue(ar.sent)
220 self.assertTrue(ar.sent)
221 ar.get()
221 ar.get()
222
222
223 def test_remote_reference(self):
223 def test_remote_reference(self):
224 v = self.client[-1]
224 v = self.client[-1]
225 v['a'] = 123
225 v['a'] = 123
226 ra = pmod.Reference('a')
226 ra = pmod.Reference('a')
227 b = v.apply_sync(lambda x: x, ra)
227 b = v.apply_sync(lambda x: x, ra)
228 self.assertEqual(b, 123)
228 self.assertEqual(b, 123)
229
229
230
230
231 def test_scatter_gather(self):
231 def test_scatter_gather(self):
232 view = self.client[:]
232 view = self.client[:]
233 seq1 = range(16)
233 seq1 = list(range(16))
234 view.scatter('a', seq1)
234 view.scatter('a', seq1)
235 seq2 = view.gather('a', block=True)
235 seq2 = view.gather('a', block=True)
236 self.assertEqual(seq2, seq1)
236 self.assertEqual(seq2, seq1)
237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238
238
239 @skip_without('numpy')
239 @skip_without('numpy')
240 def test_scatter_gather_numpy(self):
240 def test_scatter_gather_numpy(self):
241 import numpy
241 import numpy
242 from numpy.testing.utils import assert_array_equal
242 from numpy.testing.utils import assert_array_equal
243 view = self.client[:]
243 view = self.client[:]
244 a = numpy.arange(64)
244 a = numpy.arange(64)
245 view.scatter('a', a, block=True)
245 view.scatter('a', a, block=True)
246 b = view.gather('a', block=True)
246 b = view.gather('a', block=True)
247 assert_array_equal(b, a)
247 assert_array_equal(b, a)
248
248
249 def test_scatter_gather_lazy(self):
249 def test_scatter_gather_lazy(self):
250 """scatter/gather with targets='all'"""
250 """scatter/gather with targets='all'"""
251 view = self.client.direct_view(targets='all')
251 view = self.client.direct_view(targets='all')
252 x = range(64)
252 x = list(range(64))
253 view.scatter('x', x)
253 view.scatter('x', x)
254 gathered = view.gather('x', block=True)
254 gathered = view.gather('x', block=True)
255 self.assertEqual(gathered, x)
255 self.assertEqual(gathered, x)
256
256
257
257
258 @dec.known_failure_py3
258 @dec.known_failure_py3
259 @skip_without('numpy')
259 @skip_without('numpy')
260 def test_push_numpy_nocopy(self):
260 def test_push_numpy_nocopy(self):
261 import numpy
261 import numpy
262 view = self.client[:]
262 view = self.client[:]
263 a = numpy.arange(64)
263 a = numpy.arange(64)
264 view['A'] = a
264 view['A'] = a
265 @interactive
265 @interactive
266 def check_writeable(x):
266 def check_writeable(x):
267 return x.flags.writeable
267 return x.flags.writeable
268
268
269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271
271
272 view.push(dict(B=a))
272 view.push(dict(B=a))
273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275
275
276 @skip_without('numpy')
276 @skip_without('numpy')
277 def test_apply_numpy(self):
277 def test_apply_numpy(self):
278 """view.apply(f, ndarray)"""
278 """view.apply(f, ndarray)"""
279 import numpy
279 import numpy
280 from numpy.testing.utils import assert_array_equal
280 from numpy.testing.utils import assert_array_equal
281
281
282 A = numpy.random.random((100,100))
282 A = numpy.random.random((100,100))
283 view = self.client[-1]
283 view = self.client[-1]
284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 B = A.astype(dt)
285 B = A.astype(dt)
286 C = view.apply_sync(lambda x:x, B)
286 C = view.apply_sync(lambda x:x, B)
287 assert_array_equal(B,C)
287 assert_array_equal(B,C)
288
288
289 @skip_without('numpy')
289 @skip_without('numpy')
290 def test_push_pull_recarray(self):
290 def test_push_pull_recarray(self):
291 """push/pull recarrays"""
291 """push/pull recarrays"""
292 import numpy
292 import numpy
293 from numpy.testing.utils import assert_array_equal
293 from numpy.testing.utils import assert_array_equal
294
294
295 view = self.client[-1]
295 view = self.client[-1]
296
296
297 R = numpy.array([
297 R = numpy.array([
298 (1, 'hi', 0.),
298 (1, 'hi', 0.),
299 (2**30, 'there', 2.5),
299 (2**30, 'there', 2.5),
300 (-99999, 'world', -12345.6789),
300 (-99999, 'world', -12345.6789),
301 ], [('n', int), ('s', '|S10'), ('f', float)])
301 ], [('n', int), ('s', '|S10'), ('f', float)])
302
302
303 view['RR'] = R
303 view['RR'] = R
304 R2 = view['RR']
304 R2 = view['RR']
305
305
306 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
306 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
307 self.assertEqual(r_dtype, R.dtype)
307 self.assertEqual(r_dtype, R.dtype)
308 self.assertEqual(r_shape, R.shape)
308 self.assertEqual(r_shape, R.shape)
309 self.assertEqual(R2.dtype, R.dtype)
309 self.assertEqual(R2.dtype, R.dtype)
310 self.assertEqual(R2.shape, R.shape)
310 self.assertEqual(R2.shape, R.shape)
311 assert_array_equal(R2, R)
311 assert_array_equal(R2, R)
312
312
313 @skip_without('pandas')
313 @skip_without('pandas')
314 def test_push_pull_timeseries(self):
314 def test_push_pull_timeseries(self):
315 """push/pull pandas.TimeSeries"""
315 """push/pull pandas.TimeSeries"""
316 import pandas
316 import pandas
317
317
318 ts = pandas.TimeSeries(range(10))
318 ts = pandas.TimeSeries(list(range(10)))
319
319
320 view = self.client[-1]
320 view = self.client[-1]
321
321
322 view.push(dict(ts=ts), block=True)
322 view.push(dict(ts=ts), block=True)
323 rts = view['ts']
323 rts = view['ts']
324
324
325 self.assertEqual(type(rts), type(ts))
325 self.assertEqual(type(rts), type(ts))
326 self.assertTrue((ts == rts).all())
326 self.assertTrue((ts == rts).all())
327
327
328 def test_map(self):
328 def test_map(self):
329 view = self.client[:]
329 view = self.client[:]
330 def f(x):
330 def f(x):
331 return x**2
331 return x**2
332 data = range(16)
332 data = list(range(16))
333 r = view.map_sync(f, data)
333 r = view.map_sync(f, data)
334 self.assertEqual(r, map(f, data))
334 self.assertEqual(r, list(map(f, data)))
335
335
336 def test_map_iterable(self):
336 def test_map_iterable(self):
337 """test map on iterables (direct)"""
337 """test map on iterables (direct)"""
338 view = self.client[:]
338 view = self.client[:]
339 # 101 is prime, so it won't be evenly distributed
339 # 101 is prime, so it won't be evenly distributed
340 arr = range(101)
340 arr = range(101)
341 # ensure it will be an iterator, even in Python 3
341 # ensure it will be an iterator, even in Python 3
342 it = iter(arr)
342 it = iter(arr)
343 r = view.map_sync(lambda x: x, it)
343 r = view.map_sync(lambda x: x, it)
344 self.assertEqual(r, list(arr))
344 self.assertEqual(r, list(arr))
345
345
346 @skip_without('numpy')
346 @skip_without('numpy')
347 def test_map_numpy(self):
347 def test_map_numpy(self):
348 """test map on numpy arrays (direct)"""
348 """test map on numpy arrays (direct)"""
349 import numpy
349 import numpy
350 from numpy.testing.utils import assert_array_equal
350 from numpy.testing.utils import assert_array_equal
351
351
352 view = self.client[:]
352 view = self.client[:]
353 # 101 is prime, so it won't be evenly distributed
353 # 101 is prime, so it won't be evenly distributed
354 arr = numpy.arange(101)
354 arr = numpy.arange(101)
355 r = view.map_sync(lambda x: x, arr)
355 r = view.map_sync(lambda x: x, arr)
356 assert_array_equal(r, arr)
356 assert_array_equal(r, arr)
357
357
358 def test_scatter_gather_nonblocking(self):
358 def test_scatter_gather_nonblocking(self):
359 data = range(16)
359 data = list(range(16))
360 view = self.client[:]
360 view = self.client[:]
361 view.scatter('a', data, block=False)
361 view.scatter('a', data, block=False)
362 ar = view.gather('a', block=False)
362 ar = view.gather('a', block=False)
363 self.assertEqual(ar.get(), data)
363 self.assertEqual(ar.get(), data)
364
364
365 @skip_without('numpy')
365 @skip_without('numpy')
366 def test_scatter_gather_numpy_nonblocking(self):
366 def test_scatter_gather_numpy_nonblocking(self):
367 import numpy
367 import numpy
368 from numpy.testing.utils import assert_array_equal
368 from numpy.testing.utils import assert_array_equal
369 a = numpy.arange(64)
369 a = numpy.arange(64)
370 view = self.client[:]
370 view = self.client[:]
371 ar = view.scatter('a', a, block=False)
371 ar = view.scatter('a', a, block=False)
372 self.assertTrue(isinstance(ar, AsyncResult))
372 self.assertTrue(isinstance(ar, AsyncResult))
373 amr = view.gather('a', block=False)
373 amr = view.gather('a', block=False)
374 self.assertTrue(isinstance(amr, AsyncMapResult))
374 self.assertTrue(isinstance(amr, AsyncMapResult))
375 assert_array_equal(amr.get(), a)
375 assert_array_equal(amr.get(), a)
376
376
377 def test_execute(self):
377 def test_execute(self):
378 view = self.client[:]
378 view = self.client[:]
379 # self.client.debug=True
379 # self.client.debug=True
380 execute = view.execute
380 execute = view.execute
381 ar = execute('c=30', block=False)
381 ar = execute('c=30', block=False)
382 self.assertTrue(isinstance(ar, AsyncResult))
382 self.assertTrue(isinstance(ar, AsyncResult))
383 ar = execute('d=[0,1,2]', block=False)
383 ar = execute('d=[0,1,2]', block=False)
384 self.client.wait(ar, 1)
384 self.client.wait(ar, 1)
385 self.assertEqual(len(ar.get()), len(self.client))
385 self.assertEqual(len(ar.get()), len(self.client))
386 for c in view['c']:
386 for c in view['c']:
387 self.assertEqual(c, 30)
387 self.assertEqual(c, 30)
388
388
389 def test_abort(self):
389 def test_abort(self):
390 view = self.client[-1]
390 view = self.client[-1]
391 ar = view.execute('import time; time.sleep(1)', block=False)
391 ar = view.execute('import time; time.sleep(1)', block=False)
392 ar2 = view.apply_async(lambda : 2)
392 ar2 = view.apply_async(lambda : 2)
393 ar3 = view.apply_async(lambda : 3)
393 ar3 = view.apply_async(lambda : 3)
394 view.abort(ar2)
394 view.abort(ar2)
395 view.abort(ar3.msg_ids)
395 view.abort(ar3.msg_ids)
396 self.assertRaises(error.TaskAborted, ar2.get)
396 self.assertRaises(error.TaskAborted, ar2.get)
397 self.assertRaises(error.TaskAborted, ar3.get)
397 self.assertRaises(error.TaskAborted, ar3.get)
398
398
399 def test_abort_all(self):
399 def test_abort_all(self):
400 """view.abort() aborts all outstanding tasks"""
400 """view.abort() aborts all outstanding tasks"""
401 view = self.client[-1]
401 view = self.client[-1]
402 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
402 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
403 view.abort()
403 view.abort()
404 view.wait(timeout=5)
404 view.wait(timeout=5)
405 for ar in ars[5:]:
405 for ar in ars[5:]:
406 self.assertRaises(error.TaskAborted, ar.get)
406 self.assertRaises(error.TaskAborted, ar.get)
407
407
408 def test_temp_flags(self):
408 def test_temp_flags(self):
409 view = self.client[-1]
409 view = self.client[-1]
410 view.block=True
410 view.block=True
411 with view.temp_flags(block=False):
411 with view.temp_flags(block=False):
412 self.assertFalse(view.block)
412 self.assertFalse(view.block)
413 self.assertTrue(view.block)
413 self.assertTrue(view.block)
414
414
415 @dec.known_failure_py3
415 @dec.known_failure_py3
416 def test_importer(self):
416 def test_importer(self):
417 view = self.client[-1]
417 view = self.client[-1]
418 view.clear(block=True)
418 view.clear(block=True)
419 with view.importer:
419 with view.importer:
420 import re
420 import re
421
421
422 @interactive
422 @interactive
423 def findall(pat, s):
423 def findall(pat, s):
424 # this globals() step isn't necessary in real code
424 # this globals() step isn't necessary in real code
425 # only to prevent a closure in the test
425 # only to prevent a closure in the test
426 re = globals()['re']
426 re = globals()['re']
427 return re.findall(pat, s)
427 return re.findall(pat, s)
428
428
429 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
429 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
430
430
431 def test_unicode_execute(self):
431 def test_unicode_execute(self):
432 """test executing unicode strings"""
432 """test executing unicode strings"""
433 v = self.client[-1]
433 v = self.client[-1]
434 v.block=True
434 v.block=True
435 if sys.version_info[0] >= 3:
435 if sys.version_info[0] >= 3:
436 code="a='é'"
436 code="a='é'"
437 else:
437 else:
438 code=u"a=u'é'"
438 code=u"a=u'é'"
439 v.execute(code)
439 v.execute(code)
440 self.assertEqual(v['a'], u'é')
440 self.assertEqual(v['a'], u'é')
441
441
442 def test_unicode_apply_result(self):
442 def test_unicode_apply_result(self):
443 """test unicode apply results"""
443 """test unicode apply results"""
444 v = self.client[-1]
444 v = self.client[-1]
445 r = v.apply_sync(lambda : u'é')
445 r = v.apply_sync(lambda : u'é')
446 self.assertEqual(r, u'é')
446 self.assertEqual(r, u'é')
447
447
448 def test_unicode_apply_arg(self):
448 def test_unicode_apply_arg(self):
449 """test passing unicode arguments to apply"""
449 """test passing unicode arguments to apply"""
450 v = self.client[-1]
450 v = self.client[-1]
451
451
452 @interactive
452 @interactive
453 def check_unicode(a, check):
453 def check_unicode(a, check):
454 assert isinstance(a, unicode_type), "%r is not unicode"%a
454 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
455 assert isinstance(check, bytes), "%r is not bytes"%check
455 assert isinstance(check, bytes), "%r is not bytes"%check
456 assert a.encode('utf8') == check, "%s != %s"%(a,check)
456 assert a.encode('utf8') == check, "%s != %s"%(a,check)
457
457
458 for s in [ u'é', u'ßø®∫',u'asdf' ]:
458 for s in [ u'é', u'ßø®∫',u'asdf' ]:
459 try:
459 try:
460 v.apply_sync(check_unicode, s, s.encode('utf8'))
460 v.apply_sync(check_unicode, s, s.encode('utf8'))
461 except error.RemoteError as e:
461 except error.RemoteError as e:
462 if e.ename == 'AssertionError':
462 if e.ename == 'AssertionError':
463 self.fail(e.evalue)
463 self.fail(e.evalue)
464 else:
464 else:
465 raise e
465 raise e
466
466
467 def test_map_reference(self):
467 def test_map_reference(self):
468 """view.map(<Reference>, *seqs) should work"""
468 """view.map(<Reference>, *seqs) should work"""
469 v = self.client[:]
469 v = self.client[:]
470 v.scatter('n', self.client.ids, flatten=True)
470 v.scatter('n', self.client.ids, flatten=True)
471 v.execute("f = lambda x,y: x*y")
471 v.execute("f = lambda x,y: x*y")
472 rf = pmod.Reference('f')
472 rf = pmod.Reference('f')
473 nlist = list(range(10))
473 nlist = list(range(10))
474 mlist = nlist[::-1]
474 mlist = nlist[::-1]
475 expected = [ m*n for m,n in zip(mlist, nlist) ]
475 expected = [ m*n for m,n in zip(mlist, nlist) ]
476 result = v.map_sync(rf, mlist, nlist)
476 result = v.map_sync(rf, mlist, nlist)
477 self.assertEqual(result, expected)
477 self.assertEqual(result, expected)
478
478
479 def test_apply_reference(self):
479 def test_apply_reference(self):
480 """view.apply(<Reference>, *args) should work"""
480 """view.apply(<Reference>, *args) should work"""
481 v = self.client[:]
481 v = self.client[:]
482 v.scatter('n', self.client.ids, flatten=True)
482 v.scatter('n', self.client.ids, flatten=True)
483 v.execute("f = lambda x: n*x")
483 v.execute("f = lambda x: n*x")
484 rf = pmod.Reference('f')
484 rf = pmod.Reference('f')
485 result = v.apply_sync(rf, 5)
485 result = v.apply_sync(rf, 5)
486 expected = [ 5*id for id in self.client.ids ]
486 expected = [ 5*id for id in self.client.ids ]
487 self.assertEqual(result, expected)
487 self.assertEqual(result, expected)
488
488
489 def test_eval_reference(self):
489 def test_eval_reference(self):
490 v = self.client[self.client.ids[0]]
490 v = self.client[self.client.ids[0]]
491 v['g'] = range(5)
491 v['g'] = list(range(5))
492 rg = pmod.Reference('g[0]')
492 rg = pmod.Reference('g[0]')
493 echo = lambda x:x
493 echo = lambda x:x
494 self.assertEqual(v.apply_sync(echo, rg), 0)
494 self.assertEqual(v.apply_sync(echo, rg), 0)
495
495
496 def test_reference_nameerror(self):
496 def test_reference_nameerror(self):
497 v = self.client[self.client.ids[0]]
497 v = self.client[self.client.ids[0]]
498 r = pmod.Reference('elvis_has_left')
498 r = pmod.Reference('elvis_has_left')
499 echo = lambda x:x
499 echo = lambda x:x
500 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
500 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
501
501
502 def test_single_engine_map(self):
502 def test_single_engine_map(self):
503 e0 = self.client[self.client.ids[0]]
503 e0 = self.client[self.client.ids[0]]
504 r = range(5)
504 r = list(range(5))
505 check = [ -1*i for i in r ]
505 check = [ -1*i for i in r ]
506 result = e0.map_sync(lambda x: -1*x, r)
506 result = e0.map_sync(lambda x: -1*x, r)
507 self.assertEqual(result, check)
507 self.assertEqual(result, check)
508
508
509 def test_len(self):
509 def test_len(self):
510 """len(view) makes sense"""
510 """len(view) makes sense"""
511 e0 = self.client[self.client.ids[0]]
511 e0 = self.client[self.client.ids[0]]
512 self.assertEqual(len(e0), 1)
512 self.assertEqual(len(e0), 1)
513 v = self.client[:]
513 v = self.client[:]
514 self.assertEqual(len(v), len(self.client.ids))
514 self.assertEqual(len(v), len(self.client.ids))
515 v = self.client.direct_view('all')
515 v = self.client.direct_view('all')
516 self.assertEqual(len(v), len(self.client.ids))
516 self.assertEqual(len(v), len(self.client.ids))
517 v = self.client[:2]
517 v = self.client[:2]
518 self.assertEqual(len(v), 2)
518 self.assertEqual(len(v), 2)
519 v = self.client[:1]
519 v = self.client[:1]
520 self.assertEqual(len(v), 1)
520 self.assertEqual(len(v), 1)
521 v = self.client.load_balanced_view()
521 v = self.client.load_balanced_view()
522 self.assertEqual(len(v), len(self.client.ids))
522 self.assertEqual(len(v), len(self.client.ids))
523
523
524
524
525 # begin execute tests
525 # begin execute tests
526
526
527 def test_execute_reply(self):
527 def test_execute_reply(self):
528 e0 = self.client[self.client.ids[0]]
528 e0 = self.client[self.client.ids[0]]
529 e0.block = True
529 e0.block = True
530 ar = e0.execute("5", silent=False)
530 ar = e0.execute("5", silent=False)
531 er = ar.get()
531 er = ar.get()
532 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
532 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
533 self.assertEqual(er.pyout['data']['text/plain'], '5')
533 self.assertEqual(er.pyout['data']['text/plain'], '5')
534
534
535 def test_execute_reply_rich(self):
535 def test_execute_reply_rich(self):
536 e0 = self.client[self.client.ids[0]]
536 e0 = self.client[self.client.ids[0]]
537 e0.block = True
537 e0.block = True
538 e0.execute("from IPython.display import Image, HTML")
538 e0.execute("from IPython.display import Image, HTML")
539 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
539 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
540 er = ar.get()
540 er = ar.get()
541 b64data = base64.encodestring(b'garbage').decode('ascii')
541 b64data = base64.encodestring(b'garbage').decode('ascii')
542 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
542 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
543 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
543 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
544 er = ar.get()
544 er = ar.get()
545 self.assertEqual(er._repr_html_(), "<b>bold</b>")
545 self.assertEqual(er._repr_html_(), "<b>bold</b>")
546
546
547 def test_execute_reply_stdout(self):
547 def test_execute_reply_stdout(self):
548 e0 = self.client[self.client.ids[0]]
548 e0 = self.client[self.client.ids[0]]
549 e0.block = True
549 e0.block = True
550 ar = e0.execute("print (5)", silent=False)
550 ar = e0.execute("print (5)", silent=False)
551 er = ar.get()
551 er = ar.get()
552 self.assertEqual(er.stdout.strip(), '5')
552 self.assertEqual(er.stdout.strip(), '5')
553
553
554 def test_execute_pyout(self):
554 def test_execute_pyout(self):
555 """execute triggers pyout with silent=False"""
555 """execute triggers pyout with silent=False"""
556 view = self.client[:]
556 view = self.client[:]
557 ar = view.execute("5", silent=False, block=True)
557 ar = view.execute("5", silent=False, block=True)
558
558
559 expected = [{'text/plain' : '5'}] * len(view)
559 expected = [{'text/plain' : '5'}] * len(view)
560 mimes = [ out['data'] for out in ar.pyout ]
560 mimes = [ out['data'] for out in ar.pyout ]
561 self.assertEqual(mimes, expected)
561 self.assertEqual(mimes, expected)
562
562
563 def test_execute_silent(self):
563 def test_execute_silent(self):
564 """execute does not trigger pyout with silent=True"""
564 """execute does not trigger pyout with silent=True"""
565 view = self.client[:]
565 view = self.client[:]
566 ar = view.execute("5", block=True)
566 ar = view.execute("5", block=True)
567 expected = [None] * len(view)
567 expected = [None] * len(view)
568 self.assertEqual(ar.pyout, expected)
568 self.assertEqual(ar.pyout, expected)
569
569
570 def test_execute_magic(self):
570 def test_execute_magic(self):
571 """execute accepts IPython commands"""
571 """execute accepts IPython commands"""
572 view = self.client[:]
572 view = self.client[:]
573 view.execute("a = 5")
573 view.execute("a = 5")
574 ar = view.execute("%whos", block=True)
574 ar = view.execute("%whos", block=True)
575 # this will raise, if that failed
575 # this will raise, if that failed
576 ar.get(5)
576 ar.get(5)
577 for stdout in ar.stdout:
577 for stdout in ar.stdout:
578 lines = stdout.splitlines()
578 lines = stdout.splitlines()
579 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
579 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
580 found = False
580 found = False
581 for line in lines[2:]:
581 for line in lines[2:]:
582 split = line.split()
582 split = line.split()
583 if split == ['a', 'int', '5']:
583 if split == ['a', 'int', '5']:
584 found = True
584 found = True
585 break
585 break
586 self.assertTrue(found, "whos output wrong: %s" % stdout)
586 self.assertTrue(found, "whos output wrong: %s" % stdout)
587
587
588 def test_execute_displaypub(self):
588 def test_execute_displaypub(self):
589 """execute tracks display_pub output"""
589 """execute tracks display_pub output"""
590 view = self.client[:]
590 view = self.client[:]
591 view.execute("from IPython.core.display import *")
591 view.execute("from IPython.core.display import *")
592 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
592 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
593
593
594 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
594 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
595 for outputs in ar.outputs:
595 for outputs in ar.outputs:
596 mimes = [ out['data'] for out in outputs ]
596 mimes = [ out['data'] for out in outputs ]
597 self.assertEqual(mimes, expected)
597 self.assertEqual(mimes, expected)
598
598
599 def test_apply_displaypub(self):
599 def test_apply_displaypub(self):
600 """apply tracks display_pub output"""
600 """apply tracks display_pub output"""
601 view = self.client[:]
601 view = self.client[:]
602 view.execute("from IPython.core.display import *")
602 view.execute("from IPython.core.display import *")
603
603
604 @interactive
604 @interactive
605 def publish():
605 def publish():
606 [ display(i) for i in range(5) ]
606 [ display(i) for i in range(5) ]
607
607
608 ar = view.apply_async(publish)
608 ar = view.apply_async(publish)
609 ar.get(5)
609 ar.get(5)
610 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
610 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
611 for outputs in ar.outputs:
611 for outputs in ar.outputs:
612 mimes = [ out['data'] for out in outputs ]
612 mimes = [ out['data'] for out in outputs ]
613 self.assertEqual(mimes, expected)
613 self.assertEqual(mimes, expected)
614
614
615 def test_execute_raises(self):
615 def test_execute_raises(self):
616 """exceptions in execute requests raise appropriately"""
616 """exceptions in execute requests raise appropriately"""
617 view = self.client[-1]
617 view = self.client[-1]
618 ar = view.execute("1/0")
618 ar = view.execute("1/0")
619 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
619 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
620
620
621 def test_remoteerror_render_exception(self):
621 def test_remoteerror_render_exception(self):
622 """RemoteErrors get nice tracebacks"""
622 """RemoteErrors get nice tracebacks"""
623 view = self.client[-1]
623 view = self.client[-1]
624 ar = view.execute("1/0")
624 ar = view.execute("1/0")
625 ip = get_ipython()
625 ip = get_ipython()
626 ip.user_ns['ar'] = ar
626 ip.user_ns['ar'] = ar
627 with capture_output() as io:
627 with capture_output() as io:
628 ip.run_cell("ar.get(2)")
628 ip.run_cell("ar.get(2)")
629
629
630 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
630 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
631
631
632 def test_compositeerror_render_exception(self):
632 def test_compositeerror_render_exception(self):
633 """CompositeErrors get nice tracebacks"""
633 """CompositeErrors get nice tracebacks"""
634 view = self.client[:]
634 view = self.client[:]
635 ar = view.execute("1/0")
635 ar = view.execute("1/0")
636 ip = get_ipython()
636 ip = get_ipython()
637 ip.user_ns['ar'] = ar
637 ip.user_ns['ar'] = ar
638
638
639 with capture_output() as io:
639 with capture_output() as io:
640 ip.run_cell("ar.get(2)")
640 ip.run_cell("ar.get(2)")
641
641
642 count = min(error.CompositeError.tb_limit, len(view))
642 count = min(error.CompositeError.tb_limit, len(view))
643
643
644 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
644 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
645 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
645 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
646 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
646 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
647
647
648 def test_compositeerror_truncate(self):
648 def test_compositeerror_truncate(self):
649 """Truncate CompositeErrors with many exceptions"""
649 """Truncate CompositeErrors with many exceptions"""
650 view = self.client[:]
650 view = self.client[:]
651 msg_ids = []
651 msg_ids = []
652 for i in range(10):
652 for i in range(10):
653 ar = view.execute("1/0")
653 ar = view.execute("1/0")
654 msg_ids.extend(ar.msg_ids)
654 msg_ids.extend(ar.msg_ids)
655
655
656 ar = self.client.get_result(msg_ids)
656 ar = self.client.get_result(msg_ids)
657 try:
657 try:
658 ar.get()
658 ar.get()
659 except error.CompositeError as _e:
659 except error.CompositeError as _e:
660 e = _e
660 e = _e
661 else:
661 else:
662 self.fail("Should have raised CompositeError")
662 self.fail("Should have raised CompositeError")
663
663
664 lines = e.render_traceback()
664 lines = e.render_traceback()
665 with capture_output() as io:
665 with capture_output() as io:
666 e.print_traceback()
666 e.print_traceback()
667
667
668 self.assertTrue("more exceptions" in lines[-1])
668 self.assertTrue("more exceptions" in lines[-1])
669 count = e.tb_limit
669 count = e.tb_limit
670
670
671 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
671 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
672 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
672 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
673 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
673 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
674
674
675 @dec.skipif_not_matplotlib
675 @dec.skipif_not_matplotlib
676 def test_magic_pylab(self):
676 def test_magic_pylab(self):
677 """%pylab works on engines"""
677 """%pylab works on engines"""
678 view = self.client[-1]
678 view = self.client[-1]
679 ar = view.execute("%pylab inline")
679 ar = view.execute("%pylab inline")
680 # at least check if this raised:
680 # at least check if this raised:
681 reply = ar.get(5)
681 reply = ar.get(5)
682 # include imports, in case user config
682 # include imports, in case user config
683 ar = view.execute("plot(rand(100))", silent=False)
683 ar = view.execute("plot(rand(100))", silent=False)
684 reply = ar.get(5)
684 reply = ar.get(5)
685 self.assertEqual(len(reply.outputs), 1)
685 self.assertEqual(len(reply.outputs), 1)
686 output = reply.outputs[0]
686 output = reply.outputs[0]
687 self.assertTrue("data" in output)
687 self.assertTrue("data" in output)
688 data = output['data']
688 data = output['data']
689 self.assertTrue("image/png" in data)
689 self.assertTrue("image/png" in data)
690
690
691 def test_func_default_func(self):
691 def test_func_default_func(self):
692 """interactively defined function as apply func default"""
692 """interactively defined function as apply func default"""
693 def foo():
693 def foo():
694 return 'foo'
694 return 'foo'
695
695
696 def bar(f=foo):
696 def bar(f=foo):
697 return f()
697 return f()
698
698
699 view = self.client[-1]
699 view = self.client[-1]
700 ar = view.apply_async(bar)
700 ar = view.apply_async(bar)
701 r = ar.get(10)
701 r = ar.get(10)
702 self.assertEqual(r, 'foo')
702 self.assertEqual(r, 'foo')
703 def test_data_pub_single(self):
703 def test_data_pub_single(self):
704 view = self.client[-1]
704 view = self.client[-1]
705 ar = view.execute('\n'.join([
705 ar = view.execute('\n'.join([
706 'from IPython.kernel.zmq.datapub import publish_data',
706 'from IPython.kernel.zmq.datapub import publish_data',
707 'for i in range(5):',
707 'for i in range(5):',
708 ' publish_data(dict(i=i))'
708 ' publish_data(dict(i=i))'
709 ]), block=False)
709 ]), block=False)
710 self.assertTrue(isinstance(ar.data, dict))
710 self.assertTrue(isinstance(ar.data, dict))
711 ar.get(5)
711 ar.get(5)
712 self.assertEqual(ar.data, dict(i=4))
712 self.assertEqual(ar.data, dict(i=4))
713
713
714 def test_data_pub(self):
714 def test_data_pub(self):
715 view = self.client[:]
715 view = self.client[:]
716 ar = view.execute('\n'.join([
716 ar = view.execute('\n'.join([
717 'from IPython.kernel.zmq.datapub import publish_data',
717 'from IPython.kernel.zmq.datapub import publish_data',
718 'for i in range(5):',
718 'for i in range(5):',
719 ' publish_data(dict(i=i))'
719 ' publish_data(dict(i=i))'
720 ]), block=False)
720 ]), block=False)
721 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
721 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
722 ar.get(5)
722 ar.get(5)
723 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
723 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
724
724
725 def test_can_list_arg(self):
725 def test_can_list_arg(self):
726 """args in lists are canned"""
726 """args in lists are canned"""
727 view = self.client[-1]
727 view = self.client[-1]
728 view['a'] = 128
728 view['a'] = 128
729 rA = pmod.Reference('a')
729 rA = pmod.Reference('a')
730 ar = view.apply_async(lambda x: x, [rA])
730 ar = view.apply_async(lambda x: x, [rA])
731 r = ar.get(5)
731 r = ar.get(5)
732 self.assertEqual(r, [128])
732 self.assertEqual(r, [128])
733
733
734 def test_can_dict_arg(self):
734 def test_can_dict_arg(self):
735 """args in dicts are canned"""
735 """args in dicts are canned"""
736 view = self.client[-1]
736 view = self.client[-1]
737 view['a'] = 128
737 view['a'] = 128
738 rA = pmod.Reference('a')
738 rA = pmod.Reference('a')
739 ar = view.apply_async(lambda x: x, dict(foo=rA))
739 ar = view.apply_async(lambda x: x, dict(foo=rA))
740 r = ar.get(5)
740 r = ar.get(5)
741 self.assertEqual(r, dict(foo=128))
741 self.assertEqual(r, dict(foo=128))
742
742
743 def test_can_list_kwarg(self):
743 def test_can_list_kwarg(self):
744 """kwargs in lists are canned"""
744 """kwargs in lists are canned"""
745 view = self.client[-1]
745 view = self.client[-1]
746 view['a'] = 128
746 view['a'] = 128
747 rA = pmod.Reference('a')
747 rA = pmod.Reference('a')
748 ar = view.apply_async(lambda x=5: x, x=[rA])
748 ar = view.apply_async(lambda x=5: x, x=[rA])
749 r = ar.get(5)
749 r = ar.get(5)
750 self.assertEqual(r, [128])
750 self.assertEqual(r, [128])
751
751
752 def test_can_dict_kwarg(self):
752 def test_can_dict_kwarg(self):
753 """kwargs in dicts are canned"""
753 """kwargs in dicts are canned"""
754 view = self.client[-1]
754 view = self.client[-1]
755 view['a'] = 128
755 view['a'] = 128
756 rA = pmod.Reference('a')
756 rA = pmod.Reference('a')
757 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
757 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
758 r = ar.get(5)
758 r = ar.get(5)
759 self.assertEqual(r, dict(foo=128))
759 self.assertEqual(r, dict(foo=128))
760
760
761 def test_map_ref(self):
761 def test_map_ref(self):
762 """view.map works with references"""
762 """view.map works with references"""
763 view = self.client[:]
763 view = self.client[:]
764 ranks = sorted(self.client.ids)
764 ranks = sorted(self.client.ids)
765 view.scatter('rank', ranks, flatten=True)
765 view.scatter('rank', ranks, flatten=True)
766 rrank = pmod.Reference('rank')
766 rrank = pmod.Reference('rank')
767
767
768 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
768 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
769 drank = amr.get(5)
769 drank = amr.get(5)
770 self.assertEqual(drank, [ r*2 for r in ranks ])
770 self.assertEqual(drank, [ r*2 for r in ranks ])
771
771
772 def test_nested_getitem_setitem(self):
772 def test_nested_getitem_setitem(self):
773 """get and set with view['a.b']"""
773 """get and set with view['a.b']"""
774 view = self.client[-1]
774 view = self.client[-1]
775 view.execute('\n'.join([
775 view.execute('\n'.join([
776 'class A(object): pass',
776 'class A(object): pass',
777 'a = A()',
777 'a = A()',
778 'a.b = 128',
778 'a.b = 128',
779 ]), block=True)
779 ]), block=True)
780 ra = pmod.Reference('a')
780 ra = pmod.Reference('a')
781
781
782 r = view.apply_sync(lambda x: x.b, ra)
782 r = view.apply_sync(lambda x: x.b, ra)
783 self.assertEqual(r, 128)
783 self.assertEqual(r, 128)
784 self.assertEqual(view['a.b'], 128)
784 self.assertEqual(view['a.b'], 128)
785
785
786 view['a.b'] = 0
786 view['a.b'] = 0
787
787
788 r = view.apply_sync(lambda x: x.b, ra)
788 r = view.apply_sync(lambda x: x.b, ra)
789 self.assertEqual(r, 0)
789 self.assertEqual(r, 0)
790 self.assertEqual(view['a.b'], 0)
790 self.assertEqual(view['a.b'], 0)
791
791
792 def test_return_namedtuple(self):
792 def test_return_namedtuple(self):
793 def namedtuplify(x, y):
793 def namedtuplify(x, y):
794 from IPython.parallel.tests.test_view import point
794 from IPython.parallel.tests.test_view import point
795 return point(x, y)
795 return point(x, y)
796
796
797 view = self.client[-1]
797 view = self.client[-1]
798 p = view.apply_sync(namedtuplify, 1, 2)
798 p = view.apply_sync(namedtuplify, 1, 2)
799 self.assertEqual(p.x, 1)
799 self.assertEqual(p.x, 1)
800 self.assertEqual(p.y, 2)
800 self.assertEqual(p.y, 2)
801
801
802 def test_apply_namedtuple(self):
802 def test_apply_namedtuple(self):
803 def echoxy(p):
803 def echoxy(p):
804 return p.y, p.x
804 return p.y, p.x
805
805
806 view = self.client[-1]
806 view = self.client[-1]
807 tup = view.apply_sync(echoxy, point(1, 2))
807 tup = view.apply_sync(echoxy, point(1, 2))
808 self.assertEqual(tup, (2,1))
808 self.assertEqual(tup, (2,1))
809
809
810 def test_sync_imports(self):
810 def test_sync_imports(self):
811 view = self.client[-1]
811 view = self.client[-1]
812 with capture_output() as io:
812 with capture_output() as io:
813 with view.sync_imports():
813 with view.sync_imports():
814 import IPython
814 import IPython
815 self.assertIn("IPython", io.stdout)
815 self.assertIn("IPython", io.stdout)
816
816
817 @interactive
817 @interactive
818 def find_ipython():
818 def find_ipython():
819 return 'IPython' in globals()
819 return 'IPython' in globals()
820
820
821 assert view.apply_sync(find_ipython)
821 assert view.apply_sync(find_ipython)
822
822
823 def test_sync_imports_quiet(self):
823 def test_sync_imports_quiet(self):
824 view = self.client[-1]
824 view = self.client[-1]
825 with capture_output() as io:
825 with capture_output() as io:
826 with view.sync_imports(quiet=True):
826 with view.sync_imports(quiet=True):
827 import IPython
827 import IPython
828 self.assertEqual(io.stdout, '')
828 self.assertEqual(io.stdout, '')
829
829
830 @interactive
830 @interactive
831 def find_ipython():
831 def find_ipython():
832 return 'IPython' in globals()
832 return 'IPython' in globals()
833
833
834 assert view.apply_sync(find_ipython)
834 assert view.apply_sync(find_ipython)
835
835
@@ -1,369 +1,369 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 from IPython.utils.py3compat import string_types, iteritems, itervalues
47 from IPython.utils.py3compat import string_types, iteritems, itervalues
48 from IPython.kernel.zmq.log import EnginePUBHandler
48 from IPython.kernel.zmq.log import EnginePUBHandler
49 from IPython.kernel.zmq.serialize import (
49 from IPython.kernel.zmq.serialize import (
50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 )
51 )
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Classes
54 # Classes
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 class Namespace(dict):
57 class Namespace(dict):
58 """Subclass of dict for attribute access to keys."""
58 """Subclass of dict for attribute access to keys."""
59
59
60 def __getattr__(self, key):
60 def __getattr__(self, key):
61 """getattr aliased to getitem"""
61 """getattr aliased to getitem"""
62 if key in self:
62 if key in self:
63 return self[key]
63 return self[key]
64 else:
64 else:
65 raise NameError(key)
65 raise NameError(key)
66
66
67 def __setattr__(self, key, value):
67 def __setattr__(self, key, value):
68 """setattr aliased to setitem, with strict"""
68 """setattr aliased to setitem, with strict"""
69 if hasattr(dict, key):
69 if hasattr(dict, key):
70 raise KeyError("Cannot override dict keys %r"%key)
70 raise KeyError("Cannot override dict keys %r"%key)
71 self[key] = value
71 self[key] = value
72
72
73
73
74 class ReverseDict(dict):
74 class ReverseDict(dict):
75 """simple double-keyed subset of dict methods."""
75 """simple double-keyed subset of dict methods."""
76
76
77 def __init__(self, *args, **kwargs):
77 def __init__(self, *args, **kwargs):
78 dict.__init__(self, *args, **kwargs)
78 dict.__init__(self, *args, **kwargs)
79 self._reverse = dict()
79 self._reverse = dict()
80 for key, value in iteritems(self):
80 for key, value in iteritems(self):
81 self._reverse[value] = key
81 self._reverse[value] = key
82
82
83 def __getitem__(self, key):
83 def __getitem__(self, key):
84 try:
84 try:
85 return dict.__getitem__(self, key)
85 return dict.__getitem__(self, key)
86 except KeyError:
86 except KeyError:
87 return self._reverse[key]
87 return self._reverse[key]
88
88
89 def __setitem__(self, key, value):
89 def __setitem__(self, key, value):
90 if key in self._reverse:
90 if key in self._reverse:
91 raise KeyError("Can't have key %r on both sides!"%key)
91 raise KeyError("Can't have key %r on both sides!"%key)
92 dict.__setitem__(self, key, value)
92 dict.__setitem__(self, key, value)
93 self._reverse[value] = key
93 self._reverse[value] = key
94
94
95 def pop(self, key):
95 def pop(self, key):
96 value = dict.pop(self, key)
96 value = dict.pop(self, key)
97 self._reverse.pop(value)
97 self._reverse.pop(value)
98 return value
98 return value
99
99
100 def get(self, key, default=None):
100 def get(self, key, default=None):
101 try:
101 try:
102 return self[key]
102 return self[key]
103 except KeyError:
103 except KeyError:
104 return default
104 return default
105
105
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 # Functions
107 # Functions
108 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
109
109
110 @decorator
110 @decorator
111 def log_errors(f, self, *args, **kwargs):
111 def log_errors(f, self, *args, **kwargs):
112 """decorator to log unhandled exceptions raised in a method.
112 """decorator to log unhandled exceptions raised in a method.
113
113
114 For use wrapping on_recv callbacks, so that exceptions
114 For use wrapping on_recv callbacks, so that exceptions
115 do not cause the stream to be closed.
115 do not cause the stream to be closed.
116 """
116 """
117 try:
117 try:
118 return f(self, *args, **kwargs)
118 return f(self, *args, **kwargs)
119 except Exception:
119 except Exception:
120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121
121
122
122
123 def is_url(url):
123 def is_url(url):
124 """boolean check for whether a string is a zmq url"""
124 """boolean check for whether a string is a zmq url"""
125 if '://' not in url:
125 if '://' not in url:
126 return False
126 return False
127 proto, addr = url.split('://', 1)
127 proto, addr = url.split('://', 1)
128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 return False
129 return False
130 return True
130 return True
131
131
132 def validate_url(url):
132 def validate_url(url):
133 """validate a url for zeromq"""
133 """validate a url for zeromq"""
134 if not isinstance(url, string_types):
134 if not isinstance(url, string_types):
135 raise TypeError("url must be a string, not %r"%type(url))
135 raise TypeError("url must be a string, not %r"%type(url))
136 url = url.lower()
136 url = url.lower()
137
137
138 proto_addr = url.split('://')
138 proto_addr = url.split('://')
139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 proto, addr = proto_addr
140 proto, addr = proto_addr
141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142
142
143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 # author: Remi Sabourin
144 # author: Remi Sabourin
145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146
146
147 if proto == 'tcp':
147 if proto == 'tcp':
148 lis = addr.split(':')
148 lis = addr.split(':')
149 assert len(lis) == 2, 'Invalid url: %r'%url
149 assert len(lis) == 2, 'Invalid url: %r'%url
150 addr,s_port = lis
150 addr,s_port = lis
151 try:
151 try:
152 port = int(s_port)
152 port = int(s_port)
153 except ValueError:
153 except ValueError:
154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155
155
156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157
157
158 else:
158 else:
159 # only validate tcp urls currently
159 # only validate tcp urls currently
160 pass
160 pass
161
161
162 return True
162 return True
163
163
164
164
165 def validate_url_container(container):
165 def validate_url_container(container):
166 """validate a potentially nested collection of urls."""
166 """validate a potentially nested collection of urls."""
167 if isinstance(container, string_types):
167 if isinstance(container, string_types):
168 url = container
168 url = container
169 return validate_url(url)
169 return validate_url(url)
170 elif isinstance(container, dict):
170 elif isinstance(container, dict):
171 container = itervalues(container)
171 container = itervalues(container)
172
172
173 for element in container:
173 for element in container:
174 validate_url_container(element)
174 validate_url_container(element)
175
175
176
176
177 def split_url(url):
177 def split_url(url):
178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 proto_addr = url.split('://')
179 proto_addr = url.split('://')
180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 proto, addr = proto_addr
181 proto, addr = proto_addr
182 lis = addr.split(':')
182 lis = addr.split(':')
183 assert len(lis) == 2, 'Invalid url: %r'%url
183 assert len(lis) == 2, 'Invalid url: %r'%url
184 addr,s_port = lis
184 addr,s_port = lis
185 return proto,addr,s_port
185 return proto,addr,s_port
186
186
187 def disambiguate_ip_address(ip, location=None):
187 def disambiguate_ip_address(ip, location=None):
188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 ones, based on the location (default interpretation of location is localhost)."""
189 ones, based on the location (default interpretation of location is localhost)."""
190 if ip in ('0.0.0.0', '*'):
190 if ip in ('0.0.0.0', '*'):
191 if location is None or is_public_ip(location) or not public_ips():
191 if location is None or is_public_ip(location) or not public_ips():
192 # If location is unspecified or cannot be determined, assume local
192 # If location is unspecified or cannot be determined, assume local
193 ip = localhost()
193 ip = localhost()
194 elif location:
194 elif location:
195 return location
195 return location
196 return ip
196 return ip
197
197
198 def disambiguate_url(url, location=None):
198 def disambiguate_url(url, location=None):
199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 ones, based on the location (default interpretation is localhost).
200 ones, based on the location (default interpretation is localhost).
201
201
202 This is for zeromq urls, such as tcp://*:10101."""
202 This is for zeromq urls, such as tcp://*:10101."""
203 try:
203 try:
204 proto,ip,port = split_url(url)
204 proto,ip,port = split_url(url)
205 except AssertionError:
205 except AssertionError:
206 # probably not tcp url; could be ipc, etc.
206 # probably not tcp url; could be ipc, etc.
207 return url
207 return url
208
208
209 ip = disambiguate_ip_address(ip,location)
209 ip = disambiguate_ip_address(ip,location)
210
210
211 return "%s://%s:%s"%(proto,ip,port)
211 return "%s://%s:%s"%(proto,ip,port)
212
212
213
213
214 #--------------------------------------------------------------------------
214 #--------------------------------------------------------------------------
215 # helpers for implementing old MEC API via view.apply
215 # helpers for implementing old MEC API via view.apply
216 #--------------------------------------------------------------------------
216 #--------------------------------------------------------------------------
217
217
218 def interactive(f):
218 def interactive(f):
219 """decorator for making functions appear as interactively defined.
219 """decorator for making functions appear as interactively defined.
220 This results in the function being linked to the user_ns as globals()
220 This results in the function being linked to the user_ns as globals()
221 instead of the module globals().
221 instead of the module globals().
222 """
222 """
223 f.__module__ = '__main__'
223 f.__module__ = '__main__'
224 return f
224 return f
225
225
226 @interactive
226 @interactive
227 def _push(**ns):
227 def _push(**ns):
228 """helper method for implementing `client.push` via `client.apply`"""
228 """helper method for implementing `client.push` via `client.apply`"""
229 user_ns = globals()
229 user_ns = globals()
230 tmp = '_IP_PUSH_TMP_'
230 tmp = '_IP_PUSH_TMP_'
231 while tmp in user_ns:
231 while tmp in user_ns:
232 tmp = tmp + '_'
232 tmp = tmp + '_'
233 try:
233 try:
234 for name, value in iteritems(ns):
234 for name, value in ns.items():
235 user_ns[tmp] = value
235 user_ns[tmp] = value
236 exec("%s = %s" % (name, tmp), user_ns)
236 exec("%s = %s" % (name, tmp), user_ns)
237 finally:
237 finally:
238 user_ns.pop(tmp, None)
238 user_ns.pop(tmp, None)
239
239
240 @interactive
240 @interactive
241 def _pull(keys):
241 def _pull(keys):
242 """helper method for implementing `client.pull` via `client.apply`"""
242 """helper method for implementing `client.pull` via `client.apply`"""
243 if isinstance(keys, (list,tuple, set)):
243 if isinstance(keys, (list,tuple, set)):
244 return map(lambda key: eval(key, globals()), keys)
244 return [eval(key, globals()) for key in keys]
245 else:
245 else:
246 return eval(keys, globals())
246 return eval(keys, globals())
247
247
248 @interactive
248 @interactive
249 def _execute(code):
249 def _execute(code):
250 """helper method for implementing `client.execute` via `client.apply`"""
250 """helper method for implementing `client.execute` via `client.apply`"""
251 exec(code, globals())
251 exec(code, globals())
252
252
253 #--------------------------------------------------------------------------
253 #--------------------------------------------------------------------------
254 # extra process management utilities
254 # extra process management utilities
255 #--------------------------------------------------------------------------
255 #--------------------------------------------------------------------------
256
256
257 _random_ports = set()
257 _random_ports = set()
258
258
259 def select_random_ports(n):
259 def select_random_ports(n):
260 """Selects and return n random ports that are available."""
260 """Selects and return n random ports that are available."""
261 ports = []
261 ports = []
262 for i in range(n):
262 for i in range(n):
263 sock = socket.socket()
263 sock = socket.socket()
264 sock.bind(('', 0))
264 sock.bind(('', 0))
265 while sock.getsockname()[1] in _random_ports:
265 while sock.getsockname()[1] in _random_ports:
266 sock.close()
266 sock.close()
267 sock = socket.socket()
267 sock = socket.socket()
268 sock.bind(('', 0))
268 sock.bind(('', 0))
269 ports.append(sock)
269 ports.append(sock)
270 for i, sock in enumerate(ports):
270 for i, sock in enumerate(ports):
271 port = sock.getsockname()[1]
271 port = sock.getsockname()[1]
272 sock.close()
272 sock.close()
273 ports[i] = port
273 ports[i] = port
274 _random_ports.add(port)
274 _random_ports.add(port)
275 return ports
275 return ports
276
276
277 def signal_children(children):
277 def signal_children(children):
278 """Relay interupt/term signals to children, for more solid process cleanup."""
278 """Relay interupt/term signals to children, for more solid process cleanup."""
279 def terminate_children(sig, frame):
279 def terminate_children(sig, frame):
280 log = Application.instance().log
280 log = Application.instance().log
281 log.critical("Got signal %i, terminating children..."%sig)
281 log.critical("Got signal %i, terminating children..."%sig)
282 for child in children:
282 for child in children:
283 child.terminate()
283 child.terminate()
284
284
285 sys.exit(sig != SIGINT)
285 sys.exit(sig != SIGINT)
286 # sys.exit(sig)
286 # sys.exit(sig)
287 for sig in (SIGINT, SIGABRT, SIGTERM):
287 for sig in (SIGINT, SIGABRT, SIGTERM):
288 signal(sig, terminate_children)
288 signal(sig, terminate_children)
289
289
290 def generate_exec_key(keyfile):
290 def generate_exec_key(keyfile):
291 import uuid
291 import uuid
292 newkey = str(uuid.uuid4())
292 newkey = str(uuid.uuid4())
293 with open(keyfile, 'w') as f:
293 with open(keyfile, 'w') as f:
294 # f.write('ipython-key ')
294 # f.write('ipython-key ')
295 f.write(newkey+'\n')
295 f.write(newkey+'\n')
296 # set user-only RW permissions (0600)
296 # set user-only RW permissions (0600)
297 # this will have no effect on Windows
297 # this will have no effect on Windows
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299
299
300
300
301 def integer_loglevel(loglevel):
301 def integer_loglevel(loglevel):
302 try:
302 try:
303 loglevel = int(loglevel)
303 loglevel = int(loglevel)
304 except ValueError:
304 except ValueError:
305 if isinstance(loglevel, str):
305 if isinstance(loglevel, str):
306 loglevel = getattr(logging, loglevel)
306 loglevel = getattr(logging, loglevel)
307 return loglevel
307 return loglevel
308
308
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 logger = logging.getLogger(logname)
310 logger = logging.getLogger(logname)
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 # don't add a second PUBHandler
312 # don't add a second PUBHandler
313 return
313 return
314 loglevel = integer_loglevel(loglevel)
314 loglevel = integer_loglevel(loglevel)
315 lsock = context.socket(zmq.PUB)
315 lsock = context.socket(zmq.PUB)
316 lsock.connect(iface)
316 lsock.connect(iface)
317 handler = handlers.PUBHandler(lsock)
317 handler = handlers.PUBHandler(lsock)
318 handler.setLevel(loglevel)
318 handler.setLevel(loglevel)
319 handler.root_topic = root
319 handler.root_topic = root
320 logger.addHandler(handler)
320 logger.addHandler(handler)
321 logger.setLevel(loglevel)
321 logger.setLevel(loglevel)
322
322
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 logger = logging.getLogger()
324 logger = logging.getLogger()
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 # don't add a second PUBHandler
326 # don't add a second PUBHandler
327 return
327 return
328 loglevel = integer_loglevel(loglevel)
328 loglevel = integer_loglevel(loglevel)
329 lsock = context.socket(zmq.PUB)
329 lsock = context.socket(zmq.PUB)
330 lsock.connect(iface)
330 lsock.connect(iface)
331 handler = EnginePUBHandler(engine, lsock)
331 handler = EnginePUBHandler(engine, lsock)
332 handler.setLevel(loglevel)
332 handler.setLevel(loglevel)
333 logger.addHandler(handler)
333 logger.addHandler(handler)
334 logger.setLevel(loglevel)
334 logger.setLevel(loglevel)
335 return logger
335 return logger
336
336
337 def local_logger(logname, loglevel=logging.DEBUG):
337 def local_logger(logname, loglevel=logging.DEBUG):
338 loglevel = integer_loglevel(loglevel)
338 loglevel = integer_loglevel(loglevel)
339 logger = logging.getLogger(logname)
339 logger = logging.getLogger(logname)
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 # don't add a second StreamHandler
341 # don't add a second StreamHandler
342 return
342 return
343 handler = logging.StreamHandler()
343 handler = logging.StreamHandler()
344 handler.setLevel(loglevel)
344 handler.setLevel(loglevel)
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 datefmt="%Y-%m-%d %H:%M:%S")
346 datefmt="%Y-%m-%d %H:%M:%S")
347 handler.setFormatter(formatter)
347 handler.setFormatter(formatter)
348
348
349 logger.addHandler(handler)
349 logger.addHandler(handler)
350 logger.setLevel(loglevel)
350 logger.setLevel(loglevel)
351 return logger
351 return logger
352
352
353 def set_hwm(sock, hwm=0):
353 def set_hwm(sock, hwm=0):
354 """set zmq High Water Mark on a socket
354 """set zmq High Water Mark on a socket
355
355
356 in a way that always works for various pyzmq / libzmq versions.
356 in a way that always works for various pyzmq / libzmq versions.
357 """
357 """
358 import zmq
358 import zmq
359
359
360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
361 opt = getattr(zmq, key, None)
361 opt = getattr(zmq, key, None)
362 if opt is None:
362 if opt is None:
363 continue
363 continue
364 try:
364 try:
365 sock.setsockopt(opt, hwm)
365 sock.setsockopt(opt, hwm)
366 except zmq.ZMQError:
366 except zmq.ZMQError:
367 pass
367 pass
368
368
369 No newline at end of file
369
General Comments 0
You need to be logged in to leave comments. Login now