##// END OF EJS Templates
fix purge_results for args other than specified msg_id...
MinRK -
Show More
@@ -1,1418 +1,1422
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import time
21 21 import warnings
22 22 from datetime import datetime
23 23 from getpass import getpass
24 24 from pprint import pprint
25 25
26 26 pjoin = os.path.join
27 27
28 28 import zmq
29 29 # from zmq.eventloop import ioloop, zmqstream
30 30
31 31 from IPython.config.configurable import MultipleInstanceError
32 32 from IPython.core.application import BaseIPythonApplication
33 33
34 34 from IPython.utils.jsonutil import rekey
35 35 from IPython.utils.localinterfaces import LOCAL_IPS
36 36 from IPython.utils.path import get_ipython_dir
37 37 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 38 Dict, List, Bool, Set)
39 39 from IPython.external.decorator import decorator
40 40 from IPython.external.ssh import tunnel
41 41
42 42 from IPython.parallel import error
43 43 from IPython.parallel import util
44 44
45 45 from IPython.zmq.session import Session, Message
46 46
47 47 from .asyncresult import AsyncResult, AsyncHubResult
48 48 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 49 from .view import DirectView, LoadBalancedView
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Decorators for Client methods
53 53 #--------------------------------------------------------------------------
54 54
55 55 @decorator
56 56 def spin_first(f, self, *args, **kwargs):
57 57 """Call spin() to sync state prior to calling the method."""
58 58 self.spin()
59 59 return f(self, *args, **kwargs)
60 60
61 61
62 62 #--------------------------------------------------------------------------
63 63 # Classes
64 64 #--------------------------------------------------------------------------
65 65
66 66 class Metadata(dict):
67 67 """Subclass of dict for initializing metadata values.
68 68
69 69 Attribute access works on keys.
70 70
71 71 These objects have a strict set of keys - errors will raise if you try
72 72 to add new keys.
73 73 """
74 74 def __init__(self, *args, **kwargs):
75 75 dict.__init__(self)
76 76 md = {'msg_id' : None,
77 77 'submitted' : None,
78 78 'started' : None,
79 79 'completed' : None,
80 80 'received' : None,
81 81 'engine_uuid' : None,
82 82 'engine_id' : None,
83 83 'follow' : None,
84 84 'after' : None,
85 85 'status' : None,
86 86
87 87 'pyin' : None,
88 88 'pyout' : None,
89 89 'pyerr' : None,
90 90 'stdout' : '',
91 91 'stderr' : '',
92 92 }
93 93 self.update(md)
94 94 self.update(dict(*args, **kwargs))
95 95
96 96 def __getattr__(self, key):
97 97 """getattr aliased to getitem"""
98 98 if key in self.iterkeys():
99 99 return self[key]
100 100 else:
101 101 raise AttributeError(key)
102 102
103 103 def __setattr__(self, key, value):
104 104 """setattr aliased to setitem, with strict"""
105 105 if key in self.iterkeys():
106 106 self[key] = value
107 107 else:
108 108 raise AttributeError(key)
109 109
110 110 def __setitem__(self, key, value):
111 111 """strict static key enforcement"""
112 112 if key in self.iterkeys():
113 113 dict.__setitem__(self, key, value)
114 114 else:
115 115 raise KeyError(key)
116 116
117 117
118 118 class Client(HasTraits):
119 119 """A semi-synchronous client to the IPython ZMQ cluster
120 120
121 121 Parameters
122 122 ----------
123 123
124 124 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
125 125 Connection information for the Hub's registration. If a json connector
126 126 file is given, then likely no further configuration is necessary.
127 127 [Default: use profile]
128 128 profile : bytes
129 129 The name of the Cluster profile to be used to find connector information.
130 130 If run from an IPython application, the default profile will be the same
131 131 as the running application, otherwise it will be 'default'.
132 132 context : zmq.Context
133 133 Pass an existing zmq.Context instance, otherwise the client will create its own.
134 134 debug : bool
135 135 flag for lots of message printing for debug purposes
136 136 timeout : int/float
137 137 time (in seconds) to wait for connection replies from the Hub
138 138 [Default: 10]
139 139
140 140 #-------------- session related args ----------------
141 141
142 142 config : Config object
143 143 If specified, this will be relayed to the Session for configuration
144 144 username : str
145 145 set username for the session object
146 146 packer : str (import_string) or callable
147 147 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
148 148 function to serialize messages. Must support same input as
149 149 JSON, and output must be bytes.
150 150 You can pass a callable directly as `pack`
151 151 unpacker : str (import_string) or callable
152 152 The inverse of packer. Only necessary if packer is specified as *not* one
153 153 of 'json' or 'pickle'.
154 154
155 155 #-------------- ssh related args ----------------
156 156 # These are args for configuring the ssh tunnel to be used
157 157 # credentials are used to forward connections over ssh to the Controller
158 158 # Note that the ip given in `addr` needs to be relative to sshserver
159 159 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
160 160 # and set sshserver as the same machine the Controller is on. However,
161 161 # the only requirement is that sshserver is able to see the Controller
162 162 # (i.e. is within the same trusted network).
163 163
164 164 sshserver : str
165 165 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
166 166 If keyfile or password is specified, and this is not, it will default to
167 167 the ip given in addr.
168 168 sshkey : str; path to public ssh key file
169 169 This specifies a key to be used in ssh login, default None.
170 170 Regular default ssh keys will be used without specifying this argument.
171 171 password : str
172 172 Your ssh password to sshserver. Note that if this is left None,
173 173 you will be prompted for it if passwordless key based login is unavailable.
174 174 paramiko : bool
175 175 flag for whether to use paramiko instead of shell ssh for tunneling.
176 176 [default: True on win32, False else]
177 177
178 178 ------- exec authentication args -------
179 179 If even localhost is untrusted, you can have some protection against
180 180 unauthorized execution by signing messages with HMAC digests.
181 181 Messages are still sent as cleartext, so if someone can snoop your
182 182 loopback traffic this will not protect your privacy, but will prevent
183 183 unauthorized execution.
184 184
185 185 exec_key : str
186 186 an authentication key or file containing a key
187 187 default: None
188 188
189 189
190 190 Attributes
191 191 ----------
192 192
193 193 ids : list of int engine IDs
194 194 requesting the ids attribute always synchronizes
195 195 the registration state. To request ids without synchronization,
196 196 use semi-private _ids attributes.
197 197
198 198 history : list of msg_ids
199 199 a list of msg_ids, keeping track of all the execution
200 200 messages you have submitted in order.
201 201
202 202 outstanding : set of msg_ids
203 203 a set of msg_ids that have been submitted, but whose
204 204 results have not yet been received.
205 205
206 206 results : dict
207 207 a dict of all our results, keyed by msg_id
208 208
209 209 block : bool
210 210 determines default behavior when block not specified
211 211 in execution methods
212 212
213 213 Methods
214 214 -------
215 215
216 216 spin
217 217 flushes incoming results and registration state changes
218 218 control methods spin, and requesting `ids` also ensures up to date
219 219
220 220 wait
221 221 wait on one or more msg_ids
222 222
223 223 execution methods
224 224 apply
225 225 legacy: execute, run
226 226
227 227 data movement
228 228 push, pull, scatter, gather
229 229
230 230 query methods
231 231 queue_status, get_result, purge, result_status
232 232
233 233 control methods
234 234 abort, shutdown
235 235
236 236 """
237 237
238 238
239 239 block = Bool(False)
240 240 outstanding = Set()
241 241 results = Instance('collections.defaultdict', (dict,))
242 242 metadata = Instance('collections.defaultdict', (Metadata,))
243 243 history = List()
244 244 debug = Bool(False)
245 245
246 246 profile=Unicode()
247 247 def _profile_default(self):
248 248 if BaseIPythonApplication.initialized():
249 249 # an IPython app *might* be running, try to get its profile
250 250 try:
251 251 return BaseIPythonApplication.instance().profile
252 252 except (AttributeError, MultipleInstanceError):
253 253 # could be a *different* subclass of config.Application,
254 254 # which would raise one of these two errors.
255 255 return u'default'
256 256 else:
257 257 return u'default'
258 258
259 259
260 260 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 261 _ids = List()
262 262 _connected=Bool(False)
263 263 _ssh=Bool(False)
264 264 _context = Instance('zmq.Context')
265 265 _config = Dict()
266 266 _engines=Instance(util.ReverseDict, (), {})
267 267 # _hub_socket=Instance('zmq.Socket')
268 268 _query_socket=Instance('zmq.Socket')
269 269 _control_socket=Instance('zmq.Socket')
270 270 _iopub_socket=Instance('zmq.Socket')
271 271 _notification_socket=Instance('zmq.Socket')
272 272 _mux_socket=Instance('zmq.Socket')
273 273 _task_socket=Instance('zmq.Socket')
274 274 _task_scheme=Unicode()
275 275 _closed = False
276 276 _ignored_control_replies=Int(0)
277 277 _ignored_hub_replies=Int(0)
278 278
279 279 def __new__(self, *args, **kw):
280 280 # don't raise on positional args
281 281 return HasTraits.__new__(self, **kw)
282 282
283 283 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
284 284 context=None, debug=False, exec_key=None,
285 285 sshserver=None, sshkey=None, password=None, paramiko=None,
286 286 timeout=10, **extra_args
287 287 ):
288 288 if profile:
289 289 super(Client, self).__init__(debug=debug, profile=profile)
290 290 else:
291 291 super(Client, self).__init__(debug=debug)
292 292 if context is None:
293 293 context = zmq.Context.instance()
294 294 self._context = context
295 295
296 296 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
297 297 if self._cd is not None:
298 298 if url_or_file is None:
299 299 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
300 300 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
301 301 " Please specify at least one of url_or_file or profile."
302 302
303 303 try:
304 304 util.validate_url(url_or_file)
305 305 except AssertionError:
306 306 if not os.path.exists(url_or_file):
307 307 if self._cd:
308 308 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
309 309 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
310 310 with open(url_or_file) as f:
311 311 cfg = json.loads(f.read())
312 312 else:
313 313 cfg = {'url':url_or_file}
314 314
315 315 # sync defaults from args, json:
316 316 if sshserver:
317 317 cfg['ssh'] = sshserver
318 318 if exec_key:
319 319 cfg['exec_key'] = exec_key
320 320 exec_key = cfg['exec_key']
321 321 location = cfg.setdefault('location', None)
322 322 cfg['url'] = util.disambiguate_url(cfg['url'], location)
323 323 url = cfg['url']
324 324 proto,addr,port = util.split_url(url)
325 325 if location is not None and addr == '127.0.0.1':
326 326 # location specified, and connection is expected to be local
327 327 if location not in LOCAL_IPS and not sshserver:
328 328 # load ssh from JSON *only* if the controller is not on
329 329 # this machine
330 330 sshserver=cfg['ssh']
331 331 if location not in LOCAL_IPS and not sshserver:
332 332 # warn if no ssh specified, but SSH is probably needed
333 333 # This is only a warning, because the most likely cause
334 334 # is a local Controller on a laptop whose IP is dynamic
335 335 warnings.warn("""
336 336 Controller appears to be listening on localhost, but not on this machine.
337 337 If this is true, you should specify Client(...,sshserver='you@%s')
338 338 or instruct your controller to listen on an external IP."""%location,
339 339 RuntimeWarning)
340 340
341 341 self._config = cfg
342 342
343 343 self._ssh = bool(sshserver or sshkey or password)
344 344 if self._ssh and sshserver is None:
345 345 # default to ssh via localhost
346 346 sshserver = url.split('://')[1].split(':')[0]
347 347 if self._ssh and password is None:
348 348 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
349 349 password=False
350 350 else:
351 351 password = getpass("SSH Password for %s: "%sshserver)
352 352 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
353 353
354 354 # configure and construct the session
355 355 if exec_key is not None:
356 356 if os.path.isfile(exec_key):
357 357 extra_args['keyfile'] = exec_key
358 358 else:
359 359 if isinstance(exec_key, unicode):
360 360 exec_key = exec_key.encode('ascii')
361 361 extra_args['key'] = exec_key
362 362 self.session = Session(**extra_args)
363 363
364 364 self._query_socket = self._context.socket(zmq.XREQ)
365 365 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
366 366 if self._ssh:
367 367 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
368 368 else:
369 369 self._query_socket.connect(url)
370 370
371 371 self.session.debug = self.debug
372 372
373 373 self._notification_handlers = {'registration_notification' : self._register_engine,
374 374 'unregistration_notification' : self._unregister_engine,
375 375 'shutdown_notification' : lambda msg: self.close(),
376 376 }
377 377 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
378 378 'apply_reply' : self._handle_apply_reply}
379 379 self._connect(sshserver, ssh_kwargs, timeout)
380 380
381 381 def __del__(self):
382 382 """cleanup sockets, but _not_ context."""
383 383 self.close()
384 384
385 385 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
386 386 if ipython_dir is None:
387 387 ipython_dir = get_ipython_dir()
388 388 if profile_dir is not None:
389 389 try:
390 390 self._cd = ProfileDir.find_profile_dir(profile_dir)
391 391 return
392 392 except ProfileDirError:
393 393 pass
394 394 elif profile is not None:
395 395 try:
396 396 self._cd = ProfileDir.find_profile_dir_by_name(
397 397 ipython_dir, profile)
398 398 return
399 399 except ProfileDirError:
400 400 pass
401 401 self._cd = None
402 402
403 403 def _update_engines(self, engines):
404 404 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
405 405 for k,v in engines.iteritems():
406 406 eid = int(k)
407 407 self._engines[eid] = bytes(v) # force not unicode
408 408 self._ids.append(eid)
409 409 self._ids = sorted(self._ids)
410 410 if sorted(self._engines.keys()) != range(len(self._engines)) and \
411 411 self._task_scheme == 'pure' and self._task_socket:
412 412 self._stop_scheduling_tasks()
413 413
414 414 def _stop_scheduling_tasks(self):
415 415 """Stop scheduling tasks because an engine has been unregistered
416 416 from a pure ZMQ scheduler.
417 417 """
418 418 self._task_socket.close()
419 419 self._task_socket = None
420 420 msg = "An engine has been unregistered, and we are using pure " +\
421 421 "ZMQ task scheduling. Task farming will be disabled."
422 422 if self.outstanding:
423 423 msg += " If you were running tasks when this happened, " +\
424 424 "some `outstanding` msg_ids may never resolve."
425 425 warnings.warn(msg, RuntimeWarning)
426 426
427 427 def _build_targets(self, targets):
428 428 """Turn valid target IDs or 'all' into two lists:
429 429 (int_ids, uuids).
430 430 """
431 431 if not self._ids:
432 432 # flush notification socket if no engines yet, just in case
433 433 if not self.ids:
434 434 raise error.NoEnginesRegistered("Can't build targets without any engines")
435 435
436 436 if targets is None:
437 437 targets = self._ids
438 438 elif isinstance(targets, str):
439 439 if targets.lower() == 'all':
440 440 targets = self._ids
441 441 else:
442 442 raise TypeError("%r not valid str target, must be 'all'"%(targets))
443 443 elif isinstance(targets, int):
444 444 if targets < 0:
445 445 targets = self.ids[targets]
446 446 if targets not in self._ids:
447 447 raise IndexError("No such engine: %i"%targets)
448 448 targets = [targets]
449 449
450 450 if isinstance(targets, slice):
451 451 indices = range(len(self._ids))[targets]
452 452 ids = self.ids
453 453 targets = [ ids[i] for i in indices ]
454 454
455 455 if not isinstance(targets, (tuple, list, xrange)):
456 456 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
457 457
458 458 return [self._engines[t] for t in targets], list(targets)
459 459
460 460 def _connect(self, sshserver, ssh_kwargs, timeout):
461 461 """setup all our socket connections to the cluster. This is called from
462 462 __init__."""
463 463
464 464 # Maybe allow reconnecting?
465 465 if self._connected:
466 466 return
467 467 self._connected=True
468 468
469 469 def connect_socket(s, url):
470 470 url = util.disambiguate_url(url, self._config['location'])
471 471 if self._ssh:
472 472 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
473 473 else:
474 474 return s.connect(url)
475 475
476 476 self.session.send(self._query_socket, 'connection_request')
477 477 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
478 478 poller = zmq.Poller()
479 479 poller.register(self._query_socket, zmq.POLLIN)
480 480 # poll expects milliseconds, timeout is seconds
481 481 evts = poller.poll(timeout*1000)
482 482 if not evts:
483 483 raise error.TimeoutError("Hub connection request timed out")
484 484 idents,msg = self.session.recv(self._query_socket,mode=0)
485 485 if self.debug:
486 486 pprint(msg)
487 487 msg = Message(msg)
488 488 content = msg.content
489 489 self._config['registration'] = dict(content)
490 490 if content.status == 'ok':
491 491 if content.mux:
492 492 self._mux_socket = self._context.socket(zmq.XREQ)
493 493 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
494 494 connect_socket(self._mux_socket, content.mux)
495 495 if content.task:
496 496 self._task_scheme, task_addr = content.task
497 497 self._task_socket = self._context.socket(zmq.XREQ)
498 498 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
499 499 connect_socket(self._task_socket, task_addr)
500 500 if content.notification:
501 501 self._notification_socket = self._context.socket(zmq.SUB)
502 502 connect_socket(self._notification_socket, content.notification)
503 503 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
504 504 # if content.query:
505 505 # self._query_socket = self._context.socket(zmq.XREQ)
506 506 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
507 507 # connect_socket(self._query_socket, content.query)
508 508 if content.control:
509 509 self._control_socket = self._context.socket(zmq.XREQ)
510 510 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
511 511 connect_socket(self._control_socket, content.control)
512 512 if content.iopub:
513 513 self._iopub_socket = self._context.socket(zmq.SUB)
514 514 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
515 515 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 516 connect_socket(self._iopub_socket, content.iopub)
517 517 self._update_engines(dict(content.engines))
518 518 else:
519 519 self._connected = False
520 520 raise Exception("Failed to connect!")
521 521
522 522 #--------------------------------------------------------------------------
523 523 # handlers and callbacks for incoming messages
524 524 #--------------------------------------------------------------------------
525 525
526 526 def _unwrap_exception(self, content):
527 527 """unwrap exception, and remap engine_id to int."""
528 528 e = error.unwrap_exception(content)
529 529 # print e.traceback
530 530 if e.engine_info:
531 531 e_uuid = e.engine_info['engine_uuid']
532 532 eid = self._engines[e_uuid]
533 533 e.engine_info['engine_id'] = eid
534 534 return e
535 535
536 536 def _extract_metadata(self, header, parent, content):
537 537 md = {'msg_id' : parent['msg_id'],
538 538 'received' : datetime.now(),
539 539 'engine_uuid' : header.get('engine', None),
540 540 'follow' : parent.get('follow', []),
541 541 'after' : parent.get('after', []),
542 542 'status' : content['status'],
543 543 }
544 544
545 545 if md['engine_uuid'] is not None:
546 546 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
547 547
548 548 if 'date' in parent:
549 549 md['submitted'] = parent['date']
550 550 if 'started' in header:
551 551 md['started'] = header['started']
552 552 if 'date' in header:
553 553 md['completed'] = header['date']
554 554 return md
555 555
556 556 def _register_engine(self, msg):
557 557 """Register a new engine, and update our connection info."""
558 558 content = msg['content']
559 559 eid = content['id']
560 560 d = {eid : content['queue']}
561 561 self._update_engines(d)
562 562
563 563 def _unregister_engine(self, msg):
564 564 """Unregister an engine that has died."""
565 565 content = msg['content']
566 566 eid = int(content['id'])
567 567 if eid in self._ids:
568 568 self._ids.remove(eid)
569 569 uuid = self._engines.pop(eid)
570 570
571 571 self._handle_stranded_msgs(eid, uuid)
572 572
573 573 if self._task_socket and self._task_scheme == 'pure':
574 574 self._stop_scheduling_tasks()
575 575
576 576 def _handle_stranded_msgs(self, eid, uuid):
577 577 """Handle messages known to be on an engine when the engine unregisters.
578 578
579 579 It is possible that this will fire prematurely - that is, an engine will
580 580 go down after completing a result, and the client will be notified
581 581 of the unregistration and later receive the successful result.
582 582 """
583 583
584 584 outstanding = self._outstanding_dict[uuid]
585 585
586 586 for msg_id in list(outstanding):
587 587 if msg_id in self.results:
588 588 # we already
589 589 continue
590 590 try:
591 591 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
592 592 except:
593 593 content = error.wrap_exception()
594 594 # build a fake message:
595 595 parent = {}
596 596 header = {}
597 597 parent['msg_id'] = msg_id
598 598 header['engine'] = uuid
599 599 header['date'] = datetime.now()
600 600 msg = dict(parent_header=parent, header=header, content=content)
601 601 self._handle_apply_reply(msg)
602 602
603 603 def _handle_execute_reply(self, msg):
604 604 """Save the reply to an execute_request into our results.
605 605
606 606 execute messages are never actually used. apply is used instead.
607 607 """
608 608
609 609 parent = msg['parent_header']
610 610 msg_id = parent['msg_id']
611 611 if msg_id not in self.outstanding:
612 612 if msg_id in self.history:
613 613 print ("got stale result: %s"%msg_id)
614 614 else:
615 615 print ("got unknown result: %s"%msg_id)
616 616 else:
617 617 self.outstanding.remove(msg_id)
618 618 self.results[msg_id] = self._unwrap_exception(msg['content'])
619 619
620 620 def _handle_apply_reply(self, msg):
621 621 """Save the reply to an apply_request into our results."""
622 622 parent = msg['parent_header']
623 623 msg_id = parent['msg_id']
624 624 if msg_id not in self.outstanding:
625 625 if msg_id in self.history:
626 626 print ("got stale result: %s"%msg_id)
627 627 print self.results[msg_id]
628 628 print msg
629 629 else:
630 630 print ("got unknown result: %s"%msg_id)
631 631 else:
632 632 self.outstanding.remove(msg_id)
633 633 content = msg['content']
634 634 header = msg['header']
635 635
636 636 # construct metadata:
637 637 md = self.metadata[msg_id]
638 638 md.update(self._extract_metadata(header, parent, content))
639 639 # is this redundant?
640 640 self.metadata[msg_id] = md
641 641
642 642 e_outstanding = self._outstanding_dict[md['engine_uuid']]
643 643 if msg_id in e_outstanding:
644 644 e_outstanding.remove(msg_id)
645 645
646 646 # construct result:
647 647 if content['status'] == 'ok':
648 648 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
649 649 elif content['status'] == 'aborted':
650 650 self.results[msg_id] = error.TaskAborted(msg_id)
651 651 elif content['status'] == 'resubmitted':
652 652 # TODO: handle resubmission
653 653 pass
654 654 else:
655 655 self.results[msg_id] = self._unwrap_exception(content)
656 656
657 657 def _flush_notifications(self):
658 658 """Flush notifications of engine registrations waiting
659 659 in ZMQ queue."""
660 660 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
661 661 while msg is not None:
662 662 if self.debug:
663 663 pprint(msg)
664 664 msg_type = msg['msg_type']
665 665 handler = self._notification_handlers.get(msg_type, None)
666 666 if handler is None:
667 667 raise Exception("Unhandled message type: %s"%msg.msg_type)
668 668 else:
669 669 handler(msg)
670 670 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
671 671
672 672 def _flush_results(self, sock):
673 673 """Flush task or queue results waiting in ZMQ queue."""
674 674 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
675 675 while msg is not None:
676 676 if self.debug:
677 677 pprint(msg)
678 678 msg_type = msg['msg_type']
679 679 handler = self._queue_handlers.get(msg_type, None)
680 680 if handler is None:
681 681 raise Exception("Unhandled message type: %s"%msg.msg_type)
682 682 else:
683 683 handler(msg)
684 684 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
685 685
686 686 def _flush_control(self, sock):
687 687 """Flush replies from the control channel waiting
688 688 in the ZMQ queue.
689 689
690 690 Currently: ignore them."""
691 691 if self._ignored_control_replies <= 0:
692 692 return
693 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 694 while msg is not None:
695 695 self._ignored_control_replies -= 1
696 696 if self.debug:
697 697 pprint(msg)
698 698 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
699 699
700 700 def _flush_ignored_control(self):
701 701 """flush ignored control replies"""
702 702 while self._ignored_control_replies > 0:
703 703 self.session.recv(self._control_socket)
704 704 self._ignored_control_replies -= 1
705 705
706 706 def _flush_ignored_hub_replies(self):
707 707 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
708 708 while msg is not None:
709 709 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
710 710
711 711 def _flush_iopub(self, sock):
712 712 """Flush replies from the iopub channel waiting
713 713 in the ZMQ queue.
714 714 """
715 715 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
716 716 while msg is not None:
717 717 if self.debug:
718 718 pprint(msg)
719 719 parent = msg['parent_header']
720 720 msg_id = parent['msg_id']
721 721 content = msg['content']
722 722 header = msg['header']
723 723 msg_type = msg['msg_type']
724 724
725 725 # init metadata:
726 726 md = self.metadata[msg_id]
727 727
728 728 if msg_type == 'stream':
729 729 name = content['name']
730 730 s = md[name] or ''
731 731 md[name] = s + content['data']
732 732 elif msg_type == 'pyerr':
733 733 md.update({'pyerr' : self._unwrap_exception(content)})
734 734 elif msg_type == 'pyin':
735 735 md.update({'pyin' : content['code']})
736 736 else:
737 737 md.update({msg_type : content.get('data', '')})
738 738
739 739 # reduntant?
740 740 self.metadata[msg_id] = md
741 741
742 742 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
743 743
744 744 #--------------------------------------------------------------------------
745 745 # len, getitem
746 746 #--------------------------------------------------------------------------
747 747
748 748 def __len__(self):
749 749 """len(client) returns # of engines."""
750 750 return len(self.ids)
751 751
752 752 def __getitem__(self, key):
753 753 """index access returns DirectView multiplexer objects
754 754
755 755 Must be int, slice, or list/tuple/xrange of ints"""
756 756 if not isinstance(key, (int, slice, tuple, list, xrange)):
757 757 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
758 758 else:
759 759 return self.direct_view(key)
760 760
761 761 #--------------------------------------------------------------------------
762 762 # Begin public methods
763 763 #--------------------------------------------------------------------------
764 764
765 765 @property
766 766 def ids(self):
767 767 """Always up-to-date ids property."""
768 768 self._flush_notifications()
769 769 # always copy:
770 770 return list(self._ids)
771 771
772 772 def close(self):
773 773 if self._closed:
774 774 return
775 775 snames = filter(lambda n: n.endswith('socket'), dir(self))
776 776 for socket in map(lambda name: getattr(self, name), snames):
777 777 if isinstance(socket, zmq.Socket) and not socket.closed:
778 778 socket.close()
779 779 self._closed = True
780 780
781 781 def spin(self):
782 782 """Flush any registration notifications and execution results
783 783 waiting in the ZMQ queue.
784 784 """
785 785 if self._notification_socket:
786 786 self._flush_notifications()
787 787 if self._mux_socket:
788 788 self._flush_results(self._mux_socket)
789 789 if self._task_socket:
790 790 self._flush_results(self._task_socket)
791 791 if self._control_socket:
792 792 self._flush_control(self._control_socket)
793 793 if self._iopub_socket:
794 794 self._flush_iopub(self._iopub_socket)
795 795 if self._query_socket:
796 796 self._flush_ignored_hub_replies()
797 797
798 798 def wait(self, jobs=None, timeout=-1):
799 799 """waits on one or more `jobs`, for up to `timeout` seconds.
800 800
801 801 Parameters
802 802 ----------
803 803
804 804 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
805 805 ints are indices to self.history
806 806 strs are msg_ids
807 807 default: wait on all outstanding messages
808 808 timeout : float
809 809 a time in seconds, after which to give up.
810 810 default is -1, which means no timeout
811 811
812 812 Returns
813 813 -------
814 814
815 815 True : when all msg_ids are done
816 816 False : timeout reached, some msg_ids still outstanding
817 817 """
818 818 tic = time.time()
819 819 if jobs is None:
820 820 theids = self.outstanding
821 821 else:
822 822 if isinstance(jobs, (int, str, AsyncResult)):
823 823 jobs = [jobs]
824 824 theids = set()
825 825 for job in jobs:
826 826 if isinstance(job, int):
827 827 # index access
828 828 job = self.history[job]
829 829 elif isinstance(job, AsyncResult):
830 830 map(theids.add, job.msg_ids)
831 831 continue
832 832 theids.add(job)
833 833 if not theids.intersection(self.outstanding):
834 834 return True
835 835 self.spin()
836 836 while theids.intersection(self.outstanding):
837 837 if timeout >= 0 and ( time.time()-tic ) > timeout:
838 838 break
839 839 time.sleep(1e-3)
840 840 self.spin()
841 841 return len(theids.intersection(self.outstanding)) == 0
842 842
843 843 #--------------------------------------------------------------------------
844 844 # Control methods
845 845 #--------------------------------------------------------------------------
846 846
847 847 @spin_first
848 848 def clear(self, targets=None, block=None):
849 849 """Clear the namespace in target(s)."""
850 850 block = self.block if block is None else block
851 851 targets = self._build_targets(targets)[0]
852 852 for t in targets:
853 853 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
854 854 error = False
855 855 if block:
856 856 self._flush_ignored_control()
857 857 for i in range(len(targets)):
858 858 idents,msg = self.session.recv(self._control_socket,0)
859 859 if self.debug:
860 860 pprint(msg)
861 861 if msg['content']['status'] != 'ok':
862 862 error = self._unwrap_exception(msg['content'])
863 863 else:
864 864 self._ignored_control_replies += len(targets)
865 865 if error:
866 866 raise error
867 867
868 868
869 869 @spin_first
870 870 def abort(self, jobs=None, targets=None, block=None):
871 871 """Abort specific jobs from the execution queues of target(s).
872 872
873 873 This is a mechanism to prevent jobs that have already been submitted
874 874 from executing.
875 875
876 876 Parameters
877 877 ----------
878 878
879 879 jobs : msg_id, list of msg_ids, or AsyncResult
880 880 The jobs to be aborted
881 881
882 882
883 883 """
884 884 block = self.block if block is None else block
885 885 targets = self._build_targets(targets)[0]
886 886 msg_ids = []
887 887 if isinstance(jobs, (basestring,AsyncResult)):
888 888 jobs = [jobs]
889 889 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
890 890 if bad_ids:
891 891 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
892 892 for j in jobs:
893 893 if isinstance(j, AsyncResult):
894 894 msg_ids.extend(j.msg_ids)
895 895 else:
896 896 msg_ids.append(j)
897 897 content = dict(msg_ids=msg_ids)
898 898 for t in targets:
899 899 self.session.send(self._control_socket, 'abort_request',
900 900 content=content, ident=t)
901 901 error = False
902 902 if block:
903 903 self._flush_ignored_control()
904 904 for i in range(len(targets)):
905 905 idents,msg = self.session.recv(self._control_socket,0)
906 906 if self.debug:
907 907 pprint(msg)
908 908 if msg['content']['status'] != 'ok':
909 909 error = self._unwrap_exception(msg['content'])
910 910 else:
911 911 self._ignored_control_replies += len(targets)
912 912 if error:
913 913 raise error
914 914
915 915 @spin_first
916 916 def shutdown(self, targets=None, restart=False, hub=False, block=None):
917 917 """Terminates one or more engine processes, optionally including the hub."""
918 918 block = self.block if block is None else block
919 919 if hub:
920 920 targets = 'all'
921 921 targets = self._build_targets(targets)[0]
922 922 for t in targets:
923 923 self.session.send(self._control_socket, 'shutdown_request',
924 924 content={'restart':restart},ident=t)
925 925 error = False
926 926 if block or hub:
927 927 self._flush_ignored_control()
928 928 for i in range(len(targets)):
929 929 idents,msg = self.session.recv(self._control_socket, 0)
930 930 if self.debug:
931 931 pprint(msg)
932 932 if msg['content']['status'] != 'ok':
933 933 error = self._unwrap_exception(msg['content'])
934 934 else:
935 935 self._ignored_control_replies += len(targets)
936 936
937 937 if hub:
938 938 time.sleep(0.25)
939 939 self.session.send(self._query_socket, 'shutdown_request')
940 940 idents,msg = self.session.recv(self._query_socket, 0)
941 941 if self.debug:
942 942 pprint(msg)
943 943 if msg['content']['status'] != 'ok':
944 944 error = self._unwrap_exception(msg['content'])
945 945
946 946 if error:
947 947 raise error
948 948
949 949 #--------------------------------------------------------------------------
950 950 # Execution related methods
951 951 #--------------------------------------------------------------------------
952 952
953 953 def _maybe_raise(self, result):
954 954 """wrapper for maybe raising an exception if apply failed."""
955 955 if isinstance(result, error.RemoteError):
956 956 raise result
957 957
958 958 return result
959 959
960 960 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
961 961 ident=None):
962 962 """construct and send an apply message via a socket.
963 963
964 964 This is the principal method with which all engine execution is performed by views.
965 965 """
966 966
967 967 assert not self._closed, "cannot use me anymore, I'm closed!"
968 968 # defaults:
969 969 args = args if args is not None else []
970 970 kwargs = kwargs if kwargs is not None else {}
971 971 subheader = subheader if subheader is not None else {}
972 972
973 973 # validate arguments
974 974 if not callable(f):
975 975 raise TypeError("f must be callable, not %s"%type(f))
976 976 if not isinstance(args, (tuple, list)):
977 977 raise TypeError("args must be tuple or list, not %s"%type(args))
978 978 if not isinstance(kwargs, dict):
979 979 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
980 980 if not isinstance(subheader, dict):
981 981 raise TypeError("subheader must be dict, not %s"%type(subheader))
982 982
983 983 bufs = util.pack_apply_message(f,args,kwargs)
984 984
985 985 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
986 986 subheader=subheader, track=track)
987 987
988 988 msg_id = msg['msg_id']
989 989 self.outstanding.add(msg_id)
990 990 if ident:
991 991 # possibly routed to a specific engine
992 992 if isinstance(ident, list):
993 993 ident = ident[-1]
994 994 if ident in self._engines.values():
995 995 # save for later, in case of engine death
996 996 self._outstanding_dict[ident].add(msg_id)
997 997 self.history.append(msg_id)
998 998 self.metadata[msg_id]['submitted'] = datetime.now()
999 999
1000 1000 return msg
1001 1001
1002 1002 #--------------------------------------------------------------------------
1003 1003 # construct a View object
1004 1004 #--------------------------------------------------------------------------
1005 1005
1006 1006 def load_balanced_view(self, targets=None):
1007 1007 """construct a DirectView object.
1008 1008
1009 1009 If no arguments are specified, create a LoadBalancedView
1010 1010 using all engines.
1011 1011
1012 1012 Parameters
1013 1013 ----------
1014 1014
1015 1015 targets: list,slice,int,etc. [default: use all engines]
1016 1016 The subset of engines across which to load-balance
1017 1017 """
1018 1018 if targets is not None:
1019 1019 targets = self._build_targets(targets)[1]
1020 1020 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1021 1021
1022 1022 def direct_view(self, targets='all'):
1023 1023 """construct a DirectView object.
1024 1024
1025 1025 If no targets are specified, create a DirectView
1026 1026 using all engines.
1027 1027
1028 1028 Parameters
1029 1029 ----------
1030 1030
1031 1031 targets: list,slice,int,etc. [default: use all engines]
1032 1032 The engines to use for the View
1033 1033 """
1034 1034 single = isinstance(targets, int)
1035 1035 targets = self._build_targets(targets)[1]
1036 1036 if single:
1037 1037 targets = targets[0]
1038 1038 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1039 1039
1040 1040 #--------------------------------------------------------------------------
1041 1041 # Query methods
1042 1042 #--------------------------------------------------------------------------
1043 1043
1044 1044 @spin_first
1045 1045 def get_result(self, indices_or_msg_ids=None, block=None):
1046 1046 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1047 1047
1048 1048 If the client already has the results, no request to the Hub will be made.
1049 1049
1050 1050 This is a convenient way to construct AsyncResult objects, which are wrappers
1051 1051 that include metadata about execution, and allow for awaiting results that
1052 1052 were not submitted by this Client.
1053 1053
1054 1054 It can also be a convenient way to retrieve the metadata associated with
1055 1055 blocking execution, since it always retrieves
1056 1056
1057 1057 Examples
1058 1058 --------
1059 1059 ::
1060 1060
1061 1061 In [10]: r = client.apply()
1062 1062
1063 1063 Parameters
1064 1064 ----------
1065 1065
1066 1066 indices_or_msg_ids : integer history index, str msg_id, or list of either
1067 1067 The indices or msg_ids of indices to be retrieved
1068 1068
1069 1069 block : bool
1070 1070 Whether to wait for the result to be done
1071 1071
1072 1072 Returns
1073 1073 -------
1074 1074
1075 1075 AsyncResult
1076 1076 A single AsyncResult object will always be returned.
1077 1077
1078 1078 AsyncHubResult
1079 1079 A subclass of AsyncResult that retrieves results from the Hub
1080 1080
1081 1081 """
1082 1082 block = self.block if block is None else block
1083 1083 if indices_or_msg_ids is None:
1084 1084 indices_or_msg_ids = -1
1085 1085
1086 1086 if not isinstance(indices_or_msg_ids, (list,tuple)):
1087 1087 indices_or_msg_ids = [indices_or_msg_ids]
1088 1088
1089 1089 theids = []
1090 1090 for id in indices_or_msg_ids:
1091 1091 if isinstance(id, int):
1092 1092 id = self.history[id]
1093 1093 if not isinstance(id, str):
1094 1094 raise TypeError("indices must be str or int, not %r"%id)
1095 1095 theids.append(id)
1096 1096
1097 1097 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1098 1098 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1099 1099
1100 1100 if remote_ids:
1101 1101 ar = AsyncHubResult(self, msg_ids=theids)
1102 1102 else:
1103 1103 ar = AsyncResult(self, msg_ids=theids)
1104 1104
1105 1105 if block:
1106 1106 ar.wait()
1107 1107
1108 1108 return ar
1109 1109
1110 1110 @spin_first
1111 1111 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1112 1112 """Resubmit one or more tasks.
1113 1113
1114 1114 in-flight tasks may not be resubmitted.
1115 1115
1116 1116 Parameters
1117 1117 ----------
1118 1118
1119 1119 indices_or_msg_ids : integer history index, str msg_id, or list of either
1120 1120 The indices or msg_ids of indices to be retrieved
1121 1121
1122 1122 block : bool
1123 1123 Whether to wait for the result to be done
1124 1124
1125 1125 Returns
1126 1126 -------
1127 1127
1128 1128 AsyncHubResult
1129 1129 A subclass of AsyncResult that retrieves results from the Hub
1130 1130
1131 1131 """
1132 1132 block = self.block if block is None else block
1133 1133 if indices_or_msg_ids is None:
1134 1134 indices_or_msg_ids = -1
1135 1135
1136 1136 if not isinstance(indices_or_msg_ids, (list,tuple)):
1137 1137 indices_or_msg_ids = [indices_or_msg_ids]
1138 1138
1139 1139 theids = []
1140 1140 for id in indices_or_msg_ids:
1141 1141 if isinstance(id, int):
1142 1142 id = self.history[id]
1143 1143 if not isinstance(id, str):
1144 1144 raise TypeError("indices must be str or int, not %r"%id)
1145 1145 theids.append(id)
1146 1146
1147 1147 for msg_id in theids:
1148 1148 self.outstanding.discard(msg_id)
1149 1149 if msg_id in self.history:
1150 1150 self.history.remove(msg_id)
1151 1151 self.results.pop(msg_id, None)
1152 1152 self.metadata.pop(msg_id, None)
1153 1153 content = dict(msg_ids = theids)
1154 1154
1155 1155 self.session.send(self._query_socket, 'resubmit_request', content)
1156 1156
1157 1157 zmq.select([self._query_socket], [], [])
1158 1158 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1159 1159 if self.debug:
1160 1160 pprint(msg)
1161 1161 content = msg['content']
1162 1162 if content['status'] != 'ok':
1163 1163 raise self._unwrap_exception(content)
1164 1164
1165 1165 ar = AsyncHubResult(self, msg_ids=theids)
1166 1166
1167 1167 if block:
1168 1168 ar.wait()
1169 1169
1170 1170 return ar
1171 1171
1172 1172 @spin_first
1173 1173 def result_status(self, msg_ids, status_only=True):
1174 1174 """Check on the status of the result(s) of the apply request with `msg_ids`.
1175 1175
1176 1176 If status_only is False, then the actual results will be retrieved, else
1177 1177 only the status of the results will be checked.
1178 1178
1179 1179 Parameters
1180 1180 ----------
1181 1181
1182 1182 msg_ids : list of msg_ids
1183 1183 if int:
1184 1184 Passed as index to self.history for convenience.
1185 1185 status_only : bool (default: True)
1186 1186 if False:
1187 1187 Retrieve the actual results of completed tasks.
1188 1188
1189 1189 Returns
1190 1190 -------
1191 1191
1192 1192 results : dict
1193 1193 There will always be the keys 'pending' and 'completed', which will
1194 1194 be lists of msg_ids that are incomplete or complete. If `status_only`
1195 1195 is False, then completed results will be keyed by their `msg_id`.
1196 1196 """
1197 1197 if not isinstance(msg_ids, (list,tuple)):
1198 1198 msg_ids = [msg_ids]
1199 1199
1200 1200 theids = []
1201 1201 for msg_id in msg_ids:
1202 1202 if isinstance(msg_id, int):
1203 1203 msg_id = self.history[msg_id]
1204 1204 if not isinstance(msg_id, basestring):
1205 1205 raise TypeError("msg_ids must be str, not %r"%msg_id)
1206 1206 theids.append(msg_id)
1207 1207
1208 1208 completed = []
1209 1209 local_results = {}
1210 1210
1211 1211 # comment this block out to temporarily disable local shortcut:
1212 1212 for msg_id in theids:
1213 1213 if msg_id in self.results:
1214 1214 completed.append(msg_id)
1215 1215 local_results[msg_id] = self.results[msg_id]
1216 1216 theids.remove(msg_id)
1217 1217
1218 1218 if theids: # some not locally cached
1219 1219 content = dict(msg_ids=theids, status_only=status_only)
1220 1220 msg = self.session.send(self._query_socket, "result_request", content=content)
1221 1221 zmq.select([self._query_socket], [], [])
1222 1222 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1223 1223 if self.debug:
1224 1224 pprint(msg)
1225 1225 content = msg['content']
1226 1226 if content['status'] != 'ok':
1227 1227 raise self._unwrap_exception(content)
1228 1228 buffers = msg['buffers']
1229 1229 else:
1230 1230 content = dict(completed=[],pending=[])
1231 1231
1232 1232 content['completed'].extend(completed)
1233 1233
1234 1234 if status_only:
1235 1235 return content
1236 1236
1237 1237 failures = []
1238 1238 # load cached results into result:
1239 1239 content.update(local_results)
1240 1240
1241 1241 # update cache with results:
1242 1242 for msg_id in sorted(theids):
1243 1243 if msg_id in content['completed']:
1244 1244 rec = content[msg_id]
1245 1245 parent = rec['header']
1246 1246 header = rec['result_header']
1247 1247 rcontent = rec['result_content']
1248 1248 iodict = rec['io']
1249 1249 if isinstance(rcontent, str):
1250 1250 rcontent = self.session.unpack(rcontent)
1251 1251
1252 1252 md = self.metadata[msg_id]
1253 1253 md.update(self._extract_metadata(header, parent, rcontent))
1254 1254 md.update(iodict)
1255 1255
1256 1256 if rcontent['status'] == 'ok':
1257 1257 res,buffers = util.unserialize_object(buffers)
1258 1258 else:
1259 1259 print rcontent
1260 1260 res = self._unwrap_exception(rcontent)
1261 1261 failures.append(res)
1262 1262
1263 1263 self.results[msg_id] = res
1264 1264 content[msg_id] = res
1265 1265
1266 1266 if len(theids) == 1 and failures:
1267 1267 raise failures[0]
1268 1268
1269 1269 error.collect_exceptions(failures, "result_status")
1270 1270 return content
1271 1271
1272 1272 @spin_first
1273 1273 def queue_status(self, targets='all', verbose=False):
1274 1274 """Fetch the status of engine queues.
1275 1275
1276 1276 Parameters
1277 1277 ----------
1278 1278
1279 1279 targets : int/str/list of ints/strs
1280 1280 the engines whose states are to be queried.
1281 1281 default : all
1282 1282 verbose : bool
1283 1283 Whether to return lengths only, or lists of ids for each element
1284 1284 """
1285 1285 engine_ids = self._build_targets(targets)[1]
1286 1286 content = dict(targets=engine_ids, verbose=verbose)
1287 1287 self.session.send(self._query_socket, "queue_request", content=content)
1288 1288 idents,msg = self.session.recv(self._query_socket, 0)
1289 1289 if self.debug:
1290 1290 pprint(msg)
1291 1291 content = msg['content']
1292 1292 status = content.pop('status')
1293 1293 if status != 'ok':
1294 1294 raise self._unwrap_exception(content)
1295 1295 content = rekey(content)
1296 1296 if isinstance(targets, int):
1297 1297 return content[targets]
1298 1298 else:
1299 1299 return content
1300 1300
1301 1301 @spin_first
1302 1302 def purge_results(self, jobs=[], targets=[]):
1303 1303 """Tell the Hub to forget results.
1304 1304
1305 1305 Individual results can be purged by msg_id, or the entire
1306 1306 history of specific targets can be purged.
1307 1307
1308 Use `purge_results('all')` to scrub everything from the Hub's db.
1309
1308 1310 Parameters
1309 1311 ----------
1310 1312
1311 1313 jobs : str or list of str or AsyncResult objects
1312 1314 the msg_ids whose results should be forgotten.
1313 1315 targets : int/str/list of ints/strs
1314 The targets, by uuid or int_id, whose entire history is to be purged.
1315 Use `targets='all'` to scrub everything from the Hub's memory.
1316 The targets, by int_id, whose entire history is to be purged.
1316 1317
1317 1318 default : None
1318 1319 """
1319 1320 if not targets and not jobs:
1320 1321 raise ValueError("Must specify at least one of `targets` and `jobs`")
1321 1322 if targets:
1322 1323 targets = self._build_targets(targets)[1]
1323 1324
1324 1325 # construct msg_ids from jobs
1325 msg_ids = []
1326 if isinstance(jobs, (basestring,AsyncResult)):
1327 jobs = [jobs]
1328 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1329 if bad_ids:
1330 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1331 for j in jobs:
1332 if isinstance(j, AsyncResult):
1333 msg_ids.extend(j.msg_ids)
1334 else:
1335 msg_ids.append(j)
1336
1337 content = dict(targets=targets, msg_ids=msg_ids)
1326 if jobs == 'all':
1327 msg_ids = jobs
1328 else:
1329 msg_ids = []
1330 if isinstance(jobs, (basestring,AsyncResult)):
1331 jobs = [jobs]
1332 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1333 if bad_ids:
1334 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1335 for j in jobs:
1336 if isinstance(j, AsyncResult):
1337 msg_ids.extend(j.msg_ids)
1338 else:
1339 msg_ids.append(j)
1340
1341 content = dict(engine_ids=targets, msg_ids=msg_ids)
1338 1342 self.session.send(self._query_socket, "purge_request", content=content)
1339 1343 idents, msg = self.session.recv(self._query_socket, 0)
1340 1344 if self.debug:
1341 1345 pprint(msg)
1342 1346 content = msg['content']
1343 1347 if content['status'] != 'ok':
1344 1348 raise self._unwrap_exception(content)
1345 1349
1346 1350 @spin_first
1347 1351 def hub_history(self):
1348 1352 """Get the Hub's history
1349 1353
1350 1354 Just like the Client, the Hub has a history, which is a list of msg_ids.
1351 1355 This will contain the history of all clients, and, depending on configuration,
1352 1356 may contain history across multiple cluster sessions.
1353 1357
1354 1358 Any msg_id returned here is a valid argument to `get_result`.
1355 1359
1356 1360 Returns
1357 1361 -------
1358 1362
1359 1363 msg_ids : list of strs
1360 1364 list of all msg_ids, ordered by task submission time.
1361 1365 """
1362 1366
1363 1367 self.session.send(self._query_socket, "history_request", content={})
1364 1368 idents, msg = self.session.recv(self._query_socket, 0)
1365 1369
1366 1370 if self.debug:
1367 1371 pprint(msg)
1368 1372 content = msg['content']
1369 1373 if content['status'] != 'ok':
1370 1374 raise self._unwrap_exception(content)
1371 1375 else:
1372 1376 return content['history']
1373 1377
1374 1378 @spin_first
1375 1379 def db_query(self, query, keys=None):
1376 1380 """Query the Hub's TaskRecord database
1377 1381
1378 1382 This will return a list of task record dicts that match `query`
1379 1383
1380 1384 Parameters
1381 1385 ----------
1382 1386
1383 1387 query : mongodb query dict
1384 1388 The search dict. See mongodb query docs for details.
1385 1389 keys : list of strs [optional]
1386 1390 The subset of keys to be returned. The default is to fetch everything but buffers.
1387 1391 'msg_id' will *always* be included.
1388 1392 """
1389 1393 if isinstance(keys, basestring):
1390 1394 keys = [keys]
1391 1395 content = dict(query=query, keys=keys)
1392 1396 self.session.send(self._query_socket, "db_request", content=content)
1393 1397 idents, msg = self.session.recv(self._query_socket, 0)
1394 1398 if self.debug:
1395 1399 pprint(msg)
1396 1400 content = msg['content']
1397 1401 if content['status'] != 'ok':
1398 1402 raise self._unwrap_exception(content)
1399 1403
1400 1404 records = content['records']
1401 1405
1402 1406 buffer_lens = content['buffer_lens']
1403 1407 result_buffer_lens = content['result_buffer_lens']
1404 1408 buffers = msg['buffers']
1405 1409 has_bufs = buffer_lens is not None
1406 1410 has_rbufs = result_buffer_lens is not None
1407 1411 for i,rec in enumerate(records):
1408 1412 # relink buffers
1409 1413 if has_bufs:
1410 1414 blen = buffer_lens[i]
1411 1415 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1412 1416 if has_rbufs:
1413 1417 blen = result_buffer_lens[i]
1414 1418 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1415 1419
1416 1420 return records
1417 1421
1418 1422 __all__ = [ 'Client' ]
@@ -1,1288 +1,1288
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20 from __future__ import print_function
21 21
22 22 import sys
23 23 import time
24 24 from datetime import datetime
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop
28 28 from zmq.eventloop.zmqstream import ZMQStream
29 29
30 30 # internal:
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'result_header' : None,
68 68 'result_content' : None,
69 69 'result_buffers' : None,
70 70 'queue' : None,
71 71 'pyin' : None,
72 72 'pyout': None,
73 73 'pyerr': None,
74 74 'stdout': '',
75 75 'stderr': '',
76 76 }
77 77
78 78 def init_record(msg):
79 79 """Initialize a TaskRecord based on a request."""
80 80 header = msg['header']
81 81 return {
82 82 'msg_id' : header['msg_id'],
83 83 'header' : header,
84 84 'content': msg['content'],
85 85 'buffers': msg['buffers'],
86 86 'submitted': header['date'],
87 87 'client_uuid' : None,
88 88 'engine_uuid' : None,
89 89 'started': None,
90 90 'completed': None,
91 91 'resubmitted': None,
92 92 'result_header' : None,
93 93 'result_content' : None,
94 94 'result_buffers' : None,
95 95 'queue' : None,
96 96 'pyin' : None,
97 97 'pyout': None,
98 98 'pyerr': None,
99 99 'stdout': '',
100 100 'stderr': '',
101 101 }
102 102
103 103
104 104 class EngineConnector(HasTraits):
105 105 """A simple object for accessing the various zmq connections of an object.
106 106 Attributes are:
107 107 id (int): engine ID
108 108 uuid (str): uuid (unused?)
109 109 queue (str): identity of queue's XREQ socket
110 110 registration (str): identity of registration XREQ socket
111 111 heartbeat (str): identity of heartbeat XREQ socket
112 112 """
113 113 id=Int(0)
114 114 queue=CBytes()
115 115 control=CBytes()
116 116 registration=CBytes()
117 117 heartbeat=CBytes()
118 118 pending=Set()
119 119
120 120 class HubFactory(RegistrationFactory):
121 121 """The Configurable for setting up a Hub."""
122 122
123 123 # port-pairs for monitoredqueues:
124 124 hb = Tuple(Int,Int,config=True,
125 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 126 def _hb_default(self):
127 127 return tuple(util.select_random_ports(2))
128 128
129 129 mux = Tuple(Int,Int,config=True,
130 130 help="""Engine/Client Port pair for MUX queue""")
131 131
132 132 def _mux_default(self):
133 133 return tuple(util.select_random_ports(2))
134 134
135 135 task = Tuple(Int,Int,config=True,
136 136 help="""Engine/Client Port pair for Task queue""")
137 137 def _task_default(self):
138 138 return tuple(util.select_random_ports(2))
139 139
140 140 control = Tuple(Int,Int,config=True,
141 141 help="""Engine/Client Port pair for Control queue""")
142 142
143 143 def _control_default(self):
144 144 return tuple(util.select_random_ports(2))
145 145
146 146 iopub = Tuple(Int,Int,config=True,
147 147 help="""Engine/Client Port pair for IOPub relay""")
148 148
149 149 def _iopub_default(self):
150 150 return tuple(util.select_random_ports(2))
151 151
152 152 # single ports:
153 153 mon_port = Int(config=True,
154 154 help="""Monitor (SUB) port for queue traffic""")
155 155
156 156 def _mon_port_default(self):
157 157 return util.select_random_ports(1)[0]
158 158
159 159 notifier_port = Int(config=True,
160 160 help="""PUB port for sending engine status notifications""")
161 161
162 162 def _notifier_port_default(self):
163 163 return util.select_random_ports(1)[0]
164 164
165 165 engine_ip = Unicode('127.0.0.1', config=True,
166 166 help="IP on which to listen for engine connections. [default: loopback]")
167 167 engine_transport = Unicode('tcp', config=True,
168 168 help="0MQ transport for engine connections. [default: tcp]")
169 169
170 170 client_ip = Unicode('127.0.0.1', config=True,
171 171 help="IP on which to listen for client connections. [default: loopback]")
172 172 client_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for client connections. [default : tcp]")
174 174
175 175 monitor_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for monitor messages. [default: loopback]")
177 177 monitor_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for monitor messages. [default : tcp]")
179 179
180 180 monitor_url = Unicode('')
181 181
182 182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 183 config=True, help="""The class to use for the DB backend""")
184 184
185 185 # not configurable
186 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 188
189 189 def _ip_changed(self, name, old, new):
190 190 self.engine_ip = new
191 191 self.client_ip = new
192 192 self.monitor_ip = new
193 193 self._update_monitor_url()
194 194
195 195 def _update_monitor_url(self):
196 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 197
198 198 def _transport_changed(self, name, old, new):
199 199 self.engine_transport = new
200 200 self.client_transport = new
201 201 self.monitor_transport = new
202 202 self._update_monitor_url()
203 203
204 204 def __init__(self, **kwargs):
205 205 super(HubFactory, self).__init__(**kwargs)
206 206 self._update_monitor_url()
207 207
208 208
209 209 def construct(self):
210 210 self.init_hub()
211 211
212 212 def start(self):
213 213 self.heartmonitor.start()
214 214 self.log.info("Heartmonitor started")
215 215
216 216 def init_hub(self):
217 217 """construct"""
218 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 220
221 221 ctx = self.context
222 222 loop = self.loop
223 223
224 224 # Registrar socket
225 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 q.bind(client_iface % self.regport)
227 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 228 if self.client_ip != self.engine_ip:
229 229 q.bind(engine_iface % self.regport)
230 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 231
232 232 ### Engine connections ###
233 233
234 234 # heartbeat
235 235 hpub = ctx.socket(zmq.PUB)
236 236 hpub.bind(engine_iface % self.hb[0])
237 237 hrep = ctx.socket(zmq.XREP)
238 238 hrep.bind(engine_iface % self.hb[1])
239 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 240 pingstream=ZMQStream(hpub,loop),
241 241 pongstream=ZMQStream(hrep,loop)
242 242 )
243 243
244 244 ### Client connections ###
245 245 # Notifier socket
246 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 247 n.bind(client_iface%self.notifier_port)
248 248
249 249 ### build and launch the queues ###
250 250
251 251 # monitor socket
252 252 sub = ctx.socket(zmq.SUB)
253 253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 254 sub.bind(self.monitor_url)
255 255 sub.bind('inproc://monitor')
256 256 sub = ZMQStream(sub, loop)
257 257
258 258 # connect the db
259 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 260 # cdir = self.config.Global.cluster_dir
261 261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 262 config=self.config, log=self.log)
263 263 time.sleep(.25)
264 264 try:
265 265 scheme = self.config.TaskScheduler.scheme_name
266 266 except AttributeError:
267 267 from .scheduler import TaskScheduler
268 268 scheme = TaskScheduler.scheme_name.get_default_value()
269 269 # build connection dicts
270 270 self.engine_info = {
271 271 'control' : engine_iface%self.control[1],
272 272 'mux': engine_iface%self.mux[1],
273 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 274 'task' : engine_iface%self.task[1],
275 275 'iopub' : engine_iface%self.iopub[1],
276 276 # 'monitor' : engine_iface%self.mon_port,
277 277 }
278 278
279 279 self.client_info = {
280 280 'control' : client_iface%self.control[0],
281 281 'mux': client_iface%self.mux[0],
282 282 'task' : (scheme, client_iface%self.task[0]),
283 283 'iopub' : client_iface%self.iopub[0],
284 284 'notification': client_iface%self.notifier_port
285 285 }
286 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 287 self.log.debug("Hub client addrs: %s"%self.client_info)
288 288
289 289 # resubmit stream
290 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 291 url = util.disambiguate_url(self.client_info['task'][-1])
292 292 r.setsockopt(zmq.IDENTITY, self.session.session)
293 293 r.connect(url)
294 294
295 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 296 query=q, notifier=n, resubmit=r, db=self.db,
297 297 engine_info=self.engine_info, client_info=self.client_info,
298 298 log=self.log)
299 299
300 300
301 301 class Hub(SessionFactory):
302 302 """The IPython Controller Hub with 0MQ connections
303 303
304 304 Parameters
305 305 ==========
306 306 loop: zmq IOLoop instance
307 307 session: Session object
308 308 <removed> context: zmq context for creating new connections (?)
309 309 queue: ZMQStream for monitoring the command queue (SUB)
310 310 query: ZMQStream for engine registration and client queries requests (XREP)
311 311 heartbeat: HeartMonitor object checking the pulse of the engines
312 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 313 db: connection to db for out of memory logging of commands
314 314 NotImplemented
315 315 engine_info: dict of zmq connection information for engines to connect
316 316 to the queues.
317 317 client_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 """
320 320 # internal data structures:
321 321 ids=Set() # engine IDs
322 322 keytable=Dict()
323 323 by_ident=Dict()
324 324 engines=Dict()
325 325 clients=Dict()
326 326 hearts=Dict()
327 327 pending=Set()
328 328 queues=Dict() # pending msg_ids keyed by engine_id
329 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 330 completed=Dict() # completed msg_ids keyed by engine_id
331 331 all_completed=Set() # completed msg_ids keyed by engine_id
332 332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 334 incoming_registrations=Dict()
335 335 registration_timeout=Int()
336 336 _idcounter=Int(0)
337 337
338 338 # objects from constructor:
339 339 query=Instance(ZMQStream)
340 340 monitor=Instance(ZMQStream)
341 341 notifier=Instance(ZMQStream)
342 342 resubmit=Instance(ZMQStream)
343 343 heartmonitor=Instance(HeartMonitor)
344 344 db=Instance(object)
345 345 client_info=Dict()
346 346 engine_info=Dict()
347 347
348 348
349 349 def __init__(self, **kwargs):
350 350 """
351 351 # universal:
352 352 loop: IOLoop for creating future connections
353 353 session: streamsession for sending serialized data
354 354 # engine:
355 355 queue: ZMQStream for monitoring queue messages
356 356 query: ZMQStream for engine+client registration and client requests
357 357 heartbeat: HeartMonitor object for tracking engines
358 358 # extra:
359 359 db: ZMQStream for db connection (NotImplemented)
360 360 engine_info: zmq address/protocol dict for engine connections
361 361 client_info: zmq address/protocol dict for client connections
362 362 """
363 363
364 364 super(Hub, self).__init__(**kwargs)
365 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 366
367 367 # validate connection dicts:
368 368 for k,v in self.client_info.iteritems():
369 369 if k == 'task':
370 370 util.validate_url_container(v[1])
371 371 else:
372 372 util.validate_url_container(v)
373 373 # util.validate_url_container(self.client_info)
374 374 util.validate_url_container(self.engine_info)
375 375
376 376 # register our callbacks
377 377 self.query.on_recv(self.dispatch_query)
378 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 379
380 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 382
383 383 self.monitor_handlers = { 'in' : self.save_queue_request,
384 384 'out': self.save_queue_result,
385 385 'intask': self.save_task_request,
386 386 'outtask': self.save_task_result,
387 387 'tracktask': self.save_task_destination,
388 388 'incontrol': _passer,
389 389 'outcontrol': _passer,
390 390 'iopub': self.save_iopub_message,
391 391 }
392 392
393 393 self.query_handlers = {'queue_request': self.queue_status,
394 394 'result_request': self.get_results,
395 395 'history_request': self.get_history,
396 396 'db_request': self.db_query,
397 397 'purge_request': self.purge_results,
398 398 'load_request': self.check_load,
399 399 'resubmit_request': self.resubmit_task,
400 400 'shutdown_request': self.shutdown_request,
401 401 'registration_request' : self.register_engine,
402 402 'unregistration_request' : self.unregister_engine,
403 403 'connection_request': self.connection_request,
404 404 }
405 405
406 406 # ignore resubmit replies
407 407 self.resubmit.on_recv(lambda msg: None, copy=False)
408 408
409 409 self.log.info("hub::created hub")
410 410
411 411 @property
412 412 def _next_id(self):
413 413 """gemerate a new ID.
414 414
415 415 No longer reuse old ids, just count from 0."""
416 416 newid = self._idcounter
417 417 self._idcounter += 1
418 418 return newid
419 419 # newid = 0
420 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 421 # # print newid, self.ids, self.incoming_registrations
422 422 # while newid in self.ids or newid in incoming:
423 423 # newid += 1
424 424 # return newid
425 425
426 426 #-----------------------------------------------------------------------------
427 427 # message validation
428 428 #-----------------------------------------------------------------------------
429 429
430 430 def _validate_targets(self, targets):
431 431 """turn any valid targets argument into a list of integer ids"""
432 432 if targets is None:
433 433 # default to all
434 434 targets = self.ids
435 435
436 436 if isinstance(targets, (int,str,unicode)):
437 437 # only one target specified
438 438 targets = [targets]
439 439 _targets = []
440 440 for t in targets:
441 441 # map raw identities to ids
442 442 if isinstance(t, (str,unicode)):
443 443 t = self.by_ident.get(t, t)
444 444 _targets.append(t)
445 445 targets = _targets
446 446 bad_targets = [ t for t in targets if t not in self.ids ]
447 447 if bad_targets:
448 448 raise IndexError("No Such Engine: %r"%bad_targets)
449 449 if not targets:
450 450 raise IndexError("No Engines Registered")
451 451 return targets
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # dispatch methods (1 per stream)
455 455 #-----------------------------------------------------------------------------
456 456
457 457
458 458 def dispatch_monitor_traffic(self, msg):
459 459 """all ME and Task queue messages come through here, as well as
460 460 IOPub traffic."""
461 461 self.log.debug("monitor traffic: %r"%msg[:2])
462 462 switch = msg[0]
463 463 try:
464 464 idents, msg = self.session.feed_identities(msg[1:])
465 465 except ValueError:
466 466 idents=[]
467 467 if not idents:
468 468 self.log.error("Bad Monitor Message: %r"%msg)
469 469 return
470 470 handler = self.monitor_handlers.get(switch, None)
471 471 if handler is not None:
472 472 handler(idents, msg)
473 473 else:
474 474 self.log.error("Invalid monitor topic: %r"%switch)
475 475
476 476
477 477 def dispatch_query(self, msg):
478 478 """Route registration requests and queries from clients."""
479 479 try:
480 480 idents, msg = self.session.feed_identities(msg)
481 481 except ValueError:
482 482 idents = []
483 483 if not idents:
484 484 self.log.error("Bad Query Message: %r"%msg)
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 488 msg = self.session.unpack_message(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 492 self.session.send(self.query, "hub_error", ident=client_id,
493 493 content=content)
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 497 msg_type = msg['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %r"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 553 msg = self.session.unpack_message(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 562 return
563 563 record = init_record(msg)
564 564 msg_id = record['msg_id']
565 565 record['engine_uuid'] = queue_id
566 566 record['client_uuid'] = client_id
567 567 record['queue'] = 'mux'
568 568
569 569 try:
570 570 # it's posible iopub arrived first:
571 571 existing = self.db.get_record(msg_id)
572 572 for key,evalue in existing.iteritems():
573 573 rvalue = record.get(key, None)
574 574 if evalue and rvalue and evalue != rvalue:
575 575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 576 elif evalue and not rvalue:
577 577 record[key] = evalue
578 578 try:
579 579 self.db.update_record(msg_id, record)
580 580 except Exception:
581 581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 582 except KeyError:
583 583 try:
584 584 self.db.add_record(msg_id, record)
585 585 except Exception:
586 586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587 587
588 588
589 589 self.pending.add(msg_id)
590 590 self.queues[eid].append(msg_id)
591 591
592 592 def save_queue_result(self, idents, msg):
593 593 if len(idents) < 2:
594 594 self.log.error("invalid identity prefix: %r"%idents)
595 595 return
596 596
597 597 client_id, queue_id = idents[:2]
598 598 try:
599 599 msg = self.session.unpack_message(msg)
600 600 except Exception:
601 601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 602 queue_id,client_id, msg), exc_info=True)
603 603 return
604 604
605 605 eid = self.by_ident.get(queue_id, None)
606 606 if eid is None:
607 607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 608 return
609 609
610 610 parent = msg['parent_header']
611 611 if not parent:
612 612 return
613 613 msg_id = parent['msg_id']
614 614 if msg_id in self.pending:
615 615 self.pending.remove(msg_id)
616 616 self.all_completed.add(msg_id)
617 617 self.queues[eid].remove(msg_id)
618 618 self.completed[eid].append(msg_id)
619 619 elif msg_id not in self.all_completed:
620 620 # it could be a result from a dead engine that died before delivering the
621 621 # result
622 622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 623 return
624 624 # update record anyway, because the unregistration could have been premature
625 625 rheader = msg['header']
626 626 completed = rheader['date']
627 627 started = rheader.get('started', None)
628 628 result = {
629 629 'result_header' : rheader,
630 630 'result_content': msg['content'],
631 631 'started' : started,
632 632 'completed' : completed
633 633 }
634 634
635 635 result['result_buffers'] = msg['buffers']
636 636 try:
637 637 self.db.update_record(msg_id, result)
638 638 except Exception:
639 639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640 640
641 641
642 642 #--------------------- Task Queue Traffic ------------------------------
643 643
644 644 def save_task_request(self, idents, msg):
645 645 """Save the submission of a task."""
646 646 client_id = idents[0]
647 647
648 648 try:
649 649 msg = self.session.unpack_message(msg)
650 650 except Exception:
651 651 self.log.error("task::client %r sent invalid task message: %r"%(
652 652 client_id, msg), exc_info=True)
653 653 return
654 654 record = init_record(msg)
655 655
656 656 record['client_uuid'] = client_id
657 657 record['queue'] = 'task'
658 658 header = msg['header']
659 659 msg_id = header['msg_id']
660 660 self.pending.add(msg_id)
661 661 self.unassigned.add(msg_id)
662 662 try:
663 663 # it's posible iopub arrived first:
664 664 existing = self.db.get_record(msg_id)
665 665 if existing['resubmitted']:
666 666 for key in ('submitted', 'client_uuid', 'buffers'):
667 667 # don't clobber these keys on resubmit
668 668 # submitted and client_uuid should be different
669 669 # and buffers might be big, and shouldn't have changed
670 670 record.pop(key)
671 671 # still check content,header which should not change
672 672 # but are not expensive to compare as buffers
673 673
674 674 for key,evalue in existing.iteritems():
675 675 if key.endswith('buffers'):
676 676 # don't compare buffers
677 677 continue
678 678 rvalue = record.get(key, None)
679 679 if evalue and rvalue and evalue != rvalue:
680 680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 681 elif evalue and not rvalue:
682 682 record[key] = evalue
683 683 try:
684 684 self.db.update_record(msg_id, record)
685 685 except Exception:
686 686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 687 except KeyError:
688 688 try:
689 689 self.db.add_record(msg_id, record)
690 690 except Exception:
691 691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 692 except Exception:
693 693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694 694
695 695 def save_task_result(self, idents, msg):
696 696 """save the result of a completed task."""
697 697 client_id = idents[0]
698 698 try:
699 699 msg = self.session.unpack_message(msg)
700 700 except Exception:
701 701 self.log.error("task::invalid task result message send to %r: %r"%(
702 702 client_id, msg), exc_info=True)
703 703 return
704 704
705 705 parent = msg['parent_header']
706 706 if not parent:
707 707 # print msg
708 708 self.log.warn("Task %r had no parent!"%msg)
709 709 return
710 710 msg_id = parent['msg_id']
711 711 if msg_id in self.unassigned:
712 712 self.unassigned.remove(msg_id)
713 713
714 714 header = msg['header']
715 715 engine_uuid = header.get('engine', None)
716 716 eid = self.by_ident.get(engine_uuid, None)
717 717
718 718 if msg_id in self.pending:
719 719 self.pending.remove(msg_id)
720 720 self.all_completed.add(msg_id)
721 721 if eid is not None:
722 722 self.completed[eid].append(msg_id)
723 723 if msg_id in self.tasks[eid]:
724 724 self.tasks[eid].remove(msg_id)
725 725 completed = header['date']
726 726 started = header.get('started', None)
727 727 result = {
728 728 'result_header' : header,
729 729 'result_content': msg['content'],
730 730 'started' : started,
731 731 'completed' : completed,
732 732 'engine_uuid': engine_uuid
733 733 }
734 734
735 735 result['result_buffers'] = msg['buffers']
736 736 try:
737 737 self.db.update_record(msg_id, result)
738 738 except Exception:
739 739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740 740
741 741 else:
742 742 self.log.debug("task::unknown task %r finished"%msg_id)
743 743
744 744 def save_task_destination(self, idents, msg):
745 745 try:
746 746 msg = self.session.unpack_message(msg, content=True)
747 747 except Exception:
748 748 self.log.error("task::invalid task tracking message", exc_info=True)
749 749 return
750 750 content = msg['content']
751 751 # print (content)
752 752 msg_id = content['msg_id']
753 753 engine_uuid = content['engine_id']
754 754 eid = self.by_ident[engine_uuid]
755 755
756 756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 757 if msg_id in self.unassigned:
758 758 self.unassigned.remove(msg_id)
759 759 # else:
760 760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761 761
762 762 self.tasks[eid].append(msg_id)
763 763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 764 try:
765 765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 766 except Exception:
767 767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768 768
769 769
770 770 def mia_task_request(self, idents, msg):
771 771 raise NotImplementedError
772 772 client_id = idents[0]
773 773 # content = dict(mia=self.mia,status='ok')
774 774 # self.session.send('mia_reply', content=content, idents=client_id)
775 775
776 776
777 777 #--------------------- IOPub Traffic ------------------------------
778 778
779 779 def save_iopub_message(self, topics, msg):
780 780 """save an iopub message into the db"""
781 781 # print (topics)
782 782 try:
783 783 msg = self.session.unpack_message(msg, content=True)
784 784 except Exception:
785 785 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 786 return
787 787
788 788 parent = msg['parent_header']
789 789 if not parent:
790 790 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 791 return
792 792 msg_id = parent['msg_id']
793 793 msg_type = msg['msg_type']
794 794 content = msg['content']
795 795
796 796 # ensure msg_id is in db
797 797 try:
798 798 rec = self.db.get_record(msg_id)
799 799 except KeyError:
800 800 rec = empty_record()
801 801 rec['msg_id'] = msg_id
802 802 self.db.add_record(msg_id, rec)
803 803 # stream
804 804 d = {}
805 805 if msg_type == 'stream':
806 806 name = content['name']
807 807 s = rec[name] or ''
808 808 d[name] = s + content['data']
809 809
810 810 elif msg_type == 'pyerr':
811 811 d['pyerr'] = content
812 812 elif msg_type == 'pyin':
813 813 d['pyin'] = content['code']
814 814 else:
815 815 d[msg_type] = content.get('data', '')
816 816
817 817 try:
818 818 self.db.update_record(msg_id, d)
819 819 except Exception:
820 820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821 821
822 822
823 823
824 824 #-------------------------------------------------------------------------
825 825 # Registration requests
826 826 #-------------------------------------------------------------------------
827 827
828 828 def connection_request(self, client_id, msg):
829 829 """Reply with connection addresses for clients."""
830 830 self.log.info("client::client %r connected"%client_id)
831 831 content = dict(status='ok')
832 832 content.update(self.client_info)
833 833 jsonable = {}
834 834 for k,v in self.keytable.iteritems():
835 835 if v not in self.dead_engines:
836 836 jsonable[str(k)] = v
837 837 content['engines'] = jsonable
838 838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839 839
840 840 def register_engine(self, reg, msg):
841 841 """Register a new engine."""
842 842 content = msg['content']
843 843 try:
844 844 queue = content['queue']
845 845 except KeyError:
846 846 self.log.error("registration::queue not specified", exc_info=True)
847 847 return
848 848 heart = content.get('heartbeat', None)
849 849 """register a new engine, and create the socket(s) necessary"""
850 850 eid = self._next_id
851 851 # print (eid, queue, reg, heart)
852 852
853 853 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
854 854
855 855 content = dict(id=eid,status='ok')
856 856 content.update(self.engine_info)
857 857 # check if requesting available IDs:
858 858 if queue in self.by_ident:
859 859 try:
860 860 raise KeyError("queue_id %r in use"%queue)
861 861 except:
862 862 content = error.wrap_exception()
863 863 self.log.error("queue_id %r in use"%queue, exc_info=True)
864 864 elif heart in self.hearts: # need to check unique hearts?
865 865 try:
866 866 raise KeyError("heart_id %r in use"%heart)
867 867 except:
868 868 self.log.error("heart_id %r in use"%heart, exc_info=True)
869 869 content = error.wrap_exception()
870 870 else:
871 871 for h, pack in self.incoming_registrations.iteritems():
872 872 if heart == h:
873 873 try:
874 874 raise KeyError("heart_id %r in use"%heart)
875 875 except:
876 876 self.log.error("heart_id %r in use"%heart, exc_info=True)
877 877 content = error.wrap_exception()
878 878 break
879 879 elif queue == pack[1]:
880 880 try:
881 881 raise KeyError("queue_id %r in use"%queue)
882 882 except:
883 883 self.log.error("queue_id %r in use"%queue, exc_info=True)
884 884 content = error.wrap_exception()
885 885 break
886 886
887 887 msg = self.session.send(self.query, "registration_reply",
888 888 content=content,
889 889 ident=reg)
890 890
891 891 if content['status'] == 'ok':
892 892 if heart in self.heartmonitor.hearts:
893 893 # already beating
894 894 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
895 895 self.finish_registration(heart)
896 896 else:
897 897 purge = lambda : self._purge_stalled_registration(heart)
898 898 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
899 899 dc.start()
900 900 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
901 901 else:
902 902 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
903 903 return eid
904 904
905 905 def unregister_engine(self, ident, msg):
906 906 """Unregister an engine that explicitly requested to leave."""
907 907 try:
908 908 eid = msg['content']['id']
909 909 except:
910 910 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
911 911 return
912 912 self.log.info("registration::unregister_engine(%r)"%eid)
913 913 # print (eid)
914 914 uuid = self.keytable[eid]
915 915 content=dict(id=eid, queue=uuid)
916 916 self.dead_engines.add(uuid)
917 917 # self.ids.remove(eid)
918 918 # uuid = self.keytable.pop(eid)
919 919 #
920 920 # ec = self.engines.pop(eid)
921 921 # self.hearts.pop(ec.heartbeat)
922 922 # self.by_ident.pop(ec.queue)
923 923 # self.completed.pop(eid)
924 924 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
925 925 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
926 926 dc.start()
927 927 ############## TODO: HANDLE IT ################
928 928
929 929 if self.notifier:
930 930 self.session.send(self.notifier, "unregistration_notification", content=content)
931 931
932 932 def _handle_stranded_msgs(self, eid, uuid):
933 933 """Handle messages known to be on an engine when the engine unregisters.
934 934
935 935 It is possible that this will fire prematurely - that is, an engine will
936 936 go down after completing a result, and the client will be notified
937 937 that the result failed and later receive the actual result.
938 938 """
939 939
940 940 outstanding = self.queues[eid]
941 941
942 942 for msg_id in outstanding:
943 943 self.pending.remove(msg_id)
944 944 self.all_completed.add(msg_id)
945 945 try:
946 946 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
947 947 except:
948 948 content = error.wrap_exception()
949 949 # build a fake header:
950 950 header = {}
951 951 header['engine'] = uuid
952 952 header['date'] = datetime.now()
953 953 rec = dict(result_content=content, result_header=header, result_buffers=[])
954 954 rec['completed'] = header['date']
955 955 rec['engine_uuid'] = uuid
956 956 try:
957 957 self.db.update_record(msg_id, rec)
958 958 except Exception:
959 959 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
960 960
961 961
962 962 def finish_registration(self, heart):
963 963 """Second half of engine registration, called after our HeartMonitor
964 964 has received a beat from the Engine's Heart."""
965 965 try:
966 966 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
967 967 except KeyError:
968 968 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
969 969 return
970 970 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
971 971 if purge is not None:
972 972 purge.stop()
973 973 control = queue
974 974 self.ids.add(eid)
975 975 self.keytable[eid] = queue
976 976 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
977 977 control=control, heartbeat=heart)
978 978 self.by_ident[queue] = eid
979 979 self.queues[eid] = list()
980 980 self.tasks[eid] = list()
981 981 self.completed[eid] = list()
982 982 self.hearts[heart] = eid
983 983 content = dict(id=eid, queue=self.engines[eid].queue)
984 984 if self.notifier:
985 985 self.session.send(self.notifier, "registration_notification", content=content)
986 986 self.log.info("engine::Engine Connected: %i"%eid)
987 987
988 988 def _purge_stalled_registration(self, heart):
989 989 if heart in self.incoming_registrations:
990 990 eid = self.incoming_registrations.pop(heart)[0]
991 991 self.log.info("registration::purging stalled registration: %i"%eid)
992 992 else:
993 993 pass
994 994
995 995 #-------------------------------------------------------------------------
996 996 # Client Requests
997 997 #-------------------------------------------------------------------------
998 998
999 999 def shutdown_request(self, client_id, msg):
1000 1000 """handle shutdown request."""
1001 1001 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1002 1002 # also notify other clients of shutdown
1003 1003 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1004 1004 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1005 1005 dc.start()
1006 1006
1007 1007 def _shutdown(self):
1008 1008 self.log.info("hub::hub shutting down.")
1009 1009 time.sleep(0.1)
1010 1010 sys.exit(0)
1011 1011
1012 1012
1013 1013 def check_load(self, client_id, msg):
1014 1014 content = msg['content']
1015 1015 try:
1016 1016 targets = content['targets']
1017 1017 targets = self._validate_targets(targets)
1018 1018 except:
1019 1019 content = error.wrap_exception()
1020 1020 self.session.send(self.query, "hub_error",
1021 1021 content=content, ident=client_id)
1022 1022 return
1023 1023
1024 1024 content = dict(status='ok')
1025 1025 # loads = {}
1026 1026 for t in targets:
1027 1027 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1028 1028 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1029 1029
1030 1030
1031 1031 def queue_status(self, client_id, msg):
1032 1032 """Return the Queue status of one or more targets.
1033 1033 if verbose: return the msg_ids
1034 1034 else: return len of each type.
1035 1035 keys: queue (pending MUX jobs)
1036 1036 tasks (pending Task jobs)
1037 1037 completed (finished jobs from both queues)"""
1038 1038 content = msg['content']
1039 1039 targets = content['targets']
1040 1040 try:
1041 1041 targets = self._validate_targets(targets)
1042 1042 except:
1043 1043 content = error.wrap_exception()
1044 1044 self.session.send(self.query, "hub_error",
1045 1045 content=content, ident=client_id)
1046 1046 return
1047 1047 verbose = content.get('verbose', False)
1048 1048 content = dict(status='ok')
1049 1049 for t in targets:
1050 1050 queue = self.queues[t]
1051 1051 completed = self.completed[t]
1052 1052 tasks = self.tasks[t]
1053 1053 if not verbose:
1054 1054 queue = len(queue)
1055 1055 completed = len(completed)
1056 1056 tasks = len(tasks)
1057 1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1058 1058 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1059 1059
1060 1060 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1061 1061
1062 1062 def purge_results(self, client_id, msg):
1063 1063 """Purge results from memory. This method is more valuable before we move
1064 1064 to a DB based message storage mechanism."""
1065 1065 content = msg['content']
1066 self.log.info("Dropping records with %s", content)
1066 1067 msg_ids = content.get('msg_ids', [])
1067 1068 reply = dict(status='ok')
1068 1069 if msg_ids == 'all':
1069 1070 try:
1070 1071 self.db.drop_matching_records(dict(completed={'$ne':None}))
1071 1072 except Exception:
1072 1073 reply = error.wrap_exception()
1073 1074 else:
1074 1075 pending = filter(lambda m: m in self.pending, msg_ids)
1075 1076 if pending:
1076 1077 try:
1077 1078 raise IndexError("msg pending: %r"%pending[0])
1078 1079 except:
1079 1080 reply = error.wrap_exception()
1080 1081 else:
1081 1082 try:
1082 1083 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1083 1084 except Exception:
1084 1085 reply = error.wrap_exception()
1085 1086
1086 1087 if reply['status'] == 'ok':
1087 1088 eids = content.get('engine_ids', [])
1088 1089 for eid in eids:
1089 1090 if eid not in self.engines:
1090 1091 try:
1091 1092 raise IndexError("No such engine: %i"%eid)
1092 1093 except:
1093 1094 reply = error.wrap_exception()
1094 1095 break
1095 msg_ids = self.completed.pop(eid)
1096 1096 uid = self.engines[eid].queue
1097 1097 try:
1098 1098 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1099 1099 except Exception:
1100 1100 reply = error.wrap_exception()
1101 1101 break
1102 1102
1103 1103 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1104 1104
1105 1105 def resubmit_task(self, client_id, msg):
1106 1106 """Resubmit one or more tasks."""
1107 1107 def finish(reply):
1108 1108 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1109 1109
1110 1110 content = msg['content']
1111 1111 msg_ids = content['msg_ids']
1112 1112 reply = dict(status='ok')
1113 1113 try:
1114 1114 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1115 1115 'header', 'content', 'buffers'])
1116 1116 except Exception:
1117 1117 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1118 1118 return finish(error.wrap_exception())
1119 1119
1120 1120 # validate msg_ids
1121 1121 found_ids = [ rec['msg_id'] for rec in records ]
1122 1122 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1123 1123 if len(records) > len(msg_ids):
1124 1124 try:
1125 1125 raise RuntimeError("DB appears to be in an inconsistent state."
1126 1126 "More matching records were found than should exist")
1127 1127 except Exception:
1128 1128 return finish(error.wrap_exception())
1129 1129 elif len(records) < len(msg_ids):
1130 1130 missing = [ m for m in msg_ids if m not in found_ids ]
1131 1131 try:
1132 1132 raise KeyError("No such msg(s): %r"%missing)
1133 1133 except KeyError:
1134 1134 return finish(error.wrap_exception())
1135 1135 elif invalid_ids:
1136 1136 msg_id = invalid_ids[0]
1137 1137 try:
1138 1138 raise ValueError("Task %r appears to be inflight"%(msg_id))
1139 1139 except Exception:
1140 1140 return finish(error.wrap_exception())
1141 1141
1142 1142 # clear the existing records
1143 1143 now = datetime.now()
1144 1144 rec = empty_record()
1145 1145 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1146 1146 rec['resubmitted'] = now
1147 1147 rec['queue'] = 'task'
1148 1148 rec['client_uuid'] = client_id[0]
1149 1149 try:
1150 1150 for msg_id in msg_ids:
1151 1151 self.all_completed.discard(msg_id)
1152 1152 self.db.update_record(msg_id, rec)
1153 1153 except Exception:
1154 1154 self.log.error('db::db error upating record', exc_info=True)
1155 1155 reply = error.wrap_exception()
1156 1156 else:
1157 1157 # send the messages
1158 1158 for rec in records:
1159 1159 header = rec['header']
1160 1160 # include resubmitted in header to prevent digest collision
1161 1161 header['resubmitted'] = now
1162 1162 msg = self.session.msg(header['msg_type'])
1163 1163 msg['content'] = rec['content']
1164 1164 msg['header'] = header
1165 1165 msg['msg_id'] = rec['msg_id']
1166 1166 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1167 1167
1168 1168 finish(dict(status='ok'))
1169 1169
1170 1170
1171 1171 def _extract_record(self, rec):
1172 1172 """decompose a TaskRecord dict into subsection of reply for get_result"""
1173 1173 io_dict = {}
1174 1174 for key in 'pyin pyout pyerr stdout stderr'.split():
1175 1175 io_dict[key] = rec[key]
1176 1176 content = { 'result_content': rec['result_content'],
1177 1177 'header': rec['header'],
1178 1178 'result_header' : rec['result_header'],
1179 1179 'io' : io_dict,
1180 1180 }
1181 1181 if rec['result_buffers']:
1182 1182 buffers = map(str, rec['result_buffers'])
1183 1183 else:
1184 1184 buffers = []
1185 1185
1186 1186 return content, buffers
1187 1187
1188 1188 def get_results(self, client_id, msg):
1189 1189 """Get the result of 1 or more messages."""
1190 1190 content = msg['content']
1191 1191 msg_ids = sorted(set(content['msg_ids']))
1192 1192 statusonly = content.get('status_only', False)
1193 1193 pending = []
1194 1194 completed = []
1195 1195 content = dict(status='ok')
1196 1196 content['pending'] = pending
1197 1197 content['completed'] = completed
1198 1198 buffers = []
1199 1199 if not statusonly:
1200 1200 try:
1201 1201 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1202 1202 # turn match list into dict, for faster lookup
1203 1203 records = {}
1204 1204 for rec in matches:
1205 1205 records[rec['msg_id']] = rec
1206 1206 except Exception:
1207 1207 content = error.wrap_exception()
1208 1208 self.session.send(self.query, "result_reply", content=content,
1209 1209 parent=msg, ident=client_id)
1210 1210 return
1211 1211 else:
1212 1212 records = {}
1213 1213 for msg_id in msg_ids:
1214 1214 if msg_id in self.pending:
1215 1215 pending.append(msg_id)
1216 1216 elif msg_id in self.all_completed:
1217 1217 completed.append(msg_id)
1218 1218 if not statusonly:
1219 1219 c,bufs = self._extract_record(records[msg_id])
1220 1220 content[msg_id] = c
1221 1221 buffers.extend(bufs)
1222 1222 elif msg_id in records:
1223 1223 if rec['completed']:
1224 1224 completed.append(msg_id)
1225 1225 c,bufs = self._extract_record(records[msg_id])
1226 1226 content[msg_id] = c
1227 1227 buffers.extend(bufs)
1228 1228 else:
1229 1229 pending.append(msg_id)
1230 1230 else:
1231 1231 try:
1232 1232 raise KeyError('No such message: '+msg_id)
1233 1233 except:
1234 1234 content = error.wrap_exception()
1235 1235 break
1236 1236 self.session.send(self.query, "result_reply", content=content,
1237 1237 parent=msg, ident=client_id,
1238 1238 buffers=buffers)
1239 1239
1240 1240 def get_history(self, client_id, msg):
1241 1241 """Get a list of all msg_ids in our DB records"""
1242 1242 try:
1243 1243 msg_ids = self.db.get_history()
1244 1244 except Exception as e:
1245 1245 content = error.wrap_exception()
1246 1246 else:
1247 1247 content = dict(status='ok', history=msg_ids)
1248 1248
1249 1249 self.session.send(self.query, "history_reply", content=content,
1250 1250 parent=msg, ident=client_id)
1251 1251
1252 1252 def db_query(self, client_id, msg):
1253 1253 """Perform a raw query on the task record database."""
1254 1254 content = msg['content']
1255 1255 query = content.get('query', {})
1256 1256 keys = content.get('keys', None)
1257 1257 buffers = []
1258 1258 empty = list()
1259 1259 try:
1260 1260 records = self.db.find_records(query, keys)
1261 1261 except Exception as e:
1262 1262 content = error.wrap_exception()
1263 1263 else:
1264 1264 # extract buffers from reply content:
1265 1265 if keys is not None:
1266 1266 buffer_lens = [] if 'buffers' in keys else None
1267 1267 result_buffer_lens = [] if 'result_buffers' in keys else None
1268 1268 else:
1269 1269 buffer_lens = []
1270 1270 result_buffer_lens = []
1271 1271
1272 1272 for rec in records:
1273 1273 # buffers may be None, so double check
1274 1274 if buffer_lens is not None:
1275 1275 b = rec.pop('buffers', empty) or empty
1276 1276 buffer_lens.append(len(b))
1277 1277 buffers.extend(b)
1278 1278 if result_buffer_lens is not None:
1279 1279 rb = rec.pop('result_buffers', empty) or empty
1280 1280 result_buffer_lens.append(len(rb))
1281 1281 buffers.extend(rb)
1282 1282 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1283 1283 result_buffer_lens=result_buffer_lens)
1284 1284
1285 1285 self.session.send(self.query, "db_reply", content=content,
1286 1286 parent=msg, ident=client_id,
1287 1287 buffers=buffers)
1288 1288
@@ -1,249 +1,257
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import time
20 20 from datetime import datetime
21 21 from tempfile import mktemp
22 22
23 23 import zmq
24 24
25 25 from IPython.parallel.client import client as clientmod
26 26 from IPython.parallel import error
27 27 from IPython.parallel import AsyncResult, AsyncHubResult
28 28 from IPython.parallel import LoadBalancedView, DirectView
29 29
30 30 from clienttest import ClusterTestCase, segfault, wait, add_engines
31 31
32 32 def setup():
33 33 add_engines(4)
34 34
35 35 class TestClient(ClusterTestCase):
36 36
37 37 def test_ids(self):
38 38 n = len(self.client.ids)
39 39 self.add_engines(3)
40 40 self.assertEquals(len(self.client.ids), n+3)
41 41
42 42 def test_view_indexing(self):
43 43 """test index access for views"""
44 44 self.add_engines(2)
45 45 targets = self.client._build_targets('all')[-1]
46 46 v = self.client[:]
47 47 self.assertEquals(v.targets, targets)
48 48 t = self.client.ids[2]
49 49 v = self.client[t]
50 50 self.assert_(isinstance(v, DirectView))
51 51 self.assertEquals(v.targets, t)
52 52 t = self.client.ids[2:4]
53 53 v = self.client[t]
54 54 self.assert_(isinstance(v, DirectView))
55 55 self.assertEquals(v.targets, t)
56 56 v = self.client[::2]
57 57 self.assert_(isinstance(v, DirectView))
58 58 self.assertEquals(v.targets, targets[::2])
59 59 v = self.client[1::3]
60 60 self.assert_(isinstance(v, DirectView))
61 61 self.assertEquals(v.targets, targets[1::3])
62 62 v = self.client[:-3]
63 63 self.assert_(isinstance(v, DirectView))
64 64 self.assertEquals(v.targets, targets[:-3])
65 65 v = self.client[-1]
66 66 self.assert_(isinstance(v, DirectView))
67 67 self.assertEquals(v.targets, targets[-1])
68 68 self.assertRaises(TypeError, lambda : self.client[None])
69 69
70 70 def test_lbview_targets(self):
71 71 """test load_balanced_view targets"""
72 72 v = self.client.load_balanced_view()
73 73 self.assertEquals(v.targets, None)
74 74 v = self.client.load_balanced_view(-1)
75 75 self.assertEquals(v.targets, [self.client.ids[-1]])
76 76 v = self.client.load_balanced_view('all')
77 77 self.assertEquals(v.targets, self.client.ids)
78 78
79 79 def test_targets(self):
80 80 """test various valid targets arguments"""
81 81 build = self.client._build_targets
82 82 ids = self.client.ids
83 83 idents,targets = build(None)
84 84 self.assertEquals(ids, targets)
85 85
86 86 def test_clear(self):
87 87 """test clear behavior"""
88 88 # self.add_engines(2)
89 89 v = self.client[:]
90 90 v.block=True
91 91 v.push(dict(a=5))
92 92 v.pull('a')
93 93 id0 = self.client.ids[-1]
94 94 self.client.clear(targets=id0, block=True)
95 95 a = self.client[:-1].get('a')
96 96 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
97 97 self.client.clear(block=True)
98 98 for i in self.client.ids:
99 99 # print i
100 100 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
101 101
102 102 def test_get_result(self):
103 103 """test getting results from the Hub."""
104 104 c = clientmod.Client(profile='iptest')
105 105 # self.add_engines(1)
106 106 t = c.ids[-1]
107 107 ar = c[t].apply_async(wait, 1)
108 108 # give the monitor time to notice the message
109 109 time.sleep(.25)
110 110 ahr = self.client.get_result(ar.msg_ids)
111 111 self.assertTrue(isinstance(ahr, AsyncHubResult))
112 112 self.assertEquals(ahr.get(), ar.get())
113 113 ar2 = self.client.get_result(ar.msg_ids)
114 114 self.assertFalse(isinstance(ar2, AsyncHubResult))
115 115 c.close()
116 116
117 117 def test_ids_list(self):
118 118 """test client.ids"""
119 119 # self.add_engines(2)
120 120 ids = self.client.ids
121 121 self.assertEquals(ids, self.client._ids)
122 122 self.assertFalse(ids is self.client._ids)
123 123 ids.remove(ids[-1])
124 124 self.assertNotEquals(ids, self.client._ids)
125 125
126 126 def test_queue_status(self):
127 127 # self.addEngine(4)
128 128 ids = self.client.ids
129 129 id0 = ids[0]
130 130 qs = self.client.queue_status(targets=id0)
131 131 self.assertTrue(isinstance(qs, dict))
132 132 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
133 133 allqs = self.client.queue_status()
134 134 self.assertTrue(isinstance(allqs, dict))
135 135 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
136 136 unassigned = allqs.pop('unassigned')
137 137 for eid,qs in allqs.items():
138 138 self.assertTrue(isinstance(qs, dict))
139 139 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
140 140
141 141 def test_shutdown(self):
142 142 # self.addEngine(4)
143 143 ids = self.client.ids
144 144 id0 = ids[0]
145 145 self.client.shutdown(id0, block=True)
146 146 while id0 in self.client.ids:
147 147 time.sleep(0.1)
148 148 self.client.spin()
149 149
150 150 self.assertRaises(IndexError, lambda : self.client[id0])
151 151
152 152 def test_result_status(self):
153 153 pass
154 154 # to be written
155 155
156 156 def test_db_query_dt(self):
157 157 """test db query by date"""
158 158 hist = self.client.hub_history()
159 159 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
160 160 tic = middle['submitted']
161 161 before = self.client.db_query({'submitted' : {'$lt' : tic}})
162 162 after = self.client.db_query({'submitted' : {'$gte' : tic}})
163 163 self.assertEquals(len(before)+len(after),len(hist))
164 164 for b in before:
165 165 self.assertTrue(b['submitted'] < tic)
166 166 for a in after:
167 167 self.assertTrue(a['submitted'] >= tic)
168 168 same = self.client.db_query({'submitted' : tic})
169 169 for s in same:
170 170 self.assertTrue(s['submitted'] == tic)
171 171
172 172 def test_db_query_keys(self):
173 173 """test extracting subset of record keys"""
174 174 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
175 175 for rec in found:
176 176 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
177 177
178 178 def test_db_query_msg_id(self):
179 179 """ensure msg_id is always in db queries"""
180 180 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
181 181 for rec in found:
182 182 self.assertTrue('msg_id' in rec.keys())
183 183 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
184 184 for rec in found:
185 185 self.assertTrue('msg_id' in rec.keys())
186 186 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
187 187 for rec in found:
188 188 self.assertTrue('msg_id' in rec.keys())
189 189
190 190 def test_db_query_in(self):
191 191 """test db query with '$in','$nin' operators"""
192 192 hist = self.client.hub_history()
193 193 even = hist[::2]
194 194 odd = hist[1::2]
195 195 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
196 196 found = [ r['msg_id'] for r in recs ]
197 197 self.assertEquals(set(even), set(found))
198 198 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
199 199 found = [ r['msg_id'] for r in recs ]
200 200 self.assertEquals(set(odd), set(found))
201 201
202 202 def test_hub_history(self):
203 203 hist = self.client.hub_history()
204 204 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
205 205 recdict = {}
206 206 for rec in recs:
207 207 recdict[rec['msg_id']] = rec
208 208
209 209 latest = datetime(1984,1,1)
210 210 for msg_id in hist:
211 211 rec = recdict[msg_id]
212 212 newt = rec['submitted']
213 213 self.assertTrue(newt >= latest)
214 214 latest = newt
215 215 ar = self.client[-1].apply_async(lambda : 1)
216 216 ar.get()
217 217 time.sleep(0.25)
218 218 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
219 219
220 220 def test_resubmit(self):
221 221 def f():
222 222 import random
223 223 return random.random()
224 224 v = self.client.load_balanced_view()
225 225 ar = v.apply_async(f)
226 226 r1 = ar.get(1)
227 227 ahr = self.client.resubmit(ar.msg_ids)
228 228 r2 = ahr.get(1)
229 229 self.assertFalse(r1 == r2)
230 230
231 231 def test_resubmit_inflight(self):
232 232 """ensure ValueError on resubmit of inflight task"""
233 233 v = self.client.load_balanced_view()
234 234 ar = v.apply_async(time.sleep,1)
235 235 # give the message a chance to arrive
236 236 time.sleep(0.2)
237 237 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
238 238 ar.get(2)
239 239
240 240 def test_resubmit_badkey(self):
241 241 """ensure KeyError on resubmit of nonexistant task"""
242 242 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
243 243
244 244 def test_purge_results(self):
245 # ensure there are some tasks
246 for i in range(5):
247 self.client[:].apply_sync(lambda : 1)
245 248 hist = self.client.hub_history()
246 self.client.purge_results(hist)
249 self.client.purge_results(hist[-1])
247 250 newhist = self.client.hub_history()
248 self.assertTrue(len(newhist) == 0)
251 self.assertEquals(len(newhist)+1,len(hist))
252
253 def test_purge_all_results(self):
254 self.client.purge_results('all')
255 hist = self.client.hub_history()
256 self.assertEquals(len(hist), 0)
249 257
General Comments 0
You need to be logged in to leave comments. Login now