##// END OF EJS Templates
add Client.resubmit for re-running tasks...
MinRK -
Show More
@@ -1,1292 +1,1354 b''
1 1 """A semi-synchronous Client for the ZMQ cluster"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 28 Dict, List, Bool, Str, Set)
29 29 from IPython.external.decorator import decorator
30 30 from IPython.external.ssh import tunnel
31 31
32 32 from IPython.parallel import error
33 33 from IPython.parallel import streamsession as ss
34 34 from IPython.parallel import util
35 35
36 36 from .asyncresult import AsyncResult, AsyncHubResult
37 37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
38 38 from .view import DirectView, LoadBalancedView
39 39
40 40 #--------------------------------------------------------------------------
41 41 # Decorators for Client methods
42 42 #--------------------------------------------------------------------------
43 43
44 44 @decorator
45 45 def spin_first(f, self, *args, **kwargs):
46 46 """Call spin() to sync state prior to calling the method."""
47 47 self.spin()
48 48 return f(self, *args, **kwargs)
49 49
50 50
51 51 #--------------------------------------------------------------------------
52 52 # Classes
53 53 #--------------------------------------------------------------------------
54 54
55 55 class Metadata(dict):
56 56 """Subclass of dict for initializing metadata values.
57 57
58 58 Attribute access works on keys.
59 59
60 60 These objects have a strict set of keys - errors will raise if you try
61 61 to add new keys.
62 62 """
63 63 def __init__(self, *args, **kwargs):
64 64 dict.__init__(self)
65 65 md = {'msg_id' : None,
66 66 'submitted' : None,
67 67 'started' : None,
68 68 'completed' : None,
69 69 'received' : None,
70 70 'engine_uuid' : None,
71 71 'engine_id' : None,
72 72 'follow' : None,
73 73 'after' : None,
74 74 'status' : None,
75 75
76 76 'pyin' : None,
77 77 'pyout' : None,
78 78 'pyerr' : None,
79 79 'stdout' : '',
80 80 'stderr' : '',
81 81 }
82 82 self.update(md)
83 83 self.update(dict(*args, **kwargs))
84 84
85 85 def __getattr__(self, key):
86 86 """getattr aliased to getitem"""
87 87 if key in self.iterkeys():
88 88 return self[key]
89 89 else:
90 90 raise AttributeError(key)
91 91
92 92 def __setattr__(self, key, value):
93 93 """setattr aliased to setitem, with strict"""
94 94 if key in self.iterkeys():
95 95 self[key] = value
96 96 else:
97 97 raise AttributeError(key)
98 98
99 99 def __setitem__(self, key, value):
100 100 """strict static key enforcement"""
101 101 if key in self.iterkeys():
102 102 dict.__setitem__(self, key, value)
103 103 else:
104 104 raise KeyError(key)
105 105
106 106
107 107 class Client(HasTraits):
108 108 """A semi-synchronous client to the IPython ZMQ cluster
109 109
110 110 Parameters
111 111 ----------
112 112
113 113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 114 Connection information for the Hub's registration. If a json connector
115 115 file is given, then likely no further configuration is necessary.
116 116 [Default: use profile]
117 117 profile : bytes
118 118 The name of the Cluster profile to be used to find connector information.
119 119 [Default: 'default']
120 120 context : zmq.Context
121 121 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 122 username : bytes
123 123 set username to be passed to the Session object
124 124 debug : bool
125 125 flag for lots of message printing for debug purposes
126 126
127 127 #-------------- ssh related args ----------------
128 128 # These are args for configuring the ssh tunnel to be used
129 129 # credentials are used to forward connections over ssh to the Controller
130 130 # Note that the ip given in `addr` needs to be relative to sshserver
131 131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 132 # and set sshserver as the same machine the Controller is on. However,
133 133 # the only requirement is that sshserver is able to see the Controller
134 134 # (i.e. is within the same trusted network).
135 135
136 136 sshserver : str
137 137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 138 If keyfile or password is specified, and this is not, it will default to
139 139 the ip given in addr.
140 140 sshkey : str; path to public ssh key file
141 141 This specifies a key to be used in ssh login, default None.
142 142 Regular default ssh keys will be used without specifying this argument.
143 143 password : str
144 144 Your ssh password to sshserver. Note that if this is left None,
145 145 you will be prompted for it if passwordless key based login is unavailable.
146 146 paramiko : bool
147 147 flag for whether to use paramiko instead of shell ssh for tunneling.
148 148 [default: True on win32, False else]
149 149
150 150 ------- exec authentication args -------
151 151 If even localhost is untrusted, you can have some protection against
152 152 unauthorized execution by using a key. Messages are still sent
153 153 as cleartext, so if someone can snoop your loopback traffic this will
154 154 not help against malicious attacks.
155 155
156 156 exec_key : str
157 157 an authentication key or file containing a key
158 158 default: None
159 159
160 160
161 161 Attributes
162 162 ----------
163 163
164 164 ids : list of int engine IDs
165 165 requesting the ids attribute always synchronizes
166 166 the registration state. To request ids without synchronization,
167 167 use semi-private _ids attributes.
168 168
169 169 history : list of msg_ids
170 170 a list of msg_ids, keeping track of all the execution
171 171 messages you have submitted in order.
172 172
173 173 outstanding : set of msg_ids
174 174 a set of msg_ids that have been submitted, but whose
175 175 results have not yet been received.
176 176
177 177 results : dict
178 178 a dict of all our results, keyed by msg_id
179 179
180 180 block : bool
181 181 determines default behavior when block not specified
182 182 in execution methods
183 183
184 184 Methods
185 185 -------
186 186
187 187 spin
188 188 flushes incoming results and registration state changes
189 189 control methods spin, and requesting `ids` also ensures up to date
190 190
191 191 wait
192 192 wait on one or more msg_ids
193 193
194 194 execution methods
195 195 apply
196 196 legacy: execute, run
197 197
198 198 data movement
199 199 push, pull, scatter, gather
200 200
201 201 query methods
202 202 queue_status, get_result, purge, result_status
203 203
204 204 control methods
205 205 abort, shutdown
206 206
207 207 """
208 208
209 209
210 210 block = Bool(False)
211 211 outstanding = Set()
212 212 results = Instance('collections.defaultdict', (dict,))
213 213 metadata = Instance('collections.defaultdict', (Metadata,))
214 214 history = List()
215 215 debug = Bool(False)
216 216 profile=CUnicode('default')
217 217
218 218 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 219 _ids = List()
220 220 _connected=Bool(False)
221 221 _ssh=Bool(False)
222 222 _context = Instance('zmq.Context')
223 223 _config = Dict()
224 224 _engines=Instance(util.ReverseDict, (), {})
225 225 # _hub_socket=Instance('zmq.Socket')
226 226 _query_socket=Instance('zmq.Socket')
227 227 _control_socket=Instance('zmq.Socket')
228 228 _iopub_socket=Instance('zmq.Socket')
229 229 _notification_socket=Instance('zmq.Socket')
230 230 _mux_socket=Instance('zmq.Socket')
231 231 _task_socket=Instance('zmq.Socket')
232 232 _task_scheme=Str()
233 233 _closed = False
234 234 _ignored_control_replies=Int(0)
235 235 _ignored_hub_replies=Int(0)
236 236
237 237 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
238 238 context=None, username=None, debug=False, exec_key=None,
239 239 sshserver=None, sshkey=None, password=None, paramiko=None,
240 240 timeout=10
241 241 ):
242 242 super(Client, self).__init__(debug=debug, profile=profile)
243 243 if context is None:
244 244 context = zmq.Context.instance()
245 245 self._context = context
246 246
247 247
248 248 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
249 249 if self._cd is not None:
250 250 if url_or_file is None:
251 251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 253 " Please specify at least one of url_or_file or profile."
254 254
255 255 try:
256 256 util.validate_url(url_or_file)
257 257 except AssertionError:
258 258 if not os.path.exists(url_or_file):
259 259 if self._cd:
260 260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 262 with open(url_or_file) as f:
263 263 cfg = json.loads(f.read())
264 264 else:
265 265 cfg = {'url':url_or_file}
266 266
267 267 # sync defaults from args, json:
268 268 if sshserver:
269 269 cfg['ssh'] = sshserver
270 270 if exec_key:
271 271 cfg['exec_key'] = exec_key
272 272 exec_key = cfg['exec_key']
273 273 sshserver=cfg['ssh']
274 274 url = cfg['url']
275 275 location = cfg.setdefault('location', None)
276 276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 277 url = cfg['url']
278 278
279 279 self._config = cfg
280 280
281 281 self._ssh = bool(sshserver or sshkey or password)
282 282 if self._ssh and sshserver is None:
283 283 # default to ssh via localhost
284 284 sshserver = url.split('://')[1].split(':')[0]
285 285 if self._ssh and password is None:
286 286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 287 password=False
288 288 else:
289 289 password = getpass("SSH Password for %s: "%sshserver)
290 290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 291 if exec_key is not None and os.path.isfile(exec_key):
292 292 arg = 'keyfile'
293 293 else:
294 294 arg = 'key'
295 295 key_arg = {arg:exec_key}
296 296 if username is None:
297 297 self.session = ss.StreamSession(**key_arg)
298 298 else:
299 299 self.session = ss.StreamSession(username, **key_arg)
300 300 self._query_socket = self._context.socket(zmq.XREQ)
301 301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 302 if self._ssh:
303 303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 304 else:
305 305 self._query_socket.connect(url)
306 306
307 307 self.session.debug = self.debug
308 308
309 309 self._notification_handlers = {'registration_notification' : self._register_engine,
310 310 'unregistration_notification' : self._unregister_engine,
311 311 'shutdown_notification' : lambda msg: self.close(),
312 312 }
313 313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 314 'apply_reply' : self._handle_apply_reply}
315 315 self._connect(sshserver, ssh_kwargs, timeout)
316 316
317 317 def __del__(self):
318 318 """cleanup sockets, but _not_ context."""
319 319 self.close()
320 320
321 321 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
322 322 if ipython_dir is None:
323 323 ipython_dir = get_ipython_dir()
324 324 if cluster_dir is not None:
325 325 try:
326 326 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
327 327 return
328 328 except ClusterDirError:
329 329 pass
330 330 elif profile is not None:
331 331 try:
332 332 self._cd = ClusterDir.find_cluster_dir_by_profile(
333 333 ipython_dir, profile)
334 334 return
335 335 except ClusterDirError:
336 336 pass
337 337 self._cd = None
338 338
339 339 def _update_engines(self, engines):
340 340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 341 for k,v in engines.iteritems():
342 342 eid = int(k)
343 343 self._engines[eid] = bytes(v) # force not unicode
344 344 self._ids.append(eid)
345 345 self._ids = sorted(self._ids)
346 346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 347 self._task_scheme == 'pure' and self._task_socket:
348 348 self._stop_scheduling_tasks()
349 349
350 350 def _stop_scheduling_tasks(self):
351 351 """Stop scheduling tasks because an engine has been unregistered
352 352 from a pure ZMQ scheduler.
353 353 """
354 354 self._task_socket.close()
355 355 self._task_socket = None
356 356 msg = "An engine has been unregistered, and we are using pure " +\
357 357 "ZMQ task scheduling. Task farming will be disabled."
358 358 if self.outstanding:
359 359 msg += " If you were running tasks when this happened, " +\
360 360 "some `outstanding` msg_ids may never resolve."
361 361 warnings.warn(msg, RuntimeWarning)
362 362
363 363 def _build_targets(self, targets):
364 364 """Turn valid target IDs or 'all' into two lists:
365 365 (int_ids, uuids).
366 366 """
367 367 if not self._ids:
368 368 # flush notification socket if no engines yet, just in case
369 369 if not self.ids:
370 370 raise error.NoEnginesRegistered("Can't build targets without any engines")
371 371
372 372 if targets is None:
373 373 targets = self._ids
374 374 elif isinstance(targets, str):
375 375 if targets.lower() == 'all':
376 376 targets = self._ids
377 377 else:
378 378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 379 elif isinstance(targets, int):
380 380 if targets < 0:
381 381 targets = self.ids[targets]
382 382 if targets not in self._ids:
383 383 raise IndexError("No such engine: %i"%targets)
384 384 targets = [targets]
385 385
386 386 if isinstance(targets, slice):
387 387 indices = range(len(self._ids))[targets]
388 388 ids = self.ids
389 389 targets = [ ids[i] for i in indices ]
390 390
391 391 if not isinstance(targets, (tuple, list, xrange)):
392 392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393 393
394 394 return [self._engines[t] for t in targets], list(targets)
395 395
396 396 def _connect(self, sshserver, ssh_kwargs, timeout):
397 397 """setup all our socket connections to the cluster. This is called from
398 398 __init__."""
399 399
400 400 # Maybe allow reconnecting?
401 401 if self._connected:
402 402 return
403 403 self._connected=True
404 404
405 405 def connect_socket(s, url):
406 406 url = util.disambiguate_url(url, self._config['location'])
407 407 if self._ssh:
408 408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 409 else:
410 410 return s.connect(url)
411 411
412 412 self.session.send(self._query_socket, 'connection_request')
413 413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 414 if not r:
415 415 raise error.TimeoutError("Hub connection request timed out")
416 416 idents,msg = self.session.recv(self._query_socket,mode=0)
417 417 if self.debug:
418 418 pprint(msg)
419 419 msg = ss.Message(msg)
420 420 content = msg.content
421 421 self._config['registration'] = dict(content)
422 422 if content.status == 'ok':
423 423 if content.mux:
424 424 self._mux_socket = self._context.socket(zmq.XREQ)
425 425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 426 connect_socket(self._mux_socket, content.mux)
427 427 if content.task:
428 428 self._task_scheme, task_addr = content.task
429 429 self._task_socket = self._context.socket(zmq.XREQ)
430 430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 431 connect_socket(self._task_socket, task_addr)
432 432 if content.notification:
433 433 self._notification_socket = self._context.socket(zmq.SUB)
434 434 connect_socket(self._notification_socket, content.notification)
435 435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 436 # if content.query:
437 437 # self._query_socket = self._context.socket(zmq.XREQ)
438 438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 439 # connect_socket(self._query_socket, content.query)
440 440 if content.control:
441 441 self._control_socket = self._context.socket(zmq.XREQ)
442 442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 443 connect_socket(self._control_socket, content.control)
444 444 if content.iopub:
445 445 self._iopub_socket = self._context.socket(zmq.SUB)
446 446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 448 connect_socket(self._iopub_socket, content.iopub)
449 449 self._update_engines(dict(content.engines))
450 450 else:
451 451 self._connected = False
452 452 raise Exception("Failed to connect!")
453 453
454 454 #--------------------------------------------------------------------------
455 455 # handlers and callbacks for incoming messages
456 456 #--------------------------------------------------------------------------
457 457
458 458 def _unwrap_exception(self, content):
459 459 """unwrap exception, and remap engine_id to int."""
460 460 e = error.unwrap_exception(content)
461 461 # print e.traceback
462 462 if e.engine_info:
463 463 e_uuid = e.engine_info['engine_uuid']
464 464 eid = self._engines[e_uuid]
465 465 e.engine_info['engine_id'] = eid
466 466 return e
467 467
468 468 def _extract_metadata(self, header, parent, content):
469 469 md = {'msg_id' : parent['msg_id'],
470 470 'received' : datetime.now(),
471 471 'engine_uuid' : header.get('engine', None),
472 472 'follow' : parent.get('follow', []),
473 473 'after' : parent.get('after', []),
474 474 'status' : content['status'],
475 475 }
476 476
477 477 if md['engine_uuid'] is not None:
478 478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479 479
480 480 if 'date' in parent:
481 481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
482 482 if 'started' in header:
483 483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
484 484 if 'date' in header:
485 485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
486 486 return md
487 487
488 488 def _register_engine(self, msg):
489 489 """Register a new engine, and update our connection info."""
490 490 content = msg['content']
491 491 eid = content['id']
492 492 d = {eid : content['queue']}
493 493 self._update_engines(d)
494 494
495 495 def _unregister_engine(self, msg):
496 496 """Unregister an engine that has died."""
497 497 content = msg['content']
498 498 eid = int(content['id'])
499 499 if eid in self._ids:
500 500 self._ids.remove(eid)
501 501 uuid = self._engines.pop(eid)
502 502
503 503 self._handle_stranded_msgs(eid, uuid)
504 504
505 505 if self._task_socket and self._task_scheme == 'pure':
506 506 self._stop_scheduling_tasks()
507 507
508 508 def _handle_stranded_msgs(self, eid, uuid):
509 509 """Handle messages known to be on an engine when the engine unregisters.
510 510
511 511 It is possible that this will fire prematurely - that is, an engine will
512 512 go down after completing a result, and the client will be notified
513 513 of the unregistration and later receive the successful result.
514 514 """
515 515
516 516 outstanding = self._outstanding_dict[uuid]
517 517
518 518 for msg_id in list(outstanding):
519 519 if msg_id in self.results:
520 520 # we already
521 521 continue
522 522 try:
523 523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 524 except:
525 525 content = error.wrap_exception()
526 526 # build a fake message:
527 527 parent = {}
528 528 header = {}
529 529 parent['msg_id'] = msg_id
530 530 header['engine'] = uuid
531 531 header['date'] = datetime.now().strftime(util.ISO8601)
532 532 msg = dict(parent_header=parent, header=header, content=content)
533 533 self._handle_apply_reply(msg)
534 534
535 535 def _handle_execute_reply(self, msg):
536 536 """Save the reply to an execute_request into our results.
537 537
538 538 execute messages are never actually used. apply is used instead.
539 539 """
540 540
541 541 parent = msg['parent_header']
542 542 msg_id = parent['msg_id']
543 543 if msg_id not in self.outstanding:
544 544 if msg_id in self.history:
545 545 print ("got stale result: %s"%msg_id)
546 546 else:
547 547 print ("got unknown result: %s"%msg_id)
548 548 else:
549 549 self.outstanding.remove(msg_id)
550 550 self.results[msg_id] = self._unwrap_exception(msg['content'])
551 551
552 552 def _handle_apply_reply(self, msg):
553 553 """Save the reply to an apply_request into our results."""
554 554 parent = msg['parent_header']
555 555 msg_id = parent['msg_id']
556 556 if msg_id not in self.outstanding:
557 557 if msg_id in self.history:
558 558 print ("got stale result: %s"%msg_id)
559 559 print self.results[msg_id]
560 560 print msg
561 561 else:
562 562 print ("got unknown result: %s"%msg_id)
563 563 else:
564 564 self.outstanding.remove(msg_id)
565 565 content = msg['content']
566 566 header = msg['header']
567 567
568 568 # construct metadata:
569 569 md = self.metadata[msg_id]
570 570 md.update(self._extract_metadata(header, parent, content))
571 571 # is this redundant?
572 572 self.metadata[msg_id] = md
573 573
574 574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 575 if msg_id in e_outstanding:
576 576 e_outstanding.remove(msg_id)
577 577
578 578 # construct result:
579 579 if content['status'] == 'ok':
580 580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 581 elif content['status'] == 'aborted':
582 582 self.results[msg_id] = error.TaskAborted(msg_id)
583 583 elif content['status'] == 'resubmitted':
584 584 # TODO: handle resubmission
585 585 pass
586 586 else:
587 587 self.results[msg_id] = self._unwrap_exception(content)
588 588
589 589 def _flush_notifications(self):
590 590 """Flush notifications of engine registrations waiting
591 591 in ZMQ queue."""
592 592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 593 while msg is not None:
594 594 if self.debug:
595 595 pprint(msg)
596 596 msg = msg[-1]
597 597 msg_type = msg['msg_type']
598 598 handler = self._notification_handlers.get(msg_type, None)
599 599 if handler is None:
600 600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 601 else:
602 602 handler(msg)
603 603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 604
605 605 def _flush_results(self, sock):
606 606 """Flush task or queue results waiting in ZMQ queue."""
607 607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 608 while msg is not None:
609 609 if self.debug:
610 610 pprint(msg)
611 611 msg = msg[-1]
612 612 msg_type = msg['msg_type']
613 613 handler = self._queue_handlers.get(msg_type, None)
614 614 if handler is None:
615 615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 616 else:
617 617 handler(msg)
618 618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619 619
620 620 def _flush_control(self, sock):
621 621 """Flush replies from the control channel waiting
622 622 in the ZMQ queue.
623 623
624 624 Currently: ignore them."""
625 625 if self._ignored_control_replies <= 0:
626 626 return
627 627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 628 while msg is not None:
629 629 self._ignored_control_replies -= 1
630 630 if self.debug:
631 631 pprint(msg)
632 632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633 633
634 634 def _flush_ignored_control(self):
635 635 """flush ignored control replies"""
636 636 while self._ignored_control_replies > 0:
637 637 self.session.recv(self._control_socket)
638 638 self._ignored_control_replies -= 1
639 639
640 640 def _flush_ignored_hub_replies(self):
641 641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 642 while msg is not None:
643 643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644 644
645 645 def _flush_iopub(self, sock):
646 646 """Flush replies from the iopub channel waiting
647 647 in the ZMQ queue.
648 648 """
649 649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 650 while msg is not None:
651 651 if self.debug:
652 652 pprint(msg)
653 653 msg = msg[-1]
654 654 parent = msg['parent_header']
655 655 msg_id = parent['msg_id']
656 656 content = msg['content']
657 657 header = msg['header']
658 658 msg_type = msg['msg_type']
659 659
660 660 # init metadata:
661 661 md = self.metadata[msg_id]
662 662
663 663 if msg_type == 'stream':
664 664 name = content['name']
665 665 s = md[name] or ''
666 666 md[name] = s + content['data']
667 667 elif msg_type == 'pyerr':
668 668 md.update({'pyerr' : self._unwrap_exception(content)})
669 669 elif msg_type == 'pyin':
670 670 md.update({'pyin' : content['code']})
671 671 else:
672 672 md.update({msg_type : content.get('data', '')})
673 673
674 674 # reduntant?
675 675 self.metadata[msg_id] = md
676 676
677 677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678 678
679 679 #--------------------------------------------------------------------------
680 680 # len, getitem
681 681 #--------------------------------------------------------------------------
682 682
683 683 def __len__(self):
684 684 """len(client) returns # of engines."""
685 685 return len(self.ids)
686 686
687 687 def __getitem__(self, key):
688 688 """index access returns DirectView multiplexer objects
689 689
690 690 Must be int, slice, or list/tuple/xrange of ints"""
691 691 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 693 else:
694 694 return self.direct_view(key)
695 695
696 696 #--------------------------------------------------------------------------
697 697 # Begin public methods
698 698 #--------------------------------------------------------------------------
699 699
700 700 @property
701 701 def ids(self):
702 702 """Always up-to-date ids property."""
703 703 self._flush_notifications()
704 704 # always copy:
705 705 return list(self._ids)
706 706
707 707 def close(self):
708 708 if self._closed:
709 709 return
710 710 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 711 for socket in map(lambda name: getattr(self, name), snames):
712 712 if isinstance(socket, zmq.Socket) and not socket.closed:
713 713 socket.close()
714 714 self._closed = True
715 715
716 716 def spin(self):
717 717 """Flush any registration notifications and execution results
718 718 waiting in the ZMQ queue.
719 719 """
720 720 if self._notification_socket:
721 721 self._flush_notifications()
722 722 if self._mux_socket:
723 723 self._flush_results(self._mux_socket)
724 724 if self._task_socket:
725 725 self._flush_results(self._task_socket)
726 726 if self._control_socket:
727 727 self._flush_control(self._control_socket)
728 728 if self._iopub_socket:
729 729 self._flush_iopub(self._iopub_socket)
730 730 if self._query_socket:
731 731 self._flush_ignored_hub_replies()
732 732
733 733 def wait(self, jobs=None, timeout=-1):
734 734 """waits on one or more `jobs`, for up to `timeout` seconds.
735 735
736 736 Parameters
737 737 ----------
738 738
739 739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 740 ints are indices to self.history
741 741 strs are msg_ids
742 742 default: wait on all outstanding messages
743 743 timeout : float
744 744 a time in seconds, after which to give up.
745 745 default is -1, which means no timeout
746 746
747 747 Returns
748 748 -------
749 749
750 750 True : when all msg_ids are done
751 751 False : timeout reached, some msg_ids still outstanding
752 752 """
753 753 tic = time.time()
754 754 if jobs is None:
755 755 theids = self.outstanding
756 756 else:
757 757 if isinstance(jobs, (int, str, AsyncResult)):
758 758 jobs = [jobs]
759 759 theids = set()
760 760 for job in jobs:
761 761 if isinstance(job, int):
762 762 # index access
763 763 job = self.history[job]
764 764 elif isinstance(job, AsyncResult):
765 765 map(theids.add, job.msg_ids)
766 766 continue
767 767 theids.add(job)
768 768 if not theids.intersection(self.outstanding):
769 769 return True
770 770 self.spin()
771 771 while theids.intersection(self.outstanding):
772 772 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 773 break
774 774 time.sleep(1e-3)
775 775 self.spin()
776 776 return len(theids.intersection(self.outstanding)) == 0
777 777
778 778 #--------------------------------------------------------------------------
779 779 # Control methods
780 780 #--------------------------------------------------------------------------
781 781
782 782 @spin_first
783 783 def clear(self, targets=None, block=None):
784 784 """Clear the namespace in target(s)."""
785 785 block = self.block if block is None else block
786 786 targets = self._build_targets(targets)[0]
787 787 for t in targets:
788 788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 789 error = False
790 790 if block:
791 791 self._flush_ignored_control()
792 792 for i in range(len(targets)):
793 793 idents,msg = self.session.recv(self._control_socket,0)
794 794 if self.debug:
795 795 pprint(msg)
796 796 if msg['content']['status'] != 'ok':
797 797 error = self._unwrap_exception(msg['content'])
798 798 else:
799 799 self._ignored_control_replies += len(targets)
800 800 if error:
801 801 raise error
802 802
803 803
804 804 @spin_first
805 805 def abort(self, jobs=None, targets=None, block=None):
806 806 """Abort specific jobs from the execution queues of target(s).
807 807
808 808 This is a mechanism to prevent jobs that have already been submitted
809 809 from executing.
810 810
811 811 Parameters
812 812 ----------
813 813
814 814 jobs : msg_id, list of msg_ids, or AsyncResult
815 815 The jobs to be aborted
816 816
817 817
818 818 """
819 819 block = self.block if block is None else block
820 820 targets = self._build_targets(targets)[0]
821 821 msg_ids = []
822 822 if isinstance(jobs, (basestring,AsyncResult)):
823 823 jobs = [jobs]
824 824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 825 if bad_ids:
826 826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 827 for j in jobs:
828 828 if isinstance(j, AsyncResult):
829 829 msg_ids.extend(j.msg_ids)
830 830 else:
831 831 msg_ids.append(j)
832 832 content = dict(msg_ids=msg_ids)
833 833 for t in targets:
834 834 self.session.send(self._control_socket, 'abort_request',
835 835 content=content, ident=t)
836 836 error = False
837 837 if block:
838 838 self._flush_ignored_control()
839 839 for i in range(len(targets)):
840 840 idents,msg = self.session.recv(self._control_socket,0)
841 841 if self.debug:
842 842 pprint(msg)
843 843 if msg['content']['status'] != 'ok':
844 844 error = self._unwrap_exception(msg['content'])
845 845 else:
846 846 self._ignored_control_replies += len(targets)
847 847 if error:
848 848 raise error
849 849
850 850 @spin_first
851 851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 852 """Terminates one or more engine processes, optionally including the hub."""
853 853 block = self.block if block is None else block
854 854 if hub:
855 855 targets = 'all'
856 856 targets = self._build_targets(targets)[0]
857 857 for t in targets:
858 858 self.session.send(self._control_socket, 'shutdown_request',
859 859 content={'restart':restart},ident=t)
860 860 error = False
861 861 if block or hub:
862 862 self._flush_ignored_control()
863 863 for i in range(len(targets)):
864 864 idents,msg = self.session.recv(self._control_socket, 0)
865 865 if self.debug:
866 866 pprint(msg)
867 867 if msg['content']['status'] != 'ok':
868 868 error = self._unwrap_exception(msg['content'])
869 869 else:
870 870 self._ignored_control_replies += len(targets)
871 871
872 872 if hub:
873 873 time.sleep(0.25)
874 874 self.session.send(self._query_socket, 'shutdown_request')
875 875 idents,msg = self.session.recv(self._query_socket, 0)
876 876 if self.debug:
877 877 pprint(msg)
878 878 if msg['content']['status'] != 'ok':
879 879 error = self._unwrap_exception(msg['content'])
880 880
881 881 if error:
882 882 raise error
883 883
884 884 #--------------------------------------------------------------------------
885 885 # Execution related methods
886 886 #--------------------------------------------------------------------------
887 887
888 888 def _maybe_raise(self, result):
889 889 """wrapper for maybe raising an exception if apply failed."""
890 890 if isinstance(result, error.RemoteError):
891 891 raise result
892 892
893 893 return result
894 894
895 895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 896 ident=None):
897 897 """construct and send an apply message via a socket.
898 898
899 899 This is the principal method with which all engine execution is performed by views.
900 900 """
901 901
902 902 assert not self._closed, "cannot use me anymore, I'm closed!"
903 903 # defaults:
904 904 args = args if args is not None else []
905 905 kwargs = kwargs if kwargs is not None else {}
906 906 subheader = subheader if subheader is not None else {}
907 907
908 908 # validate arguments
909 909 if not callable(f):
910 910 raise TypeError("f must be callable, not %s"%type(f))
911 911 if not isinstance(args, (tuple, list)):
912 912 raise TypeError("args must be tuple or list, not %s"%type(args))
913 913 if not isinstance(kwargs, dict):
914 914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 915 if not isinstance(subheader, dict):
916 916 raise TypeError("subheader must be dict, not %s"%type(subheader))
917 917
918 918 bufs = util.pack_apply_message(f,args,kwargs)
919 919
920 920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 921 subheader=subheader, track=track)
922 922
923 923 msg_id = msg['msg_id']
924 924 self.outstanding.add(msg_id)
925 925 if ident:
926 926 # possibly routed to a specific engine
927 927 if isinstance(ident, list):
928 928 ident = ident[-1]
929 929 if ident in self._engines.values():
930 930 # save for later, in case of engine death
931 931 self._outstanding_dict[ident].add(msg_id)
932 932 self.history.append(msg_id)
933 933 self.metadata[msg_id]['submitted'] = datetime.now()
934 934
935 935 return msg
936 936
937 937 #--------------------------------------------------------------------------
938 938 # construct a View object
939 939 #--------------------------------------------------------------------------
940 940
941 941 def load_balanced_view(self, targets=None):
942 942 """construct a DirectView object.
943 943
944 944 If no arguments are specified, create a LoadBalancedView
945 945 using all engines.
946 946
947 947 Parameters
948 948 ----------
949 949
950 950 targets: list,slice,int,etc. [default: use all engines]
951 951 The subset of engines across which to load-balance
952 952 """
953 953 if targets is not None:
954 954 targets = self._build_targets(targets)[1]
955 955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956 956
957 957 def direct_view(self, targets='all'):
958 958 """construct a DirectView object.
959 959
960 960 If no targets are specified, create a DirectView
961 961 using all engines.
962 962
963 963 Parameters
964 964 ----------
965 965
966 966 targets: list,slice,int,etc. [default: use all engines]
967 967 The engines to use for the View
968 968 """
969 969 single = isinstance(targets, int)
970 970 targets = self._build_targets(targets)[1]
971 971 if single:
972 972 targets = targets[0]
973 973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974 974
975 975 #--------------------------------------------------------------------------
976 976 # Query methods
977 977 #--------------------------------------------------------------------------
978 978
979 979 @spin_first
980 980 def get_result(self, indices_or_msg_ids=None, block=None):
981 981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982 982
983 983 If the client already has the results, no request to the Hub will be made.
984 984
985 985 This is a convenient way to construct AsyncResult objects, which are wrappers
986 986 that include metadata about execution, and allow for awaiting results that
987 987 were not submitted by this Client.
988 988
989 989 It can also be a convenient way to retrieve the metadata associated with
990 990 blocking execution, since it always retrieves
991 991
992 992 Examples
993 993 --------
994 994 ::
995 995
996 996 In [10]: r = client.apply()
997 997
998 998 Parameters
999 999 ----------
1000 1000
1001 1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 1002 The indices or msg_ids of indices to be retrieved
1003 1003
1004 1004 block : bool
1005 1005 Whether to wait for the result to be done
1006 1006
1007 1007 Returns
1008 1008 -------
1009 1009
1010 1010 AsyncResult
1011 1011 A single AsyncResult object will always be returned.
1012 1012
1013 1013 AsyncHubResult
1014 1014 A subclass of AsyncResult that retrieves results from the Hub
1015 1015
1016 1016 """
1017 1017 block = self.block if block is None else block
1018 1018 if indices_or_msg_ids is None:
1019 1019 indices_or_msg_ids = -1
1020 1020
1021 1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 1022 indices_or_msg_ids = [indices_or_msg_ids]
1023 1023
1024 1024 theids = []
1025 1025 for id in indices_or_msg_ids:
1026 1026 if isinstance(id, int):
1027 1027 id = self.history[id]
1028 1028 if not isinstance(id, str):
1029 1029 raise TypeError("indices must be str or int, not %r"%id)
1030 1030 theids.append(id)
1031 1031
1032 1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034 1034
1035 1035 if remote_ids:
1036 1036 ar = AsyncHubResult(self, msg_ids=theids)
1037 1037 else:
1038 1038 ar = AsyncResult(self, msg_ids=theids)
1039 1039
1040 1040 if block:
1041 1041 ar.wait()
1042 1042
1043 1043 return ar
1044
1045 @spin_first
1046 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 """Resubmit one or more tasks.
1048
1049 in-flight tasks may not be resubmitted.
1050
1051 Parameters
1052 ----------
1053
1054 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 The indices or msg_ids of indices to be retrieved
1056
1057 block : bool
1058 Whether to wait for the result to be done
1059
1060 Returns
1061 -------
1062
1063 AsyncHubResult
1064 A subclass of AsyncResult that retrieves results from the Hub
1065
1066 """
1067 block = self.block if block is None else block
1068 if indices_or_msg_ids is None:
1069 indices_or_msg_ids = -1
1070
1071 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 indices_or_msg_ids = [indices_or_msg_ids]
1073
1074 theids = []
1075 for id in indices_or_msg_ids:
1076 if isinstance(id, int):
1077 id = self.history[id]
1078 if not isinstance(id, str):
1079 raise TypeError("indices must be str or int, not %r"%id)
1080 theids.append(id)
1081
1082 for msg_id in theids:
1083 self.outstanding.discard(msg_id)
1084 if msg_id in self.history:
1085 self.history.remove(msg_id)
1086 self.results.pop(msg_id, None)
1087 self.metadata.pop(msg_id, None)
1088 content = dict(msg_ids = theids)
1089
1090 self.session.send(self._query_socket, 'resubmit_request', content)
1091
1092 zmq.select([self._query_socket], [], [])
1093 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 if self.debug:
1095 pprint(msg)
1096 content = msg['content']
1097 if content['status'] != 'ok':
1098 raise self._unwrap_exception(content)
1099
1100 ar = AsyncHubResult(self, msg_ids=theids)
1101
1102 if block:
1103 ar.wait()
1104
1105 return ar
1044 1106
1045 1107 @spin_first
1046 1108 def result_status(self, msg_ids, status_only=True):
1047 1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1048 1110
1049 1111 If status_only is False, then the actual results will be retrieved, else
1050 1112 only the status of the results will be checked.
1051 1113
1052 1114 Parameters
1053 1115 ----------
1054 1116
1055 1117 msg_ids : list of msg_ids
1056 1118 if int:
1057 1119 Passed as index to self.history for convenience.
1058 1120 status_only : bool (default: True)
1059 1121 if False:
1060 1122 Retrieve the actual results of completed tasks.
1061 1123
1062 1124 Returns
1063 1125 -------
1064 1126
1065 1127 results : dict
1066 1128 There will always be the keys 'pending' and 'completed', which will
1067 1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1068 1130 is False, then completed results will be keyed by their `msg_id`.
1069 1131 """
1070 1132 if not isinstance(msg_ids, (list,tuple)):
1071 1133 msg_ids = [msg_ids]
1072 1134
1073 1135 theids = []
1074 1136 for msg_id in msg_ids:
1075 1137 if isinstance(msg_id, int):
1076 1138 msg_id = self.history[msg_id]
1077 1139 if not isinstance(msg_id, basestring):
1078 1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1079 1141 theids.append(msg_id)
1080 1142
1081 1143 completed = []
1082 1144 local_results = {}
1083 1145
1084 1146 # comment this block out to temporarily disable local shortcut:
1085 1147 for msg_id in theids:
1086 1148 if msg_id in self.results:
1087 1149 completed.append(msg_id)
1088 1150 local_results[msg_id] = self.results[msg_id]
1089 1151 theids.remove(msg_id)
1090 1152
1091 1153 if theids: # some not locally cached
1092 1154 content = dict(msg_ids=theids, status_only=status_only)
1093 1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1094 1156 zmq.select([self._query_socket], [], [])
1095 1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1096 1158 if self.debug:
1097 1159 pprint(msg)
1098 1160 content = msg['content']
1099 1161 if content['status'] != 'ok':
1100 1162 raise self._unwrap_exception(content)
1101 1163 buffers = msg['buffers']
1102 1164 else:
1103 1165 content = dict(completed=[],pending=[])
1104 1166
1105 1167 content['completed'].extend(completed)
1106 1168
1107 1169 if status_only:
1108 1170 return content
1109 1171
1110 1172 failures = []
1111 1173 # load cached results into result:
1112 1174 content.update(local_results)
1113 1175 # update cache with results:
1114 1176 for msg_id in sorted(theids):
1115 1177 if msg_id in content['completed']:
1116 1178 rec = content[msg_id]
1117 1179 parent = rec['header']
1118 1180 header = rec['result_header']
1119 1181 rcontent = rec['result_content']
1120 1182 iodict = rec['io']
1121 1183 if isinstance(rcontent, str):
1122 1184 rcontent = self.session.unpack(rcontent)
1123 1185
1124 1186 md = self.metadata[msg_id]
1125 1187 md.update(self._extract_metadata(header, parent, rcontent))
1126 1188 md.update(iodict)
1127 1189
1128 1190 if rcontent['status'] == 'ok':
1129 1191 res,buffers = util.unserialize_object(buffers)
1130 1192 else:
1131 1193 print rcontent
1132 1194 res = self._unwrap_exception(rcontent)
1133 1195 failures.append(res)
1134 1196
1135 1197 self.results[msg_id] = res
1136 1198 content[msg_id] = res
1137 1199
1138 1200 if len(theids) == 1 and failures:
1139 1201 raise failures[0]
1140 1202
1141 1203 error.collect_exceptions(failures, "result_status")
1142 1204 return content
1143 1205
1144 1206 @spin_first
1145 1207 def queue_status(self, targets='all', verbose=False):
1146 1208 """Fetch the status of engine queues.
1147 1209
1148 1210 Parameters
1149 1211 ----------
1150 1212
1151 1213 targets : int/str/list of ints/strs
1152 1214 the engines whose states are to be queried.
1153 1215 default : all
1154 1216 verbose : bool
1155 1217 Whether to return lengths only, or lists of ids for each element
1156 1218 """
1157 1219 engine_ids = self._build_targets(targets)[1]
1158 1220 content = dict(targets=engine_ids, verbose=verbose)
1159 1221 self.session.send(self._query_socket, "queue_request", content=content)
1160 1222 idents,msg = self.session.recv(self._query_socket, 0)
1161 1223 if self.debug:
1162 1224 pprint(msg)
1163 1225 content = msg['content']
1164 1226 status = content.pop('status')
1165 1227 if status != 'ok':
1166 1228 raise self._unwrap_exception(content)
1167 1229 content = util.rekey(content)
1168 1230 if isinstance(targets, int):
1169 1231 return content[targets]
1170 1232 else:
1171 1233 return content
1172 1234
1173 1235 @spin_first
1174 1236 def purge_results(self, jobs=[], targets=[]):
1175 1237 """Tell the Hub to forget results.
1176 1238
1177 1239 Individual results can be purged by msg_id, or the entire
1178 1240 history of specific targets can be purged.
1179 1241
1180 1242 Parameters
1181 1243 ----------
1182 1244
1183 1245 jobs : str or list of str or AsyncResult objects
1184 1246 the msg_ids whose results should be forgotten.
1185 1247 targets : int/str/list of ints/strs
1186 1248 The targets, by uuid or int_id, whose entire history is to be purged.
1187 1249 Use `targets='all'` to scrub everything from the Hub's memory.
1188 1250
1189 1251 default : None
1190 1252 """
1191 1253 if not targets and not jobs:
1192 1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1193 1255 if targets:
1194 1256 targets = self._build_targets(targets)[1]
1195 1257
1196 1258 # construct msg_ids from jobs
1197 1259 msg_ids = []
1198 1260 if isinstance(jobs, (basestring,AsyncResult)):
1199 1261 jobs = [jobs]
1200 1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1201 1263 if bad_ids:
1202 1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1203 1265 for j in jobs:
1204 1266 if isinstance(j, AsyncResult):
1205 1267 msg_ids.extend(j.msg_ids)
1206 1268 else:
1207 1269 msg_ids.append(j)
1208 1270
1209 1271 content = dict(targets=targets, msg_ids=msg_ids)
1210 1272 self.session.send(self._query_socket, "purge_request", content=content)
1211 1273 idents, msg = self.session.recv(self._query_socket, 0)
1212 1274 if self.debug:
1213 1275 pprint(msg)
1214 1276 content = msg['content']
1215 1277 if content['status'] != 'ok':
1216 1278 raise self._unwrap_exception(content)
1217 1279
1218 1280 @spin_first
1219 1281 def hub_history(self):
1220 1282 """Get the Hub's history
1221 1283
1222 1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1223 1285 This will contain the history of all clients, and, depending on configuration,
1224 1286 may contain history across multiple cluster sessions.
1225 1287
1226 1288 Any msg_id returned here is a valid argument to `get_result`.
1227 1289
1228 1290 Returns
1229 1291 -------
1230 1292
1231 1293 msg_ids : list of strs
1232 1294 list of all msg_ids, ordered by task submission time.
1233 1295 """
1234 1296
1235 1297 self.session.send(self._query_socket, "history_request", content={})
1236 1298 idents, msg = self.session.recv(self._query_socket, 0)
1237 1299
1238 1300 if self.debug:
1239 1301 pprint(msg)
1240 1302 content = msg['content']
1241 1303 if content['status'] != 'ok':
1242 1304 raise self._unwrap_exception(content)
1243 1305 else:
1244 1306 return content['history']
1245 1307
1246 1308 @spin_first
1247 1309 def db_query(self, query, keys=None):
1248 1310 """Query the Hub's TaskRecord database
1249 1311
1250 1312 This will return a list of task record dicts that match `query`
1251 1313
1252 1314 Parameters
1253 1315 ----------
1254 1316
1255 1317 query : mongodb query dict
1256 1318 The search dict. See mongodb query docs for details.
1257 1319 keys : list of strs [optional]
1258 1320 THe subset of keys to be returned. The default is to fetch everything.
1259 1321 'msg_id' will *always* be included.
1260 1322 """
1261 1323 content = dict(query=query, keys=keys)
1262 1324 self.session.send(self._query_socket, "db_request", content=content)
1263 1325 idents, msg = self.session.recv(self._query_socket, 0)
1264 1326 if self.debug:
1265 1327 pprint(msg)
1266 1328 content = msg['content']
1267 1329 if content['status'] != 'ok':
1268 1330 raise self._unwrap_exception(content)
1269 1331
1270 1332 records = content['records']
1271 1333 buffer_lens = content['buffer_lens']
1272 1334 result_buffer_lens = content['result_buffer_lens']
1273 1335 buffers = msg['buffers']
1274 1336 has_bufs = buffer_lens is not None
1275 1337 has_rbufs = result_buffer_lens is not None
1276 1338 for i,rec in enumerate(records):
1277 1339 # relink buffers
1278 1340 if has_bufs:
1279 1341 blen = buffer_lens[i]
1280 1342 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1281 1343 if has_rbufs:
1282 1344 blen = result_buffer_lens[i]
1283 1345 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1284 1346 # turn timestamps back into times
1285 1347 for key in 'submitted started completed resubmitted'.split():
1286 1348 maybedate = rec.get(key, None)
1287 1349 if maybedate and util.ISO8601_RE.match(maybedate):
1288 1350 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1289 1351
1290 1352 return records
1291 1353
1292 1354 __all__ = [ 'Client' ]
@@ -1,1193 +1,1284 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29 29
30 30 from IPython.parallel import error, util
31 31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 32
33 33 from .heartmonitor import HeartMonitor
34 34
35 35 #-----------------------------------------------------------------------------
36 36 # Code
37 37 #-----------------------------------------------------------------------------
38 38
39 39 def _passer(*args, **kwargs):
40 40 return
41 41
42 42 def _printer(*args, **kwargs):
43 43 print (args)
44 44 print (kwargs)
45 45
46 46 def empty_record():
47 47 """Return an empty dict with all record keys."""
48 48 return {
49 49 'msg_id' : None,
50 50 'header' : None,
51 51 'content': None,
52 52 'buffers': None,
53 53 'submitted': None,
54 54 'client_uuid' : None,
55 55 'engine_uuid' : None,
56 56 'started': None,
57 57 'completed': None,
58 58 'resubmitted': None,
59 59 'result_header' : None,
60 60 'result_content' : None,
61 61 'result_buffers' : None,
62 62 'queue' : None,
63 63 'pyin' : None,
64 64 'pyout': None,
65 65 'pyerr': None,
66 66 'stdout': '',
67 67 'stderr': '',
68 68 }
69 69
70 70 def init_record(msg):
71 71 """Initialize a TaskRecord based on a request."""
72 72 header = msg['header']
73 73 return {
74 74 'msg_id' : header['msg_id'],
75 75 'header' : header,
76 76 'content': msg['content'],
77 77 'buffers': msg['buffers'],
78 78 'submitted': datetime.strptime(header['date'], util.ISO8601),
79 79 'client_uuid' : None,
80 80 'engine_uuid' : None,
81 81 'started': None,
82 82 'completed': None,
83 83 'resubmitted': None,
84 84 'result_header' : None,
85 85 'result_content' : None,
86 86 'result_buffers' : None,
87 87 'queue' : None,
88 88 'pyin' : None,
89 89 'pyout': None,
90 90 'pyerr': None,
91 91 'stdout': '',
92 92 'stderr': '',
93 93 }
94 94
95 95
96 96 class EngineConnector(HasTraits):
97 97 """A simple object for accessing the various zmq connections of an object.
98 98 Attributes are:
99 99 id (int): engine ID
100 100 uuid (str): uuid (unused?)
101 101 queue (str): identity of queue's XREQ socket
102 102 registration (str): identity of registration XREQ socket
103 103 heartbeat (str): identity of heartbeat XREQ socket
104 104 """
105 105 id=Int(0)
106 106 queue=Str()
107 107 control=Str()
108 108 registration=Str()
109 109 heartbeat=Str()
110 110 pending=Set()
111 111
112 112 class HubFactory(RegistrationFactory):
113 113 """The Configurable for setting up a Hub."""
114 114
115 115 # name of a scheduler scheme
116 116 scheme = Str('leastload', config=True)
117 117
118 118 # port-pairs for monitoredqueues:
119 119 hb = Instance(list, config=True)
120 120 def _hb_default(self):
121 121 return util.select_random_ports(2)
122 122
123 123 mux = Instance(list, config=True)
124 124 def _mux_default(self):
125 125 return util.select_random_ports(2)
126 126
127 127 task = Instance(list, config=True)
128 128 def _task_default(self):
129 129 return util.select_random_ports(2)
130 130
131 131 control = Instance(list, config=True)
132 132 def _control_default(self):
133 133 return util.select_random_ports(2)
134 134
135 135 iopub = Instance(list, config=True)
136 136 def _iopub_default(self):
137 137 return util.select_random_ports(2)
138 138
139 139 # single ports:
140 140 mon_port = Instance(int, config=True)
141 141 def _mon_port_default(self):
142 142 return util.select_random_ports(1)[0]
143 143
144 144 notifier_port = Instance(int, config=True)
145 145 def _notifier_port_default(self):
146 146 return util.select_random_ports(1)[0]
147 147
148 148 ping = Int(1000, config=True) # ping frequency
149 149
150 150 engine_ip = CStr('127.0.0.1', config=True)
151 151 engine_transport = CStr('tcp', config=True)
152 152
153 153 client_ip = CStr('127.0.0.1', config=True)
154 154 client_transport = CStr('tcp', config=True)
155 155
156 156 monitor_ip = CStr('127.0.0.1', config=True)
157 157 monitor_transport = CStr('tcp', config=True)
158 158
159 159 monitor_url = CStr('')
160 160
161 161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
162 162
163 163 # not configurable
164 164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
165 165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
166 166 subconstructors = List()
167 167 _constructed = Bool(False)
168 168
169 169 def _ip_changed(self, name, old, new):
170 170 self.engine_ip = new
171 171 self.client_ip = new
172 172 self.monitor_ip = new
173 173 self._update_monitor_url()
174 174
175 175 def _update_monitor_url(self):
176 176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
177 177
178 178 def _transport_changed(self, name, old, new):
179 179 self.engine_transport = new
180 180 self.client_transport = new
181 181 self.monitor_transport = new
182 182 self._update_monitor_url()
183 183
184 184 def __init__(self, **kwargs):
185 185 super(HubFactory, self).__init__(**kwargs)
186 186 self._update_monitor_url()
187 187 # self.on_trait_change(self._sync_ips, 'ip')
188 188 # self.on_trait_change(self._sync_transports, 'transport')
189 189 self.subconstructors.append(self.construct_hub)
190 190
191 191
192 192 def construct(self):
193 193 assert not self._constructed, "already constructed!"
194 194
195 195 for subc in self.subconstructors:
196 196 subc()
197 197
198 198 self._constructed = True
199 199
200 200
201 201 def start(self):
202 202 assert self._constructed, "must be constructed by self.construct() first!"
203 203 self.heartmonitor.start()
204 204 self.log.info("Heartmonitor started")
205 205
206 206 def construct_hub(self):
207 207 """construct"""
208 208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
209 209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
210 210
211 211 ctx = self.context
212 212 loop = self.loop
213 213
214 214 # Registrar socket
215 215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
216 216 q.bind(client_iface % self.regport)
217 217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
218 218 if self.client_ip != self.engine_ip:
219 219 q.bind(engine_iface % self.regport)
220 220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
221 221
222 222 ### Engine connections ###
223 223
224 224 # heartbeat
225 225 hpub = ctx.socket(zmq.PUB)
226 226 hpub.bind(engine_iface % self.hb[0])
227 227 hrep = ctx.socket(zmq.XREP)
228 228 hrep.bind(engine_iface % self.hb[1])
229 229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
230 230 period=self.ping, logname=self.log.name)
231 231
232 232 ### Client connections ###
233 233 # Notifier socket
234 234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
235 235 n.bind(client_iface%self.notifier_port)
236 236
237 237 ### build and launch the queues ###
238 238
239 239 # monitor socket
240 240 sub = ctx.socket(zmq.SUB)
241 241 sub.setsockopt(zmq.SUBSCRIBE, "")
242 242 sub.bind(self.monitor_url)
243 243 sub.bind('inproc://monitor')
244 244 sub = ZMQStream(sub, loop)
245 245
246 246 # connect the db
247 247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
248 248 # cdir = self.config.Global.cluster_dir
249 249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
250 250 time.sleep(.25)
251 251
252 252 # build connection dicts
253 253 self.engine_info = {
254 254 'control' : engine_iface%self.control[1],
255 255 'mux': engine_iface%self.mux[1],
256 256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
257 257 'task' : engine_iface%self.task[1],
258 258 'iopub' : engine_iface%self.iopub[1],
259 259 # 'monitor' : engine_iface%self.mon_port,
260 260 }
261 261
262 262 self.client_info = {
263 263 'control' : client_iface%self.control[0],
264 264 'mux': client_iface%self.mux[0],
265 265 'task' : (self.scheme, client_iface%self.task[0]),
266 266 'iopub' : client_iface%self.iopub[0],
267 267 'notification': client_iface%self.notifier_port
268 268 }
269 269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 270 self.log.debug("Hub client addrs: %s"%self.client_info)
271
272 # resubmit stream
273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 url = util.disambiguate_url(self.client_info['task'][-1])
275 r.setsockopt(zmq.IDENTITY, self.session.session)
276 r.connect(url)
277
271 278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
272 query=q, notifier=n, db=self.db,
279 query=q, notifier=n, resubmit=r, db=self.db,
273 280 engine_info=self.engine_info, client_info=self.client_info,
274 281 logname=self.log.name)
275 282
276 283
277 284 class Hub(LoggingFactory):
278 285 """The IPython Controller Hub with 0MQ connections
279 286
280 287 Parameters
281 288 ==========
282 289 loop: zmq IOLoop instance
283 290 session: StreamSession object
284 291 <removed> context: zmq context for creating new connections (?)
285 292 queue: ZMQStream for monitoring the command queue (SUB)
286 293 query: ZMQStream for engine registration and client queries requests (XREP)
287 294 heartbeat: HeartMonitor object checking the pulse of the engines
288 295 notifier: ZMQStream for broadcasting engine registration changes (PUB)
289 296 db: connection to db for out of memory logging of commands
290 297 NotImplemented
291 298 engine_info: dict of zmq connection information for engines to connect
292 299 to the queues.
293 300 client_info: dict of zmq connection information for engines to connect
294 301 to the queues.
295 302 """
296 303 # internal data structures:
297 304 ids=Set() # engine IDs
298 305 keytable=Dict()
299 306 by_ident=Dict()
300 307 engines=Dict()
301 308 clients=Dict()
302 309 hearts=Dict()
303 310 pending=Set()
304 311 queues=Dict() # pending msg_ids keyed by engine_id
305 312 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
306 313 completed=Dict() # completed msg_ids keyed by engine_id
307 314 all_completed=Set() # completed msg_ids keyed by engine_id
308 315 dead_engines=Set() # completed msg_ids keyed by engine_id
309 316 unassigned=Set() # set of task msg_ds not yet assigned a destination
310 317 incoming_registrations=Dict()
311 318 registration_timeout=Int()
312 319 _idcounter=Int(0)
313 320
314 321 # objects from constructor:
315 322 loop=Instance(ioloop.IOLoop)
316 323 query=Instance(ZMQStream)
317 324 monitor=Instance(ZMQStream)
318 heartmonitor=Instance(HeartMonitor)
319 325 notifier=Instance(ZMQStream)
326 resubmit=Instance(ZMQStream)
327 heartmonitor=Instance(HeartMonitor)
320 328 db=Instance(object)
321 329 client_info=Dict()
322 330 engine_info=Dict()
323 331
324 332
325 333 def __init__(self, **kwargs):
326 334 """
327 335 # universal:
328 336 loop: IOLoop for creating future connections
329 337 session: streamsession for sending serialized data
330 338 # engine:
331 339 queue: ZMQStream for monitoring queue messages
332 340 query: ZMQStream for engine+client registration and client requests
333 341 heartbeat: HeartMonitor object for tracking engines
334 342 # extra:
335 343 db: ZMQStream for db connection (NotImplemented)
336 344 engine_info: zmq address/protocol dict for engine connections
337 345 client_info: zmq address/protocol dict for client connections
338 346 """
339 347
340 348 super(Hub, self).__init__(**kwargs)
341 349 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
342 350
343 351 # validate connection dicts:
344 352 for k,v in self.client_info.iteritems():
345 353 if k == 'task':
346 354 util.validate_url_container(v[1])
347 355 else:
348 356 util.validate_url_container(v)
349 357 # util.validate_url_container(self.client_info)
350 358 util.validate_url_container(self.engine_info)
351 359
352 360 # register our callbacks
353 361 self.query.on_recv(self.dispatch_query)
354 362 self.monitor.on_recv(self.dispatch_monitor_traffic)
355 363
356 364 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
357 365 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
358 366
359 367 self.monitor_handlers = { 'in' : self.save_queue_request,
360 368 'out': self.save_queue_result,
361 369 'intask': self.save_task_request,
362 370 'outtask': self.save_task_result,
363 371 'tracktask': self.save_task_destination,
364 372 'incontrol': _passer,
365 373 'outcontrol': _passer,
366 374 'iopub': self.save_iopub_message,
367 375 }
368 376
369 377 self.query_handlers = {'queue_request': self.queue_status,
370 378 'result_request': self.get_results,
371 379 'history_request': self.get_history,
372 380 'db_request': self.db_query,
373 381 'purge_request': self.purge_results,
374 382 'load_request': self.check_load,
375 383 'resubmit_request': self.resubmit_task,
376 384 'shutdown_request': self.shutdown_request,
377 385 'registration_request' : self.register_engine,
378 386 'unregistration_request' : self.unregister_engine,
379 387 'connection_request': self.connection_request,
380 388 }
381 389
390 # ignore resubmit replies
391 self.resubmit.on_recv(lambda msg: None, copy=False)
392
382 393 self.log.info("hub::created hub")
383 394
384 395 @property
385 396 def _next_id(self):
386 397 """gemerate a new ID.
387 398
388 399 No longer reuse old ids, just count from 0."""
389 400 newid = self._idcounter
390 401 self._idcounter += 1
391 402 return newid
392 403 # newid = 0
393 404 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
394 405 # # print newid, self.ids, self.incoming_registrations
395 406 # while newid in self.ids or newid in incoming:
396 407 # newid += 1
397 408 # return newid
398 409
399 410 #-----------------------------------------------------------------------------
400 411 # message validation
401 412 #-----------------------------------------------------------------------------
402 413
403 414 def _validate_targets(self, targets):
404 415 """turn any valid targets argument into a list of integer ids"""
405 416 if targets is None:
406 417 # default to all
407 418 targets = self.ids
408 419
409 420 if isinstance(targets, (int,str,unicode)):
410 421 # only one target specified
411 422 targets = [targets]
412 423 _targets = []
413 424 for t in targets:
414 425 # map raw identities to ids
415 426 if isinstance(t, (str,unicode)):
416 427 t = self.by_ident.get(t, t)
417 428 _targets.append(t)
418 429 targets = _targets
419 430 bad_targets = [ t for t in targets if t not in self.ids ]
420 431 if bad_targets:
421 432 raise IndexError("No Such Engine: %r"%bad_targets)
422 433 if not targets:
423 434 raise IndexError("No Engines Registered")
424 435 return targets
425 436
426 437 #-----------------------------------------------------------------------------
427 438 # dispatch methods (1 per stream)
428 439 #-----------------------------------------------------------------------------
429 440
430 441 # def dispatch_registration_request(self, msg):
431 442 # """"""
432 443 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
433 444 # idents,msg = self.session.feed_identities(msg)
434 445 # if not idents:
435 446 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
436 447 # return
437 448 # try:
438 449 # msg = self.session.unpack_message(msg,content=True)
439 450 # except:
440 451 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
441 452 # return
442 453 #
443 454 # msg_type = msg['msg_type']
444 455 # content = msg['content']
445 456 #
446 457 # handler = self.query_handlers.get(msg_type, None)
447 458 # if handler is None:
448 459 # self.log.error("registration::got bad registration message: %s"%msg)
449 460 # else:
450 461 # handler(idents, msg)
451 462
452 463 def dispatch_monitor_traffic(self, msg):
453 464 """all ME and Task queue messages come through here, as well as
454 465 IOPub traffic."""
455 self.log.debug("monitor traffic: %s"%msg[:2])
466 self.log.debug("monitor traffic: %r"%msg[:2])
456 467 switch = msg[0]
457 468 idents, msg = self.session.feed_identities(msg[1:])
458 469 if not idents:
459 self.log.error("Bad Monitor Message: %s"%msg)
470 self.log.error("Bad Monitor Message: %r"%msg)
460 471 return
461 472 handler = self.monitor_handlers.get(switch, None)
462 473 if handler is not None:
463 474 handler(idents, msg)
464 475 else:
465 self.log.error("Invalid monitor topic: %s"%switch)
476 self.log.error("Invalid monitor topic: %r"%switch)
466 477
467 478
468 479 def dispatch_query(self, msg):
469 480 """Route registration requests and queries from clients."""
470 481 idents, msg = self.session.feed_identities(msg)
471 482 if not idents:
472 self.log.error("Bad Query Message: %s"%msg)
483 self.log.error("Bad Query Message: %r"%msg)
473 484 return
474 485 client_id = idents[0]
475 486 try:
476 487 msg = self.session.unpack_message(msg, content=True)
477 488 except:
478 489 content = error.wrap_exception()
479 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
480 491 self.session.send(self.query, "hub_error", ident=client_id,
481 492 content=content)
482 493 return
483 494
484 495 # print client_id, header, parent, content
485 496 #switch on message type:
486 497 msg_type = msg['msg_type']
487 self.log.info("client::client %s requested %s"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
488 499 handler = self.query_handlers.get(msg_type, None)
489 500 try:
490 assert handler is not None, "Bad Message Type: %s"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
491 502 except:
492 503 content = error.wrap_exception()
493 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
494 505 self.session.send(self.query, "hub_error", ident=client_id,
495 506 content=content)
496 507 return
508
497 509 else:
498 510 handler(idents, msg)
499 511
500 512 def dispatch_db(self, msg):
501 513 """"""
502 514 raise NotImplementedError
503 515
504 516 #---------------------------------------------------------------------------
505 517 # handler methods (1 per event)
506 518 #---------------------------------------------------------------------------
507 519
508 520 #----------------------- Heartbeat --------------------------------------
509 521
510 522 def handle_new_heart(self, heart):
511 523 """handler to attach to heartbeater.
512 524 Called when a new heart starts to beat.
513 525 Triggers completion of registration."""
514 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
515 527 if heart not in self.incoming_registrations:
516 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
517 529 else:
518 530 self.finish_registration(heart)
519 531
520 532
521 533 def handle_heart_failure(self, heart):
522 534 """handler to attach to heartbeater.
523 535 called when a previously registered heart fails to respond to beat request.
524 536 triggers unregistration"""
525 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
526 538 eid = self.hearts.get(heart, None)
527 539 queue = self.engines[eid].queue
528 540 if eid is None:
529 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
530 542 else:
531 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
532 544
533 545 #----------------------- MUX Queue Traffic ------------------------------
534 546
535 547 def save_queue_request(self, idents, msg):
536 548 if len(idents) < 2:
537 549 self.log.error("invalid identity prefix: %s"%idents)
538 550 return
539 551 queue_id, client_id = idents[:2]
540 552 try:
541 553 msg = self.session.unpack_message(msg, content=False)
542 554 except:
543 555 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
544 556 return
545 557
546 558 eid = self.by_ident.get(queue_id, None)
547 559 if eid is None:
548 560 self.log.error("queue::target %r not registered"%queue_id)
549 561 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
550 562 return
551 563
552 564 header = msg['header']
553 565 msg_id = header['msg_id']
554 566 record = init_record(msg)
555 567 record['engine_uuid'] = queue_id
556 568 record['client_uuid'] = client_id
557 569 record['queue'] = 'mux'
558 570
559 571 try:
560 572 # it's posible iopub arrived first:
561 573 existing = self.db.get_record(msg_id)
562 574 for key,evalue in existing.iteritems():
563 rvalue = record[key]
575 rvalue = record.get(key, None)
564 576 if evalue and rvalue and evalue != rvalue:
565 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
566 578 elif evalue and not rvalue:
567 579 record[key] = evalue
568 580 self.db.update_record(msg_id, record)
569 581 except KeyError:
570 582 self.db.add_record(msg_id, record)
571 583
572 584 self.pending.add(msg_id)
573 585 self.queues[eid].append(msg_id)
574 586
575 587 def save_queue_result(self, idents, msg):
576 588 if len(idents) < 2:
577 589 self.log.error("invalid identity prefix: %s"%idents)
578 590 return
579 591
580 592 client_id, queue_id = idents[:2]
581 593 try:
582 594 msg = self.session.unpack_message(msg, content=False)
583 595 except:
584 596 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
585 597 queue_id,client_id, msg), exc_info=True)
586 598 return
587 599
588 600 eid = self.by_ident.get(queue_id, None)
589 601 if eid is None:
590 602 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
591 603 # self.log.debug("queue:: %s"%msg[2:])
592 604 return
593 605
594 606 parent = msg['parent_header']
595 607 if not parent:
596 608 return
597 609 msg_id = parent['msg_id']
598 610 if msg_id in self.pending:
599 611 self.pending.remove(msg_id)
600 612 self.all_completed.add(msg_id)
601 613 self.queues[eid].remove(msg_id)
602 614 self.completed[eid].append(msg_id)
603 615 elif msg_id not in self.all_completed:
604 616 # it could be a result from a dead engine that died before delivering the
605 617 # result
606 618 self.log.warn("queue:: unknown msg finished %s"%msg_id)
607 619 return
608 620 # update record anyway, because the unregistration could have been premature
609 621 rheader = msg['header']
610 622 completed = datetime.strptime(rheader['date'], util.ISO8601)
611 623 started = rheader.get('started', None)
612 624 if started is not None:
613 625 started = datetime.strptime(started, util.ISO8601)
614 626 result = {
615 627 'result_header' : rheader,
616 628 'result_content': msg['content'],
617 629 'started' : started,
618 630 'completed' : completed
619 631 }
620 632
621 633 result['result_buffers'] = msg['buffers']
622 634 try:
623 635 self.db.update_record(msg_id, result)
624 636 except Exception:
625 637 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
626 638
627 639
628 640 #--------------------- Task Queue Traffic ------------------------------
629 641
630 642 def save_task_request(self, idents, msg):
631 643 """Save the submission of a task."""
632 644 client_id = idents[0]
633 645
634 646 try:
635 647 msg = self.session.unpack_message(msg, content=False)
636 648 except:
637 649 self.log.error("task::client %r sent invalid task message: %s"%(
638 650 client_id, msg), exc_info=True)
639 651 return
640 652 record = init_record(msg)
641 653
642 654 record['client_uuid'] = client_id
643 655 record['queue'] = 'task'
644 656 header = msg['header']
645 657 msg_id = header['msg_id']
646 658 self.pending.add(msg_id)
647 659 self.unassigned.add(msg_id)
648 660 try:
649 661 # it's posible iopub arrived first:
650 662 existing = self.db.get_record(msg_id)
663 if existing['resubmitted']:
664 for key in ('submitted', 'client_uuid', 'buffers'):
665 # don't clobber these keys on resubmit
666 # submitted and client_uuid should be different
667 # and buffers might be big, and shouldn't have changed
668 record.pop(key)
669 # still check content,header which should not change
670 # but are not expensive to compare as buffers
671
651 672 for key,evalue in existing.iteritems():
652 rvalue = record[key]
673 if key.endswith('buffers'):
674 # don't compare buffers
675 continue
676 rvalue = record.get(key, None)
653 677 if evalue and rvalue and evalue != rvalue:
654 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
655 679 elif evalue and not rvalue:
656 680 record[key] = evalue
657 681 self.db.update_record(msg_id, record)
658 682 except KeyError:
659 683 self.db.add_record(msg_id, record)
660 684 except Exception:
661 685 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
662 686
663 687 def save_task_result(self, idents, msg):
664 688 """save the result of a completed task."""
665 689 client_id = idents[0]
666 690 try:
667 691 msg = self.session.unpack_message(msg, content=False)
668 692 except:
669 693 self.log.error("task::invalid task result message send to %r: %s"%(
670 694 client_id, msg), exc_info=True)
671 695 raise
672 696 return
673 697
674 698 parent = msg['parent_header']
675 699 if not parent:
676 700 # print msg
677 701 self.log.warn("Task %r had no parent!"%msg)
678 702 return
679 703 msg_id = parent['msg_id']
680 704 if msg_id in self.unassigned:
681 705 self.unassigned.remove(msg_id)
682 706
683 707 header = msg['header']
684 708 engine_uuid = header.get('engine', None)
685 709 eid = self.by_ident.get(engine_uuid, None)
686 710
687 711 if msg_id in self.pending:
688 712 self.pending.remove(msg_id)
689 713 self.all_completed.add(msg_id)
690 714 if eid is not None:
691 715 self.completed[eid].append(msg_id)
692 716 if msg_id in self.tasks[eid]:
693 717 self.tasks[eid].remove(msg_id)
694 718 completed = datetime.strptime(header['date'], util.ISO8601)
695 719 started = header.get('started', None)
696 720 if started is not None:
697 721 started = datetime.strptime(started, util.ISO8601)
698 722 result = {
699 723 'result_header' : header,
700 724 'result_content': msg['content'],
701 725 'started' : started,
702 726 'completed' : completed,
703 727 'engine_uuid': engine_uuid
704 728 }
705 729
706 730 result['result_buffers'] = msg['buffers']
707 731 try:
708 732 self.db.update_record(msg_id, result)
709 733 except Exception:
710 734 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
711 735
712 736 else:
713 737 self.log.debug("task::unknown task %s finished"%msg_id)
714 738
715 739 def save_task_destination(self, idents, msg):
716 740 try:
717 741 msg = self.session.unpack_message(msg, content=True)
718 742 except:
719 743 self.log.error("task::invalid task tracking message", exc_info=True)
720 744 return
721 745 content = msg['content']
722 746 # print (content)
723 747 msg_id = content['msg_id']
724 748 engine_uuid = content['engine_id']
725 749 eid = self.by_ident[engine_uuid]
726 750
727 751 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
728 752 if msg_id in self.unassigned:
729 753 self.unassigned.remove(msg_id)
730 754 # else:
731 755 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
732 756
733 757 self.tasks[eid].append(msg_id)
734 758 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
735 759 try:
736 760 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
737 761 except Exception:
738 762 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
739 763
740 764
741 765 def mia_task_request(self, idents, msg):
742 766 raise NotImplementedError
743 767 client_id = idents[0]
744 768 # content = dict(mia=self.mia,status='ok')
745 769 # self.session.send('mia_reply', content=content, idents=client_id)
746 770
747 771
748 772 #--------------------- IOPub Traffic ------------------------------
749 773
750 774 def save_iopub_message(self, topics, msg):
751 775 """save an iopub message into the db"""
752 776 # print (topics)
753 777 try:
754 778 msg = self.session.unpack_message(msg, content=True)
755 779 except:
756 780 self.log.error("iopub::invalid IOPub message", exc_info=True)
757 781 return
758 782
759 783 parent = msg['parent_header']
760 784 if not parent:
761 785 self.log.error("iopub::invalid IOPub message: %s"%msg)
762 786 return
763 787 msg_id = parent['msg_id']
764 788 msg_type = msg['msg_type']
765 789 content = msg['content']
766 790
767 791 # ensure msg_id is in db
768 792 try:
769 793 rec = self.db.get_record(msg_id)
770 794 except KeyError:
771 795 rec = empty_record()
772 796 rec['msg_id'] = msg_id
773 797 self.db.add_record(msg_id, rec)
774 798 # stream
775 799 d = {}
776 800 if msg_type == 'stream':
777 801 name = content['name']
778 802 s = rec[name] or ''
779 803 d[name] = s + content['data']
780 804
781 805 elif msg_type == 'pyerr':
782 806 d['pyerr'] = content
783 807 elif msg_type == 'pyin':
784 808 d['pyin'] = content['code']
785 809 else:
786 810 d[msg_type] = content.get('data', '')
787 811
788 812 try:
789 813 self.db.update_record(msg_id, d)
790 814 except Exception:
791 815 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
792 816
793 817
794 818
795 819 #-------------------------------------------------------------------------
796 820 # Registration requests
797 821 #-------------------------------------------------------------------------
798 822
799 823 def connection_request(self, client_id, msg):
800 824 """Reply with connection addresses for clients."""
801 825 self.log.info("client::client %s connected"%client_id)
802 826 content = dict(status='ok')
803 827 content.update(self.client_info)
804 828 jsonable = {}
805 829 for k,v in self.keytable.iteritems():
806 830 if v not in self.dead_engines:
807 831 jsonable[str(k)] = v
808 832 content['engines'] = jsonable
809 833 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
810 834
811 835 def register_engine(self, reg, msg):
812 836 """Register a new engine."""
813 837 content = msg['content']
814 838 try:
815 839 queue = content['queue']
816 840 except KeyError:
817 841 self.log.error("registration::queue not specified", exc_info=True)
818 842 return
819 843 heart = content.get('heartbeat', None)
820 844 """register a new engine, and create the socket(s) necessary"""
821 845 eid = self._next_id
822 846 # print (eid, queue, reg, heart)
823 847
824 848 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
825 849
826 850 content = dict(id=eid,status='ok')
827 851 content.update(self.engine_info)
828 852 # check if requesting available IDs:
829 853 if queue in self.by_ident:
830 854 try:
831 855 raise KeyError("queue_id %r in use"%queue)
832 856 except:
833 857 content = error.wrap_exception()
834 858 self.log.error("queue_id %r in use"%queue, exc_info=True)
835 859 elif heart in self.hearts: # need to check unique hearts?
836 860 try:
837 861 raise KeyError("heart_id %r in use"%heart)
838 862 except:
839 863 self.log.error("heart_id %r in use"%heart, exc_info=True)
840 864 content = error.wrap_exception()
841 865 else:
842 866 for h, pack in self.incoming_registrations.iteritems():
843 867 if heart == h:
844 868 try:
845 869 raise KeyError("heart_id %r in use"%heart)
846 870 except:
847 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
848 872 content = error.wrap_exception()
849 873 break
850 874 elif queue == pack[1]:
851 875 try:
852 876 raise KeyError("queue_id %r in use"%queue)
853 877 except:
854 878 self.log.error("queue_id %r in use"%queue, exc_info=True)
855 879 content = error.wrap_exception()
856 880 break
857 881
858 882 msg = self.session.send(self.query, "registration_reply",
859 883 content=content,
860 884 ident=reg)
861 885
862 886 if content['status'] == 'ok':
863 887 if heart in self.heartmonitor.hearts:
864 888 # already beating
865 889 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
866 890 self.finish_registration(heart)
867 891 else:
868 892 purge = lambda : self._purge_stalled_registration(heart)
869 893 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
870 894 dc.start()
871 895 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
872 896 else:
873 897 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
874 898 return eid
875 899
876 900 def unregister_engine(self, ident, msg):
877 901 """Unregister an engine that explicitly requested to leave."""
878 902 try:
879 903 eid = msg['content']['id']
880 904 except:
881 905 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
882 906 return
883 907 self.log.info("registration::unregister_engine(%s)"%eid)
884 908 # print (eid)
885 909 uuid = self.keytable[eid]
886 910 content=dict(id=eid, queue=uuid)
887 911 self.dead_engines.add(uuid)
888 912 # self.ids.remove(eid)
889 913 # uuid = self.keytable.pop(eid)
890 914 #
891 915 # ec = self.engines.pop(eid)
892 916 # self.hearts.pop(ec.heartbeat)
893 917 # self.by_ident.pop(ec.queue)
894 918 # self.completed.pop(eid)
895 919 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
896 920 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
897 921 dc.start()
898 922 ############## TODO: HANDLE IT ################
899 923
900 924 if self.notifier:
901 925 self.session.send(self.notifier, "unregistration_notification", content=content)
902 926
903 927 def _handle_stranded_msgs(self, eid, uuid):
904 928 """Handle messages known to be on an engine when the engine unregisters.
905 929
906 930 It is possible that this will fire prematurely - that is, an engine will
907 931 go down after completing a result, and the client will be notified
908 932 that the result failed and later receive the actual result.
909 933 """
910 934
911 935 outstanding = self.queues[eid]
912 936
913 937 for msg_id in outstanding:
914 938 self.pending.remove(msg_id)
915 939 self.all_completed.add(msg_id)
916 940 try:
917 941 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
918 942 except:
919 943 content = error.wrap_exception()
920 944 # build a fake header:
921 945 header = {}
922 946 header['engine'] = uuid
923 947 header['date'] = datetime.now()
924 948 rec = dict(result_content=content, result_header=header, result_buffers=[])
925 949 rec['completed'] = header['date']
926 950 rec['engine_uuid'] = uuid
927 951 try:
928 952 self.db.update_record(msg_id, rec)
929 953 except Exception:
930 954 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
931 955
932 956
933 957 def finish_registration(self, heart):
934 958 """Second half of engine registration, called after our HeartMonitor
935 959 has received a beat from the Engine's Heart."""
936 960 try:
937 961 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
938 962 except KeyError:
939 963 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
940 964 return
941 965 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
942 966 if purge is not None:
943 967 purge.stop()
944 968 control = queue
945 969 self.ids.add(eid)
946 970 self.keytable[eid] = queue
947 971 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
948 972 control=control, heartbeat=heart)
949 973 self.by_ident[queue] = eid
950 974 self.queues[eid] = list()
951 975 self.tasks[eid] = list()
952 976 self.completed[eid] = list()
953 977 self.hearts[heart] = eid
954 978 content = dict(id=eid, queue=self.engines[eid].queue)
955 979 if self.notifier:
956 980 self.session.send(self.notifier, "registration_notification", content=content)
957 981 self.log.info("engine::Engine Connected: %i"%eid)
958 982
959 983 def _purge_stalled_registration(self, heart):
960 984 if heart in self.incoming_registrations:
961 985 eid = self.incoming_registrations.pop(heart)[0]
962 986 self.log.info("registration::purging stalled registration: %i"%eid)
963 987 else:
964 988 pass
965 989
966 990 #-------------------------------------------------------------------------
967 991 # Client Requests
968 992 #-------------------------------------------------------------------------
969 993
970 994 def shutdown_request(self, client_id, msg):
971 995 """handle shutdown request."""
972 996 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
973 997 # also notify other clients of shutdown
974 998 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
975 999 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
976 1000 dc.start()
977 1001
978 1002 def _shutdown(self):
979 1003 self.log.info("hub::hub shutting down.")
980 1004 time.sleep(0.1)
981 1005 sys.exit(0)
982 1006
983 1007
984 1008 def check_load(self, client_id, msg):
985 1009 content = msg['content']
986 1010 try:
987 1011 targets = content['targets']
988 1012 targets = self._validate_targets(targets)
989 1013 except:
990 1014 content = error.wrap_exception()
991 1015 self.session.send(self.query, "hub_error",
992 1016 content=content, ident=client_id)
993 1017 return
994 1018
995 1019 content = dict(status='ok')
996 1020 # loads = {}
997 1021 for t in targets:
998 1022 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
999 1023 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1000 1024
1001 1025
1002 1026 def queue_status(self, client_id, msg):
1003 1027 """Return the Queue status of one or more targets.
1004 1028 if verbose: return the msg_ids
1005 1029 else: return len of each type.
1006 1030 keys: queue (pending MUX jobs)
1007 1031 tasks (pending Task jobs)
1008 1032 completed (finished jobs from both queues)"""
1009 1033 content = msg['content']
1010 1034 targets = content['targets']
1011 1035 try:
1012 1036 targets = self._validate_targets(targets)
1013 1037 except:
1014 1038 content = error.wrap_exception()
1015 1039 self.session.send(self.query, "hub_error",
1016 1040 content=content, ident=client_id)
1017 1041 return
1018 1042 verbose = content.get('verbose', False)
1019 1043 content = dict(status='ok')
1020 1044 for t in targets:
1021 1045 queue = self.queues[t]
1022 1046 completed = self.completed[t]
1023 1047 tasks = self.tasks[t]
1024 1048 if not verbose:
1025 1049 queue = len(queue)
1026 1050 completed = len(completed)
1027 1051 tasks = len(tasks)
1028 1052 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1029 1053 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1030 1054
1031 1055 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1032 1056
1033 1057 def purge_results(self, client_id, msg):
1034 1058 """Purge results from memory. This method is more valuable before we move
1035 1059 to a DB based message storage mechanism."""
1036 1060 content = msg['content']
1037 1061 msg_ids = content.get('msg_ids', [])
1038 1062 reply = dict(status='ok')
1039 1063 if msg_ids == 'all':
1040 1064 try:
1041 1065 self.db.drop_matching_records(dict(completed={'$ne':None}))
1042 1066 except Exception:
1043 1067 reply = error.wrap_exception()
1044 1068 else:
1045 1069 for msg_id in msg_ids:
1046 1070 if msg_id in self.all_completed:
1047 1071 self.db.drop_record(msg_id)
1048 1072 else:
1049 1073 if msg_id in self.pending:
1050 1074 try:
1051 1075 raise IndexError("msg pending: %r"%msg_id)
1052 1076 except:
1053 1077 reply = error.wrap_exception()
1054 1078 else:
1055 1079 try:
1056 1080 raise IndexError("No such msg: %r"%msg_id)
1057 1081 except:
1058 1082 reply = error.wrap_exception()
1059 1083 break
1060 1084 eids = content.get('engine_ids', [])
1061 1085 for eid in eids:
1062 1086 if eid not in self.engines:
1063 1087 try:
1064 1088 raise IndexError("No such engine: %i"%eid)
1065 1089 except:
1066 1090 reply = error.wrap_exception()
1067 1091 break
1068 1092 msg_ids = self.completed.pop(eid)
1069 1093 uid = self.engines[eid].queue
1070 1094 try:
1071 1095 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1072 1096 except Exception:
1073 1097 reply = error.wrap_exception()
1074 1098 break
1075 1099
1076 1100 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1077 1101
1078 def resubmit_task(self, client_id, msg, buffers):
1079 """Resubmit a task."""
1080 raise NotImplementedError
1102 def resubmit_task(self, client_id, msg):
1103 """Resubmit one or more tasks."""
1104 def finish(reply):
1105 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1106
1107 content = msg['content']
1108 msg_ids = content['msg_ids']
1109 reply = dict(status='ok')
1110 try:
1111 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1112 'header', 'content', 'buffers'])
1113 except Exception:
1114 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1115 return finish(error.wrap_exception())
1116
1117 # validate msg_ids
1118 found_ids = [ rec['msg_id'] for rec in records ]
1119 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1120 if len(records) > len(msg_ids):
1121 try:
1122 raise RuntimeError("DB appears to be in an inconsistent state."
1123 "More matching records were found than should exist")
1124 except Exception:
1125 return finish(error.wrap_exception())
1126 elif len(records) < len(msg_ids):
1127 missing = [ m for m in msg_ids if m not in found_ids ]
1128 try:
1129 raise KeyError("No such msg(s): %s"%missing)
1130 except KeyError:
1131 return finish(error.wrap_exception())
1132 elif invalid_ids:
1133 msg_id = invalid_ids[0]
1134 try:
1135 raise ValueError("Task %r appears to be inflight"%(msg_id))
1136 except Exception:
1137 return finish(error.wrap_exception())
1138
1139 # clear the existing records
1140 rec = empty_record()
1141 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 rec['resubmitted'] = datetime.now()
1143 rec['queue'] = 'task'
1144 rec['client_uuid'] = client_id[0]
1145 try:
1146 for msg_id in msg_ids:
1147 self.all_completed.discard(msg_id)
1148 self.db.update_record(msg_id, rec)
1149 except Exception:
1150 self.log.error('db::db error upating record', exc_info=True)
1151 reply = error.wrap_exception()
1152 else:
1153 # send the messages
1154 for rec in records:
1155 header = rec['header']
1156 msg = self.session.msg(header['msg_type'])
1157 msg['content'] = rec['content']
1158 msg['header'] = header
1159 msg['msg_id'] = rec['msg_id']
1160 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1161
1162 finish(dict(status='ok'))
1163
1081 1164
1082 1165 def _extract_record(self, rec):
1083 1166 """decompose a TaskRecord dict into subsection of reply for get_result"""
1084 1167 io_dict = {}
1085 1168 for key in 'pyin pyout pyerr stdout stderr'.split():
1086 1169 io_dict[key] = rec[key]
1087 1170 content = { 'result_content': rec['result_content'],
1088 1171 'header': rec['header'],
1089 1172 'result_header' : rec['result_header'],
1090 1173 'io' : io_dict,
1091 1174 }
1092 1175 if rec['result_buffers']:
1093 1176 buffers = map(str, rec['result_buffers'])
1094 1177 else:
1095 1178 buffers = []
1096 1179
1097 1180 return content, buffers
1098 1181
1099 1182 def get_results(self, client_id, msg):
1100 1183 """Get the result of 1 or more messages."""
1101 1184 content = msg['content']
1102 1185 msg_ids = sorted(set(content['msg_ids']))
1103 1186 statusonly = content.get('status_only', False)
1104 1187 pending = []
1105 1188 completed = []
1106 1189 content = dict(status='ok')
1107 1190 content['pending'] = pending
1108 1191 content['completed'] = completed
1109 1192 buffers = []
1110 1193 if not statusonly:
1111 1194 try:
1112 1195 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1113 1196 # turn match list into dict, for faster lookup
1114 1197 records = {}
1115 1198 for rec in matches:
1116 1199 records[rec['msg_id']] = rec
1117 1200 except Exception:
1118 1201 content = error.wrap_exception()
1119 1202 self.session.send(self.query, "result_reply", content=content,
1120 1203 parent=msg, ident=client_id)
1121 1204 return
1122 1205 else:
1123 1206 records = {}
1124 1207 for msg_id in msg_ids:
1125 1208 if msg_id in self.pending:
1126 1209 pending.append(msg_id)
1127 elif msg_id in self.all_completed or msg_id in records:
1210 elif msg_id in self.all_completed:
1128 1211 completed.append(msg_id)
1129 1212 if not statusonly:
1130 1213 c,bufs = self._extract_record(records[msg_id])
1131 1214 content[msg_id] = c
1132 1215 buffers.extend(bufs)
1216 elif msg_id in records:
1217 if rec['completed']:
1218 completed.append(msg_id)
1219 c,bufs = self._extract_record(records[msg_id])
1220 content[msg_id] = c
1221 buffers.extend(bufs)
1222 else:
1223 pending.append(msg_id)
1133 1224 else:
1134 1225 try:
1135 1226 raise KeyError('No such message: '+msg_id)
1136 1227 except:
1137 1228 content = error.wrap_exception()
1138 1229 break
1139 1230 self.session.send(self.query, "result_reply", content=content,
1140 1231 parent=msg, ident=client_id,
1141 1232 buffers=buffers)
1142 1233
1143 1234 def get_history(self, client_id, msg):
1144 1235 """Get a list of all msg_ids in our DB records"""
1145 1236 try:
1146 1237 msg_ids = self.db.get_history()
1147 1238 except Exception as e:
1148 1239 content = error.wrap_exception()
1149 1240 else:
1150 1241 content = dict(status='ok', history=msg_ids)
1151 1242
1152 1243 self.session.send(self.query, "history_reply", content=content,
1153 1244 parent=msg, ident=client_id)
1154 1245
1155 1246 def db_query(self, client_id, msg):
1156 1247 """Perform a raw query on the task record database."""
1157 1248 content = msg['content']
1158 1249 query = content.get('query', {})
1159 1250 keys = content.get('keys', None)
1160 1251 query = util.extract_dates(query)
1161 1252 buffers = []
1162 1253 empty = list()
1163 1254
1164 1255 try:
1165 1256 records = self.db.find_records(query, keys)
1166 1257 except Exception as e:
1167 1258 content = error.wrap_exception()
1168 1259 else:
1169 1260 # extract buffers from reply content:
1170 1261 if keys is not None:
1171 1262 buffer_lens = [] if 'buffers' in keys else None
1172 1263 result_buffer_lens = [] if 'result_buffers' in keys else None
1173 1264 else:
1174 1265 buffer_lens = []
1175 1266 result_buffer_lens = []
1176 1267
1177 1268 for rec in records:
1178 1269 # buffers may be None, so double check
1179 1270 if buffer_lens is not None:
1180 1271 b = rec.pop('buffers', empty) or empty
1181 1272 buffer_lens.append(len(b))
1182 1273 buffers.extend(b)
1183 1274 if result_buffer_lens is not None:
1184 1275 rb = rec.pop('result_buffers', empty) or empty
1185 1276 result_buffer_lens.append(len(rb))
1186 1277 buffers.extend(rb)
1187 1278 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1188 1279 result_buffer_lens=result_buffer_lens)
1189 1280
1190 1281 self.session.send(self.query, "db_reply", content=content,
1191 1282 parent=msg, ident=client_id,
1192 1283 buffers=buffers)
1193 1284
@@ -1,416 +1,419 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2010-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11
12 12 import os
13 13 import pprint
14 14 import uuid
15 15 from datetime import datetime
16 16
17 17 try:
18 18 import cPickle
19 19 pickle = cPickle
20 20 except:
21 21 cPickle = None
22 22 import pickle
23 23
24 24 import zmq
25 25 from zmq.utils import jsonapi
26 26 from zmq.eventloop.zmqstream import ZMQStream
27 27
28 28 from .util import ISO8601
29 29
30 30 def squash_unicode(obj):
31 31 """coerce unicode back to bytestrings."""
32 32 if isinstance(obj,dict):
33 33 for key in obj.keys():
34 34 obj[key] = squash_unicode(obj[key])
35 35 if isinstance(key, unicode):
36 36 obj[squash_unicode(key)] = obj.pop(key)
37 37 elif isinstance(obj, list):
38 38 for i,v in enumerate(obj):
39 39 obj[i] = squash_unicode(v)
40 40 elif isinstance(obj, unicode):
41 41 obj = obj.encode('utf8')
42 42 return obj
43 43
44 44 def _date_default(obj):
45 45 if isinstance(obj, datetime):
46 46 return obj.strftime(ISO8601)
47 47 else:
48 48 raise TypeError("%r is not JSON serializable"%obj)
49 49
50 50 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
51 51 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:_date_default})
52 52 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
53 53
54 54 pickle_packer = lambda o: pickle.dumps(o,-1)
55 55 pickle_unpacker = pickle.loads
56 56
57 57 default_packer = json_packer
58 58 default_unpacker = json_unpacker
59 59
60 60
61 61 DELIM="<IDS|MSG>"
62 62
63 63 class Message(object):
64 64 """A simple message object that maps dict keys to attributes.
65 65
66 66 A Message can be created from a dict and a dict from a Message instance
67 67 simply by calling dict(msg_obj)."""
68 68
69 69 def __init__(self, msg_dict):
70 70 dct = self.__dict__
71 71 for k, v in dict(msg_dict).iteritems():
72 72 if isinstance(v, dict):
73 73 v = Message(v)
74 74 dct[k] = v
75 75
76 76 # Having this iterator lets dict(msg_obj) work out of the box.
77 77 def __iter__(self):
78 78 return iter(self.__dict__.iteritems())
79 79
80 80 def __repr__(self):
81 81 return repr(self.__dict__)
82 82
83 83 def __str__(self):
84 84 return pprint.pformat(self.__dict__)
85 85
86 86 def __contains__(self, k):
87 87 return k in self.__dict__
88 88
89 89 def __getitem__(self, k):
90 90 return self.__dict__[k]
91 91
92 92
93 93 def msg_header(msg_id, msg_type, username, session):
94 94 date=datetime.now().strftime(ISO8601)
95 95 return locals()
96 96
97 97 def extract_header(msg_or_header):
98 98 """Given a message or header, return the header."""
99 99 if not msg_or_header:
100 100 return {}
101 101 try:
102 102 # See if msg_or_header is the entire message.
103 103 h = msg_or_header['header']
104 104 except KeyError:
105 105 try:
106 106 # See if msg_or_header is just the header
107 107 h = msg_or_header['msg_id']
108 108 except KeyError:
109 109 raise
110 110 else:
111 111 h = msg_or_header
112 112 if not isinstance(h, dict):
113 113 h = dict(h)
114 114 return h
115 115
116 116 class StreamSession(object):
117 117 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
118 118 debug=False
119 119 key=None
120 120
121 121 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
122 122 if username is None:
123 123 username = os.environ.get('USER','username')
124 124 self.username = username
125 125 if session is None:
126 126 self.session = str(uuid.uuid4())
127 127 else:
128 128 self.session = session
129 129 self.msg_id = str(uuid.uuid4())
130 130 if packer is None:
131 131 self.pack = default_packer
132 132 else:
133 133 if not callable(packer):
134 134 raise TypeError("packer must be callable, not %s"%type(packer))
135 135 self.pack = packer
136 136
137 137 if unpacker is None:
138 138 self.unpack = default_unpacker
139 139 else:
140 140 if not callable(unpacker):
141 141 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
142 142 self.unpack = unpacker
143 143
144 144 if key is not None and keyfile is not None:
145 145 raise TypeError("Must specify key OR keyfile, not both")
146 146 if keyfile is not None:
147 147 with open(keyfile) as f:
148 148 self.key = f.read().strip()
149 149 else:
150 150 self.key = key
151 151 if isinstance(self.key, unicode):
152 152 self.key = self.key.encode('utf8')
153 153 # print key, keyfile, self.key
154 154 self.none = self.pack({})
155 155
156 156 def msg_header(self, msg_type):
157 157 h = msg_header(self.msg_id, msg_type, self.username, self.session)
158 158 self.msg_id = str(uuid.uuid4())
159 159 return h
160 160
161 161 def msg(self, msg_type, content=None, parent=None, subheader=None):
162 162 msg = {}
163 163 msg['header'] = self.msg_header(msg_type)
164 164 msg['msg_id'] = msg['header']['msg_id']
165 165 msg['parent_header'] = {} if parent is None else extract_header(parent)
166 166 msg['msg_type'] = msg_type
167 167 msg['content'] = {} if content is None else content
168 168 sub = {} if subheader is None else subheader
169 169 msg['header'].update(sub)
170 170 return msg
171 171
172 172 def check_key(self, msg_or_header):
173 173 """Check that a message's header has the right key"""
174 174 if self.key is None:
175 175 return True
176 176 header = extract_header(msg_or_header)
177 177 return header.get('key', None) == self.key
178 178
179 179
180 180 def serialize(self, msg, ident=None):
181 181 content = msg.get('content', {})
182 182 if content is None:
183 183 content = self.none
184 184 elif isinstance(content, dict):
185 185 content = self.pack(content)
186 186 elif isinstance(content, bytes):
187 187 # content is already packed, as in a relayed message
188 188 pass
189 elif isinstance(content, unicode):
190 # should be bytes, but JSON often spits out unicode
191 content = content.encode('utf8')
189 192 else:
190 193 raise TypeError("Content incorrect type: %s"%type(content))
191 194
192 195 to_send = []
193 196
194 197 if isinstance(ident, list):
195 198 # accept list of idents
196 199 to_send.extend(ident)
197 200 elif ident is not None:
198 201 to_send.append(ident)
199 202 to_send.append(DELIM)
200 203 if self.key is not None:
201 204 to_send.append(self.key)
202 205 to_send.append(self.pack(msg['header']))
203 206 to_send.append(self.pack(msg['parent_header']))
204 207 to_send.append(content)
205 208
206 209 return to_send
207 210
208 211 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None, track=False):
209 212 """Build and send a message via stream or socket.
210 213
211 214 Parameters
212 215 ----------
213 216
214 217 stream : zmq.Socket or ZMQStream
215 218 the socket-like object used to send the data
216 219 msg_or_type : str or Message/dict
217 220 Normally, msg_or_type will be a msg_type unless a message is being sent more
218 221 than once.
219 222
220 223 content : dict or None
221 224 the content of the message (ignored if msg_or_type is a message)
222 225 buffers : list or None
223 226 the already-serialized buffers to be appended to the message
224 227 parent : Message or dict or None
225 228 the parent or parent header describing the parent of this message
226 229 subheader : dict or None
227 230 extra header keys for this message's header
228 231 ident : bytes or list of bytes
229 232 the zmq.IDENTITY routing path
230 233 track : bool
231 234 whether to track. Only for use with Sockets, because ZMQStream objects cannot track messages.
232 235
233 236 Returns
234 237 -------
235 238 msg : message dict
236 239 the constructed message
237 240 (msg,tracker) : (message dict, MessageTracker)
238 241 if track=True, then a 2-tuple will be returned, the first element being the constructed
239 242 message, and the second being the MessageTracker
240 243
241 244 """
242 245
243 246 if not isinstance(stream, (zmq.Socket, ZMQStream)):
244 247 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
245 248 elif track and isinstance(stream, ZMQStream):
246 249 raise TypeError("ZMQStream cannot track messages")
247 250
248 251 if isinstance(msg_or_type, (Message, dict)):
249 252 # we got a Message, not a msg_type
250 253 # don't build a new Message
251 254 msg = msg_or_type
252 255 else:
253 256 msg = self.msg(msg_or_type, content, parent, subheader)
254 257
255 258 buffers = [] if buffers is None else buffers
256 259 to_send = self.serialize(msg, ident)
257 260 flag = 0
258 261 if buffers:
259 262 flag = zmq.SNDMORE
260 263 _track = False
261 264 else:
262 265 _track=track
263 266 if track:
264 267 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
265 268 else:
266 269 tracker = stream.send_multipart(to_send, flag, copy=False)
267 270 for b in buffers[:-1]:
268 271 stream.send(b, flag, copy=False)
269 272 if buffers:
270 273 if track:
271 274 tracker = stream.send(buffers[-1], copy=False, track=track)
272 275 else:
273 276 tracker = stream.send(buffers[-1], copy=False)
274 277
275 278 # omsg = Message(msg)
276 279 if self.debug:
277 280 pprint.pprint(msg)
278 281 pprint.pprint(to_send)
279 282 pprint.pprint(buffers)
280 283
281 284 msg['tracker'] = tracker
282 285
283 286 return msg
284 287
285 288 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
286 289 """Send a raw message via ident path.
287 290
288 291 Parameters
289 292 ----------
290 293 msg : list of sendable buffers"""
291 294 to_send = []
292 295 if isinstance(ident, bytes):
293 296 ident = [ident]
294 297 if ident is not None:
295 298 to_send.extend(ident)
296 299 to_send.append(DELIM)
297 300 if self.key is not None:
298 301 to_send.append(self.key)
299 302 to_send.extend(msg)
300 303 stream.send_multipart(msg, flags, copy=copy)
301 304
302 305 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
303 306 """receives and unpacks a message
304 307 returns [idents], msg"""
305 308 if isinstance(socket, ZMQStream):
306 309 socket = socket.socket
307 310 try:
308 311 msg = socket.recv_multipart(mode)
309 312 except zmq.ZMQError as e:
310 313 if e.errno == zmq.EAGAIN:
311 314 # We can convert EAGAIN to None as we know in this case
312 315 # recv_multipart won't return None.
313 316 return None
314 317 else:
315 318 raise
316 319 # return an actual Message object
317 320 # determine the number of idents by trying to unpack them.
318 321 # this is terrible:
319 322 idents, msg = self.feed_identities(msg, copy)
320 323 try:
321 324 return idents, self.unpack_message(msg, content=content, copy=copy)
322 325 except Exception as e:
323 326 print (idents, msg)
324 327 # TODO: handle it
325 328 raise e
326 329
327 330 def feed_identities(self, msg, copy=True):
328 331 """feed until DELIM is reached, then return the prefix as idents and remainder as
329 332 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
330 333
331 334 Parameters
332 335 ----------
333 336 msg : a list of Message or bytes objects
334 337 the message to be split
335 338 copy : bool
336 339 flag determining whether the arguments are bytes or Messages
337 340
338 341 Returns
339 342 -------
340 343 (idents,msg) : two lists
341 344 idents will always be a list of bytes - the indentity prefix
342 345 msg will be a list of bytes or Messages, unchanged from input
343 346 msg should be unpackable via self.unpack_message at this point.
344 347 """
345 348 ikey = int(self.key is not None)
346 349 minlen = 3 + ikey
347 350 msg = list(msg)
348 351 idents = []
349 352 while len(msg) > minlen:
350 353 if copy:
351 354 s = msg[0]
352 355 else:
353 356 s = msg[0].bytes
354 357 if s == DELIM:
355 358 msg.pop(0)
356 359 break
357 360 else:
358 361 idents.append(s)
359 362 msg.pop(0)
360 363
361 364 return idents, msg
362 365
363 366 def unpack_message(self, msg, content=True, copy=True):
364 367 """Return a message object from the format
365 368 sent by self.send.
366 369
367 370 Parameters:
368 371 -----------
369 372
370 373 content : bool (True)
371 374 whether to unpack the content dict (True),
372 375 or leave it serialized (False)
373 376
374 377 copy : bool (True)
375 378 whether to return the bytes (True),
376 379 or the non-copying Message object in each place (False)
377 380
378 381 """
379 382 ikey = int(self.key is not None)
380 383 minlen = 3 + ikey
381 384 message = {}
382 385 if not copy:
383 386 for i in range(minlen):
384 387 msg[i] = msg[i].bytes
385 388 if ikey:
386 389 if not self.key == msg[0]:
387 390 raise KeyError("Invalid Session Key: %s"%msg[0])
388 391 if not len(msg) >= minlen:
389 392 raise TypeError("malformed message, must have at least %i elements"%minlen)
390 393 message['header'] = self.unpack(msg[ikey+0])
391 394 message['msg_type'] = message['header']['msg_type']
392 395 message['parent_header'] = self.unpack(msg[ikey+1])
393 396 if content:
394 397 message['content'] = self.unpack(msg[ikey+2])
395 398 else:
396 399 message['content'] = msg[ikey+2]
397 400
398 401 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
399 402 return message
400 403
401 404
402 405 def test_msg2obj():
403 406 am = dict(x=1)
404 407 ao = Message(am)
405 408 assert ao.x == am['x']
406 409
407 410 am['y'] = dict(z=1)
408 411 ao = Message(am)
409 412 assert ao.y.z == am['y']['z']
410 413
411 414 k1, k2 = 'y', 'z'
412 415 assert ao[k1][k2] == am[k1][k2]
413 416
414 417 am2 = dict(ao)
415 418 assert am['x'] == am2['x']
416 419 assert am['y']['z'] == am2['y']['z']
@@ -1,214 +1,237 b''
1 1 """Tests for parallel client.py"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import time
15 15 from datetime import datetime
16 16 from tempfile import mktemp
17 17
18 18 import zmq
19 19
20 20 from IPython.parallel.client import client as clientmod
21 21 from IPython.parallel import error
22 22 from IPython.parallel import AsyncResult, AsyncHubResult
23 23 from IPython.parallel import LoadBalancedView, DirectView
24 24
25 25 from clienttest import ClusterTestCase, segfault, wait, add_engines
26 26
27 27 def setup():
28 28 add_engines(4)
29 29
30 30 class TestClient(ClusterTestCase):
31 31
32 32 def test_ids(self):
33 33 n = len(self.client.ids)
34 34 self.add_engines(3)
35 35 self.assertEquals(len(self.client.ids), n+3)
36 36
37 37 def test_view_indexing(self):
38 38 """test index access for views"""
39 39 self.add_engines(2)
40 40 targets = self.client._build_targets('all')[-1]
41 41 v = self.client[:]
42 42 self.assertEquals(v.targets, targets)
43 43 t = self.client.ids[2]
44 44 v = self.client[t]
45 45 self.assert_(isinstance(v, DirectView))
46 46 self.assertEquals(v.targets, t)
47 47 t = self.client.ids[2:4]
48 48 v = self.client[t]
49 49 self.assert_(isinstance(v, DirectView))
50 50 self.assertEquals(v.targets, t)
51 51 v = self.client[::2]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, targets[::2])
54 54 v = self.client[1::3]
55 55 self.assert_(isinstance(v, DirectView))
56 56 self.assertEquals(v.targets, targets[1::3])
57 57 v = self.client[:-3]
58 58 self.assert_(isinstance(v, DirectView))
59 59 self.assertEquals(v.targets, targets[:-3])
60 60 v = self.client[-1]
61 61 self.assert_(isinstance(v, DirectView))
62 62 self.assertEquals(v.targets, targets[-1])
63 63 self.assertRaises(TypeError, lambda : self.client[None])
64 64
65 65 def test_lbview_targets(self):
66 66 """test load_balanced_view targets"""
67 67 v = self.client.load_balanced_view()
68 68 self.assertEquals(v.targets, None)
69 69 v = self.client.load_balanced_view(-1)
70 70 self.assertEquals(v.targets, [self.client.ids[-1]])
71 71 v = self.client.load_balanced_view('all')
72 72 self.assertEquals(v.targets, self.client.ids)
73 73
74 74 def test_targets(self):
75 75 """test various valid targets arguments"""
76 76 build = self.client._build_targets
77 77 ids = self.client.ids
78 78 idents,targets = build(None)
79 79 self.assertEquals(ids, targets)
80 80
81 81 def test_clear(self):
82 82 """test clear behavior"""
83 83 # self.add_engines(2)
84 84 v = self.client[:]
85 85 v.block=True
86 86 v.push(dict(a=5))
87 87 v.pull('a')
88 88 id0 = self.client.ids[-1]
89 89 self.client.clear(targets=id0, block=True)
90 90 a = self.client[:-1].get('a')
91 91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 92 self.client.clear(block=True)
93 93 for i in self.client.ids:
94 94 # print i
95 95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96 96
97 97 def test_get_result(self):
98 98 """test getting results from the Hub."""
99 99 c = clientmod.Client(profile='iptest')
100 100 # self.add_engines(1)
101 101 t = c.ids[-1]
102 102 ar = c[t].apply_async(wait, 1)
103 103 # give the monitor time to notice the message
104 104 time.sleep(.25)
105 105 ahr = self.client.get_result(ar.msg_ids)
106 106 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 107 self.assertEquals(ahr.get(), ar.get())
108 108 ar2 = self.client.get_result(ar.msg_ids)
109 109 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 110 c.close()
111 111
112 112 def test_ids_list(self):
113 113 """test client.ids"""
114 114 # self.add_engines(2)
115 115 ids = self.client.ids
116 116 self.assertEquals(ids, self.client._ids)
117 117 self.assertFalse(ids is self.client._ids)
118 118 ids.remove(ids[-1])
119 119 self.assertNotEquals(ids, self.client._ids)
120 120
121 121 def test_queue_status(self):
122 122 # self.addEngine(4)
123 123 ids = self.client.ids
124 124 id0 = ids[0]
125 125 qs = self.client.queue_status(targets=id0)
126 126 self.assertTrue(isinstance(qs, dict))
127 127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 128 allqs = self.client.queue_status()
129 129 self.assertTrue(isinstance(allqs, dict))
130 130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 131 unassigned = allqs.pop('unassigned')
132 132 for eid,qs in allqs.items():
133 133 self.assertTrue(isinstance(qs, dict))
134 134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135 135
136 136 def test_shutdown(self):
137 137 # self.addEngine(4)
138 138 ids = self.client.ids
139 139 id0 = ids[0]
140 140 self.client.shutdown(id0, block=True)
141 141 while id0 in self.client.ids:
142 142 time.sleep(0.1)
143 143 self.client.spin()
144 144
145 145 self.assertRaises(IndexError, lambda : self.client[id0])
146 146
147 147 def test_result_status(self):
148 148 pass
149 149 # to be written
150 150
151 151 def test_db_query_dt(self):
152 152 """test db query by date"""
153 153 hist = self.client.hub_history()
154 154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 155 tic = middle['submitted']
156 156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 158 self.assertEquals(len(before)+len(after),len(hist))
159 159 for b in before:
160 160 self.assertTrue(b['submitted'] < tic)
161 161 for a in after:
162 162 self.assertTrue(a['submitted'] >= tic)
163 163 same = self.client.db_query({'submitted' : tic})
164 164 for s in same:
165 165 self.assertTrue(s['submitted'] == tic)
166 166
167 167 def test_db_query_keys(self):
168 168 """test extracting subset of record keys"""
169 169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 170 for rec in found:
171 171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172 172
173 173 def test_db_query_msg_id(self):
174 174 """ensure msg_id is always in db queries"""
175 175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 176 for rec in found:
177 177 self.assertTrue('msg_id' in rec.keys())
178 178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 179 for rec in found:
180 180 self.assertTrue('msg_id' in rec.keys())
181 181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 182 for rec in found:
183 183 self.assertTrue('msg_id' in rec.keys())
184 184
185 185 def test_db_query_in(self):
186 186 """test db query with '$in','$nin' operators"""
187 187 hist = self.client.hub_history()
188 188 even = hist[::2]
189 189 odd = hist[1::2]
190 190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 191 found = [ r['msg_id'] for r in recs ]
192 192 self.assertEquals(set(even), set(found))
193 193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 194 found = [ r['msg_id'] for r in recs ]
195 195 self.assertEquals(set(odd), set(found))
196 196
197 197 def test_hub_history(self):
198 198 hist = self.client.hub_history()
199 199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 200 recdict = {}
201 201 for rec in recs:
202 202 recdict[rec['msg_id']] = rec
203 203
204 204 latest = datetime(1984,1,1)
205 205 for msg_id in hist:
206 206 rec = recdict[msg_id]
207 207 newt = rec['submitted']
208 208 self.assertTrue(newt >= latest)
209 209 latest = newt
210 210 ar = self.client[-1].apply_async(lambda : 1)
211 211 ar.get()
212 212 time.sleep(0.25)
213 213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214 214
215 def test_resubmit(self):
216 def f():
217 import random
218 return random.random()
219 v = self.client.load_balanced_view()
220 ar = v.apply_async(f)
221 r1 = ar.get(1)
222 ahr = self.client.resubmit(ar.msg_ids)
223 r2 = ahr.get(1)
224 self.assertFalse(r1 == r2)
225
226 def test_resubmit_inflight(self):
227 """ensure ValueError on resubmit of inflight task"""
228 v = self.client.load_balanced_view()
229 ar = v.apply_async(time.sleep,1)
230 # give the message a chance to arrive
231 time.sleep(0.2)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 ar.get(2)
234
235 def test_resubmit_badkey(self):
236 """ensure KeyError on resubmit of nonexistant task"""
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
@@ -1,120 +1,120 b''
1 1 """test LoadBalancedView objects"""
2 2 # -*- coding: utf-8 -*-
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import sys
15 15 import time
16 16
17 17 import zmq
18 18
19 19 from IPython import parallel as pmod
20 20 from IPython.parallel import error
21 21
22 22 from IPython.parallel.tests import add_engines
23 23
24 24 from .clienttest import ClusterTestCase, crash, wait, skip_without
25 25
26 26 def setup():
27 27 add_engines(3)
28 28
29 29 class TestLoadBalancedView(ClusterTestCase):
30 30
31 31 def setUp(self):
32 32 ClusterTestCase.setUp(self)
33 33 self.view = self.client.load_balanced_view()
34 34
35 35 def test_z_crash_task(self):
36 36 """test graceful handling of engine death (balanced)"""
37 37 # self.add_engines(1)
38 38 ar = self.view.apply_async(crash)
39 self.assertRaisesRemote(error.EngineError, ar.get)
39 self.assertRaisesRemote(error.EngineError, ar.get, 10)
40 40 eid = ar.engine_id
41 41 tic = time.time()
42 42 while eid in self.client.ids and time.time()-tic < 5:
43 43 time.sleep(.01)
44 44 self.client.spin()
45 45 self.assertFalse(eid in self.client.ids, "Engine should have died")
46 46
47 47 def test_map(self):
48 48 def f(x):
49 49 return x**2
50 50 data = range(16)
51 51 r = self.view.map_sync(f, data)
52 52 self.assertEquals(r, map(f, data))
53 53
54 54 def test_abort(self):
55 55 view = self.view
56 56 ar = self.client[:].apply_async(time.sleep, .5)
57 57 ar2 = view.apply_async(lambda : 2)
58 58 ar3 = view.apply_async(lambda : 3)
59 59 view.abort(ar2)
60 60 view.abort(ar3.msg_ids)
61 61 self.assertRaises(error.TaskAborted, ar2.get)
62 62 self.assertRaises(error.TaskAborted, ar3.get)
63 63
64 64 def test_retries(self):
65 65 add_engines(3)
66 66 view = self.view
67 67 view.timeout = 1 # prevent hang if this doesn't behave
68 68 def fail():
69 69 assert False
70 70 for r in range(len(self.client)-1):
71 71 with view.temp_flags(retries=r):
72 72 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
73 73
74 74 with view.temp_flags(retries=len(self.client), timeout=0.25):
75 75 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
76 76
77 77 def test_invalid_dependency(self):
78 78 view = self.view
79 79 with view.temp_flags(after='12345'):
80 80 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
81 81
82 82 def test_impossible_dependency(self):
83 83 if len(self.client) < 2:
84 84 add_engines(2)
85 85 view = self.client.load_balanced_view()
86 86 ar1 = view.apply_async(lambda : 1)
87 87 ar1.get()
88 88 e1 = ar1.engine_id
89 89 e2 = e1
90 90 while e2 == e1:
91 91 ar2 = view.apply_async(lambda : 1)
92 92 ar2.get()
93 93 e2 = ar2.engine_id
94 94
95 95 with view.temp_flags(follow=[ar1, ar2]):
96 96 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
97 97
98 98
99 99 def test_follow(self):
100 100 ar = self.view.apply_async(lambda : 1)
101 101 ar.get()
102 102 ars = []
103 103 first_id = ar.engine_id
104 104
105 105 self.view.follow = ar
106 106 for i in range(5):
107 107 ars.append(self.view.apply_async(lambda : 1))
108 108 self.view.wait(ars)
109 109 for ar in ars:
110 110 self.assertEquals(ar.engine_id, first_id)
111 111
112 112 def test_after(self):
113 113 view = self.view
114 114 ar = view.apply_async(time.sleep, 0.5)
115 115 with view.temp_flags(after=ar):
116 116 ar2 = view.apply_async(lambda : 1)
117 117
118 118 ar.wait()
119 119 ar2.wait()
120 120 self.assertTrue(ar2.started > ar.completed)
General Comments 0
You need to be logged in to leave comments. Login now