##// END OF EJS Templates
Merge pull request #6146 from minrk/parallel-handle-status...
Thomas Kluyver -
r17267:48531b99 merge
parent child Browse files
Show More
@@ -1,1870 +1,1874 b''
1 1 """A semi-synchronous Client for IPython parallel"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 from __future__ import print_function
7 7
8 8 import os
9 9 import json
10 10 import sys
11 11 from threading import Thread, Event
12 12 import time
13 13 import warnings
14 14 from datetime import datetime
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17
18 18 pjoin = os.path.join
19 19
20 20 import zmq
21 21
22 22 from IPython.config.configurable import MultipleInstanceError
23 23 from IPython.core.application import BaseIPythonApplication
24 24 from IPython.core.profiledir import ProfileDir, ProfileDirError
25 25
26 26 from IPython.utils.capture import RichOutput
27 27 from IPython.utils.coloransi import TermColors
28 28 from IPython.utils.jsonutil import rekey, extract_dates, parse_date
29 29 from IPython.utils.localinterfaces import localhost, is_local_ip
30 30 from IPython.utils.path import get_ipython_dir
31 31 from IPython.utils.py3compat import cast_bytes, string_types, xrange, iteritems
32 32 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
33 33 Dict, List, Bool, Set, Any)
34 34 from IPython.external.decorator import decorator
35 35
36 36 from IPython.parallel import Reference
37 37 from IPython.parallel import error
38 38 from IPython.parallel import util
39 39
40 40 from IPython.kernel.zmq.session import Session, Message
41 41 from IPython.kernel.zmq import serialize
42 42
43 43 from .asyncresult import AsyncResult, AsyncHubResult
44 44 from .view import DirectView, LoadBalancedView
45 45
46 46 #--------------------------------------------------------------------------
47 47 # Decorators for Client methods
48 48 #--------------------------------------------------------------------------
49 49
50 50 @decorator
51 51 def spin_first(f, self, *args, **kwargs):
52 52 """Call spin() to sync state prior to calling the method."""
53 53 self.spin()
54 54 return f(self, *args, **kwargs)
55 55
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Classes
59 59 #--------------------------------------------------------------------------
60 60
61 61
62 62 class ExecuteReply(RichOutput):
63 63 """wrapper for finished Execute results"""
64 64 def __init__(self, msg_id, content, metadata):
65 65 self.msg_id = msg_id
66 66 self._content = content
67 67 self.execution_count = content['execution_count']
68 68 self.metadata = metadata
69 69
70 70 # RichOutput overrides
71 71
72 72 @property
73 73 def source(self):
74 74 execute_result = self.metadata['execute_result']
75 75 if execute_result:
76 76 return execute_result.get('source', '')
77 77
78 78 @property
79 79 def data(self):
80 80 execute_result = self.metadata['execute_result']
81 81 if execute_result:
82 82 return execute_result.get('data', {})
83 83
84 84 @property
85 85 def _metadata(self):
86 86 execute_result = self.metadata['execute_result']
87 87 if execute_result:
88 88 return execute_result.get('metadata', {})
89 89
90 90 def display(self):
91 91 from IPython.display import publish_display_data
92 92 publish_display_data(self.data, self.metadata)
93 93
94 94 def _repr_mime_(self, mime):
95 95 if mime not in self.data:
96 96 return
97 97 data = self.data[mime]
98 98 if mime in self._metadata:
99 99 return data, self._metadata[mime]
100 100 else:
101 101 return data
102 102
103 103 def __getitem__(self, key):
104 104 return self.metadata[key]
105 105
106 106 def __getattr__(self, key):
107 107 if key not in self.metadata:
108 108 raise AttributeError(key)
109 109 return self.metadata[key]
110 110
111 111 def __repr__(self):
112 112 execute_result = self.metadata['execute_result'] or {'data':{}}
113 113 text_out = execute_result['data'].get('text/plain', '')
114 114 if len(text_out) > 32:
115 115 text_out = text_out[:29] + '...'
116 116
117 117 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
118 118
119 119 def _repr_pretty_(self, p, cycle):
120 120 execute_result = self.metadata['execute_result'] or {'data':{}}
121 121 text_out = execute_result['data'].get('text/plain', '')
122 122
123 123 if not text_out:
124 124 return
125 125
126 126 try:
127 127 ip = get_ipython()
128 128 except NameError:
129 129 colors = "NoColor"
130 130 else:
131 131 colors = ip.colors
132 132
133 133 if colors == "NoColor":
134 134 out = normal = ""
135 135 else:
136 136 out = TermColors.Red
137 137 normal = TermColors.Normal
138 138
139 139 if '\n' in text_out and not text_out.startswith('\n'):
140 140 # add newline for multiline reprs
141 141 text_out = '\n' + text_out
142 142
143 143 p.text(
144 144 out + u'Out[%i:%i]: ' % (
145 145 self.metadata['engine_id'], self.execution_count
146 146 ) + normal + text_out
147 147 )
148 148
149 149
150 150 class Metadata(dict):
151 151 """Subclass of dict for initializing metadata values.
152 152
153 153 Attribute access works on keys.
154 154
155 155 These objects have a strict set of keys - errors will raise if you try
156 156 to add new keys.
157 157 """
158 158 def __init__(self, *args, **kwargs):
159 159 dict.__init__(self)
160 160 md = {'msg_id' : None,
161 161 'submitted' : None,
162 162 'started' : None,
163 163 'completed' : None,
164 164 'received' : None,
165 165 'engine_uuid' : None,
166 166 'engine_id' : None,
167 167 'follow' : None,
168 168 'after' : None,
169 169 'status' : None,
170 170
171 171 'execute_input' : None,
172 172 'execute_result' : None,
173 173 'error' : None,
174 174 'stdout' : '',
175 175 'stderr' : '',
176 176 'outputs' : [],
177 177 'data': {},
178 178 'outputs_ready' : False,
179 179 }
180 180 self.update(md)
181 181 self.update(dict(*args, **kwargs))
182 182
183 183 def __getattr__(self, key):
184 184 """getattr aliased to getitem"""
185 185 if key in self:
186 186 return self[key]
187 187 else:
188 188 raise AttributeError(key)
189 189
190 190 def __setattr__(self, key, value):
191 191 """setattr aliased to setitem, with strict"""
192 192 if key in self:
193 193 self[key] = value
194 194 else:
195 195 raise AttributeError(key)
196 196
197 197 def __setitem__(self, key, value):
198 198 """strict static key enforcement"""
199 199 if key in self:
200 200 dict.__setitem__(self, key, value)
201 201 else:
202 202 raise KeyError(key)
203 203
204 204
205 205 class Client(HasTraits):
206 206 """A semi-synchronous client to the IPython ZMQ cluster
207 207
208 208 Parameters
209 209 ----------
210 210
211 211 url_file : str/unicode; path to ipcontroller-client.json
212 212 This JSON file should contain all the information needed to connect to a cluster,
213 213 and is likely the only argument needed.
214 214 Connection information for the Hub's registration. If a json connector
215 215 file is given, then likely no further configuration is necessary.
216 216 [Default: use profile]
217 217 profile : bytes
218 218 The name of the Cluster profile to be used to find connector information.
219 219 If run from an IPython application, the default profile will be the same
220 220 as the running application, otherwise it will be 'default'.
221 221 cluster_id : str
222 222 String id to added to runtime files, to prevent name collisions when using
223 223 multiple clusters with a single profile simultaneously.
224 224 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
225 225 Since this is text inserted into filenames, typical recommendations apply:
226 226 Simple character strings are ideal, and spaces are not recommended (but
227 227 should generally work)
228 228 context : zmq.Context
229 229 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 230 debug : bool
231 231 flag for lots of message printing for debug purposes
232 232 timeout : int/float
233 233 time (in seconds) to wait for connection replies from the Hub
234 234 [Default: 10]
235 235
236 236 #-------------- session related args ----------------
237 237
238 238 config : Config object
239 239 If specified, this will be relayed to the Session for configuration
240 240 username : str
241 241 set username for the session object
242 242
243 243 #-------------- ssh related args ----------------
244 244 # These are args for configuring the ssh tunnel to be used
245 245 # credentials are used to forward connections over ssh to the Controller
246 246 # Note that the ip given in `addr` needs to be relative to sshserver
247 247 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
248 248 # and set sshserver as the same machine the Controller is on. However,
249 249 # the only requirement is that sshserver is able to see the Controller
250 250 # (i.e. is within the same trusted network).
251 251
252 252 sshserver : str
253 253 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
254 254 If keyfile or password is specified, and this is not, it will default to
255 255 the ip given in addr.
256 256 sshkey : str; path to ssh private key file
257 257 This specifies a key to be used in ssh login, default None.
258 258 Regular default ssh keys will be used without specifying this argument.
259 259 password : str
260 260 Your ssh password to sshserver. Note that if this is left None,
261 261 you will be prompted for it if passwordless key based login is unavailable.
262 262 paramiko : bool
263 263 flag for whether to use paramiko instead of shell ssh for tunneling.
264 264 [default: True on win32, False else]
265 265
266 266
267 267 Attributes
268 268 ----------
269 269
270 270 ids : list of int engine IDs
271 271 requesting the ids attribute always synchronizes
272 272 the registration state. To request ids without synchronization,
273 273 use semi-private _ids attributes.
274 274
275 275 history : list of msg_ids
276 276 a list of msg_ids, keeping track of all the execution
277 277 messages you have submitted in order.
278 278
279 279 outstanding : set of msg_ids
280 280 a set of msg_ids that have been submitted, but whose
281 281 results have not yet been received.
282 282
283 283 results : dict
284 284 a dict of all our results, keyed by msg_id
285 285
286 286 block : bool
287 287 determines default behavior when block not specified
288 288 in execution methods
289 289
290 290 Methods
291 291 -------
292 292
293 293 spin
294 294 flushes incoming results and registration state changes
295 295 control methods spin, and requesting `ids` also ensures up to date
296 296
297 297 wait
298 298 wait on one or more msg_ids
299 299
300 300 execution methods
301 301 apply
302 302 legacy: execute, run
303 303
304 304 data movement
305 305 push, pull, scatter, gather
306 306
307 307 query methods
308 308 queue_status, get_result, purge, result_status
309 309
310 310 control methods
311 311 abort, shutdown
312 312
313 313 """
314 314
315 315
316 316 block = Bool(False)
317 317 outstanding = Set()
318 318 results = Instance('collections.defaultdict', (dict,))
319 319 metadata = Instance('collections.defaultdict', (Metadata,))
320 320 history = List()
321 321 debug = Bool(False)
322 322 _spin_thread = Any()
323 323 _stop_spinning = Any()
324 324
325 325 profile=Unicode()
326 326 def _profile_default(self):
327 327 if BaseIPythonApplication.initialized():
328 328 # an IPython app *might* be running, try to get its profile
329 329 try:
330 330 return BaseIPythonApplication.instance().profile
331 331 except (AttributeError, MultipleInstanceError):
332 332 # could be a *different* subclass of config.Application,
333 333 # which would raise one of these two errors.
334 334 return u'default'
335 335 else:
336 336 return u'default'
337 337
338 338
339 339 _outstanding_dict = Instance('collections.defaultdict', (set,))
340 340 _ids = List()
341 341 _connected=Bool(False)
342 342 _ssh=Bool(False)
343 343 _context = Instance('zmq.Context')
344 344 _config = Dict()
345 345 _engines=Instance(util.ReverseDict, (), {})
346 346 # _hub_socket=Instance('zmq.Socket')
347 347 _query_socket=Instance('zmq.Socket')
348 348 _control_socket=Instance('zmq.Socket')
349 349 _iopub_socket=Instance('zmq.Socket')
350 350 _notification_socket=Instance('zmq.Socket')
351 351 _mux_socket=Instance('zmq.Socket')
352 352 _task_socket=Instance('zmq.Socket')
353 353 _task_scheme=Unicode()
354 354 _closed = False
355 355 _ignored_control_replies=Integer(0)
356 356 _ignored_hub_replies=Integer(0)
357 357
358 358 def __new__(self, *args, **kw):
359 359 # don't raise on positional args
360 360 return HasTraits.__new__(self, **kw)
361 361
362 362 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
363 363 context=None, debug=False,
364 364 sshserver=None, sshkey=None, password=None, paramiko=None,
365 365 timeout=10, cluster_id=None, **extra_args
366 366 ):
367 367 if profile:
368 368 super(Client, self).__init__(debug=debug, profile=profile)
369 369 else:
370 370 super(Client, self).__init__(debug=debug)
371 371 if context is None:
372 372 context = zmq.Context.instance()
373 373 self._context = context
374 374 self._stop_spinning = Event()
375 375
376 376 if 'url_or_file' in extra_args:
377 377 url_file = extra_args['url_or_file']
378 378 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
379 379
380 380 if url_file and util.is_url(url_file):
381 381 raise ValueError("single urls cannot be specified, url-files must be used.")
382 382
383 383 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
384 384
385 385 if self._cd is not None:
386 386 if url_file is None:
387 387 if not cluster_id:
388 388 client_json = 'ipcontroller-client.json'
389 389 else:
390 390 client_json = 'ipcontroller-%s-client.json' % cluster_id
391 391 url_file = pjoin(self._cd.security_dir, client_json)
392 392 if url_file is None:
393 393 raise ValueError(
394 394 "I can't find enough information to connect to a hub!"
395 395 " Please specify at least one of url_file or profile."
396 396 )
397 397
398 398 with open(url_file) as f:
399 399 cfg = json.load(f)
400 400
401 401 self._task_scheme = cfg['task_scheme']
402 402
403 403 # sync defaults from args, json:
404 404 if sshserver:
405 405 cfg['ssh'] = sshserver
406 406
407 407 location = cfg.setdefault('location', None)
408 408
409 409 proto,addr = cfg['interface'].split('://')
410 410 addr = util.disambiguate_ip_address(addr, location)
411 411 cfg['interface'] = "%s://%s" % (proto, addr)
412 412
413 413 # turn interface,port into full urls:
414 414 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
415 415 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
416 416
417 417 url = cfg['registration']
418 418
419 419 if location is not None and addr == localhost():
420 420 # location specified, and connection is expected to be local
421 421 if not is_local_ip(location) and not sshserver:
422 422 # load ssh from JSON *only* if the controller is not on
423 423 # this machine
424 424 sshserver=cfg['ssh']
425 425 if not is_local_ip(location) and not sshserver:
426 426 # warn if no ssh specified, but SSH is probably needed
427 427 # This is only a warning, because the most likely cause
428 428 # is a local Controller on a laptop whose IP is dynamic
429 429 warnings.warn("""
430 430 Controller appears to be listening on localhost, but not on this machine.
431 431 If this is true, you should specify Client(...,sshserver='you@%s')
432 432 or instruct your controller to listen on an external IP."""%location,
433 433 RuntimeWarning)
434 434 elif not sshserver:
435 435 # otherwise sync with cfg
436 436 sshserver = cfg['ssh']
437 437
438 438 self._config = cfg
439 439
440 440 self._ssh = bool(sshserver or sshkey or password)
441 441 if self._ssh and sshserver is None:
442 442 # default to ssh via localhost
443 443 sshserver = addr
444 444 if self._ssh and password is None:
445 445 from zmq.ssh import tunnel
446 446 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
447 447 password=False
448 448 else:
449 449 password = getpass("SSH Password for %s: "%sshserver)
450 450 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
451 451
452 452 # configure and construct the session
453 453 try:
454 454 extra_args['packer'] = cfg['pack']
455 455 extra_args['unpacker'] = cfg['unpack']
456 456 extra_args['key'] = cast_bytes(cfg['key'])
457 457 extra_args['signature_scheme'] = cfg['signature_scheme']
458 458 except KeyError as exc:
459 459 msg = '\n'.join([
460 460 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
461 461 "If you are reusing connection files, remove them and start ipcontroller again."
462 462 ])
463 463 raise ValueError(msg.format(exc.message))
464 464
465 465 self.session = Session(**extra_args)
466 466
467 467 self._query_socket = self._context.socket(zmq.DEALER)
468 468
469 469 if self._ssh:
470 470 from zmq.ssh import tunnel
471 471 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
472 472 else:
473 473 self._query_socket.connect(cfg['registration'])
474 474
475 475 self.session.debug = self.debug
476 476
477 477 self._notification_handlers = {'registration_notification' : self._register_engine,
478 478 'unregistration_notification' : self._unregister_engine,
479 479 'shutdown_notification' : lambda msg: self.close(),
480 480 }
481 481 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 482 'apply_reply' : self._handle_apply_reply}
483 483
484 484 try:
485 485 self._connect(sshserver, ssh_kwargs, timeout)
486 486 except:
487 487 self.close(linger=0)
488 488 raise
489 489
490 490 # last step: setup magics, if we are in IPython:
491 491
492 492 try:
493 493 ip = get_ipython()
494 494 except NameError:
495 495 return
496 496 else:
497 497 if 'px' not in ip.magics_manager.magics:
498 498 # in IPython but we are the first Client.
499 499 # activate a default view for parallel magics.
500 500 self.activate()
501 501
502 502 def __del__(self):
503 503 """cleanup sockets, but _not_ context."""
504 504 self.close()
505 505
506 506 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
507 507 if ipython_dir is None:
508 508 ipython_dir = get_ipython_dir()
509 509 if profile_dir is not None:
510 510 try:
511 511 self._cd = ProfileDir.find_profile_dir(profile_dir)
512 512 return
513 513 except ProfileDirError:
514 514 pass
515 515 elif profile is not None:
516 516 try:
517 517 self._cd = ProfileDir.find_profile_dir_by_name(
518 518 ipython_dir, profile)
519 519 return
520 520 except ProfileDirError:
521 521 pass
522 522 self._cd = None
523 523
524 524 def _update_engines(self, engines):
525 525 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
526 526 for k,v in iteritems(engines):
527 527 eid = int(k)
528 528 if eid not in self._engines:
529 529 self._ids.append(eid)
530 530 self._engines[eid] = v
531 531 self._ids = sorted(self._ids)
532 532 if sorted(self._engines.keys()) != list(range(len(self._engines))) and \
533 533 self._task_scheme == 'pure' and self._task_socket:
534 534 self._stop_scheduling_tasks()
535 535
536 536 def _stop_scheduling_tasks(self):
537 537 """Stop scheduling tasks because an engine has been unregistered
538 538 from a pure ZMQ scheduler.
539 539 """
540 540 self._task_socket.close()
541 541 self._task_socket = None
542 542 msg = "An engine has been unregistered, and we are using pure " +\
543 543 "ZMQ task scheduling. Task farming will be disabled."
544 544 if self.outstanding:
545 545 msg += " If you were running tasks when this happened, " +\
546 546 "some `outstanding` msg_ids may never resolve."
547 547 warnings.warn(msg, RuntimeWarning)
548 548
549 549 def _build_targets(self, targets):
550 550 """Turn valid target IDs or 'all' into two lists:
551 551 (int_ids, uuids).
552 552 """
553 553 if not self._ids:
554 554 # flush notification socket if no engines yet, just in case
555 555 if not self.ids:
556 556 raise error.NoEnginesRegistered("Can't build targets without any engines")
557 557
558 558 if targets is None:
559 559 targets = self._ids
560 560 elif isinstance(targets, string_types):
561 561 if targets.lower() == 'all':
562 562 targets = self._ids
563 563 else:
564 564 raise TypeError("%r not valid str target, must be 'all'"%(targets))
565 565 elif isinstance(targets, int):
566 566 if targets < 0:
567 567 targets = self.ids[targets]
568 568 if targets not in self._ids:
569 569 raise IndexError("No such engine: %i"%targets)
570 570 targets = [targets]
571 571
572 572 if isinstance(targets, slice):
573 573 indices = list(range(len(self._ids))[targets])
574 574 ids = self.ids
575 575 targets = [ ids[i] for i in indices ]
576 576
577 577 if not isinstance(targets, (tuple, list, xrange)):
578 578 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
579 579
580 580 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
581 581
582 582 def _connect(self, sshserver, ssh_kwargs, timeout):
583 583 """setup all our socket connections to the cluster. This is called from
584 584 __init__."""
585 585
586 586 # Maybe allow reconnecting?
587 587 if self._connected:
588 588 return
589 589 self._connected=True
590 590
591 591 def connect_socket(s, url):
592 592 if self._ssh:
593 593 from zmq.ssh import tunnel
594 594 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
595 595 else:
596 596 return s.connect(url)
597 597
598 598 self.session.send(self._query_socket, 'connection_request')
599 599 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
600 600 poller = zmq.Poller()
601 601 poller.register(self._query_socket, zmq.POLLIN)
602 602 # poll expects milliseconds, timeout is seconds
603 603 evts = poller.poll(timeout*1000)
604 604 if not evts:
605 605 raise error.TimeoutError("Hub connection request timed out")
606 606 idents,msg = self.session.recv(self._query_socket,mode=0)
607 607 if self.debug:
608 608 pprint(msg)
609 609 content = msg['content']
610 610 # self._config['registration'] = dict(content)
611 611 cfg = self._config
612 612 if content['status'] == 'ok':
613 613 self._mux_socket = self._context.socket(zmq.DEALER)
614 614 connect_socket(self._mux_socket, cfg['mux'])
615 615
616 616 self._task_socket = self._context.socket(zmq.DEALER)
617 617 connect_socket(self._task_socket, cfg['task'])
618 618
619 619 self._notification_socket = self._context.socket(zmq.SUB)
620 620 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
621 621 connect_socket(self._notification_socket, cfg['notification'])
622 622
623 623 self._control_socket = self._context.socket(zmq.DEALER)
624 624 connect_socket(self._control_socket, cfg['control'])
625 625
626 626 self._iopub_socket = self._context.socket(zmq.SUB)
627 627 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
628 628 connect_socket(self._iopub_socket, cfg['iopub'])
629 629
630 630 self._update_engines(dict(content['engines']))
631 631 else:
632 632 self._connected = False
633 633 raise Exception("Failed to connect!")
634 634
635 635 #--------------------------------------------------------------------------
636 636 # handlers and callbacks for incoming messages
637 637 #--------------------------------------------------------------------------
638 638
639 639 def _unwrap_exception(self, content):
640 640 """unwrap exception, and remap engine_id to int."""
641 641 e = error.unwrap_exception(content)
642 642 # print e.traceback
643 643 if e.engine_info:
644 644 e_uuid = e.engine_info['engine_uuid']
645 645 eid = self._engines[e_uuid]
646 646 e.engine_info['engine_id'] = eid
647 647 return e
648 648
649 649 def _extract_metadata(self, msg):
650 650 header = msg['header']
651 651 parent = msg['parent_header']
652 652 msg_meta = msg['metadata']
653 653 content = msg['content']
654 654 md = {'msg_id' : parent['msg_id'],
655 655 'received' : datetime.now(),
656 656 'engine_uuid' : msg_meta.get('engine', None),
657 657 'follow' : msg_meta.get('follow', []),
658 658 'after' : msg_meta.get('after', []),
659 659 'status' : content['status'],
660 660 }
661 661
662 662 if md['engine_uuid'] is not None:
663 663 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
664 664
665 665 if 'date' in parent:
666 666 md['submitted'] = parent['date']
667 667 if 'started' in msg_meta:
668 668 md['started'] = parse_date(msg_meta['started'])
669 669 if 'date' in header:
670 670 md['completed'] = header['date']
671 671 return md
672 672
673 673 def _register_engine(self, msg):
674 674 """Register a new engine, and update our connection info."""
675 675 content = msg['content']
676 676 eid = content['id']
677 677 d = {eid : content['uuid']}
678 678 self._update_engines(d)
679 679
680 680 def _unregister_engine(self, msg):
681 681 """Unregister an engine that has died."""
682 682 content = msg['content']
683 683 eid = int(content['id'])
684 684 if eid in self._ids:
685 685 self._ids.remove(eid)
686 686 uuid = self._engines.pop(eid)
687 687
688 688 self._handle_stranded_msgs(eid, uuid)
689 689
690 690 if self._task_socket and self._task_scheme == 'pure':
691 691 self._stop_scheduling_tasks()
692 692
693 693 def _handle_stranded_msgs(self, eid, uuid):
694 694 """Handle messages known to be on an engine when the engine unregisters.
695 695
696 696 It is possible that this will fire prematurely - that is, an engine will
697 697 go down after completing a result, and the client will be notified
698 698 of the unregistration and later receive the successful result.
699 699 """
700 700
701 701 outstanding = self._outstanding_dict[uuid]
702 702
703 703 for msg_id in list(outstanding):
704 704 if msg_id in self.results:
705 705 # we already
706 706 continue
707 707 try:
708 708 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
709 709 except:
710 710 content = error.wrap_exception()
711 711 # build a fake message:
712 712 msg = self.session.msg('apply_reply', content=content)
713 713 msg['parent_header']['msg_id'] = msg_id
714 714 msg['metadata']['engine'] = uuid
715 715 self._handle_apply_reply(msg)
716 716
717 717 def _handle_execute_reply(self, msg):
718 718 """Save the reply to an execute_request into our results.
719 719
720 720 execute messages are never actually used. apply is used instead.
721 721 """
722 722
723 723 parent = msg['parent_header']
724 724 msg_id = parent['msg_id']
725 725 if msg_id not in self.outstanding:
726 726 if msg_id in self.history:
727 727 print("got stale result: %s"%msg_id)
728 728 else:
729 729 print("got unknown result: %s"%msg_id)
730 730 else:
731 731 self.outstanding.remove(msg_id)
732 732
733 733 content = msg['content']
734 734 header = msg['header']
735 735
736 736 # construct metadata:
737 737 md = self.metadata[msg_id]
738 738 md.update(self._extract_metadata(msg))
739 739 # is this redundant?
740 740 self.metadata[msg_id] = md
741 741
742 742 e_outstanding = self._outstanding_dict[md['engine_uuid']]
743 743 if msg_id in e_outstanding:
744 744 e_outstanding.remove(msg_id)
745 745
746 746 # construct result:
747 747 if content['status'] == 'ok':
748 748 self.results[msg_id] = ExecuteReply(msg_id, content, md)
749 749 elif content['status'] == 'aborted':
750 750 self.results[msg_id] = error.TaskAborted(msg_id)
751 751 elif content['status'] == 'resubmitted':
752 752 # TODO: handle resubmission
753 753 pass
754 754 else:
755 755 self.results[msg_id] = self._unwrap_exception(content)
756 756
757 757 def _handle_apply_reply(self, msg):
758 758 """Save the reply to an apply_request into our results."""
759 759 parent = msg['parent_header']
760 760 msg_id = parent['msg_id']
761 761 if msg_id not in self.outstanding:
762 762 if msg_id in self.history:
763 763 print("got stale result: %s"%msg_id)
764 764 print(self.results[msg_id])
765 765 print(msg)
766 766 else:
767 767 print("got unknown result: %s"%msg_id)
768 768 else:
769 769 self.outstanding.remove(msg_id)
770 770 content = msg['content']
771 771 header = msg['header']
772 772
773 773 # construct metadata:
774 774 md = self.metadata[msg_id]
775 775 md.update(self._extract_metadata(msg))
776 776 # is this redundant?
777 777 self.metadata[msg_id] = md
778 778
779 779 e_outstanding = self._outstanding_dict[md['engine_uuid']]
780 780 if msg_id in e_outstanding:
781 781 e_outstanding.remove(msg_id)
782 782
783 783 # construct result:
784 784 if content['status'] == 'ok':
785 785 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
786 786 elif content['status'] == 'aborted':
787 787 self.results[msg_id] = error.TaskAborted(msg_id)
788 788 elif content['status'] == 'resubmitted':
789 789 # TODO: handle resubmission
790 790 pass
791 791 else:
792 792 self.results[msg_id] = self._unwrap_exception(content)
793 793
794 794 def _flush_notifications(self):
795 795 """Flush notifications of engine registrations waiting
796 796 in ZMQ queue."""
797 797 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
798 798 while msg is not None:
799 799 if self.debug:
800 800 pprint(msg)
801 801 msg_type = msg['header']['msg_type']
802 802 handler = self._notification_handlers.get(msg_type, None)
803 803 if handler is None:
804 804 raise Exception("Unhandled message type: %s" % msg_type)
805 805 else:
806 806 handler(msg)
807 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 808
809 809 def _flush_results(self, sock):
810 810 """Flush task or queue results waiting in ZMQ queue."""
811 811 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
812 812 while msg is not None:
813 813 if self.debug:
814 814 pprint(msg)
815 815 msg_type = msg['header']['msg_type']
816 816 handler = self._queue_handlers.get(msg_type, None)
817 817 if handler is None:
818 818 raise Exception("Unhandled message type: %s" % msg_type)
819 819 else:
820 820 handler(msg)
821 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 822
823 823 def _flush_control(self, sock):
824 824 """Flush replies from the control channel waiting
825 825 in the ZMQ queue.
826 826
827 827 Currently: ignore them."""
828 828 if self._ignored_control_replies <= 0:
829 829 return
830 830 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 831 while msg is not None:
832 832 self._ignored_control_replies -= 1
833 833 if self.debug:
834 834 pprint(msg)
835 835 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
836 836
837 837 def _flush_ignored_control(self):
838 838 """flush ignored control replies"""
839 839 while self._ignored_control_replies > 0:
840 840 self.session.recv(self._control_socket)
841 841 self._ignored_control_replies -= 1
842 842
843 843 def _flush_ignored_hub_replies(self):
844 844 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
845 845 while msg is not None:
846 846 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
847 847
848 848 def _flush_iopub(self, sock):
849 849 """Flush replies from the iopub channel waiting
850 850 in the ZMQ queue.
851 851 """
852 852 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
853 853 while msg is not None:
854 854 if self.debug:
855 855 pprint(msg)
856 856 parent = msg['parent_header']
857 # ignore IOPub messages with no parent.
858 # Caused by print statements or warnings from before the first execution.
859 if not parent:
857 if not parent or parent['session'] != self.session.session:
858 # ignore IOPub messages not from here
860 859 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
861 860 continue
862 861 msg_id = parent['msg_id']
863 862 content = msg['content']
864 863 header = msg['header']
865 864 msg_type = msg['header']['msg_type']
865
866 if msg_type == 'status' and msg_id not in self.metadata:
867 # ignore status messages if they aren't mine
868 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
869 continue
866 870
867 871 # init metadata:
868 872 md = self.metadata[msg_id]
869 873
870 874 if msg_type == 'stream':
871 875 name = content['name']
872 876 s = md[name] or ''
873 877 md[name] = s + content['data']
874 878 elif msg_type == 'error':
875 879 md.update({'error' : self._unwrap_exception(content)})
876 880 elif msg_type == 'execute_input':
877 881 md.update({'execute_input' : content['code']})
878 882 elif msg_type == 'display_data':
879 883 md['outputs'].append(content)
880 884 elif msg_type == 'execute_result':
881 885 md['execute_result'] = content
882 886 elif msg_type == 'data_message':
883 887 data, remainder = serialize.unserialize_object(msg['buffers'])
884 888 md['data'].update(data)
885 889 elif msg_type == 'status':
886 890 # idle message comes after all outputs
887 891 if content['execution_state'] == 'idle':
888 892 md['outputs_ready'] = True
889 893 else:
890 894 # unhandled msg_type (status, etc.)
891 895 pass
892 896
893 897 # reduntant?
894 898 self.metadata[msg_id] = md
895 899
896 900 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
897 901
898 902 #--------------------------------------------------------------------------
899 903 # len, getitem
900 904 #--------------------------------------------------------------------------
901 905
902 906 def __len__(self):
903 907 """len(client) returns # of engines."""
904 908 return len(self.ids)
905 909
906 910 def __getitem__(self, key):
907 911 """index access returns DirectView multiplexer objects
908 912
909 913 Must be int, slice, or list/tuple/xrange of ints"""
910 914 if not isinstance(key, (int, slice, tuple, list, xrange)):
911 915 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
912 916 else:
913 917 return self.direct_view(key)
914 918
915 919 def __iter__(self):
916 920 """Since we define getitem, Client is iterable
917 921
918 922 but unless we also define __iter__, it won't work correctly unless engine IDs
919 923 start at zero and are continuous.
920 924 """
921 925 for eid in self.ids:
922 926 yield self.direct_view(eid)
923 927
924 928 #--------------------------------------------------------------------------
925 929 # Begin public methods
926 930 #--------------------------------------------------------------------------
927 931
928 932 @property
929 933 def ids(self):
930 934 """Always up-to-date ids property."""
931 935 self._flush_notifications()
932 936 # always copy:
933 937 return list(self._ids)
934 938
935 939 def activate(self, targets='all', suffix=''):
936 940 """Create a DirectView and register it with IPython magics
937 941
938 942 Defines the magics `%px, %autopx, %pxresult, %%px`
939 943
940 944 Parameters
941 945 ----------
942 946
943 947 targets: int, list of ints, or 'all'
944 948 The engines on which the view's magics will run
945 949 suffix: str [default: '']
946 950 The suffix, if any, for the magics. This allows you to have
947 951 multiple views associated with parallel magics at the same time.
948 952
949 953 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
950 954 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
951 955 on engine 0.
952 956 """
953 957 view = self.direct_view(targets)
954 958 view.block = True
955 959 view.activate(suffix)
956 960 return view
957 961
958 962 def close(self, linger=None):
959 963 """Close my zmq Sockets
960 964
961 965 If `linger`, set the zmq LINGER socket option,
962 966 which allows discarding of messages.
963 967 """
964 968 if self._closed:
965 969 return
966 970 self.stop_spin_thread()
967 971 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
968 972 for name in snames:
969 973 socket = getattr(self, name)
970 974 if socket is not None and not socket.closed:
971 975 if linger is not None:
972 976 socket.close(linger=linger)
973 977 else:
974 978 socket.close()
975 979 self._closed = True
976 980
977 981 def _spin_every(self, interval=1):
978 982 """target func for use in spin_thread"""
979 983 while True:
980 984 if self._stop_spinning.is_set():
981 985 return
982 986 time.sleep(interval)
983 987 self.spin()
984 988
985 989 def spin_thread(self, interval=1):
986 990 """call Client.spin() in a background thread on some regular interval
987 991
988 992 This helps ensure that messages don't pile up too much in the zmq queue
989 993 while you are working on other things, or just leaving an idle terminal.
990 994
991 995 It also helps limit potential padding of the `received` timestamp
992 996 on AsyncResult objects, used for timings.
993 997
994 998 Parameters
995 999 ----------
996 1000
997 1001 interval : float, optional
998 1002 The interval on which to spin the client in the background thread
999 1003 (simply passed to time.sleep).
1000 1004
1001 1005 Notes
1002 1006 -----
1003 1007
1004 1008 For precision timing, you may want to use this method to put a bound
1005 1009 on the jitter (in seconds) in `received` timestamps used
1006 1010 in AsyncResult.wall_time.
1007 1011
1008 1012 """
1009 1013 if self._spin_thread is not None:
1010 1014 self.stop_spin_thread()
1011 1015 self._stop_spinning.clear()
1012 1016 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1013 1017 self._spin_thread.daemon = True
1014 1018 self._spin_thread.start()
1015 1019
1016 1020 def stop_spin_thread(self):
1017 1021 """stop background spin_thread, if any"""
1018 1022 if self._spin_thread is not None:
1019 1023 self._stop_spinning.set()
1020 1024 self._spin_thread.join()
1021 1025 self._spin_thread = None
1022 1026
1023 1027 def spin(self):
1024 1028 """Flush any registration notifications and execution results
1025 1029 waiting in the ZMQ queue.
1026 1030 """
1027 1031 if self._notification_socket:
1028 1032 self._flush_notifications()
1029 1033 if self._iopub_socket:
1030 1034 self._flush_iopub(self._iopub_socket)
1031 1035 if self._mux_socket:
1032 1036 self._flush_results(self._mux_socket)
1033 1037 if self._task_socket:
1034 1038 self._flush_results(self._task_socket)
1035 1039 if self._control_socket:
1036 1040 self._flush_control(self._control_socket)
1037 1041 if self._query_socket:
1038 1042 self._flush_ignored_hub_replies()
1039 1043
1040 1044 def wait(self, jobs=None, timeout=-1):
1041 1045 """waits on one or more `jobs`, for up to `timeout` seconds.
1042 1046
1043 1047 Parameters
1044 1048 ----------
1045 1049
1046 1050 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1047 1051 ints are indices to self.history
1048 1052 strs are msg_ids
1049 1053 default: wait on all outstanding messages
1050 1054 timeout : float
1051 1055 a time in seconds, after which to give up.
1052 1056 default is -1, which means no timeout
1053 1057
1054 1058 Returns
1055 1059 -------
1056 1060
1057 1061 True : when all msg_ids are done
1058 1062 False : timeout reached, some msg_ids still outstanding
1059 1063 """
1060 1064 tic = time.time()
1061 1065 if jobs is None:
1062 1066 theids = self.outstanding
1063 1067 else:
1064 1068 if isinstance(jobs, string_types + (int, AsyncResult)):
1065 1069 jobs = [jobs]
1066 1070 theids = set()
1067 1071 for job in jobs:
1068 1072 if isinstance(job, int):
1069 1073 # index access
1070 1074 job = self.history[job]
1071 1075 elif isinstance(job, AsyncResult):
1072 1076 theids.update(job.msg_ids)
1073 1077 continue
1074 1078 theids.add(job)
1075 1079 if not theids.intersection(self.outstanding):
1076 1080 return True
1077 1081 self.spin()
1078 1082 while theids.intersection(self.outstanding):
1079 1083 if timeout >= 0 and ( time.time()-tic ) > timeout:
1080 1084 break
1081 1085 time.sleep(1e-3)
1082 1086 self.spin()
1083 1087 return len(theids.intersection(self.outstanding)) == 0
1084 1088
1085 1089 #--------------------------------------------------------------------------
1086 1090 # Control methods
1087 1091 #--------------------------------------------------------------------------
1088 1092
1089 1093 @spin_first
1090 1094 def clear(self, targets=None, block=None):
1091 1095 """Clear the namespace in target(s)."""
1092 1096 block = self.block if block is None else block
1093 1097 targets = self._build_targets(targets)[0]
1094 1098 for t in targets:
1095 1099 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1096 1100 error = False
1097 1101 if block:
1098 1102 self._flush_ignored_control()
1099 1103 for i in range(len(targets)):
1100 1104 idents,msg = self.session.recv(self._control_socket,0)
1101 1105 if self.debug:
1102 1106 pprint(msg)
1103 1107 if msg['content']['status'] != 'ok':
1104 1108 error = self._unwrap_exception(msg['content'])
1105 1109 else:
1106 1110 self._ignored_control_replies += len(targets)
1107 1111 if error:
1108 1112 raise error
1109 1113
1110 1114
1111 1115 @spin_first
1112 1116 def abort(self, jobs=None, targets=None, block=None):
1113 1117 """Abort specific jobs from the execution queues of target(s).
1114 1118
1115 1119 This is a mechanism to prevent jobs that have already been submitted
1116 1120 from executing.
1117 1121
1118 1122 Parameters
1119 1123 ----------
1120 1124
1121 1125 jobs : msg_id, list of msg_ids, or AsyncResult
1122 1126 The jobs to be aborted
1123 1127
1124 1128 If unspecified/None: abort all outstanding jobs.
1125 1129
1126 1130 """
1127 1131 block = self.block if block is None else block
1128 1132 jobs = jobs if jobs is not None else list(self.outstanding)
1129 1133 targets = self._build_targets(targets)[0]
1130 1134
1131 1135 msg_ids = []
1132 1136 if isinstance(jobs, string_types + (AsyncResult,)):
1133 1137 jobs = [jobs]
1134 1138 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1135 1139 if bad_ids:
1136 1140 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1137 1141 for j in jobs:
1138 1142 if isinstance(j, AsyncResult):
1139 1143 msg_ids.extend(j.msg_ids)
1140 1144 else:
1141 1145 msg_ids.append(j)
1142 1146 content = dict(msg_ids=msg_ids)
1143 1147 for t in targets:
1144 1148 self.session.send(self._control_socket, 'abort_request',
1145 1149 content=content, ident=t)
1146 1150 error = False
1147 1151 if block:
1148 1152 self._flush_ignored_control()
1149 1153 for i in range(len(targets)):
1150 1154 idents,msg = self.session.recv(self._control_socket,0)
1151 1155 if self.debug:
1152 1156 pprint(msg)
1153 1157 if msg['content']['status'] != 'ok':
1154 1158 error = self._unwrap_exception(msg['content'])
1155 1159 else:
1156 1160 self._ignored_control_replies += len(targets)
1157 1161 if error:
1158 1162 raise error
1159 1163
1160 1164 @spin_first
1161 1165 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1162 1166 """Terminates one or more engine processes, optionally including the hub.
1163 1167
1164 1168 Parameters
1165 1169 ----------
1166 1170
1167 1171 targets: list of ints or 'all' [default: all]
1168 1172 Which engines to shutdown.
1169 1173 hub: bool [default: False]
1170 1174 Whether to include the Hub. hub=True implies targets='all'.
1171 1175 block: bool [default: self.block]
1172 1176 Whether to wait for clean shutdown replies or not.
1173 1177 restart: bool [default: False]
1174 1178 NOT IMPLEMENTED
1175 1179 whether to restart engines after shutting them down.
1176 1180 """
1177 1181 from IPython.parallel.error import NoEnginesRegistered
1178 1182 if restart:
1179 1183 raise NotImplementedError("Engine restart is not yet implemented")
1180 1184
1181 1185 block = self.block if block is None else block
1182 1186 if hub:
1183 1187 targets = 'all'
1184 1188 try:
1185 1189 targets = self._build_targets(targets)[0]
1186 1190 except NoEnginesRegistered:
1187 1191 targets = []
1188 1192 for t in targets:
1189 1193 self.session.send(self._control_socket, 'shutdown_request',
1190 1194 content={'restart':restart},ident=t)
1191 1195 error = False
1192 1196 if block or hub:
1193 1197 self._flush_ignored_control()
1194 1198 for i in range(len(targets)):
1195 1199 idents,msg = self.session.recv(self._control_socket, 0)
1196 1200 if self.debug:
1197 1201 pprint(msg)
1198 1202 if msg['content']['status'] != 'ok':
1199 1203 error = self._unwrap_exception(msg['content'])
1200 1204 else:
1201 1205 self._ignored_control_replies += len(targets)
1202 1206
1203 1207 if hub:
1204 1208 time.sleep(0.25)
1205 1209 self.session.send(self._query_socket, 'shutdown_request')
1206 1210 idents,msg = self.session.recv(self._query_socket, 0)
1207 1211 if self.debug:
1208 1212 pprint(msg)
1209 1213 if msg['content']['status'] != 'ok':
1210 1214 error = self._unwrap_exception(msg['content'])
1211 1215
1212 1216 if error:
1213 1217 raise error
1214 1218
1215 1219 #--------------------------------------------------------------------------
1216 1220 # Execution related methods
1217 1221 #--------------------------------------------------------------------------
1218 1222
1219 1223 def _maybe_raise(self, result):
1220 1224 """wrapper for maybe raising an exception if apply failed."""
1221 1225 if isinstance(result, error.RemoteError):
1222 1226 raise result
1223 1227
1224 1228 return result
1225 1229
1226 1230 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1227 1231 ident=None):
1228 1232 """construct and send an apply message via a socket.
1229 1233
1230 1234 This is the principal method with which all engine execution is performed by views.
1231 1235 """
1232 1236
1233 1237 if self._closed:
1234 1238 raise RuntimeError("Client cannot be used after its sockets have been closed")
1235 1239
1236 1240 # defaults:
1237 1241 args = args if args is not None else []
1238 1242 kwargs = kwargs if kwargs is not None else {}
1239 1243 metadata = metadata if metadata is not None else {}
1240 1244
1241 1245 # validate arguments
1242 1246 if not callable(f) and not isinstance(f, Reference):
1243 1247 raise TypeError("f must be callable, not %s"%type(f))
1244 1248 if not isinstance(args, (tuple, list)):
1245 1249 raise TypeError("args must be tuple or list, not %s"%type(args))
1246 1250 if not isinstance(kwargs, dict):
1247 1251 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1248 1252 if not isinstance(metadata, dict):
1249 1253 raise TypeError("metadata must be dict, not %s"%type(metadata))
1250 1254
1251 1255 bufs = serialize.pack_apply_message(f, args, kwargs,
1252 1256 buffer_threshold=self.session.buffer_threshold,
1253 1257 item_threshold=self.session.item_threshold,
1254 1258 )
1255 1259
1256 1260 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1257 1261 metadata=metadata, track=track)
1258 1262
1259 1263 msg_id = msg['header']['msg_id']
1260 1264 self.outstanding.add(msg_id)
1261 1265 if ident:
1262 1266 # possibly routed to a specific engine
1263 1267 if isinstance(ident, list):
1264 1268 ident = ident[-1]
1265 1269 if ident in self._engines.values():
1266 1270 # save for later, in case of engine death
1267 1271 self._outstanding_dict[ident].add(msg_id)
1268 1272 self.history.append(msg_id)
1269 1273 self.metadata[msg_id]['submitted'] = datetime.now()
1270 1274
1271 1275 return msg
1272 1276
1273 1277 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1274 1278 """construct and send an execute request via a socket.
1275 1279
1276 1280 """
1277 1281
1278 1282 if self._closed:
1279 1283 raise RuntimeError("Client cannot be used after its sockets have been closed")
1280 1284
1281 1285 # defaults:
1282 1286 metadata = metadata if metadata is not None else {}
1283 1287
1284 1288 # validate arguments
1285 1289 if not isinstance(code, string_types):
1286 1290 raise TypeError("code must be text, not %s" % type(code))
1287 1291 if not isinstance(metadata, dict):
1288 1292 raise TypeError("metadata must be dict, not %s" % type(metadata))
1289 1293
1290 1294 content = dict(code=code, silent=bool(silent), user_expressions={})
1291 1295
1292 1296
1293 1297 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1294 1298 metadata=metadata)
1295 1299
1296 1300 msg_id = msg['header']['msg_id']
1297 1301 self.outstanding.add(msg_id)
1298 1302 if ident:
1299 1303 # possibly routed to a specific engine
1300 1304 if isinstance(ident, list):
1301 1305 ident = ident[-1]
1302 1306 if ident in self._engines.values():
1303 1307 # save for later, in case of engine death
1304 1308 self._outstanding_dict[ident].add(msg_id)
1305 1309 self.history.append(msg_id)
1306 1310 self.metadata[msg_id]['submitted'] = datetime.now()
1307 1311
1308 1312 return msg
1309 1313
1310 1314 #--------------------------------------------------------------------------
1311 1315 # construct a View object
1312 1316 #--------------------------------------------------------------------------
1313 1317
1314 1318 def load_balanced_view(self, targets=None):
1315 1319 """construct a DirectView object.
1316 1320
1317 1321 If no arguments are specified, create a LoadBalancedView
1318 1322 using all engines.
1319 1323
1320 1324 Parameters
1321 1325 ----------
1322 1326
1323 1327 targets: list,slice,int,etc. [default: use all engines]
1324 1328 The subset of engines across which to load-balance
1325 1329 """
1326 1330 if targets == 'all':
1327 1331 targets = None
1328 1332 if targets is not None:
1329 1333 targets = self._build_targets(targets)[1]
1330 1334 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1331 1335
1332 1336 def direct_view(self, targets='all'):
1333 1337 """construct a DirectView object.
1334 1338
1335 1339 If no targets are specified, create a DirectView using all engines.
1336 1340
1337 1341 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1338 1342 evaluate the target engines at each execution, whereas rc[:] will connect to
1339 1343 all *current* engines, and that list will not change.
1340 1344
1341 1345 That is, 'all' will always use all engines, whereas rc[:] will not use
1342 1346 engines added after the DirectView is constructed.
1343 1347
1344 1348 Parameters
1345 1349 ----------
1346 1350
1347 1351 targets: list,slice,int,etc. [default: use all engines]
1348 1352 The engines to use for the View
1349 1353 """
1350 1354 single = isinstance(targets, int)
1351 1355 # allow 'all' to be lazily evaluated at each execution
1352 1356 if targets != 'all':
1353 1357 targets = self._build_targets(targets)[1]
1354 1358 if single:
1355 1359 targets = targets[0]
1356 1360 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1357 1361
1358 1362 #--------------------------------------------------------------------------
1359 1363 # Query methods
1360 1364 #--------------------------------------------------------------------------
1361 1365
1362 1366 @spin_first
1363 1367 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
1364 1368 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1365 1369
1366 1370 If the client already has the results, no request to the Hub will be made.
1367 1371
1368 1372 This is a convenient way to construct AsyncResult objects, which are wrappers
1369 1373 that include metadata about execution, and allow for awaiting results that
1370 1374 were not submitted by this Client.
1371 1375
1372 1376 It can also be a convenient way to retrieve the metadata associated with
1373 1377 blocking execution, since it always retrieves
1374 1378
1375 1379 Examples
1376 1380 --------
1377 1381 ::
1378 1382
1379 1383 In [10]: r = client.apply()
1380 1384
1381 1385 Parameters
1382 1386 ----------
1383 1387
1384 1388 indices_or_msg_ids : integer history index, str msg_id, or list of either
1385 1389 The indices or msg_ids of indices to be retrieved
1386 1390
1387 1391 block : bool
1388 1392 Whether to wait for the result to be done
1389 1393 owner : bool [default: True]
1390 1394 Whether this AsyncResult should own the result.
1391 1395 If so, calling `ar.get()` will remove data from the
1392 1396 client's result and metadata cache.
1393 1397 There should only be one owner of any given msg_id.
1394 1398
1395 1399 Returns
1396 1400 -------
1397 1401
1398 1402 AsyncResult
1399 1403 A single AsyncResult object will always be returned.
1400 1404
1401 1405 AsyncHubResult
1402 1406 A subclass of AsyncResult that retrieves results from the Hub
1403 1407
1404 1408 """
1405 1409 block = self.block if block is None else block
1406 1410 if indices_or_msg_ids is None:
1407 1411 indices_or_msg_ids = -1
1408 1412
1409 1413 single_result = False
1410 1414 if not isinstance(indices_or_msg_ids, (list,tuple)):
1411 1415 indices_or_msg_ids = [indices_or_msg_ids]
1412 1416 single_result = True
1413 1417
1414 1418 theids = []
1415 1419 for id in indices_or_msg_ids:
1416 1420 if isinstance(id, int):
1417 1421 id = self.history[id]
1418 1422 if not isinstance(id, string_types):
1419 1423 raise TypeError("indices must be str or int, not %r"%id)
1420 1424 theids.append(id)
1421 1425
1422 1426 local_ids = [msg_id for msg_id in theids if (msg_id in self.outstanding or msg_id in self.results)]
1423 1427 remote_ids = [msg_id for msg_id in theids if msg_id not in local_ids]
1424 1428
1425 1429 # given single msg_id initially, get_result shot get the result itself,
1426 1430 # not a length-one list
1427 1431 if single_result:
1428 1432 theids = theids[0]
1429 1433
1430 1434 if remote_ids:
1431 1435 ar = AsyncHubResult(self, msg_ids=theids, owner=owner)
1432 1436 else:
1433 1437 ar = AsyncResult(self, msg_ids=theids, owner=owner)
1434 1438
1435 1439 if block:
1436 1440 ar.wait()
1437 1441
1438 1442 return ar
1439 1443
1440 1444 @spin_first
1441 1445 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1442 1446 """Resubmit one or more tasks.
1443 1447
1444 1448 in-flight tasks may not be resubmitted.
1445 1449
1446 1450 Parameters
1447 1451 ----------
1448 1452
1449 1453 indices_or_msg_ids : integer history index, str msg_id, or list of either
1450 1454 The indices or msg_ids of indices to be retrieved
1451 1455
1452 1456 block : bool
1453 1457 Whether to wait for the result to be done
1454 1458
1455 1459 Returns
1456 1460 -------
1457 1461
1458 1462 AsyncHubResult
1459 1463 A subclass of AsyncResult that retrieves results from the Hub
1460 1464
1461 1465 """
1462 1466 block = self.block if block is None else block
1463 1467 if indices_or_msg_ids is None:
1464 1468 indices_or_msg_ids = -1
1465 1469
1466 1470 if not isinstance(indices_or_msg_ids, (list,tuple)):
1467 1471 indices_or_msg_ids = [indices_or_msg_ids]
1468 1472
1469 1473 theids = []
1470 1474 for id in indices_or_msg_ids:
1471 1475 if isinstance(id, int):
1472 1476 id = self.history[id]
1473 1477 if not isinstance(id, string_types):
1474 1478 raise TypeError("indices must be str or int, not %r"%id)
1475 1479 theids.append(id)
1476 1480
1477 1481 content = dict(msg_ids = theids)
1478 1482
1479 1483 self.session.send(self._query_socket, 'resubmit_request', content)
1480 1484
1481 1485 zmq.select([self._query_socket], [], [])
1482 1486 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1483 1487 if self.debug:
1484 1488 pprint(msg)
1485 1489 content = msg['content']
1486 1490 if content['status'] != 'ok':
1487 1491 raise self._unwrap_exception(content)
1488 1492 mapping = content['resubmitted']
1489 1493 new_ids = [ mapping[msg_id] for msg_id in theids ]
1490 1494
1491 1495 ar = AsyncHubResult(self, msg_ids=new_ids)
1492 1496
1493 1497 if block:
1494 1498 ar.wait()
1495 1499
1496 1500 return ar
1497 1501
1498 1502 @spin_first
1499 1503 def result_status(self, msg_ids, status_only=True):
1500 1504 """Check on the status of the result(s) of the apply request with `msg_ids`.
1501 1505
1502 1506 If status_only is False, then the actual results will be retrieved, else
1503 1507 only the status of the results will be checked.
1504 1508
1505 1509 Parameters
1506 1510 ----------
1507 1511
1508 1512 msg_ids : list of msg_ids
1509 1513 if int:
1510 1514 Passed as index to self.history for convenience.
1511 1515 status_only : bool (default: True)
1512 1516 if False:
1513 1517 Retrieve the actual results of completed tasks.
1514 1518
1515 1519 Returns
1516 1520 -------
1517 1521
1518 1522 results : dict
1519 1523 There will always be the keys 'pending' and 'completed', which will
1520 1524 be lists of msg_ids that are incomplete or complete. If `status_only`
1521 1525 is False, then completed results will be keyed by their `msg_id`.
1522 1526 """
1523 1527 if not isinstance(msg_ids, (list,tuple)):
1524 1528 msg_ids = [msg_ids]
1525 1529
1526 1530 theids = []
1527 1531 for msg_id in msg_ids:
1528 1532 if isinstance(msg_id, int):
1529 1533 msg_id = self.history[msg_id]
1530 1534 if not isinstance(msg_id, string_types):
1531 1535 raise TypeError("msg_ids must be str, not %r"%msg_id)
1532 1536 theids.append(msg_id)
1533 1537
1534 1538 completed = []
1535 1539 local_results = {}
1536 1540
1537 1541 # comment this block out to temporarily disable local shortcut:
1538 1542 for msg_id in theids:
1539 1543 if msg_id in self.results:
1540 1544 completed.append(msg_id)
1541 1545 local_results[msg_id] = self.results[msg_id]
1542 1546 theids.remove(msg_id)
1543 1547
1544 1548 if theids: # some not locally cached
1545 1549 content = dict(msg_ids=theids, status_only=status_only)
1546 1550 msg = self.session.send(self._query_socket, "result_request", content=content)
1547 1551 zmq.select([self._query_socket], [], [])
1548 1552 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1549 1553 if self.debug:
1550 1554 pprint(msg)
1551 1555 content = msg['content']
1552 1556 if content['status'] != 'ok':
1553 1557 raise self._unwrap_exception(content)
1554 1558 buffers = msg['buffers']
1555 1559 else:
1556 1560 content = dict(completed=[],pending=[])
1557 1561
1558 1562 content['completed'].extend(completed)
1559 1563
1560 1564 if status_only:
1561 1565 return content
1562 1566
1563 1567 failures = []
1564 1568 # load cached results into result:
1565 1569 content.update(local_results)
1566 1570
1567 1571 # update cache with results:
1568 1572 for msg_id in sorted(theids):
1569 1573 if msg_id in content['completed']:
1570 1574 rec = content[msg_id]
1571 1575 parent = extract_dates(rec['header'])
1572 1576 header = extract_dates(rec['result_header'])
1573 1577 rcontent = rec['result_content']
1574 1578 iodict = rec['io']
1575 1579 if isinstance(rcontent, str):
1576 1580 rcontent = self.session.unpack(rcontent)
1577 1581
1578 1582 md = self.metadata[msg_id]
1579 1583 md_msg = dict(
1580 1584 content=rcontent,
1581 1585 parent_header=parent,
1582 1586 header=header,
1583 1587 metadata=rec['result_metadata'],
1584 1588 )
1585 1589 md.update(self._extract_metadata(md_msg))
1586 1590 if rec.get('received'):
1587 1591 md['received'] = parse_date(rec['received'])
1588 1592 md.update(iodict)
1589 1593
1590 1594 if rcontent['status'] == 'ok':
1591 1595 if header['msg_type'] == 'apply_reply':
1592 1596 res,buffers = serialize.unserialize_object(buffers)
1593 1597 elif header['msg_type'] == 'execute_reply':
1594 1598 res = ExecuteReply(msg_id, rcontent, md)
1595 1599 else:
1596 1600 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1597 1601 else:
1598 1602 res = self._unwrap_exception(rcontent)
1599 1603 failures.append(res)
1600 1604
1601 1605 self.results[msg_id] = res
1602 1606 content[msg_id] = res
1603 1607
1604 1608 if len(theids) == 1 and failures:
1605 1609 raise failures[0]
1606 1610
1607 1611 error.collect_exceptions(failures, "result_status")
1608 1612 return content
1609 1613
1610 1614 @spin_first
1611 1615 def queue_status(self, targets='all', verbose=False):
1612 1616 """Fetch the status of engine queues.
1613 1617
1614 1618 Parameters
1615 1619 ----------
1616 1620
1617 1621 targets : int/str/list of ints/strs
1618 1622 the engines whose states are to be queried.
1619 1623 default : all
1620 1624 verbose : bool
1621 1625 Whether to return lengths only, or lists of ids for each element
1622 1626 """
1623 1627 if targets == 'all':
1624 1628 # allow 'all' to be evaluated on the engine
1625 1629 engine_ids = None
1626 1630 else:
1627 1631 engine_ids = self._build_targets(targets)[1]
1628 1632 content = dict(targets=engine_ids, verbose=verbose)
1629 1633 self.session.send(self._query_socket, "queue_request", content=content)
1630 1634 idents,msg = self.session.recv(self._query_socket, 0)
1631 1635 if self.debug:
1632 1636 pprint(msg)
1633 1637 content = msg['content']
1634 1638 status = content.pop('status')
1635 1639 if status != 'ok':
1636 1640 raise self._unwrap_exception(content)
1637 1641 content = rekey(content)
1638 1642 if isinstance(targets, int):
1639 1643 return content[targets]
1640 1644 else:
1641 1645 return content
1642 1646
1643 1647 def _build_msgids_from_target(self, targets=None):
1644 1648 """Build a list of msg_ids from the list of engine targets"""
1645 1649 if not targets: # needed as _build_targets otherwise uses all engines
1646 1650 return []
1647 1651 target_ids = self._build_targets(targets)[0]
1648 1652 return [md_id for md_id in self.metadata if self.metadata[md_id]["engine_uuid"] in target_ids]
1649 1653
1650 1654 def _build_msgids_from_jobs(self, jobs=None):
1651 1655 """Build a list of msg_ids from "jobs" """
1652 1656 if not jobs:
1653 1657 return []
1654 1658 msg_ids = []
1655 1659 if isinstance(jobs, string_types + (AsyncResult,)):
1656 1660 jobs = [jobs]
1657 1661 bad_ids = [obj for obj in jobs if not isinstance(obj, string_types + (AsyncResult,))]
1658 1662 if bad_ids:
1659 1663 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1660 1664 for j in jobs:
1661 1665 if isinstance(j, AsyncResult):
1662 1666 msg_ids.extend(j.msg_ids)
1663 1667 else:
1664 1668 msg_ids.append(j)
1665 1669 return msg_ids
1666 1670
1667 1671 def purge_local_results(self, jobs=[], targets=[]):
1668 1672 """Clears the client caches of results and their metadata.
1669 1673
1670 1674 Individual results can be purged by msg_id, or the entire
1671 1675 history of specific targets can be purged.
1672 1676
1673 1677 Use `purge_local_results('all')` to scrub everything from the Clients's
1674 1678 results and metadata caches.
1675 1679
1676 1680 After this call all `AsyncResults` are invalid and should be discarded.
1677 1681
1678 1682 If you must "reget" the results, you can still do so by using
1679 1683 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1680 1684 redownload the results from the hub if they are still available
1681 1685 (i.e `client.purge_hub_results(...)` has not been called.
1682 1686
1683 1687 Parameters
1684 1688 ----------
1685 1689
1686 1690 jobs : str or list of str or AsyncResult objects
1687 1691 the msg_ids whose results should be purged.
1688 1692 targets : int/list of ints
1689 1693 The engines, by integer ID, whose entire result histories are to be purged.
1690 1694
1691 1695 Raises
1692 1696 ------
1693 1697
1694 1698 RuntimeError : if any of the tasks to be purged are still outstanding.
1695 1699
1696 1700 """
1697 1701 if not targets and not jobs:
1698 1702 raise ValueError("Must specify at least one of `targets` and `jobs`")
1699 1703
1700 1704 if jobs == 'all':
1701 1705 if self.outstanding:
1702 1706 raise RuntimeError("Can't purge outstanding tasks: %s" % self.outstanding)
1703 1707 self.results.clear()
1704 1708 self.metadata.clear()
1705 1709 else:
1706 1710 msg_ids = set()
1707 1711 msg_ids.update(self._build_msgids_from_target(targets))
1708 1712 msg_ids.update(self._build_msgids_from_jobs(jobs))
1709 1713 still_outstanding = self.outstanding.intersection(msg_ids)
1710 1714 if still_outstanding:
1711 1715 raise RuntimeError("Can't purge outstanding tasks: %s" % still_outstanding)
1712 1716 for mid in msg_ids:
1713 1717 self.results.pop(mid, None)
1714 1718 self.metadata.pop(mid, None)
1715 1719
1716 1720
1717 1721 @spin_first
1718 1722 def purge_hub_results(self, jobs=[], targets=[]):
1719 1723 """Tell the Hub to forget results.
1720 1724
1721 1725 Individual results can be purged by msg_id, or the entire
1722 1726 history of specific targets can be purged.
1723 1727
1724 1728 Use `purge_results('all')` to scrub everything from the Hub's db.
1725 1729
1726 1730 Parameters
1727 1731 ----------
1728 1732
1729 1733 jobs : str or list of str or AsyncResult objects
1730 1734 the msg_ids whose results should be forgotten.
1731 1735 targets : int/str/list of ints/strs
1732 1736 The targets, by int_id, whose entire history is to be purged.
1733 1737
1734 1738 default : None
1735 1739 """
1736 1740 if not targets and not jobs:
1737 1741 raise ValueError("Must specify at least one of `targets` and `jobs`")
1738 1742 if targets:
1739 1743 targets = self._build_targets(targets)[1]
1740 1744
1741 1745 # construct msg_ids from jobs
1742 1746 if jobs == 'all':
1743 1747 msg_ids = jobs
1744 1748 else:
1745 1749 msg_ids = self._build_msgids_from_jobs(jobs)
1746 1750
1747 1751 content = dict(engine_ids=targets, msg_ids=msg_ids)
1748 1752 self.session.send(self._query_socket, "purge_request", content=content)
1749 1753 idents, msg = self.session.recv(self._query_socket, 0)
1750 1754 if self.debug:
1751 1755 pprint(msg)
1752 1756 content = msg['content']
1753 1757 if content['status'] != 'ok':
1754 1758 raise self._unwrap_exception(content)
1755 1759
1756 1760 def purge_results(self, jobs=[], targets=[]):
1757 1761 """Clears the cached results from both the hub and the local client
1758 1762
1759 1763 Individual results can be purged by msg_id, or the entire
1760 1764 history of specific targets can be purged.
1761 1765
1762 1766 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1763 1767 the Client's db.
1764 1768
1765 1769 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1766 1770 the same arguments.
1767 1771
1768 1772 Parameters
1769 1773 ----------
1770 1774
1771 1775 jobs : str or list of str or AsyncResult objects
1772 1776 the msg_ids whose results should be forgotten.
1773 1777 targets : int/str/list of ints/strs
1774 1778 The targets, by int_id, whose entire history is to be purged.
1775 1779
1776 1780 default : None
1777 1781 """
1778 1782 self.purge_local_results(jobs=jobs, targets=targets)
1779 1783 self.purge_hub_results(jobs=jobs, targets=targets)
1780 1784
1781 1785 def purge_everything(self):
1782 1786 """Clears all content from previous Tasks from both the hub and the local client
1783 1787
1784 1788 In addition to calling `purge_results("all")` it also deletes the history and
1785 1789 other bookkeeping lists.
1786 1790 """
1787 1791 self.purge_results("all")
1788 1792 self.history = []
1789 1793 self.session.digest_history.clear()
1790 1794
1791 1795 @spin_first
1792 1796 def hub_history(self):
1793 1797 """Get the Hub's history
1794 1798
1795 1799 Just like the Client, the Hub has a history, which is a list of msg_ids.
1796 1800 This will contain the history of all clients, and, depending on configuration,
1797 1801 may contain history across multiple cluster sessions.
1798 1802
1799 1803 Any msg_id returned here is a valid argument to `get_result`.
1800 1804
1801 1805 Returns
1802 1806 -------
1803 1807
1804 1808 msg_ids : list of strs
1805 1809 list of all msg_ids, ordered by task submission time.
1806 1810 """
1807 1811
1808 1812 self.session.send(self._query_socket, "history_request", content={})
1809 1813 idents, msg = self.session.recv(self._query_socket, 0)
1810 1814
1811 1815 if self.debug:
1812 1816 pprint(msg)
1813 1817 content = msg['content']
1814 1818 if content['status'] != 'ok':
1815 1819 raise self._unwrap_exception(content)
1816 1820 else:
1817 1821 return content['history']
1818 1822
1819 1823 @spin_first
1820 1824 def db_query(self, query, keys=None):
1821 1825 """Query the Hub's TaskRecord database
1822 1826
1823 1827 This will return a list of task record dicts that match `query`
1824 1828
1825 1829 Parameters
1826 1830 ----------
1827 1831
1828 1832 query : mongodb query dict
1829 1833 The search dict. See mongodb query docs for details.
1830 1834 keys : list of strs [optional]
1831 1835 The subset of keys to be returned. The default is to fetch everything but buffers.
1832 1836 'msg_id' will *always* be included.
1833 1837 """
1834 1838 if isinstance(keys, string_types):
1835 1839 keys = [keys]
1836 1840 content = dict(query=query, keys=keys)
1837 1841 self.session.send(self._query_socket, "db_request", content=content)
1838 1842 idents, msg = self.session.recv(self._query_socket, 0)
1839 1843 if self.debug:
1840 1844 pprint(msg)
1841 1845 content = msg['content']
1842 1846 if content['status'] != 'ok':
1843 1847 raise self._unwrap_exception(content)
1844 1848
1845 1849 records = content['records']
1846 1850
1847 1851 buffer_lens = content['buffer_lens']
1848 1852 result_buffer_lens = content['result_buffer_lens']
1849 1853 buffers = msg['buffers']
1850 1854 has_bufs = buffer_lens is not None
1851 1855 has_rbufs = result_buffer_lens is not None
1852 1856 for i,rec in enumerate(records):
1853 1857 # unpack datetime objects
1854 1858 for hkey in ('header', 'result_header'):
1855 1859 if hkey in rec:
1856 1860 rec[hkey] = extract_dates(rec[hkey])
1857 1861 for dtkey in ('submitted', 'started', 'completed', 'received'):
1858 1862 if dtkey in rec:
1859 1863 rec[dtkey] = parse_date(rec[dtkey])
1860 1864 # relink buffers
1861 1865 if has_bufs:
1862 1866 blen = buffer_lens[i]
1863 1867 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1864 1868 if has_rbufs:
1865 1869 blen = result_buffer_lens[i]
1866 1870 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1867 1871
1868 1872 return records
1869 1873
1870 1874 __all__ = [ 'Client' ]
@@ -1,1440 +1,1449 b''
1 1 """The IPython Controller Hub with 0MQ
2 2
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6
7 7 # Copyright (c) IPython Development Team.
8 8 # Distributed under the terms of the Modified BSD License.
9 9
10 10 from __future__ import print_function
11 11
12 12 import json
13 13 import os
14 14 import sys
15 15 import time
16 16 from datetime import datetime
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop
20 20 from zmq.eventloop.zmqstream import ZMQStream
21 21
22 22 # internal:
23 23 from IPython.utils.importstring import import_item
24 24 from IPython.utils.jsonutil import extract_dates
25 25 from IPython.utils.localinterfaces import localhost
26 26 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
27 27 from IPython.utils.traitlets import (
28 28 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
29 29 )
30 30
31 31 from IPython.parallel import error, util
32 32 from IPython.parallel.factory import RegistrationFactory
33 33
34 34 from IPython.kernel.zmq.session import SessionFactory
35 35
36 36 from .heartmonitor import HeartMonitor
37 37
38 38 #-----------------------------------------------------------------------------
39 39 # Code
40 40 #-----------------------------------------------------------------------------
41 41
42 42 def _passer(*args, **kwargs):
43 43 return
44 44
45 45 def _printer(*args, **kwargs):
46 46 print (args)
47 47 print (kwargs)
48 48
49 49 def empty_record():
50 50 """Return an empty dict with all record keys."""
51 51 return {
52 52 'msg_id' : None,
53 53 'header' : None,
54 54 'metadata' : None,
55 55 'content': None,
56 56 'buffers': None,
57 57 'submitted': None,
58 58 'client_uuid' : None,
59 59 'engine_uuid' : None,
60 60 'started': None,
61 61 'completed': None,
62 62 'resubmitted': None,
63 63 'received': None,
64 64 'result_header' : None,
65 65 'result_metadata' : None,
66 66 'result_content' : None,
67 67 'result_buffers' : None,
68 68 'queue' : None,
69 69 'execute_input' : None,
70 70 'execute_result': None,
71 71 'error': None,
72 72 'stdout': '',
73 73 'stderr': '',
74 74 }
75 75
76 76 def init_record(msg):
77 77 """Initialize a TaskRecord based on a request."""
78 78 header = msg['header']
79 79 return {
80 80 'msg_id' : header['msg_id'],
81 81 'header' : header,
82 82 'content': msg['content'],
83 83 'metadata': msg['metadata'],
84 84 'buffers': msg['buffers'],
85 85 'submitted': header['date'],
86 86 'client_uuid' : None,
87 87 'engine_uuid' : None,
88 88 'started': None,
89 89 'completed': None,
90 90 'resubmitted': None,
91 91 'received': None,
92 92 'result_header' : None,
93 93 'result_metadata': None,
94 94 'result_content' : None,
95 95 'result_buffers' : None,
96 96 'queue' : None,
97 97 'execute_input' : None,
98 98 'execute_result': None,
99 99 'error': None,
100 100 'stdout': '',
101 101 'stderr': '',
102 102 }
103 103
104 104
105 105 class EngineConnector(HasTraits):
106 106 """A simple object for accessing the various zmq connections of an object.
107 107 Attributes are:
108 108 id (int): engine ID
109 109 uuid (unicode): engine UUID
110 110 pending: set of msg_ids
111 111 stallback: DelayedCallback for stalled registration
112 112 """
113 113
114 114 id = Integer(0)
115 115 uuid = Unicode()
116 116 pending = Set()
117 117 stallback = Instance(ioloop.DelayedCallback)
118 118
119 119
120 120 _db_shortcuts = {
121 121 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
122 122 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
123 123 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
124 124 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
125 125 }
126 126
127 127 class HubFactory(RegistrationFactory):
128 128 """The Configurable for setting up a Hub."""
129 129
130 130 # port-pairs for monitoredqueues:
131 131 hb = Tuple(Integer,Integer,config=True,
132 132 help="""PUB/ROUTER Port pair for Engine heartbeats""")
133 133 def _hb_default(self):
134 134 return tuple(util.select_random_ports(2))
135 135
136 136 mux = Tuple(Integer,Integer,config=True,
137 137 help="""Client/Engine Port pair for MUX queue""")
138 138
139 139 def _mux_default(self):
140 140 return tuple(util.select_random_ports(2))
141 141
142 142 task = Tuple(Integer,Integer,config=True,
143 143 help="""Client/Engine Port pair for Task queue""")
144 144 def _task_default(self):
145 145 return tuple(util.select_random_ports(2))
146 146
147 147 control = Tuple(Integer,Integer,config=True,
148 148 help="""Client/Engine Port pair for Control queue""")
149 149
150 150 def _control_default(self):
151 151 return tuple(util.select_random_ports(2))
152 152
153 153 iopub = Tuple(Integer,Integer,config=True,
154 154 help="""Client/Engine Port pair for IOPub relay""")
155 155
156 156 def _iopub_default(self):
157 157 return tuple(util.select_random_ports(2))
158 158
159 159 # single ports:
160 160 mon_port = Integer(config=True,
161 161 help="""Monitor (SUB) port for queue traffic""")
162 162
163 163 def _mon_port_default(self):
164 164 return util.select_random_ports(1)[0]
165 165
166 166 notifier_port = Integer(config=True,
167 167 help="""PUB port for sending engine status notifications""")
168 168
169 169 def _notifier_port_default(self):
170 170 return util.select_random_ports(1)[0]
171 171
172 172 engine_ip = Unicode(config=True,
173 173 help="IP on which to listen for engine connections. [default: loopback]")
174 174 def _engine_ip_default(self):
175 175 return localhost()
176 176 engine_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for engine connections. [default: tcp]")
178 178
179 179 client_ip = Unicode(config=True,
180 180 help="IP on which to listen for client connections. [default: loopback]")
181 181 client_transport = Unicode('tcp', config=True,
182 182 help="0MQ transport for client connections. [default : tcp]")
183 183
184 184 monitor_ip = Unicode(config=True,
185 185 help="IP on which to listen for monitor messages. [default: loopback]")
186 186 monitor_transport = Unicode('tcp', config=True,
187 187 help="0MQ transport for monitor messages. [default : tcp]")
188 188
189 189 _client_ip_default = _monitor_ip_default = _engine_ip_default
190 190
191 191
192 192 monitor_url = Unicode('')
193 193
194 194 db_class = DottedObjectName('NoDB',
195 195 config=True, help="""The class to use for the DB backend
196 196
197 197 Options include:
198 198
199 199 SQLiteDB: SQLite
200 200 MongoDB : use MongoDB
201 201 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
202 202 NoDB : disable database altogether (default)
203 203
204 204 """)
205 205
206 206 registration_timeout = Integer(0, config=True,
207 207 help="Engine registration timeout in seconds [default: max(30,"
208 208 "10*heartmonitor.period)]" )
209 209
210 210 def _registration_timeout_default(self):
211 211 if self.heartmonitor is None:
212 212 # early initialization, this value will be ignored
213 213 return 0
214 214 # heartmonitor period is in milliseconds, so 10x in seconds is .01
215 215 return max(30, int(.01 * self.heartmonitor.period))
216 216
217 217 # not configurable
218 218 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
219 219 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
220 220
221 221 def _ip_changed(self, name, old, new):
222 222 self.engine_ip = new
223 223 self.client_ip = new
224 224 self.monitor_ip = new
225 225 self._update_monitor_url()
226 226
227 227 def _update_monitor_url(self):
228 228 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
229 229
230 230 def _transport_changed(self, name, old, new):
231 231 self.engine_transport = new
232 232 self.client_transport = new
233 233 self.monitor_transport = new
234 234 self._update_monitor_url()
235 235
236 236 def __init__(self, **kwargs):
237 237 super(HubFactory, self).__init__(**kwargs)
238 238 self._update_monitor_url()
239 239
240 240
241 241 def construct(self):
242 242 self.init_hub()
243 243
244 244 def start(self):
245 245 self.heartmonitor.start()
246 246 self.log.info("Heartmonitor started")
247 247
248 248 def client_url(self, channel):
249 249 """return full zmq url for a named client channel"""
250 250 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
251 251
252 252 def engine_url(self, channel):
253 253 """return full zmq url for a named engine channel"""
254 254 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
255 255
256 256 def init_hub(self):
257 257 """construct Hub object"""
258 258
259 259 ctx = self.context
260 260 loop = self.loop
261 261 if 'TaskScheduler.scheme_name' in self.config:
262 262 scheme = self.config.TaskScheduler.scheme_name
263 263 else:
264 264 from .scheduler import TaskScheduler
265 265 scheme = TaskScheduler.scheme_name.get_default_value()
266 266
267 267 # build connection dicts
268 268 engine = self.engine_info = {
269 269 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
270 270 'registration' : self.regport,
271 271 'control' : self.control[1],
272 272 'mux' : self.mux[1],
273 273 'hb_ping' : self.hb[0],
274 274 'hb_pong' : self.hb[1],
275 275 'task' : self.task[1],
276 276 'iopub' : self.iopub[1],
277 277 }
278 278
279 279 client = self.client_info = {
280 280 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
281 281 'registration' : self.regport,
282 282 'control' : self.control[0],
283 283 'mux' : self.mux[0],
284 284 'task' : self.task[0],
285 285 'task_scheme' : scheme,
286 286 'iopub' : self.iopub[0],
287 287 'notification' : self.notifier_port,
288 288 }
289 289
290 290 self.log.debug("Hub engine addrs: %s", self.engine_info)
291 291 self.log.debug("Hub client addrs: %s", self.client_info)
292 292
293 293 # Registrar socket
294 294 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
295 295 util.set_hwm(q, 0)
296 296 q.bind(self.client_url('registration'))
297 297 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
298 298 if self.client_ip != self.engine_ip:
299 299 q.bind(self.engine_url('registration'))
300 300 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
301 301
302 302 ### Engine connections ###
303 303
304 304 # heartbeat
305 305 hpub = ctx.socket(zmq.PUB)
306 306 hpub.bind(self.engine_url('hb_ping'))
307 307 hrep = ctx.socket(zmq.ROUTER)
308 308 util.set_hwm(hrep, 0)
309 309 hrep.bind(self.engine_url('hb_pong'))
310 310 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
311 311 pingstream=ZMQStream(hpub,loop),
312 312 pongstream=ZMQStream(hrep,loop)
313 313 )
314 314
315 315 ### Client connections ###
316 316
317 317 # Notifier socket
318 318 n = ZMQStream(ctx.socket(zmq.PUB), loop)
319 319 n.bind(self.client_url('notification'))
320 320
321 321 ### build and launch the queues ###
322 322
323 323 # monitor socket
324 324 sub = ctx.socket(zmq.SUB)
325 325 sub.setsockopt(zmq.SUBSCRIBE, b"")
326 326 sub.bind(self.monitor_url)
327 327 sub.bind('inproc://monitor')
328 328 sub = ZMQStream(sub, loop)
329 329
330 330 # connect the db
331 331 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
332 332 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
333 333 self.db = import_item(str(db_class))(session=self.session.session,
334 334 parent=self, log=self.log)
335 335 time.sleep(.25)
336 336
337 337 # resubmit stream
338 338 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
339 339 url = util.disambiguate_url(self.client_url('task'))
340 340 r.connect(url)
341 341
342 342 # convert seconds to msec
343 343 registration_timeout = 1000*self.registration_timeout
344 344
345 345 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
346 346 query=q, notifier=n, resubmit=r, db=self.db,
347 347 engine_info=self.engine_info, client_info=self.client_info,
348 348 log=self.log, registration_timeout=registration_timeout)
349 349
350 350
351 351 class Hub(SessionFactory):
352 352 """The IPython Controller Hub with 0MQ connections
353 353
354 354 Parameters
355 355 ==========
356 356 loop: zmq IOLoop instance
357 357 session: Session object
358 358 <removed> context: zmq context for creating new connections (?)
359 359 queue: ZMQStream for monitoring the command queue (SUB)
360 360 query: ZMQStream for engine registration and client queries requests (ROUTER)
361 361 heartbeat: HeartMonitor object checking the pulse of the engines
362 362 notifier: ZMQStream for broadcasting engine registration changes (PUB)
363 363 db: connection to db for out of memory logging of commands
364 364 NotImplemented
365 365 engine_info: dict of zmq connection information for engines to connect
366 366 to the queues.
367 367 client_info: dict of zmq connection information for engines to connect
368 368 to the queues.
369 369 """
370 370
371 371 engine_state_file = Unicode()
372 372
373 373 # internal data structures:
374 374 ids=Set() # engine IDs
375 375 keytable=Dict()
376 376 by_ident=Dict()
377 377 engines=Dict()
378 378 clients=Dict()
379 379 hearts=Dict()
380 380 pending=Set()
381 381 queues=Dict() # pending msg_ids keyed by engine_id
382 382 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
383 383 completed=Dict() # completed msg_ids keyed by engine_id
384 384 all_completed=Set() # completed msg_ids keyed by engine_id
385 385 dead_engines=Set() # completed msg_ids keyed by engine_id
386 386 unassigned=Set() # set of task msg_ds not yet assigned a destination
387 387 incoming_registrations=Dict()
388 388 registration_timeout=Integer()
389 389 _idcounter=Integer(0)
390 390
391 391 # objects from constructor:
392 392 query=Instance(ZMQStream)
393 393 monitor=Instance(ZMQStream)
394 394 notifier=Instance(ZMQStream)
395 395 resubmit=Instance(ZMQStream)
396 396 heartmonitor=Instance(HeartMonitor)
397 397 db=Instance(object)
398 398 client_info=Dict()
399 399 engine_info=Dict()
400 400
401 401
402 402 def __init__(self, **kwargs):
403 403 """
404 404 # universal:
405 405 loop: IOLoop for creating future connections
406 406 session: streamsession for sending serialized data
407 407 # engine:
408 408 queue: ZMQStream for monitoring queue messages
409 409 query: ZMQStream for engine+client registration and client requests
410 410 heartbeat: HeartMonitor object for tracking engines
411 411 # extra:
412 412 db: ZMQStream for db connection (NotImplemented)
413 413 engine_info: zmq address/protocol dict for engine connections
414 414 client_info: zmq address/protocol dict for client connections
415 415 """
416 416
417 417 super(Hub, self).__init__(**kwargs)
418 418
419 419 # register our callbacks
420 420 self.query.on_recv(self.dispatch_query)
421 421 self.monitor.on_recv(self.dispatch_monitor_traffic)
422 422
423 423 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
424 424 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
425 425
426 426 self.monitor_handlers = {b'in' : self.save_queue_request,
427 427 b'out': self.save_queue_result,
428 428 b'intask': self.save_task_request,
429 429 b'outtask': self.save_task_result,
430 430 b'tracktask': self.save_task_destination,
431 431 b'incontrol': _passer,
432 432 b'outcontrol': _passer,
433 433 b'iopub': self.save_iopub_message,
434 434 }
435 435
436 436 self.query_handlers = {'queue_request': self.queue_status,
437 437 'result_request': self.get_results,
438 438 'history_request': self.get_history,
439 439 'db_request': self.db_query,
440 440 'purge_request': self.purge_results,
441 441 'load_request': self.check_load,
442 442 'resubmit_request': self.resubmit_task,
443 443 'shutdown_request': self.shutdown_request,
444 444 'registration_request' : self.register_engine,
445 445 'unregistration_request' : self.unregister_engine,
446 446 'connection_request': self.connection_request,
447 447 }
448 448
449 449 # ignore resubmit replies
450 450 self.resubmit.on_recv(lambda msg: None, copy=False)
451 451
452 452 self.log.info("hub::created hub")
453 453
454 454 @property
455 455 def _next_id(self):
456 456 """gemerate a new ID.
457 457
458 458 No longer reuse old ids, just count from 0."""
459 459 newid = self._idcounter
460 460 self._idcounter += 1
461 461 return newid
462 462 # newid = 0
463 463 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
464 464 # # print newid, self.ids, self.incoming_registrations
465 465 # while newid in self.ids or newid in incoming:
466 466 # newid += 1
467 467 # return newid
468 468
469 469 #-----------------------------------------------------------------------------
470 470 # message validation
471 471 #-----------------------------------------------------------------------------
472 472
473 473 def _validate_targets(self, targets):
474 474 """turn any valid targets argument into a list of integer ids"""
475 475 if targets is None:
476 476 # default to all
477 477 return self.ids
478 478
479 479 if isinstance(targets, (int,str,unicode_type)):
480 480 # only one target specified
481 481 targets = [targets]
482 482 _targets = []
483 483 for t in targets:
484 484 # map raw identities to ids
485 485 if isinstance(t, (str,unicode_type)):
486 486 t = self.by_ident.get(cast_bytes(t), t)
487 487 _targets.append(t)
488 488 targets = _targets
489 489 bad_targets = [ t for t in targets if t not in self.ids ]
490 490 if bad_targets:
491 491 raise IndexError("No Such Engine: %r" % bad_targets)
492 492 if not targets:
493 493 raise IndexError("No Engines Registered")
494 494 return targets
495 495
496 496 #-----------------------------------------------------------------------------
497 497 # dispatch methods (1 per stream)
498 498 #-----------------------------------------------------------------------------
499 499
500 500
501 501 @util.log_errors
502 502 def dispatch_monitor_traffic(self, msg):
503 503 """all ME and Task queue messages come through here, as well as
504 504 IOPub traffic."""
505 505 self.log.debug("monitor traffic: %r", msg[0])
506 506 switch = msg[0]
507 507 try:
508 508 idents, msg = self.session.feed_identities(msg[1:])
509 509 except ValueError:
510 510 idents=[]
511 511 if not idents:
512 512 self.log.error("Monitor message without topic: %r", msg)
513 513 return
514 514 handler = self.monitor_handlers.get(switch, None)
515 515 if handler is not None:
516 516 handler(idents, msg)
517 517 else:
518 518 self.log.error("Unrecognized monitor topic: %r", switch)
519 519
520 520
521 521 @util.log_errors
522 522 def dispatch_query(self, msg):
523 523 """Route registration requests and queries from clients."""
524 524 try:
525 525 idents, msg = self.session.feed_identities(msg)
526 526 except ValueError:
527 527 idents = []
528 528 if not idents:
529 529 self.log.error("Bad Query Message: %r", msg)
530 530 return
531 531 client_id = idents[0]
532 532 try:
533 533 msg = self.session.unserialize(msg, content=True)
534 534 except Exception:
535 535 content = error.wrap_exception()
536 536 self.log.error("Bad Query Message: %r", msg, exc_info=True)
537 537 self.session.send(self.query, "hub_error", ident=client_id,
538 538 content=content)
539 539 return
540 540 # print client_id, header, parent, content
541 541 #switch on message type:
542 542 msg_type = msg['header']['msg_type']
543 543 self.log.info("client::client %r requested %r", client_id, msg_type)
544 544 handler = self.query_handlers.get(msg_type, None)
545 545 try:
546 546 assert handler is not None, "Bad Message Type: %r" % msg_type
547 547 except:
548 548 content = error.wrap_exception()
549 549 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
550 550 self.session.send(self.query, "hub_error", ident=client_id,
551 551 content=content)
552 552 return
553 553
554 554 else:
555 555 handler(idents, msg)
556 556
557 557 def dispatch_db(self, msg):
558 558 """"""
559 559 raise NotImplementedError
560 560
561 561 #---------------------------------------------------------------------------
562 562 # handler methods (1 per event)
563 563 #---------------------------------------------------------------------------
564 564
565 565 #----------------------- Heartbeat --------------------------------------
566 566
567 567 def handle_new_heart(self, heart):
568 568 """handler to attach to heartbeater.
569 569 Called when a new heart starts to beat.
570 570 Triggers completion of registration."""
571 571 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
572 572 if heart not in self.incoming_registrations:
573 573 self.log.info("heartbeat::ignoring new heart: %r", heart)
574 574 else:
575 575 self.finish_registration(heart)
576 576
577 577
578 578 def handle_heart_failure(self, heart):
579 579 """handler to attach to heartbeater.
580 580 called when a previously registered heart fails to respond to beat request.
581 581 triggers unregistration"""
582 582 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
583 583 eid = self.hearts.get(heart, None)
584 584 uuid = self.engines[eid].uuid
585 585 if eid is None or self.keytable[eid] in self.dead_engines:
586 586 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
587 587 else:
588 588 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
589 589
590 590 #----------------------- MUX Queue Traffic ------------------------------
591 591
592 592 def save_queue_request(self, idents, msg):
593 593 if len(idents) < 2:
594 594 self.log.error("invalid identity prefix: %r", idents)
595 595 return
596 596 queue_id, client_id = idents[:2]
597 597 try:
598 598 msg = self.session.unserialize(msg)
599 599 except Exception:
600 600 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
601 601 return
602 602
603 603 eid = self.by_ident.get(queue_id, None)
604 604 if eid is None:
605 605 self.log.error("queue::target %r not registered", queue_id)
606 606 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
607 607 return
608 608 record = init_record(msg)
609 609 msg_id = record['msg_id']
610 610 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
611 611 # Unicode in records
612 612 record['engine_uuid'] = queue_id.decode('ascii')
613 613 record['client_uuid'] = msg['header']['session']
614 614 record['queue'] = 'mux'
615 615
616 616 try:
617 617 # it's posible iopub arrived first:
618 618 existing = self.db.get_record(msg_id)
619 619 for key,evalue in iteritems(existing):
620 620 rvalue = record.get(key, None)
621 621 if evalue and rvalue and evalue != rvalue:
622 622 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
623 623 elif evalue and not rvalue:
624 624 record[key] = evalue
625 625 try:
626 626 self.db.update_record(msg_id, record)
627 627 except Exception:
628 628 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
629 629 except KeyError:
630 630 try:
631 631 self.db.add_record(msg_id, record)
632 632 except Exception:
633 633 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
634 634
635 635
636 636 self.pending.add(msg_id)
637 637 self.queues[eid].append(msg_id)
638 638
639 639 def save_queue_result(self, idents, msg):
640 640 if len(idents) < 2:
641 641 self.log.error("invalid identity prefix: %r", idents)
642 642 return
643 643
644 644 client_id, queue_id = idents[:2]
645 645 try:
646 646 msg = self.session.unserialize(msg)
647 647 except Exception:
648 648 self.log.error("queue::engine %r sent invalid message to %r: %r",
649 649 queue_id, client_id, msg, exc_info=True)
650 650 return
651 651
652 652 eid = self.by_ident.get(queue_id, None)
653 653 if eid is None:
654 654 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
655 655 return
656 656
657 657 parent = msg['parent_header']
658 658 if not parent:
659 659 return
660 660 msg_id = parent['msg_id']
661 661 if msg_id in self.pending:
662 662 self.pending.remove(msg_id)
663 663 self.all_completed.add(msg_id)
664 664 self.queues[eid].remove(msg_id)
665 665 self.completed[eid].append(msg_id)
666 666 self.log.info("queue::request %r completed on %s", msg_id, eid)
667 667 elif msg_id not in self.all_completed:
668 668 # it could be a result from a dead engine that died before delivering the
669 669 # result
670 670 self.log.warn("queue:: unknown msg finished %r", msg_id)
671 671 return
672 672 # update record anyway, because the unregistration could have been premature
673 673 rheader = msg['header']
674 674 md = msg['metadata']
675 675 completed = rheader['date']
676 676 started = extract_dates(md.get('started', None))
677 677 result = {
678 678 'result_header' : rheader,
679 679 'result_metadata': md,
680 680 'result_content': msg['content'],
681 681 'received': datetime.now(),
682 682 'started' : started,
683 683 'completed' : completed
684 684 }
685 685
686 686 result['result_buffers'] = msg['buffers']
687 687 try:
688 688 self.db.update_record(msg_id, result)
689 689 except Exception:
690 690 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
691 691
692 692
693 693 #--------------------- Task Queue Traffic ------------------------------
694 694
695 695 def save_task_request(self, idents, msg):
696 696 """Save the submission of a task."""
697 697 client_id = idents[0]
698 698
699 699 try:
700 700 msg = self.session.unserialize(msg)
701 701 except Exception:
702 702 self.log.error("task::client %r sent invalid task message: %r",
703 703 client_id, msg, exc_info=True)
704 704 return
705 705 record = init_record(msg)
706 706
707 707 record['client_uuid'] = msg['header']['session']
708 708 record['queue'] = 'task'
709 709 header = msg['header']
710 710 msg_id = header['msg_id']
711 711 self.pending.add(msg_id)
712 712 self.unassigned.add(msg_id)
713 713 try:
714 714 # it's posible iopub arrived first:
715 715 existing = self.db.get_record(msg_id)
716 716 if existing['resubmitted']:
717 717 for key in ('submitted', 'client_uuid', 'buffers'):
718 718 # don't clobber these keys on resubmit
719 719 # submitted and client_uuid should be different
720 720 # and buffers might be big, and shouldn't have changed
721 721 record.pop(key)
722 722 # still check content,header which should not change
723 723 # but are not expensive to compare as buffers
724 724
725 725 for key,evalue in iteritems(existing):
726 726 if key.endswith('buffers'):
727 727 # don't compare buffers
728 728 continue
729 729 rvalue = record.get(key, None)
730 730 if evalue and rvalue and evalue != rvalue:
731 731 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
732 732 elif evalue and not rvalue:
733 733 record[key] = evalue
734 734 try:
735 735 self.db.update_record(msg_id, record)
736 736 except Exception:
737 737 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
738 738 except KeyError:
739 739 try:
740 740 self.db.add_record(msg_id, record)
741 741 except Exception:
742 742 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
743 743 except Exception:
744 744 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
745 745
746 746 def save_task_result(self, idents, msg):
747 747 """save the result of a completed task."""
748 748 client_id = idents[0]
749 749 try:
750 750 msg = self.session.unserialize(msg)
751 751 except Exception:
752 752 self.log.error("task::invalid task result message send to %r: %r",
753 753 client_id, msg, exc_info=True)
754 754 return
755 755
756 756 parent = msg['parent_header']
757 757 if not parent:
758 758 # print msg
759 759 self.log.warn("Task %r had no parent!", msg)
760 760 return
761 761 msg_id = parent['msg_id']
762 762 if msg_id in self.unassigned:
763 763 self.unassigned.remove(msg_id)
764 764
765 765 header = msg['header']
766 766 md = msg['metadata']
767 767 engine_uuid = md.get('engine', u'')
768 768 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
769 769
770 770 status = md.get('status', None)
771 771
772 772 if msg_id in self.pending:
773 773 self.log.info("task::task %r finished on %s", msg_id, eid)
774 774 self.pending.remove(msg_id)
775 775 self.all_completed.add(msg_id)
776 776 if eid is not None:
777 777 if status != 'aborted':
778 778 self.completed[eid].append(msg_id)
779 779 if msg_id in self.tasks[eid]:
780 780 self.tasks[eid].remove(msg_id)
781 781 completed = header['date']
782 782 started = extract_dates(md.get('started', None))
783 783 result = {
784 784 'result_header' : header,
785 785 'result_metadata': msg['metadata'],
786 786 'result_content': msg['content'],
787 787 'started' : started,
788 788 'completed' : completed,
789 789 'received' : datetime.now(),
790 790 'engine_uuid': engine_uuid,
791 791 }
792 792
793 793 result['result_buffers'] = msg['buffers']
794 794 try:
795 795 self.db.update_record(msg_id, result)
796 796 except Exception:
797 797 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
798 798
799 799 else:
800 800 self.log.debug("task::unknown task %r finished", msg_id)
801 801
802 802 def save_task_destination(self, idents, msg):
803 803 try:
804 804 msg = self.session.unserialize(msg, content=True)
805 805 except Exception:
806 806 self.log.error("task::invalid task tracking message", exc_info=True)
807 807 return
808 808 content = msg['content']
809 809 # print (content)
810 810 msg_id = content['msg_id']
811 811 engine_uuid = content['engine_id']
812 812 eid = self.by_ident[cast_bytes(engine_uuid)]
813 813
814 814 self.log.info("task::task %r arrived on %r", msg_id, eid)
815 815 if msg_id in self.unassigned:
816 816 self.unassigned.remove(msg_id)
817 817 # else:
818 818 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
819 819
820 820 self.tasks[eid].append(msg_id)
821 821 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
822 822 try:
823 823 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
824 824 except Exception:
825 825 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
826 826
827 827
828 828 def mia_task_request(self, idents, msg):
829 829 raise NotImplementedError
830 830 client_id = idents[0]
831 831 # content = dict(mia=self.mia,status='ok')
832 832 # self.session.send('mia_reply', content=content, idents=client_id)
833 833
834 834
835 835 #--------------------- IOPub Traffic ------------------------------
836 836
837 837 def save_iopub_message(self, topics, msg):
838 838 """save an iopub message into the db"""
839 839 # print (topics)
840 840 try:
841 841 msg = self.session.unserialize(msg, content=True)
842 842 except Exception:
843 843 self.log.error("iopub::invalid IOPub message", exc_info=True)
844 844 return
845 845
846 846 parent = msg['parent_header']
847 847 if not parent:
848 848 self.log.debug("iopub::IOPub message lacks parent: %r", msg)
849 849 return
850 850 msg_id = parent['msg_id']
851 851 msg_type = msg['header']['msg_type']
852 852 content = msg['content']
853
853
854 854 # ensure msg_id is in db
855 855 try:
856 856 rec = self.db.get_record(msg_id)
857 857 except KeyError:
858 rec = empty_record()
859 rec['msg_id'] = msg_id
860 self.db.add_record(msg_id, rec)
858 rec = None
859
861 860 # stream
862 861 d = {}
863 862 if msg_type == 'stream':
864 863 name = content['name']
865 s = rec[name] or ''
864 s = '' if rec is None else rec[name]
866 865 d[name] = s + content['data']
867 866
868 867 elif msg_type == 'error':
869 868 d['error'] = content
870 869 elif msg_type == 'execute_input':
871 870 d['execute_input'] = content['code']
872 871 elif msg_type in ('display_data', 'execute_result'):
873 872 d[msg_type] = content
874 873 elif msg_type == 'status':
875 874 pass
876 875 elif msg_type == 'data_pub':
877 876 self.log.info("ignored data_pub message for %s" % msg_id)
878 877 else:
879 878 self.log.warn("unhandled iopub msg_type: %r", msg_type)
880 879
881 880 if not d:
882 881 return
883
882
883 if rec is None:
884 # new record
885 rec = empty_record()
886 rec['msg_id'] = msg_id
887 rec.update(d)
888 d = rec
889 update_record = self.db.add_record
890 else:
891 update_record = self.db.update_record
892
884 893 try:
885 self.db.update_record(msg_id, d)
894 update_record(msg_id, d)
886 895 except Exception:
887 896 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
888 897
889 898
890 899
891 900 #-------------------------------------------------------------------------
892 901 # Registration requests
893 902 #-------------------------------------------------------------------------
894 903
895 904 def connection_request(self, client_id, msg):
896 905 """Reply with connection addresses for clients."""
897 906 self.log.info("client::client %r connected", client_id)
898 907 content = dict(status='ok')
899 908 jsonable = {}
900 909 for k,v in iteritems(self.keytable):
901 910 if v not in self.dead_engines:
902 911 jsonable[str(k)] = v
903 912 content['engines'] = jsonable
904 913 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
905 914
906 915 def register_engine(self, reg, msg):
907 916 """Register a new engine."""
908 917 content = msg['content']
909 918 try:
910 919 uuid = content['uuid']
911 920 except KeyError:
912 921 self.log.error("registration::queue not specified", exc_info=True)
913 922 return
914 923
915 924 eid = self._next_id
916 925
917 926 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
918 927
919 928 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
920 929 # check if requesting available IDs:
921 930 if cast_bytes(uuid) in self.by_ident:
922 931 try:
923 932 raise KeyError("uuid %r in use" % uuid)
924 933 except:
925 934 content = error.wrap_exception()
926 935 self.log.error("uuid %r in use", uuid, exc_info=True)
927 936 else:
928 937 for h, ec in iteritems(self.incoming_registrations):
929 938 if uuid == h:
930 939 try:
931 940 raise KeyError("heart_id %r in use" % uuid)
932 941 except:
933 942 self.log.error("heart_id %r in use", uuid, exc_info=True)
934 943 content = error.wrap_exception()
935 944 break
936 945 elif uuid == ec.uuid:
937 946 try:
938 947 raise KeyError("uuid %r in use" % uuid)
939 948 except:
940 949 self.log.error("uuid %r in use", uuid, exc_info=True)
941 950 content = error.wrap_exception()
942 951 break
943 952
944 953 msg = self.session.send(self.query, "registration_reply",
945 954 content=content,
946 955 ident=reg)
947 956
948 957 heart = cast_bytes(uuid)
949 958
950 959 if content['status'] == 'ok':
951 960 if heart in self.heartmonitor.hearts:
952 961 # already beating
953 962 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
954 963 self.finish_registration(heart)
955 964 else:
956 965 purge = lambda : self._purge_stalled_registration(heart)
957 966 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
958 967 dc.start()
959 968 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
960 969 else:
961 970 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
962 971
963 972 return eid
964 973
965 974 def unregister_engine(self, ident, msg):
966 975 """Unregister an engine that explicitly requested to leave."""
967 976 try:
968 977 eid = msg['content']['id']
969 978 except:
970 979 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
971 980 return
972 981 self.log.info("registration::unregister_engine(%r)", eid)
973 982 # print (eid)
974 983 uuid = self.keytable[eid]
975 984 content=dict(id=eid, uuid=uuid)
976 985 self.dead_engines.add(uuid)
977 986 # self.ids.remove(eid)
978 987 # uuid = self.keytable.pop(eid)
979 988 #
980 989 # ec = self.engines.pop(eid)
981 990 # self.hearts.pop(ec.heartbeat)
982 991 # self.by_ident.pop(ec.queue)
983 992 # self.completed.pop(eid)
984 993 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
985 994 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
986 995 dc.start()
987 996 ############## TODO: HANDLE IT ################
988 997
989 998 self._save_engine_state()
990 999
991 1000 if self.notifier:
992 1001 self.session.send(self.notifier, "unregistration_notification", content=content)
993 1002
994 1003 def _handle_stranded_msgs(self, eid, uuid):
995 1004 """Handle messages known to be on an engine when the engine unregisters.
996 1005
997 1006 It is possible that this will fire prematurely - that is, an engine will
998 1007 go down after completing a result, and the client will be notified
999 1008 that the result failed and later receive the actual result.
1000 1009 """
1001 1010
1002 1011 outstanding = self.queues[eid]
1003 1012
1004 1013 for msg_id in outstanding:
1005 1014 self.pending.remove(msg_id)
1006 1015 self.all_completed.add(msg_id)
1007 1016 try:
1008 1017 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1009 1018 except:
1010 1019 content = error.wrap_exception()
1011 1020 # build a fake header:
1012 1021 header = {}
1013 1022 header['engine'] = uuid
1014 1023 header['date'] = datetime.now()
1015 1024 rec = dict(result_content=content, result_header=header, result_buffers=[])
1016 1025 rec['completed'] = header['date']
1017 1026 rec['engine_uuid'] = uuid
1018 1027 try:
1019 1028 self.db.update_record(msg_id, rec)
1020 1029 except Exception:
1021 1030 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1022 1031
1023 1032
1024 1033 def finish_registration(self, heart):
1025 1034 """Second half of engine registration, called after our HeartMonitor
1026 1035 has received a beat from the Engine's Heart."""
1027 1036 try:
1028 1037 ec = self.incoming_registrations.pop(heart)
1029 1038 except KeyError:
1030 1039 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1031 1040 return
1032 1041 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1033 1042 if ec.stallback is not None:
1034 1043 ec.stallback.stop()
1035 1044 eid = ec.id
1036 1045 self.ids.add(eid)
1037 1046 self.keytable[eid] = ec.uuid
1038 1047 self.engines[eid] = ec
1039 1048 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1040 1049 self.queues[eid] = list()
1041 1050 self.tasks[eid] = list()
1042 1051 self.completed[eid] = list()
1043 1052 self.hearts[heart] = eid
1044 1053 content = dict(id=eid, uuid=self.engines[eid].uuid)
1045 1054 if self.notifier:
1046 1055 self.session.send(self.notifier, "registration_notification", content=content)
1047 1056 self.log.info("engine::Engine Connected: %i", eid)
1048 1057
1049 1058 self._save_engine_state()
1050 1059
1051 1060 def _purge_stalled_registration(self, heart):
1052 1061 if heart in self.incoming_registrations:
1053 1062 ec = self.incoming_registrations.pop(heart)
1054 1063 self.log.info("registration::purging stalled registration: %i", ec.id)
1055 1064 else:
1056 1065 pass
1057 1066
1058 1067 #-------------------------------------------------------------------------
1059 1068 # Engine State
1060 1069 #-------------------------------------------------------------------------
1061 1070
1062 1071
1063 1072 def _cleanup_engine_state_file(self):
1064 1073 """cleanup engine state mapping"""
1065 1074
1066 1075 if os.path.exists(self.engine_state_file):
1067 1076 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1068 1077 try:
1069 1078 os.remove(self.engine_state_file)
1070 1079 except IOError:
1071 1080 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1072 1081
1073 1082
1074 1083 def _save_engine_state(self):
1075 1084 """save engine mapping to JSON file"""
1076 1085 if not self.engine_state_file:
1077 1086 return
1078 1087 self.log.debug("save engine state to %s" % self.engine_state_file)
1079 1088 state = {}
1080 1089 engines = {}
1081 1090 for eid, ec in iteritems(self.engines):
1082 1091 if ec.uuid not in self.dead_engines:
1083 1092 engines[eid] = ec.uuid
1084 1093
1085 1094 state['engines'] = engines
1086 1095
1087 1096 state['next_id'] = self._idcounter
1088 1097
1089 1098 with open(self.engine_state_file, 'w') as f:
1090 1099 json.dump(state, f)
1091 1100
1092 1101
1093 1102 def _load_engine_state(self):
1094 1103 """load engine mapping from JSON file"""
1095 1104 if not os.path.exists(self.engine_state_file):
1096 1105 return
1097 1106
1098 1107 self.log.info("loading engine state from %s" % self.engine_state_file)
1099 1108
1100 1109 with open(self.engine_state_file) as f:
1101 1110 state = json.load(f)
1102 1111
1103 1112 save_notifier = self.notifier
1104 1113 self.notifier = None
1105 1114 for eid, uuid in iteritems(state['engines']):
1106 1115 heart = uuid.encode('ascii')
1107 1116 # start with this heart as current and beating:
1108 1117 self.heartmonitor.responses.add(heart)
1109 1118 self.heartmonitor.hearts.add(heart)
1110 1119
1111 1120 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1112 1121 self.finish_registration(heart)
1113 1122
1114 1123 self.notifier = save_notifier
1115 1124
1116 1125 self._idcounter = state['next_id']
1117 1126
1118 1127 #-------------------------------------------------------------------------
1119 1128 # Client Requests
1120 1129 #-------------------------------------------------------------------------
1121 1130
1122 1131 def shutdown_request(self, client_id, msg):
1123 1132 """handle shutdown request."""
1124 1133 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1125 1134 # also notify other clients of shutdown
1126 1135 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1127 1136 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1128 1137 dc.start()
1129 1138
1130 1139 def _shutdown(self):
1131 1140 self.log.info("hub::hub shutting down.")
1132 1141 time.sleep(0.1)
1133 1142 sys.exit(0)
1134 1143
1135 1144
1136 1145 def check_load(self, client_id, msg):
1137 1146 content = msg['content']
1138 1147 try:
1139 1148 targets = content['targets']
1140 1149 targets = self._validate_targets(targets)
1141 1150 except:
1142 1151 content = error.wrap_exception()
1143 1152 self.session.send(self.query, "hub_error",
1144 1153 content=content, ident=client_id)
1145 1154 return
1146 1155
1147 1156 content = dict(status='ok')
1148 1157 # loads = {}
1149 1158 for t in targets:
1150 1159 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1151 1160 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1152 1161
1153 1162
1154 1163 def queue_status(self, client_id, msg):
1155 1164 """Return the Queue status of one or more targets.
1156 1165
1157 1166 If verbose, return the msg_ids, else return len of each type.
1158 1167
1159 1168 Keys:
1160 1169
1161 1170 * queue (pending MUX jobs)
1162 1171 * tasks (pending Task jobs)
1163 1172 * completed (finished jobs from both queues)
1164 1173 """
1165 1174 content = msg['content']
1166 1175 targets = content['targets']
1167 1176 try:
1168 1177 targets = self._validate_targets(targets)
1169 1178 except:
1170 1179 content = error.wrap_exception()
1171 1180 self.session.send(self.query, "hub_error",
1172 1181 content=content, ident=client_id)
1173 1182 return
1174 1183 verbose = content.get('verbose', False)
1175 1184 content = dict(status='ok')
1176 1185 for t in targets:
1177 1186 queue = self.queues[t]
1178 1187 completed = self.completed[t]
1179 1188 tasks = self.tasks[t]
1180 1189 if not verbose:
1181 1190 queue = len(queue)
1182 1191 completed = len(completed)
1183 1192 tasks = len(tasks)
1184 1193 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1185 1194 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1186 1195 # print (content)
1187 1196 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1188 1197
1189 1198 def purge_results(self, client_id, msg):
1190 1199 """Purge results from memory. This method is more valuable before we move
1191 1200 to a DB based message storage mechanism."""
1192 1201 content = msg['content']
1193 1202 self.log.info("Dropping records with %s", content)
1194 1203 msg_ids = content.get('msg_ids', [])
1195 1204 reply = dict(status='ok')
1196 1205 if msg_ids == 'all':
1197 1206 try:
1198 1207 self.db.drop_matching_records(dict(completed={'$ne':None}))
1199 1208 except Exception:
1200 1209 reply = error.wrap_exception()
1201 1210 self.log.exception("Error dropping records")
1202 1211 else:
1203 1212 pending = [m for m in msg_ids if (m in self.pending)]
1204 1213 if pending:
1205 1214 try:
1206 1215 raise IndexError("msg pending: %r" % pending[0])
1207 1216 except:
1208 1217 reply = error.wrap_exception()
1209 1218 self.log.exception("Error dropping records")
1210 1219 else:
1211 1220 try:
1212 1221 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1213 1222 except Exception:
1214 1223 reply = error.wrap_exception()
1215 1224 self.log.exception("Error dropping records")
1216 1225
1217 1226 if reply['status'] == 'ok':
1218 1227 eids = content.get('engine_ids', [])
1219 1228 for eid in eids:
1220 1229 if eid not in self.engines:
1221 1230 try:
1222 1231 raise IndexError("No such engine: %i" % eid)
1223 1232 except:
1224 1233 reply = error.wrap_exception()
1225 1234 self.log.exception("Error dropping records")
1226 1235 break
1227 1236 uid = self.engines[eid].uuid
1228 1237 try:
1229 1238 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1230 1239 except Exception:
1231 1240 reply = error.wrap_exception()
1232 1241 self.log.exception("Error dropping records")
1233 1242 break
1234 1243
1235 1244 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1236 1245
1237 1246 def resubmit_task(self, client_id, msg):
1238 1247 """Resubmit one or more tasks."""
1239 1248 def finish(reply):
1240 1249 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1241 1250
1242 1251 content = msg['content']
1243 1252 msg_ids = content['msg_ids']
1244 1253 reply = dict(status='ok')
1245 1254 try:
1246 1255 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1247 1256 'header', 'content', 'buffers'])
1248 1257 except Exception:
1249 1258 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1250 1259 return finish(error.wrap_exception())
1251 1260
1252 1261 # validate msg_ids
1253 1262 found_ids = [ rec['msg_id'] for rec in records ]
1254 1263 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1255 1264 if len(records) > len(msg_ids):
1256 1265 try:
1257 1266 raise RuntimeError("DB appears to be in an inconsistent state."
1258 1267 "More matching records were found than should exist")
1259 1268 except Exception:
1260 1269 self.log.exception("Failed to resubmit task")
1261 1270 return finish(error.wrap_exception())
1262 1271 elif len(records) < len(msg_ids):
1263 1272 missing = [ m for m in msg_ids if m not in found_ids ]
1264 1273 try:
1265 1274 raise KeyError("No such msg(s): %r" % missing)
1266 1275 except KeyError:
1267 1276 self.log.exception("Failed to resubmit task")
1268 1277 return finish(error.wrap_exception())
1269 1278 elif pending_ids:
1270 1279 pass
1271 1280 # no need to raise on resubmit of pending task, now that we
1272 1281 # resubmit under new ID, but do we want to raise anyway?
1273 1282 # msg_id = invalid_ids[0]
1274 1283 # try:
1275 1284 # raise ValueError("Task(s) %r appears to be inflight" % )
1276 1285 # except Exception:
1277 1286 # return finish(error.wrap_exception())
1278 1287
1279 1288 # mapping of original IDs to resubmitted IDs
1280 1289 resubmitted = {}
1281 1290
1282 1291 # send the messages
1283 1292 for rec in records:
1284 1293 header = rec['header']
1285 1294 msg = self.session.msg(header['msg_type'], parent=header)
1286 1295 msg_id = msg['msg_id']
1287 1296 msg['content'] = rec['content']
1288 1297
1289 1298 # use the old header, but update msg_id and timestamp
1290 1299 fresh = msg['header']
1291 1300 header['msg_id'] = fresh['msg_id']
1292 1301 header['date'] = fresh['date']
1293 1302 msg['header'] = header
1294 1303
1295 1304 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1296 1305
1297 1306 resubmitted[rec['msg_id']] = msg_id
1298 1307 self.pending.add(msg_id)
1299 1308 msg['buffers'] = rec['buffers']
1300 1309 try:
1301 1310 self.db.add_record(msg_id, init_record(msg))
1302 1311 except Exception:
1303 1312 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1304 1313 return finish(error.wrap_exception())
1305 1314
1306 1315 finish(dict(status='ok', resubmitted=resubmitted))
1307 1316
1308 1317 # store the new IDs in the Task DB
1309 1318 for msg_id, resubmit_id in iteritems(resubmitted):
1310 1319 try:
1311 1320 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1312 1321 except Exception:
1313 1322 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1314 1323
1315 1324
1316 1325 def _extract_record(self, rec):
1317 1326 """decompose a TaskRecord dict into subsection of reply for get_result"""
1318 1327 io_dict = {}
1319 1328 for key in ('execute_input', 'execute_result', 'error', 'stdout', 'stderr'):
1320 1329 io_dict[key] = rec[key]
1321 1330 content = {
1322 1331 'header': rec['header'],
1323 1332 'metadata': rec['metadata'],
1324 1333 'result_metadata': rec['result_metadata'],
1325 1334 'result_header' : rec['result_header'],
1326 1335 'result_content': rec['result_content'],
1327 1336 'received' : rec['received'],
1328 1337 'io' : io_dict,
1329 1338 }
1330 1339 if rec['result_buffers']:
1331 1340 buffers = list(map(bytes, rec['result_buffers']))
1332 1341 else:
1333 1342 buffers = []
1334 1343
1335 1344 return content, buffers
1336 1345
1337 1346 def get_results(self, client_id, msg):
1338 1347 """Get the result of 1 or more messages."""
1339 1348 content = msg['content']
1340 1349 msg_ids = sorted(set(content['msg_ids']))
1341 1350 statusonly = content.get('status_only', False)
1342 1351 pending = []
1343 1352 completed = []
1344 1353 content = dict(status='ok')
1345 1354 content['pending'] = pending
1346 1355 content['completed'] = completed
1347 1356 buffers = []
1348 1357 if not statusonly:
1349 1358 try:
1350 1359 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1351 1360 # turn match list into dict, for faster lookup
1352 1361 records = {}
1353 1362 for rec in matches:
1354 1363 records[rec['msg_id']] = rec
1355 1364 except Exception:
1356 1365 content = error.wrap_exception()
1357 1366 self.log.exception("Failed to get results")
1358 1367 self.session.send(self.query, "result_reply", content=content,
1359 1368 parent=msg, ident=client_id)
1360 1369 return
1361 1370 else:
1362 1371 records = {}
1363 1372 for msg_id in msg_ids:
1364 1373 if msg_id in self.pending:
1365 1374 pending.append(msg_id)
1366 1375 elif msg_id in self.all_completed:
1367 1376 completed.append(msg_id)
1368 1377 if not statusonly:
1369 1378 c,bufs = self._extract_record(records[msg_id])
1370 1379 content[msg_id] = c
1371 1380 buffers.extend(bufs)
1372 1381 elif msg_id in records:
1373 1382 if rec['completed']:
1374 1383 completed.append(msg_id)
1375 1384 c,bufs = self._extract_record(records[msg_id])
1376 1385 content[msg_id] = c
1377 1386 buffers.extend(bufs)
1378 1387 else:
1379 1388 pending.append(msg_id)
1380 1389 else:
1381 1390 try:
1382 1391 raise KeyError('No such message: '+msg_id)
1383 1392 except:
1384 1393 content = error.wrap_exception()
1385 1394 break
1386 1395 self.session.send(self.query, "result_reply", content=content,
1387 1396 parent=msg, ident=client_id,
1388 1397 buffers=buffers)
1389 1398
1390 1399 def get_history(self, client_id, msg):
1391 1400 """Get a list of all msg_ids in our DB records"""
1392 1401 try:
1393 1402 msg_ids = self.db.get_history()
1394 1403 except Exception as e:
1395 1404 content = error.wrap_exception()
1396 1405 self.log.exception("Failed to get history")
1397 1406 else:
1398 1407 content = dict(status='ok', history=msg_ids)
1399 1408
1400 1409 self.session.send(self.query, "history_reply", content=content,
1401 1410 parent=msg, ident=client_id)
1402 1411
1403 1412 def db_query(self, client_id, msg):
1404 1413 """Perform a raw query on the task record database."""
1405 1414 content = msg['content']
1406 1415 query = extract_dates(content.get('query', {}))
1407 1416 keys = content.get('keys', None)
1408 1417 buffers = []
1409 1418 empty = list()
1410 1419 try:
1411 1420 records = self.db.find_records(query, keys)
1412 1421 except Exception as e:
1413 1422 content = error.wrap_exception()
1414 1423 self.log.exception("DB query failed")
1415 1424 else:
1416 1425 # extract buffers from reply content:
1417 1426 if keys is not None:
1418 1427 buffer_lens = [] if 'buffers' in keys else None
1419 1428 result_buffer_lens = [] if 'result_buffers' in keys else None
1420 1429 else:
1421 1430 buffer_lens = None
1422 1431 result_buffer_lens = None
1423 1432
1424 1433 for rec in records:
1425 1434 # buffers may be None, so double check
1426 1435 b = rec.pop('buffers', empty) or empty
1427 1436 if buffer_lens is not None:
1428 1437 buffer_lens.append(len(b))
1429 1438 buffers.extend(b)
1430 1439 rb = rec.pop('result_buffers', empty) or empty
1431 1440 if result_buffer_lens is not None:
1432 1441 result_buffer_lens.append(len(rb))
1433 1442 buffers.extend(rb)
1434 1443 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1435 1444 result_buffer_lens=result_buffer_lens)
1436 1445 # self.log.debug (content)
1437 1446 self.session.send(self.query, "db_reply", content=content,
1438 1447 parent=msg, ident=client_id,
1439 1448 buffers=buffers)
1440 1449
General Comments 0
You need to be logged in to leave comments. Login now