##// END OF EJS Templates
pyzmq-2.1.3 related testing adjustments
MinRK -
Show More
@@ -1,18 +1,23 b''
1 1 """The IPython ZMQ-based parallel computing interface."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 # from .asyncresult import *
14 14 # from .client import Client
15 15 # from .dependency import *
16 16 # from .remotefunction import *
17 17 # from .view import *
18 18
19 import zmq
20
21 if zmq.__version__ < '2.1.3':
22 raise ImportError("IPython.zmq.parallel requires pyzmq/0MQ >= 2.1.3, you appear to have %s"%zmq.__version__)
23
@@ -1,1584 +1,1591 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.pickleutil import Reference
28 28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 29 Dict, List, Bool, Str, Set)
30 30 from IPython.external.decorator import decorator
31 31 from IPython.external.ssh import tunnel
32 32
33 33 from . import error
34 34 from . import map as Map
35 35 from . import util
36 36 from . import streamsession as ss
37 37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
38 38 from .clusterdir import ClusterDir, ClusterDirError
39 39 from .dependency import Dependency, depend, require, dependent
40 40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 41 from .util import ReverseDict, validate_url, disambiguate_url
42 42 from .view import DirectView, LoadBalancedView
43 43
44 44 #--------------------------------------------------------------------------
45 45 # helpers for implementing old MEC API via client.apply
46 46 #--------------------------------------------------------------------------
47 47
48 48 def _push(user_ns, **ns):
49 49 """helper method for implementing `client.push` via `client.apply`"""
50 50 user_ns.update(ns)
51 51
52 52 def _pull(user_ns, keys):
53 53 """helper method for implementing `client.pull` via `client.apply`"""
54 54 if isinstance(keys, (list,tuple, set)):
55 55 for key in keys:
56 56 if not user_ns.has_key(key):
57 57 raise NameError("name '%s' is not defined"%key)
58 58 return map(user_ns.get, keys)
59 59 else:
60 60 if not user_ns.has_key(keys):
61 61 raise NameError("name '%s' is not defined"%keys)
62 62 return user_ns.get(keys)
63 63
64 64 def _clear(user_ns):
65 65 """helper method for implementing `client.clear` via `client.apply`"""
66 66 user_ns.clear()
67 67
68 68 def _execute(user_ns, code):
69 69 """helper method for implementing `client.execute` via `client.apply`"""
70 70 exec code in user_ns
71 71
72 72
73 73 #--------------------------------------------------------------------------
74 74 # Decorators for Client methods
75 75 #--------------------------------------------------------------------------
76 76
77 77 @decorator
78 78 def spinfirst(f, self, *args, **kwargs):
79 79 """Call spin() to sync state prior to calling the method."""
80 80 self.spin()
81 81 return f(self, *args, **kwargs)
82 82
83 83 @decorator
84 84 def defaultblock(f, self, *args, **kwargs):
85 85 """Default to self.block; preserve self.block."""
86 86 block = kwargs.get('block',None)
87 87 block = self.block if block is None else block
88 88 saveblock = self.block
89 89 self.block = block
90 90 try:
91 91 ret = f(self, *args, **kwargs)
92 92 finally:
93 93 self.block = saveblock
94 94 return ret
95 95
96 96
97 97 #--------------------------------------------------------------------------
98 98 # Classes
99 99 #--------------------------------------------------------------------------
100 100
101 101 class Metadata(dict):
102 102 """Subclass of dict for initializing metadata values.
103 103
104 104 Attribute access works on keys.
105 105
106 106 These objects have a strict set of keys - errors will raise if you try
107 107 to add new keys.
108 108 """
109 109 def __init__(self, *args, **kwargs):
110 110 dict.__init__(self)
111 111 md = {'msg_id' : None,
112 112 'submitted' : None,
113 113 'started' : None,
114 114 'completed' : None,
115 115 'received' : None,
116 116 'engine_uuid' : None,
117 117 'engine_id' : None,
118 118 'follow' : None,
119 119 'after' : None,
120 120 'status' : None,
121 121
122 122 'pyin' : None,
123 123 'pyout' : None,
124 124 'pyerr' : None,
125 125 'stdout' : '',
126 126 'stderr' : '',
127 127 }
128 128 self.update(md)
129 129 self.update(dict(*args, **kwargs))
130 130
131 131 def __getattr__(self, key):
132 132 """getattr aliased to getitem"""
133 133 if key in self.iterkeys():
134 134 return self[key]
135 135 else:
136 136 raise AttributeError(key)
137 137
138 138 def __setattr__(self, key, value):
139 139 """setattr aliased to setitem, with strict"""
140 140 if key in self.iterkeys():
141 141 self[key] = value
142 142 else:
143 143 raise AttributeError(key)
144 144
145 145 def __setitem__(self, key, value):
146 146 """strict static key enforcement"""
147 147 if key in self.iterkeys():
148 148 dict.__setitem__(self, key, value)
149 149 else:
150 150 raise KeyError(key)
151 151
152 152
153 153 class Client(HasTraits):
154 154 """A semi-synchronous client to the IPython ZMQ controller
155 155
156 156 Parameters
157 157 ----------
158 158
159 159 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 160 Connection information for the Hub's registration. If a json connector
161 161 file is given, then likely no further configuration is necessary.
162 162 [Default: use profile]
163 163 profile : bytes
164 164 The name of the Cluster profile to be used to find connector information.
165 165 [Default: 'default']
166 166 context : zmq.Context
167 167 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 168 username : bytes
169 169 set username to be passed to the Session object
170 170 debug : bool
171 171 flag for lots of message printing for debug purposes
172 172
173 173 #-------------- ssh related args ----------------
174 174 # These are args for configuring the ssh tunnel to be used
175 175 # credentials are used to forward connections over ssh to the Controller
176 176 # Note that the ip given in `addr` needs to be relative to sshserver
177 177 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 178 # and set sshserver as the same machine the Controller is on. However,
179 179 # the only requirement is that sshserver is able to see the Controller
180 180 # (i.e. is within the same trusted network).
181 181
182 182 sshserver : str
183 183 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 184 If keyfile or password is specified, and this is not, it will default to
185 185 the ip given in addr.
186 186 sshkey : str; path to public ssh key file
187 187 This specifies a key to be used in ssh login, default None.
188 188 Regular default ssh keys will be used without specifying this argument.
189 189 password : str
190 190 Your ssh password to sshserver. Note that if this is left None,
191 191 you will be prompted for it if passwordless key based login is unavailable.
192 192 paramiko : bool
193 193 flag for whether to use paramiko instead of shell ssh for tunneling.
194 194 [default: True on win32, False else]
195 195
196 196 #------- exec authentication args -------
197 197 # If even localhost is untrusted, you can have some protection against
198 198 # unauthorized execution by using a key. Messages are still sent
199 199 # as cleartext, so if someone can snoop your loopback traffic this will
200 200 # not help against malicious attacks.
201 201
202 202 exec_key : str
203 203 an authentication key or file containing a key
204 204 default: None
205 205
206 206
207 207 Attributes
208 208 ----------
209 209
210 210 ids : set of int engine IDs
211 211 requesting the ids attribute always synchronizes
212 212 the registration state. To request ids without synchronization,
213 213 use semi-private _ids attributes.
214 214
215 215 history : list of msg_ids
216 216 a list of msg_ids, keeping track of all the execution
217 217 messages you have submitted in order.
218 218
219 219 outstanding : set of msg_ids
220 220 a set of msg_ids that have been submitted, but whose
221 221 results have not yet been received.
222 222
223 223 results : dict
224 224 a dict of all our results, keyed by msg_id
225 225
226 226 block : bool
227 227 determines default behavior when block not specified
228 228 in execution methods
229 229
230 230 Methods
231 231 -------
232 232
233 233 spin
234 234 flushes incoming results and registration state changes
235 235 control methods spin, and requesting `ids` also ensures up to date
236 236
237 237 barrier
238 238 wait on one or more msg_ids
239 239
240 240 execution methods
241 241 apply
242 242 legacy: execute, run
243 243
244 244 query methods
245 245 queue_status, get_result, purge
246 246
247 247 control methods
248 248 abort, shutdown
249 249
250 250 """
251 251
252 252
253 253 block = Bool(False)
254 254 outstanding = Set()
255 255 results = Instance('collections.defaultdict', (dict,))
256 256 metadata = Instance('collections.defaultdict', (Metadata,))
257 257 history = List()
258 258 debug = Bool(False)
259 259 profile=CUnicode('default')
260 260
261 261 _outstanding_dict = Instance('collections.defaultdict', (set,))
262 262 _ids = List()
263 263 _connected=Bool(False)
264 264 _ssh=Bool(False)
265 265 _context = Instance('zmq.Context')
266 266 _config = Dict()
267 267 _engines=Instance(ReverseDict, (), {})
268 268 # _hub_socket=Instance('zmq.Socket')
269 269 _query_socket=Instance('zmq.Socket')
270 270 _control_socket=Instance('zmq.Socket')
271 271 _iopub_socket=Instance('zmq.Socket')
272 272 _notification_socket=Instance('zmq.Socket')
273 273 _apply_socket=Instance('zmq.Socket')
274 274 _mux_ident=Str()
275 275 _task_ident=Str()
276 276 _task_scheme=Str()
277 277 _balanced_views=Dict()
278 278 _direct_views=Dict()
279 279 _closed = False
280 280
281 281 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
282 282 context=None, username=None, debug=False, exec_key=None,
283 283 sshserver=None, sshkey=None, password=None, paramiko=None,
284 284 ):
285 285 super(Client, self).__init__(debug=debug, profile=profile)
286 286 if context is None:
287 context = zmq.Context()
287 context = zmq.Context.instance()
288 288 self._context = context
289 289
290 290
291 291 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
292 292 if self._cd is not None:
293 293 if url_or_file is None:
294 294 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
295 295 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
296 296 " Please specify at least one of url_or_file or profile."
297 297
298 298 try:
299 299 validate_url(url_or_file)
300 300 except AssertionError:
301 301 if not os.path.exists(url_or_file):
302 302 if self._cd:
303 303 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
304 304 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
305 305 with open(url_or_file) as f:
306 306 cfg = json.loads(f.read())
307 307 else:
308 308 cfg = {'url':url_or_file}
309 309
310 310 # sync defaults from args, json:
311 311 if sshserver:
312 312 cfg['ssh'] = sshserver
313 313 if exec_key:
314 314 cfg['exec_key'] = exec_key
315 315 exec_key = cfg['exec_key']
316 316 sshserver=cfg['ssh']
317 317 url = cfg['url']
318 318 location = cfg.setdefault('location', None)
319 319 cfg['url'] = disambiguate_url(cfg['url'], location)
320 320 url = cfg['url']
321 321
322 322 self._config = cfg
323 323
324 324 self._ssh = bool(sshserver or sshkey or password)
325 325 if self._ssh and sshserver is None:
326 326 # default to ssh via localhost
327 327 sshserver = url.split('://')[1].split(':')[0]
328 328 if self._ssh and password is None:
329 329 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
330 330 password=False
331 331 else:
332 332 password = getpass("SSH Password for %s: "%sshserver)
333 333 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
334 334 if exec_key is not None and os.path.isfile(exec_key):
335 335 arg = 'keyfile'
336 336 else:
337 337 arg = 'key'
338 338 key_arg = {arg:exec_key}
339 339 if username is None:
340 340 self.session = ss.StreamSession(**key_arg)
341 341 else:
342 342 self.session = ss.StreamSession(username, **key_arg)
343 343 self._query_socket = self._context.socket(zmq.XREQ)
344 344 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
345 345 if self._ssh:
346 346 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
347 347 else:
348 348 self._query_socket.connect(url)
349 349
350 350 self.session.debug = self.debug
351 351
352 352 self._notification_handlers = {'registration_notification' : self._register_engine,
353 353 'unregistration_notification' : self._unregister_engine,
354 354 }
355 355 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
356 356 'apply_reply' : self._handle_apply_reply}
357 357 self._connect(sshserver, ssh_kwargs)
358 358
359 359 def __del__(self):
360 360 """cleanup sockets, but _not_ context."""
361 361 self.close()
362 362
363 363 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
364 364 if ipython_dir is None:
365 365 ipython_dir = get_ipython_dir()
366 366 if cluster_dir is not None:
367 367 try:
368 368 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
369 369 return
370 370 except ClusterDirError:
371 371 pass
372 372 elif profile is not None:
373 373 try:
374 374 self._cd = ClusterDir.find_cluster_dir_by_profile(
375 375 ipython_dir, profile)
376 376 return
377 377 except ClusterDirError:
378 378 pass
379 379 self._cd = None
380 380
381 381 @property
382 382 def ids(self):
383 383 """Always up-to-date ids property."""
384 384 self._flush_notifications()
385 385 # always copy:
386 386 return list(self._ids)
387 387
388 388 def close(self):
389 389 if self._closed:
390 390 return
391 391 snames = filter(lambda n: n.endswith('socket'), dir(self))
392 392 for socket in map(lambda name: getattr(self, name), snames):
393 393 if isinstance(socket, zmq.Socket) and not socket.closed:
394 394 socket.close()
395 395 self._closed = True
396 396
397 397 def _update_engines(self, engines):
398 398 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
399 399 for k,v in engines.iteritems():
400 400 eid = int(k)
401 401 self._engines[eid] = bytes(v) # force not unicode
402 402 self._ids.append(eid)
403 403 self._ids = sorted(self._ids)
404 404 if sorted(self._engines.keys()) != range(len(self._engines)) and \
405 405 self._task_scheme == 'pure' and self._task_ident:
406 406 self._stop_scheduling_tasks()
407 407
408 408 def _stop_scheduling_tasks(self):
409 409 """Stop scheduling tasks because an engine has been unregistered
410 410 from a pure ZMQ scheduler.
411 411 """
412 412 self._task_ident = ''
413 413 # self._task_socket.close()
414 414 # self._task_socket = None
415 415 msg = "An engine has been unregistered, and we are using pure " +\
416 416 "ZMQ task scheduling. Task farming will be disabled."
417 417 if self.outstanding:
418 418 msg += " If you were running tasks when this happened, " +\
419 419 "some `outstanding` msg_ids may never resolve."
420 420 warnings.warn(msg, RuntimeWarning)
421 421
422 422 def _build_targets(self, targets):
423 423 """Turn valid target IDs or 'all' into two lists:
424 424 (int_ids, uuids).
425 425 """
426 426 if targets is None:
427 427 targets = self._ids
428 428 elif isinstance(targets, str):
429 429 if targets.lower() == 'all':
430 430 targets = self._ids
431 431 else:
432 432 raise TypeError("%r not valid str target, must be 'all'"%(targets))
433 433 elif isinstance(targets, int):
434 434 targets = [targets]
435 435 return [self._engines[t] for t in targets], list(targets)
436 436
437 437 def _connect(self, sshserver, ssh_kwargs):
438 438 """setup all our socket connections to the controller. This is called from
439 439 __init__."""
440 440
441 441 # Maybe allow reconnecting?
442 442 if self._connected:
443 443 return
444 444 self._connected=True
445 445
446 446 def connect_socket(s, url):
447 447 url = disambiguate_url(url, self._config['location'])
448 448 if self._ssh:
449 449 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
450 450 else:
451 451 return s.connect(url)
452 452
453 453 self.session.send(self._query_socket, 'connection_request')
454 454 idents,msg = self.session.recv(self._query_socket,mode=0)
455 455 if self.debug:
456 456 pprint(msg)
457 457 msg = ss.Message(msg)
458 458 content = msg.content
459 459 self._config['registration'] = dict(content)
460 460 if content.status == 'ok':
461 461 self._apply_socket = self._context.socket(zmq.XREP)
462 462 self._apply_socket.setsockopt(zmq.IDENTITY, self.session.session)
463 463 if content.mux:
464 464 # self._mux_socket = self._context.socket(zmq.XREQ)
465 465 self._mux_ident = 'mux'
466 466 connect_socket(self._apply_socket, content.mux)
467 467 if content.task:
468 468 self._task_scheme, task_addr = content.task
469 469 # self._task_socket = self._context.socket(zmq.XREQ)
470 470 # self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 471 connect_socket(self._apply_socket, task_addr)
472 472 self._task_ident = 'task'
473 473 if content.notification:
474 474 self._notification_socket = self._context.socket(zmq.SUB)
475 475 connect_socket(self._notification_socket, content.notification)
476 476 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
477 477 # if content.query:
478 478 # self._query_socket = self._context.socket(zmq.XREQ)
479 479 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
480 480 # connect_socket(self._query_socket, content.query)
481 481 if content.control:
482 482 self._control_socket = self._context.socket(zmq.XREQ)
483 483 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
484 484 connect_socket(self._control_socket, content.control)
485 485 if content.iopub:
486 486 self._iopub_socket = self._context.socket(zmq.SUB)
487 487 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
488 488 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
489 489 connect_socket(self._iopub_socket, content.iopub)
490 490 self._update_engines(dict(content.engines))
491 491 # give XREP apply_socket some time to connect
492 492 time.sleep(0.25)
493 493 else:
494 494 self._connected = False
495 495 raise Exception("Failed to connect!")
496 496
497 497 #--------------------------------------------------------------------------
498 498 # handlers and callbacks for incoming messages
499 499 #--------------------------------------------------------------------------
500 500
501 501 def _unwrap_exception(self, content):
502 502 """unwrap exception, and remap engineid to int."""
503 503 e = error.unwrap_exception(content)
504 504 # print e.traceback
505 505 if e.engine_info:
506 506 e_uuid = e.engine_info['engine_uuid']
507 507 eid = self._engines[e_uuid]
508 508 e.engine_info['engine_id'] = eid
509 509 return e
510 510
511 511 def _extract_metadata(self, header, parent, content):
512 512 md = {'msg_id' : parent['msg_id'],
513 513 'received' : datetime.now(),
514 514 'engine_uuid' : header.get('engine', None),
515 515 'follow' : parent.get('follow', []),
516 516 'after' : parent.get('after', []),
517 517 'status' : content['status'],
518 518 }
519 519
520 520 if md['engine_uuid'] is not None:
521 521 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
522 522
523 523 if 'date' in parent:
524 524 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
525 525 if 'started' in header:
526 526 md['started'] = datetime.strptime(header['started'], util.ISO8601)
527 527 if 'date' in header:
528 528 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
529 529 return md
530 530
531 531 def _register_engine(self, msg):
532 532 """Register a new engine, and update our connection info."""
533 533 content = msg['content']
534 534 eid = content['id']
535 535 d = {eid : content['queue']}
536 536 self._update_engines(d)
537 537
538 538 def _unregister_engine(self, msg):
539 539 """Unregister an engine that has died."""
540 540 content = msg['content']
541 541 eid = int(content['id'])
542 542 if eid in self._ids:
543 543 self._ids.remove(eid)
544 544 uuid = self._engines.pop(eid)
545 545
546 546 self._handle_stranded_msgs(eid, uuid)
547 547
548 548 if self._task_ident and self._task_scheme == 'pure':
549 549 self._stop_scheduling_tasks()
550 550
551 551 def _handle_stranded_msgs(self, eid, uuid):
552 552 """Handle messages known to be on an engine when the engine unregisters.
553 553
554 554 It is possible that this will fire prematurely - that is, an engine will
555 555 go down after completing a result, and the client will be notified
556 556 of the unregistration and later receive the successful result.
557 557 """
558 558
559 559 outstanding = self._outstanding_dict[uuid]
560 560
561 561 for msg_id in list(outstanding):
562 562 if msg_id in self.results:
563 563 # we already
564 564 continue
565 565 try:
566 566 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
567 567 except:
568 568 content = error.wrap_exception()
569 569 # build a fake message:
570 570 parent = {}
571 571 header = {}
572 572 parent['msg_id'] = msg_id
573 573 header['engine'] = uuid
574 574 header['date'] = datetime.now().strftime(util.ISO8601)
575 575 msg = dict(parent_header=parent, header=header, content=content)
576 576 self._handle_apply_reply(msg)
577 577
578 578 def _handle_execute_reply(self, msg):
579 579 """Save the reply to an execute_request into our results.
580 580
581 581 execute messages are never actually used. apply is used instead.
582 582 """
583 583
584 584 parent = msg['parent_header']
585 585 msg_id = parent['msg_id']
586 586 if msg_id not in self.outstanding:
587 587 if msg_id in self.history:
588 588 print ("got stale result: %s"%msg_id)
589 589 else:
590 590 print ("got unknown result: %s"%msg_id)
591 591 else:
592 592 self.outstanding.remove(msg_id)
593 593 self.results[msg_id] = self._unwrap_exception(msg['content'])
594 594
595 595 def _handle_apply_reply(self, msg):
596 596 """Save the reply to an apply_request into our results."""
597 597 parent = msg['parent_header']
598 598 msg_id = parent['msg_id']
599 599 if msg_id not in self.outstanding:
600 600 if msg_id in self.history:
601 601 print ("got stale result: %s"%msg_id)
602 602 print self.results[msg_id]
603 603 print msg
604 604 else:
605 605 print ("got unknown result: %s"%msg_id)
606 606 else:
607 607 self.outstanding.remove(msg_id)
608 608 content = msg['content']
609 609 header = msg['header']
610 610
611 611 # construct metadata:
612 612 md = self.metadata[msg_id]
613 613 md.update(self._extract_metadata(header, parent, content))
614 614 # is this redundant?
615 615 self.metadata[msg_id] = md
616 616
617 617 e_outstanding = self._outstanding_dict[md['engine_uuid']]
618 618 if msg_id in e_outstanding:
619 619 e_outstanding.remove(msg_id)
620 620
621 621 # construct result:
622 622 if content['status'] == 'ok':
623 623 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
624 624 elif content['status'] == 'aborted':
625 625 self.results[msg_id] = error.AbortedTask(msg_id)
626 626 elif content['status'] == 'resubmitted':
627 627 # TODO: handle resubmission
628 628 pass
629 629 else:
630 630 self.results[msg_id] = self._unwrap_exception(content)
631 631
632 632 def _flush_notifications(self):
633 633 """Flush notifications of engine registrations waiting
634 634 in ZMQ queue."""
635 635 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
636 636 while msg is not None:
637 637 if self.debug:
638 638 pprint(msg)
639 639 msg = msg[-1]
640 640 msg_type = msg['msg_type']
641 641 handler = self._notification_handlers.get(msg_type, None)
642 642 if handler is None:
643 643 raise Exception("Unhandled message type: %s"%msg.msg_type)
644 644 else:
645 645 handler(msg)
646 646 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
647 647
648 648 def _flush_results(self, sock):
649 649 """Flush task or queue results waiting in ZMQ queue."""
650 650 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
651 651 while msg is not None:
652 652 if self.debug:
653 653 pprint(msg)
654 654 msg = msg[-1]
655 655 msg_type = msg['msg_type']
656 656 handler = self._queue_handlers.get(msg_type, None)
657 657 if handler is None:
658 658 raise Exception("Unhandled message type: %s"%msg.msg_type)
659 659 else:
660 660 handler(msg)
661 661 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
662 662
663 663 def _flush_control(self, sock):
664 664 """Flush replies from the control channel waiting
665 665 in the ZMQ queue.
666 666
667 667 Currently: ignore them."""
668 668 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
669 669 while msg is not None:
670 670 if self.debug:
671 671 pprint(msg)
672 672 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
673 673
674 674 def _flush_iopub(self, sock):
675 675 """Flush replies from the iopub channel waiting
676 676 in the ZMQ queue.
677 677 """
678 678 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
679 679 while msg is not None:
680 680 if self.debug:
681 681 pprint(msg)
682 682 msg = msg[-1]
683 683 parent = msg['parent_header']
684 684 msg_id = parent['msg_id']
685 685 content = msg['content']
686 686 header = msg['header']
687 687 msg_type = msg['msg_type']
688 688
689 689 # init metadata:
690 690 md = self.metadata[msg_id]
691 691
692 692 if msg_type == 'stream':
693 693 name = content['name']
694 694 s = md[name] or ''
695 695 md[name] = s + content['data']
696 696 elif msg_type == 'pyerr':
697 697 md.update({'pyerr' : self._unwrap_exception(content)})
698 698 else:
699 699 md.update({msg_type : content['data']})
700 700
701 701 # reduntant?
702 702 self.metadata[msg_id] = md
703 703
704 704 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
705 705
706 706 #--------------------------------------------------------------------------
707 707 # len, getitem
708 708 #--------------------------------------------------------------------------
709 709
710 710 def __len__(self):
711 711 """len(client) returns # of engines."""
712 712 return len(self.ids)
713 713
714 714 def __getitem__(self, key):
715 715 """index access returns DirectView multiplexer objects
716 716
717 717 Must be int, slice, or list/tuple/xrange of ints"""
718 718 if not isinstance(key, (int, slice, tuple, list, xrange)):
719 719 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
720 720 else:
721 721 return self.view(key, balanced=False)
722 722
723 723 #--------------------------------------------------------------------------
724 724 # Begin public methods
725 725 #--------------------------------------------------------------------------
726 726
727 727 def spin(self):
728 728 """Flush any registration notifications and execution results
729 729 waiting in the ZMQ queue.
730 730 """
731 731 if self._notification_socket:
732 732 self._flush_notifications()
733 733 if self._apply_socket:
734 734 self._flush_results(self._apply_socket)
735 735 if self._control_socket:
736 736 self._flush_control(self._control_socket)
737 737 if self._iopub_socket:
738 738 self._flush_iopub(self._iopub_socket)
739 739
740 740 def barrier(self, jobs=None, timeout=-1):
741 741 """waits on one or more `jobs`, for up to `timeout` seconds.
742 742
743 743 Parameters
744 744 ----------
745 745
746 746 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
747 747 ints are indices to self.history
748 748 strs are msg_ids
749 749 default: wait on all outstanding messages
750 750 timeout : float
751 751 a time in seconds, after which to give up.
752 752 default is -1, which means no timeout
753 753
754 754 Returns
755 755 -------
756 756
757 757 True : when all msg_ids are done
758 758 False : timeout reached, some msg_ids still outstanding
759 759 """
760 760 tic = time.time()
761 761 if jobs is None:
762 762 theids = self.outstanding
763 763 else:
764 764 if isinstance(jobs, (int, str, AsyncResult)):
765 765 jobs = [jobs]
766 766 theids = set()
767 767 for job in jobs:
768 768 if isinstance(job, int):
769 769 # index access
770 770 job = self.history[job]
771 771 elif isinstance(job, AsyncResult):
772 772 map(theids.add, job.msg_ids)
773 773 continue
774 774 theids.add(job)
775 775 if not theids.intersection(self.outstanding):
776 776 return True
777 777 self.spin()
778 778 while theids.intersection(self.outstanding):
779 779 if timeout >= 0 and ( time.time()-tic ) > timeout:
780 780 break
781 781 time.sleep(1e-3)
782 782 self.spin()
783 783 return len(theids.intersection(self.outstanding)) == 0
784 784
785 785 #--------------------------------------------------------------------------
786 786 # Control methods
787 787 #--------------------------------------------------------------------------
788 788
789 789 @spinfirst
790 790 @defaultblock
791 791 def clear(self, targets=None, block=None):
792 792 """Clear the namespace in target(s)."""
793 793 targets = self._build_targets(targets)[0]
794 794 for t in targets:
795 795 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
796 796 error = False
797 797 if self.block:
798 798 for i in range(len(targets)):
799 799 idents,msg = self.session.recv(self._control_socket,0)
800 800 if self.debug:
801 801 pprint(msg)
802 802 if msg['content']['status'] != 'ok':
803 803 error = self._unwrap_exception(msg['content'])
804 804 if error:
805 805 raise error
806 806
807 807
808 808 @spinfirst
809 809 @defaultblock
810 810 def abort(self, jobs=None, targets=None, block=None):
811 811 """Abort specific jobs from the execution queues of target(s).
812 812
813 813 This is a mechanism to prevent jobs that have already been submitted
814 814 from executing.
815 815
816 816 Parameters
817 817 ----------
818 818
819 819 jobs : msg_id, list of msg_ids, or AsyncResult
820 820 The jobs to be aborted
821 821
822 822
823 823 """
824 824 targets = self._build_targets(targets)[0]
825 825 msg_ids = []
826 826 if isinstance(jobs, (basestring,AsyncResult)):
827 827 jobs = [jobs]
828 828 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
829 829 if bad_ids:
830 830 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
831 831 for j in jobs:
832 832 if isinstance(j, AsyncResult):
833 833 msg_ids.extend(j.msg_ids)
834 834 else:
835 835 msg_ids.append(j)
836 836 content = dict(msg_ids=msg_ids)
837 837 for t in targets:
838 838 self.session.send(self._control_socket, 'abort_request',
839 839 content=content, ident=t)
840 840 error = False
841 841 if self.block:
842 842 for i in range(len(targets)):
843 843 idents,msg = self.session.recv(self._control_socket,0)
844 844 if self.debug:
845 845 pprint(msg)
846 846 if msg['content']['status'] != 'ok':
847 847 error = self._unwrap_exception(msg['content'])
848 848 if error:
849 849 raise error
850 850
851 851 @spinfirst
852 852 @defaultblock
853 853 def shutdown(self, targets=None, restart=False, controller=False, block=None):
854 854 """Terminates one or more engine processes, optionally including the controller."""
855 855 if controller:
856 856 targets = 'all'
857 857 targets = self._build_targets(targets)[0]
858 858 for t in targets:
859 859 self.session.send(self._control_socket, 'shutdown_request',
860 860 content={'restart':restart},ident=t)
861 861 error = False
862 862 if block or controller:
863 863 for i in range(len(targets)):
864 864 idents,msg = self.session.recv(self._control_socket,0)
865 865 if self.debug:
866 866 pprint(msg)
867 867 if msg['content']['status'] != 'ok':
868 868 error = self._unwrap_exception(msg['content'])
869 869
870 870 if controller:
871 871 time.sleep(0.25)
872 872 self.session.send(self._query_socket, 'shutdown_request')
873 873 idents,msg = self.session.recv(self._query_socket, 0)
874 874 if self.debug:
875 875 pprint(msg)
876 876 if msg['content']['status'] != 'ok':
877 877 error = self._unwrap_exception(msg['content'])
878 878
879 879 if error:
880 880 raise error
881 881
882 882 #--------------------------------------------------------------------------
883 883 # Execution methods
884 884 #--------------------------------------------------------------------------
885 885
886 886 @defaultblock
887 887 def execute(self, code, targets='all', block=None):
888 888 """Executes `code` on `targets` in blocking or nonblocking manner.
889 889
890 890 ``execute`` is always `bound` (affects engine namespace)
891 891
892 892 Parameters
893 893 ----------
894 894
895 895 code : str
896 896 the code string to be executed
897 897 targets : int/str/list of ints/strs
898 898 the engines on which to execute
899 899 default : all
900 900 block : bool
901 901 whether or not to wait until done to return
902 902 default: self.block
903 903 """
904 904 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
905 905 if not block:
906 906 return result
907 907
908 908 def run(self, filename, targets='all', block=None):
909 909 """Execute contents of `filename` on engine(s).
910 910
911 911 This simply reads the contents of the file and calls `execute`.
912 912
913 913 Parameters
914 914 ----------
915 915
916 916 filename : str
917 917 The path to the file
918 918 targets : int/str/list of ints/strs
919 919 the engines on which to execute
920 920 default : all
921 921 block : bool
922 922 whether or not to wait until done
923 923 default: self.block
924 924
925 925 """
926 926 with open(filename, 'r') as f:
927 927 # add newline in case of trailing indented whitespace
928 928 # which will cause SyntaxError
929 929 code = f.read()+'\n'
930 930 return self.execute(code, targets=targets, block=block)
931 931
932 932 def _maybe_raise(self, result):
933 933 """wrapper for maybe raising an exception if apply failed."""
934 934 if isinstance(result, error.RemoteError):
935 935 raise result
936 936
937 937 return result
938 938
939 939 def _build_dependency(self, dep):
940 940 """helper for building jsonable dependencies from various input forms"""
941 941 if isinstance(dep, Dependency):
942 942 return dep.as_dict()
943 943 elif isinstance(dep, AsyncResult):
944 944 return dep.msg_ids
945 945 elif dep is None:
946 946 return []
947 947 else:
948 948 # pass to Dependency constructor
949 949 return list(Dependency(dep))
950 950
951 951 @defaultblock
952 952 def apply(self, f, args=None, kwargs=None, bound=False, block=None,
953 953 targets=None, balanced=None,
954 954 after=None, follow=None, timeout=None,
955 955 track=False):
956 956 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
957 957
958 958 This is the central execution command for the client.
959 959
960 960 Parameters
961 961 ----------
962 962
963 963 f : function
964 964 The fuction to be called remotely
965 965 args : tuple/list
966 966 The positional arguments passed to `f`
967 967 kwargs : dict
968 968 The keyword arguments passed to `f`
969 969 bound : bool (default: False)
970 970 Whether to pass the Engine(s) Namespace as the first argument to `f`.
971 971 block : bool (default: self.block)
972 972 Whether to wait for the result, or return immediately.
973 973 False:
974 974 returns AsyncResult
975 975 True:
976 976 returns actual result(s) of f(*args, **kwargs)
977 977 if multiple targets:
978 978 list of results, matching `targets`
979 track : bool
980 whether to track non-copying sends.
981 [default False]
982
979 983 targets : int,list of ints, 'all', None
980 984 Specify the destination of the job.
981 985 if None:
982 986 Submit via Task queue for load-balancing.
983 987 if 'all':
984 988 Run on all active engines
985 989 if list:
986 990 Run on each specified engine
987 991 if int:
988 992 Run on single engine
989
993 Note:
994 that if `balanced=True`, and `targets` is specified,
995 then the load-balancing will be limited to balancing
996 among `targets`.
997
990 998 balanced : bool, default None
991 999 whether to load-balance. This will default to True
992 1000 if targets is unspecified, or False if targets is specified.
993
994 The following arguments are only used when balanced is True:
1001
1002 If `balanced` and `targets` are both specified, the task will
1003 be assigne to *one* of the targets by the scheduler.
1004
1005 The following arguments are only used when balanced is True:
1006
995 1007 after : Dependency or collection of msg_ids
996 1008 Only for load-balanced execution (targets=None)
997 1009 Specify a list of msg_ids as a time-based dependency.
998 1010 This job will only be run *after* the dependencies
999 1011 have been met.
1000
1012
1001 1013 follow : Dependency or collection of msg_ids
1002 1014 Only for load-balanced execution (targets=None)
1003 1015 Specify a list of msg_ids as a location-based dependency.
1004 1016 This job will only be run on an engine where this dependency
1005 1017 is met.
1006
1018
1007 1019 timeout : float/int or None
1008 1020 Only for load-balanced execution (targets=None)
1009 1021 Specify an amount of time (in seconds) for the scheduler to
1010 1022 wait for dependencies to be met before failing with a
1011 1023 DependencyTimeout.
1012 track : bool
1013 whether to track non-copying sends.
1014 [default False]
1015
1016 after,follow,timeout only used if `balanced=True`.
1017 1024
1018 1025 Returns
1019 1026 -------
1020 1027
1021 1028 if block is False:
1022 1029 return AsyncResult wrapping msg_ids
1023 1030 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1024 1031 else:
1025 if single target:
1032 if single target (or balanced):
1026 1033 return result of `f(*args, **kwargs)`
1027 1034 else:
1028 1035 return list of results, matching `targets`
1029 1036 """
1030 1037 assert not self._closed, "cannot use me anymore, I'm closed!"
1031 1038 # defaults:
1032 1039 block = block if block is not None else self.block
1033 1040 args = args if args is not None else []
1034 1041 kwargs = kwargs if kwargs is not None else {}
1035 1042
1036 1043 if not self._ids:
1037 1044 # flush notification socket if no engines yet
1038 1045 any_ids = self.ids
1039 1046 if not any_ids:
1040 1047 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
1041 1048
1042 1049 if balanced is None:
1043 1050 if targets is None:
1044 1051 # default to balanced if targets unspecified
1045 1052 balanced = True
1046 1053 else:
1047 1054 # otherwise default to multiplexing
1048 1055 balanced = False
1049 1056
1050 1057 if targets is None and balanced is False:
1051 1058 # default to all if *not* balanced, and targets is unspecified
1052 1059 targets = 'all'
1053 1060
1054 1061 # enforce types of f,args,kwrags
1055 1062 if not callable(f):
1056 1063 raise TypeError("f must be callable, not %s"%type(f))
1057 1064 if not isinstance(args, (tuple, list)):
1058 1065 raise TypeError("args must be tuple or list, not %s"%type(args))
1059 1066 if not isinstance(kwargs, dict):
1060 1067 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1061 1068
1062 1069 options = dict(bound=bound, block=block, targets=targets, track=track)
1063 1070
1064 1071 if balanced:
1065 1072 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1066 1073 after=after, follow=follow, **options)
1067 1074 elif follow or after or timeout:
1068 1075 msg = "follow, after, and timeout args are only used for"
1069 1076 msg += " load-balanced execution."
1070 1077 raise ValueError(msg)
1071 1078 else:
1072 1079 return self._apply_direct(f, args, kwargs, **options)
1073 1080
1074 1081 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1075 1082 after=None, follow=None, timeout=None, track=None):
1076 1083 """call f(*args, **kwargs) remotely in a load-balanced manner.
1077 1084
1078 1085 This is a private method, see `apply` for details.
1079 1086 Not to be called directly!
1080 1087 """
1081 1088
1082 1089 loc = locals()
1083 1090 for name in ('bound', 'block', 'track'):
1084 1091 assert loc[name] is not None, "kwarg %r must be specified!"%name
1085 1092
1086 1093 if not self._task_ident:
1087 1094 msg = "Task farming is disabled"
1088 1095 if self._task_scheme == 'pure':
1089 1096 msg += " because the pure ZMQ scheduler cannot handle"
1090 1097 msg += " disappearing engines."
1091 1098 raise RuntimeError(msg)
1092 1099
1093 1100 if self._task_scheme == 'pure':
1094 1101 # pure zmq scheme doesn't support dependencies
1095 1102 msg = "Pure ZMQ scheduler doesn't support dependencies"
1096 1103 if (follow or after):
1097 1104 # hard fail on DAG dependencies
1098 1105 raise RuntimeError(msg)
1099 1106 if isinstance(f, dependent):
1100 1107 # soft warn on functional dependencies
1101 1108 warnings.warn(msg, RuntimeWarning)
1102 1109
1103 1110 # defaults:
1104 1111 args = args if args is not None else []
1105 1112 kwargs = kwargs if kwargs is not None else {}
1106 1113
1107 1114 if targets:
1108 1115 idents,_ = self._build_targets(targets)
1109 1116 else:
1110 1117 idents = []
1111 1118
1112 1119 after = self._build_dependency(after)
1113 1120 follow = self._build_dependency(follow)
1114 1121 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1115 1122 bufs = util.pack_apply_message(f,args,kwargs)
1116 1123 content = dict(bound=bound)
1117 1124
1118 1125 msg = self.session.send(self._apply_socket, "apply_request", ident=self._task_ident,
1119 1126 content=content, buffers=bufs, subheader=subheader, track=track)
1120 1127 msg_id = msg['msg_id']
1121 1128 self.outstanding.add(msg_id)
1122 1129 self.history.append(msg_id)
1123 1130 self.metadata[msg_id]['submitted'] = datetime.now()
1124 1131 tracker = None if track is False else msg['tracker']
1125 1132 ar = AsyncResult(self, [msg_id], fname=f.__name__, targets=targets, tracker=tracker)
1126 1133 if block:
1127 1134 try:
1128 1135 return ar.get()
1129 1136 except KeyboardInterrupt:
1130 1137 return ar
1131 1138 else:
1132 1139 return ar
1133 1140
1134 1141 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None,
1135 1142 track=None):
1136 1143 """Then underlying method for applying functions to specific engines
1137 1144 via the MUX queue.
1138 1145
1139 1146 This is a private method, see `apply` for details.
1140 1147 Not to be called directly!
1141 1148 """
1142 1149
1143 1150 if not self._mux_ident:
1144 1151 msg = "Multiplexing is disabled"
1145 1152 raise RuntimeError(msg)
1146 1153
1147 1154 loc = locals()
1148 1155 for name in ('bound', 'block', 'targets', 'track'):
1149 1156 assert loc[name] is not None, "kwarg %r must be specified!"%name
1150 1157
1151 1158 idents,targets = self._build_targets(targets)
1152 1159
1153 1160 subheader = {}
1154 1161 content = dict(bound=bound)
1155 1162 bufs = util.pack_apply_message(f,args,kwargs)
1156 1163
1157 1164 msg_ids = []
1158 1165 trackers = []
1159 1166 for ident in idents:
1160 1167 msg = self.session.send(self._apply_socket, "apply_request",
1161 1168 content=content, buffers=bufs, ident=[self._mux_ident, ident], subheader=subheader,
1162 1169 track=track)
1163 1170 if track:
1164 1171 trackers.append(msg['tracker'])
1165 1172 msg_id = msg['msg_id']
1166 1173 self.outstanding.add(msg_id)
1167 1174 self._outstanding_dict[ident].add(msg_id)
1168 1175 self.history.append(msg_id)
1169 1176 msg_ids.append(msg_id)
1170 1177
1171 1178 tracker = None if track is False else zmq.MessageTracker(*trackers)
1172 1179 ar = AsyncResult(self, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
1173 1180
1174 1181 if block:
1175 1182 try:
1176 1183 return ar.get()
1177 1184 except KeyboardInterrupt:
1178 1185 return ar
1179 1186 else:
1180 1187 return ar
1181 1188
1182 1189 #--------------------------------------------------------------------------
1183 1190 # construct a View object
1184 1191 #--------------------------------------------------------------------------
1185 1192
1186 1193 @defaultblock
1187 1194 def remote(self, bound=False, block=None, targets=None, balanced=None):
1188 1195 """Decorator for making a RemoteFunction"""
1189 1196 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1190 1197
1191 1198 @defaultblock
1192 1199 def parallel(self, dist='b', bound=False, block=None, targets=None, balanced=None):
1193 1200 """Decorator for making a ParallelFunction"""
1194 1201 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1195 1202
1196 1203 def _cache_view(self, targets, balanced):
1197 1204 """save views, so subsequent requests don't create new objects."""
1198 1205 if balanced:
1199 1206 view_class = LoadBalancedView
1200 1207 view_cache = self._balanced_views
1201 1208 else:
1202 1209 view_class = DirectView
1203 1210 view_cache = self._direct_views
1204 1211
1205 1212 # use str, since often targets will be a list
1206 1213 key = str(targets)
1207 1214 if key not in view_cache:
1208 1215 view_cache[key] = view_class(client=self, targets=targets)
1209 1216
1210 1217 return view_cache[key]
1211 1218
1212 1219 def view(self, targets=None, balanced=None):
1213 1220 """Method for constructing View objects.
1214 1221
1215 1222 If no arguments are specified, create a LoadBalancedView
1216 1223 using all engines. If only `targets` specified, it will
1217 1224 be a DirectView. This method is the underlying implementation
1218 1225 of ``client.__getitem__``.
1219 1226
1220 1227 Parameters
1221 1228 ----------
1222 1229
1223 1230 targets: list,slice,int,etc. [default: use all engines]
1224 1231 The engines to use for the View
1225 1232 balanced : bool [default: False if targets specified, True else]
1226 1233 whether to build a LoadBalancedView or a DirectView
1227 1234
1228 1235 """
1229 1236
1230 1237 balanced = (targets is None) if balanced is None else balanced
1231 1238
1232 1239 if targets is None:
1233 1240 if balanced:
1234 1241 return self._cache_view(None,True)
1235 1242 else:
1236 1243 targets = slice(None)
1237 1244
1238 1245 if isinstance(targets, int):
1239 1246 if targets < 0:
1240 1247 targets = self.ids[targets]
1241 1248 if targets not in self.ids:
1242 1249 raise IndexError("No such engine: %i"%targets)
1243 1250 return self._cache_view(targets, balanced)
1244 1251
1245 1252 if isinstance(targets, slice):
1246 1253 indices = range(len(self.ids))[targets]
1247 1254 ids = sorted(self._ids)
1248 1255 targets = [ ids[i] for i in indices ]
1249 1256
1250 1257 if isinstance(targets, (tuple, list, xrange)):
1251 1258 _,targets = self._build_targets(list(targets))
1252 1259 return self._cache_view(targets, balanced)
1253 1260 else:
1254 1261 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1255 1262
1256 1263 #--------------------------------------------------------------------------
1257 1264 # Data movement
1258 1265 #--------------------------------------------------------------------------
1259 1266
1260 1267 @defaultblock
1261 1268 def push(self, ns, targets='all', block=None, track=False):
1262 1269 """Push the contents of `ns` into the namespace on `target`"""
1263 1270 if not isinstance(ns, dict):
1264 1271 raise TypeError("Must be a dict, not %s"%type(ns))
1265 1272 result = self.apply(_push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1266 1273 if not block:
1267 1274 return result
1268 1275
1269 1276 @defaultblock
1270 1277 def pull(self, keys, targets='all', block=None):
1271 1278 """Pull objects from `target`'s namespace by `keys`"""
1272 1279 if isinstance(keys, basestring):
1273 1280 pass
1274 1281 elif isinstance(keys, (list,tuple,set)):
1275 1282 for key in keys:
1276 1283 if not isinstance(key, basestring):
1277 1284 raise TypeError("keys must be str, not type %r"%type(key))
1278 1285 else:
1279 1286 raise TypeError("keys must be strs, not %r"%keys)
1280 1287 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1281 1288 return result
1282 1289
1283 1290 @defaultblock
1284 1291 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None, track=False):
1285 1292 """
1286 1293 Partition a Python sequence and send the partitions to a set of engines.
1287 1294 """
1288 1295 targets = self._build_targets(targets)[-1]
1289 1296 mapObject = Map.dists[dist]()
1290 1297 nparts = len(targets)
1291 1298 msg_ids = []
1292 1299 trackers = []
1293 1300 for index, engineid in enumerate(targets):
1294 1301 partition = mapObject.getPartition(seq, index, nparts)
1295 1302 if flatten and len(partition) == 1:
1296 1303 r = self.push({key: partition[0]}, targets=engineid, block=False, track=track)
1297 1304 else:
1298 1305 r = self.push({key: partition}, targets=engineid, block=False, track=track)
1299 1306 msg_ids.extend(r.msg_ids)
1300 1307 if track:
1301 1308 trackers.append(r._tracker)
1302 1309
1303 1310 if track:
1304 1311 tracker = zmq.MessageTracker(*trackers)
1305 1312 else:
1306 1313 tracker = None
1307 1314
1308 1315 r = AsyncResult(self, msg_ids, fname='scatter', targets=targets, tracker=tracker)
1309 1316 if block:
1310 1317 r.wait()
1311 1318 else:
1312 1319 return r
1313 1320
1314 1321 @defaultblock
1315 1322 def gather(self, key, dist='b', targets='all', block=None):
1316 1323 """
1317 1324 Gather a partitioned sequence on a set of engines as a single local seq.
1318 1325 """
1319 1326
1320 1327 targets = self._build_targets(targets)[-1]
1321 1328 mapObject = Map.dists[dist]()
1322 1329 msg_ids = []
1323 1330 for index, engineid in enumerate(targets):
1324 1331 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1325 1332
1326 1333 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1327 1334 if block:
1328 1335 return r.get()
1329 1336 else:
1330 1337 return r
1331 1338
1332 1339 #--------------------------------------------------------------------------
1333 1340 # Query methods
1334 1341 #--------------------------------------------------------------------------
1335 1342
1336 1343 @spinfirst
1337 1344 @defaultblock
1338 1345 def get_result(self, indices_or_msg_ids=None, block=None):
1339 1346 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1340 1347
1341 1348 If the client already has the results, no request to the Hub will be made.
1342 1349
1343 1350 This is a convenient way to construct AsyncResult objects, which are wrappers
1344 1351 that include metadata about execution, and allow for awaiting results that
1345 1352 were not submitted by this Client.
1346 1353
1347 1354 It can also be a convenient way to retrieve the metadata associated with
1348 1355 blocking execution, since it always retrieves
1349 1356
1350 1357 Examples
1351 1358 --------
1352 1359 ::
1353 1360
1354 1361 In [10]: r = client.apply()
1355 1362
1356 1363 Parameters
1357 1364 ----------
1358 1365
1359 1366 indices_or_msg_ids : integer history index, str msg_id, or list of either
1360 1367 The indices or msg_ids of indices to be retrieved
1361 1368
1362 1369 block : bool
1363 1370 Whether to wait for the result to be done
1364 1371
1365 1372 Returns
1366 1373 -------
1367 1374
1368 1375 AsyncResult
1369 1376 A single AsyncResult object will always be returned.
1370 1377
1371 1378 AsyncHubResult
1372 1379 A subclass of AsyncResult that retrieves results from the Hub
1373 1380
1374 1381 """
1375 1382 if indices_or_msg_ids is None:
1376 1383 indices_or_msg_ids = -1
1377 1384
1378 1385 if not isinstance(indices_or_msg_ids, (list,tuple)):
1379 1386 indices_or_msg_ids = [indices_or_msg_ids]
1380 1387
1381 1388 theids = []
1382 1389 for id in indices_or_msg_ids:
1383 1390 if isinstance(id, int):
1384 1391 id = self.history[id]
1385 1392 if not isinstance(id, str):
1386 1393 raise TypeError("indices must be str or int, not %r"%id)
1387 1394 theids.append(id)
1388 1395
1389 1396 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1390 1397 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1391 1398
1392 1399 if remote_ids:
1393 1400 ar = AsyncHubResult(self, msg_ids=theids)
1394 1401 else:
1395 1402 ar = AsyncResult(self, msg_ids=theids)
1396 1403
1397 1404 if block:
1398 1405 ar.wait()
1399 1406
1400 1407 return ar
1401 1408
1402 1409 @spinfirst
1403 1410 def result_status(self, msg_ids, status_only=True):
1404 1411 """Check on the status of the result(s) of the apply request with `msg_ids`.
1405 1412
1406 1413 If status_only is False, then the actual results will be retrieved, else
1407 1414 only the status of the results will be checked.
1408 1415
1409 1416 Parameters
1410 1417 ----------
1411 1418
1412 1419 msg_ids : list of msg_ids
1413 1420 if int:
1414 1421 Passed as index to self.history for convenience.
1415 1422 status_only : bool (default: True)
1416 1423 if False:
1417 1424 Retrieve the actual results of completed tasks.
1418 1425
1419 1426 Returns
1420 1427 -------
1421 1428
1422 1429 results : dict
1423 1430 There will always be the keys 'pending' and 'completed', which will
1424 1431 be lists of msg_ids that are incomplete or complete. If `status_only`
1425 1432 is False, then completed results will be keyed by their `msg_id`.
1426 1433 """
1427 1434 if not isinstance(msg_ids, (list,tuple)):
1428 1435 msg_ids = [msg_ids]
1429 1436
1430 1437 theids = []
1431 1438 for msg_id in msg_ids:
1432 1439 if isinstance(msg_id, int):
1433 1440 msg_id = self.history[msg_id]
1434 1441 if not isinstance(msg_id, basestring):
1435 1442 raise TypeError("msg_ids must be str, not %r"%msg_id)
1436 1443 theids.append(msg_id)
1437 1444
1438 1445 completed = []
1439 1446 local_results = {}
1440 1447
1441 1448 # comment this block out to temporarily disable local shortcut:
1442 1449 for msg_id in theids:
1443 1450 if msg_id in self.results:
1444 1451 completed.append(msg_id)
1445 1452 local_results[msg_id] = self.results[msg_id]
1446 1453 theids.remove(msg_id)
1447 1454
1448 1455 if theids: # some not locally cached
1449 1456 content = dict(msg_ids=theids, status_only=status_only)
1450 1457 msg = self.session.send(self._query_socket, "result_request", content=content)
1451 1458 zmq.select([self._query_socket], [], [])
1452 1459 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1453 1460 if self.debug:
1454 1461 pprint(msg)
1455 1462 content = msg['content']
1456 1463 if content['status'] != 'ok':
1457 1464 raise self._unwrap_exception(content)
1458 1465 buffers = msg['buffers']
1459 1466 else:
1460 1467 content = dict(completed=[],pending=[])
1461 1468
1462 1469 content['completed'].extend(completed)
1463 1470
1464 1471 if status_only:
1465 1472 return content
1466 1473
1467 1474 failures = []
1468 1475 # load cached results into result:
1469 1476 content.update(local_results)
1470 1477 # update cache with results:
1471 1478 for msg_id in sorted(theids):
1472 1479 if msg_id in content['completed']:
1473 1480 rec = content[msg_id]
1474 1481 parent = rec['header']
1475 1482 header = rec['result_header']
1476 1483 rcontent = rec['result_content']
1477 1484 iodict = rec['io']
1478 1485 if isinstance(rcontent, str):
1479 1486 rcontent = self.session.unpack(rcontent)
1480 1487
1481 1488 md = self.metadata[msg_id]
1482 1489 md.update(self._extract_metadata(header, parent, rcontent))
1483 1490 md.update(iodict)
1484 1491
1485 1492 if rcontent['status'] == 'ok':
1486 1493 res,buffers = util.unserialize_object(buffers)
1487 1494 else:
1488 1495 print rcontent
1489 1496 res = self._unwrap_exception(rcontent)
1490 1497 failures.append(res)
1491 1498
1492 1499 self.results[msg_id] = res
1493 1500 content[msg_id] = res
1494 1501
1495 1502 if len(theids) == 1 and failures:
1496 1503 raise failures[0]
1497 1504
1498 1505 error.collect_exceptions(failures, "result_status")
1499 1506 return content
1500 1507
1501 1508 @spinfirst
1502 1509 def queue_status(self, targets='all', verbose=False):
1503 1510 """Fetch the status of engine queues.
1504 1511
1505 1512 Parameters
1506 1513 ----------
1507 1514
1508 1515 targets : int/str/list of ints/strs
1509 1516 the engines whose states are to be queried.
1510 1517 default : all
1511 1518 verbose : bool
1512 1519 Whether to return lengths only, or lists of ids for each element
1513 1520 """
1514 1521 targets = self._build_targets(targets)[1]
1515 1522 content = dict(targets=targets, verbose=verbose)
1516 1523 self.session.send(self._query_socket, "queue_request", content=content)
1517 1524 idents,msg = self.session.recv(self._query_socket, 0)
1518 1525 if self.debug:
1519 1526 pprint(msg)
1520 1527 content = msg['content']
1521 1528 status = content.pop('status')
1522 1529 if status != 'ok':
1523 1530 raise self._unwrap_exception(content)
1524 1531 return util.rekey(content)
1525 1532
1526 1533 @spinfirst
1527 1534 def purge_results(self, jobs=[], targets=[]):
1528 1535 """Tell the controller to forget results.
1529 1536
1530 1537 Individual results can be purged by msg_id, or the entire
1531 1538 history of specific targets can be purged.
1532 1539
1533 1540 Parameters
1534 1541 ----------
1535 1542
1536 1543 jobs : str or list of strs or AsyncResult objects
1537 1544 the msg_ids whose results should be forgotten.
1538 1545 targets : int/str/list of ints/strs
1539 1546 The targets, by uuid or int_id, whose entire history is to be purged.
1540 1547 Use `targets='all'` to scrub everything from the controller's memory.
1541 1548
1542 1549 default : None
1543 1550 """
1544 1551 if not targets and not jobs:
1545 1552 raise ValueError("Must specify at least one of `targets` and `jobs`")
1546 1553 if targets:
1547 1554 targets = self._build_targets(targets)[1]
1548 1555
1549 1556 # construct msg_ids from jobs
1550 1557 msg_ids = []
1551 1558 if isinstance(jobs, (basestring,AsyncResult)):
1552 1559 jobs = [jobs]
1553 1560 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1554 1561 if bad_ids:
1555 1562 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1556 1563 for j in jobs:
1557 1564 if isinstance(j, AsyncResult):
1558 1565 msg_ids.extend(j.msg_ids)
1559 1566 else:
1560 1567 msg_ids.append(j)
1561 1568
1562 1569 content = dict(targets=targets, msg_ids=msg_ids)
1563 1570 self.session.send(self._query_socket, "purge_request", content=content)
1564 1571 idents, msg = self.session.recv(self._query_socket, 0)
1565 1572 if self.debug:
1566 1573 pprint(msg)
1567 1574 content = msg['content']
1568 1575 if content['status'] != 'ok':
1569 1576 raise self._unwrap_exception(content)
1570 1577
1571 1578
1572 1579 __all__ = [ 'Client',
1573 1580 'depend',
1574 1581 'require',
1575 1582 'remote',
1576 1583 'parallel',
1577 1584 'RemoteFunction',
1578 1585 'ParallelFunction',
1579 1586 'DirectView',
1580 1587 'LoadBalancedView',
1581 1588 'AsyncResult',
1582 1589 'AsyncMapResult',
1583 1590 'Reference'
1584 1591 ]
@@ -1,105 +1,106 b''
1 1 import sys
2 2 import tempfile
3 3 import time
4 4 from signal import SIGINT
5 5 from multiprocessing import Process
6 6
7 7 from nose import SkipTest
8 8
9 9 from zmq.tests import BaseZMQTestCase
10 10
11 11 from IPython.external.decorator import decorator
12 12
13 13 from IPython.zmq.parallel import error
14 14 from IPython.zmq.parallel.client import Client
15 15 from IPython.zmq.parallel.ipcluster import launch_process
16 16 from IPython.zmq.parallel.entry_point import select_random_ports
17 17 from IPython.zmq.parallel.tests import processes,add_engine
18 18
19 19 # simple tasks for use in apply tests
20 20
21 21 def segfault():
22 22 """this will segfault"""
23 23 import ctypes
24 24 ctypes.memset(-1,0,1)
25 25
26 26 def wait(n):
27 27 """sleep for a time"""
28 28 import time
29 29 time.sleep(n)
30 30 return n
31 31
32 32 def raiser(eclass):
33 33 """raise an exception"""
34 34 raise eclass()
35 35
36 36 # test decorator for skipping tests when libraries are unavailable
37 37 def skip_without(*names):
38 38 """skip a test if some names are not importable"""
39 39 @decorator
40 40 def skip_without_names(f, *args, **kwargs):
41 41 """decorator to skip tests in the absence of numpy."""
42 42 for name in names:
43 43 try:
44 44 __import__(name)
45 45 except ImportError:
46 46 raise SkipTest
47 47 return f(*args, **kwargs)
48 48 return skip_without_names
49 49
50 50
51 51 class ClusterTestCase(BaseZMQTestCase):
52 52
53 53 def add_engines(self, n=1, block=True):
54 54 """add multiple engines to our cluster"""
55 55 for i in range(n):
56 56 self.engines.append(add_engine())
57 57 if block:
58 58 self.wait_on_engines()
59 59
60 60 def wait_on_engines(self, timeout=5):
61 61 """wait for our engines to connect."""
62 62 n = len(self.engines)+self.base_engine_count
63 63 tic = time.time()
64 64 while time.time()-tic < timeout and len(self.client.ids) < n:
65 65 time.sleep(0.1)
66 66
67 67 assert not len(self.client.ids) < n, "waiting for engines timed out"
68 68
69 69 def connect_client(self):
70 70 """connect a client with my Context, and track its sockets for cleanup"""
71 71 c = Client(profile='iptest',context=self.context)
72 for name in filter(lambda n:n.endswith('socket'), dir(c)):
73 self.sockets.append(getattr(c, name))
72
73 # for name in filter(lambda n:n.endswith('socket'), dir(c)):
74 # self.sockets.append(getattr(c, name))
74 75 return c
75 76
76 77 def assertRaisesRemote(self, etype, f, *args, **kwargs):
77 78 try:
78 79 try:
79 80 f(*args, **kwargs)
80 81 except error.CompositeError as e:
81 82 e.raise_exception()
82 83 except error.RemoteError as e:
83 84 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(e.ename, etype.__name__))
84 85 else:
85 86 self.fail("should have raised a RemoteError")
86 87
87 88 def setUp(self):
88 89 BaseZMQTestCase.setUp(self)
89 90 self.client = self.connect_client()
90 91 self.base_engine_count=len(self.client.ids)
91 92 self.engines=[]
92 93
93 94 def tearDown(self):
94 95
95 96 # close fds:
96 97 for e in filter(lambda e: e.poll() is not None, processes):
97 98 processes.remove(e)
98 99
99 100 self.client.close()
100 101 BaseZMQTestCase.tearDown(self)
101 102 # this will be superfluous when pyzmq merges PR #88
102 103 self.context.term()
103 print tempfile.TemporaryFile().fileno(),
104 sys.stdout.flush()
104 # print tempfile.TemporaryFile().fileno(),
105 # sys.stdout.flush()
105 106 No newline at end of file
@@ -1,262 +1,262 b''
1 1 import time
2 2 from tempfile import mktemp
3 3
4 import nose.tools as nt
5 4 import zmq
6 5
7 6 from IPython.zmq.parallel import client as clientmod
8 7 from IPython.zmq.parallel import error
9 8 from IPython.zmq.parallel.asyncresult import AsyncResult, AsyncHubResult
10 9 from IPython.zmq.parallel.view import LoadBalancedView, DirectView
11 10
12 11 from clienttest import ClusterTestCase, segfault, wait
13 12
14 13 class TestClient(ClusterTestCase):
15 14
16 15 def test_ids(self):
17 16 n = len(self.client.ids)
18 17 self.add_engines(3)
19 18 self.assertEquals(len(self.client.ids), n+3)
20 19
21 20 def test_segfault_task(self):
22 21 """test graceful handling of engine death (balanced)"""
23 22 self.add_engines(1)
24 23 ar = self.client.apply(segfault, block=False)
25 24 self.assertRaisesRemote(error.EngineError, ar.get)
26 25 eid = ar.engine_id
27 26 while eid in self.client.ids:
28 27 time.sleep(.01)
29 28 self.client.spin()
30 29
31 30 def test_segfault_mux(self):
32 31 """test graceful handling of engine death (direct)"""
33 32 self.add_engines(1)
34 33 eid = self.client.ids[-1]
35 34 ar = self.client[eid].apply_async(segfault)
36 35 self.assertRaisesRemote(error.EngineError, ar.get)
37 36 eid = ar.engine_id
38 37 while eid in self.client.ids:
39 38 time.sleep(.01)
40 39 self.client.spin()
41 40
42 41 def test_view_indexing(self):
43 42 """test index access for views"""
44 43 self.add_engines(2)
45 44 targets = self.client._build_targets('all')[-1]
46 45 v = self.client[:]
47 46 self.assertEquals(v.targets, targets)
48 47 t = self.client.ids[2]
49 48 v = self.client[t]
50 49 self.assert_(isinstance(v, DirectView))
51 50 self.assertEquals(v.targets, t)
52 51 t = self.client.ids[2:4]
53 52 v = self.client[t]
54 53 self.assert_(isinstance(v, DirectView))
55 54 self.assertEquals(v.targets, t)
56 55 v = self.client[::2]
57 56 self.assert_(isinstance(v, DirectView))
58 57 self.assertEquals(v.targets, targets[::2])
59 58 v = self.client[1::3]
60 59 self.assert_(isinstance(v, DirectView))
61 60 self.assertEquals(v.targets, targets[1::3])
62 61 v = self.client[:-3]
63 62 self.assert_(isinstance(v, DirectView))
64 63 self.assertEquals(v.targets, targets[:-3])
65 64 v = self.client[-1]
66 65 self.assert_(isinstance(v, DirectView))
67 66 self.assertEquals(v.targets, targets[-1])
68 nt.assert_raises(TypeError, lambda : self.client[None])
67 self.assertRaises(TypeError, lambda : self.client[None])
69 68
70 69 def test_view_cache(self):
71 70 """test that multiple view requests return the same object"""
72 71 v = self.client[:2]
73 72 v2 =self.client[:2]
74 73 self.assertTrue(v is v2)
75 74 v = self.client.view()
76 75 v2 = self.client.view(balanced=True)
77 76 self.assertTrue(v is v2)
78 77
79 78 def test_targets(self):
80 79 """test various valid targets arguments"""
81 80 build = self.client._build_targets
82 81 ids = self.client.ids
83 82 idents,targets = build(None)
84 83 self.assertEquals(ids, targets)
85 84
86 85 def test_clear(self):
87 86 """test clear behavior"""
88 87 self.add_engines(2)
89 88 self.client.block=True
90 89 self.client.push(dict(a=5))
91 90 self.client.pull('a')
92 91 id0 = self.client.ids[-1]
93 92 self.client.clear(targets=id0)
94 93 self.client.pull('a', targets=self.client.ids[:-1])
95 94 self.assertRaisesRemote(NameError, self.client.pull, 'a')
96 95 self.client.clear()
97 96 for i in self.client.ids:
98 97 self.assertRaisesRemote(NameError, self.client.pull, 'a', targets=i)
99 98
100 99
101 100 def test_push_pull(self):
102 101 """test pushing and pulling"""
103 102 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
104 103 t = self.client.ids[-1]
105 104 self.add_engines(2)
106 105 push = self.client.push
107 106 pull = self.client.pull
108 107 self.client.block=True
109 108 nengines = len(self.client)
110 109 push({'data':data}, targets=t)
111 110 d = pull('data', targets=t)
112 111 self.assertEquals(d, data)
113 112 push({'data':data})
114 113 d = pull('data')
115 114 self.assertEquals(d, nengines*[data])
116 115 ar = push({'data':data}, block=False)
117 116 self.assertTrue(isinstance(ar, AsyncResult))
118 117 r = ar.get()
119 118 ar = pull('data', block=False)
120 119 self.assertTrue(isinstance(ar, AsyncResult))
121 120 r = ar.get()
122 121 self.assertEquals(r, nengines*[data])
123 122 push(dict(a=10,b=20))
124 123 r = pull(('a','b'))
125 124 self.assertEquals(r, nengines*[[10,20]])
126 125
127 126 def test_push_pull_function(self):
128 127 "test pushing and pulling functions"
129 128 def testf(x):
130 129 return 2.0*x
131 130
132 131 self.add_engines(4)
133 132 t = self.client.ids[-1]
134 133 self.client.block=True
135 134 push = self.client.push
136 135 pull = self.client.pull
137 136 execute = self.client.execute
138 137 push({'testf':testf}, targets=t)
139 138 r = pull('testf', targets=t)
140 139 self.assertEqual(r(1.0), testf(1.0))
141 140 execute('r = testf(10)', targets=t)
142 141 r = pull('r', targets=t)
143 142 self.assertEquals(r, testf(10))
144 143 ar = push({'testf':testf}, block=False)
145 144 ar.get()
146 145 ar = pull('testf', block=False)
147 146 rlist = ar.get()
148 147 for r in rlist:
149 148 self.assertEqual(r(1.0), testf(1.0))
150 149 execute("def g(x): return x*x", targets=t)
151 150 r = pull(('testf','g'),targets=t)
152 151 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
153 152
154 153 def test_push_function_globals(self):
155 154 """test that pushed functions have access to globals"""
156 155 def geta():
157 156 return a
158 157 self.add_engines(1)
159 158 v = self.client[-1]
160 159 v.block=True
161 160 v['f'] = geta
162 161 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
163 162 v.execute('a=5')
164 163 v.execute('b=f()')
165 164 self.assertEquals(v['b'], 5)
166 165
167 166 def test_push_function_defaults(self):
168 167 """test that pushed functions preserve default args"""
169 168 def echo(a=10):
170 169 return a
171 170 self.add_engines(1)
172 171 v = self.client[-1]
173 172 v.block=True
174 173 v['f'] = echo
175 174 v.execute('b=f()')
176 175 self.assertEquals(v['b'], 10)
177 176
178 177 def test_get_result(self):
179 178 """test getting results from the Hub."""
180 179 c = clientmod.Client(profile='iptest')
181 180 self.add_engines(1)
181 t = c.ids[-1]
182 182 ar = c.apply(wait, (1,), block=False, targets=t)
183 183 # give the monitor time to notice the message
184 184 time.sleep(.25)
185 185 ahr = self.client.get_result(ar.msg_ids)
186 186 self.assertTrue(isinstance(ahr, AsyncHubResult))
187 187 self.assertEquals(ahr.get(), ar.get())
188 188 ar2 = self.client.get_result(ar.msg_ids)
189 189 self.assertFalse(isinstance(ar2, AsyncHubResult))
190 190
191 191 def test_ids_list(self):
192 192 """test client.ids"""
193 193 self.add_engines(2)
194 194 ids = self.client.ids
195 195 self.assertEquals(ids, self.client._ids)
196 196 self.assertFalse(ids is self.client._ids)
197 197 ids.remove(ids[-1])
198 198 self.assertNotEquals(ids, self.client._ids)
199 199
200 200 def test_run_newline(self):
201 201 """test that run appends newline to files"""
202 202 tmpfile = mktemp()
203 203 with open(tmpfile, 'w') as f:
204 204 f.write("""def g():
205 205 return 5
206 206 """)
207 207 v = self.client[-1]
208 208 v.run(tmpfile, block=True)
209 209 self.assertEquals(v.apply_sync(lambda : g()), 5)
210 210
211 211 def test_apply_tracked(self):
212 212 """test tracking for apply"""
213 213 # self.add_engines(1)
214 214 t = self.client.ids[-1]
215 215 self.client.block=False
216 216 def echo(n=1024*1024, **kwargs):
217 217 return self.client.apply(lambda x: x, args=('x'*n,), targets=t, **kwargs)
218 218 ar = echo(1)
219 219 self.assertTrue(ar._tracker is None)
220 220 self.assertTrue(ar.sent)
221 221 ar = echo(track=True)
222 222 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
223 223 self.assertEquals(ar.sent, ar._tracker.done)
224 224 ar._tracker.wait()
225 225 self.assertTrue(ar.sent)
226 226
227 227 def test_push_tracked(self):
228 228 t = self.client.ids[-1]
229 229 ns = dict(x='x'*1024*1024)
230 230 ar = self.client.push(ns, targets=t, block=False)
231 231 self.assertTrue(ar._tracker is None)
232 232 self.assertTrue(ar.sent)
233 233
234 234 ar = self.client.push(ns, targets=t, block=False, track=True)
235 235 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
236 236 self.assertEquals(ar.sent, ar._tracker.done)
237 237 ar._tracker.wait()
238 238 self.assertTrue(ar.sent)
239 239 ar.get()
240 240
241 241 def test_scatter_tracked(self):
242 242 t = self.client.ids
243 243 x='x'*1024*1024
244 244 ar = self.client.scatter('x', x, targets=t, block=False)
245 245 self.assertTrue(ar._tracker is None)
246 246 self.assertTrue(ar.sent)
247 247
248 248 ar = self.client.scatter('x', x, targets=t, block=False, track=True)
249 249 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
250 250 self.assertEquals(ar.sent, ar._tracker.done)
251 251 ar._tracker.wait()
252 252 self.assertTrue(ar.sent)
253 253 ar.get()
254 254
255 255 def test_remote_reference(self):
256 256 v = self.client[-1]
257 257 v['a'] = 123
258 258 ra = clientmod.Reference('a')
259 259 b = v.apply_sync(lambda x: x, ra)
260 260 self.assertEquals(b, 123)
261 261
262 262
@@ -1,89 +1,87 b''
1 1 """test serialization with newserialized"""
2 2
3 3 from unittest import TestCase
4 4
5 import nose.tools as nt
6
7 5 from IPython.testing.parametric import parametric
8 6 from IPython.utils import newserialized as ns
9 7 from IPython.utils.pickleutil import can, uncan, CannedObject, CannedFunction
10 8 from IPython.zmq.parallel.tests.clienttest import skip_without
11 9
12 10
13 11 class CanningTestCase(TestCase):
14 12 def test_canning(self):
15 13 d = dict(a=5,b=6)
16 14 cd = can(d)
17 nt.assert_true(isinstance(cd, dict))
15 self.assertTrue(isinstance(cd, dict))
18 16
19 17 def test_canned_function(self):
20 18 f = lambda : 7
21 19 cf = can(f)
22 nt.assert_true(isinstance(cf, CannedFunction))
20 self.assertTrue(isinstance(cf, CannedFunction))
23 21
24 22 @parametric
25 23 def test_can_roundtrip(cls):
26 24 objs = [
27 25 dict(),
28 26 set(),
29 27 list(),
30 28 ['a',1,['a',1],u'e'],
31 29 ]
32 30 return map(cls.run_roundtrip, objs)
33 31
34 32 @classmethod
35 def run_roundtrip(cls, obj):
33 def run_roundtrip(self, obj):
36 34 o = uncan(can(obj))
37 nt.assert_equals(obj, o)
35 assert o == obj, "failed assertion: %r == %r"%(o,obj)
38 36
39 37 def test_serialized_interfaces(self):
40 38
41 39 us = {'a':10, 'b':range(10)}
42 40 s = ns.serialize(us)
43 41 uus = ns.unserialize(s)
44 nt.assert_true(isinstance(s, ns.SerializeIt))
45 nt.assert_equals(uus, us)
42 self.assertTrue(isinstance(s, ns.SerializeIt))
43 self.assertEquals(uus, us)
46 44
47 45 def test_pickle_serialized(self):
48 46 obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L}
49 47 original = ns.UnSerialized(obj)
50 48 originalSer = ns.SerializeIt(original)
51 49 firstData = originalSer.getData()
52 50 firstTD = originalSer.getTypeDescriptor()
53 51 firstMD = originalSer.getMetadata()
54 nt.assert_equals(firstTD, 'pickle')
55 nt.assert_equals(firstMD, {})
52 self.assertEquals(firstTD, 'pickle')
53 self.assertEquals(firstMD, {})
56 54 unSerialized = ns.UnSerializeIt(originalSer)
57 55 secondObj = unSerialized.getObject()
58 56 for k, v in secondObj.iteritems():
59 nt.assert_equals(obj[k], v)
57 self.assertEquals(obj[k], v)
60 58 secondSer = ns.SerializeIt(ns.UnSerialized(secondObj))
61 nt.assert_equals(firstData, secondSer.getData())
62 nt.assert_equals(firstTD, secondSer.getTypeDescriptor() )
63 nt.assert_equals(firstMD, secondSer.getMetadata())
59 self.assertEquals(firstData, secondSer.getData())
60 self.assertEquals(firstTD, secondSer.getTypeDescriptor() )
61 self.assertEquals(firstMD, secondSer.getMetadata())
64 62
65 63 @skip_without('numpy')
66 64 def test_ndarray_serialized(self):
67 65 import numpy
68 66 a = numpy.linspace(0.0, 1.0, 1000)
69 67 unSer1 = ns.UnSerialized(a)
70 68 ser1 = ns.SerializeIt(unSer1)
71 69 td = ser1.getTypeDescriptor()
72 nt.assert_equals(td, 'ndarray')
70 self.assertEquals(td, 'ndarray')
73 71 md = ser1.getMetadata()
74 nt.assert_equals(md['shape'], a.shape)
75 nt.assert_equals(md['dtype'], a.dtype.str)
72 self.assertEquals(md['shape'], a.shape)
73 self.assertEquals(md['dtype'], a.dtype.str)
76 74 buff = ser1.getData()
77 nt.assert_equals(buff, numpy.getbuffer(a))
75 self.assertEquals(buff, numpy.getbuffer(a))
78 76 s = ns.Serialized(buff, td, md)
79 77 final = ns.unserialize(s)
80 nt.assert_equals(numpy.getbuffer(a), numpy.getbuffer(final))
81 nt.assert_true((a==final).all())
82 nt.assert_equals(a.dtype.str, final.dtype.str)
83 nt.assert_equals(a.shape, final.shape)
78 self.assertEquals(numpy.getbuffer(a), numpy.getbuffer(final))
79 self.assertTrue((a==final).all())
80 self.assertEquals(a.dtype.str, final.dtype.str)
81 self.assertEquals(a.shape, final.shape)
84 82 # test non-copying:
85 83 a[2] = 1e9
86 nt.assert_true((a==final).all())
84 self.assertTrue((a==final).all())
87 85
88 86
89 87 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now