##// END OF EJS Templates
add Client.spin_thread()...
MinRK -
Show More
@@ -1,1454 +1,1505 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 from threading import Thread, Event
21 22 import time
22 23 import warnings
23 24 from datetime import datetime
24 25 from getpass import getpass
25 26 from pprint import pprint
26 27
27 28 pjoin = os.path.join
28 29
29 30 import zmq
30 31 # from zmq.eventloop import ioloop, zmqstream
31 32
32 33 from IPython.config.configurable import MultipleInstanceError
33 34 from IPython.core.application import BaseIPythonApplication
34 35
35 36 from IPython.utils.jsonutil import rekey
36 37 from IPython.utils.localinterfaces import LOCAL_IPS
37 38 from IPython.utils.path import get_ipython_dir
38 39 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
39 Dict, List, Bool, Set)
40 Dict, List, Bool, Set, Any)
40 41 from IPython.external.decorator import decorator
41 42 from IPython.external.ssh import tunnel
42 43
43 44 from IPython.parallel import Reference
44 45 from IPython.parallel import error
45 46 from IPython.parallel import util
46 47
47 48 from IPython.zmq.session import Session, Message
48 49
49 50 from .asyncresult import AsyncResult, AsyncHubResult
50 51 from IPython.core.profiledir import ProfileDir, ProfileDirError
51 52 from .view import DirectView, LoadBalancedView
52 53
53 54 if sys.version_info[0] >= 3:
54 55 # xrange is used in a couple 'isinstance' tests in py2
55 56 # should be just 'range' in 3k
56 57 xrange = range
57 58
58 59 #--------------------------------------------------------------------------
59 60 # Decorators for Client methods
60 61 #--------------------------------------------------------------------------
61 62
62 63 @decorator
63 64 def spin_first(f, self, *args, **kwargs):
64 65 """Call spin() to sync state prior to calling the method."""
65 66 self.spin()
66 67 return f(self, *args, **kwargs)
67 68
68 69
69 70 #--------------------------------------------------------------------------
70 71 # Classes
71 72 #--------------------------------------------------------------------------
72 73
73 74 class Metadata(dict):
74 75 """Subclass of dict for initializing metadata values.
75 76
76 77 Attribute access works on keys.
77 78
78 79 These objects have a strict set of keys - errors will raise if you try
79 80 to add new keys.
80 81 """
81 82 def __init__(self, *args, **kwargs):
82 83 dict.__init__(self)
83 84 md = {'msg_id' : None,
84 85 'submitted' : None,
85 86 'started' : None,
86 87 'completed' : None,
87 88 'received' : None,
88 89 'engine_uuid' : None,
89 90 'engine_id' : None,
90 91 'follow' : None,
91 92 'after' : None,
92 93 'status' : None,
93 94
94 95 'pyin' : None,
95 96 'pyout' : None,
96 97 'pyerr' : None,
97 98 'stdout' : '',
98 99 'stderr' : '',
99 100 }
100 101 self.update(md)
101 102 self.update(dict(*args, **kwargs))
102 103
103 104 def __getattr__(self, key):
104 105 """getattr aliased to getitem"""
105 106 if key in self.iterkeys():
106 107 return self[key]
107 108 else:
108 109 raise AttributeError(key)
109 110
110 111 def __setattr__(self, key, value):
111 112 """setattr aliased to setitem, with strict"""
112 113 if key in self.iterkeys():
113 114 self[key] = value
114 115 else:
115 116 raise AttributeError(key)
116 117
117 118 def __setitem__(self, key, value):
118 119 """strict static key enforcement"""
119 120 if key in self.iterkeys():
120 121 dict.__setitem__(self, key, value)
121 122 else:
122 123 raise KeyError(key)
123 124
124 125
125 126 class Client(HasTraits):
126 127 """A semi-synchronous client to the IPython ZMQ cluster
127 128
128 129 Parameters
129 130 ----------
130 131
131 132 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
132 133 Connection information for the Hub's registration. If a json connector
133 134 file is given, then likely no further configuration is necessary.
134 135 [Default: use profile]
135 136 profile : bytes
136 137 The name of the Cluster profile to be used to find connector information.
137 138 If run from an IPython application, the default profile will be the same
138 139 as the running application, otherwise it will be 'default'.
139 140 context : zmq.Context
140 141 Pass an existing zmq.Context instance, otherwise the client will create its own.
141 142 debug : bool
142 143 flag for lots of message printing for debug purposes
143 144 timeout : int/float
144 145 time (in seconds) to wait for connection replies from the Hub
145 146 [Default: 10]
146 147
147 148 #-------------- session related args ----------------
148 149
149 150 config : Config object
150 151 If specified, this will be relayed to the Session for configuration
151 152 username : str
152 153 set username for the session object
153 154 packer : str (import_string) or callable
154 155 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
155 156 function to serialize messages. Must support same input as
156 157 JSON, and output must be bytes.
157 158 You can pass a callable directly as `pack`
158 159 unpacker : str (import_string) or callable
159 160 The inverse of packer. Only necessary if packer is specified as *not* one
160 161 of 'json' or 'pickle'.
161 162
162 163 #-------------- ssh related args ----------------
163 164 # These are args for configuring the ssh tunnel to be used
164 165 # credentials are used to forward connections over ssh to the Controller
165 166 # Note that the ip given in `addr` needs to be relative to sshserver
166 167 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
167 168 # and set sshserver as the same machine the Controller is on. However,
168 169 # the only requirement is that sshserver is able to see the Controller
169 170 # (i.e. is within the same trusted network).
170 171
171 172 sshserver : str
172 173 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
173 174 If keyfile or password is specified, and this is not, it will default to
174 175 the ip given in addr.
175 176 sshkey : str; path to ssh private key file
176 177 This specifies a key to be used in ssh login, default None.
177 178 Regular default ssh keys will be used without specifying this argument.
178 179 password : str
179 180 Your ssh password to sshserver. Note that if this is left None,
180 181 you will be prompted for it if passwordless key based login is unavailable.
181 182 paramiko : bool
182 183 flag for whether to use paramiko instead of shell ssh for tunneling.
183 184 [default: True on win32, False else]
184 185
185 186 ------- exec authentication args -------
186 187 If even localhost is untrusted, you can have some protection against
187 188 unauthorized execution by signing messages with HMAC digests.
188 189 Messages are still sent as cleartext, so if someone can snoop your
189 190 loopback traffic this will not protect your privacy, but will prevent
190 191 unauthorized execution.
191 192
192 193 exec_key : str
193 194 an authentication key or file containing a key
194 195 default: None
195 196
196 197
197 198 Attributes
198 199 ----------
199 200
200 201 ids : list of int engine IDs
201 202 requesting the ids attribute always synchronizes
202 203 the registration state. To request ids without synchronization,
203 204 use semi-private _ids attributes.
204 205
205 206 history : list of msg_ids
206 207 a list of msg_ids, keeping track of all the execution
207 208 messages you have submitted in order.
208 209
209 210 outstanding : set of msg_ids
210 211 a set of msg_ids that have been submitted, but whose
211 212 results have not yet been received.
212 213
213 214 results : dict
214 215 a dict of all our results, keyed by msg_id
215 216
216 217 block : bool
217 218 determines default behavior when block not specified
218 219 in execution methods
219 220
220 221 Methods
221 222 -------
222 223
223 224 spin
224 225 flushes incoming results and registration state changes
225 226 control methods spin, and requesting `ids` also ensures up to date
226 227
227 228 wait
228 229 wait on one or more msg_ids
229 230
230 231 execution methods
231 232 apply
232 233 legacy: execute, run
233 234
234 235 data movement
235 236 push, pull, scatter, gather
236 237
237 238 query methods
238 239 queue_status, get_result, purge, result_status
239 240
240 241 control methods
241 242 abort, shutdown
242 243
243 244 """
244 245
245 246
246 247 block = Bool(False)
247 248 outstanding = Set()
248 249 results = Instance('collections.defaultdict', (dict,))
249 250 metadata = Instance('collections.defaultdict', (Metadata,))
250 251 history = List()
251 252 debug = Bool(False)
253 _spin_thread = Any()
254 _stop_spinning = Any()
252 255
253 256 profile=Unicode()
254 257 def _profile_default(self):
255 258 if BaseIPythonApplication.initialized():
256 259 # an IPython app *might* be running, try to get its profile
257 260 try:
258 261 return BaseIPythonApplication.instance().profile
259 262 except (AttributeError, MultipleInstanceError):
260 263 # could be a *different* subclass of config.Application,
261 264 # which would raise one of these two errors.
262 265 return u'default'
263 266 else:
264 267 return u'default'
265 268
266 269
267 270 _outstanding_dict = Instance('collections.defaultdict', (set,))
268 271 _ids = List()
269 272 _connected=Bool(False)
270 273 _ssh=Bool(False)
271 274 _context = Instance('zmq.Context')
272 275 _config = Dict()
273 276 _engines=Instance(util.ReverseDict, (), {})
274 277 # _hub_socket=Instance('zmq.Socket')
275 278 _query_socket=Instance('zmq.Socket')
276 279 _control_socket=Instance('zmq.Socket')
277 280 _iopub_socket=Instance('zmq.Socket')
278 281 _notification_socket=Instance('zmq.Socket')
279 282 _mux_socket=Instance('zmq.Socket')
280 283 _task_socket=Instance('zmq.Socket')
281 284 _task_scheme=Unicode()
282 285 _closed = False
283 286 _ignored_control_replies=Integer(0)
284 287 _ignored_hub_replies=Integer(0)
285 288
286 289 def __new__(self, *args, **kw):
287 290 # don't raise on positional args
288 291 return HasTraits.__new__(self, **kw)
289 292
290 293 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
291 294 context=None, debug=False, exec_key=None,
292 295 sshserver=None, sshkey=None, password=None, paramiko=None,
293 296 timeout=10, **extra_args
294 297 ):
295 298 if profile:
296 299 super(Client, self).__init__(debug=debug, profile=profile)
297 300 else:
298 301 super(Client, self).__init__(debug=debug)
299 302 if context is None:
300 303 context = zmq.Context.instance()
301 304 self._context = context
305 self._stop_spinning = Event()
302 306
303 307 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
304 308 if self._cd is not None:
305 309 if url_or_file is None:
306 310 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
307 311 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
308 312 " Please specify at least one of url_or_file or profile."
309 313
310 314 if not util.is_url(url_or_file):
311 315 # it's not a url, try for a file
312 316 if not os.path.exists(url_or_file):
313 317 if self._cd:
314 318 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 319 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 320 with open(url_or_file) as f:
317 321 cfg = json.loads(f.read())
318 322 else:
319 323 cfg = {'url':url_or_file}
320 324
321 325 # sync defaults from args, json:
322 326 if sshserver:
323 327 cfg['ssh'] = sshserver
324 328 if exec_key:
325 329 cfg['exec_key'] = exec_key
326 330 exec_key = cfg['exec_key']
327 331 location = cfg.setdefault('location', None)
328 332 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 333 url = cfg['url']
330 334 proto,addr,port = util.split_url(url)
331 335 if location is not None and addr == '127.0.0.1':
332 336 # location specified, and connection is expected to be local
333 337 if location not in LOCAL_IPS and not sshserver:
334 338 # load ssh from JSON *only* if the controller is not on
335 339 # this machine
336 340 sshserver=cfg['ssh']
337 341 if location not in LOCAL_IPS and not sshserver:
338 342 # warn if no ssh specified, but SSH is probably needed
339 343 # This is only a warning, because the most likely cause
340 344 # is a local Controller on a laptop whose IP is dynamic
341 345 warnings.warn("""
342 346 Controller appears to be listening on localhost, but not on this machine.
343 347 If this is true, you should specify Client(...,sshserver='you@%s')
344 348 or instruct your controller to listen on an external IP."""%location,
345 349 RuntimeWarning)
346 350 elif not sshserver:
347 351 # otherwise sync with cfg
348 352 sshserver = cfg['ssh']
349 353
350 354 self._config = cfg
351 355
352 356 self._ssh = bool(sshserver or sshkey or password)
353 357 if self._ssh and sshserver is None:
354 358 # default to ssh via localhost
355 359 sshserver = url.split('://')[1].split(':')[0]
356 360 if self._ssh and password is None:
357 361 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 362 password=False
359 363 else:
360 364 password = getpass("SSH Password for %s: "%sshserver)
361 365 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362 366
363 367 # configure and construct the session
364 368 if exec_key is not None:
365 369 if os.path.isfile(exec_key):
366 370 extra_args['keyfile'] = exec_key
367 371 else:
368 372 exec_key = util.asbytes(exec_key)
369 373 extra_args['key'] = exec_key
370 374 self.session = Session(**extra_args)
371 375
372 376 self._query_socket = self._context.socket(zmq.DEALER)
373 377 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
374 378 if self._ssh:
375 379 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 380 else:
377 381 self._query_socket.connect(url)
378 382
379 383 self.session.debug = self.debug
380 384
381 385 self._notification_handlers = {'registration_notification' : self._register_engine,
382 386 'unregistration_notification' : self._unregister_engine,
383 387 'shutdown_notification' : lambda msg: self.close(),
384 388 }
385 389 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 390 'apply_reply' : self._handle_apply_reply}
387 391 self._connect(sshserver, ssh_kwargs, timeout)
388 392
389 393 def __del__(self):
390 394 """cleanup sockets, but _not_ context."""
391 395 self.close()
392 396
393 397 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 398 if ipython_dir is None:
395 399 ipython_dir = get_ipython_dir()
396 400 if profile_dir is not None:
397 401 try:
398 402 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 403 return
400 404 except ProfileDirError:
401 405 pass
402 406 elif profile is not None:
403 407 try:
404 408 self._cd = ProfileDir.find_profile_dir_by_name(
405 409 ipython_dir, profile)
406 410 return
407 411 except ProfileDirError:
408 412 pass
409 413 self._cd = None
410 414
411 415 def _update_engines(self, engines):
412 416 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 417 for k,v in engines.iteritems():
414 418 eid = int(k)
415 419 self._engines[eid] = v
416 420 self._ids.append(eid)
417 421 self._ids = sorted(self._ids)
418 422 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 423 self._task_scheme == 'pure' and self._task_socket:
420 424 self._stop_scheduling_tasks()
421 425
422 426 def _stop_scheduling_tasks(self):
423 427 """Stop scheduling tasks because an engine has been unregistered
424 428 from a pure ZMQ scheduler.
425 429 """
426 430 self._task_socket.close()
427 431 self._task_socket = None
428 432 msg = "An engine has been unregistered, and we are using pure " +\
429 433 "ZMQ task scheduling. Task farming will be disabled."
430 434 if self.outstanding:
431 435 msg += " If you were running tasks when this happened, " +\
432 436 "some `outstanding` msg_ids may never resolve."
433 437 warnings.warn(msg, RuntimeWarning)
434 438
435 439 def _build_targets(self, targets):
436 440 """Turn valid target IDs or 'all' into two lists:
437 441 (int_ids, uuids).
438 442 """
439 443 if not self._ids:
440 444 # flush notification socket if no engines yet, just in case
441 445 if not self.ids:
442 446 raise error.NoEnginesRegistered("Can't build targets without any engines")
443 447
444 448 if targets is None:
445 449 targets = self._ids
446 450 elif isinstance(targets, basestring):
447 451 if targets.lower() == 'all':
448 452 targets = self._ids
449 453 else:
450 454 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 455 elif isinstance(targets, int):
452 456 if targets < 0:
453 457 targets = self.ids[targets]
454 458 if targets not in self._ids:
455 459 raise IndexError("No such engine: %i"%targets)
456 460 targets = [targets]
457 461
458 462 if isinstance(targets, slice):
459 463 indices = range(len(self._ids))[targets]
460 464 ids = self.ids
461 465 targets = [ ids[i] for i in indices ]
462 466
463 467 if not isinstance(targets, (tuple, list, xrange)):
464 468 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465 469
466 470 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467 471
468 472 def _connect(self, sshserver, ssh_kwargs, timeout):
469 473 """setup all our socket connections to the cluster. This is called from
470 474 __init__."""
471 475
472 476 # Maybe allow reconnecting?
473 477 if self._connected:
474 478 return
475 479 self._connected=True
476 480
477 481 def connect_socket(s, url):
478 482 url = util.disambiguate_url(url, self._config['location'])
479 483 if self._ssh:
480 484 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 485 else:
482 486 return s.connect(url)
483 487
484 488 self.session.send(self._query_socket, 'connection_request')
485 489 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 490 poller = zmq.Poller()
487 491 poller.register(self._query_socket, zmq.POLLIN)
488 492 # poll expects milliseconds, timeout is seconds
489 493 evts = poller.poll(timeout*1000)
490 494 if not evts:
491 495 raise error.TimeoutError("Hub connection request timed out")
492 496 idents,msg = self.session.recv(self._query_socket,mode=0)
493 497 if self.debug:
494 498 pprint(msg)
495 499 msg = Message(msg)
496 500 content = msg.content
497 501 self._config['registration'] = dict(content)
498 502 if content.status == 'ok':
499 503 ident = self.session.bsession
500 504 if content.mux:
501 505 self._mux_socket = self._context.socket(zmq.DEALER)
502 506 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 507 connect_socket(self._mux_socket, content.mux)
504 508 if content.task:
505 509 self._task_scheme, task_addr = content.task
506 510 self._task_socket = self._context.socket(zmq.DEALER)
507 511 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 512 connect_socket(self._task_socket, task_addr)
509 513 if content.notification:
510 514 self._notification_socket = self._context.socket(zmq.SUB)
511 515 connect_socket(self._notification_socket, content.notification)
512 516 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 517 # if content.query:
514 518 # self._query_socket = self._context.socket(zmq.DEALER)
515 519 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
516 520 # connect_socket(self._query_socket, content.query)
517 521 if content.control:
518 522 self._control_socket = self._context.socket(zmq.DEALER)
519 523 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 524 connect_socket(self._control_socket, content.control)
521 525 if content.iopub:
522 526 self._iopub_socket = self._context.socket(zmq.SUB)
523 527 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 528 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 529 connect_socket(self._iopub_socket, content.iopub)
526 530 self._update_engines(dict(content.engines))
527 531 else:
528 532 self._connected = False
529 533 raise Exception("Failed to connect!")
530 534
531 535 #--------------------------------------------------------------------------
532 536 # handlers and callbacks for incoming messages
533 537 #--------------------------------------------------------------------------
534 538
535 539 def _unwrap_exception(self, content):
536 540 """unwrap exception, and remap engine_id to int."""
537 541 e = error.unwrap_exception(content)
538 542 # print e.traceback
539 543 if e.engine_info:
540 544 e_uuid = e.engine_info['engine_uuid']
541 545 eid = self._engines[e_uuid]
542 546 e.engine_info['engine_id'] = eid
543 547 return e
544 548
545 549 def _extract_metadata(self, header, parent, content):
546 550 md = {'msg_id' : parent['msg_id'],
547 551 'received' : datetime.now(),
548 552 'engine_uuid' : header.get('engine', None),
549 553 'follow' : parent.get('follow', []),
550 554 'after' : parent.get('after', []),
551 555 'status' : content['status'],
552 556 }
553 557
554 558 if md['engine_uuid'] is not None:
555 559 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556 560
557 561 if 'date' in parent:
558 562 md['submitted'] = parent['date']
559 563 if 'started' in header:
560 564 md['started'] = header['started']
561 565 if 'date' in header:
562 566 md['completed'] = header['date']
563 567 return md
564 568
565 569 def _register_engine(self, msg):
566 570 """Register a new engine, and update our connection info."""
567 571 content = msg['content']
568 572 eid = content['id']
569 573 d = {eid : content['queue']}
570 574 self._update_engines(d)
571 575
572 576 def _unregister_engine(self, msg):
573 577 """Unregister an engine that has died."""
574 578 content = msg['content']
575 579 eid = int(content['id'])
576 580 if eid in self._ids:
577 581 self._ids.remove(eid)
578 582 uuid = self._engines.pop(eid)
579 583
580 584 self._handle_stranded_msgs(eid, uuid)
581 585
582 586 if self._task_socket and self._task_scheme == 'pure':
583 587 self._stop_scheduling_tasks()
584 588
585 589 def _handle_stranded_msgs(self, eid, uuid):
586 590 """Handle messages known to be on an engine when the engine unregisters.
587 591
588 592 It is possible that this will fire prematurely - that is, an engine will
589 593 go down after completing a result, and the client will be notified
590 594 of the unregistration and later receive the successful result.
591 595 """
592 596
593 597 outstanding = self._outstanding_dict[uuid]
594 598
595 599 for msg_id in list(outstanding):
596 600 if msg_id in self.results:
597 601 # we already
598 602 continue
599 603 try:
600 604 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 605 except:
602 606 content = error.wrap_exception()
603 607 # build a fake message:
604 608 parent = {}
605 609 header = {}
606 610 parent['msg_id'] = msg_id
607 611 header['engine'] = uuid
608 612 header['date'] = datetime.now()
609 613 msg = dict(parent_header=parent, header=header, content=content)
610 614 self._handle_apply_reply(msg)
611 615
612 616 def _handle_execute_reply(self, msg):
613 617 """Save the reply to an execute_request into our results.
614 618
615 619 execute messages are never actually used. apply is used instead.
616 620 """
617 621
618 622 parent = msg['parent_header']
619 623 msg_id = parent['msg_id']
620 624 if msg_id not in self.outstanding:
621 625 if msg_id in self.history:
622 626 print ("got stale result: %s"%msg_id)
623 627 else:
624 628 print ("got unknown result: %s"%msg_id)
625 629 else:
626 630 self.outstanding.remove(msg_id)
627 631 self.results[msg_id] = self._unwrap_exception(msg['content'])
628 632
629 633 def _handle_apply_reply(self, msg):
630 634 """Save the reply to an apply_request into our results."""
631 635 parent = msg['parent_header']
632 636 msg_id = parent['msg_id']
633 637 if msg_id not in self.outstanding:
634 638 if msg_id in self.history:
635 639 print ("got stale result: %s"%msg_id)
636 640 print self.results[msg_id]
637 641 print msg
638 642 else:
639 643 print ("got unknown result: %s"%msg_id)
640 644 else:
641 645 self.outstanding.remove(msg_id)
642 646 content = msg['content']
643 647 header = msg['header']
644 648
645 649 # construct metadata:
646 650 md = self.metadata[msg_id]
647 651 md.update(self._extract_metadata(header, parent, content))
648 652 # is this redundant?
649 653 self.metadata[msg_id] = md
650 654
651 655 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 656 if msg_id in e_outstanding:
653 657 e_outstanding.remove(msg_id)
654 658
655 659 # construct result:
656 660 if content['status'] == 'ok':
657 661 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 662 elif content['status'] == 'aborted':
659 663 self.results[msg_id] = error.TaskAborted(msg_id)
660 664 elif content['status'] == 'resubmitted':
661 665 # TODO: handle resubmission
662 666 pass
663 667 else:
664 668 self.results[msg_id] = self._unwrap_exception(content)
665 669
666 670 def _flush_notifications(self):
667 671 """Flush notifications of engine registrations waiting
668 672 in ZMQ queue."""
669 673 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 674 while msg is not None:
671 675 if self.debug:
672 676 pprint(msg)
673 677 msg_type = msg['header']['msg_type']
674 678 handler = self._notification_handlers.get(msg_type, None)
675 679 if handler is None:
676 680 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 681 else:
678 682 handler(msg)
679 683 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680 684
681 685 def _flush_results(self, sock):
682 686 """Flush task or queue results waiting in ZMQ queue."""
683 687 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 688 while msg is not None:
685 689 if self.debug:
686 690 pprint(msg)
687 691 msg_type = msg['header']['msg_type']
688 692 handler = self._queue_handlers.get(msg_type, None)
689 693 if handler is None:
690 694 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 695 else:
692 696 handler(msg)
693 697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 698
695 699 def _flush_control(self, sock):
696 700 """Flush replies from the control channel waiting
697 701 in the ZMQ queue.
698 702
699 703 Currently: ignore them."""
700 704 if self._ignored_control_replies <= 0:
701 705 return
702 706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 707 while msg is not None:
704 708 self._ignored_control_replies -= 1
705 709 if self.debug:
706 710 pprint(msg)
707 711 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708 712
709 713 def _flush_ignored_control(self):
710 714 """flush ignored control replies"""
711 715 while self._ignored_control_replies > 0:
712 716 self.session.recv(self._control_socket)
713 717 self._ignored_control_replies -= 1
714 718
715 719 def _flush_ignored_hub_replies(self):
716 720 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 721 while msg is not None:
718 722 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719 723
720 724 def _flush_iopub(self, sock):
721 725 """Flush replies from the iopub channel waiting
722 726 in the ZMQ queue.
723 727 """
724 728 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 729 while msg is not None:
726 730 if self.debug:
727 731 pprint(msg)
728 732 parent = msg['parent_header']
729 733 # ignore IOPub messages with no parent.
730 734 # Caused by print statements or warnings from before the first execution.
731 735 if not parent:
732 736 continue
733 737 msg_id = parent['msg_id']
734 738 content = msg['content']
735 739 header = msg['header']
736 740 msg_type = msg['header']['msg_type']
737 741
738 742 # init metadata:
739 743 md = self.metadata[msg_id]
740 744
741 745 if msg_type == 'stream':
742 746 name = content['name']
743 747 s = md[name] or ''
744 748 md[name] = s + content['data']
745 749 elif msg_type == 'pyerr':
746 750 md.update({'pyerr' : self._unwrap_exception(content)})
747 751 elif msg_type == 'pyin':
748 752 md.update({'pyin' : content['code']})
749 753 else:
750 754 md.update({msg_type : content.get('data', '')})
751 755
752 756 # reduntant?
753 757 self.metadata[msg_id] = md
754 758
755 759 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
756 760
757 761 #--------------------------------------------------------------------------
758 762 # len, getitem
759 763 #--------------------------------------------------------------------------
760 764
761 765 def __len__(self):
762 766 """len(client) returns # of engines."""
763 767 return len(self.ids)
764 768
765 769 def __getitem__(self, key):
766 770 """index access returns DirectView multiplexer objects
767 771
768 772 Must be int, slice, or list/tuple/xrange of ints"""
769 773 if not isinstance(key, (int, slice, tuple, list, xrange)):
770 774 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
771 775 else:
772 776 return self.direct_view(key)
773 777
774 778 #--------------------------------------------------------------------------
775 779 # Begin public methods
776 780 #--------------------------------------------------------------------------
777 781
778 782 @property
779 783 def ids(self):
780 784 """Always up-to-date ids property."""
781 785 self._flush_notifications()
782 786 # always copy:
783 787 return list(self._ids)
784 788
785 789 def close(self):
786 790 if self._closed:
787 791 return
792 self.stop_spin_thread()
788 793 snames = filter(lambda n: n.endswith('socket'), dir(self))
789 794 for socket in map(lambda name: getattr(self, name), snames):
790 795 if isinstance(socket, zmq.Socket) and not socket.closed:
791 796 socket.close()
792 797 self._closed = True
793 798
799 def _spin_every(self, interval=1):
800 """target func for use in spin_thread"""
801 while True:
802 if self._stop_spinning.is_set():
803 return
804 time.sleep(interval)
805 self.spin()
806
807 def spin_thread(self, interval=1):
808 """call Client.spin() in a background thread on some regular interval
809
810 This helps ensure that messages don't pile up too much in the zmq queue
811 while you are working on other things, or just leaving an idle terminal.
812
813 It also helps limit potential padding of the `received` timestamp
814 on AsyncResult objects, used for timings.
815
816 Parameters
817 ----------
818
819 interval : float, optional
820 The interval on which to spin the client in the background thread
821 (simply passed to time.sleep).
822
823 Notes
824 -----
825
826 For precision timing, you may want to use this method to put a bound
827 on the jitter (in seconds) in `received` timestamps used
828 in AsyncResult.wall_time.
829
830 """
831 if self._spin_thread is not None:
832 self.stop_spin_thread()
833 self._stop_spinning.clear()
834 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
835 self._spin_thread.daemon = True
836 self._spin_thread.start()
837
838 def stop_spin_thread(self):
839 """stop background spin_thread, if any"""
840 if self._spin_thread is not None:
841 self._stop_spinning.set()
842 self._spin_thread.join()
843 self._spin_thread = None
844
794 845 def spin(self):
795 846 """Flush any registration notifications and execution results
796 847 waiting in the ZMQ queue.
797 848 """
798 849 if self._notification_socket:
799 850 self._flush_notifications()
800 851 if self._mux_socket:
801 852 self._flush_results(self._mux_socket)
802 853 if self._task_socket:
803 854 self._flush_results(self._task_socket)
804 855 if self._control_socket:
805 856 self._flush_control(self._control_socket)
806 857 if self._iopub_socket:
807 858 self._flush_iopub(self._iopub_socket)
808 859 if self._query_socket:
809 860 self._flush_ignored_hub_replies()
810 861
811 862 def wait(self, jobs=None, timeout=-1):
812 863 """waits on one or more `jobs`, for up to `timeout` seconds.
813 864
814 865 Parameters
815 866 ----------
816 867
817 868 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
818 869 ints are indices to self.history
819 870 strs are msg_ids
820 871 default: wait on all outstanding messages
821 872 timeout : float
822 873 a time in seconds, after which to give up.
823 874 default is -1, which means no timeout
824 875
825 876 Returns
826 877 -------
827 878
828 879 True : when all msg_ids are done
829 880 False : timeout reached, some msg_ids still outstanding
830 881 """
831 882 tic = time.time()
832 883 if jobs is None:
833 884 theids = self.outstanding
834 885 else:
835 886 if isinstance(jobs, (int, basestring, AsyncResult)):
836 887 jobs = [jobs]
837 888 theids = set()
838 889 for job in jobs:
839 890 if isinstance(job, int):
840 891 # index access
841 892 job = self.history[job]
842 893 elif isinstance(job, AsyncResult):
843 894 map(theids.add, job.msg_ids)
844 895 continue
845 896 theids.add(job)
846 897 if not theids.intersection(self.outstanding):
847 898 return True
848 899 self.spin()
849 900 while theids.intersection(self.outstanding):
850 901 if timeout >= 0 and ( time.time()-tic ) > timeout:
851 902 break
852 903 time.sleep(1e-3)
853 904 self.spin()
854 905 return len(theids.intersection(self.outstanding)) == 0
855 906
856 907 #--------------------------------------------------------------------------
857 908 # Control methods
858 909 #--------------------------------------------------------------------------
859 910
860 911 @spin_first
861 912 def clear(self, targets=None, block=None):
862 913 """Clear the namespace in target(s)."""
863 914 block = self.block if block is None else block
864 915 targets = self._build_targets(targets)[0]
865 916 for t in targets:
866 917 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
867 918 error = False
868 919 if block:
869 920 self._flush_ignored_control()
870 921 for i in range(len(targets)):
871 922 idents,msg = self.session.recv(self._control_socket,0)
872 923 if self.debug:
873 924 pprint(msg)
874 925 if msg['content']['status'] != 'ok':
875 926 error = self._unwrap_exception(msg['content'])
876 927 else:
877 928 self._ignored_control_replies += len(targets)
878 929 if error:
879 930 raise error
880 931
881 932
882 933 @spin_first
883 934 def abort(self, jobs=None, targets=None, block=None):
884 935 """Abort specific jobs from the execution queues of target(s).
885 936
886 937 This is a mechanism to prevent jobs that have already been submitted
887 938 from executing.
888 939
889 940 Parameters
890 941 ----------
891 942
892 943 jobs : msg_id, list of msg_ids, or AsyncResult
893 944 The jobs to be aborted
894 945
895 946 If unspecified/None: abort all outstanding jobs.
896 947
897 948 """
898 949 block = self.block if block is None else block
899 950 jobs = jobs if jobs is not None else list(self.outstanding)
900 951 targets = self._build_targets(targets)[0]
901 952
902 953 msg_ids = []
903 954 if isinstance(jobs, (basestring,AsyncResult)):
904 955 jobs = [jobs]
905 956 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
906 957 if bad_ids:
907 958 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
908 959 for j in jobs:
909 960 if isinstance(j, AsyncResult):
910 961 msg_ids.extend(j.msg_ids)
911 962 else:
912 963 msg_ids.append(j)
913 964 content = dict(msg_ids=msg_ids)
914 965 for t in targets:
915 966 self.session.send(self._control_socket, 'abort_request',
916 967 content=content, ident=t)
917 968 error = False
918 969 if block:
919 970 self._flush_ignored_control()
920 971 for i in range(len(targets)):
921 972 idents,msg = self.session.recv(self._control_socket,0)
922 973 if self.debug:
923 974 pprint(msg)
924 975 if msg['content']['status'] != 'ok':
925 976 error = self._unwrap_exception(msg['content'])
926 977 else:
927 978 self._ignored_control_replies += len(targets)
928 979 if error:
929 980 raise error
930 981
931 982 @spin_first
932 983 def shutdown(self, targets=None, restart=False, hub=False, block=None):
933 984 """Terminates one or more engine processes, optionally including the hub."""
934 985 block = self.block if block is None else block
935 986 if hub:
936 987 targets = 'all'
937 988 targets = self._build_targets(targets)[0]
938 989 for t in targets:
939 990 self.session.send(self._control_socket, 'shutdown_request',
940 991 content={'restart':restart},ident=t)
941 992 error = False
942 993 if block or hub:
943 994 self._flush_ignored_control()
944 995 for i in range(len(targets)):
945 996 idents,msg = self.session.recv(self._control_socket, 0)
946 997 if self.debug:
947 998 pprint(msg)
948 999 if msg['content']['status'] != 'ok':
949 1000 error = self._unwrap_exception(msg['content'])
950 1001 else:
951 1002 self._ignored_control_replies += len(targets)
952 1003
953 1004 if hub:
954 1005 time.sleep(0.25)
955 1006 self.session.send(self._query_socket, 'shutdown_request')
956 1007 idents,msg = self.session.recv(self._query_socket, 0)
957 1008 if self.debug:
958 1009 pprint(msg)
959 1010 if msg['content']['status'] != 'ok':
960 1011 error = self._unwrap_exception(msg['content'])
961 1012
962 1013 if error:
963 1014 raise error
964 1015
965 1016 #--------------------------------------------------------------------------
966 1017 # Execution related methods
967 1018 #--------------------------------------------------------------------------
968 1019
969 1020 def _maybe_raise(self, result):
970 1021 """wrapper for maybe raising an exception if apply failed."""
971 1022 if isinstance(result, error.RemoteError):
972 1023 raise result
973 1024
974 1025 return result
975 1026
976 1027 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
977 1028 ident=None):
978 1029 """construct and send an apply message via a socket.
979 1030
980 1031 This is the principal method with which all engine execution is performed by views.
981 1032 """
982 1033
983 1034 assert not self._closed, "cannot use me anymore, I'm closed!"
984 1035 # defaults:
985 1036 args = args if args is not None else []
986 1037 kwargs = kwargs if kwargs is not None else {}
987 1038 subheader = subheader if subheader is not None else {}
988 1039
989 1040 # validate arguments
990 1041 if not callable(f) and not isinstance(f, Reference):
991 1042 raise TypeError("f must be callable, not %s"%type(f))
992 1043 if not isinstance(args, (tuple, list)):
993 1044 raise TypeError("args must be tuple or list, not %s"%type(args))
994 1045 if not isinstance(kwargs, dict):
995 1046 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
996 1047 if not isinstance(subheader, dict):
997 1048 raise TypeError("subheader must be dict, not %s"%type(subheader))
998 1049
999 1050 bufs = util.pack_apply_message(f,args,kwargs)
1000 1051
1001 1052 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1002 1053 subheader=subheader, track=track)
1003 1054
1004 1055 msg_id = msg['header']['msg_id']
1005 1056 self.outstanding.add(msg_id)
1006 1057 if ident:
1007 1058 # possibly routed to a specific engine
1008 1059 if isinstance(ident, list):
1009 1060 ident = ident[-1]
1010 1061 if ident in self._engines.values():
1011 1062 # save for later, in case of engine death
1012 1063 self._outstanding_dict[ident].add(msg_id)
1013 1064 self.history.append(msg_id)
1014 1065 self.metadata[msg_id]['submitted'] = datetime.now()
1015 1066
1016 1067 return msg
1017 1068
1018 1069 #--------------------------------------------------------------------------
1019 1070 # construct a View object
1020 1071 #--------------------------------------------------------------------------
1021 1072
1022 1073 def load_balanced_view(self, targets=None):
1023 1074 """construct a DirectView object.
1024 1075
1025 1076 If no arguments are specified, create a LoadBalancedView
1026 1077 using all engines.
1027 1078
1028 1079 Parameters
1029 1080 ----------
1030 1081
1031 1082 targets: list,slice,int,etc. [default: use all engines]
1032 1083 The subset of engines across which to load-balance
1033 1084 """
1034 1085 if targets == 'all':
1035 1086 targets = None
1036 1087 if targets is not None:
1037 1088 targets = self._build_targets(targets)[1]
1038 1089 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1039 1090
1040 1091 def direct_view(self, targets='all'):
1041 1092 """construct a DirectView object.
1042 1093
1043 1094 If no targets are specified, create a DirectView using all engines.
1044 1095
1045 1096 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1046 1097 evaluate the target engines at each execution, whereas rc[:] will connect to
1047 1098 all *current* engines, and that list will not change.
1048 1099
1049 1100 That is, 'all' will always use all engines, whereas rc[:] will not use
1050 1101 engines added after the DirectView is constructed.
1051 1102
1052 1103 Parameters
1053 1104 ----------
1054 1105
1055 1106 targets: list,slice,int,etc. [default: use all engines]
1056 1107 The engines to use for the View
1057 1108 """
1058 1109 single = isinstance(targets, int)
1059 1110 # allow 'all' to be lazily evaluated at each execution
1060 1111 if targets != 'all':
1061 1112 targets = self._build_targets(targets)[1]
1062 1113 if single:
1063 1114 targets = targets[0]
1064 1115 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1065 1116
1066 1117 #--------------------------------------------------------------------------
1067 1118 # Query methods
1068 1119 #--------------------------------------------------------------------------
1069 1120
1070 1121 @spin_first
1071 1122 def get_result(self, indices_or_msg_ids=None, block=None):
1072 1123 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1073 1124
1074 1125 If the client already has the results, no request to the Hub will be made.
1075 1126
1076 1127 This is a convenient way to construct AsyncResult objects, which are wrappers
1077 1128 that include metadata about execution, and allow for awaiting results that
1078 1129 were not submitted by this Client.
1079 1130
1080 1131 It can also be a convenient way to retrieve the metadata associated with
1081 1132 blocking execution, since it always retrieves
1082 1133
1083 1134 Examples
1084 1135 --------
1085 1136 ::
1086 1137
1087 1138 In [10]: r = client.apply()
1088 1139
1089 1140 Parameters
1090 1141 ----------
1091 1142
1092 1143 indices_or_msg_ids : integer history index, str msg_id, or list of either
1093 1144 The indices or msg_ids of indices to be retrieved
1094 1145
1095 1146 block : bool
1096 1147 Whether to wait for the result to be done
1097 1148
1098 1149 Returns
1099 1150 -------
1100 1151
1101 1152 AsyncResult
1102 1153 A single AsyncResult object will always be returned.
1103 1154
1104 1155 AsyncHubResult
1105 1156 A subclass of AsyncResult that retrieves results from the Hub
1106 1157
1107 1158 """
1108 1159 block = self.block if block is None else block
1109 1160 if indices_or_msg_ids is None:
1110 1161 indices_or_msg_ids = -1
1111 1162
1112 1163 if not isinstance(indices_or_msg_ids, (list,tuple)):
1113 1164 indices_or_msg_ids = [indices_or_msg_ids]
1114 1165
1115 1166 theids = []
1116 1167 for id in indices_or_msg_ids:
1117 1168 if isinstance(id, int):
1118 1169 id = self.history[id]
1119 1170 if not isinstance(id, basestring):
1120 1171 raise TypeError("indices must be str or int, not %r"%id)
1121 1172 theids.append(id)
1122 1173
1123 1174 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1124 1175 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1125 1176
1126 1177 if remote_ids:
1127 1178 ar = AsyncHubResult(self, msg_ids=theids)
1128 1179 else:
1129 1180 ar = AsyncResult(self, msg_ids=theids)
1130 1181
1131 1182 if block:
1132 1183 ar.wait()
1133 1184
1134 1185 return ar
1135 1186
1136 1187 @spin_first
1137 1188 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1138 1189 """Resubmit one or more tasks.
1139 1190
1140 1191 in-flight tasks may not be resubmitted.
1141 1192
1142 1193 Parameters
1143 1194 ----------
1144 1195
1145 1196 indices_or_msg_ids : integer history index, str msg_id, or list of either
1146 1197 The indices or msg_ids of indices to be retrieved
1147 1198
1148 1199 block : bool
1149 1200 Whether to wait for the result to be done
1150 1201
1151 1202 Returns
1152 1203 -------
1153 1204
1154 1205 AsyncHubResult
1155 1206 A subclass of AsyncResult that retrieves results from the Hub
1156 1207
1157 1208 """
1158 1209 block = self.block if block is None else block
1159 1210 if indices_or_msg_ids is None:
1160 1211 indices_or_msg_ids = -1
1161 1212
1162 1213 if not isinstance(indices_or_msg_ids, (list,tuple)):
1163 1214 indices_or_msg_ids = [indices_or_msg_ids]
1164 1215
1165 1216 theids = []
1166 1217 for id in indices_or_msg_ids:
1167 1218 if isinstance(id, int):
1168 1219 id = self.history[id]
1169 1220 if not isinstance(id, basestring):
1170 1221 raise TypeError("indices must be str or int, not %r"%id)
1171 1222 theids.append(id)
1172 1223
1173 1224 for msg_id in theids:
1174 1225 self.outstanding.discard(msg_id)
1175 1226 if msg_id in self.history:
1176 1227 self.history.remove(msg_id)
1177 1228 self.results.pop(msg_id, None)
1178 1229 self.metadata.pop(msg_id, None)
1179 1230 content = dict(msg_ids = theids)
1180 1231
1181 1232 self.session.send(self._query_socket, 'resubmit_request', content)
1182 1233
1183 1234 zmq.select([self._query_socket], [], [])
1184 1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1185 1236 if self.debug:
1186 1237 pprint(msg)
1187 1238 content = msg['content']
1188 1239 if content['status'] != 'ok':
1189 1240 raise self._unwrap_exception(content)
1190 1241
1191 1242 ar = AsyncHubResult(self, msg_ids=theids)
1192 1243
1193 1244 if block:
1194 1245 ar.wait()
1195 1246
1196 1247 return ar
1197 1248
1198 1249 @spin_first
1199 1250 def result_status(self, msg_ids, status_only=True):
1200 1251 """Check on the status of the result(s) of the apply request with `msg_ids`.
1201 1252
1202 1253 If status_only is False, then the actual results will be retrieved, else
1203 1254 only the status of the results will be checked.
1204 1255
1205 1256 Parameters
1206 1257 ----------
1207 1258
1208 1259 msg_ids : list of msg_ids
1209 1260 if int:
1210 1261 Passed as index to self.history for convenience.
1211 1262 status_only : bool (default: True)
1212 1263 if False:
1213 1264 Retrieve the actual results of completed tasks.
1214 1265
1215 1266 Returns
1216 1267 -------
1217 1268
1218 1269 results : dict
1219 1270 There will always be the keys 'pending' and 'completed', which will
1220 1271 be lists of msg_ids that are incomplete or complete. If `status_only`
1221 1272 is False, then completed results will be keyed by their `msg_id`.
1222 1273 """
1223 1274 if not isinstance(msg_ids, (list,tuple)):
1224 1275 msg_ids = [msg_ids]
1225 1276
1226 1277 theids = []
1227 1278 for msg_id in msg_ids:
1228 1279 if isinstance(msg_id, int):
1229 1280 msg_id = self.history[msg_id]
1230 1281 if not isinstance(msg_id, basestring):
1231 1282 raise TypeError("msg_ids must be str, not %r"%msg_id)
1232 1283 theids.append(msg_id)
1233 1284
1234 1285 completed = []
1235 1286 local_results = {}
1236 1287
1237 1288 # comment this block out to temporarily disable local shortcut:
1238 1289 for msg_id in theids:
1239 1290 if msg_id in self.results:
1240 1291 completed.append(msg_id)
1241 1292 local_results[msg_id] = self.results[msg_id]
1242 1293 theids.remove(msg_id)
1243 1294
1244 1295 if theids: # some not locally cached
1245 1296 content = dict(msg_ids=theids, status_only=status_only)
1246 1297 msg = self.session.send(self._query_socket, "result_request", content=content)
1247 1298 zmq.select([self._query_socket], [], [])
1248 1299 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1249 1300 if self.debug:
1250 1301 pprint(msg)
1251 1302 content = msg['content']
1252 1303 if content['status'] != 'ok':
1253 1304 raise self._unwrap_exception(content)
1254 1305 buffers = msg['buffers']
1255 1306 else:
1256 1307 content = dict(completed=[],pending=[])
1257 1308
1258 1309 content['completed'].extend(completed)
1259 1310
1260 1311 if status_only:
1261 1312 return content
1262 1313
1263 1314 failures = []
1264 1315 # load cached results into result:
1265 1316 content.update(local_results)
1266 1317
1267 1318 # update cache with results:
1268 1319 for msg_id in sorted(theids):
1269 1320 if msg_id in content['completed']:
1270 1321 rec = content[msg_id]
1271 1322 parent = rec['header']
1272 1323 header = rec['result_header']
1273 1324 rcontent = rec['result_content']
1274 1325 iodict = rec['io']
1275 1326 if isinstance(rcontent, str):
1276 1327 rcontent = self.session.unpack(rcontent)
1277 1328
1278 1329 md = self.metadata[msg_id]
1279 1330 md.update(self._extract_metadata(header, parent, rcontent))
1280 1331 if rec.get('received'):
1281 1332 md['received'] = rec['received']
1282 1333 md.update(iodict)
1283 1334
1284 1335 if rcontent['status'] == 'ok':
1285 1336 res,buffers = util.unserialize_object(buffers)
1286 1337 else:
1287 1338 print rcontent
1288 1339 res = self._unwrap_exception(rcontent)
1289 1340 failures.append(res)
1290 1341
1291 1342 self.results[msg_id] = res
1292 1343 content[msg_id] = res
1293 1344
1294 1345 if len(theids) == 1 and failures:
1295 1346 raise failures[0]
1296 1347
1297 1348 error.collect_exceptions(failures, "result_status")
1298 1349 return content
1299 1350
1300 1351 @spin_first
1301 1352 def queue_status(self, targets='all', verbose=False):
1302 1353 """Fetch the status of engine queues.
1303 1354
1304 1355 Parameters
1305 1356 ----------
1306 1357
1307 1358 targets : int/str/list of ints/strs
1308 1359 the engines whose states are to be queried.
1309 1360 default : all
1310 1361 verbose : bool
1311 1362 Whether to return lengths only, or lists of ids for each element
1312 1363 """
1313 1364 if targets == 'all':
1314 1365 # allow 'all' to be evaluated on the engine
1315 1366 engine_ids = None
1316 1367 else:
1317 1368 engine_ids = self._build_targets(targets)[1]
1318 1369 content = dict(targets=engine_ids, verbose=verbose)
1319 1370 self.session.send(self._query_socket, "queue_request", content=content)
1320 1371 idents,msg = self.session.recv(self._query_socket, 0)
1321 1372 if self.debug:
1322 1373 pprint(msg)
1323 1374 content = msg['content']
1324 1375 status = content.pop('status')
1325 1376 if status != 'ok':
1326 1377 raise self._unwrap_exception(content)
1327 1378 content = rekey(content)
1328 1379 if isinstance(targets, int):
1329 1380 return content[targets]
1330 1381 else:
1331 1382 return content
1332 1383
1333 1384 @spin_first
1334 1385 def purge_results(self, jobs=[], targets=[]):
1335 1386 """Tell the Hub to forget results.
1336 1387
1337 1388 Individual results can be purged by msg_id, or the entire
1338 1389 history of specific targets can be purged.
1339 1390
1340 1391 Use `purge_results('all')` to scrub everything from the Hub's db.
1341 1392
1342 1393 Parameters
1343 1394 ----------
1344 1395
1345 1396 jobs : str or list of str or AsyncResult objects
1346 1397 the msg_ids whose results should be forgotten.
1347 1398 targets : int/str/list of ints/strs
1348 1399 The targets, by int_id, whose entire history is to be purged.
1349 1400
1350 1401 default : None
1351 1402 """
1352 1403 if not targets and not jobs:
1353 1404 raise ValueError("Must specify at least one of `targets` and `jobs`")
1354 1405 if targets:
1355 1406 targets = self._build_targets(targets)[1]
1356 1407
1357 1408 # construct msg_ids from jobs
1358 1409 if jobs == 'all':
1359 1410 msg_ids = jobs
1360 1411 else:
1361 1412 msg_ids = []
1362 1413 if isinstance(jobs, (basestring,AsyncResult)):
1363 1414 jobs = [jobs]
1364 1415 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1365 1416 if bad_ids:
1366 1417 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1367 1418 for j in jobs:
1368 1419 if isinstance(j, AsyncResult):
1369 1420 msg_ids.extend(j.msg_ids)
1370 1421 else:
1371 1422 msg_ids.append(j)
1372 1423
1373 1424 content = dict(engine_ids=targets, msg_ids=msg_ids)
1374 1425 self.session.send(self._query_socket, "purge_request", content=content)
1375 1426 idents, msg = self.session.recv(self._query_socket, 0)
1376 1427 if self.debug:
1377 1428 pprint(msg)
1378 1429 content = msg['content']
1379 1430 if content['status'] != 'ok':
1380 1431 raise self._unwrap_exception(content)
1381 1432
1382 1433 @spin_first
1383 1434 def hub_history(self):
1384 1435 """Get the Hub's history
1385 1436
1386 1437 Just like the Client, the Hub has a history, which is a list of msg_ids.
1387 1438 This will contain the history of all clients, and, depending on configuration,
1388 1439 may contain history across multiple cluster sessions.
1389 1440
1390 1441 Any msg_id returned here is a valid argument to `get_result`.
1391 1442
1392 1443 Returns
1393 1444 -------
1394 1445
1395 1446 msg_ids : list of strs
1396 1447 list of all msg_ids, ordered by task submission time.
1397 1448 """
1398 1449
1399 1450 self.session.send(self._query_socket, "history_request", content={})
1400 1451 idents, msg = self.session.recv(self._query_socket, 0)
1401 1452
1402 1453 if self.debug:
1403 1454 pprint(msg)
1404 1455 content = msg['content']
1405 1456 if content['status'] != 'ok':
1406 1457 raise self._unwrap_exception(content)
1407 1458 else:
1408 1459 return content['history']
1409 1460
1410 1461 @spin_first
1411 1462 def db_query(self, query, keys=None):
1412 1463 """Query the Hub's TaskRecord database
1413 1464
1414 1465 This will return a list of task record dicts that match `query`
1415 1466
1416 1467 Parameters
1417 1468 ----------
1418 1469
1419 1470 query : mongodb query dict
1420 1471 The search dict. See mongodb query docs for details.
1421 1472 keys : list of strs [optional]
1422 1473 The subset of keys to be returned. The default is to fetch everything but buffers.
1423 1474 'msg_id' will *always* be included.
1424 1475 """
1425 1476 if isinstance(keys, basestring):
1426 1477 keys = [keys]
1427 1478 content = dict(query=query, keys=keys)
1428 1479 self.session.send(self._query_socket, "db_request", content=content)
1429 1480 idents, msg = self.session.recv(self._query_socket, 0)
1430 1481 if self.debug:
1431 1482 pprint(msg)
1432 1483 content = msg['content']
1433 1484 if content['status'] != 'ok':
1434 1485 raise self._unwrap_exception(content)
1435 1486
1436 1487 records = content['records']
1437 1488
1438 1489 buffer_lens = content['buffer_lens']
1439 1490 result_buffer_lens = content['result_buffer_lens']
1440 1491 buffers = msg['buffers']
1441 1492 has_bufs = buffer_lens is not None
1442 1493 has_rbufs = result_buffer_lens is not None
1443 1494 for i,rec in enumerate(records):
1444 1495 # relink buffers
1445 1496 if has_bufs:
1446 1497 blen = buffer_lens[i]
1447 1498 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1448 1499 if has_rbufs:
1449 1500 blen = result_buffer_lens[i]
1450 1501 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1451 1502
1452 1503 return records
1453 1504
1454 1505 __all__ = [ 'Client' ]
@@ -1,319 +1,337 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython.parallel.client import client as clientmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult
30 30 from IPython.parallel import LoadBalancedView, DirectView
31 31
32 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33 33
34 34 def setup():
35 35 add_engines(4, total=True)
36 36
37 37 class TestClient(ClusterTestCase):
38 38
39 39 def test_ids(self):
40 40 n = len(self.client.ids)
41 41 self.add_engines(2)
42 42 self.assertEquals(len(self.client.ids), n+2)
43 43
44 44 def test_view_indexing(self):
45 45 """test index access for views"""
46 46 self.minimum_engines(4)
47 47 targets = self.client._build_targets('all')[-1]
48 48 v = self.client[:]
49 49 self.assertEquals(v.targets, targets)
50 50 t = self.client.ids[2]
51 51 v = self.client[t]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, t)
54 54 t = self.client.ids[2:4]
55 55 v = self.client[t]
56 56 self.assert_(isinstance(v, DirectView))
57 57 self.assertEquals(v.targets, t)
58 58 v = self.client[::2]
59 59 self.assert_(isinstance(v, DirectView))
60 60 self.assertEquals(v.targets, targets[::2])
61 61 v = self.client[1::3]
62 62 self.assert_(isinstance(v, DirectView))
63 63 self.assertEquals(v.targets, targets[1::3])
64 64 v = self.client[:-3]
65 65 self.assert_(isinstance(v, DirectView))
66 66 self.assertEquals(v.targets, targets[:-3])
67 67 v = self.client[-1]
68 68 self.assert_(isinstance(v, DirectView))
69 69 self.assertEquals(v.targets, targets[-1])
70 70 self.assertRaises(TypeError, lambda : self.client[None])
71 71
72 72 def test_lbview_targets(self):
73 73 """test load_balanced_view targets"""
74 74 v = self.client.load_balanced_view()
75 75 self.assertEquals(v.targets, None)
76 76 v = self.client.load_balanced_view(-1)
77 77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 78 v = self.client.load_balanced_view('all')
79 79 self.assertEquals(v.targets, None)
80 80
81 81 def test_dview_targets(self):
82 82 """test direct_view targets"""
83 83 v = self.client.direct_view()
84 84 self.assertEquals(v.targets, 'all')
85 85 v = self.client.direct_view('all')
86 86 self.assertEquals(v.targets, 'all')
87 87 v = self.client.direct_view(-1)
88 88 self.assertEquals(v.targets, self.client.ids[-1])
89 89
90 90 def test_lazy_all_targets(self):
91 91 """test lazy evaluation of rc.direct_view('all')"""
92 92 v = self.client.direct_view()
93 93 self.assertEquals(v.targets, 'all')
94 94
95 95 def double(x):
96 96 return x*2
97 97 seq = range(100)
98 98 ref = [ double(x) for x in seq ]
99 99
100 100 # add some engines, which should be used
101 101 self.add_engines(1)
102 102 n1 = len(self.client.ids)
103 103
104 104 # simple apply
105 105 r = v.apply_sync(lambda : 1)
106 106 self.assertEquals(r, [1] * n1)
107 107
108 108 # map goes through remotefunction
109 109 r = v.map_sync(double, seq)
110 110 self.assertEquals(r, ref)
111 111
112 112 # add a couple more engines, and try again
113 113 self.add_engines(2)
114 114 n2 = len(self.client.ids)
115 115 self.assertNotEquals(n2, n1)
116 116
117 117 # apply
118 118 r = v.apply_sync(lambda : 1)
119 119 self.assertEquals(r, [1] * n2)
120 120
121 121 # map
122 122 r = v.map_sync(double, seq)
123 123 self.assertEquals(r, ref)
124 124
125 125 def test_targets(self):
126 126 """test various valid targets arguments"""
127 127 build = self.client._build_targets
128 128 ids = self.client.ids
129 129 idents,targets = build(None)
130 130 self.assertEquals(ids, targets)
131 131
132 132 def test_clear(self):
133 133 """test clear behavior"""
134 134 self.minimum_engines(2)
135 135 v = self.client[:]
136 136 v.block=True
137 137 v.push(dict(a=5))
138 138 v.pull('a')
139 139 id0 = self.client.ids[-1]
140 140 self.client.clear(targets=id0, block=True)
141 141 a = self.client[:-1].get('a')
142 142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 143 self.client.clear(block=True)
144 144 for i in self.client.ids:
145 145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146 146
147 147 def test_get_result(self):
148 148 """test getting results from the Hub."""
149 149 c = clientmod.Client(profile='iptest')
150 150 t = c.ids[-1]
151 151 ar = c[t].apply_async(wait, 1)
152 152 # give the monitor time to notice the message
153 153 time.sleep(.25)
154 154 ahr = self.client.get_result(ar.msg_ids)
155 155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 156 self.assertEquals(ahr.get(), ar.get())
157 157 ar2 = self.client.get_result(ar.msg_ids)
158 158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 159 c.close()
160 160
161 161 def test_ids_list(self):
162 162 """test client.ids"""
163 163 ids = self.client.ids
164 164 self.assertEquals(ids, self.client._ids)
165 165 self.assertFalse(ids is self.client._ids)
166 166 ids.remove(ids[-1])
167 167 self.assertNotEquals(ids, self.client._ids)
168 168
169 169 def test_queue_status(self):
170 170 ids = self.client.ids
171 171 id0 = ids[0]
172 172 qs = self.client.queue_status(targets=id0)
173 173 self.assertTrue(isinstance(qs, dict))
174 174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 175 allqs = self.client.queue_status()
176 176 self.assertTrue(isinstance(allqs, dict))
177 177 intkeys = list(allqs.keys())
178 178 intkeys.remove('unassigned')
179 179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 180 unassigned = allqs.pop('unassigned')
181 181 for eid,qs in allqs.items():
182 182 self.assertTrue(isinstance(qs, dict))
183 183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184 184
185 185 def test_shutdown(self):
186 186 ids = self.client.ids
187 187 id0 = ids[0]
188 188 self.client.shutdown(id0, block=True)
189 189 while id0 in self.client.ids:
190 190 time.sleep(0.1)
191 191 self.client.spin()
192 192
193 193 self.assertRaises(IndexError, lambda : self.client[id0])
194 194
195 195 def test_result_status(self):
196 196 pass
197 197 # to be written
198 198
199 199 def test_db_query_dt(self):
200 200 """test db query by date"""
201 201 hist = self.client.hub_history()
202 202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 203 tic = middle['submitted']
204 204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 206 self.assertEquals(len(before)+len(after),len(hist))
207 207 for b in before:
208 208 self.assertTrue(b['submitted'] < tic)
209 209 for a in after:
210 210 self.assertTrue(a['submitted'] >= tic)
211 211 same = self.client.db_query({'submitted' : tic})
212 212 for s in same:
213 213 self.assertTrue(s['submitted'] == tic)
214 214
215 215 def test_db_query_keys(self):
216 216 """test extracting subset of record keys"""
217 217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 218 for rec in found:
219 219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220 220
221 221 def test_db_query_default_keys(self):
222 222 """default db_query excludes buffers"""
223 223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 224 for rec in found:
225 225 keys = set(rec.keys())
226 226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228 228
229 229 def test_db_query_msg_id(self):
230 230 """ensure msg_id is always in db queries"""
231 231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 232 for rec in found:
233 233 self.assertTrue('msg_id' in rec.keys())
234 234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 235 for rec in found:
236 236 self.assertTrue('msg_id' in rec.keys())
237 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 238 for rec in found:
239 239 self.assertTrue('msg_id' in rec.keys())
240 240
241 241 def test_db_query_in(self):
242 242 """test db query with '$in','$nin' operators"""
243 243 hist = self.client.hub_history()
244 244 even = hist[::2]
245 245 odd = hist[1::2]
246 246 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
247 247 found = [ r['msg_id'] for r in recs ]
248 248 self.assertEquals(set(even), set(found))
249 249 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
250 250 found = [ r['msg_id'] for r in recs ]
251 251 self.assertEquals(set(odd), set(found))
252 252
253 253 def test_hub_history(self):
254 254 hist = self.client.hub_history()
255 255 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
256 256 recdict = {}
257 257 for rec in recs:
258 258 recdict[rec['msg_id']] = rec
259 259
260 260 latest = datetime(1984,1,1)
261 261 for msg_id in hist:
262 262 rec = recdict[msg_id]
263 263 newt = rec['submitted']
264 264 self.assertTrue(newt >= latest)
265 265 latest = newt
266 266 ar = self.client[-1].apply_async(lambda : 1)
267 267 ar.get()
268 268 time.sleep(0.25)
269 269 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
270 270
271 271 def test_resubmit(self):
272 272 def f():
273 273 import random
274 274 return random.random()
275 275 v = self.client.load_balanced_view()
276 276 ar = v.apply_async(f)
277 277 r1 = ar.get(1)
278 278 # give the Hub a chance to notice:
279 279 time.sleep(0.5)
280 280 ahr = self.client.resubmit(ar.msg_ids)
281 281 r2 = ahr.get(1)
282 282 self.assertFalse(r1 == r2)
283 283
284 284 def test_resubmit_inflight(self):
285 285 """ensure ValueError on resubmit of inflight task"""
286 286 v = self.client.load_balanced_view()
287 287 ar = v.apply_async(time.sleep,1)
288 288 # give the message a chance to arrive
289 289 time.sleep(0.2)
290 290 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
291 291 ar.get(2)
292 292
293 293 def test_resubmit_badkey(self):
294 294 """ensure KeyError on resubmit of nonexistant task"""
295 295 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
296 296
297 297 def test_purge_results(self):
298 298 # ensure there are some tasks
299 299 for i in range(5):
300 300 self.client[:].apply_sync(lambda : 1)
301 301 # Wait for the Hub to realise the result is done:
302 302 # This prevents a race condition, where we
303 303 # might purge a result the Hub still thinks is pending.
304 304 time.sleep(0.1)
305 305 rc2 = clientmod.Client(profile='iptest')
306 306 hist = self.client.hub_history()
307 307 ahr = rc2.get_result([hist[-1]])
308 308 ahr.wait(10)
309 309 self.client.purge_results(hist[-1])
310 310 newhist = self.client.hub_history()
311 311 self.assertEquals(len(newhist)+1,len(hist))
312 312 rc2.spin()
313 313 rc2.close()
314 314
315 315 def test_purge_all_results(self):
316 316 self.client.purge_results('all')
317 317 hist = self.client.hub_history()
318 318 self.assertEquals(len(hist), 0)
319 319
320 def test_spin_thread(self):
321 self.client.spin_thread(0.01)
322 ar = self.client[-1].apply_async(lambda : 1)
323 time.sleep(0.1)
324 self.assertTrue(ar.wall_time < 0.1,
325 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
326 )
327
328 def test_stop_spin_thread(self):
329 self.client.spin_thread(0.01)
330 self.client.stop_spin_thread()
331 ar = self.client[-1].apply_async(lambda : 1)
332 time.sleep(0.15)
333 self.assertTrue(ar.wall_time > 0.1,
334 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
335 )
336
337
General Comments 0
You need to be logged in to leave comments. Login now