##// END OF EJS Templates
support use_cloudpickle in addition to use_dill
James Porter -
Show More
@@ -1,1125 +1,1131 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import imp
20 20 import sys
21 21 import warnings
22 22 from contextlib import contextmanager
23 23 from types import ModuleType
24 24
25 25 import zmq
26 26
27 27 from IPython.testing.skipdoctest import skip_doctest
28 28 from IPython.utils import pickleutil
29 29 from IPython.utils.traitlets import (
30 30 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
31 31 )
32 32 from IPython.external.decorator import decorator
33 33
34 34 from IPython.parallel import util
35 35 from IPython.parallel.controller.dependency import Dependency, dependent
36 36 from IPython.utils.py3compat import string_types, iteritems, PY3
37 37
38 38 from . import map as Map
39 39 from .asyncresult import AsyncResult, AsyncMapResult
40 40 from .remotefunction import ParallelFunction, parallel, remote, getname
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Decorators
44 44 #-----------------------------------------------------------------------------
45 45
46 46 @decorator
47 47 def save_ids(f, self, *args, **kwargs):
48 48 """Keep our history and outstanding attributes up to date after a method call."""
49 49 n_previous = len(self.client.history)
50 50 try:
51 51 ret = f(self, *args, **kwargs)
52 52 finally:
53 53 nmsgs = len(self.client.history) - n_previous
54 54 msg_ids = self.client.history[-nmsgs:]
55 55 self.history.extend(msg_ids)
56 56 self.outstanding.update(msg_ids)
57 57 return ret
58 58
59 59 @decorator
60 60 def sync_results(f, self, *args, **kwargs):
61 61 """sync relevant results from self.client to our results attribute."""
62 62 if self._in_sync_results:
63 63 return f(self, *args, **kwargs)
64 64 self._in_sync_results = True
65 65 try:
66 66 ret = f(self, *args, **kwargs)
67 67 finally:
68 68 self._in_sync_results = False
69 69 self._sync_results()
70 70 return ret
71 71
72 72 @decorator
73 73 def spin_after(f, self, *args, **kwargs):
74 74 """call spin after the method."""
75 75 ret = f(self, *args, **kwargs)
76 76 self.spin()
77 77 return ret
78 78
79 79 #-----------------------------------------------------------------------------
80 80 # Classes
81 81 #-----------------------------------------------------------------------------
82 82
83 83 @skip_doctest
84 84 class View(HasTraits):
85 85 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
86 86
87 87 Don't use this class, use subclasses.
88 88
89 89 Methods
90 90 -------
91 91
92 92 spin
93 93 flushes incoming results and registration state changes
94 94 control methods spin, and requesting `ids` also ensures up to date
95 95
96 96 wait
97 97 wait on one or more msg_ids
98 98
99 99 execution methods
100 100 apply
101 101 legacy: execute, run
102 102
103 103 data movement
104 104 push, pull, scatter, gather
105 105
106 106 query methods
107 107 get_result, queue_status, purge_results, result_status
108 108
109 109 control methods
110 110 abort, shutdown
111 111
112 112 """
113 113 # flags
114 114 block=Bool(False)
115 115 track=Bool(True)
116 116 targets = Any()
117 117
118 118 history=List()
119 119 outstanding = Set()
120 120 results = Dict()
121 121 client = Instance('IPython.parallel.Client')
122 122
123 123 _socket = Instance('zmq.Socket')
124 124 _flag_names = List(['targets', 'block', 'track'])
125 125 _in_sync_results = Bool(False)
126 126 _targets = Any()
127 127 _idents = Any()
128 128
129 129 def __init__(self, client=None, socket=None, **flags):
130 130 super(View, self).__init__(client=client, _socket=socket)
131 131 self.results = client.results
132 132 self.block = client.block
133 133
134 134 self.set_flags(**flags)
135 135
136 136 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
137 137
138 138 def __repr__(self):
139 139 strtargets = str(self.targets)
140 140 if len(strtargets) > 16:
141 141 strtargets = strtargets[:12]+'...]'
142 142 return "<%s %s>"%(self.__class__.__name__, strtargets)
143 143
144 144 def __len__(self):
145 145 if isinstance(self.targets, list):
146 146 return len(self.targets)
147 147 elif isinstance(self.targets, int):
148 148 return 1
149 149 else:
150 150 return len(self.client)
151 151
152 152 def set_flags(self, **kwargs):
153 153 """set my attribute flags by keyword.
154 154
155 155 Views determine behavior with a few attributes (`block`, `track`, etc.).
156 156 These attributes can be set all at once by name with this method.
157 157
158 158 Parameters
159 159 ----------
160 160
161 161 block : bool
162 162 whether to wait for results
163 163 track : bool
164 164 whether to create a MessageTracker to allow the user to
165 165 safely edit after arrays and buffers during non-copying
166 166 sends.
167 167 """
168 168 for name, value in iteritems(kwargs):
169 169 if name not in self._flag_names:
170 170 raise KeyError("Invalid name: %r"%name)
171 171 else:
172 172 setattr(self, name, value)
173 173
174 174 @contextmanager
175 175 def temp_flags(self, **kwargs):
176 176 """temporarily set flags, for use in `with` statements.
177 177
178 178 See set_flags for permanent setting of flags
179 179
180 180 Examples
181 181 --------
182 182
183 183 >>> view.track=False
184 184 ...
185 185 >>> with view.temp_flags(track=True):
186 186 ... ar = view.apply(dostuff, my_big_array)
187 187 ... ar.tracker.wait() # wait for send to finish
188 188 >>> view.track
189 189 False
190 190
191 191 """
192 192 # preflight: save flags, and set temporaries
193 193 saved_flags = {}
194 194 for f in self._flag_names:
195 195 saved_flags[f] = getattr(self, f)
196 196 self.set_flags(**kwargs)
197 197 # yield to the with-statement block
198 198 try:
199 199 yield
200 200 finally:
201 201 # postflight: restore saved flags
202 202 self.set_flags(**saved_flags)
203 203
204 204
205 205 #----------------------------------------------------------------
206 206 # apply
207 207 #----------------------------------------------------------------
208 208
209 209 def _sync_results(self):
210 210 """to be called by @sync_results decorator
211 211
212 212 after submitting any tasks.
213 213 """
214 214 delta = self.outstanding.difference(self.client.outstanding)
215 215 completed = self.outstanding.intersection(delta)
216 216 self.outstanding = self.outstanding.difference(completed)
217 217
218 218 @sync_results
219 219 @save_ids
220 220 def _really_apply(self, f, args, kwargs, block=None, **options):
221 221 """wrapper for client.send_apply_request"""
222 222 raise NotImplementedError("Implement in subclasses")
223 223
224 224 def apply(self, f, *args, **kwargs):
225 225 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
226 226
227 227 This method sets all apply flags via this View's attributes.
228 228
229 229 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
230 230 instance if ``self.block`` is False, otherwise the return value of
231 231 ``f(*args, **kwargs)``.
232 232 """
233 233 return self._really_apply(f, args, kwargs)
234 234
235 235 def apply_async(self, f, *args, **kwargs):
236 236 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
237 237
238 238 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
239 239 """
240 240 return self._really_apply(f, args, kwargs, block=False)
241 241
242 242 @spin_after
243 243 def apply_sync(self, f, *args, **kwargs):
244 244 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
245 245 returning the result.
246 246 """
247 247 return self._really_apply(f, args, kwargs, block=True)
248 248
249 249 #----------------------------------------------------------------
250 250 # wrappers for client and control methods
251 251 #----------------------------------------------------------------
252 252 @sync_results
253 253 def spin(self):
254 254 """spin the client, and sync"""
255 255 self.client.spin()
256 256
257 257 @sync_results
258 258 def wait(self, jobs=None, timeout=-1):
259 259 """waits on one or more `jobs`, for up to `timeout` seconds.
260 260
261 261 Parameters
262 262 ----------
263 263
264 264 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
265 265 ints are indices to self.history
266 266 strs are msg_ids
267 267 default: wait on all outstanding messages
268 268 timeout : float
269 269 a time in seconds, after which to give up.
270 270 default is -1, which means no timeout
271 271
272 272 Returns
273 273 -------
274 274
275 275 True : when all msg_ids are done
276 276 False : timeout reached, some msg_ids still outstanding
277 277 """
278 278 if jobs is None:
279 279 jobs = self.history
280 280 return self.client.wait(jobs, timeout)
281 281
282 282 def abort(self, jobs=None, targets=None, block=None):
283 283 """Abort jobs on my engines.
284 284
285 285 Parameters
286 286 ----------
287 287
288 288 jobs : None, str, list of strs, optional
289 289 if None: abort all jobs.
290 290 else: abort specific msg_id(s).
291 291 """
292 292 block = block if block is not None else self.block
293 293 targets = targets if targets is not None else self.targets
294 294 jobs = jobs if jobs is not None else list(self.outstanding)
295 295
296 296 return self.client.abort(jobs=jobs, targets=targets, block=block)
297 297
298 298 def queue_status(self, targets=None, verbose=False):
299 299 """Fetch the Queue status of my engines"""
300 300 targets = targets if targets is not None else self.targets
301 301 return self.client.queue_status(targets=targets, verbose=verbose)
302 302
303 303 def purge_results(self, jobs=[], targets=[]):
304 304 """Instruct the controller to forget specific results."""
305 305 if targets is None or targets == 'all':
306 306 targets = self.targets
307 307 return self.client.purge_results(jobs=jobs, targets=targets)
308 308
309 309 def shutdown(self, targets=None, restart=False, hub=False, block=None):
310 310 """Terminates one or more engine processes, optionally including the hub.
311 311 """
312 312 block = self.block if block is None else block
313 313 if targets is None or targets == 'all':
314 314 targets = self.targets
315 315 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
316 316
317 317 @spin_after
318 318 def get_result(self, indices_or_msg_ids=None):
319 319 """return one or more results, specified by history index or msg_id.
320 320
321 321 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
322 322 """
323 323
324 324 if indices_or_msg_ids is None:
325 325 indices_or_msg_ids = -1
326 326 if isinstance(indices_or_msg_ids, int):
327 327 indices_or_msg_ids = self.history[indices_or_msg_ids]
328 328 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
329 329 indices_or_msg_ids = list(indices_or_msg_ids)
330 330 for i,index in enumerate(indices_or_msg_ids):
331 331 if isinstance(index, int):
332 332 indices_or_msg_ids[i] = self.history[index]
333 333 return self.client.get_result(indices_or_msg_ids)
334 334
335 335 #-------------------------------------------------------------------
336 336 # Map
337 337 #-------------------------------------------------------------------
338 338
339 339 @sync_results
340 340 def map(self, f, *sequences, **kwargs):
341 341 """override in subclasses"""
342 342 raise NotImplementedError
343 343
344 344 def map_async(self, f, *sequences, **kwargs):
345 345 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 346
347 347 This is equivalent to ``map(...block=False)``.
348 348
349 349 See `self.map` for details.
350 350 """
351 351 if 'block' in kwargs:
352 352 raise TypeError("map_async doesn't take a `block` keyword argument.")
353 353 kwargs['block'] = False
354 354 return self.map(f,*sequences,**kwargs)
355 355
356 356 def map_sync(self, f, *sequences, **kwargs):
357 357 """Parallel version of builtin :func:`python:map`, using this view's engines.
358 358
359 359 This is equivalent to ``map(...block=True)``.
360 360
361 361 See `self.map` for details.
362 362 """
363 363 if 'block' in kwargs:
364 364 raise TypeError("map_sync doesn't take a `block` keyword argument.")
365 365 kwargs['block'] = True
366 366 return self.map(f,*sequences,**kwargs)
367 367
368 368 def imap(self, f, *sequences, **kwargs):
369 369 """Parallel version of :func:`itertools.imap`.
370 370
371 371 See `self.map` for details.
372 372
373 373 """
374 374
375 375 return iter(self.map_async(f,*sequences, **kwargs))
376 376
377 377 #-------------------------------------------------------------------
378 378 # Decorators
379 379 #-------------------------------------------------------------------
380 380
381 381 def remote(self, block=None, **flags):
382 382 """Decorator for making a RemoteFunction"""
383 383 block = self.block if block is None else block
384 384 return remote(self, block=block, **flags)
385 385
386 386 def parallel(self, dist='b', block=None, **flags):
387 387 """Decorator for making a ParallelFunction"""
388 388 block = self.block if block is None else block
389 389 return parallel(self, dist=dist, block=block, **flags)
390 390
391 391 @skip_doctest
392 392 class DirectView(View):
393 393 """Direct Multiplexer View of one or more engines.
394 394
395 395 These are created via indexed access to a client:
396 396
397 397 >>> dv_1 = client[1]
398 398 >>> dv_all = client[:]
399 399 >>> dv_even = client[::2]
400 400 >>> dv_some = client[1:3]
401 401
402 402 This object provides dictionary access to engine namespaces:
403 403
404 404 # push a=5:
405 405 >>> dv['a'] = 5
406 406 # pull 'foo':
407 407 >>> dv['foo']
408 408
409 409 """
410 410
411 411 def __init__(self, client=None, socket=None, targets=None):
412 412 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
413 413
414 414 @property
415 415 def importer(self):
416 416 """sync_imports(local=True) as a property.
417 417
418 418 See sync_imports for details.
419 419
420 420 """
421 421 return self.sync_imports(True)
422 422
423 423 @contextmanager
424 424 def sync_imports(self, local=True, quiet=False):
425 425 """Context Manager for performing simultaneous local and remote imports.
426 426
427 427 'import x as y' will *not* work. The 'as y' part will simply be ignored.
428 428
429 429 If `local=True`, then the package will also be imported locally.
430 430
431 431 If `quiet=True`, no output will be produced when attempting remote
432 432 imports.
433 433
434 434 Note that remote-only (`local=False`) imports have not been implemented.
435 435
436 436 >>> with view.sync_imports():
437 437 ... from numpy import recarray
438 438 importing recarray from numpy on engine(s)
439 439
440 440 """
441 441 from IPython.utils.py3compat import builtin_mod
442 442 local_import = builtin_mod.__import__
443 443 modules = set()
444 444 results = []
445 445 @util.interactive
446 446 def remote_import(name, fromlist, level):
447 447 """the function to be passed to apply, that actually performs the import
448 448 on the engine, and loads up the user namespace.
449 449 """
450 450 import sys
451 451 user_ns = globals()
452 452 mod = __import__(name, fromlist=fromlist, level=level)
453 453 if fromlist:
454 454 for key in fromlist:
455 455 user_ns[key] = getattr(mod, key)
456 456 else:
457 457 user_ns[name] = sys.modules[name]
458 458
459 459 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
460 460 """the drop-in replacement for __import__, that optionally imports
461 461 locally as well.
462 462 """
463 463 # don't override nested imports
464 464 save_import = builtin_mod.__import__
465 465 builtin_mod.__import__ = local_import
466 466
467 467 if imp.lock_held():
468 468 # this is a side-effect import, don't do it remotely, or even
469 469 # ignore the local effects
470 470 return local_import(name, globals, locals, fromlist, level)
471 471
472 472 imp.acquire_lock()
473 473 if local:
474 474 mod = local_import(name, globals, locals, fromlist, level)
475 475 else:
476 476 raise NotImplementedError("remote-only imports not yet implemented")
477 477 imp.release_lock()
478 478
479 479 key = name+':'+','.join(fromlist or [])
480 480 if level <= 0 and key not in modules:
481 481 modules.add(key)
482 482 if not quiet:
483 483 if fromlist:
484 484 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
485 485 else:
486 486 print("importing %s on engine(s)"%name)
487 487 results.append(self.apply_async(remote_import, name, fromlist, level))
488 488 # restore override
489 489 builtin_mod.__import__ = save_import
490 490
491 491 return mod
492 492
493 493 # override __import__
494 494 builtin_mod.__import__ = view_import
495 495 try:
496 496 # enter the block
497 497 yield
498 498 except ImportError:
499 499 if local:
500 500 raise
501 501 else:
502 502 # ignore import errors if not doing local imports
503 503 pass
504 504 finally:
505 505 # always restore __import__
506 506 builtin_mod.__import__ = local_import
507 507
508 508 for r in results:
509 509 # raise possible remote ImportErrors here
510 510 r.get()
511 511
512 512 def use_dill(self):
513 513 """Expand serialization support with dill
514 514
515 515 adds support for closures, etc.
516 516
517 517 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
518 518 """
519 519 pickleutil.use_dill()
520 520 return self.apply(pickleutil.use_dill)
521 521
522 def use_cloudpickle(self):
523 """Expand serialization support with cloudpickle.
524 """
525 pickleutil.use_cloudpickle()
526 return self.apply(pickleutil.use_cloudpickle)
527
522 528
523 529 @sync_results
524 530 @save_ids
525 531 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
526 532 """calls f(*args, **kwargs) on remote engines, returning the result.
527 533
528 534 This method sets all of `apply`'s flags via this View's attributes.
529 535
530 536 Parameters
531 537 ----------
532 538
533 539 f : callable
534 540
535 541 args : list [default: empty]
536 542
537 543 kwargs : dict [default: empty]
538 544
539 545 targets : target list [default: self.targets]
540 546 where to run
541 547 block : bool [default: self.block]
542 548 whether to block
543 549 track : bool [default: self.track]
544 550 whether to ask zmq to track the message, for safe non-copying sends
545 551
546 552 Returns
547 553 -------
548 554
549 555 if self.block is False:
550 556 returns AsyncResult
551 557 else:
552 558 returns actual result of f(*args, **kwargs) on the engine(s)
553 559 This will be a list of self.targets is also a list (even length 1), or
554 560 the single result if self.targets is an integer engine id
555 561 """
556 562 args = [] if args is None else args
557 563 kwargs = {} if kwargs is None else kwargs
558 564 block = self.block if block is None else block
559 565 track = self.track if track is None else track
560 566 targets = self.targets if targets is None else targets
561 567
562 568 _idents, _targets = self.client._build_targets(targets)
563 569 msg_ids = []
564 570 trackers = []
565 571 for ident in _idents:
566 572 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
567 573 ident=ident)
568 574 if track:
569 575 trackers.append(msg['tracker'])
570 576 msg_ids.append(msg['header']['msg_id'])
571 577 if isinstance(targets, int):
572 578 msg_ids = msg_ids[0]
573 579 tracker = None if track is False else zmq.MessageTracker(*trackers)
574 580 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, tracker=tracker)
575 581 if block:
576 582 try:
577 583 return ar.get()
578 584 except KeyboardInterrupt:
579 585 pass
580 586 return ar
581 587
582 588
583 589 @sync_results
584 590 def map(self, f, *sequences, **kwargs):
585 591 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
586 592
587 593 Parallel version of builtin `map`, using this View's `targets`.
588 594
589 595 There will be one task per target, so work will be chunked
590 596 if the sequences are longer than `targets`.
591 597
592 598 Results can be iterated as they are ready, but will become available in chunks.
593 599
594 600 Parameters
595 601 ----------
596 602
597 603 f : callable
598 604 function to be mapped
599 605 *sequences: one or more sequences of matching length
600 606 the sequences to be distributed and passed to `f`
601 607 block : bool
602 608 whether to wait for the result or not [default self.block]
603 609
604 610 Returns
605 611 -------
606 612
607 613
608 614 If block=False
609 615 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
610 616 An object like AsyncResult, but which reassembles the sequence of results
611 617 into a single list. AsyncMapResults can be iterated through before all
612 618 results are complete.
613 619 else
614 620 A list, the result of ``map(f,*sequences)``
615 621 """
616 622
617 623 block = kwargs.pop('block', self.block)
618 624 for k in kwargs.keys():
619 625 if k not in ['block', 'track']:
620 626 raise TypeError("invalid keyword arg, %r"%k)
621 627
622 628 assert len(sequences) > 0, "must have some sequences to map onto!"
623 629 pf = ParallelFunction(self, f, block=block, **kwargs)
624 630 return pf.map(*sequences)
625 631
626 632 @sync_results
627 633 @save_ids
628 634 def execute(self, code, silent=True, targets=None, block=None):
629 635 """Executes `code` on `targets` in blocking or nonblocking manner.
630 636
631 637 ``execute`` is always `bound` (affects engine namespace)
632 638
633 639 Parameters
634 640 ----------
635 641
636 642 code : str
637 643 the code string to be executed
638 644 block : bool
639 645 whether or not to wait until done to return
640 646 default: self.block
641 647 """
642 648 block = self.block if block is None else block
643 649 targets = self.targets if targets is None else targets
644 650
645 651 _idents, _targets = self.client._build_targets(targets)
646 652 msg_ids = []
647 653 trackers = []
648 654 for ident in _idents:
649 655 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
650 656 msg_ids.append(msg['header']['msg_id'])
651 657 if isinstance(targets, int):
652 658 msg_ids = msg_ids[0]
653 659 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets)
654 660 if block:
655 661 try:
656 662 ar.get()
657 663 except KeyboardInterrupt:
658 664 pass
659 665 return ar
660 666
661 667 def run(self, filename, targets=None, block=None):
662 668 """Execute contents of `filename` on my engine(s).
663 669
664 670 This simply reads the contents of the file and calls `execute`.
665 671
666 672 Parameters
667 673 ----------
668 674
669 675 filename : str
670 676 The path to the file
671 677 targets : int/str/list of ints/strs
672 678 the engines on which to execute
673 679 default : all
674 680 block : bool
675 681 whether or not to wait until done
676 682 default: self.block
677 683
678 684 """
679 685 with open(filename, 'r') as f:
680 686 # add newline in case of trailing indented whitespace
681 687 # which will cause SyntaxError
682 688 code = f.read()+'\n'
683 689 return self.execute(code, block=block, targets=targets)
684 690
685 691 def update(self, ns):
686 692 """update remote namespace with dict `ns`
687 693
688 694 See `push` for details.
689 695 """
690 696 return self.push(ns, block=self.block, track=self.track)
691 697
692 698 def push(self, ns, targets=None, block=None, track=None):
693 699 """update remote namespace with dict `ns`
694 700
695 701 Parameters
696 702 ----------
697 703
698 704 ns : dict
699 705 dict of keys with which to update engine namespace(s)
700 706 block : bool [default : self.block]
701 707 whether to wait to be notified of engine receipt
702 708
703 709 """
704 710
705 711 block = block if block is not None else self.block
706 712 track = track if track is not None else self.track
707 713 targets = targets if targets is not None else self.targets
708 714 # applier = self.apply_sync if block else self.apply_async
709 715 if not isinstance(ns, dict):
710 716 raise TypeError("Must be a dict, not %s"%type(ns))
711 717 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
712 718
713 719 def get(self, key_s):
714 720 """get object(s) by `key_s` from remote namespace
715 721
716 722 see `pull` for details.
717 723 """
718 724 # block = block if block is not None else self.block
719 725 return self.pull(key_s, block=True)
720 726
721 727 def pull(self, names, targets=None, block=None):
722 728 """get object(s) by `name` from remote namespace
723 729
724 730 will return one object if it is a key.
725 731 can also take a list of keys, in which case it will return a list of objects.
726 732 """
727 733 block = block if block is not None else self.block
728 734 targets = targets if targets is not None else self.targets
729 735 applier = self.apply_sync if block else self.apply_async
730 736 if isinstance(names, string_types):
731 737 pass
732 738 elif isinstance(names, (list,tuple,set)):
733 739 for key in names:
734 740 if not isinstance(key, string_types):
735 741 raise TypeError("keys must be str, not type %r"%type(key))
736 742 else:
737 743 raise TypeError("names must be strs, not %r"%names)
738 744 return self._really_apply(util._pull, (names,), block=block, targets=targets)
739 745
740 746 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
741 747 """
742 748 Partition a Python sequence and send the partitions to a set of engines.
743 749 """
744 750 block = block if block is not None else self.block
745 751 track = track if track is not None else self.track
746 752 targets = targets if targets is not None else self.targets
747 753
748 754 # construct integer ID list:
749 755 targets = self.client._build_targets(targets)[1]
750 756
751 757 mapObject = Map.dists[dist]()
752 758 nparts = len(targets)
753 759 msg_ids = []
754 760 trackers = []
755 761 for index, engineid in enumerate(targets):
756 762 partition = mapObject.getPartition(seq, index, nparts)
757 763 if flatten and len(partition) == 1:
758 764 ns = {key: partition[0]}
759 765 else:
760 766 ns = {key: partition}
761 767 r = self.push(ns, block=False, track=track, targets=engineid)
762 768 msg_ids.extend(r.msg_ids)
763 769 if track:
764 770 trackers.append(r._tracker)
765 771
766 772 if track:
767 773 tracker = zmq.MessageTracker(*trackers)
768 774 else:
769 775 tracker = None
770 776
771 777 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
772 778 if block:
773 779 r.wait()
774 780 else:
775 781 return r
776 782
777 783 @sync_results
778 784 @save_ids
779 785 def gather(self, key, dist='b', targets=None, block=None):
780 786 """
781 787 Gather a partitioned sequence on a set of engines as a single local seq.
782 788 """
783 789 block = block if block is not None else self.block
784 790 targets = targets if targets is not None else self.targets
785 791 mapObject = Map.dists[dist]()
786 792 msg_ids = []
787 793
788 794 # construct integer ID list:
789 795 targets = self.client._build_targets(targets)[1]
790 796
791 797 for index, engineid in enumerate(targets):
792 798 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
793 799
794 800 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
795 801
796 802 if block:
797 803 try:
798 804 return r.get()
799 805 except KeyboardInterrupt:
800 806 pass
801 807 return r
802 808
803 809 def __getitem__(self, key):
804 810 return self.get(key)
805 811
806 812 def __setitem__(self,key, value):
807 813 self.update({key:value})
808 814
809 815 def clear(self, targets=None, block=None):
810 816 """Clear the remote namespaces on my engines."""
811 817 block = block if block is not None else self.block
812 818 targets = targets if targets is not None else self.targets
813 819 return self.client.clear(targets=targets, block=block)
814 820
815 821 #----------------------------------------
816 822 # activate for %px, %autopx, etc. magics
817 823 #----------------------------------------
818 824
819 825 def activate(self, suffix=''):
820 826 """Activate IPython magics associated with this View
821 827
822 828 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
823 829
824 830 Parameters
825 831 ----------
826 832
827 833 suffix: str [default: '']
828 834 The suffix, if any, for the magics. This allows you to have
829 835 multiple views associated with parallel magics at the same time.
830 836
831 837 e.g. ``rc[::2].activate(suffix='_even')`` will give you
832 838 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
833 839 on the even engines.
834 840 """
835 841
836 842 from IPython.parallel.client.magics import ParallelMagics
837 843
838 844 try:
839 845 # This is injected into __builtins__.
840 846 ip = get_ipython()
841 847 except NameError:
842 848 print("The IPython parallel magics (%px, etc.) only work within IPython.")
843 849 return
844 850
845 851 M = ParallelMagics(ip, self, suffix)
846 852 ip.magics_manager.register(M)
847 853
848 854
849 855 @skip_doctest
850 856 class LoadBalancedView(View):
851 857 """An load-balancing View that only executes via the Task scheduler.
852 858
853 859 Load-balanced views can be created with the client's `view` method:
854 860
855 861 >>> v = client.load_balanced_view()
856 862
857 863 or targets can be specified, to restrict the potential destinations:
858 864
859 865 >>> v = client.client.load_balanced_view([1,3])
860 866
861 867 which would restrict loadbalancing to between engines 1 and 3.
862 868
863 869 """
864 870
865 871 follow=Any()
866 872 after=Any()
867 873 timeout=CFloat()
868 874 retries = Integer(0)
869 875
870 876 _task_scheme = Any()
871 877 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
872 878
873 879 def __init__(self, client=None, socket=None, **flags):
874 880 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
875 881 self._task_scheme=client._task_scheme
876 882
877 883 def _validate_dependency(self, dep):
878 884 """validate a dependency.
879 885
880 886 For use in `set_flags`.
881 887 """
882 888 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
883 889 return True
884 890 elif isinstance(dep, (list,set, tuple)):
885 891 for d in dep:
886 892 if not isinstance(d, string_types + (AsyncResult,)):
887 893 return False
888 894 elif isinstance(dep, dict):
889 895 if set(dep.keys()) != set(Dependency().as_dict().keys()):
890 896 return False
891 897 if not isinstance(dep['msg_ids'], list):
892 898 return False
893 899 for d in dep['msg_ids']:
894 900 if not isinstance(d, string_types):
895 901 return False
896 902 else:
897 903 return False
898 904
899 905 return True
900 906
901 907 def _render_dependency(self, dep):
902 908 """helper for building jsonable dependencies from various input forms."""
903 909 if isinstance(dep, Dependency):
904 910 return dep.as_dict()
905 911 elif isinstance(dep, AsyncResult):
906 912 return dep.msg_ids
907 913 elif dep is None:
908 914 return []
909 915 else:
910 916 # pass to Dependency constructor
911 917 return list(Dependency(dep))
912 918
913 919 def set_flags(self, **kwargs):
914 920 """set my attribute flags by keyword.
915 921
916 922 A View is a wrapper for the Client's apply method, but with attributes
917 923 that specify keyword arguments, those attributes can be set by keyword
918 924 argument with this method.
919 925
920 926 Parameters
921 927 ----------
922 928
923 929 block : bool
924 930 whether to wait for results
925 931 track : bool
926 932 whether to create a MessageTracker to allow the user to
927 933 safely edit after arrays and buffers during non-copying
928 934 sends.
929 935
930 936 after : Dependency or collection of msg_ids
931 937 Only for load-balanced execution (targets=None)
932 938 Specify a list of msg_ids as a time-based dependency.
933 939 This job will only be run *after* the dependencies
934 940 have been met.
935 941
936 942 follow : Dependency or collection of msg_ids
937 943 Only for load-balanced execution (targets=None)
938 944 Specify a list of msg_ids as a location-based dependency.
939 945 This job will only be run on an engine where this dependency
940 946 is met.
941 947
942 948 timeout : float/int or None
943 949 Only for load-balanced execution (targets=None)
944 950 Specify an amount of time (in seconds) for the scheduler to
945 951 wait for dependencies to be met before failing with a
946 952 DependencyTimeout.
947 953
948 954 retries : int
949 955 Number of times a task will be retried on failure.
950 956 """
951 957
952 958 super(LoadBalancedView, self).set_flags(**kwargs)
953 959 for name in ('follow', 'after'):
954 960 if name in kwargs:
955 961 value = kwargs[name]
956 962 if self._validate_dependency(value):
957 963 setattr(self, name, value)
958 964 else:
959 965 raise ValueError("Invalid dependency: %r"%value)
960 966 if 'timeout' in kwargs:
961 967 t = kwargs['timeout']
962 968 if not isinstance(t, (int, float, type(None))):
963 969 if (not PY3) and (not isinstance(t, long)):
964 970 raise TypeError("Invalid type for timeout: %r"%type(t))
965 971 if t is not None:
966 972 if t < 0:
967 973 raise ValueError("Invalid timeout: %s"%t)
968 974 self.timeout = t
969 975
970 976 @sync_results
971 977 @save_ids
972 978 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
973 979 after=None, follow=None, timeout=None,
974 980 targets=None, retries=None):
975 981 """calls f(*args, **kwargs) on a remote engine, returning the result.
976 982
977 983 This method temporarily sets all of `apply`'s flags for a single call.
978 984
979 985 Parameters
980 986 ----------
981 987
982 988 f : callable
983 989
984 990 args : list [default: empty]
985 991
986 992 kwargs : dict [default: empty]
987 993
988 994 block : bool [default: self.block]
989 995 whether to block
990 996 track : bool [default: self.track]
991 997 whether to ask zmq to track the message, for safe non-copying sends
992 998
993 999 !!!!!! TODO: THE REST HERE !!!!
994 1000
995 1001 Returns
996 1002 -------
997 1003
998 1004 if self.block is False:
999 1005 returns AsyncResult
1000 1006 else:
1001 1007 returns actual result of f(*args, **kwargs) on the engine(s)
1002 1008 This will be a list of self.targets is also a list (even length 1), or
1003 1009 the single result if self.targets is an integer engine id
1004 1010 """
1005 1011
1006 1012 # validate whether we can run
1007 1013 if self._socket.closed:
1008 1014 msg = "Task farming is disabled"
1009 1015 if self._task_scheme == 'pure':
1010 1016 msg += " because the pure ZMQ scheduler cannot handle"
1011 1017 msg += " disappearing engines."
1012 1018 raise RuntimeError(msg)
1013 1019
1014 1020 if self._task_scheme == 'pure':
1015 1021 # pure zmq scheme doesn't support extra features
1016 1022 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1017 1023 "follow, after, retries, targets, timeout"
1018 1024 if (follow or after or retries or targets or timeout):
1019 1025 # hard fail on Scheduler flags
1020 1026 raise RuntimeError(msg)
1021 1027 if isinstance(f, dependent):
1022 1028 # soft warn on functional dependencies
1023 1029 warnings.warn(msg, RuntimeWarning)
1024 1030
1025 1031 # build args
1026 1032 args = [] if args is None else args
1027 1033 kwargs = {} if kwargs is None else kwargs
1028 1034 block = self.block if block is None else block
1029 1035 track = self.track if track is None else track
1030 1036 after = self.after if after is None else after
1031 1037 retries = self.retries if retries is None else retries
1032 1038 follow = self.follow if follow is None else follow
1033 1039 timeout = self.timeout if timeout is None else timeout
1034 1040 targets = self.targets if targets is None else targets
1035 1041
1036 1042 if not isinstance(retries, int):
1037 1043 raise TypeError('retries must be int, not %r'%type(retries))
1038 1044
1039 1045 if targets is None:
1040 1046 idents = []
1041 1047 else:
1042 1048 idents = self.client._build_targets(targets)[0]
1043 1049 # ensure *not* bytes
1044 1050 idents = [ ident.decode() for ident in idents ]
1045 1051
1046 1052 after = self._render_dependency(after)
1047 1053 follow = self._render_dependency(follow)
1048 1054 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1049 1055
1050 1056 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1051 1057 metadata=metadata)
1052 1058 tracker = None if track is False else msg['tracker']
1053 1059
1054 1060 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1055 1061
1056 1062 if block:
1057 1063 try:
1058 1064 return ar.get()
1059 1065 except KeyboardInterrupt:
1060 1066 pass
1061 1067 return ar
1062 1068
1063 1069 @sync_results
1064 1070 @save_ids
1065 1071 def map(self, f, *sequences, **kwargs):
1066 1072 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067 1073
1068 1074 Parallel version of builtin `map`, load-balanced by this View.
1069 1075
1070 1076 `block`, and `chunksize` can be specified by keyword only.
1071 1077
1072 1078 Each `chunksize` elements will be a separate task, and will be
1073 1079 load-balanced. This lets individual elements be available for iteration
1074 1080 as soon as they arrive.
1075 1081
1076 1082 Parameters
1077 1083 ----------
1078 1084
1079 1085 f : callable
1080 1086 function to be mapped
1081 1087 *sequences: one or more sequences of matching length
1082 1088 the sequences to be distributed and passed to `f`
1083 1089 block : bool [default self.block]
1084 1090 whether to wait for the result or not
1085 1091 track : bool
1086 1092 whether to create a MessageTracker to allow the user to
1087 1093 safely edit after arrays and buffers during non-copying
1088 1094 sends.
1089 1095 chunksize : int [default 1]
1090 1096 how many elements should be in each task.
1091 1097 ordered : bool [default True]
1092 1098 Whether the results should be gathered as they arrive, or enforce
1093 1099 the order of submission.
1094 1100
1095 1101 Only applies when iterating through AsyncMapResult as results arrive.
1096 1102 Has no effect when block=True.
1097 1103
1098 1104 Returns
1099 1105 -------
1100 1106
1101 1107 if block=False
1102 1108 An :class:`~IPython.parallel.client.asyncresult.AsyncMapResult` instance.
1103 1109 An object like AsyncResult, but which reassembles the sequence of results
1104 1110 into a single list. AsyncMapResults can be iterated through before all
1105 1111 results are complete.
1106 1112 else
1107 1113 A list, the result of ``map(f,*sequences)``
1108 1114 """
1109 1115
1110 1116 # default
1111 1117 block = kwargs.get('block', self.block)
1112 1118 chunksize = kwargs.get('chunksize', 1)
1113 1119 ordered = kwargs.get('ordered', True)
1114 1120
1115 1121 keyset = set(kwargs.keys())
1116 1122 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 1123 if extra_keys:
1118 1124 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119 1125
1120 1126 assert len(sequences) > 0, "must have some sequences to map onto!"
1121 1127
1122 1128 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 1129 return pf.map(*sequences)
1124 1130
1125 1131 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,390 +1,410 b''
1 1 # encoding: utf-8
2 2
3 3 """Pickle related utilities. Perhaps this should be called 'can'."""
4 4
5 5 __docformat__ = "restructuredtext en"
6 6
7 7 #-------------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-------------------------------------------------------------------------------
13 13
14 14 #-------------------------------------------------------------------------------
15 15 # Imports
16 16 #-------------------------------------------------------------------------------
17 17
18 18 import copy
19 19 import logging
20 20 import sys
21 21 from types import FunctionType
22 22
23 23 try:
24 24 import cPickle as pickle
25 25 except ImportError:
26 26 import pickle
27 27
28 28 from . import codeutil # This registers a hook when it's imported
29 29 from . import py3compat
30 30 from .importstring import import_item
31 31 from .py3compat import string_types, iteritems
32 32
33 33 from IPython.config import Application
34 34
35 35 if py3compat.PY3:
36 36 buffer = memoryview
37 37 class_type = type
38 38 else:
39 39 from types import ClassType
40 40 class_type = (type, ClassType)
41 41
42 42 #-------------------------------------------------------------------------------
43 43 # Functions
44 44 #-------------------------------------------------------------------------------
45 45
46 46
47 47 def use_dill():
48 48 """use dill to expand serialization support
49 49
50 50 adds support for object methods and closures to serialization.
51 51 """
52 52 # import dill causes most of the magic
53 53 import dill
54 54
55 55 # dill doesn't work with cPickle,
56 56 # tell the two relevant modules to use plain pickle
57 57
58 58 global pickle
59 59 pickle = dill
60 60
61 61 try:
62 62 from IPython.kernel.zmq import serialize
63 63 except ImportError:
64 64 pass
65 65 else:
66 66 serialize.pickle = dill
67 67
68 68 # disable special function handling, let dill take care of it
69 69 can_map.pop(FunctionType, None)
70 70
71 def use_cloudpickle():
72 """use cloudpickle to expand serialization support
73
74 adds support for object methods and closures to serialization.
75 """
76 from cloud.serialization import cloudpickle
77
78 global pickle
79 pickle = cloudpickle
80
81 try:
82 from IPython.kernel.zmq import serialize
83 except ImportError:
84 pass
85 else:
86 serialize.pickle = cloudpickle
87
88 # disable special function handling, let cloudpickle take care of it
89 can_map.pop(FunctionType, None)
90
71 91
72 92 #-------------------------------------------------------------------------------
73 93 # Classes
74 94 #-------------------------------------------------------------------------------
75 95
76 96
77 97 class CannedObject(object):
78 98 def __init__(self, obj, keys=[], hook=None):
79 99 """can an object for safe pickling
80 100
81 101 Parameters
82 102 ==========
83 103
84 104 obj:
85 105 The object to be canned
86 106 keys: list (optional)
87 107 list of attribute names that will be explicitly canned / uncanned
88 108 hook: callable (optional)
89 109 An optional extra callable,
90 110 which can do additional processing of the uncanned object.
91 111
92 112 large data may be offloaded into the buffers list,
93 113 used for zero-copy transfers.
94 114 """
95 115 self.keys = keys
96 116 self.obj = copy.copy(obj)
97 117 self.hook = can(hook)
98 118 for key in keys:
99 119 setattr(self.obj, key, can(getattr(obj, key)))
100 120
101 121 self.buffers = []
102 122
103 123 def get_object(self, g=None):
104 124 if g is None:
105 125 g = {}
106 126 obj = self.obj
107 127 for key in self.keys:
108 128 setattr(obj, key, uncan(getattr(obj, key), g))
109 129
110 130 if self.hook:
111 131 self.hook = uncan(self.hook, g)
112 132 self.hook(obj, g)
113 133 return self.obj
114 134
115 135
116 136 class Reference(CannedObject):
117 137 """object for wrapping a remote reference by name."""
118 138 def __init__(self, name):
119 139 if not isinstance(name, string_types):
120 140 raise TypeError("illegal name: %r"%name)
121 141 self.name = name
122 142 self.buffers = []
123 143
124 144 def __repr__(self):
125 145 return "<Reference: %r>"%self.name
126 146
127 147 def get_object(self, g=None):
128 148 if g is None:
129 149 g = {}
130 150
131 151 return eval(self.name, g)
132 152
133 153
134 154 class CannedFunction(CannedObject):
135 155
136 156 def __init__(self, f):
137 157 self._check_type(f)
138 158 self.code = f.__code__
139 159 if f.__defaults__:
140 160 self.defaults = [ can(fd) for fd in f.__defaults__ ]
141 161 else:
142 162 self.defaults = None
143 163 self.module = f.__module__ or '__main__'
144 164 self.__name__ = f.__name__
145 165 self.buffers = []
146 166
147 167 def _check_type(self, obj):
148 168 assert isinstance(obj, FunctionType), "Not a function type"
149 169
150 170 def get_object(self, g=None):
151 171 # try to load function back into its module:
152 172 if not self.module.startswith('__'):
153 173 __import__(self.module)
154 174 g = sys.modules[self.module].__dict__
155 175
156 176 if g is None:
157 177 g = {}
158 178 if self.defaults:
159 179 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
160 180 else:
161 181 defaults = None
162 182 newFunc = FunctionType(self.code, g, self.__name__, defaults)
163 183 return newFunc
164 184
165 185 class CannedClass(CannedObject):
166 186
167 187 def __init__(self, cls):
168 188 self._check_type(cls)
169 189 self.name = cls.__name__
170 190 self.old_style = not isinstance(cls, type)
171 191 self._canned_dict = {}
172 192 for k,v in cls.__dict__.items():
173 193 if k not in ('__weakref__', '__dict__'):
174 194 self._canned_dict[k] = can(v)
175 195 if self.old_style:
176 196 mro = []
177 197 else:
178 198 mro = cls.mro()
179 199
180 200 self.parents = [ can(c) for c in mro[1:] ]
181 201 self.buffers = []
182 202
183 203 def _check_type(self, obj):
184 204 assert isinstance(obj, class_type), "Not a class type"
185 205
186 206 def get_object(self, g=None):
187 207 parents = tuple(uncan(p, g) for p in self.parents)
188 208 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
189 209
190 210 class CannedArray(CannedObject):
191 211 def __init__(self, obj):
192 212 from numpy import ascontiguousarray
193 213 self.shape = obj.shape
194 214 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
195 215 self.pickled = False
196 216 if sum(obj.shape) == 0:
197 217 self.pickled = True
198 218 elif obj.dtype == 'O':
199 219 # can't handle object dtype with buffer approach
200 220 self.pickled = True
201 221 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
202 222 self.pickled = True
203 223 if self.pickled:
204 224 # just pickle it
205 225 self.buffers = [pickle.dumps(obj, -1)]
206 226 else:
207 227 # ensure contiguous
208 228 obj = ascontiguousarray(obj, dtype=None)
209 229 self.buffers = [buffer(obj)]
210 230
211 231 def get_object(self, g=None):
212 232 from numpy import frombuffer
213 233 data = self.buffers[0]
214 234 if self.pickled:
215 235 # no shape, we just pickled it
216 236 return pickle.loads(data)
217 237 else:
218 238 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
219 239
220 240
221 241 class CannedBytes(CannedObject):
222 242 wrap = bytes
223 243 def __init__(self, obj):
224 244 self.buffers = [obj]
225 245
226 246 def get_object(self, g=None):
227 247 data = self.buffers[0]
228 248 return self.wrap(data)
229 249
230 250 def CannedBuffer(CannedBytes):
231 251 wrap = buffer
232 252
233 253 #-------------------------------------------------------------------------------
234 254 # Functions
235 255 #-------------------------------------------------------------------------------
236 256
237 257 def _logger():
238 258 """get the logger for the current Application
239 259
240 260 the root logger will be used if no Application is running
241 261 """
242 262 if Application.initialized():
243 263 logger = Application.instance().log
244 264 else:
245 265 logger = logging.getLogger()
246 266 if not logger.handlers:
247 267 logging.basicConfig()
248 268
249 269 return logger
250 270
251 271 def _import_mapping(mapping, original=None):
252 272 """import any string-keys in a type mapping
253 273
254 274 """
255 275 log = _logger()
256 276 log.debug("Importing canning map")
257 277 for key,value in list(mapping.items()):
258 278 if isinstance(key, string_types):
259 279 try:
260 280 cls = import_item(key)
261 281 except Exception:
262 282 if original and key not in original:
263 283 # only message on user-added classes
264 284 log.error("canning class not importable: %r", key, exc_info=True)
265 285 mapping.pop(key)
266 286 else:
267 287 mapping[cls] = mapping.pop(key)
268 288
269 289 def istype(obj, check):
270 290 """like isinstance(obj, check), but strict
271 291
272 292 This won't catch subclasses.
273 293 """
274 294 if isinstance(check, tuple):
275 295 for cls in check:
276 296 if type(obj) is cls:
277 297 return True
278 298 return False
279 299 else:
280 300 return type(obj) is check
281 301
282 302 def can(obj):
283 303 """prepare an object for pickling"""
284 304
285 305 import_needed = False
286 306
287 307 for cls,canner in iteritems(can_map):
288 308 if isinstance(cls, string_types):
289 309 import_needed = True
290 310 break
291 311 elif istype(obj, cls):
292 312 return canner(obj)
293 313
294 314 if import_needed:
295 315 # perform can_map imports, then try again
296 316 # this will usually only happen once
297 317 _import_mapping(can_map, _original_can_map)
298 318 return can(obj)
299 319
300 320 return obj
301 321
302 322 def can_class(obj):
303 323 if isinstance(obj, class_type) and obj.__module__ == '__main__':
304 324 return CannedClass(obj)
305 325 else:
306 326 return obj
307 327
308 328 def can_dict(obj):
309 329 """can the *values* of a dict"""
310 330 if istype(obj, dict):
311 331 newobj = {}
312 332 for k, v in iteritems(obj):
313 333 newobj[k] = can(v)
314 334 return newobj
315 335 else:
316 336 return obj
317 337
318 338 sequence_types = (list, tuple, set)
319 339
320 340 def can_sequence(obj):
321 341 """can the elements of a sequence"""
322 342 if istype(obj, sequence_types):
323 343 t = type(obj)
324 344 return t([can(i) for i in obj])
325 345 else:
326 346 return obj
327 347
328 348 def uncan(obj, g=None):
329 349 """invert canning"""
330 350
331 351 import_needed = False
332 352 for cls,uncanner in iteritems(uncan_map):
333 353 if isinstance(cls, string_types):
334 354 import_needed = True
335 355 break
336 356 elif isinstance(obj, cls):
337 357 return uncanner(obj, g)
338 358
339 359 if import_needed:
340 360 # perform uncan_map imports, then try again
341 361 # this will usually only happen once
342 362 _import_mapping(uncan_map, _original_uncan_map)
343 363 return uncan(obj, g)
344 364
345 365 return obj
346 366
347 367 def uncan_dict(obj, g=None):
348 368 if istype(obj, dict):
349 369 newobj = {}
350 370 for k, v in iteritems(obj):
351 371 newobj[k] = uncan(v,g)
352 372 return newobj
353 373 else:
354 374 return obj
355 375
356 376 def uncan_sequence(obj, g=None):
357 377 if istype(obj, sequence_types):
358 378 t = type(obj)
359 379 return t([uncan(i,g) for i in obj])
360 380 else:
361 381 return obj
362 382
363 383 def _uncan_dependent_hook(dep, g=None):
364 384 dep.check_dependency()
365 385
366 386 def can_dependent(obj):
367 387 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
368 388
369 389 #-------------------------------------------------------------------------------
370 390 # API dictionaries
371 391 #-------------------------------------------------------------------------------
372 392
373 393 # These dicts can be extended for custom serialization of new objects
374 394
375 395 can_map = {
376 396 'IPython.parallel.dependent' : can_dependent,
377 397 'numpy.ndarray' : CannedArray,
378 398 FunctionType : CannedFunction,
379 399 bytes : CannedBytes,
380 400 buffer : CannedBuffer,
381 401 class_type : can_class,
382 402 }
383 403
384 404 uncan_map = {
385 405 CannedObject : lambda obj, g: obj.get_object(g),
386 406 }
387 407
388 408 # for use in _import_mapping:
389 409 _original_can_map = can_map.copy()
390 410 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now