##// END OF EJS Templates
specify sshkey is *private*
MinRK -
Show More
@@ -1,1435 +1,1435 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 sshkey : str; path to public ssh key file
174 sshkey : str; path to ssh private key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 try:
310 310 util.validate_url(url_or_file)
311 311 except AssertionError:
312 312 if not os.path.exists(url_or_file):
313 313 if self._cd:
314 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 316 with open(url_or_file) as f:
317 317 cfg = json.loads(f.read())
318 318 else:
319 319 cfg = {'url':url_or_file}
320 320
321 321 # sync defaults from args, json:
322 322 if sshserver:
323 323 cfg['ssh'] = sshserver
324 324 if exec_key:
325 325 cfg['exec_key'] = exec_key
326 326 exec_key = cfg['exec_key']
327 327 location = cfg.setdefault('location', None)
328 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 329 url = cfg['url']
330 330 proto,addr,port = util.split_url(url)
331 331 if location is not None and addr == '127.0.0.1':
332 332 # location specified, and connection is expected to be local
333 333 if location not in LOCAL_IPS and not sshserver:
334 334 # load ssh from JSON *only* if the controller is not on
335 335 # this machine
336 336 sshserver=cfg['ssh']
337 337 if location not in LOCAL_IPS and not sshserver:
338 338 # warn if no ssh specified, but SSH is probably needed
339 339 # This is only a warning, because the most likely cause
340 340 # is a local Controller on a laptop whose IP is dynamic
341 341 warnings.warn("""
342 342 Controller appears to be listening on localhost, but not on this machine.
343 343 If this is true, you should specify Client(...,sshserver='you@%s')
344 344 or instruct your controller to listen on an external IP."""%location,
345 345 RuntimeWarning)
346 346 elif not sshserver:
347 347 # otherwise sync with cfg
348 348 sshserver = cfg['ssh']
349 349
350 350 self._config = cfg
351 351
352 352 self._ssh = bool(sshserver or sshkey or password)
353 353 if self._ssh and sshserver is None:
354 354 # default to ssh via localhost
355 355 sshserver = url.split('://')[1].split(':')[0]
356 356 if self._ssh and password is None:
357 357 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
358 358 password=False
359 359 else:
360 360 password = getpass("SSH Password for %s: "%sshserver)
361 361 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
362 362
363 363 # configure and construct the session
364 364 if exec_key is not None:
365 365 if os.path.isfile(exec_key):
366 366 extra_args['keyfile'] = exec_key
367 367 else:
368 368 exec_key = util.asbytes(exec_key)
369 369 extra_args['key'] = exec_key
370 370 self.session = Session(**extra_args)
371 371
372 372 self._query_socket = self._context.socket(zmq.XREQ)
373 373 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
374 374 if self._ssh:
375 375 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
376 376 else:
377 377 self._query_socket.connect(url)
378 378
379 379 self.session.debug = self.debug
380 380
381 381 self._notification_handlers = {'registration_notification' : self._register_engine,
382 382 'unregistration_notification' : self._unregister_engine,
383 383 'shutdown_notification' : lambda msg: self.close(),
384 384 }
385 385 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
386 386 'apply_reply' : self._handle_apply_reply}
387 387 self._connect(sshserver, ssh_kwargs, timeout)
388 388
389 389 def __del__(self):
390 390 """cleanup sockets, but _not_ context."""
391 391 self.close()
392 392
393 393 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
394 394 if ipython_dir is None:
395 395 ipython_dir = get_ipython_dir()
396 396 if profile_dir is not None:
397 397 try:
398 398 self._cd = ProfileDir.find_profile_dir(profile_dir)
399 399 return
400 400 except ProfileDirError:
401 401 pass
402 402 elif profile is not None:
403 403 try:
404 404 self._cd = ProfileDir.find_profile_dir_by_name(
405 405 ipython_dir, profile)
406 406 return
407 407 except ProfileDirError:
408 408 pass
409 409 self._cd = None
410 410
411 411 def _update_engines(self, engines):
412 412 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
413 413 for k,v in engines.iteritems():
414 414 eid = int(k)
415 415 self._engines[eid] = v
416 416 self._ids.append(eid)
417 417 self._ids = sorted(self._ids)
418 418 if sorted(self._engines.keys()) != range(len(self._engines)) and \
419 419 self._task_scheme == 'pure' and self._task_socket:
420 420 self._stop_scheduling_tasks()
421 421
422 422 def _stop_scheduling_tasks(self):
423 423 """Stop scheduling tasks because an engine has been unregistered
424 424 from a pure ZMQ scheduler.
425 425 """
426 426 self._task_socket.close()
427 427 self._task_socket = None
428 428 msg = "An engine has been unregistered, and we are using pure " +\
429 429 "ZMQ task scheduling. Task farming will be disabled."
430 430 if self.outstanding:
431 431 msg += " If you were running tasks when this happened, " +\
432 432 "some `outstanding` msg_ids may never resolve."
433 433 warnings.warn(msg, RuntimeWarning)
434 434
435 435 def _build_targets(self, targets):
436 436 """Turn valid target IDs or 'all' into two lists:
437 437 (int_ids, uuids).
438 438 """
439 439 if not self._ids:
440 440 # flush notification socket if no engines yet, just in case
441 441 if not self.ids:
442 442 raise error.NoEnginesRegistered("Can't build targets without any engines")
443 443
444 444 if targets is None:
445 445 targets = self._ids
446 446 elif isinstance(targets, basestring):
447 447 if targets.lower() == 'all':
448 448 targets = self._ids
449 449 else:
450 450 raise TypeError("%r not valid str target, must be 'all'"%(targets))
451 451 elif isinstance(targets, int):
452 452 if targets < 0:
453 453 targets = self.ids[targets]
454 454 if targets not in self._ids:
455 455 raise IndexError("No such engine: %i"%targets)
456 456 targets = [targets]
457 457
458 458 if isinstance(targets, slice):
459 459 indices = range(len(self._ids))[targets]
460 460 ids = self.ids
461 461 targets = [ ids[i] for i in indices ]
462 462
463 463 if not isinstance(targets, (tuple, list, xrange)):
464 464 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
465 465
466 466 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
467 467
468 468 def _connect(self, sshserver, ssh_kwargs, timeout):
469 469 """setup all our socket connections to the cluster. This is called from
470 470 __init__."""
471 471
472 472 # Maybe allow reconnecting?
473 473 if self._connected:
474 474 return
475 475 self._connected=True
476 476
477 477 def connect_socket(s, url):
478 478 url = util.disambiguate_url(url, self._config['location'])
479 479 if self._ssh:
480 480 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
481 481 else:
482 482 return s.connect(url)
483 483
484 484 self.session.send(self._query_socket, 'connection_request')
485 485 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
486 486 poller = zmq.Poller()
487 487 poller.register(self._query_socket, zmq.POLLIN)
488 488 # poll expects milliseconds, timeout is seconds
489 489 evts = poller.poll(timeout*1000)
490 490 if not evts:
491 491 raise error.TimeoutError("Hub connection request timed out")
492 492 idents,msg = self.session.recv(self._query_socket,mode=0)
493 493 if self.debug:
494 494 pprint(msg)
495 495 msg = Message(msg)
496 496 content = msg.content
497 497 self._config['registration'] = dict(content)
498 498 if content.status == 'ok':
499 499 ident = util.asbytes(self.session.session)
500 500 if content.mux:
501 501 self._mux_socket = self._context.socket(zmq.XREQ)
502 502 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
503 503 connect_socket(self._mux_socket, content.mux)
504 504 if content.task:
505 505 self._task_scheme, task_addr = content.task
506 506 self._task_socket = self._context.socket(zmq.XREQ)
507 507 self._task_socket.setsockopt(zmq.IDENTITY, ident)
508 508 connect_socket(self._task_socket, task_addr)
509 509 if content.notification:
510 510 self._notification_socket = self._context.socket(zmq.SUB)
511 511 connect_socket(self._notification_socket, content.notification)
512 512 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
513 513 # if content.query:
514 514 # self._query_socket = self._context.socket(zmq.XREQ)
515 515 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 516 # connect_socket(self._query_socket, content.query)
517 517 if content.control:
518 518 self._control_socket = self._context.socket(zmq.XREQ)
519 519 self._control_socket.setsockopt(zmq.IDENTITY, ident)
520 520 connect_socket(self._control_socket, content.control)
521 521 if content.iopub:
522 522 self._iopub_socket = self._context.socket(zmq.SUB)
523 523 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
524 524 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
525 525 connect_socket(self._iopub_socket, content.iopub)
526 526 self._update_engines(dict(content.engines))
527 527 else:
528 528 self._connected = False
529 529 raise Exception("Failed to connect!")
530 530
531 531 #--------------------------------------------------------------------------
532 532 # handlers and callbacks for incoming messages
533 533 #--------------------------------------------------------------------------
534 534
535 535 def _unwrap_exception(self, content):
536 536 """unwrap exception, and remap engine_id to int."""
537 537 e = error.unwrap_exception(content)
538 538 # print e.traceback
539 539 if e.engine_info:
540 540 e_uuid = e.engine_info['engine_uuid']
541 541 eid = self._engines[e_uuid]
542 542 e.engine_info['engine_id'] = eid
543 543 return e
544 544
545 545 def _extract_metadata(self, header, parent, content):
546 546 md = {'msg_id' : parent['msg_id'],
547 547 'received' : datetime.now(),
548 548 'engine_uuid' : header.get('engine', None),
549 549 'follow' : parent.get('follow', []),
550 550 'after' : parent.get('after', []),
551 551 'status' : content['status'],
552 552 }
553 553
554 554 if md['engine_uuid'] is not None:
555 555 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
556 556
557 557 if 'date' in parent:
558 558 md['submitted'] = parent['date']
559 559 if 'started' in header:
560 560 md['started'] = header['started']
561 561 if 'date' in header:
562 562 md['completed'] = header['date']
563 563 return md
564 564
565 565 def _register_engine(self, msg):
566 566 """Register a new engine, and update our connection info."""
567 567 content = msg['content']
568 568 eid = content['id']
569 569 d = {eid : content['queue']}
570 570 self._update_engines(d)
571 571
572 572 def _unregister_engine(self, msg):
573 573 """Unregister an engine that has died."""
574 574 content = msg['content']
575 575 eid = int(content['id'])
576 576 if eid in self._ids:
577 577 self._ids.remove(eid)
578 578 uuid = self._engines.pop(eid)
579 579
580 580 self._handle_stranded_msgs(eid, uuid)
581 581
582 582 if self._task_socket and self._task_scheme == 'pure':
583 583 self._stop_scheduling_tasks()
584 584
585 585 def _handle_stranded_msgs(self, eid, uuid):
586 586 """Handle messages known to be on an engine when the engine unregisters.
587 587
588 588 It is possible that this will fire prematurely - that is, an engine will
589 589 go down after completing a result, and the client will be notified
590 590 of the unregistration and later receive the successful result.
591 591 """
592 592
593 593 outstanding = self._outstanding_dict[uuid]
594 594
595 595 for msg_id in list(outstanding):
596 596 if msg_id in self.results:
597 597 # we already
598 598 continue
599 599 try:
600 600 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
601 601 except:
602 602 content = error.wrap_exception()
603 603 # build a fake message:
604 604 parent = {}
605 605 header = {}
606 606 parent['msg_id'] = msg_id
607 607 header['engine'] = uuid
608 608 header['date'] = datetime.now()
609 609 msg = dict(parent_header=parent, header=header, content=content)
610 610 self._handle_apply_reply(msg)
611 611
612 612 def _handle_execute_reply(self, msg):
613 613 """Save the reply to an execute_request into our results.
614 614
615 615 execute messages are never actually used. apply is used instead.
616 616 """
617 617
618 618 parent = msg['parent_header']
619 619 msg_id = parent['msg_id']
620 620 if msg_id not in self.outstanding:
621 621 if msg_id in self.history:
622 622 print ("got stale result: %s"%msg_id)
623 623 else:
624 624 print ("got unknown result: %s"%msg_id)
625 625 else:
626 626 self.outstanding.remove(msg_id)
627 627 self.results[msg_id] = self._unwrap_exception(msg['content'])
628 628
629 629 def _handle_apply_reply(self, msg):
630 630 """Save the reply to an apply_request into our results."""
631 631 parent = msg['parent_header']
632 632 msg_id = parent['msg_id']
633 633 if msg_id not in self.outstanding:
634 634 if msg_id in self.history:
635 635 print ("got stale result: %s"%msg_id)
636 636 print self.results[msg_id]
637 637 print msg
638 638 else:
639 639 print ("got unknown result: %s"%msg_id)
640 640 else:
641 641 self.outstanding.remove(msg_id)
642 642 content = msg['content']
643 643 header = msg['header']
644 644
645 645 # construct metadata:
646 646 md = self.metadata[msg_id]
647 647 md.update(self._extract_metadata(header, parent, content))
648 648 # is this redundant?
649 649 self.metadata[msg_id] = md
650 650
651 651 e_outstanding = self._outstanding_dict[md['engine_uuid']]
652 652 if msg_id in e_outstanding:
653 653 e_outstanding.remove(msg_id)
654 654
655 655 # construct result:
656 656 if content['status'] == 'ok':
657 657 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
658 658 elif content['status'] == 'aborted':
659 659 self.results[msg_id] = error.TaskAborted(msg_id)
660 660 elif content['status'] == 'resubmitted':
661 661 # TODO: handle resubmission
662 662 pass
663 663 else:
664 664 self.results[msg_id] = self._unwrap_exception(content)
665 665
666 666 def _flush_notifications(self):
667 667 """Flush notifications of engine registrations waiting
668 668 in ZMQ queue."""
669 669 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
670 670 while msg is not None:
671 671 if self.debug:
672 672 pprint(msg)
673 673 msg_type = msg['header']['msg_type']
674 674 handler = self._notification_handlers.get(msg_type, None)
675 675 if handler is None:
676 676 raise Exception("Unhandled message type: %s"%msg.msg_type)
677 677 else:
678 678 handler(msg)
679 679 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
680 680
681 681 def _flush_results(self, sock):
682 682 """Flush task or queue results waiting in ZMQ queue."""
683 683 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 684 while msg is not None:
685 685 if self.debug:
686 686 pprint(msg)
687 687 msg_type = msg['header']['msg_type']
688 688 handler = self._queue_handlers.get(msg_type, None)
689 689 if handler is None:
690 690 raise Exception("Unhandled message type: %s"%msg.msg_type)
691 691 else:
692 692 handler(msg)
693 693 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 694
695 695 def _flush_control(self, sock):
696 696 """Flush replies from the control channel waiting
697 697 in the ZMQ queue.
698 698
699 699 Currently: ignore them."""
700 700 if self._ignored_control_replies <= 0:
701 701 return
702 702 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
703 703 while msg is not None:
704 704 self._ignored_control_replies -= 1
705 705 if self.debug:
706 706 pprint(msg)
707 707 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
708 708
709 709 def _flush_ignored_control(self):
710 710 """flush ignored control replies"""
711 711 while self._ignored_control_replies > 0:
712 712 self.session.recv(self._control_socket)
713 713 self._ignored_control_replies -= 1
714 714
715 715 def _flush_ignored_hub_replies(self):
716 716 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
717 717 while msg is not None:
718 718 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
719 719
720 720 def _flush_iopub(self, sock):
721 721 """Flush replies from the iopub channel waiting
722 722 in the ZMQ queue.
723 723 """
724 724 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
725 725 while msg is not None:
726 726 if self.debug:
727 727 pprint(msg)
728 728 parent = msg['parent_header']
729 729 msg_id = parent['msg_id']
730 730 content = msg['content']
731 731 header = msg['header']
732 732 msg_type = msg['header']['msg_type']
733 733
734 734 # init metadata:
735 735 md = self.metadata[msg_id]
736 736
737 737 if msg_type == 'stream':
738 738 name = content['name']
739 739 s = md[name] or ''
740 740 md[name] = s + content['data']
741 741 elif msg_type == 'pyerr':
742 742 md.update({'pyerr' : self._unwrap_exception(content)})
743 743 elif msg_type == 'pyin':
744 744 md.update({'pyin' : content['code']})
745 745 else:
746 746 md.update({msg_type : content.get('data', '')})
747 747
748 748 # reduntant?
749 749 self.metadata[msg_id] = md
750 750
751 751 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
752 752
753 753 #--------------------------------------------------------------------------
754 754 # len, getitem
755 755 #--------------------------------------------------------------------------
756 756
757 757 def __len__(self):
758 758 """len(client) returns # of engines."""
759 759 return len(self.ids)
760 760
761 761 def __getitem__(self, key):
762 762 """index access returns DirectView multiplexer objects
763 763
764 764 Must be int, slice, or list/tuple/xrange of ints"""
765 765 if not isinstance(key, (int, slice, tuple, list, xrange)):
766 766 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
767 767 else:
768 768 return self.direct_view(key)
769 769
770 770 #--------------------------------------------------------------------------
771 771 # Begin public methods
772 772 #--------------------------------------------------------------------------
773 773
774 774 @property
775 775 def ids(self):
776 776 """Always up-to-date ids property."""
777 777 self._flush_notifications()
778 778 # always copy:
779 779 return list(self._ids)
780 780
781 781 def close(self):
782 782 if self._closed:
783 783 return
784 784 snames = filter(lambda n: n.endswith('socket'), dir(self))
785 785 for socket in map(lambda name: getattr(self, name), snames):
786 786 if isinstance(socket, zmq.Socket) and not socket.closed:
787 787 socket.close()
788 788 self._closed = True
789 789
790 790 def spin(self):
791 791 """Flush any registration notifications and execution results
792 792 waiting in the ZMQ queue.
793 793 """
794 794 if self._notification_socket:
795 795 self._flush_notifications()
796 796 if self._mux_socket:
797 797 self._flush_results(self._mux_socket)
798 798 if self._task_socket:
799 799 self._flush_results(self._task_socket)
800 800 if self._control_socket:
801 801 self._flush_control(self._control_socket)
802 802 if self._iopub_socket:
803 803 self._flush_iopub(self._iopub_socket)
804 804 if self._query_socket:
805 805 self._flush_ignored_hub_replies()
806 806
807 807 def wait(self, jobs=None, timeout=-1):
808 808 """waits on one or more `jobs`, for up to `timeout` seconds.
809 809
810 810 Parameters
811 811 ----------
812 812
813 813 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
814 814 ints are indices to self.history
815 815 strs are msg_ids
816 816 default: wait on all outstanding messages
817 817 timeout : float
818 818 a time in seconds, after which to give up.
819 819 default is -1, which means no timeout
820 820
821 821 Returns
822 822 -------
823 823
824 824 True : when all msg_ids are done
825 825 False : timeout reached, some msg_ids still outstanding
826 826 """
827 827 tic = time.time()
828 828 if jobs is None:
829 829 theids = self.outstanding
830 830 else:
831 831 if isinstance(jobs, (int, basestring, AsyncResult)):
832 832 jobs = [jobs]
833 833 theids = set()
834 834 for job in jobs:
835 835 if isinstance(job, int):
836 836 # index access
837 837 job = self.history[job]
838 838 elif isinstance(job, AsyncResult):
839 839 map(theids.add, job.msg_ids)
840 840 continue
841 841 theids.add(job)
842 842 if not theids.intersection(self.outstanding):
843 843 return True
844 844 self.spin()
845 845 while theids.intersection(self.outstanding):
846 846 if timeout >= 0 and ( time.time()-tic ) > timeout:
847 847 break
848 848 time.sleep(1e-3)
849 849 self.spin()
850 850 return len(theids.intersection(self.outstanding)) == 0
851 851
852 852 #--------------------------------------------------------------------------
853 853 # Control methods
854 854 #--------------------------------------------------------------------------
855 855
856 856 @spin_first
857 857 def clear(self, targets=None, block=None):
858 858 """Clear the namespace in target(s)."""
859 859 block = self.block if block is None else block
860 860 targets = self._build_targets(targets)[0]
861 861 for t in targets:
862 862 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
863 863 error = False
864 864 if block:
865 865 self._flush_ignored_control()
866 866 for i in range(len(targets)):
867 867 idents,msg = self.session.recv(self._control_socket,0)
868 868 if self.debug:
869 869 pprint(msg)
870 870 if msg['content']['status'] != 'ok':
871 871 error = self._unwrap_exception(msg['content'])
872 872 else:
873 873 self._ignored_control_replies += len(targets)
874 874 if error:
875 875 raise error
876 876
877 877
878 878 @spin_first
879 879 def abort(self, jobs=None, targets=None, block=None):
880 880 """Abort specific jobs from the execution queues of target(s).
881 881
882 882 This is a mechanism to prevent jobs that have already been submitted
883 883 from executing.
884 884
885 885 Parameters
886 886 ----------
887 887
888 888 jobs : msg_id, list of msg_ids, or AsyncResult
889 889 The jobs to be aborted
890 890
891 891
892 892 """
893 893 block = self.block if block is None else block
894 894 targets = self._build_targets(targets)[0]
895 895 msg_ids = []
896 896 if isinstance(jobs, (basestring,AsyncResult)):
897 897 jobs = [jobs]
898 898 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
899 899 if bad_ids:
900 900 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
901 901 for j in jobs:
902 902 if isinstance(j, AsyncResult):
903 903 msg_ids.extend(j.msg_ids)
904 904 else:
905 905 msg_ids.append(j)
906 906 content = dict(msg_ids=msg_ids)
907 907 for t in targets:
908 908 self.session.send(self._control_socket, 'abort_request',
909 909 content=content, ident=t)
910 910 error = False
911 911 if block:
912 912 self._flush_ignored_control()
913 913 for i in range(len(targets)):
914 914 idents,msg = self.session.recv(self._control_socket,0)
915 915 if self.debug:
916 916 pprint(msg)
917 917 if msg['content']['status'] != 'ok':
918 918 error = self._unwrap_exception(msg['content'])
919 919 else:
920 920 self._ignored_control_replies += len(targets)
921 921 if error:
922 922 raise error
923 923
924 924 @spin_first
925 925 def shutdown(self, targets=None, restart=False, hub=False, block=None):
926 926 """Terminates one or more engine processes, optionally including the hub."""
927 927 block = self.block if block is None else block
928 928 if hub:
929 929 targets = 'all'
930 930 targets = self._build_targets(targets)[0]
931 931 for t in targets:
932 932 self.session.send(self._control_socket, 'shutdown_request',
933 933 content={'restart':restart},ident=t)
934 934 error = False
935 935 if block or hub:
936 936 self._flush_ignored_control()
937 937 for i in range(len(targets)):
938 938 idents,msg = self.session.recv(self._control_socket, 0)
939 939 if self.debug:
940 940 pprint(msg)
941 941 if msg['content']['status'] != 'ok':
942 942 error = self._unwrap_exception(msg['content'])
943 943 else:
944 944 self._ignored_control_replies += len(targets)
945 945
946 946 if hub:
947 947 time.sleep(0.25)
948 948 self.session.send(self._query_socket, 'shutdown_request')
949 949 idents,msg = self.session.recv(self._query_socket, 0)
950 950 if self.debug:
951 951 pprint(msg)
952 952 if msg['content']['status'] != 'ok':
953 953 error = self._unwrap_exception(msg['content'])
954 954
955 955 if error:
956 956 raise error
957 957
958 958 #--------------------------------------------------------------------------
959 959 # Execution related methods
960 960 #--------------------------------------------------------------------------
961 961
962 962 def _maybe_raise(self, result):
963 963 """wrapper for maybe raising an exception if apply failed."""
964 964 if isinstance(result, error.RemoteError):
965 965 raise result
966 966
967 967 return result
968 968
969 969 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
970 970 ident=None):
971 971 """construct and send an apply message via a socket.
972 972
973 973 This is the principal method with which all engine execution is performed by views.
974 974 """
975 975
976 976 assert not self._closed, "cannot use me anymore, I'm closed!"
977 977 # defaults:
978 978 args = args if args is not None else []
979 979 kwargs = kwargs if kwargs is not None else {}
980 980 subheader = subheader if subheader is not None else {}
981 981
982 982 # validate arguments
983 983 if not callable(f):
984 984 raise TypeError("f must be callable, not %s"%type(f))
985 985 if not isinstance(args, (tuple, list)):
986 986 raise TypeError("args must be tuple or list, not %s"%type(args))
987 987 if not isinstance(kwargs, dict):
988 988 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
989 989 if not isinstance(subheader, dict):
990 990 raise TypeError("subheader must be dict, not %s"%type(subheader))
991 991
992 992 bufs = util.pack_apply_message(f,args,kwargs)
993 993
994 994 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
995 995 subheader=subheader, track=track)
996 996
997 997 msg_id = msg['header']['msg_id']
998 998 self.outstanding.add(msg_id)
999 999 if ident:
1000 1000 # possibly routed to a specific engine
1001 1001 if isinstance(ident, list):
1002 1002 ident = ident[-1]
1003 1003 if ident in self._engines.values():
1004 1004 # save for later, in case of engine death
1005 1005 self._outstanding_dict[ident].add(msg_id)
1006 1006 self.history.append(msg_id)
1007 1007 self.metadata[msg_id]['submitted'] = datetime.now()
1008 1008
1009 1009 return msg
1010 1010
1011 1011 #--------------------------------------------------------------------------
1012 1012 # construct a View object
1013 1013 #--------------------------------------------------------------------------
1014 1014
1015 1015 def load_balanced_view(self, targets=None):
1016 1016 """construct a DirectView object.
1017 1017
1018 1018 If no arguments are specified, create a LoadBalancedView
1019 1019 using all engines.
1020 1020
1021 1021 Parameters
1022 1022 ----------
1023 1023
1024 1024 targets: list,slice,int,etc. [default: use all engines]
1025 1025 The subset of engines across which to load-balance
1026 1026 """
1027 1027 if targets == 'all':
1028 1028 targets = None
1029 1029 if targets is not None:
1030 1030 targets = self._build_targets(targets)[1]
1031 1031 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1032 1032
1033 1033 def direct_view(self, targets='all'):
1034 1034 """construct a DirectView object.
1035 1035
1036 1036 If no targets are specified, create a DirectView
1037 1037 using all engines.
1038 1038
1039 1039 Parameters
1040 1040 ----------
1041 1041
1042 1042 targets: list,slice,int,etc. [default: use all engines]
1043 1043 The engines to use for the View
1044 1044 """
1045 1045 single = isinstance(targets, int)
1046 1046 # allow 'all' to be lazily evaluated at each execution
1047 1047 if targets != 'all':
1048 1048 targets = self._build_targets(targets)[1]
1049 1049 if single:
1050 1050 targets = targets[0]
1051 1051 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1052 1052
1053 1053 #--------------------------------------------------------------------------
1054 1054 # Query methods
1055 1055 #--------------------------------------------------------------------------
1056 1056
1057 1057 @spin_first
1058 1058 def get_result(self, indices_or_msg_ids=None, block=None):
1059 1059 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1060 1060
1061 1061 If the client already has the results, no request to the Hub will be made.
1062 1062
1063 1063 This is a convenient way to construct AsyncResult objects, which are wrappers
1064 1064 that include metadata about execution, and allow for awaiting results that
1065 1065 were not submitted by this Client.
1066 1066
1067 1067 It can also be a convenient way to retrieve the metadata associated with
1068 1068 blocking execution, since it always retrieves
1069 1069
1070 1070 Examples
1071 1071 --------
1072 1072 ::
1073 1073
1074 1074 In [10]: r = client.apply()
1075 1075
1076 1076 Parameters
1077 1077 ----------
1078 1078
1079 1079 indices_or_msg_ids : integer history index, str msg_id, or list of either
1080 1080 The indices or msg_ids of indices to be retrieved
1081 1081
1082 1082 block : bool
1083 1083 Whether to wait for the result to be done
1084 1084
1085 1085 Returns
1086 1086 -------
1087 1087
1088 1088 AsyncResult
1089 1089 A single AsyncResult object will always be returned.
1090 1090
1091 1091 AsyncHubResult
1092 1092 A subclass of AsyncResult that retrieves results from the Hub
1093 1093
1094 1094 """
1095 1095 block = self.block if block is None else block
1096 1096 if indices_or_msg_ids is None:
1097 1097 indices_or_msg_ids = -1
1098 1098
1099 1099 if not isinstance(indices_or_msg_ids, (list,tuple)):
1100 1100 indices_or_msg_ids = [indices_or_msg_ids]
1101 1101
1102 1102 theids = []
1103 1103 for id in indices_or_msg_ids:
1104 1104 if isinstance(id, int):
1105 1105 id = self.history[id]
1106 1106 if not isinstance(id, basestring):
1107 1107 raise TypeError("indices must be str or int, not %r"%id)
1108 1108 theids.append(id)
1109 1109
1110 1110 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1111 1111 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1112 1112
1113 1113 if remote_ids:
1114 1114 ar = AsyncHubResult(self, msg_ids=theids)
1115 1115 else:
1116 1116 ar = AsyncResult(self, msg_ids=theids)
1117 1117
1118 1118 if block:
1119 1119 ar.wait()
1120 1120
1121 1121 return ar
1122 1122
1123 1123 @spin_first
1124 1124 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1125 1125 """Resubmit one or more tasks.
1126 1126
1127 1127 in-flight tasks may not be resubmitted.
1128 1128
1129 1129 Parameters
1130 1130 ----------
1131 1131
1132 1132 indices_or_msg_ids : integer history index, str msg_id, or list of either
1133 1133 The indices or msg_ids of indices to be retrieved
1134 1134
1135 1135 block : bool
1136 1136 Whether to wait for the result to be done
1137 1137
1138 1138 Returns
1139 1139 -------
1140 1140
1141 1141 AsyncHubResult
1142 1142 A subclass of AsyncResult that retrieves results from the Hub
1143 1143
1144 1144 """
1145 1145 block = self.block if block is None else block
1146 1146 if indices_or_msg_ids is None:
1147 1147 indices_or_msg_ids = -1
1148 1148
1149 1149 if not isinstance(indices_or_msg_ids, (list,tuple)):
1150 1150 indices_or_msg_ids = [indices_or_msg_ids]
1151 1151
1152 1152 theids = []
1153 1153 for id in indices_or_msg_ids:
1154 1154 if isinstance(id, int):
1155 1155 id = self.history[id]
1156 1156 if not isinstance(id, basestring):
1157 1157 raise TypeError("indices must be str or int, not %r"%id)
1158 1158 theids.append(id)
1159 1159
1160 1160 for msg_id in theids:
1161 1161 self.outstanding.discard(msg_id)
1162 1162 if msg_id in self.history:
1163 1163 self.history.remove(msg_id)
1164 1164 self.results.pop(msg_id, None)
1165 1165 self.metadata.pop(msg_id, None)
1166 1166 content = dict(msg_ids = theids)
1167 1167
1168 1168 self.session.send(self._query_socket, 'resubmit_request', content)
1169 1169
1170 1170 zmq.select([self._query_socket], [], [])
1171 1171 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1172 1172 if self.debug:
1173 1173 pprint(msg)
1174 1174 content = msg['content']
1175 1175 if content['status'] != 'ok':
1176 1176 raise self._unwrap_exception(content)
1177 1177
1178 1178 ar = AsyncHubResult(self, msg_ids=theids)
1179 1179
1180 1180 if block:
1181 1181 ar.wait()
1182 1182
1183 1183 return ar
1184 1184
1185 1185 @spin_first
1186 1186 def result_status(self, msg_ids, status_only=True):
1187 1187 """Check on the status of the result(s) of the apply request with `msg_ids`.
1188 1188
1189 1189 If status_only is False, then the actual results will be retrieved, else
1190 1190 only the status of the results will be checked.
1191 1191
1192 1192 Parameters
1193 1193 ----------
1194 1194
1195 1195 msg_ids : list of msg_ids
1196 1196 if int:
1197 1197 Passed as index to self.history for convenience.
1198 1198 status_only : bool (default: True)
1199 1199 if False:
1200 1200 Retrieve the actual results of completed tasks.
1201 1201
1202 1202 Returns
1203 1203 -------
1204 1204
1205 1205 results : dict
1206 1206 There will always be the keys 'pending' and 'completed', which will
1207 1207 be lists of msg_ids that are incomplete or complete. If `status_only`
1208 1208 is False, then completed results will be keyed by their `msg_id`.
1209 1209 """
1210 1210 if not isinstance(msg_ids, (list,tuple)):
1211 1211 msg_ids = [msg_ids]
1212 1212
1213 1213 theids = []
1214 1214 for msg_id in msg_ids:
1215 1215 if isinstance(msg_id, int):
1216 1216 msg_id = self.history[msg_id]
1217 1217 if not isinstance(msg_id, basestring):
1218 1218 raise TypeError("msg_ids must be str, not %r"%msg_id)
1219 1219 theids.append(msg_id)
1220 1220
1221 1221 completed = []
1222 1222 local_results = {}
1223 1223
1224 1224 # comment this block out to temporarily disable local shortcut:
1225 1225 for msg_id in theids:
1226 1226 if msg_id in self.results:
1227 1227 completed.append(msg_id)
1228 1228 local_results[msg_id] = self.results[msg_id]
1229 1229 theids.remove(msg_id)
1230 1230
1231 1231 if theids: # some not locally cached
1232 1232 content = dict(msg_ids=theids, status_only=status_only)
1233 1233 msg = self.session.send(self._query_socket, "result_request", content=content)
1234 1234 zmq.select([self._query_socket], [], [])
1235 1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1236 1236 if self.debug:
1237 1237 pprint(msg)
1238 1238 content = msg['content']
1239 1239 if content['status'] != 'ok':
1240 1240 raise self._unwrap_exception(content)
1241 1241 buffers = msg['buffers']
1242 1242 else:
1243 1243 content = dict(completed=[],pending=[])
1244 1244
1245 1245 content['completed'].extend(completed)
1246 1246
1247 1247 if status_only:
1248 1248 return content
1249 1249
1250 1250 failures = []
1251 1251 # load cached results into result:
1252 1252 content.update(local_results)
1253 1253
1254 1254 # update cache with results:
1255 1255 for msg_id in sorted(theids):
1256 1256 if msg_id in content['completed']:
1257 1257 rec = content[msg_id]
1258 1258 parent = rec['header']
1259 1259 header = rec['result_header']
1260 1260 rcontent = rec['result_content']
1261 1261 iodict = rec['io']
1262 1262 if isinstance(rcontent, str):
1263 1263 rcontent = self.session.unpack(rcontent)
1264 1264
1265 1265 md = self.metadata[msg_id]
1266 1266 md.update(self._extract_metadata(header, parent, rcontent))
1267 1267 md.update(iodict)
1268 1268
1269 1269 if rcontent['status'] == 'ok':
1270 1270 res,buffers = util.unserialize_object(buffers)
1271 1271 else:
1272 1272 print rcontent
1273 1273 res = self._unwrap_exception(rcontent)
1274 1274 failures.append(res)
1275 1275
1276 1276 self.results[msg_id] = res
1277 1277 content[msg_id] = res
1278 1278
1279 1279 if len(theids) == 1 and failures:
1280 1280 raise failures[0]
1281 1281
1282 1282 error.collect_exceptions(failures, "result_status")
1283 1283 return content
1284 1284
1285 1285 @spin_first
1286 1286 def queue_status(self, targets='all', verbose=False):
1287 1287 """Fetch the status of engine queues.
1288 1288
1289 1289 Parameters
1290 1290 ----------
1291 1291
1292 1292 targets : int/str/list of ints/strs
1293 1293 the engines whose states are to be queried.
1294 1294 default : all
1295 1295 verbose : bool
1296 1296 Whether to return lengths only, or lists of ids for each element
1297 1297 """
1298 1298 engine_ids = self._build_targets(targets)[1]
1299 1299 content = dict(targets=engine_ids, verbose=verbose)
1300 1300 self.session.send(self._query_socket, "queue_request", content=content)
1301 1301 idents,msg = self.session.recv(self._query_socket, 0)
1302 1302 if self.debug:
1303 1303 pprint(msg)
1304 1304 content = msg['content']
1305 1305 status = content.pop('status')
1306 1306 if status != 'ok':
1307 1307 raise self._unwrap_exception(content)
1308 1308 content = rekey(content)
1309 1309 if isinstance(targets, int):
1310 1310 return content[targets]
1311 1311 else:
1312 1312 return content
1313 1313
1314 1314 @spin_first
1315 1315 def purge_results(self, jobs=[], targets=[]):
1316 1316 """Tell the Hub to forget results.
1317 1317
1318 1318 Individual results can be purged by msg_id, or the entire
1319 1319 history of specific targets can be purged.
1320 1320
1321 1321 Use `purge_results('all')` to scrub everything from the Hub's db.
1322 1322
1323 1323 Parameters
1324 1324 ----------
1325 1325
1326 1326 jobs : str or list of str or AsyncResult objects
1327 1327 the msg_ids whose results should be forgotten.
1328 1328 targets : int/str/list of ints/strs
1329 1329 The targets, by int_id, whose entire history is to be purged.
1330 1330
1331 1331 default : None
1332 1332 """
1333 1333 if not targets and not jobs:
1334 1334 raise ValueError("Must specify at least one of `targets` and `jobs`")
1335 1335 if targets:
1336 1336 targets = self._build_targets(targets)[1]
1337 1337
1338 1338 # construct msg_ids from jobs
1339 1339 if jobs == 'all':
1340 1340 msg_ids = jobs
1341 1341 else:
1342 1342 msg_ids = []
1343 1343 if isinstance(jobs, (basestring,AsyncResult)):
1344 1344 jobs = [jobs]
1345 1345 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1346 1346 if bad_ids:
1347 1347 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1348 1348 for j in jobs:
1349 1349 if isinstance(j, AsyncResult):
1350 1350 msg_ids.extend(j.msg_ids)
1351 1351 else:
1352 1352 msg_ids.append(j)
1353 1353
1354 1354 content = dict(engine_ids=targets, msg_ids=msg_ids)
1355 1355 self.session.send(self._query_socket, "purge_request", content=content)
1356 1356 idents, msg = self.session.recv(self._query_socket, 0)
1357 1357 if self.debug:
1358 1358 pprint(msg)
1359 1359 content = msg['content']
1360 1360 if content['status'] != 'ok':
1361 1361 raise self._unwrap_exception(content)
1362 1362
1363 1363 @spin_first
1364 1364 def hub_history(self):
1365 1365 """Get the Hub's history
1366 1366
1367 1367 Just like the Client, the Hub has a history, which is a list of msg_ids.
1368 1368 This will contain the history of all clients, and, depending on configuration,
1369 1369 may contain history across multiple cluster sessions.
1370 1370
1371 1371 Any msg_id returned here is a valid argument to `get_result`.
1372 1372
1373 1373 Returns
1374 1374 -------
1375 1375
1376 1376 msg_ids : list of strs
1377 1377 list of all msg_ids, ordered by task submission time.
1378 1378 """
1379 1379
1380 1380 self.session.send(self._query_socket, "history_request", content={})
1381 1381 idents, msg = self.session.recv(self._query_socket, 0)
1382 1382
1383 1383 if self.debug:
1384 1384 pprint(msg)
1385 1385 content = msg['content']
1386 1386 if content['status'] != 'ok':
1387 1387 raise self._unwrap_exception(content)
1388 1388 else:
1389 1389 return content['history']
1390 1390
1391 1391 @spin_first
1392 1392 def db_query(self, query, keys=None):
1393 1393 """Query the Hub's TaskRecord database
1394 1394
1395 1395 This will return a list of task record dicts that match `query`
1396 1396
1397 1397 Parameters
1398 1398 ----------
1399 1399
1400 1400 query : mongodb query dict
1401 1401 The search dict. See mongodb query docs for details.
1402 1402 keys : list of strs [optional]
1403 1403 The subset of keys to be returned. The default is to fetch everything but buffers.
1404 1404 'msg_id' will *always* be included.
1405 1405 """
1406 1406 if isinstance(keys, basestring):
1407 1407 keys = [keys]
1408 1408 content = dict(query=query, keys=keys)
1409 1409 self.session.send(self._query_socket, "db_request", content=content)
1410 1410 idents, msg = self.session.recv(self._query_socket, 0)
1411 1411 if self.debug:
1412 1412 pprint(msg)
1413 1413 content = msg['content']
1414 1414 if content['status'] != 'ok':
1415 1415 raise self._unwrap_exception(content)
1416 1416
1417 1417 records = content['records']
1418 1418
1419 1419 buffer_lens = content['buffer_lens']
1420 1420 result_buffer_lens = content['result_buffer_lens']
1421 1421 buffers = msg['buffers']
1422 1422 has_bufs = buffer_lens is not None
1423 1423 has_rbufs = result_buffer_lens is not None
1424 1424 for i,rec in enumerate(records):
1425 1425 # relink buffers
1426 1426 if has_bufs:
1427 1427 blen = buffer_lens[i]
1428 1428 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1429 1429 if has_rbufs:
1430 1430 blen = result_buffer_lens[i]
1431 1431 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1432 1432
1433 1433 return records
1434 1434
1435 1435 __all__ = [ 'Client' ]
@@ -1,226 +1,226 b''
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from getpass import getpass
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.external.ssh import tunnel
26 26 # internal
27 27 from IPython.utils.traitlets import (
28 28 Instance, Dict, Int, Type, CFloat, Unicode, CBytes, Bool
29 29 )
30 30 # from IPython.utils.localinterfaces import LOCALHOST
31 31
32 32 from IPython.parallel.controller.heartmonitor import Heart
33 33 from IPython.parallel.factory import RegistrationFactory
34 34 from IPython.parallel.util import disambiguate_url, asbytes
35 35
36 36 from IPython.zmq.session import Message
37 37
38 38 from .streamkernel import Kernel
39 39
40 40 class EngineFactory(RegistrationFactory):
41 41 """IPython engine"""
42 42
43 43 # configurables:
44 44 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
45 45 help="""The OutStream for handling stdout/err.
46 46 Typically 'IPython.zmq.iostream.OutStream'""")
47 47 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
48 48 help="""The class for handling displayhook.
49 49 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
50 50 location=Unicode(config=True,
51 51 help="""The location (an IP address) of the controller. This is
52 52 used for disambiguating URLs, to determine whether
53 53 loopback should be used to connect or the public address.""")
54 54 timeout=CFloat(2,config=True,
55 55 help="""The time (in seconds) to wait for the Controller to respond
56 56 to registration requests before giving up.""")
57 57 sshserver=Unicode(config=True,
58 58 help="""The SSH server to use for tunneling connections to the Controller.""")
59 59 sshkey=Unicode(config=True,
60 help="""The SSH keyfile to use when tunneling connections to the Controller.""")
60 help="""The SSH private key file to use when tunneling connections to the Controller.""")
61 61 paramiko=Bool(sys.platform == 'win32', config=True,
62 62 help="""Whether to use paramiko instead of openssh for tunnels.""")
63 63
64 64 # not configurable:
65 65 user_ns=Dict()
66 66 id=Int(allow_none=True)
67 67 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 68 kernel=Instance(Kernel)
69 69
70 70 bident = CBytes()
71 71 ident = Unicode()
72 72 def _ident_changed(self, name, old, new):
73 73 self.bident = asbytes(new)
74 74 using_ssh=Bool(False)
75 75
76 76
77 77 def __init__(self, **kwargs):
78 78 super(EngineFactory, self).__init__(**kwargs)
79 79 self.ident = self.session.session
80 80
81 81 def init_connector(self):
82 82 """construct connection function, which handles tunnels."""
83 83 self.using_ssh = bool(self.sshkey or self.sshserver)
84 84
85 85 if self.sshkey and not self.sshserver:
86 86 # We are using ssh directly to the controller, tunneling localhost to localhost
87 87 self.sshserver = self.url.split('://')[1].split(':')[0]
88 88
89 89 if self.using_ssh:
90 90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
91 91 password=False
92 92 else:
93 93 password = getpass("SSH Password for %s: "%self.sshserver)
94 94 else:
95 95 password = False
96 96
97 97 def connect(s, url):
98 98 url = disambiguate_url(url, self.location)
99 99 if self.using_ssh:
100 100 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
101 101 return tunnel.tunnel_connection(s, url, self.sshserver,
102 102 keyfile=self.sshkey, paramiko=self.paramiko,
103 103 password=password,
104 104 )
105 105 else:
106 106 return s.connect(url)
107 107
108 108 def maybe_tunnel(url):
109 109 """like connect, but don't complete the connection (for use by heartbeat)"""
110 110 url = disambiguate_url(url, self.location)
111 111 if self.using_ssh:
112 112 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
113 113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
114 114 keyfile=self.sshkey, paramiko=self.paramiko,
115 115 password=password,
116 116 )
117 117 return url
118 118 return connect, maybe_tunnel
119 119
120 120 def register(self):
121 121 """send the registration_request"""
122 122
123 123 self.log.info("Registering with controller at %s"%self.url)
124 124 ctx = self.context
125 125 connect,maybe_tunnel = self.init_connector()
126 126 reg = ctx.socket(zmq.XREQ)
127 127 reg.setsockopt(zmq.IDENTITY, self.bident)
128 128 connect(reg, self.url)
129 129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
130 130
131 131
132 132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
133 133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
134 134 # print (self.session.key)
135 135 self.session.send(self.registrar, "registration_request",content=content)
136 136
137 137 def complete_registration(self, msg, connect, maybe_tunnel):
138 138 # print msg
139 139 self._abort_dc.stop()
140 140 ctx = self.context
141 141 loop = self.loop
142 142 identity = self.bident
143 143 idents,msg = self.session.feed_identities(msg)
144 144 msg = Message(self.session.unserialize(msg))
145 145
146 146 if msg.content.status == 'ok':
147 147 self.id = int(msg.content.id)
148 148
149 149 # launch heartbeat
150 150 hb_addrs = msg.content.heartbeat
151 151
152 152 # possibly forward hb ports with tunnels
153 153 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
154 154 heart = Heart(*map(str, hb_addrs), heart_id=identity)
155 155 heart.start()
156 156
157 157 # create Shell Streams (MUX, Task, etc.):
158 158 queue_addr = msg.content.mux
159 159 shell_addrs = [ str(queue_addr) ]
160 160 task_addr = msg.content.task
161 161 if task_addr:
162 162 shell_addrs.append(str(task_addr))
163 163
164 164 # Uncomment this to go back to two-socket model
165 165 # shell_streams = []
166 166 # for addr in shell_addrs:
167 167 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
168 168 # stream.setsockopt(zmq.IDENTITY, identity)
169 169 # stream.connect(disambiguate_url(addr, self.location))
170 170 # shell_streams.append(stream)
171 171
172 172 # Now use only one shell stream for mux and tasks
173 173 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
174 174 stream.setsockopt(zmq.IDENTITY, identity)
175 175 shell_streams = [stream]
176 176 for addr in shell_addrs:
177 177 connect(stream, addr)
178 178 # end single stream-socket
179 179
180 180 # control stream:
181 181 control_addr = str(msg.content.control)
182 182 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
183 183 control_stream.setsockopt(zmq.IDENTITY, identity)
184 184 connect(control_stream, control_addr)
185 185
186 186 # create iopub stream:
187 187 iopub_addr = msg.content.iopub
188 188 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
189 189 iopub_stream.setsockopt(zmq.IDENTITY, identity)
190 190 connect(iopub_stream, iopub_addr)
191 191
192 192 # # Redirect input streams and set a display hook.
193 193 if self.out_stream_factory:
194 194 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
195 195 sys.stdout.topic = 'engine.%i.stdout'%self.id
196 196 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
197 197 sys.stderr.topic = 'engine.%i.stderr'%self.id
198 198 if self.display_hook_factory:
199 199 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
200 200 sys.displayhook.topic = 'engine.%i.pyout'%self.id
201 201
202 202 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
203 203 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
204 204 loop=loop, user_ns = self.user_ns, log=self.log)
205 205 self.kernel.start()
206 206
207 207
208 208 else:
209 209 self.log.fatal("Registration Failed: %s"%msg)
210 210 raise Exception("Registration Failed: %s"%msg)
211 211
212 212 self.log.info("Completed registration with id %i"%self.id)
213 213
214 214
215 215 def abort(self):
216 216 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
217 217 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
218 218 time.sleep(1)
219 219 sys.exit(255)
220 220
221 221 def start(self):
222 222 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
223 223 dc.start()
224 224 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
225 225 self._abort_dc.start()
226 226
General Comments 0
You need to be logged in to leave comments. Login now