##// END OF EJS Templates
Fix parallel test suite
Thomas Kluyver -
Show More
@@ -1,275 +1,275 b''
1 1 # encoding: utf-8
2 2 """
3 3 The Base Application class for IPython.parallel apps
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Min RK
9 9
10 10 """
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2008-2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 import os
24 24 import logging
25 25 import re
26 26 import sys
27 27
28 28 from subprocess import Popen, PIPE
29 29
30 30 from IPython.config.application import catch_config_error, LevelFormatter
31 31 from IPython.core import release
32 32 from IPython.core.crashhandler import CrashHandler
33 33 from IPython.core.application import (
34 34 BaseIPythonApplication,
35 35 base_aliases as base_ip_aliases,
36 36 base_flags as base_ip_flags
37 37 )
38 38 from IPython.utils.path import expand_path
39 39 from IPython.utils.py3compat import unicode_type
40 40
41 41 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Module errors
45 45 #-----------------------------------------------------------------------------
46 46
47 47 class PIDFileError(Exception):
48 48 pass
49 49
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Crash handler for this application
53 53 #-----------------------------------------------------------------------------
54 54
55 55 class ParallelCrashHandler(CrashHandler):
56 56 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
57 57
58 58 def __init__(self, app):
59 59 contact_name = release.authors['Min'][0]
60 60 contact_email = release.author_email
61 61 bug_tracker = 'https://github.com/ipython/ipython/issues'
62 62 super(ParallelCrashHandler,self).__init__(
63 63 app, contact_name, contact_email, bug_tracker
64 64 )
65 65
66 66
67 67 #-----------------------------------------------------------------------------
68 68 # Main application
69 69 #-----------------------------------------------------------------------------
70 70 base_aliases = {}
71 71 base_aliases.update(base_ip_aliases)
72 72 base_aliases.update({
73 73 'work-dir' : 'BaseParallelApplication.work_dir',
74 74 'log-to-file' : 'BaseParallelApplication.log_to_file',
75 75 'clean-logs' : 'BaseParallelApplication.clean_logs',
76 76 'log-url' : 'BaseParallelApplication.log_url',
77 77 'cluster-id' : 'BaseParallelApplication.cluster_id',
78 78 })
79 79
80 80 base_flags = {
81 81 'log-to-file' : (
82 82 {'BaseParallelApplication' : {'log_to_file' : True}},
83 83 "send log output to a file"
84 84 )
85 85 }
86 86 base_flags.update(base_ip_flags)
87 87
88 88 class BaseParallelApplication(BaseIPythonApplication):
89 89 """The base Application for IPython.parallel apps
90 90
91 91 Principle extensions to BaseIPyythonApplication:
92 92
93 93 * work_dir
94 94 * remote logging via pyzmq
95 95 * IOLoop instance
96 96 """
97 97
98 98 crash_handler_class = ParallelCrashHandler
99 99
100 100 def _log_level_default(self):
101 101 # temporarily override default_log_level to INFO
102 102 return logging.INFO
103 103
104 104 def _log_format_default(self):
105 105 """override default log format to include time"""
106 106 return u"%(asctime)s.%(msecs).03d [%(name)s]%(highlevel)s %(message)s"
107 107
108 108 work_dir = Unicode(os.getcwdu(), config=True,
109 109 help='Set the working dir for the process.'
110 110 )
111 111 def _work_dir_changed(self, name, old, new):
112 112 self.work_dir = unicode_type(expand_path(new))
113 113
114 114 log_to_file = Bool(config=True,
115 115 help="whether to log to a file")
116 116
117 117 clean_logs = Bool(False, config=True,
118 118 help="whether to cleanup old logfiles before starting")
119 119
120 120 log_url = Unicode('', config=True,
121 121 help="The ZMQ URL of the iplogger to aggregate logging.")
122 122
123 123 cluster_id = Unicode('', config=True,
124 124 help="""String id to add to runtime files, to prevent name collisions when
125 125 using multiple clusters with a single profile simultaneously.
126 126
127 127 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
128 128
129 129 Since this is text inserted into filenames, typical recommendations apply:
130 130 Simple character strings are ideal, and spaces are not recommended (but should
131 131 generally work).
132 132 """
133 133 )
134 134 def _cluster_id_changed(self, name, old, new):
135 135 self.name = self.__class__.name
136 136 if new:
137 137 self.name += '-%s'%new
138 138
139 139 def _config_files_default(self):
140 140 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
141 141
142 142 loop = Instance('zmq.eventloop.ioloop.IOLoop')
143 143 def _loop_default(self):
144 144 from zmq.eventloop.ioloop import IOLoop
145 145 return IOLoop.instance()
146 146
147 147 aliases = Dict(base_aliases)
148 148 flags = Dict(base_flags)
149 149
150 150 @catch_config_error
151 151 def initialize(self, argv=None):
152 152 """initialize the app"""
153 153 super(BaseParallelApplication, self).initialize(argv)
154 154 self.to_work_dir()
155 155 self.reinit_logging()
156 156
157 157 def to_work_dir(self):
158 158 wd = self.work_dir
159 159 if unicode_type(wd) != os.getcwdu():
160 160 os.chdir(wd)
161 161 self.log.info("Changing to working dir: %s" % wd)
162 162 # This is the working dir by now.
163 163 sys.path.insert(0, '')
164 164
165 165 def reinit_logging(self):
166 166 # Remove old log files
167 167 log_dir = self.profile_dir.log_dir
168 168 if self.clean_logs:
169 169 for f in os.listdir(log_dir):
170 170 if re.match(r'%s-\d+\.(log|err|out)' % self.name, f):
171 171 try:
172 172 os.remove(os.path.join(log_dir, f))
173 173 except (OSError, IOError):
174 174 # probably just conflict from sibling process
175 175 # already removing it
176 176 pass
177 177 if self.log_to_file:
178 178 # Start logging to the new log file
179 179 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
180 180 logfile = os.path.join(log_dir, log_filename)
181 181 open_log_file = open(logfile, 'w')
182 182 else:
183 183 open_log_file = None
184 184 if open_log_file is not None:
185 185 while self.log.handlers:
186 186 self.log.removeHandler(self.log.handlers[0])
187 187 self._log_handler = logging.StreamHandler(open_log_file)
188 188 self.log.addHandler(self._log_handler)
189 189 else:
190 190 self._log_handler = self.log.handlers[0]
191 191 # Add timestamps to log format:
192 192 self._log_formatter = LevelFormatter(self.log_format,
193 193 datefmt=self.log_datefmt)
194 194 self._log_handler.setFormatter(self._log_formatter)
195 195 # do not propagate log messages to root logger
196 196 # ipcluster app will sometimes print duplicate messages during shutdown
197 197 # if this is 1 (default):
198 198 self.log.propagate = False
199 199
200 200 def write_pid_file(self, overwrite=False):
201 201 """Create a .pid file in the pid_dir with my pid.
202 202
203 203 This must be called after pre_construct, which sets `self.pid_dir`.
204 204 This raises :exc:`PIDFileError` if the pid file exists already.
205 205 """
206 206 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
207 207 if os.path.isfile(pid_file):
208 208 pid = self.get_pid_from_file()
209 209 if not overwrite:
210 210 raise PIDFileError(
211 211 'The pid file [%s] already exists. \nThis could mean that this '
212 212 'server is already running with [pid=%s].' % (pid_file, pid)
213 213 )
214 214 with open(pid_file, 'w') as f:
215 215 self.log.info("Creating pid file: %s" % pid_file)
216 216 f.write(repr(os.getpid())+'\n')
217 217
218 218 def remove_pid_file(self):
219 219 """Remove the pid file.
220 220
221 221 This should be called at shutdown by registering a callback with
222 222 :func:`reactor.addSystemEventTrigger`. This needs to return
223 223 ``None``.
224 224 """
225 225 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
226 226 if os.path.isfile(pid_file):
227 227 try:
228 228 self.log.info("Removing pid file: %s" % pid_file)
229 229 os.remove(pid_file)
230 230 except:
231 231 self.log.warn("Error removing the pid file: %s" % pid_file)
232 232
233 233 def get_pid_from_file(self):
234 234 """Get the pid from the pid file.
235 235
236 236 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
237 237 """
238 238 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
239 239 if os.path.isfile(pid_file):
240 240 with open(pid_file, 'r') as f:
241 241 s = f.read().strip()
242 242 try:
243 243 pid = int(s)
244 244 except:
245 245 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
246 246 return pid
247 247 else:
248 248 raise PIDFileError('pid file not found: %s' % pid_file)
249 249
250 250 def check_pid(self, pid):
251 251 if os.name == 'nt':
252 252 try:
253 253 import ctypes
254 254 # returns 0 if no such process (of ours) exists
255 255 # positive int otherwise
256 256 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
257 257 except Exception:
258 258 self.log.warn(
259 259 "Could not determine whether pid %i is running via `OpenProcess`. "
260 260 " Making the likely assumption that it is."%pid
261 261 )
262 262 return True
263 263 return bool(p)
264 264 else:
265 265 try:
266 266 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
267 267 output,_ = p.communicate()
268 268 except OSError:
269 269 self.log.warn(
270 270 "Could not determine whether pid %i is running via `ps x`. "
271 271 " Making the likely assumption that it is."%pid
272 272 )
273 273 return True
274 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
274 pids = list(map(int, re.findall(r'^\W*\d+', output, re.MULTILINE)))
275 275 return pid in pids
@@ -1,707 +1,707 b''
1 1 """AsyncResult objects for the client
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import print_function
19 19
20 20 import sys
21 21 import time
22 22 from datetime import datetime
23 23
24 24 from zmq import MessageTracker
25 25
26 26 from IPython.core.display import clear_output, display, display_pretty
27 27 from IPython.external.decorator import decorator
28 28 from IPython.parallel import error
29 29 from IPython.utils.py3compat import string_types
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Functions
33 33 #-----------------------------------------------------------------------------
34 34
35 35 def _raw_text(s):
36 36 display_pretty(s, raw=True)
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Classes
40 40 #-----------------------------------------------------------------------------
41 41
42 42 # global empty tracker that's always done:
43 43 finished_tracker = MessageTracker()
44 44
45 45 @decorator
46 46 def check_ready(f, self, *args, **kwargs):
47 47 """Call spin() to sync state prior to calling the method."""
48 48 self.wait(0)
49 49 if not self._ready:
50 50 raise error.TimeoutError("result not ready")
51 51 return f(self, *args, **kwargs)
52 52
53 53 class AsyncResult(object):
54 54 """Class for representing results of non-blocking calls.
55 55
56 56 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
57 57 """
58 58
59 59 msg_ids = None
60 60 _targets = None
61 61 _tracker = None
62 62 _single_result = False
63 63
64 64 def __init__(self, client, msg_ids, fname='unknown', targets=None, tracker=None):
65 65 if isinstance(msg_ids, string_types):
66 66 # always a list
67 67 msg_ids = [msg_ids]
68 68 self._single_result = True
69 69 else:
70 70 self._single_result = False
71 71 if tracker is None:
72 72 # default to always done
73 73 tracker = finished_tracker
74 74 self._client = client
75 75 self.msg_ids = msg_ids
76 76 self._fname=fname
77 77 self._targets = targets
78 78 self._tracker = tracker
79 79
80 80 self._ready = False
81 81 self._outputs_ready = False
82 82 self._success = None
83 self._metadata = [ self._client.metadata.get(id) for id in self.msg_ids ]
83 self._metadata = [self._client.metadata[id] for id in self.msg_ids]
84 84
85 85 def __repr__(self):
86 86 if self._ready:
87 87 return "<%s: finished>"%(self.__class__.__name__)
88 88 else:
89 89 return "<%s: %s>"%(self.__class__.__name__,self._fname)
90 90
91 91
92 92 def _reconstruct_result(self, res):
93 93 """Reconstruct our result from actual result list (always a list)
94 94
95 95 Override me in subclasses for turning a list of results
96 96 into the expected form.
97 97 """
98 98 if self._single_result:
99 99 return res[0]
100 100 else:
101 101 return res
102 102
103 103 def get(self, timeout=-1):
104 104 """Return the result when it arrives.
105 105
106 106 If `timeout` is not ``None`` and the result does not arrive within
107 107 `timeout` seconds then ``TimeoutError`` is raised. If the
108 108 remote call raised an exception then that exception will be reraised
109 109 by get() inside a `RemoteError`.
110 110 """
111 111 if not self.ready():
112 112 self.wait(timeout)
113 113
114 114 if self._ready:
115 115 if self._success:
116 116 return self._result
117 117 else:
118 118 raise self._exception
119 119 else:
120 120 raise error.TimeoutError("Result not ready.")
121 121
122 122 def _check_ready(self):
123 123 if not self.ready():
124 124 raise error.TimeoutError("Result not ready.")
125 125
126 126 def ready(self):
127 127 """Return whether the call has completed."""
128 128 if not self._ready:
129 129 self.wait(0)
130 130 elif not self._outputs_ready:
131 131 self._wait_for_outputs(0)
132 132
133 133 return self._ready
134 134
135 135 def wait(self, timeout=-1):
136 136 """Wait until the result is available or until `timeout` seconds pass.
137 137
138 138 This method always returns None.
139 139 """
140 140 if self._ready:
141 141 self._wait_for_outputs(timeout)
142 142 return
143 143 self._ready = self._client.wait(self.msg_ids, timeout)
144 144 if self._ready:
145 145 try:
146 results = map(self._client.results.get, self.msg_ids)
146 results = list(map(self._client.results.get, self.msg_ids))
147 147 self._result = results
148 148 if self._single_result:
149 149 r = results[0]
150 150 if isinstance(r, Exception):
151 151 raise r
152 152 else:
153 153 results = error.collect_exceptions(results, self._fname)
154 154 self._result = self._reconstruct_result(results)
155 155 except Exception as e:
156 156 self._exception = e
157 157 self._success = False
158 158 else:
159 159 self._success = True
160 160 finally:
161 161 if timeout is None or timeout < 0:
162 162 # cutoff infinite wait at 10s
163 163 timeout = 10
164 164 self._wait_for_outputs(timeout)
165 165
166 166
167 167 def successful(self):
168 168 """Return whether the call completed without raising an exception.
169 169
170 170 Will raise ``AssertionError`` if the result is not ready.
171 171 """
172 172 assert self.ready()
173 173 return self._success
174 174
175 175 #----------------------------------------------------------------
176 176 # Extra methods not in mp.pool.AsyncResult
177 177 #----------------------------------------------------------------
178 178
179 179 def get_dict(self, timeout=-1):
180 180 """Get the results as a dict, keyed by engine_id.
181 181
182 182 timeout behavior is described in `get()`.
183 183 """
184 184
185 185 results = self.get(timeout)
186 186 if self._single_result:
187 187 results = [results]
188 188 engine_ids = [ md['engine_id'] for md in self._metadata ]
189 189
190 190
191 191 rdict = {}
192 192 for engine_id, result in zip(engine_ids, results):
193 193 if engine_id in rdict:
194 194 raise ValueError("Cannot build dict, %i jobs ran on engine #%i" % (
195 195 engine_ids.count(engine_id), engine_id)
196 196 )
197 197 else:
198 198 rdict[engine_id] = result
199 199
200 200 return rdict
201 201
202 202 @property
203 203 def result(self):
204 204 """result property wrapper for `get(timeout=-1)`."""
205 205 return self.get()
206 206
207 207 # abbreviated alias:
208 208 r = result
209 209
210 210 @property
211 211 def metadata(self):
212 212 """property for accessing execution metadata."""
213 213 if self._single_result:
214 214 return self._metadata[0]
215 215 else:
216 216 return self._metadata
217 217
218 218 @property
219 219 def result_dict(self):
220 220 """result property as a dict."""
221 221 return self.get_dict()
222 222
223 223 def __dict__(self):
224 224 return self.get_dict(0)
225 225
226 226 def abort(self):
227 227 """abort my tasks."""
228 228 assert not self.ready(), "Can't abort, I am already done!"
229 229 return self._client.abort(self.msg_ids, targets=self._targets, block=True)
230 230
231 231 @property
232 232 def sent(self):
233 233 """check whether my messages have been sent."""
234 234 return self._tracker.done
235 235
236 236 def wait_for_send(self, timeout=-1):
237 237 """wait for pyzmq send to complete.
238 238
239 239 This is necessary when sending arrays that you intend to edit in-place.
240 240 `timeout` is in seconds, and will raise TimeoutError if it is reached
241 241 before the send completes.
242 242 """
243 243 return self._tracker.wait(timeout)
244 244
245 245 #-------------------------------------
246 246 # dict-access
247 247 #-------------------------------------
248 248
249 249 def __getitem__(self, key):
250 250 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
251 251 """
252 252 if isinstance(key, int):
253 253 self._check_ready()
254 254 return error.collect_exceptions([self._result[key]], self._fname)[0]
255 255 elif isinstance(key, slice):
256 256 self._check_ready()
257 257 return error.collect_exceptions(self._result[key], self._fname)
258 258 elif isinstance(key, string_types):
259 259 # metadata proxy *does not* require that results are done
260 260 self.wait(0)
261 261 values = [ md[key] for md in self._metadata ]
262 262 if self._single_result:
263 263 return values[0]
264 264 else:
265 265 return values
266 266 else:
267 267 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
268 268
269 269 def __getattr__(self, key):
270 270 """getattr maps to getitem for convenient attr access to metadata."""
271 271 try:
272 272 return self.__getitem__(key)
273 273 except (error.TimeoutError, KeyError):
274 274 raise AttributeError("%r object has no attribute %r"%(
275 275 self.__class__.__name__, key))
276 276
277 277 # asynchronous iterator:
278 278 def __iter__(self):
279 279 if self._single_result:
280 280 raise TypeError("AsyncResults with a single result are not iterable.")
281 281 try:
282 282 rlist = self.get(0)
283 283 except error.TimeoutError:
284 284 # wait for each result individually
285 285 for msg_id in self.msg_ids:
286 286 ar = AsyncResult(self._client, msg_id, self._fname)
287 287 yield ar.get()
288 288 else:
289 289 # already done
290 290 for r in rlist:
291 291 yield r
292 292
293 293 def __len__(self):
294 294 return len(self.msg_ids)
295 295
296 296 #-------------------------------------
297 297 # Sugar methods and attributes
298 298 #-------------------------------------
299 299
300 300 def timedelta(self, start, end, start_key=min, end_key=max):
301 301 """compute the difference between two sets of timestamps
302 302
303 303 The default behavior is to use the earliest of the first
304 304 and the latest of the second list, but this can be changed
305 305 by passing a different
306 306
307 307 Parameters
308 308 ----------
309 309
310 310 start : one or more datetime objects (e.g. ar.submitted)
311 311 end : one or more datetime objects (e.g. ar.received)
312 312 start_key : callable
313 313 Function to call on `start` to extract the relevant
314 314 entry [defalt: min]
315 315 end_key : callable
316 316 Function to call on `end` to extract the relevant
317 317 entry [default: max]
318 318
319 319 Returns
320 320 -------
321 321
322 322 dt : float
323 323 The time elapsed (in seconds) between the two selected timestamps.
324 324 """
325 325 if not isinstance(start, datetime):
326 326 # handle single_result AsyncResults, where ar.stamp is single object,
327 327 # not a list
328 328 start = start_key(start)
329 329 if not isinstance(end, datetime):
330 330 # handle single_result AsyncResults, where ar.stamp is single object,
331 331 # not a list
332 332 end = end_key(end)
333 333 return (end - start).total_seconds()
334 334
335 335 @property
336 336 def progress(self):
337 337 """the number of tasks which have been completed at this point.
338 338
339 339 Fractional progress would be given by 1.0 * ar.progress / len(ar)
340 340 """
341 341 self.wait(0)
342 342 return len(self) - len(set(self.msg_ids).intersection(self._client.outstanding))
343 343
344 344 @property
345 345 def elapsed(self):
346 346 """elapsed time since initial submission"""
347 347 if self.ready():
348 348 return self.wall_time
349 349
350 350 now = submitted = datetime.now()
351 351 for msg_id in self.msg_ids:
352 352 if msg_id in self._client.metadata:
353 353 stamp = self._client.metadata[msg_id]['submitted']
354 354 if stamp and stamp < submitted:
355 355 submitted = stamp
356 356 return (now-submitted).total_seconds()
357 357
358 358 @property
359 359 @check_ready
360 360 def serial_time(self):
361 361 """serial computation time of a parallel calculation
362 362
363 363 Computed as the sum of (completed-started) of each task
364 364 """
365 365 t = 0
366 366 for md in self._metadata:
367 367 t += (md['completed'] - md['started']).total_seconds()
368 368 return t
369 369
370 370 @property
371 371 @check_ready
372 372 def wall_time(self):
373 373 """actual computation time of a parallel calculation
374 374
375 375 Computed as the time between the latest `received` stamp
376 376 and the earliest `submitted`.
377 377
378 378 Only reliable if Client was spinning/waiting when the task finished, because
379 379 the `received` timestamp is created when a result is pulled off of the zmq queue,
380 380 which happens as a result of `client.spin()`.
381 381
382 382 For similar comparison of other timestamp pairs, check out AsyncResult.timedelta.
383 383
384 384 """
385 385 return self.timedelta(self.submitted, self.received)
386 386
387 387 def wait_interactive(self, interval=1., timeout=-1):
388 388 """interactive wait, printing progress at regular intervals"""
389 389 if timeout is None:
390 390 timeout = -1
391 391 N = len(self)
392 392 tic = time.time()
393 393 while not self.ready() and (timeout < 0 or time.time() - tic <= timeout):
394 394 self.wait(interval)
395 395 clear_output(wait=True)
396 396 print("%4i/%i tasks finished after %4i s" % (self.progress, N, self.elapsed), end="")
397 397 sys.stdout.flush()
398 398 print()
399 399 print("done")
400 400
401 401 def _republish_displaypub(self, content, eid):
402 402 """republish individual displaypub content dicts"""
403 403 try:
404 404 ip = get_ipython()
405 405 except NameError:
406 406 # displaypub is meaningless outside IPython
407 407 return
408 408 md = content['metadata'] or {}
409 409 md['engine'] = eid
410 410 ip.display_pub.publish(content['source'], content['data'], md)
411 411
412 412 def _display_stream(self, text, prefix='', file=None):
413 413 if not text:
414 414 # nothing to display
415 415 return
416 416 if file is None:
417 417 file = sys.stdout
418 418 end = '' if text.endswith('\n') else '\n'
419 419
420 420 multiline = text.count('\n') > int(text.endswith('\n'))
421 421 if prefix and multiline and not text.startswith('\n'):
422 422 prefix = prefix + '\n'
423 423 print("%s%s" % (prefix, text), file=file, end=end)
424 424
425 425
426 426 def _display_single_result(self):
427 427 self._display_stream(self.stdout)
428 428 self._display_stream(self.stderr, file=sys.stderr)
429 429
430 430 try:
431 431 get_ipython()
432 432 except NameError:
433 433 # displaypub is meaningless outside IPython
434 434 return
435 435
436 436 for output in self.outputs:
437 437 self._republish_displaypub(output, self.engine_id)
438 438
439 439 if self.pyout is not None:
440 440 display(self.get())
441 441
442 442 def _wait_for_outputs(self, timeout=-1):
443 443 """wait for the 'status=idle' message that indicates we have all outputs
444 444 """
445 445 if self._outputs_ready or not self._success:
446 446 # don't wait on errors
447 447 return
448 448
449 449 # cast None to -1 for infinite timeout
450 450 if timeout is None:
451 451 timeout = -1
452 452
453 453 tic = time.time()
454 454 while True:
455 455 self._client._flush_iopub(self._client._iopub_socket)
456 456 self._outputs_ready = all(md['outputs_ready']
457 457 for md in self._metadata)
458 458 if self._outputs_ready or \
459 459 (timeout >= 0 and time.time() > tic + timeout):
460 460 break
461 461 time.sleep(0.01)
462 462
463 463 @check_ready
464 464 def display_outputs(self, groupby="type"):
465 465 """republish the outputs of the computation
466 466
467 467 Parameters
468 468 ----------
469 469
470 470 groupby : str [default: type]
471 471 if 'type':
472 472 Group outputs by type (show all stdout, then all stderr, etc.):
473 473
474 474 [stdout:1] foo
475 475 [stdout:2] foo
476 476 [stderr:1] bar
477 477 [stderr:2] bar
478 478 if 'engine':
479 479 Display outputs for each engine before moving on to the next:
480 480
481 481 [stdout:1] foo
482 482 [stderr:1] bar
483 483 [stdout:2] foo
484 484 [stderr:2] bar
485 485
486 486 if 'order':
487 487 Like 'type', but further collate individual displaypub
488 488 outputs. This is meant for cases of each command producing
489 489 several plots, and you would like to see all of the first
490 490 plots together, then all of the second plots, and so on.
491 491 """
492 492 if self._single_result:
493 493 self._display_single_result()
494 494 return
495 495
496 496 stdouts = self.stdout
497 497 stderrs = self.stderr
498 498 pyouts = self.pyout
499 499 output_lists = self.outputs
500 500 results = self.get()
501 501
502 502 targets = self.engine_id
503 503
504 504 if groupby == "engine":
505 505 for eid,stdout,stderr,outputs,r,pyout in zip(
506 506 targets, stdouts, stderrs, output_lists, results, pyouts
507 507 ):
508 508 self._display_stream(stdout, '[stdout:%i] ' % eid)
509 509 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
510 510
511 511 try:
512 512 get_ipython()
513 513 except NameError:
514 514 # displaypub is meaningless outside IPython
515 515 return
516 516
517 517 if outputs or pyout is not None:
518 518 _raw_text('[output:%i]' % eid)
519 519
520 520 for output in outputs:
521 521 self._republish_displaypub(output, eid)
522 522
523 523 if pyout is not None:
524 524 display(r)
525 525
526 526 elif groupby in ('type', 'order'):
527 527 # republish stdout:
528 528 for eid,stdout in zip(targets, stdouts):
529 529 self._display_stream(stdout, '[stdout:%i] ' % eid)
530 530
531 531 # republish stderr:
532 532 for eid,stderr in zip(targets, stderrs):
533 533 self._display_stream(stderr, '[stderr:%i] ' % eid, file=sys.stderr)
534 534
535 535 try:
536 536 get_ipython()
537 537 except NameError:
538 538 # displaypub is meaningless outside IPython
539 539 return
540 540
541 541 if groupby == 'order':
542 542 output_dict = dict((eid, outputs) for eid,outputs in zip(targets, output_lists))
543 543 N = max(len(outputs) for outputs in output_lists)
544 544 for i in range(N):
545 545 for eid in targets:
546 546 outputs = output_dict[eid]
547 547 if len(outputs) >= N:
548 548 _raw_text('[output:%i]' % eid)
549 549 self._republish_displaypub(outputs[i], eid)
550 550 else:
551 551 # republish displaypub output
552 552 for eid,outputs in zip(targets, output_lists):
553 553 if outputs:
554 554 _raw_text('[output:%i]' % eid)
555 555 for output in outputs:
556 556 self._republish_displaypub(output, eid)
557 557
558 558 # finally, add pyout:
559 559 for eid,r,pyout in zip(targets, results, pyouts):
560 560 if pyout is not None:
561 561 display(r)
562 562
563 563 else:
564 564 raise ValueError("groupby must be one of 'type', 'engine', 'collate', not %r" % groupby)
565 565
566 566
567 567
568 568
569 569 class AsyncMapResult(AsyncResult):
570 570 """Class for representing results of non-blocking gathers.
571 571
572 572 This will properly reconstruct the gather.
573 573
574 574 This class is iterable at any time, and will wait on results as they come.
575 575
576 576 If ordered=False, then the first results to arrive will come first, otherwise
577 577 results will be yielded in the order they were submitted.
578 578
579 579 """
580 580
581 581 def __init__(self, client, msg_ids, mapObject, fname='', ordered=True):
582 582 AsyncResult.__init__(self, client, msg_ids, fname=fname)
583 583 self._mapObject = mapObject
584 584 self._single_result = False
585 585 self.ordered = ordered
586 586
587 587 def _reconstruct_result(self, res):
588 588 """Perform the gather on the actual results."""
589 589 return self._mapObject.joinPartitions(res)
590 590
591 591 # asynchronous iterator:
592 592 def __iter__(self):
593 593 it = self._ordered_iter if self.ordered else self._unordered_iter
594 594 for r in it():
595 595 yield r
596 596
597 597 # asynchronous ordered iterator:
598 598 def _ordered_iter(self):
599 599 """iterator for results *as they arrive*, preserving submission order."""
600 600 try:
601 601 rlist = self.get(0)
602 602 except error.TimeoutError:
603 603 # wait for each result individually
604 604 for msg_id in self.msg_ids:
605 605 ar = AsyncResult(self._client, msg_id, self._fname)
606 606 rlist = ar.get()
607 607 try:
608 608 for r in rlist:
609 609 yield r
610 610 except TypeError:
611 611 # flattened, not a list
612 612 # this could get broken by flattened data that returns iterables
613 613 # but most calls to map do not expose the `flatten` argument
614 614 yield rlist
615 615 else:
616 616 # already done
617 617 for r in rlist:
618 618 yield r
619 619
620 620 # asynchronous unordered iterator:
621 621 def _unordered_iter(self):
622 622 """iterator for results *as they arrive*, on FCFS basis, ignoring submission order."""
623 623 try:
624 624 rlist = self.get(0)
625 625 except error.TimeoutError:
626 626 pending = set(self.msg_ids)
627 627 while pending:
628 628 try:
629 629 self._client.wait(pending, 1e-3)
630 630 except error.TimeoutError:
631 631 # ignore timeout error, because that only means
632 632 # *some* jobs are outstanding
633 633 pass
634 634 # update ready set with those no longer outstanding:
635 635 ready = pending.difference(self._client.outstanding)
636 636 # update pending to exclude those that are finished
637 637 pending = pending.difference(ready)
638 638 while ready:
639 639 msg_id = ready.pop()
640 640 ar = AsyncResult(self._client, msg_id, self._fname)
641 641 rlist = ar.get()
642 642 try:
643 643 for r in rlist:
644 644 yield r
645 645 except TypeError:
646 646 # flattened, not a list
647 647 # this could get broken by flattened data that returns iterables
648 648 # but most calls to map do not expose the `flatten` argument
649 649 yield rlist
650 650 else:
651 651 # already done
652 652 for r in rlist:
653 653 yield r
654 654
655 655
656 656 class AsyncHubResult(AsyncResult):
657 657 """Class to wrap pending results that must be requested from the Hub.
658 658
659 659 Note that waiting/polling on these objects requires polling the Hubover the network,
660 660 so use `AsyncHubResult.wait()` sparingly.
661 661 """
662 662
663 663 def _wait_for_outputs(self, timeout=-1):
664 664 """no-op, because HubResults are never incomplete"""
665 665 self._outputs_ready = True
666 666
667 667 def wait(self, timeout=-1):
668 668 """wait for result to complete."""
669 669 start = time.time()
670 670 if self._ready:
671 671 return
672 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
672 local_ids = [m for m in self.msg_ids if m in self._client.outstanding]
673 673 local_ready = self._client.wait(local_ids, timeout)
674 674 if local_ready:
675 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
675 remote_ids = [m for m in self.msg_ids if m not in self._client.results]
676 676 if not remote_ids:
677 677 self._ready = True
678 678 else:
679 679 rdict = self._client.result_status(remote_ids, status_only=False)
680 680 pending = rdict['pending']
681 681 while pending and (timeout < 0 or time.time() < start+timeout):
682 682 rdict = self._client.result_status(remote_ids, status_only=False)
683 683 pending = rdict['pending']
684 684 if pending:
685 685 time.sleep(0.1)
686 686 if not pending:
687 687 self._ready = True
688 688 if self._ready:
689 689 try:
690 results = map(self._client.results.get, self.msg_ids)
690 results = list(map(self._client.results.get, self.msg_ids))
691 691 self._result = results
692 692 if self._single_result:
693 693 r = results[0]
694 694 if isinstance(r, Exception):
695 695 raise r
696 696 else:
697 697 results = error.collect_exceptions(results, self._fname)
698 698 self._result = self._reconstruct_result(results)
699 699 except Exception as e:
700 700 self._exception = e
701 701 self._success = False
702 702 else:
703 703 self._success = True
704 704 finally:
705 self._metadata = map(self._client.metadata.get, self.msg_ids)
705 self._metadata = [self._client.metadata[mid] for mid in self.msg_ids]
706 706
707 707 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult']
@@ -1,1854 +1,1855 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import os
20 20 import json
21 21 import sys
22 22 from threading import Thread, Event
23 23 import time
24 24 import warnings
25 25 from datetime import datetime
26 26 from getpass import getpass
27 27 from pprint import pprint
28 28
29 29 pjoin = os.path.join
30 30
31 31 import zmq
32 32 # from zmq.eventloop import ioloop, zmqstream
33 33
34 34 from IPython.config.configurable import MultipleInstanceError
35 35 from IPython.core.application import BaseIPythonApplication
36 36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37 37
38 38 from IPython.utils.capture import RichOutput
39 39 from IPython.utils.coloransi import TermColors
40 40 from IPython.utils.jsonutil import rekey
41 41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 42 from IPython.utils.path import get_ipython_dir
43 43 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
44 44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 45 Dict, List, Bool, Set, Any)
46 46 from IPython.external.decorator import decorator
47 47 from IPython.external.ssh import tunnel
48 48
49 49 from IPython.parallel import Reference
50 50 from IPython.parallel import error
51 51 from IPython.parallel import util
52 52
53 53 from IPython.kernel.zmq.session import Session, Message
54 54 from IPython.kernel.zmq import serialize
55 55
56 56 from .asyncresult import AsyncResult, AsyncHubResult
57 57 from .view import DirectView, LoadBalancedView
58 58
59 59 #--------------------------------------------------------------------------
60 60 # Decorators for Client methods
61 61 #--------------------------------------------------------------------------
62 62
63 63 @decorator
64 64 def spin_first(f, self, *args, **kwargs):
65 65 """Call spin() to sync state prior to calling the method."""
66 66 self.spin()
67 67 return f(self, *args, **kwargs)
68 68
69 69
70 70 #--------------------------------------------------------------------------
71 71 # Classes
72 72 #--------------------------------------------------------------------------
73 73
74 74
75 75 class ExecuteReply(RichOutput):
76 76 """wrapper for finished Execute results"""
77 77 def __init__(self, msg_id, content, metadata):
78 78 self.msg_id = msg_id
79 79 self._content = content
80 80 self.execution_count = content['execution_count']
81 81 self.metadata = metadata
82 82
83 83 # RichOutput overrides
84 84
85 85 @property
86 86 def source(self):
87 87 pyout = self.metadata['pyout']
88 88 if pyout:
89 89 return pyout.get('source', '')
90 90
91 91 @property
92 92 def data(self):
93 93 pyout = self.metadata['pyout']
94 94 if pyout:
95 95 return pyout.get('data', {})
96 96
97 97 @property
98 98 def _metadata(self):
99 99 pyout = self.metadata['pyout']
100 100 if pyout:
101 101 return pyout.get('metadata', {})
102 102
103 103 def display(self):
104 104 from IPython.display import publish_display_data
105 105 publish_display_data(self.source, self.data, self.metadata)
106 106
107 107 def _repr_mime_(self, mime):
108 108 if mime not in self.data:
109 109 return
110 110 data = self.data[mime]
111 111 if mime in self._metadata:
112 112 return data, self._metadata[mime]
113 113 else:
114 114 return data
115 115
116 116 def __getitem__(self, key):
117 117 return self.metadata[key]
118 118
119 119 def __getattr__(self, key):
120 120 if key not in self.metadata:
121 121 raise AttributeError(key)
122 122 return self.metadata[key]
123 123
124 124 def __repr__(self):
125 125 pyout = self.metadata['pyout'] or {'data':{}}
126 126 text_out = pyout['data'].get('text/plain', '')
127 127 if len(text_out) > 32:
128 128 text_out = text_out[:29] + '...'
129 129
130 130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
131 131
132 132 def _repr_pretty_(self, p, cycle):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 text_out = pyout['data'].get('text/plain', '')
135 135
136 136 if not text_out:
137 137 return
138 138
139 139 try:
140 140 ip = get_ipython()
141 141 except NameError:
142 142 colors = "NoColor"
143 143 else:
144 144 colors = ip.colors
145 145
146 146 if colors == "NoColor":
147 147 out = normal = ""
148 148 else:
149 149 out = TermColors.Red
150 150 normal = TermColors.Normal
151 151
152 152 if '\n' in text_out and not text_out.startswith('\n'):
153 153 # add newline for multiline reprs
154 154 text_out = '\n' + text_out
155 155
156 156 p.text(
157 157 out + u'Out[%i:%i]: ' % (
158 158 self.metadata['engine_id'], self.execution_count
159 159 ) + normal + text_out
160 160 )
161 161
162 162
163 163 class Metadata(dict):
164 164 """Subclass of dict for initializing metadata values.
165 165
166 166 Attribute access works on keys.
167 167
168 168 These objects have a strict set of keys - errors will raise if you try
169 169 to add new keys.
170 170 """
171 171 def __init__(self, *args, **kwargs):
172 172 dict.__init__(self)
173 173 md = {'msg_id' : None,
174 174 'submitted' : None,
175 175 'started' : None,
176 176 'completed' : None,
177 177 'received' : None,
178 178 'engine_uuid' : None,
179 179 'engine_id' : None,
180 180 'follow' : None,
181 181 'after' : None,
182 182 'status' : None,
183 183
184 184 'pyin' : None,
185 185 'pyout' : None,
186 186 'pyerr' : None,
187 187 'stdout' : '',
188 188 'stderr' : '',
189 189 'outputs' : [],
190 190 'data': {},
191 191 'outputs_ready' : False,
192 192 }
193 193 self.update(md)
194 194 self.update(dict(*args, **kwargs))
195 195
196 196 def __getattr__(self, key):
197 197 """getattr aliased to getitem"""
198 198 if key in self:
199 199 return self[key]
200 200 else:
201 201 raise AttributeError(key)
202 202
203 203 def __setattr__(self, key, value):
204 204 """setattr aliased to setitem, with strict"""
205 205 if key in self:
206 206 self[key] = value
207 207 else:
208 208 raise AttributeError(key)
209 209
210 210 def __setitem__(self, key, value):
211 211 """strict static key enforcement"""
212 212 if key in self:
213 213 dict.__setitem__(self, key, value)
214 214 else:
215 215 raise KeyError(key)
216 216
217 217
218 218 class Client(HasTraits):
219 219 """A semi-synchronous client to the IPython ZMQ cluster
220 220
221 221 Parameters
222 222 ----------
223 223
224 224 url_file : str/unicode; path to ipcontroller-client.json
225 225 This JSON file should contain all the information needed to connect to a cluster,
226 226 and is likely the only argument needed.
227 227 Connection information for the Hub's registration. If a json connector
228 228 file is given, then likely no further configuration is necessary.
229 229 [Default: use profile]
230 230 profile : bytes
231 231 The name of the Cluster profile to be used to find connector information.
232 232 If run from an IPython application, the default profile will be the same
233 233 as the running application, otherwise it will be 'default'.
234 234 cluster_id : str
235 235 String id to added to runtime files, to prevent name collisions when using
236 236 multiple clusters with a single profile simultaneously.
237 237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
238 238 Since this is text inserted into filenames, typical recommendations apply:
239 239 Simple character strings are ideal, and spaces are not recommended (but
240 240 should generally work)
241 241 context : zmq.Context
242 242 Pass an existing zmq.Context instance, otherwise the client will create its own.
243 243 debug : bool
244 244 flag for lots of message printing for debug purposes
245 245 timeout : int/float
246 246 time (in seconds) to wait for connection replies from the Hub
247 247 [Default: 10]
248 248
249 249 #-------------- session related args ----------------
250 250
251 251 config : Config object
252 252 If specified, this will be relayed to the Session for configuration
253 253 username : str
254 254 set username for the session object
255 255
256 256 #-------------- ssh related args ----------------
257 257 # These are args for configuring the ssh tunnel to be used
258 258 # credentials are used to forward connections over ssh to the Controller
259 259 # Note that the ip given in `addr` needs to be relative to sshserver
260 260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
261 261 # and set sshserver as the same machine the Controller is on. However,
262 262 # the only requirement is that sshserver is able to see the Controller
263 263 # (i.e. is within the same trusted network).
264 264
265 265 sshserver : str
266 266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
267 267 If keyfile or password is specified, and this is not, it will default to
268 268 the ip given in addr.
269 269 sshkey : str; path to ssh private key file
270 270 This specifies a key to be used in ssh login, default None.
271 271 Regular default ssh keys will be used without specifying this argument.
272 272 password : str
273 273 Your ssh password to sshserver. Note that if this is left None,
274 274 you will be prompted for it if passwordless key based login is unavailable.
275 275 paramiko : bool
276 276 flag for whether to use paramiko instead of shell ssh for tunneling.
277 277 [default: True on win32, False else]
278 278
279 279
280 280 Attributes
281 281 ----------
282 282
283 283 ids : list of int engine IDs
284 284 requesting the ids attribute always synchronizes
285 285 the registration state. To request ids without synchronization,
286 286 use semi-private _ids attributes.
287 287
288 288 history : list of msg_ids
289 289 a list of msg_ids, keeping track of all the execution
290 290 messages you have submitted in order.
291 291
292 292 outstanding : set of msg_ids
293 293 a set of msg_ids that have been submitted, but whose
294 294 results have not yet been received.
295 295
296 296 results : dict
297 297 a dict of all our results, keyed by msg_id
298 298
299 299 block : bool
300 300 determines default behavior when block not specified
301 301 in execution methods
302 302
303 303 Methods
304 304 -------
305 305
306 306 spin
307 307 flushes incoming results and registration state changes
308 308 control methods spin, and requesting `ids` also ensures up to date
309 309
310 310 wait
311 311 wait on one or more msg_ids
312 312
313 313 execution methods
314 314 apply
315 315 legacy: execute, run
316 316
317 317 data movement
318 318 push, pull, scatter, gather
319 319
320 320 query methods
321 321 queue_status, get_result, purge, result_status
322 322
323 323 control methods
324 324 abort, shutdown
325 325
326 326 """
327 327
328 328
329 329 block = Bool(False)
330 330 outstanding = Set()
331 331 results = Instance('collections.defaultdict', (dict,))
332 332 metadata = Instance('collections.defaultdict', (Metadata,))
333 333 history = List()
334 334 debug = Bool(False)
335 335 _spin_thread = Any()
336 336 _stop_spinning = Any()
337 337
338 338 profile=Unicode()
339 339 def _profile_default(self):
340 340 if BaseIPythonApplication.initialized():
341 341 # an IPython app *might* be running, try to get its profile
342 342 try:
343 343 return BaseIPythonApplication.instance().profile
344 344 except (AttributeError, MultipleInstanceError):
345 345 # could be a *different* subclass of config.Application,
346 346 # which would raise one of these two errors.
347 347 return u'default'
348 348 else:
349 349 return u'default'
350 350
351 351
352 352 _outstanding_dict = Instance('collections.defaultdict', (set,))
353 353 _ids = List()
354 354 _connected=Bool(False)
355 355 _ssh=Bool(False)
356 356 _context = Instance('zmq.Context')
357 357 _config = Dict()
358 358 _engines=Instance(util.ReverseDict, (), {})
359 359 # _hub_socket=Instance('zmq.Socket')
360 360 _query_socket=Instance('zmq.Socket')
361 361 _control_socket=Instance('zmq.Socket')
362 362 _iopub_socket=Instance('zmq.Socket')
363 363 _notification_socket=Instance('zmq.Socket')
364 364 _mux_socket=Instance('zmq.Socket')
365 365 _task_socket=Instance('zmq.Socket')
366 366 _task_scheme=Unicode()
367 367 _closed = False
368 368 _ignored_control_replies=Integer(0)
369 369 _ignored_hub_replies=Integer(0)
370 370
371 371 def __new__(self, *args, **kw):
372 372 # don't raise on positional args
373 373 return HasTraits.__new__(self, **kw)
374 374
375 375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
376 376 context=None, debug=False,
377 377 sshserver=None, sshkey=None, password=None, paramiko=None,
378 378 timeout=10, cluster_id=None, **extra_args
379 379 ):
380 380 if profile:
381 381 super(Client, self).__init__(debug=debug, profile=profile)
382 382 else:
383 383 super(Client, self).__init__(debug=debug)
384 384 if context is None:
385 385 context = zmq.Context.instance()
386 386 self._context = context
387 387 self._stop_spinning = Event()
388 388
389 389 if 'url_or_file' in extra_args:
390 390 url_file = extra_args['url_or_file']
391 391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
392 392
393 393 if url_file and util.is_url(url_file):
394 394 raise ValueError("single urls cannot be specified, url-files must be used.")
395 395
396 396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
397 397
398 398 if self._cd is not None:
399 399 if url_file is None:
400 400 if not cluster_id:
401 401 client_json = 'ipcontroller-client.json'
402 402 else:
403 403 client_json = 'ipcontroller-%s-client.json' % cluster_id
404 404 url_file = pjoin(self._cd.security_dir, client_json)
405 405 if url_file is None:
406 406 raise ValueError(
407 407 "I can't find enough information to connect to a hub!"
408 408 " Please specify at least one of url_file or profile."
409 409 )
410 410
411 411 with open(url_file) as f:
412 412 cfg = json.load(f)
413 413
414 414 self._task_scheme = cfg['task_scheme']
415 415
416 416 # sync defaults from args, json:
417 417 if sshserver:
418 418 cfg['ssh'] = sshserver
419 419
420 420 location = cfg.setdefault('location', None)
421 421
422 422 proto,addr = cfg['interface'].split('://')
423 423 addr = util.disambiguate_ip_address(addr, location)
424 424 cfg['interface'] = "%s://%s" % (proto, addr)
425 425
426 426 # turn interface,port into full urls:
427 427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
428 428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
429 429
430 430 url = cfg['registration']
431 431
432 432 if location is not None and addr == localhost():
433 433 # location specified, and connection is expected to be local
434 434 if not is_local_ip(location) and not sshserver:
435 435 # load ssh from JSON *only* if the controller is not on
436 436 # this machine
437 437 sshserver=cfg['ssh']
438 438 if not is_local_ip(location) and not sshserver:
439 439 # warn if no ssh specified, but SSH is probably needed
440 440 # This is only a warning, because the most likely cause
441 441 # is a local Controller on a laptop whose IP is dynamic
442 442 warnings.warn("""
443 443 Controller appears to be listening on localhost, but not on this machine.
444 444 If this is true, you should specify Client(...,sshserver='you@%s')
445 445 or instruct your controller to listen on an external IP."""%location,
446 446 RuntimeWarning)
447 447 elif not sshserver:
448 448 # otherwise sync with cfg
449 449 sshserver = cfg['ssh']
450 450
451 451 self._config = cfg
452 452
453 453 self._ssh = bool(sshserver or sshkey or password)
454 454 if self._ssh and sshserver is None:
455 455 # default to ssh via localhost
456 456 sshserver = addr
457 457 if self._ssh and password is None:
458 458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
459 459 password=False
460 460 else:
461 461 password = getpass("SSH Password for %s: "%sshserver)
462 462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
463 463
464 464 # configure and construct the session
465 465 try:
466 466 extra_args['packer'] = cfg['pack']
467 467 extra_args['unpacker'] = cfg['unpack']
468 468 extra_args['key'] = cast_bytes(cfg['key'])
469 469 extra_args['signature_scheme'] = cfg['signature_scheme']
470 470 except KeyError as exc:
471 471 msg = '\n'.join([
472 472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
473 473 "If you are reusing connection files, remove them and start ipcontroller again."
474 474 ])
475 475 raise ValueError(msg.format(exc.message))
476 476
477 477 self.session = Session(**extra_args)
478 478
479 479 self._query_socket = self._context.socket(zmq.DEALER)
480 480
481 481 if self._ssh:
482 482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
483 483 else:
484 484 self._query_socket.connect(cfg['registration'])
485 485
486 486 self.session.debug = self.debug
487 487
488 488 self._notification_handlers = {'registration_notification' : self._register_engine,
489 489 'unregistration_notification' : self._unregister_engine,
490 490 'shutdown_notification' : lambda msg: self.close(),
491 491 }
492 492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
493 493 'apply_reply' : self._handle_apply_reply}
494 494
495 495 try:
496 496 self._connect(sshserver, ssh_kwargs, timeout)
497 497 except:
498 498 self.close(linger=0)
499 499 raise
500 500
501 501 # last step: setup magics, if we are in IPython:
502 502
503 503 try:
504 504 ip = get_ipython()
505 505 except NameError:
506 506 return
507 507 else:
508 508 if 'px' not in ip.magics_manager.magics:
509 509 # in IPython but we are the first Client.
510 510 # activate a default view for parallel magics.
511 511 self.activate()
512 512
513 513 def __del__(self):
514 514 """cleanup sockets, but _not_ context."""
515 515 self.close()
516 516
517 517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
518 518 if ipython_dir is None:
519 519 ipython_dir = get_ipython_dir()
520 520 if profile_dir is not None:
521 521 try:
522 522 self._cd = ProfileDir.find_profile_dir(profile_dir)
523 523 return
524 524 except ProfileDirError:
525 525 pass
526 526 elif profile is not None:
527 527 try:
528 528 self._cd = ProfileDir.find_profile_dir_by_name(
529 529 ipython_dir, profile)
530 530 return
531 531 except ProfileDirError:
532 532 pass
533 533 self._cd = None
534 534
535 535 def _update_engines(self, engines):
536 536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
537 537 for k,v in iteritems(engines):
538 538 eid = int(k)
539 539 if eid not in self._engines:
540 540 self._ids.append(eid)
541 541 self._engines[eid] = v
542 542 self._ids = sorted(self._ids)
543 if sorted(self._engines.keys()) != range(len(self._engines)) and \
543 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
544 544 self._task_scheme == 'pure' and self._task_socket:
545 545 self._stop_scheduling_tasks()
546 546
547 547 def _stop_scheduling_tasks(self):
548 548 """Stop scheduling tasks because an engine has been unregistered
549 549 from a pure ZMQ scheduler.
550 550 """
551 551 self._task_socket.close()
552 552 self._task_socket = None
553 553 msg = "An engine has been unregistered, and we are using pure " +\
554 554 "ZMQ task scheduling. Task farming will be disabled."
555 555 if self.outstanding:
556 556 msg += " If you were running tasks when this happened, " +\
557 557 "some `outstanding` msg_ids may never resolve."
558 558 warnings.warn(msg, RuntimeWarning)
559 559
560 560 def _build_targets(self, targets):
561 561 """Turn valid target IDs or 'all' into two lists:
562 562 (int_ids, uuids).
563 563 """
564 564 if not self._ids:
565 565 # flush notification socket if no engines yet, just in case
566 566 if not self.ids:
567 567 raise error.NoEnginesRegistered("Can't build targets without any engines")
568 568
569 569 if targets is None:
570 570 targets = self._ids
571 571 elif isinstance(targets, string_types):
572 572 if targets.lower() == 'all':
573 573 targets = self._ids
574 574 else:
575 575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
576 576 elif isinstance(targets, int):
577 577 if targets < 0:
578 578 targets = self.ids[targets]
579 579 if targets not in self._ids:
580 580 raise IndexError("No such engine: %i"%targets)
581 581 targets = [targets]
582 582
583 583 if isinstance(targets, slice):
584 indices = range(len(self._ids))[targets]
584 indices = list(range(len(self._ids))[targets])
585 585 ids = self.ids
586 586 targets = [ ids[i] for i in indices ]
587 587
588 588 if not isinstance(targets, (tuple, list, xrange)):
589 589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
590 590
591 591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
592 592
593 593 def _connect(self, sshserver, ssh_kwargs, timeout):
594 594 """setup all our socket connections to the cluster. This is called from
595 595 __init__."""
596 596
597 597 # Maybe allow reconnecting?
598 598 if self._connected:
599 599 return
600 600 self._connected=True
601 601
602 602 def connect_socket(s, url):
603 603 if self._ssh:
604 604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
605 605 else:
606 606 return s.connect(url)
607 607
608 608 self.session.send(self._query_socket, 'connection_request')
609 609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
610 610 poller = zmq.Poller()
611 611 poller.register(self._query_socket, zmq.POLLIN)
612 612 # poll expects milliseconds, timeout is seconds
613 613 evts = poller.poll(timeout*1000)
614 614 if not evts:
615 615 raise error.TimeoutError("Hub connection request timed out")
616 616 idents,msg = self.session.recv(self._query_socket,mode=0)
617 617 if self.debug:
618 618 pprint(msg)
619 619 content = msg['content']
620 620 # self._config['registration'] = dict(content)
621 621 cfg = self._config
622 622 if content['status'] == 'ok':
623 623 self._mux_socket = self._context.socket(zmq.DEALER)
624 624 connect_socket(self._mux_socket, cfg['mux'])
625 625
626 626 self._task_socket = self._context.socket(zmq.DEALER)
627 627 connect_socket(self._task_socket, cfg['task'])
628 628
629 629 self._notification_socket = self._context.socket(zmq.SUB)
630 630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
631 631 connect_socket(self._notification_socket, cfg['notification'])
632 632
633 633 self._control_socket = self._context.socket(zmq.DEALER)
634 634 connect_socket(self._control_socket, cfg['control'])
635 635
636 636 self._iopub_socket = self._context.socket(zmq.SUB)
637 637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
638 638 connect_socket(self._iopub_socket, cfg['iopub'])
639 639
640 640 self._update_engines(dict(content['engines']))
641 641 else:
642 642 self._connected = False
643 643 raise Exception("Failed to connect!")
644 644
645 645 #--------------------------------------------------------------------------
646 646 # handlers and callbacks for incoming messages
647 647 #--------------------------------------------------------------------------
648 648
649 649 def _unwrap_exception(self, content):
650 650 """unwrap exception, and remap engine_id to int."""
651 651 e = error.unwrap_exception(content)
652 652 # print e.traceback
653 653 if e.engine_info:
654 654 e_uuid = e.engine_info['engine_uuid']
655 655 eid = self._engines[e_uuid]
656 656 e.engine_info['engine_id'] = eid
657 657 return e
658 658
659 659 def _extract_metadata(self, msg):
660 660 header = msg['header']
661 661 parent = msg['parent_header']
662 662 msg_meta = msg['metadata']
663 663 content = msg['content']
664 664 md = {'msg_id' : parent['msg_id'],
665 665 'received' : datetime.now(),
666 666 'engine_uuid' : msg_meta.get('engine', None),
667 667 'follow' : msg_meta.get('follow', []),
668 668 'after' : msg_meta.get('after', []),
669 669 'status' : content['status'],
670 670 }
671 671
672 672 if md['engine_uuid'] is not None:
673 673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
674 674
675 675 if 'date' in parent:
676 676 md['submitted'] = parent['date']
677 677 if 'started' in msg_meta:
678 678 md['started'] = msg_meta['started']
679 679 if 'date' in header:
680 680 md['completed'] = header['date']
681 681 return md
682 682
683 683 def _register_engine(self, msg):
684 684 """Register a new engine, and update our connection info."""
685 685 content = msg['content']
686 686 eid = content['id']
687 687 d = {eid : content['uuid']}
688 688 self._update_engines(d)
689 689
690 690 def _unregister_engine(self, msg):
691 691 """Unregister an engine that has died."""
692 692 content = msg['content']
693 693 eid = int(content['id'])
694 694 if eid in self._ids:
695 695 self._ids.remove(eid)
696 696 uuid = self._engines.pop(eid)
697 697
698 698 self._handle_stranded_msgs(eid, uuid)
699 699
700 700 if self._task_socket and self._task_scheme == 'pure':
701 701 self._stop_scheduling_tasks()
702 702
703 703 def _handle_stranded_msgs(self, eid, uuid):
704 704 """Handle messages known to be on an engine when the engine unregisters.
705 705
706 706 It is possible that this will fire prematurely - that is, an engine will
707 707 go down after completing a result, and the client will be notified
708 708 of the unregistration and later receive the successful result.
709 709 """
710 710
711 711 outstanding = self._outstanding_dict[uuid]
712 712
713 713 for msg_id in list(outstanding):
714 714 if msg_id in self.results:
715 715 # we already
716 716 continue
717 717 try:
718 718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
719 719 except:
720 720 content = error.wrap_exception()
721 721 # build a fake message:
722 722 msg = self.session.msg('apply_reply', content=content)
723 723 msg['parent_header']['msg_id'] = msg_id
724 724 msg['metadata']['engine'] = uuid
725 725 self._handle_apply_reply(msg)
726 726
727 727 def _handle_execute_reply(self, msg):
728 728 """Save the reply to an execute_request into our results.
729 729
730 730 execute messages are never actually used. apply is used instead.
731 731 """
732 732
733 733 parent = msg['parent_header']
734 734 msg_id = parent['msg_id']
735 735 if msg_id not in self.outstanding:
736 736 if msg_id in self.history:
737 737 print(("got stale result: %s"%msg_id))
738 738 else:
739 739 print(("got unknown result: %s"%msg_id))
740 740 else:
741 741 self.outstanding.remove(msg_id)
742 742
743 743 content = msg['content']
744 744 header = msg['header']
745 745
746 746 # construct metadata:
747 747 md = self.metadata[msg_id]
748 748 md.update(self._extract_metadata(msg))
749 749 # is this redundant?
750 750 self.metadata[msg_id] = md
751 751
752 752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
753 753 if msg_id in e_outstanding:
754 754 e_outstanding.remove(msg_id)
755 755
756 756 # construct result:
757 757 if content['status'] == 'ok':
758 758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
759 759 elif content['status'] == 'aborted':
760 760 self.results[msg_id] = error.TaskAborted(msg_id)
761 761 elif content['status'] == 'resubmitted':
762 762 # TODO: handle resubmission
763 763 pass
764 764 else:
765 765 self.results[msg_id] = self._unwrap_exception(content)
766 766
767 767 def _handle_apply_reply(self, msg):
768 768 """Save the reply to an apply_request into our results."""
769 769 parent = msg['parent_header']
770 770 msg_id = parent['msg_id']
771 771 if msg_id not in self.outstanding:
772 772 if msg_id in self.history:
773 773 print(("got stale result: %s"%msg_id))
774 774 print(self.results[msg_id])
775 775 print(msg)
776 776 else:
777 777 print(("got unknown result: %s"%msg_id))
778 778 else:
779 779 self.outstanding.remove(msg_id)
780 780 content = msg['content']
781 781 header = msg['header']
782 782
783 783 # construct metadata:
784 784 md = self.metadata[msg_id]
785 785 md.update(self._extract_metadata(msg))
786 786 # is this redundant?
787 787 self.metadata[msg_id] = md
788 788
789 789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
790 790 if msg_id in e_outstanding:
791 791 e_outstanding.remove(msg_id)
792 792
793 793 # construct result:
794 794 if content['status'] == 'ok':
795 795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
796 796 elif content['status'] == 'aborted':
797 797 self.results[msg_id] = error.TaskAborted(msg_id)
798 798 elif content['status'] == 'resubmitted':
799 799 # TODO: handle resubmission
800 800 pass
801 801 else:
802 802 self.results[msg_id] = self._unwrap_exception(content)
803 803
804 804 def _flush_notifications(self):
805 805 """Flush notifications of engine registrations waiting
806 806 in ZMQ queue."""
807 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 808 while msg is not None:
809 809 if self.debug:
810 810 pprint(msg)
811 811 msg_type = msg['header']['msg_type']
812 812 handler = self._notification_handlers.get(msg_type, None)
813 813 if handler is None:
814 814 raise Exception("Unhandled message type: %s" % msg_type)
815 815 else:
816 816 handler(msg)
817 817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
818 818
819 819 def _flush_results(self, sock):
820 820 """Flush task or queue results waiting in ZMQ queue."""
821 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 822 while msg is not None:
823 823 if self.debug:
824 824 pprint(msg)
825 825 msg_type = msg['header']['msg_type']
826 826 handler = self._queue_handlers.get(msg_type, None)
827 827 if handler is None:
828 828 raise Exception("Unhandled message type: %s" % msg_type)
829 829 else:
830 830 handler(msg)
831 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832 832
833 833 def _flush_control(self, sock):
834 834 """Flush replies from the control channel waiting
835 835 in the ZMQ queue.
836 836
837 837 Currently: ignore them."""
838 838 if self._ignored_control_replies <= 0:
839 839 return
840 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 841 while msg is not None:
842 842 self._ignored_control_replies -= 1
843 843 if self.debug:
844 844 pprint(msg)
845 845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846 846
847 847 def _flush_ignored_control(self):
848 848 """flush ignored control replies"""
849 849 while self._ignored_control_replies > 0:
850 850 self.session.recv(self._control_socket)
851 851 self._ignored_control_replies -= 1
852 852
853 853 def _flush_ignored_hub_replies(self):
854 854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
855 855 while msg is not None:
856 856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
857 857
858 858 def _flush_iopub(self, sock):
859 859 """Flush replies from the iopub channel waiting
860 860 in the ZMQ queue.
861 861 """
862 862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
863 863 while msg is not None:
864 864 if self.debug:
865 865 pprint(msg)
866 866 parent = msg['parent_header']
867 867 # ignore IOPub messages with no parent.
868 868 # Caused by print statements or warnings from before the first execution.
869 869 if not parent:
870 870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 871 continue
872 872 msg_id = parent['msg_id']
873 873 content = msg['content']
874 874 header = msg['header']
875 875 msg_type = msg['header']['msg_type']
876 876
877 877 # init metadata:
878 878 md = self.metadata[msg_id]
879 879
880 880 if msg_type == 'stream':
881 881 name = content['name']
882 882 s = md[name] or ''
883 883 md[name] = s + content['data']
884 884 elif msg_type == 'pyerr':
885 885 md.update({'pyerr' : self._unwrap_exception(content)})
886 886 elif msg_type == 'pyin':
887 887 md.update({'pyin' : content['code']})
888 888 elif msg_type == 'display_data':
889 889 md['outputs'].append(content)
890 890 elif msg_type == 'pyout':
891 891 md['pyout'] = content
892 892 elif msg_type == 'data_message':
893 893 data, remainder = serialize.unserialize_object(msg['buffers'])
894 894 md['data'].update(data)
895 895 elif msg_type == 'status':
896 896 # idle message comes after all outputs
897 897 if content['execution_state'] == 'idle':
898 898 md['outputs_ready'] = True
899 899 else:
900 900 # unhandled msg_type (status, etc.)
901 901 pass
902 902
903 903 # reduntant?
904 904 self.metadata[msg_id] = md
905 905
906 906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
907 907
908 908 #--------------------------------------------------------------------------
909 909 # len, getitem
910 910 #--------------------------------------------------------------------------
911 911
912 912 def __len__(self):
913 913 """len(client) returns # of engines."""
914 914 return len(self.ids)
915 915
916 916 def __getitem__(self, key):
917 917 """index access returns DirectView multiplexer objects
918 918
919 919 Must be int, slice, or list/tuple/xrange of ints"""
920 920 if not isinstance(key, (int, slice, tuple, list, xrange)):
921 921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
922 922 else:
923 923 return self.direct_view(key)
924 924
925 925 #--------------------------------------------------------------------------
926 926 # Begin public methods
927 927 #--------------------------------------------------------------------------
928 928
929 929 @property
930 930 def ids(self):
931 931 """Always up-to-date ids property."""
932 932 self._flush_notifications()
933 933 # always copy:
934 934 return list(self._ids)
935 935
936 936 def activate(self, targets='all', suffix=''):
937 937 """Create a DirectView and register it with IPython magics
938 938
939 939 Defines the magics `%px, %autopx, %pxresult, %%px`
940 940
941 941 Parameters
942 942 ----------
943 943
944 944 targets: int, list of ints, or 'all'
945 945 The engines on which the view's magics will run
946 946 suffix: str [default: '']
947 947 The suffix, if any, for the magics. This allows you to have
948 948 multiple views associated with parallel magics at the same time.
949 949
950 950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
951 951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
952 952 on engine 0.
953 953 """
954 954 view = self.direct_view(targets)
955 955 view.block = True
956 956 view.activate(suffix)
957 957 return view
958 958
959 959 def close(self, linger=None):
960 960 """Close my zmq Sockets
961 961
962 962 If `linger`, set the zmq LINGER socket option,
963 963 which allows discarding of messages.
964 964 """
965 965 if self._closed:
966 966 return
967 967 self.stop_spin_thread()
968 968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
969 969 for name in snames:
970 970 socket = getattr(self, name)
971 971 if socket is not None and not socket.closed:
972 972 if linger is not None:
973 973 socket.close(linger=linger)
974 974 else:
975 975 socket.close()
976 976 self._closed = True
977 977
978 978 def _spin_every(self, interval=1):
979 979 """target func for use in spin_thread"""
980 980 while True:
981 981 if self._stop_spinning.is_set():
982 982 return
983 983 time.sleep(interval)
984 984 self.spin()
985 985
986 986 def spin_thread(self, interval=1):
987 987 """call Client.spin() in a background thread on some regular interval
988 988
989 989 This helps ensure that messages don't pile up too much in the zmq queue
990 990 while you are working on other things, or just leaving an idle terminal.
991 991
992 992 It also helps limit potential padding of the `received` timestamp
993 993 on AsyncResult objects, used for timings.
994 994
995 995 Parameters
996 996 ----------
997 997
998 998 interval : float, optional
999 999 The interval on which to spin the client in the background thread
1000 1000 (simply passed to time.sleep).
1001 1001
1002 1002 Notes
1003 1003 -----
1004 1004
1005 1005 For precision timing, you may want to use this method to put a bound
1006 1006 on the jitter (in seconds) in `received` timestamps used
1007 1007 in AsyncResult.wall_time.
1008 1008
1009 1009 """
1010 1010 if self._spin_thread is not None:
1011 1011 self.stop_spin_thread()
1012 1012 self._stop_spinning.clear()
1013 1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1014 1014 self._spin_thread.daemon = True
1015 1015 self._spin_thread.start()
1016 1016
1017 1017 def stop_spin_thread(self):
1018 1018 """stop background spin_thread, if any"""
1019 1019 if self._spin_thread is not None:
1020 1020 self._stop_spinning.set()
1021 1021 self._spin_thread.join()
1022 1022 self._spin_thread = None
1023 1023
1024 1024 def spin(self):
1025 1025 """Flush any registration notifications and execution results
1026 1026 waiting in the ZMQ queue.
1027 1027 """
1028 1028 if self._notification_socket:
1029 1029 self._flush_notifications()
1030 1030 if self._iopub_socket:
1031 1031 self._flush_iopub(self._iopub_socket)
1032 1032 if self._mux_socket:
1033 1033 self._flush_results(self._mux_socket)
1034 1034 if self._task_socket:
1035 1035 self._flush_results(self._task_socket)
1036 1036 if self._control_socket:
1037 1037 self._flush_control(self._control_socket)
1038 1038 if self._query_socket:
1039 1039 self._flush_ignored_hub_replies()
1040 1040
1041 1041 def wait(self, jobs=None, timeout=-1):
1042 1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1043 1043
1044 1044 Parameters
1045 1045 ----------
1046 1046
1047 1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1048 1048 ints are indices to self.history
1049 1049 strs are msg_ids
1050 1050 default: wait on all outstanding messages
1051 1051 timeout : float
1052 1052 a time in seconds, after which to give up.
1053 1053 default is -1, which means no timeout
1054 1054
1055 1055 Returns
1056 1056 -------
1057 1057
1058 1058 True : when all msg_ids are done
1059 1059 False : timeout reached, some msg_ids still outstanding
1060 1060 """
1061 1061 tic = time.time()
1062 1062 if jobs is None:
1063 1063 theids = self.outstanding
1064 1064 else:
1065 1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1066 1066 jobs = [jobs]
1067 1067 theids = set()
1068 1068 for job in jobs:
1069 1069 if isinstance(job, int):
1070 1070 # index access
1071 1071 job = self.history[job]
1072 1072 elif isinstance(job, AsyncResult):
1073 map(theids.add, job.msg_ids)
1073 theids.update(job.msg_ids)
1074 1074 continue
1075 1075 theids.add(job)
1076 1076 if not theids.intersection(self.outstanding):
1077 1077 return True
1078 1078 self.spin()
1079 1079 while theids.intersection(self.outstanding):
1080 1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1081 1081 break
1082 1082 time.sleep(1e-3)
1083 1083 self.spin()
1084 1084 return len(theids.intersection(self.outstanding)) == 0
1085 1085
1086 1086 #--------------------------------------------------------------------------
1087 1087 # Control methods
1088 1088 #--------------------------------------------------------------------------
1089 1089
1090 1090 @spin_first
1091 1091 def clear(self, targets=None, block=None):
1092 1092 """Clear the namespace in target(s)."""
1093 1093 block = self.block if block is None else block
1094 1094 targets = self._build_targets(targets)[0]
1095 1095 for t in targets:
1096 1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1097 1097 error = False
1098 1098 if block:
1099 1099 self._flush_ignored_control()
1100 1100 for i in range(len(targets)):
1101 1101 idents,msg = self.session.recv(self._control_socket,0)
1102 1102 if self.debug:
1103 1103 pprint(msg)
1104 1104 if msg['content']['status'] != 'ok':
1105 1105 error = self._unwrap_exception(msg['content'])
1106 1106 else:
1107 1107 self._ignored_control_replies += len(targets)
1108 1108 if error:
1109 1109 raise error
1110 1110
1111 1111
1112 1112 @spin_first
1113 1113 def abort(self, jobs=None, targets=None, block=None):
1114 1114 """Abort specific jobs from the execution queues of target(s).
1115 1115
1116 1116 This is a mechanism to prevent jobs that have already been submitted
1117 1117 from executing.
1118 1118
1119 1119 Parameters
1120 1120 ----------
1121 1121
1122 1122 jobs : msg_id, list of msg_ids, or AsyncResult
1123 1123 The jobs to be aborted
1124 1124
1125 1125 If unspecified/None: abort all outstanding jobs.
1126 1126
1127 1127 """
1128 1128 block = self.block if block is None else block
1129 1129 jobs = jobs if jobs is not None else list(self.outstanding)
1130 1130 targets = self._build_targets(targets)[0]
1131 1131
1132 1132 msg_ids = []
1133 1133 if isinstance(jobs, string_types + (AsyncResult,)):
1134 1134 jobs = [jobs]
1135 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1135 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1136 1136 if bad_ids:
1137 1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1138 1138 for j in jobs:
1139 1139 if isinstance(j, AsyncResult):
1140 1140 msg_ids.extend(j.msg_ids)
1141 1141 else:
1142 1142 msg_ids.append(j)
1143 1143 content = dict(msg_ids=msg_ids)
1144 1144 for t in targets:
1145 1145 self.session.send(self._control_socket, 'abort_request',
1146 1146 content=content, ident=t)
1147 1147 error = False
1148 1148 if block:
1149 1149 self._flush_ignored_control()
1150 1150 for i in range(len(targets)):
1151 1151 idents,msg = self.session.recv(self._control_socket,0)
1152 1152 if self.debug:
1153 1153 pprint(msg)
1154 1154 if msg['content']['status'] != 'ok':
1155 1155 error = self._unwrap_exception(msg['content'])
1156 1156 else:
1157 1157 self._ignored_control_replies += len(targets)
1158 1158 if error:
1159 1159 raise error
1160 1160
1161 1161 @spin_first
1162 1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1163 1163 """Terminates one or more engine processes, optionally including the hub.
1164 1164
1165 1165 Parameters
1166 1166 ----------
1167 1167
1168 1168 targets: list of ints or 'all' [default: all]
1169 1169 Which engines to shutdown.
1170 1170 hub: bool [default: False]
1171 1171 Whether to include the Hub. hub=True implies targets='all'.
1172 1172 block: bool [default: self.block]
1173 1173 Whether to wait for clean shutdown replies or not.
1174 1174 restart: bool [default: False]
1175 1175 NOT IMPLEMENTED
1176 1176 whether to restart engines after shutting them down.
1177 1177 """
1178 1178 from IPython.parallel.error import NoEnginesRegistered
1179 1179 if restart:
1180 1180 raise NotImplementedError("Engine restart is not yet implemented")
1181 1181
1182 1182 block = self.block if block is None else block
1183 1183 if hub:
1184 1184 targets = 'all'
1185 1185 try:
1186 1186 targets = self._build_targets(targets)[0]
1187 1187 except NoEnginesRegistered:
1188 1188 targets = []
1189 1189 for t in targets:
1190 1190 self.session.send(self._control_socket, 'shutdown_request',
1191 1191 content={'restart':restart},ident=t)
1192 1192 error = False
1193 1193 if block or hub:
1194 1194 self._flush_ignored_control()
1195 1195 for i in range(len(targets)):
1196 1196 idents,msg = self.session.recv(self._control_socket, 0)
1197 1197 if self.debug:
1198 1198 pprint(msg)
1199 1199 if msg['content']['status'] != 'ok':
1200 1200 error = self._unwrap_exception(msg['content'])
1201 1201 else:
1202 1202 self._ignored_control_replies += len(targets)
1203 1203
1204 1204 if hub:
1205 1205 time.sleep(0.25)
1206 1206 self.session.send(self._query_socket, 'shutdown_request')
1207 1207 idents,msg = self.session.recv(self._query_socket, 0)
1208 1208 if self.debug:
1209 1209 pprint(msg)
1210 1210 if msg['content']['status'] != 'ok':
1211 1211 error = self._unwrap_exception(msg['content'])
1212 1212
1213 1213 if error:
1214 1214 raise error
1215 1215
1216 1216 #--------------------------------------------------------------------------
1217 1217 # Execution related methods
1218 1218 #--------------------------------------------------------------------------
1219 1219
1220 1220 def _maybe_raise(self, result):
1221 1221 """wrapper for maybe raising an exception if apply failed."""
1222 1222 if isinstance(result, error.RemoteError):
1223 1223 raise result
1224 1224
1225 1225 return result
1226 1226
1227 1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1228 1228 ident=None):
1229 1229 """construct and send an apply message via a socket.
1230 1230
1231 1231 This is the principal method with which all engine execution is performed by views.
1232 1232 """
1233 1233
1234 1234 if self._closed:
1235 1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1236 1236
1237 1237 # defaults:
1238 1238 args = args if args is not None else []
1239 1239 kwargs = kwargs if kwargs is not None else {}
1240 1240 metadata = metadata if metadata is not None else {}
1241 1241
1242 1242 # validate arguments
1243 1243 if not callable(f) and not isinstance(f, Reference):
1244 1244 raise TypeError("f must be callable, not %s"%type(f))
1245 1245 if not isinstance(args, (tuple, list)):
1246 1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1247 1247 if not isinstance(kwargs, dict):
1248 1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1249 1249 if not isinstance(metadata, dict):
1250 1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1251 1251
1252 1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1253 1253 buffer_threshold=self.session.buffer_threshold,
1254 1254 item_threshold=self.session.item_threshold,
1255 1255 )
1256 1256
1257 1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1258 1258 metadata=metadata, track=track)
1259 1259
1260 1260 msg_id = msg['header']['msg_id']
1261 1261 self.outstanding.add(msg_id)
1262 1262 if ident:
1263 1263 # possibly routed to a specific engine
1264 1264 if isinstance(ident, list):
1265 1265 ident = ident[-1]
1266 1266 if ident in self._engines.values():
1267 1267 # save for later, in case of engine death
1268 1268 self._outstanding_dict[ident].add(msg_id)
1269 1269 self.history.append(msg_id)
1270 1270 self.metadata[msg_id]['submitted'] = datetime.now()
1271 1271
1272 1272 return msg
1273 1273
1274 1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1275 1275 """construct and send an execute request via a socket.
1276 1276
1277 1277 """
1278 1278
1279 1279 if self._closed:
1280 1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1281 1281
1282 1282 # defaults:
1283 1283 metadata = metadata if metadata is not None else {}
1284 1284
1285 1285 # validate arguments
1286 1286 if not isinstance(code, string_types):
1287 1287 raise TypeError("code must be text, not %s" % type(code))
1288 1288 if not isinstance(metadata, dict):
1289 1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1290 1290
1291 1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1292 1292
1293 1293
1294 1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1295 1295 metadata=metadata)
1296 1296
1297 1297 msg_id = msg['header']['msg_id']
1298 1298 self.outstanding.add(msg_id)
1299 1299 if ident:
1300 1300 # possibly routed to a specific engine
1301 1301 if isinstance(ident, list):
1302 1302 ident = ident[-1]
1303 1303 if ident in self._engines.values():
1304 1304 # save for later, in case of engine death
1305 1305 self._outstanding_dict[ident].add(msg_id)
1306 1306 self.history.append(msg_id)
1307 1307 self.metadata[msg_id]['submitted'] = datetime.now()
1308 1308
1309 1309 return msg
1310 1310
1311 1311 #--------------------------------------------------------------------------
1312 1312 # construct a View object
1313 1313 #--------------------------------------------------------------------------
1314 1314
1315 1315 def load_balanced_view(self, targets=None):
1316 1316 """construct a DirectView object.
1317 1317
1318 1318 If no arguments are specified, create a LoadBalancedView
1319 1319 using all engines.
1320 1320
1321 1321 Parameters
1322 1322 ----------
1323 1323
1324 1324 targets: list,slice,int,etc. [default: use all engines]
1325 1325 The subset of engines across which to load-balance
1326 1326 """
1327 1327 if targets == 'all':
1328 1328 targets = None
1329 1329 if targets is not None:
1330 1330 targets = self._build_targets(targets)[1]
1331 1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1332 1332
1333 1333 def direct_view(self, targets='all'):
1334 1334 """construct a DirectView object.
1335 1335
1336 1336 If no targets are specified, create a DirectView using all engines.
1337 1337
1338 1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1339 1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1340 1340 all *current* engines, and that list will not change.
1341 1341
1342 1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1343 1343 engines added after the DirectView is constructed.
1344 1344
1345 1345 Parameters
1346 1346 ----------
1347 1347
1348 1348 targets: list,slice,int,etc. [default: use all engines]
1349 1349 The engines to use for the View
1350 1350 """
1351 1351 single = isinstance(targets, int)
1352 1352 # allow 'all' to be lazily evaluated at each execution
1353 1353 if targets != 'all':
1354 1354 targets = self._build_targets(targets)[1]
1355 1355 if single:
1356 1356 targets = targets[0]
1357 1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1358 1358
1359 1359 #--------------------------------------------------------------------------
1360 1360 # Query methods
1361 1361 #--------------------------------------------------------------------------
1362 1362
1363 1363 @spin_first
1364 1364 def get_result(self, indices_or_msg_ids=None, block=None):
1365 1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1366 1366
1367 1367 If the client already has the results, no request to the Hub will be made.
1368 1368
1369 1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1370 1370 that include metadata about execution, and allow for awaiting results that
1371 1371 were not submitted by this Client.
1372 1372
1373 1373 It can also be a convenient way to retrieve the metadata associated with
1374 1374 blocking execution, since it always retrieves
1375 1375
1376 1376 Examples
1377 1377 --------
1378 1378 ::
1379 1379
1380 1380 In [10]: r = client.apply()
1381 1381
1382 1382 Parameters
1383 1383 ----------
1384 1384
1385 1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1386 1386 The indices or msg_ids of indices to be retrieved
1387 1387
1388 1388 block : bool
1389 1389 Whether to wait for the result to be done
1390 1390
1391 1391 Returns
1392 1392 -------
1393 1393
1394 1394 AsyncResult
1395 1395 A single AsyncResult object will always be returned.
1396 1396
1397 1397 AsyncHubResult
1398 1398 A subclass of AsyncResult that retrieves results from the Hub
1399 1399
1400 1400 """
1401 1401 block = self.block if block is None else block
1402 1402 if indices_or_msg_ids is None:
1403 1403 indices_or_msg_ids = -1
1404 1404
1405 1405 single_result = False
1406 1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1407 1407 indices_or_msg_ids = [indices_or_msg_ids]
1408 1408 single_result = True
1409 1409
1410 1410 theids = []
1411 1411 for id in indices_or_msg_ids:
1412 1412 if isinstance(id, int):
1413 1413 id = self.history[id]
1414 1414 if not isinstance(id, string_types):
1415 1415 raise TypeError("indices must be str or int, not %r"%id)
1416 1416 theids.append(id)
1417 1417
1418 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1419 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1418 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1419 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1420 1420
1421 1421 # given single msg_id initially, get_result shot get the result itself,
1422 1422 # not a length-one list
1423 1423 if single_result:
1424 1424 theids = theids[0]
1425 1425
1426 1426 if remote_ids:
1427 1427 ar = AsyncHubResult(self, msg_ids=theids)
1428 1428 else:
1429 1429 ar = AsyncResult(self, msg_ids=theids)
1430 1430
1431 1431 if block:
1432 1432 ar.wait()
1433 1433
1434 1434 return ar
1435 1435
1436 1436 @spin_first
1437 1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1438 1438 """Resubmit one or more tasks.
1439 1439
1440 1440 in-flight tasks may not be resubmitted.
1441 1441
1442 1442 Parameters
1443 1443 ----------
1444 1444
1445 1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1446 1446 The indices or msg_ids of indices to be retrieved
1447 1447
1448 1448 block : bool
1449 1449 Whether to wait for the result to be done
1450 1450
1451 1451 Returns
1452 1452 -------
1453 1453
1454 1454 AsyncHubResult
1455 1455 A subclass of AsyncResult that retrieves results from the Hub
1456 1456
1457 1457 """
1458 1458 block = self.block if block is None else block
1459 1459 if indices_or_msg_ids is None:
1460 1460 indices_or_msg_ids = -1
1461 1461
1462 1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1463 1463 indices_or_msg_ids = [indices_or_msg_ids]
1464 1464
1465 1465 theids = []
1466 1466 for id in indices_or_msg_ids:
1467 1467 if isinstance(id, int):
1468 1468 id = self.history[id]
1469 1469 if not isinstance(id, string_types):
1470 1470 raise TypeError("indices must be str or int, not %r"%id)
1471 1471 theids.append(id)
1472 1472
1473 1473 content = dict(msg_ids = theids)
1474 1474
1475 1475 self.session.send(self._query_socket, 'resubmit_request', content)
1476 1476
1477 1477 zmq.select([self._query_socket], [], [])
1478 1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1479 1479 if self.debug:
1480 1480 pprint(msg)
1481 1481 content = msg['content']
1482 1482 if content['status'] != 'ok':
1483 1483 raise self._unwrap_exception(content)
1484 1484 mapping = content['resubmitted']
1485 1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1486 1486
1487 1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1488 1488
1489 1489 if block:
1490 1490 ar.wait()
1491 1491
1492 1492 return ar
1493 1493
1494 1494 @spin_first
1495 1495 def result_status(self, msg_ids, status_only=True):
1496 1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1497 1497
1498 1498 If status_only is False, then the actual results will be retrieved, else
1499 1499 only the status of the results will be checked.
1500 1500
1501 1501 Parameters
1502 1502 ----------
1503 1503
1504 1504 msg_ids : list of msg_ids
1505 1505 if int:
1506 1506 Passed as index to self.history for convenience.
1507 1507 status_only : bool (default: True)
1508 1508 if False:
1509 1509 Retrieve the actual results of completed tasks.
1510 1510
1511 1511 Returns
1512 1512 -------
1513 1513
1514 1514 results : dict
1515 1515 There will always be the keys 'pending' and 'completed', which will
1516 1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1517 1517 is False, then completed results will be keyed by their `msg_id`.
1518 1518 """
1519 1519 if not isinstance(msg_ids, (list,tuple)):
1520 1520 msg_ids = [msg_ids]
1521 1521
1522 1522 theids = []
1523 1523 for msg_id in msg_ids:
1524 1524 if isinstance(msg_id, int):
1525 1525 msg_id = self.history[msg_id]
1526 1526 if not isinstance(msg_id, string_types):
1527 1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1528 1528 theids.append(msg_id)
1529 1529
1530 1530 completed = []
1531 1531 local_results = {}
1532 1532
1533 1533 # comment this block out to temporarily disable local shortcut:
1534 1534 for msg_id in theids:
1535 1535 if msg_id in self.results:
1536 1536 completed.append(msg_id)
1537 1537 local_results[msg_id] = self.results[msg_id]
1538 1538 theids.remove(msg_id)
1539 1539
1540 1540 if theids: # some not locally cached
1541 1541 content = dict(msg_ids=theids, status_only=status_only)
1542 1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1543 1543 zmq.select([self._query_socket], [], [])
1544 1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1545 1545 if self.debug:
1546 1546 pprint(msg)
1547 1547 content = msg['content']
1548 1548 if content['status'] != 'ok':
1549 1549 raise self._unwrap_exception(content)
1550 1550 buffers = msg['buffers']
1551 1551 else:
1552 1552 content = dict(completed=[],pending=[])
1553 1553
1554 1554 content['completed'].extend(completed)
1555 1555
1556 1556 if status_only:
1557 1557 return content
1558 1558
1559 1559 failures = []
1560 1560 # load cached results into result:
1561 1561 content.update(local_results)
1562 1562
1563 1563 # update cache with results:
1564 1564 for msg_id in sorted(theids):
1565 1565 if msg_id in content['completed']:
1566 1566 rec = content[msg_id]
1567 1567 parent = rec['header']
1568 1568 header = rec['result_header']
1569 1569 rcontent = rec['result_content']
1570 1570 iodict = rec['io']
1571 1571 if isinstance(rcontent, str):
1572 1572 rcontent = self.session.unpack(rcontent)
1573 1573
1574 1574 md = self.metadata[msg_id]
1575 1575 md_msg = dict(
1576 1576 content=rcontent,
1577 1577 parent_header=parent,
1578 1578 header=header,
1579 1579 metadata=rec['result_metadata'],
1580 1580 )
1581 1581 md.update(self._extract_metadata(md_msg))
1582 1582 if rec.get('received'):
1583 1583 md['received'] = rec['received']
1584 1584 md.update(iodict)
1585 1585
1586 1586 if rcontent['status'] == 'ok':
1587 1587 if header['msg_type'] == 'apply_reply':
1588 1588 res,buffers = serialize.unserialize_object(buffers)
1589 1589 elif header['msg_type'] == 'execute_reply':
1590 1590 res = ExecuteReply(msg_id, rcontent, md)
1591 1591 else:
1592 1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1593 1593 else:
1594 1594 res = self._unwrap_exception(rcontent)
1595 1595 failures.append(res)
1596 1596
1597 1597 self.results[msg_id] = res
1598 1598 content[msg_id] = res
1599 1599
1600 1600 if len(theids) == 1 and failures:
1601 1601 raise failures[0]
1602 1602
1603 1603 error.collect_exceptions(failures, "result_status")
1604 1604 return content
1605 1605
1606 1606 @spin_first
1607 1607 def queue_status(self, targets='all', verbose=False):
1608 1608 """Fetch the status of engine queues.
1609 1609
1610 1610 Parameters
1611 1611 ----------
1612 1612
1613 1613 targets : int/str/list of ints/strs
1614 1614 the engines whose states are to be queried.
1615 1615 default : all
1616 1616 verbose : bool
1617 1617 Whether to return lengths only, or lists of ids for each element
1618 1618 """
1619 1619 if targets == 'all':
1620 1620 # allow 'all' to be evaluated on the engine
1621 1621 engine_ids = None
1622 1622 else:
1623 1623 engine_ids = self._build_targets(targets)[1]
1624 1624 content = dict(targets=engine_ids, verbose=verbose)
1625 1625 self.session.send(self._query_socket, "queue_request", content=content)
1626 1626 idents,msg = self.session.recv(self._query_socket, 0)
1627 1627 if self.debug:
1628 1628 pprint(msg)
1629 1629 content = msg['content']
1630 1630 status = content.pop('status')
1631 1631 if status != 'ok':
1632 1632 raise self._unwrap_exception(content)
1633 1633 content = rekey(content)
1634 1634 if isinstance(targets, int):
1635 1635 return content[targets]
1636 1636 else:
1637 1637 return content
1638 1638
1639 1639 def _build_msgids_from_target(self, targets=None):
1640 1640 """Build a list of msg_ids from the list of engine targets"""
1641 1641 if not targets: # needed as _build_targets otherwise uses all engines
1642 1642 return []
1643 1643 target_ids = self._build_targets(targets)[0]
1644 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1644 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1645 1645
1646 1646 def _build_msgids_from_jobs(self, jobs=None):
1647 1647 """Build a list of msg_ids from "jobs" """
1648 1648 if not jobs:
1649 1649 return []
1650 1650 msg_ids = []
1651 1651 if isinstance(jobs, string_types + (AsyncResult,)):
1652 1652 jobs = [jobs]
1653 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1653 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1654 1654 if bad_ids:
1655 1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1656 1656 for j in jobs:
1657 1657 if isinstance(j, AsyncResult):
1658 1658 msg_ids.extend(j.msg_ids)
1659 1659 else:
1660 1660 msg_ids.append(j)
1661 1661 return msg_ids
1662 1662
1663 1663 def purge_local_results(self, jobs=[], targets=[]):
1664 1664 """Clears the client caches of results and frees such memory.
1665 1665
1666 1666 Individual results can be purged by msg_id, or the entire
1667 1667 history of specific targets can be purged.
1668 1668
1669 1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1670 1670
1671 1671 The client must have no outstanding tasks before purging the caches.
1672 1672 Raises `AssertionError` if there are still outstanding tasks.
1673 1673
1674 1674 After this call all `AsyncResults` are invalid and should be discarded.
1675 1675
1676 1676 If you must "reget" the results, you can still do so by using
1677 1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 1678 redownload the results from the hub if they are still available
1679 1679 (i.e `client.purge_hub_results(...)` has not been called.
1680 1680
1681 1681 Parameters
1682 1682 ----------
1683 1683
1684 1684 jobs : str or list of str or AsyncResult objects
1685 1685 the msg_ids whose results should be purged.
1686 1686 targets : int/str/list of ints/strs
1687 1687 The targets, by int_id, whose entire results are to be purged.
1688 1688
1689 1689 default : None
1690 1690 """
1691 1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1692 1692
1693 1693 if not targets and not jobs:
1694 1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1695 1695
1696 1696 if jobs == 'all':
1697 1697 self.results.clear()
1698 1698 self.metadata.clear()
1699 1699 return
1700 1700 else:
1701 1701 msg_ids = []
1702 1702 msg_ids.extend(self._build_msgids_from_target(targets))
1703 1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1704 map(self.results.pop, msg_ids)
1705 map(self.metadata.pop, msg_ids)
1704 for mid in msg_ids:
1705 self.results.pop(mid)
1706 self.metadata.pop(mid)
1706 1707
1707 1708
1708 1709 @spin_first
1709 1710 def purge_hub_results(self, jobs=[], targets=[]):
1710 1711 """Tell the Hub to forget results.
1711 1712
1712 1713 Individual results can be purged by msg_id, or the entire
1713 1714 history of specific targets can be purged.
1714 1715
1715 1716 Use `purge_results('all')` to scrub everything from the Hub's db.
1716 1717
1717 1718 Parameters
1718 1719 ----------
1719 1720
1720 1721 jobs : str or list of str or AsyncResult objects
1721 1722 the msg_ids whose results should be forgotten.
1722 1723 targets : int/str/list of ints/strs
1723 1724 The targets, by int_id, whose entire history is to be purged.
1724 1725
1725 1726 default : None
1726 1727 """
1727 1728 if not targets and not jobs:
1728 1729 raise ValueError("Must specify at least one of `targets` and `jobs`")
1729 1730 if targets:
1730 1731 targets = self._build_targets(targets)[1]
1731 1732
1732 1733 # construct msg_ids from jobs
1733 1734 if jobs == 'all':
1734 1735 msg_ids = jobs
1735 1736 else:
1736 1737 msg_ids = self._build_msgids_from_jobs(jobs)
1737 1738
1738 1739 content = dict(engine_ids=targets, msg_ids=msg_ids)
1739 1740 self.session.send(self._query_socket, "purge_request", content=content)
1740 1741 idents, msg = self.session.recv(self._query_socket, 0)
1741 1742 if self.debug:
1742 1743 pprint(msg)
1743 1744 content = msg['content']
1744 1745 if content['status'] != 'ok':
1745 1746 raise self._unwrap_exception(content)
1746 1747
1747 1748 def purge_results(self, jobs=[], targets=[]):
1748 1749 """Clears the cached results from both the hub and the local client
1749 1750
1750 1751 Individual results can be purged by msg_id, or the entire
1751 1752 history of specific targets can be purged.
1752 1753
1753 1754 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1754 1755 the Client's db.
1755 1756
1756 1757 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1757 1758 the same arguments.
1758 1759
1759 1760 Parameters
1760 1761 ----------
1761 1762
1762 1763 jobs : str or list of str or AsyncResult objects
1763 1764 the msg_ids whose results should be forgotten.
1764 1765 targets : int/str/list of ints/strs
1765 1766 The targets, by int_id, whose entire history is to be purged.
1766 1767
1767 1768 default : None
1768 1769 """
1769 1770 self.purge_local_results(jobs=jobs, targets=targets)
1770 1771 self.purge_hub_results(jobs=jobs, targets=targets)
1771 1772
1772 1773 def purge_everything(self):
1773 1774 """Clears all content from previous Tasks from both the hub and the local client
1774 1775
1775 1776 In addition to calling `purge_results("all")` it also deletes the history and
1776 1777 other bookkeeping lists.
1777 1778 """
1778 1779 self.purge_results("all")
1779 1780 self.history = []
1780 1781 self.session.digest_history.clear()
1781 1782
1782 1783 @spin_first
1783 1784 def hub_history(self):
1784 1785 """Get the Hub's history
1785 1786
1786 1787 Just like the Client, the Hub has a history, which is a list of msg_ids.
1787 1788 This will contain the history of all clients, and, depending on configuration,
1788 1789 may contain history across multiple cluster sessions.
1789 1790
1790 1791 Any msg_id returned here is a valid argument to `get_result`.
1791 1792
1792 1793 Returns
1793 1794 -------
1794 1795
1795 1796 msg_ids : list of strs
1796 1797 list of all msg_ids, ordered by task submission time.
1797 1798 """
1798 1799
1799 1800 self.session.send(self._query_socket, "history_request", content={})
1800 1801 idents, msg = self.session.recv(self._query_socket, 0)
1801 1802
1802 1803 if self.debug:
1803 1804 pprint(msg)
1804 1805 content = msg['content']
1805 1806 if content['status'] != 'ok':
1806 1807 raise self._unwrap_exception(content)
1807 1808 else:
1808 1809 return content['history']
1809 1810
1810 1811 @spin_first
1811 1812 def db_query(self, query, keys=None):
1812 1813 """Query the Hub's TaskRecord database
1813 1814
1814 1815 This will return a list of task record dicts that match `query`
1815 1816
1816 1817 Parameters
1817 1818 ----------
1818 1819
1819 1820 query : mongodb query dict
1820 1821 The search dict. See mongodb query docs for details.
1821 1822 keys : list of strs [optional]
1822 1823 The subset of keys to be returned. The default is to fetch everything but buffers.
1823 1824 'msg_id' will *always* be included.
1824 1825 """
1825 1826 if isinstance(keys, string_types):
1826 1827 keys = [keys]
1827 1828 content = dict(query=query, keys=keys)
1828 1829 self.session.send(self._query_socket, "db_request", content=content)
1829 1830 idents, msg = self.session.recv(self._query_socket, 0)
1830 1831 if self.debug:
1831 1832 pprint(msg)
1832 1833 content = msg['content']
1833 1834 if content['status'] != 'ok':
1834 1835 raise self._unwrap_exception(content)
1835 1836
1836 1837 records = content['records']
1837 1838
1838 1839 buffer_lens = content['buffer_lens']
1839 1840 result_buffer_lens = content['result_buffer_lens']
1840 1841 buffers = msg['buffers']
1841 1842 has_bufs = buffer_lens is not None
1842 1843 has_rbufs = result_buffer_lens is not None
1843 1844 for i,rec in enumerate(records):
1844 1845 # relink buffers
1845 1846 if has_bufs:
1846 1847 blen = buffer_lens[i]
1847 1848 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1848 1849 if has_rbufs:
1849 1850 blen = result_buffer_lens[i]
1850 1851 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1851 1852
1852 1853 return records
1853 1854
1854 1855 __all__ = [ 'Client' ]
@@ -1,1118 +1,1119 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import imp
20 20 import sys
21 21 import warnings
22 22 from contextlib import contextmanager
23 23 from types import ModuleType
24 24
25 25 import zmq
26 26
27 27 from IPython.testing.skipdoctest import skip_doctest
28 28 from IPython.utils.traitlets import (
29 29 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
30 30 )
31 31 from IPython.external.decorator import decorator
32 32
33 33 from IPython.parallel import util
34 34 from IPython.parallel.controller.dependency import Dependency, dependent
35 from IPython.utils.py3compat import string_types, iteritems
35 from IPython.utils.py3compat import string_types, iteritems, PY3
36 36
37 37 from . import map as Map
38 38 from .asyncresult import AsyncResult, AsyncMapResult
39 39 from .remotefunction import ParallelFunction, parallel, remote, getname
40 40
41 41 #-----------------------------------------------------------------------------
42 42 # Decorators
43 43 #-----------------------------------------------------------------------------
44 44
45 45 @decorator
46 46 def save_ids(f, self, *args, **kwargs):
47 47 """Keep our history and outstanding attributes up to date after a method call."""
48 48 n_previous = len(self.client.history)
49 49 try:
50 50 ret = f(self, *args, **kwargs)
51 51 finally:
52 52 nmsgs = len(self.client.history) - n_previous
53 53 msg_ids = self.client.history[-nmsgs:]
54 54 self.history.extend(msg_ids)
55 map(self.outstanding.add, msg_ids)
55 self.outstanding.update(msg_ids)
56 56 return ret
57 57
58 58 @decorator
59 59 def sync_results(f, self, *args, **kwargs):
60 60 """sync relevant results from self.client to our results attribute."""
61 61 if self._in_sync_results:
62 62 return f(self, *args, **kwargs)
63 63 self._in_sync_results = True
64 64 try:
65 65 ret = f(self, *args, **kwargs)
66 66 finally:
67 67 self._in_sync_results = False
68 68 self._sync_results()
69 69 return ret
70 70
71 71 @decorator
72 72 def spin_after(f, self, *args, **kwargs):
73 73 """call spin after the method."""
74 74 ret = f(self, *args, **kwargs)
75 75 self.spin()
76 76 return ret
77 77
78 78 #-----------------------------------------------------------------------------
79 79 # Classes
80 80 #-----------------------------------------------------------------------------
81 81
82 82 @skip_doctest
83 83 class View(HasTraits):
84 84 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
85 85
86 86 Don't use this class, use subclasses.
87 87
88 88 Methods
89 89 -------
90 90
91 91 spin
92 92 flushes incoming results and registration state changes
93 93 control methods spin, and requesting `ids` also ensures up to date
94 94
95 95 wait
96 96 wait on one or more msg_ids
97 97
98 98 execution methods
99 99 apply
100 100 legacy: execute, run
101 101
102 102 data movement
103 103 push, pull, scatter, gather
104 104
105 105 query methods
106 106 get_result, queue_status, purge_results, result_status
107 107
108 108 control methods
109 109 abort, shutdown
110 110
111 111 """
112 112 # flags
113 113 block=Bool(False)
114 114 track=Bool(True)
115 115 targets = Any()
116 116
117 117 history=List()
118 118 outstanding = Set()
119 119 results = Dict()
120 120 client = Instance('IPython.parallel.Client')
121 121
122 122 _socket = Instance('zmq.Socket')
123 123 _flag_names = List(['targets', 'block', 'track'])
124 124 _in_sync_results = Bool(False)
125 125 _targets = Any()
126 126 _idents = Any()
127 127
128 128 def __init__(self, client=None, socket=None, **flags):
129 129 super(View, self).__init__(client=client, _socket=socket)
130 130 self.results = client.results
131 131 self.block = client.block
132 132
133 133 self.set_flags(**flags)
134 134
135 135 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
136 136
137 137 def __repr__(self):
138 138 strtargets = str(self.targets)
139 139 if len(strtargets) > 16:
140 140 strtargets = strtargets[:12]+'...]'
141 141 return "<%s %s>"%(self.__class__.__name__, strtargets)
142 142
143 143 def __len__(self):
144 144 if isinstance(self.targets, list):
145 145 return len(self.targets)
146 146 elif isinstance(self.targets, int):
147 147 return 1
148 148 else:
149 149 return len(self.client)
150 150
151 151 def set_flags(self, **kwargs):
152 152 """set my attribute flags by keyword.
153 153
154 154 Views determine behavior with a few attributes (`block`, `track`, etc.).
155 155 These attributes can be set all at once by name with this method.
156 156
157 157 Parameters
158 158 ----------
159 159
160 160 block : bool
161 161 whether to wait for results
162 162 track : bool
163 163 whether to create a MessageTracker to allow the user to
164 164 safely edit after arrays and buffers during non-copying
165 165 sends.
166 166 """
167 167 for name, value in iteritems(kwargs):
168 168 if name not in self._flag_names:
169 169 raise KeyError("Invalid name: %r"%name)
170 170 else:
171 171 setattr(self, name, value)
172 172
173 173 @contextmanager
174 174 def temp_flags(self, **kwargs):
175 175 """temporarily set flags, for use in `with` statements.
176 176
177 177 See set_flags for permanent setting of flags
178 178
179 179 Examples
180 180 --------
181 181
182 182 >>> view.track=False
183 183 ...
184 184 >>> with view.temp_flags(track=True):
185 185 ... ar = view.apply(dostuff, my_big_array)
186 186 ... ar.tracker.wait() # wait for send to finish
187 187 >>> view.track
188 188 False
189 189
190 190 """
191 191 # preflight: save flags, and set temporaries
192 192 saved_flags = {}
193 193 for f in self._flag_names:
194 194 saved_flags[f] = getattr(self, f)
195 195 self.set_flags(**kwargs)
196 196 # yield to the with-statement block
197 197 try:
198 198 yield
199 199 finally:
200 200 # postflight: restore saved flags
201 201 self.set_flags(**saved_flags)
202 202
203 203
204 204 #----------------------------------------------------------------
205 205 # apply
206 206 #----------------------------------------------------------------
207 207
208 208 def _sync_results(self):
209 209 """to be called by @sync_results decorator
210 210
211 211 after submitting any tasks.
212 212 """
213 213 delta = self.outstanding.difference(self.client.outstanding)
214 214 completed = self.outstanding.intersection(delta)
215 215 self.outstanding = self.outstanding.difference(completed)
216 216
217 217 @sync_results
218 218 @save_ids
219 219 def _really_apply(self, f, args, kwargs, block=None, **options):
220 220 """wrapper for client.send_apply_request"""
221 221 raise NotImplementedError("Implement in subclasses")
222 222
223 223 def apply(self, f, *args, **kwargs):
224 224 """calls f(*args, **kwargs) on remote engines, returning the result.
225 225
226 226 This method sets all apply flags via this View's attributes.
227 227
228 228 if self.block is False:
229 229 returns AsyncResult
230 230 else:
231 231 returns actual result of f(*args, **kwargs)
232 232 """
233 233 return self._really_apply(f, args, kwargs)
234 234
235 235 def apply_async(self, f, *args, **kwargs):
236 236 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
237 237
238 238 returns AsyncResult
239 239 """
240 240 return self._really_apply(f, args, kwargs, block=False)
241 241
242 242 @spin_after
243 243 def apply_sync(self, f, *args, **kwargs):
244 244 """calls f(*args, **kwargs) on remote engines in a blocking manner,
245 245 returning the result.
246 246
247 247 returns: actual result of f(*args, **kwargs)
248 248 """
249 249 return self._really_apply(f, args, kwargs, block=True)
250 250
251 251 #----------------------------------------------------------------
252 252 # wrappers for client and control methods
253 253 #----------------------------------------------------------------
254 254 @sync_results
255 255 def spin(self):
256 256 """spin the client, and sync"""
257 257 self.client.spin()
258 258
259 259 @sync_results
260 260 def wait(self, jobs=None, timeout=-1):
261 261 """waits on one or more `jobs`, for up to `timeout` seconds.
262 262
263 263 Parameters
264 264 ----------
265 265
266 266 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
267 267 ints are indices to self.history
268 268 strs are msg_ids
269 269 default: wait on all outstanding messages
270 270 timeout : float
271 271 a time in seconds, after which to give up.
272 272 default is -1, which means no timeout
273 273
274 274 Returns
275 275 -------
276 276
277 277 True : when all msg_ids are done
278 278 False : timeout reached, some msg_ids still outstanding
279 279 """
280 280 if jobs is None:
281 281 jobs = self.history
282 282 return self.client.wait(jobs, timeout)
283 283
284 284 def abort(self, jobs=None, targets=None, block=None):
285 285 """Abort jobs on my engines.
286 286
287 287 Parameters
288 288 ----------
289 289
290 290 jobs : None, str, list of strs, optional
291 291 if None: abort all jobs.
292 292 else: abort specific msg_id(s).
293 293 """
294 294 block = block if block is not None else self.block
295 295 targets = targets if targets is not None else self.targets
296 296 jobs = jobs if jobs is not None else list(self.outstanding)
297 297
298 298 return self.client.abort(jobs=jobs, targets=targets, block=block)
299 299
300 300 def queue_status(self, targets=None, verbose=False):
301 301 """Fetch the Queue status of my engines"""
302 302 targets = targets if targets is not None else self.targets
303 303 return self.client.queue_status(targets=targets, verbose=verbose)
304 304
305 305 def purge_results(self, jobs=[], targets=[]):
306 306 """Instruct the controller to forget specific results."""
307 307 if targets is None or targets == 'all':
308 308 targets = self.targets
309 309 return self.client.purge_results(jobs=jobs, targets=targets)
310 310
311 311 def shutdown(self, targets=None, restart=False, hub=False, block=None):
312 312 """Terminates one or more engine processes, optionally including the hub.
313 313 """
314 314 block = self.block if block is None else block
315 315 if targets is None or targets == 'all':
316 316 targets = self.targets
317 317 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
318 318
319 319 @spin_after
320 320 def get_result(self, indices_or_msg_ids=None):
321 321 """return one or more results, specified by history index or msg_id.
322 322
323 323 See client.get_result for details.
324 324
325 325 """
326 326
327 327 if indices_or_msg_ids is None:
328 328 indices_or_msg_ids = -1
329 329 if isinstance(indices_or_msg_ids, int):
330 330 indices_or_msg_ids = self.history[indices_or_msg_ids]
331 331 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
332 332 indices_or_msg_ids = list(indices_or_msg_ids)
333 333 for i,index in enumerate(indices_or_msg_ids):
334 334 if isinstance(index, int):
335 335 indices_or_msg_ids[i] = self.history[index]
336 336 return self.client.get_result(indices_or_msg_ids)
337 337
338 338 #-------------------------------------------------------------------
339 339 # Map
340 340 #-------------------------------------------------------------------
341 341
342 342 @sync_results
343 343 def map(self, f, *sequences, **kwargs):
344 344 """override in subclasses"""
345 345 raise NotImplementedError
346 346
347 347 def map_async(self, f, *sequences, **kwargs):
348 348 """Parallel version of builtin `map`, using this view's engines.
349 349
350 350 This is equivalent to map(...block=False)
351 351
352 352 See `self.map` for details.
353 353 """
354 354 if 'block' in kwargs:
355 355 raise TypeError("map_async doesn't take a `block` keyword argument.")
356 356 kwargs['block'] = False
357 357 return self.map(f,*sequences,**kwargs)
358 358
359 359 def map_sync(self, f, *sequences, **kwargs):
360 360 """Parallel version of builtin `map`, using this view's engines.
361 361
362 362 This is equivalent to map(...block=True)
363 363
364 364 See `self.map` for details.
365 365 """
366 366 if 'block' in kwargs:
367 367 raise TypeError("map_sync doesn't take a `block` keyword argument.")
368 368 kwargs['block'] = True
369 369 return self.map(f,*sequences,**kwargs)
370 370
371 371 def imap(self, f, *sequences, **kwargs):
372 372 """Parallel version of `itertools.imap`.
373 373
374 374 See `self.map` for details.
375 375
376 376 """
377 377
378 378 return iter(self.map_async(f,*sequences, **kwargs))
379 379
380 380 #-------------------------------------------------------------------
381 381 # Decorators
382 382 #-------------------------------------------------------------------
383 383
384 384 def remote(self, block=None, **flags):
385 385 """Decorator for making a RemoteFunction"""
386 386 block = self.block if block is None else block
387 387 return remote(self, block=block, **flags)
388 388
389 389 def parallel(self, dist='b', block=None, **flags):
390 390 """Decorator for making a ParallelFunction"""
391 391 block = self.block if block is None else block
392 392 return parallel(self, dist=dist, block=block, **flags)
393 393
394 394 @skip_doctest
395 395 class DirectView(View):
396 396 """Direct Multiplexer View of one or more engines.
397 397
398 398 These are created via indexed access to a client:
399 399
400 400 >>> dv_1 = client[1]
401 401 >>> dv_all = client[:]
402 402 >>> dv_even = client[::2]
403 403 >>> dv_some = client[1:3]
404 404
405 405 This object provides dictionary access to engine namespaces:
406 406
407 407 # push a=5:
408 408 >>> dv['a'] = 5
409 409 # pull 'foo':
410 410 >>> db['foo']
411 411
412 412 """
413 413
414 414 def __init__(self, client=None, socket=None, targets=None):
415 415 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
416 416
417 417 @property
418 418 def importer(self):
419 419 """sync_imports(local=True) as a property.
420 420
421 421 See sync_imports for details.
422 422
423 423 """
424 424 return self.sync_imports(True)
425 425
426 426 @contextmanager
427 427 def sync_imports(self, local=True, quiet=False):
428 428 """Context Manager for performing simultaneous local and remote imports.
429 429
430 430 'import x as y' will *not* work. The 'as y' part will simply be ignored.
431 431
432 432 If `local=True`, then the package will also be imported locally.
433 433
434 434 If `quiet=True`, no output will be produced when attempting remote
435 435 imports.
436 436
437 437 Note that remote-only (`local=False`) imports have not been implemented.
438 438
439 439 >>> with view.sync_imports():
440 440 ... from numpy import recarray
441 441 importing recarray from numpy on engine(s)
442 442
443 443 """
444 444 from IPython.utils.py3compat import builtin_mod
445 445 local_import = builtin_mod.__import__
446 446 modules = set()
447 447 results = []
448 448 @util.interactive
449 449 def remote_import(name, fromlist, level):
450 450 """the function to be passed to apply, that actually performs the import
451 451 on the engine, and loads up the user namespace.
452 452 """
453 453 import sys
454 454 user_ns = globals()
455 455 mod = __import__(name, fromlist=fromlist, level=level)
456 456 if fromlist:
457 457 for key in fromlist:
458 458 user_ns[key] = getattr(mod, key)
459 459 else:
460 460 user_ns[name] = sys.modules[name]
461 461
462 462 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
463 463 """the drop-in replacement for __import__, that optionally imports
464 464 locally as well.
465 465 """
466 466 # don't override nested imports
467 467 save_import = builtin_mod.__import__
468 468 builtin_mod.__import__ = local_import
469 469
470 470 if imp.lock_held():
471 471 # this is a side-effect import, don't do it remotely, or even
472 472 # ignore the local effects
473 473 return local_import(name, globals, locals, fromlist, level)
474 474
475 475 imp.acquire_lock()
476 476 if local:
477 477 mod = local_import(name, globals, locals, fromlist, level)
478 478 else:
479 479 raise NotImplementedError("remote-only imports not yet implemented")
480 480 imp.release_lock()
481 481
482 482 key = name+':'+','.join(fromlist or [])
483 483 if level <= 0 and key not in modules:
484 484 modules.add(key)
485 485 if not quiet:
486 486 if fromlist:
487 487 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
488 488 else:
489 489 print("importing %s on engine(s)"%name)
490 490 results.append(self.apply_async(remote_import, name, fromlist, level))
491 491 # restore override
492 492 builtin_mod.__import__ = save_import
493 493
494 494 return mod
495 495
496 496 # override __import__
497 497 builtin_mod.__import__ = view_import
498 498 try:
499 499 # enter the block
500 500 yield
501 501 except ImportError:
502 502 if local:
503 503 raise
504 504 else:
505 505 # ignore import errors if not doing local imports
506 506 pass
507 507 finally:
508 508 # always restore __import__
509 509 builtin_mod.__import__ = local_import
510 510
511 511 for r in results:
512 512 # raise possible remote ImportErrors here
513 513 r.get()
514 514
515 515
516 516 @sync_results
517 517 @save_ids
518 518 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
519 519 """calls f(*args, **kwargs) on remote engines, returning the result.
520 520
521 521 This method sets all of `apply`'s flags via this View's attributes.
522 522
523 523 Parameters
524 524 ----------
525 525
526 526 f : callable
527 527
528 528 args : list [default: empty]
529 529
530 530 kwargs : dict [default: empty]
531 531
532 532 targets : target list [default: self.targets]
533 533 where to run
534 534 block : bool [default: self.block]
535 535 whether to block
536 536 track : bool [default: self.track]
537 537 whether to ask zmq to track the message, for safe non-copying sends
538 538
539 539 Returns
540 540 -------
541 541
542 542 if self.block is False:
543 543 returns AsyncResult
544 544 else:
545 545 returns actual result of f(*args, **kwargs) on the engine(s)
546 546 This will be a list of self.targets is also a list (even length 1), or
547 547 the single result if self.targets is an integer engine id
548 548 """
549 549 args = [] if args is None else args
550 550 kwargs = {} if kwargs is None else kwargs
551 551 block = self.block if block is None else block
552 552 track = self.track if track is None else track
553 553 targets = self.targets if targets is None else targets
554 554
555 555 _idents, _targets = self.client._build_targets(targets)
556 556 msg_ids = []
557 557 trackers = []
558 558 for ident in _idents:
559 559 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
560 560 ident=ident)
561 561 if track:
562 562 trackers.append(msg['tracker'])
563 563 msg_ids.append(msg['header']['msg_id'])
564 564 if isinstance(targets, int):
565 565 msg_ids = msg_ids[0]
566 566 tracker = None if track is False else zmq.MessageTracker(*trackers)
567 567 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
568 568 if block:
569 569 try:
570 570 return ar.get()
571 571 except KeyboardInterrupt:
572 572 pass
573 573 return ar
574 574
575 575
576 576 @sync_results
577 577 def map(self, f, *sequences, **kwargs):
578 578 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
579 579
580 580 Parallel version of builtin `map`, using this View's `targets`.
581 581
582 582 There will be one task per target, so work will be chunked
583 583 if the sequences are longer than `targets`.
584 584
585 585 Results can be iterated as they are ready, but will become available in chunks.
586 586
587 587 Parameters
588 588 ----------
589 589
590 590 f : callable
591 591 function to be mapped
592 592 *sequences: one or more sequences of matching length
593 593 the sequences to be distributed and passed to `f`
594 594 block : bool
595 595 whether to wait for the result or not [default self.block]
596 596
597 597 Returns
598 598 -------
599 599
600 600 if block=False:
601 601 AsyncMapResult
602 602 An object like AsyncResult, but which reassembles the sequence of results
603 603 into a single list. AsyncMapResults can be iterated through before all
604 604 results are complete.
605 605 else:
606 606 list
607 607 the result of map(f,*sequences)
608 608 """
609 609
610 610 block = kwargs.pop('block', self.block)
611 611 for k in kwargs.keys():
612 612 if k not in ['block', 'track']:
613 613 raise TypeError("invalid keyword arg, %r"%k)
614 614
615 615 assert len(sequences) > 0, "must have some sequences to map onto!"
616 616 pf = ParallelFunction(self, f, block=block, **kwargs)
617 617 return pf.map(*sequences)
618 618
619 619 @sync_results
620 620 @save_ids
621 621 def execute(self, code, silent=True, targets=None, block=None):
622 622 """Executes `code` on `targets` in blocking or nonblocking manner.
623 623
624 624 ``execute`` is always `bound` (affects engine namespace)
625 625
626 626 Parameters
627 627 ----------
628 628
629 629 code : str
630 630 the code string to be executed
631 631 block : bool
632 632 whether or not to wait until done to return
633 633 default: self.block
634 634 """
635 635 block = self.block if block is None else block
636 636 targets = self.targets if targets is None else targets
637 637
638 638 _idents, _targets = self.client._build_targets(targets)
639 639 msg_ids = []
640 640 trackers = []
641 641 for ident in _idents:
642 642 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
643 643 msg_ids.append(msg['header']['msg_id'])
644 644 if isinstance(targets, int):
645 645 msg_ids = msg_ids[0]
646 646 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
647 647 if block:
648 648 try:
649 649 ar.get()
650 650 except KeyboardInterrupt:
651 651 pass
652 652 return ar
653 653
654 654 def run(self, filename, targets=None, block=None):
655 655 """Execute contents of `filename` on my engine(s).
656 656
657 657 This simply reads the contents of the file and calls `execute`.
658 658
659 659 Parameters
660 660 ----------
661 661
662 662 filename : str
663 663 The path to the file
664 664 targets : int/str/list of ints/strs
665 665 the engines on which to execute
666 666 default : all
667 667 block : bool
668 668 whether or not to wait until done
669 669 default: self.block
670 670
671 671 """
672 672 with open(filename, 'r') as f:
673 673 # add newline in case of trailing indented whitespace
674 674 # which will cause SyntaxError
675 675 code = f.read()+'\n'
676 676 return self.execute(code, block=block, targets=targets)
677 677
678 678 def update(self, ns):
679 679 """update remote namespace with dict `ns`
680 680
681 681 See `push` for details.
682 682 """
683 683 return self.push(ns, block=self.block, track=self.track)
684 684
685 685 def push(self, ns, targets=None, block=None, track=None):
686 686 """update remote namespace with dict `ns`
687 687
688 688 Parameters
689 689 ----------
690 690
691 691 ns : dict
692 692 dict of keys with which to update engine namespace(s)
693 693 block : bool [default : self.block]
694 694 whether to wait to be notified of engine receipt
695 695
696 696 """
697 697
698 698 block = block if block is not None else self.block
699 699 track = track if track is not None else self.track
700 700 targets = targets if targets is not None else self.targets
701 701 # applier = self.apply_sync if block else self.apply_async
702 702 if not isinstance(ns, dict):
703 703 raise TypeError("Must be a dict, not %s"%type(ns))
704 704 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
705 705
706 706 def get(self, key_s):
707 707 """get object(s) by `key_s` from remote namespace
708 708
709 709 see `pull` for details.
710 710 """
711 711 # block = block if block is not None else self.block
712 712 return self.pull(key_s, block=True)
713 713
714 714 def pull(self, names, targets=None, block=None):
715 715 """get object(s) by `name` from remote namespace
716 716
717 717 will return one object if it is a key.
718 718 can also take a list of keys, in which case it will return a list of objects.
719 719 """
720 720 block = block if block is not None else self.block
721 721 targets = targets if targets is not None else self.targets
722 722 applier = self.apply_sync if block else self.apply_async
723 723 if isinstance(names, string_types):
724 724 pass
725 725 elif isinstance(names, (list,tuple,set)):
726 726 for key in names:
727 727 if not isinstance(key, string_types):
728 728 raise TypeError("keys must be str, not type %r"%type(key))
729 729 else:
730 730 raise TypeError("names must be strs, not %r"%names)
731 731 return self._really_apply(util._pull, (names,), block=block, targets=targets)
732 732
733 733 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
734 734 """
735 735 Partition a Python sequence and send the partitions to a set of engines.
736 736 """
737 737 block = block if block is not None else self.block
738 738 track = track if track is not None else self.track
739 739 targets = targets if targets is not None else self.targets
740 740
741 741 # construct integer ID list:
742 742 targets = self.client._build_targets(targets)[1]
743 743
744 744 mapObject = Map.dists[dist]()
745 745 nparts = len(targets)
746 746 msg_ids = []
747 747 trackers = []
748 748 for index, engineid in enumerate(targets):
749 749 partition = mapObject.getPartition(seq, index, nparts)
750 750 if flatten and len(partition) == 1:
751 751 ns = {key: partition[0]}
752 752 else:
753 753 ns = {key: partition}
754 754 r = self.push(ns, block=False, track=track, targets=engineid)
755 755 msg_ids.extend(r.msg_ids)
756 756 if track:
757 757 trackers.append(r._tracker)
758 758
759 759 if track:
760 760 tracker = zmq.MessageTracker(*trackers)
761 761 else:
762 762 tracker = None
763 763
764 764 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
765 765 if block:
766 766 r.wait()
767 767 else:
768 768 return r
769 769
770 770 @sync_results
771 771 @save_ids
772 772 def gather(self, key, dist='b', targets=None, block=None):
773 773 """
774 774 Gather a partitioned sequence on a set of engines as a single local seq.
775 775 """
776 776 block = block if block is not None else self.block
777 777 targets = targets if targets is not None else self.targets
778 778 mapObject = Map.dists[dist]()
779 779 msg_ids = []
780 780
781 781 # construct integer ID list:
782 782 targets = self.client._build_targets(targets)[1]
783 783
784 784 for index, engineid in enumerate(targets):
785 785 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
786 786
787 787 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
788 788
789 789 if block:
790 790 try:
791 791 return r.get()
792 792 except KeyboardInterrupt:
793 793 pass
794 794 return r
795 795
796 796 def __getitem__(self, key):
797 797 return self.get(key)
798 798
799 799 def __setitem__(self,key, value):
800 800 self.update({key:value})
801 801
802 802 def clear(self, targets=None, block=None):
803 803 """Clear the remote namespaces on my engines."""
804 804 block = block if block is not None else self.block
805 805 targets = targets if targets is not None else self.targets
806 806 return self.client.clear(targets=targets, block=block)
807 807
808 808 #----------------------------------------
809 809 # activate for %px, %autopx, etc. magics
810 810 #----------------------------------------
811 811
812 812 def activate(self, suffix=''):
813 813 """Activate IPython magics associated with this View
814 814
815 815 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
816 816
817 817 Parameters
818 818 ----------
819 819
820 820 suffix: str [default: '']
821 821 The suffix, if any, for the magics. This allows you to have
822 822 multiple views associated with parallel magics at the same time.
823 823
824 824 e.g. ``rc[::2].activate(suffix='_even')`` will give you
825 825 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
826 826 on the even engines.
827 827 """
828 828
829 829 from IPython.parallel.client.magics import ParallelMagics
830 830
831 831 try:
832 832 # This is injected into __builtins__.
833 833 ip = get_ipython()
834 834 except NameError:
835 835 print("The IPython parallel magics (%px, etc.) only work within IPython.")
836 836 return
837 837
838 838 M = ParallelMagics(ip, self, suffix)
839 839 ip.magics_manager.register(M)
840 840
841 841
842 842 @skip_doctest
843 843 class LoadBalancedView(View):
844 844 """An load-balancing View that only executes via the Task scheduler.
845 845
846 846 Load-balanced views can be created with the client's `view` method:
847 847
848 848 >>> v = client.load_balanced_view()
849 849
850 850 or targets can be specified, to restrict the potential destinations:
851 851
852 852 >>> v = client.client.load_balanced_view([1,3])
853 853
854 854 which would restrict loadbalancing to between engines 1 and 3.
855 855
856 856 """
857 857
858 858 follow=Any()
859 859 after=Any()
860 860 timeout=CFloat()
861 861 retries = Integer(0)
862 862
863 863 _task_scheme = Any()
864 864 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
865 865
866 866 def __init__(self, client=None, socket=None, **flags):
867 867 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
868 868 self._task_scheme=client._task_scheme
869 869
870 870 def _validate_dependency(self, dep):
871 871 """validate a dependency.
872 872
873 873 For use in `set_flags`.
874 874 """
875 875 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
876 876 return True
877 877 elif isinstance(dep, (list,set, tuple)):
878 878 for d in dep:
879 879 if not isinstance(d, string_types + (AsyncResult,)):
880 880 return False
881 881 elif isinstance(dep, dict):
882 882 if set(dep.keys()) != set(Dependency().as_dict().keys()):
883 883 return False
884 884 if not isinstance(dep['msg_ids'], list):
885 885 return False
886 886 for d in dep['msg_ids']:
887 887 if not isinstance(d, string_types):
888 888 return False
889 889 else:
890 890 return False
891 891
892 892 return True
893 893
894 894 def _render_dependency(self, dep):
895 895 """helper for building jsonable dependencies from various input forms."""
896 896 if isinstance(dep, Dependency):
897 897 return dep.as_dict()
898 898 elif isinstance(dep, AsyncResult):
899 899 return dep.msg_ids
900 900 elif dep is None:
901 901 return []
902 902 else:
903 903 # pass to Dependency constructor
904 904 return list(Dependency(dep))
905 905
906 906 def set_flags(self, **kwargs):
907 907 """set my attribute flags by keyword.
908 908
909 909 A View is a wrapper for the Client's apply method, but with attributes
910 910 that specify keyword arguments, those attributes can be set by keyword
911 911 argument with this method.
912 912
913 913 Parameters
914 914 ----------
915 915
916 916 block : bool
917 917 whether to wait for results
918 918 track : bool
919 919 whether to create a MessageTracker to allow the user to
920 920 safely edit after arrays and buffers during non-copying
921 921 sends.
922 922
923 923 after : Dependency or collection of msg_ids
924 924 Only for load-balanced execution (targets=None)
925 925 Specify a list of msg_ids as a time-based dependency.
926 926 This job will only be run *after* the dependencies
927 927 have been met.
928 928
929 929 follow : Dependency or collection of msg_ids
930 930 Only for load-balanced execution (targets=None)
931 931 Specify a list of msg_ids as a location-based dependency.
932 932 This job will only be run on an engine where this dependency
933 933 is met.
934 934
935 935 timeout : float/int or None
936 936 Only for load-balanced execution (targets=None)
937 937 Specify an amount of time (in seconds) for the scheduler to
938 938 wait for dependencies to be met before failing with a
939 939 DependencyTimeout.
940 940
941 941 retries : int
942 942 Number of times a task will be retried on failure.
943 943 """
944 944
945 945 super(LoadBalancedView, self).set_flags(**kwargs)
946 946 for name in ('follow', 'after'):
947 947 if name in kwargs:
948 948 value = kwargs[name]
949 949 if self._validate_dependency(value):
950 950 setattr(self, name, value)
951 951 else:
952 952 raise ValueError("Invalid dependency: %r"%value)
953 953 if 'timeout' in kwargs:
954 954 t = kwargs['timeout']
955 if not isinstance(t, (int, long, float, type(None))):
955 if not isinstance(t, (int, float, type(None))):
956 if (not PY3) and (not isinstance(t, long)):
956 957 raise TypeError("Invalid type for timeout: %r"%type(t))
957 958 if t is not None:
958 959 if t < 0:
959 960 raise ValueError("Invalid timeout: %s"%t)
960 961 self.timeout = t
961 962
962 963 @sync_results
963 964 @save_ids
964 965 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
965 966 after=None, follow=None, timeout=None,
966 967 targets=None, retries=None):
967 968 """calls f(*args, **kwargs) on a remote engine, returning the result.
968 969
969 970 This method temporarily sets all of `apply`'s flags for a single call.
970 971
971 972 Parameters
972 973 ----------
973 974
974 975 f : callable
975 976
976 977 args : list [default: empty]
977 978
978 979 kwargs : dict [default: empty]
979 980
980 981 block : bool [default: self.block]
981 982 whether to block
982 983 track : bool [default: self.track]
983 984 whether to ask zmq to track the message, for safe non-copying sends
984 985
985 986 !!!!!! TODO: THE REST HERE !!!!
986 987
987 988 Returns
988 989 -------
989 990
990 991 if self.block is False:
991 992 returns AsyncResult
992 993 else:
993 994 returns actual result of f(*args, **kwargs) on the engine(s)
994 995 This will be a list of self.targets is also a list (even length 1), or
995 996 the single result if self.targets is an integer engine id
996 997 """
997 998
998 999 # validate whether we can run
999 1000 if self._socket.closed:
1000 1001 msg = "Task farming is disabled"
1001 1002 if self._task_scheme == 'pure':
1002 1003 msg += " because the pure ZMQ scheduler cannot handle"
1003 1004 msg += " disappearing engines."
1004 1005 raise RuntimeError(msg)
1005 1006
1006 1007 if self._task_scheme == 'pure':
1007 1008 # pure zmq scheme doesn't support extra features
1008 1009 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1009 1010 "follow, after, retries, targets, timeout"
1010 1011 if (follow or after or retries or targets or timeout):
1011 1012 # hard fail on Scheduler flags
1012 1013 raise RuntimeError(msg)
1013 1014 if isinstance(f, dependent):
1014 1015 # soft warn on functional dependencies
1015 1016 warnings.warn(msg, RuntimeWarning)
1016 1017
1017 1018 # build args
1018 1019 args = [] if args is None else args
1019 1020 kwargs = {} if kwargs is None else kwargs
1020 1021 block = self.block if block is None else block
1021 1022 track = self.track if track is None else track
1022 1023 after = self.after if after is None else after
1023 1024 retries = self.retries if retries is None else retries
1024 1025 follow = self.follow if follow is None else follow
1025 1026 timeout = self.timeout if timeout is None else timeout
1026 1027 targets = self.targets if targets is None else targets
1027 1028
1028 1029 if not isinstance(retries, int):
1029 1030 raise TypeError('retries must be int, not %r'%type(retries))
1030 1031
1031 1032 if targets is None:
1032 1033 idents = []
1033 1034 else:
1034 1035 idents = self.client._build_targets(targets)[0]
1035 1036 # ensure *not* bytes
1036 1037 idents = [ ident.decode() for ident in idents ]
1037 1038
1038 1039 after = self._render_dependency(after)
1039 1040 follow = self._render_dependency(follow)
1040 1041 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1041 1042
1042 1043 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1043 1044 metadata=metadata)
1044 1045 tracker = None if track is False else msg['tracker']
1045 1046
1046 1047 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1047 1048
1048 1049 if block:
1049 1050 try:
1050 1051 return ar.get()
1051 1052 except KeyboardInterrupt:
1052 1053 pass
1053 1054 return ar
1054 1055
1055 1056 @sync_results
1056 1057 @save_ids
1057 1058 def map(self, f, *sequences, **kwargs):
1058 1059 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1059 1060
1060 1061 Parallel version of builtin `map`, load-balanced by this View.
1061 1062
1062 1063 `block`, and `chunksize` can be specified by keyword only.
1063 1064
1064 1065 Each `chunksize` elements will be a separate task, and will be
1065 1066 load-balanced. This lets individual elements be available for iteration
1066 1067 as soon as they arrive.
1067 1068
1068 1069 Parameters
1069 1070 ----------
1070 1071
1071 1072 f : callable
1072 1073 function to be mapped
1073 1074 *sequences: one or more sequences of matching length
1074 1075 the sequences to be distributed and passed to `f`
1075 1076 block : bool [default self.block]
1076 1077 whether to wait for the result or not
1077 1078 track : bool
1078 1079 whether to create a MessageTracker to allow the user to
1079 1080 safely edit after arrays and buffers during non-copying
1080 1081 sends.
1081 1082 chunksize : int [default 1]
1082 1083 how many elements should be in each task.
1083 1084 ordered : bool [default True]
1084 1085 Whether the results should be gathered as they arrive, or enforce
1085 1086 the order of submission.
1086 1087
1087 1088 Only applies when iterating through AsyncMapResult as results arrive.
1088 1089 Has no effect when block=True.
1089 1090
1090 1091 Returns
1091 1092 -------
1092 1093
1093 1094 if block=False:
1094 1095 AsyncMapResult
1095 1096 An object like AsyncResult, but which reassembles the sequence of results
1096 1097 into a single list. AsyncMapResults can be iterated through before all
1097 1098 results are complete.
1098 1099 else:
1099 1100 the result of map(f,*sequences)
1100 1101
1101 1102 """
1102 1103
1103 1104 # default
1104 1105 block = kwargs.get('block', self.block)
1105 1106 chunksize = kwargs.get('chunksize', 1)
1106 1107 ordered = kwargs.get('ordered', True)
1107 1108
1108 1109 keyset = set(kwargs.keys())
1109 1110 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1110 1111 if extra_keys:
1111 1112 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1112 1113
1113 1114 assert len(sequences) > 0, "must have some sequences to map onto!"
1114 1115
1115 1116 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1116 1117 return pf.map(*sequences)
1117 1118
1118 1119 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,226 +1,230 b''
1 1 """Dependency utilities
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2013 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 from types import ModuleType
15 15
16 16 from IPython.parallel.client.asyncresult import AsyncResult
17 17 from IPython.parallel.error import UnmetDependency
18 18 from IPython.parallel.util import interactive
19 19 from IPython.utils import py3compat
20 20 from IPython.utils.py3compat import string_types
21 21 from IPython.utils.pickleutil import can, uncan
22 22
23 23 class depend(object):
24 24 """Dependency decorator, for use with tasks.
25 25
26 26 `@depend` lets you define a function for engine dependencies
27 27 just like you use `apply` for tasks.
28 28
29 29
30 30 Examples
31 31 --------
32 32 ::
33 33
34 34 @depend(df, a,b, c=5)
35 35 def f(m,n,p)
36 36
37 37 view.apply(f, 1,2,3)
38 38
39 39 will call df(a,b,c=5) on the engine, and if it returns False or
40 40 raises an UnmetDependency error, then the task will not be run
41 41 and another engine will be tried.
42 42 """
43 43 def __init__(self, f, *args, **kwargs):
44 44 self.f = f
45 45 self.args = args
46 46 self.kwargs = kwargs
47 47
48 48 def __call__(self, f):
49 49 return dependent(f, self.f, *self.args, **self.kwargs)
50 50
51 51 class dependent(object):
52 52 """A function that depends on another function.
53 53 This is an object to prevent the closure used
54 54 in traditional decorators, which are not picklable.
55 55 """
56 56
57 57 def __init__(self, f, df, *dargs, **dkwargs):
58 58 self.f = f
59 self.__name__ = getattr(f, '__name__', 'f')
59 name = getattr(f, '__name__', 'f')
60 if py3compat.PY3:
61 self.__name__ = name
62 else:
63 self.func_name = name
60 64 self.df = df
61 65 self.dargs = dargs
62 66 self.dkwargs = dkwargs
63 67
64 68 def check_dependency(self):
65 69 if self.df(*self.dargs, **self.dkwargs) is False:
66 70 raise UnmetDependency()
67 71
68 72 def __call__(self, *args, **kwargs):
69 73 return self.f(*args, **kwargs)
70 74
71 75 if not py3compat.PY3:
72 76 @property
73 77 def __name__(self):
74 78 return self.__name__
75 79
76 80 @interactive
77 81 def _require(*modules, **mapping):
78 82 """Helper for @require decorator."""
79 83 from IPython.parallel.error import UnmetDependency
80 84 from IPython.utils.pickleutil import uncan
81 85 user_ns = globals()
82 86 for name in modules:
83 87 try:
84 88 exec('import %s' % name, user_ns)
85 89 except ImportError:
86 90 raise UnmetDependency(name)
87 91
88 92 for name, cobj in mapping.items():
89 93 user_ns[name] = uncan(cobj, user_ns)
90 94 return True
91 95
92 96 def require(*objects, **mapping):
93 97 """Simple decorator for requiring local objects and modules to be available
94 98 when the decorated function is called on the engine.
95 99
96 100 Modules specified by name or passed directly will be imported
97 101 prior to calling the decorated function.
98 102
99 103 Objects other than modules will be pushed as a part of the task.
100 104 Functions can be passed positionally,
101 105 and will be pushed to the engine with their __name__.
102 106 Other objects can be passed by keyword arg.
103 107
104 108 Examples
105 109 --------
106 110
107 111 In [1]: @require('numpy')
108 112 ...: def norm(a):
109 113 ...: return numpy.linalg.norm(a,2)
110 114
111 115 In [2]: foo = lambda x: x*x
112 116 In [3]: @require(foo)
113 117 ...: def bar(a):
114 118 ...: return foo(1-a)
115 119 """
116 120 names = []
117 121 for obj in objects:
118 122 if isinstance(obj, ModuleType):
119 123 obj = obj.__name__
120 124
121 125 if isinstance(obj, string_types):
122 126 names.append(obj)
123 127 elif hasattr(obj, '__name__'):
124 128 mapping[obj.__name__] = obj
125 129 else:
126 130 raise TypeError("Objects other than modules and functions "
127 131 "must be passed by kwarg, but got: %s" % type(obj)
128 132 )
129 133
130 134 for name, obj in mapping.items():
131 135 mapping[name] = can(obj)
132 136 return depend(_require, *names, **mapping)
133 137
134 138 class Dependency(set):
135 139 """An object for representing a set of msg_id dependencies.
136 140
137 141 Subclassed from set().
138 142
139 143 Parameters
140 144 ----------
141 145 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
142 146 The msg_ids to depend on
143 147 all : bool [default True]
144 148 Whether the dependency should be considered met when *all* depending tasks have completed
145 149 or only when *any* have been completed.
146 150 success : bool [default True]
147 151 Whether to consider successes as fulfilling dependencies.
148 152 failure : bool [default False]
149 153 Whether to consider failures as fulfilling dependencies.
150 154
151 155 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
152 156 as soon as the first depended-upon task fails.
153 157 """
154 158
155 159 all=True
156 160 success=True
157 161 failure=True
158 162
159 163 def __init__(self, dependencies=[], all=True, success=True, failure=False):
160 164 if isinstance(dependencies, dict):
161 165 # load from dict
162 166 all = dependencies.get('all', True)
163 167 success = dependencies.get('success', success)
164 168 failure = dependencies.get('failure', failure)
165 169 dependencies = dependencies.get('dependencies', [])
166 170 ids = []
167 171
168 172 # extract ids from various sources:
169 173 if isinstance(dependencies, string_types + (AsyncResult,)):
170 174 dependencies = [dependencies]
171 175 for d in dependencies:
172 176 if isinstance(d, string_types):
173 177 ids.append(d)
174 178 elif isinstance(d, AsyncResult):
175 179 ids.extend(d.msg_ids)
176 180 else:
177 181 raise TypeError("invalid dependency type: %r"%type(d))
178 182
179 183 set.__init__(self, ids)
180 184 self.all = all
181 185 if not (success or failure):
182 186 raise ValueError("Must depend on at least one of successes or failures!")
183 187 self.success=success
184 188 self.failure = failure
185 189
186 190 def check(self, completed, failed=None):
187 191 """check whether our dependencies have been met."""
188 192 if len(self) == 0:
189 193 return True
190 194 against = set()
191 195 if self.success:
192 196 against = completed
193 197 if failed is not None and self.failure:
194 198 against = against.union(failed)
195 199 if self.all:
196 200 return self.issubset(against)
197 201 else:
198 202 return not self.isdisjoint(against)
199 203
200 204 def unreachable(self, completed, failed=None):
201 205 """return whether this dependency has become impossible."""
202 206 if len(self) == 0:
203 207 return False
204 208 against = set()
205 209 if not self.success:
206 210 against = completed
207 211 if failed is not None and not self.failure:
208 212 against = against.union(failed)
209 213 if self.all:
210 214 return not self.isdisjoint(against)
211 215 else:
212 216 return self.issubset(against)
213 217
214 218
215 219 def as_dict(self):
216 220 """Represent this dependency as a dict. For json compatibility."""
217 221 return dict(
218 222 dependencies=list(self),
219 223 all=self.all,
220 224 success=self.success,
221 225 failure=self.failure
222 226 )
223 227
224 228
225 229 __all__ = ['depend', 'require', 'dependent', 'Dependency']
226 230
@@ -1,190 +1,192 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and ROUTER sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their DEALER identities.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18 import time
19 19 import uuid
20 20
21 21 import zmq
22 22 from zmq.devices import ThreadDevice, ThreadMonitoredQueue
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.config.configurable import LoggingConfigurable
26 26 from IPython.utils.py3compat import str_to_bytes
27 27 from IPython.utils.traitlets import Set, Instance, CFloat, Integer, Dict
28 28
29 29 from IPython.parallel.util import log_errors
30 30
31 31 class Heart(object):
32 32 """A basic heart object for responding to a HeartMonitor.
33 33 This is a simple wrapper with defaults for the most common
34 34 Device model for responding to heartbeats.
35 35
36 36 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
37 37 SUB/DEALER for in/out.
38 38
39 39 You can specify the DEALER's IDENTITY via the optional heart_id argument."""
40 40 device=None
41 41 id=None
42 42 def __init__(self, in_addr, out_addr, mon_addr=None, in_type=zmq.SUB, out_type=zmq.DEALER, mon_type=zmq.PUB, heart_id=None):
43 43 if mon_addr is None:
44 44 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
45 45 else:
46 46 self.device = ThreadMonitoredQueue(in_type, out_type, mon_type, in_prefix=b"", out_prefix=b"")
47 47 # do not allow the device to share global Context.instance,
48 48 # which is the default behavior in pyzmq > 2.1.10
49 49 self.device.context_factory = zmq.Context
50 50
51 51 self.device.daemon=True
52 52 self.device.connect_in(in_addr)
53 53 self.device.connect_out(out_addr)
54 54 if mon_addr is not None:
55 55 self.device.connect_mon(mon_addr)
56 56 if in_type == zmq.SUB:
57 57 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
58 58 if heart_id is None:
59 59 heart_id = uuid.uuid4().bytes
60 60 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
61 61 self.id = heart_id
62 62
63 63 def start(self):
64 64 return self.device.start()
65 65
66 66
67 67 class HeartMonitor(LoggingConfigurable):
68 68 """A basic HeartMonitor class
69 69 pingstream: a PUB stream
70 70 pongstream: an ROUTER stream
71 71 period: the period of the heartbeat in milliseconds"""
72 72
73 73 period = Integer(3000, config=True,
74 74 help='The frequency at which the Hub pings the engines for heartbeats '
75 75 '(in ms)',
76 76 )
77 77 max_heartmonitor_misses = Integer(10, config=True,
78 78 help='Allowed consecutive missed pings from controller Hub to engine before unregistering.',
79 79 )
80 80
81 81 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
82 82 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
83 83 loop = Instance('zmq.eventloop.ioloop.IOLoop')
84 84 def _loop_default(self):
85 85 return ioloop.IOLoop.instance()
86 86
87 87 # not settable:
88 88 hearts=Set()
89 89 responses=Set()
90 90 on_probation=Dict()
91 91 last_ping=CFloat(0)
92 92 _new_handlers = Set()
93 93 _failure_handlers = Set()
94 94 lifetime = CFloat(0)
95 95 tic = CFloat(0)
96 96
97 97 def __init__(self, **kwargs):
98 98 super(HeartMonitor, self).__init__(**kwargs)
99 99
100 100 self.pongstream.on_recv(self.handle_pong)
101 101
102 102 def start(self):
103 103 self.tic = time.time()
104 104 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
105 105 self.caller.start()
106 106
107 107 def add_new_heart_handler(self, handler):
108 108 """add a new handler for new hearts"""
109 109 self.log.debug("heartbeat::new_heart_handler: %s", handler)
110 110 self._new_handlers.add(handler)
111 111
112 112 def add_heart_failure_handler(self, handler):
113 113 """add a new handler for heart failure"""
114 114 self.log.debug("heartbeat::new heart failure handler: %s", handler)
115 115 self._failure_handlers.add(handler)
116 116
117 117 def beat(self):
118 118 self.pongstream.flush()
119 119 self.last_ping = self.lifetime
120 120
121 121 toc = time.time()
122 122 self.lifetime += toc-self.tic
123 123 self.tic = toc
124 124 self.log.debug("heartbeat::sending %s", self.lifetime)
125 125 goodhearts = self.hearts.intersection(self.responses)
126 126 missed_beats = self.hearts.difference(goodhearts)
127 127 newhearts = self.responses.difference(goodhearts)
128 map(self.handle_new_heart, newhearts)
128 for heart in newhearts:
129 self.handle_new_heart(heart)
129 130 heartfailures, on_probation = self._check_missed(missed_beats, self.on_probation,
130 131 self.hearts)
131 map(self.handle_heart_failure, heartfailures)
132 for failure in heartfailures:
133 self.handle_heart_failure(failure)
132 134 self.on_probation = on_probation
133 135 self.responses = set()
134 136 #print self.on_probation, self.hearts
135 137 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
136 138 self.pingstream.send(str_to_bytes(str(self.lifetime)))
137 139 # flush stream to force immediate socket send
138 140 self.pingstream.flush()
139 141
140 142 def _check_missed(self, missed_beats, on_probation, hearts):
141 143 """Update heartbeats on probation, identifying any that have too many misses.
142 144 """
143 145 failures = []
144 146 new_probation = {}
145 147 for cur_heart in (b for b in missed_beats if b in hearts):
146 148 miss_count = on_probation.get(cur_heart, 0) + 1
147 149 self.log.info("heartbeat::missed %s : %s" % (cur_heart, miss_count))
148 150 if miss_count > self.max_heartmonitor_misses:
149 151 failures.append(cur_heart)
150 152 else:
151 153 new_probation[cur_heart] = miss_count
152 154 return failures, new_probation
153 155
154 156 def handle_new_heart(self, heart):
155 157 if self._new_handlers:
156 158 for handler in self._new_handlers:
157 159 handler(heart)
158 160 else:
159 161 self.log.info("heartbeat::yay, got new heart %s!", heart)
160 162 self.hearts.add(heart)
161 163
162 164 def handle_heart_failure(self, heart):
163 165 if self._failure_handlers:
164 166 for handler in self._failure_handlers:
165 167 try:
166 168 handler(heart)
167 169 except Exception as e:
168 170 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
169 171 pass
170 172 else:
171 173 self.log.info("heartbeat::Heart %s failed :(", heart)
172 174 self.hearts.remove(heart)
173 175
174 176
175 177 @log_errors
176 178 def handle_pong(self, msg):
177 179 "a heart just beat"
178 180 current = str_to_bytes(str(self.lifetime))
179 181 last = str_to_bytes(str(self.last_ping))
180 182 if msg[1] == current:
181 183 delta = time.time()-self.tic
182 184 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
183 185 self.responses.add(msg[0])
184 186 elif msg[1] == last:
185 187 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
186 188 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
187 189 self.responses.add(msg[0])
188 190 else:
189 191 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
190 192
@@ -1,1421 +1,1421 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import json
22 22 import os
23 23 import sys
24 24 import time
25 25 from datetime import datetime
26 26
27 27 import zmq
28 28 from zmq.eventloop import ioloop
29 29 from zmq.eventloop.zmqstream import ZMQStream
30 30
31 31 # internal:
32 32 from IPython.utils.importstring import import_item
33 33 from IPython.utils.localinterfaces import localhost
34 34 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
35 35 from IPython.utils.traitlets import (
36 36 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 37 )
38 38
39 39 from IPython.parallel import error, util
40 40 from IPython.parallel.factory import RegistrationFactory
41 41
42 42 from IPython.kernel.zmq.session import SessionFactory
43 43
44 44 from .heartmonitor import HeartMonitor
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Code
48 48 #-----------------------------------------------------------------------------
49 49
50 50 def _passer(*args, **kwargs):
51 51 return
52 52
53 53 def _printer(*args, **kwargs):
54 54 print (args)
55 55 print (kwargs)
56 56
57 57 def empty_record():
58 58 """Return an empty dict with all record keys."""
59 59 return {
60 60 'msg_id' : None,
61 61 'header' : None,
62 62 'metadata' : None,
63 63 'content': None,
64 64 'buffers': None,
65 65 'submitted': None,
66 66 'client_uuid' : None,
67 67 'engine_uuid' : None,
68 68 'started': None,
69 69 'completed': None,
70 70 'resubmitted': None,
71 71 'received': None,
72 72 'result_header' : None,
73 73 'result_metadata' : None,
74 74 'result_content' : None,
75 75 'result_buffers' : None,
76 76 'queue' : None,
77 77 'pyin' : None,
78 78 'pyout': None,
79 79 'pyerr': None,
80 80 'stdout': '',
81 81 'stderr': '',
82 82 }
83 83
84 84 def init_record(msg):
85 85 """Initialize a TaskRecord based on a request."""
86 86 header = msg['header']
87 87 return {
88 88 'msg_id' : header['msg_id'],
89 89 'header' : header,
90 90 'content': msg['content'],
91 91 'metadata': msg['metadata'],
92 92 'buffers': msg['buffers'],
93 93 'submitted': header['date'],
94 94 'client_uuid' : None,
95 95 'engine_uuid' : None,
96 96 'started': None,
97 97 'completed': None,
98 98 'resubmitted': None,
99 99 'received': None,
100 100 'result_header' : None,
101 101 'result_metadata': None,
102 102 'result_content' : None,
103 103 'result_buffers' : None,
104 104 'queue' : None,
105 105 'pyin' : None,
106 106 'pyout': None,
107 107 'pyerr': None,
108 108 'stdout': '',
109 109 'stderr': '',
110 110 }
111 111
112 112
113 113 class EngineConnector(HasTraits):
114 114 """A simple object for accessing the various zmq connections of an object.
115 115 Attributes are:
116 116 id (int): engine ID
117 117 uuid (unicode): engine UUID
118 118 pending: set of msg_ids
119 119 stallback: DelayedCallback for stalled registration
120 120 """
121 121
122 122 id = Integer(0)
123 123 uuid = Unicode()
124 124 pending = Set()
125 125 stallback = Instance(ioloop.DelayedCallback)
126 126
127 127
128 128 _db_shortcuts = {
129 129 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 130 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 131 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 132 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 133 }
134 134
135 135 class HubFactory(RegistrationFactory):
136 136 """The Configurable for setting up a Hub."""
137 137
138 138 # port-pairs for monitoredqueues:
139 139 hb = Tuple(Integer,Integer,config=True,
140 140 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 141 def _hb_default(self):
142 142 return tuple(util.select_random_ports(2))
143 143
144 144 mux = Tuple(Integer,Integer,config=True,
145 145 help="""Client/Engine Port pair for MUX queue""")
146 146
147 147 def _mux_default(self):
148 148 return tuple(util.select_random_ports(2))
149 149
150 150 task = Tuple(Integer,Integer,config=True,
151 151 help="""Client/Engine Port pair for Task queue""")
152 152 def _task_default(self):
153 153 return tuple(util.select_random_ports(2))
154 154
155 155 control = Tuple(Integer,Integer,config=True,
156 156 help="""Client/Engine Port pair for Control queue""")
157 157
158 158 def _control_default(self):
159 159 return tuple(util.select_random_ports(2))
160 160
161 161 iopub = Tuple(Integer,Integer,config=True,
162 162 help="""Client/Engine Port pair for IOPub relay""")
163 163
164 164 def _iopub_default(self):
165 165 return tuple(util.select_random_ports(2))
166 166
167 167 # single ports:
168 168 mon_port = Integer(config=True,
169 169 help="""Monitor (SUB) port for queue traffic""")
170 170
171 171 def _mon_port_default(self):
172 172 return util.select_random_ports(1)[0]
173 173
174 174 notifier_port = Integer(config=True,
175 175 help="""PUB port for sending engine status notifications""")
176 176
177 177 def _notifier_port_default(self):
178 178 return util.select_random_ports(1)[0]
179 179
180 180 engine_ip = Unicode(config=True,
181 181 help="IP on which to listen for engine connections. [default: loopback]")
182 182 def _engine_ip_default(self):
183 183 return localhost()
184 184 engine_transport = Unicode('tcp', config=True,
185 185 help="0MQ transport for engine connections. [default: tcp]")
186 186
187 187 client_ip = Unicode(config=True,
188 188 help="IP on which to listen for client connections. [default: loopback]")
189 189 client_transport = Unicode('tcp', config=True,
190 190 help="0MQ transport for client connections. [default : tcp]")
191 191
192 192 monitor_ip = Unicode(config=True,
193 193 help="IP on which to listen for monitor messages. [default: loopback]")
194 194 monitor_transport = Unicode('tcp', config=True,
195 195 help="0MQ transport for monitor messages. [default : tcp]")
196 196
197 197 _client_ip_default = _monitor_ip_default = _engine_ip_default
198 198
199 199
200 200 monitor_url = Unicode('')
201 201
202 202 db_class = DottedObjectName('NoDB',
203 203 config=True, help="""The class to use for the DB backend
204 204
205 205 Options include:
206 206
207 207 SQLiteDB: SQLite
208 208 MongoDB : use MongoDB
209 209 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 210 NoDB : disable database altogether (default)
211 211
212 212 """)
213 213
214 214 # not configurable
215 215 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 216 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217 217
218 218 def _ip_changed(self, name, old, new):
219 219 self.engine_ip = new
220 220 self.client_ip = new
221 221 self.monitor_ip = new
222 222 self._update_monitor_url()
223 223
224 224 def _update_monitor_url(self):
225 225 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226 226
227 227 def _transport_changed(self, name, old, new):
228 228 self.engine_transport = new
229 229 self.client_transport = new
230 230 self.monitor_transport = new
231 231 self._update_monitor_url()
232 232
233 233 def __init__(self, **kwargs):
234 234 super(HubFactory, self).__init__(**kwargs)
235 235 self._update_monitor_url()
236 236
237 237
238 238 def construct(self):
239 239 self.init_hub()
240 240
241 241 def start(self):
242 242 self.heartmonitor.start()
243 243 self.log.info("Heartmonitor started")
244 244
245 245 def client_url(self, channel):
246 246 """return full zmq url for a named client channel"""
247 247 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248 248
249 249 def engine_url(self, channel):
250 250 """return full zmq url for a named engine channel"""
251 251 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252 252
253 253 def init_hub(self):
254 254 """construct Hub object"""
255 255
256 256 ctx = self.context
257 257 loop = self.loop
258 258 if 'TaskScheduler.scheme_name' in self.config:
259 259 scheme = self.config.TaskScheduler.scheme_name
260 260 else:
261 261 from .scheduler import TaskScheduler
262 262 scheme = TaskScheduler.scheme_name.get_default_value()
263 263
264 264 # build connection dicts
265 265 engine = self.engine_info = {
266 266 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
267 267 'registration' : self.regport,
268 268 'control' : self.control[1],
269 269 'mux' : self.mux[1],
270 270 'hb_ping' : self.hb[0],
271 271 'hb_pong' : self.hb[1],
272 272 'task' : self.task[1],
273 273 'iopub' : self.iopub[1],
274 274 }
275 275
276 276 client = self.client_info = {
277 277 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
278 278 'registration' : self.regport,
279 279 'control' : self.control[0],
280 280 'mux' : self.mux[0],
281 281 'task' : self.task[0],
282 282 'task_scheme' : scheme,
283 283 'iopub' : self.iopub[0],
284 284 'notification' : self.notifier_port,
285 285 }
286 286
287 287 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 288 self.log.debug("Hub client addrs: %s", self.client_info)
289 289
290 290 # Registrar socket
291 291 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
292 292 util.set_hwm(q, 0)
293 293 q.bind(self.client_url('registration'))
294 294 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
295 295 if self.client_ip != self.engine_ip:
296 296 q.bind(self.engine_url('registration'))
297 297 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
298 298
299 299 ### Engine connections ###
300 300
301 301 # heartbeat
302 302 hpub = ctx.socket(zmq.PUB)
303 303 hpub.bind(self.engine_url('hb_ping'))
304 304 hrep = ctx.socket(zmq.ROUTER)
305 305 util.set_hwm(hrep, 0)
306 306 hrep.bind(self.engine_url('hb_pong'))
307 307 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
308 308 pingstream=ZMQStream(hpub,loop),
309 309 pongstream=ZMQStream(hrep,loop)
310 310 )
311 311
312 312 ### Client connections ###
313 313
314 314 # Notifier socket
315 315 n = ZMQStream(ctx.socket(zmq.PUB), loop)
316 316 n.bind(self.client_url('notification'))
317 317
318 318 ### build and launch the queues ###
319 319
320 320 # monitor socket
321 321 sub = ctx.socket(zmq.SUB)
322 322 sub.setsockopt(zmq.SUBSCRIBE, b"")
323 323 sub.bind(self.monitor_url)
324 324 sub.bind('inproc://monitor')
325 325 sub = ZMQStream(sub, loop)
326 326
327 327 # connect the db
328 328 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
329 329 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
330 330 self.db = import_item(str(db_class))(session=self.session.session,
331 331 parent=self, log=self.log)
332 332 time.sleep(.25)
333 333
334 334 # resubmit stream
335 335 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
336 336 url = util.disambiguate_url(self.client_url('task'))
337 337 r.connect(url)
338 338
339 339 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
340 340 query=q, notifier=n, resubmit=r, db=self.db,
341 341 engine_info=self.engine_info, client_info=self.client_info,
342 342 log=self.log)
343 343
344 344
345 345 class Hub(SessionFactory):
346 346 """The IPython Controller Hub with 0MQ connections
347 347
348 348 Parameters
349 349 ==========
350 350 loop: zmq IOLoop instance
351 351 session: Session object
352 352 <removed> context: zmq context for creating new connections (?)
353 353 queue: ZMQStream for monitoring the command queue (SUB)
354 354 query: ZMQStream for engine registration and client queries requests (ROUTER)
355 355 heartbeat: HeartMonitor object checking the pulse of the engines
356 356 notifier: ZMQStream for broadcasting engine registration changes (PUB)
357 357 db: connection to db for out of memory logging of commands
358 358 NotImplemented
359 359 engine_info: dict of zmq connection information for engines to connect
360 360 to the queues.
361 361 client_info: dict of zmq connection information for engines to connect
362 362 to the queues.
363 363 """
364 364
365 365 engine_state_file = Unicode()
366 366
367 367 # internal data structures:
368 368 ids=Set() # engine IDs
369 369 keytable=Dict()
370 370 by_ident=Dict()
371 371 engines=Dict()
372 372 clients=Dict()
373 373 hearts=Dict()
374 374 pending=Set()
375 375 queues=Dict() # pending msg_ids keyed by engine_id
376 376 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
377 377 completed=Dict() # completed msg_ids keyed by engine_id
378 378 all_completed=Set() # completed msg_ids keyed by engine_id
379 379 dead_engines=Set() # completed msg_ids keyed by engine_id
380 380 unassigned=Set() # set of task msg_ds not yet assigned a destination
381 381 incoming_registrations=Dict()
382 382 registration_timeout=Integer()
383 383 _idcounter=Integer(0)
384 384
385 385 # objects from constructor:
386 386 query=Instance(ZMQStream)
387 387 monitor=Instance(ZMQStream)
388 388 notifier=Instance(ZMQStream)
389 389 resubmit=Instance(ZMQStream)
390 390 heartmonitor=Instance(HeartMonitor)
391 391 db=Instance(object)
392 392 client_info=Dict()
393 393 engine_info=Dict()
394 394
395 395
396 396 def __init__(self, **kwargs):
397 397 """
398 398 # universal:
399 399 loop: IOLoop for creating future connections
400 400 session: streamsession for sending serialized data
401 401 # engine:
402 402 queue: ZMQStream for monitoring queue messages
403 403 query: ZMQStream for engine+client registration and client requests
404 404 heartbeat: HeartMonitor object for tracking engines
405 405 # extra:
406 406 db: ZMQStream for db connection (NotImplemented)
407 407 engine_info: zmq address/protocol dict for engine connections
408 408 client_info: zmq address/protocol dict for client connections
409 409 """
410 410
411 411 super(Hub, self).__init__(**kwargs)
412 412 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
413 413
414 414 # register our callbacks
415 415 self.query.on_recv(self.dispatch_query)
416 416 self.monitor.on_recv(self.dispatch_monitor_traffic)
417 417
418 418 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
419 419 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
420 420
421 421 self.monitor_handlers = {b'in' : self.save_queue_request,
422 422 b'out': self.save_queue_result,
423 423 b'intask': self.save_task_request,
424 424 b'outtask': self.save_task_result,
425 425 b'tracktask': self.save_task_destination,
426 426 b'incontrol': _passer,
427 427 b'outcontrol': _passer,
428 428 b'iopub': self.save_iopub_message,
429 429 }
430 430
431 431 self.query_handlers = {'queue_request': self.queue_status,
432 432 'result_request': self.get_results,
433 433 'history_request': self.get_history,
434 434 'db_request': self.db_query,
435 435 'purge_request': self.purge_results,
436 436 'load_request': self.check_load,
437 437 'resubmit_request': self.resubmit_task,
438 438 'shutdown_request': self.shutdown_request,
439 439 'registration_request' : self.register_engine,
440 440 'unregistration_request' : self.unregister_engine,
441 441 'connection_request': self.connection_request,
442 442 }
443 443
444 444 # ignore resubmit replies
445 445 self.resubmit.on_recv(lambda msg: None, copy=False)
446 446
447 447 self.log.info("hub::created hub")
448 448
449 449 @property
450 450 def _next_id(self):
451 451 """gemerate a new ID.
452 452
453 453 No longer reuse old ids, just count from 0."""
454 454 newid = self._idcounter
455 455 self._idcounter += 1
456 456 return newid
457 457 # newid = 0
458 458 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
459 459 # # print newid, self.ids, self.incoming_registrations
460 460 # while newid in self.ids or newid in incoming:
461 461 # newid += 1
462 462 # return newid
463 463
464 464 #-----------------------------------------------------------------------------
465 465 # message validation
466 466 #-----------------------------------------------------------------------------
467 467
468 468 def _validate_targets(self, targets):
469 469 """turn any valid targets argument into a list of integer ids"""
470 470 if targets is None:
471 471 # default to all
472 472 return self.ids
473 473
474 474 if isinstance(targets, (int,str,unicode_type)):
475 475 # only one target specified
476 476 targets = [targets]
477 477 _targets = []
478 478 for t in targets:
479 479 # map raw identities to ids
480 480 if isinstance(t, (str,unicode_type)):
481 481 t = self.by_ident.get(cast_bytes(t), t)
482 482 _targets.append(t)
483 483 targets = _targets
484 484 bad_targets = [ t for t in targets if t not in self.ids ]
485 485 if bad_targets:
486 486 raise IndexError("No Such Engine: %r" % bad_targets)
487 487 if not targets:
488 488 raise IndexError("No Engines Registered")
489 489 return targets
490 490
491 491 #-----------------------------------------------------------------------------
492 492 # dispatch methods (1 per stream)
493 493 #-----------------------------------------------------------------------------
494 494
495 495
496 496 @util.log_errors
497 497 def dispatch_monitor_traffic(self, msg):
498 498 """all ME and Task queue messages come through here, as well as
499 499 IOPub traffic."""
500 500 self.log.debug("monitor traffic: %r", msg[0])
501 501 switch = msg[0]
502 502 try:
503 503 idents, msg = self.session.feed_identities(msg[1:])
504 504 except ValueError:
505 505 idents=[]
506 506 if not idents:
507 507 self.log.error("Monitor message without topic: %r", msg)
508 508 return
509 509 handler = self.monitor_handlers.get(switch, None)
510 510 if handler is not None:
511 511 handler(idents, msg)
512 512 else:
513 513 self.log.error("Unrecognized monitor topic: %r", switch)
514 514
515 515
516 516 @util.log_errors
517 517 def dispatch_query(self, msg):
518 518 """Route registration requests and queries from clients."""
519 519 try:
520 520 idents, msg = self.session.feed_identities(msg)
521 521 except ValueError:
522 522 idents = []
523 523 if not idents:
524 524 self.log.error("Bad Query Message: %r", msg)
525 525 return
526 526 client_id = idents[0]
527 527 try:
528 528 msg = self.session.unserialize(msg, content=True)
529 529 except Exception:
530 530 content = error.wrap_exception()
531 531 self.log.error("Bad Query Message: %r", msg, exc_info=True)
532 532 self.session.send(self.query, "hub_error", ident=client_id,
533 533 content=content)
534 534 return
535 535 # print client_id, header, parent, content
536 536 #switch on message type:
537 537 msg_type = msg['header']['msg_type']
538 538 self.log.info("client::client %r requested %r", client_id, msg_type)
539 539 handler = self.query_handlers.get(msg_type, None)
540 540 try:
541 541 assert handler is not None, "Bad Message Type: %r" % msg_type
542 542 except:
543 543 content = error.wrap_exception()
544 544 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
545 545 self.session.send(self.query, "hub_error", ident=client_id,
546 546 content=content)
547 547 return
548 548
549 549 else:
550 550 handler(idents, msg)
551 551
552 552 def dispatch_db(self, msg):
553 553 """"""
554 554 raise NotImplementedError
555 555
556 556 #---------------------------------------------------------------------------
557 557 # handler methods (1 per event)
558 558 #---------------------------------------------------------------------------
559 559
560 560 #----------------------- Heartbeat --------------------------------------
561 561
562 562 def handle_new_heart(self, heart):
563 563 """handler to attach to heartbeater.
564 564 Called when a new heart starts to beat.
565 565 Triggers completion of registration."""
566 566 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
567 567 if heart not in self.incoming_registrations:
568 568 self.log.info("heartbeat::ignoring new heart: %r", heart)
569 569 else:
570 570 self.finish_registration(heart)
571 571
572 572
573 573 def handle_heart_failure(self, heart):
574 574 """handler to attach to heartbeater.
575 575 called when a previously registered heart fails to respond to beat request.
576 576 triggers unregistration"""
577 577 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
578 578 eid = self.hearts.get(heart, None)
579 579 uuid = self.engines[eid].uuid
580 580 if eid is None or self.keytable[eid] in self.dead_engines:
581 581 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
582 582 else:
583 583 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
584 584
585 585 #----------------------- MUX Queue Traffic ------------------------------
586 586
587 587 def save_queue_request(self, idents, msg):
588 588 if len(idents) < 2:
589 589 self.log.error("invalid identity prefix: %r", idents)
590 590 return
591 591 queue_id, client_id = idents[:2]
592 592 try:
593 593 msg = self.session.unserialize(msg)
594 594 except Exception:
595 595 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
596 596 return
597 597
598 598 eid = self.by_ident.get(queue_id, None)
599 599 if eid is None:
600 600 self.log.error("queue::target %r not registered", queue_id)
601 601 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
602 602 return
603 603 record = init_record(msg)
604 604 msg_id = record['msg_id']
605 605 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
606 606 # Unicode in records
607 607 record['engine_uuid'] = queue_id.decode('ascii')
608 608 record['client_uuid'] = msg['header']['session']
609 609 record['queue'] = 'mux'
610 610
611 611 try:
612 612 # it's posible iopub arrived first:
613 613 existing = self.db.get_record(msg_id)
614 614 for key,evalue in iteritems(existing):
615 615 rvalue = record.get(key, None)
616 616 if evalue and rvalue and evalue != rvalue:
617 617 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
618 618 elif evalue and not rvalue:
619 619 record[key] = evalue
620 620 try:
621 621 self.db.update_record(msg_id, record)
622 622 except Exception:
623 623 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
624 624 except KeyError:
625 625 try:
626 626 self.db.add_record(msg_id, record)
627 627 except Exception:
628 628 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
629 629
630 630
631 631 self.pending.add(msg_id)
632 632 self.queues[eid].append(msg_id)
633 633
634 634 def save_queue_result(self, idents, msg):
635 635 if len(idents) < 2:
636 636 self.log.error("invalid identity prefix: %r", idents)
637 637 return
638 638
639 639 client_id, queue_id = idents[:2]
640 640 try:
641 641 msg = self.session.unserialize(msg)
642 642 except Exception:
643 643 self.log.error("queue::engine %r sent invalid message to %r: %r",
644 644 queue_id, client_id, msg, exc_info=True)
645 645 return
646 646
647 647 eid = self.by_ident.get(queue_id, None)
648 648 if eid is None:
649 649 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
650 650 return
651 651
652 652 parent = msg['parent_header']
653 653 if not parent:
654 654 return
655 655 msg_id = parent['msg_id']
656 656 if msg_id in self.pending:
657 657 self.pending.remove(msg_id)
658 658 self.all_completed.add(msg_id)
659 659 self.queues[eid].remove(msg_id)
660 660 self.completed[eid].append(msg_id)
661 661 self.log.info("queue::request %r completed on %s", msg_id, eid)
662 662 elif msg_id not in self.all_completed:
663 663 # it could be a result from a dead engine that died before delivering the
664 664 # result
665 665 self.log.warn("queue:: unknown msg finished %r", msg_id)
666 666 return
667 667 # update record anyway, because the unregistration could have been premature
668 668 rheader = msg['header']
669 669 md = msg['metadata']
670 670 completed = rheader['date']
671 671 started = md.get('started', None)
672 672 result = {
673 673 'result_header' : rheader,
674 674 'result_metadata': md,
675 675 'result_content': msg['content'],
676 676 'received': datetime.now(),
677 677 'started' : started,
678 678 'completed' : completed
679 679 }
680 680
681 681 result['result_buffers'] = msg['buffers']
682 682 try:
683 683 self.db.update_record(msg_id, result)
684 684 except Exception:
685 685 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
686 686
687 687
688 688 #--------------------- Task Queue Traffic ------------------------------
689 689
690 690 def save_task_request(self, idents, msg):
691 691 """Save the submission of a task."""
692 692 client_id = idents[0]
693 693
694 694 try:
695 695 msg = self.session.unserialize(msg)
696 696 except Exception:
697 697 self.log.error("task::client %r sent invalid task message: %r",
698 698 client_id, msg, exc_info=True)
699 699 return
700 700 record = init_record(msg)
701 701
702 702 record['client_uuid'] = msg['header']['session']
703 703 record['queue'] = 'task'
704 704 header = msg['header']
705 705 msg_id = header['msg_id']
706 706 self.pending.add(msg_id)
707 707 self.unassigned.add(msg_id)
708 708 try:
709 709 # it's posible iopub arrived first:
710 710 existing = self.db.get_record(msg_id)
711 711 if existing['resubmitted']:
712 712 for key in ('submitted', 'client_uuid', 'buffers'):
713 713 # don't clobber these keys on resubmit
714 714 # submitted and client_uuid should be different
715 715 # and buffers might be big, and shouldn't have changed
716 716 record.pop(key)
717 717 # still check content,header which should not change
718 718 # but are not expensive to compare as buffers
719 719
720 720 for key,evalue in iteritems(existing):
721 721 if key.endswith('buffers'):
722 722 # don't compare buffers
723 723 continue
724 724 rvalue = record.get(key, None)
725 725 if evalue and rvalue and evalue != rvalue:
726 726 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
727 727 elif evalue and not rvalue:
728 728 record[key] = evalue
729 729 try:
730 730 self.db.update_record(msg_id, record)
731 731 except Exception:
732 732 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
733 733 except KeyError:
734 734 try:
735 735 self.db.add_record(msg_id, record)
736 736 except Exception:
737 737 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
738 738 except Exception:
739 739 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
740 740
741 741 def save_task_result(self, idents, msg):
742 742 """save the result of a completed task."""
743 743 client_id = idents[0]
744 744 try:
745 745 msg = self.session.unserialize(msg)
746 746 except Exception:
747 747 self.log.error("task::invalid task result message send to %r: %r",
748 748 client_id, msg, exc_info=True)
749 749 return
750 750
751 751 parent = msg['parent_header']
752 752 if not parent:
753 753 # print msg
754 754 self.log.warn("Task %r had no parent!", msg)
755 755 return
756 756 msg_id = parent['msg_id']
757 757 if msg_id in self.unassigned:
758 758 self.unassigned.remove(msg_id)
759 759
760 760 header = msg['header']
761 761 md = msg['metadata']
762 762 engine_uuid = md.get('engine', u'')
763 763 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
764 764
765 765 status = md.get('status', None)
766 766
767 767 if msg_id in self.pending:
768 768 self.log.info("task::task %r finished on %s", msg_id, eid)
769 769 self.pending.remove(msg_id)
770 770 self.all_completed.add(msg_id)
771 771 if eid is not None:
772 772 if status != 'aborted':
773 773 self.completed[eid].append(msg_id)
774 774 if msg_id in self.tasks[eid]:
775 775 self.tasks[eid].remove(msg_id)
776 776 completed = header['date']
777 777 started = md.get('started', None)
778 778 result = {
779 779 'result_header' : header,
780 780 'result_metadata': msg['metadata'],
781 781 'result_content': msg['content'],
782 782 'started' : started,
783 783 'completed' : completed,
784 784 'received' : datetime.now(),
785 785 'engine_uuid': engine_uuid,
786 786 }
787 787
788 788 result['result_buffers'] = msg['buffers']
789 789 try:
790 790 self.db.update_record(msg_id, result)
791 791 except Exception:
792 792 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
793 793
794 794 else:
795 795 self.log.debug("task::unknown task %r finished", msg_id)
796 796
797 797 def save_task_destination(self, idents, msg):
798 798 try:
799 799 msg = self.session.unserialize(msg, content=True)
800 800 except Exception:
801 801 self.log.error("task::invalid task tracking message", exc_info=True)
802 802 return
803 803 content = msg['content']
804 804 # print (content)
805 805 msg_id = content['msg_id']
806 806 engine_uuid = content['engine_id']
807 807 eid = self.by_ident[cast_bytes(engine_uuid)]
808 808
809 809 self.log.info("task::task %r arrived on %r", msg_id, eid)
810 810 if msg_id in self.unassigned:
811 811 self.unassigned.remove(msg_id)
812 812 # else:
813 813 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
814 814
815 815 self.tasks[eid].append(msg_id)
816 816 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
817 817 try:
818 818 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
819 819 except Exception:
820 820 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
821 821
822 822
823 823 def mia_task_request(self, idents, msg):
824 824 raise NotImplementedError
825 825 client_id = idents[0]
826 826 # content = dict(mia=self.mia,status='ok')
827 827 # self.session.send('mia_reply', content=content, idents=client_id)
828 828
829 829
830 830 #--------------------- IOPub Traffic ------------------------------
831 831
832 832 def save_iopub_message(self, topics, msg):
833 833 """save an iopub message into the db"""
834 834 # print (topics)
835 835 try:
836 836 msg = self.session.unserialize(msg, content=True)
837 837 except Exception:
838 838 self.log.error("iopub::invalid IOPub message", exc_info=True)
839 839 return
840 840
841 841 parent = msg['parent_header']
842 842 if not parent:
843 843 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
844 844 return
845 845 msg_id = parent['msg_id']
846 846 msg_type = msg['header']['msg_type']
847 847 content = msg['content']
848 848
849 849 # ensure msg_id is in db
850 850 try:
851 851 rec = self.db.get_record(msg_id)
852 852 except KeyError:
853 853 rec = empty_record()
854 854 rec['msg_id'] = msg_id
855 855 self.db.add_record(msg_id, rec)
856 856 # stream
857 857 d = {}
858 858 if msg_type == 'stream':
859 859 name = content['name']
860 860 s = rec[name] or ''
861 861 d[name] = s + content['data']
862 862
863 863 elif msg_type == 'pyerr':
864 864 d['pyerr'] = content
865 865 elif msg_type == 'pyin':
866 866 d['pyin'] = content['code']
867 867 elif msg_type in ('display_data', 'pyout'):
868 868 d[msg_type] = content
869 869 elif msg_type == 'status':
870 870 pass
871 871 elif msg_type == 'data_pub':
872 872 self.log.info("ignored data_pub message for %s" % msg_id)
873 873 else:
874 874 self.log.warn("unhandled iopub msg_type: %r", msg_type)
875 875
876 876 if not d:
877 877 return
878 878
879 879 try:
880 880 self.db.update_record(msg_id, d)
881 881 except Exception:
882 882 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
883 883
884 884
885 885
886 886 #-------------------------------------------------------------------------
887 887 # Registration requests
888 888 #-------------------------------------------------------------------------
889 889
890 890 def connection_request(self, client_id, msg):
891 891 """Reply with connection addresses for clients."""
892 892 self.log.info("client::client %r connected", client_id)
893 893 content = dict(status='ok')
894 894 jsonable = {}
895 895 for k,v in iteritems(self.keytable):
896 896 if v not in self.dead_engines:
897 897 jsonable[str(k)] = v
898 898 content['engines'] = jsonable
899 899 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
900 900
901 901 def register_engine(self, reg, msg):
902 902 """Register a new engine."""
903 903 content = msg['content']
904 904 try:
905 905 uuid = content['uuid']
906 906 except KeyError:
907 907 self.log.error("registration::queue not specified", exc_info=True)
908 908 return
909 909
910 910 eid = self._next_id
911 911
912 912 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
913 913
914 914 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
915 915 # check if requesting available IDs:
916 916 if cast_bytes(uuid) in self.by_ident:
917 917 try:
918 918 raise KeyError("uuid %r in use" % uuid)
919 919 except:
920 920 content = error.wrap_exception()
921 921 self.log.error("uuid %r in use", uuid, exc_info=True)
922 922 else:
923 923 for h, ec in iteritems(self.incoming_registrations):
924 924 if uuid == h:
925 925 try:
926 926 raise KeyError("heart_id %r in use" % uuid)
927 927 except:
928 928 self.log.error("heart_id %r in use", uuid, exc_info=True)
929 929 content = error.wrap_exception()
930 930 break
931 931 elif uuid == ec.uuid:
932 932 try:
933 933 raise KeyError("uuid %r in use" % uuid)
934 934 except:
935 935 self.log.error("uuid %r in use", uuid, exc_info=True)
936 936 content = error.wrap_exception()
937 937 break
938 938
939 939 msg = self.session.send(self.query, "registration_reply",
940 940 content=content,
941 941 ident=reg)
942 942
943 943 heart = cast_bytes(uuid)
944 944
945 945 if content['status'] == 'ok':
946 946 if heart in self.heartmonitor.hearts:
947 947 # already beating
948 948 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
949 949 self.finish_registration(heart)
950 950 else:
951 951 purge = lambda : self._purge_stalled_registration(heart)
952 952 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
953 953 dc.start()
954 954 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
955 955 else:
956 956 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
957 957
958 958 return eid
959 959
960 960 def unregister_engine(self, ident, msg):
961 961 """Unregister an engine that explicitly requested to leave."""
962 962 try:
963 963 eid = msg['content']['id']
964 964 except:
965 965 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
966 966 return
967 967 self.log.info("registration::unregister_engine(%r)", eid)
968 968 # print (eid)
969 969 uuid = self.keytable[eid]
970 970 content=dict(id=eid, uuid=uuid)
971 971 self.dead_engines.add(uuid)
972 972 # self.ids.remove(eid)
973 973 # uuid = self.keytable.pop(eid)
974 974 #
975 975 # ec = self.engines.pop(eid)
976 976 # self.hearts.pop(ec.heartbeat)
977 977 # self.by_ident.pop(ec.queue)
978 978 # self.completed.pop(eid)
979 979 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
980 980 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
981 981 dc.start()
982 982 ############## TODO: HANDLE IT ################
983 983
984 984 self._save_engine_state()
985 985
986 986 if self.notifier:
987 987 self.session.send(self.notifier, "unregistration_notification", content=content)
988 988
989 989 def _handle_stranded_msgs(self, eid, uuid):
990 990 """Handle messages known to be on an engine when the engine unregisters.
991 991
992 992 It is possible that this will fire prematurely - that is, an engine will
993 993 go down after completing a result, and the client will be notified
994 994 that the result failed and later receive the actual result.
995 995 """
996 996
997 997 outstanding = self.queues[eid]
998 998
999 999 for msg_id in outstanding:
1000 1000 self.pending.remove(msg_id)
1001 1001 self.all_completed.add(msg_id)
1002 1002 try:
1003 1003 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1004 1004 except:
1005 1005 content = error.wrap_exception()
1006 1006 # build a fake header:
1007 1007 header = {}
1008 1008 header['engine'] = uuid
1009 1009 header['date'] = datetime.now()
1010 1010 rec = dict(result_content=content, result_header=header, result_buffers=[])
1011 1011 rec['completed'] = header['date']
1012 1012 rec['engine_uuid'] = uuid
1013 1013 try:
1014 1014 self.db.update_record(msg_id, rec)
1015 1015 except Exception:
1016 1016 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1017 1017
1018 1018
1019 1019 def finish_registration(self, heart):
1020 1020 """Second half of engine registration, called after our HeartMonitor
1021 1021 has received a beat from the Engine's Heart."""
1022 1022 try:
1023 1023 ec = self.incoming_registrations.pop(heart)
1024 1024 except KeyError:
1025 1025 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1026 1026 return
1027 1027 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1028 1028 if ec.stallback is not None:
1029 1029 ec.stallback.stop()
1030 1030 eid = ec.id
1031 1031 self.ids.add(eid)
1032 1032 self.keytable[eid] = ec.uuid
1033 1033 self.engines[eid] = ec
1034 1034 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1035 1035 self.queues[eid] = list()
1036 1036 self.tasks[eid] = list()
1037 1037 self.completed[eid] = list()
1038 1038 self.hearts[heart] = eid
1039 1039 content = dict(id=eid, uuid=self.engines[eid].uuid)
1040 1040 if self.notifier:
1041 1041 self.session.send(self.notifier, "registration_notification", content=content)
1042 1042 self.log.info("engine::Engine Connected: %i", eid)
1043 1043
1044 1044 self._save_engine_state()
1045 1045
1046 1046 def _purge_stalled_registration(self, heart):
1047 1047 if heart in self.incoming_registrations:
1048 1048 ec = self.incoming_registrations.pop(heart)
1049 1049 self.log.info("registration::purging stalled registration: %i", ec.id)
1050 1050 else:
1051 1051 pass
1052 1052
1053 1053 #-------------------------------------------------------------------------
1054 1054 # Engine State
1055 1055 #-------------------------------------------------------------------------
1056 1056
1057 1057
1058 1058 def _cleanup_engine_state_file(self):
1059 1059 """cleanup engine state mapping"""
1060 1060
1061 1061 if os.path.exists(self.engine_state_file):
1062 1062 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1063 1063 try:
1064 1064 os.remove(self.engine_state_file)
1065 1065 except IOError:
1066 1066 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1067 1067
1068 1068
1069 1069 def _save_engine_state(self):
1070 1070 """save engine mapping to JSON file"""
1071 1071 if not self.engine_state_file:
1072 1072 return
1073 1073 self.log.debug("save engine state to %s" % self.engine_state_file)
1074 1074 state = {}
1075 1075 engines = {}
1076 1076 for eid, ec in iteritems(self.engines):
1077 1077 if ec.uuid not in self.dead_engines:
1078 1078 engines[eid] = ec.uuid
1079 1079
1080 1080 state['engines'] = engines
1081 1081
1082 1082 state['next_id'] = self._idcounter
1083 1083
1084 1084 with open(self.engine_state_file, 'w') as f:
1085 1085 json.dump(state, f)
1086 1086
1087 1087
1088 1088 def _load_engine_state(self):
1089 1089 """load engine mapping from JSON file"""
1090 1090 if not os.path.exists(self.engine_state_file):
1091 1091 return
1092 1092
1093 1093 self.log.info("loading engine state from %s" % self.engine_state_file)
1094 1094
1095 1095 with open(self.engine_state_file) as f:
1096 1096 state = json.load(f)
1097 1097
1098 1098 save_notifier = self.notifier
1099 1099 self.notifier = None
1100 1100 for eid, uuid in iteritems(state['engines']):
1101 1101 heart = uuid.encode('ascii')
1102 1102 # start with this heart as current and beating:
1103 1103 self.heartmonitor.responses.add(heart)
1104 1104 self.heartmonitor.hearts.add(heart)
1105 1105
1106 1106 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1107 1107 self.finish_registration(heart)
1108 1108
1109 1109 self.notifier = save_notifier
1110 1110
1111 1111 self._idcounter = state['next_id']
1112 1112
1113 1113 #-------------------------------------------------------------------------
1114 1114 # Client Requests
1115 1115 #-------------------------------------------------------------------------
1116 1116
1117 1117 def shutdown_request(self, client_id, msg):
1118 1118 """handle shutdown request."""
1119 1119 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1120 1120 # also notify other clients of shutdown
1121 1121 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1122 1122 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1123 1123 dc.start()
1124 1124
1125 1125 def _shutdown(self):
1126 1126 self.log.info("hub::hub shutting down.")
1127 1127 time.sleep(0.1)
1128 1128 sys.exit(0)
1129 1129
1130 1130
1131 1131 def check_load(self, client_id, msg):
1132 1132 content = msg['content']
1133 1133 try:
1134 1134 targets = content['targets']
1135 1135 targets = self._validate_targets(targets)
1136 1136 except:
1137 1137 content = error.wrap_exception()
1138 1138 self.session.send(self.query, "hub_error",
1139 1139 content=content, ident=client_id)
1140 1140 return
1141 1141
1142 1142 content = dict(status='ok')
1143 1143 # loads = {}
1144 1144 for t in targets:
1145 1145 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1146 1146 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1147 1147
1148 1148
1149 1149 def queue_status(self, client_id, msg):
1150 1150 """Return the Queue status of one or more targets.
1151 1151 if verbose: return the msg_ids
1152 1152 else: return len of each type.
1153 1153 keys: queue (pending MUX jobs)
1154 1154 tasks (pending Task jobs)
1155 1155 completed (finished jobs from both queues)"""
1156 1156 content = msg['content']
1157 1157 targets = content['targets']
1158 1158 try:
1159 1159 targets = self._validate_targets(targets)
1160 1160 except:
1161 1161 content = error.wrap_exception()
1162 1162 self.session.send(self.query, "hub_error",
1163 1163 content=content, ident=client_id)
1164 1164 return
1165 1165 verbose = content.get('verbose', False)
1166 1166 content = dict(status='ok')
1167 1167 for t in targets:
1168 1168 queue = self.queues[t]
1169 1169 completed = self.completed[t]
1170 1170 tasks = self.tasks[t]
1171 1171 if not verbose:
1172 1172 queue = len(queue)
1173 1173 completed = len(completed)
1174 1174 tasks = len(tasks)
1175 1175 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1176 1176 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1177 1177 # print (content)
1178 1178 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1179 1179
1180 1180 def purge_results(self, client_id, msg):
1181 1181 """Purge results from memory. This method is more valuable before we move
1182 1182 to a DB based message storage mechanism."""
1183 1183 content = msg['content']
1184 1184 self.log.info("Dropping records with %s", content)
1185 1185 msg_ids = content.get('msg_ids', [])
1186 1186 reply = dict(status='ok')
1187 1187 if msg_ids == 'all':
1188 1188 try:
1189 1189 self.db.drop_matching_records(dict(completed={'$ne':None}))
1190 1190 except Exception:
1191 1191 reply = error.wrap_exception()
1192 1192 else:
1193 pending = filter(lambda m: m in self.pending, msg_ids)
1193 pending = [m for m in msg_ids if (m in self.pending)]
1194 1194 if pending:
1195 1195 try:
1196 1196 raise IndexError("msg pending: %r" % pending[0])
1197 1197 except:
1198 1198 reply = error.wrap_exception()
1199 1199 else:
1200 1200 try:
1201 1201 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1202 1202 except Exception:
1203 1203 reply = error.wrap_exception()
1204 1204
1205 1205 if reply['status'] == 'ok':
1206 1206 eids = content.get('engine_ids', [])
1207 1207 for eid in eids:
1208 1208 if eid not in self.engines:
1209 1209 try:
1210 1210 raise IndexError("No such engine: %i" % eid)
1211 1211 except:
1212 1212 reply = error.wrap_exception()
1213 1213 break
1214 1214 uid = self.engines[eid].uuid
1215 1215 try:
1216 1216 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1217 1217 except Exception:
1218 1218 reply = error.wrap_exception()
1219 1219 break
1220 1220
1221 1221 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1222 1222
1223 1223 def resubmit_task(self, client_id, msg):
1224 1224 """Resubmit one or more tasks."""
1225 1225 def finish(reply):
1226 1226 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1227 1227
1228 1228 content = msg['content']
1229 1229 msg_ids = content['msg_ids']
1230 1230 reply = dict(status='ok')
1231 1231 try:
1232 1232 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1233 1233 'header', 'content', 'buffers'])
1234 1234 except Exception:
1235 1235 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1236 1236 return finish(error.wrap_exception())
1237 1237
1238 1238 # validate msg_ids
1239 1239 found_ids = [ rec['msg_id'] for rec in records ]
1240 1240 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1241 1241 if len(records) > len(msg_ids):
1242 1242 try:
1243 1243 raise RuntimeError("DB appears to be in an inconsistent state."
1244 1244 "More matching records were found than should exist")
1245 1245 except Exception:
1246 1246 return finish(error.wrap_exception())
1247 1247 elif len(records) < len(msg_ids):
1248 1248 missing = [ m for m in msg_ids if m not in found_ids ]
1249 1249 try:
1250 1250 raise KeyError("No such msg(s): %r" % missing)
1251 1251 except KeyError:
1252 1252 return finish(error.wrap_exception())
1253 1253 elif pending_ids:
1254 1254 pass
1255 1255 # no need to raise on resubmit of pending task, now that we
1256 1256 # resubmit under new ID, but do we want to raise anyway?
1257 1257 # msg_id = invalid_ids[0]
1258 1258 # try:
1259 1259 # raise ValueError("Task(s) %r appears to be inflight" % )
1260 1260 # except Exception:
1261 1261 # return finish(error.wrap_exception())
1262 1262
1263 1263 # mapping of original IDs to resubmitted IDs
1264 1264 resubmitted = {}
1265 1265
1266 1266 # send the messages
1267 1267 for rec in records:
1268 1268 header = rec['header']
1269 1269 msg = self.session.msg(header['msg_type'], parent=header)
1270 1270 msg_id = msg['msg_id']
1271 1271 msg['content'] = rec['content']
1272 1272
1273 1273 # use the old header, but update msg_id and timestamp
1274 1274 fresh = msg['header']
1275 1275 header['msg_id'] = fresh['msg_id']
1276 1276 header['date'] = fresh['date']
1277 1277 msg['header'] = header
1278 1278
1279 1279 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1280 1280
1281 1281 resubmitted[rec['msg_id']] = msg_id
1282 1282 self.pending.add(msg_id)
1283 1283 msg['buffers'] = rec['buffers']
1284 1284 try:
1285 1285 self.db.add_record(msg_id, init_record(msg))
1286 1286 except Exception:
1287 1287 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1288 1288 return finish(error.wrap_exception())
1289 1289
1290 1290 finish(dict(status='ok', resubmitted=resubmitted))
1291 1291
1292 1292 # store the new IDs in the Task DB
1293 1293 for msg_id, resubmit_id in iteritems(resubmitted):
1294 1294 try:
1295 1295 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1296 1296 except Exception:
1297 1297 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1298 1298
1299 1299
1300 1300 def _extract_record(self, rec):
1301 1301 """decompose a TaskRecord dict into subsection of reply for get_result"""
1302 1302 io_dict = {}
1303 1303 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1304 1304 io_dict[key] = rec[key]
1305 1305 content = {
1306 1306 'header': rec['header'],
1307 1307 'metadata': rec['metadata'],
1308 1308 'result_metadata': rec['result_metadata'],
1309 1309 'result_header' : rec['result_header'],
1310 1310 'result_content': rec['result_content'],
1311 1311 'received' : rec['received'],
1312 1312 'io' : io_dict,
1313 1313 }
1314 1314 if rec['result_buffers']:
1315 buffers = map(bytes, rec['result_buffers'])
1315 buffers = list(map(bytes, rec['result_buffers']))
1316 1316 else:
1317 1317 buffers = []
1318 1318
1319 1319 return content, buffers
1320 1320
1321 1321 def get_results(self, client_id, msg):
1322 1322 """Get the result of 1 or more messages."""
1323 1323 content = msg['content']
1324 1324 msg_ids = sorted(set(content['msg_ids']))
1325 1325 statusonly = content.get('status_only', False)
1326 1326 pending = []
1327 1327 completed = []
1328 1328 content = dict(status='ok')
1329 1329 content['pending'] = pending
1330 1330 content['completed'] = completed
1331 1331 buffers = []
1332 1332 if not statusonly:
1333 1333 try:
1334 1334 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1335 1335 # turn match list into dict, for faster lookup
1336 1336 records = {}
1337 1337 for rec in matches:
1338 1338 records[rec['msg_id']] = rec
1339 1339 except Exception:
1340 1340 content = error.wrap_exception()
1341 1341 self.session.send(self.query, "result_reply", content=content,
1342 1342 parent=msg, ident=client_id)
1343 1343 return
1344 1344 else:
1345 1345 records = {}
1346 1346 for msg_id in msg_ids:
1347 1347 if msg_id in self.pending:
1348 1348 pending.append(msg_id)
1349 1349 elif msg_id in self.all_completed:
1350 1350 completed.append(msg_id)
1351 1351 if not statusonly:
1352 1352 c,bufs = self._extract_record(records[msg_id])
1353 1353 content[msg_id] = c
1354 1354 buffers.extend(bufs)
1355 1355 elif msg_id in records:
1356 1356 if rec['completed']:
1357 1357 completed.append(msg_id)
1358 1358 c,bufs = self._extract_record(records[msg_id])
1359 1359 content[msg_id] = c
1360 1360 buffers.extend(bufs)
1361 1361 else:
1362 1362 pending.append(msg_id)
1363 1363 else:
1364 1364 try:
1365 1365 raise KeyError('No such message: '+msg_id)
1366 1366 except:
1367 1367 content = error.wrap_exception()
1368 1368 break
1369 1369 self.session.send(self.query, "result_reply", content=content,
1370 1370 parent=msg, ident=client_id,
1371 1371 buffers=buffers)
1372 1372
1373 1373 def get_history(self, client_id, msg):
1374 1374 """Get a list of all msg_ids in our DB records"""
1375 1375 try:
1376 1376 msg_ids = self.db.get_history()
1377 1377 except Exception as e:
1378 1378 content = error.wrap_exception()
1379 1379 else:
1380 1380 content = dict(status='ok', history=msg_ids)
1381 1381
1382 1382 self.session.send(self.query, "history_reply", content=content,
1383 1383 parent=msg, ident=client_id)
1384 1384
1385 1385 def db_query(self, client_id, msg):
1386 1386 """Perform a raw query on the task record database."""
1387 1387 content = msg['content']
1388 1388 query = content.get('query', {})
1389 1389 keys = content.get('keys', None)
1390 1390 buffers = []
1391 1391 empty = list()
1392 1392 try:
1393 1393 records = self.db.find_records(query, keys)
1394 1394 except Exception as e:
1395 1395 content = error.wrap_exception()
1396 1396 else:
1397 1397 # extract buffers from reply content:
1398 1398 if keys is not None:
1399 1399 buffer_lens = [] if 'buffers' in keys else None
1400 1400 result_buffer_lens = [] if 'result_buffers' in keys else None
1401 1401 else:
1402 1402 buffer_lens = None
1403 1403 result_buffer_lens = None
1404 1404
1405 1405 for rec in records:
1406 1406 # buffers may be None, so double check
1407 1407 b = rec.pop('buffers', empty) or empty
1408 1408 if buffer_lens is not None:
1409 1409 buffer_lens.append(len(b))
1410 1410 buffers.extend(b)
1411 1411 rb = rec.pop('result_buffers', empty) or empty
1412 1412 if result_buffer_lens is not None:
1413 1413 result_buffer_lens.append(len(rb))
1414 1414 buffers.extend(rb)
1415 1415 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1416 1416 result_buffer_lens=result_buffer_lens)
1417 1417 # self.log.debug (content)
1418 1418 self.session.send(self.query, "db_reply", content=content,
1419 1419 parent=msg, ident=client_id,
1420 1420 buffers=buffers)
1421 1421
@@ -1,122 +1,122 b''
1 1 """A TaskRecord backend using mongodb
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 from pymongo import Connection
15 15
16 16 # bson.Binary import moved
17 17 try:
18 18 from bson.binary import Binary
19 19 except ImportError:
20 20 from bson import Binary
21 21
22 22 from IPython.utils.traitlets import Dict, List, Unicode, Instance
23 23
24 24 from .dictdb import BaseDB
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # MongoDB class
28 28 #-----------------------------------------------------------------------------
29 29
30 30 class MongoDB(BaseDB):
31 31 """MongoDB TaskRecord backend."""
32 32
33 33 connection_args = List(config=True,
34 34 help="""Positional arguments to be passed to pymongo.Connection. Only
35 35 necessary if the default mongodb configuration does not point to your
36 36 mongod instance.""")
37 37 connection_kwargs = Dict(config=True,
38 38 help="""Keyword arguments to be passed to pymongo.Connection. Only
39 39 necessary if the default mongodb configuration does not point to your
40 40 mongod instance."""
41 41 )
42 42 database = Unicode("ipython-tasks", config=True,
43 43 help="""The MongoDB database name to use for storing tasks for this session. If unspecified,
44 44 a new database will be created with the Hub's IDENT. Specifying the database will result
45 45 in tasks from previous sessions being available via Clients' db_query and
46 46 get_result methods.""")
47 47
48 48 _connection = Instance(Connection) # pymongo connection
49 49
50 50 def __init__(self, **kwargs):
51 51 super(MongoDB, self).__init__(**kwargs)
52 52 if self._connection is None:
53 53 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
54 54 if not self.database:
55 55 self.database = self.session
56 56 self._db = self._connection[self.database]
57 57 self._records = self._db['task_records']
58 58 self._records.ensure_index('msg_id', unique=True)
59 59 self._records.ensure_index('submitted') # for sorting history
60 60 # for rec in self._records.find
61 61
62 62 def _binary_buffers(self, rec):
63 63 for key in ('buffers', 'result_buffers'):
64 64 if rec.get(key, None):
65 rec[key] = map(Binary, rec[key])
65 rec[key] = list(map(Binary, rec[key]))
66 66 return rec
67 67
68 68 def add_record(self, msg_id, rec):
69 69 """Add a new Task Record, by msg_id."""
70 70 # print rec
71 71 rec = self._binary_buffers(rec)
72 72 self._records.insert(rec)
73 73
74 74 def get_record(self, msg_id):
75 75 """Get a specific Task Record, by msg_id."""
76 76 r = self._records.find_one({'msg_id': msg_id})
77 77 if not r:
78 78 # r will be '' if nothing is found
79 79 raise KeyError(msg_id)
80 80 return r
81 81
82 82 def update_record(self, msg_id, rec):
83 83 """Update the data in an existing record."""
84 84 rec = self._binary_buffers(rec)
85 85
86 86 self._records.update({'msg_id':msg_id}, {'$set': rec})
87 87
88 88 def drop_matching_records(self, check):
89 89 """Remove a record from the DB."""
90 90 self._records.remove(check)
91 91
92 92 def drop_record(self, msg_id):
93 93 """Remove a record from the DB."""
94 94 self._records.remove({'msg_id':msg_id})
95 95
96 96 def find_records(self, check, keys=None):
97 97 """Find records matching a query dict, optionally extracting subset of keys.
98 98
99 99 Returns list of matching records.
100 100
101 101 Parameters
102 102 ----------
103 103
104 104 check: dict
105 105 mongodb-style query argument
106 106 keys: list of strs [optional]
107 107 if specified, the subset of keys to extract. msg_id will *always* be
108 108 included.
109 109 """
110 110 if keys and 'msg_id' not in keys:
111 111 keys.append('msg_id')
112 112 matches = list(self._records.find(check,keys))
113 113 for rec in matches:
114 114 rec.pop('_id')
115 115 return matches
116 116
117 117 def get_history(self):
118 118 """get all msg_ids, ordered by time submitted."""
119 119 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
120 120 return [ rec['msg_id'] for rec in cursor ]
121 121
122 122
@@ -1,860 +1,859 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 import logging
23 23 import sys
24 24 import time
25 25
26 26 from collections import deque
27 27 from datetime import datetime
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 44 from IPython.utils.py3compat import cast_bytes
45 45
46 46 from IPython.parallel import error, util
47 47 from IPython.parallel.factory import SessionFactory
48 48 from IPython.parallel.util import connect_logger, local_logger
49 49
50 50 from .dependency import Dependency
51 51
52 52 @decorator
53 53 def logged(f,self,*args,**kwargs):
54 54 # print ("#--------------------")
55 55 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
56 56 # print ("#--")
57 57 return f(self,*args, **kwargs)
58 58
59 59 #----------------------------------------------------------------------
60 60 # Chooser functions
61 61 #----------------------------------------------------------------------
62 62
63 63 def plainrandom(loads):
64 64 """Plain random pick."""
65 65 n = len(loads)
66 66 return randint(0,n-1)
67 67
68 68 def lru(loads):
69 69 """Always pick the front of the line.
70 70
71 71 The content of `loads` is ignored.
72 72
73 73 Assumes LRU ordering of loads, with oldest first.
74 74 """
75 75 return 0
76 76
77 77 def twobin(loads):
78 78 """Pick two at random, use the LRU of the two.
79 79
80 80 The content of loads is ignored.
81 81
82 82 Assumes LRU ordering of loads, with oldest first.
83 83 """
84 84 n = len(loads)
85 85 a = randint(0,n-1)
86 86 b = randint(0,n-1)
87 87 return min(a,b)
88 88
89 89 def weighted(loads):
90 90 """Pick two at random using inverse load as weight.
91 91
92 92 Return the less loaded of the two.
93 93 """
94 94 # weight 0 a million times more than 1:
95 95 weights = 1./(1e-6+numpy.array(loads))
96 96 sums = weights.cumsum()
97 97 t = sums[-1]
98 98 x = random()*t
99 99 y = random()*t
100 100 idx = 0
101 101 idy = 0
102 102 while sums[idx] < x:
103 103 idx += 1
104 104 while sums[idy] < y:
105 105 idy += 1
106 106 if weights[idy] > weights[idx]:
107 107 return idy
108 108 else:
109 109 return idx
110 110
111 111 def leastload(loads):
112 112 """Always choose the lowest load.
113 113
114 114 If the lowest load occurs more than once, the first
115 115 occurance will be used. If loads has LRU ordering, this means
116 116 the LRU of those with the lowest load is chosen.
117 117 """
118 118 return loads.index(min(loads))
119 119
120 120 #---------------------------------------------------------------------
121 121 # Classes
122 122 #---------------------------------------------------------------------
123 123
124 124
125 125 # store empty default dependency:
126 126 MET = Dependency([])
127 127
128 128
129 129 class Job(object):
130 130 """Simple container for a job"""
131 131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 132 targets, after, follow, timeout):
133 133 self.msg_id = msg_id
134 134 self.raw_msg = raw_msg
135 135 self.idents = idents
136 136 self.msg = msg
137 137 self.header = header
138 138 self.metadata = metadata
139 139 self.targets = targets
140 140 self.after = after
141 141 self.follow = follow
142 142 self.timeout = timeout
143 143
144 144 self.removed = False # used for lazy-delete from sorted queue
145 145 self.timestamp = time.time()
146 146 self.timeout_id = 0
147 147 self.blacklist = set()
148 148
149 149 def __lt__(self, other):
150 150 return self.timestamp < other.timestamp
151 151
152 152 def __cmp__(self, other):
153 153 return cmp(self.timestamp, other.timestamp)
154 154
155 155 @property
156 156 def dependents(self):
157 157 return self.follow.union(self.after)
158 158
159 159
160 160 class TaskScheduler(SessionFactory):
161 161 """Python TaskScheduler object.
162 162
163 163 This is the simplest object that supports msg_id based
164 164 DAG dependencies. *Only* task msg_ids are checked, not
165 165 msg_ids of jobs submitted via the MUX queue.
166 166
167 167 """
168 168
169 169 hwm = Integer(1, config=True,
170 170 help="""specify the High Water Mark (HWM) for the downstream
171 171 socket in the Task scheduler. This is the maximum number
172 172 of allowed outstanding tasks on each engine.
173 173
174 174 The default (1) means that only one task can be outstanding on each
175 175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
176 176 engines continue to be assigned tasks while they are working,
177 177 effectively hiding network latency behind computation, but can result
178 178 in an imbalance of work when submitting many heterogenous tasks all at
179 179 once. Any positive value greater than one is a compromise between the
180 180 two.
181 181
182 182 """
183 183 )
184 184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
185 185 'leastload', config=True, allow_none=False,
186 186 help="""select the task scheduler scheme [default: Python LRU]
187 187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
188 188 )
189 189 def _scheme_name_changed(self, old, new):
190 190 self.log.debug("Using scheme %r"%new)
191 191 self.scheme = globals()[new]
192 192
193 193 # input arguments:
194 194 scheme = Instance(FunctionType) # function for determining the destination
195 195 def _scheme_default(self):
196 196 return leastload
197 197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
198 198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
199 199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
200 200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
201 201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
202 202
203 203 # internals:
204 204 queue = Instance(deque) # sorted list of Jobs
205 205 def _queue_default(self):
206 206 return deque()
207 207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
208 208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
209 209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
210 210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
211 211 pending = Dict() # dict by engine_uuid of submitted tasks
212 212 completed = Dict() # dict by engine_uuid of completed tasks
213 213 failed = Dict() # dict by engine_uuid of failed tasks
214 214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
215 215 clients = Dict() # dict by msg_id for who submitted the task
216 216 targets = List() # list of target IDENTs
217 217 loads = List() # list of engine loads
218 218 # full = Set() # set of IDENTs that have HWM outstanding tasks
219 219 all_completed = Set() # set of all completed tasks
220 220 all_failed = Set() # set of all failed tasks
221 221 all_done = Set() # set of all finished tasks=union(completed,failed)
222 222 all_ids = Set() # set of all submitted task IDs
223 223
224 224 ident = CBytes() # ZMQ identity. This should just be self.session.session
225 225 # but ensure Bytes
226 226 def _ident_default(self):
227 227 return self.session.bsession
228 228
229 229 def start(self):
230 230 self.query_stream.on_recv(self.dispatch_query_reply)
231 231 self.session.send(self.query_stream, "connection_request", {})
232 232
233 233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
234 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 235
236 236 self._notification_handlers = dict(
237 237 registration_notification = self._register_engine,
238 238 unregistration_notification = self._unregister_engine
239 239 )
240 240 self.notifier_stream.on_recv(self.dispatch_notification)
241 241 self.log.info("Scheduler started [%s]" % self.scheme_name)
242 242
243 243 def resume_receiving(self):
244 244 """Resume accepting jobs."""
245 245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246 246
247 247 def stop_receiving(self):
248 248 """Stop accepting jobs while there are no engines.
249 249 Leave them in the ZMQ queue."""
250 250 self.client_stream.on_recv(None)
251 251
252 252 #-----------------------------------------------------------------------
253 253 # [Un]Registration Handling
254 254 #-----------------------------------------------------------------------
255 255
256 256
257 257 def dispatch_query_reply(self, msg):
258 258 """handle reply to our initial connection request"""
259 259 try:
260 260 idents,msg = self.session.feed_identities(msg)
261 261 except ValueError:
262 262 self.log.warn("task::Invalid Message: %r",msg)
263 263 return
264 264 try:
265 265 msg = self.session.unserialize(msg)
266 266 except ValueError:
267 267 self.log.warn("task::Unauthorized message from: %r"%idents)
268 268 return
269 269
270 270 content = msg['content']
271 271 for uuid in content.get('engines', {}).values():
272 272 self._register_engine(cast_bytes(uuid))
273 273
274 274
275 275 @util.log_errors
276 276 def dispatch_notification(self, msg):
277 277 """dispatch register/unregister events."""
278 278 try:
279 279 idents,msg = self.session.feed_identities(msg)
280 280 except ValueError:
281 281 self.log.warn("task::Invalid Message: %r",msg)
282 282 return
283 283 try:
284 284 msg = self.session.unserialize(msg)
285 285 except ValueError:
286 286 self.log.warn("task::Unauthorized message from: %r"%idents)
287 287 return
288 288
289 289 msg_type = msg['header']['msg_type']
290 290
291 291 handler = self._notification_handlers.get(msg_type, None)
292 292 if handler is None:
293 293 self.log.error("Unhandled message type: %r"%msg_type)
294 294 else:
295 295 try:
296 296 handler(cast_bytes(msg['content']['uuid']))
297 297 except Exception:
298 298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299 299
300 300 def _register_engine(self, uid):
301 301 """New engine with ident `uid` became available."""
302 302 # head of the line:
303 303 self.targets.insert(0,uid)
304 304 self.loads.insert(0,0)
305 305
306 306 # initialize sets
307 307 self.completed[uid] = set()
308 308 self.failed[uid] = set()
309 309 self.pending[uid] = {}
310 310
311 311 # rescan the graph:
312 312 self.update_graph(None)
313 313
314 314 def _unregister_engine(self, uid):
315 315 """Existing engine with ident `uid` became unavailable."""
316 316 if len(self.targets) == 1:
317 317 # this was our only engine
318 318 pass
319 319
320 320 # handle any potentially finished tasks:
321 321 self.engine_stream.flush()
322 322
323 323 # don't pop destinations, because they might be used later
324 324 # map(self.destinations.pop, self.completed.pop(uid))
325 325 # map(self.destinations.pop, self.failed.pop(uid))
326 326
327 327 # prevent this engine from receiving work
328 328 idx = self.targets.index(uid)
329 329 self.targets.pop(idx)
330 330 self.loads.pop(idx)
331 331
332 332 # wait 5 seconds before cleaning up pending jobs, since the results might
333 333 # still be incoming
334 334 if self.pending[uid]:
335 335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 336 dc.start()
337 337 else:
338 338 self.completed.pop(uid)
339 339 self.failed.pop(uid)
340 340
341 341
342 342 def handle_stranded_tasks(self, engine):
343 343 """Deal with jobs resident in an engine that died."""
344 344 lost = self.pending[engine]
345 345 for msg_id in lost.keys():
346 346 if msg_id not in self.pending[engine]:
347 347 # prevent double-handling of messages
348 348 continue
349 349
350 350 raw_msg = lost[msg_id].raw_msg
351 351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 352 parent = self.session.unpack(msg[1].bytes)
353 353 idents = [engine, idents[0]]
354 354
355 355 # build fake error reply
356 356 try:
357 357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 358 except:
359 359 content = error.wrap_exception()
360 360 # build fake metadata
361 361 md = dict(
362 362 status=u'error',
363 363 engine=engine.decode('ascii'),
364 364 date=datetime.now(),
365 365 )
366 366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
367 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
368 368 # and dispatch it
369 369 self.dispatch_result(raw_reply)
370 370
371 371 # finally scrub completed/failed lists
372 372 self.completed.pop(engine)
373 373 self.failed.pop(engine)
374 374
375 375
376 376 #-----------------------------------------------------------------------
377 377 # Job Submission
378 378 #-----------------------------------------------------------------------
379 379
380 380
381 381 @util.log_errors
382 382 def dispatch_submission(self, raw_msg):
383 383 """Dispatch job submission to appropriate handlers."""
384 384 # ensure targets up to date:
385 385 self.notifier_stream.flush()
386 386 try:
387 387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 388 msg = self.session.unserialize(msg, content=False, copy=False)
389 389 except Exception:
390 390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 391 return
392 392
393 393
394 394 # send to monitor
395 395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396 396
397 397 header = msg['header']
398 398 md = msg['metadata']
399 399 msg_id = header['msg_id']
400 400 self.all_ids.add(msg_id)
401 401
402 402 # get targets as a set of bytes objects
403 403 # from a list of unicode objects
404 404 targets = md.get('targets', [])
405 targets = map(cast_bytes, targets)
406 targets = set(targets)
405 targets = set(map(cast_bytes, targets))
407 406
408 407 retries = md.get('retries', 0)
409 408 self.retries[msg_id] = retries
410 409
411 410 # time dependencies
412 411 after = md.get('after', None)
413 412 if after:
414 413 after = Dependency(after)
415 414 if after.all:
416 415 if after.success:
417 416 after = Dependency(after.difference(self.all_completed),
418 417 success=after.success,
419 418 failure=after.failure,
420 419 all=after.all,
421 420 )
422 421 if after.failure:
423 422 after = Dependency(after.difference(self.all_failed),
424 423 success=after.success,
425 424 failure=after.failure,
426 425 all=after.all,
427 426 )
428 427 if after.check(self.all_completed, self.all_failed):
429 428 # recast as empty set, if `after` already met,
430 429 # to prevent unnecessary set comparisons
431 430 after = MET
432 431 else:
433 432 after = MET
434 433
435 434 # location dependencies
436 435 follow = Dependency(md.get('follow', []))
437 436
438 437 timeout = md.get('timeout', None)
439 438 if timeout:
440 439 timeout = float(timeout)
441 440
442 441 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
443 442 header=header, targets=targets, after=after, follow=follow,
444 443 timeout=timeout, metadata=md,
445 444 )
446 445 # validate and reduce dependencies:
447 446 for dep in after,follow:
448 447 if not dep: # empty dependency
449 448 continue
450 449 # check valid:
451 450 if msg_id in dep or dep.difference(self.all_ids):
452 451 self.queue_map[msg_id] = job
453 452 return self.fail_unreachable(msg_id, error.InvalidDependency)
454 453 # check if unreachable:
455 454 if dep.unreachable(self.all_completed, self.all_failed):
456 455 self.queue_map[msg_id] = job
457 456 return self.fail_unreachable(msg_id)
458 457
459 458 if after.check(self.all_completed, self.all_failed):
460 459 # time deps already met, try to run
461 460 if not self.maybe_run(job):
462 461 # can't run yet
463 462 if msg_id not in self.all_failed:
464 463 # could have failed as unreachable
465 464 self.save_unmet(job)
466 465 else:
467 466 self.save_unmet(job)
468 467
469 468 def job_timeout(self, job, timeout_id):
470 469 """callback for a job's timeout.
471 470
472 471 The job may or may not have been run at this point.
473 472 """
474 473 if job.timeout_id != timeout_id:
475 474 # not the most recent call
476 475 return
477 476 now = time.time()
478 477 if job.timeout >= (now + 1):
479 478 self.log.warn("task %s timeout fired prematurely: %s > %s",
480 479 job.msg_id, job.timeout, now
481 480 )
482 481 if job.msg_id in self.queue_map:
483 482 # still waiting, but ran out of time
484 483 self.log.info("task %r timed out", job.msg_id)
485 484 self.fail_unreachable(job.msg_id, error.TaskTimeout)
486 485
487 486 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
488 487 """a task has become unreachable, send a reply with an ImpossibleDependency
489 488 error."""
490 489 if msg_id not in self.queue_map:
491 490 self.log.error("task %r already failed!", msg_id)
492 491 return
493 492 job = self.queue_map.pop(msg_id)
494 493 # lazy-delete from the queue
495 494 job.removed = True
496 495 for mid in job.dependents:
497 496 if mid in self.graph:
498 497 self.graph[mid].remove(msg_id)
499 498
500 499 try:
501 500 raise why()
502 501 except:
503 502 content = error.wrap_exception()
504 503 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
505 504
506 505 self.all_done.add(msg_id)
507 506 self.all_failed.add(msg_id)
508 507
509 508 msg = self.session.send(self.client_stream, 'apply_reply', content,
510 509 parent=job.header, ident=job.idents)
511 510 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
512 511
513 512 self.update_graph(msg_id, success=False)
514 513
515 514 def available_engines(self):
516 515 """return a list of available engine indices based on HWM"""
517 516 if not self.hwm:
518 return range(len(self.targets))
517 return list(range(len(self.targets)))
519 518 available = []
520 519 for idx in range(len(self.targets)):
521 520 if self.loads[idx] < self.hwm:
522 521 available.append(idx)
523 522 return available
524 523
525 524 def maybe_run(self, job):
526 525 """check location dependencies, and run if they are met."""
527 526 msg_id = job.msg_id
528 527 self.log.debug("Attempting to assign task %s", msg_id)
529 528 available = self.available_engines()
530 529 if not available:
531 530 # no engines, definitely can't run
532 531 return False
533 532
534 533 if job.follow or job.targets or job.blacklist or self.hwm:
535 534 # we need a can_run filter
536 535 def can_run(idx):
537 536 # check hwm
538 537 if self.hwm and self.loads[idx] == self.hwm:
539 538 return False
540 539 target = self.targets[idx]
541 540 # check blacklist
542 541 if target in job.blacklist:
543 542 return False
544 543 # check targets
545 544 if job.targets and target not in job.targets:
546 545 return False
547 546 # check follow
548 547 return job.follow.check(self.completed[target], self.failed[target])
549 548
550 indices = filter(can_run, available)
549 indices = list(filter(can_run, available))
551 550
552 551 if not indices:
553 552 # couldn't run
554 553 if job.follow.all:
555 554 # check follow for impossibility
556 555 dests = set()
557 556 relevant = set()
558 557 if job.follow.success:
559 558 relevant = self.all_completed
560 559 if job.follow.failure:
561 560 relevant = relevant.union(self.all_failed)
562 561 for m in job.follow.intersection(relevant):
563 562 dests.add(self.destinations[m])
564 563 if len(dests) > 1:
565 564 self.queue_map[msg_id] = job
566 565 self.fail_unreachable(msg_id)
567 566 return False
568 567 if job.targets:
569 568 # check blacklist+targets for impossibility
570 569 job.targets.difference_update(job.blacklist)
571 570 if not job.targets or not job.targets.intersection(self.targets):
572 571 self.queue_map[msg_id] = job
573 572 self.fail_unreachable(msg_id)
574 573 return False
575 574 return False
576 575 else:
577 576 indices = None
578 577
579 578 self.submit_task(job, indices)
580 579 return True
581 580
582 581 def save_unmet(self, job):
583 582 """Save a message for later submission when its dependencies are met."""
584 583 msg_id = job.msg_id
585 584 self.log.debug("Adding task %s to the queue", msg_id)
586 585 self.queue_map[msg_id] = job
587 586 self.queue.append(job)
588 587 # track the ids in follow or after, but not those already finished
589 588 for dep_id in job.after.union(job.follow).difference(self.all_done):
590 589 if dep_id not in self.graph:
591 590 self.graph[dep_id] = set()
592 591 self.graph[dep_id].add(msg_id)
593 592
594 593 # schedule timeout callback
595 594 if job.timeout:
596 595 timeout_id = job.timeout_id = job.timeout_id + 1
597 596 self.loop.add_timeout(time.time() + job.timeout,
598 597 lambda : self.job_timeout(job, timeout_id)
599 598 )
600 599
601 600
602 601 def submit_task(self, job, indices=None):
603 602 """Submit a task to any of a subset of our targets."""
604 603 if indices:
605 604 loads = [self.loads[i] for i in indices]
606 605 else:
607 606 loads = self.loads
608 607 idx = self.scheme(loads)
609 608 if indices:
610 609 idx = indices[idx]
611 610 target = self.targets[idx]
612 611 # print (target, map(str, msg[:3]))
613 612 # send job to the engine
614 613 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
615 614 self.engine_stream.send_multipart(job.raw_msg, copy=False)
616 615 # update load
617 616 self.add_job(idx)
618 617 self.pending[target][job.msg_id] = job
619 618 # notify Hub
620 619 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
621 620 self.session.send(self.mon_stream, 'task_destination', content=content,
622 621 ident=[b'tracktask',self.ident])
623 622
624 623
625 624 #-----------------------------------------------------------------------
626 625 # Result Handling
627 626 #-----------------------------------------------------------------------
628 627
629 628
630 629 @util.log_errors
631 630 def dispatch_result(self, raw_msg):
632 631 """dispatch method for result replies"""
633 632 try:
634 633 idents,msg = self.session.feed_identities(raw_msg, copy=False)
635 634 msg = self.session.unserialize(msg, content=False, copy=False)
636 635 engine = idents[0]
637 636 try:
638 637 idx = self.targets.index(engine)
639 638 except ValueError:
640 639 pass # skip load-update for dead engines
641 640 else:
642 641 self.finish_job(idx)
643 642 except Exception:
644 643 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
645 644 return
646 645
647 646 md = msg['metadata']
648 647 parent = msg['parent_header']
649 648 if md.get('dependencies_met', True):
650 649 success = (md['status'] == 'ok')
651 650 msg_id = parent['msg_id']
652 651 retries = self.retries[msg_id]
653 652 if not success and retries > 0:
654 653 # failed
655 654 self.retries[msg_id] = retries - 1
656 655 self.handle_unmet_dependency(idents, parent)
657 656 else:
658 657 del self.retries[msg_id]
659 658 # relay to client and update graph
660 659 self.handle_result(idents, parent, raw_msg, success)
661 660 # send to Hub monitor
662 661 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
663 662 else:
664 663 self.handle_unmet_dependency(idents, parent)
665 664
666 665 def handle_result(self, idents, parent, raw_msg, success=True):
667 666 """handle a real task result, either success or failure"""
668 667 # first, relay result to client
669 668 engine = idents[0]
670 669 client = idents[1]
671 670 # swap_ids for ROUTER-ROUTER mirror
672 671 raw_msg[:2] = [client,engine]
673 672 # print (map(str, raw_msg[:4]))
674 673 self.client_stream.send_multipart(raw_msg, copy=False)
675 674 # now, update our data structures
676 675 msg_id = parent['msg_id']
677 676 self.pending[engine].pop(msg_id)
678 677 if success:
679 678 self.completed[engine].add(msg_id)
680 679 self.all_completed.add(msg_id)
681 680 else:
682 681 self.failed[engine].add(msg_id)
683 682 self.all_failed.add(msg_id)
684 683 self.all_done.add(msg_id)
685 684 self.destinations[msg_id] = engine
686 685
687 686 self.update_graph(msg_id, success)
688 687
689 688 def handle_unmet_dependency(self, idents, parent):
690 689 """handle an unmet dependency"""
691 690 engine = idents[0]
692 691 msg_id = parent['msg_id']
693 692
694 693 job = self.pending[engine].pop(msg_id)
695 694 job.blacklist.add(engine)
696 695
697 696 if job.blacklist == job.targets:
698 697 self.queue_map[msg_id] = job
699 698 self.fail_unreachable(msg_id)
700 699 elif not self.maybe_run(job):
701 700 # resubmit failed
702 701 if msg_id not in self.all_failed:
703 702 # put it back in our dependency tree
704 703 self.save_unmet(job)
705 704
706 705 if self.hwm:
707 706 try:
708 707 idx = self.targets.index(engine)
709 708 except ValueError:
710 709 pass # skip load-update for dead engines
711 710 else:
712 711 if self.loads[idx] == self.hwm-1:
713 712 self.update_graph(None)
714 713
715 714 def update_graph(self, dep_id=None, success=True):
716 715 """dep_id just finished. Update our dependency
717 716 graph and submit any jobs that just became runnable.
718 717
719 718 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
720 719 """
721 720 # print ("\n\n***********")
722 721 # pprint (dep_id)
723 722 # pprint (self.graph)
724 723 # pprint (self.queue_map)
725 724 # pprint (self.all_completed)
726 725 # pprint (self.all_failed)
727 726 # print ("\n\n***********\n\n")
728 727 # update any jobs that depended on the dependency
729 728 msg_ids = self.graph.pop(dep_id, [])
730 729
731 730 # recheck *all* jobs if
732 731 # a) we have HWM and an engine just become no longer full
733 732 # or b) dep_id was given as None
734 733
735 734 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
736 735 jobs = self.queue
737 736 using_queue = True
738 737 else:
739 738 using_queue = False
740 739 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
741 740
742 741 to_restore = []
743 742 while jobs:
744 743 job = jobs.popleft()
745 744 if job.removed:
746 745 continue
747 746 msg_id = job.msg_id
748 747
749 748 put_it_back = True
750 749
751 750 if job.after.unreachable(self.all_completed, self.all_failed)\
752 751 or job.follow.unreachable(self.all_completed, self.all_failed):
753 752 self.fail_unreachable(msg_id)
754 753 put_it_back = False
755 754
756 755 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
757 756 if self.maybe_run(job):
758 757 put_it_back = False
759 758 self.queue_map.pop(msg_id)
760 759 for mid in job.dependents:
761 760 if mid in self.graph:
762 761 self.graph[mid].remove(msg_id)
763 762
764 763 # abort the loop if we just filled up all of our engines.
765 764 # avoids an O(N) operation in situation of full queue,
766 765 # where graph update is triggered as soon as an engine becomes
767 766 # non-full, and all tasks after the first are checked,
768 767 # even though they can't run.
769 768 if not self.available_engines():
770 769 break
771 770
772 771 if using_queue and put_it_back:
773 772 # popped a job from the queue but it neither ran nor failed,
774 773 # so we need to put it back when we are done
775 774 # make sure to_restore preserves the same ordering
776 775 to_restore.append(job)
777 776
778 777 # put back any tasks we popped but didn't run
779 778 if using_queue:
780 779 self.queue.extendleft(to_restore)
781 780
782 781 #----------------------------------------------------------------------
783 782 # methods to be overridden by subclasses
784 783 #----------------------------------------------------------------------
785 784
786 785 def add_job(self, idx):
787 786 """Called after self.targets[idx] just got the job with header.
788 787 Override with subclasses. The default ordering is simple LRU.
789 788 The default loads are the number of outstanding jobs."""
790 789 self.loads[idx] += 1
791 790 for lis in (self.targets, self.loads):
792 791 lis.append(lis.pop(idx))
793 792
794 793
795 794 def finish_job(self, idx):
796 795 """Called after self.targets[idx] just finished a job.
797 796 Override with subclasses."""
798 797 self.loads[idx] -= 1
799 798
800 799
801 800
802 801 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
803 802 logname='root', log_url=None, loglevel=logging.DEBUG,
804 803 identity=b'task', in_thread=False):
805 804
806 805 ZMQStream = zmqstream.ZMQStream
807 806
808 807 if config:
809 808 # unwrap dict back into Config
810 809 config = Config(config)
811 810
812 811 if in_thread:
813 812 # use instance() to get the same Context/Loop as our parent
814 813 ctx = zmq.Context.instance()
815 814 loop = ioloop.IOLoop.instance()
816 815 else:
817 816 # in a process, don't use instance()
818 817 # for safety with multiprocessing
819 818 ctx = zmq.Context()
820 819 loop = ioloop.IOLoop()
821 820 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
822 821 util.set_hwm(ins, 0)
823 822 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
824 823 ins.bind(in_addr)
825 824
826 825 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
827 826 util.set_hwm(outs, 0)
828 827 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
829 828 outs.bind(out_addr)
830 829 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
831 830 util.set_hwm(mons, 0)
832 831 mons.connect(mon_addr)
833 832 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
834 833 nots.setsockopt(zmq.SUBSCRIBE, b'')
835 834 nots.connect(not_addr)
836 835
837 836 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
838 837 querys.connect(reg_addr)
839 838
840 839 # setup logging.
841 840 if in_thread:
842 841 log = Application.instance().log
843 842 else:
844 843 if log_url:
845 844 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
846 845 else:
847 846 log = local_logger(logname, loglevel)
848 847
849 848 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
850 849 mon_stream=mons, notifier_stream=nots,
851 850 query_stream=querys,
852 851 loop=loop, log=log,
853 852 config=config)
854 853 scheduler.start()
855 854 if not in_thread:
856 855 try:
857 856 loop.start()
858 857 except KeyboardInterrupt:
859 858 scheduler.log.critical("Interrupted, exiting...")
860 859
@@ -1,422 +1,422 b''
1 1 """A TaskRecord backend using sqlite3
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import json
15 15 import os
16 16 try:
17 17 import cPickle as pickle
18 18 except ImportError:
19 19 import pickle
20 20 from datetime import datetime
21 21
22 22 try:
23 23 import sqlite3
24 24 except ImportError:
25 25 sqlite3 = None
26 26
27 27 from zmq.eventloop import ioloop
28 28
29 29 from IPython.utils.traitlets import Unicode, Instance, List, Dict
30 30 from .dictdb import BaseDB
31 31 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
32 32 from IPython.utils.py3compat import iteritems
33 33
34 34 #-----------------------------------------------------------------------------
35 35 # SQLite operators, adapters, and converters
36 36 #-----------------------------------------------------------------------------
37 37
38 38 try:
39 39 buffer
40 40 except NameError:
41 41 # py3k
42 42 buffer = memoryview
43 43
44 44 operators = {
45 45 '$lt' : "<",
46 46 '$gt' : ">",
47 47 # null is handled weird with ==,!=
48 48 '$eq' : "=",
49 49 '$ne' : "!=",
50 50 '$lte': "<=",
51 51 '$gte': ">=",
52 52 '$in' : ('=', ' OR '),
53 53 '$nin': ('!=', ' AND '),
54 54 # '$all': None,
55 55 # '$mod': None,
56 56 # '$exists' : None
57 57 }
58 58 null_operators = {
59 59 '=' : "IS NULL",
60 60 '!=' : "IS NOT NULL",
61 61 }
62 62
63 63 def _adapt_dict(d):
64 64 return json.dumps(d, default=date_default)
65 65
66 66 def _convert_dict(ds):
67 67 if ds is None:
68 68 return ds
69 69 else:
70 70 if isinstance(ds, bytes):
71 71 # If I understand the sqlite doc correctly, this will always be utf8
72 72 ds = ds.decode('utf8')
73 73 return extract_dates(json.loads(ds))
74 74
75 75 def _adapt_bufs(bufs):
76 76 # this is *horrible*
77 77 # copy buffers into single list and pickle it:
78 78 if bufs and isinstance(bufs[0], (bytes, buffer)):
79 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
79 return sqlite3.Binary(pickle.dumps(list(map(bytes, bufs)),-1))
80 80 elif bufs:
81 81 return bufs
82 82 else:
83 83 return None
84 84
85 85 def _convert_bufs(bs):
86 86 if bs is None:
87 87 return []
88 88 else:
89 89 return pickle.loads(bytes(bs))
90 90
91 91 #-----------------------------------------------------------------------------
92 92 # SQLiteDB class
93 93 #-----------------------------------------------------------------------------
94 94
95 95 class SQLiteDB(BaseDB):
96 96 """SQLite3 TaskRecord backend."""
97 97
98 98 filename = Unicode('tasks.db', config=True,
99 99 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
100 100 location = Unicode('', config=True,
101 101 help="""The directory containing the sqlite task database. The default
102 102 is to use the cluster_dir location.""")
103 103 table = Unicode("ipython-tasks", config=True,
104 104 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
105 105 a new table will be created with the Hub's IDENT. Specifying the table will result
106 106 in tasks from previous sessions being available via Clients' db_query and
107 107 get_result methods.""")
108 108
109 109 if sqlite3 is not None:
110 110 _db = Instance('sqlite3.Connection')
111 111 else:
112 112 _db = None
113 113 # the ordered list of column names
114 114 _keys = List(['msg_id' ,
115 115 'header' ,
116 116 'metadata',
117 117 'content',
118 118 'buffers',
119 119 'submitted',
120 120 'client_uuid' ,
121 121 'engine_uuid' ,
122 122 'started',
123 123 'completed',
124 124 'resubmitted',
125 125 'received',
126 126 'result_header' ,
127 127 'result_metadata',
128 128 'result_content' ,
129 129 'result_buffers' ,
130 130 'queue' ,
131 131 'pyin' ,
132 132 'pyout',
133 133 'pyerr',
134 134 'stdout',
135 135 'stderr',
136 136 ])
137 137 # sqlite datatypes for checking that db is current format
138 138 _types = Dict({'msg_id' : 'text' ,
139 139 'header' : 'dict text',
140 140 'metadata' : 'dict text',
141 141 'content' : 'dict text',
142 142 'buffers' : 'bufs blob',
143 143 'submitted' : 'timestamp',
144 144 'client_uuid' : 'text',
145 145 'engine_uuid' : 'text',
146 146 'started' : 'timestamp',
147 147 'completed' : 'timestamp',
148 148 'resubmitted' : 'text',
149 149 'received' : 'timestamp',
150 150 'result_header' : 'dict text',
151 151 'result_metadata' : 'dict text',
152 152 'result_content' : 'dict text',
153 153 'result_buffers' : 'bufs blob',
154 154 'queue' : 'text',
155 155 'pyin' : 'text',
156 156 'pyout' : 'text',
157 157 'pyerr' : 'text',
158 158 'stdout' : 'text',
159 159 'stderr' : 'text',
160 160 })
161 161
162 162 def __init__(self, **kwargs):
163 163 super(SQLiteDB, self).__init__(**kwargs)
164 164 if sqlite3 is None:
165 165 raise ImportError("SQLiteDB requires sqlite3")
166 166 if not self.table:
167 167 # use session, and prefix _, since starting with # is illegal
168 168 self.table = '_'+self.session.replace('-','_')
169 169 if not self.location:
170 170 # get current profile
171 171 from IPython.core.application import BaseIPythonApplication
172 172 if BaseIPythonApplication.initialized():
173 173 app = BaseIPythonApplication.instance()
174 174 if app.profile_dir is not None:
175 175 self.location = app.profile_dir.location
176 176 else:
177 177 self.location = u'.'
178 178 else:
179 179 self.location = u'.'
180 180 self._init_db()
181 181
182 182 # register db commit as 2s periodic callback
183 183 # to prevent clogging pipes
184 184 # assumes we are being run in a zmq ioloop app
185 185 loop = ioloop.IOLoop.instance()
186 186 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
187 187 pc.start()
188 188
189 189 def _defaults(self, keys=None):
190 190 """create an empty record"""
191 191 d = {}
192 192 keys = self._keys if keys is None else keys
193 193 for key in keys:
194 194 d[key] = None
195 195 return d
196 196
197 197 def _check_table(self):
198 198 """Ensure that an incorrect table doesn't exist
199 199
200 200 If a bad (old) table does exist, return False
201 201 """
202 202 cursor = self._db.execute("PRAGMA table_info('%s')"%self.table)
203 203 lines = cursor.fetchall()
204 204 if not lines:
205 205 # table does not exist
206 206 return True
207 207 types = {}
208 208 keys = []
209 209 for line in lines:
210 210 keys.append(line[1])
211 211 types[line[1]] = line[2]
212 212 if self._keys != keys:
213 213 # key mismatch
214 214 self.log.warn('keys mismatch')
215 215 return False
216 216 for key in self._keys:
217 217 if types[key] != self._types[key]:
218 218 self.log.warn(
219 219 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
220 220 )
221 221 return False
222 222 return True
223 223
224 224 def _init_db(self):
225 225 """Connect to the database and get new session number."""
226 226 # register adapters
227 227 sqlite3.register_adapter(dict, _adapt_dict)
228 228 sqlite3.register_converter('dict', _convert_dict)
229 229 sqlite3.register_adapter(list, _adapt_bufs)
230 230 sqlite3.register_converter('bufs', _convert_bufs)
231 231 # connect to the db
232 232 dbfile = os.path.join(self.location, self.filename)
233 233 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
234 234 # isolation_level = None)#,
235 235 cached_statements=64)
236 236 # print dir(self._db)
237 237 first_table = previous_table = self.table
238 238 i=0
239 239 while not self._check_table():
240 240 i+=1
241 241 self.table = first_table+'_%i'%i
242 242 self.log.warn(
243 243 "Table %s exists and doesn't match db format, trying %s"%
244 244 (previous_table, self.table)
245 245 )
246 246 previous_table = self.table
247 247
248 248 self._db.execute("""CREATE TABLE IF NOT EXISTS '%s'
249 249 (msg_id text PRIMARY KEY,
250 250 header dict text,
251 251 metadata dict text,
252 252 content dict text,
253 253 buffers bufs blob,
254 254 submitted timestamp,
255 255 client_uuid text,
256 256 engine_uuid text,
257 257 started timestamp,
258 258 completed timestamp,
259 259 resubmitted text,
260 260 received timestamp,
261 261 result_header dict text,
262 262 result_metadata dict text,
263 263 result_content dict text,
264 264 result_buffers bufs blob,
265 265 queue text,
266 266 pyin text,
267 267 pyout text,
268 268 pyerr text,
269 269 stdout text,
270 270 stderr text)
271 271 """%self.table)
272 272 self._db.commit()
273 273
274 274 def _dict_to_list(self, d):
275 275 """turn a mongodb-style record dict into a list."""
276 276
277 277 return [ d[key] for key in self._keys ]
278 278
279 279 def _list_to_dict(self, line, keys=None):
280 280 """Inverse of dict_to_list"""
281 281 keys = self._keys if keys is None else keys
282 282 d = self._defaults(keys)
283 283 for key,value in zip(keys, line):
284 284 d[key] = value
285 285
286 286 return d
287 287
288 288 def _render_expression(self, check):
289 289 """Turn a mongodb-style search dict into an SQL query."""
290 290 expressions = []
291 291 args = []
292 292
293 293 skeys = set(check.keys())
294 294 skeys.difference_update(set(self._keys))
295 295 skeys.difference_update(set(['buffers', 'result_buffers']))
296 296 if skeys:
297 297 raise KeyError("Illegal testing key(s): %s"%skeys)
298 298
299 299 for name,sub_check in iteritems(check):
300 300 if isinstance(sub_check, dict):
301 301 for test,value in iteritems(sub_check):
302 302 try:
303 303 op = operators[test]
304 304 except KeyError:
305 305 raise KeyError("Unsupported operator: %r"%test)
306 306 if isinstance(op, tuple):
307 307 op, join = op
308 308
309 309 if value is None and op in null_operators:
310 310 expr = "%s %s" % (name, null_operators[op])
311 311 else:
312 312 expr = "%s %s ?"%(name, op)
313 313 if isinstance(value, (tuple,list)):
314 314 if op in null_operators and any([v is None for v in value]):
315 315 # equality tests don't work with NULL
316 316 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
317 317 expr = '( %s )'%( join.join([expr]*len(value)) )
318 318 args.extend(value)
319 319 else:
320 320 args.append(value)
321 321 expressions.append(expr)
322 322 else:
323 323 # it's an equality check
324 324 if sub_check is None:
325 325 expressions.append("%s IS NULL" % name)
326 326 else:
327 327 expressions.append("%s = ?"%name)
328 328 args.append(sub_check)
329 329
330 330 expr = " AND ".join(expressions)
331 331 return expr, args
332 332
333 333 def add_record(self, msg_id, rec):
334 334 """Add a new Task Record, by msg_id."""
335 335 d = self._defaults()
336 336 d.update(rec)
337 337 d['msg_id'] = msg_id
338 338 line = self._dict_to_list(d)
339 339 tups = '(%s)'%(','.join(['?']*len(line)))
340 340 self._db.execute("INSERT INTO '%s' VALUES %s"%(self.table, tups), line)
341 341 # self._db.commit()
342 342
343 343 def get_record(self, msg_id):
344 344 """Get a specific Task Record, by msg_id."""
345 345 cursor = self._db.execute("""SELECT * FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
346 346 line = cursor.fetchone()
347 347 if line is None:
348 348 raise KeyError("No such msg: %r"%msg_id)
349 349 return self._list_to_dict(line)
350 350
351 351 def update_record(self, msg_id, rec):
352 352 """Update the data in an existing record."""
353 353 query = "UPDATE '%s' SET "%self.table
354 354 sets = []
355 355 keys = sorted(rec.keys())
356 356 values = []
357 357 for key in keys:
358 358 sets.append('%s = ?'%key)
359 359 values.append(rec[key])
360 360 query += ', '.join(sets)
361 361 query += ' WHERE msg_id == ?'
362 362 values.append(msg_id)
363 363 self._db.execute(query, values)
364 364 # self._db.commit()
365 365
366 366 def drop_record(self, msg_id):
367 367 """Remove a record from the DB."""
368 368 self._db.execute("""DELETE FROM '%s' WHERE msg_id==?"""%self.table, (msg_id,))
369 369 # self._db.commit()
370 370
371 371 def drop_matching_records(self, check):
372 372 """Remove a record from the DB."""
373 373 expr,args = self._render_expression(check)
374 374 query = "DELETE FROM '%s' WHERE %s"%(self.table, expr)
375 375 self._db.execute(query,args)
376 376 # self._db.commit()
377 377
378 378 def find_records(self, check, keys=None):
379 379 """Find records matching a query dict, optionally extracting subset of keys.
380 380
381 381 Returns list of matching records.
382 382
383 383 Parameters
384 384 ----------
385 385
386 386 check: dict
387 387 mongodb-style query argument
388 388 keys: list of strs [optional]
389 389 if specified, the subset of keys to extract. msg_id will *always* be
390 390 included.
391 391 """
392 392 if keys:
393 393 bad_keys = [ key for key in keys if key not in self._keys ]
394 394 if bad_keys:
395 395 raise KeyError("Bad record key(s): %s"%bad_keys)
396 396
397 397 if keys:
398 398 # ensure msg_id is present and first:
399 399 if 'msg_id' in keys:
400 400 keys.remove('msg_id')
401 401 keys.insert(0, 'msg_id')
402 402 req = ', '.join(keys)
403 403 else:
404 404 req = '*'
405 405 expr,args = self._render_expression(check)
406 406 query = """SELECT %s FROM '%s' WHERE %s"""%(req, self.table, expr)
407 407 cursor = self._db.execute(query, args)
408 408 matches = cursor.fetchall()
409 409 records = []
410 410 for line in matches:
411 411 rec = self._list_to_dict(line, keys)
412 412 records.append(rec)
413 413 return records
414 414
415 415 def get_history(self):
416 416 """get all msg_ids, ordered by time submitted."""
417 417 query = """SELECT msg_id FROM '%s' ORDER by submitted ASC"""%self.table
418 418 cursor = self._db.execute(query)
419 419 # will be a list of length 1 tuples
420 420 return [ tup[0] for tup in cursor.fetchall()]
421 421
422 422 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,326 +1,326 b''
1 1 """Tests for asyncresult.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import time
20 20
21 21 import nose.tools as nt
22 22
23 23 from IPython.utils.io import capture_output
24 24
25 25 from IPython.parallel.error import TimeoutError
26 26 from IPython.parallel import error, Client
27 27 from IPython.parallel.tests import add_engines
28 28 from .clienttest import ClusterTestCase
29 29 from IPython.utils.py3compat import iteritems
30 30
31 31 def setup():
32 32 add_engines(2, total=True)
33 33
34 34 def wait(n):
35 35 import time
36 36 time.sleep(n)
37 37 return n
38 38
39 39 def echo(x):
40 40 return x
41 41
42 42 class AsyncResultTest(ClusterTestCase):
43 43
44 44 def test_single_result_view(self):
45 45 """various one-target views get the right value for single_result"""
46 46 eid = self.client.ids[-1]
47 47 ar = self.client[eid].apply_async(lambda : 42)
48 48 self.assertEqual(ar.get(), 42)
49 49 ar = self.client[[eid]].apply_async(lambda : 42)
50 50 self.assertEqual(ar.get(), [42])
51 51 ar = self.client[-1:].apply_async(lambda : 42)
52 52 self.assertEqual(ar.get(), [42])
53 53
54 54 def test_get_after_done(self):
55 55 ar = self.client[-1].apply_async(lambda : 42)
56 56 ar.wait()
57 57 self.assertTrue(ar.ready())
58 58 self.assertEqual(ar.get(), 42)
59 59 self.assertEqual(ar.get(), 42)
60 60
61 61 def test_get_before_done(self):
62 62 ar = self.client[-1].apply_async(wait, 0.1)
63 63 self.assertRaises(TimeoutError, ar.get, 0)
64 64 ar.wait(0)
65 65 self.assertFalse(ar.ready())
66 66 self.assertEqual(ar.get(), 0.1)
67 67
68 68 def test_get_after_error(self):
69 69 ar = self.client[-1].apply_async(lambda : 1/0)
70 70 ar.wait(10)
71 71 self.assertRaisesRemote(ZeroDivisionError, ar.get)
72 72 self.assertRaisesRemote(ZeroDivisionError, ar.get)
73 73 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
74 74
75 75 def test_get_dict(self):
76 76 n = len(self.client)
77 77 ar = self.client[:].apply_async(lambda : 5)
78 78 self.assertEqual(ar.get(), [5]*n)
79 79 d = ar.get_dict()
80 80 self.assertEqual(sorted(d.keys()), sorted(self.client.ids))
81 81 for eid,r in iteritems(d):
82 82 self.assertEqual(r, 5)
83 83
84 84 def test_get_dict_single(self):
85 85 view = self.client[-1]
86 for v in (range(5), 5, ('abc', 'def'), 'string'):
86 for v in (list(range(5)), 5, ('abc', 'def'), 'string'):
87 87 ar = view.apply_async(echo, v)
88 88 self.assertEqual(ar.get(), v)
89 89 d = ar.get_dict()
90 90 self.assertEqual(d, {view.targets : v})
91 91
92 92 def test_get_dict_bad(self):
93 93 ar = self.client[:].apply_async(lambda : 5)
94 94 ar2 = self.client[:].apply_async(lambda : 5)
95 95 ar = self.client.get_result(ar.msg_ids + ar2.msg_ids)
96 96 self.assertRaises(ValueError, ar.get_dict)
97 97
98 98 def test_list_amr(self):
99 99 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
100 100 rlist = list(ar)
101 101
102 102 def test_getattr(self):
103 103 ar = self.client[:].apply_async(wait, 0.5)
104 104 self.assertEqual(ar.engine_id, [None] * len(ar))
105 105 self.assertRaises(AttributeError, lambda : ar._foo)
106 106 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
107 107 self.assertRaises(AttributeError, lambda : ar.foo)
108 108 self.assertFalse(hasattr(ar, '__length_hint__'))
109 109 self.assertFalse(hasattr(ar, 'foo'))
110 110 self.assertTrue(hasattr(ar, 'engine_id'))
111 111 ar.get(5)
112 112 self.assertRaises(AttributeError, lambda : ar._foo)
113 113 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
114 114 self.assertRaises(AttributeError, lambda : ar.foo)
115 115 self.assertTrue(isinstance(ar.engine_id, list))
116 116 self.assertEqual(ar.engine_id, ar['engine_id'])
117 117 self.assertFalse(hasattr(ar, '__length_hint__'))
118 118 self.assertFalse(hasattr(ar, 'foo'))
119 119 self.assertTrue(hasattr(ar, 'engine_id'))
120 120
121 121 def test_getitem(self):
122 122 ar = self.client[:].apply_async(wait, 0.5)
123 123 self.assertEqual(ar['engine_id'], [None] * len(ar))
124 124 self.assertRaises(KeyError, lambda : ar['foo'])
125 125 ar.get(5)
126 126 self.assertRaises(KeyError, lambda : ar['foo'])
127 127 self.assertTrue(isinstance(ar['engine_id'], list))
128 128 self.assertEqual(ar.engine_id, ar['engine_id'])
129 129
130 130 def test_single_result(self):
131 131 ar = self.client[-1].apply_async(wait, 0.5)
132 132 self.assertRaises(KeyError, lambda : ar['foo'])
133 133 self.assertEqual(ar['engine_id'], None)
134 134 self.assertTrue(ar.get(5) == 0.5)
135 135 self.assertTrue(isinstance(ar['engine_id'], int))
136 136 self.assertTrue(isinstance(ar.engine_id, int))
137 137 self.assertEqual(ar.engine_id, ar['engine_id'])
138 138
139 139 def test_abort(self):
140 140 e = self.client[-1]
141 141 ar = e.execute('import time; time.sleep(1)', block=False)
142 142 ar2 = e.apply_async(lambda : 2)
143 143 ar2.abort()
144 144 self.assertRaises(error.TaskAborted, ar2.get)
145 145 ar.get()
146 146
147 147 def test_len(self):
148 148 v = self.client.load_balanced_view()
149 ar = v.map_async(lambda x: x, range(10))
149 ar = v.map_async(lambda x: x, list(range(10)))
150 150 self.assertEqual(len(ar), 10)
151 ar = v.apply_async(lambda x: x, range(10))
151 ar = v.apply_async(lambda x: x, list(range(10)))
152 152 self.assertEqual(len(ar), 1)
153 ar = self.client[:].apply_async(lambda x: x, range(10))
153 ar = self.client[:].apply_async(lambda x: x, list(range(10)))
154 154 self.assertEqual(len(ar), len(self.client.ids))
155 155
156 156 def test_wall_time_single(self):
157 157 v = self.client.load_balanced_view()
158 158 ar = v.apply_async(time.sleep, 0.25)
159 159 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
160 160 ar.get(2)
161 161 self.assertTrue(ar.wall_time < 1.)
162 162 self.assertTrue(ar.wall_time > 0.2)
163 163
164 164 def test_wall_time_multi(self):
165 165 self.minimum_engines(4)
166 166 v = self.client[:]
167 167 ar = v.apply_async(time.sleep, 0.25)
168 168 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
169 169 ar.get(2)
170 170 self.assertTrue(ar.wall_time < 1.)
171 171 self.assertTrue(ar.wall_time > 0.2)
172 172
173 173 def test_serial_time_single(self):
174 174 v = self.client.load_balanced_view()
175 175 ar = v.apply_async(time.sleep, 0.25)
176 176 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
177 177 ar.get(2)
178 178 self.assertTrue(ar.serial_time < 1.)
179 179 self.assertTrue(ar.serial_time > 0.2)
180 180
181 181 def test_serial_time_multi(self):
182 182 self.minimum_engines(4)
183 183 v = self.client[:]
184 184 ar = v.apply_async(time.sleep, 0.25)
185 185 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
186 186 ar.get(2)
187 187 self.assertTrue(ar.serial_time < 2.)
188 188 self.assertTrue(ar.serial_time > 0.8)
189 189
190 190 def test_elapsed_single(self):
191 191 v = self.client.load_balanced_view()
192 192 ar = v.apply_async(time.sleep, 0.25)
193 193 while not ar.ready():
194 194 time.sleep(0.01)
195 195 self.assertTrue(ar.elapsed < 1)
196 196 self.assertTrue(ar.elapsed < 1)
197 197 ar.get(2)
198 198
199 199 def test_elapsed_multi(self):
200 200 v = self.client[:]
201 201 ar = v.apply_async(time.sleep, 0.25)
202 202 while not ar.ready():
203 203 time.sleep(0.01)
204 204 self.assertTrue(ar.elapsed < 1)
205 205 self.assertTrue(ar.elapsed < 1)
206 206 ar.get(2)
207 207
208 208 def test_hubresult_timestamps(self):
209 209 self.minimum_engines(4)
210 210 v = self.client[:]
211 211 ar = v.apply_async(time.sleep, 0.25)
212 212 ar.get(2)
213 213 rc2 = Client(profile='iptest')
214 214 # must have try/finally to close second Client, otherwise
215 215 # will have dangling sockets causing problems
216 216 try:
217 217 time.sleep(0.25)
218 218 hr = rc2.get_result(ar.msg_ids)
219 219 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
220 220 hr.get(1)
221 221 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
222 222 self.assertEqual(hr.serial_time, ar.serial_time)
223 223 finally:
224 224 rc2.close()
225 225
226 226 def test_display_empty_streams_single(self):
227 227 """empty stdout/err are not displayed (single result)"""
228 228 self.minimum_engines(1)
229 229
230 230 v = self.client[-1]
231 231 ar = v.execute("print (5555)")
232 232 ar.get(5)
233 233 with capture_output() as io:
234 234 ar.display_outputs()
235 235 self.assertEqual(io.stderr, '')
236 236 self.assertEqual('5555\n', io.stdout)
237 237
238 238 ar = v.execute("a=5")
239 239 ar.get(5)
240 240 with capture_output() as io:
241 241 ar.display_outputs()
242 242 self.assertEqual(io.stderr, '')
243 243 self.assertEqual(io.stdout, '')
244 244
245 245 def test_display_empty_streams_type(self):
246 246 """empty stdout/err are not displayed (groupby type)"""
247 247 self.minimum_engines(1)
248 248
249 249 v = self.client[:]
250 250 ar = v.execute("print (5555)")
251 251 ar.get(5)
252 252 with capture_output() as io:
253 253 ar.display_outputs()
254 254 self.assertEqual(io.stderr, '')
255 255 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
256 256 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 257 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
258 258
259 259 ar = v.execute("a=5")
260 260 ar.get(5)
261 261 with capture_output() as io:
262 262 ar.display_outputs()
263 263 self.assertEqual(io.stderr, '')
264 264 self.assertEqual(io.stdout, '')
265 265
266 266 def test_display_empty_streams_engine(self):
267 267 """empty stdout/err are not displayed (groupby engine)"""
268 268 self.minimum_engines(1)
269 269
270 270 v = self.client[:]
271 271 ar = v.execute("print (5555)")
272 272 ar.get(5)
273 273 with capture_output() as io:
274 274 ar.display_outputs('engine')
275 275 self.assertEqual(io.stderr, '')
276 276 self.assertEqual(io.stdout.count('5555'), len(v), io.stdout)
277 277 self.assertFalse('\n\n' in io.stdout, io.stdout)
278 278 self.assertEqual(io.stdout.count('[stdout:'), len(v), io.stdout)
279 279
280 280 ar = v.execute("a=5")
281 281 ar.get(5)
282 282 with capture_output() as io:
283 283 ar.display_outputs('engine')
284 284 self.assertEqual(io.stderr, '')
285 285 self.assertEqual(io.stdout, '')
286 286
287 287 def test_await_data(self):
288 288 """asking for ar.data flushes outputs"""
289 289 self.minimum_engines(1)
290 290
291 291 v = self.client[-1]
292 292 ar = v.execute('\n'.join([
293 293 "import time",
294 294 "from IPython.kernel.zmq.datapub import publish_data",
295 295 "for i in range(5):",
296 296 " publish_data(dict(i=i))",
297 297 " time.sleep(0.1)",
298 298 ]), block=False)
299 299 found = set()
300 300 tic = time.time()
301 301 # timeout after 10s
302 302 while time.time() <= tic + 10:
303 303 if ar.data:
304 304 found.add(ar.data['i'])
305 305 if ar.data['i'] == 4:
306 306 break
307 307 time.sleep(0.05)
308 308
309 309 ar.get(5)
310 310 nt.assert_in(4, found)
311 311 self.assertTrue(len(found) > 1, "should have seen data multiple times, but got: %s" % found)
312 312
313 313 def test_not_single_result(self):
314 314 save_build = self.client._build_targets
315 315 def single_engine(*a, **kw):
316 316 idents, targets = save_build(*a, **kw)
317 317 return idents[:1], targets[:1]
318 318 ids = single_engine('all')[1]
319 319 self.client._build_targets = single_engine
320 320 for targets in ('all', None, ids):
321 321 dv = self.client.direct_view(targets=targets)
322 322 ar = dv.apply_async(lambda : 5)
323 323 self.assertEqual(ar.get(10), [5])
324 324 self.client._build_targets = save_build
325 325
326 326
@@ -1,517 +1,517 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython import parallel
28 28 from IPython.parallel.client import client as clientmod
29 29 from IPython.parallel import error
30 30 from IPython.parallel import AsyncResult, AsyncHubResult
31 31 from IPython.parallel import LoadBalancedView, DirectView
32 32
33 33 from .clienttest import ClusterTestCase, segfault, wait, add_engines
34 34
35 35 def setup():
36 36 add_engines(4, total=True)
37 37
38 38 class TestClient(ClusterTestCase):
39 39
40 40 def test_ids(self):
41 41 n = len(self.client.ids)
42 42 self.add_engines(2)
43 43 self.assertEqual(len(self.client.ids), n+2)
44 44
45 45 def test_view_indexing(self):
46 46 """test index access for views"""
47 47 self.minimum_engines(4)
48 48 targets = self.client._build_targets('all')[-1]
49 49 v = self.client[:]
50 50 self.assertEqual(v.targets, targets)
51 51 t = self.client.ids[2]
52 52 v = self.client[t]
53 53 self.assertTrue(isinstance(v, DirectView))
54 54 self.assertEqual(v.targets, t)
55 55 t = self.client.ids[2:4]
56 56 v = self.client[t]
57 57 self.assertTrue(isinstance(v, DirectView))
58 58 self.assertEqual(v.targets, t)
59 59 v = self.client[::2]
60 60 self.assertTrue(isinstance(v, DirectView))
61 61 self.assertEqual(v.targets, targets[::2])
62 62 v = self.client[1::3]
63 63 self.assertTrue(isinstance(v, DirectView))
64 64 self.assertEqual(v.targets, targets[1::3])
65 65 v = self.client[:-3]
66 66 self.assertTrue(isinstance(v, DirectView))
67 67 self.assertEqual(v.targets, targets[:-3])
68 68 v = self.client[-1]
69 69 self.assertTrue(isinstance(v, DirectView))
70 70 self.assertEqual(v.targets, targets[-1])
71 71 self.assertRaises(TypeError, lambda : self.client[None])
72 72
73 73 def test_lbview_targets(self):
74 74 """test load_balanced_view targets"""
75 75 v = self.client.load_balanced_view()
76 76 self.assertEqual(v.targets, None)
77 77 v = self.client.load_balanced_view(-1)
78 78 self.assertEqual(v.targets, [self.client.ids[-1]])
79 79 v = self.client.load_balanced_view('all')
80 80 self.assertEqual(v.targets, None)
81 81
82 82 def test_dview_targets(self):
83 83 """test direct_view targets"""
84 84 v = self.client.direct_view()
85 85 self.assertEqual(v.targets, 'all')
86 86 v = self.client.direct_view('all')
87 87 self.assertEqual(v.targets, 'all')
88 88 v = self.client.direct_view(-1)
89 89 self.assertEqual(v.targets, self.client.ids[-1])
90 90
91 91 def test_lazy_all_targets(self):
92 92 """test lazy evaluation of rc.direct_view('all')"""
93 93 v = self.client.direct_view()
94 94 self.assertEqual(v.targets, 'all')
95 95
96 96 def double(x):
97 97 return x*2
98 seq = range(100)
98 seq = list(range(100))
99 99 ref = [ double(x) for x in seq ]
100 100
101 101 # add some engines, which should be used
102 102 self.add_engines(1)
103 103 n1 = len(self.client.ids)
104 104
105 105 # simple apply
106 106 r = v.apply_sync(lambda : 1)
107 107 self.assertEqual(r, [1] * n1)
108 108
109 109 # map goes through remotefunction
110 110 r = v.map_sync(double, seq)
111 111 self.assertEqual(r, ref)
112 112
113 113 # add a couple more engines, and try again
114 114 self.add_engines(2)
115 115 n2 = len(self.client.ids)
116 116 self.assertNotEqual(n2, n1)
117 117
118 118 # apply
119 119 r = v.apply_sync(lambda : 1)
120 120 self.assertEqual(r, [1] * n2)
121 121
122 122 # map
123 123 r = v.map_sync(double, seq)
124 124 self.assertEqual(r, ref)
125 125
126 126 def test_targets(self):
127 127 """test various valid targets arguments"""
128 128 build = self.client._build_targets
129 129 ids = self.client.ids
130 130 idents,targets = build(None)
131 131 self.assertEqual(ids, targets)
132 132
133 133 def test_clear(self):
134 134 """test clear behavior"""
135 135 self.minimum_engines(2)
136 136 v = self.client[:]
137 137 v.block=True
138 138 v.push(dict(a=5))
139 139 v.pull('a')
140 140 id0 = self.client.ids[-1]
141 141 self.client.clear(targets=id0, block=True)
142 142 a = self.client[:-1].get('a')
143 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 144 self.client.clear(block=True)
145 145 for i in self.client.ids:
146 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 147
148 148 def test_get_result(self):
149 149 """test getting results from the Hub."""
150 150 c = clientmod.Client(profile='iptest')
151 151 t = c.ids[-1]
152 152 ar = c[t].apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = self.client.get_result(ar.msg_ids[0])
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEqual(ahr.get(), ar.get())
158 158 ar2 = self.client.get_result(ar.msg_ids[0])
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.close()
161 161
162 162 def test_get_execute_result(self):
163 163 """test getting execute results from the Hub."""
164 164 c = clientmod.Client(profile='iptest')
165 165 t = c.ids[-1]
166 166 cell = '\n'.join([
167 167 'import time',
168 168 'time.sleep(0.25)',
169 169 '5'
170 170 ])
171 171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 172 # give the monitor time to notice the message
173 173 time.sleep(.25)
174 174 ahr = self.client.get_result(ar.msg_ids[0])
175 175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 176 self.assertEqual(ahr.get().pyout, ar.get().pyout)
177 177 ar2 = self.client.get_result(ar.msg_ids[0])
178 178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 179 c.close()
180 180
181 181 def test_ids_list(self):
182 182 """test client.ids"""
183 183 ids = self.client.ids
184 184 self.assertEqual(ids, self.client._ids)
185 185 self.assertFalse(ids is self.client._ids)
186 186 ids.remove(ids[-1])
187 187 self.assertNotEqual(ids, self.client._ids)
188 188
189 189 def test_queue_status(self):
190 190 ids = self.client.ids
191 191 id0 = ids[0]
192 192 qs = self.client.queue_status(targets=id0)
193 193 self.assertTrue(isinstance(qs, dict))
194 194 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
195 195 allqs = self.client.queue_status()
196 196 self.assertTrue(isinstance(allqs, dict))
197 197 intkeys = list(allqs.keys())
198 198 intkeys.remove('unassigned')
199 199 self.assertEqual(sorted(intkeys), sorted(self.client.ids))
200 200 unassigned = allqs.pop('unassigned')
201 201 for eid,qs in allqs.items():
202 202 self.assertTrue(isinstance(qs, dict))
203 203 self.assertEqual(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
204 204
205 205 def test_shutdown(self):
206 206 ids = self.client.ids
207 207 id0 = ids[0]
208 208 self.client.shutdown(id0, block=True)
209 209 while id0 in self.client.ids:
210 210 time.sleep(0.1)
211 211 self.client.spin()
212 212
213 213 self.assertRaises(IndexError, lambda : self.client[id0])
214 214
215 215 def test_result_status(self):
216 216 pass
217 217 # to be written
218 218
219 219 def test_db_query_dt(self):
220 220 """test db query by date"""
221 221 hist = self.client.hub_history()
222 222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
223 223 tic = middle['submitted']
224 224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
225 225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
226 226 self.assertEqual(len(before)+len(after),len(hist))
227 227 for b in before:
228 228 self.assertTrue(b['submitted'] < tic)
229 229 for a in after:
230 230 self.assertTrue(a['submitted'] >= tic)
231 231 same = self.client.db_query({'submitted' : tic})
232 232 for s in same:
233 233 self.assertTrue(s['submitted'] == tic)
234 234
235 235 def test_db_query_keys(self):
236 236 """test extracting subset of record keys"""
237 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
238 238 for rec in found:
239 239 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
240 240
241 241 def test_db_query_default_keys(self):
242 242 """default db_query excludes buffers"""
243 243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 244 for rec in found:
245 245 keys = set(rec.keys())
246 246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
247 247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
248 248
249 249 def test_db_query_msg_id(self):
250 250 """ensure msg_id is always in db queries"""
251 251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
252 252 for rec in found:
253 253 self.assertTrue('msg_id' in rec.keys())
254 254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
255 255 for rec in found:
256 256 self.assertTrue('msg_id' in rec.keys())
257 257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
258 258 for rec in found:
259 259 self.assertTrue('msg_id' in rec.keys())
260 260
261 261 def test_db_query_get_result(self):
262 262 """pop in db_query shouldn't pop from result itself"""
263 263 self.client[:].apply_sync(lambda : 1)
264 264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
265 265 rc2 = clientmod.Client(profile='iptest')
266 266 # If this bug is not fixed, this call will hang:
267 267 ar = rc2.get_result(self.client.history[-1])
268 268 ar.wait(2)
269 269 self.assertTrue(ar.ready())
270 270 ar.get()
271 271 rc2.close()
272 272
273 273 def test_db_query_in(self):
274 274 """test db query with '$in','$nin' operators"""
275 275 hist = self.client.hub_history()
276 276 even = hist[::2]
277 277 odd = hist[1::2]
278 278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
279 279 found = [ r['msg_id'] for r in recs ]
280 280 self.assertEqual(set(even), set(found))
281 281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
282 282 found = [ r['msg_id'] for r in recs ]
283 283 self.assertEqual(set(odd), set(found))
284 284
285 285 def test_hub_history(self):
286 286 hist = self.client.hub_history()
287 287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
288 288 recdict = {}
289 289 for rec in recs:
290 290 recdict[rec['msg_id']] = rec
291 291
292 292 latest = datetime(1984,1,1)
293 293 for msg_id in hist:
294 294 rec = recdict[msg_id]
295 295 newt = rec['submitted']
296 296 self.assertTrue(newt >= latest)
297 297 latest = newt
298 298 ar = self.client[-1].apply_async(lambda : 1)
299 299 ar.get()
300 300 time.sleep(0.25)
301 301 self.assertEqual(self.client.hub_history()[-1:],ar.msg_ids)
302 302
303 303 def _wait_for_idle(self):
304 304 """wait for an engine to become idle, according to the Hub"""
305 305 rc = self.client
306 306
307 307 # step 1. wait for all requests to be noticed
308 308 # timeout 5s, polling every 100ms
309 309 msg_ids = set(rc.history)
310 310 hub_hist = rc.hub_history()
311 311 for i in range(50):
312 312 if msg_ids.difference(hub_hist):
313 313 time.sleep(0.1)
314 314 hub_hist = rc.hub_history()
315 315 else:
316 316 break
317 317
318 318 self.assertEqual(len(msg_ids.difference(hub_hist)), 0)
319 319
320 320 # step 2. wait for all requests to be done
321 321 # timeout 5s, polling every 100ms
322 322 qs = rc.queue_status()
323 323 for i in range(50):
324 324 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
325 325 time.sleep(0.1)
326 326 qs = rc.queue_status()
327 327 else:
328 328 break
329 329
330 330 # ensure Hub up to date:
331 331 self.assertEqual(qs['unassigned'], 0)
332 332 for eid in rc.ids:
333 333 self.assertEqual(qs[eid]['tasks'], 0)
334 334
335 335
336 336 def test_resubmit(self):
337 337 def f():
338 338 import random
339 339 return random.random()
340 340 v = self.client.load_balanced_view()
341 341 ar = v.apply_async(f)
342 342 r1 = ar.get(1)
343 343 # give the Hub a chance to notice:
344 344 self._wait_for_idle()
345 345 ahr = self.client.resubmit(ar.msg_ids)
346 346 r2 = ahr.get(1)
347 347 self.assertFalse(r1 == r2)
348 348
349 349 def test_resubmit_chain(self):
350 350 """resubmit resubmitted tasks"""
351 351 v = self.client.load_balanced_view()
352 352 ar = v.apply_async(lambda x: x, 'x'*1024)
353 353 ar.get()
354 354 self._wait_for_idle()
355 355 ars = [ar]
356 356
357 357 for i in range(10):
358 358 ar = ars[-1]
359 359 ar2 = self.client.resubmit(ar.msg_ids)
360 360
361 361 [ ar.get() for ar in ars ]
362 362
363 363 def test_resubmit_header(self):
364 364 """resubmit shouldn't clobber the whole header"""
365 365 def f():
366 366 import random
367 367 return random.random()
368 368 v = self.client.load_balanced_view()
369 369 v.retries = 1
370 370 ar = v.apply_async(f)
371 371 r1 = ar.get(1)
372 372 # give the Hub a chance to notice:
373 373 self._wait_for_idle()
374 374 ahr = self.client.resubmit(ar.msg_ids)
375 375 ahr.get(1)
376 376 time.sleep(0.5)
377 377 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
378 378 h1,h2 = [ r['header'] for r in records ]
379 379 for key in set(h1.keys()).union(set(h2.keys())):
380 380 if key in ('msg_id', 'date'):
381 381 self.assertNotEqual(h1[key], h2[key])
382 382 else:
383 383 self.assertEqual(h1[key], h2[key])
384 384
385 385 def test_resubmit_aborted(self):
386 386 def f():
387 387 import random
388 388 return random.random()
389 389 v = self.client.load_balanced_view()
390 390 # restrict to one engine, so we can put a sleep
391 391 # ahead of the task, so it will get aborted
392 392 eid = self.client.ids[-1]
393 393 v.targets = [eid]
394 394 sleep = v.apply_async(time.sleep, 0.5)
395 395 ar = v.apply_async(f)
396 396 ar.abort()
397 397 self.assertRaises(error.TaskAborted, ar.get)
398 398 # Give the Hub a chance to get up to date:
399 399 self._wait_for_idle()
400 400 ahr = self.client.resubmit(ar.msg_ids)
401 401 r2 = ahr.get(1)
402 402
403 403 def test_resubmit_inflight(self):
404 404 """resubmit of inflight task"""
405 405 v = self.client.load_balanced_view()
406 406 ar = v.apply_async(time.sleep,1)
407 407 # give the message a chance to arrive
408 408 time.sleep(0.2)
409 409 ahr = self.client.resubmit(ar.msg_ids)
410 410 ar.get(2)
411 411 ahr.get(2)
412 412
413 413 def test_resubmit_badkey(self):
414 414 """ensure KeyError on resubmit of nonexistant task"""
415 415 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
416 416
417 417 def test_purge_hub_results(self):
418 418 # ensure there are some tasks
419 419 for i in range(5):
420 420 self.client[:].apply_sync(lambda : 1)
421 421 # Wait for the Hub to realise the result is done:
422 422 # This prevents a race condition, where we
423 423 # might purge a result the Hub still thinks is pending.
424 424 self._wait_for_idle()
425 425 rc2 = clientmod.Client(profile='iptest')
426 426 hist = self.client.hub_history()
427 427 ahr = rc2.get_result([hist[-1]])
428 428 ahr.wait(10)
429 429 self.client.purge_hub_results(hist[-1])
430 430 newhist = self.client.hub_history()
431 431 self.assertEqual(len(newhist)+1,len(hist))
432 432 rc2.spin()
433 433 rc2.close()
434 434
435 435 def test_purge_local_results(self):
436 436 # ensure there are some tasks
437 437 res = []
438 438 for i in range(5):
439 439 res.append(self.client[:].apply_async(lambda : 1))
440 440 self._wait_for_idle()
441 441 self.client.wait(10) # wait for the results to come back
442 442 before = len(self.client.results)
443 443 self.assertEqual(len(self.client.metadata),before)
444 444 self.client.purge_local_results(res[-1])
445 445 self.assertEqual(len(self.client.results),before-len(res[-1]), msg="Not removed from results")
446 446 self.assertEqual(len(self.client.metadata),before-len(res[-1]), msg="Not removed from metadata")
447 447
448 448 def test_purge_all_hub_results(self):
449 449 self.client.purge_hub_results('all')
450 450 hist = self.client.hub_history()
451 451 self.assertEqual(len(hist), 0)
452 452
453 453 def test_purge_all_local_results(self):
454 454 self.client.purge_local_results('all')
455 455 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
456 456 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
457 457
458 458 def test_purge_all_results(self):
459 459 # ensure there are some tasks
460 460 for i in range(5):
461 461 self.client[:].apply_sync(lambda : 1)
462 462 self.client.wait(10)
463 463 self._wait_for_idle()
464 464 self.client.purge_results('all')
465 465 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
466 466 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
467 467 hist = self.client.hub_history()
468 468 self.assertEqual(len(hist), 0, msg="hub history not empty")
469 469
470 470 def test_purge_everything(self):
471 471 # ensure there are some tasks
472 472 for i in range(5):
473 473 self.client[:].apply_sync(lambda : 1)
474 474 self.client.wait(10)
475 475 self._wait_for_idle()
476 476 self.client.purge_everything()
477 477 # The client results
478 478 self.assertEqual(len(self.client.results), 0, msg="Results not empty")
479 479 self.assertEqual(len(self.client.metadata), 0, msg="metadata not empty")
480 480 # The client "bookkeeping"
481 481 self.assertEqual(len(self.client.session.digest_history), 0, msg="session digest not empty")
482 482 self.assertEqual(len(self.client.history), 0, msg="client history not empty")
483 483 # the hub results
484 484 hist = self.client.hub_history()
485 485 self.assertEqual(len(hist), 0, msg="hub history not empty")
486 486
487 487
488 488 def test_spin_thread(self):
489 489 self.client.spin_thread(0.01)
490 490 ar = self.client[-1].apply_async(lambda : 1)
491 491 time.sleep(0.1)
492 492 self.assertTrue(ar.wall_time < 0.1,
493 493 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
494 494 )
495 495
496 496 def test_stop_spin_thread(self):
497 497 self.client.spin_thread(0.01)
498 498 self.client.stop_spin_thread()
499 499 ar = self.client[-1].apply_async(lambda : 1)
500 500 time.sleep(0.15)
501 501 self.assertTrue(ar.wall_time > 0.1,
502 502 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
503 503 )
504 504
505 505 def test_activate(self):
506 506 ip = get_ipython()
507 507 magics = ip.magics_manager.magics
508 508 self.assertTrue('px' in magics['line'])
509 509 self.assertTrue('px' in magics['cell'])
510 510 v0 = self.client.activate(-1, '0')
511 511 self.assertTrue('px0' in magics['line'])
512 512 self.assertTrue('px0' in magics['cell'])
513 513 self.assertEqual(v0.targets, self.client.ids[-1])
514 514 v0 = self.client.activate('all', 'all')
515 515 self.assertTrue('pxall' in magics['line'])
516 516 self.assertTrue('pxall' in magics['cell'])
517 517 self.assertEqual(v0.targets, 'all')
@@ -1,136 +1,136 b''
1 1 """Tests for dependency.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 __docformat__ = "restructuredtext en"
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Copyright (C) 2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-------------------------------------------------------------------------------
16 16
17 17 #-------------------------------------------------------------------------------
18 18 # Imports
19 19 #-------------------------------------------------------------------------------
20 20
21 21 # import
22 22 import os
23 23
24 24 from IPython.utils.pickleutil import can, uncan
25 25
26 26 import IPython.parallel as pmod
27 27 from IPython.parallel.util import interactive
28 28
29 29 from IPython.parallel.tests import add_engines
30 30 from .clienttest import ClusterTestCase
31 31
32 32 def setup():
33 33 add_engines(1, total=True)
34 34
35 35 @pmod.require('time')
36 36 def wait(n):
37 37 time.sleep(n)
38 38 return n
39 39
40 40 @pmod.interactive
41 41 def func(x):
42 42 return x*x
43 43
44 mixed = map(str, range(10))
45 completed = map(str, range(0,10,2))
46 failed = map(str, range(1,10,2))
44 mixed = list(map(str, range(10)))
45 completed = list(map(str, range(0,10,2)))
46 failed = list(map(str, range(1,10,2)))
47 47
48 48 class DependencyTest(ClusterTestCase):
49 49
50 50 def setUp(self):
51 51 ClusterTestCase.setUp(self)
52 52 self.user_ns = {'__builtins__' : __builtins__}
53 53 self.view = self.client.load_balanced_view()
54 54 self.dview = self.client[-1]
55 55 self.succeeded = set(map(str, range(0,25,2)))
56 56 self.failed = set(map(str, range(1,25,2)))
57 57
58 58 def assertMet(self, dep):
59 59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
60 60
61 61 def assertUnmet(self, dep):
62 62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
63 63
64 64 def assertUnreachable(self, dep):
65 65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
66 66
67 67 def assertReachable(self, dep):
68 68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
69 69
70 70 def cancan(self, f):
71 71 """decorator to pass through canning into self.user_ns"""
72 72 return uncan(can(f), self.user_ns)
73 73
74 74 def test_require_imports(self):
75 75 """test that @require imports names"""
76 76 @self.cancan
77 @pmod.require('urllib')
77 @pmod.require('base64')
78 78 @interactive
79 def encode(dikt):
80 return urllib.urlencode(dikt)
79 def encode(arg):
80 return base64.b64encode(arg)
81 81 # must pass through canning to properly connect namespaces
82 self.assertEqual(encode(dict(a=5)), 'a=5')
82 self.assertEqual(encode(b'foo'), b'Zm9v')
83 83
84 84 def test_success_only(self):
85 85 dep = pmod.Dependency(mixed, success=True, failure=False)
86 86 self.assertUnmet(dep)
87 87 self.assertUnreachable(dep)
88 88 dep.all=False
89 89 self.assertMet(dep)
90 90 self.assertReachable(dep)
91 91 dep = pmod.Dependency(completed, success=True, failure=False)
92 92 self.assertMet(dep)
93 93 self.assertReachable(dep)
94 94 dep.all=False
95 95 self.assertMet(dep)
96 96 self.assertReachable(dep)
97 97
98 98 def test_failure_only(self):
99 99 dep = pmod.Dependency(mixed, success=False, failure=True)
100 100 self.assertUnmet(dep)
101 101 self.assertUnreachable(dep)
102 102 dep.all=False
103 103 self.assertMet(dep)
104 104 self.assertReachable(dep)
105 105 dep = pmod.Dependency(completed, success=False, failure=True)
106 106 self.assertUnmet(dep)
107 107 self.assertUnreachable(dep)
108 108 dep.all=False
109 109 self.assertUnmet(dep)
110 110 self.assertUnreachable(dep)
111 111
112 112 def test_require_function(self):
113 113
114 114 @pmod.interactive
115 115 def bar(a):
116 116 return func(a)
117 117
118 118 @pmod.require(func)
119 119 @pmod.interactive
120 120 def bar2(a):
121 121 return func(a)
122 122
123 123 self.client[:].clear()
124 124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
125 125 ar = self.view.apply_async(bar2, 5)
126 126 self.assertEqual(ar.get(5), func(5))
127 127
128 128 def test_require_object(self):
129 129
130 130 @pmod.require(foo=func)
131 131 @pmod.interactive
132 132 def bar(a):
133 133 return foo(a)
134 134
135 135 ar = self.view.apply_async(bar, 5)
136 136 self.assertEqual(ar.get(5), func(5))
@@ -1,221 +1,221 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test LoadBalancedView objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from nose import SkipTest
24 24 from nose.plugins.attrib import attr
25 25
26 26 from IPython import parallel as pmod
27 27 from IPython.parallel import error
28 28
29 29 from IPython.parallel.tests import add_engines
30 30
31 31 from .clienttest import ClusterTestCase, crash, wait, skip_without
32 32
33 33 def setup():
34 34 add_engines(3, total=True)
35 35
36 36 class TestLoadBalancedView(ClusterTestCase):
37 37
38 38 def setUp(self):
39 39 ClusterTestCase.setUp(self)
40 40 self.view = self.client.load_balanced_view()
41 41
42 42 @attr('crash')
43 43 def test_z_crash_task(self):
44 44 """test graceful handling of engine death (balanced)"""
45 45 # self.add_engines(1)
46 46 ar = self.view.apply_async(crash)
47 47 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 48 eid = ar.engine_id
49 49 tic = time.time()
50 50 while eid in self.client.ids and time.time()-tic < 5:
51 51 time.sleep(.01)
52 52 self.client.spin()
53 53 self.assertFalse(eid in self.client.ids, "Engine should have died")
54 54
55 55 def test_map(self):
56 56 def f(x):
57 57 return x**2
58 data = range(16)
58 data = list(range(16))
59 59 r = self.view.map_sync(f, data)
60 self.assertEqual(r, map(f, data))
60 self.assertEqual(r, list(map(f, data)))
61 61
62 62 def test_map_generator(self):
63 63 def f(x):
64 64 return x**2
65 65
66 data = range(16)
66 data = list(range(16))
67 67 r = self.view.map_sync(f, iter(data))
68 self.assertEqual(r, map(f, iter(data)))
68 self.assertEqual(r, list(map(f, iter(data))))
69 69
70 70 def test_map_short_first(self):
71 71 def f(x,y):
72 72 if y is None:
73 73 return y
74 74 if x is None:
75 75 return x
76 76 return x*y
77 data = range(10)
78 data2 = range(4)
77 data = list(range(10))
78 data2 = list(range(4))
79 79
80 80 r = self.view.map_sync(f, data, data2)
81 self.assertEqual(r, map(f, data, data2))
81 self.assertEqual(r, list(map(f, data, data2)))
82 82
83 83 def test_map_short_last(self):
84 84 def f(x,y):
85 85 if y is None:
86 86 return y
87 87 if x is None:
88 88 return x
89 89 return x*y
90 data = range(4)
91 data2 = range(10)
90 data = list(range(4))
91 data2 = list(range(10))
92 92
93 93 r = self.view.map_sync(f, data, data2)
94 self.assertEqual(r, map(f, data, data2))
94 self.assertEqual(r, list(map(f, data, data2)))
95 95
96 96 def test_map_unordered(self):
97 97 def f(x):
98 98 return x**2
99 99 def slow_f(x):
100 100 import time
101 101 time.sleep(0.05*x)
102 102 return x**2
103 data = range(16,0,-1)
104 reference = map(f, data)
103 data = list(range(16,0,-1))
104 reference = list(map(f, data))
105 105
106 106 amr = self.view.map_async(slow_f, data, ordered=False)
107 107 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
108 108 # check individual elements, retrieved as they come
109 109 # list comprehension uses __iter__
110 110 astheycame = [ r for r in amr ]
111 111 # Ensure that at least one result came out of order:
112 112 self.assertNotEqual(astheycame, reference, "should not have preserved order")
113 113 self.assertEqual(sorted(astheycame, reverse=True), reference, "result corrupted")
114 114
115 115 def test_map_ordered(self):
116 116 def f(x):
117 117 return x**2
118 118 def slow_f(x):
119 119 import time
120 120 time.sleep(0.05*x)
121 121 return x**2
122 data = range(16,0,-1)
123 reference = map(f, data)
122 data = list(range(16,0,-1))
123 reference = list(map(f, data))
124 124
125 125 amr = self.view.map_async(slow_f, data)
126 126 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
127 127 # check individual elements, retrieved as they come
128 128 # list(amr) uses __iter__
129 129 astheycame = list(amr)
130 130 # Ensure that results came in order
131 131 self.assertEqual(astheycame, reference)
132 132 self.assertEqual(amr.result, reference)
133 133
134 134 def test_map_iterable(self):
135 135 """test map on iterables (balanced)"""
136 136 view = self.view
137 137 # 101 is prime, so it won't be evenly distributed
138 138 arr = range(101)
139 139 # so that it will be an iterator, even in Python 3
140 140 it = iter(arr)
141 141 r = view.map_sync(lambda x:x, arr)
142 142 self.assertEqual(r, list(arr))
143 143
144 144
145 145 def test_abort(self):
146 146 view = self.view
147 147 ar = self.client[:].apply_async(time.sleep, .5)
148 148 ar = self.client[:].apply_async(time.sleep, .5)
149 149 time.sleep(0.2)
150 150 ar2 = view.apply_async(lambda : 2)
151 151 ar3 = view.apply_async(lambda : 3)
152 152 view.abort(ar2)
153 153 view.abort(ar3.msg_ids)
154 154 self.assertRaises(error.TaskAborted, ar2.get)
155 155 self.assertRaises(error.TaskAborted, ar3.get)
156 156
157 157 def test_retries(self):
158 158 self.minimum_engines(3)
159 159 view = self.view
160 160 def fail():
161 161 assert False
162 162 for r in range(len(self.client)-1):
163 163 with view.temp_flags(retries=r):
164 164 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
165 165
166 166 with view.temp_flags(retries=len(self.client), timeout=0.1):
167 167 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
168 168
169 169 def test_short_timeout(self):
170 170 self.minimum_engines(2)
171 171 view = self.view
172 172 def fail():
173 173 import time
174 174 time.sleep(0.25)
175 175 assert False
176 176 with view.temp_flags(retries=1, timeout=0.01):
177 177 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
178 178
179 179 def test_invalid_dependency(self):
180 180 view = self.view
181 181 with view.temp_flags(after='12345'):
182 182 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
183 183
184 184 def test_impossible_dependency(self):
185 185 self.minimum_engines(2)
186 186 view = self.client.load_balanced_view()
187 187 ar1 = view.apply_async(lambda : 1)
188 188 ar1.get()
189 189 e1 = ar1.engine_id
190 190 e2 = e1
191 191 while e2 == e1:
192 192 ar2 = view.apply_async(lambda : 1)
193 193 ar2.get()
194 194 e2 = ar2.engine_id
195 195
196 196 with view.temp_flags(follow=[ar1, ar2]):
197 197 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
198 198
199 199
200 200 def test_follow(self):
201 201 ar = self.view.apply_async(lambda : 1)
202 202 ar.get()
203 203 ars = []
204 204 first_id = ar.engine_id
205 205
206 206 self.view.follow = ar
207 207 for i in range(5):
208 208 ars.append(self.view.apply_async(lambda : 1))
209 209 self.view.wait(ars)
210 210 for ar in ars:
211 211 self.assertEqual(ar.engine_id, first_id)
212 212
213 213 def test_after(self):
214 214 view = self.view
215 215 ar = view.apply_async(time.sleep, 0.5)
216 216 with view.temp_flags(after=ar):
217 217 ar2 = view.apply_async(lambda : 1)
218 218
219 219 ar.wait()
220 220 ar2.wait()
221 221 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,835 +1,835 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import base64
20 20 import sys
21 21 import platform
22 22 import time
23 23 from collections import namedtuple
24 24 from tempfile import mktemp
25 25
26 26 import zmq
27 27 from nose.plugins.attrib import attr
28 28
29 29 from IPython.testing import decorators as dec
30 30 from IPython.utils.io import capture_output
31 31 from IPython.utils.py3compat import unicode_type
32 32
33 33 from IPython import parallel as pmod
34 34 from IPython.parallel import error
35 35 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
36 36 from IPython.parallel.util import interactive
37 37
38 38 from IPython.parallel.tests import add_engines
39 39
40 40 from .clienttest import ClusterTestCase, crash, wait, skip_without
41 41
42 42 def setup():
43 43 add_engines(3, total=True)
44 44
45 45 point = namedtuple("point", "x y")
46 46
47 47 class TestView(ClusterTestCase):
48 48
49 49 def setUp(self):
50 50 # On Win XP, wait for resource cleanup, else parallel test group fails
51 51 if platform.system() == "Windows" and platform.win32_ver()[0] == "XP":
52 52 # 1 sec fails. 1.5 sec seems ok. Using 2 sec for margin of safety
53 53 time.sleep(2)
54 54 super(TestView, self).setUp()
55 55
56 56 @attr('crash')
57 57 def test_z_crash_mux(self):
58 58 """test graceful handling of engine death (direct)"""
59 59 # self.add_engines(1)
60 60 eid = self.client.ids[-1]
61 61 ar = self.client[eid].apply_async(crash)
62 62 self.assertRaisesRemote(error.EngineError, ar.get, 10)
63 63 eid = ar.engine_id
64 64 tic = time.time()
65 65 while eid in self.client.ids and time.time()-tic < 5:
66 66 time.sleep(.01)
67 67 self.client.spin()
68 68 self.assertFalse(eid in self.client.ids, "Engine should have died")
69 69
70 70 def test_push_pull(self):
71 71 """test pushing and pulling"""
72 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
72 data = dict(a=10, b=1.05, c=list(range(10)), d={'e':(1,2),'f':'hi'})
73 73 t = self.client.ids[-1]
74 74 v = self.client[t]
75 75 push = v.push
76 76 pull = v.pull
77 77 v.block=True
78 78 nengines = len(self.client)
79 79 push({'data':data})
80 80 d = pull('data')
81 81 self.assertEqual(d, data)
82 82 self.client[:].push({'data':data})
83 83 d = self.client[:].pull('data', block=True)
84 84 self.assertEqual(d, nengines*[data])
85 85 ar = push({'data':data}, block=False)
86 86 self.assertTrue(isinstance(ar, AsyncResult))
87 87 r = ar.get()
88 88 ar = self.client[:].pull('data', block=False)
89 89 self.assertTrue(isinstance(ar, AsyncResult))
90 90 r = ar.get()
91 91 self.assertEqual(r, nengines*[data])
92 92 self.client[:].push(dict(a=10,b=20))
93 93 r = self.client[:].pull(('a','b'), block=True)
94 94 self.assertEqual(r, nengines*[[10,20]])
95 95
96 96 def test_push_pull_function(self):
97 97 "test pushing and pulling functions"
98 98 def testf(x):
99 99 return 2.0*x
100 100
101 101 t = self.client.ids[-1]
102 102 v = self.client[t]
103 103 v.block=True
104 104 push = v.push
105 105 pull = v.pull
106 106 execute = v.execute
107 107 push({'testf':testf})
108 108 r = pull('testf')
109 109 self.assertEqual(r(1.0), testf(1.0))
110 110 execute('r = testf(10)')
111 111 r = pull('r')
112 112 self.assertEqual(r, testf(10))
113 113 ar = self.client[:].push({'testf':testf}, block=False)
114 114 ar.get()
115 115 ar = self.client[:].pull('testf', block=False)
116 116 rlist = ar.get()
117 117 for r in rlist:
118 118 self.assertEqual(r(1.0), testf(1.0))
119 119 execute("def g(x): return x*x")
120 120 r = pull(('testf','g'))
121 121 self.assertEqual((r[0](10),r[1](10)), (testf(10), 100))
122 122
123 123 def test_push_function_globals(self):
124 124 """test that pushed functions have access to globals"""
125 125 @interactive
126 126 def geta():
127 127 return a
128 128 # self.add_engines(1)
129 129 v = self.client[-1]
130 130 v.block=True
131 131 v['f'] = geta
132 132 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
133 133 v.execute('a=5')
134 134 v.execute('b=f()')
135 135 self.assertEqual(v['b'], 5)
136 136
137 137 def test_push_function_defaults(self):
138 138 """test that pushed functions preserve default args"""
139 139 def echo(a=10):
140 140 return a
141 141 v = self.client[-1]
142 142 v.block=True
143 143 v['f'] = echo
144 144 v.execute('b=f()')
145 145 self.assertEqual(v['b'], 10)
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = pmod.Client(profile='iptest')
150 150 # self.add_engines(1)
151 151 t = c.ids[-1]
152 152 v = c[t]
153 153 v2 = self.client[t]
154 154 ar = v.apply_async(wait, 1)
155 155 # give the monitor time to notice the message
156 156 time.sleep(.25)
157 157 ahr = v2.get_result(ar.msg_ids[0])
158 158 self.assertTrue(isinstance(ahr, AsyncHubResult))
159 159 self.assertEqual(ahr.get(), ar.get())
160 160 ar2 = v2.get_result(ar.msg_ids[0])
161 161 self.assertFalse(isinstance(ar2, AsyncHubResult))
162 162 c.spin()
163 163 c.close()
164 164
165 165 def test_run_newline(self):
166 166 """test that run appends newline to files"""
167 167 tmpfile = mktemp()
168 168 with open(tmpfile, 'w') as f:
169 169 f.write("""def g():
170 170 return 5
171 171 """)
172 172 v = self.client[-1]
173 173 v.run(tmpfile, block=True)
174 174 self.assertEqual(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
175 175
176 176 def test_apply_tracked(self):
177 177 """test tracking for apply"""
178 178 # self.add_engines(1)
179 179 t = self.client.ids[-1]
180 180 v = self.client[t]
181 181 v.block=False
182 182 def echo(n=1024*1024, **kwargs):
183 183 with v.temp_flags(**kwargs):
184 184 return v.apply(lambda x: x, 'x'*n)
185 185 ar = echo(1, track=False)
186 186 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
187 187 self.assertTrue(ar.sent)
188 188 ar = echo(track=True)
189 189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 190 self.assertEqual(ar.sent, ar._tracker.done)
191 191 ar._tracker.wait()
192 192 self.assertTrue(ar.sent)
193 193
194 194 def test_push_tracked(self):
195 195 t = self.client.ids[-1]
196 196 ns = dict(x='x'*1024*1024)
197 197 v = self.client[t]
198 198 ar = v.push(ns, block=False, track=False)
199 199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 200 self.assertTrue(ar.sent)
201 201
202 202 ar = v.push(ns, block=False, track=True)
203 203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 204 ar._tracker.wait()
205 205 self.assertEqual(ar.sent, ar._tracker.done)
206 206 self.assertTrue(ar.sent)
207 207 ar.get()
208 208
209 209 def test_scatter_tracked(self):
210 210 t = self.client.ids
211 211 x='x'*1024*1024
212 212 ar = self.client[t].scatter('x', x, block=False, track=False)
213 213 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
214 214 self.assertTrue(ar.sent)
215 215
216 216 ar = self.client[t].scatter('x', x, block=False, track=True)
217 217 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
218 218 self.assertEqual(ar.sent, ar._tracker.done)
219 219 ar._tracker.wait()
220 220 self.assertTrue(ar.sent)
221 221 ar.get()
222 222
223 223 def test_remote_reference(self):
224 224 v = self.client[-1]
225 225 v['a'] = 123
226 226 ra = pmod.Reference('a')
227 227 b = v.apply_sync(lambda x: x, ra)
228 228 self.assertEqual(b, 123)
229 229
230 230
231 231 def test_scatter_gather(self):
232 232 view = self.client[:]
233 seq1 = range(16)
233 seq1 = list(range(16))
234 234 view.scatter('a', seq1)
235 235 seq2 = view.gather('a', block=True)
236 236 self.assertEqual(seq2, seq1)
237 237 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
238 238
239 239 @skip_without('numpy')
240 240 def test_scatter_gather_numpy(self):
241 241 import numpy
242 242 from numpy.testing.utils import assert_array_equal
243 243 view = self.client[:]
244 244 a = numpy.arange(64)
245 245 view.scatter('a', a, block=True)
246 246 b = view.gather('a', block=True)
247 247 assert_array_equal(b, a)
248 248
249 249 def test_scatter_gather_lazy(self):
250 250 """scatter/gather with targets='all'"""
251 251 view = self.client.direct_view(targets='all')
252 x = range(64)
252 x = list(range(64))
253 253 view.scatter('x', x)
254 254 gathered = view.gather('x', block=True)
255 255 self.assertEqual(gathered, x)
256 256
257 257
258 258 @dec.known_failure_py3
259 259 @skip_without('numpy')
260 260 def test_push_numpy_nocopy(self):
261 261 import numpy
262 262 view = self.client[:]
263 263 a = numpy.arange(64)
264 264 view['A'] = a
265 265 @interactive
266 266 def check_writeable(x):
267 267 return x.flags.writeable
268 268
269 269 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
270 270 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
271 271
272 272 view.push(dict(B=a))
273 273 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
274 274 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
275 275
276 276 @skip_without('numpy')
277 277 def test_apply_numpy(self):
278 278 """view.apply(f, ndarray)"""
279 279 import numpy
280 280 from numpy.testing.utils import assert_array_equal
281 281
282 282 A = numpy.random.random((100,100))
283 283 view = self.client[-1]
284 284 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
285 285 B = A.astype(dt)
286 286 C = view.apply_sync(lambda x:x, B)
287 287 assert_array_equal(B,C)
288 288
289 289 @skip_without('numpy')
290 290 def test_push_pull_recarray(self):
291 291 """push/pull recarrays"""
292 292 import numpy
293 293 from numpy.testing.utils import assert_array_equal
294 294
295 295 view = self.client[-1]
296 296
297 297 R = numpy.array([
298 298 (1, 'hi', 0.),
299 299 (2**30, 'there', 2.5),
300 300 (-99999, 'world', -12345.6789),
301 301 ], [('n', int), ('s', '|S10'), ('f', float)])
302 302
303 303 view['RR'] = R
304 304 R2 = view['RR']
305 305
306 306 r_dtype, r_shape = view.apply_sync(interactive(lambda : (RR.dtype, RR.shape)))
307 307 self.assertEqual(r_dtype, R.dtype)
308 308 self.assertEqual(r_shape, R.shape)
309 309 self.assertEqual(R2.dtype, R.dtype)
310 310 self.assertEqual(R2.shape, R.shape)
311 311 assert_array_equal(R2, R)
312 312
313 313 @skip_without('pandas')
314 314 def test_push_pull_timeseries(self):
315 315 """push/pull pandas.TimeSeries"""
316 316 import pandas
317 317
318 ts = pandas.TimeSeries(range(10))
318 ts = pandas.TimeSeries(list(range(10)))
319 319
320 320 view = self.client[-1]
321 321
322 322 view.push(dict(ts=ts), block=True)
323 323 rts = view['ts']
324 324
325 325 self.assertEqual(type(rts), type(ts))
326 326 self.assertTrue((ts == rts).all())
327 327
328 328 def test_map(self):
329 329 view = self.client[:]
330 330 def f(x):
331 331 return x**2
332 data = range(16)
332 data = list(range(16))
333 333 r = view.map_sync(f, data)
334 self.assertEqual(r, map(f, data))
334 self.assertEqual(r, list(map(f, data)))
335 335
336 336 def test_map_iterable(self):
337 337 """test map on iterables (direct)"""
338 338 view = self.client[:]
339 339 # 101 is prime, so it won't be evenly distributed
340 340 arr = range(101)
341 341 # ensure it will be an iterator, even in Python 3
342 342 it = iter(arr)
343 343 r = view.map_sync(lambda x: x, it)
344 344 self.assertEqual(r, list(arr))
345 345
346 346 @skip_without('numpy')
347 347 def test_map_numpy(self):
348 348 """test map on numpy arrays (direct)"""
349 349 import numpy
350 350 from numpy.testing.utils import assert_array_equal
351 351
352 352 view = self.client[:]
353 353 # 101 is prime, so it won't be evenly distributed
354 354 arr = numpy.arange(101)
355 355 r = view.map_sync(lambda x: x, arr)
356 356 assert_array_equal(r, arr)
357 357
358 358 def test_scatter_gather_nonblocking(self):
359 data = range(16)
359 data = list(range(16))
360 360 view = self.client[:]
361 361 view.scatter('a', data, block=False)
362 362 ar = view.gather('a', block=False)
363 363 self.assertEqual(ar.get(), data)
364 364
365 365 @skip_without('numpy')
366 366 def test_scatter_gather_numpy_nonblocking(self):
367 367 import numpy
368 368 from numpy.testing.utils import assert_array_equal
369 369 a = numpy.arange(64)
370 370 view = self.client[:]
371 371 ar = view.scatter('a', a, block=False)
372 372 self.assertTrue(isinstance(ar, AsyncResult))
373 373 amr = view.gather('a', block=False)
374 374 self.assertTrue(isinstance(amr, AsyncMapResult))
375 375 assert_array_equal(amr.get(), a)
376 376
377 377 def test_execute(self):
378 378 view = self.client[:]
379 379 # self.client.debug=True
380 380 execute = view.execute
381 381 ar = execute('c=30', block=False)
382 382 self.assertTrue(isinstance(ar, AsyncResult))
383 383 ar = execute('d=[0,1,2]', block=False)
384 384 self.client.wait(ar, 1)
385 385 self.assertEqual(len(ar.get()), len(self.client))
386 386 for c in view['c']:
387 387 self.assertEqual(c, 30)
388 388
389 389 def test_abort(self):
390 390 view = self.client[-1]
391 391 ar = view.execute('import time; time.sleep(1)', block=False)
392 392 ar2 = view.apply_async(lambda : 2)
393 393 ar3 = view.apply_async(lambda : 3)
394 394 view.abort(ar2)
395 395 view.abort(ar3.msg_ids)
396 396 self.assertRaises(error.TaskAborted, ar2.get)
397 397 self.assertRaises(error.TaskAborted, ar3.get)
398 398
399 399 def test_abort_all(self):
400 400 """view.abort() aborts all outstanding tasks"""
401 401 view = self.client[-1]
402 402 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
403 403 view.abort()
404 404 view.wait(timeout=5)
405 405 for ar in ars[5:]:
406 406 self.assertRaises(error.TaskAborted, ar.get)
407 407
408 408 def test_temp_flags(self):
409 409 view = self.client[-1]
410 410 view.block=True
411 411 with view.temp_flags(block=False):
412 412 self.assertFalse(view.block)
413 413 self.assertTrue(view.block)
414 414
415 415 @dec.known_failure_py3
416 416 def test_importer(self):
417 417 view = self.client[-1]
418 418 view.clear(block=True)
419 419 with view.importer:
420 420 import re
421 421
422 422 @interactive
423 423 def findall(pat, s):
424 424 # this globals() step isn't necessary in real code
425 425 # only to prevent a closure in the test
426 426 re = globals()['re']
427 427 return re.findall(pat, s)
428 428
429 429 self.assertEqual(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
430 430
431 431 def test_unicode_execute(self):
432 432 """test executing unicode strings"""
433 433 v = self.client[-1]
434 434 v.block=True
435 435 if sys.version_info[0] >= 3:
436 436 code="a='é'"
437 437 else:
438 438 code=u"a=u'é'"
439 439 v.execute(code)
440 440 self.assertEqual(v['a'], u'é')
441 441
442 442 def test_unicode_apply_result(self):
443 443 """test unicode apply results"""
444 444 v = self.client[-1]
445 445 r = v.apply_sync(lambda : u'é')
446 446 self.assertEqual(r, u'é')
447 447
448 448 def test_unicode_apply_arg(self):
449 449 """test passing unicode arguments to apply"""
450 450 v = self.client[-1]
451 451
452 452 @interactive
453 453 def check_unicode(a, check):
454 assert isinstance(a, unicode_type), "%r is not unicode"%a
454 assert not isinstance(a, bytes), "%r is bytes, not unicode"%a
455 455 assert isinstance(check, bytes), "%r is not bytes"%check
456 456 assert a.encode('utf8') == check, "%s != %s"%(a,check)
457 457
458 458 for s in [ u'é', u'ßø®∫',u'asdf' ]:
459 459 try:
460 460 v.apply_sync(check_unicode, s, s.encode('utf8'))
461 461 except error.RemoteError as e:
462 462 if e.ename == 'AssertionError':
463 463 self.fail(e.evalue)
464 464 else:
465 465 raise e
466 466
467 467 def test_map_reference(self):
468 468 """view.map(<Reference>, *seqs) should work"""
469 469 v = self.client[:]
470 470 v.scatter('n', self.client.ids, flatten=True)
471 471 v.execute("f = lambda x,y: x*y")
472 472 rf = pmod.Reference('f')
473 473 nlist = list(range(10))
474 474 mlist = nlist[::-1]
475 475 expected = [ m*n for m,n in zip(mlist, nlist) ]
476 476 result = v.map_sync(rf, mlist, nlist)
477 477 self.assertEqual(result, expected)
478 478
479 479 def test_apply_reference(self):
480 480 """view.apply(<Reference>, *args) should work"""
481 481 v = self.client[:]
482 482 v.scatter('n', self.client.ids, flatten=True)
483 483 v.execute("f = lambda x: n*x")
484 484 rf = pmod.Reference('f')
485 485 result = v.apply_sync(rf, 5)
486 486 expected = [ 5*id for id in self.client.ids ]
487 487 self.assertEqual(result, expected)
488 488
489 489 def test_eval_reference(self):
490 490 v = self.client[self.client.ids[0]]
491 v['g'] = range(5)
491 v['g'] = list(range(5))
492 492 rg = pmod.Reference('g[0]')
493 493 echo = lambda x:x
494 494 self.assertEqual(v.apply_sync(echo, rg), 0)
495 495
496 496 def test_reference_nameerror(self):
497 497 v = self.client[self.client.ids[0]]
498 498 r = pmod.Reference('elvis_has_left')
499 499 echo = lambda x:x
500 500 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
501 501
502 502 def test_single_engine_map(self):
503 503 e0 = self.client[self.client.ids[0]]
504 r = range(5)
504 r = list(range(5))
505 505 check = [ -1*i for i in r ]
506 506 result = e0.map_sync(lambda x: -1*x, r)
507 507 self.assertEqual(result, check)
508 508
509 509 def test_len(self):
510 510 """len(view) makes sense"""
511 511 e0 = self.client[self.client.ids[0]]
512 512 self.assertEqual(len(e0), 1)
513 513 v = self.client[:]
514 514 self.assertEqual(len(v), len(self.client.ids))
515 515 v = self.client.direct_view('all')
516 516 self.assertEqual(len(v), len(self.client.ids))
517 517 v = self.client[:2]
518 518 self.assertEqual(len(v), 2)
519 519 v = self.client[:1]
520 520 self.assertEqual(len(v), 1)
521 521 v = self.client.load_balanced_view()
522 522 self.assertEqual(len(v), len(self.client.ids))
523 523
524 524
525 525 # begin execute tests
526 526
527 527 def test_execute_reply(self):
528 528 e0 = self.client[self.client.ids[0]]
529 529 e0.block = True
530 530 ar = e0.execute("5", silent=False)
531 531 er = ar.get()
532 532 self.assertEqual(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
533 533 self.assertEqual(er.pyout['data']['text/plain'], '5')
534 534
535 535 def test_execute_reply_rich(self):
536 536 e0 = self.client[self.client.ids[0]]
537 537 e0.block = True
538 538 e0.execute("from IPython.display import Image, HTML")
539 539 ar = e0.execute("Image(data=b'garbage', format='png', width=10)", silent=False)
540 540 er = ar.get()
541 541 b64data = base64.encodestring(b'garbage').decode('ascii')
542 542 self.assertEqual(er._repr_png_(), (b64data, dict(width=10)))
543 543 ar = e0.execute("HTML('<b>bold</b>')", silent=False)
544 544 er = ar.get()
545 545 self.assertEqual(er._repr_html_(), "<b>bold</b>")
546 546
547 547 def test_execute_reply_stdout(self):
548 548 e0 = self.client[self.client.ids[0]]
549 549 e0.block = True
550 550 ar = e0.execute("print (5)", silent=False)
551 551 er = ar.get()
552 552 self.assertEqual(er.stdout.strip(), '5')
553 553
554 554 def test_execute_pyout(self):
555 555 """execute triggers pyout with silent=False"""
556 556 view = self.client[:]
557 557 ar = view.execute("5", silent=False, block=True)
558 558
559 559 expected = [{'text/plain' : '5'}] * len(view)
560 560 mimes = [ out['data'] for out in ar.pyout ]
561 561 self.assertEqual(mimes, expected)
562 562
563 563 def test_execute_silent(self):
564 564 """execute does not trigger pyout with silent=True"""
565 565 view = self.client[:]
566 566 ar = view.execute("5", block=True)
567 567 expected = [None] * len(view)
568 568 self.assertEqual(ar.pyout, expected)
569 569
570 570 def test_execute_magic(self):
571 571 """execute accepts IPython commands"""
572 572 view = self.client[:]
573 573 view.execute("a = 5")
574 574 ar = view.execute("%whos", block=True)
575 575 # this will raise, if that failed
576 576 ar.get(5)
577 577 for stdout in ar.stdout:
578 578 lines = stdout.splitlines()
579 579 self.assertEqual(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
580 580 found = False
581 581 for line in lines[2:]:
582 582 split = line.split()
583 583 if split == ['a', 'int', '5']:
584 584 found = True
585 585 break
586 586 self.assertTrue(found, "whos output wrong: %s" % stdout)
587 587
588 588 def test_execute_displaypub(self):
589 589 """execute tracks display_pub output"""
590 590 view = self.client[:]
591 591 view.execute("from IPython.core.display import *")
592 592 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
593 593
594 594 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
595 595 for outputs in ar.outputs:
596 596 mimes = [ out['data'] for out in outputs ]
597 597 self.assertEqual(mimes, expected)
598 598
599 599 def test_apply_displaypub(self):
600 600 """apply tracks display_pub output"""
601 601 view = self.client[:]
602 602 view.execute("from IPython.core.display import *")
603 603
604 604 @interactive
605 605 def publish():
606 606 [ display(i) for i in range(5) ]
607 607
608 608 ar = view.apply_async(publish)
609 609 ar.get(5)
610 610 expected = [ {u'text/plain' : unicode_type(j)} for j in range(5) ]
611 611 for outputs in ar.outputs:
612 612 mimes = [ out['data'] for out in outputs ]
613 613 self.assertEqual(mimes, expected)
614 614
615 615 def test_execute_raises(self):
616 616 """exceptions in execute requests raise appropriately"""
617 617 view = self.client[-1]
618 618 ar = view.execute("1/0")
619 619 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
620 620
621 621 def test_remoteerror_render_exception(self):
622 622 """RemoteErrors get nice tracebacks"""
623 623 view = self.client[-1]
624 624 ar = view.execute("1/0")
625 625 ip = get_ipython()
626 626 ip.user_ns['ar'] = ar
627 627 with capture_output() as io:
628 628 ip.run_cell("ar.get(2)")
629 629
630 630 self.assertTrue('ZeroDivisionError' in io.stdout, io.stdout)
631 631
632 632 def test_compositeerror_render_exception(self):
633 633 """CompositeErrors get nice tracebacks"""
634 634 view = self.client[:]
635 635 ar = view.execute("1/0")
636 636 ip = get_ipython()
637 637 ip.user_ns['ar'] = ar
638 638
639 639 with capture_output() as io:
640 640 ip.run_cell("ar.get(2)")
641 641
642 642 count = min(error.CompositeError.tb_limit, len(view))
643 643
644 644 self.assertEqual(io.stdout.count('ZeroDivisionError'), count * 2, io.stdout)
645 645 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
646 646 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
647 647
648 648 def test_compositeerror_truncate(self):
649 649 """Truncate CompositeErrors with many exceptions"""
650 650 view = self.client[:]
651 651 msg_ids = []
652 652 for i in range(10):
653 653 ar = view.execute("1/0")
654 654 msg_ids.extend(ar.msg_ids)
655 655
656 656 ar = self.client.get_result(msg_ids)
657 657 try:
658 658 ar.get()
659 659 except error.CompositeError as _e:
660 660 e = _e
661 661 else:
662 662 self.fail("Should have raised CompositeError")
663 663
664 664 lines = e.render_traceback()
665 665 with capture_output() as io:
666 666 e.print_traceback()
667 667
668 668 self.assertTrue("more exceptions" in lines[-1])
669 669 count = e.tb_limit
670 670
671 671 self.assertEqual(io.stdout.count('ZeroDivisionError'), 2 * count, io.stdout)
672 672 self.assertEqual(io.stdout.count('by zero'), count, io.stdout)
673 673 self.assertEqual(io.stdout.count(':execute'), count, io.stdout)
674 674
675 675 @dec.skipif_not_matplotlib
676 676 def test_magic_pylab(self):
677 677 """%pylab works on engines"""
678 678 view = self.client[-1]
679 679 ar = view.execute("%pylab inline")
680 680 # at least check if this raised:
681 681 reply = ar.get(5)
682 682 # include imports, in case user config
683 683 ar = view.execute("plot(rand(100))", silent=False)
684 684 reply = ar.get(5)
685 685 self.assertEqual(len(reply.outputs), 1)
686 686 output = reply.outputs[0]
687 687 self.assertTrue("data" in output)
688 688 data = output['data']
689 689 self.assertTrue("image/png" in data)
690 690
691 691 def test_func_default_func(self):
692 692 """interactively defined function as apply func default"""
693 693 def foo():
694 694 return 'foo'
695 695
696 696 def bar(f=foo):
697 697 return f()
698 698
699 699 view = self.client[-1]
700 700 ar = view.apply_async(bar)
701 701 r = ar.get(10)
702 702 self.assertEqual(r, 'foo')
703 703 def test_data_pub_single(self):
704 704 view = self.client[-1]
705 705 ar = view.execute('\n'.join([
706 706 'from IPython.kernel.zmq.datapub import publish_data',
707 707 'for i in range(5):',
708 708 ' publish_data(dict(i=i))'
709 709 ]), block=False)
710 710 self.assertTrue(isinstance(ar.data, dict))
711 711 ar.get(5)
712 712 self.assertEqual(ar.data, dict(i=4))
713 713
714 714 def test_data_pub(self):
715 715 view = self.client[:]
716 716 ar = view.execute('\n'.join([
717 717 'from IPython.kernel.zmq.datapub import publish_data',
718 718 'for i in range(5):',
719 719 ' publish_data(dict(i=i))'
720 720 ]), block=False)
721 721 self.assertTrue(all(isinstance(d, dict) for d in ar.data))
722 722 ar.get(5)
723 723 self.assertEqual(ar.data, [dict(i=4)] * len(ar))
724 724
725 725 def test_can_list_arg(self):
726 726 """args in lists are canned"""
727 727 view = self.client[-1]
728 728 view['a'] = 128
729 729 rA = pmod.Reference('a')
730 730 ar = view.apply_async(lambda x: x, [rA])
731 731 r = ar.get(5)
732 732 self.assertEqual(r, [128])
733 733
734 734 def test_can_dict_arg(self):
735 735 """args in dicts are canned"""
736 736 view = self.client[-1]
737 737 view['a'] = 128
738 738 rA = pmod.Reference('a')
739 739 ar = view.apply_async(lambda x: x, dict(foo=rA))
740 740 r = ar.get(5)
741 741 self.assertEqual(r, dict(foo=128))
742 742
743 743 def test_can_list_kwarg(self):
744 744 """kwargs in lists are canned"""
745 745 view = self.client[-1]
746 746 view['a'] = 128
747 747 rA = pmod.Reference('a')
748 748 ar = view.apply_async(lambda x=5: x, x=[rA])
749 749 r = ar.get(5)
750 750 self.assertEqual(r, [128])
751 751
752 752 def test_can_dict_kwarg(self):
753 753 """kwargs in dicts are canned"""
754 754 view = self.client[-1]
755 755 view['a'] = 128
756 756 rA = pmod.Reference('a')
757 757 ar = view.apply_async(lambda x=5: x, dict(foo=rA))
758 758 r = ar.get(5)
759 759 self.assertEqual(r, dict(foo=128))
760 760
761 761 def test_map_ref(self):
762 762 """view.map works with references"""
763 763 view = self.client[:]
764 764 ranks = sorted(self.client.ids)
765 765 view.scatter('rank', ranks, flatten=True)
766 766 rrank = pmod.Reference('rank')
767 767
768 768 amr = view.map_async(lambda x: x*2, [rrank] * len(view))
769 769 drank = amr.get(5)
770 770 self.assertEqual(drank, [ r*2 for r in ranks ])
771 771
772 772 def test_nested_getitem_setitem(self):
773 773 """get and set with view['a.b']"""
774 774 view = self.client[-1]
775 775 view.execute('\n'.join([
776 776 'class A(object): pass',
777 777 'a = A()',
778 778 'a.b = 128',
779 779 ]), block=True)
780 780 ra = pmod.Reference('a')
781 781
782 782 r = view.apply_sync(lambda x: x.b, ra)
783 783 self.assertEqual(r, 128)
784 784 self.assertEqual(view['a.b'], 128)
785 785
786 786 view['a.b'] = 0
787 787
788 788 r = view.apply_sync(lambda x: x.b, ra)
789 789 self.assertEqual(r, 0)
790 790 self.assertEqual(view['a.b'], 0)
791 791
792 792 def test_return_namedtuple(self):
793 793 def namedtuplify(x, y):
794 794 from IPython.parallel.tests.test_view import point
795 795 return point(x, y)
796 796
797 797 view = self.client[-1]
798 798 p = view.apply_sync(namedtuplify, 1, 2)
799 799 self.assertEqual(p.x, 1)
800 800 self.assertEqual(p.y, 2)
801 801
802 802 def test_apply_namedtuple(self):
803 803 def echoxy(p):
804 804 return p.y, p.x
805 805
806 806 view = self.client[-1]
807 807 tup = view.apply_sync(echoxy, point(1, 2))
808 808 self.assertEqual(tup, (2,1))
809 809
810 810 def test_sync_imports(self):
811 811 view = self.client[-1]
812 812 with capture_output() as io:
813 813 with view.sync_imports():
814 814 import IPython
815 815 self.assertIn("IPython", io.stdout)
816 816
817 817 @interactive
818 818 def find_ipython():
819 819 return 'IPython' in globals()
820 820
821 821 assert view.apply_sync(find_ipython)
822 822
823 823 def test_sync_imports_quiet(self):
824 824 view = self.client[-1]
825 825 with capture_output() as io:
826 826 with view.sync_imports(quiet=True):
827 827 import IPython
828 828 self.assertEqual(io.stdout, '')
829 829
830 830 @interactive
831 831 def find_ipython():
832 832 return 'IPython' in globals()
833 833
834 834 assert view.apply_sync(find_ipython)
835 835
@@ -1,369 +1,369 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 47 from IPython.utils.py3compat import string_types, iteritems, itervalues
48 48 from IPython.kernel.zmq.log import EnginePUBHandler
49 49 from IPython.kernel.zmq.serialize import (
50 50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 51 )
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Classes
55 55 #-----------------------------------------------------------------------------
56 56
57 57 class Namespace(dict):
58 58 """Subclass of dict for attribute access to keys."""
59 59
60 60 def __getattr__(self, key):
61 61 """getattr aliased to getitem"""
62 62 if key in self:
63 63 return self[key]
64 64 else:
65 65 raise NameError(key)
66 66
67 67 def __setattr__(self, key, value):
68 68 """setattr aliased to setitem, with strict"""
69 69 if hasattr(dict, key):
70 70 raise KeyError("Cannot override dict keys %r"%key)
71 71 self[key] = value
72 72
73 73
74 74 class ReverseDict(dict):
75 75 """simple double-keyed subset of dict methods."""
76 76
77 77 def __init__(self, *args, **kwargs):
78 78 dict.__init__(self, *args, **kwargs)
79 79 self._reverse = dict()
80 80 for key, value in iteritems(self):
81 81 self._reverse[value] = key
82 82
83 83 def __getitem__(self, key):
84 84 try:
85 85 return dict.__getitem__(self, key)
86 86 except KeyError:
87 87 return self._reverse[key]
88 88
89 89 def __setitem__(self, key, value):
90 90 if key in self._reverse:
91 91 raise KeyError("Can't have key %r on both sides!"%key)
92 92 dict.__setitem__(self, key, value)
93 93 self._reverse[value] = key
94 94
95 95 def pop(self, key):
96 96 value = dict.pop(self, key)
97 97 self._reverse.pop(value)
98 98 return value
99 99
100 100 def get(self, key, default=None):
101 101 try:
102 102 return self[key]
103 103 except KeyError:
104 104 return default
105 105
106 106 #-----------------------------------------------------------------------------
107 107 # Functions
108 108 #-----------------------------------------------------------------------------
109 109
110 110 @decorator
111 111 def log_errors(f, self, *args, **kwargs):
112 112 """decorator to log unhandled exceptions raised in a method.
113 113
114 114 For use wrapping on_recv callbacks, so that exceptions
115 115 do not cause the stream to be closed.
116 116 """
117 117 try:
118 118 return f(self, *args, **kwargs)
119 119 except Exception:
120 120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121 121
122 122
123 123 def is_url(url):
124 124 """boolean check for whether a string is a zmq url"""
125 125 if '://' not in url:
126 126 return False
127 127 proto, addr = url.split('://', 1)
128 128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 129 return False
130 130 return True
131 131
132 132 def validate_url(url):
133 133 """validate a url for zeromq"""
134 134 if not isinstance(url, string_types):
135 135 raise TypeError("url must be a string, not %r"%type(url))
136 136 url = url.lower()
137 137
138 138 proto_addr = url.split('://')
139 139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 140 proto, addr = proto_addr
141 141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142 142
143 143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 144 # author: Remi Sabourin
145 145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146 146
147 147 if proto == 'tcp':
148 148 lis = addr.split(':')
149 149 assert len(lis) == 2, 'Invalid url: %r'%url
150 150 addr,s_port = lis
151 151 try:
152 152 port = int(s_port)
153 153 except ValueError:
154 154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155 155
156 156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157 157
158 158 else:
159 159 # only validate tcp urls currently
160 160 pass
161 161
162 162 return True
163 163
164 164
165 165 def validate_url_container(container):
166 166 """validate a potentially nested collection of urls."""
167 167 if isinstance(container, string_types):
168 168 url = container
169 169 return validate_url(url)
170 170 elif isinstance(container, dict):
171 171 container = itervalues(container)
172 172
173 173 for element in container:
174 174 validate_url_container(element)
175 175
176 176
177 177 def split_url(url):
178 178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 179 proto_addr = url.split('://')
180 180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 181 proto, addr = proto_addr
182 182 lis = addr.split(':')
183 183 assert len(lis) == 2, 'Invalid url: %r'%url
184 184 addr,s_port = lis
185 185 return proto,addr,s_port
186 186
187 187 def disambiguate_ip_address(ip, location=None):
188 188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 189 ones, based on the location (default interpretation of location is localhost)."""
190 190 if ip in ('0.0.0.0', '*'):
191 191 if location is None or is_public_ip(location) or not public_ips():
192 192 # If location is unspecified or cannot be determined, assume local
193 193 ip = localhost()
194 194 elif location:
195 195 return location
196 196 return ip
197 197
198 198 def disambiguate_url(url, location=None):
199 199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 200 ones, based on the location (default interpretation is localhost).
201 201
202 202 This is for zeromq urls, such as tcp://*:10101."""
203 203 try:
204 204 proto,ip,port = split_url(url)
205 205 except AssertionError:
206 206 # probably not tcp url; could be ipc, etc.
207 207 return url
208 208
209 209 ip = disambiguate_ip_address(ip,location)
210 210
211 211 return "%s://%s:%s"%(proto,ip,port)
212 212
213 213
214 214 #--------------------------------------------------------------------------
215 215 # helpers for implementing old MEC API via view.apply
216 216 #--------------------------------------------------------------------------
217 217
218 218 def interactive(f):
219 219 """decorator for making functions appear as interactively defined.
220 220 This results in the function being linked to the user_ns as globals()
221 221 instead of the module globals().
222 222 """
223 223 f.__module__ = '__main__'
224 224 return f
225 225
226 226 @interactive
227 227 def _push(**ns):
228 228 """helper method for implementing `client.push` via `client.apply`"""
229 229 user_ns = globals()
230 230 tmp = '_IP_PUSH_TMP_'
231 231 while tmp in user_ns:
232 232 tmp = tmp + '_'
233 233 try:
234 for name, value in iteritems(ns):
234 for name, value in ns.items():
235 235 user_ns[tmp] = value
236 236 exec("%s = %s" % (name, tmp), user_ns)
237 237 finally:
238 238 user_ns.pop(tmp, None)
239 239
240 240 @interactive
241 241 def _pull(keys):
242 242 """helper method for implementing `client.pull` via `client.apply`"""
243 243 if isinstance(keys, (list,tuple, set)):
244 return map(lambda key: eval(key, globals()), keys)
244 return [eval(key, globals()) for key in keys]
245 245 else:
246 246 return eval(keys, globals())
247 247
248 248 @interactive
249 249 def _execute(code):
250 250 """helper method for implementing `client.execute` via `client.apply`"""
251 251 exec(code, globals())
252 252
253 253 #--------------------------------------------------------------------------
254 254 # extra process management utilities
255 255 #--------------------------------------------------------------------------
256 256
257 257 _random_ports = set()
258 258
259 259 def select_random_ports(n):
260 260 """Selects and return n random ports that are available."""
261 261 ports = []
262 262 for i in range(n):
263 263 sock = socket.socket()
264 264 sock.bind(('', 0))
265 265 while sock.getsockname()[1] in _random_ports:
266 266 sock.close()
267 267 sock = socket.socket()
268 268 sock.bind(('', 0))
269 269 ports.append(sock)
270 270 for i, sock in enumerate(ports):
271 271 port = sock.getsockname()[1]
272 272 sock.close()
273 273 ports[i] = port
274 274 _random_ports.add(port)
275 275 return ports
276 276
277 277 def signal_children(children):
278 278 """Relay interupt/term signals to children, for more solid process cleanup."""
279 279 def terminate_children(sig, frame):
280 280 log = Application.instance().log
281 281 log.critical("Got signal %i, terminating children..."%sig)
282 282 for child in children:
283 283 child.terminate()
284 284
285 285 sys.exit(sig != SIGINT)
286 286 # sys.exit(sig)
287 287 for sig in (SIGINT, SIGABRT, SIGTERM):
288 288 signal(sig, terminate_children)
289 289
290 290 def generate_exec_key(keyfile):
291 291 import uuid
292 292 newkey = str(uuid.uuid4())
293 293 with open(keyfile, 'w') as f:
294 294 # f.write('ipython-key ')
295 295 f.write(newkey+'\n')
296 296 # set user-only RW permissions (0600)
297 297 # this will have no effect on Windows
298 298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299 299
300 300
301 301 def integer_loglevel(loglevel):
302 302 try:
303 303 loglevel = int(loglevel)
304 304 except ValueError:
305 305 if isinstance(loglevel, str):
306 306 loglevel = getattr(logging, loglevel)
307 307 return loglevel
308 308
309 309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 310 logger = logging.getLogger(logname)
311 311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 312 # don't add a second PUBHandler
313 313 return
314 314 loglevel = integer_loglevel(loglevel)
315 315 lsock = context.socket(zmq.PUB)
316 316 lsock.connect(iface)
317 317 handler = handlers.PUBHandler(lsock)
318 318 handler.setLevel(loglevel)
319 319 handler.root_topic = root
320 320 logger.addHandler(handler)
321 321 logger.setLevel(loglevel)
322 322
323 323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 324 logger = logging.getLogger()
325 325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 326 # don't add a second PUBHandler
327 327 return
328 328 loglevel = integer_loglevel(loglevel)
329 329 lsock = context.socket(zmq.PUB)
330 330 lsock.connect(iface)
331 331 handler = EnginePUBHandler(engine, lsock)
332 332 handler.setLevel(loglevel)
333 333 logger.addHandler(handler)
334 334 logger.setLevel(loglevel)
335 335 return logger
336 336
337 337 def local_logger(logname, loglevel=logging.DEBUG):
338 338 loglevel = integer_loglevel(loglevel)
339 339 logger = logging.getLogger(logname)
340 340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 341 # don't add a second StreamHandler
342 342 return
343 343 handler = logging.StreamHandler()
344 344 handler.setLevel(loglevel)
345 345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 346 datefmt="%Y-%m-%d %H:%M:%S")
347 347 handler.setFormatter(formatter)
348 348
349 349 logger.addHandler(handler)
350 350 logger.setLevel(loglevel)
351 351 return logger
352 352
353 353 def set_hwm(sock, hwm=0):
354 354 """set zmq High Water Mark on a socket
355 355
356 356 in a way that always works for various pyzmq / libzmq versions.
357 357 """
358 358 import zmq
359 359
360 360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
361 361 opt = getattr(zmq, key, None)
362 362 if opt is None:
363 363 continue
364 364 try:
365 365 sock.setsockopt(opt, hwm)
366 366 except zmq.ZMQError:
367 367 pass
368 368
369 369 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now