##// END OF EJS Templates
fix & test HubResults from execute requests
MinRK -
Show More
@@ -1,1717 +1,1721
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52
53 53 from .asyncresult import AsyncResult, AsyncHubResult
54 54 from .view import DirectView, LoadBalancedView
55 55
56 56 if sys.version_info[0] >= 3:
57 57 # xrange is used in a couple 'isinstance' tests in py2
58 58 # should be just 'range' in 3k
59 59 xrange = range
60 60
61 61 #--------------------------------------------------------------------------
62 62 # Decorators for Client methods
63 63 #--------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def spin_first(f, self, *args, **kwargs):
67 67 """Call spin() to sync state prior to calling the method."""
68 68 self.spin()
69 69 return f(self, *args, **kwargs)
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Classes
74 74 #--------------------------------------------------------------------------
75 75
76 76
77 77 class ExecuteReply(object):
78 78 """wrapper for finished Execute results"""
79 79 def __init__(self, msg_id, content, metadata):
80 80 self.msg_id = msg_id
81 81 self._content = content
82 82 self.execution_count = content['execution_count']
83 83 self.metadata = metadata
84 84
85 85 def __getitem__(self, key):
86 86 return self.metadata[key]
87 87
88 88 def __getattr__(self, key):
89 89 if key not in self.metadata:
90 90 raise AttributeError(key)
91 91 return self.metadata[key]
92 92
93 93 def __repr__(self):
94 94 pyout = self.metadata['pyout'] or {'data':{}}
95 95 text_out = pyout['data'].get('text/plain', '')
96 96 if len(text_out) > 32:
97 97 text_out = text_out[:29] + '...'
98 98
99 99 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 100
101 101 def _repr_pretty_(self, p, cycle):
102 102 pyout = self.metadata['pyout'] or {'data':{}}
103 103 text_out = pyout['data'].get('text/plain', '')
104 104
105 105 if not text_out:
106 106 return
107 107
108 108 try:
109 109 ip = get_ipython()
110 110 except NameError:
111 111 colors = "NoColor"
112 112 else:
113 113 colors = ip.colors
114 114
115 115 if colors == "NoColor":
116 116 out = normal = ""
117 117 else:
118 118 out = TermColors.Red
119 119 normal = TermColors.Normal
120 120
121 121 if '\n' in text_out and not text_out.startswith('\n'):
122 122 # add newline for multiline reprs
123 123 text_out = '\n' + text_out
124 124
125 125 p.text(
126 126 out + u'Out[%i:%i]: ' % (
127 127 self.metadata['engine_id'], self.execution_count
128 128 ) + normal + text_out
129 129 )
130 130
131 131 def _repr_html_(self):
132 132 pyout = self.metadata['pyout'] or {'data':{}}
133 133 return pyout['data'].get("text/html")
134 134
135 135 def _repr_latex_(self):
136 136 pyout = self.metadata['pyout'] or {'data':{}}
137 137 return pyout['data'].get("text/latex")
138 138
139 139 def _repr_json_(self):
140 140 pyout = self.metadata['pyout'] or {'data':{}}
141 141 return pyout['data'].get("application/json")
142 142
143 143 def _repr_javascript_(self):
144 144 pyout = self.metadata['pyout'] or {'data':{}}
145 145 return pyout['data'].get("application/javascript")
146 146
147 147 def _repr_png_(self):
148 148 pyout = self.metadata['pyout'] or {'data':{}}
149 149 return pyout['data'].get("image/png")
150 150
151 151 def _repr_jpeg_(self):
152 152 pyout = self.metadata['pyout'] or {'data':{}}
153 153 return pyout['data'].get("image/jpeg")
154 154
155 155 def _repr_svg_(self):
156 156 pyout = self.metadata['pyout'] or {'data':{}}
157 157 return pyout['data'].get("image/svg+xml")
158 158
159 159
160 160 class Metadata(dict):
161 161 """Subclass of dict for initializing metadata values.
162 162
163 163 Attribute access works on keys.
164 164
165 165 These objects have a strict set of keys - errors will raise if you try
166 166 to add new keys.
167 167 """
168 168 def __init__(self, *args, **kwargs):
169 169 dict.__init__(self)
170 170 md = {'msg_id' : None,
171 171 'submitted' : None,
172 172 'started' : None,
173 173 'completed' : None,
174 174 'received' : None,
175 175 'engine_uuid' : None,
176 176 'engine_id' : None,
177 177 'follow' : None,
178 178 'after' : None,
179 179 'status' : None,
180 180
181 181 'pyin' : None,
182 182 'pyout' : None,
183 183 'pyerr' : None,
184 184 'stdout' : '',
185 185 'stderr' : '',
186 186 'outputs' : [],
187 187 'outputs_ready' : False,
188 188 }
189 189 self.update(md)
190 190 self.update(dict(*args, **kwargs))
191 191
192 192 def __getattr__(self, key):
193 193 """getattr aliased to getitem"""
194 194 if key in self.iterkeys():
195 195 return self[key]
196 196 else:
197 197 raise AttributeError(key)
198 198
199 199 def __setattr__(self, key, value):
200 200 """setattr aliased to setitem, with strict"""
201 201 if key in self.iterkeys():
202 202 self[key] = value
203 203 else:
204 204 raise AttributeError(key)
205 205
206 206 def __setitem__(self, key, value):
207 207 """strict static key enforcement"""
208 208 if key in self.iterkeys():
209 209 dict.__setitem__(self, key, value)
210 210 else:
211 211 raise KeyError(key)
212 212
213 213
214 214 class Client(HasTraits):
215 215 """A semi-synchronous client to the IPython ZMQ cluster
216 216
217 217 Parameters
218 218 ----------
219 219
220 220 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
221 221 Connection information for the Hub's registration. If a json connector
222 222 file is given, then likely no further configuration is necessary.
223 223 [Default: use profile]
224 224 profile : bytes
225 225 The name of the Cluster profile to be used to find connector information.
226 226 If run from an IPython application, the default profile will be the same
227 227 as the running application, otherwise it will be 'default'.
228 228 context : zmq.Context
229 229 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 230 debug : bool
231 231 flag for lots of message printing for debug purposes
232 232 timeout : int/float
233 233 time (in seconds) to wait for connection replies from the Hub
234 234 [Default: 10]
235 235
236 236 #-------------- session related args ----------------
237 237
238 238 config : Config object
239 239 If specified, this will be relayed to the Session for configuration
240 240 username : str
241 241 set username for the session object
242 242 packer : str (import_string) or callable
243 243 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
244 244 function to serialize messages. Must support same input as
245 245 JSON, and output must be bytes.
246 246 You can pass a callable directly as `pack`
247 247 unpacker : str (import_string) or callable
248 248 The inverse of packer. Only necessary if packer is specified as *not* one
249 249 of 'json' or 'pickle'.
250 250
251 251 #-------------- ssh related args ----------------
252 252 # These are args for configuring the ssh tunnel to be used
253 253 # credentials are used to forward connections over ssh to the Controller
254 254 # Note that the ip given in `addr` needs to be relative to sshserver
255 255 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
256 256 # and set sshserver as the same machine the Controller is on. However,
257 257 # the only requirement is that sshserver is able to see the Controller
258 258 # (i.e. is within the same trusted network).
259 259
260 260 sshserver : str
261 261 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
262 262 If keyfile or password is specified, and this is not, it will default to
263 263 the ip given in addr.
264 264 sshkey : str; path to ssh private key file
265 265 This specifies a key to be used in ssh login, default None.
266 266 Regular default ssh keys will be used without specifying this argument.
267 267 password : str
268 268 Your ssh password to sshserver. Note that if this is left None,
269 269 you will be prompted for it if passwordless key based login is unavailable.
270 270 paramiko : bool
271 271 flag for whether to use paramiko instead of shell ssh for tunneling.
272 272 [default: True on win32, False else]
273 273
274 274 ------- exec authentication args -------
275 275 If even localhost is untrusted, you can have some protection against
276 276 unauthorized execution by signing messages with HMAC digests.
277 277 Messages are still sent as cleartext, so if someone can snoop your
278 278 loopback traffic this will not protect your privacy, but will prevent
279 279 unauthorized execution.
280 280
281 281 exec_key : str
282 282 an authentication key or file containing a key
283 283 default: None
284 284
285 285
286 286 Attributes
287 287 ----------
288 288
289 289 ids : list of int engine IDs
290 290 requesting the ids attribute always synchronizes
291 291 the registration state. To request ids without synchronization,
292 292 use semi-private _ids attributes.
293 293
294 294 history : list of msg_ids
295 295 a list of msg_ids, keeping track of all the execution
296 296 messages you have submitted in order.
297 297
298 298 outstanding : set of msg_ids
299 299 a set of msg_ids that have been submitted, but whose
300 300 results have not yet been received.
301 301
302 302 results : dict
303 303 a dict of all our results, keyed by msg_id
304 304
305 305 block : bool
306 306 determines default behavior when block not specified
307 307 in execution methods
308 308
309 309 Methods
310 310 -------
311 311
312 312 spin
313 313 flushes incoming results and registration state changes
314 314 control methods spin, and requesting `ids` also ensures up to date
315 315
316 316 wait
317 317 wait on one or more msg_ids
318 318
319 319 execution methods
320 320 apply
321 321 legacy: execute, run
322 322
323 323 data movement
324 324 push, pull, scatter, gather
325 325
326 326 query methods
327 327 queue_status, get_result, purge, result_status
328 328
329 329 control methods
330 330 abort, shutdown
331 331
332 332 """
333 333
334 334
335 335 block = Bool(False)
336 336 outstanding = Set()
337 337 results = Instance('collections.defaultdict', (dict,))
338 338 metadata = Instance('collections.defaultdict', (Metadata,))
339 339 history = List()
340 340 debug = Bool(False)
341 341 _spin_thread = Any()
342 342 _stop_spinning = Any()
343 343
344 344 profile=Unicode()
345 345 def _profile_default(self):
346 346 if BaseIPythonApplication.initialized():
347 347 # an IPython app *might* be running, try to get its profile
348 348 try:
349 349 return BaseIPythonApplication.instance().profile
350 350 except (AttributeError, MultipleInstanceError):
351 351 # could be a *different* subclass of config.Application,
352 352 # which would raise one of these two errors.
353 353 return u'default'
354 354 else:
355 355 return u'default'
356 356
357 357
358 358 _outstanding_dict = Instance('collections.defaultdict', (set,))
359 359 _ids = List()
360 360 _connected=Bool(False)
361 361 _ssh=Bool(False)
362 362 _context = Instance('zmq.Context')
363 363 _config = Dict()
364 364 _engines=Instance(util.ReverseDict, (), {})
365 365 # _hub_socket=Instance('zmq.Socket')
366 366 _query_socket=Instance('zmq.Socket')
367 367 _control_socket=Instance('zmq.Socket')
368 368 _iopub_socket=Instance('zmq.Socket')
369 369 _notification_socket=Instance('zmq.Socket')
370 370 _mux_socket=Instance('zmq.Socket')
371 371 _task_socket=Instance('zmq.Socket')
372 372 _task_scheme=Unicode()
373 373 _closed = False
374 374 _ignored_control_replies=Integer(0)
375 375 _ignored_hub_replies=Integer(0)
376 376
377 377 def __new__(self, *args, **kw):
378 378 # don't raise on positional args
379 379 return HasTraits.__new__(self, **kw)
380 380
381 381 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
382 382 context=None, debug=False, exec_key=None,
383 383 sshserver=None, sshkey=None, password=None, paramiko=None,
384 384 timeout=10, **extra_args
385 385 ):
386 386 if profile:
387 387 super(Client, self).__init__(debug=debug, profile=profile)
388 388 else:
389 389 super(Client, self).__init__(debug=debug)
390 390 if context is None:
391 391 context = zmq.Context.instance()
392 392 self._context = context
393 393 self._stop_spinning = Event()
394 394
395 395 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 396 if self._cd is not None:
397 397 if url_or_file is None:
398 398 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
399 399 if url_or_file is None:
400 400 raise ValueError(
401 401 "I can't find enough information to connect to a hub!"
402 402 " Please specify at least one of url_or_file or profile."
403 403 )
404 404
405 405 if not util.is_url(url_or_file):
406 406 # it's not a url, try for a file
407 407 if not os.path.exists(url_or_file):
408 408 if self._cd:
409 409 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
410 410 if not os.path.exists(url_or_file):
411 411 raise IOError("Connection file not found: %r" % url_or_file)
412 412 with open(url_or_file) as f:
413 413 cfg = json.loads(f.read())
414 414 else:
415 415 cfg = {'url':url_or_file}
416 416
417 417 # sync defaults from args, json:
418 418 if sshserver:
419 419 cfg['ssh'] = sshserver
420 420 if exec_key:
421 421 cfg['exec_key'] = exec_key
422 422 exec_key = cfg['exec_key']
423 423 location = cfg.setdefault('location', None)
424 424 cfg['url'] = util.disambiguate_url(cfg['url'], location)
425 425 url = cfg['url']
426 426 proto,addr,port = util.split_url(url)
427 427 if location is not None and addr == '127.0.0.1':
428 428 # location specified, and connection is expected to be local
429 429 if location not in LOCAL_IPS and not sshserver:
430 430 # load ssh from JSON *only* if the controller is not on
431 431 # this machine
432 432 sshserver=cfg['ssh']
433 433 if location not in LOCAL_IPS and not sshserver:
434 434 # warn if no ssh specified, but SSH is probably needed
435 435 # This is only a warning, because the most likely cause
436 436 # is a local Controller on a laptop whose IP is dynamic
437 437 warnings.warn("""
438 438 Controller appears to be listening on localhost, but not on this machine.
439 439 If this is true, you should specify Client(...,sshserver='you@%s')
440 440 or instruct your controller to listen on an external IP."""%location,
441 441 RuntimeWarning)
442 442 elif not sshserver:
443 443 # otherwise sync with cfg
444 444 sshserver = cfg['ssh']
445 445
446 446 self._config = cfg
447 447
448 448 self._ssh = bool(sshserver or sshkey or password)
449 449 if self._ssh and sshserver is None:
450 450 # default to ssh via localhost
451 451 sshserver = url.split('://')[1].split(':')[0]
452 452 if self._ssh and password is None:
453 453 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
454 454 password=False
455 455 else:
456 456 password = getpass("SSH Password for %s: "%sshserver)
457 457 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
458 458
459 459 # configure and construct the session
460 460 if exec_key is not None:
461 461 if os.path.isfile(exec_key):
462 462 extra_args['keyfile'] = exec_key
463 463 else:
464 464 exec_key = cast_bytes(exec_key)
465 465 extra_args['key'] = exec_key
466 466 self.session = Session(**extra_args)
467 467
468 468 self._query_socket = self._context.socket(zmq.DEALER)
469 469 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
470 470 if self._ssh:
471 471 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
472 472 else:
473 473 self._query_socket.connect(url)
474 474
475 475 self.session.debug = self.debug
476 476
477 477 self._notification_handlers = {'registration_notification' : self._register_engine,
478 478 'unregistration_notification' : self._unregister_engine,
479 479 'shutdown_notification' : lambda msg: self.close(),
480 480 }
481 481 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 482 'apply_reply' : self._handle_apply_reply}
483 483 self._connect(sshserver, ssh_kwargs, timeout)
484 484
485 485 # last step: setup magics, if we are in IPython:
486 486
487 487 try:
488 488 ip = get_ipython()
489 489 except NameError:
490 490 return
491 491 else:
492 492 if 'px' not in ip.magics_manager.magics:
493 493 # in IPython but we are the first Client.
494 494 # activate a default view for parallel magics.
495 495 self.activate()
496 496
497 497 def __del__(self):
498 498 """cleanup sockets, but _not_ context."""
499 499 self.close()
500 500
501 501 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
502 502 if ipython_dir is None:
503 503 ipython_dir = get_ipython_dir()
504 504 if profile_dir is not None:
505 505 try:
506 506 self._cd = ProfileDir.find_profile_dir(profile_dir)
507 507 return
508 508 except ProfileDirError:
509 509 pass
510 510 elif profile is not None:
511 511 try:
512 512 self._cd = ProfileDir.find_profile_dir_by_name(
513 513 ipython_dir, profile)
514 514 return
515 515 except ProfileDirError:
516 516 pass
517 517 self._cd = None
518 518
519 519 def _update_engines(self, engines):
520 520 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
521 521 for k,v in engines.iteritems():
522 522 eid = int(k)
523 523 self._engines[eid] = v
524 524 self._ids.append(eid)
525 525 self._ids = sorted(self._ids)
526 526 if sorted(self._engines.keys()) != range(len(self._engines)) and \
527 527 self._task_scheme == 'pure' and self._task_socket:
528 528 self._stop_scheduling_tasks()
529 529
530 530 def _stop_scheduling_tasks(self):
531 531 """Stop scheduling tasks because an engine has been unregistered
532 532 from a pure ZMQ scheduler.
533 533 """
534 534 self._task_socket.close()
535 535 self._task_socket = None
536 536 msg = "An engine has been unregistered, and we are using pure " +\
537 537 "ZMQ task scheduling. Task farming will be disabled."
538 538 if self.outstanding:
539 539 msg += " If you were running tasks when this happened, " +\
540 540 "some `outstanding` msg_ids may never resolve."
541 541 warnings.warn(msg, RuntimeWarning)
542 542
543 543 def _build_targets(self, targets):
544 544 """Turn valid target IDs or 'all' into two lists:
545 545 (int_ids, uuids).
546 546 """
547 547 if not self._ids:
548 548 # flush notification socket if no engines yet, just in case
549 549 if not self.ids:
550 550 raise error.NoEnginesRegistered("Can't build targets without any engines")
551 551
552 552 if targets is None:
553 553 targets = self._ids
554 554 elif isinstance(targets, basestring):
555 555 if targets.lower() == 'all':
556 556 targets = self._ids
557 557 else:
558 558 raise TypeError("%r not valid str target, must be 'all'"%(targets))
559 559 elif isinstance(targets, int):
560 560 if targets < 0:
561 561 targets = self.ids[targets]
562 562 if targets not in self._ids:
563 563 raise IndexError("No such engine: %i"%targets)
564 564 targets = [targets]
565 565
566 566 if isinstance(targets, slice):
567 567 indices = range(len(self._ids))[targets]
568 568 ids = self.ids
569 569 targets = [ ids[i] for i in indices ]
570 570
571 571 if not isinstance(targets, (tuple, list, xrange)):
572 572 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
573 573
574 574 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
575 575
576 576 def _connect(self, sshserver, ssh_kwargs, timeout):
577 577 """setup all our socket connections to the cluster. This is called from
578 578 __init__."""
579 579
580 580 # Maybe allow reconnecting?
581 581 if self._connected:
582 582 return
583 583 self._connected=True
584 584
585 585 def connect_socket(s, url):
586 586 url = util.disambiguate_url(url, self._config['location'])
587 587 if self._ssh:
588 588 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
589 589 else:
590 590 return s.connect(url)
591 591
592 592 self.session.send(self._query_socket, 'connection_request')
593 593 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
594 594 poller = zmq.Poller()
595 595 poller.register(self._query_socket, zmq.POLLIN)
596 596 # poll expects milliseconds, timeout is seconds
597 597 evts = poller.poll(timeout*1000)
598 598 if not evts:
599 599 raise error.TimeoutError("Hub connection request timed out")
600 600 idents,msg = self.session.recv(self._query_socket,mode=0)
601 601 if self.debug:
602 602 pprint(msg)
603 603 msg = Message(msg)
604 604 content = msg.content
605 605 self._config['registration'] = dict(content)
606 606 if content.status == 'ok':
607 607 ident = self.session.bsession
608 608 if content.mux:
609 609 self._mux_socket = self._context.socket(zmq.DEALER)
610 610 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
611 611 connect_socket(self._mux_socket, content.mux)
612 612 if content.task:
613 613 self._task_scheme, task_addr = content.task
614 614 self._task_socket = self._context.socket(zmq.DEALER)
615 615 self._task_socket.setsockopt(zmq.IDENTITY, ident)
616 616 connect_socket(self._task_socket, task_addr)
617 617 if content.notification:
618 618 self._notification_socket = self._context.socket(zmq.SUB)
619 619 connect_socket(self._notification_socket, content.notification)
620 620 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
621 621 # if content.query:
622 622 # self._query_socket = self._context.socket(zmq.DEALER)
623 623 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
624 624 # connect_socket(self._query_socket, content.query)
625 625 if content.control:
626 626 self._control_socket = self._context.socket(zmq.DEALER)
627 627 self._control_socket.setsockopt(zmq.IDENTITY, ident)
628 628 connect_socket(self._control_socket, content.control)
629 629 if content.iopub:
630 630 self._iopub_socket = self._context.socket(zmq.SUB)
631 631 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
632 632 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
633 633 connect_socket(self._iopub_socket, content.iopub)
634 634 self._update_engines(dict(content.engines))
635 635 else:
636 636 self._connected = False
637 637 raise Exception("Failed to connect!")
638 638
639 639 #--------------------------------------------------------------------------
640 640 # handlers and callbacks for incoming messages
641 641 #--------------------------------------------------------------------------
642 642
643 643 def _unwrap_exception(self, content):
644 644 """unwrap exception, and remap engine_id to int."""
645 645 e = error.unwrap_exception(content)
646 646 # print e.traceback
647 647 if e.engine_info:
648 648 e_uuid = e.engine_info['engine_uuid']
649 649 eid = self._engines[e_uuid]
650 650 e.engine_info['engine_id'] = eid
651 651 return e
652 652
653 653 def _extract_metadata(self, header, parent, content):
654 654 md = {'msg_id' : parent['msg_id'],
655 655 'received' : datetime.now(),
656 656 'engine_uuid' : header.get('engine', None),
657 657 'follow' : parent.get('follow', []),
658 658 'after' : parent.get('after', []),
659 659 'status' : content['status'],
660 660 }
661 661
662 662 if md['engine_uuid'] is not None:
663 663 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
664 664
665 665 if 'date' in parent:
666 666 md['submitted'] = parent['date']
667 667 if 'started' in header:
668 668 md['started'] = header['started']
669 669 if 'date' in header:
670 670 md['completed'] = header['date']
671 671 return md
672 672
673 673 def _register_engine(self, msg):
674 674 """Register a new engine, and update our connection info."""
675 675 content = msg['content']
676 676 eid = content['id']
677 677 d = {eid : content['queue']}
678 678 self._update_engines(d)
679 679
680 680 def _unregister_engine(self, msg):
681 681 """Unregister an engine that has died."""
682 682 content = msg['content']
683 683 eid = int(content['id'])
684 684 if eid in self._ids:
685 685 self._ids.remove(eid)
686 686 uuid = self._engines.pop(eid)
687 687
688 688 self._handle_stranded_msgs(eid, uuid)
689 689
690 690 if self._task_socket and self._task_scheme == 'pure':
691 691 self._stop_scheduling_tasks()
692 692
693 693 def _handle_stranded_msgs(self, eid, uuid):
694 694 """Handle messages known to be on an engine when the engine unregisters.
695 695
696 696 It is possible that this will fire prematurely - that is, an engine will
697 697 go down after completing a result, and the client will be notified
698 698 of the unregistration and later receive the successful result.
699 699 """
700 700
701 701 outstanding = self._outstanding_dict[uuid]
702 702
703 703 for msg_id in list(outstanding):
704 704 if msg_id in self.results:
705 705 # we already
706 706 continue
707 707 try:
708 708 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
709 709 except:
710 710 content = error.wrap_exception()
711 711 # build a fake message:
712 712 parent = {}
713 713 header = {}
714 714 parent['msg_id'] = msg_id
715 715 header['engine'] = uuid
716 716 header['date'] = datetime.now()
717 717 msg = dict(parent_header=parent, header=header, content=content)
718 718 self._handle_apply_reply(msg)
719 719
720 720 def _handle_execute_reply(self, msg):
721 721 """Save the reply to an execute_request into our results.
722 722
723 723 execute messages are never actually used. apply is used instead.
724 724 """
725 725
726 726 parent = msg['parent_header']
727 727 msg_id = parent['msg_id']
728 728 if msg_id not in self.outstanding:
729 729 if msg_id in self.history:
730 730 print ("got stale result: %s"%msg_id)
731 731 else:
732 732 print ("got unknown result: %s"%msg_id)
733 733 else:
734 734 self.outstanding.remove(msg_id)
735 735
736 736 content = msg['content']
737 737 header = msg['header']
738 738
739 739 # construct metadata:
740 740 md = self.metadata[msg_id]
741 741 md.update(self._extract_metadata(header, parent, content))
742 742 # is this redundant?
743 743 self.metadata[msg_id] = md
744 744
745 745 e_outstanding = self._outstanding_dict[md['engine_uuid']]
746 746 if msg_id in e_outstanding:
747 747 e_outstanding.remove(msg_id)
748 748
749 749 # construct result:
750 750 if content['status'] == 'ok':
751 751 self.results[msg_id] = ExecuteReply(msg_id, content, md)
752 752 elif content['status'] == 'aborted':
753 753 self.results[msg_id] = error.TaskAborted(msg_id)
754 754 elif content['status'] == 'resubmitted':
755 755 # TODO: handle resubmission
756 756 pass
757 757 else:
758 758 self.results[msg_id] = self._unwrap_exception(content)
759 759
760 760 def _handle_apply_reply(self, msg):
761 761 """Save the reply to an apply_request into our results."""
762 762 parent = msg['parent_header']
763 763 msg_id = parent['msg_id']
764 764 if msg_id not in self.outstanding:
765 765 if msg_id in self.history:
766 766 print ("got stale result: %s"%msg_id)
767 767 print self.results[msg_id]
768 768 print msg
769 769 else:
770 770 print ("got unknown result: %s"%msg_id)
771 771 else:
772 772 self.outstanding.remove(msg_id)
773 773 content = msg['content']
774 774 header = msg['header']
775 775
776 776 # construct metadata:
777 777 md = self.metadata[msg_id]
778 778 md.update(self._extract_metadata(header, parent, content))
779 779 # is this redundant?
780 780 self.metadata[msg_id] = md
781 781
782 782 e_outstanding = self._outstanding_dict[md['engine_uuid']]
783 783 if msg_id in e_outstanding:
784 784 e_outstanding.remove(msg_id)
785 785
786 786 # construct result:
787 787 if content['status'] == 'ok':
788 788 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
789 789 elif content['status'] == 'aborted':
790 790 self.results[msg_id] = error.TaskAborted(msg_id)
791 791 elif content['status'] == 'resubmitted':
792 792 # TODO: handle resubmission
793 793 pass
794 794 else:
795 795 self.results[msg_id] = self._unwrap_exception(content)
796 796
797 797 def _flush_notifications(self):
798 798 """Flush notifications of engine registrations waiting
799 799 in ZMQ queue."""
800 800 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
801 801 while msg is not None:
802 802 if self.debug:
803 803 pprint(msg)
804 804 msg_type = msg['header']['msg_type']
805 805 handler = self._notification_handlers.get(msg_type, None)
806 806 if handler is None:
807 807 raise Exception("Unhandled message type: %s"%msg.msg_type)
808 808 else:
809 809 handler(msg)
810 810 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
811 811
812 812 def _flush_results(self, sock):
813 813 """Flush task or queue results waiting in ZMQ queue."""
814 814 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
815 815 while msg is not None:
816 816 if self.debug:
817 817 pprint(msg)
818 818 msg_type = msg['header']['msg_type']
819 819 handler = self._queue_handlers.get(msg_type, None)
820 820 if handler is None:
821 821 raise Exception("Unhandled message type: %s"%msg.msg_type)
822 822 else:
823 823 handler(msg)
824 824 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
825 825
826 826 def _flush_control(self, sock):
827 827 """Flush replies from the control channel waiting
828 828 in the ZMQ queue.
829 829
830 830 Currently: ignore them."""
831 831 if self._ignored_control_replies <= 0:
832 832 return
833 833 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
834 834 while msg is not None:
835 835 self._ignored_control_replies -= 1
836 836 if self.debug:
837 837 pprint(msg)
838 838 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
839 839
840 840 def _flush_ignored_control(self):
841 841 """flush ignored control replies"""
842 842 while self._ignored_control_replies > 0:
843 843 self.session.recv(self._control_socket)
844 844 self._ignored_control_replies -= 1
845 845
846 846 def _flush_ignored_hub_replies(self):
847 847 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
848 848 while msg is not None:
849 849 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
850 850
851 851 def _flush_iopub(self, sock):
852 852 """Flush replies from the iopub channel waiting
853 853 in the ZMQ queue.
854 854 """
855 855 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
856 856 while msg is not None:
857 857 if self.debug:
858 858 pprint(msg)
859 859 parent = msg['parent_header']
860 860 # ignore IOPub messages with no parent.
861 861 # Caused by print statements or warnings from before the first execution.
862 862 if not parent:
863 863 continue
864 864 msg_id = parent['msg_id']
865 865 content = msg['content']
866 866 header = msg['header']
867 867 msg_type = msg['header']['msg_type']
868 868
869 869 # init metadata:
870 870 md = self.metadata[msg_id]
871 871
872 872 if msg_type == 'stream':
873 873 name = content['name']
874 874 s = md[name] or ''
875 875 md[name] = s + content['data']
876 876 elif msg_type == 'pyerr':
877 877 md.update({'pyerr' : self._unwrap_exception(content)})
878 878 elif msg_type == 'pyin':
879 879 md.update({'pyin' : content['code']})
880 880 elif msg_type == 'display_data':
881 881 md['outputs'].append(content)
882 882 elif msg_type == 'pyout':
883 883 md['pyout'] = content
884 884 elif msg_type == 'status':
885 885 # idle message comes after all outputs
886 886 if content['execution_state'] == 'idle':
887 887 md['outputs_ready'] = True
888 888 else:
889 889 # unhandled msg_type (status, etc.)
890 890 pass
891 891
892 892 # reduntant?
893 893 self.metadata[msg_id] = md
894 894
895 895 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
896 896
897 897 #--------------------------------------------------------------------------
898 898 # len, getitem
899 899 #--------------------------------------------------------------------------
900 900
901 901 def __len__(self):
902 902 """len(client) returns # of engines."""
903 903 return len(self.ids)
904 904
905 905 def __getitem__(self, key):
906 906 """index access returns DirectView multiplexer objects
907 907
908 908 Must be int, slice, or list/tuple/xrange of ints"""
909 909 if not isinstance(key, (int, slice, tuple, list, xrange)):
910 910 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
911 911 else:
912 912 return self.direct_view(key)
913 913
914 914 #--------------------------------------------------------------------------
915 915 # Begin public methods
916 916 #--------------------------------------------------------------------------
917 917
918 918 @property
919 919 def ids(self):
920 920 """Always up-to-date ids property."""
921 921 self._flush_notifications()
922 922 # always copy:
923 923 return list(self._ids)
924 924
925 925 def activate(self, targets='all', suffix=''):
926 926 """Create a DirectView and register it with IPython magics
927 927
928 928 Defines the magics `%px, %autopx, %pxresult, %%px`
929 929
930 930 Parameters
931 931 ----------
932 932
933 933 targets: int, list of ints, or 'all'
934 934 The engines on which the view's magics will run
935 935 suffix: str [default: '']
936 936 The suffix, if any, for the magics. This allows you to have
937 937 multiple views associated with parallel magics at the same time.
938 938
939 939 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
940 940 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
941 941 on engine 0.
942 942 """
943 943 view = self.direct_view(targets)
944 944 view.block = True
945 945 view.activate(suffix)
946 946 return view
947 947
948 948 def close(self):
949 949 if self._closed:
950 950 return
951 951 self.stop_spin_thread()
952 952 snames = filter(lambda n: n.endswith('socket'), dir(self))
953 953 for socket in map(lambda name: getattr(self, name), snames):
954 954 if isinstance(socket, zmq.Socket) and not socket.closed:
955 955 socket.close()
956 956 self._closed = True
957 957
958 958 def _spin_every(self, interval=1):
959 959 """target func for use in spin_thread"""
960 960 while True:
961 961 if self._stop_spinning.is_set():
962 962 return
963 963 time.sleep(interval)
964 964 self.spin()
965 965
966 966 def spin_thread(self, interval=1):
967 967 """call Client.spin() in a background thread on some regular interval
968 968
969 969 This helps ensure that messages don't pile up too much in the zmq queue
970 970 while you are working on other things, or just leaving an idle terminal.
971 971
972 972 It also helps limit potential padding of the `received` timestamp
973 973 on AsyncResult objects, used for timings.
974 974
975 975 Parameters
976 976 ----------
977 977
978 978 interval : float, optional
979 979 The interval on which to spin the client in the background thread
980 980 (simply passed to time.sleep).
981 981
982 982 Notes
983 983 -----
984 984
985 985 For precision timing, you may want to use this method to put a bound
986 986 on the jitter (in seconds) in `received` timestamps used
987 987 in AsyncResult.wall_time.
988 988
989 989 """
990 990 if self._spin_thread is not None:
991 991 self.stop_spin_thread()
992 992 self._stop_spinning.clear()
993 993 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
994 994 self._spin_thread.daemon = True
995 995 self._spin_thread.start()
996 996
997 997 def stop_spin_thread(self):
998 998 """stop background spin_thread, if any"""
999 999 if self._spin_thread is not None:
1000 1000 self._stop_spinning.set()
1001 1001 self._spin_thread.join()
1002 1002 self._spin_thread = None
1003 1003
1004 1004 def spin(self):
1005 1005 """Flush any registration notifications and execution results
1006 1006 waiting in the ZMQ queue.
1007 1007 """
1008 1008 if self._notification_socket:
1009 1009 self._flush_notifications()
1010 1010 if self._iopub_socket:
1011 1011 self._flush_iopub(self._iopub_socket)
1012 1012 if self._mux_socket:
1013 1013 self._flush_results(self._mux_socket)
1014 1014 if self._task_socket:
1015 1015 self._flush_results(self._task_socket)
1016 1016 if self._control_socket:
1017 1017 self._flush_control(self._control_socket)
1018 1018 if self._query_socket:
1019 1019 self._flush_ignored_hub_replies()
1020 1020
1021 1021 def wait(self, jobs=None, timeout=-1):
1022 1022 """waits on one or more `jobs`, for up to `timeout` seconds.
1023 1023
1024 1024 Parameters
1025 1025 ----------
1026 1026
1027 1027 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1028 1028 ints are indices to self.history
1029 1029 strs are msg_ids
1030 1030 default: wait on all outstanding messages
1031 1031 timeout : float
1032 1032 a time in seconds, after which to give up.
1033 1033 default is -1, which means no timeout
1034 1034
1035 1035 Returns
1036 1036 -------
1037 1037
1038 1038 True : when all msg_ids are done
1039 1039 False : timeout reached, some msg_ids still outstanding
1040 1040 """
1041 1041 tic = time.time()
1042 1042 if jobs is None:
1043 1043 theids = self.outstanding
1044 1044 else:
1045 1045 if isinstance(jobs, (int, basestring, AsyncResult)):
1046 1046 jobs = [jobs]
1047 1047 theids = set()
1048 1048 for job in jobs:
1049 1049 if isinstance(job, int):
1050 1050 # index access
1051 1051 job = self.history[job]
1052 1052 elif isinstance(job, AsyncResult):
1053 1053 map(theids.add, job.msg_ids)
1054 1054 continue
1055 1055 theids.add(job)
1056 1056 if not theids.intersection(self.outstanding):
1057 1057 return True
1058 1058 self.spin()
1059 1059 while theids.intersection(self.outstanding):
1060 1060 if timeout >= 0 and ( time.time()-tic ) > timeout:
1061 1061 break
1062 1062 time.sleep(1e-3)
1063 1063 self.spin()
1064 1064 return len(theids.intersection(self.outstanding)) == 0
1065 1065
1066 1066 #--------------------------------------------------------------------------
1067 1067 # Control methods
1068 1068 #--------------------------------------------------------------------------
1069 1069
1070 1070 @spin_first
1071 1071 def clear(self, targets=None, block=None):
1072 1072 """Clear the namespace in target(s)."""
1073 1073 block = self.block if block is None else block
1074 1074 targets = self._build_targets(targets)[0]
1075 1075 for t in targets:
1076 1076 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1077 1077 error = False
1078 1078 if block:
1079 1079 self._flush_ignored_control()
1080 1080 for i in range(len(targets)):
1081 1081 idents,msg = self.session.recv(self._control_socket,0)
1082 1082 if self.debug:
1083 1083 pprint(msg)
1084 1084 if msg['content']['status'] != 'ok':
1085 1085 error = self._unwrap_exception(msg['content'])
1086 1086 else:
1087 1087 self._ignored_control_replies += len(targets)
1088 1088 if error:
1089 1089 raise error
1090 1090
1091 1091
1092 1092 @spin_first
1093 1093 def abort(self, jobs=None, targets=None, block=None):
1094 1094 """Abort specific jobs from the execution queues of target(s).
1095 1095
1096 1096 This is a mechanism to prevent jobs that have already been submitted
1097 1097 from executing.
1098 1098
1099 1099 Parameters
1100 1100 ----------
1101 1101
1102 1102 jobs : msg_id, list of msg_ids, or AsyncResult
1103 1103 The jobs to be aborted
1104 1104
1105 1105 If unspecified/None: abort all outstanding jobs.
1106 1106
1107 1107 """
1108 1108 block = self.block if block is None else block
1109 1109 jobs = jobs if jobs is not None else list(self.outstanding)
1110 1110 targets = self._build_targets(targets)[0]
1111 1111
1112 1112 msg_ids = []
1113 1113 if isinstance(jobs, (basestring,AsyncResult)):
1114 1114 jobs = [jobs]
1115 1115 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1116 1116 if bad_ids:
1117 1117 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1118 1118 for j in jobs:
1119 1119 if isinstance(j, AsyncResult):
1120 1120 msg_ids.extend(j.msg_ids)
1121 1121 else:
1122 1122 msg_ids.append(j)
1123 1123 content = dict(msg_ids=msg_ids)
1124 1124 for t in targets:
1125 1125 self.session.send(self._control_socket, 'abort_request',
1126 1126 content=content, ident=t)
1127 1127 error = False
1128 1128 if block:
1129 1129 self._flush_ignored_control()
1130 1130 for i in range(len(targets)):
1131 1131 idents,msg = self.session.recv(self._control_socket,0)
1132 1132 if self.debug:
1133 1133 pprint(msg)
1134 1134 if msg['content']['status'] != 'ok':
1135 1135 error = self._unwrap_exception(msg['content'])
1136 1136 else:
1137 1137 self._ignored_control_replies += len(targets)
1138 1138 if error:
1139 1139 raise error
1140 1140
1141 1141 @spin_first
1142 1142 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1143 1143 """Terminates one or more engine processes, optionally including the hub.
1144 1144
1145 1145 Parameters
1146 1146 ----------
1147 1147
1148 1148 targets: list of ints or 'all' [default: all]
1149 1149 Which engines to shutdown.
1150 1150 hub: bool [default: False]
1151 1151 Whether to include the Hub. hub=True implies targets='all'.
1152 1152 block: bool [default: self.block]
1153 1153 Whether to wait for clean shutdown replies or not.
1154 1154 restart: bool [default: False]
1155 1155 NOT IMPLEMENTED
1156 1156 whether to restart engines after shutting them down.
1157 1157 """
1158 1158
1159 1159 if restart:
1160 1160 raise NotImplementedError("Engine restart is not yet implemented")
1161 1161
1162 1162 block = self.block if block is None else block
1163 1163 if hub:
1164 1164 targets = 'all'
1165 1165 targets = self._build_targets(targets)[0]
1166 1166 for t in targets:
1167 1167 self.session.send(self._control_socket, 'shutdown_request',
1168 1168 content={'restart':restart},ident=t)
1169 1169 error = False
1170 1170 if block or hub:
1171 1171 self._flush_ignored_control()
1172 1172 for i in range(len(targets)):
1173 1173 idents,msg = self.session.recv(self._control_socket, 0)
1174 1174 if self.debug:
1175 1175 pprint(msg)
1176 1176 if msg['content']['status'] != 'ok':
1177 1177 error = self._unwrap_exception(msg['content'])
1178 1178 else:
1179 1179 self._ignored_control_replies += len(targets)
1180 1180
1181 1181 if hub:
1182 1182 time.sleep(0.25)
1183 1183 self.session.send(self._query_socket, 'shutdown_request')
1184 1184 idents,msg = self.session.recv(self._query_socket, 0)
1185 1185 if self.debug:
1186 1186 pprint(msg)
1187 1187 if msg['content']['status'] != 'ok':
1188 1188 error = self._unwrap_exception(msg['content'])
1189 1189
1190 1190 if error:
1191 1191 raise error
1192 1192
1193 1193 #--------------------------------------------------------------------------
1194 1194 # Execution related methods
1195 1195 #--------------------------------------------------------------------------
1196 1196
1197 1197 def _maybe_raise(self, result):
1198 1198 """wrapper for maybe raising an exception if apply failed."""
1199 1199 if isinstance(result, error.RemoteError):
1200 1200 raise result
1201 1201
1202 1202 return result
1203 1203
1204 1204 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1205 1205 ident=None):
1206 1206 """construct and send an apply message via a socket.
1207 1207
1208 1208 This is the principal method with which all engine execution is performed by views.
1209 1209 """
1210 1210
1211 1211 if self._closed:
1212 1212 raise RuntimeError("Client cannot be used after its sockets have been closed")
1213 1213
1214 1214 # defaults:
1215 1215 args = args if args is not None else []
1216 1216 kwargs = kwargs if kwargs is not None else {}
1217 1217 subheader = subheader if subheader is not None else {}
1218 1218
1219 1219 # validate arguments
1220 1220 if not callable(f) and not isinstance(f, Reference):
1221 1221 raise TypeError("f must be callable, not %s"%type(f))
1222 1222 if not isinstance(args, (tuple, list)):
1223 1223 raise TypeError("args must be tuple or list, not %s"%type(args))
1224 1224 if not isinstance(kwargs, dict):
1225 1225 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1226 1226 if not isinstance(subheader, dict):
1227 1227 raise TypeError("subheader must be dict, not %s"%type(subheader))
1228 1228
1229 1229 bufs = util.pack_apply_message(f,args,kwargs)
1230 1230
1231 1231 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1232 1232 subheader=subheader, track=track)
1233 1233
1234 1234 msg_id = msg['header']['msg_id']
1235 1235 self.outstanding.add(msg_id)
1236 1236 if ident:
1237 1237 # possibly routed to a specific engine
1238 1238 if isinstance(ident, list):
1239 1239 ident = ident[-1]
1240 1240 if ident in self._engines.values():
1241 1241 # save for later, in case of engine death
1242 1242 self._outstanding_dict[ident].add(msg_id)
1243 1243 self.history.append(msg_id)
1244 1244 self.metadata[msg_id]['submitted'] = datetime.now()
1245 1245
1246 1246 return msg
1247 1247
1248 1248 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1249 1249 """construct and send an execute request via a socket.
1250 1250
1251 1251 """
1252 1252
1253 1253 if self._closed:
1254 1254 raise RuntimeError("Client cannot be used after its sockets have been closed")
1255 1255
1256 1256 # defaults:
1257 1257 subheader = subheader if subheader is not None else {}
1258 1258
1259 1259 # validate arguments
1260 1260 if not isinstance(code, basestring):
1261 1261 raise TypeError("code must be text, not %s" % type(code))
1262 1262 if not isinstance(subheader, dict):
1263 1263 raise TypeError("subheader must be dict, not %s" % type(subheader))
1264 1264
1265 1265 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1266 1266
1267 1267
1268 1268 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1269 1269 subheader=subheader)
1270 1270
1271 1271 msg_id = msg['header']['msg_id']
1272 1272 self.outstanding.add(msg_id)
1273 1273 if ident:
1274 1274 # possibly routed to a specific engine
1275 1275 if isinstance(ident, list):
1276 1276 ident = ident[-1]
1277 1277 if ident in self._engines.values():
1278 1278 # save for later, in case of engine death
1279 1279 self._outstanding_dict[ident].add(msg_id)
1280 1280 self.history.append(msg_id)
1281 1281 self.metadata[msg_id]['submitted'] = datetime.now()
1282 1282
1283 1283 return msg
1284 1284
1285 1285 #--------------------------------------------------------------------------
1286 1286 # construct a View object
1287 1287 #--------------------------------------------------------------------------
1288 1288
1289 1289 def load_balanced_view(self, targets=None):
1290 1290 """construct a DirectView object.
1291 1291
1292 1292 If no arguments are specified, create a LoadBalancedView
1293 1293 using all engines.
1294 1294
1295 1295 Parameters
1296 1296 ----------
1297 1297
1298 1298 targets: list,slice,int,etc. [default: use all engines]
1299 1299 The subset of engines across which to load-balance
1300 1300 """
1301 1301 if targets == 'all':
1302 1302 targets = None
1303 1303 if targets is not None:
1304 1304 targets = self._build_targets(targets)[1]
1305 1305 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1306 1306
1307 1307 def direct_view(self, targets='all'):
1308 1308 """construct a DirectView object.
1309 1309
1310 1310 If no targets are specified, create a DirectView using all engines.
1311 1311
1312 1312 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1313 1313 evaluate the target engines at each execution, whereas rc[:] will connect to
1314 1314 all *current* engines, and that list will not change.
1315 1315
1316 1316 That is, 'all' will always use all engines, whereas rc[:] will not use
1317 1317 engines added after the DirectView is constructed.
1318 1318
1319 1319 Parameters
1320 1320 ----------
1321 1321
1322 1322 targets: list,slice,int,etc. [default: use all engines]
1323 1323 The engines to use for the View
1324 1324 """
1325 1325 single = isinstance(targets, int)
1326 1326 # allow 'all' to be lazily evaluated at each execution
1327 1327 if targets != 'all':
1328 1328 targets = self._build_targets(targets)[1]
1329 1329 if single:
1330 1330 targets = targets[0]
1331 1331 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1332 1332
1333 1333 #--------------------------------------------------------------------------
1334 1334 # Query methods
1335 1335 #--------------------------------------------------------------------------
1336 1336
1337 1337 @spin_first
1338 1338 def get_result(self, indices_or_msg_ids=None, block=None):
1339 1339 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1340 1340
1341 1341 If the client already has the results, no request to the Hub will be made.
1342 1342
1343 1343 This is a convenient way to construct AsyncResult objects, which are wrappers
1344 1344 that include metadata about execution, and allow for awaiting results that
1345 1345 were not submitted by this Client.
1346 1346
1347 1347 It can also be a convenient way to retrieve the metadata associated with
1348 1348 blocking execution, since it always retrieves
1349 1349
1350 1350 Examples
1351 1351 --------
1352 1352 ::
1353 1353
1354 1354 In [10]: r = client.apply()
1355 1355
1356 1356 Parameters
1357 1357 ----------
1358 1358
1359 1359 indices_or_msg_ids : integer history index, str msg_id, or list of either
1360 1360 The indices or msg_ids of indices to be retrieved
1361 1361
1362 1362 block : bool
1363 1363 Whether to wait for the result to be done
1364 1364
1365 1365 Returns
1366 1366 -------
1367 1367
1368 1368 AsyncResult
1369 1369 A single AsyncResult object will always be returned.
1370 1370
1371 1371 AsyncHubResult
1372 1372 A subclass of AsyncResult that retrieves results from the Hub
1373 1373
1374 1374 """
1375 1375 block = self.block if block is None else block
1376 1376 if indices_or_msg_ids is None:
1377 1377 indices_or_msg_ids = -1
1378 1378
1379 1379 if not isinstance(indices_or_msg_ids, (list,tuple)):
1380 1380 indices_or_msg_ids = [indices_or_msg_ids]
1381 1381
1382 1382 theids = []
1383 1383 for id in indices_or_msg_ids:
1384 1384 if isinstance(id, int):
1385 1385 id = self.history[id]
1386 1386 if not isinstance(id, basestring):
1387 1387 raise TypeError("indices must be str or int, not %r"%id)
1388 1388 theids.append(id)
1389 1389
1390 1390 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1391 1391 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1392 1392
1393 1393 if remote_ids:
1394 1394 ar = AsyncHubResult(self, msg_ids=theids)
1395 1395 else:
1396 1396 ar = AsyncResult(self, msg_ids=theids)
1397 1397
1398 1398 if block:
1399 1399 ar.wait()
1400 1400
1401 1401 return ar
1402 1402
1403 1403 @spin_first
1404 1404 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1405 1405 """Resubmit one or more tasks.
1406 1406
1407 1407 in-flight tasks may not be resubmitted.
1408 1408
1409 1409 Parameters
1410 1410 ----------
1411 1411
1412 1412 indices_or_msg_ids : integer history index, str msg_id, or list of either
1413 1413 The indices or msg_ids of indices to be retrieved
1414 1414
1415 1415 block : bool
1416 1416 Whether to wait for the result to be done
1417 1417
1418 1418 Returns
1419 1419 -------
1420 1420
1421 1421 AsyncHubResult
1422 1422 A subclass of AsyncResult that retrieves results from the Hub
1423 1423
1424 1424 """
1425 1425 block = self.block if block is None else block
1426 1426 if indices_or_msg_ids is None:
1427 1427 indices_or_msg_ids = -1
1428 1428
1429 1429 if not isinstance(indices_or_msg_ids, (list,tuple)):
1430 1430 indices_or_msg_ids = [indices_or_msg_ids]
1431 1431
1432 1432 theids = []
1433 1433 for id in indices_or_msg_ids:
1434 1434 if isinstance(id, int):
1435 1435 id = self.history[id]
1436 1436 if not isinstance(id, basestring):
1437 1437 raise TypeError("indices must be str or int, not %r"%id)
1438 1438 theids.append(id)
1439 1439
1440 1440 content = dict(msg_ids = theids)
1441 1441
1442 1442 self.session.send(self._query_socket, 'resubmit_request', content)
1443 1443
1444 1444 zmq.select([self._query_socket], [], [])
1445 1445 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1446 1446 if self.debug:
1447 1447 pprint(msg)
1448 1448 content = msg['content']
1449 1449 if content['status'] != 'ok':
1450 1450 raise self._unwrap_exception(content)
1451 1451 mapping = content['resubmitted']
1452 1452 new_ids = [ mapping[msg_id] for msg_id in theids ]
1453 1453
1454 1454 ar = AsyncHubResult(self, msg_ids=new_ids)
1455 1455
1456 1456 if block:
1457 1457 ar.wait()
1458 1458
1459 1459 return ar
1460 1460
1461 1461 @spin_first
1462 1462 def result_status(self, msg_ids, status_only=True):
1463 1463 """Check on the status of the result(s) of the apply request with `msg_ids`.
1464 1464
1465 1465 If status_only is False, then the actual results will be retrieved, else
1466 1466 only the status of the results will be checked.
1467 1467
1468 1468 Parameters
1469 1469 ----------
1470 1470
1471 1471 msg_ids : list of msg_ids
1472 1472 if int:
1473 1473 Passed as index to self.history for convenience.
1474 1474 status_only : bool (default: True)
1475 1475 if False:
1476 1476 Retrieve the actual results of completed tasks.
1477 1477
1478 1478 Returns
1479 1479 -------
1480 1480
1481 1481 results : dict
1482 1482 There will always be the keys 'pending' and 'completed', which will
1483 1483 be lists of msg_ids that are incomplete or complete. If `status_only`
1484 1484 is False, then completed results will be keyed by their `msg_id`.
1485 1485 """
1486 1486 if not isinstance(msg_ids, (list,tuple)):
1487 1487 msg_ids = [msg_ids]
1488 1488
1489 1489 theids = []
1490 1490 for msg_id in msg_ids:
1491 1491 if isinstance(msg_id, int):
1492 1492 msg_id = self.history[msg_id]
1493 1493 if not isinstance(msg_id, basestring):
1494 1494 raise TypeError("msg_ids must be str, not %r"%msg_id)
1495 1495 theids.append(msg_id)
1496 1496
1497 1497 completed = []
1498 1498 local_results = {}
1499 1499
1500 1500 # comment this block out to temporarily disable local shortcut:
1501 1501 for msg_id in theids:
1502 1502 if msg_id in self.results:
1503 1503 completed.append(msg_id)
1504 1504 local_results[msg_id] = self.results[msg_id]
1505 1505 theids.remove(msg_id)
1506 1506
1507 1507 if theids: # some not locally cached
1508 1508 content = dict(msg_ids=theids, status_only=status_only)
1509 1509 msg = self.session.send(self._query_socket, "result_request", content=content)
1510 1510 zmq.select([self._query_socket], [], [])
1511 1511 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1512 1512 if self.debug:
1513 1513 pprint(msg)
1514 1514 content = msg['content']
1515 1515 if content['status'] != 'ok':
1516 1516 raise self._unwrap_exception(content)
1517 1517 buffers = msg['buffers']
1518 1518 else:
1519 1519 content = dict(completed=[],pending=[])
1520 1520
1521 1521 content['completed'].extend(completed)
1522 1522
1523 1523 if status_only:
1524 1524 return content
1525 1525
1526 1526 failures = []
1527 1527 # load cached results into result:
1528 1528 content.update(local_results)
1529 1529
1530 1530 # update cache with results:
1531 1531 for msg_id in sorted(theids):
1532 1532 if msg_id in content['completed']:
1533 1533 rec = content[msg_id]
1534 1534 parent = rec['header']
1535 1535 header = rec['result_header']
1536 1536 rcontent = rec['result_content']
1537 1537 iodict = rec['io']
1538 1538 if isinstance(rcontent, str):
1539 1539 rcontent = self.session.unpack(rcontent)
1540 1540
1541 1541 md = self.metadata[msg_id]
1542 1542 md.update(self._extract_metadata(header, parent, rcontent))
1543 1543 if rec.get('received'):
1544 1544 md['received'] = rec['received']
1545 1545 md.update(iodict)
1546 1546
1547 1547 if rcontent['status'] == 'ok':
1548 if header['msg_type'] == 'apply_reply':
1548 1549 res,buffers = util.unserialize_object(buffers)
1550 elif header['msg_type'] == 'execute_reply':
1551 res = ExecuteReply(msg_id, rcontent, md)
1552 else:
1553 raise KeyError("unhandled msg type: %r" % header[msg_type])
1549 1554 else:
1550 print rcontent
1551 1555 res = self._unwrap_exception(rcontent)
1552 1556 failures.append(res)
1553 1557
1554 1558 self.results[msg_id] = res
1555 1559 content[msg_id] = res
1556 1560
1557 1561 if len(theids) == 1 and failures:
1558 1562 raise failures[0]
1559 1563
1560 1564 error.collect_exceptions(failures, "result_status")
1561 1565 return content
1562 1566
1563 1567 @spin_first
1564 1568 def queue_status(self, targets='all', verbose=False):
1565 1569 """Fetch the status of engine queues.
1566 1570
1567 1571 Parameters
1568 1572 ----------
1569 1573
1570 1574 targets : int/str/list of ints/strs
1571 1575 the engines whose states are to be queried.
1572 1576 default : all
1573 1577 verbose : bool
1574 1578 Whether to return lengths only, or lists of ids for each element
1575 1579 """
1576 1580 if targets == 'all':
1577 1581 # allow 'all' to be evaluated on the engine
1578 1582 engine_ids = None
1579 1583 else:
1580 1584 engine_ids = self._build_targets(targets)[1]
1581 1585 content = dict(targets=engine_ids, verbose=verbose)
1582 1586 self.session.send(self._query_socket, "queue_request", content=content)
1583 1587 idents,msg = self.session.recv(self._query_socket, 0)
1584 1588 if self.debug:
1585 1589 pprint(msg)
1586 1590 content = msg['content']
1587 1591 status = content.pop('status')
1588 1592 if status != 'ok':
1589 1593 raise self._unwrap_exception(content)
1590 1594 content = rekey(content)
1591 1595 if isinstance(targets, int):
1592 1596 return content[targets]
1593 1597 else:
1594 1598 return content
1595 1599
1596 1600 @spin_first
1597 1601 def purge_results(self, jobs=[], targets=[]):
1598 1602 """Tell the Hub to forget results.
1599 1603
1600 1604 Individual results can be purged by msg_id, or the entire
1601 1605 history of specific targets can be purged.
1602 1606
1603 1607 Use `purge_results('all')` to scrub everything from the Hub's db.
1604 1608
1605 1609 Parameters
1606 1610 ----------
1607 1611
1608 1612 jobs : str or list of str or AsyncResult objects
1609 1613 the msg_ids whose results should be forgotten.
1610 1614 targets : int/str/list of ints/strs
1611 1615 The targets, by int_id, whose entire history is to be purged.
1612 1616
1613 1617 default : None
1614 1618 """
1615 1619 if not targets and not jobs:
1616 1620 raise ValueError("Must specify at least one of `targets` and `jobs`")
1617 1621 if targets:
1618 1622 targets = self._build_targets(targets)[1]
1619 1623
1620 1624 # construct msg_ids from jobs
1621 1625 if jobs == 'all':
1622 1626 msg_ids = jobs
1623 1627 else:
1624 1628 msg_ids = []
1625 1629 if isinstance(jobs, (basestring,AsyncResult)):
1626 1630 jobs = [jobs]
1627 1631 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1628 1632 if bad_ids:
1629 1633 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1630 1634 for j in jobs:
1631 1635 if isinstance(j, AsyncResult):
1632 1636 msg_ids.extend(j.msg_ids)
1633 1637 else:
1634 1638 msg_ids.append(j)
1635 1639
1636 1640 content = dict(engine_ids=targets, msg_ids=msg_ids)
1637 1641 self.session.send(self._query_socket, "purge_request", content=content)
1638 1642 idents, msg = self.session.recv(self._query_socket, 0)
1639 1643 if self.debug:
1640 1644 pprint(msg)
1641 1645 content = msg['content']
1642 1646 if content['status'] != 'ok':
1643 1647 raise self._unwrap_exception(content)
1644 1648
1645 1649 @spin_first
1646 1650 def hub_history(self):
1647 1651 """Get the Hub's history
1648 1652
1649 1653 Just like the Client, the Hub has a history, which is a list of msg_ids.
1650 1654 This will contain the history of all clients, and, depending on configuration,
1651 1655 may contain history across multiple cluster sessions.
1652 1656
1653 1657 Any msg_id returned here is a valid argument to `get_result`.
1654 1658
1655 1659 Returns
1656 1660 -------
1657 1661
1658 1662 msg_ids : list of strs
1659 1663 list of all msg_ids, ordered by task submission time.
1660 1664 """
1661 1665
1662 1666 self.session.send(self._query_socket, "history_request", content={})
1663 1667 idents, msg = self.session.recv(self._query_socket, 0)
1664 1668
1665 1669 if self.debug:
1666 1670 pprint(msg)
1667 1671 content = msg['content']
1668 1672 if content['status'] != 'ok':
1669 1673 raise self._unwrap_exception(content)
1670 1674 else:
1671 1675 return content['history']
1672 1676
1673 1677 @spin_first
1674 1678 def db_query(self, query, keys=None):
1675 1679 """Query the Hub's TaskRecord database
1676 1680
1677 1681 This will return a list of task record dicts that match `query`
1678 1682
1679 1683 Parameters
1680 1684 ----------
1681 1685
1682 1686 query : mongodb query dict
1683 1687 The search dict. See mongodb query docs for details.
1684 1688 keys : list of strs [optional]
1685 1689 The subset of keys to be returned. The default is to fetch everything but buffers.
1686 1690 'msg_id' will *always* be included.
1687 1691 """
1688 1692 if isinstance(keys, basestring):
1689 1693 keys = [keys]
1690 1694 content = dict(query=query, keys=keys)
1691 1695 self.session.send(self._query_socket, "db_request", content=content)
1692 1696 idents, msg = self.session.recv(self._query_socket, 0)
1693 1697 if self.debug:
1694 1698 pprint(msg)
1695 1699 content = msg['content']
1696 1700 if content['status'] != 'ok':
1697 1701 raise self._unwrap_exception(content)
1698 1702
1699 1703 records = content['records']
1700 1704
1701 1705 buffer_lens = content['buffer_lens']
1702 1706 result_buffer_lens = content['result_buffer_lens']
1703 1707 buffers = msg['buffers']
1704 1708 has_bufs = buffer_lens is not None
1705 1709 has_rbufs = result_buffer_lens is not None
1706 1710 for i,rec in enumerate(records):
1707 1711 # relink buffers
1708 1712 if has_bufs:
1709 1713 blen = buffer_lens[i]
1710 1714 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1711 1715 if has_rbufs:
1712 1716 blen = result_buffer_lens[i]
1713 1717 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1714 1718
1715 1719 return records
1716 1720
1717 1721 __all__ = [ 'Client' ]
@@ -1,1314 +1,1321
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.py3compat import cast_bytes
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'received': None,
68 68 'result_header' : None,
69 69 'result_content' : None,
70 70 'result_buffers' : None,
71 71 'queue' : None,
72 72 'pyin' : None,
73 73 'pyout': None,
74 74 'pyerr': None,
75 75 'stdout': '',
76 76 'stderr': '',
77 77 }
78 78
79 79 def init_record(msg):
80 80 """Initialize a TaskRecord based on a request."""
81 81 header = msg['header']
82 82 return {
83 83 'msg_id' : header['msg_id'],
84 84 'header' : header,
85 85 'content': msg['content'],
86 86 'buffers': msg['buffers'],
87 87 'submitted': header['date'],
88 88 'client_uuid' : None,
89 89 'engine_uuid' : None,
90 90 'started': None,
91 91 'completed': None,
92 92 'resubmitted': None,
93 93 'received': None,
94 94 'result_header' : None,
95 95 'result_content' : None,
96 96 'result_buffers' : None,
97 97 'queue' : None,
98 98 'pyin' : None,
99 99 'pyout': None,
100 100 'pyerr': None,
101 101 'stdout': '',
102 102 'stderr': '',
103 103 }
104 104
105 105
106 106 class EngineConnector(HasTraits):
107 107 """A simple object for accessing the various zmq connections of an object.
108 108 Attributes are:
109 109 id (int): engine ID
110 110 uuid (str): uuid (unused?)
111 111 queue (str): identity of queue's XREQ socket
112 112 registration (str): identity of registration XREQ socket
113 113 heartbeat (str): identity of heartbeat XREQ socket
114 114 """
115 115 id=Integer(0)
116 116 queue=CBytes()
117 117 control=CBytes()
118 118 registration=CBytes()
119 119 heartbeat=CBytes()
120 120 pending=Set()
121 121
122 122 class HubFactory(RegistrationFactory):
123 123 """The Configurable for setting up a Hub."""
124 124
125 125 # port-pairs for monitoredqueues:
126 126 hb = Tuple(Integer,Integer,config=True,
127 127 help="""XREQ/SUB Port pair for Engine heartbeats""")
128 128 def _hb_default(self):
129 129 return tuple(util.select_random_ports(2))
130 130
131 131 mux = Tuple(Integer,Integer,config=True,
132 132 help="""Engine/Client Port pair for MUX queue""")
133 133
134 134 def _mux_default(self):
135 135 return tuple(util.select_random_ports(2))
136 136
137 137 task = Tuple(Integer,Integer,config=True,
138 138 help="""Engine/Client Port pair for Task queue""")
139 139 def _task_default(self):
140 140 return tuple(util.select_random_ports(2))
141 141
142 142 control = Tuple(Integer,Integer,config=True,
143 143 help="""Engine/Client Port pair for Control queue""")
144 144
145 145 def _control_default(self):
146 146 return tuple(util.select_random_ports(2))
147 147
148 148 iopub = Tuple(Integer,Integer,config=True,
149 149 help="""Engine/Client Port pair for IOPub relay""")
150 150
151 151 def _iopub_default(self):
152 152 return tuple(util.select_random_ports(2))
153 153
154 154 # single ports:
155 155 mon_port = Integer(config=True,
156 156 help="""Monitor (SUB) port for queue traffic""")
157 157
158 158 def _mon_port_default(self):
159 159 return util.select_random_ports(1)[0]
160 160
161 161 notifier_port = Integer(config=True,
162 162 help="""PUB port for sending engine status notifications""")
163 163
164 164 def _notifier_port_default(self):
165 165 return util.select_random_ports(1)[0]
166 166
167 167 engine_ip = Unicode('127.0.0.1', config=True,
168 168 help="IP on which to listen for engine connections. [default: loopback]")
169 169 engine_transport = Unicode('tcp', config=True,
170 170 help="0MQ transport for engine connections. [default: tcp]")
171 171
172 172 client_ip = Unicode('127.0.0.1', config=True,
173 173 help="IP on which to listen for client connections. [default: loopback]")
174 174 client_transport = Unicode('tcp', config=True,
175 175 help="0MQ transport for client connections. [default : tcp]")
176 176
177 177 monitor_ip = Unicode('127.0.0.1', config=True,
178 178 help="IP on which to listen for monitor messages. [default: loopback]")
179 179 monitor_transport = Unicode('tcp', config=True,
180 180 help="0MQ transport for monitor messages. [default : tcp]")
181 181
182 182 monitor_url = Unicode('')
183 183
184 184 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
185 185 config=True, help="""The class to use for the DB backend""")
186 186
187 187 # not configurable
188 188 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
189 189 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
190 190
191 191 def _ip_changed(self, name, old, new):
192 192 self.engine_ip = new
193 193 self.client_ip = new
194 194 self.monitor_ip = new
195 195 self._update_monitor_url()
196 196
197 197 def _update_monitor_url(self):
198 198 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
199 199
200 200 def _transport_changed(self, name, old, new):
201 201 self.engine_transport = new
202 202 self.client_transport = new
203 203 self.monitor_transport = new
204 204 self._update_monitor_url()
205 205
206 206 def __init__(self, **kwargs):
207 207 super(HubFactory, self).__init__(**kwargs)
208 208 self._update_monitor_url()
209 209
210 210
211 211 def construct(self):
212 212 self.init_hub()
213 213
214 214 def start(self):
215 215 self.heartmonitor.start()
216 216 self.log.info("Heartmonitor started")
217 217
218 218 def init_hub(self):
219 219 """construct"""
220 220 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
221 221 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
222 222
223 223 ctx = self.context
224 224 loop = self.loop
225 225
226 226 # Registrar socket
227 227 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
228 228 q.bind(client_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
230 230 if self.client_ip != self.engine_ip:
231 231 q.bind(engine_iface % self.regport)
232 232 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
233 233
234 234 ### Engine connections ###
235 235
236 236 # heartbeat
237 237 hpub = ctx.socket(zmq.PUB)
238 238 hpub.bind(engine_iface % self.hb[0])
239 239 hrep = ctx.socket(zmq.ROUTER)
240 240 hrep.bind(engine_iface % self.hb[1])
241 241 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
242 242 pingstream=ZMQStream(hpub,loop),
243 243 pongstream=ZMQStream(hrep,loop)
244 244 )
245 245
246 246 ### Client connections ###
247 247 # Notifier socket
248 248 n = ZMQStream(ctx.socket(zmq.PUB), loop)
249 249 n.bind(client_iface%self.notifier_port)
250 250
251 251 ### build and launch the queues ###
252 252
253 253 # monitor socket
254 254 sub = ctx.socket(zmq.SUB)
255 255 sub.setsockopt(zmq.SUBSCRIBE, b"")
256 256 sub.bind(self.monitor_url)
257 257 sub.bind('inproc://monitor')
258 258 sub = ZMQStream(sub, loop)
259 259
260 260 # connect the db
261 261 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
262 262 # cdir = self.config.Global.cluster_dir
263 263 self.db = import_item(str(self.db_class))(session=self.session.session,
264 264 config=self.config, log=self.log)
265 265 time.sleep(.25)
266 266 try:
267 267 scheme = self.config.TaskScheduler.scheme_name
268 268 except AttributeError:
269 269 from .scheduler import TaskScheduler
270 270 scheme = TaskScheduler.scheme_name.get_default_value()
271 271 # build connection dicts
272 272 self.engine_info = {
273 273 'control' : engine_iface%self.control[1],
274 274 'mux': engine_iface%self.mux[1],
275 275 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
276 276 'task' : engine_iface%self.task[1],
277 277 'iopub' : engine_iface%self.iopub[1],
278 278 # 'monitor' : engine_iface%self.mon_port,
279 279 }
280 280
281 281 self.client_info = {
282 282 'control' : client_iface%self.control[0],
283 283 'mux': client_iface%self.mux[0],
284 284 'task' : (scheme, client_iface%self.task[0]),
285 285 'iopub' : client_iface%self.iopub[0],
286 286 'notification': client_iface%self.notifier_port
287 287 }
288 288 self.log.debug("Hub engine addrs: %s", self.engine_info)
289 289 self.log.debug("Hub client addrs: %s", self.client_info)
290 290
291 291 # resubmit stream
292 292 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
293 293 url = util.disambiguate_url(self.client_info['task'][-1])
294 294 r.setsockopt(zmq.IDENTITY, self.session.bsession)
295 295 r.connect(url)
296 296
297 297 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
298 298 query=q, notifier=n, resubmit=r, db=self.db,
299 299 engine_info=self.engine_info, client_info=self.client_info,
300 300 log=self.log)
301 301
302 302
303 303 class Hub(SessionFactory):
304 304 """The IPython Controller Hub with 0MQ connections
305 305
306 306 Parameters
307 307 ==========
308 308 loop: zmq IOLoop instance
309 309 session: Session object
310 310 <removed> context: zmq context for creating new connections (?)
311 311 queue: ZMQStream for monitoring the command queue (SUB)
312 312 query: ZMQStream for engine registration and client queries requests (XREP)
313 313 heartbeat: HeartMonitor object checking the pulse of the engines
314 314 notifier: ZMQStream for broadcasting engine registration changes (PUB)
315 315 db: connection to db for out of memory logging of commands
316 316 NotImplemented
317 317 engine_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 client_info: dict of zmq connection information for engines to connect
320 320 to the queues.
321 321 """
322 322 # internal data structures:
323 323 ids=Set() # engine IDs
324 324 keytable=Dict()
325 325 by_ident=Dict()
326 326 engines=Dict()
327 327 clients=Dict()
328 328 hearts=Dict()
329 329 pending=Set()
330 330 queues=Dict() # pending msg_ids keyed by engine_id
331 331 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
332 332 completed=Dict() # completed msg_ids keyed by engine_id
333 333 all_completed=Set() # completed msg_ids keyed by engine_id
334 334 dead_engines=Set() # completed msg_ids keyed by engine_id
335 335 unassigned=Set() # set of task msg_ds not yet assigned a destination
336 336 incoming_registrations=Dict()
337 337 registration_timeout=Integer()
338 338 _idcounter=Integer(0)
339 339
340 340 # objects from constructor:
341 341 query=Instance(ZMQStream)
342 342 monitor=Instance(ZMQStream)
343 343 notifier=Instance(ZMQStream)
344 344 resubmit=Instance(ZMQStream)
345 345 heartmonitor=Instance(HeartMonitor)
346 346 db=Instance(object)
347 347 client_info=Dict()
348 348 engine_info=Dict()
349 349
350 350
351 351 def __init__(self, **kwargs):
352 352 """
353 353 # universal:
354 354 loop: IOLoop for creating future connections
355 355 session: streamsession for sending serialized data
356 356 # engine:
357 357 queue: ZMQStream for monitoring queue messages
358 358 query: ZMQStream for engine+client registration and client requests
359 359 heartbeat: HeartMonitor object for tracking engines
360 360 # extra:
361 361 db: ZMQStream for db connection (NotImplemented)
362 362 engine_info: zmq address/protocol dict for engine connections
363 363 client_info: zmq address/protocol dict for client connections
364 364 """
365 365
366 366 super(Hub, self).__init__(**kwargs)
367 367 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
368 368
369 369 # validate connection dicts:
370 370 for k,v in self.client_info.iteritems():
371 371 if k == 'task':
372 372 util.validate_url_container(v[1])
373 373 else:
374 374 util.validate_url_container(v)
375 375 # util.validate_url_container(self.client_info)
376 376 util.validate_url_container(self.engine_info)
377 377
378 378 # register our callbacks
379 379 self.query.on_recv(self.dispatch_query)
380 380 self.monitor.on_recv(self.dispatch_monitor_traffic)
381 381
382 382 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
383 383 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
384 384
385 385 self.monitor_handlers = {b'in' : self.save_queue_request,
386 386 b'out': self.save_queue_result,
387 387 b'intask': self.save_task_request,
388 388 b'outtask': self.save_task_result,
389 389 b'tracktask': self.save_task_destination,
390 390 b'incontrol': _passer,
391 391 b'outcontrol': _passer,
392 392 b'iopub': self.save_iopub_message,
393 393 }
394 394
395 395 self.query_handlers = {'queue_request': self.queue_status,
396 396 'result_request': self.get_results,
397 397 'history_request': self.get_history,
398 398 'db_request': self.db_query,
399 399 'purge_request': self.purge_results,
400 400 'load_request': self.check_load,
401 401 'resubmit_request': self.resubmit_task,
402 402 'shutdown_request': self.shutdown_request,
403 403 'registration_request' : self.register_engine,
404 404 'unregistration_request' : self.unregister_engine,
405 405 'connection_request': self.connection_request,
406 406 }
407 407
408 408 # ignore resubmit replies
409 409 self.resubmit.on_recv(lambda msg: None, copy=False)
410 410
411 411 self.log.info("hub::created hub")
412 412
413 413 @property
414 414 def _next_id(self):
415 415 """gemerate a new ID.
416 416
417 417 No longer reuse old ids, just count from 0."""
418 418 newid = self._idcounter
419 419 self._idcounter += 1
420 420 return newid
421 421 # newid = 0
422 422 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
423 423 # # print newid, self.ids, self.incoming_registrations
424 424 # while newid in self.ids or newid in incoming:
425 425 # newid += 1
426 426 # return newid
427 427
428 428 #-----------------------------------------------------------------------------
429 429 # message validation
430 430 #-----------------------------------------------------------------------------
431 431
432 432 def _validate_targets(self, targets):
433 433 """turn any valid targets argument into a list of integer ids"""
434 434 if targets is None:
435 435 # default to all
436 436 return self.ids
437 437
438 438 if isinstance(targets, (int,str,unicode)):
439 439 # only one target specified
440 440 targets = [targets]
441 441 _targets = []
442 442 for t in targets:
443 443 # map raw identities to ids
444 444 if isinstance(t, (str,unicode)):
445 445 t = self.by_ident.get(cast_bytes(t), t)
446 446 _targets.append(t)
447 447 targets = _targets
448 448 bad_targets = [ t for t in targets if t not in self.ids ]
449 449 if bad_targets:
450 450 raise IndexError("No Such Engine: %r" % bad_targets)
451 451 if not targets:
452 452 raise IndexError("No Engines Registered")
453 453 return targets
454 454
455 455 #-----------------------------------------------------------------------------
456 456 # dispatch methods (1 per stream)
457 457 #-----------------------------------------------------------------------------
458 458
459 459
460 460 @util.log_errors
461 461 def dispatch_monitor_traffic(self, msg):
462 462 """all ME and Task queue messages come through here, as well as
463 463 IOPub traffic."""
464 464 self.log.debug("monitor traffic: %r", msg[0])
465 465 switch = msg[0]
466 466 try:
467 467 idents, msg = self.session.feed_identities(msg[1:])
468 468 except ValueError:
469 469 idents=[]
470 470 if not idents:
471 471 self.log.error("Monitor message without topic: %r", msg)
472 472 return
473 473 handler = self.monitor_handlers.get(switch, None)
474 474 if handler is not None:
475 475 handler(idents, msg)
476 476 else:
477 477 self.log.error("Unrecognized monitor topic: %r", switch)
478 478
479 479
480 480 @util.log_errors
481 481 def dispatch_query(self, msg):
482 482 """Route registration requests and queries from clients."""
483 483 try:
484 484 idents, msg = self.session.feed_identities(msg)
485 485 except ValueError:
486 486 idents = []
487 487 if not idents:
488 488 self.log.error("Bad Query Message: %r", msg)
489 489 return
490 490 client_id = idents[0]
491 491 try:
492 492 msg = self.session.unserialize(msg, content=True)
493 493 except Exception:
494 494 content = error.wrap_exception()
495 495 self.log.error("Bad Query Message: %r", msg, exc_info=True)
496 496 self.session.send(self.query, "hub_error", ident=client_id,
497 497 content=content)
498 498 return
499 499 # print client_id, header, parent, content
500 500 #switch on message type:
501 501 msg_type = msg['header']['msg_type']
502 502 self.log.info("client::client %r requested %r", client_id, msg_type)
503 503 handler = self.query_handlers.get(msg_type, None)
504 504 try:
505 505 assert handler is not None, "Bad Message Type: %r" % msg_type
506 506 except:
507 507 content = error.wrap_exception()
508 508 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
509 509 self.session.send(self.query, "hub_error", ident=client_id,
510 510 content=content)
511 511 return
512 512
513 513 else:
514 514 handler(idents, msg)
515 515
516 516 def dispatch_db(self, msg):
517 517 """"""
518 518 raise NotImplementedError
519 519
520 520 #---------------------------------------------------------------------------
521 521 # handler methods (1 per event)
522 522 #---------------------------------------------------------------------------
523 523
524 524 #----------------------- Heartbeat --------------------------------------
525 525
526 526 def handle_new_heart(self, heart):
527 527 """handler to attach to heartbeater.
528 528 Called when a new heart starts to beat.
529 529 Triggers completion of registration."""
530 530 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
531 531 if heart not in self.incoming_registrations:
532 532 self.log.info("heartbeat::ignoring new heart: %r", heart)
533 533 else:
534 534 self.finish_registration(heart)
535 535
536 536
537 537 def handle_heart_failure(self, heart):
538 538 """handler to attach to heartbeater.
539 539 called when a previously registered heart fails to respond to beat request.
540 540 triggers unregistration"""
541 541 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
542 542 eid = self.hearts.get(heart, None)
543 543 queue = self.engines[eid].queue
544 544 if eid is None or self.keytable[eid] in self.dead_engines:
545 545 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
546 546 else:
547 547 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
548 548
549 549 #----------------------- MUX Queue Traffic ------------------------------
550 550
551 551 def save_queue_request(self, idents, msg):
552 552 if len(idents) < 2:
553 553 self.log.error("invalid identity prefix: %r", idents)
554 554 return
555 555 queue_id, client_id = idents[:2]
556 556 try:
557 557 msg = self.session.unserialize(msg)
558 558 except Exception:
559 559 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
560 560 return
561 561
562 562 eid = self.by_ident.get(queue_id, None)
563 563 if eid is None:
564 564 self.log.error("queue::target %r not registered", queue_id)
565 565 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
566 566 return
567 567 record = init_record(msg)
568 568 msg_id = record['msg_id']
569 569 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
570 570 # Unicode in records
571 571 record['engine_uuid'] = queue_id.decode('ascii')
572 572 record['client_uuid'] = client_id.decode('ascii')
573 573 record['queue'] = 'mux'
574 574
575 575 try:
576 576 # it's posible iopub arrived first:
577 577 existing = self.db.get_record(msg_id)
578 578 for key,evalue in existing.iteritems():
579 579 rvalue = record.get(key, None)
580 580 if evalue and rvalue and evalue != rvalue:
581 581 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
582 582 elif evalue and not rvalue:
583 583 record[key] = evalue
584 584 try:
585 585 self.db.update_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
588 588 except KeyError:
589 589 try:
590 590 self.db.add_record(msg_id, record)
591 591 except Exception:
592 592 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
593 593
594 594
595 595 self.pending.add(msg_id)
596 596 self.queues[eid].append(msg_id)
597 597
598 598 def save_queue_result(self, idents, msg):
599 599 if len(idents) < 2:
600 600 self.log.error("invalid identity prefix: %r", idents)
601 601 return
602 602
603 603 client_id, queue_id = idents[:2]
604 604 try:
605 605 msg = self.session.unserialize(msg)
606 606 except Exception:
607 607 self.log.error("queue::engine %r sent invalid message to %r: %r",
608 608 queue_id, client_id, msg, exc_info=True)
609 609 return
610 610
611 611 eid = self.by_ident.get(queue_id, None)
612 612 if eid is None:
613 613 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
614 614 return
615 615
616 616 parent = msg['parent_header']
617 617 if not parent:
618 618 return
619 619 msg_id = parent['msg_id']
620 620 if msg_id in self.pending:
621 621 self.pending.remove(msg_id)
622 622 self.all_completed.add(msg_id)
623 623 self.queues[eid].remove(msg_id)
624 624 self.completed[eid].append(msg_id)
625 625 self.log.info("queue::request %r completed on %s", msg_id, eid)
626 626 elif msg_id not in self.all_completed:
627 627 # it could be a result from a dead engine that died before delivering the
628 628 # result
629 629 self.log.warn("queue:: unknown msg finished %r", msg_id)
630 630 return
631 631 # update record anyway, because the unregistration could have been premature
632 632 rheader = msg['header']
633 633 completed = rheader['date']
634 634 started = rheader.get('started', None)
635 635 result = {
636 636 'result_header' : rheader,
637 637 'result_content': msg['content'],
638 638 'received': datetime.now(),
639 639 'started' : started,
640 640 'completed' : completed
641 641 }
642 642
643 643 result['result_buffers'] = msg['buffers']
644 644 try:
645 645 self.db.update_record(msg_id, result)
646 646 except Exception:
647 647 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
648 648
649 649
650 650 #--------------------- Task Queue Traffic ------------------------------
651 651
652 652 def save_task_request(self, idents, msg):
653 653 """Save the submission of a task."""
654 654 client_id = idents[0]
655 655
656 656 try:
657 657 msg = self.session.unserialize(msg)
658 658 except Exception:
659 659 self.log.error("task::client %r sent invalid task message: %r",
660 660 client_id, msg, exc_info=True)
661 661 return
662 662 record = init_record(msg)
663 663
664 664 record['client_uuid'] = client_id.decode('ascii')
665 665 record['queue'] = 'task'
666 666 header = msg['header']
667 667 msg_id = header['msg_id']
668 668 self.pending.add(msg_id)
669 669 self.unassigned.add(msg_id)
670 670 try:
671 671 # it's posible iopub arrived first:
672 672 existing = self.db.get_record(msg_id)
673 673 if existing['resubmitted']:
674 674 for key in ('submitted', 'client_uuid', 'buffers'):
675 675 # don't clobber these keys on resubmit
676 676 # submitted and client_uuid should be different
677 677 # and buffers might be big, and shouldn't have changed
678 678 record.pop(key)
679 679 # still check content,header which should not change
680 680 # but are not expensive to compare as buffers
681 681
682 682 for key,evalue in existing.iteritems():
683 683 if key.endswith('buffers'):
684 684 # don't compare buffers
685 685 continue
686 686 rvalue = record.get(key, None)
687 687 if evalue and rvalue and evalue != rvalue:
688 688 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
689 689 elif evalue and not rvalue:
690 690 record[key] = evalue
691 691 try:
692 692 self.db.update_record(msg_id, record)
693 693 except Exception:
694 694 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
695 695 except KeyError:
696 696 try:
697 697 self.db.add_record(msg_id, record)
698 698 except Exception:
699 699 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
700 700 except Exception:
701 701 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
702 702
703 703 def save_task_result(self, idents, msg):
704 704 """save the result of a completed task."""
705 705 client_id = idents[0]
706 706 try:
707 707 msg = self.session.unserialize(msg)
708 708 except Exception:
709 709 self.log.error("task::invalid task result message send to %r: %r",
710 710 client_id, msg, exc_info=True)
711 711 return
712 712
713 713 parent = msg['parent_header']
714 714 if not parent:
715 715 # print msg
716 716 self.log.warn("Task %r had no parent!", msg)
717 717 return
718 718 msg_id = parent['msg_id']
719 719 if msg_id in self.unassigned:
720 720 self.unassigned.remove(msg_id)
721 721
722 722 header = msg['header']
723 723 engine_uuid = header.get('engine', u'')
724 724 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
725 725
726 726 status = header.get('status', None)
727 727
728 728 if msg_id in self.pending:
729 729 self.log.info("task::task %r finished on %s", msg_id, eid)
730 730 self.pending.remove(msg_id)
731 731 self.all_completed.add(msg_id)
732 732 if eid is not None:
733 733 if status != 'aborted':
734 734 self.completed[eid].append(msg_id)
735 735 if msg_id in self.tasks[eid]:
736 736 self.tasks[eid].remove(msg_id)
737 737 completed = header['date']
738 738 started = header.get('started', None)
739 739 result = {
740 740 'result_header' : header,
741 741 'result_content': msg['content'],
742 742 'started' : started,
743 743 'completed' : completed,
744 744 'received' : datetime.now(),
745 745 'engine_uuid': engine_uuid,
746 746 }
747 747
748 748 result['result_buffers'] = msg['buffers']
749 749 try:
750 750 self.db.update_record(msg_id, result)
751 751 except Exception:
752 752 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
753 753
754 754 else:
755 755 self.log.debug("task::unknown task %r finished", msg_id)
756 756
757 757 def save_task_destination(self, idents, msg):
758 758 try:
759 759 msg = self.session.unserialize(msg, content=True)
760 760 except Exception:
761 761 self.log.error("task::invalid task tracking message", exc_info=True)
762 762 return
763 763 content = msg['content']
764 764 # print (content)
765 765 msg_id = content['msg_id']
766 766 engine_uuid = content['engine_id']
767 767 eid = self.by_ident[cast_bytes(engine_uuid)]
768 768
769 769 self.log.info("task::task %r arrived on %r", msg_id, eid)
770 770 if msg_id in self.unassigned:
771 771 self.unassigned.remove(msg_id)
772 772 # else:
773 773 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
774 774
775 775 self.tasks[eid].append(msg_id)
776 776 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
777 777 try:
778 778 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
779 779 except Exception:
780 780 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
781 781
782 782
783 783 def mia_task_request(self, idents, msg):
784 784 raise NotImplementedError
785 785 client_id = idents[0]
786 786 # content = dict(mia=self.mia,status='ok')
787 787 # self.session.send('mia_reply', content=content, idents=client_id)
788 788
789 789
790 790 #--------------------- IOPub Traffic ------------------------------
791 791
792 792 def save_iopub_message(self, topics, msg):
793 793 """save an iopub message into the db"""
794 794 # print (topics)
795 795 try:
796 796 msg = self.session.unserialize(msg, content=True)
797 797 except Exception:
798 798 self.log.error("iopub::invalid IOPub message", exc_info=True)
799 799 return
800 800
801 801 parent = msg['parent_header']
802 802 if not parent:
803 803 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
804 804 return
805 805 msg_id = parent['msg_id']
806 806 msg_type = msg['header']['msg_type']
807 807 content = msg['content']
808 808
809 809 # ensure msg_id is in db
810 810 try:
811 811 rec = self.db.get_record(msg_id)
812 812 except KeyError:
813 813 rec = empty_record()
814 814 rec['msg_id'] = msg_id
815 815 self.db.add_record(msg_id, rec)
816 816 # stream
817 817 d = {}
818 818 if msg_type == 'stream':
819 819 name = content['name']
820 820 s = rec[name] or ''
821 821 d[name] = s + content['data']
822 822
823 823 elif msg_type == 'pyerr':
824 824 d['pyerr'] = content
825 825 elif msg_type == 'pyin':
826 826 d['pyin'] = content['code']
827 elif msg_type in ('display_data', 'pyout'):
828 d[msg_type] = content
829 elif msg_type == 'status':
830 pass
827 831 else:
828 d[msg_type] = content.get('data', '')
832 self.log.warn("unhandled iopub msg_type: %r", msg_type)
833
834 if not d:
835 return
829 836
830 837 try:
831 838 self.db.update_record(msg_id, d)
832 839 except Exception:
833 840 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
834 841
835 842
836 843
837 844 #-------------------------------------------------------------------------
838 845 # Registration requests
839 846 #-------------------------------------------------------------------------
840 847
841 848 def connection_request(self, client_id, msg):
842 849 """Reply with connection addresses for clients."""
843 850 self.log.info("client::client %r connected", client_id)
844 851 content = dict(status='ok')
845 852 content.update(self.client_info)
846 853 jsonable = {}
847 854 for k,v in self.keytable.iteritems():
848 855 if v not in self.dead_engines:
849 856 jsonable[str(k)] = v.decode('ascii')
850 857 content['engines'] = jsonable
851 858 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
852 859
853 860 def register_engine(self, reg, msg):
854 861 """Register a new engine."""
855 862 content = msg['content']
856 863 try:
857 864 queue = cast_bytes(content['queue'])
858 865 except KeyError:
859 866 self.log.error("registration::queue not specified", exc_info=True)
860 867 return
861 868 heart = content.get('heartbeat', None)
862 869 if heart:
863 870 heart = cast_bytes(heart)
864 871 """register a new engine, and create the socket(s) necessary"""
865 872 eid = self._next_id
866 873 # print (eid, queue, reg, heart)
867 874
868 875 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
869 876
870 877 content = dict(id=eid,status='ok')
871 878 content.update(self.engine_info)
872 879 # check if requesting available IDs:
873 880 if queue in self.by_ident:
874 881 try:
875 882 raise KeyError("queue_id %r in use" % queue)
876 883 except:
877 884 content = error.wrap_exception()
878 885 self.log.error("queue_id %r in use", queue, exc_info=True)
879 886 elif heart in self.hearts: # need to check unique hearts?
880 887 try:
881 888 raise KeyError("heart_id %r in use" % heart)
882 889 except:
883 890 self.log.error("heart_id %r in use", heart, exc_info=True)
884 891 content = error.wrap_exception()
885 892 else:
886 893 for h, pack in self.incoming_registrations.iteritems():
887 894 if heart == h:
888 895 try:
889 896 raise KeyError("heart_id %r in use" % heart)
890 897 except:
891 898 self.log.error("heart_id %r in use", heart, exc_info=True)
892 899 content = error.wrap_exception()
893 900 break
894 901 elif queue == pack[1]:
895 902 try:
896 903 raise KeyError("queue_id %r in use" % queue)
897 904 except:
898 905 self.log.error("queue_id %r in use", queue, exc_info=True)
899 906 content = error.wrap_exception()
900 907 break
901 908
902 909 msg = self.session.send(self.query, "registration_reply",
903 910 content=content,
904 911 ident=reg)
905 912
906 913 if content['status'] == 'ok':
907 914 if heart in self.heartmonitor.hearts:
908 915 # already beating
909 916 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
910 917 self.finish_registration(heart)
911 918 else:
912 919 purge = lambda : self._purge_stalled_registration(heart)
913 920 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
914 921 dc.start()
915 922 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
916 923 else:
917 924 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
918 925 return eid
919 926
920 927 def unregister_engine(self, ident, msg):
921 928 """Unregister an engine that explicitly requested to leave."""
922 929 try:
923 930 eid = msg['content']['id']
924 931 except:
925 932 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
926 933 return
927 934 self.log.info("registration::unregister_engine(%r)", eid)
928 935 # print (eid)
929 936 uuid = self.keytable[eid]
930 937 content=dict(id=eid, queue=uuid.decode('ascii'))
931 938 self.dead_engines.add(uuid)
932 939 # self.ids.remove(eid)
933 940 # uuid = self.keytable.pop(eid)
934 941 #
935 942 # ec = self.engines.pop(eid)
936 943 # self.hearts.pop(ec.heartbeat)
937 944 # self.by_ident.pop(ec.queue)
938 945 # self.completed.pop(eid)
939 946 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
940 947 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
941 948 dc.start()
942 949 ############## TODO: HANDLE IT ################
943 950
944 951 if self.notifier:
945 952 self.session.send(self.notifier, "unregistration_notification", content=content)
946 953
947 954 def _handle_stranded_msgs(self, eid, uuid):
948 955 """Handle messages known to be on an engine when the engine unregisters.
949 956
950 957 It is possible that this will fire prematurely - that is, an engine will
951 958 go down after completing a result, and the client will be notified
952 959 that the result failed and later receive the actual result.
953 960 """
954 961
955 962 outstanding = self.queues[eid]
956 963
957 964 for msg_id in outstanding:
958 965 self.pending.remove(msg_id)
959 966 self.all_completed.add(msg_id)
960 967 try:
961 968 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
962 969 except:
963 970 content = error.wrap_exception()
964 971 # build a fake header:
965 972 header = {}
966 973 header['engine'] = uuid
967 974 header['date'] = datetime.now()
968 975 rec = dict(result_content=content, result_header=header, result_buffers=[])
969 976 rec['completed'] = header['date']
970 977 rec['engine_uuid'] = uuid
971 978 try:
972 979 self.db.update_record(msg_id, rec)
973 980 except Exception:
974 981 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
975 982
976 983
977 984 def finish_registration(self, heart):
978 985 """Second half of engine registration, called after our HeartMonitor
979 986 has received a beat from the Engine's Heart."""
980 987 try:
981 988 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
982 989 except KeyError:
983 990 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
984 991 return
985 992 self.log.info("registration::finished registering engine %i:%r", eid, queue)
986 993 if purge is not None:
987 994 purge.stop()
988 995 control = queue
989 996 self.ids.add(eid)
990 997 self.keytable[eid] = queue
991 998 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
992 999 control=control, heartbeat=heart)
993 1000 self.by_ident[queue] = eid
994 1001 self.queues[eid] = list()
995 1002 self.tasks[eid] = list()
996 1003 self.completed[eid] = list()
997 1004 self.hearts[heart] = eid
998 1005 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
999 1006 if self.notifier:
1000 1007 self.session.send(self.notifier, "registration_notification", content=content)
1001 1008 self.log.info("engine::Engine Connected: %i", eid)
1002 1009
1003 1010 def _purge_stalled_registration(self, heart):
1004 1011 if heart in self.incoming_registrations:
1005 1012 eid = self.incoming_registrations.pop(heart)[0]
1006 1013 self.log.info("registration::purging stalled registration: %i", eid)
1007 1014 else:
1008 1015 pass
1009 1016
1010 1017 #-------------------------------------------------------------------------
1011 1018 # Client Requests
1012 1019 #-------------------------------------------------------------------------
1013 1020
1014 1021 def shutdown_request(self, client_id, msg):
1015 1022 """handle shutdown request."""
1016 1023 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1017 1024 # also notify other clients of shutdown
1018 1025 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1019 1026 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1020 1027 dc.start()
1021 1028
1022 1029 def _shutdown(self):
1023 1030 self.log.info("hub::hub shutting down.")
1024 1031 time.sleep(0.1)
1025 1032 sys.exit(0)
1026 1033
1027 1034
1028 1035 def check_load(self, client_id, msg):
1029 1036 content = msg['content']
1030 1037 try:
1031 1038 targets = content['targets']
1032 1039 targets = self._validate_targets(targets)
1033 1040 except:
1034 1041 content = error.wrap_exception()
1035 1042 self.session.send(self.query, "hub_error",
1036 1043 content=content, ident=client_id)
1037 1044 return
1038 1045
1039 1046 content = dict(status='ok')
1040 1047 # loads = {}
1041 1048 for t in targets:
1042 1049 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1043 1050 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1044 1051
1045 1052
1046 1053 def queue_status(self, client_id, msg):
1047 1054 """Return the Queue status of one or more targets.
1048 1055 if verbose: return the msg_ids
1049 1056 else: return len of each type.
1050 1057 keys: queue (pending MUX jobs)
1051 1058 tasks (pending Task jobs)
1052 1059 completed (finished jobs from both queues)"""
1053 1060 content = msg['content']
1054 1061 targets = content['targets']
1055 1062 try:
1056 1063 targets = self._validate_targets(targets)
1057 1064 except:
1058 1065 content = error.wrap_exception()
1059 1066 self.session.send(self.query, "hub_error",
1060 1067 content=content, ident=client_id)
1061 1068 return
1062 1069 verbose = content.get('verbose', False)
1063 1070 content = dict(status='ok')
1064 1071 for t in targets:
1065 1072 queue = self.queues[t]
1066 1073 completed = self.completed[t]
1067 1074 tasks = self.tasks[t]
1068 1075 if not verbose:
1069 1076 queue = len(queue)
1070 1077 completed = len(completed)
1071 1078 tasks = len(tasks)
1072 1079 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1073 1080 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1074 1081 # print (content)
1075 1082 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1076 1083
1077 1084 def purge_results(self, client_id, msg):
1078 1085 """Purge results from memory. This method is more valuable before we move
1079 1086 to a DB based message storage mechanism."""
1080 1087 content = msg['content']
1081 1088 self.log.info("Dropping records with %s", content)
1082 1089 msg_ids = content.get('msg_ids', [])
1083 1090 reply = dict(status='ok')
1084 1091 if msg_ids == 'all':
1085 1092 try:
1086 1093 self.db.drop_matching_records(dict(completed={'$ne':None}))
1087 1094 except Exception:
1088 1095 reply = error.wrap_exception()
1089 1096 else:
1090 1097 pending = filter(lambda m: m in self.pending, msg_ids)
1091 1098 if pending:
1092 1099 try:
1093 1100 raise IndexError("msg pending: %r" % pending[0])
1094 1101 except:
1095 1102 reply = error.wrap_exception()
1096 1103 else:
1097 1104 try:
1098 1105 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1099 1106 except Exception:
1100 1107 reply = error.wrap_exception()
1101 1108
1102 1109 if reply['status'] == 'ok':
1103 1110 eids = content.get('engine_ids', [])
1104 1111 for eid in eids:
1105 1112 if eid not in self.engines:
1106 1113 try:
1107 1114 raise IndexError("No such engine: %i" % eid)
1108 1115 except:
1109 1116 reply = error.wrap_exception()
1110 1117 break
1111 1118 uid = self.engines[eid].queue
1112 1119 try:
1113 1120 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1114 1121 except Exception:
1115 1122 reply = error.wrap_exception()
1116 1123 break
1117 1124
1118 1125 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1119 1126
1120 1127 def resubmit_task(self, client_id, msg):
1121 1128 """Resubmit one or more tasks."""
1122 1129 def finish(reply):
1123 1130 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1124 1131
1125 1132 content = msg['content']
1126 1133 msg_ids = content['msg_ids']
1127 1134 reply = dict(status='ok')
1128 1135 try:
1129 1136 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1130 1137 'header', 'content', 'buffers'])
1131 1138 except Exception:
1132 1139 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1133 1140 return finish(error.wrap_exception())
1134 1141
1135 1142 # validate msg_ids
1136 1143 found_ids = [ rec['msg_id'] for rec in records ]
1137 1144 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1138 1145 if len(records) > len(msg_ids):
1139 1146 try:
1140 1147 raise RuntimeError("DB appears to be in an inconsistent state."
1141 1148 "More matching records were found than should exist")
1142 1149 except Exception:
1143 1150 return finish(error.wrap_exception())
1144 1151 elif len(records) < len(msg_ids):
1145 1152 missing = [ m for m in msg_ids if m not in found_ids ]
1146 1153 try:
1147 1154 raise KeyError("No such msg(s): %r" % missing)
1148 1155 except KeyError:
1149 1156 return finish(error.wrap_exception())
1150 1157 elif pending_ids:
1151 1158 pass
1152 1159 # no need to raise on resubmit of pending task, now that we
1153 1160 # resubmit under new ID, but do we want to raise anyway?
1154 1161 # msg_id = invalid_ids[0]
1155 1162 # try:
1156 1163 # raise ValueError("Task(s) %r appears to be inflight" % )
1157 1164 # except Exception:
1158 1165 # return finish(error.wrap_exception())
1159 1166
1160 1167 # mapping of original IDs to resubmitted IDs
1161 1168 resubmitted = {}
1162 1169
1163 1170 # send the messages
1164 1171 for rec in records:
1165 1172 header = rec['header']
1166 1173 msg = self.session.msg(header['msg_type'], parent=header)
1167 1174 msg_id = msg['msg_id']
1168 1175 msg['content'] = rec['content']
1169 1176
1170 1177 # use the old header, but update msg_id and timestamp
1171 1178 fresh = msg['header']
1172 1179 header['msg_id'] = fresh['msg_id']
1173 1180 header['date'] = fresh['date']
1174 1181 msg['header'] = header
1175 1182
1176 1183 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1177 1184
1178 1185 resubmitted[rec['msg_id']] = msg_id
1179 1186 self.pending.add(msg_id)
1180 1187 msg['buffers'] = rec['buffers']
1181 1188 try:
1182 1189 self.db.add_record(msg_id, init_record(msg))
1183 1190 except Exception:
1184 1191 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1185 1192
1186 1193 finish(dict(status='ok', resubmitted=resubmitted))
1187 1194
1188 1195 # store the new IDs in the Task DB
1189 1196 for msg_id, resubmit_id in resubmitted.iteritems():
1190 1197 try:
1191 1198 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1192 1199 except Exception:
1193 1200 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1194 1201
1195 1202
1196 1203 def _extract_record(self, rec):
1197 1204 """decompose a TaskRecord dict into subsection of reply for get_result"""
1198 1205 io_dict = {}
1199 1206 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1200 1207 io_dict[key] = rec[key]
1201 1208 content = { 'result_content': rec['result_content'],
1202 1209 'header': rec['header'],
1203 1210 'result_header' : rec['result_header'],
1204 1211 'received' : rec['received'],
1205 1212 'io' : io_dict,
1206 1213 }
1207 1214 if rec['result_buffers']:
1208 1215 buffers = map(bytes, rec['result_buffers'])
1209 1216 else:
1210 1217 buffers = []
1211 1218
1212 1219 return content, buffers
1213 1220
1214 1221 def get_results(self, client_id, msg):
1215 1222 """Get the result of 1 or more messages."""
1216 1223 content = msg['content']
1217 1224 msg_ids = sorted(set(content['msg_ids']))
1218 1225 statusonly = content.get('status_only', False)
1219 1226 pending = []
1220 1227 completed = []
1221 1228 content = dict(status='ok')
1222 1229 content['pending'] = pending
1223 1230 content['completed'] = completed
1224 1231 buffers = []
1225 1232 if not statusonly:
1226 1233 try:
1227 1234 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1228 1235 # turn match list into dict, for faster lookup
1229 1236 records = {}
1230 1237 for rec in matches:
1231 1238 records[rec['msg_id']] = rec
1232 1239 except Exception:
1233 1240 content = error.wrap_exception()
1234 1241 self.session.send(self.query, "result_reply", content=content,
1235 1242 parent=msg, ident=client_id)
1236 1243 return
1237 1244 else:
1238 1245 records = {}
1239 1246 for msg_id in msg_ids:
1240 1247 if msg_id in self.pending:
1241 1248 pending.append(msg_id)
1242 1249 elif msg_id in self.all_completed:
1243 1250 completed.append(msg_id)
1244 1251 if not statusonly:
1245 1252 c,bufs = self._extract_record(records[msg_id])
1246 1253 content[msg_id] = c
1247 1254 buffers.extend(bufs)
1248 1255 elif msg_id in records:
1249 1256 if rec['completed']:
1250 1257 completed.append(msg_id)
1251 1258 c,bufs = self._extract_record(records[msg_id])
1252 1259 content[msg_id] = c
1253 1260 buffers.extend(bufs)
1254 1261 else:
1255 1262 pending.append(msg_id)
1256 1263 else:
1257 1264 try:
1258 1265 raise KeyError('No such message: '+msg_id)
1259 1266 except:
1260 1267 content = error.wrap_exception()
1261 1268 break
1262 1269 self.session.send(self.query, "result_reply", content=content,
1263 1270 parent=msg, ident=client_id,
1264 1271 buffers=buffers)
1265 1272
1266 1273 def get_history(self, client_id, msg):
1267 1274 """Get a list of all msg_ids in our DB records"""
1268 1275 try:
1269 1276 msg_ids = self.db.get_history()
1270 1277 except Exception as e:
1271 1278 content = error.wrap_exception()
1272 1279 else:
1273 1280 content = dict(status='ok', history=msg_ids)
1274 1281
1275 1282 self.session.send(self.query, "history_reply", content=content,
1276 1283 parent=msg, ident=client_id)
1277 1284
1278 1285 def db_query(self, client_id, msg):
1279 1286 """Perform a raw query on the task record database."""
1280 1287 content = msg['content']
1281 1288 query = content.get('query', {})
1282 1289 keys = content.get('keys', None)
1283 1290 buffers = []
1284 1291 empty = list()
1285 1292 try:
1286 1293 records = self.db.find_records(query, keys)
1287 1294 except Exception as e:
1288 1295 content = error.wrap_exception()
1289 1296 else:
1290 1297 # extract buffers from reply content:
1291 1298 if keys is not None:
1292 1299 buffer_lens = [] if 'buffers' in keys else None
1293 1300 result_buffer_lens = [] if 'result_buffers' in keys else None
1294 1301 else:
1295 1302 buffer_lens = None
1296 1303 result_buffer_lens = None
1297 1304
1298 1305 for rec in records:
1299 1306 # buffers may be None, so double check
1300 1307 b = rec.pop('buffers', empty) or empty
1301 1308 if buffer_lens is not None:
1302 1309 buffer_lens.append(len(b))
1303 1310 buffers.extend(b)
1304 1311 rb = rec.pop('result_buffers', empty) or empty
1305 1312 if result_buffer_lens is not None:
1306 1313 result_buffer_lens.append(len(rb))
1307 1314 buffers.extend(rb)
1308 1315 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1309 1316 result_buffer_lens=result_buffer_lens)
1310 1317 # self.log.debug (content)
1311 1318 self.session.send(self.query, "db_reply", content=content,
1312 1319 parent=msg, ident=client_id,
1313 1320 buffers=buffers)
1314 1321
@@ -1,436 +1,455
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import time
22 22 from datetime import datetime
23 23 from tempfile import mktemp
24 24
25 25 import zmq
26 26
27 27 from IPython import parallel
28 28 from IPython.parallel.client import client as clientmod
29 29 from IPython.parallel import error
30 30 from IPython.parallel import AsyncResult, AsyncHubResult
31 31 from IPython.parallel import LoadBalancedView, DirectView
32 32
33 33 from clienttest import ClusterTestCase, segfault, wait, add_engines
34 34
35 35 def setup():
36 36 add_engines(4, total=True)
37 37
38 38 class TestClient(ClusterTestCase):
39 39
40 40 def test_ids(self):
41 41 n = len(self.client.ids)
42 42 self.add_engines(2)
43 43 self.assertEquals(len(self.client.ids), n+2)
44 44
45 45 def test_view_indexing(self):
46 46 """test index access for views"""
47 47 self.minimum_engines(4)
48 48 targets = self.client._build_targets('all')[-1]
49 49 v = self.client[:]
50 50 self.assertEquals(v.targets, targets)
51 51 t = self.client.ids[2]
52 52 v = self.client[t]
53 53 self.assert_(isinstance(v, DirectView))
54 54 self.assertEquals(v.targets, t)
55 55 t = self.client.ids[2:4]
56 56 v = self.client[t]
57 57 self.assert_(isinstance(v, DirectView))
58 58 self.assertEquals(v.targets, t)
59 59 v = self.client[::2]
60 60 self.assert_(isinstance(v, DirectView))
61 61 self.assertEquals(v.targets, targets[::2])
62 62 v = self.client[1::3]
63 63 self.assert_(isinstance(v, DirectView))
64 64 self.assertEquals(v.targets, targets[1::3])
65 65 v = self.client[:-3]
66 66 self.assert_(isinstance(v, DirectView))
67 67 self.assertEquals(v.targets, targets[:-3])
68 68 v = self.client[-1]
69 69 self.assert_(isinstance(v, DirectView))
70 70 self.assertEquals(v.targets, targets[-1])
71 71 self.assertRaises(TypeError, lambda : self.client[None])
72 72
73 73 def test_lbview_targets(self):
74 74 """test load_balanced_view targets"""
75 75 v = self.client.load_balanced_view()
76 76 self.assertEquals(v.targets, None)
77 77 v = self.client.load_balanced_view(-1)
78 78 self.assertEquals(v.targets, [self.client.ids[-1]])
79 79 v = self.client.load_balanced_view('all')
80 80 self.assertEquals(v.targets, None)
81 81
82 82 def test_dview_targets(self):
83 83 """test direct_view targets"""
84 84 v = self.client.direct_view()
85 85 self.assertEquals(v.targets, 'all')
86 86 v = self.client.direct_view('all')
87 87 self.assertEquals(v.targets, 'all')
88 88 v = self.client.direct_view(-1)
89 89 self.assertEquals(v.targets, self.client.ids[-1])
90 90
91 91 def test_lazy_all_targets(self):
92 92 """test lazy evaluation of rc.direct_view('all')"""
93 93 v = self.client.direct_view()
94 94 self.assertEquals(v.targets, 'all')
95 95
96 96 def double(x):
97 97 return x*2
98 98 seq = range(100)
99 99 ref = [ double(x) for x in seq ]
100 100
101 101 # add some engines, which should be used
102 102 self.add_engines(1)
103 103 n1 = len(self.client.ids)
104 104
105 105 # simple apply
106 106 r = v.apply_sync(lambda : 1)
107 107 self.assertEquals(r, [1] * n1)
108 108
109 109 # map goes through remotefunction
110 110 r = v.map_sync(double, seq)
111 111 self.assertEquals(r, ref)
112 112
113 113 # add a couple more engines, and try again
114 114 self.add_engines(2)
115 115 n2 = len(self.client.ids)
116 116 self.assertNotEquals(n2, n1)
117 117
118 118 # apply
119 119 r = v.apply_sync(lambda : 1)
120 120 self.assertEquals(r, [1] * n2)
121 121
122 122 # map
123 123 r = v.map_sync(double, seq)
124 124 self.assertEquals(r, ref)
125 125
126 126 def test_targets(self):
127 127 """test various valid targets arguments"""
128 128 build = self.client._build_targets
129 129 ids = self.client.ids
130 130 idents,targets = build(None)
131 131 self.assertEquals(ids, targets)
132 132
133 133 def test_clear(self):
134 134 """test clear behavior"""
135 135 self.minimum_engines(2)
136 136 v = self.client[:]
137 137 v.block=True
138 138 v.push(dict(a=5))
139 139 v.pull('a')
140 140 id0 = self.client.ids[-1]
141 141 self.client.clear(targets=id0, block=True)
142 142 a = self.client[:-1].get('a')
143 143 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
144 144 self.client.clear(block=True)
145 145 for i in self.client.ids:
146 146 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
147 147
148 148 def test_get_result(self):
149 149 """test getting results from the Hub."""
150 150 c = clientmod.Client(profile='iptest')
151 151 t = c.ids[-1]
152 152 ar = c[t].apply_async(wait, 1)
153 153 # give the monitor time to notice the message
154 154 time.sleep(.25)
155 155 ahr = self.client.get_result(ar.msg_ids)
156 156 self.assertTrue(isinstance(ahr, AsyncHubResult))
157 157 self.assertEquals(ahr.get(), ar.get())
158 158 ar2 = self.client.get_result(ar.msg_ids)
159 159 self.assertFalse(isinstance(ar2, AsyncHubResult))
160 160 c.close()
161 161
162 def test_get_execute_result(self):
163 """test getting execute results from the Hub."""
164 c = clientmod.Client(profile='iptest')
165 t = c.ids[-1]
166 cell = '\n'.join([
167 'import time',
168 'time.sleep(0.25)',
169 '5'
170 ])
171 ar = c[t].execute("import time; time.sleep(1)", silent=False)
172 # give the monitor time to notice the message
173 time.sleep(.25)
174 ahr = self.client.get_result(ar.msg_ids)
175 self.assertTrue(isinstance(ahr, AsyncHubResult))
176 self.assertEquals(ahr.get().pyout, ar.get().pyout)
177 ar2 = self.client.get_result(ar.msg_ids)
178 self.assertFalse(isinstance(ar2, AsyncHubResult))
179 c.close()
180
162 181 def test_ids_list(self):
163 182 """test client.ids"""
164 183 ids = self.client.ids
165 184 self.assertEquals(ids, self.client._ids)
166 185 self.assertFalse(ids is self.client._ids)
167 186 ids.remove(ids[-1])
168 187 self.assertNotEquals(ids, self.client._ids)
169 188
170 189 def test_queue_status(self):
171 190 ids = self.client.ids
172 191 id0 = ids[0]
173 192 qs = self.client.queue_status(targets=id0)
174 193 self.assertTrue(isinstance(qs, dict))
175 194 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
176 195 allqs = self.client.queue_status()
177 196 self.assertTrue(isinstance(allqs, dict))
178 197 intkeys = list(allqs.keys())
179 198 intkeys.remove('unassigned')
180 199 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
181 200 unassigned = allqs.pop('unassigned')
182 201 for eid,qs in allqs.items():
183 202 self.assertTrue(isinstance(qs, dict))
184 203 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
185 204
186 205 def test_shutdown(self):
187 206 ids = self.client.ids
188 207 id0 = ids[0]
189 208 self.client.shutdown(id0, block=True)
190 209 while id0 in self.client.ids:
191 210 time.sleep(0.1)
192 211 self.client.spin()
193 212
194 213 self.assertRaises(IndexError, lambda : self.client[id0])
195 214
196 215 def test_result_status(self):
197 216 pass
198 217 # to be written
199 218
200 219 def test_db_query_dt(self):
201 220 """test db query by date"""
202 221 hist = self.client.hub_history()
203 222 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
204 223 tic = middle['submitted']
205 224 before = self.client.db_query({'submitted' : {'$lt' : tic}})
206 225 after = self.client.db_query({'submitted' : {'$gte' : tic}})
207 226 self.assertEquals(len(before)+len(after),len(hist))
208 227 for b in before:
209 228 self.assertTrue(b['submitted'] < tic)
210 229 for a in after:
211 230 self.assertTrue(a['submitted'] >= tic)
212 231 same = self.client.db_query({'submitted' : tic})
213 232 for s in same:
214 233 self.assertTrue(s['submitted'] == tic)
215 234
216 235 def test_db_query_keys(self):
217 236 """test extracting subset of record keys"""
218 237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
219 238 for rec in found:
220 239 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
221 240
222 241 def test_db_query_default_keys(self):
223 242 """default db_query excludes buffers"""
224 243 found = self.client.db_query({'msg_id': {'$ne' : ''}})
225 244 for rec in found:
226 245 keys = set(rec.keys())
227 246 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
228 247 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
229 248
230 249 def test_db_query_msg_id(self):
231 250 """ensure msg_id is always in db queries"""
232 251 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
233 252 for rec in found:
234 253 self.assertTrue('msg_id' in rec.keys())
235 254 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
236 255 for rec in found:
237 256 self.assertTrue('msg_id' in rec.keys())
238 257 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
239 258 for rec in found:
240 259 self.assertTrue('msg_id' in rec.keys())
241 260
242 261 def test_db_query_get_result(self):
243 262 """pop in db_query shouldn't pop from result itself"""
244 263 self.client[:].apply_sync(lambda : 1)
245 264 found = self.client.db_query({'msg_id': {'$ne' : ''}})
246 265 rc2 = clientmod.Client(profile='iptest')
247 266 # If this bug is not fixed, this call will hang:
248 267 ar = rc2.get_result(self.client.history[-1])
249 268 ar.wait(2)
250 269 self.assertTrue(ar.ready())
251 270 ar.get()
252 271 rc2.close()
253 272
254 273 def test_db_query_in(self):
255 274 """test db query with '$in','$nin' operators"""
256 275 hist = self.client.hub_history()
257 276 even = hist[::2]
258 277 odd = hist[1::2]
259 278 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
260 279 found = [ r['msg_id'] for r in recs ]
261 280 self.assertEquals(set(even), set(found))
262 281 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
263 282 found = [ r['msg_id'] for r in recs ]
264 283 self.assertEquals(set(odd), set(found))
265 284
266 285 def test_hub_history(self):
267 286 hist = self.client.hub_history()
268 287 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
269 288 recdict = {}
270 289 for rec in recs:
271 290 recdict[rec['msg_id']] = rec
272 291
273 292 latest = datetime(1984,1,1)
274 293 for msg_id in hist:
275 294 rec = recdict[msg_id]
276 295 newt = rec['submitted']
277 296 self.assertTrue(newt >= latest)
278 297 latest = newt
279 298 ar = self.client[-1].apply_async(lambda : 1)
280 299 ar.get()
281 300 time.sleep(0.25)
282 301 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
283 302
284 303 def _wait_for_idle(self):
285 304 """wait for an engine to become idle, according to the Hub"""
286 305 rc = self.client
287 306
288 307 # timeout 5s, polling every 100ms
289 308 qs = rc.queue_status()
290 309 for i in range(50):
291 310 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
292 311 time.sleep(0.1)
293 312 qs = rc.queue_status()
294 313 else:
295 314 break
296 315
297 316 # ensure Hub up to date:
298 317 self.assertEquals(qs['unassigned'], 0)
299 318 for eid in rc.ids:
300 319 self.assertEquals(qs[eid]['tasks'], 0)
301 320
302 321
303 322 def test_resubmit(self):
304 323 def f():
305 324 import random
306 325 return random.random()
307 326 v = self.client.load_balanced_view()
308 327 ar = v.apply_async(f)
309 328 r1 = ar.get(1)
310 329 # give the Hub a chance to notice:
311 330 self._wait_for_idle()
312 331 ahr = self.client.resubmit(ar.msg_ids)
313 332 r2 = ahr.get(1)
314 333 self.assertFalse(r1 == r2)
315 334
316 335 def test_resubmit_chain(self):
317 336 """resubmit resubmitted tasks"""
318 337 v = self.client.load_balanced_view()
319 338 ar = v.apply_async(lambda x: x, 'x'*1024)
320 339 ar.get()
321 340 self._wait_for_idle()
322 341 ars = [ar]
323 342
324 343 for i in range(10):
325 344 ar = ars[-1]
326 345 ar2 = self.client.resubmit(ar.msg_ids)
327 346
328 347 [ ar.get() for ar in ars ]
329 348
330 349 def test_resubmit_header(self):
331 350 """resubmit shouldn't clobber the whole header"""
332 351 def f():
333 352 import random
334 353 return random.random()
335 354 v = self.client.load_balanced_view()
336 355 v.retries = 1
337 356 ar = v.apply_async(f)
338 357 r1 = ar.get(1)
339 358 # give the Hub a chance to notice:
340 359 self._wait_for_idle()
341 360 ahr = self.client.resubmit(ar.msg_ids)
342 361 ahr.get(1)
343 362 time.sleep(0.5)
344 363 records = self.client.db_query({'msg_id': {'$in': ar.msg_ids + ahr.msg_ids}}, keys='header')
345 364 h1,h2 = [ r['header'] for r in records ]
346 365 for key in set(h1.keys()).union(set(h2.keys())):
347 366 if key in ('msg_id', 'date'):
348 367 self.assertNotEquals(h1[key], h2[key])
349 368 else:
350 369 self.assertEquals(h1[key], h2[key])
351 370
352 371 def test_resubmit_aborted(self):
353 372 def f():
354 373 import random
355 374 return random.random()
356 375 v = self.client.load_balanced_view()
357 376 # restrict to one engine, so we can put a sleep
358 377 # ahead of the task, so it will get aborted
359 378 eid = self.client.ids[-1]
360 379 v.targets = [eid]
361 380 sleep = v.apply_async(time.sleep, 0.5)
362 381 ar = v.apply_async(f)
363 382 ar.abort()
364 383 self.assertRaises(error.TaskAborted, ar.get)
365 384 # Give the Hub a chance to get up to date:
366 385 self._wait_for_idle()
367 386 ahr = self.client.resubmit(ar.msg_ids)
368 387 r2 = ahr.get(1)
369 388
370 389 def test_resubmit_inflight(self):
371 390 """resubmit of inflight task"""
372 391 v = self.client.load_balanced_view()
373 392 ar = v.apply_async(time.sleep,1)
374 393 # give the message a chance to arrive
375 394 time.sleep(0.2)
376 395 ahr = self.client.resubmit(ar.msg_ids)
377 396 ar.get(2)
378 397 ahr.get(2)
379 398
380 399 def test_resubmit_badkey(self):
381 400 """ensure KeyError on resubmit of nonexistant task"""
382 401 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
383 402
384 403 def test_purge_results(self):
385 404 # ensure there are some tasks
386 405 for i in range(5):
387 406 self.client[:].apply_sync(lambda : 1)
388 407 # Wait for the Hub to realise the result is done:
389 408 # This prevents a race condition, where we
390 409 # might purge a result the Hub still thinks is pending.
391 410 time.sleep(0.1)
392 411 rc2 = clientmod.Client(profile='iptest')
393 412 hist = self.client.hub_history()
394 413 ahr = rc2.get_result([hist[-1]])
395 414 ahr.wait(10)
396 415 self.client.purge_results(hist[-1])
397 416 newhist = self.client.hub_history()
398 417 self.assertEquals(len(newhist)+1,len(hist))
399 418 rc2.spin()
400 419 rc2.close()
401 420
402 421 def test_purge_all_results(self):
403 422 self.client.purge_results('all')
404 423 hist = self.client.hub_history()
405 424 self.assertEquals(len(hist), 0)
406 425
407 426 def test_spin_thread(self):
408 427 self.client.spin_thread(0.01)
409 428 ar = self.client[-1].apply_async(lambda : 1)
410 429 time.sleep(0.1)
411 430 self.assertTrue(ar.wall_time < 0.1,
412 431 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
413 432 )
414 433
415 434 def test_stop_spin_thread(self):
416 435 self.client.spin_thread(0.01)
417 436 self.client.stop_spin_thread()
418 437 ar = self.client[-1].apply_async(lambda : 1)
419 438 time.sleep(0.15)
420 439 self.assertTrue(ar.wall_time > 0.1,
421 440 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
422 441 )
423 442
424 443 def test_activate(self):
425 444 ip = get_ipython()
426 445 magics = ip.magics_manager.magics
427 446 self.assertTrue('px' in magics['line'])
428 447 self.assertTrue('px' in magics['cell'])
429 448 v0 = self.client.activate(-1, '0')
430 449 self.assertTrue('px0' in magics['line'])
431 450 self.assertTrue('px0' in magics['cell'])
432 451 self.assertEquals(v0.targets, self.client.ids[-1])
433 452 v0 = self.client.activate('all', 'all')
434 453 self.assertTrue('pxall' in magics['line'])
435 454 self.assertTrue('pxall' in magics['cell'])
436 455 self.assertEquals(v0.targets, 'all')
General Comments 0
You need to be logged in to leave comments. Login now