##// END OF EJS Templates
update connections and diagrams for reduced sockets
MinRK -
Show More

The requested changes are too big and content was truncated. Show full diff

@@ -1,1570 +1,1584 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.pickleutil import Reference
28 28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 29 Dict, List, Bool, Str, Set)
30 30 from IPython.external.decorator import decorator
31 31 from IPython.external.ssh import tunnel
32 32
33 33 from . import error
34 34 from . import map as Map
35 35 from . import util
36 36 from . import streamsession as ss
37 37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 38 from .clusterdir import ClusterDir, ClusterDirError
39 39 from .dependency import Dependency, depend, require, dependent
40 40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 41 from .util import ReverseDict, validate_url, disambiguate_url
42 42 from .view import DirectView, LoadBalancedView
43 43
44 44 #--------------------------------------------------------------------------
45 45 # helpers for implementing old MEC API via client.apply
46 46 #--------------------------------------------------------------------------
47 47
48 48 def _push(user_ns, **ns):
49 49 """helper method for implementing `client.push` via `client.apply`"""
50 50 user_ns.update(ns)
51 51
52 52 def _pull(user_ns, keys):
53 53 """helper method for implementing `client.pull` via `client.apply`"""
54 54 if isinstance(keys, (list,tuple, set)):
55 55 for key in keys:
56 56 if not user_ns.has_key(key):
57 57 raise NameError("name '%s' is not defined"%key)
58 58 return map(user_ns.get, keys)
59 59 else:
60 60 if not user_ns.has_key(keys):
61 61 raise NameError("name '%s' is not defined"%keys)
62 62 return user_ns.get(keys)
63 63
64 64 def _clear(user_ns):
65 65 """helper method for implementing `client.clear` via `client.apply`"""
66 66 user_ns.clear()
67 67
68 68 def _execute(user_ns, code):
69 69 """helper method for implementing `client.execute` via `client.apply`"""
70 70 exec code in user_ns
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Decorators for Client methods
75 75 #--------------------------------------------------------------------------
76 76
77 77 @decorator
78 78 def spinfirst(f, self, *args, **kwargs):
79 79 """Call spin() to sync state prior to calling the method."""
80 80 self.spin()
81 81 return f(self, *args, **kwargs)
82 82
83 83 @decorator
84 84 def defaultblock(f, self, *args, **kwargs):
85 85 """Default to self.block; preserve self.block."""
86 86 block = kwargs.get('block',None)
87 87 block = self.block if block is None else block
88 88 saveblock = self.block
89 89 self.block = block
90 90 try:
91 91 ret = f(self, *args, **kwargs)
92 92 finally:
93 93 self.block = saveblock
94 94 return ret
95 95
96 96
97 97 #--------------------------------------------------------------------------
98 98 # Classes
99 99 #--------------------------------------------------------------------------
100 100
101 101 class Metadata(dict):
102 102 """Subclass of dict for initializing metadata values.
103 103
104 104 Attribute access works on keys.
105 105
106 106 These objects have a strict set of keys - errors will raise if you try
107 107 to add new keys.
108 108 """
109 109 def __init__(self, *args, **kwargs):
110 110 dict.__init__(self)
111 111 md = {'msg_id' : None,
112 112 'submitted' : None,
113 113 'started' : None,
114 114 'completed' : None,
115 115 'received' : None,
116 116 'engine_uuid' : None,
117 117 'engine_id' : None,
118 118 'follow' : None,
119 119 'after' : None,
120 120 'status' : None,
121 121
122 122 'pyin' : None,
123 123 'pyout' : None,
124 124 'pyerr' : None,
125 125 'stdout' : '',
126 126 'stderr' : '',
127 127 }
128 128 self.update(md)
129 129 self.update(dict(*args, **kwargs))
130 130
131 131 def __getattr__(self, key):
132 132 """getattr aliased to getitem"""
133 133 if key in self.iterkeys():
134 134 return self[key]
135 135 else:
136 136 raise AttributeError(key)
137 137
138 138 def __setattr__(self, key, value):
139 139 """setattr aliased to setitem, with strict"""
140 140 if key in self.iterkeys():
141 141 self[key] = value
142 142 else:
143 143 raise AttributeError(key)
144 144
145 145 def __setitem__(self, key, value):
146 146 """strict static key enforcement"""
147 147 if key in self.iterkeys():
148 148 dict.__setitem__(self, key, value)
149 149 else:
150 150 raise KeyError(key)
151 151
152 152
153 153 class Client(HasTraits):
154 154 """A semi-synchronous client to the IPython ZMQ controller
155 155
156 156 Parameters
157 157 ----------
158 158
159 159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 160 Connection information for the Hub's registration. If a json connector
161 161 file is given, then likely no further configuration is necessary.
162 162 [Default: use profile]
163 163 profile : bytes
164 164 The name of the Cluster profile to be used to find connector information.
165 165 [Default: 'default']
166 166 context : zmq.Context
167 167 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 168 username : bytes
169 169 set username to be passed to the Session object
170 170 debug : bool
171 171 flag for lots of message printing for debug purposes
172 172
173 173 #-------------- ssh related args ----------------
174 174 # These are args for configuring the ssh tunnel to be used
175 175 # credentials are used to forward connections over ssh to the Controller
176 176 # Note that the ip given in `addr` needs to be relative to sshserver
177 177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 178 # and set sshserver as the same machine the Controller is on. However,
179 179 # the only requirement is that sshserver is able to see the Controller
180 180 # (i.e. is within the same trusted network).
181 181
182 182 sshserver : str
183 183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 184 If keyfile or password is specified, and this is not, it will default to
185 185 the ip given in addr.
186 186 sshkey : str; path to public ssh key file
187 187 This specifies a key to be used in ssh login, default None.
188 188 Regular default ssh keys will be used without specifying this argument.
189 189 password : str
190 190 Your ssh password to sshserver. Note that if this is left None,
191 191 you will be prompted for it if passwordless key based login is unavailable.
192 192 paramiko : bool
193 193 flag for whether to use paramiko instead of shell ssh for tunneling.
194 194 [default: True on win32, False else]
195 195
196 196 #------- exec authentication args -------
197 197 # If even localhost is untrusted, you can have some protection against
198 198 # unauthorized execution by using a key. Messages are still sent
199 199 # as cleartext, so if someone can snoop your loopback traffic this will
200 200 # not help against malicious attacks.
201 201
202 202 exec_key : str
203 203 an authentication key or file containing a key
204 204 default: None
205 205
206 206
207 207 Attributes
208 208 ----------
209 209
210 210 ids : set of int engine IDs
211 211 requesting the ids attribute always synchronizes
212 212 the registration state. To request ids without synchronization,
213 213 use semi-private _ids attributes.
214 214
215 215 history : list of msg_ids
216 216 a list of msg_ids, keeping track of all the execution
217 217 messages you have submitted in order.
218 218
219 219 outstanding : set of msg_ids
220 220 a set of msg_ids that have been submitted, but whose
221 221 results have not yet been received.
222 222
223 223 results : dict
224 224 a dict of all our results, keyed by msg_id
225 225
226 226 block : bool
227 227 determines default behavior when block not specified
228 228 in execution methods
229 229
230 230 Methods
231 231 -------
232 232
233 233 spin
234 234 flushes incoming results and registration state changes
235 235 control methods spin, and requesting `ids` also ensures up to date
236 236
237 237 barrier
238 238 wait on one or more msg_ids
239 239
240 240 execution methods
241 241 apply
242 242 legacy: execute, run
243 243
244 244 query methods
245 245 queue_status, get_result, purge
246 246
247 247 control methods
248 248 abort, shutdown
249 249
250 250 """
251 251
252 252
253 253 block = Bool(False)
254 254 outstanding = Set()
255 255 results = Instance('collections.defaultdict', (dict,))
256 256 metadata = Instance('collections.defaultdict', (Metadata,))
257 257 history = List()
258 258 debug = Bool(False)
259 259 profile=CUnicode('default')
260 260
261 261 _outstanding_dict = Instance('collections.defaultdict', (set,))
262 262 _ids = List()
263 263 _connected=Bool(False)
264 264 _ssh=Bool(False)
265 265 _context = Instance('zmq.Context')
266 266 _config = Dict()
267 267 _engines=Instance(ReverseDict, (), {})
268 268 # _hub_socket=Instance('zmq.Socket')
269 269 _query_socket=Instance('zmq.Socket')
270 270 _control_socket=Instance('zmq.Socket')
271 271 _iopub_socket=Instance('zmq.Socket')
272 272 _notification_socket=Instance('zmq.Socket')
273 _mux_socket=Instance('zmq.Socket')
274 _task_socket=Instance('zmq.Socket')
273 _apply_socket=Instance('zmq.Socket')
274 _mux_ident=Str()
275 _task_ident=Str()
275 276 _task_scheme=Str()
276 277 _balanced_views=Dict()
277 278 _direct_views=Dict()
278 279 _closed = False
279 280
280 281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
281 282 context=None, username=None, debug=False, exec_key=None,
282 283 sshserver=None, sshkey=None, password=None, paramiko=None,
283 284 ):
284 285 super(Client, self).__init__(debug=debug, profile=profile)
285 286 if context is None:
286 287 context = zmq.Context()
287 288 self._context = context
288 289
289 290
290 291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
291 292 if self._cd is not None:
292 293 if url_or_file is None:
293 294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
294 295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
295 296 " Please specify at least one of url_or_file or profile."
296 297
297 298 try:
298 299 validate_url(url_or_file)
299 300 except AssertionError:
300 301 if not os.path.exists(url_or_file):
301 302 if self._cd:
302 303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
303 304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
304 305 with open(url_or_file) as f:
305 306 cfg = json.loads(f.read())
306 307 else:
307 308 cfg = {'url':url_or_file}
308 309
309 310 # sync defaults from args, json:
310 311 if sshserver:
311 312 cfg['ssh'] = sshserver
312 313 if exec_key:
313 314 cfg['exec_key'] = exec_key
314 315 exec_key = cfg['exec_key']
315 316 sshserver=cfg['ssh']
316 317 url = cfg['url']
317 318 location = cfg.setdefault('location', None)
318 319 cfg['url'] = disambiguate_url(cfg['url'], location)
319 320 url = cfg['url']
320 321
321 322 self._config = cfg
322 323
323 324 self._ssh = bool(sshserver or sshkey or password)
324 325 if self._ssh and sshserver is None:
325 326 # default to ssh via localhost
326 327 sshserver = url.split('://')[1].split(':')[0]
327 328 if self._ssh and password is None:
328 329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
329 330 password=False
330 331 else:
331 332 password = getpass("SSH Password for %s: "%sshserver)
332 333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
333 334 if exec_key is not None and os.path.isfile(exec_key):
334 335 arg = 'keyfile'
335 336 else:
336 337 arg = 'key'
337 338 key_arg = {arg:exec_key}
338 339 if username is None:
339 340 self.session = ss.StreamSession(**key_arg)
340 341 else:
341 342 self.session = ss.StreamSession(username, **key_arg)
342 343 self._query_socket = self._context.socket(zmq.XREQ)
343 344 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
344 345 if self._ssh:
345 346 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
346 347 else:
347 348 self._query_socket.connect(url)
348 349
349 350 self.session.debug = self.debug
350 351
351 352 self._notification_handlers = {'registration_notification' : self._register_engine,
352 353 'unregistration_notification' : self._unregister_engine,
353 354 }
354 355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
355 356 'apply_reply' : self._handle_apply_reply}
356 357 self._connect(sshserver, ssh_kwargs)
357 358
358 359 def __del__(self):
359 360 """cleanup sockets, but _not_ context."""
360 361 self.close()
361 362
362 363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
363 364 if ipython_dir is None:
364 365 ipython_dir = get_ipython_dir()
365 366 if cluster_dir is not None:
366 367 try:
367 368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
368 369 return
369 370 except ClusterDirError:
370 371 pass
371 372 elif profile is not None:
372 373 try:
373 374 self._cd = ClusterDir.find_cluster_dir_by_profile(
374 375 ipython_dir, profile)
375 376 return
376 377 except ClusterDirError:
377 378 pass
378 379 self._cd = None
379 380
380 381 @property
381 382 def ids(self):
382 383 """Always up-to-date ids property."""
383 384 self._flush_notifications()
384 385 # always copy:
385 386 return list(self._ids)
386 387
387 388 def close(self):
388 389 if self._closed:
389 390 return
390 391 snames = filter(lambda n: n.endswith('socket'), dir(self))
391 392 for socket in map(lambda name: getattr(self, name), snames):
392 393 if isinstance(socket, zmq.Socket) and not socket.closed:
393 394 socket.close()
394 395 self._closed = True
395 396
396 397 def _update_engines(self, engines):
397 398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
398 399 for k,v in engines.iteritems():
399 400 eid = int(k)
400 401 self._engines[eid] = bytes(v) # force not unicode
401 402 self._ids.append(eid)
402 403 self._ids = sorted(self._ids)
403 404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
404 self._task_scheme == 'pure' and self._task_socket:
405 self._task_scheme == 'pure' and self._task_ident:
405 406 self._stop_scheduling_tasks()
406 407
407 408 def _stop_scheduling_tasks(self):
408 409 """Stop scheduling tasks because an engine has been unregistered
409 410 from a pure ZMQ scheduler.
410 411 """
411
412 self._task_socket.close()
413 self._task_socket = None
412 self._task_ident = ''
413 # self._task_socket.close()
414 # self._task_socket = None
414 415 msg = "An engine has been unregistered, and we are using pure " +\
415 416 "ZMQ task scheduling. Task farming will be disabled."
416 417 if self.outstanding:
417 418 msg += " If you were running tasks when this happened, " +\
418 419 "some `outstanding` msg_ids may never resolve."
419 420 warnings.warn(msg, RuntimeWarning)
420 421
421 422 def _build_targets(self, targets):
422 423 """Turn valid target IDs or 'all' into two lists:
423 424 (int_ids, uuids).
424 425 """
425 426 if targets is None:
426 427 targets = self._ids
427 428 elif isinstance(targets, str):
428 429 if targets.lower() == 'all':
429 430 targets = self._ids
430 431 else:
431 432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
432 433 elif isinstance(targets, int):
433 434 targets = [targets]
434 435 return [self._engines[t] for t in targets], list(targets)
435 436
436 437 def _connect(self, sshserver, ssh_kwargs):
437 438 """setup all our socket connections to the controller. This is called from
438 439 __init__."""
439 440
440 441 # Maybe allow reconnecting?
441 442 if self._connected:
442 443 return
443 444 self._connected=True
444 445
445 446 def connect_socket(s, url):
446 447 url = disambiguate_url(url, self._config['location'])
447 448 if self._ssh:
448 449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
449 450 else:
450 451 return s.connect(url)
451 452
452 453 self.session.send(self._query_socket, 'connection_request')
453 454 idents,msg = self.session.recv(self._query_socket,mode=0)
454 455 if self.debug:
455 456 pprint(msg)
456 457 msg = ss.Message(msg)
457 458 content = msg.content
458 459 self._config['registration'] = dict(content)
459 460 if content.status == 'ok':
461 self._apply_socket = self._context.socket(zmq.XREP)
462 self._apply_socket.setsockopt(zmq.IDENTITY, self.session.session)
460 463 if content.mux:
461 self._mux_socket = self._context.socket(zmq.XREQ)
462 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 connect_socket(self._mux_socket, content.mux)
464 # self._mux_socket = self._context.socket(zmq.XREQ)
465 self._mux_ident = 'mux'
466 connect_socket(self._apply_socket, content.mux)
464 467 if content.task:
465 468 self._task_scheme, task_addr = content.task
466 self._task_socket = self._context.socket(zmq.XREQ)
467 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
468 connect_socket(self._task_socket, task_addr)
469 # self._task_socket = self._context.socket(zmq.XREQ)
470 # self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 connect_socket(self._apply_socket, task_addr)
472 self._task_ident = 'task'
469 473 if content.notification:
470 474 self._notification_socket = self._context.socket(zmq.SUB)
471 475 connect_socket(self._notification_socket, content.notification)
472 476 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
473 477 # if content.query:
474 478 # self._query_socket = self._context.socket(zmq.XREQ)
475 479 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
476 480 # connect_socket(self._query_socket, content.query)
477 481 if content.control:
478 482 self._control_socket = self._context.socket(zmq.XREQ)
479 483 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
480 484 connect_socket(self._control_socket, content.control)
481 485 if content.iopub:
482 486 self._iopub_socket = self._context.socket(zmq.SUB)
483 487 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
484 488 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
485 489 connect_socket(self._iopub_socket, content.iopub)
486 490 self._update_engines(dict(content.engines))
487
491 # give XREP apply_socket some time to connect
492 time.sleep(0.25)
488 493 else:
489 494 self._connected = False
490 495 raise Exception("Failed to connect!")
491 496
492 497 #--------------------------------------------------------------------------
493 498 # handlers and callbacks for incoming messages
494 499 #--------------------------------------------------------------------------
495 500
496 501 def _unwrap_exception(self, content):
497 502 """unwrap exception, and remap engineid to int."""
498 503 e = error.unwrap_exception(content)
499 print e.traceback
504 # print e.traceback
500 505 if e.engine_info:
501 506 e_uuid = e.engine_info['engine_uuid']
502 507 eid = self._engines[e_uuid]
503 508 e.engine_info['engine_id'] = eid
504 509 return e
505 510
506 511 def _extract_metadata(self, header, parent, content):
507 512 md = {'msg_id' : parent['msg_id'],
508 513 'received' : datetime.now(),
509 514 'engine_uuid' : header.get('engine', None),
510 515 'follow' : parent.get('follow', []),
511 516 'after' : parent.get('after', []),
512 517 'status' : content['status'],
513 518 }
514 519
515 520 if md['engine_uuid'] is not None:
516 521 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
517 522
518 523 if 'date' in parent:
519 524 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
520 525 if 'started' in header:
521 526 md['started'] = datetime.strptime(header['started'], util.ISO8601)
522 527 if 'date' in header:
523 528 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
524 529 return md
525 530
526 531 def _register_engine(self, msg):
527 532 """Register a new engine, and update our connection info."""
528 533 content = msg['content']
529 534 eid = content['id']
530 535 d = {eid : content['queue']}
531 536 self._update_engines(d)
532 537
533 538 def _unregister_engine(self, msg):
534 539 """Unregister an engine that has died."""
535 540 content = msg['content']
536 541 eid = int(content['id'])
537 542 if eid in self._ids:
538 543 self._ids.remove(eid)
539 544 uuid = self._engines.pop(eid)
540 545
541 546 self._handle_stranded_msgs(eid, uuid)
542 547
543 if self._task_socket and self._task_scheme == 'pure':
548 if self._task_ident and self._task_scheme == 'pure':
544 549 self._stop_scheduling_tasks()
545 550
546 551 def _handle_stranded_msgs(self, eid, uuid):
547 552 """Handle messages known to be on an engine when the engine unregisters.
548 553
549 554 It is possible that this will fire prematurely - that is, an engine will
550 555 go down after completing a result, and the client will be notified
551 556 of the unregistration and later receive the successful result.
552 557 """
553 558
554 559 outstanding = self._outstanding_dict[uuid]
555 560
556 561 for msg_id in list(outstanding):
557 562 if msg_id in self.results:
558 563 # we already
559 564 continue
560 565 try:
561 566 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
562 567 except:
563 568 content = error.wrap_exception()
564 569 # build a fake message:
565 570 parent = {}
566 571 header = {}
567 572 parent['msg_id'] = msg_id
568 573 header['engine'] = uuid
569 574 header['date'] = datetime.now().strftime(util.ISO8601)
570 575 msg = dict(parent_header=parent, header=header, content=content)
571 576 self._handle_apply_reply(msg)
572 577
573 578 def _handle_execute_reply(self, msg):
574 579 """Save the reply to an execute_request into our results.
575 580
576 581 execute messages are never actually used. apply is used instead.
577 582 """
578 583
579 584 parent = msg['parent_header']
580 585 msg_id = parent['msg_id']
581 586 if msg_id not in self.outstanding:
582 587 if msg_id in self.history:
583 588 print ("got stale result: %s"%msg_id)
584 589 else:
585 590 print ("got unknown result: %s"%msg_id)
586 591 else:
587 592 self.outstanding.remove(msg_id)
588 593 self.results[msg_id] = self._unwrap_exception(msg['content'])
589 594
590 595 def _handle_apply_reply(self, msg):
591 596 """Save the reply to an apply_request into our results."""
592 597 parent = msg['parent_header']
593 598 msg_id = parent['msg_id']
594 599 if msg_id not in self.outstanding:
595 600 if msg_id in self.history:
596 601 print ("got stale result: %s"%msg_id)
597 602 print self.results[msg_id]
598 603 print msg
599 604 else:
600 605 print ("got unknown result: %s"%msg_id)
601 606 else:
602 607 self.outstanding.remove(msg_id)
603 608 content = msg['content']
604 609 header = msg['header']
605 610
606 611 # construct metadata:
607 612 md = self.metadata[msg_id]
608 613 md.update(self._extract_metadata(header, parent, content))
609 614 # is this redundant?
610 615 self.metadata[msg_id] = md
611 616
612 617 e_outstanding = self._outstanding_dict[md['engine_uuid']]
613 618 if msg_id in e_outstanding:
614 619 e_outstanding.remove(msg_id)
615 620
616 621 # construct result:
617 622 if content['status'] == 'ok':
618 623 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
619 624 elif content['status'] == 'aborted':
620 625 self.results[msg_id] = error.AbortedTask(msg_id)
621 626 elif content['status'] == 'resubmitted':
622 627 # TODO: handle resubmission
623 628 pass
624 629 else:
625 630 self.results[msg_id] = self._unwrap_exception(content)
626 631
627 632 def _flush_notifications(self):
628 633 """Flush notifications of engine registrations waiting
629 634 in ZMQ queue."""
630 635 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
631 636 while msg is not None:
632 637 if self.debug:
633 638 pprint(msg)
634 639 msg = msg[-1]
635 640 msg_type = msg['msg_type']
636 641 handler = self._notification_handlers.get(msg_type, None)
637 642 if handler is None:
638 643 raise Exception("Unhandled message type: %s"%msg.msg_type)
639 644 else:
640 645 handler(msg)
641 646 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
642 647
643 648 def _flush_results(self, sock):
644 649 """Flush task or queue results waiting in ZMQ queue."""
645 650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
646 651 while msg is not None:
647 652 if self.debug:
648 653 pprint(msg)
649 654 msg = msg[-1]
650 655 msg_type = msg['msg_type']
651 656 handler = self._queue_handlers.get(msg_type, None)
652 657 if handler is None:
653 658 raise Exception("Unhandled message type: %s"%msg.msg_type)
654 659 else:
655 660 handler(msg)
656 661 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
657 662
658 663 def _flush_control(self, sock):
659 664 """Flush replies from the control channel waiting
660 665 in the ZMQ queue.
661 666
662 667 Currently: ignore them."""
663 668 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
664 669 while msg is not None:
665 670 if self.debug:
666 671 pprint(msg)
667 672 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
668 673
669 674 def _flush_iopub(self, sock):
670 675 """Flush replies from the iopub channel waiting
671 676 in the ZMQ queue.
672 677 """
673 678 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
674 679 while msg is not None:
675 680 if self.debug:
676 681 pprint(msg)
677 682 msg = msg[-1]
678 683 parent = msg['parent_header']
679 684 msg_id = parent['msg_id']
680 685 content = msg['content']
681 686 header = msg['header']
682 687 msg_type = msg['msg_type']
683 688
684 689 # init metadata:
685 690 md = self.metadata[msg_id]
686 691
687 692 if msg_type == 'stream':
688 693 name = content['name']
689 694 s = md[name] or ''
690 695 md[name] = s + content['data']
691 696 elif msg_type == 'pyerr':
692 697 md.update({'pyerr' : self._unwrap_exception(content)})
693 698 else:
694 699 md.update({msg_type : content['data']})
695 700
696 701 # reduntant?
697 702 self.metadata[msg_id] = md
698 703
699 704 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
700 705
701 706 #--------------------------------------------------------------------------
702 707 # len, getitem
703 708 #--------------------------------------------------------------------------
704 709
705 710 def __len__(self):
706 711 """len(client) returns # of engines."""
707 712 return len(self.ids)
708 713
709 714 def __getitem__(self, key):
710 715 """index access returns DirectView multiplexer objects
711 716
712 717 Must be int, slice, or list/tuple/xrange of ints"""
713 718 if not isinstance(key, (int, slice, tuple, list, xrange)):
714 719 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
715 720 else:
716 721 return self.view(key, balanced=False)
717 722
718 723 #--------------------------------------------------------------------------
719 724 # Begin public methods
720 725 #--------------------------------------------------------------------------
721 726
722 727 def spin(self):
723 728 """Flush any registration notifications and execution results
724 729 waiting in the ZMQ queue.
725 730 """
726 731 if self._notification_socket:
727 732 self._flush_notifications()
728 if self._mux_socket:
729 self._flush_results(self._mux_socket)
730 if self._task_socket:
731 self._flush_results(self._task_socket)
733 if self._apply_socket:
734 self._flush_results(self._apply_socket)
732 735 if self._control_socket:
733 736 self._flush_control(self._control_socket)
734 737 if self._iopub_socket:
735 738 self._flush_iopub(self._iopub_socket)
736 739
737 740 def barrier(self, jobs=None, timeout=-1):
738 741 """waits on one or more `jobs`, for up to `timeout` seconds.
739 742
740 743 Parameters
741 744 ----------
742 745
743 746 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
744 747 ints are indices to self.history
745 748 strs are msg_ids
746 749 default: wait on all outstanding messages
747 750 timeout : float
748 751 a time in seconds, after which to give up.
749 752 default is -1, which means no timeout
750 753
751 754 Returns
752 755 -------
753 756
754 757 True : when all msg_ids are done
755 758 False : timeout reached, some msg_ids still outstanding
756 759 """
757 760 tic = time.time()
758 761 if jobs is None:
759 762 theids = self.outstanding
760 763 else:
761 764 if isinstance(jobs, (int, str, AsyncResult)):
762 765 jobs = [jobs]
763 766 theids = set()
764 767 for job in jobs:
765 768 if isinstance(job, int):
766 769 # index access
767 770 job = self.history[job]
768 771 elif isinstance(job, AsyncResult):
769 772 map(theids.add, job.msg_ids)
770 773 continue
771 774 theids.add(job)
772 775 if not theids.intersection(self.outstanding):
773 776 return True
774 777 self.spin()
775 778 while theids.intersection(self.outstanding):
776 779 if timeout >= 0 and ( time.time()-tic ) > timeout:
777 780 break
778 781 time.sleep(1e-3)
779 782 self.spin()
780 783 return len(theids.intersection(self.outstanding)) == 0
781 784
782 785 #--------------------------------------------------------------------------
783 786 # Control methods
784 787 #--------------------------------------------------------------------------
785 788
786 789 @spinfirst
787 790 @defaultblock
788 791 def clear(self, targets=None, block=None):
789 792 """Clear the namespace in target(s)."""
790 793 targets = self._build_targets(targets)[0]
791 794 for t in targets:
792 795 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
793 796 error = False
794 797 if self.block:
795 798 for i in range(len(targets)):
796 799 idents,msg = self.session.recv(self._control_socket,0)
797 800 if self.debug:
798 801 pprint(msg)
799 802 if msg['content']['status'] != 'ok':
800 803 error = self._unwrap_exception(msg['content'])
801 804 if error:
802 805 raise error
803 806
804 807
805 808 @spinfirst
806 809 @defaultblock
807 810 def abort(self, jobs=None, targets=None, block=None):
808 811 """Abort specific jobs from the execution queues of target(s).
809 812
810 813 This is a mechanism to prevent jobs that have already been submitted
811 814 from executing.
812 815
813 816 Parameters
814 817 ----------
815 818
816 819 jobs : msg_id, list of msg_ids, or AsyncResult
817 820 The jobs to be aborted
818 821
819 822
820 823 """
821 824 targets = self._build_targets(targets)[0]
822 825 msg_ids = []
823 826 if isinstance(jobs, (basestring,AsyncResult)):
824 827 jobs = [jobs]
825 828 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
826 829 if bad_ids:
827 830 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
828 831 for j in jobs:
829 832 if isinstance(j, AsyncResult):
830 833 msg_ids.extend(j.msg_ids)
831 834 else:
832 835 msg_ids.append(j)
833 836 content = dict(msg_ids=msg_ids)
834 837 for t in targets:
835 838 self.session.send(self._control_socket, 'abort_request',
836 839 content=content, ident=t)
837 840 error = False
838 841 if self.block:
839 842 for i in range(len(targets)):
840 843 idents,msg = self.session.recv(self._control_socket,0)
841 844 if self.debug:
842 845 pprint(msg)
843 846 if msg['content']['status'] != 'ok':
844 847 error = self._unwrap_exception(msg['content'])
845 848 if error:
846 849 raise error
847 850
848 851 @spinfirst
849 852 @defaultblock
850 853 def shutdown(self, targets=None, restart=False, controller=False, block=None):
851 854 """Terminates one or more engine processes, optionally including the controller."""
852 855 if controller:
853 856 targets = 'all'
854 857 targets = self._build_targets(targets)[0]
855 858 for t in targets:
856 859 self.session.send(self._control_socket, 'shutdown_request',
857 860 content={'restart':restart},ident=t)
858 861 error = False
859 862 if block or controller:
860 863 for i in range(len(targets)):
861 864 idents,msg = self.session.recv(self._control_socket,0)
862 865 if self.debug:
863 866 pprint(msg)
864 867 if msg['content']['status'] != 'ok':
865 868 error = self._unwrap_exception(msg['content'])
866 869
867 870 if controller:
868 871 time.sleep(0.25)
869 872 self.session.send(self._query_socket, 'shutdown_request')
870 873 idents,msg = self.session.recv(self._query_socket, 0)
871 874 if self.debug:
872 875 pprint(msg)
873 876 if msg['content']['status'] != 'ok':
874 877 error = self._unwrap_exception(msg['content'])
875 878
876 879 if error:
877 880 raise error
878 881
879 882 #--------------------------------------------------------------------------
880 883 # Execution methods
881 884 #--------------------------------------------------------------------------
882 885
883 886 @defaultblock
884 887 def execute(self, code, targets='all', block=None):
885 888 """Executes `code` on `targets` in blocking or nonblocking manner.
886 889
887 890 ``execute`` is always `bound` (affects engine namespace)
888 891
889 892 Parameters
890 893 ----------
891 894
892 895 code : str
893 896 the code string to be executed
894 897 targets : int/str/list of ints/strs
895 898 the engines on which to execute
896 899 default : all
897 900 block : bool
898 901 whether or not to wait until done to return
899 902 default: self.block
900 903 """
901 904 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
902 905 if not block:
903 906 return result
904 907
905 908 def run(self, filename, targets='all', block=None):
906 909 """Execute contents of `filename` on engine(s).
907 910
908 911 This simply reads the contents of the file and calls `execute`.
909 912
910 913 Parameters
911 914 ----------
912 915
913 916 filename : str
914 917 The path to the file
915 918 targets : int/str/list of ints/strs
916 919 the engines on which to execute
917 920 default : all
918 921 block : bool
919 922 whether or not to wait until done
920 923 default: self.block
921 924
922 925 """
923 926 with open(filename, 'r') as f:
924 927 # add newline in case of trailing indented whitespace
925 928 # which will cause SyntaxError
926 929 code = f.read()+'\n'
927 930 return self.execute(code, targets=targets, block=block)
928 931
929 932 def _maybe_raise(self, result):
930 933 """wrapper for maybe raising an exception if apply failed."""
931 934 if isinstance(result, error.RemoteError):
932 935 raise result
933 936
934 937 return result
935 938
936 939 def _build_dependency(self, dep):
937 940 """helper for building jsonable dependencies from various input forms"""
938 941 if isinstance(dep, Dependency):
939 942 return dep.as_dict()
940 943 elif isinstance(dep, AsyncResult):
941 944 return dep.msg_ids
942 945 elif dep is None:
943 946 return []
944 947 else:
945 948 # pass to Dependency constructor
946 949 return list(Dependency(dep))
947 950
948 951 @defaultblock
949 952 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
950 953 targets=None, balanced=None,
951 954 after=None, follow=None, timeout=None,
952 955 track=False):
953 956 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
954 957
955 958 This is the central execution command for the client.
956 959
957 960 Parameters
958 961 ----------
959 962
960 963 f : function
961 964 The fuction to be called remotely
962 965 args : tuple/list
963 966 The positional arguments passed to `f`
964 967 kwargs : dict
965 968 The keyword arguments passed to `f`
966 969 bound : bool (default: False)
967 970 Whether to pass the Engine(s) Namespace as the first argument to `f`.
968 971 block : bool (default: self.block)
969 972 Whether to wait for the result, or return immediately.
970 973 False:
971 974 returns AsyncResult
972 975 True:
973 976 returns actual result(s) of f(*args, **kwargs)
974 977 if multiple targets:
975 978 list of results, matching `targets`
976 979 targets : int,list of ints, 'all', None
977 980 Specify the destination of the job.
978 981 if None:
979 982 Submit via Task queue for load-balancing.
980 983 if 'all':
981 984 Run on all active engines
982 985 if list:
983 986 Run on each specified engine
984 987 if int:
985 988 Run on single engine
986 989
987 990 balanced : bool, default None
988 991 whether to load-balance. This will default to True
989 992 if targets is unspecified, or False if targets is specified.
990 993
991 994 The following arguments are only used when balanced is True:
992 995 after : Dependency or collection of msg_ids
993 996 Only for load-balanced execution (targets=None)
994 997 Specify a list of msg_ids as a time-based dependency.
995 998 This job will only be run *after* the dependencies
996 999 have been met.
997 1000
998 1001 follow : Dependency or collection of msg_ids
999 1002 Only for load-balanced execution (targets=None)
1000 1003 Specify a list of msg_ids as a location-based dependency.
1001 1004 This job will only be run on an engine where this dependency
1002 1005 is met.
1003 1006
1004 1007 timeout : float/int or None
1005 1008 Only for load-balanced execution (targets=None)
1006 1009 Specify an amount of time (in seconds) for the scheduler to
1007 1010 wait for dependencies to be met before failing with a
1008 1011 DependencyTimeout.
1009 1012 track : bool
1010 1013 whether to track non-copying sends.
1011 1014 [default False]
1012 1015
1013 1016 after,follow,timeout only used if `balanced=True`.
1014 1017
1015 1018 Returns
1016 1019 -------
1017 1020
1018 1021 if block is False:
1019 1022 return AsyncResult wrapping msg_ids
1020 1023 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1021 1024 else:
1022 1025 if single target:
1023 1026 return result of `f(*args, **kwargs)`
1024 1027 else:
1025 1028 return list of results, matching `targets`
1026 1029 """
1027 1030 assert not self._closed, "cannot use me anymore, I'm closed!"
1028 1031 # defaults:
1029 1032 block = block if block is not None else self.block
1030 1033 args = args if args is not None else []
1031 1034 kwargs = kwargs if kwargs is not None else {}
1032 1035
1036 if not self._ids:
1037 # flush notification socket if no engines yet
1038 any_ids = self.ids
1039 if not any_ids:
1040 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
1041
1033 1042 if balanced is None:
1034 1043 if targets is None:
1035 1044 # default to balanced if targets unspecified
1036 1045 balanced = True
1037 1046 else:
1038 1047 # otherwise default to multiplexing
1039 1048 balanced = False
1040 1049
1041 1050 if targets is None and balanced is False:
1042 1051 # default to all if *not* balanced, and targets is unspecified
1043 1052 targets = 'all'
1044 1053
1045 1054 # enforce types of f,args,kwrags
1046 1055 if not callable(f):
1047 1056 raise TypeError("f must be callable, not %s"%type(f))
1048 1057 if not isinstance(args, (tuple, list)):
1049 1058 raise TypeError("args must be tuple or list, not %s"%type(args))
1050 1059 if not isinstance(kwargs, dict):
1051 1060 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1052 1061
1053 1062 options = dict(bound=bound, block=block, targets=targets, track=track)
1054 1063
1055 1064 if balanced:
1056 1065 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1057 1066 after=after, follow=follow, **options)
1058 1067 elif follow or after or timeout:
1059 1068 msg = "follow, after, and timeout args are only used for"
1060 1069 msg += " load-balanced execution."
1061 1070 raise ValueError(msg)
1062 1071 else:
1063 1072 return self._apply_direct(f, args, kwargs, **options)
1064 1073
1065 1074 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1066 1075 after=None, follow=None, timeout=None, track=None):
1067 1076 """call f(*args, **kwargs) remotely in a load-balanced manner.
1068 1077
1069 1078 This is a private method, see `apply` for details.
1070 1079 Not to be called directly!
1071 1080 """
1072 1081
1073 1082 loc = locals()
1074 1083 for name in ('bound', 'block', 'track'):
1075 1084 assert loc[name] is not None, "kwarg %r must be specified!"%name
1076 1085
1077 if self._task_socket is None:
1086 if not self._task_ident:
1078 1087 msg = "Task farming is disabled"
1079 1088 if self._task_scheme == 'pure':
1080 1089 msg += " because the pure ZMQ scheduler cannot handle"
1081 1090 msg += " disappearing engines."
1082 1091 raise RuntimeError(msg)
1083 1092
1084 1093 if self._task_scheme == 'pure':
1085 1094 # pure zmq scheme doesn't support dependencies
1086 1095 msg = "Pure ZMQ scheduler doesn't support dependencies"
1087 1096 if (follow or after):
1088 1097 # hard fail on DAG dependencies
1089 1098 raise RuntimeError(msg)
1090 1099 if isinstance(f, dependent):
1091 1100 # soft warn on functional dependencies
1092 1101 warnings.warn(msg, RuntimeWarning)
1093 1102
1094 1103 # defaults:
1095 1104 args = args if args is not None else []
1096 1105 kwargs = kwargs if kwargs is not None else {}
1097 1106
1098 1107 if targets:
1099 1108 idents,_ = self._build_targets(targets)
1100 1109 else:
1101 1110 idents = []
1102 1111
1103 1112 after = self._build_dependency(after)
1104 1113 follow = self._build_dependency(follow)
1105 1114 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1106 1115 bufs = util.pack_apply_message(f,args,kwargs)
1107 1116 content = dict(bound=bound)
1108 1117
1109 msg = self.session.send(self._task_socket, "apply_request",
1118 msg = self.session.send(self._apply_socket, "apply_request", ident=self._task_ident,
1110 1119 content=content, buffers=bufs, subheader=subheader, track=track)
1111 1120 msg_id = msg['msg_id']
1112 1121 self.outstanding.add(msg_id)
1113 1122 self.history.append(msg_id)
1114 1123 self.metadata[msg_id]['submitted'] = datetime.now()
1115 1124 tracker = None if track is False else msg['tracker']
1116 1125 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1117 1126 if block:
1118 1127 try:
1119 1128 return ar.get()
1120 1129 except KeyboardInterrupt:
1121 1130 return ar
1122 1131 else:
1123 1132 return ar
1124 1133
1125 1134 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1126 1135 track=None):
1127 1136 """Then underlying method for applying functions to specific engines
1128 1137 via the MUX queue.
1129 1138
1130 1139 This is a private method, see `apply` for details.
1131 1140 Not to be called directly!
1132 1141 """
1142
1143 if not self._mux_ident:
1144 msg = "Multiplexing is disabled"
1145 raise RuntimeError(msg)
1146
1133 1147 loc = locals()
1134 1148 for name in ('bound', 'block', 'targets', 'track'):
1135 1149 assert loc[name] is not None, "kwarg %r must be specified!"%name
1136 1150
1137 1151 idents,targets = self._build_targets(targets)
1138 1152
1139 1153 subheader = {}
1140 1154 content = dict(bound=bound)
1141 1155 bufs = util.pack_apply_message(f,args,kwargs)
1142 1156
1143 1157 msg_ids = []
1144 1158 trackers = []
1145 1159 for ident in idents:
1146 msg = self.session.send(self._mux_socket, "apply_request",
1147 content=content, buffers=bufs, ident=ident, subheader=subheader,
1160 msg = self.session.send(self._apply_socket, "apply_request",
1161 content=content, buffers=bufs, ident=[self._mux_ident, ident], subheader=subheader,
1148 1162 track=track)
1149 1163 if track:
1150 1164 trackers.append(msg['tracker'])
1151 1165 msg_id = msg['msg_id']
1152 1166 self.outstanding.add(msg_id)
1153 1167 self._outstanding_dict[ident].add(msg_id)
1154 1168 self.history.append(msg_id)
1155 1169 msg_ids.append(msg_id)
1156 1170
1157 1171 tracker = None if track is False else zmq.MessageTracker(*trackers)
1158 1172 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1159 1173
1160 1174 if block:
1161 1175 try:
1162 1176 return ar.get()
1163 1177 except KeyboardInterrupt:
1164 1178 return ar
1165 1179 else:
1166 1180 return ar
1167 1181
1168 1182 #--------------------------------------------------------------------------
1169 1183 # construct a View object
1170 1184 #--------------------------------------------------------------------------
1171 1185
1172 1186 @defaultblock
1173 1187 def remote(self, bound=False, block=None, targets=None, balanced=None):
1174 1188 """Decorator for making a RemoteFunction"""
1175 1189 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1176 1190
1177 1191 @defaultblock
1178 1192 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1179 1193 """Decorator for making a ParallelFunction"""
1180 1194 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1181 1195
1182 1196 def _cache_view(self, targets, balanced):
1183 1197 """save views, so subsequent requests don't create new objects."""
1184 1198 if balanced:
1185 1199 view_class = LoadBalancedView
1186 1200 view_cache = self._balanced_views
1187 1201 else:
1188 1202 view_class = DirectView
1189 1203 view_cache = self._direct_views
1190 1204
1191 1205 # use str, since often targets will be a list
1192 1206 key = str(targets)
1193 1207 if key not in view_cache:
1194 1208 view_cache[key] = view_class(client=self, targets=targets)
1195 1209
1196 1210 return view_cache[key]
1197 1211
1198 1212 def view(self, targets=None, balanced=None):
1199 1213 """Method for constructing View objects.
1200 1214
1201 1215 If no arguments are specified, create a LoadBalancedView
1202 1216 using all engines. If only `targets` specified, it will
1203 1217 be a DirectView. This method is the underlying implementation
1204 1218 of ``client.__getitem__``.
1205 1219
1206 1220 Parameters
1207 1221 ----------
1208 1222
1209 1223 targets: list,slice,int,etc. [default: use all engines]
1210 1224 The engines to use for the View
1211 1225 balanced : bool [default: False if targets specified, True else]
1212 1226 whether to build a LoadBalancedView or a DirectView
1213 1227
1214 1228 """
1215 1229
1216 1230 balanced = (targets is None) if balanced is None else balanced
1217 1231
1218 1232 if targets is None:
1219 1233 if balanced:
1220 1234 return self._cache_view(None,True)
1221 1235 else:
1222 1236 targets = slice(None)
1223 1237
1224 1238 if isinstance(targets, int):
1225 1239 if targets < 0:
1226 1240 targets = self.ids[targets]
1227 1241 if targets not in self.ids:
1228 1242 raise IndexError("No such engine: %i"%targets)
1229 1243 return self._cache_view(targets, balanced)
1230 1244
1231 1245 if isinstance(targets, slice):
1232 1246 indices = range(len(self.ids))[targets]
1233 1247 ids = sorted(self._ids)
1234 1248 targets = [ ids[i] for i in indices ]
1235 1249
1236 1250 if isinstance(targets, (tuple, list, xrange)):
1237 1251 _,targets = self._build_targets(list(targets))
1238 1252 return self._cache_view(targets, balanced)
1239 1253 else:
1240 1254 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1241 1255
1242 1256 #--------------------------------------------------------------------------
1243 1257 # Data movement
1244 1258 #--------------------------------------------------------------------------
1245 1259
1246 1260 @defaultblock
1247 1261 def push(self, ns, targets='all', block=None, track=False):
1248 1262 """Push the contents of `ns` into the namespace on `target`"""
1249 1263 if not isinstance(ns, dict):
1250 1264 raise TypeError("Must be a dict, not %s"%type(ns))
1251 1265 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1252 1266 if not block:
1253 1267 return result
1254 1268
1255 1269 @defaultblock
1256 1270 def pull(self, keys, targets='all', block=None):
1257 1271 """Pull objects from `target`'s namespace by `keys`"""
1258 1272 if isinstance(keys, basestring):
1259 1273 pass
1260 1274 elif isinstance(keys, (list,tuple,set)):
1261 1275 for key in keys:
1262 1276 if not isinstance(key, basestring):
1263 1277 raise TypeError("keys must be str, not type %r"%type(key))
1264 1278 else:
1265 1279 raise TypeError("keys must be strs, not %r"%keys)
1266 1280 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1267 1281 return result
1268 1282
1269 1283 @defaultblock
1270 1284 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1271 1285 """
1272 1286 Partition a Python sequence and send the partitions to a set of engines.
1273 1287 """
1274 1288 targets = self._build_targets(targets)[-1]
1275 1289 mapObject = Map.dists[dist]()
1276 1290 nparts = len(targets)
1277 1291 msg_ids = []
1278 1292 trackers = []
1279 1293 for index, engineid in enumerate(targets):
1280 1294 partition = mapObject.getPartition(seq, index, nparts)
1281 1295 if flatten and len(partition) == 1:
1282 1296 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1283 1297 else:
1284 1298 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1285 1299 msg_ids.extend(r.msg_ids)
1286 1300 if track:
1287 1301 trackers.append(r._tracker)
1288 1302
1289 1303 if track:
1290 1304 tracker = zmq.MessageTracker(*trackers)
1291 1305 else:
1292 1306 tracker = None
1293 1307
1294 1308 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1295 1309 if block:
1296 1310 r.wait()
1297 1311 else:
1298 1312 return r
1299 1313
1300 1314 @defaultblock
1301 1315 def gather(self, key, dist='b', targets='all', block=None):
1302 1316 """
1303 1317 Gather a partitioned sequence on a set of engines as a single local seq.
1304 1318 """
1305 1319
1306 1320 targets = self._build_targets(targets)[-1]
1307 1321 mapObject = Map.dists[dist]()
1308 1322 msg_ids = []
1309 1323 for index, engineid in enumerate(targets):
1310 1324 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1311 1325
1312 1326 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1313 1327 if block:
1314 1328 return r.get()
1315 1329 else:
1316 1330 return r
1317 1331
1318 1332 #--------------------------------------------------------------------------
1319 1333 # Query methods
1320 1334 #--------------------------------------------------------------------------
1321 1335
1322 1336 @spinfirst
1323 1337 @defaultblock
1324 1338 def get_result(self, indices_or_msg_ids=None, block=None):
1325 1339 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1326 1340
1327 1341 If the client already has the results, no request to the Hub will be made.
1328 1342
1329 1343 This is a convenient way to construct AsyncResult objects, which are wrappers
1330 1344 that include metadata about execution, and allow for awaiting results that
1331 1345 were not submitted by this Client.
1332 1346
1333 1347 It can also be a convenient way to retrieve the metadata associated with
1334 1348 blocking execution, since it always retrieves
1335 1349
1336 1350 Examples
1337 1351 --------
1338 1352 ::
1339 1353
1340 1354 In [10]: r = client.apply()
1341 1355
1342 1356 Parameters
1343 1357 ----------
1344 1358
1345 1359 indices_or_msg_ids : integer history index, str msg_id, or list of either
1346 1360 The indices or msg_ids of indices to be retrieved
1347 1361
1348 1362 block : bool
1349 1363 Whether to wait for the result to be done
1350 1364
1351 1365 Returns
1352 1366 -------
1353 1367
1354 1368 AsyncResult
1355 1369 A single AsyncResult object will always be returned.
1356 1370
1357 1371 AsyncHubResult
1358 1372 A subclass of AsyncResult that retrieves results from the Hub
1359 1373
1360 1374 """
1361 1375 if indices_or_msg_ids is None:
1362 1376 indices_or_msg_ids = -1
1363 1377
1364 1378 if not isinstance(indices_or_msg_ids, (list,tuple)):
1365 1379 indices_or_msg_ids = [indices_or_msg_ids]
1366 1380
1367 1381 theids = []
1368 1382 for id in indices_or_msg_ids:
1369 1383 if isinstance(id, int):
1370 1384 id = self.history[id]
1371 1385 if not isinstance(id, str):
1372 1386 raise TypeError("indices must be str or int, not %r"%id)
1373 1387 theids.append(id)
1374 1388
1375 1389 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1376 1390 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1377 1391
1378 1392 if remote_ids:
1379 1393 ar = AsyncHubResult(self, msg_ids=theids)
1380 1394 else:
1381 1395 ar = AsyncResult(self, msg_ids=theids)
1382 1396
1383 1397 if block:
1384 1398 ar.wait()
1385 1399
1386 1400 return ar
1387 1401
1388 1402 @spinfirst
1389 1403 def result_status(self, msg_ids, status_only=True):
1390 1404 """Check on the status of the result(s) of the apply request with `msg_ids`.
1391 1405
1392 1406 If status_only is False, then the actual results will be retrieved, else
1393 1407 only the status of the results will be checked.
1394 1408
1395 1409 Parameters
1396 1410 ----------
1397 1411
1398 1412 msg_ids : list of msg_ids
1399 1413 if int:
1400 1414 Passed as index to self.history for convenience.
1401 1415 status_only : bool (default: True)
1402 1416 if False:
1403 1417 Retrieve the actual results of completed tasks.
1404 1418
1405 1419 Returns
1406 1420 -------
1407 1421
1408 1422 results : dict
1409 1423 There will always be the keys 'pending' and 'completed', which will
1410 1424 be lists of msg_ids that are incomplete or complete. If `status_only`
1411 1425 is False, then completed results will be keyed by their `msg_id`.
1412 1426 """
1413 1427 if not isinstance(msg_ids, (list,tuple)):
1414 1428 msg_ids = [msg_ids]
1415 1429
1416 1430 theids = []
1417 1431 for msg_id in msg_ids:
1418 1432 if isinstance(msg_id, int):
1419 1433 msg_id = self.history[msg_id]
1420 1434 if not isinstance(msg_id, basestring):
1421 1435 raise TypeError("msg_ids must be str, not %r"%msg_id)
1422 1436 theids.append(msg_id)
1423 1437
1424 1438 completed = []
1425 1439 local_results = {}
1426 1440
1427 1441 # comment this block out to temporarily disable local shortcut:
1428 1442 for msg_id in theids:
1429 1443 if msg_id in self.results:
1430 1444 completed.append(msg_id)
1431 1445 local_results[msg_id] = self.results[msg_id]
1432 1446 theids.remove(msg_id)
1433 1447
1434 1448 if theids: # some not locally cached
1435 1449 content = dict(msg_ids=theids, status_only=status_only)
1436 1450 msg = self.session.send(self._query_socket, "result_request", content=content)
1437 1451 zmq.select([self._query_socket], [], [])
1438 1452 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1439 1453 if self.debug:
1440 1454 pprint(msg)
1441 1455 content = msg['content']
1442 1456 if content['status'] != 'ok':
1443 1457 raise self._unwrap_exception(content)
1444 1458 buffers = msg['buffers']
1445 1459 else:
1446 1460 content = dict(completed=[],pending=[])
1447 1461
1448 1462 content['completed'].extend(completed)
1449 1463
1450 1464 if status_only:
1451 1465 return content
1452 1466
1453 1467 failures = []
1454 1468 # load cached results into result:
1455 1469 content.update(local_results)
1456 1470 # update cache with results:
1457 1471 for msg_id in sorted(theids):
1458 1472 if msg_id in content['completed']:
1459 1473 rec = content[msg_id]
1460 1474 parent = rec['header']
1461 1475 header = rec['result_header']
1462 1476 rcontent = rec['result_content']
1463 1477 iodict = rec['io']
1464 1478 if isinstance(rcontent, str):
1465 1479 rcontent = self.session.unpack(rcontent)
1466 1480
1467 1481 md = self.metadata[msg_id]
1468 1482 md.update(self._extract_metadata(header, parent, rcontent))
1469 1483 md.update(iodict)
1470 1484
1471 1485 if rcontent['status'] == 'ok':
1472 1486 res,buffers = util.unserialize_object(buffers)
1473 1487 else:
1474 1488 print rcontent
1475 1489 res = self._unwrap_exception(rcontent)
1476 1490 failures.append(res)
1477 1491
1478 1492 self.results[msg_id] = res
1479 1493 content[msg_id] = res
1480 1494
1481 1495 if len(theids) == 1 and failures:
1482 1496 raise failures[0]
1483 1497
1484 1498 error.collect_exceptions(failures, "result_status")
1485 1499 return content
1486 1500
1487 1501 @spinfirst
1488 1502 def queue_status(self, targets='all', verbose=False):
1489 1503 """Fetch the status of engine queues.
1490 1504
1491 1505 Parameters
1492 1506 ----------
1493 1507
1494 1508 targets : int/str/list of ints/strs
1495 1509 the engines whose states are to be queried.
1496 1510 default : all
1497 1511 verbose : bool
1498 1512 Whether to return lengths only, or lists of ids for each element
1499 1513 """
1500 1514 targets = self._build_targets(targets)[1]
1501 1515 content = dict(targets=targets, verbose=verbose)
1502 1516 self.session.send(self._query_socket, "queue_request", content=content)
1503 1517 idents,msg = self.session.recv(self._query_socket, 0)
1504 1518 if self.debug:
1505 1519 pprint(msg)
1506 1520 content = msg['content']
1507 1521 status = content.pop('status')
1508 1522 if status != 'ok':
1509 1523 raise self._unwrap_exception(content)
1510 1524 return util.rekey(content)
1511 1525
1512 1526 @spinfirst
1513 1527 def purge_results(self, jobs=[], targets=[]):
1514 1528 """Tell the controller to forget results.
1515 1529
1516 1530 Individual results can be purged by msg_id, or the entire
1517 1531 history of specific targets can be purged.
1518 1532
1519 1533 Parameters
1520 1534 ----------
1521 1535
1522 1536 jobs : str or list of strs or AsyncResult objects
1523 1537 the msg_ids whose results should be forgotten.
1524 1538 targets : int/str/list of ints/strs
1525 1539 The targets, by uuid or int_id, whose entire history is to be purged.
1526 1540 Use `targets='all'` to scrub everything from the controller's memory.
1527 1541
1528 1542 default : None
1529 1543 """
1530 1544 if not targets and not jobs:
1531 1545 raise ValueError("Must specify at least one of `targets` and `jobs`")
1532 1546 if targets:
1533 1547 targets = self._build_targets(targets)[1]
1534 1548
1535 1549 # construct msg_ids from jobs
1536 1550 msg_ids = []
1537 1551 if isinstance(jobs, (basestring,AsyncResult)):
1538 1552 jobs = [jobs]
1539 1553 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1540 1554 if bad_ids:
1541 1555 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1542 1556 for j in jobs:
1543 1557 if isinstance(j, AsyncResult):
1544 1558 msg_ids.extend(j.msg_ids)
1545 1559 else:
1546 1560 msg_ids.append(j)
1547 1561
1548 1562 content = dict(targets=targets, msg_ids=msg_ids)
1549 1563 self.session.send(self._query_socket, "purge_request", content=content)
1550 1564 idents, msg = self.session.recv(self._query_socket, 0)
1551 1565 if self.debug:
1552 1566 pprint(msg)
1553 1567 content = msg['content']
1554 1568 if content['status'] != 'ok':
1555 1569 raise self._unwrap_exception(content)
1556 1570
1557 1571
1558 1572 __all__ = [ 'Client',
1559 1573 'depend',
1560 1574 'require',
1561 1575 'remote',
1562 1576 'parallel',
1563 1577 'RemoteFunction',
1564 1578 'ParallelFunction',
1565 1579 'DirectView',
1566 1580 'LoadBalancedView',
1567 1581 'AsyncResult',
1568 1582 'AsyncMapResult',
1569 1583 'Reference'
1570 1584 ]
@@ -1,115 +1,118 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller with 0MQ
3 3 This is a collection of one Hub and several Schedulers.
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15 from __future__ import print_function
16 16
17 17 import logging
18 18 from multiprocessing import Process
19 19
20 20 import zmq
21 21 from zmq.devices import ProcessMonitoredQueue
22 22 # internal:
23 23 from IPython.utils.importstring import import_item
24 24 from IPython.utils.traitlets import Int, CStr, Instance, List, Bool
25 25
26 26 from .entry_point import signal_children
27 27 from .hub import Hub, HubFactory
28 28 from .scheduler import launch_scheduler
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Configurable
32 32 #-----------------------------------------------------------------------------
33 33
34 34
35 35 class ControllerFactory(HubFactory):
36 36 """Configurable for setting up a Hub and Schedulers."""
37 37
38 38 usethreads = Bool(False, config=True)
39 39 # pure-zmq downstream HWM
40 40 hwm = Int(0, config=True)
41 41
42 42 # internal
43 43 children = List()
44 44 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
45 45
46 46 def _usethreads_changed(self, name, old, new):
47 47 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
48 48
49 49 def __init__(self, **kwargs):
50 50 super(ControllerFactory, self).__init__(**kwargs)
51 51 self.subconstructors.append(self.construct_schedulers)
52 52
53 53 def start(self):
54 54 super(ControllerFactory, self).start()
55 55 child_procs = []
56 56 for child in self.children:
57 57 child.start()
58 58 if isinstance(child, ProcessMonitoredQueue):
59 59 child_procs.append(child.launcher)
60 60 elif isinstance(child, Process):
61 61 child_procs.append(child)
62 62 if child_procs:
63 63 signal_children(child_procs)
64 64
65 65
66 66 def construct_schedulers(self):
67 67 children = self.children
68 68 mq = import_item(self.mq_class)
69 69
70 70 maybe_inproc = 'inproc://monitor' if self.usethreads else self.monitor_url
71 71 # IOPub relay (in a Process)
72 72 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
73 73 q.bind_in(self.client_info['iopub'])
74 74 q.bind_out(self.engine_info['iopub'])
75 75 q.setsockopt_out(zmq.SUBSCRIBE, '')
76 76 q.connect_mon(maybe_inproc)
77 77 q.daemon=True
78 78 children.append(q)
79 79
80 80 # Multiplexer Queue (in a Process)
81 81 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
82 82 q.bind_in(self.client_info['mux'])
83 q.setsockopt_in(zmq.IDENTITY, 'mux')
83 84 q.bind_out(self.engine_info['mux'])
84 85 q.connect_mon(maybe_inproc)
85 86 q.daemon=True
86 87 children.append(q)
87 88
88 89 # Control Queue (in a Process)
89 90 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
90 91 q.bind_in(self.client_info['control'])
92 q.setsockopt_in(zmq.IDENTITY, 'control')
91 93 q.bind_out(self.engine_info['control'])
92 94 q.connect_mon(maybe_inproc)
93 95 q.daemon=True
94 96 children.append(q)
95 97 # Task Queue (in a Process)
96 98 if self.scheme == 'pure':
97 99 self.log.warn("task::using pure XREQ Task scheduler")
98 100 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
99 101 q.setsockopt_out(zmq.HWM, self.hwm)
100 102 q.bind_in(self.client_info['task'][1])
103 q.setsockopt_in(zmq.IDENTITY, 'task')
101 104 q.bind_out(self.engine_info['task'])
102 105 q.connect_mon(maybe_inproc)
103 106 q.daemon=True
104 107 children.append(q)
105 108 elif self.scheme == 'none':
106 109 self.log.warn("task::using no Task scheduler")
107 110
108 111 else:
109 112 self.log.info("task::using Python %s Task scheduler"%self.scheme)
110 113 sargs = (self.client_info['task'][1], self.engine_info['task'], self.monitor_url, self.client_info['notification'])
111 114 kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level, config=self.config)
112 115 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
113 116 q.daemon=True
114 117 children.append(q)
115 118
@@ -1,580 +1,584 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 #----------------------------------------------------------------------
9 9 # Imports
10 10 #----------------------------------------------------------------------
11 11
12 12 from __future__ import print_function
13 13
14 14 import logging
15 15 import sys
16 16
17 17 from datetime import datetime, timedelta
18 18 from random import randint, random
19 19 from types import FunctionType
20 20
21 21 try:
22 22 import numpy
23 23 except ImportError:
24 24 numpy = None
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop, zmqstream
28 28
29 29 # local imports
30 30 from IPython.external.decorator import decorator
31 31 from IPython.utils.traitlets import Instance, Dict, List, Set
32 32
33 33 from . import error
34 34 from .dependency import Dependency
35 35 from .entry_point import connect_logger, local_logger
36 36 from .factory import SessionFactory
37 37
38 38
39 39 @decorator
40 40 def logged(f,self,*args,**kwargs):
41 41 # print ("#--------------------")
42 42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
43 43 # print ("#--")
44 44 return f(self,*args, **kwargs)
45 45
46 46 #----------------------------------------------------------------------
47 47 # Chooser functions
48 48 #----------------------------------------------------------------------
49 49
50 50 def plainrandom(loads):
51 51 """Plain random pick."""
52 52 n = len(loads)
53 53 return randint(0,n-1)
54 54
55 55 def lru(loads):
56 56 """Always pick the front of the line.
57 57
58 58 The content of `loads` is ignored.
59 59
60 60 Assumes LRU ordering of loads, with oldest first.
61 61 """
62 62 return 0
63 63
64 64 def twobin(loads):
65 65 """Pick two at random, use the LRU of the two.
66 66
67 67 The content of loads is ignored.
68 68
69 69 Assumes LRU ordering of loads, with oldest first.
70 70 """
71 71 n = len(loads)
72 72 a = randint(0,n-1)
73 73 b = randint(0,n-1)
74 74 return min(a,b)
75 75
76 76 def weighted(loads):
77 77 """Pick two at random using inverse load as weight.
78 78
79 79 Return the less loaded of the two.
80 80 """
81 81 # weight 0 a million times more than 1:
82 82 weights = 1./(1e-6+numpy.array(loads))
83 83 sums = weights.cumsum()
84 84 t = sums[-1]
85 85 x = random()*t
86 86 y = random()*t
87 87 idx = 0
88 88 idy = 0
89 89 while sums[idx] < x:
90 90 idx += 1
91 91 while sums[idy] < y:
92 92 idy += 1
93 93 if weights[idy] > weights[idx]:
94 94 return idy
95 95 else:
96 96 return idx
97 97
98 98 def leastload(loads):
99 99 """Always choose the lowest load.
100 100
101 101 If the lowest load occurs more than once, the first
102 102 occurance will be used. If loads has LRU ordering, this means
103 103 the LRU of those with the lowest load is chosen.
104 104 """
105 105 return loads.index(min(loads))
106 106
107 107 #---------------------------------------------------------------------
108 108 # Classes
109 109 #---------------------------------------------------------------------
110 110 # store empty default dependency:
111 111 MET = Dependency([])
112 112
113 113 class TaskScheduler(SessionFactory):
114 114 """Python TaskScheduler object.
115 115
116 116 This is the simplest object that supports msg_id based
117 117 DAG dependencies. *Only* task msg_ids are checked, not
118 118 msg_ids of jobs submitted via the MUX queue.
119 119
120 120 """
121 121
122 122 # input arguments:
123 123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
124 124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
125 125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
126 126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
127 127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
128 128
129 129 # internals:
130 130 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
131 131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
132 132 pending = Dict() # dict by engine_uuid of submitted tasks
133 133 completed = Dict() # dict by engine_uuid of completed tasks
134 134 failed = Dict() # dict by engine_uuid of failed tasks
135 135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
136 136 clients = Dict() # dict by msg_id for who submitted the task
137 137 targets = List() # list of target IDENTs
138 138 loads = List() # list of engine loads
139 139 all_completed = Set() # set of all completed tasks
140 140 all_failed = Set() # set of all failed tasks
141 141 all_done = Set() # set of all finished tasks=union(completed,failed)
142 142 all_ids = Set() # set of all submitted task IDs
143 143 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
144 144 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
145 145
146 146
147 147 def start(self):
148 148 self.engine_stream.on_recv(self.dispatch_result, copy=False)
149 149 self._notification_handlers = dict(
150 150 registration_notification = self._register_engine,
151 151 unregistration_notification = self._unregister_engine
152 152 )
153 153 self.notifier_stream.on_recv(self.dispatch_notification)
154 154 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
155 155 self.auditor.start()
156 156 self.log.info("Scheduler started...%r"%self)
157 157
158 158 def resume_receiving(self):
159 159 """Resume accepting jobs."""
160 160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
161 161
162 162 def stop_receiving(self):
163 163 """Stop accepting jobs while there are no engines.
164 164 Leave them in the ZMQ queue."""
165 165 self.client_stream.on_recv(None)
166 166
167 167 #-----------------------------------------------------------------------
168 168 # [Un]Registration Handling
169 169 #-----------------------------------------------------------------------
170 170
171 171 def dispatch_notification(self, msg):
172 172 """dispatch register/unregister events."""
173 173 idents,msg = self.session.feed_identities(msg)
174 174 msg = self.session.unpack_message(msg)
175 175 msg_type = msg['msg_type']
176 176 handler = self._notification_handlers.get(msg_type, None)
177 177 if handler is None:
178 178 raise Exception("Unhandled message type: %s"%msg_type)
179 179 else:
180 180 try:
181 181 handler(str(msg['content']['queue']))
182 182 except KeyError:
183 183 self.log.error("task::Invalid notification msg: %s"%msg)
184 184
185 185 @logged
186 186 def _register_engine(self, uid):
187 187 """New engine with ident `uid` became available."""
188 188 # head of the line:
189 189 self.targets.insert(0,uid)
190 190 self.loads.insert(0,0)
191 191 # initialize sets
192 192 self.completed[uid] = set()
193 193 self.failed[uid] = set()
194 194 self.pending[uid] = {}
195 195 if len(self.targets) == 1:
196 196 self.resume_receiving()
197 197
198 198 def _unregister_engine(self, uid):
199 199 """Existing engine with ident `uid` became unavailable."""
200 200 if len(self.targets) == 1:
201 201 # this was our only engine
202 202 self.stop_receiving()
203 203
204 204 # handle any potentially finished tasks:
205 205 self.engine_stream.flush()
206 206
207 207 self.completed.pop(uid)
208 208 self.failed.pop(uid)
209 209 # don't pop destinations, because it might be used later
210 210 # map(self.destinations.pop, self.completed.pop(uid))
211 211 # map(self.destinations.pop, self.failed.pop(uid))
212 212
213 213 idx = self.targets.index(uid)
214 214 self.targets.pop(idx)
215 215 self.loads.pop(idx)
216 216
217 217 # wait 5 seconds before cleaning up pending jobs, since the results might
218 218 # still be incoming
219 219 if self.pending[uid]:
220 220 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
221 221 dc.start()
222 222
223 223 @logged
224 224 def handle_stranded_tasks(self, engine):
225 225 """Deal with jobs resident in an engine that died."""
226 226 lost = self.pending.pop(engine)
227 227
228 228 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
229 229 self.all_failed.add(msg_id)
230 230 self.all_done.add(msg_id)
231 231 idents,msg = self.session.feed_identities(raw_msg, copy=False)
232 232 msg = self.session.unpack_message(msg, copy=False, content=False)
233 233 parent = msg['header']
234 234 idents = [idents[0],engine]+idents[1:]
235 235 print (idents)
236 236 try:
237 237 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
238 238 except:
239 239 content = error.wrap_exception()
240 240 msg = self.session.send(self.client_stream, 'apply_reply', content,
241 241 parent=parent, ident=idents)
242 242 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
243 243 self.update_graph(msg_id)
244 244
245 245
246 246 #-----------------------------------------------------------------------
247 247 # Job Submission
248 248 #-----------------------------------------------------------------------
249 249 @logged
250 250 def dispatch_submission(self, raw_msg):
251 251 """Dispatch job submission to appropriate handlers."""
252 252 # ensure targets up to date:
253 253 self.notifier_stream.flush()
254 254 try:
255 255 idents, msg = self.session.feed_identities(raw_msg, copy=False)
256 256 msg = self.session.unpack_message(msg, content=False, copy=False)
257 257 except:
258 258 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
259 259 return
260 260
261 261 # send to monitor
262 262 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
263 263
264 264 header = msg['header']
265 265 msg_id = header['msg_id']
266 266 self.all_ids.add(msg_id)
267 267
268 268 # targets
269 269 targets = set(header.get('targets', []))
270 270
271 271 # time dependencies
272 272 after = Dependency(header.get('after', []))
273 273 if after.all:
274 274 after.difference_update(self.all_completed)
275 275 if not after.success_only:
276 276 after.difference_update(self.all_failed)
277 277 if after.check(self.all_completed, self.all_failed):
278 278 # recast as empty set, if `after` already met,
279 279 # to prevent unnecessary set comparisons
280 280 after = MET
281 281
282 282 # location dependencies
283 283 follow = Dependency(header.get('follow', []))
284 284
285 285 # turn timeouts into datetime objects:
286 286 timeout = header.get('timeout', None)
287 287 if timeout:
288 288 timeout = datetime.now() + timedelta(0,timeout,0)
289 289
290 290 args = [raw_msg, targets, after, follow, timeout]
291 291
292 292 # validate and reduce dependencies:
293 293 for dep in after,follow:
294 294 # check valid:
295 295 if msg_id in dep or dep.difference(self.all_ids):
296 296 self.depending[msg_id] = args
297 297 return self.fail_unreachable(msg_id, error.InvalidDependency)
298 298 # check if unreachable:
299 299 if dep.unreachable(self.all_failed):
300 300 self.depending[msg_id] = args
301 301 return self.fail_unreachable(msg_id)
302 302
303 303 if after.check(self.all_completed, self.all_failed):
304 304 # time deps already met, try to run
305 305 if not self.maybe_run(msg_id, *args):
306 306 # can't run yet
307 307 self.save_unmet(msg_id, *args)
308 308 else:
309 309 self.save_unmet(msg_id, *args)
310 310
311 311 # @logged
312 312 def audit_timeouts(self):
313 313 """Audit all waiting tasks for expired timeouts."""
314 314 now = datetime.now()
315 315 for msg_id in self.depending.keys():
316 316 # must recheck, in case one failure cascaded to another:
317 317 if msg_id in self.depending:
318 318 raw,after,targets,follow,timeout = self.depending[msg_id]
319 319 if timeout and timeout < now:
320 320 self.fail_unreachable(msg_id, timeout=True)
321 321
322 322 @logged
323 323 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
324 324 """a task has become unreachable, send a reply with an ImpossibleDependency
325 325 error."""
326 326 if msg_id not in self.depending:
327 327 self.log.error("msg %r already failed!"%msg_id)
328 328 return
329 329 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
330 330 for mid in follow.union(after):
331 331 if mid in self.graph:
332 332 self.graph[mid].remove(msg_id)
333 333
334 334 # FIXME: unpacking a message I've already unpacked, but didn't save:
335 335 idents,msg = self.session.feed_identities(raw_msg, copy=False)
336 336 msg = self.session.unpack_message(msg, copy=False, content=False)
337 337 header = msg['header']
338 338
339 339 try:
340 340 raise why()
341 341 except:
342 342 content = error.wrap_exception()
343 343
344 344 self.all_done.add(msg_id)
345 345 self.all_failed.add(msg_id)
346 346
347 347 msg = self.session.send(self.client_stream, 'apply_reply', content,
348 348 parent=header, ident=idents)
349 349 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
350 350
351 351 self.update_graph(msg_id, success=False)
352 352
353 353 @logged
354 354 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
355 355 """check location dependencies, and run if they are met."""
356 356 blacklist = self.blacklist.setdefault(msg_id, set())
357 357 if follow or targets or blacklist:
358 358 # we need a can_run filter
359 359 def can_run(idx):
360 360 target = self.targets[idx]
361 361 # check targets
362 362 if targets and target not in targets:
363 363 return False
364 364 # check blacklist
365 365 if target in blacklist:
366 366 return False
367 367 # check follow
368 368 return follow.check(self.completed[target], self.failed[target])
369 369
370 370 indices = filter(can_run, range(len(self.targets)))
371 371 if not indices:
372 372 # couldn't run
373 373 if follow.all:
374 374 # check follow for impossibility
375 375 dests = set()
376 376 relevant = self.all_completed if follow.success_only else self.all_done
377 377 for m in follow.intersection(relevant):
378 378 dests.add(self.destinations[m])
379 379 if len(dests) > 1:
380 380 self.fail_unreachable(msg_id)
381 381 return False
382 382 if targets:
383 383 # check blacklist+targets for impossibility
384 384 targets.difference_update(blacklist)
385 385 if not targets or not targets.intersection(self.targets):
386 386 self.fail_unreachable(msg_id)
387 387 return False
388 388 return False
389 389 else:
390 390 indices = None
391 391
392 392 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
393 393 return True
394 394
395 395 @logged
396 396 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
397 397 """Save a message for later submission when its dependencies are met."""
398 398 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
399 399 # track the ids in follow or after, but not those already finished
400 400 for dep_id in after.union(follow).difference(self.all_done):
401 401 if dep_id not in self.graph:
402 402 self.graph[dep_id] = set()
403 403 self.graph[dep_id].add(msg_id)
404 404
405 405 @logged
406 406 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
407 407 """Submit a task to any of a subset of our targets."""
408 408 if indices:
409 409 loads = [self.loads[i] for i in indices]
410 410 else:
411 411 loads = self.loads
412 412 idx = self.scheme(loads)
413 413 if indices:
414 414 idx = indices[idx]
415 415 target = self.targets[idx]
416 416 # print (target, map(str, msg[:3]))
417 417 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
418 418 self.engine_stream.send_multipart(raw_msg, copy=False)
419 419 self.add_job(idx)
420 420 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
421 421 content = dict(msg_id=msg_id, engine_id=target)
422 422 self.session.send(self.mon_stream, 'task_destination', content=content,
423 423 ident=['tracktask',self.session.session])
424 424
425 425 #-----------------------------------------------------------------------
426 426 # Result Handling
427 427 #-----------------------------------------------------------------------
428 428 @logged
429 429 def dispatch_result(self, raw_msg):
430 430 """dispatch method for result replies"""
431 431 try:
432 432 idents,msg = self.session.feed_identities(raw_msg, copy=False)
433 433 msg = self.session.unpack_message(msg, content=False, copy=False)
434 434 except:
435 435 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
436 436 return
437 437
438 438 header = msg['header']
439 439 if header.get('dependencies_met', True):
440 440 success = (header['status'] == 'ok')
441 441 self.handle_result(idents, msg['parent_header'], raw_msg, success)
442 442 # send to Hub monitor
443 443 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
444 444 else:
445 445 self.handle_unmet_dependency(idents, msg['parent_header'])
446 446
447 447 @logged
448 448 def handle_result(self, idents, parent, raw_msg, success=True):
449 449 """handle a real task result, either success or failure"""
450 450 # first, relay result to client
451 451 engine = idents[0]
452 452 client = idents[1]
453 453 # swap_ids for XREP-XREP mirror
454 454 raw_msg[:2] = [client,engine]
455 455 # print (map(str, raw_msg[:4]))
456 456 self.client_stream.send_multipart(raw_msg, copy=False)
457 457 # now, update our data structures
458 458 msg_id = parent['msg_id']
459 459 self.blacklist.pop(msg_id, None)
460 460 self.pending[engine].pop(msg_id)
461 461 if success:
462 462 self.completed[engine].add(msg_id)
463 463 self.all_completed.add(msg_id)
464 464 else:
465 465 self.failed[engine].add(msg_id)
466 466 self.all_failed.add(msg_id)
467 467 self.all_done.add(msg_id)
468 468 self.destinations[msg_id] = engine
469 469
470 470 self.update_graph(msg_id, success)
471 471
472 472 @logged
473 473 def handle_unmet_dependency(self, idents, parent):
474 474 """handle an unmet dependency"""
475 475 engine = idents[0]
476 476 msg_id = parent['msg_id']
477 477
478 478 if msg_id not in self.blacklist:
479 479 self.blacklist[msg_id] = set()
480 480 self.blacklist[msg_id].add(engine)
481 481
482 482 args = self.pending[engine].pop(msg_id)
483 483 raw,targets,after,follow,timeout = args
484 484
485 485 if self.blacklist[msg_id] == targets:
486 486 self.depending[msg_id] = args
487 487 return self.fail_unreachable(msg_id)
488 488
489 489 elif not self.maybe_run(msg_id, *args):
490 490 # resubmit failed, put it back in our dependency tree
491 491 self.save_unmet(msg_id, *args)
492 492
493 493
494 494 @logged
495 495 def update_graph(self, dep_id, success=True):
496 496 """dep_id just finished. Update our dependency
497 497 graph and submit any jobs that just became runable."""
498 498 # print ("\n\n***********")
499 499 # pprint (dep_id)
500 500 # pprint (self.graph)
501 501 # pprint (self.depending)
502 502 # pprint (self.all_completed)
503 503 # pprint (self.all_failed)
504 504 # print ("\n\n***********\n\n")
505 505 if dep_id not in self.graph:
506 506 return
507 507 jobs = self.graph.pop(dep_id)
508 508
509 509 for msg_id in jobs:
510 510 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
511 511 # if dep_id in after:
512 512 # if after.all and (success or not after.success_only):
513 513 # after.remove(dep_id)
514 514
515 515 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
516 516 self.fail_unreachable(msg_id)
517 517
518 518 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
519 519 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
520 520
521 521 self.depending.pop(msg_id)
522 522 for mid in follow.union(after):
523 523 if mid in self.graph:
524 524 self.graph[mid].remove(msg_id)
525 525
526 526 #----------------------------------------------------------------------
527 527 # methods to be overridden by subclasses
528 528 #----------------------------------------------------------------------
529 529
530 530 def add_job(self, idx):
531 531 """Called after self.targets[idx] just got the job with header.
532 532 Override with subclasses. The default ordering is simple LRU.
533 533 The default loads are the number of outstanding jobs."""
534 534 self.loads[idx] += 1
535 535 for lis in (self.targets, self.loads):
536 536 lis.append(lis.pop(idx))
537 537
538 538
539 539 def finish_job(self, idx):
540 540 """Called after self.targets[idx] just finished a job.
541 541 Override with subclasses."""
542 542 self.loads[idx] -= 1
543 543
544 544
545 545
546 546 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
547 log_addr=None, loglevel=logging.DEBUG, scheme='lru'):
547 log_addr=None, loglevel=logging.DEBUG, scheme='lru',
548 identity=b'task'):
548 549 from zmq.eventloop import ioloop
549 550 from zmq.eventloop.zmqstream import ZMQStream
550 551
551 552 ctx = zmq.Context()
552 553 loop = ioloop.IOLoop()
553 554 print (in_addr, out_addr, mon_addr, not_addr)
554 555 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
556 ins.setsockopt(zmq.IDENTITY, identity)
555 557 ins.bind(in_addr)
558
556 559 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
560 outs.setsockopt(zmq.IDENTITY, identity)
557 561 outs.bind(out_addr)
558 562 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
559 563 mons.connect(mon_addr)
560 564 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
561 565 nots.setsockopt(zmq.SUBSCRIBE, '')
562 566 nots.connect(not_addr)
563 567
564 568 scheme = globals().get(scheme, None)
565 569 # setup logging
566 570 if log_addr:
567 571 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
568 572 else:
569 573 local_logger(logname, loglevel)
570 574
571 575 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
572 576 mon_stream=mons, notifier_stream=nots,
573 577 scheme=scheme, loop=loop, logname=logname,
574 578 config=config)
575 579 scheduler.start()
576 580 try:
577 581 loop.start()
578 582 except KeyboardInterrupt:
579 583 print ("interrupted, exiting...", file=sys.__stderr__)
580 584
@@ -1,48 +1,48 b''
1 1 """toplevel setup/teardown for parallel tests."""
2 2
3 3 import tempfile
4 4 import time
5 5 from subprocess import Popen, PIPE, STDOUT
6 6
7 7 from IPython.zmq.parallel import client
8 8
9 9 processes = []
10 10 blackhole = tempfile.TemporaryFile()
11 11
12 12 # nose setup/teardown
13 13
14 14 def setup():
15 cp = Popen('ipcontrollerz --profile iptest -r --log-level 40'.split(), stdout=blackhole, stderr=STDOUT)
15 cp = Popen('ipcontrollerz --profile iptest -r --log-level 10 --log-to-file'.split(), stdout=blackhole, stderr=STDOUT)
16 16 processes.append(cp)
17 17 time.sleep(.5)
18 18 add_engine()
19 19 c = client.Client(profile='iptest')
20 20 while not c.ids:
21 21 time.sleep(.1)
22 22 c.spin()
23 23
24 24 def add_engine(profile='iptest'):
25 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '40'], stdout=blackhole, stderr=STDOUT)
25 ep = Popen(['ipenginez']+ ['--profile', profile, '--log-level', '10', '--log-to-file'], stdout=blackhole, stderr=STDOUT)
26 26 # ep.start()
27 27 processes.append(ep)
28 28 return ep
29 29
30 30 def teardown():
31 31 time.sleep(1)
32 32 while processes:
33 33 p = processes.pop()
34 34 if p.poll() is None:
35 35 try:
36 36 p.terminate()
37 37 except Exception, e:
38 38 print e
39 39 pass
40 40 if p.poll() is None:
41 41 time.sleep(.25)
42 42 if p.poll() is None:
43 43 try:
44 44 print 'killing'
45 45 p.kill()
46 46 except:
47 47 print "couldn't shutdown process: ", p
48 48
@@ -1,100 +1,105 b''
1 import sys
2 import tempfile
1 3 import time
2 4 from signal import SIGINT
3 5 from multiprocessing import Process
4 6
5 7 from nose import SkipTest
6 8
7 9 from zmq.tests import BaseZMQTestCase
8 10
9 11 from IPython.external.decorator import decorator
10 12
11 13 from IPython.zmq.parallel import error
12 14 from IPython.zmq.parallel.client import Client
13 15 from IPython.zmq.parallel.ipcluster import launch_process
14 16 from IPython.zmq.parallel.entry_point import select_random_ports
15 17 from IPython.zmq.parallel.tests import processes,add_engine
16 18
17 19 # simple tasks for use in apply tests
18 20
19 21 def segfault():
20 22 """this will segfault"""
21 23 import ctypes
22 24 ctypes.memset(-1,0,1)
23 25
24 26 def wait(n):
25 27 """sleep for a time"""
26 28 import time
27 29 time.sleep(n)
28 30 return n
29 31
30 32 def raiser(eclass):
31 33 """raise an exception"""
32 34 raise eclass()
33 35
34 36 # test decorator for skipping tests when libraries are unavailable
35 37 def skip_without(*names):
36 38 """skip a test if some names are not importable"""
37 39 @decorator
38 40 def skip_without_names(f, *args, **kwargs):
39 41 """decorator to skip tests in the absence of numpy."""
40 42 for name in names:
41 43 try:
42 44 __import__(name)
43 45 except ImportError:
44 46 raise SkipTest
45 47 return f(*args, **kwargs)
46 48 return skip_without_names
47 49
48 50
49 51 class ClusterTestCase(BaseZMQTestCase):
50 52
51 53 def add_engines(self, n=1, block=True):
52 54 """add multiple engines to our cluster"""
53 55 for i in range(n):
54 56 self.engines.append(add_engine())
55 57 if block:
56 58 self.wait_on_engines()
57 59
58 60 def wait_on_engines(self, timeout=5):
59 61 """wait for our engines to connect."""
60 62 n = len(self.engines)+self.base_engine_count
61 63 tic = time.time()
62 64 while time.time()-tic < timeout and len(self.client.ids) < n:
63 65 time.sleep(0.1)
64 66
65 assert not self.client.ids < n, "waiting for engines timed out"
67 assert not len(self.client.ids) < n, "waiting for engines timed out"
66 68
67 69 def connect_client(self):
68 70 """connect a client with my Context, and track its sockets for cleanup"""
69 71 c = Client(profile='iptest',context=self.context)
70 72 for name in filter(lambda n:n.endswith('socket'), dir(c)):
71 73 self.sockets.append(getattr(c, name))
72 74 return c
73 75
74 76 def assertRaisesRemote(self, etype, f, *args, **kwargs):
75 77 try:
76 78 try:
77 79 f(*args, **kwargs)
78 80 except error.CompositeError as e:
79 81 e.raise_exception()
80 82 except error.RemoteError as e:
81 83 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
82 84 else:
83 85 self.fail("should have raised a RemoteError")
84 86
85 87 def setUp(self):
86 88 BaseZMQTestCase.setUp(self)
87 89 self.client = self.connect_client()
88 90 self.base_engine_count=len(self.client.ids)
89 91 self.engines=[]
90 92
91 93 def tearDown(self):
94
95 # close fds:
96 for e in filter(lambda e: e.poll() is not None, processes):
97 processes.remove(e)
98
92 99 self.client.close()
93 100 BaseZMQTestCase.tearDown(self)
94 # [ e.terminate() for e in filter(lambda e: e.poll() is None, self.engines) ]
95 # [ e.wait() for e in self.engines ]
96 # while len(self.client.ids) > self.base_engine_count:
97 # time.sleep(.1)
98 # del self.engines
99 # BaseZMQTestCase.tearDown(self)
101 # this will be superfluous when pyzmq merges PR #88
102 self.context.term()
103 print tempfile.TemporaryFile().fileno(),
104 sys.stdout.flush()
100 105 No newline at end of file
@@ -1,262 +1,262 b''
1 1 import time
2 2 from tempfile import mktemp
3 3
4 4 import nose.tools as nt
5 5 import zmq
6 6
7 7 from IPython.zmq.parallel import client as clientmod
8 8 from IPython.zmq.parallel import error
9 9 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
10 10 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
11 11
12 12 from clienttest import ClusterTestCase, segfault, wait
13 13
14 14 class TestClient(ClusterTestCase):
15 15
16 16 def test_ids(self):
17 17 n = len(self.client.ids)
18 18 self.add_engines(3)
19 19 self.assertEquals(len(self.client.ids), n+3)
20 self.assertTrue
21 20
22 21 def test_segfault_task(self):
23 22 """test graceful handling of engine death (balanced)"""
24 23 self.add_engines(1)
25 24 ar = self.client.apply(segfault, block=False)
26 25 self.assertRaisesRemote(error.EngineError, ar.get)
27 26 eid = ar.engine_id
28 27 while eid in self.client.ids:
29 28 time.sleep(.01)
30 29 self.client.spin()
31 30
32 31 def test_segfault_mux(self):
33 32 """test graceful handling of engine death (direct)"""
34 33 self.add_engines(1)
35 34 eid = self.client.ids[-1]
36 35 ar = self.client[eid].apply_async(segfault)
37 36 self.assertRaisesRemote(error.EngineError, ar.get)
38 37 eid = ar.engine_id
39 38 while eid in self.client.ids:
40 39 time.sleep(.01)
41 40 self.client.spin()
42 41
43 42 def test_view_indexing(self):
44 43 """test index access for views"""
45 44 self.add_engines(2)
46 45 targets = self.client._build_targets('all')[-1]
47 46 v = self.client[:]
48 47 self.assertEquals(v.targets, targets)
49 48 t = self.client.ids[2]
50 49 v = self.client[t]
51 50 self.assert_(isinstance(v, DirectView))
52 51 self.assertEquals(v.targets, t)
53 52 t = self.client.ids[2:4]
54 53 v = self.client[t]
55 54 self.assert_(isinstance(v, DirectView))
56 55 self.assertEquals(v.targets, t)
57 56 v = self.client[::2]
58 57 self.assert_(isinstance(v, DirectView))
59 58 self.assertEquals(v.targets, targets[::2])
60 59 v = self.client[1::3]
61 60 self.assert_(isinstance(v, DirectView))
62 61 self.assertEquals(v.targets, targets[1::3])
63 62 v = self.client[:-3]
64 63 self.assert_(isinstance(v, DirectView))
65 64 self.assertEquals(v.targets, targets[:-3])
66 65 v = self.client[-1]
67 66 self.assert_(isinstance(v, DirectView))
68 67 self.assertEquals(v.targets, targets[-1])
69 68 nt.assert_raises(TypeError, lambda : self.client[None])
70 69
71 70 def test_view_cache(self):
72 71 """test that multiple view requests return the same object"""
73 72 v = self.client[:2]
74 73 v2 =self.client[:2]
75 74 self.assertTrue(v is v2)
76 75 v = self.client.view()
77 76 v2 = self.client.view(balanced=True)
78 77 self.assertTrue(v is v2)
79 78
80 79 def test_targets(self):
81 80 """test various valid targets arguments"""
82 81 build = self.client._build_targets
83 82 ids = self.client.ids
84 83 idents,targets = build(None)
85 84 self.assertEquals(ids, targets)
86 85
87 86 def test_clear(self):
88 87 """test clear behavior"""
89 88 self.add_engines(2)
90 89 self.client.block=True
91 90 self.client.push(dict(a=5))
92 91 self.client.pull('a')
93 92 id0 = self.client.ids[-1]
94 93 self.client.clear(targets=id0)
95 94 self.client.pull('a', targets=self.client.ids[:-1])
96 95 self.assertRaisesRemote(NameError, self.client.pull, 'a')
97 96 self.client.clear()
98 97 for i in self.client.ids:
99 98 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
100 99
101 100
102 101 def test_push_pull(self):
103 102 """test pushing and pulling"""
104 103 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
105 104 t = self.client.ids[-1]
106 105 self.add_engines(2)
107 106 push = self.client.push
108 107 pull = self.client.pull
109 108 self.client.block=True
110 109 nengines = len(self.client)
111 110 push({'data':data}, targets=t)
112 111 d = pull('data', targets=t)
113 112 self.assertEquals(d, data)
114 113 push({'data':data})
115 114 d = pull('data')
116 115 self.assertEquals(d, nengines*[data])
117 116 ar = push({'data':data}, block=False)
118 117 self.assertTrue(isinstance(ar, AsyncResult))
119 118 r = ar.get()
120 119 ar = pull('data', block=False)
121 120 self.assertTrue(isinstance(ar, AsyncResult))
122 121 r = ar.get()
123 122 self.assertEquals(r, nengines*[data])
124 123 push(dict(a=10,b=20))
125 124 r = pull(('a','b'))
126 125 self.assertEquals(r, nengines*[[10,20]])
127 126
128 127 def test_push_pull_function(self):
129 128 "test pushing and pulling functions"
130 129 def testf(x):
131 130 return 2.0*x
132 131
133 132 self.add_engines(4)
134 133 t = self.client.ids[-1]
135 134 self.client.block=True
136 135 push = self.client.push
137 136 pull = self.client.pull
138 137 execute = self.client.execute
139 138 push({'testf':testf}, targets=t)
140 139 r = pull('testf', targets=t)
141 140 self.assertEqual(r(1.0), testf(1.0))
142 141 execute('r = testf(10)', targets=t)
143 142 r = pull('r', targets=t)
144 143 self.assertEquals(r, testf(10))
145 144 ar = push({'testf':testf}, block=False)
146 145 ar.get()
147 146 ar = pull('testf', block=False)
148 147 rlist = ar.get()
149 148 for r in rlist:
150 149 self.assertEqual(r(1.0), testf(1.0))
151 150 execute("def g(x): return x*x", targets=t)
152 151 r = pull(('testf','g'),targets=t)
153 152 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
154 153
155 154 def test_push_function_globals(self):
156 155 """test that pushed functions have access to globals"""
157 156 def geta():
158 157 return a
159 158 self.add_engines(1)
160 159 v = self.client[-1]
161 160 v.block=True
162 161 v['f'] = geta
163 162 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
164 163 v.execute('a=5')
165 164 v.execute('b=f()')
166 165 self.assertEquals(v['b'], 5)
167 166
168 167 def test_push_function_defaults(self):
169 168 """test that pushed functions preserve default args"""
170 169 def echo(a=10):
171 170 return a
172 171 self.add_engines(1)
173 172 v = self.client[-1]
174 173 v.block=True
175 174 v['f'] = echo
176 175 v.execute('b=f()')
177 176 self.assertEquals(v['b'], 10)
178 177
179 178 def test_get_result(self):
180 179 """test getting results from the Hub."""
181 180 c = clientmod.Client(profile='iptest')
182 t = self.client.ids[-1]
181 self.add_engines(1)
183 182 ar = c.apply(wait, (1,), block=False, targets=t)
183 # give the monitor time to notice the message
184 184 time.sleep(.25)
185 185 ahr = self.client.get_result(ar.msg_ids)
186 186 self.assertTrue(isinstance(ahr, AsyncHubResult))
187 187 self.assertEquals(ahr.get(), ar.get())
188 188 ar2 = self.client.get_result(ar.msg_ids)
189 189 self.assertFalse(isinstance(ar2, AsyncHubResult))
190 190
191 191 def test_ids_list(self):
192 192 """test client.ids"""
193 193 self.add_engines(2)
194 194 ids = self.client.ids
195 195 self.assertEquals(ids, self.client._ids)
196 196 self.assertFalse(ids is self.client._ids)
197 197 ids.remove(ids[-1])
198 198 self.assertNotEquals(ids, self.client._ids)
199 199
200 200 def test_run_newline(self):
201 201 """test that run appends newline to files"""
202 202 tmpfile = mktemp()
203 203 with open(tmpfile, 'w') as f:
204 204 f.write("""def g():
205 205 return 5
206 206 """)
207 207 v = self.client[-1]
208 208 v.run(tmpfile, block=True)
209 209 self.assertEquals(v.apply_sync(lambda : g()), 5)
210 210
211 211 def test_apply_tracked(self):
212 212 """test tracking for apply"""
213 213 # self.add_engines(1)
214 214 t = self.client.ids[-1]
215 215 self.client.block=False
216 216 def echo(n=1024*1024, **kwargs):
217 217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
218 218 ar = echo(1)
219 219 self.assertTrue(ar._tracker is None)
220 220 self.assertTrue(ar.sent)
221 221 ar = echo(track=True)
222 222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
223 223 self.assertEquals(ar.sent, ar._tracker.done)
224 224 ar._tracker.wait()
225 225 self.assertTrue(ar.sent)
226 226
227 227 def test_push_tracked(self):
228 228 t = self.client.ids[-1]
229 229 ns = dict(x='x'*1024*1024)
230 230 ar = self.client.push(ns, targets=t, block=False)
231 231 self.assertTrue(ar._tracker is None)
232 232 self.assertTrue(ar.sent)
233 233
234 234 ar = self.client.push(ns, targets=t, block=False, track=True)
235 235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
236 236 self.assertEquals(ar.sent, ar._tracker.done)
237 237 ar._tracker.wait()
238 238 self.assertTrue(ar.sent)
239 239 ar.get()
240 240
241 241 def test_scatter_tracked(self):
242 242 t = self.client.ids
243 243 x='x'*1024*1024
244 244 ar = self.client.scatter('x', x, targets=t, block=False)
245 245 self.assertTrue(ar._tracker is None)
246 246 self.assertTrue(ar.sent)
247 247
248 248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
249 249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
250 250 self.assertEquals(ar.sent, ar._tracker.done)
251 251 ar._tracker.wait()
252 252 self.assertTrue(ar.sent)
253 253 ar.get()
254 254
255 255 def test_remote_reference(self):
256 256 v = self.client[-1]
257 257 v['a'] = 123
258 258 ra = clientmod.Reference('a')
259 259 b = v.apply_sync(lambda x: x, ra)
260 260 self.assertEquals(b, 123)
261 261
262 262
1 NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file, binary diff hidden
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: modified file
The requested commit or file is too big and content was truncated. Show full diff
1 NO CONTENT: file was removed, binary diff hidden
1 NO CONTENT: file was removed, binary diff hidden
General Comments 0
You need to be logged in to leave comments. Login now