##// END OF EJS Templates
cleanup pass
MinRK -
Show More
@@ -0,0 +1,18 b''
1 """The IPython ZMQ-based parallel computing interface."""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
8
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
12
13 from .asyncresult import *
14 from .client import Client
15 from .dependency import *
16 from .remotefunction import *
17 from .view import *
18
@@ -30,7 +30,7 b' def check_ready(f, self, *args, **kwargs):'
30 30 class AsyncResult(object):
31 31 """Class for representing results of non-blocking calls.
32 32
33 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
33 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
34 34 """
35 35
36 36 msg_ids = None
@@ -53,7 +53,8 b' class AsyncResult(object):'
53 53
54 54
55 55 def _reconstruct_result(self, res):
56 """
56 """Reconstruct our result from actual result list (always a list)
57
57 58 Override me in subclasses for turning a list of results
58 59 into the expected form.
59 60 """
@@ -68,7 +69,7 b' class AsyncResult(object):'
68 69 If `timeout` is not ``None`` and the result does not arrive within
69 70 `timeout` seconds then ``TimeoutError`` is raised. If the
70 71 remote call raised an exception then that exception will be reraised
71 by get().
72 by get() inside a `RemoteError`.
72 73 """
73 74 if not self.ready():
74 75 self.wait(timeout)
@@ -89,6 +90,8 b' class AsyncResult(object):'
89 90
90 91 def wait(self, timeout=-1):
91 92 """Wait until the result is available or until `timeout` seconds pass.
93
94 This method always returns None.
92 95 """
93 96 if self._ready:
94 97 return
@@ -118,7 +121,7 b' class AsyncResult(object):'
118 121
119 122 Will raise ``AssertionError`` if the result is not ready.
120 123 """
121 assert self._ready
124 assert self.ready()
122 125 return self._success
123 126
124 127 #----------------------------------------------------------------
@@ -126,7 +129,11 b' class AsyncResult(object):'
126 129 #----------------------------------------------------------------
127 130
128 131 def get_dict(self, timeout=-1):
129 """Get the results as a dict, keyed by engine_id."""
132 """Get the results as a dict, keyed by engine_id.
133
134 timeout behavior is described in `get()`.
135 """
136
130 137 results = self.get(timeout)
131 138 engine_ids = [ md['engine_id'] for md in self._metadata ]
132 139 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
@@ -140,7 +147,7 b' class AsyncResult(object):'
140 147 @property
141 148 @check_ready
142 149 def result(self):
143 """result property."""
150 """result property wrapper for `get(timeout=0)`."""
144 151 return self._result
145 152
146 153 # abbreviated alias:
@@ -149,7 +156,7 b' class AsyncResult(object):'
149 156 @property
150 157 @check_ready
151 158 def metadata(self):
152 """metadata property."""
159 """property for accessing execution metadata."""
153 160 if self._single_result:
154 161 return self._metadata[0]
155 162 else:
@@ -186,7 +193,7 b' class AsyncResult(object):'
186 193
187 194 @check_ready
188 195 def __getattr__(self, key):
189 """getattr maps to getitem for convenient access to metadata."""
196 """getattr maps to getitem for convenient attr access to metadata."""
190 197 if key not in self._metadata[0].keys():
191 198 raise AttributeError("%r object has no attribute %r"%(
192 199 self.__class__.__name__, key))
@@ -249,7 +256,11 b' class AsyncMapResult(AsyncResult):'
249 256
250 257
251 258 class AsyncHubResult(AsyncResult):
252 """Class to wrap pending results that must be requested from the Hub"""
259 """Class to wrap pending results that must be requested from the Hub.
260
261 Note that waiting/polling on these objects requires polling the Hubover the network,
262 so use `AsyncHubResult.wait()` sparingly.
263 """
253 264
254 265 def wait(self, timeout=-1):
255 266 """wait for result to complete."""
@@ -32,12 +32,13 b' from IPython.external.ssh import tunnel'
32 32
33 33 from . import error
34 34 from . import map as Map
35 from . import util
35 36 from . import streamsession as ss
36 37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 38 from .clusterdir import ClusterDir, ClusterDirError
38 39 from .dependency import Dependency, depend, require, dependent
39 40 from .remotefunction import remote,parallel,ParallelFunction,RemoteFunction
40 from .util import ReverseDict, disambiguate_url, validate_url
41 from .util import ReverseDict, validate_url, disambiguate_url
41 42 from .view import DirectView, LoadBalancedView
42 43
43 44 #--------------------------------------------------------------------------
@@ -489,7 +490,7 b' class Client(HasTraits):'
489 490
490 491 def _unwrap_exception(self, content):
491 492 """unwrap exception, and remap engineid to int."""
492 e = ss.unwrap_exception(content)
493 e = error.unwrap_exception(content)
493 494 if e.engine_info:
494 495 e_uuid = e.engine_info['engine_uuid']
495 496 eid = self._engines[e_uuid]
@@ -526,11 +527,11 b' class Client(HasTraits):'
526 527 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
527 528
528 529 if 'date' in parent:
529 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
530 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
530 531 if 'started' in header:
531 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
532 md['started'] = datetime.strptime(header['started'], util.ISO8601)
532 533 if 'date' in header:
533 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
534 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
534 535 return md
535 536
536 537 def _handle_execute_reply(self, msg):
@@ -573,7 +574,7 b' class Client(HasTraits):'
573 574
574 575 # construct result:
575 576 if content['status'] == 'ok':
576 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
577 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
577 578 elif content['status'] == 'aborted':
578 579 self.results[msg_id] = error.AbortedTask(msg_id)
579 580 elif content['status'] == 'resubmitted':
@@ -1055,7 +1056,7 b' class Client(HasTraits):'
1055 1056 after = self._build_dependency(after)
1056 1057 follow = self._build_dependency(follow)
1057 1058 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1058 bufs = ss.pack_apply_message(f,args,kwargs)
1059 bufs = util.pack_apply_message(f,args,kwargs)
1059 1060 content = dict(bound=bound)
1060 1061
1061 1062 msg = self.session.send(self._task_socket, "apply_request",
@@ -1087,7 +1088,7 b' class Client(HasTraits):'
1087 1088
1088 1089 subheader = {}
1089 1090 content = dict(bound=bound)
1090 bufs = ss.pack_apply_message(f,args,kwargs)
1091 bufs = util.pack_apply_message(f,args,kwargs)
1091 1092
1092 1093 msg_ids = []
1093 1094 for ident in idents:
@@ -1399,7 +1400,7 b' class Client(HasTraits):'
1399 1400 md.update(iodict)
1400 1401
1401 1402 if rcontent['status'] == 'ok':
1402 res,buffers = ss.unserialize_object(buffers)
1403 res,buffers = util.unserialize_object(buffers)
1403 1404 else:
1404 1405 print rcontent
1405 1406 res = self._unwrap_exception(rcontent)
@@ -1437,7 +1438,7 b' class Client(HasTraits):'
1437 1438 status = content.pop('status')
1438 1439 if status != 'ok':
1439 1440 raise self._unwrap_exception(content)
1440 return ss.rekey(content)
1441 return util.rekey(content)
1441 1442
1442 1443 @spinfirst
1443 1444 def purge_results(self, jobs=[], targets=[]):
@@ -1495,5 +1496,6 b" __all__ = [ 'Client',"
1495 1496 'DirectView',
1496 1497 'LoadBalancedView',
1497 1498 'AsyncResult',
1498 'AsyncMapResult'
1499 'AsyncMapResult',
1500 'Reference'
1499 1501 ]
@@ -22,7 +22,6 b' import logging'
22 22 import re
23 23 import shutil
24 24 import sys
25 import warnings
26 25
27 26 from IPython.config.loader import PyFileConfigLoader
28 27 from IPython.config.configurable import Configurable
@@ -21,7 +21,7 b' import zmq'
21 21 from zmq.devices import ProcessMonitoredQueue
22 22 # internal:
23 23 from IPython.utils.importstring import import_item
24 from IPython.utils.traitlets import Int, Str, Instance, List, Bool
24 from IPython.utils.traitlets import Int, CStr, Instance, List, Bool
25 25
26 26 from .entry_point import signal_children
27 27 from .hub import Hub, HubFactory
@@ -41,7 +41,7 b' class ControllerFactory(HubFactory):'
41 41
42 42 # internal
43 43 children = List()
44 mq_class = Str('zmq.devices.ProcessMonitoredQueue')
44 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
45 45
46 46 def _usethreads_changed(self, name, old, new):
47 47 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
@@ -7,7 +7,25 b' from .error import UnmetDependency'
7 7
8 8
9 9 class depend(object):
10 """Dependency decorator, for use with tasks."""
10 """Dependency decorator, for use with tasks.
11
12 `@depend` lets you define a function for engine dependencies
13 just like you use `apply` for tasks.
14
15
16 Examples
17 --------
18 ::
19
20 @depend(df, a,b, c=5)
21 def f(m,n,p)
22
23 view.apply(f, 1,2,3)
24
25 will call df(a,b,c=5) on the engine, and if it returns False or
26 raises an UnmetDependency error, then the task will not be run
27 and another engine will be tried.
28 """
11 29 def __init__(self, f, *args, **kwargs):
12 30 self.f = f
13 31 self.args = args
@@ -39,6 +57,7 b' class dependent(object):'
39 57 return self.func_name
40 58
41 59 def _require(*names):
60 """Helper for @require decorator."""
42 61 for name in names:
43 62 try:
44 63 __import__(name)
@@ -47,12 +66,35 b' def _require(*names):'
47 66 return True
48 67
49 68 def require(*names):
69 """Simple decorator for requiring names to be importable.
70
71 Examples
72 --------
73
74 In [1]: @require('numpy')
75 ...: def norm(a):
76 ...: import numpy
77 ...: return numpy.linalg.norm(a,2)
78 """
50 79 return depend(_require, *names)
51 80
52 81 class Dependency(set):
53 82 """An object for representing a set of msg_id dependencies.
54 83
55 Subclassed from set()."""
84 Subclassed from set().
85
86 Parameters
87 ----------
88 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
89 The msg_ids to depend on
90 all : bool [default True]
91 Whether the dependency should be considered met when *all* depending tasks have completed
92 or only when *any* have been completed.
93 success_only : bool [default True]
94 Whether to consider only successes for Dependencies, or consider failures as well.
95 If `all=success_only=True`, then this task will fail with an ImpossibleDependency
96 as soon as the first depended-upon task fails.
97 """
56 98
57 99 all=True
58 100 success_only=True
@@ -45,12 +45,12 b' We support a subset of mongodb operators:'
45 45 from datetime import datetime
46 46
47 47 filters = {
48 '$eq' : lambda a,b: a==b,
49 48 '$lt' : lambda a,b: a < b,
50 49 '$gt' : lambda a,b: b > a,
50 '$eq' : lambda a,b: a == b,
51 '$ne' : lambda a,b: a != b,
51 52 '$lte': lambda a,b: a <= b,
52 53 '$gte': lambda a,b: a >= b,
53 '$ne' : lambda a,b: not a==b,
54 54 '$in' : lambda a,b: a in b,
55 55 '$nin': lambda a,b: a not in b,
56 56 '$all' : lambda a,b: all([ a in bb for bb in b ]),
@@ -1,21 +1,17 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 connected to the Controller's queue(s).
4 connected to the Controller's Schedulers.
5 5 """
6 6 from __future__ import print_function
7 7
8 import logging
9 8 import sys
10 9 import time
11 import uuid
12 from pprint import pprint
13 10
14 11 import zmq
15 12 from zmq.eventloop import ioloop, zmqstream
16 13
17 14 # internal
18 from IPython.config.configurable import Configurable
19 15 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
20 16 # from IPython.utils.localinterfaces import LOCALHOST
21 17
@@ -25,10 +21,6 b' from .streamkernel import Kernel'
25 21 from .streamsession import Message
26 22 from .util import disambiguate_url
27 23
28 def printer(*msg):
29 # print (self.log.handlers, file=sys.__stdout__)
30 self.log.info(str(msg))
31
32 24 class EngineFactory(RegistrationFactory):
33 25 """IPython engine"""
34 26
@@ -3,6 +3,9 b''
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 import sys
7 import traceback
8
6 9 __docformat__ = "restructuredtext en"
7 10
8 11 # Tell nose to skip this module
@@ -290,3 +293,21 b" def collect_exceptions(rdict_or_list, method='unspecified'):"
290 293 except CompositeError as e:
291 294 raise e
292 295
296 def wrap_exception(engine_info={}):
297 etype, evalue, tb = sys.exc_info()
298 stb = traceback.format_exception(etype, evalue, tb)
299 exc_content = {
300 'status' : 'error',
301 'traceback' : stb,
302 'ename' : unicode(etype.__name__),
303 'evalue' : unicode(evalue),
304 'engine_info' : engine_info
305 }
306 return exc_content
307
308 def unwrap_exception(content):
309 err = RemoteError(content['ename'], content['evalue'],
310 ''.join(content['traceback']),
311 content.get('engine_info', {}))
312 return err
313
@@ -31,7 +31,7 b' from IPython.zmq.parallel.entry_point import select_random_ports'
31 31 class LoggingFactory(Configurable):
32 32 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
33 33 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
34 logname = CStr('ZMQ')
34 logname = CUnicode('ZMQ')
35 35 def _logname_changed(self, name, old, new):
36 36 self.log = logging.getLogger(new)
37 37
@@ -44,8 +44,8 b' class SessionFactory(LoggingFactory):'
44 44 ident = CStr('',config=True)
45 45 def _ident_default(self):
46 46 return str(uuid.uuid4())
47 username = Str(os.environ.get('USER','username'),config=True)
48 exec_key = CStr('',config=True)
47 username = CUnicode(os.environ.get('USER','username'),config=True)
48 exec_key = CUnicode('',config=True)
49 49 # not configurable:
50 50 context = Instance('zmq.Context', (), {})
51 51 session = Instance('IPython.zmq.parallel.streamsession.StreamSession')
@@ -15,7 +15,6 b' and monitors traffic through the various queues.'
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 import logging
19 18 import sys
20 19 import time
21 20 from datetime import datetime
@@ -25,16 +24,15 b' from zmq.eventloop import ioloop'
25 24 from zmq.eventloop.zmqstream import ZMQStream
26 25
27 26 # internal:
28 from IPython.config.configurable import Configurable
29 27 from IPython.utils.importstring import import_item
30 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
31 29
32 30 from .entry_point import select_random_ports
33 31 from .factory import RegistrationFactory, LoggingFactory
34 32
33 from . import error
35 34 from .heartmonitor import HeartMonitor
36 from .streamsession import Message, wrap_exception, ISO8601
37 from .util import validate_url_container
35 from .util import validate_url_container, ISO8601
38 36
39 37 try:
40 38 from pymongo.binary import Binary
@@ -491,7 +489,7 b' class Hub(LoggingFactory):'
491 489 try:
492 490 msg = self.session.unpack_message(msg, content=True)
493 491 except:
494 content = wrap_exception()
492 content = error.wrap_exception()
495 493 self.log.error("Bad Client Message: %s"%msg, exc_info=True)
496 494 self.session.send(self.clientele, "hub_error", ident=client_id,
497 495 content=content)
@@ -505,7 +503,7 b' class Hub(LoggingFactory):'
505 503 try:
506 504 assert handler is not None, "Bad Message Type: %s"%msg_type
507 505 except:
508 content = wrap_exception()
506 content = error.wrap_exception()
509 507 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
510 508 self.session.send(self.clientele, "hub_error", ident=client_id,
511 509 content=content)
@@ -802,14 +800,14 b' class Hub(LoggingFactory):'
802 800 try:
803 801 raise KeyError("queue_id %r in use"%queue)
804 802 except:
805 content = wrap_exception()
803 content = error.wrap_exception()
806 804 self.log.error("queue_id %r in use"%queue, exc_info=True)
807 805 elif heart in self.hearts: # need to check unique hearts?
808 806 try:
809 807 raise KeyError("heart_id %r in use"%heart)
810 808 except:
811 809 self.log.error("heart_id %r in use"%heart, exc_info=True)
812 content = wrap_exception()
810 content = error.wrap_exception()
813 811 else:
814 812 for h, pack in self.incoming_registrations.iteritems():
815 813 if heart == h:
@@ -817,14 +815,14 b' class Hub(LoggingFactory):'
817 815 raise KeyError("heart_id %r in use"%heart)
818 816 except:
819 817 self.log.error("heart_id %r in use"%heart, exc_info=True)
820 content = wrap_exception()
818 content = error.wrap_exception()
821 819 break
822 820 elif queue == pack[1]:
823 821 try:
824 822 raise KeyError("queue_id %r in use"%queue)
825 823 except:
826 824 self.log.error("queue_id %r in use"%queue, exc_info=True)
827 content = wrap_exception()
825 content = error.wrap_exception()
828 826 break
829 827
830 828 msg = self.session.send(self.registrar, "registration_reply",
@@ -928,7 +926,7 b' class Hub(LoggingFactory):'
928 926 targets = content['targets']
929 927 targets = self._validate_targets(targets)
930 928 except:
931 content = wrap_exception()
929 content = error.wrap_exception()
932 930 self.session.send(self.clientele, "hub_error",
933 931 content=content, ident=client_id)
934 932 return
@@ -952,7 +950,7 b' class Hub(LoggingFactory):'
952 950 try:
953 951 targets = self._validate_targets(targets)
954 952 except:
955 content = wrap_exception()
953 content = error.wrap_exception()
956 954 self.session.send(self.clientele, "hub_error",
957 955 content=content, ident=client_id)
958 956 return
@@ -987,12 +985,12 b' class Hub(LoggingFactory):'
987 985 try:
988 986 raise IndexError("msg pending: %r"%msg_id)
989 987 except:
990 reply = wrap_exception()
988 reply = error.wrap_exception()
991 989 else:
992 990 try:
993 991 raise IndexError("No such msg: %r"%msg_id)
994 992 except:
995 reply = wrap_exception()
993 reply = error.wrap_exception()
996 994 break
997 995 eids = content.get('engine_ids', [])
998 996 for eid in eids:
@@ -1000,7 +998,7 b' class Hub(LoggingFactory):'
1000 998 try:
1001 999 raise IndexError("No such engine: %i"%eid)
1002 1000 except:
1003 reply = wrap_exception()
1001 reply = error.wrap_exception()
1004 1002 break
1005 1003 msg_ids = self.completed.pop(eid)
1006 1004 uid = self.engines[eid].queue
@@ -1046,7 +1044,7 b' class Hub(LoggingFactory):'
1046 1044 try:
1047 1045 raise KeyError('No such message: '+msg_id)
1048 1046 except:
1049 content = wrap_exception()
1047 content = error.wrap_exception()
1050 1048 break
1051 1049 self.session.send(self.clientele, "result_reply", content=content,
1052 1050 parent=msg, ident=client_id,
@@ -102,7 +102,31 b' class RemoteFunction(object):'
102 102
103 103
104 104 class ParallelFunction(RemoteFunction):
105 """Class for mapping a function to sequences."""
105 """Class for mapping a function to sequences.
106
107 This will distribute the sequences according the a mapper, and call
108 the function on each sub-sequence. If called via map, then the function
109 will be called once on each element, rather that each sub-sequence.
110
111 Parameters
112 ----------
113
114 client : Client instance
115 The client to be used to connect to engines
116 f : callable
117 The function to be wrapped into a remote function
118 bound : bool [default: False]
119 Whether the affect the remote namespace when called
120 block : bool [default: None]
121 Whether to wait for results or not. The default behavior is
122 to use the current `block` attribute of `client`
123 targets : valid target list [default: all]
124 The targets on which to execute.
125 balanced : bool
126 Whether to load-balance with the Task scheduler or not
127 chunk_size : int or None
128 The size of chunk to use when breaking up sequences in a load-balanced manner
129 """
106 130 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
107 131 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
108 132 self.chunk_size = chunk_size
@@ -164,7 +188,11 b' class ParallelFunction(RemoteFunction):'
164 188 return r
165 189
166 190 def map(self, *sequences):
167 """call a function on each element of a sequence remotely."""
191 """call a function on each element of a sequence remotely.
192 This should behave very much like the builtin map, but return an AsyncMapResult
193 if self.block is False.
194 """
195 # set _map as a flag for use inside self.__call__
168 196 self._map = True
169 197 try:
170 198 ret = self.__call__(*sequences)
@@ -172,3 +200,4 b' class ParallelFunction(RemoteFunction):'
172 200 del self._map
173 201 return ret
174 202
203 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction'] No newline at end of file
@@ -31,7 +31,6 b' from IPython.external.decorator import decorator'
31 31 from IPython.utils.traitlets import Instance, Dict, List, Set
32 32
33 33 from . import error
34 from . import streamsession as ss
35 34 from .dependency import Dependency
36 35 from .entry_point import connect_logger, local_logger
37 36 from .factory import SessionFactory
@@ -237,7 +236,7 b' class TaskScheduler(SessionFactory):'
237 236 try:
238 237 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
239 238 except:
240 content = ss.wrap_exception()
239 content = error.wrap_exception()
241 240 msg = self.session.send(self.client_stream, 'apply_reply', content,
242 241 parent=parent, ident=idents)
243 242 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
@@ -340,7 +339,7 b' class TaskScheduler(SessionFactory):'
340 339 try:
341 340 raise why()
342 341 except:
343 content = ss.wrap_exception()
342 content = error.wrap_exception()
344 343
345 344 self.all_done.add(msg_id)
346 345 self.all_failed.add(msg_id)
@@ -9,13 +9,9 b' Kernel adapted from kernel.py to use ZMQ Streams'
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 import __builtin__
13 12
14 import logging
15 import os
16 13 import sys
17 14 import time
18 import traceback
19 15
20 16 from code import CommandCompiler
21 17 from datetime import datetime
@@ -28,16 +24,17 b' from zmq.eventloop import ioloop, zmqstream'
28 24
29 25 # Local imports.
30 26 from IPython.core import ultratb
31 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
27 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Str
32 28 from IPython.zmq.completer import KernelCompleter
33 29 from IPython.zmq.iostream import OutStream
34 30 from IPython.zmq.displayhook import DisplayHook
35 31
36 32 from . import heartmonitor
37 33 from .client import Client
34 from .error import wrap_exception
38 35 from .factory import SessionFactory
39 from .streamsession import StreamSession, Message, extract_header, serialize_object,\
40 unpack_apply_message, ISO8601, wrap_exception
36 from .streamsession import StreamSession
37 from .util import serialize_object, unpack_apply_message, ISO8601
41 38
42 39 def printer(*args):
43 40 pprint(args, stream=sys.__stdout__)
@@ -5,8 +5,6 b''
5 5
6 6 import os
7 7 import pprint
8 import sys
9 import traceback
10 8 import uuid
11 9 from datetime import datetime
12 10
@@ -21,10 +19,7 b' import zmq'
21 19 from zmq.utils import jsonapi
22 20 from zmq.eventloop.zmqstream import ZMQStream
23 21
24 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
25 from IPython.utils.newserialized import serialize, unserialize
26
27 from .error import RemoteError
22 from .util import ISO8601
28 23
29 24 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 25 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
@@ -66,26 +61,6 b' else:'
66 61
67 62
68 63 DELIM="<IDS|MSG>"
69 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
70
71 def wrap_exception(engine_info={}):
72 etype, evalue, tb = sys.exc_info()
73 stb = traceback.format_exception(etype, evalue, tb)
74 exc_content = {
75 'status' : 'error',
76 'traceback' : stb,
77 'ename' : unicode(etype.__name__),
78 'evalue' : unicode(evalue),
79 'engine_info' : engine_info
80 }
81 return exc_content
82
83 def unwrap_exception(content):
84 err = RemoteError(content['ename'], content['evalue'],
85 ''.join(content['traceback']),
86 content.get('engine_info', {}))
87 return err
88
89 64
90 65 class Message(object):
91 66 """A simple message object that maps dict keys to attributes.
@@ -140,146 +115,6 b' def extract_header(msg_or_header):'
140 115 h = dict(h)
141 116 return h
142 117
143 def rekey(dikt):
144 """Rekey a dict that has been forced to use str keys where there should be
145 ints by json. This belongs in the jsonutil added by fperez."""
146 for k in dikt.iterkeys():
147 if isinstance(k, str):
148 ik=fk=None
149 try:
150 ik = int(k)
151 except ValueError:
152 try:
153 fk = float(k)
154 except ValueError:
155 continue
156 if ik is not None:
157 nk = ik
158 else:
159 nk = fk
160 if nk in dikt:
161 raise KeyError("already have key %r"%nk)
162 dikt[nk] = dikt.pop(k)
163 return dikt
164
165 def serialize_object(obj, threshold=64e-6):
166 """Serialize an object into a list of sendable buffers.
167
168 Parameters
169 ----------
170
171 obj : object
172 The object to be serialized
173 threshold : float
174 The threshold for not double-pickling the content.
175
176
177 Returns
178 -------
179 ('pmd', [bufs]) :
180 where pmd is the pickled metadata wrapper,
181 bufs is a list of data buffers
182 """
183 databuffers = []
184 if isinstance(obj, (list, tuple)):
185 clist = canSequence(obj)
186 slist = map(serialize, clist)
187 for s in slist:
188 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
189 databuffers.append(s.getData())
190 s.data = None
191 return pickle.dumps(slist,-1), databuffers
192 elif isinstance(obj, dict):
193 sobj = {}
194 for k in sorted(obj.iterkeys()):
195 s = serialize(can(obj[k]))
196 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
197 databuffers.append(s.getData())
198 s.data = None
199 sobj[k] = s
200 return pickle.dumps(sobj,-1),databuffers
201 else:
202 s = serialize(can(obj))
203 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
204 databuffers.append(s.getData())
205 s.data = None
206 return pickle.dumps(s,-1),databuffers
207
208
209 def unserialize_object(bufs):
210 """reconstruct an object serialized by serialize_object from data buffers."""
211 bufs = list(bufs)
212 sobj = pickle.loads(bufs.pop(0))
213 if isinstance(sobj, (list, tuple)):
214 for s in sobj:
215 if s.data is None:
216 s.data = bufs.pop(0)
217 return uncanSequence(map(unserialize, sobj)), bufs
218 elif isinstance(sobj, dict):
219 newobj = {}
220 for k in sorted(sobj.iterkeys()):
221 s = sobj[k]
222 if s.data is None:
223 s.data = bufs.pop(0)
224 newobj[k] = uncan(unserialize(s))
225 return newobj, bufs
226 else:
227 if sobj.data is None:
228 sobj.data = bufs.pop(0)
229 return uncan(unserialize(sobj)), bufs
230
231 def pack_apply_message(f, args, kwargs, threshold=64e-6):
232 """pack up a function, args, and kwargs to be sent over the wire
233 as a series of buffers. Any object whose data is larger than `threshold`
234 will not have their data copied (currently only numpy arrays support zero-copy)"""
235 msg = [pickle.dumps(can(f),-1)]
236 databuffers = [] # for large objects
237 sargs, bufs = serialize_object(args,threshold)
238 msg.append(sargs)
239 databuffers.extend(bufs)
240 skwargs, bufs = serialize_object(kwargs,threshold)
241 msg.append(skwargs)
242 databuffers.extend(bufs)
243 msg.extend(databuffers)
244 return msg
245
246 def unpack_apply_message(bufs, g=None, copy=True):
247 """unpack f,args,kwargs from buffers packed by pack_apply_message()
248 Returns: original f,args,kwargs"""
249 bufs = list(bufs) # allow us to pop
250 assert len(bufs) >= 3, "not enough buffers!"
251 if not copy:
252 for i in range(3):
253 bufs[i] = bufs[i].bytes
254 cf = pickle.loads(bufs.pop(0))
255 sargs = list(pickle.loads(bufs.pop(0)))
256 skwargs = dict(pickle.loads(bufs.pop(0)))
257 # print sargs, skwargs
258 f = uncan(cf, g)
259 for sa in sargs:
260 if sa.data is None:
261 m = bufs.pop(0)
262 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
263 if copy:
264 sa.data = buffer(m)
265 else:
266 sa.data = m.buffer
267 else:
268 if copy:
269 sa.data = m
270 else:
271 sa.data = m.bytes
272
273 args = uncanSequence(map(unserialize, sargs), g)
274 kwargs = {}
275 for k in sorted(skwargs.iterkeys()):
276 sa = skwargs[k]
277 if sa.data is None:
278 sa.data = bufs.pop(0)
279 kwargs[k] = uncan(unserialize(sa), g)
280
281 return f,args,kwargs
282
283 118 class StreamSession(object):
284 119 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
285 120 debug=False
@@ -47,24 +47,24 b' class TestSession(SessionTestCase):'
47 47 self.assertEquals(s.username, 'carrot')
48 48
49 49
50 def test_rekey(self):
51 """rekeying dict around json str keys"""
52 d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 self.assertRaises(KeyError, ss.rekey, d)
54
55 d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 rd = ss.rekey(d)
58 self.assertEquals(d2,rd)
59
60 d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 d2 = {1.5:d['1.5'],1:d['1']}
62 rd = ss.rekey(d)
63 self.assertEquals(d2,rd)
64
65 d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 self.assertRaises(KeyError, ss.rekey, d)
67
50 # def test_rekey(self):
51 # """rekeying dict around json str keys"""
52 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 # self.assertRaises(KeyError, ss.rekey, d)
54 #
55 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 # rd = ss.rekey(d)
58 # self.assertEquals(d2,rd)
59 #
60 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 # d2 = {1.5:d['1.5'],1:d['1']}
62 # rd = ss.rekey(d)
63 # self.assertEquals(d2,rd)
64 #
65 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 # self.assertRaises(KeyError, ss.rekey, d)
67 #
68 68 def test_unique_msg_ids(self):
69 69 """test that messages receive unique ids"""
70 70 ids = set()
@@ -1,7 +1,20 b''
1 """some generic utilities"""
1 """some generic utilities for dealing with classes, urls, and serialization"""
2 2 import re
3 3 import socket
4 4
5 try:
6 import cPickle
7 pickle = cPickle
8 except:
9 cPickle = None
10 import pickle
11
12
13 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
14 from IPython.utils.newserialized import serialize, unserialize
15
16 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
17
5 18 class ReverseDict(dict):
6 19 """simple double-keyed subset of dict methods."""
7 20
@@ -34,7 +47,6 b' class ReverseDict(dict):'
34 47 except KeyError:
35 48 return default
36 49
37
38 50 def validate_url(url):
39 51 """validate a url for zeromq"""
40 52 if not isinstance(url, basestring):
@@ -117,3 +129,143 b' def disambiguate_url(url, location=None):'
117 129 return "%s://%s:%s"%(proto,ip,port)
118 130
119 131
132 def rekey(dikt):
133 """Rekey a dict that has been forced to use str keys where there should be
134 ints by json. This belongs in the jsonutil added by fperez."""
135 for k in dikt.iterkeys():
136 if isinstance(k, str):
137 ik=fk=None
138 try:
139 ik = int(k)
140 except ValueError:
141 try:
142 fk = float(k)
143 except ValueError:
144 continue
145 if ik is not None:
146 nk = ik
147 else:
148 nk = fk
149 if nk in dikt:
150 raise KeyError("already have key %r"%nk)
151 dikt[nk] = dikt.pop(k)
152 return dikt
153
154 def serialize_object(obj, threshold=64e-6):
155 """Serialize an object into a list of sendable buffers.
156
157 Parameters
158 ----------
159
160 obj : object
161 The object to be serialized
162 threshold : float
163 The threshold for not double-pickling the content.
164
165
166 Returns
167 -------
168 ('pmd', [bufs]) :
169 where pmd is the pickled metadata wrapper,
170 bufs is a list of data buffers
171 """
172 databuffers = []
173 if isinstance(obj, (list, tuple)):
174 clist = canSequence(obj)
175 slist = map(serialize, clist)
176 for s in slist:
177 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
178 databuffers.append(s.getData())
179 s.data = None
180 return pickle.dumps(slist,-1), databuffers
181 elif isinstance(obj, dict):
182 sobj = {}
183 for k in sorted(obj.iterkeys()):
184 s = serialize(can(obj[k]))
185 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
186 databuffers.append(s.getData())
187 s.data = None
188 sobj[k] = s
189 return pickle.dumps(sobj,-1),databuffers
190 else:
191 s = serialize(can(obj))
192 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
193 databuffers.append(s.getData())
194 s.data = None
195 return pickle.dumps(s,-1),databuffers
196
197
198 def unserialize_object(bufs):
199 """reconstruct an object serialized by serialize_object from data buffers."""
200 bufs = list(bufs)
201 sobj = pickle.loads(bufs.pop(0))
202 if isinstance(sobj, (list, tuple)):
203 for s in sobj:
204 if s.data is None:
205 s.data = bufs.pop(0)
206 return uncanSequence(map(unserialize, sobj)), bufs
207 elif isinstance(sobj, dict):
208 newobj = {}
209 for k in sorted(sobj.iterkeys()):
210 s = sobj[k]
211 if s.data is None:
212 s.data = bufs.pop(0)
213 newobj[k] = uncan(unserialize(s))
214 return newobj, bufs
215 else:
216 if sobj.data is None:
217 sobj.data = bufs.pop(0)
218 return uncan(unserialize(sobj)), bufs
219
220 def pack_apply_message(f, args, kwargs, threshold=64e-6):
221 """pack up a function, args, and kwargs to be sent over the wire
222 as a series of buffers. Any object whose data is larger than `threshold`
223 will not have their data copied (currently only numpy arrays support zero-copy)"""
224 msg = [pickle.dumps(can(f),-1)]
225 databuffers = [] # for large objects
226 sargs, bufs = serialize_object(args,threshold)
227 msg.append(sargs)
228 databuffers.extend(bufs)
229 skwargs, bufs = serialize_object(kwargs,threshold)
230 msg.append(skwargs)
231 databuffers.extend(bufs)
232 msg.extend(databuffers)
233 return msg
234
235 def unpack_apply_message(bufs, g=None, copy=True):
236 """unpack f,args,kwargs from buffers packed by pack_apply_message()
237 Returns: original f,args,kwargs"""
238 bufs = list(bufs) # allow us to pop
239 assert len(bufs) >= 3, "not enough buffers!"
240 if not copy:
241 for i in range(3):
242 bufs[i] = bufs[i].bytes
243 cf = pickle.loads(bufs.pop(0))
244 sargs = list(pickle.loads(bufs.pop(0)))
245 skwargs = dict(pickle.loads(bufs.pop(0)))
246 # print sargs, skwargs
247 f = uncan(cf, g)
248 for sa in sargs:
249 if sa.data is None:
250 m = bufs.pop(0)
251 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
252 if copy:
253 sa.data = buffer(m)
254 else:
255 sa.data = m.buffer
256 else:
257 if copy:
258 sa.data = m
259 else:
260 sa.data = m.bytes
261
262 args = uncanSequence(map(unserialize, sargs), g)
263 kwargs = {}
264 for k in sorted(skwargs.iterkeys()):
265 sa = skwargs[k]
266 if sa.data is None:
267 sa.data = bufs.pop(0)
268 kwargs[k] = uncan(unserialize(sa), g)
269
270 return f,args,kwargs
271
@@ -1,4 +1,4 b''
1 """Views of remote engines"""
1 """Views of remote engines."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
@@ -11,7 +11,7 b''
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from IPython.testing import decorators as testdec
14 from IPython.utils.traitlets import HasTraits, Bool, List, Dict, Set, Int, Instance
14 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance
15 15
16 16 from IPython.external.decorator import decorator
17 17
@@ -82,7 +82,7 b' class View(HasTraits):'
82 82 _ntargets = Int(1)
83 83 _balanced = Bool(False)
84 84 _default_names = List(['block', 'bound'])
85 _targets = None
85 _targets = Any()
86 86
87 87 def __init__(self, client=None, targets=None):
88 88 super(View, self).__init__(client=client)
@@ -655,3 +655,4 b' class LoadBalancedView(View):'
655 655 chunk_size=chunk_size)
656 656 return pf.map(*sequences)
657 657
658 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now