Show More
@@ -1,425 +1,426 b'' | |||||
1 | # encoding: utf-8 |
|
1 | # encoding: utf-8 | |
2 | """Pickle related utilities. Perhaps this should be called 'can'.""" |
|
2 | """Pickle related utilities. Perhaps this should be called 'can'.""" | |
3 |
|
3 | |||
4 | # Copyright (c) IPython Development Team. |
|
4 | # Copyright (c) IPython Development Team. | |
5 | # Distributed under the terms of the Modified BSD License. |
|
5 | # Distributed under the terms of the Modified BSD License. | |
6 |
|
6 | |||
7 | import copy |
|
7 | import copy | |
8 | import logging |
|
8 | import logging | |
9 | import sys |
|
9 | import sys | |
10 | from types import FunctionType |
|
10 | from types import FunctionType | |
11 |
|
11 | |||
12 | try: |
|
12 | try: | |
13 | import cPickle as pickle |
|
13 | import cPickle as pickle | |
14 | except ImportError: |
|
14 | except ImportError: | |
15 | import pickle |
|
15 | import pickle | |
16 |
|
16 | |||
|
17 | from IPython.utils import py3compat | |||
|
18 | from IPython.utils.importstring import import_item | |||
|
19 | from IPython.utils.py3compat import string_types, iteritems, buffer_to_bytes_py2 | |||
|
20 | ||||
17 | from . import codeutil # This registers a hook when it's imported |
|
21 | from . import codeutil # This registers a hook when it's imported | |
18 | from . import py3compat |
|
|||
19 | from .importstring import import_item |
|
|||
20 | from .py3compat import string_types, iteritems, buffer_to_bytes_py2 |
|
|||
21 |
|
22 | |||
22 | from IPython.config import Application |
|
23 | from IPython.config import Application | |
23 | from IPython.utils.log import get_logger |
|
24 | from IPython.utils.log import get_logger | |
24 |
|
25 | |||
25 | if py3compat.PY3: |
|
26 | if py3compat.PY3: | |
26 | buffer = memoryview |
|
27 | buffer = memoryview | |
27 | class_type = type |
|
28 | class_type = type | |
28 | else: |
|
29 | else: | |
29 | from types import ClassType |
|
30 | from types import ClassType | |
30 | class_type = (type, ClassType) |
|
31 | class_type = (type, ClassType) | |
31 |
|
32 | |||
32 | try: |
|
33 | try: | |
33 | PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL |
|
34 | PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL | |
34 | except AttributeError: |
|
35 | except AttributeError: | |
35 | PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL |
|
36 | PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL | |
36 |
|
37 | |||
37 | def _get_cell_type(a=None): |
|
38 | def _get_cell_type(a=None): | |
38 | """the type of a closure cell doesn't seem to be importable, |
|
39 | """the type of a closure cell doesn't seem to be importable, | |
39 | so just create one |
|
40 | so just create one | |
40 | """ |
|
41 | """ | |
41 | def inner(): |
|
42 | def inner(): | |
42 | return a |
|
43 | return a | |
43 | return type(py3compat.get_closure(inner)[0]) |
|
44 | return type(py3compat.get_closure(inner)[0]) | |
44 |
|
45 | |||
45 | cell_type = _get_cell_type() |
|
46 | cell_type = _get_cell_type() | |
46 |
|
47 | |||
47 | #------------------------------------------------------------------------------- |
|
48 | #------------------------------------------------------------------------------- | |
48 | # Functions |
|
49 | # Functions | |
49 | #------------------------------------------------------------------------------- |
|
50 | #------------------------------------------------------------------------------- | |
50 |
|
51 | |||
51 |
|
52 | |||
52 | def use_dill(): |
|
53 | def use_dill(): | |
53 | """use dill to expand serialization support |
|
54 | """use dill to expand serialization support | |
54 |
|
55 | |||
55 | adds support for object methods and closures to serialization. |
|
56 | adds support for object methods and closures to serialization. | |
56 | """ |
|
57 | """ | |
57 | # import dill causes most of the magic |
|
58 | # import dill causes most of the magic | |
58 | import dill |
|
59 | import dill | |
59 |
|
60 | |||
60 | # dill doesn't work with cPickle, |
|
61 | # dill doesn't work with cPickle, | |
61 | # tell the two relevant modules to use plain pickle |
|
62 | # tell the two relevant modules to use plain pickle | |
62 |
|
63 | |||
63 | global pickle |
|
64 | global pickle | |
64 | pickle = dill |
|
65 | pickle = dill | |
65 |
|
66 | |||
66 | try: |
|
67 | try: | |
67 | from IPython.kernel.zmq import serialize |
|
68 | from IPython.kernel.zmq import serialize | |
68 | except ImportError: |
|
69 | except ImportError: | |
69 | pass |
|
70 | pass | |
70 | else: |
|
71 | else: | |
71 | serialize.pickle = dill |
|
72 | serialize.pickle = dill | |
72 |
|
73 | |||
73 | # disable special function handling, let dill take care of it |
|
74 | # disable special function handling, let dill take care of it | |
74 | can_map.pop(FunctionType, None) |
|
75 | can_map.pop(FunctionType, None) | |
75 |
|
76 | |||
76 | def use_cloudpickle(): |
|
77 | def use_cloudpickle(): | |
77 | """use cloudpickle to expand serialization support |
|
78 | """use cloudpickle to expand serialization support | |
78 |
|
79 | |||
79 | adds support for object methods and closures to serialization. |
|
80 | adds support for object methods and closures to serialization. | |
80 | """ |
|
81 | """ | |
81 | from cloud.serialization import cloudpickle |
|
82 | from cloud.serialization import cloudpickle | |
82 |
|
83 | |||
83 | global pickle |
|
84 | global pickle | |
84 | pickle = cloudpickle |
|
85 | pickle = cloudpickle | |
85 |
|
86 | |||
86 | try: |
|
87 | try: | |
87 | from IPython.kernel.zmq import serialize |
|
88 | from IPython.kernel.zmq import serialize | |
88 | except ImportError: |
|
89 | except ImportError: | |
89 | pass |
|
90 | pass | |
90 | else: |
|
91 | else: | |
91 | serialize.pickle = cloudpickle |
|
92 | serialize.pickle = cloudpickle | |
92 |
|
93 | |||
93 | # disable special function handling, let cloudpickle take care of it |
|
94 | # disable special function handling, let cloudpickle take care of it | |
94 | can_map.pop(FunctionType, None) |
|
95 | can_map.pop(FunctionType, None) | |
95 |
|
96 | |||
96 |
|
97 | |||
97 | #------------------------------------------------------------------------------- |
|
98 | #------------------------------------------------------------------------------- | |
98 | # Classes |
|
99 | # Classes | |
99 | #------------------------------------------------------------------------------- |
|
100 | #------------------------------------------------------------------------------- | |
100 |
|
101 | |||
101 |
|
102 | |||
102 | class CannedObject(object): |
|
103 | class CannedObject(object): | |
103 | def __init__(self, obj, keys=[], hook=None): |
|
104 | def __init__(self, obj, keys=[], hook=None): | |
104 | """can an object for safe pickling |
|
105 | """can an object for safe pickling | |
105 |
|
106 | |||
106 | Parameters |
|
107 | Parameters | |
107 | ========== |
|
108 | ========== | |
108 |
|
109 | |||
109 | obj: |
|
110 | obj: | |
110 | The object to be canned |
|
111 | The object to be canned | |
111 | keys: list (optional) |
|
112 | keys: list (optional) | |
112 | list of attribute names that will be explicitly canned / uncanned |
|
113 | list of attribute names that will be explicitly canned / uncanned | |
113 | hook: callable (optional) |
|
114 | hook: callable (optional) | |
114 | An optional extra callable, |
|
115 | An optional extra callable, | |
115 | which can do additional processing of the uncanned object. |
|
116 | which can do additional processing of the uncanned object. | |
116 |
|
117 | |||
117 | large data may be offloaded into the buffers list, |
|
118 | large data may be offloaded into the buffers list, | |
118 | used for zero-copy transfers. |
|
119 | used for zero-copy transfers. | |
119 | """ |
|
120 | """ | |
120 | self.keys = keys |
|
121 | self.keys = keys | |
121 | self.obj = copy.copy(obj) |
|
122 | self.obj = copy.copy(obj) | |
122 | self.hook = can(hook) |
|
123 | self.hook = can(hook) | |
123 | for key in keys: |
|
124 | for key in keys: | |
124 | setattr(self.obj, key, can(getattr(obj, key))) |
|
125 | setattr(self.obj, key, can(getattr(obj, key))) | |
125 |
|
126 | |||
126 | self.buffers = [] |
|
127 | self.buffers = [] | |
127 |
|
128 | |||
128 | def get_object(self, g=None): |
|
129 | def get_object(self, g=None): | |
129 | if g is None: |
|
130 | if g is None: | |
130 | g = {} |
|
131 | g = {} | |
131 | obj = self.obj |
|
132 | obj = self.obj | |
132 | for key in self.keys: |
|
133 | for key in self.keys: | |
133 | setattr(obj, key, uncan(getattr(obj, key), g)) |
|
134 | setattr(obj, key, uncan(getattr(obj, key), g)) | |
134 |
|
135 | |||
135 | if self.hook: |
|
136 | if self.hook: | |
136 | self.hook = uncan(self.hook, g) |
|
137 | self.hook = uncan(self.hook, g) | |
137 | self.hook(obj, g) |
|
138 | self.hook(obj, g) | |
138 | return self.obj |
|
139 | return self.obj | |
139 |
|
140 | |||
140 |
|
141 | |||
141 | class Reference(CannedObject): |
|
142 | class Reference(CannedObject): | |
142 | """object for wrapping a remote reference by name.""" |
|
143 | """object for wrapping a remote reference by name.""" | |
143 | def __init__(self, name): |
|
144 | def __init__(self, name): | |
144 | if not isinstance(name, string_types): |
|
145 | if not isinstance(name, string_types): | |
145 | raise TypeError("illegal name: %r"%name) |
|
146 | raise TypeError("illegal name: %r"%name) | |
146 | self.name = name |
|
147 | self.name = name | |
147 | self.buffers = [] |
|
148 | self.buffers = [] | |
148 |
|
149 | |||
149 | def __repr__(self): |
|
150 | def __repr__(self): | |
150 | return "<Reference: %r>"%self.name |
|
151 | return "<Reference: %r>"%self.name | |
151 |
|
152 | |||
152 | def get_object(self, g=None): |
|
153 | def get_object(self, g=None): | |
153 | if g is None: |
|
154 | if g is None: | |
154 | g = {} |
|
155 | g = {} | |
155 |
|
156 | |||
156 | return eval(self.name, g) |
|
157 | return eval(self.name, g) | |
157 |
|
158 | |||
158 |
|
159 | |||
159 | class CannedCell(CannedObject): |
|
160 | class CannedCell(CannedObject): | |
160 | """Can a closure cell""" |
|
161 | """Can a closure cell""" | |
161 | def __init__(self, cell): |
|
162 | def __init__(self, cell): | |
162 | self.cell_contents = can(cell.cell_contents) |
|
163 | self.cell_contents = can(cell.cell_contents) | |
163 |
|
164 | |||
164 | def get_object(self, g=None): |
|
165 | def get_object(self, g=None): | |
165 | cell_contents = uncan(self.cell_contents, g) |
|
166 | cell_contents = uncan(self.cell_contents, g) | |
166 | def inner(): |
|
167 | def inner(): | |
167 | return cell_contents |
|
168 | return cell_contents | |
168 | return py3compat.get_closure(inner)[0] |
|
169 | return py3compat.get_closure(inner)[0] | |
169 |
|
170 | |||
170 |
|
171 | |||
171 | class CannedFunction(CannedObject): |
|
172 | class CannedFunction(CannedObject): | |
172 |
|
173 | |||
173 | def __init__(self, f): |
|
174 | def __init__(self, f): | |
174 | self._check_type(f) |
|
175 | self._check_type(f) | |
175 | self.code = f.__code__ |
|
176 | self.code = f.__code__ | |
176 | if f.__defaults__: |
|
177 | if f.__defaults__: | |
177 | self.defaults = [ can(fd) for fd in f.__defaults__ ] |
|
178 | self.defaults = [ can(fd) for fd in f.__defaults__ ] | |
178 | else: |
|
179 | else: | |
179 | self.defaults = None |
|
180 | self.defaults = None | |
180 |
|
181 | |||
181 | closure = py3compat.get_closure(f) |
|
182 | closure = py3compat.get_closure(f) | |
182 | if closure: |
|
183 | if closure: | |
183 | self.closure = tuple( can(cell) for cell in closure ) |
|
184 | self.closure = tuple( can(cell) for cell in closure ) | |
184 | else: |
|
185 | else: | |
185 | self.closure = None |
|
186 | self.closure = None | |
186 |
|
187 | |||
187 | self.module = f.__module__ or '__main__' |
|
188 | self.module = f.__module__ or '__main__' | |
188 | self.__name__ = f.__name__ |
|
189 | self.__name__ = f.__name__ | |
189 | self.buffers = [] |
|
190 | self.buffers = [] | |
190 |
|
191 | |||
191 | def _check_type(self, obj): |
|
192 | def _check_type(self, obj): | |
192 | assert isinstance(obj, FunctionType), "Not a function type" |
|
193 | assert isinstance(obj, FunctionType), "Not a function type" | |
193 |
|
194 | |||
194 | def get_object(self, g=None): |
|
195 | def get_object(self, g=None): | |
195 | # try to load function back into its module: |
|
196 | # try to load function back into its module: | |
196 | if not self.module.startswith('__'): |
|
197 | if not self.module.startswith('__'): | |
197 | __import__(self.module) |
|
198 | __import__(self.module) | |
198 | g = sys.modules[self.module].__dict__ |
|
199 | g = sys.modules[self.module].__dict__ | |
199 |
|
200 | |||
200 | if g is None: |
|
201 | if g is None: | |
201 | g = {} |
|
202 | g = {} | |
202 | if self.defaults: |
|
203 | if self.defaults: | |
203 | defaults = tuple(uncan(cfd, g) for cfd in self.defaults) |
|
204 | defaults = tuple(uncan(cfd, g) for cfd in self.defaults) | |
204 | else: |
|
205 | else: | |
205 | defaults = None |
|
206 | defaults = None | |
206 | if self.closure: |
|
207 | if self.closure: | |
207 | closure = tuple(uncan(cell, g) for cell in self.closure) |
|
208 | closure = tuple(uncan(cell, g) for cell in self.closure) | |
208 | else: |
|
209 | else: | |
209 | closure = None |
|
210 | closure = None | |
210 | newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) |
|
211 | newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) | |
211 | return newFunc |
|
212 | return newFunc | |
212 |
|
213 | |||
213 | class CannedClass(CannedObject): |
|
214 | class CannedClass(CannedObject): | |
214 |
|
215 | |||
215 | def __init__(self, cls): |
|
216 | def __init__(self, cls): | |
216 | self._check_type(cls) |
|
217 | self._check_type(cls) | |
217 | self.name = cls.__name__ |
|
218 | self.name = cls.__name__ | |
218 | self.old_style = not isinstance(cls, type) |
|
219 | self.old_style = not isinstance(cls, type) | |
219 | self._canned_dict = {} |
|
220 | self._canned_dict = {} | |
220 | for k,v in cls.__dict__.items(): |
|
221 | for k,v in cls.__dict__.items(): | |
221 | if k not in ('__weakref__', '__dict__'): |
|
222 | if k not in ('__weakref__', '__dict__'): | |
222 | self._canned_dict[k] = can(v) |
|
223 | self._canned_dict[k] = can(v) | |
223 | if self.old_style: |
|
224 | if self.old_style: | |
224 | mro = [] |
|
225 | mro = [] | |
225 | else: |
|
226 | else: | |
226 | mro = cls.mro() |
|
227 | mro = cls.mro() | |
227 |
|
228 | |||
228 | self.parents = [ can(c) for c in mro[1:] ] |
|
229 | self.parents = [ can(c) for c in mro[1:] ] | |
229 | self.buffers = [] |
|
230 | self.buffers = [] | |
230 |
|
231 | |||
231 | def _check_type(self, obj): |
|
232 | def _check_type(self, obj): | |
232 | assert isinstance(obj, class_type), "Not a class type" |
|
233 | assert isinstance(obj, class_type), "Not a class type" | |
233 |
|
234 | |||
234 | def get_object(self, g=None): |
|
235 | def get_object(self, g=None): | |
235 | parents = tuple(uncan(p, g) for p in self.parents) |
|
236 | parents = tuple(uncan(p, g) for p in self.parents) | |
236 | return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) |
|
237 | return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) | |
237 |
|
238 | |||
238 | class CannedArray(CannedObject): |
|
239 | class CannedArray(CannedObject): | |
239 | def __init__(self, obj): |
|
240 | def __init__(self, obj): | |
240 | from numpy import ascontiguousarray |
|
241 | from numpy import ascontiguousarray | |
241 | self.shape = obj.shape |
|
242 | self.shape = obj.shape | |
242 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str |
|
243 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str | |
243 | self.pickled = False |
|
244 | self.pickled = False | |
244 | if sum(obj.shape) == 0: |
|
245 | if sum(obj.shape) == 0: | |
245 | self.pickled = True |
|
246 | self.pickled = True | |
246 | elif obj.dtype == 'O': |
|
247 | elif obj.dtype == 'O': | |
247 | # can't handle object dtype with buffer approach |
|
248 | # can't handle object dtype with buffer approach | |
248 | self.pickled = True |
|
249 | self.pickled = True | |
249 | elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()): |
|
250 | elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()): | |
250 | self.pickled = True |
|
251 | self.pickled = True | |
251 | if self.pickled: |
|
252 | if self.pickled: | |
252 | # just pickle it |
|
253 | # just pickle it | |
253 | self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)] |
|
254 | self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)] | |
254 | else: |
|
255 | else: | |
255 | # ensure contiguous |
|
256 | # ensure contiguous | |
256 | obj = ascontiguousarray(obj, dtype=None) |
|
257 | obj = ascontiguousarray(obj, dtype=None) | |
257 | self.buffers = [buffer(obj)] |
|
258 | self.buffers = [buffer(obj)] | |
258 |
|
259 | |||
259 | def get_object(self, g=None): |
|
260 | def get_object(self, g=None): | |
260 | from numpy import frombuffer |
|
261 | from numpy import frombuffer | |
261 | data = self.buffers[0] |
|
262 | data = self.buffers[0] | |
262 | if self.pickled: |
|
263 | if self.pickled: | |
263 | # we just pickled it |
|
264 | # we just pickled it | |
264 | return pickle.loads(buffer_to_bytes_py2(data)) |
|
265 | return pickle.loads(buffer_to_bytes_py2(data)) | |
265 | else: |
|
266 | else: | |
266 | return frombuffer(data, dtype=self.dtype).reshape(self.shape) |
|
267 | return frombuffer(data, dtype=self.dtype).reshape(self.shape) | |
267 |
|
268 | |||
268 |
|
269 | |||
269 | class CannedBytes(CannedObject): |
|
270 | class CannedBytes(CannedObject): | |
270 | wrap = bytes |
|
271 | wrap = bytes | |
271 | def __init__(self, obj): |
|
272 | def __init__(self, obj): | |
272 | self.buffers = [obj] |
|
273 | self.buffers = [obj] | |
273 |
|
274 | |||
274 | def get_object(self, g=None): |
|
275 | def get_object(self, g=None): | |
275 | data = self.buffers[0] |
|
276 | data = self.buffers[0] | |
276 | return self.wrap(data) |
|
277 | return self.wrap(data) | |
277 |
|
278 | |||
278 | def CannedBuffer(CannedBytes): |
|
279 | def CannedBuffer(CannedBytes): | |
279 | wrap = buffer |
|
280 | wrap = buffer | |
280 |
|
281 | |||
281 | #------------------------------------------------------------------------------- |
|
282 | #------------------------------------------------------------------------------- | |
282 | # Functions |
|
283 | # Functions | |
283 | #------------------------------------------------------------------------------- |
|
284 | #------------------------------------------------------------------------------- | |
284 |
|
285 | |||
285 | def _import_mapping(mapping, original=None): |
|
286 | def _import_mapping(mapping, original=None): | |
286 | """import any string-keys in a type mapping |
|
287 | """import any string-keys in a type mapping | |
287 |
|
288 | |||
288 | """ |
|
289 | """ | |
289 | log = get_logger() |
|
290 | log = get_logger() | |
290 | log.debug("Importing canning map") |
|
291 | log.debug("Importing canning map") | |
291 | for key,value in list(mapping.items()): |
|
292 | for key,value in list(mapping.items()): | |
292 | if isinstance(key, string_types): |
|
293 | if isinstance(key, string_types): | |
293 | try: |
|
294 | try: | |
294 | cls = import_item(key) |
|
295 | cls = import_item(key) | |
295 | except Exception: |
|
296 | except Exception: | |
296 | if original and key not in original: |
|
297 | if original and key not in original: | |
297 | # only message on user-added classes |
|
298 | # only message on user-added classes | |
298 | log.error("canning class not importable: %r", key, exc_info=True) |
|
299 | log.error("canning class not importable: %r", key, exc_info=True) | |
299 | mapping.pop(key) |
|
300 | mapping.pop(key) | |
300 | else: |
|
301 | else: | |
301 | mapping[cls] = mapping.pop(key) |
|
302 | mapping[cls] = mapping.pop(key) | |
302 |
|
303 | |||
303 | def istype(obj, check): |
|
304 | def istype(obj, check): | |
304 | """like isinstance(obj, check), but strict |
|
305 | """like isinstance(obj, check), but strict | |
305 |
|
306 | |||
306 | This won't catch subclasses. |
|
307 | This won't catch subclasses. | |
307 | """ |
|
308 | """ | |
308 | if isinstance(check, tuple): |
|
309 | if isinstance(check, tuple): | |
309 | for cls in check: |
|
310 | for cls in check: | |
310 | if type(obj) is cls: |
|
311 | if type(obj) is cls: | |
311 | return True |
|
312 | return True | |
312 | return False |
|
313 | return False | |
313 | else: |
|
314 | else: | |
314 | return type(obj) is check |
|
315 | return type(obj) is check | |
315 |
|
316 | |||
316 | def can(obj): |
|
317 | def can(obj): | |
317 | """prepare an object for pickling""" |
|
318 | """prepare an object for pickling""" | |
318 |
|
319 | |||
319 | import_needed = False |
|
320 | import_needed = False | |
320 |
|
321 | |||
321 | for cls,canner in iteritems(can_map): |
|
322 | for cls,canner in iteritems(can_map): | |
322 | if isinstance(cls, string_types): |
|
323 | if isinstance(cls, string_types): | |
323 | import_needed = True |
|
324 | import_needed = True | |
324 | break |
|
325 | break | |
325 | elif istype(obj, cls): |
|
326 | elif istype(obj, cls): | |
326 | return canner(obj) |
|
327 | return canner(obj) | |
327 |
|
328 | |||
328 | if import_needed: |
|
329 | if import_needed: | |
329 | # perform can_map imports, then try again |
|
330 | # perform can_map imports, then try again | |
330 | # this will usually only happen once |
|
331 | # this will usually only happen once | |
331 | _import_mapping(can_map, _original_can_map) |
|
332 | _import_mapping(can_map, _original_can_map) | |
332 | return can(obj) |
|
333 | return can(obj) | |
333 |
|
334 | |||
334 | return obj |
|
335 | return obj | |
335 |
|
336 | |||
336 | def can_class(obj): |
|
337 | def can_class(obj): | |
337 | if isinstance(obj, class_type) and obj.__module__ == '__main__': |
|
338 | if isinstance(obj, class_type) and obj.__module__ == '__main__': | |
338 | return CannedClass(obj) |
|
339 | return CannedClass(obj) | |
339 | else: |
|
340 | else: | |
340 | return obj |
|
341 | return obj | |
341 |
|
342 | |||
342 | def can_dict(obj): |
|
343 | def can_dict(obj): | |
343 | """can the *values* of a dict""" |
|
344 | """can the *values* of a dict""" | |
344 | if istype(obj, dict): |
|
345 | if istype(obj, dict): | |
345 | newobj = {} |
|
346 | newobj = {} | |
346 | for k, v in iteritems(obj): |
|
347 | for k, v in iteritems(obj): | |
347 | newobj[k] = can(v) |
|
348 | newobj[k] = can(v) | |
348 | return newobj |
|
349 | return newobj | |
349 | else: |
|
350 | else: | |
350 | return obj |
|
351 | return obj | |
351 |
|
352 | |||
352 | sequence_types = (list, tuple, set) |
|
353 | sequence_types = (list, tuple, set) | |
353 |
|
354 | |||
354 | def can_sequence(obj): |
|
355 | def can_sequence(obj): | |
355 | """can the elements of a sequence""" |
|
356 | """can the elements of a sequence""" | |
356 | if istype(obj, sequence_types): |
|
357 | if istype(obj, sequence_types): | |
357 | t = type(obj) |
|
358 | t = type(obj) | |
358 | return t([can(i) for i in obj]) |
|
359 | return t([can(i) for i in obj]) | |
359 | else: |
|
360 | else: | |
360 | return obj |
|
361 | return obj | |
361 |
|
362 | |||
362 | def uncan(obj, g=None): |
|
363 | def uncan(obj, g=None): | |
363 | """invert canning""" |
|
364 | """invert canning""" | |
364 |
|
365 | |||
365 | import_needed = False |
|
366 | import_needed = False | |
366 | for cls,uncanner in iteritems(uncan_map): |
|
367 | for cls,uncanner in iteritems(uncan_map): | |
367 | if isinstance(cls, string_types): |
|
368 | if isinstance(cls, string_types): | |
368 | import_needed = True |
|
369 | import_needed = True | |
369 | break |
|
370 | break | |
370 | elif isinstance(obj, cls): |
|
371 | elif isinstance(obj, cls): | |
371 | return uncanner(obj, g) |
|
372 | return uncanner(obj, g) | |
372 |
|
373 | |||
373 | if import_needed: |
|
374 | if import_needed: | |
374 | # perform uncan_map imports, then try again |
|
375 | # perform uncan_map imports, then try again | |
375 | # this will usually only happen once |
|
376 | # this will usually only happen once | |
376 | _import_mapping(uncan_map, _original_uncan_map) |
|
377 | _import_mapping(uncan_map, _original_uncan_map) | |
377 | return uncan(obj, g) |
|
378 | return uncan(obj, g) | |
378 |
|
379 | |||
379 | return obj |
|
380 | return obj | |
380 |
|
381 | |||
381 | def uncan_dict(obj, g=None): |
|
382 | def uncan_dict(obj, g=None): | |
382 | if istype(obj, dict): |
|
383 | if istype(obj, dict): | |
383 | newobj = {} |
|
384 | newobj = {} | |
384 | for k, v in iteritems(obj): |
|
385 | for k, v in iteritems(obj): | |
385 | newobj[k] = uncan(v,g) |
|
386 | newobj[k] = uncan(v,g) | |
386 | return newobj |
|
387 | return newobj | |
387 | else: |
|
388 | else: | |
388 | return obj |
|
389 | return obj | |
389 |
|
390 | |||
390 | def uncan_sequence(obj, g=None): |
|
391 | def uncan_sequence(obj, g=None): | |
391 | if istype(obj, sequence_types): |
|
392 | if istype(obj, sequence_types): | |
392 | t = type(obj) |
|
393 | t = type(obj) | |
393 | return t([uncan(i,g) for i in obj]) |
|
394 | return t([uncan(i,g) for i in obj]) | |
394 | else: |
|
395 | else: | |
395 | return obj |
|
396 | return obj | |
396 |
|
397 | |||
397 | def _uncan_dependent_hook(dep, g=None): |
|
398 | def _uncan_dependent_hook(dep, g=None): | |
398 | dep.check_dependency() |
|
399 | dep.check_dependency() | |
399 |
|
400 | |||
400 | def can_dependent(obj): |
|
401 | def can_dependent(obj): | |
401 | return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook) |
|
402 | return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook) | |
402 |
|
403 | |||
403 | #------------------------------------------------------------------------------- |
|
404 | #------------------------------------------------------------------------------- | |
404 | # API dictionaries |
|
405 | # API dictionaries | |
405 | #------------------------------------------------------------------------------- |
|
406 | #------------------------------------------------------------------------------- | |
406 |
|
407 | |||
407 | # These dicts can be extended for custom serialization of new objects |
|
408 | # These dicts can be extended for custom serialization of new objects | |
408 |
|
409 | |||
409 | can_map = { |
|
410 | can_map = { | |
410 | 'IPython.parallel.dependent' : can_dependent, |
|
411 | 'IPython.parallel.dependent' : can_dependent, | |
411 | 'numpy.ndarray' : CannedArray, |
|
412 | 'numpy.ndarray' : CannedArray, | |
412 | FunctionType : CannedFunction, |
|
413 | FunctionType : CannedFunction, | |
413 | bytes : CannedBytes, |
|
414 | bytes : CannedBytes, | |
414 | buffer : CannedBuffer, |
|
415 | buffer : CannedBuffer, | |
415 | cell_type : CannedCell, |
|
416 | cell_type : CannedCell, | |
416 | class_type : can_class, |
|
417 | class_type : can_class, | |
417 | } |
|
418 | } | |
418 |
|
419 | |||
419 | uncan_map = { |
|
420 | uncan_map = { | |
420 | CannedObject : lambda obj, g: obj.get_object(g), |
|
421 | CannedObject : lambda obj, g: obj.get_object(g), | |
421 | } |
|
422 | } | |
422 |
|
423 | |||
423 | # for use in _import_mapping: |
|
424 | # for use in _import_mapping: | |
424 | _original_can_map = can_map.copy() |
|
425 | _original_can_map = can_map.copy() | |
425 | _original_uncan_map = uncan_map.copy() |
|
426 | _original_uncan_map = uncan_map.copy() |
@@ -1,179 +1,179 b'' | |||||
1 | """serialization utilities for apply messages""" |
|
1 | """serialization utilities for apply messages""" | |
2 |
|
2 | |||
3 | # Copyright (c) IPython Development Team. |
|
3 | # Copyright (c) IPython Development Team. | |
4 | # Distributed under the terms of the Modified BSD License. |
|
4 | # Distributed under the terms of the Modified BSD License. | |
5 |
|
5 | |||
6 | try: |
|
6 | try: | |
7 | import cPickle |
|
7 | import cPickle | |
8 | pickle = cPickle |
|
8 | pickle = cPickle | |
9 | except: |
|
9 | except: | |
10 | cPickle = None |
|
10 | cPickle = None | |
11 | import pickle |
|
11 | import pickle | |
12 |
|
12 | |||
13 | # IPython imports |
|
13 | # IPython imports | |
14 | from IPython.utils.py3compat import PY3, buffer_to_bytes_py2 |
|
14 | from IPython.utils.py3compat import PY3, buffer_to_bytes_py2 | |
15 | from IPython.utils.data import flatten |
|
15 | from IPython.utils.data import flatten | |
16 |
from |
|
16 | from ipython_kernel.pickleutil import ( | |
17 | can, uncan, can_sequence, uncan_sequence, CannedObject, |
|
17 | can, uncan, can_sequence, uncan_sequence, CannedObject, | |
18 | istype, sequence_types, PICKLE_PROTOCOL, |
|
18 | istype, sequence_types, PICKLE_PROTOCOL, | |
19 | ) |
|
19 | ) | |
20 | from jupyter_client.session import MAX_ITEMS, MAX_BYTES |
|
20 | from jupyter_client.session import MAX_ITEMS, MAX_BYTES | |
21 |
|
21 | |||
22 |
|
22 | |||
23 | if PY3: |
|
23 | if PY3: | |
24 | buffer = memoryview |
|
24 | buffer = memoryview | |
25 |
|
25 | |||
26 | #----------------------------------------------------------------------------- |
|
26 | #----------------------------------------------------------------------------- | |
27 | # Serialization Functions |
|
27 | # Serialization Functions | |
28 | #----------------------------------------------------------------------------- |
|
28 | #----------------------------------------------------------------------------- | |
29 |
|
29 | |||
30 |
|
30 | |||
31 | def _extract_buffers(obj, threshold=MAX_BYTES): |
|
31 | def _extract_buffers(obj, threshold=MAX_BYTES): | |
32 | """extract buffers larger than a certain threshold""" |
|
32 | """extract buffers larger than a certain threshold""" | |
33 | buffers = [] |
|
33 | buffers = [] | |
34 | if isinstance(obj, CannedObject) and obj.buffers: |
|
34 | if isinstance(obj, CannedObject) and obj.buffers: | |
35 | for i,buf in enumerate(obj.buffers): |
|
35 | for i,buf in enumerate(obj.buffers): | |
36 | if len(buf) > threshold: |
|
36 | if len(buf) > threshold: | |
37 | # buffer larger than threshold, prevent pickling |
|
37 | # buffer larger than threshold, prevent pickling | |
38 | obj.buffers[i] = None |
|
38 | obj.buffers[i] = None | |
39 | buffers.append(buf) |
|
39 | buffers.append(buf) | |
40 | elif isinstance(buf, buffer): |
|
40 | elif isinstance(buf, buffer): | |
41 | # buffer too small for separate send, coerce to bytes |
|
41 | # buffer too small for separate send, coerce to bytes | |
42 | # because pickling buffer objects just results in broken pointers |
|
42 | # because pickling buffer objects just results in broken pointers | |
43 | obj.buffers[i] = bytes(buf) |
|
43 | obj.buffers[i] = bytes(buf) | |
44 | return buffers |
|
44 | return buffers | |
45 |
|
45 | |||
46 | def _restore_buffers(obj, buffers): |
|
46 | def _restore_buffers(obj, buffers): | |
47 | """restore buffers extracted by """ |
|
47 | """restore buffers extracted by """ | |
48 | if isinstance(obj, CannedObject) and obj.buffers: |
|
48 | if isinstance(obj, CannedObject) and obj.buffers: | |
49 | for i,buf in enumerate(obj.buffers): |
|
49 | for i,buf in enumerate(obj.buffers): | |
50 | if buf is None: |
|
50 | if buf is None: | |
51 | obj.buffers[i] = buffers.pop(0) |
|
51 | obj.buffers[i] = buffers.pop(0) | |
52 |
|
52 | |||
53 | def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): |
|
53 | def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): | |
54 | """Serialize an object into a list of sendable buffers. |
|
54 | """Serialize an object into a list of sendable buffers. | |
55 |
|
55 | |||
56 | Parameters |
|
56 | Parameters | |
57 | ---------- |
|
57 | ---------- | |
58 |
|
58 | |||
59 | obj : object |
|
59 | obj : object | |
60 | The object to be serialized |
|
60 | The object to be serialized | |
61 | buffer_threshold : int |
|
61 | buffer_threshold : int | |
62 | The threshold (in bytes) for pulling out data buffers |
|
62 | The threshold (in bytes) for pulling out data buffers | |
63 | to avoid pickling them. |
|
63 | to avoid pickling them. | |
64 | item_threshold : int |
|
64 | item_threshold : int | |
65 | The maximum number of items over which canning will iterate. |
|
65 | The maximum number of items over which canning will iterate. | |
66 | Containers (lists, dicts) larger than this will be pickled without |
|
66 | Containers (lists, dicts) larger than this will be pickled without | |
67 | introspection. |
|
67 | introspection. | |
68 |
|
68 | |||
69 | Returns |
|
69 | Returns | |
70 | ------- |
|
70 | ------- | |
71 | [bufs] : list of buffers representing the serialized object. |
|
71 | [bufs] : list of buffers representing the serialized object. | |
72 | """ |
|
72 | """ | |
73 | buffers = [] |
|
73 | buffers = [] | |
74 | if istype(obj, sequence_types) and len(obj) < item_threshold: |
|
74 | if istype(obj, sequence_types) and len(obj) < item_threshold: | |
75 | cobj = can_sequence(obj) |
|
75 | cobj = can_sequence(obj) | |
76 | for c in cobj: |
|
76 | for c in cobj: | |
77 | buffers.extend(_extract_buffers(c, buffer_threshold)) |
|
77 | buffers.extend(_extract_buffers(c, buffer_threshold)) | |
78 | elif istype(obj, dict) and len(obj) < item_threshold: |
|
78 | elif istype(obj, dict) and len(obj) < item_threshold: | |
79 | cobj = {} |
|
79 | cobj = {} | |
80 | for k in sorted(obj): |
|
80 | for k in sorted(obj): | |
81 | c = can(obj[k]) |
|
81 | c = can(obj[k]) | |
82 | buffers.extend(_extract_buffers(c, buffer_threshold)) |
|
82 | buffers.extend(_extract_buffers(c, buffer_threshold)) | |
83 | cobj[k] = c |
|
83 | cobj[k] = c | |
84 | else: |
|
84 | else: | |
85 | cobj = can(obj) |
|
85 | cobj = can(obj) | |
86 | buffers.extend(_extract_buffers(cobj, buffer_threshold)) |
|
86 | buffers.extend(_extract_buffers(cobj, buffer_threshold)) | |
87 |
|
87 | |||
88 | buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL)) |
|
88 | buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL)) | |
89 | return buffers |
|
89 | return buffers | |
90 |
|
90 | |||
91 | def deserialize_object(buffers, g=None): |
|
91 | def deserialize_object(buffers, g=None): | |
92 | """reconstruct an object serialized by serialize_object from data buffers. |
|
92 | """reconstruct an object serialized by serialize_object from data buffers. | |
93 |
|
93 | |||
94 | Parameters |
|
94 | Parameters | |
95 | ---------- |
|
95 | ---------- | |
96 |
|
96 | |||
97 | bufs : list of buffers/bytes |
|
97 | bufs : list of buffers/bytes | |
98 |
|
98 | |||
99 | g : globals to be used when uncanning |
|
99 | g : globals to be used when uncanning | |
100 |
|
100 | |||
101 | Returns |
|
101 | Returns | |
102 | ------- |
|
102 | ------- | |
103 |
|
103 | |||
104 | (newobj, bufs) : unpacked object, and the list of remaining unused buffers. |
|
104 | (newobj, bufs) : unpacked object, and the list of remaining unused buffers. | |
105 | """ |
|
105 | """ | |
106 | bufs = list(buffers) |
|
106 | bufs = list(buffers) | |
107 | pobj = buffer_to_bytes_py2(bufs.pop(0)) |
|
107 | pobj = buffer_to_bytes_py2(bufs.pop(0)) | |
108 | canned = pickle.loads(pobj) |
|
108 | canned = pickle.loads(pobj) | |
109 | if istype(canned, sequence_types) and len(canned) < MAX_ITEMS: |
|
109 | if istype(canned, sequence_types) and len(canned) < MAX_ITEMS: | |
110 | for c in canned: |
|
110 | for c in canned: | |
111 | _restore_buffers(c, bufs) |
|
111 | _restore_buffers(c, bufs) | |
112 | newobj = uncan_sequence(canned, g) |
|
112 | newobj = uncan_sequence(canned, g) | |
113 | elif istype(canned, dict) and len(canned) < MAX_ITEMS: |
|
113 | elif istype(canned, dict) and len(canned) < MAX_ITEMS: | |
114 | newobj = {} |
|
114 | newobj = {} | |
115 | for k in sorted(canned): |
|
115 | for k in sorted(canned): | |
116 | c = canned[k] |
|
116 | c = canned[k] | |
117 | _restore_buffers(c, bufs) |
|
117 | _restore_buffers(c, bufs) | |
118 | newobj[k] = uncan(c, g) |
|
118 | newobj[k] = uncan(c, g) | |
119 | else: |
|
119 | else: | |
120 | _restore_buffers(canned, bufs) |
|
120 | _restore_buffers(canned, bufs) | |
121 | newobj = uncan(canned, g) |
|
121 | newobj = uncan(canned, g) | |
122 |
|
122 | |||
123 | return newobj, bufs |
|
123 | return newobj, bufs | |
124 |
|
124 | |||
125 | def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): |
|
125 | def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS): | |
126 | """pack up a function, args, and kwargs to be sent over the wire |
|
126 | """pack up a function, args, and kwargs to be sent over the wire | |
127 |
|
127 | |||
128 | Each element of args/kwargs will be canned for special treatment, |
|
128 | Each element of args/kwargs will be canned for special treatment, | |
129 | but inspection will not go any deeper than that. |
|
129 | but inspection will not go any deeper than that. | |
130 |
|
130 | |||
131 | Any object whose data is larger than `threshold` will not have their data copied |
|
131 | Any object whose data is larger than `threshold` will not have their data copied | |
132 | (only numpy arrays and bytes/buffers support zero-copy) |
|
132 | (only numpy arrays and bytes/buffers support zero-copy) | |
133 |
|
133 | |||
134 | Message will be a list of bytes/buffers of the format: |
|
134 | Message will be a list of bytes/buffers of the format: | |
135 |
|
135 | |||
136 | [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ] |
|
136 | [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ] | |
137 |
|
137 | |||
138 | With length at least two + len(args) + len(kwargs) |
|
138 | With length at least two + len(args) + len(kwargs) | |
139 | """ |
|
139 | """ | |
140 |
|
140 | |||
141 | arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args) |
|
141 | arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args) | |
142 |
|
142 | |||
143 | kw_keys = sorted(kwargs.keys()) |
|
143 | kw_keys = sorted(kwargs.keys()) | |
144 | kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys) |
|
144 | kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys) | |
145 |
|
145 | |||
146 | info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys) |
|
146 | info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys) | |
147 |
|
147 | |||
148 | msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)] |
|
148 | msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)] | |
149 | msg.append(pickle.dumps(info, PICKLE_PROTOCOL)) |
|
149 | msg.append(pickle.dumps(info, PICKLE_PROTOCOL)) | |
150 | msg.extend(arg_bufs) |
|
150 | msg.extend(arg_bufs) | |
151 | msg.extend(kwarg_bufs) |
|
151 | msg.extend(kwarg_bufs) | |
152 |
|
152 | |||
153 | return msg |
|
153 | return msg | |
154 |
|
154 | |||
155 | def unpack_apply_message(bufs, g=None, copy=True): |
|
155 | def unpack_apply_message(bufs, g=None, copy=True): | |
156 | """unpack f,args,kwargs from buffers packed by pack_apply_message() |
|
156 | """unpack f,args,kwargs from buffers packed by pack_apply_message() | |
157 | Returns: original f,args,kwargs""" |
|
157 | Returns: original f,args,kwargs""" | |
158 | bufs = list(bufs) # allow us to pop |
|
158 | bufs = list(bufs) # allow us to pop | |
159 | assert len(bufs) >= 2, "not enough buffers!" |
|
159 | assert len(bufs) >= 2, "not enough buffers!" | |
160 | pf = buffer_to_bytes_py2(bufs.pop(0)) |
|
160 | pf = buffer_to_bytes_py2(bufs.pop(0)) | |
161 | f = uncan(pickle.loads(pf), g) |
|
161 | f = uncan(pickle.loads(pf), g) | |
162 | pinfo = buffer_to_bytes_py2(bufs.pop(0)) |
|
162 | pinfo = buffer_to_bytes_py2(bufs.pop(0)) | |
163 | info = pickle.loads(pinfo) |
|
163 | info = pickle.loads(pinfo) | |
164 | arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:] |
|
164 | arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:] | |
165 |
|
165 | |||
166 | args = [] |
|
166 | args = [] | |
167 | for i in range(info['nargs']): |
|
167 | for i in range(info['nargs']): | |
168 | arg, arg_bufs = deserialize_object(arg_bufs, g) |
|
168 | arg, arg_bufs = deserialize_object(arg_bufs, g) | |
169 | args.append(arg) |
|
169 | args.append(arg) | |
170 | args = tuple(args) |
|
170 | args = tuple(args) | |
171 | assert not arg_bufs, "Shouldn't be any arg bufs left over" |
|
171 | assert not arg_bufs, "Shouldn't be any arg bufs left over" | |
172 |
|
172 | |||
173 | kwargs = {} |
|
173 | kwargs = {} | |
174 | for key in info['kw_keys']: |
|
174 | for key in info['kw_keys']: | |
175 | kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g) |
|
175 | kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g) | |
176 | kwargs[key] = kwarg |
|
176 | kwargs[key] = kwarg | |
177 | assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over" |
|
177 | assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over" | |
178 |
|
178 | |||
179 | return f,args,kwargs |
|
179 | return f,args,kwargs |
@@ -1,62 +1,62 b'' | |||||
1 |
|
1 | |||
2 | import pickle |
|
2 | import pickle | |
3 |
|
3 | |||
4 | import nose.tools as nt |
|
4 | import nose.tools as nt | |
5 |
from |
|
5 | from ipython_kernel import codeutil | |
6 |
from |
|
6 | from ipython_kernel.pickleutil import can, uncan | |
7 |
|
7 | |||
8 | def interactive(f): |
|
8 | def interactive(f): | |
9 | f.__module__ = '__main__' |
|
9 | f.__module__ = '__main__' | |
10 | return f |
|
10 | return f | |
11 |
|
11 | |||
12 | def dumps(obj): |
|
12 | def dumps(obj): | |
13 | return pickle.dumps(can(obj)) |
|
13 | return pickle.dumps(can(obj)) | |
14 |
|
14 | |||
15 | def loads(obj): |
|
15 | def loads(obj): | |
16 | return uncan(pickle.loads(obj)) |
|
16 | return uncan(pickle.loads(obj)) | |
17 |
|
17 | |||
18 | def test_no_closure(): |
|
18 | def test_no_closure(): | |
19 | @interactive |
|
19 | @interactive | |
20 | def foo(): |
|
20 | def foo(): | |
21 | a = 5 |
|
21 | a = 5 | |
22 | return a |
|
22 | return a | |
23 |
|
23 | |||
24 | pfoo = dumps(foo) |
|
24 | pfoo = dumps(foo) | |
25 | bar = loads(pfoo) |
|
25 | bar = loads(pfoo) | |
26 | nt.assert_equal(foo(), bar()) |
|
26 | nt.assert_equal(foo(), bar()) | |
27 |
|
27 | |||
28 | def test_generator_closure(): |
|
28 | def test_generator_closure(): | |
29 | # this only creates a closure on Python 3 |
|
29 | # this only creates a closure on Python 3 | |
30 | @interactive |
|
30 | @interactive | |
31 | def foo(): |
|
31 | def foo(): | |
32 | i = 'i' |
|
32 | i = 'i' | |
33 | r = [ i for j in (1,2) ] |
|
33 | r = [ i for j in (1,2) ] | |
34 | return r |
|
34 | return r | |
35 |
|
35 | |||
36 | pfoo = dumps(foo) |
|
36 | pfoo = dumps(foo) | |
37 | bar = loads(pfoo) |
|
37 | bar = loads(pfoo) | |
38 | nt.assert_equal(foo(), bar()) |
|
38 | nt.assert_equal(foo(), bar()) | |
39 |
|
39 | |||
40 | def test_nested_closure(): |
|
40 | def test_nested_closure(): | |
41 | @interactive |
|
41 | @interactive | |
42 | def foo(): |
|
42 | def foo(): | |
43 | i = 'i' |
|
43 | i = 'i' | |
44 | def g(): |
|
44 | def g(): | |
45 | return i |
|
45 | return i | |
46 | return g() |
|
46 | return g() | |
47 |
|
47 | |||
48 | pfoo = dumps(foo) |
|
48 | pfoo = dumps(foo) | |
49 | bar = loads(pfoo) |
|
49 | bar = loads(pfoo) | |
50 | nt.assert_equal(foo(), bar()) |
|
50 | nt.assert_equal(foo(), bar()) | |
51 |
|
51 | |||
52 | def test_closure(): |
|
52 | def test_closure(): | |
53 | i = 'i' |
|
53 | i = 'i' | |
54 | @interactive |
|
54 | @interactive | |
55 | def foo(): |
|
55 | def foo(): | |
56 | return i |
|
56 | return i | |
57 |
|
57 | |||
58 | pfoo = dumps(foo) |
|
58 | pfoo = dumps(foo) | |
59 | bar = loads(pfoo) |
|
59 | bar = loads(pfoo) | |
60 | nt.assert_equal(foo(), bar()) |
|
60 | nt.assert_equal(foo(), bar()) | |
61 |
|
61 | |||
62 | No newline at end of file |
|
62 |
@@ -1,208 +1,208 b'' | |||||
1 | """test serialization tools""" |
|
1 | """test serialization tools""" | |
2 |
|
2 | |||
3 | # Copyright (c) IPython Development Team. |
|
3 | # Copyright (c) IPython Development Team. | |
4 | # Distributed under the terms of the Modified BSD License. |
|
4 | # Distributed under the terms of the Modified BSD License. | |
5 |
|
5 | |||
6 | import pickle |
|
6 | import pickle | |
7 | from collections import namedtuple |
|
7 | from collections import namedtuple | |
8 |
|
8 | |||
9 | import nose.tools as nt |
|
9 | import nose.tools as nt | |
10 |
|
10 | |||
11 | # from unittest import TestCaes |
|
11 | # from unittest import TestCaes | |
12 | from ipython_kernel.serialize import serialize_object, deserialize_object |
|
12 | from ipython_kernel.serialize import serialize_object, deserialize_object | |
13 | from IPython.testing import decorators as dec |
|
13 | from IPython.testing import decorators as dec | |
14 |
from |
|
14 | from ipython_kernel.pickleutil import CannedArray, CannedClass | |
15 | from IPython.utils.py3compat import iteritems |
|
15 | from IPython.utils.py3compat import iteritems | |
16 | from IPython.parallel import interactive |
|
16 | from IPython.parallel import interactive | |
17 |
|
17 | |||
18 | #------------------------------------------------------------------------------- |
|
18 | #------------------------------------------------------------------------------- | |
19 | # Globals and Utilities |
|
19 | # Globals and Utilities | |
20 | #------------------------------------------------------------------------------- |
|
20 | #------------------------------------------------------------------------------- | |
21 |
|
21 | |||
22 | def roundtrip(obj): |
|
22 | def roundtrip(obj): | |
23 | """roundtrip an object through serialization""" |
|
23 | """roundtrip an object through serialization""" | |
24 | bufs = serialize_object(obj) |
|
24 | bufs = serialize_object(obj) | |
25 | obj2, remainder = deserialize_object(bufs) |
|
25 | obj2, remainder = deserialize_object(bufs) | |
26 | nt.assert_equals(remainder, []) |
|
26 | nt.assert_equals(remainder, []) | |
27 | return obj2 |
|
27 | return obj2 | |
28 |
|
28 | |||
29 | class C(object): |
|
29 | class C(object): | |
30 | """dummy class for """ |
|
30 | """dummy class for """ | |
31 |
|
31 | |||
32 | def __init__(self, **kwargs): |
|
32 | def __init__(self, **kwargs): | |
33 | for key,value in iteritems(kwargs): |
|
33 | for key,value in iteritems(kwargs): | |
34 | setattr(self, key, value) |
|
34 | setattr(self, key, value) | |
35 |
|
35 | |||
36 | SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,)) |
|
36 | SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,)) | |
37 | DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10') |
|
37 | DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10') | |
38 |
|
38 | |||
39 | #------------------------------------------------------------------------------- |
|
39 | #------------------------------------------------------------------------------- | |
40 | # Tests |
|
40 | # Tests | |
41 | #------------------------------------------------------------------------------- |
|
41 | #------------------------------------------------------------------------------- | |
42 |
|
42 | |||
43 | def new_array(shape, dtype): |
|
43 | def new_array(shape, dtype): | |
44 | import numpy |
|
44 | import numpy | |
45 | return numpy.random.random(shape).astype(dtype) |
|
45 | return numpy.random.random(shape).astype(dtype) | |
46 |
|
46 | |||
47 | def test_roundtrip_simple(): |
|
47 | def test_roundtrip_simple(): | |
48 | for obj in [ |
|
48 | for obj in [ | |
49 | 'hello', |
|
49 | 'hello', | |
50 | dict(a='b', b=10), |
|
50 | dict(a='b', b=10), | |
51 | [1,2,'hi'], |
|
51 | [1,2,'hi'], | |
52 | (b'123', 'hello'), |
|
52 | (b'123', 'hello'), | |
53 | ]: |
|
53 | ]: | |
54 | obj2 = roundtrip(obj) |
|
54 | obj2 = roundtrip(obj) | |
55 | nt.assert_equal(obj, obj2) |
|
55 | nt.assert_equal(obj, obj2) | |
56 |
|
56 | |||
57 | def test_roundtrip_nested(): |
|
57 | def test_roundtrip_nested(): | |
58 | for obj in [ |
|
58 | for obj in [ | |
59 | dict(a=range(5), b={1:b'hello'}), |
|
59 | dict(a=range(5), b={1:b'hello'}), | |
60 | [range(5),[range(3),(1,[b'whoda'])]], |
|
60 | [range(5),[range(3),(1,[b'whoda'])]], | |
61 | ]: |
|
61 | ]: | |
62 | obj2 = roundtrip(obj) |
|
62 | obj2 = roundtrip(obj) | |
63 | nt.assert_equal(obj, obj2) |
|
63 | nt.assert_equal(obj, obj2) | |
64 |
|
64 | |||
65 | def test_roundtrip_buffered(): |
|
65 | def test_roundtrip_buffered(): | |
66 | for obj in [ |
|
66 | for obj in [ | |
67 | dict(a=b"x"*1025), |
|
67 | dict(a=b"x"*1025), | |
68 | b"hello"*500, |
|
68 | b"hello"*500, | |
69 | [b"hello"*501, 1,2,3] |
|
69 | [b"hello"*501, 1,2,3] | |
70 | ]: |
|
70 | ]: | |
71 | bufs = serialize_object(obj) |
|
71 | bufs = serialize_object(obj) | |
72 | nt.assert_equal(len(bufs), 2) |
|
72 | nt.assert_equal(len(bufs), 2) | |
73 | obj2, remainder = deserialize_object(bufs) |
|
73 | obj2, remainder = deserialize_object(bufs) | |
74 | nt.assert_equal(remainder, []) |
|
74 | nt.assert_equal(remainder, []) | |
75 | nt.assert_equal(obj, obj2) |
|
75 | nt.assert_equal(obj, obj2) | |
76 |
|
76 | |||
77 | @dec.skip_without('numpy') |
|
77 | @dec.skip_without('numpy') | |
78 | def test_numpy(): |
|
78 | def test_numpy(): | |
79 | import numpy |
|
79 | import numpy | |
80 | from numpy.testing.utils import assert_array_equal |
|
80 | from numpy.testing.utils import assert_array_equal | |
81 | for shape in SHAPES: |
|
81 | for shape in SHAPES: | |
82 | for dtype in DTYPES: |
|
82 | for dtype in DTYPES: | |
83 | A = new_array(shape, dtype=dtype) |
|
83 | A = new_array(shape, dtype=dtype) | |
84 | bufs = serialize_object(A) |
|
84 | bufs = serialize_object(A) | |
85 | B, r = deserialize_object(bufs) |
|
85 | B, r = deserialize_object(bufs) | |
86 | nt.assert_equal(r, []) |
|
86 | nt.assert_equal(r, []) | |
87 | nt.assert_equal(A.shape, B.shape) |
|
87 | nt.assert_equal(A.shape, B.shape) | |
88 | nt.assert_equal(A.dtype, B.dtype) |
|
88 | nt.assert_equal(A.dtype, B.dtype) | |
89 | assert_array_equal(A,B) |
|
89 | assert_array_equal(A,B) | |
90 |
|
90 | |||
91 | @dec.skip_without('numpy') |
|
91 | @dec.skip_without('numpy') | |
92 | def test_recarray(): |
|
92 | def test_recarray(): | |
93 | import numpy |
|
93 | import numpy | |
94 | from numpy.testing.utils import assert_array_equal |
|
94 | from numpy.testing.utils import assert_array_equal | |
95 | for shape in SHAPES: |
|
95 | for shape in SHAPES: | |
96 | for dtype in [ |
|
96 | for dtype in [ | |
97 | [('f', float), ('s', '|S10')], |
|
97 | [('f', float), ('s', '|S10')], | |
98 | [('n', int), ('s', '|S1'), ('u', 'uint32')], |
|
98 | [('n', int), ('s', '|S1'), ('u', 'uint32')], | |
99 | ]: |
|
99 | ]: | |
100 | A = new_array(shape, dtype=dtype) |
|
100 | A = new_array(shape, dtype=dtype) | |
101 |
|
101 | |||
102 | bufs = serialize_object(A) |
|
102 | bufs = serialize_object(A) | |
103 | B, r = deserialize_object(bufs) |
|
103 | B, r = deserialize_object(bufs) | |
104 | nt.assert_equal(r, []) |
|
104 | nt.assert_equal(r, []) | |
105 | nt.assert_equal(A.shape, B.shape) |
|
105 | nt.assert_equal(A.shape, B.shape) | |
106 | nt.assert_equal(A.dtype, B.dtype) |
|
106 | nt.assert_equal(A.dtype, B.dtype) | |
107 | assert_array_equal(A,B) |
|
107 | assert_array_equal(A,B) | |
108 |
|
108 | |||
109 | @dec.skip_without('numpy') |
|
109 | @dec.skip_without('numpy') | |
110 | def test_numpy_in_seq(): |
|
110 | def test_numpy_in_seq(): | |
111 | import numpy |
|
111 | import numpy | |
112 | from numpy.testing.utils import assert_array_equal |
|
112 | from numpy.testing.utils import assert_array_equal | |
113 | for shape in SHAPES: |
|
113 | for shape in SHAPES: | |
114 | for dtype in DTYPES: |
|
114 | for dtype in DTYPES: | |
115 | A = new_array(shape, dtype=dtype) |
|
115 | A = new_array(shape, dtype=dtype) | |
116 | bufs = serialize_object((A,1,2,b'hello')) |
|
116 | bufs = serialize_object((A,1,2,b'hello')) | |
117 | canned = pickle.loads(bufs[0]) |
|
117 | canned = pickle.loads(bufs[0]) | |
118 | nt.assert_is_instance(canned[0], CannedArray) |
|
118 | nt.assert_is_instance(canned[0], CannedArray) | |
119 | tup, r = deserialize_object(bufs) |
|
119 | tup, r = deserialize_object(bufs) | |
120 | B = tup[0] |
|
120 | B = tup[0] | |
121 | nt.assert_equal(r, []) |
|
121 | nt.assert_equal(r, []) | |
122 | nt.assert_equal(A.shape, B.shape) |
|
122 | nt.assert_equal(A.shape, B.shape) | |
123 | nt.assert_equal(A.dtype, B.dtype) |
|
123 | nt.assert_equal(A.dtype, B.dtype) | |
124 | assert_array_equal(A,B) |
|
124 | assert_array_equal(A,B) | |
125 |
|
125 | |||
126 | @dec.skip_without('numpy') |
|
126 | @dec.skip_without('numpy') | |
127 | def test_numpy_in_dict(): |
|
127 | def test_numpy_in_dict(): | |
128 | import numpy |
|
128 | import numpy | |
129 | from numpy.testing.utils import assert_array_equal |
|
129 | from numpy.testing.utils import assert_array_equal | |
130 | for shape in SHAPES: |
|
130 | for shape in SHAPES: | |
131 | for dtype in DTYPES: |
|
131 | for dtype in DTYPES: | |
132 | A = new_array(shape, dtype=dtype) |
|
132 | A = new_array(shape, dtype=dtype) | |
133 | bufs = serialize_object(dict(a=A,b=1,c=range(20))) |
|
133 | bufs = serialize_object(dict(a=A,b=1,c=range(20))) | |
134 | canned = pickle.loads(bufs[0]) |
|
134 | canned = pickle.loads(bufs[0]) | |
135 | nt.assert_is_instance(canned['a'], CannedArray) |
|
135 | nt.assert_is_instance(canned['a'], CannedArray) | |
136 | d, r = deserialize_object(bufs) |
|
136 | d, r = deserialize_object(bufs) | |
137 | B = d['a'] |
|
137 | B = d['a'] | |
138 | nt.assert_equal(r, []) |
|
138 | nt.assert_equal(r, []) | |
139 | nt.assert_equal(A.shape, B.shape) |
|
139 | nt.assert_equal(A.shape, B.shape) | |
140 | nt.assert_equal(A.dtype, B.dtype) |
|
140 | nt.assert_equal(A.dtype, B.dtype) | |
141 | assert_array_equal(A,B) |
|
141 | assert_array_equal(A,B) | |
142 |
|
142 | |||
143 | def test_class(): |
|
143 | def test_class(): | |
144 | @interactive |
|
144 | @interactive | |
145 | class C(object): |
|
145 | class C(object): | |
146 | a=5 |
|
146 | a=5 | |
147 | bufs = serialize_object(dict(C=C)) |
|
147 | bufs = serialize_object(dict(C=C)) | |
148 | canned = pickle.loads(bufs[0]) |
|
148 | canned = pickle.loads(bufs[0]) | |
149 | nt.assert_is_instance(canned['C'], CannedClass) |
|
149 | nt.assert_is_instance(canned['C'], CannedClass) | |
150 | d, r = deserialize_object(bufs) |
|
150 | d, r = deserialize_object(bufs) | |
151 | C2 = d['C'] |
|
151 | C2 = d['C'] | |
152 | nt.assert_equal(C2.a, C.a) |
|
152 | nt.assert_equal(C2.a, C.a) | |
153 |
|
153 | |||
154 | def test_class_oldstyle(): |
|
154 | def test_class_oldstyle(): | |
155 | @interactive |
|
155 | @interactive | |
156 | class C: |
|
156 | class C: | |
157 | a=5 |
|
157 | a=5 | |
158 |
|
158 | |||
159 | bufs = serialize_object(dict(C=C)) |
|
159 | bufs = serialize_object(dict(C=C)) | |
160 | canned = pickle.loads(bufs[0]) |
|
160 | canned = pickle.loads(bufs[0]) | |
161 | nt.assert_is_instance(canned['C'], CannedClass) |
|
161 | nt.assert_is_instance(canned['C'], CannedClass) | |
162 | d, r = deserialize_object(bufs) |
|
162 | d, r = deserialize_object(bufs) | |
163 | C2 = d['C'] |
|
163 | C2 = d['C'] | |
164 | nt.assert_equal(C2.a, C.a) |
|
164 | nt.assert_equal(C2.a, C.a) | |
165 |
|
165 | |||
166 | def test_tuple(): |
|
166 | def test_tuple(): | |
167 | tup = (lambda x:x, 1) |
|
167 | tup = (lambda x:x, 1) | |
168 | bufs = serialize_object(tup) |
|
168 | bufs = serialize_object(tup) | |
169 | canned = pickle.loads(bufs[0]) |
|
169 | canned = pickle.loads(bufs[0]) | |
170 | nt.assert_is_instance(canned, tuple) |
|
170 | nt.assert_is_instance(canned, tuple) | |
171 | t2, r = deserialize_object(bufs) |
|
171 | t2, r = deserialize_object(bufs) | |
172 | nt.assert_equal(t2[0](t2[1]), tup[0](tup[1])) |
|
172 | nt.assert_equal(t2[0](t2[1]), tup[0](tup[1])) | |
173 |
|
173 | |||
174 | point = namedtuple('point', 'x y') |
|
174 | point = namedtuple('point', 'x y') | |
175 |
|
175 | |||
176 | def test_namedtuple(): |
|
176 | def test_namedtuple(): | |
177 | p = point(1,2) |
|
177 | p = point(1,2) | |
178 | bufs = serialize_object(p) |
|
178 | bufs = serialize_object(p) | |
179 | canned = pickle.loads(bufs[0]) |
|
179 | canned = pickle.loads(bufs[0]) | |
180 | nt.assert_is_instance(canned, point) |
|
180 | nt.assert_is_instance(canned, point) | |
181 | p2, r = deserialize_object(bufs, globals()) |
|
181 | p2, r = deserialize_object(bufs, globals()) | |
182 | nt.assert_equal(p2.x, p.x) |
|
182 | nt.assert_equal(p2.x, p.x) | |
183 | nt.assert_equal(p2.y, p.y) |
|
183 | nt.assert_equal(p2.y, p.y) | |
184 |
|
184 | |||
185 | def test_list(): |
|
185 | def test_list(): | |
186 | lis = [lambda x:x, 1] |
|
186 | lis = [lambda x:x, 1] | |
187 | bufs = serialize_object(lis) |
|
187 | bufs = serialize_object(lis) | |
188 | canned = pickle.loads(bufs[0]) |
|
188 | canned = pickle.loads(bufs[0]) | |
189 | nt.assert_is_instance(canned, list) |
|
189 | nt.assert_is_instance(canned, list) | |
190 | l2, r = deserialize_object(bufs) |
|
190 | l2, r = deserialize_object(bufs) | |
191 | nt.assert_equal(l2[0](l2[1]), lis[0](lis[1])) |
|
191 | nt.assert_equal(l2[0](l2[1]), lis[0](lis[1])) | |
192 |
|
192 | |||
193 | def test_class_inheritance(): |
|
193 | def test_class_inheritance(): | |
194 | @interactive |
|
194 | @interactive | |
195 | class C(object): |
|
195 | class C(object): | |
196 | a=5 |
|
196 | a=5 | |
197 |
|
197 | |||
198 | @interactive |
|
198 | @interactive | |
199 | class D(C): |
|
199 | class D(C): | |
200 | b=10 |
|
200 | b=10 | |
201 |
|
201 | |||
202 | bufs = serialize_object(dict(D=D)) |
|
202 | bufs = serialize_object(dict(D=D)) | |
203 | canned = pickle.loads(bufs[0]) |
|
203 | canned = pickle.loads(bufs[0]) | |
204 | nt.assert_is_instance(canned['D'], CannedClass) |
|
204 | nt.assert_is_instance(canned['D'], CannedClass) | |
205 | d, r = deserialize_object(bufs) |
|
205 | d, r = deserialize_object(bufs) | |
206 | D2 = d['D'] |
|
206 | D2 = d['D'] | |
207 | nt.assert_equal(D2.a, D.a) |
|
207 | nt.assert_equal(D2.a, D.a) | |
208 | nt.assert_equal(D2.b, D.b) |
|
208 | nt.assert_equal(D2.b, D.b) |
@@ -1,72 +1,72 b'' | |||||
1 | """The IPython ZMQ-based parallel computing interface. |
|
1 | """The IPython ZMQ-based parallel computing interface. | |
2 |
|
2 | |||
3 | Authors: |
|
3 | Authors: | |
4 |
|
4 | |||
5 | * MinRK |
|
5 | * MinRK | |
6 | """ |
|
6 | """ | |
7 | #----------------------------------------------------------------------------- |
|
7 | #----------------------------------------------------------------------------- | |
8 | # Copyright (C) 2011 The IPython Development Team |
|
8 | # Copyright (C) 2011 The IPython Development Team | |
9 | # |
|
9 | # | |
10 | # Distributed under the terms of the BSD License. The full license is in |
|
10 | # Distributed under the terms of the BSD License. The full license is in | |
11 | # the file COPYING, distributed as part of this software. |
|
11 | # the file COPYING, distributed as part of this software. | |
12 | #----------------------------------------------------------------------------- |
|
12 | #----------------------------------------------------------------------------- | |
13 |
|
13 | |||
14 | #----------------------------------------------------------------------------- |
|
14 | #----------------------------------------------------------------------------- | |
15 | # Imports |
|
15 | # Imports | |
16 | #----------------------------------------------------------------------------- |
|
16 | #----------------------------------------------------------------------------- | |
17 |
|
17 | |||
18 | import os |
|
18 | import os | |
19 | import warnings |
|
19 | import warnings | |
20 |
|
20 | |||
21 | import zmq |
|
21 | import zmq | |
22 |
|
22 | |||
23 | from IPython.config.configurable import MultipleInstanceError |
|
23 | from IPython.config.configurable import MultipleInstanceError | |
24 | from IPython.utils.zmqrelated import check_for_zmq |
|
24 | from IPython.utils.zmqrelated import check_for_zmq | |
25 |
|
25 | |||
26 | min_pyzmq = '2.1.11' |
|
26 | min_pyzmq = '2.1.11' | |
27 |
|
27 | |||
28 | check_for_zmq(min_pyzmq, 'ipython_parallel') |
|
28 | check_for_zmq(min_pyzmq, 'ipython_parallel') | |
29 |
|
29 | |||
30 |
from |
|
30 | from ipython_kernel.pickleutil import Reference | |
31 |
|
31 | |||
32 | from .client.asyncresult import * |
|
32 | from .client.asyncresult import * | |
33 | from .client.client import Client |
|
33 | from .client.client import Client | |
34 | from .client.remotefunction import * |
|
34 | from .client.remotefunction import * | |
35 | from .client.view import * |
|
35 | from .client.view import * | |
36 | from .controller.dependency import * |
|
36 | from .controller.dependency import * | |
37 | from .error import * |
|
37 | from .error import * | |
38 | from .util import interactive |
|
38 | from .util import interactive | |
39 |
|
39 | |||
40 | #----------------------------------------------------------------------------- |
|
40 | #----------------------------------------------------------------------------- | |
41 | # Functions |
|
41 | # Functions | |
42 | #----------------------------------------------------------------------------- |
|
42 | #----------------------------------------------------------------------------- | |
43 |
|
43 | |||
44 |
|
44 | |||
45 | def bind_kernel(**kwargs): |
|
45 | def bind_kernel(**kwargs): | |
46 | """Bind an Engine's Kernel to be used as a full IPython kernel. |
|
46 | """Bind an Engine's Kernel to be used as a full IPython kernel. | |
47 |
|
47 | |||
48 | This allows a running Engine to be used simultaneously as a full IPython kernel |
|
48 | This allows a running Engine to be used simultaneously as a full IPython kernel | |
49 | with the QtConsole or other frontends. |
|
49 | with the QtConsole or other frontends. | |
50 |
|
50 | |||
51 | This function returns immediately. |
|
51 | This function returns immediately. | |
52 | """ |
|
52 | """ | |
53 | from IPython.kernel.zmq.kernelapp import IPKernelApp |
|
53 | from IPython.kernel.zmq.kernelapp import IPKernelApp | |
54 | from ipython_parallel.apps.ipengineapp import IPEngineApp |
|
54 | from ipython_parallel.apps.ipengineapp import IPEngineApp | |
55 |
|
55 | |||
56 | # first check for IPKernelApp, in which case this should be a no-op |
|
56 | # first check for IPKernelApp, in which case this should be a no-op | |
57 | # because there is already a bound kernel |
|
57 | # because there is already a bound kernel | |
58 | if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp): |
|
58 | if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp): | |
59 | return |
|
59 | return | |
60 |
|
60 | |||
61 | if IPEngineApp.initialized(): |
|
61 | if IPEngineApp.initialized(): | |
62 | try: |
|
62 | try: | |
63 | app = IPEngineApp.instance() |
|
63 | app = IPEngineApp.instance() | |
64 | except MultipleInstanceError: |
|
64 | except MultipleInstanceError: | |
65 | pass |
|
65 | pass | |
66 | else: |
|
66 | else: | |
67 | return app.bind_kernel(**kwargs) |
|
67 | return app.bind_kernel(**kwargs) | |
68 |
|
68 | |||
69 | raise RuntimeError("bind_kernel be called from an IPEngineApp instance") |
|
69 | raise RuntimeError("bind_kernel be called from an IPEngineApp instance") | |
70 |
|
70 | |||
71 |
|
71 | |||
72 |
|
72 |
@@ -1,1125 +1,1125 b'' | |||||
1 | """Views of remote engines.""" |
|
1 | """Views of remote engines.""" | |
2 |
|
2 | |||
3 | # Copyright (c) IPython Development Team. |
|
3 | # Copyright (c) IPython Development Team. | |
4 | # Distributed under the terms of the Modified BSD License. |
|
4 | # Distributed under the terms of the Modified BSD License. | |
5 |
|
5 | |||
6 | from __future__ import print_function |
|
6 | from __future__ import print_function | |
7 |
|
7 | |||
8 | import imp |
|
8 | import imp | |
9 | import sys |
|
9 | import sys | |
10 | import warnings |
|
10 | import warnings | |
11 | from contextlib import contextmanager |
|
11 | from contextlib import contextmanager | |
12 | from types import ModuleType |
|
12 | from types import ModuleType | |
13 |
|
13 | |||
14 | import zmq |
|
14 | import zmq | |
15 |
|
15 | |||
16 | from IPython.testing.skipdoctest import skip_doctest |
|
16 | from IPython.testing.skipdoctest import skip_doctest | |
17 | from IPython.utils import pickleutil |
|
17 | from IPython.utils import pickleutil | |
18 | from IPython.utils.traitlets import ( |
|
18 | from IPython.utils.traitlets import ( | |
19 | HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer |
|
19 | HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer | |
20 | ) |
|
20 | ) | |
21 | from decorator import decorator |
|
21 | from decorator import decorator | |
22 |
|
22 | |||
23 | from ipython_parallel import util |
|
23 | from ipython_parallel import util | |
24 | from ipython_parallel.controller.dependency import Dependency, dependent |
|
24 | from ipython_parallel.controller.dependency import Dependency, dependent | |
25 | from IPython.utils.py3compat import string_types, iteritems, PY3 |
|
25 | from IPython.utils.py3compat import string_types, iteritems, PY3 | |
26 |
|
26 | |||
27 | from . import map as Map |
|
27 | from . import map as Map | |
28 | from .asyncresult import AsyncResult, AsyncMapResult |
|
28 | from .asyncresult import AsyncResult, AsyncMapResult | |
29 | from .remotefunction import ParallelFunction, parallel, remote, getname |
|
29 | from .remotefunction import ParallelFunction, parallel, remote, getname | |
30 |
|
30 | |||
31 | #----------------------------------------------------------------------------- |
|
31 | #----------------------------------------------------------------------------- | |
32 | # Decorators |
|
32 | # Decorators | |
33 | #----------------------------------------------------------------------------- |
|
33 | #----------------------------------------------------------------------------- | |
34 |
|
34 | |||
35 | @decorator |
|
35 | @decorator | |
36 | def save_ids(f, self, *args, **kwargs): |
|
36 | def save_ids(f, self, *args, **kwargs): | |
37 | """Keep our history and outstanding attributes up to date after a method call.""" |
|
37 | """Keep our history and outstanding attributes up to date after a method call.""" | |
38 | n_previous = len(self.client.history) |
|
38 | n_previous = len(self.client.history) | |
39 | try: |
|
39 | try: | |
40 | ret = f(self, *args, **kwargs) |
|
40 | ret = f(self, *args, **kwargs) | |
41 | finally: |
|
41 | finally: | |
42 | nmsgs = len(self.client.history) - n_previous |
|
42 | nmsgs = len(self.client.history) - n_previous | |
43 | msg_ids = self.client.history[-nmsgs:] |
|
43 | msg_ids = self.client.history[-nmsgs:] | |
44 | self.history.extend(msg_ids) |
|
44 | self.history.extend(msg_ids) | |
45 | self.outstanding.update(msg_ids) |
|
45 | self.outstanding.update(msg_ids) | |
46 | return ret |
|
46 | return ret | |
47 |
|
47 | |||
48 | @decorator |
|
48 | @decorator | |
49 | def sync_results(f, self, *args, **kwargs): |
|
49 | def sync_results(f, self, *args, **kwargs): | |
50 | """sync relevant results from self.client to our results attribute.""" |
|
50 | """sync relevant results from self.client to our results attribute.""" | |
51 | if self._in_sync_results: |
|
51 | if self._in_sync_results: | |
52 | return f(self, *args, **kwargs) |
|
52 | return f(self, *args, **kwargs) | |
53 | self._in_sync_results = True |
|
53 | self._in_sync_results = True | |
54 | try: |
|
54 | try: | |
55 | ret = f(self, *args, **kwargs) |
|
55 | ret = f(self, *args, **kwargs) | |
56 | finally: |
|
56 | finally: | |
57 | self._in_sync_results = False |
|
57 | self._in_sync_results = False | |
58 | self._sync_results() |
|
58 | self._sync_results() | |
59 | return ret |
|
59 | return ret | |
60 |
|
60 | |||
61 | @decorator |
|
61 | @decorator | |
62 | def spin_after(f, self, *args, **kwargs): |
|
62 | def spin_after(f, self, *args, **kwargs): | |
63 | """call spin after the method.""" |
|
63 | """call spin after the method.""" | |
64 | ret = f(self, *args, **kwargs) |
|
64 | ret = f(self, *args, **kwargs) | |
65 | self.spin() |
|
65 | self.spin() | |
66 | return ret |
|
66 | return ret | |
67 |
|
67 | |||
68 | #----------------------------------------------------------------------------- |
|
68 | #----------------------------------------------------------------------------- | |
69 | # Classes |
|
69 | # Classes | |
70 | #----------------------------------------------------------------------------- |
|
70 | #----------------------------------------------------------------------------- | |
71 |
|
71 | |||
72 | @skip_doctest |
|
72 | @skip_doctest | |
73 | class View(HasTraits): |
|
73 | class View(HasTraits): | |
74 | """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes. |
|
74 | """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes. | |
75 |
|
75 | |||
76 | Don't use this class, use subclasses. |
|
76 | Don't use this class, use subclasses. | |
77 |
|
77 | |||
78 | Methods |
|
78 | Methods | |
79 | ------- |
|
79 | ------- | |
80 |
|
80 | |||
81 | spin |
|
81 | spin | |
82 | flushes incoming results and registration state changes |
|
82 | flushes incoming results and registration state changes | |
83 | control methods spin, and requesting `ids` also ensures up to date |
|
83 | control methods spin, and requesting `ids` also ensures up to date | |
84 |
|
84 | |||
85 | wait |
|
85 | wait | |
86 | wait on one or more msg_ids |
|
86 | wait on one or more msg_ids | |
87 |
|
87 | |||
88 | execution methods |
|
88 | execution methods | |
89 | apply |
|
89 | apply | |
90 | legacy: execute, run |
|
90 | legacy: execute, run | |
91 |
|
91 | |||
92 | data movement |
|
92 | data movement | |
93 | push, pull, scatter, gather |
|
93 | push, pull, scatter, gather | |
94 |
|
94 | |||
95 | query methods |
|
95 | query methods | |
96 | get_result, queue_status, purge_results, result_status |
|
96 | get_result, queue_status, purge_results, result_status | |
97 |
|
97 | |||
98 | control methods |
|
98 | control methods | |
99 | abort, shutdown |
|
99 | abort, shutdown | |
100 |
|
100 | |||
101 | """ |
|
101 | """ | |
102 | # flags |
|
102 | # flags | |
103 | block=Bool(False) |
|
103 | block=Bool(False) | |
104 | track=Bool(True) |
|
104 | track=Bool(True) | |
105 | targets = Any() |
|
105 | targets = Any() | |
106 |
|
106 | |||
107 | history=List() |
|
107 | history=List() | |
108 | outstanding = Set() |
|
108 | outstanding = Set() | |
109 | results = Dict() |
|
109 | results = Dict() | |
110 | client = Instance('ipython_parallel.Client', allow_none=True) |
|
110 | client = Instance('ipython_parallel.Client', allow_none=True) | |
111 |
|
111 | |||
112 | _socket = Instance('zmq.Socket', allow_none=True) |
|
112 | _socket = Instance('zmq.Socket', allow_none=True) | |
113 | _flag_names = List(['targets', 'block', 'track']) |
|
113 | _flag_names = List(['targets', 'block', 'track']) | |
114 | _in_sync_results = Bool(False) |
|
114 | _in_sync_results = Bool(False) | |
115 | _targets = Any() |
|
115 | _targets = Any() | |
116 | _idents = Any() |
|
116 | _idents = Any() | |
117 |
|
117 | |||
118 | def __init__(self, client=None, socket=None, **flags): |
|
118 | def __init__(self, client=None, socket=None, **flags): | |
119 | super(View, self).__init__(client=client, _socket=socket) |
|
119 | super(View, self).__init__(client=client, _socket=socket) | |
120 | self.results = client.results |
|
120 | self.results = client.results | |
121 | self.block = client.block |
|
121 | self.block = client.block | |
122 |
|
122 | |||
123 | self.set_flags(**flags) |
|
123 | self.set_flags(**flags) | |
124 |
|
124 | |||
125 | assert not self.__class__ is View, "Don't use base View objects, use subclasses" |
|
125 | assert not self.__class__ is View, "Don't use base View objects, use subclasses" | |
126 |
|
126 | |||
127 | def __repr__(self): |
|
127 | def __repr__(self): | |
128 | strtargets = str(self.targets) |
|
128 | strtargets = str(self.targets) | |
129 | if len(strtargets) > 16: |
|
129 | if len(strtargets) > 16: | |
130 | strtargets = strtargets[:12]+'...]' |
|
130 | strtargets = strtargets[:12]+'...]' | |
131 | return "<%s %s>"%(self.__class__.__name__, strtargets) |
|
131 | return "<%s %s>"%(self.__class__.__name__, strtargets) | |
132 |
|
132 | |||
133 | def __len__(self): |
|
133 | def __len__(self): | |
134 | if isinstance(self.targets, list): |
|
134 | if isinstance(self.targets, list): | |
135 | return len(self.targets) |
|
135 | return len(self.targets) | |
136 | elif isinstance(self.targets, int): |
|
136 | elif isinstance(self.targets, int): | |
137 | return 1 |
|
137 | return 1 | |
138 | else: |
|
138 | else: | |
139 | return len(self.client) |
|
139 | return len(self.client) | |
140 |
|
140 | |||
141 | def set_flags(self, **kwargs): |
|
141 | def set_flags(self, **kwargs): | |
142 | """set my attribute flags by keyword. |
|
142 | """set my attribute flags by keyword. | |
143 |
|
143 | |||
144 | Views determine behavior with a few attributes (`block`, `track`, etc.). |
|
144 | Views determine behavior with a few attributes (`block`, `track`, etc.). | |
145 | These attributes can be set all at once by name with this method. |
|
145 | These attributes can be set all at once by name with this method. | |
146 |
|
146 | |||
147 | Parameters |
|
147 | Parameters | |
148 | ---------- |
|
148 | ---------- | |
149 |
|
149 | |||
150 | block : bool |
|
150 | block : bool | |
151 | whether to wait for results |
|
151 | whether to wait for results | |
152 | track : bool |
|
152 | track : bool | |
153 | whether to create a MessageTracker to allow the user to |
|
153 | whether to create a MessageTracker to allow the user to | |
154 | safely edit after arrays and buffers during non-copying |
|
154 | safely edit after arrays and buffers during non-copying | |
155 | sends. |
|
155 | sends. | |
156 | """ |
|
156 | """ | |
157 | for name, value in iteritems(kwargs): |
|
157 | for name, value in iteritems(kwargs): | |
158 | if name not in self._flag_names: |
|
158 | if name not in self._flag_names: | |
159 | raise KeyError("Invalid name: %r"%name) |
|
159 | raise KeyError("Invalid name: %r"%name) | |
160 | else: |
|
160 | else: | |
161 | setattr(self, name, value) |
|
161 | setattr(self, name, value) | |
162 |
|
162 | |||
163 | @contextmanager |
|
163 | @contextmanager | |
164 | def temp_flags(self, **kwargs): |
|
164 | def temp_flags(self, **kwargs): | |
165 | """temporarily set flags, for use in `with` statements. |
|
165 | """temporarily set flags, for use in `with` statements. | |
166 |
|
166 | |||
167 | See set_flags for permanent setting of flags |
|
167 | See set_flags for permanent setting of flags | |
168 |
|
168 | |||
169 | Examples |
|
169 | Examples | |
170 | -------- |
|
170 | -------- | |
171 |
|
171 | |||
172 | >>> view.track=False |
|
172 | >>> view.track=False | |
173 | ... |
|
173 | ... | |
174 | >>> with view.temp_flags(track=True): |
|
174 | >>> with view.temp_flags(track=True): | |
175 | ... ar = view.apply(dostuff, my_big_array) |
|
175 | ... ar = view.apply(dostuff, my_big_array) | |
176 | ... ar.tracker.wait() # wait for send to finish |
|
176 | ... ar.tracker.wait() # wait for send to finish | |
177 | >>> view.track |
|
177 | >>> view.track | |
178 | False |
|
178 | False | |
179 |
|
179 | |||
180 | """ |
|
180 | """ | |
181 | # preflight: save flags, and set temporaries |
|
181 | # preflight: save flags, and set temporaries | |
182 | saved_flags = {} |
|
182 | saved_flags = {} | |
183 | for f in self._flag_names: |
|
183 | for f in self._flag_names: | |
184 | saved_flags[f] = getattr(self, f) |
|
184 | saved_flags[f] = getattr(self, f) | |
185 | self.set_flags(**kwargs) |
|
185 | self.set_flags(**kwargs) | |
186 | # yield to the with-statement block |
|
186 | # yield to the with-statement block | |
187 | try: |
|
187 | try: | |
188 | yield |
|
188 | yield | |
189 | finally: |
|
189 | finally: | |
190 | # postflight: restore saved flags |
|
190 | # postflight: restore saved flags | |
191 | self.set_flags(**saved_flags) |
|
191 | self.set_flags(**saved_flags) | |
192 |
|
192 | |||
193 |
|
193 | |||
194 | #---------------------------------------------------------------- |
|
194 | #---------------------------------------------------------------- | |
195 | # apply |
|
195 | # apply | |
196 | #---------------------------------------------------------------- |
|
196 | #---------------------------------------------------------------- | |
197 |
|
197 | |||
198 | def _sync_results(self): |
|
198 | def _sync_results(self): | |
199 | """to be called by @sync_results decorator |
|
199 | """to be called by @sync_results decorator | |
200 |
|
200 | |||
201 | after submitting any tasks. |
|
201 | after submitting any tasks. | |
202 | """ |
|
202 | """ | |
203 | delta = self.outstanding.difference(self.client.outstanding) |
|
203 | delta = self.outstanding.difference(self.client.outstanding) | |
204 | completed = self.outstanding.intersection(delta) |
|
204 | completed = self.outstanding.intersection(delta) | |
205 | self.outstanding = self.outstanding.difference(completed) |
|
205 | self.outstanding = self.outstanding.difference(completed) | |
206 |
|
206 | |||
207 | @sync_results |
|
207 | @sync_results | |
208 | @save_ids |
|
208 | @save_ids | |
209 | def _really_apply(self, f, args, kwargs, block=None, **options): |
|
209 | def _really_apply(self, f, args, kwargs, block=None, **options): | |
210 | """wrapper for client.send_apply_request""" |
|
210 | """wrapper for client.send_apply_request""" | |
211 | raise NotImplementedError("Implement in subclasses") |
|
211 | raise NotImplementedError("Implement in subclasses") | |
212 |
|
212 | |||
213 | def apply(self, f, *args, **kwargs): |
|
213 | def apply(self, f, *args, **kwargs): | |
214 | """calls ``f(*args, **kwargs)`` on remote engines, returning the result. |
|
214 | """calls ``f(*args, **kwargs)`` on remote engines, returning the result. | |
215 |
|
215 | |||
216 | This method sets all apply flags via this View's attributes. |
|
216 | This method sets all apply flags via this View's attributes. | |
217 |
|
217 | |||
218 | Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` |
|
218 | Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` | |
219 | instance if ``self.block`` is False, otherwise the return value of |
|
219 | instance if ``self.block`` is False, otherwise the return value of | |
220 | ``f(*args, **kwargs)``. |
|
220 | ``f(*args, **kwargs)``. | |
221 | """ |
|
221 | """ | |
222 | return self._really_apply(f, args, kwargs) |
|
222 | return self._really_apply(f, args, kwargs) | |
223 |
|
223 | |||
224 | def apply_async(self, f, *args, **kwargs): |
|
224 | def apply_async(self, f, *args, **kwargs): | |
225 | """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner. |
|
225 | """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner. | |
226 |
|
226 | |||
227 | Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance. |
|
227 | Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance. | |
228 | """ |
|
228 | """ | |
229 | return self._really_apply(f, args, kwargs, block=False) |
|
229 | return self._really_apply(f, args, kwargs, block=False) | |
230 |
|
230 | |||
231 | @spin_after |
|
231 | @spin_after | |
232 | def apply_sync(self, f, *args, **kwargs): |
|
232 | def apply_sync(self, f, *args, **kwargs): | |
233 | """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner, |
|
233 | """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner, | |
234 | returning the result. |
|
234 | returning the result. | |
235 | """ |
|
235 | """ | |
236 | return self._really_apply(f, args, kwargs, block=True) |
|
236 | return self._really_apply(f, args, kwargs, block=True) | |
237 |
|
237 | |||
238 | #---------------------------------------------------------------- |
|
238 | #---------------------------------------------------------------- | |
239 | # wrappers for client and control methods |
|
239 | # wrappers for client and control methods | |
240 | #---------------------------------------------------------------- |
|
240 | #---------------------------------------------------------------- | |
241 | @sync_results |
|
241 | @sync_results | |
242 | def spin(self): |
|
242 | def spin(self): | |
243 | """spin the client, and sync""" |
|
243 | """spin the client, and sync""" | |
244 | self.client.spin() |
|
244 | self.client.spin() | |
245 |
|
245 | |||
246 | @sync_results |
|
246 | @sync_results | |
247 | def wait(self, jobs=None, timeout=-1): |
|
247 | def wait(self, jobs=None, timeout=-1): | |
248 | """waits on one or more `jobs`, for up to `timeout` seconds. |
|
248 | """waits on one or more `jobs`, for up to `timeout` seconds. | |
249 |
|
249 | |||
250 | Parameters |
|
250 | Parameters | |
251 | ---------- |
|
251 | ---------- | |
252 |
|
252 | |||
253 | jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects |
|
253 | jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects | |
254 | ints are indices to self.history |
|
254 | ints are indices to self.history | |
255 | strs are msg_ids |
|
255 | strs are msg_ids | |
256 | default: wait on all outstanding messages |
|
256 | default: wait on all outstanding messages | |
257 | timeout : float |
|
257 | timeout : float | |
258 | a time in seconds, after which to give up. |
|
258 | a time in seconds, after which to give up. | |
259 | default is -1, which means no timeout |
|
259 | default is -1, which means no timeout | |
260 |
|
260 | |||
261 | Returns |
|
261 | Returns | |
262 | ------- |
|
262 | ------- | |
263 |
|
263 | |||
264 | True : when all msg_ids are done |
|
264 | True : when all msg_ids are done | |
265 | False : timeout reached, some msg_ids still outstanding |
|
265 | False : timeout reached, some msg_ids still outstanding | |
266 | """ |
|
266 | """ | |
267 | if jobs is None: |
|
267 | if jobs is None: | |
268 | jobs = self.history |
|
268 | jobs = self.history | |
269 | return self.client.wait(jobs, timeout) |
|
269 | return self.client.wait(jobs, timeout) | |
270 |
|
270 | |||
271 | def abort(self, jobs=None, targets=None, block=None): |
|
271 | def abort(self, jobs=None, targets=None, block=None): | |
272 | """Abort jobs on my engines. |
|
272 | """Abort jobs on my engines. | |
273 |
|
273 | |||
274 | Parameters |
|
274 | Parameters | |
275 | ---------- |
|
275 | ---------- | |
276 |
|
276 | |||
277 | jobs : None, str, list of strs, optional |
|
277 | jobs : None, str, list of strs, optional | |
278 | if None: abort all jobs. |
|
278 | if None: abort all jobs. | |
279 | else: abort specific msg_id(s). |
|
279 | else: abort specific msg_id(s). | |
280 | """ |
|
280 | """ | |
281 | block = block if block is not None else self.block |
|
281 | block = block if block is not None else self.block | |
282 | targets = targets if targets is not None else self.targets |
|
282 | targets = targets if targets is not None else self.targets | |
283 | jobs = jobs if jobs is not None else list(self.outstanding) |
|
283 | jobs = jobs if jobs is not None else list(self.outstanding) | |
284 |
|
284 | |||
285 | return self.client.abort(jobs=jobs, targets=targets, block=block) |
|
285 | return self.client.abort(jobs=jobs, targets=targets, block=block) | |
286 |
|
286 | |||
287 | def queue_status(self, targets=None, verbose=False): |
|
287 | def queue_status(self, targets=None, verbose=False): | |
288 | """Fetch the Queue status of my engines""" |
|
288 | """Fetch the Queue status of my engines""" | |
289 | targets = targets if targets is not None else self.targets |
|
289 | targets = targets if targets is not None else self.targets | |
290 | return self.client.queue_status(targets=targets, verbose=verbose) |
|
290 | return self.client.queue_status(targets=targets, verbose=verbose) | |
291 |
|
291 | |||
292 | def purge_results(self, jobs=[], targets=[]): |
|
292 | def purge_results(self, jobs=[], targets=[]): | |
293 | """Instruct the controller to forget specific results.""" |
|
293 | """Instruct the controller to forget specific results.""" | |
294 | if targets is None or targets == 'all': |
|
294 | if targets is None or targets == 'all': | |
295 | targets = self.targets |
|
295 | targets = self.targets | |
296 | return self.client.purge_results(jobs=jobs, targets=targets) |
|
296 | return self.client.purge_results(jobs=jobs, targets=targets) | |
297 |
|
297 | |||
298 | def shutdown(self, targets=None, restart=False, hub=False, block=None): |
|
298 | def shutdown(self, targets=None, restart=False, hub=False, block=None): | |
299 | """Terminates one or more engine processes, optionally including the hub. |
|
299 | """Terminates one or more engine processes, optionally including the hub. | |
300 | """ |
|
300 | """ | |
301 | block = self.block if block is None else block |
|
301 | block = self.block if block is None else block | |
302 | if targets is None or targets == 'all': |
|
302 | if targets is None or targets == 'all': | |
303 | targets = self.targets |
|
303 | targets = self.targets | |
304 | return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block) |
|
304 | return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block) | |
305 |
|
305 | |||
306 | @spin_after |
|
306 | @spin_after | |
307 | def get_result(self, indices_or_msg_ids=None, block=None, owner=True): |
|
307 | def get_result(self, indices_or_msg_ids=None, block=None, owner=True): | |
308 | """return one or more results, specified by history index or msg_id. |
|
308 | """return one or more results, specified by history index or msg_id. | |
309 |
|
309 | |||
310 | See :meth:`IPython.parallel.client.client.Client.get_result` for details. |
|
310 | See :meth:`IPython.parallel.client.client.Client.get_result` for details. | |
311 | """ |
|
311 | """ | |
312 |
|
312 | |||
313 | if indices_or_msg_ids is None: |
|
313 | if indices_or_msg_ids is None: | |
314 | indices_or_msg_ids = -1 |
|
314 | indices_or_msg_ids = -1 | |
315 | if isinstance(indices_or_msg_ids, int): |
|
315 | if isinstance(indices_or_msg_ids, int): | |
316 | indices_or_msg_ids = self.history[indices_or_msg_ids] |
|
316 | indices_or_msg_ids = self.history[indices_or_msg_ids] | |
317 | elif isinstance(indices_or_msg_ids, (list,tuple,set)): |
|
317 | elif isinstance(indices_or_msg_ids, (list,tuple,set)): | |
318 | indices_or_msg_ids = list(indices_or_msg_ids) |
|
318 | indices_or_msg_ids = list(indices_or_msg_ids) | |
319 | for i,index in enumerate(indices_or_msg_ids): |
|
319 | for i,index in enumerate(indices_or_msg_ids): | |
320 | if isinstance(index, int): |
|
320 | if isinstance(index, int): | |
321 | indices_or_msg_ids[i] = self.history[index] |
|
321 | indices_or_msg_ids[i] = self.history[index] | |
322 | return self.client.get_result(indices_or_msg_ids, block=block, owner=owner) |
|
322 | return self.client.get_result(indices_or_msg_ids, block=block, owner=owner) | |
323 |
|
323 | |||
324 | #------------------------------------------------------------------- |
|
324 | #------------------------------------------------------------------- | |
325 | # Map |
|
325 | # Map | |
326 | #------------------------------------------------------------------- |
|
326 | #------------------------------------------------------------------- | |
327 |
|
327 | |||
328 | @sync_results |
|
328 | @sync_results | |
329 | def map(self, f, *sequences, **kwargs): |
|
329 | def map(self, f, *sequences, **kwargs): | |
330 | """override in subclasses""" |
|
330 | """override in subclasses""" | |
331 | raise NotImplementedError |
|
331 | raise NotImplementedError | |
332 |
|
332 | |||
333 | def map_async(self, f, *sequences, **kwargs): |
|
333 | def map_async(self, f, *sequences, **kwargs): | |
334 | """Parallel version of builtin :func:`python:map`, using this view's engines. |
|
334 | """Parallel version of builtin :func:`python:map`, using this view's engines. | |
335 |
|
335 | |||
336 | This is equivalent to ``map(...block=False)``. |
|
336 | This is equivalent to ``map(...block=False)``. | |
337 |
|
337 | |||
338 | See `self.map` for details. |
|
338 | See `self.map` for details. | |
339 | """ |
|
339 | """ | |
340 | if 'block' in kwargs: |
|
340 | if 'block' in kwargs: | |
341 | raise TypeError("map_async doesn't take a `block` keyword argument.") |
|
341 | raise TypeError("map_async doesn't take a `block` keyword argument.") | |
342 | kwargs['block'] = False |
|
342 | kwargs['block'] = False | |
343 | return self.map(f,*sequences,**kwargs) |
|
343 | return self.map(f,*sequences,**kwargs) | |
344 |
|
344 | |||
345 | def map_sync(self, f, *sequences, **kwargs): |
|
345 | def map_sync(self, f, *sequences, **kwargs): | |
346 | """Parallel version of builtin :func:`python:map`, using this view's engines. |
|
346 | """Parallel version of builtin :func:`python:map`, using this view's engines. | |
347 |
|
347 | |||
348 | This is equivalent to ``map(...block=True)``. |
|
348 | This is equivalent to ``map(...block=True)``. | |
349 |
|
349 | |||
350 | See `self.map` for details. |
|
350 | See `self.map` for details. | |
351 | """ |
|
351 | """ | |
352 | if 'block' in kwargs: |
|
352 | if 'block' in kwargs: | |
353 | raise TypeError("map_sync doesn't take a `block` keyword argument.") |
|
353 | raise TypeError("map_sync doesn't take a `block` keyword argument.") | |
354 | kwargs['block'] = True |
|
354 | kwargs['block'] = True | |
355 | return self.map(f,*sequences,**kwargs) |
|
355 | return self.map(f,*sequences,**kwargs) | |
356 |
|
356 | |||
357 | def imap(self, f, *sequences, **kwargs): |
|
357 | def imap(self, f, *sequences, **kwargs): | |
358 | """Parallel version of :func:`itertools.imap`. |
|
358 | """Parallel version of :func:`itertools.imap`. | |
359 |
|
359 | |||
360 | See `self.map` for details. |
|
360 | See `self.map` for details. | |
361 |
|
361 | |||
362 | """ |
|
362 | """ | |
363 |
|
363 | |||
364 | return iter(self.map_async(f,*sequences, **kwargs)) |
|
364 | return iter(self.map_async(f,*sequences, **kwargs)) | |
365 |
|
365 | |||
366 | #------------------------------------------------------------------- |
|
366 | #------------------------------------------------------------------- | |
367 | # Decorators |
|
367 | # Decorators | |
368 | #------------------------------------------------------------------- |
|
368 | #------------------------------------------------------------------- | |
369 |
|
369 | |||
370 | def remote(self, block=None, **flags): |
|
370 | def remote(self, block=None, **flags): | |
371 | """Decorator for making a RemoteFunction""" |
|
371 | """Decorator for making a RemoteFunction""" | |
372 | block = self.block if block is None else block |
|
372 | block = self.block if block is None else block | |
373 | return remote(self, block=block, **flags) |
|
373 | return remote(self, block=block, **flags) | |
374 |
|
374 | |||
375 | def parallel(self, dist='b', block=None, **flags): |
|
375 | def parallel(self, dist='b', block=None, **flags): | |
376 | """Decorator for making a ParallelFunction""" |
|
376 | """Decorator for making a ParallelFunction""" | |
377 | block = self.block if block is None else block |
|
377 | block = self.block if block is None else block | |
378 | return parallel(self, dist=dist, block=block, **flags) |
|
378 | return parallel(self, dist=dist, block=block, **flags) | |
379 |
|
379 | |||
380 | @skip_doctest |
|
380 | @skip_doctest | |
381 | class DirectView(View): |
|
381 | class DirectView(View): | |
382 | """Direct Multiplexer View of one or more engines. |
|
382 | """Direct Multiplexer View of one or more engines. | |
383 |
|
383 | |||
384 | These are created via indexed access to a client: |
|
384 | These are created via indexed access to a client: | |
385 |
|
385 | |||
386 | >>> dv_1 = client[1] |
|
386 | >>> dv_1 = client[1] | |
387 | >>> dv_all = client[:] |
|
387 | >>> dv_all = client[:] | |
388 | >>> dv_even = client[::2] |
|
388 | >>> dv_even = client[::2] | |
389 | >>> dv_some = client[1:3] |
|
389 | >>> dv_some = client[1:3] | |
390 |
|
390 | |||
391 | This object provides dictionary access to engine namespaces: |
|
391 | This object provides dictionary access to engine namespaces: | |
392 |
|
392 | |||
393 | # push a=5: |
|
393 | # push a=5: | |
394 | >>> dv['a'] = 5 |
|
394 | >>> dv['a'] = 5 | |
395 | # pull 'foo': |
|
395 | # pull 'foo': | |
396 | >>> dv['foo'] |
|
396 | >>> dv['foo'] | |
397 |
|
397 | |||
398 | """ |
|
398 | """ | |
399 |
|
399 | |||
400 | def __init__(self, client=None, socket=None, targets=None): |
|
400 | def __init__(self, client=None, socket=None, targets=None): | |
401 | super(DirectView, self).__init__(client=client, socket=socket, targets=targets) |
|
401 | super(DirectView, self).__init__(client=client, socket=socket, targets=targets) | |
402 |
|
402 | |||
403 | @property |
|
403 | @property | |
404 | def importer(self): |
|
404 | def importer(self): | |
405 | """sync_imports(local=True) as a property. |
|
405 | """sync_imports(local=True) as a property. | |
406 |
|
406 | |||
407 | See sync_imports for details. |
|
407 | See sync_imports for details. | |
408 |
|
408 | |||
409 | """ |
|
409 | """ | |
410 | return self.sync_imports(True) |
|
410 | return self.sync_imports(True) | |
411 |
|
411 | |||
412 | @contextmanager |
|
412 | @contextmanager | |
413 | def sync_imports(self, local=True, quiet=False): |
|
413 | def sync_imports(self, local=True, quiet=False): | |
414 | """Context Manager for performing simultaneous local and remote imports. |
|
414 | """Context Manager for performing simultaneous local and remote imports. | |
415 |
|
415 | |||
416 | 'import x as y' will *not* work. The 'as y' part will simply be ignored. |
|
416 | 'import x as y' will *not* work. The 'as y' part will simply be ignored. | |
417 |
|
417 | |||
418 | If `local=True`, then the package will also be imported locally. |
|
418 | If `local=True`, then the package will also be imported locally. | |
419 |
|
419 | |||
420 | If `quiet=True`, no output will be produced when attempting remote |
|
420 | If `quiet=True`, no output will be produced when attempting remote | |
421 | imports. |
|
421 | imports. | |
422 |
|
422 | |||
423 | Note that remote-only (`local=False`) imports have not been implemented. |
|
423 | Note that remote-only (`local=False`) imports have not been implemented. | |
424 |
|
424 | |||
425 | >>> with view.sync_imports(): |
|
425 | >>> with view.sync_imports(): | |
426 | ... from numpy import recarray |
|
426 | ... from numpy import recarray | |
427 | importing recarray from numpy on engine(s) |
|
427 | importing recarray from numpy on engine(s) | |
428 |
|
428 | |||
429 | """ |
|
429 | """ | |
430 | from IPython.utils.py3compat import builtin_mod |
|
430 | from IPython.utils.py3compat import builtin_mod | |
431 | local_import = builtin_mod.__import__ |
|
431 | local_import = builtin_mod.__import__ | |
432 | modules = set() |
|
432 | modules = set() | |
433 | results = [] |
|
433 | results = [] | |
434 | @util.interactive |
|
434 | @util.interactive | |
435 | def remote_import(name, fromlist, level): |
|
435 | def remote_import(name, fromlist, level): | |
436 | """the function to be passed to apply, that actually performs the import |
|
436 | """the function to be passed to apply, that actually performs the import | |
437 | on the engine, and loads up the user namespace. |
|
437 | on the engine, and loads up the user namespace. | |
438 | """ |
|
438 | """ | |
439 | import sys |
|
439 | import sys | |
440 | user_ns = globals() |
|
440 | user_ns = globals() | |
441 | mod = __import__(name, fromlist=fromlist, level=level) |
|
441 | mod = __import__(name, fromlist=fromlist, level=level) | |
442 | if fromlist: |
|
442 | if fromlist: | |
443 | for key in fromlist: |
|
443 | for key in fromlist: | |
444 | user_ns[key] = getattr(mod, key) |
|
444 | user_ns[key] = getattr(mod, key) | |
445 | else: |
|
445 | else: | |
446 | user_ns[name] = sys.modules[name] |
|
446 | user_ns[name] = sys.modules[name] | |
447 |
|
447 | |||
448 | def view_import(name, globals={}, locals={}, fromlist=[], level=0): |
|
448 | def view_import(name, globals={}, locals={}, fromlist=[], level=0): | |
449 | """the drop-in replacement for __import__, that optionally imports |
|
449 | """the drop-in replacement for __import__, that optionally imports | |
450 | locally as well. |
|
450 | locally as well. | |
451 | """ |
|
451 | """ | |
452 | # don't override nested imports |
|
452 | # don't override nested imports | |
453 | save_import = builtin_mod.__import__ |
|
453 | save_import = builtin_mod.__import__ | |
454 | builtin_mod.__import__ = local_import |
|
454 | builtin_mod.__import__ = local_import | |
455 |
|
455 | |||
456 | if imp.lock_held(): |
|
456 | if imp.lock_held(): | |
457 | # this is a side-effect import, don't do it remotely, or even |
|
457 | # this is a side-effect import, don't do it remotely, or even | |
458 | # ignore the local effects |
|
458 | # ignore the local effects | |
459 | return local_import(name, globals, locals, fromlist, level) |
|
459 | return local_import(name, globals, locals, fromlist, level) | |
460 |
|
460 | |||
461 | imp.acquire_lock() |
|
461 | imp.acquire_lock() | |
462 | if local: |
|
462 | if local: | |
463 | mod = local_import(name, globals, locals, fromlist, level) |
|
463 | mod = local_import(name, globals, locals, fromlist, level) | |
464 | else: |
|
464 | else: | |
465 | raise NotImplementedError("remote-only imports not yet implemented") |
|
465 | raise NotImplementedError("remote-only imports not yet implemented") | |
466 | imp.release_lock() |
|
466 | imp.release_lock() | |
467 |
|
467 | |||
468 | key = name+':'+','.join(fromlist or []) |
|
468 | key = name+':'+','.join(fromlist or []) | |
469 | if level <= 0 and key not in modules: |
|
469 | if level <= 0 and key not in modules: | |
470 | modules.add(key) |
|
470 | modules.add(key) | |
471 | if not quiet: |
|
471 | if not quiet: | |
472 | if fromlist: |
|
472 | if fromlist: | |
473 | print("importing %s from %s on engine(s)"%(','.join(fromlist), name)) |
|
473 | print("importing %s from %s on engine(s)"%(','.join(fromlist), name)) | |
474 | else: |
|
474 | else: | |
475 | print("importing %s on engine(s)"%name) |
|
475 | print("importing %s on engine(s)"%name) | |
476 | results.append(self.apply_async(remote_import, name, fromlist, level)) |
|
476 | results.append(self.apply_async(remote_import, name, fromlist, level)) | |
477 | # restore override |
|
477 | # restore override | |
478 | builtin_mod.__import__ = save_import |
|
478 | builtin_mod.__import__ = save_import | |
479 |
|
479 | |||
480 | return mod |
|
480 | return mod | |
481 |
|
481 | |||
482 | # override __import__ |
|
482 | # override __import__ | |
483 | builtin_mod.__import__ = view_import |
|
483 | builtin_mod.__import__ = view_import | |
484 | try: |
|
484 | try: | |
485 | # enter the block |
|
485 | # enter the block | |
486 | yield |
|
486 | yield | |
487 | except ImportError: |
|
487 | except ImportError: | |
488 | if local: |
|
488 | if local: | |
489 | raise |
|
489 | raise | |
490 | else: |
|
490 | else: | |
491 | # ignore import errors if not doing local imports |
|
491 | # ignore import errors if not doing local imports | |
492 | pass |
|
492 | pass | |
493 | finally: |
|
493 | finally: | |
494 | # always restore __import__ |
|
494 | # always restore __import__ | |
495 | builtin_mod.__import__ = local_import |
|
495 | builtin_mod.__import__ = local_import | |
496 |
|
496 | |||
497 | for r in results: |
|
497 | for r in results: | |
498 | # raise possible remote ImportErrors here |
|
498 | # raise possible remote ImportErrors here | |
499 | r.get() |
|
499 | r.get() | |
500 |
|
500 | |||
501 | def use_dill(self): |
|
501 | def use_dill(self): | |
502 | """Expand serialization support with dill |
|
502 | """Expand serialization support with dill | |
503 |
|
503 | |||
504 | adds support for closures, etc. |
|
504 | adds support for closures, etc. | |
505 |
|
505 | |||
506 |
This calls |
|
506 | This calls ipython_kernel.pickleutil.use_dill() here and on each engine. | |
507 | """ |
|
507 | """ | |
508 | pickleutil.use_dill() |
|
508 | pickleutil.use_dill() | |
509 | return self.apply(pickleutil.use_dill) |
|
509 | return self.apply(pickleutil.use_dill) | |
510 |
|
510 | |||
511 | def use_cloudpickle(self): |
|
511 | def use_cloudpickle(self): | |
512 | """Expand serialization support with cloudpickle. |
|
512 | """Expand serialization support with cloudpickle. | |
513 | """ |
|
513 | """ | |
514 | pickleutil.use_cloudpickle() |
|
514 | pickleutil.use_cloudpickle() | |
515 | return self.apply(pickleutil.use_cloudpickle) |
|
515 | return self.apply(pickleutil.use_cloudpickle) | |
516 |
|
516 | |||
517 |
|
517 | |||
518 | @sync_results |
|
518 | @sync_results | |
519 | @save_ids |
|
519 | @save_ids | |
520 | def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None): |
|
520 | def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None): | |
521 | """calls f(*args, **kwargs) on remote engines, returning the result. |
|
521 | """calls f(*args, **kwargs) on remote engines, returning the result. | |
522 |
|
522 | |||
523 | This method sets all of `apply`'s flags via this View's attributes. |
|
523 | This method sets all of `apply`'s flags via this View's attributes. | |
524 |
|
524 | |||
525 | Parameters |
|
525 | Parameters | |
526 | ---------- |
|
526 | ---------- | |
527 |
|
527 | |||
528 | f : callable |
|
528 | f : callable | |
529 |
|
529 | |||
530 | args : list [default: empty] |
|
530 | args : list [default: empty] | |
531 |
|
531 | |||
532 | kwargs : dict [default: empty] |
|
532 | kwargs : dict [default: empty] | |
533 |
|
533 | |||
534 | targets : target list [default: self.targets] |
|
534 | targets : target list [default: self.targets] | |
535 | where to run |
|
535 | where to run | |
536 | block : bool [default: self.block] |
|
536 | block : bool [default: self.block] | |
537 | whether to block |
|
537 | whether to block | |
538 | track : bool [default: self.track] |
|
538 | track : bool [default: self.track] | |
539 | whether to ask zmq to track the message, for safe non-copying sends |
|
539 | whether to ask zmq to track the message, for safe non-copying sends | |
540 |
|
540 | |||
541 | Returns |
|
541 | Returns | |
542 | ------- |
|
542 | ------- | |
543 |
|
543 | |||
544 | if self.block is False: |
|
544 | if self.block is False: | |
545 | returns AsyncResult |
|
545 | returns AsyncResult | |
546 | else: |
|
546 | else: | |
547 | returns actual result of f(*args, **kwargs) on the engine(s) |
|
547 | returns actual result of f(*args, **kwargs) on the engine(s) | |
548 | This will be a list of self.targets is also a list (even length 1), or |
|
548 | This will be a list of self.targets is also a list (even length 1), or | |
549 | the single result if self.targets is an integer engine id |
|
549 | the single result if self.targets is an integer engine id | |
550 | """ |
|
550 | """ | |
551 | args = [] if args is None else args |
|
551 | args = [] if args is None else args | |
552 | kwargs = {} if kwargs is None else kwargs |
|
552 | kwargs = {} if kwargs is None else kwargs | |
553 | block = self.block if block is None else block |
|
553 | block = self.block if block is None else block | |
554 | track = self.track if track is None else track |
|
554 | track = self.track if track is None else track | |
555 | targets = self.targets if targets is None else targets |
|
555 | targets = self.targets if targets is None else targets | |
556 |
|
556 | |||
557 | _idents, _targets = self.client._build_targets(targets) |
|
557 | _idents, _targets = self.client._build_targets(targets) | |
558 | msg_ids = [] |
|
558 | msg_ids = [] | |
559 | trackers = [] |
|
559 | trackers = [] | |
560 | for ident in _idents: |
|
560 | for ident in _idents: | |
561 | msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track, |
|
561 | msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track, | |
562 | ident=ident) |
|
562 | ident=ident) | |
563 | if track: |
|
563 | if track: | |
564 | trackers.append(msg['tracker']) |
|
564 | trackers.append(msg['tracker']) | |
565 | msg_ids.append(msg['header']['msg_id']) |
|
565 | msg_ids.append(msg['header']['msg_id']) | |
566 | if isinstance(targets, int): |
|
566 | if isinstance(targets, int): | |
567 | msg_ids = msg_ids[0] |
|
567 | msg_ids = msg_ids[0] | |
568 | tracker = None if track is False else zmq.MessageTracker(*trackers) |
|
568 | tracker = None if track is False else zmq.MessageTracker(*trackers) | |
569 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, |
|
569 | ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets, | |
570 | tracker=tracker, owner=True, |
|
570 | tracker=tracker, owner=True, | |
571 | ) |
|
571 | ) | |
572 | if block: |
|
572 | if block: | |
573 | try: |
|
573 | try: | |
574 | return ar.get() |
|
574 | return ar.get() | |
575 | except KeyboardInterrupt: |
|
575 | except KeyboardInterrupt: | |
576 | pass |
|
576 | pass | |
577 | return ar |
|
577 | return ar | |
578 |
|
578 | |||
579 |
|
579 | |||
580 | @sync_results |
|
580 | @sync_results | |
581 | def map(self, f, *sequences, **kwargs): |
|
581 | def map(self, f, *sequences, **kwargs): | |
582 | """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult |
|
582 | """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult | |
583 |
|
583 | |||
584 | Parallel version of builtin `map`, using this View's `targets`. |
|
584 | Parallel version of builtin `map`, using this View's `targets`. | |
585 |
|
585 | |||
586 | There will be one task per target, so work will be chunked |
|
586 | There will be one task per target, so work will be chunked | |
587 | if the sequences are longer than `targets`. |
|
587 | if the sequences are longer than `targets`. | |
588 |
|
588 | |||
589 | Results can be iterated as they are ready, but will become available in chunks. |
|
589 | Results can be iterated as they are ready, but will become available in chunks. | |
590 |
|
590 | |||
591 | Parameters |
|
591 | Parameters | |
592 | ---------- |
|
592 | ---------- | |
593 |
|
593 | |||
594 | f : callable |
|
594 | f : callable | |
595 | function to be mapped |
|
595 | function to be mapped | |
596 | *sequences: one or more sequences of matching length |
|
596 | *sequences: one or more sequences of matching length | |
597 | the sequences to be distributed and passed to `f` |
|
597 | the sequences to be distributed and passed to `f` | |
598 | block : bool |
|
598 | block : bool | |
599 | whether to wait for the result or not [default self.block] |
|
599 | whether to wait for the result or not [default self.block] | |
600 |
|
600 | |||
601 | Returns |
|
601 | Returns | |
602 | ------- |
|
602 | ------- | |
603 |
|
603 | |||
604 |
|
604 | |||
605 | If block=False |
|
605 | If block=False | |
606 | An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance. |
|
606 | An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance. | |
607 | An object like AsyncResult, but which reassembles the sequence of results |
|
607 | An object like AsyncResult, but which reassembles the sequence of results | |
608 | into a single list. AsyncMapResults can be iterated through before all |
|
608 | into a single list. AsyncMapResults can be iterated through before all | |
609 | results are complete. |
|
609 | results are complete. | |
610 | else |
|
610 | else | |
611 | A list, the result of ``map(f,*sequences)`` |
|
611 | A list, the result of ``map(f,*sequences)`` | |
612 | """ |
|
612 | """ | |
613 |
|
613 | |||
614 | block = kwargs.pop('block', self.block) |
|
614 | block = kwargs.pop('block', self.block) | |
615 | for k in kwargs.keys(): |
|
615 | for k in kwargs.keys(): | |
616 | if k not in ['block', 'track']: |
|
616 | if k not in ['block', 'track']: | |
617 | raise TypeError("invalid keyword arg, %r"%k) |
|
617 | raise TypeError("invalid keyword arg, %r"%k) | |
618 |
|
618 | |||
619 | assert len(sequences) > 0, "must have some sequences to map onto!" |
|
619 | assert len(sequences) > 0, "must have some sequences to map onto!" | |
620 | pf = ParallelFunction(self, f, block=block, **kwargs) |
|
620 | pf = ParallelFunction(self, f, block=block, **kwargs) | |
621 | return pf.map(*sequences) |
|
621 | return pf.map(*sequences) | |
622 |
|
622 | |||
623 | @sync_results |
|
623 | @sync_results | |
624 | @save_ids |
|
624 | @save_ids | |
625 | def execute(self, code, silent=True, targets=None, block=None): |
|
625 | def execute(self, code, silent=True, targets=None, block=None): | |
626 | """Executes `code` on `targets` in blocking or nonblocking manner. |
|
626 | """Executes `code` on `targets` in blocking or nonblocking manner. | |
627 |
|
627 | |||
628 | ``execute`` is always `bound` (affects engine namespace) |
|
628 | ``execute`` is always `bound` (affects engine namespace) | |
629 |
|
629 | |||
630 | Parameters |
|
630 | Parameters | |
631 | ---------- |
|
631 | ---------- | |
632 |
|
632 | |||
633 | code : str |
|
633 | code : str | |
634 | the code string to be executed |
|
634 | the code string to be executed | |
635 | block : bool |
|
635 | block : bool | |
636 | whether or not to wait until done to return |
|
636 | whether or not to wait until done to return | |
637 | default: self.block |
|
637 | default: self.block | |
638 | """ |
|
638 | """ | |
639 | block = self.block if block is None else block |
|
639 | block = self.block if block is None else block | |
640 | targets = self.targets if targets is None else targets |
|
640 | targets = self.targets if targets is None else targets | |
641 |
|
641 | |||
642 | _idents, _targets = self.client._build_targets(targets) |
|
642 | _idents, _targets = self.client._build_targets(targets) | |
643 | msg_ids = [] |
|
643 | msg_ids = [] | |
644 | trackers = [] |
|
644 | trackers = [] | |
645 | for ident in _idents: |
|
645 | for ident in _idents: | |
646 | msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident) |
|
646 | msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident) | |
647 | msg_ids.append(msg['header']['msg_id']) |
|
647 | msg_ids.append(msg['header']['msg_id']) | |
648 | if isinstance(targets, int): |
|
648 | if isinstance(targets, int): | |
649 | msg_ids = msg_ids[0] |
|
649 | msg_ids = msg_ids[0] | |
650 | ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True) |
|
650 | ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True) | |
651 | if block: |
|
651 | if block: | |
652 | try: |
|
652 | try: | |
653 | ar.get() |
|
653 | ar.get() | |
654 | except KeyboardInterrupt: |
|
654 | except KeyboardInterrupt: | |
655 | pass |
|
655 | pass | |
656 | return ar |
|
656 | return ar | |
657 |
|
657 | |||
658 | def run(self, filename, targets=None, block=None): |
|
658 | def run(self, filename, targets=None, block=None): | |
659 | """Execute contents of `filename` on my engine(s). |
|
659 | """Execute contents of `filename` on my engine(s). | |
660 |
|
660 | |||
661 | This simply reads the contents of the file and calls `execute`. |
|
661 | This simply reads the contents of the file and calls `execute`. | |
662 |
|
662 | |||
663 | Parameters |
|
663 | Parameters | |
664 | ---------- |
|
664 | ---------- | |
665 |
|
665 | |||
666 | filename : str |
|
666 | filename : str | |
667 | The path to the file |
|
667 | The path to the file | |
668 | targets : int/str/list of ints/strs |
|
668 | targets : int/str/list of ints/strs | |
669 | the engines on which to execute |
|
669 | the engines on which to execute | |
670 | default : all |
|
670 | default : all | |
671 | block : bool |
|
671 | block : bool | |
672 | whether or not to wait until done |
|
672 | whether or not to wait until done | |
673 | default: self.block |
|
673 | default: self.block | |
674 |
|
674 | |||
675 | """ |
|
675 | """ | |
676 | with open(filename, 'r') as f: |
|
676 | with open(filename, 'r') as f: | |
677 | # add newline in case of trailing indented whitespace |
|
677 | # add newline in case of trailing indented whitespace | |
678 | # which will cause SyntaxError |
|
678 | # which will cause SyntaxError | |
679 | code = f.read()+'\n' |
|
679 | code = f.read()+'\n' | |
680 | return self.execute(code, block=block, targets=targets) |
|
680 | return self.execute(code, block=block, targets=targets) | |
681 |
|
681 | |||
682 | def update(self, ns): |
|
682 | def update(self, ns): | |
683 | """update remote namespace with dict `ns` |
|
683 | """update remote namespace with dict `ns` | |
684 |
|
684 | |||
685 | See `push` for details. |
|
685 | See `push` for details. | |
686 | """ |
|
686 | """ | |
687 | return self.push(ns, block=self.block, track=self.track) |
|
687 | return self.push(ns, block=self.block, track=self.track) | |
688 |
|
688 | |||
689 | def push(self, ns, targets=None, block=None, track=None): |
|
689 | def push(self, ns, targets=None, block=None, track=None): | |
690 | """update remote namespace with dict `ns` |
|
690 | """update remote namespace with dict `ns` | |
691 |
|
691 | |||
692 | Parameters |
|
692 | Parameters | |
693 | ---------- |
|
693 | ---------- | |
694 |
|
694 | |||
695 | ns : dict |
|
695 | ns : dict | |
696 | dict of keys with which to update engine namespace(s) |
|
696 | dict of keys with which to update engine namespace(s) | |
697 | block : bool [default : self.block] |
|
697 | block : bool [default : self.block] | |
698 | whether to wait to be notified of engine receipt |
|
698 | whether to wait to be notified of engine receipt | |
699 |
|
699 | |||
700 | """ |
|
700 | """ | |
701 |
|
701 | |||
702 | block = block if block is not None else self.block |
|
702 | block = block if block is not None else self.block | |
703 | track = track if track is not None else self.track |
|
703 | track = track if track is not None else self.track | |
704 | targets = targets if targets is not None else self.targets |
|
704 | targets = targets if targets is not None else self.targets | |
705 | # applier = self.apply_sync if block else self.apply_async |
|
705 | # applier = self.apply_sync if block else self.apply_async | |
706 | if not isinstance(ns, dict): |
|
706 | if not isinstance(ns, dict): | |
707 | raise TypeError("Must be a dict, not %s"%type(ns)) |
|
707 | raise TypeError("Must be a dict, not %s"%type(ns)) | |
708 | return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets) |
|
708 | return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets) | |
709 |
|
709 | |||
710 | def get(self, key_s): |
|
710 | def get(self, key_s): | |
711 | """get object(s) by `key_s` from remote namespace |
|
711 | """get object(s) by `key_s` from remote namespace | |
712 |
|
712 | |||
713 | see `pull` for details. |
|
713 | see `pull` for details. | |
714 | """ |
|
714 | """ | |
715 | # block = block if block is not None else self.block |
|
715 | # block = block if block is not None else self.block | |
716 | return self.pull(key_s, block=True) |
|
716 | return self.pull(key_s, block=True) | |
717 |
|
717 | |||
718 | def pull(self, names, targets=None, block=None): |
|
718 | def pull(self, names, targets=None, block=None): | |
719 | """get object(s) by `name` from remote namespace |
|
719 | """get object(s) by `name` from remote namespace | |
720 |
|
720 | |||
721 | will return one object if it is a key. |
|
721 | will return one object if it is a key. | |
722 | can also take a list of keys, in which case it will return a list of objects. |
|
722 | can also take a list of keys, in which case it will return a list of objects. | |
723 | """ |
|
723 | """ | |
724 | block = block if block is not None else self.block |
|
724 | block = block if block is not None else self.block | |
725 | targets = targets if targets is not None else self.targets |
|
725 | targets = targets if targets is not None else self.targets | |
726 | applier = self.apply_sync if block else self.apply_async |
|
726 | applier = self.apply_sync if block else self.apply_async | |
727 | if isinstance(names, string_types): |
|
727 | if isinstance(names, string_types): | |
728 | pass |
|
728 | pass | |
729 | elif isinstance(names, (list,tuple,set)): |
|
729 | elif isinstance(names, (list,tuple,set)): | |
730 | for key in names: |
|
730 | for key in names: | |
731 | if not isinstance(key, string_types): |
|
731 | if not isinstance(key, string_types): | |
732 | raise TypeError("keys must be str, not type %r"%type(key)) |
|
732 | raise TypeError("keys must be str, not type %r"%type(key)) | |
733 | else: |
|
733 | else: | |
734 | raise TypeError("names must be strs, not %r"%names) |
|
734 | raise TypeError("names must be strs, not %r"%names) | |
735 | return self._really_apply(util._pull, (names,), block=block, targets=targets) |
|
735 | return self._really_apply(util._pull, (names,), block=block, targets=targets) | |
736 |
|
736 | |||
737 | def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None): |
|
737 | def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None): | |
738 | """ |
|
738 | """ | |
739 | Partition a Python sequence and send the partitions to a set of engines. |
|
739 | Partition a Python sequence and send the partitions to a set of engines. | |
740 | """ |
|
740 | """ | |
741 | block = block if block is not None else self.block |
|
741 | block = block if block is not None else self.block | |
742 | track = track if track is not None else self.track |
|
742 | track = track if track is not None else self.track | |
743 | targets = targets if targets is not None else self.targets |
|
743 | targets = targets if targets is not None else self.targets | |
744 |
|
744 | |||
745 | # construct integer ID list: |
|
745 | # construct integer ID list: | |
746 | targets = self.client._build_targets(targets)[1] |
|
746 | targets = self.client._build_targets(targets)[1] | |
747 |
|
747 | |||
748 | mapObject = Map.dists[dist]() |
|
748 | mapObject = Map.dists[dist]() | |
749 | nparts = len(targets) |
|
749 | nparts = len(targets) | |
750 | msg_ids = [] |
|
750 | msg_ids = [] | |
751 | trackers = [] |
|
751 | trackers = [] | |
752 | for index, engineid in enumerate(targets): |
|
752 | for index, engineid in enumerate(targets): | |
753 | partition = mapObject.getPartition(seq, index, nparts) |
|
753 | partition = mapObject.getPartition(seq, index, nparts) | |
754 | if flatten and len(partition) == 1: |
|
754 | if flatten and len(partition) == 1: | |
755 | ns = {key: partition[0]} |
|
755 | ns = {key: partition[0]} | |
756 | else: |
|
756 | else: | |
757 | ns = {key: partition} |
|
757 | ns = {key: partition} | |
758 | r = self.push(ns, block=False, track=track, targets=engineid) |
|
758 | r = self.push(ns, block=False, track=track, targets=engineid) | |
759 | msg_ids.extend(r.msg_ids) |
|
759 | msg_ids.extend(r.msg_ids) | |
760 | if track: |
|
760 | if track: | |
761 | trackers.append(r._tracker) |
|
761 | trackers.append(r._tracker) | |
762 |
|
762 | |||
763 | if track: |
|
763 | if track: | |
764 | tracker = zmq.MessageTracker(*trackers) |
|
764 | tracker = zmq.MessageTracker(*trackers) | |
765 | else: |
|
765 | else: | |
766 | tracker = None |
|
766 | tracker = None | |
767 |
|
767 | |||
768 | r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, |
|
768 | r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, | |
769 | tracker=tracker, owner=True, |
|
769 | tracker=tracker, owner=True, | |
770 | ) |
|
770 | ) | |
771 | if block: |
|
771 | if block: | |
772 | r.wait() |
|
772 | r.wait() | |
773 | else: |
|
773 | else: | |
774 | return r |
|
774 | return r | |
775 |
|
775 | |||
776 | @sync_results |
|
776 | @sync_results | |
777 | @save_ids |
|
777 | @save_ids | |
778 | def gather(self, key, dist='b', targets=None, block=None): |
|
778 | def gather(self, key, dist='b', targets=None, block=None): | |
779 | """ |
|
779 | """ | |
780 | Gather a partitioned sequence on a set of engines as a single local seq. |
|
780 | Gather a partitioned sequence on a set of engines as a single local seq. | |
781 | """ |
|
781 | """ | |
782 | block = block if block is not None else self.block |
|
782 | block = block if block is not None else self.block | |
783 | targets = targets if targets is not None else self.targets |
|
783 | targets = targets if targets is not None else self.targets | |
784 | mapObject = Map.dists[dist]() |
|
784 | mapObject = Map.dists[dist]() | |
785 | msg_ids = [] |
|
785 | msg_ids = [] | |
786 |
|
786 | |||
787 | # construct integer ID list: |
|
787 | # construct integer ID list: | |
788 | targets = self.client._build_targets(targets)[1] |
|
788 | targets = self.client._build_targets(targets)[1] | |
789 |
|
789 | |||
790 | for index, engineid in enumerate(targets): |
|
790 | for index, engineid in enumerate(targets): | |
791 | msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids) |
|
791 | msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids) | |
792 |
|
792 | |||
793 | r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') |
|
793 | r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather') | |
794 |
|
794 | |||
795 | if block: |
|
795 | if block: | |
796 | try: |
|
796 | try: | |
797 | return r.get() |
|
797 | return r.get() | |
798 | except KeyboardInterrupt: |
|
798 | except KeyboardInterrupt: | |
799 | pass |
|
799 | pass | |
800 | return r |
|
800 | return r | |
801 |
|
801 | |||
802 | def __getitem__(self, key): |
|
802 | def __getitem__(self, key): | |
803 | return self.get(key) |
|
803 | return self.get(key) | |
804 |
|
804 | |||
805 | def __setitem__(self,key, value): |
|
805 | def __setitem__(self,key, value): | |
806 | self.update({key:value}) |
|
806 | self.update({key:value}) | |
807 |
|
807 | |||
808 | def clear(self, targets=None, block=None): |
|
808 | def clear(self, targets=None, block=None): | |
809 | """Clear the remote namespaces on my engines.""" |
|
809 | """Clear the remote namespaces on my engines.""" | |
810 | block = block if block is not None else self.block |
|
810 | block = block if block is not None else self.block | |
811 | targets = targets if targets is not None else self.targets |
|
811 | targets = targets if targets is not None else self.targets | |
812 | return self.client.clear(targets=targets, block=block) |
|
812 | return self.client.clear(targets=targets, block=block) | |
813 |
|
813 | |||
814 | #---------------------------------------- |
|
814 | #---------------------------------------- | |
815 | # activate for %px, %autopx, etc. magics |
|
815 | # activate for %px, %autopx, etc. magics | |
816 | #---------------------------------------- |
|
816 | #---------------------------------------- | |
817 |
|
817 | |||
818 | def activate(self, suffix=''): |
|
818 | def activate(self, suffix=''): | |
819 | """Activate IPython magics associated with this View |
|
819 | """Activate IPython magics associated with this View | |
820 |
|
820 | |||
821 | Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig` |
|
821 | Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig` | |
822 |
|
822 | |||
823 | Parameters |
|
823 | Parameters | |
824 | ---------- |
|
824 | ---------- | |
825 |
|
825 | |||
826 | suffix: str [default: ''] |
|
826 | suffix: str [default: ''] | |
827 | The suffix, if any, for the magics. This allows you to have |
|
827 | The suffix, if any, for the magics. This allows you to have | |
828 | multiple views associated with parallel magics at the same time. |
|
828 | multiple views associated with parallel magics at the same time. | |
829 |
|
829 | |||
830 | e.g. ``rc[::2].activate(suffix='_even')`` will give you |
|
830 | e.g. ``rc[::2].activate(suffix='_even')`` will give you | |
831 | the magics ``%px_even``, ``%pxresult_even``, etc. for running magics |
|
831 | the magics ``%px_even``, ``%pxresult_even``, etc. for running magics | |
832 | on the even engines. |
|
832 | on the even engines. | |
833 | """ |
|
833 | """ | |
834 |
|
834 | |||
835 | from IPython.parallel.client.magics import ParallelMagics |
|
835 | from IPython.parallel.client.magics import ParallelMagics | |
836 |
|
836 | |||
837 | try: |
|
837 | try: | |
838 | # This is injected into __builtins__. |
|
838 | # This is injected into __builtins__. | |
839 | ip = get_ipython() |
|
839 | ip = get_ipython() | |
840 | except NameError: |
|
840 | except NameError: | |
841 | print("The IPython parallel magics (%px, etc.) only work within IPython.") |
|
841 | print("The IPython parallel magics (%px, etc.) only work within IPython.") | |
842 | return |
|
842 | return | |
843 |
|
843 | |||
844 | M = ParallelMagics(ip, self, suffix) |
|
844 | M = ParallelMagics(ip, self, suffix) | |
845 | ip.magics_manager.register(M) |
|
845 | ip.magics_manager.register(M) | |
846 |
|
846 | |||
847 |
|
847 | |||
848 | @skip_doctest |
|
848 | @skip_doctest | |
849 | class LoadBalancedView(View): |
|
849 | class LoadBalancedView(View): | |
850 | """An load-balancing View that only executes via the Task scheduler. |
|
850 | """An load-balancing View that only executes via the Task scheduler. | |
851 |
|
851 | |||
852 | Load-balanced views can be created with the client's `view` method: |
|
852 | Load-balanced views can be created with the client's `view` method: | |
853 |
|
853 | |||
854 | >>> v = client.load_balanced_view() |
|
854 | >>> v = client.load_balanced_view() | |
855 |
|
855 | |||
856 | or targets can be specified, to restrict the potential destinations: |
|
856 | or targets can be specified, to restrict the potential destinations: | |
857 |
|
857 | |||
858 | >>> v = client.load_balanced_view([1,3]) |
|
858 | >>> v = client.load_balanced_view([1,3]) | |
859 |
|
859 | |||
860 | which would restrict loadbalancing to between engines 1 and 3. |
|
860 | which would restrict loadbalancing to between engines 1 and 3. | |
861 |
|
861 | |||
862 | """ |
|
862 | """ | |
863 |
|
863 | |||
864 | follow=Any() |
|
864 | follow=Any() | |
865 | after=Any() |
|
865 | after=Any() | |
866 | timeout=CFloat() |
|
866 | timeout=CFloat() | |
867 | retries = Integer(0) |
|
867 | retries = Integer(0) | |
868 |
|
868 | |||
869 | _task_scheme = Any() |
|
869 | _task_scheme = Any() | |
870 | _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries']) |
|
870 | _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries']) | |
871 |
|
871 | |||
872 | def __init__(self, client=None, socket=None, **flags): |
|
872 | def __init__(self, client=None, socket=None, **flags): | |
873 | super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) |
|
873 | super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags) | |
874 | self._task_scheme=client._task_scheme |
|
874 | self._task_scheme=client._task_scheme | |
875 |
|
875 | |||
876 | def _validate_dependency(self, dep): |
|
876 | def _validate_dependency(self, dep): | |
877 | """validate a dependency. |
|
877 | """validate a dependency. | |
878 |
|
878 | |||
879 | For use in `set_flags`. |
|
879 | For use in `set_flags`. | |
880 | """ |
|
880 | """ | |
881 | if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)): |
|
881 | if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)): | |
882 | return True |
|
882 | return True | |
883 | elif isinstance(dep, (list,set, tuple)): |
|
883 | elif isinstance(dep, (list,set, tuple)): | |
884 | for d in dep: |
|
884 | for d in dep: | |
885 | if not isinstance(d, string_types + (AsyncResult,)): |
|
885 | if not isinstance(d, string_types + (AsyncResult,)): | |
886 | return False |
|
886 | return False | |
887 | elif isinstance(dep, dict): |
|
887 | elif isinstance(dep, dict): | |
888 | if set(dep.keys()) != set(Dependency().as_dict().keys()): |
|
888 | if set(dep.keys()) != set(Dependency().as_dict().keys()): | |
889 | return False |
|
889 | return False | |
890 | if not isinstance(dep['msg_ids'], list): |
|
890 | if not isinstance(dep['msg_ids'], list): | |
891 | return False |
|
891 | return False | |
892 | for d in dep['msg_ids']: |
|
892 | for d in dep['msg_ids']: | |
893 | if not isinstance(d, string_types): |
|
893 | if not isinstance(d, string_types): | |
894 | return False |
|
894 | return False | |
895 | else: |
|
895 | else: | |
896 | return False |
|
896 | return False | |
897 |
|
897 | |||
898 | return True |
|
898 | return True | |
899 |
|
899 | |||
900 | def _render_dependency(self, dep): |
|
900 | def _render_dependency(self, dep): | |
901 | """helper for building jsonable dependencies from various input forms.""" |
|
901 | """helper for building jsonable dependencies from various input forms.""" | |
902 | if isinstance(dep, Dependency): |
|
902 | if isinstance(dep, Dependency): | |
903 | return dep.as_dict() |
|
903 | return dep.as_dict() | |
904 | elif isinstance(dep, AsyncResult): |
|
904 | elif isinstance(dep, AsyncResult): | |
905 | return dep.msg_ids |
|
905 | return dep.msg_ids | |
906 | elif dep is None: |
|
906 | elif dep is None: | |
907 | return [] |
|
907 | return [] | |
908 | else: |
|
908 | else: | |
909 | # pass to Dependency constructor |
|
909 | # pass to Dependency constructor | |
910 | return list(Dependency(dep)) |
|
910 | return list(Dependency(dep)) | |
911 |
|
911 | |||
912 | def set_flags(self, **kwargs): |
|
912 | def set_flags(self, **kwargs): | |
913 | """set my attribute flags by keyword. |
|
913 | """set my attribute flags by keyword. | |
914 |
|
914 | |||
915 | A View is a wrapper for the Client's apply method, but with attributes |
|
915 | A View is a wrapper for the Client's apply method, but with attributes | |
916 | that specify keyword arguments, those attributes can be set by keyword |
|
916 | that specify keyword arguments, those attributes can be set by keyword | |
917 | argument with this method. |
|
917 | argument with this method. | |
918 |
|
918 | |||
919 | Parameters |
|
919 | Parameters | |
920 | ---------- |
|
920 | ---------- | |
921 |
|
921 | |||
922 | block : bool |
|
922 | block : bool | |
923 | whether to wait for results |
|
923 | whether to wait for results | |
924 | track : bool |
|
924 | track : bool | |
925 | whether to create a MessageTracker to allow the user to |
|
925 | whether to create a MessageTracker to allow the user to | |
926 | safely edit after arrays and buffers during non-copying |
|
926 | safely edit after arrays and buffers during non-copying | |
927 | sends. |
|
927 | sends. | |
928 |
|
928 | |||
929 | after : Dependency or collection of msg_ids |
|
929 | after : Dependency or collection of msg_ids | |
930 | Only for load-balanced execution (targets=None) |
|
930 | Only for load-balanced execution (targets=None) | |
931 | Specify a list of msg_ids as a time-based dependency. |
|
931 | Specify a list of msg_ids as a time-based dependency. | |
932 | This job will only be run *after* the dependencies |
|
932 | This job will only be run *after* the dependencies | |
933 | have been met. |
|
933 | have been met. | |
934 |
|
934 | |||
935 | follow : Dependency or collection of msg_ids |
|
935 | follow : Dependency or collection of msg_ids | |
936 | Only for load-balanced execution (targets=None) |
|
936 | Only for load-balanced execution (targets=None) | |
937 | Specify a list of msg_ids as a location-based dependency. |
|
937 | Specify a list of msg_ids as a location-based dependency. | |
938 | This job will only be run on an engine where this dependency |
|
938 | This job will only be run on an engine where this dependency | |
939 | is met. |
|
939 | is met. | |
940 |
|
940 | |||
941 | timeout : float/int or None |
|
941 | timeout : float/int or None | |
942 | Only for load-balanced execution (targets=None) |
|
942 | Only for load-balanced execution (targets=None) | |
943 | Specify an amount of time (in seconds) for the scheduler to |
|
943 | Specify an amount of time (in seconds) for the scheduler to | |
944 | wait for dependencies to be met before failing with a |
|
944 | wait for dependencies to be met before failing with a | |
945 | DependencyTimeout. |
|
945 | DependencyTimeout. | |
946 |
|
946 | |||
947 | retries : int |
|
947 | retries : int | |
948 | Number of times a task will be retried on failure. |
|
948 | Number of times a task will be retried on failure. | |
949 | """ |
|
949 | """ | |
950 |
|
950 | |||
951 | super(LoadBalancedView, self).set_flags(**kwargs) |
|
951 | super(LoadBalancedView, self).set_flags(**kwargs) | |
952 | for name in ('follow', 'after'): |
|
952 | for name in ('follow', 'after'): | |
953 | if name in kwargs: |
|
953 | if name in kwargs: | |
954 | value = kwargs[name] |
|
954 | value = kwargs[name] | |
955 | if self._validate_dependency(value): |
|
955 | if self._validate_dependency(value): | |
956 | setattr(self, name, value) |
|
956 | setattr(self, name, value) | |
957 | else: |
|
957 | else: | |
958 | raise ValueError("Invalid dependency: %r"%value) |
|
958 | raise ValueError("Invalid dependency: %r"%value) | |
959 | if 'timeout' in kwargs: |
|
959 | if 'timeout' in kwargs: | |
960 | t = kwargs['timeout'] |
|
960 | t = kwargs['timeout'] | |
961 | if not isinstance(t, (int, float, type(None))): |
|
961 | if not isinstance(t, (int, float, type(None))): | |
962 | if (not PY3) and (not isinstance(t, long)): |
|
962 | if (not PY3) and (not isinstance(t, long)): | |
963 | raise TypeError("Invalid type for timeout: %r"%type(t)) |
|
963 | raise TypeError("Invalid type for timeout: %r"%type(t)) | |
964 | if t is not None: |
|
964 | if t is not None: | |
965 | if t < 0: |
|
965 | if t < 0: | |
966 | raise ValueError("Invalid timeout: %s"%t) |
|
966 | raise ValueError("Invalid timeout: %s"%t) | |
967 | self.timeout = t |
|
967 | self.timeout = t | |
968 |
|
968 | |||
969 | @sync_results |
|
969 | @sync_results | |
970 | @save_ids |
|
970 | @save_ids | |
971 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, |
|
971 | def _really_apply(self, f, args=None, kwargs=None, block=None, track=None, | |
972 | after=None, follow=None, timeout=None, |
|
972 | after=None, follow=None, timeout=None, | |
973 | targets=None, retries=None): |
|
973 | targets=None, retries=None): | |
974 | """calls f(*args, **kwargs) on a remote engine, returning the result. |
|
974 | """calls f(*args, **kwargs) on a remote engine, returning the result. | |
975 |
|
975 | |||
976 | This method temporarily sets all of `apply`'s flags for a single call. |
|
976 | This method temporarily sets all of `apply`'s flags for a single call. | |
977 |
|
977 | |||
978 | Parameters |
|
978 | Parameters | |
979 | ---------- |
|
979 | ---------- | |
980 |
|
980 | |||
981 | f : callable |
|
981 | f : callable | |
982 |
|
982 | |||
983 | args : list [default: empty] |
|
983 | args : list [default: empty] | |
984 |
|
984 | |||
985 | kwargs : dict [default: empty] |
|
985 | kwargs : dict [default: empty] | |
986 |
|
986 | |||
987 | block : bool [default: self.block] |
|
987 | block : bool [default: self.block] | |
988 | whether to block |
|
988 | whether to block | |
989 | track : bool [default: self.track] |
|
989 | track : bool [default: self.track] | |
990 | whether to ask zmq to track the message, for safe non-copying sends |
|
990 | whether to ask zmq to track the message, for safe non-copying sends | |
991 |
|
991 | |||
992 | !!!!!! TODO: THE REST HERE !!!! |
|
992 | !!!!!! TODO: THE REST HERE !!!! | |
993 |
|
993 | |||
994 | Returns |
|
994 | Returns | |
995 | ------- |
|
995 | ------- | |
996 |
|
996 | |||
997 | if self.block is False: |
|
997 | if self.block is False: | |
998 | returns AsyncResult |
|
998 | returns AsyncResult | |
999 | else: |
|
999 | else: | |
1000 | returns actual result of f(*args, **kwargs) on the engine(s) |
|
1000 | returns actual result of f(*args, **kwargs) on the engine(s) | |
1001 | This will be a list of self.targets is also a list (even length 1), or |
|
1001 | This will be a list of self.targets is also a list (even length 1), or | |
1002 | the single result if self.targets is an integer engine id |
|
1002 | the single result if self.targets is an integer engine id | |
1003 | """ |
|
1003 | """ | |
1004 |
|
1004 | |||
1005 | # validate whether we can run |
|
1005 | # validate whether we can run | |
1006 | if self._socket.closed: |
|
1006 | if self._socket.closed: | |
1007 | msg = "Task farming is disabled" |
|
1007 | msg = "Task farming is disabled" | |
1008 | if self._task_scheme == 'pure': |
|
1008 | if self._task_scheme == 'pure': | |
1009 | msg += " because the pure ZMQ scheduler cannot handle" |
|
1009 | msg += " because the pure ZMQ scheduler cannot handle" | |
1010 | msg += " disappearing engines." |
|
1010 | msg += " disappearing engines." | |
1011 | raise RuntimeError(msg) |
|
1011 | raise RuntimeError(msg) | |
1012 |
|
1012 | |||
1013 | if self._task_scheme == 'pure': |
|
1013 | if self._task_scheme == 'pure': | |
1014 | # pure zmq scheme doesn't support extra features |
|
1014 | # pure zmq scheme doesn't support extra features | |
1015 | msg = "Pure ZMQ scheduler doesn't support the following flags:" |
|
1015 | msg = "Pure ZMQ scheduler doesn't support the following flags:" | |
1016 | "follow, after, retries, targets, timeout" |
|
1016 | "follow, after, retries, targets, timeout" | |
1017 | if (follow or after or retries or targets or timeout): |
|
1017 | if (follow or after or retries or targets or timeout): | |
1018 | # hard fail on Scheduler flags |
|
1018 | # hard fail on Scheduler flags | |
1019 | raise RuntimeError(msg) |
|
1019 | raise RuntimeError(msg) | |
1020 | if isinstance(f, dependent): |
|
1020 | if isinstance(f, dependent): | |
1021 | # soft warn on functional dependencies |
|
1021 | # soft warn on functional dependencies | |
1022 | warnings.warn(msg, RuntimeWarning) |
|
1022 | warnings.warn(msg, RuntimeWarning) | |
1023 |
|
1023 | |||
1024 | # build args |
|
1024 | # build args | |
1025 | args = [] if args is None else args |
|
1025 | args = [] if args is None else args | |
1026 | kwargs = {} if kwargs is None else kwargs |
|
1026 | kwargs = {} if kwargs is None else kwargs | |
1027 | block = self.block if block is None else block |
|
1027 | block = self.block if block is None else block | |
1028 | track = self.track if track is None else track |
|
1028 | track = self.track if track is None else track | |
1029 | after = self.after if after is None else after |
|
1029 | after = self.after if after is None else after | |
1030 | retries = self.retries if retries is None else retries |
|
1030 | retries = self.retries if retries is None else retries | |
1031 | follow = self.follow if follow is None else follow |
|
1031 | follow = self.follow if follow is None else follow | |
1032 | timeout = self.timeout if timeout is None else timeout |
|
1032 | timeout = self.timeout if timeout is None else timeout | |
1033 | targets = self.targets if targets is None else targets |
|
1033 | targets = self.targets if targets is None else targets | |
1034 |
|
1034 | |||
1035 | if not isinstance(retries, int): |
|
1035 | if not isinstance(retries, int): | |
1036 | raise TypeError('retries must be int, not %r'%type(retries)) |
|
1036 | raise TypeError('retries must be int, not %r'%type(retries)) | |
1037 |
|
1037 | |||
1038 | if targets is None: |
|
1038 | if targets is None: | |
1039 | idents = [] |
|
1039 | idents = [] | |
1040 | else: |
|
1040 | else: | |
1041 | idents = self.client._build_targets(targets)[0] |
|
1041 | idents = self.client._build_targets(targets)[0] | |
1042 | # ensure *not* bytes |
|
1042 | # ensure *not* bytes | |
1043 | idents = [ ident.decode() for ident in idents ] |
|
1043 | idents = [ ident.decode() for ident in idents ] | |
1044 |
|
1044 | |||
1045 | after = self._render_dependency(after) |
|
1045 | after = self._render_dependency(after) | |
1046 | follow = self._render_dependency(follow) |
|
1046 | follow = self._render_dependency(follow) | |
1047 | metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries) |
|
1047 | metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries) | |
1048 |
|
1048 | |||
1049 | msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track, |
|
1049 | msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track, | |
1050 | metadata=metadata) |
|
1050 | metadata=metadata) | |
1051 | tracker = None if track is False else msg['tracker'] |
|
1051 | tracker = None if track is False else msg['tracker'] | |
1052 |
|
1052 | |||
1053 | ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), |
|
1053 | ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), | |
1054 | targets=None, tracker=tracker, owner=True, |
|
1054 | targets=None, tracker=tracker, owner=True, | |
1055 | ) |
|
1055 | ) | |
1056 | if block: |
|
1056 | if block: | |
1057 | try: |
|
1057 | try: | |
1058 | return ar.get() |
|
1058 | return ar.get() | |
1059 | except KeyboardInterrupt: |
|
1059 | except KeyboardInterrupt: | |
1060 | pass |
|
1060 | pass | |
1061 | return ar |
|
1061 | return ar | |
1062 |
|
1062 | |||
1063 | @sync_results |
|
1063 | @sync_results | |
1064 | @save_ids |
|
1064 | @save_ids | |
1065 | def map(self, f, *sequences, **kwargs): |
|
1065 | def map(self, f, *sequences, **kwargs): | |
1066 | """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult |
|
1066 | """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult | |
1067 |
|
1067 | |||
1068 | Parallel version of builtin `map`, load-balanced by this View. |
|
1068 | Parallel version of builtin `map`, load-balanced by this View. | |
1069 |
|
1069 | |||
1070 | `block`, and `chunksize` can be specified by keyword only. |
|
1070 | `block`, and `chunksize` can be specified by keyword only. | |
1071 |
|
1071 | |||
1072 | Each `chunksize` elements will be a separate task, and will be |
|
1072 | Each `chunksize` elements will be a separate task, and will be | |
1073 | load-balanced. This lets individual elements be available for iteration |
|
1073 | load-balanced. This lets individual elements be available for iteration | |
1074 | as soon as they arrive. |
|
1074 | as soon as they arrive. | |
1075 |
|
1075 | |||
1076 | Parameters |
|
1076 | Parameters | |
1077 | ---------- |
|
1077 | ---------- | |
1078 |
|
1078 | |||
1079 | f : callable |
|
1079 | f : callable | |
1080 | function to be mapped |
|
1080 | function to be mapped | |
1081 | *sequences: one or more sequences of matching length |
|
1081 | *sequences: one or more sequences of matching length | |
1082 | the sequences to be distributed and passed to `f` |
|
1082 | the sequences to be distributed and passed to `f` | |
1083 | block : bool [default self.block] |
|
1083 | block : bool [default self.block] | |
1084 | whether to wait for the result or not |
|
1084 | whether to wait for the result or not | |
1085 | track : bool |
|
1085 | track : bool | |
1086 | whether to create a MessageTracker to allow the user to |
|
1086 | whether to create a MessageTracker to allow the user to | |
1087 | safely edit after arrays and buffers during non-copying |
|
1087 | safely edit after arrays and buffers during non-copying | |
1088 | sends. |
|
1088 | sends. | |
1089 | chunksize : int [default 1] |
|
1089 | chunksize : int [default 1] | |
1090 | how many elements should be in each task. |
|
1090 | how many elements should be in each task. | |
1091 | ordered : bool [default True] |
|
1091 | ordered : bool [default True] | |
1092 | Whether the results should be gathered as they arrive, or enforce |
|
1092 | Whether the results should be gathered as they arrive, or enforce | |
1093 | the order of submission. |
|
1093 | the order of submission. | |
1094 |
|
1094 | |||
1095 | Only applies when iterating through AsyncMapResult as results arrive. |
|
1095 | Only applies when iterating through AsyncMapResult as results arrive. | |
1096 | Has no effect when block=True. |
|
1096 | Has no effect when block=True. | |
1097 |
|
1097 | |||
1098 | Returns |
|
1098 | Returns | |
1099 | ------- |
|
1099 | ------- | |
1100 |
|
1100 | |||
1101 | if block=False |
|
1101 | if block=False | |
1102 | An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance. |
|
1102 | An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance. | |
1103 | An object like AsyncResult, but which reassembles the sequence of results |
|
1103 | An object like AsyncResult, but which reassembles the sequence of results | |
1104 | into a single list. AsyncMapResults can be iterated through before all |
|
1104 | into a single list. AsyncMapResults can be iterated through before all | |
1105 | results are complete. |
|
1105 | results are complete. | |
1106 | else |
|
1106 | else | |
1107 | A list, the result of ``map(f,*sequences)`` |
|
1107 | A list, the result of ``map(f,*sequences)`` | |
1108 | """ |
|
1108 | """ | |
1109 |
|
1109 | |||
1110 | # default |
|
1110 | # default | |
1111 | block = kwargs.get('block', self.block) |
|
1111 | block = kwargs.get('block', self.block) | |
1112 | chunksize = kwargs.get('chunksize', 1) |
|
1112 | chunksize = kwargs.get('chunksize', 1) | |
1113 | ordered = kwargs.get('ordered', True) |
|
1113 | ordered = kwargs.get('ordered', True) | |
1114 |
|
1114 | |||
1115 | keyset = set(kwargs.keys()) |
|
1115 | keyset = set(kwargs.keys()) | |
1116 | extra_keys = keyset.difference_update(set(['block', 'chunksize'])) |
|
1116 | extra_keys = keyset.difference_update(set(['block', 'chunksize'])) | |
1117 | if extra_keys: |
|
1117 | if extra_keys: | |
1118 | raise TypeError("Invalid kwargs: %s"%list(extra_keys)) |
|
1118 | raise TypeError("Invalid kwargs: %s"%list(extra_keys)) | |
1119 |
|
1119 | |||
1120 | assert len(sequences) > 0, "must have some sequences to map onto!" |
|
1120 | assert len(sequences) > 0, "must have some sequences to map onto!" | |
1121 |
|
1121 | |||
1122 | pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered) |
|
1122 | pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered) | |
1123 | return pf.map(*sequences) |
|
1123 | return pf.map(*sequences) | |
1124 |
|
1124 | |||
1125 | __all__ = ['LoadBalancedView', 'DirectView'] |
|
1125 | __all__ = ['LoadBalancedView', 'DirectView'] |
@@ -1,229 +1,229 b'' | |||||
1 | """Dependency utilities |
|
1 | """Dependency utilities | |
2 |
|
2 | |||
3 | Authors: |
|
3 | Authors: | |
4 |
|
4 | |||
5 | * Min RK |
|
5 | * Min RK | |
6 | """ |
|
6 | """ | |
7 | #----------------------------------------------------------------------------- |
|
7 | #----------------------------------------------------------------------------- | |
8 | # Copyright (C) 2013 The IPython Development Team |
|
8 | # Copyright (C) 2013 The IPython Development Team | |
9 | # |
|
9 | # | |
10 | # Distributed under the terms of the BSD License. The full license is in |
|
10 | # Distributed under the terms of the BSD License. The full license is in | |
11 | # the file COPYING, distributed as part of this software. |
|
11 | # the file COPYING, distributed as part of this software. | |
12 | #----------------------------------------------------------------------------- |
|
12 | #----------------------------------------------------------------------------- | |
13 |
|
13 | |||
14 | from types import ModuleType |
|
14 | from types import ModuleType | |
15 |
|
15 | |||
16 | from ipython_parallel.client.asyncresult import AsyncResult |
|
16 | from ipython_parallel.client.asyncresult import AsyncResult | |
17 | from ipython_parallel.error import UnmetDependency |
|
17 | from ipython_parallel.error import UnmetDependency | |
18 | from ipython_parallel.util import interactive |
|
18 | from ipython_parallel.util import interactive | |
19 | from IPython.utils import py3compat |
|
19 | from IPython.utils import py3compat | |
20 | from IPython.utils.py3compat import string_types |
|
20 | from IPython.utils.py3compat import string_types | |
21 |
from |
|
21 | from ipython_kernel.pickleutil import can, uncan | |
22 |
|
22 | |||
23 | class depend(object): |
|
23 | class depend(object): | |
24 | """Dependency decorator, for use with tasks. |
|
24 | """Dependency decorator, for use with tasks. | |
25 |
|
25 | |||
26 | `@depend` lets you define a function for engine dependencies |
|
26 | `@depend` lets you define a function for engine dependencies | |
27 | just like you use `apply` for tasks. |
|
27 | just like you use `apply` for tasks. | |
28 |
|
28 | |||
29 |
|
29 | |||
30 | Examples |
|
30 | Examples | |
31 | -------- |
|
31 | -------- | |
32 | :: |
|
32 | :: | |
33 |
|
33 | |||
34 | @depend(df, a,b, c=5) |
|
34 | @depend(df, a,b, c=5) | |
35 | def f(m,n,p) |
|
35 | def f(m,n,p) | |
36 |
|
36 | |||
37 | view.apply(f, 1,2,3) |
|
37 | view.apply(f, 1,2,3) | |
38 |
|
38 | |||
39 | will call df(a,b,c=5) on the engine, and if it returns False or |
|
39 | will call df(a,b,c=5) on the engine, and if it returns False or | |
40 | raises an UnmetDependency error, then the task will not be run |
|
40 | raises an UnmetDependency error, then the task will not be run | |
41 | and another engine will be tried. |
|
41 | and another engine will be tried. | |
42 | """ |
|
42 | """ | |
43 | def __init__(self, _wrapped_f, *args, **kwargs): |
|
43 | def __init__(self, _wrapped_f, *args, **kwargs): | |
44 | self.f = _wrapped_f |
|
44 | self.f = _wrapped_f | |
45 | self.args = args |
|
45 | self.args = args | |
46 | self.kwargs = kwargs |
|
46 | self.kwargs = kwargs | |
47 |
|
47 | |||
48 | def __call__(self, f): |
|
48 | def __call__(self, f): | |
49 | return dependent(f, self.f, *self.args, **self.kwargs) |
|
49 | return dependent(f, self.f, *self.args, **self.kwargs) | |
50 |
|
50 | |||
51 | class dependent(object): |
|
51 | class dependent(object): | |
52 | """A function that depends on another function. |
|
52 | """A function that depends on another function. | |
53 | This is an object to prevent the closure used |
|
53 | This is an object to prevent the closure used | |
54 | in traditional decorators, which are not picklable. |
|
54 | in traditional decorators, which are not picklable. | |
55 | """ |
|
55 | """ | |
56 |
|
56 | |||
57 | def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs): |
|
57 | def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs): | |
58 | self.f = _wrapped_f |
|
58 | self.f = _wrapped_f | |
59 | name = getattr(_wrapped_f, '__name__', 'f') |
|
59 | name = getattr(_wrapped_f, '__name__', 'f') | |
60 | if py3compat.PY3: |
|
60 | if py3compat.PY3: | |
61 | self.__name__ = name |
|
61 | self.__name__ = name | |
62 | else: |
|
62 | else: | |
63 | self.func_name = name |
|
63 | self.func_name = name | |
64 | self.df = _wrapped_df |
|
64 | self.df = _wrapped_df | |
65 | self.dargs = dargs |
|
65 | self.dargs = dargs | |
66 | self.dkwargs = dkwargs |
|
66 | self.dkwargs = dkwargs | |
67 |
|
67 | |||
68 | def check_dependency(self): |
|
68 | def check_dependency(self): | |
69 | if self.df(*self.dargs, **self.dkwargs) is False: |
|
69 | if self.df(*self.dargs, **self.dkwargs) is False: | |
70 | raise UnmetDependency() |
|
70 | raise UnmetDependency() | |
71 |
|
71 | |||
72 | def __call__(self, *args, **kwargs): |
|
72 | def __call__(self, *args, **kwargs): | |
73 | return self.f(*args, **kwargs) |
|
73 | return self.f(*args, **kwargs) | |
74 |
|
74 | |||
75 | if not py3compat.PY3: |
|
75 | if not py3compat.PY3: | |
76 | @property |
|
76 | @property | |
77 | def __name__(self): |
|
77 | def __name__(self): | |
78 | return self.func_name |
|
78 | return self.func_name | |
79 |
|
79 | |||
80 | @interactive |
|
80 | @interactive | |
81 | def _require(*modules, **mapping): |
|
81 | def _require(*modules, **mapping): | |
82 | """Helper for @require decorator.""" |
|
82 | """Helper for @require decorator.""" | |
83 | from ipython_parallel.error import UnmetDependency |
|
83 | from ipython_parallel.error import UnmetDependency | |
84 |
from |
|
84 | from ipython_kernel.pickleutil import uncan | |
85 | user_ns = globals() |
|
85 | user_ns = globals() | |
86 | for name in modules: |
|
86 | for name in modules: | |
87 | try: |
|
87 | try: | |
88 | exec('import %s' % name, user_ns) |
|
88 | exec('import %s' % name, user_ns) | |
89 | except ImportError: |
|
89 | except ImportError: | |
90 | raise UnmetDependency(name) |
|
90 | raise UnmetDependency(name) | |
91 |
|
91 | |||
92 | for name, cobj in mapping.items(): |
|
92 | for name, cobj in mapping.items(): | |
93 | user_ns[name] = uncan(cobj, user_ns) |
|
93 | user_ns[name] = uncan(cobj, user_ns) | |
94 | return True |
|
94 | return True | |
95 |
|
95 | |||
96 | def require(*objects, **mapping): |
|
96 | def require(*objects, **mapping): | |
97 | """Simple decorator for requiring local objects and modules to be available |
|
97 | """Simple decorator for requiring local objects and modules to be available | |
98 | when the decorated function is called on the engine. |
|
98 | when the decorated function is called on the engine. | |
99 |
|
99 | |||
100 | Modules specified by name or passed directly will be imported |
|
100 | Modules specified by name or passed directly will be imported | |
101 | prior to calling the decorated function. |
|
101 | prior to calling the decorated function. | |
102 |
|
102 | |||
103 | Objects other than modules will be pushed as a part of the task. |
|
103 | Objects other than modules will be pushed as a part of the task. | |
104 | Functions can be passed positionally, |
|
104 | Functions can be passed positionally, | |
105 | and will be pushed to the engine with their __name__. |
|
105 | and will be pushed to the engine with their __name__. | |
106 | Other objects can be passed by keyword arg. |
|
106 | Other objects can be passed by keyword arg. | |
107 |
|
107 | |||
108 | Examples:: |
|
108 | Examples:: | |
109 |
|
109 | |||
110 | In [1]: @require('numpy') |
|
110 | In [1]: @require('numpy') | |
111 | ...: def norm(a): |
|
111 | ...: def norm(a): | |
112 | ...: return numpy.linalg.norm(a,2) |
|
112 | ...: return numpy.linalg.norm(a,2) | |
113 |
|
113 | |||
114 | In [2]: foo = lambda x: x*x |
|
114 | In [2]: foo = lambda x: x*x | |
115 | In [3]: @require(foo) |
|
115 | In [3]: @require(foo) | |
116 | ...: def bar(a): |
|
116 | ...: def bar(a): | |
117 | ...: return foo(1-a) |
|
117 | ...: return foo(1-a) | |
118 | """ |
|
118 | """ | |
119 | names = [] |
|
119 | names = [] | |
120 | for obj in objects: |
|
120 | for obj in objects: | |
121 | if isinstance(obj, ModuleType): |
|
121 | if isinstance(obj, ModuleType): | |
122 | obj = obj.__name__ |
|
122 | obj = obj.__name__ | |
123 |
|
123 | |||
124 | if isinstance(obj, string_types): |
|
124 | if isinstance(obj, string_types): | |
125 | names.append(obj) |
|
125 | names.append(obj) | |
126 | elif hasattr(obj, '__name__'): |
|
126 | elif hasattr(obj, '__name__'): | |
127 | mapping[obj.__name__] = obj |
|
127 | mapping[obj.__name__] = obj | |
128 | else: |
|
128 | else: | |
129 | raise TypeError("Objects other than modules and functions " |
|
129 | raise TypeError("Objects other than modules and functions " | |
130 | "must be passed by kwarg, but got: %s" % type(obj) |
|
130 | "must be passed by kwarg, but got: %s" % type(obj) | |
131 | ) |
|
131 | ) | |
132 |
|
132 | |||
133 | for name, obj in mapping.items(): |
|
133 | for name, obj in mapping.items(): | |
134 | mapping[name] = can(obj) |
|
134 | mapping[name] = can(obj) | |
135 | return depend(_require, *names, **mapping) |
|
135 | return depend(_require, *names, **mapping) | |
136 |
|
136 | |||
137 | class Dependency(set): |
|
137 | class Dependency(set): | |
138 | """An object for representing a set of msg_id dependencies. |
|
138 | """An object for representing a set of msg_id dependencies. | |
139 |
|
139 | |||
140 | Subclassed from set(). |
|
140 | Subclassed from set(). | |
141 |
|
141 | |||
142 | Parameters |
|
142 | Parameters | |
143 | ---------- |
|
143 | ---------- | |
144 | dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict() |
|
144 | dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict() | |
145 | The msg_ids to depend on |
|
145 | The msg_ids to depend on | |
146 | all : bool [default True] |
|
146 | all : bool [default True] | |
147 | Whether the dependency should be considered met when *all* depending tasks have completed |
|
147 | Whether the dependency should be considered met when *all* depending tasks have completed | |
148 | or only when *any* have been completed. |
|
148 | or only when *any* have been completed. | |
149 | success : bool [default True] |
|
149 | success : bool [default True] | |
150 | Whether to consider successes as fulfilling dependencies. |
|
150 | Whether to consider successes as fulfilling dependencies. | |
151 | failure : bool [default False] |
|
151 | failure : bool [default False] | |
152 | Whether to consider failures as fulfilling dependencies. |
|
152 | Whether to consider failures as fulfilling dependencies. | |
153 |
|
153 | |||
154 | If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency |
|
154 | If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency | |
155 | as soon as the first depended-upon task fails. |
|
155 | as soon as the first depended-upon task fails. | |
156 | """ |
|
156 | """ | |
157 |
|
157 | |||
158 | all=True |
|
158 | all=True | |
159 | success=True |
|
159 | success=True | |
160 | failure=True |
|
160 | failure=True | |
161 |
|
161 | |||
162 | def __init__(self, dependencies=[], all=True, success=True, failure=False): |
|
162 | def __init__(self, dependencies=[], all=True, success=True, failure=False): | |
163 | if isinstance(dependencies, dict): |
|
163 | if isinstance(dependencies, dict): | |
164 | # load from dict |
|
164 | # load from dict | |
165 | all = dependencies.get('all', True) |
|
165 | all = dependencies.get('all', True) | |
166 | success = dependencies.get('success', success) |
|
166 | success = dependencies.get('success', success) | |
167 | failure = dependencies.get('failure', failure) |
|
167 | failure = dependencies.get('failure', failure) | |
168 | dependencies = dependencies.get('dependencies', []) |
|
168 | dependencies = dependencies.get('dependencies', []) | |
169 | ids = [] |
|
169 | ids = [] | |
170 |
|
170 | |||
171 | # extract ids from various sources: |
|
171 | # extract ids from various sources: | |
172 | if isinstance(dependencies, string_types + (AsyncResult,)): |
|
172 | if isinstance(dependencies, string_types + (AsyncResult,)): | |
173 | dependencies = [dependencies] |
|
173 | dependencies = [dependencies] | |
174 | for d in dependencies: |
|
174 | for d in dependencies: | |
175 | if isinstance(d, string_types): |
|
175 | if isinstance(d, string_types): | |
176 | ids.append(d) |
|
176 | ids.append(d) | |
177 | elif isinstance(d, AsyncResult): |
|
177 | elif isinstance(d, AsyncResult): | |
178 | ids.extend(d.msg_ids) |
|
178 | ids.extend(d.msg_ids) | |
179 | else: |
|
179 | else: | |
180 | raise TypeError("invalid dependency type: %r"%type(d)) |
|
180 | raise TypeError("invalid dependency type: %r"%type(d)) | |
181 |
|
181 | |||
182 | set.__init__(self, ids) |
|
182 | set.__init__(self, ids) | |
183 | self.all = all |
|
183 | self.all = all | |
184 | if not (success or failure): |
|
184 | if not (success or failure): | |
185 | raise ValueError("Must depend on at least one of successes or failures!") |
|
185 | raise ValueError("Must depend on at least one of successes or failures!") | |
186 | self.success=success |
|
186 | self.success=success | |
187 | self.failure = failure |
|
187 | self.failure = failure | |
188 |
|
188 | |||
189 | def check(self, completed, failed=None): |
|
189 | def check(self, completed, failed=None): | |
190 | """check whether our dependencies have been met.""" |
|
190 | """check whether our dependencies have been met.""" | |
191 | if len(self) == 0: |
|
191 | if len(self) == 0: | |
192 | return True |
|
192 | return True | |
193 | against = set() |
|
193 | against = set() | |
194 | if self.success: |
|
194 | if self.success: | |
195 | against = completed |
|
195 | against = completed | |
196 | if failed is not None and self.failure: |
|
196 | if failed is not None and self.failure: | |
197 | against = against.union(failed) |
|
197 | against = against.union(failed) | |
198 | if self.all: |
|
198 | if self.all: | |
199 | return self.issubset(against) |
|
199 | return self.issubset(against) | |
200 | else: |
|
200 | else: | |
201 | return not self.isdisjoint(against) |
|
201 | return not self.isdisjoint(against) | |
202 |
|
202 | |||
203 | def unreachable(self, completed, failed=None): |
|
203 | def unreachable(self, completed, failed=None): | |
204 | """return whether this dependency has become impossible.""" |
|
204 | """return whether this dependency has become impossible.""" | |
205 | if len(self) == 0: |
|
205 | if len(self) == 0: | |
206 | return False |
|
206 | return False | |
207 | against = set() |
|
207 | against = set() | |
208 | if not self.success: |
|
208 | if not self.success: | |
209 | against = completed |
|
209 | against = completed | |
210 | if failed is not None and not self.failure: |
|
210 | if failed is not None and not self.failure: | |
211 | against = against.union(failed) |
|
211 | against = against.union(failed) | |
212 | if self.all: |
|
212 | if self.all: | |
213 | return not self.isdisjoint(against) |
|
213 | return not self.isdisjoint(against) | |
214 | else: |
|
214 | else: | |
215 | return self.issubset(against) |
|
215 | return self.issubset(against) | |
216 |
|
216 | |||
217 |
|
217 | |||
218 | def as_dict(self): |
|
218 | def as_dict(self): | |
219 | """Represent this dependency as a dict. For json compatibility.""" |
|
219 | """Represent this dependency as a dict. For json compatibility.""" | |
220 | return dict( |
|
220 | return dict( | |
221 | dependencies=list(self), |
|
221 | dependencies=list(self), | |
222 | all=self.all, |
|
222 | all=self.all, | |
223 | success=self.success, |
|
223 | success=self.success, | |
224 | failure=self.failure |
|
224 | failure=self.failure | |
225 | ) |
|
225 | ) | |
226 |
|
226 | |||
227 |
|
227 | |||
228 | __all__ = ['depend', 'require', 'dependent', 'Dependency'] |
|
228 | __all__ = ['depend', 'require', 'dependent', 'Dependency'] | |
229 |
|
229 |
@@ -1,136 +1,136 b'' | |||||
1 | """Tests for dependency.py |
|
1 | """Tests for dependency.py | |
2 |
|
2 | |||
3 | Authors: |
|
3 | Authors: | |
4 |
|
4 | |||
5 | * Min RK |
|
5 | * Min RK | |
6 | """ |
|
6 | """ | |
7 |
|
7 | |||
8 | __docformat__ = "restructuredtext en" |
|
8 | __docformat__ = "restructuredtext en" | |
9 |
|
9 | |||
10 | #------------------------------------------------------------------------------- |
|
10 | #------------------------------------------------------------------------------- | |
11 | # Copyright (C) 2011 The IPython Development Team |
|
11 | # Copyright (C) 2011 The IPython Development Team | |
12 | # |
|
12 | # | |
13 | # Distributed under the terms of the BSD License. The full license is in |
|
13 | # Distributed under the terms of the BSD License. The full license is in | |
14 | # the file COPYING, distributed as part of this software. |
|
14 | # the file COPYING, distributed as part of this software. | |
15 | #------------------------------------------------------------------------------- |
|
15 | #------------------------------------------------------------------------------- | |
16 |
|
16 | |||
17 | #------------------------------------------------------------------------------- |
|
17 | #------------------------------------------------------------------------------- | |
18 | # Imports |
|
18 | # Imports | |
19 | #------------------------------------------------------------------------------- |
|
19 | #------------------------------------------------------------------------------- | |
20 |
|
20 | |||
21 | # import |
|
21 | # import | |
22 | import os |
|
22 | import os | |
23 |
|
23 | |||
24 |
from |
|
24 | from ipython_kernel.pickleutil import can, uncan | |
25 |
|
25 | |||
26 | import ipython_parallel as pmod |
|
26 | import ipython_parallel as pmod | |
27 | from ipython_parallel.util import interactive |
|
27 | from ipython_parallel.util import interactive | |
28 |
|
28 | |||
29 | from ipython_parallel.tests import add_engines |
|
29 | from ipython_parallel.tests import add_engines | |
30 | from .clienttest import ClusterTestCase |
|
30 | from .clienttest import ClusterTestCase | |
31 |
|
31 | |||
32 | def setup(): |
|
32 | def setup(): | |
33 | add_engines(1, total=True) |
|
33 | add_engines(1, total=True) | |
34 |
|
34 | |||
35 | @pmod.require('time') |
|
35 | @pmod.require('time') | |
36 | def wait(n): |
|
36 | def wait(n): | |
37 | time.sleep(n) |
|
37 | time.sleep(n) | |
38 | return n |
|
38 | return n | |
39 |
|
39 | |||
40 | @pmod.interactive |
|
40 | @pmod.interactive | |
41 | def func(x): |
|
41 | def func(x): | |
42 | return x*x |
|
42 | return x*x | |
43 |
|
43 | |||
44 | mixed = list(map(str, range(10))) |
|
44 | mixed = list(map(str, range(10))) | |
45 | completed = list(map(str, range(0,10,2))) |
|
45 | completed = list(map(str, range(0,10,2))) | |
46 | failed = list(map(str, range(1,10,2))) |
|
46 | failed = list(map(str, range(1,10,2))) | |
47 |
|
47 | |||
48 | class DependencyTest(ClusterTestCase): |
|
48 | class DependencyTest(ClusterTestCase): | |
49 |
|
49 | |||
50 | def setUp(self): |
|
50 | def setUp(self): | |
51 | ClusterTestCase.setUp(self) |
|
51 | ClusterTestCase.setUp(self) | |
52 | self.user_ns = {'__builtins__' : __builtins__} |
|
52 | self.user_ns = {'__builtins__' : __builtins__} | |
53 | self.view = self.client.load_balanced_view() |
|
53 | self.view = self.client.load_balanced_view() | |
54 | self.dview = self.client[-1] |
|
54 | self.dview = self.client[-1] | |
55 | self.succeeded = set(map(str, range(0,25,2))) |
|
55 | self.succeeded = set(map(str, range(0,25,2))) | |
56 | self.failed = set(map(str, range(1,25,2))) |
|
56 | self.failed = set(map(str, range(1,25,2))) | |
57 |
|
57 | |||
58 | def assertMet(self, dep): |
|
58 | def assertMet(self, dep): | |
59 | self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met") |
|
59 | self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met") | |
60 |
|
60 | |||
61 | def assertUnmet(self, dep): |
|
61 | def assertUnmet(self, dep): | |
62 | self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met") |
|
62 | self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met") | |
63 |
|
63 | |||
64 | def assertUnreachable(self, dep): |
|
64 | def assertUnreachable(self, dep): | |
65 | self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable") |
|
65 | self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable") | |
66 |
|
66 | |||
67 | def assertReachable(self, dep): |
|
67 | def assertReachable(self, dep): | |
68 | self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable") |
|
68 | self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable") | |
69 |
|
69 | |||
70 | def cancan(self, f): |
|
70 | def cancan(self, f): | |
71 | """decorator to pass through canning into self.user_ns""" |
|
71 | """decorator to pass through canning into self.user_ns""" | |
72 | return uncan(can(f), self.user_ns) |
|
72 | return uncan(can(f), self.user_ns) | |
73 |
|
73 | |||
74 | def test_require_imports(self): |
|
74 | def test_require_imports(self): | |
75 | """test that @require imports names""" |
|
75 | """test that @require imports names""" | |
76 | @self.cancan |
|
76 | @self.cancan | |
77 | @pmod.require('base64') |
|
77 | @pmod.require('base64') | |
78 | @interactive |
|
78 | @interactive | |
79 | def encode(arg): |
|
79 | def encode(arg): | |
80 | return base64.b64encode(arg) |
|
80 | return base64.b64encode(arg) | |
81 | # must pass through canning to properly connect namespaces |
|
81 | # must pass through canning to properly connect namespaces | |
82 | self.assertEqual(encode(b'foo'), b'Zm9v') |
|
82 | self.assertEqual(encode(b'foo'), b'Zm9v') | |
83 |
|
83 | |||
84 | def test_success_only(self): |
|
84 | def test_success_only(self): | |
85 | dep = pmod.Dependency(mixed, success=True, failure=False) |
|
85 | dep = pmod.Dependency(mixed, success=True, failure=False) | |
86 | self.assertUnmet(dep) |
|
86 | self.assertUnmet(dep) | |
87 | self.assertUnreachable(dep) |
|
87 | self.assertUnreachable(dep) | |
88 | dep.all=False |
|
88 | dep.all=False | |
89 | self.assertMet(dep) |
|
89 | self.assertMet(dep) | |
90 | self.assertReachable(dep) |
|
90 | self.assertReachable(dep) | |
91 | dep = pmod.Dependency(completed, success=True, failure=False) |
|
91 | dep = pmod.Dependency(completed, success=True, failure=False) | |
92 | self.assertMet(dep) |
|
92 | self.assertMet(dep) | |
93 | self.assertReachable(dep) |
|
93 | self.assertReachable(dep) | |
94 | dep.all=False |
|
94 | dep.all=False | |
95 | self.assertMet(dep) |
|
95 | self.assertMet(dep) | |
96 | self.assertReachable(dep) |
|
96 | self.assertReachable(dep) | |
97 |
|
97 | |||
98 | def test_failure_only(self): |
|
98 | def test_failure_only(self): | |
99 | dep = pmod.Dependency(mixed, success=False, failure=True) |
|
99 | dep = pmod.Dependency(mixed, success=False, failure=True) | |
100 | self.assertUnmet(dep) |
|
100 | self.assertUnmet(dep) | |
101 | self.assertUnreachable(dep) |
|
101 | self.assertUnreachable(dep) | |
102 | dep.all=False |
|
102 | dep.all=False | |
103 | self.assertMet(dep) |
|
103 | self.assertMet(dep) | |
104 | self.assertReachable(dep) |
|
104 | self.assertReachable(dep) | |
105 | dep = pmod.Dependency(completed, success=False, failure=True) |
|
105 | dep = pmod.Dependency(completed, success=False, failure=True) | |
106 | self.assertUnmet(dep) |
|
106 | self.assertUnmet(dep) | |
107 | self.assertUnreachable(dep) |
|
107 | self.assertUnreachable(dep) | |
108 | dep.all=False |
|
108 | dep.all=False | |
109 | self.assertUnmet(dep) |
|
109 | self.assertUnmet(dep) | |
110 | self.assertUnreachable(dep) |
|
110 | self.assertUnreachable(dep) | |
111 |
|
111 | |||
112 | def test_require_function(self): |
|
112 | def test_require_function(self): | |
113 |
|
113 | |||
114 | @pmod.interactive |
|
114 | @pmod.interactive | |
115 | def bar(a): |
|
115 | def bar(a): | |
116 | return func(a) |
|
116 | return func(a) | |
117 |
|
117 | |||
118 | @pmod.require(func) |
|
118 | @pmod.require(func) | |
119 | @pmod.interactive |
|
119 | @pmod.interactive | |
120 | def bar2(a): |
|
120 | def bar2(a): | |
121 | return func(a) |
|
121 | return func(a) | |
122 |
|
122 | |||
123 | self.client[:].clear() |
|
123 | self.client[:].clear() | |
124 | self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5) |
|
124 | self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5) | |
125 | ar = self.view.apply_async(bar2, 5) |
|
125 | ar = self.view.apply_async(bar2, 5) | |
126 | self.assertEqual(ar.get(5), func(5)) |
|
126 | self.assertEqual(ar.get(5), func(5)) | |
127 |
|
127 | |||
128 | def test_require_object(self): |
|
128 | def test_require_object(self): | |
129 |
|
129 | |||
130 | @pmod.require(foo=func) |
|
130 | @pmod.require(foo=func) | |
131 | @pmod.interactive |
|
131 | @pmod.interactive | |
132 | def bar(a): |
|
132 | def bar(a): | |
133 | return foo(a) |
|
133 | return foo(a) | |
134 |
|
134 | |||
135 | ar = self.view.apply_async(bar, 5) |
|
135 | ar = self.view.apply_async(bar, 5) | |
136 | self.assertEqual(ar.get(5), func(5)) |
|
136 | self.assertEqual(ar.get(5), func(5)) |
@@ -1,884 +1,889 b'' | |||||
1 | """Session object for building, serializing, sending, and receiving messages in |
|
1 | """Session object for building, serializing, sending, and receiving messages in | |
2 | IPython. The Session object supports serialization, HMAC signatures, and |
|
2 | IPython. The Session object supports serialization, HMAC signatures, and | |
3 | metadata on messages. |
|
3 | metadata on messages. | |
4 |
|
4 | |||
5 | Also defined here are utilities for working with Sessions: |
|
5 | Also defined here are utilities for working with Sessions: | |
6 | * A SessionFactory to be used as a base class for configurables that work with |
|
6 | * A SessionFactory to be used as a base class for configurables that work with | |
7 | Sessions. |
|
7 | Sessions. | |
8 | * A Message object for convenience that allows attribute-access to the msg dict. |
|
8 | * A Message object for convenience that allows attribute-access to the msg dict. | |
9 | """ |
|
9 | """ | |
10 |
|
10 | |||
11 | # Copyright (c) IPython Development Team. |
|
11 | # Copyright (c) IPython Development Team. | |
12 | # Distributed under the terms of the Modified BSD License. |
|
12 | # Distributed under the terms of the Modified BSD License. | |
13 |
|
13 | |||
14 | import hashlib |
|
14 | import hashlib | |
15 | import hmac |
|
15 | import hmac | |
16 | import logging |
|
16 | import logging | |
17 | import os |
|
17 | import os | |
18 | import pprint |
|
18 | import pprint | |
19 | import random |
|
19 | import random | |
20 | import uuid |
|
20 | import uuid | |
21 | import warnings |
|
21 | import warnings | |
22 | from datetime import datetime |
|
22 | from datetime import datetime | |
23 |
|
23 | |||
24 | try: |
|
24 | try: | |
25 | import cPickle |
|
25 | import cPickle | |
26 | pickle = cPickle |
|
26 | pickle = cPickle | |
27 | except: |
|
27 | except: | |
28 | cPickle = None |
|
28 | cPickle = None | |
29 | import pickle |
|
29 | import pickle | |
30 |
|
30 | |||
31 | try: |
|
31 | try: | |
|
32 | # py3 | |||
|
33 | PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL | |||
|
34 | except AttributeError: | |||
|
35 | PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL | |||
|
36 | ||||
|
37 | try: | |||
32 | # We are using compare_digest to limit the surface of timing attacks |
|
38 | # We are using compare_digest to limit the surface of timing attacks | |
33 | from hmac import compare_digest |
|
39 | from hmac import compare_digest | |
34 | except ImportError: |
|
40 | except ImportError: | |
35 | # Python < 2.7.7: When digests don't match no feedback is provided, |
|
41 | # Python < 2.7.7: When digests don't match no feedback is provided, | |
36 | # limiting the surface of attack |
|
42 | # limiting the surface of attack | |
37 | def compare_digest(a,b): return a == b |
|
43 | def compare_digest(a,b): return a == b | |
38 |
|
44 | |||
39 | import zmq |
|
45 | import zmq | |
40 | from zmq.utils import jsonapi |
|
46 | from zmq.utils import jsonapi | |
41 | from zmq.eventloop.ioloop import IOLoop |
|
47 | from zmq.eventloop.ioloop import IOLoop | |
42 | from zmq.eventloop.zmqstream import ZMQStream |
|
48 | from zmq.eventloop.zmqstream import ZMQStream | |
43 |
|
49 | |||
44 | from IPython.core.release import kernel_protocol_version |
|
50 | from IPython.core.release import kernel_protocol_version | |
45 | from IPython.config.configurable import Configurable, LoggingConfigurable |
|
51 | from IPython.config.configurable import Configurable, LoggingConfigurable | |
46 | from IPython.utils import io |
|
52 | from IPython.utils import io | |
47 | from IPython.utils.importstring import import_item |
|
53 | from IPython.utils.importstring import import_item | |
48 | from IPython.utils.jsonutil import extract_dates, squash_dates, date_default |
|
54 | from IPython.utils.jsonutil import extract_dates, squash_dates, date_default | |
49 | from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type, |
|
55 | from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type, | |
50 | iteritems) |
|
56 | iteritems) | |
51 | from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set, |
|
57 | from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set, | |
52 | DottedObjectName, CUnicode, Dict, Integer, |
|
58 | DottedObjectName, CUnicode, Dict, Integer, | |
53 | TraitError, |
|
59 | TraitError, | |
54 | ) |
|
60 | ) | |
55 | from IPython.utils.pickleutil import PICKLE_PROTOCOL |
|
|||
56 | from jupyter_client.adapter import adapt |
|
61 | from jupyter_client.adapter import adapt | |
57 |
|
62 | |||
58 | #----------------------------------------------------------------------------- |
|
63 | #----------------------------------------------------------------------------- | |
59 | # utility functions |
|
64 | # utility functions | |
60 | #----------------------------------------------------------------------------- |
|
65 | #----------------------------------------------------------------------------- | |
61 |
|
66 | |||
62 | def squash_unicode(obj): |
|
67 | def squash_unicode(obj): | |
63 | """coerce unicode back to bytestrings.""" |
|
68 | """coerce unicode back to bytestrings.""" | |
64 | if isinstance(obj,dict): |
|
69 | if isinstance(obj,dict): | |
65 | for key in obj.keys(): |
|
70 | for key in obj.keys(): | |
66 | obj[key] = squash_unicode(obj[key]) |
|
71 | obj[key] = squash_unicode(obj[key]) | |
67 | if isinstance(key, unicode_type): |
|
72 | if isinstance(key, unicode_type): | |
68 | obj[squash_unicode(key)] = obj.pop(key) |
|
73 | obj[squash_unicode(key)] = obj.pop(key) | |
69 | elif isinstance(obj, list): |
|
74 | elif isinstance(obj, list): | |
70 | for i,v in enumerate(obj): |
|
75 | for i,v in enumerate(obj): | |
71 | obj[i] = squash_unicode(v) |
|
76 | obj[i] = squash_unicode(v) | |
72 | elif isinstance(obj, unicode_type): |
|
77 | elif isinstance(obj, unicode_type): | |
73 | obj = obj.encode('utf8') |
|
78 | obj = obj.encode('utf8') | |
74 | return obj |
|
79 | return obj | |
75 |
|
80 | |||
76 | #----------------------------------------------------------------------------- |
|
81 | #----------------------------------------------------------------------------- | |
77 | # globals and defaults |
|
82 | # globals and defaults | |
78 | #----------------------------------------------------------------------------- |
|
83 | #----------------------------------------------------------------------------- | |
79 |
|
84 | |||
80 | # default values for the thresholds: |
|
85 | # default values for the thresholds: | |
81 | MAX_ITEMS = 64 |
|
86 | MAX_ITEMS = 64 | |
82 | MAX_BYTES = 1024 |
|
87 | MAX_BYTES = 1024 | |
83 |
|
88 | |||
84 | # ISO8601-ify datetime objects |
|
89 | # ISO8601-ify datetime objects | |
85 | # allow unicode |
|
90 | # allow unicode | |
86 | # disallow nan, because it's not actually valid JSON |
|
91 | # disallow nan, because it's not actually valid JSON | |
87 | json_packer = lambda obj: jsonapi.dumps(obj, default=date_default, |
|
92 | json_packer = lambda obj: jsonapi.dumps(obj, default=date_default, | |
88 | ensure_ascii=False, allow_nan=False, |
|
93 | ensure_ascii=False, allow_nan=False, | |
89 | ) |
|
94 | ) | |
90 | json_unpacker = lambda s: jsonapi.loads(s) |
|
95 | json_unpacker = lambda s: jsonapi.loads(s) | |
91 |
|
96 | |||
92 | pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) |
|
97 | pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL) | |
93 | pickle_unpacker = pickle.loads |
|
98 | pickle_unpacker = pickle.loads | |
94 |
|
99 | |||
95 | default_packer = json_packer |
|
100 | default_packer = json_packer | |
96 | default_unpacker = json_unpacker |
|
101 | default_unpacker = json_unpacker | |
97 |
|
102 | |||
98 | DELIM = b"<IDS|MSG>" |
|
103 | DELIM = b"<IDS|MSG>" | |
99 | # singleton dummy tracker, which will always report as done |
|
104 | # singleton dummy tracker, which will always report as done | |
100 | DONE = zmq.MessageTracker() |
|
105 | DONE = zmq.MessageTracker() | |
101 |
|
106 | |||
102 | #----------------------------------------------------------------------------- |
|
107 | #----------------------------------------------------------------------------- | |
103 | # Mixin tools for apps that use Sessions |
|
108 | # Mixin tools for apps that use Sessions | |
104 | #----------------------------------------------------------------------------- |
|
109 | #----------------------------------------------------------------------------- | |
105 |
|
110 | |||
106 | session_aliases = dict( |
|
111 | session_aliases = dict( | |
107 | ident = 'Session.session', |
|
112 | ident = 'Session.session', | |
108 | user = 'Session.username', |
|
113 | user = 'Session.username', | |
109 | keyfile = 'Session.keyfile', |
|
114 | keyfile = 'Session.keyfile', | |
110 | ) |
|
115 | ) | |
111 |
|
116 | |||
112 | session_flags = { |
|
117 | session_flags = { | |
113 | 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())), |
|
118 | 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())), | |
114 | 'keyfile' : '' }}, |
|
119 | 'keyfile' : '' }}, | |
115 | """Use HMAC digests for authentication of messages. |
|
120 | """Use HMAC digests for authentication of messages. | |
116 | Setting this flag will generate a new UUID to use as the HMAC key. |
|
121 | Setting this flag will generate a new UUID to use as the HMAC key. | |
117 | """), |
|
122 | """), | |
118 | 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }}, |
|
123 | 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }}, | |
119 | """Don't authenticate messages."""), |
|
124 | """Don't authenticate messages."""), | |
120 | } |
|
125 | } | |
121 |
|
126 | |||
122 | def default_secure(cfg): |
|
127 | def default_secure(cfg): | |
123 | """Set the default behavior for a config environment to be secure. |
|
128 | """Set the default behavior for a config environment to be secure. | |
124 |
|
129 | |||
125 | If Session.key/keyfile have not been set, set Session.key to |
|
130 | If Session.key/keyfile have not been set, set Session.key to | |
126 | a new random UUID. |
|
131 | a new random UUID. | |
127 | """ |
|
132 | """ | |
128 | warnings.warn("default_secure is deprecated", DeprecationWarning) |
|
133 | warnings.warn("default_secure is deprecated", DeprecationWarning) | |
129 | if 'Session' in cfg: |
|
134 | if 'Session' in cfg: | |
130 | if 'key' in cfg.Session or 'keyfile' in cfg.Session: |
|
135 | if 'key' in cfg.Session or 'keyfile' in cfg.Session: | |
131 | return |
|
136 | return | |
132 | # key/keyfile not specified, generate new UUID: |
|
137 | # key/keyfile not specified, generate new UUID: | |
133 | cfg.Session.key = str_to_bytes(str(uuid.uuid4())) |
|
138 | cfg.Session.key = str_to_bytes(str(uuid.uuid4())) | |
134 |
|
139 | |||
135 |
|
140 | |||
136 | #----------------------------------------------------------------------------- |
|
141 | #----------------------------------------------------------------------------- | |
137 | # Classes |
|
142 | # Classes | |
138 | #----------------------------------------------------------------------------- |
|
143 | #----------------------------------------------------------------------------- | |
139 |
|
144 | |||
140 | class SessionFactory(LoggingConfigurable): |
|
145 | class SessionFactory(LoggingConfigurable): | |
141 | """The Base class for configurables that have a Session, Context, logger, |
|
146 | """The Base class for configurables that have a Session, Context, logger, | |
142 | and IOLoop. |
|
147 | and IOLoop. | |
143 | """ |
|
148 | """ | |
144 |
|
149 | |||
145 | logname = Unicode('') |
|
150 | logname = Unicode('') | |
146 | def _logname_changed(self, name, old, new): |
|
151 | def _logname_changed(self, name, old, new): | |
147 | self.log = logging.getLogger(new) |
|
152 | self.log = logging.getLogger(new) | |
148 |
|
153 | |||
149 | # not configurable: |
|
154 | # not configurable: | |
150 | context = Instance('zmq.Context') |
|
155 | context = Instance('zmq.Context') | |
151 | def _context_default(self): |
|
156 | def _context_default(self): | |
152 | return zmq.Context.instance() |
|
157 | return zmq.Context.instance() | |
153 |
|
158 | |||
154 | session = Instance('jupyter_client.session.Session', |
|
159 | session = Instance('jupyter_client.session.Session', | |
155 | allow_none=True) |
|
160 | allow_none=True) | |
156 |
|
161 | |||
157 | loop = Instance('zmq.eventloop.ioloop.IOLoop') |
|
162 | loop = Instance('zmq.eventloop.ioloop.IOLoop') | |
158 | def _loop_default(self): |
|
163 | def _loop_default(self): | |
159 | return IOLoop.instance() |
|
164 | return IOLoop.instance() | |
160 |
|
165 | |||
161 | def __init__(self, **kwargs): |
|
166 | def __init__(self, **kwargs): | |
162 | super(SessionFactory, self).__init__(**kwargs) |
|
167 | super(SessionFactory, self).__init__(**kwargs) | |
163 |
|
168 | |||
164 | if self.session is None: |
|
169 | if self.session is None: | |
165 | # construct the session |
|
170 | # construct the session | |
166 | self.session = Session(**kwargs) |
|
171 | self.session = Session(**kwargs) | |
167 |
|
172 | |||
168 |
|
173 | |||
169 | class Message(object): |
|
174 | class Message(object): | |
170 | """A simple message object that maps dict keys to attributes. |
|
175 | """A simple message object that maps dict keys to attributes. | |
171 |
|
176 | |||
172 | A Message can be created from a dict and a dict from a Message instance |
|
177 | A Message can be created from a dict and a dict from a Message instance | |
173 | simply by calling dict(msg_obj).""" |
|
178 | simply by calling dict(msg_obj).""" | |
174 |
|
179 | |||
175 | def __init__(self, msg_dict): |
|
180 | def __init__(self, msg_dict): | |
176 | dct = self.__dict__ |
|
181 | dct = self.__dict__ | |
177 | for k, v in iteritems(dict(msg_dict)): |
|
182 | for k, v in iteritems(dict(msg_dict)): | |
178 | if isinstance(v, dict): |
|
183 | if isinstance(v, dict): | |
179 | v = Message(v) |
|
184 | v = Message(v) | |
180 | dct[k] = v |
|
185 | dct[k] = v | |
181 |
|
186 | |||
182 | # Having this iterator lets dict(msg_obj) work out of the box. |
|
187 | # Having this iterator lets dict(msg_obj) work out of the box. | |
183 | def __iter__(self): |
|
188 | def __iter__(self): | |
184 | return iter(iteritems(self.__dict__)) |
|
189 | return iter(iteritems(self.__dict__)) | |
185 |
|
190 | |||
186 | def __repr__(self): |
|
191 | def __repr__(self): | |
187 | return repr(self.__dict__) |
|
192 | return repr(self.__dict__) | |
188 |
|
193 | |||
189 | def __str__(self): |
|
194 | def __str__(self): | |
190 | return pprint.pformat(self.__dict__) |
|
195 | return pprint.pformat(self.__dict__) | |
191 |
|
196 | |||
192 | def __contains__(self, k): |
|
197 | def __contains__(self, k): | |
193 | return k in self.__dict__ |
|
198 | return k in self.__dict__ | |
194 |
|
199 | |||
195 | def __getitem__(self, k): |
|
200 | def __getitem__(self, k): | |
196 | return self.__dict__[k] |
|
201 | return self.__dict__[k] | |
197 |
|
202 | |||
198 |
|
203 | |||
199 | def msg_header(msg_id, msg_type, username, session): |
|
204 | def msg_header(msg_id, msg_type, username, session): | |
200 | date = datetime.now() |
|
205 | date = datetime.now() | |
201 | version = kernel_protocol_version |
|
206 | version = kernel_protocol_version | |
202 | return locals() |
|
207 | return locals() | |
203 |
|
208 | |||
204 | def extract_header(msg_or_header): |
|
209 | def extract_header(msg_or_header): | |
205 | """Given a message or header, return the header.""" |
|
210 | """Given a message or header, return the header.""" | |
206 | if not msg_or_header: |
|
211 | if not msg_or_header: | |
207 | return {} |
|
212 | return {} | |
208 | try: |
|
213 | try: | |
209 | # See if msg_or_header is the entire message. |
|
214 | # See if msg_or_header is the entire message. | |
210 | h = msg_or_header['header'] |
|
215 | h = msg_or_header['header'] | |
211 | except KeyError: |
|
216 | except KeyError: | |
212 | try: |
|
217 | try: | |
213 | # See if msg_or_header is just the header |
|
218 | # See if msg_or_header is just the header | |
214 | h = msg_or_header['msg_id'] |
|
219 | h = msg_or_header['msg_id'] | |
215 | except KeyError: |
|
220 | except KeyError: | |
216 | raise |
|
221 | raise | |
217 | else: |
|
222 | else: | |
218 | h = msg_or_header |
|
223 | h = msg_or_header | |
219 | if not isinstance(h, dict): |
|
224 | if not isinstance(h, dict): | |
220 | h = dict(h) |
|
225 | h = dict(h) | |
221 | return h |
|
226 | return h | |
222 |
|
227 | |||
223 | class Session(Configurable): |
|
228 | class Session(Configurable): | |
224 | """Object for handling serialization and sending of messages. |
|
229 | """Object for handling serialization and sending of messages. | |
225 |
|
230 | |||
226 | The Session object handles building messages and sending them |
|
231 | The Session object handles building messages and sending them | |
227 | with ZMQ sockets or ZMQStream objects. Objects can communicate with each |
|
232 | with ZMQ sockets or ZMQStream objects. Objects can communicate with each | |
228 | other over the network via Session objects, and only need to work with the |
|
233 | other over the network via Session objects, and only need to work with the | |
229 | dict-based IPython message spec. The Session will handle |
|
234 | dict-based IPython message spec. The Session will handle | |
230 | serialization/deserialization, security, and metadata. |
|
235 | serialization/deserialization, security, and metadata. | |
231 |
|
236 | |||
232 | Sessions support configurable serialization via packer/unpacker traits, |
|
237 | Sessions support configurable serialization via packer/unpacker traits, | |
233 | and signing with HMAC digests via the key/keyfile traits. |
|
238 | and signing with HMAC digests via the key/keyfile traits. | |
234 |
|
239 | |||
235 | Parameters |
|
240 | Parameters | |
236 | ---------- |
|
241 | ---------- | |
237 |
|
242 | |||
238 | debug : bool |
|
243 | debug : bool | |
239 | whether to trigger extra debugging statements |
|
244 | whether to trigger extra debugging statements | |
240 | packer/unpacker : str : 'json', 'pickle' or import_string |
|
245 | packer/unpacker : str : 'json', 'pickle' or import_string | |
241 | importstrings for methods to serialize message parts. If just |
|
246 | importstrings for methods to serialize message parts. If just | |
242 | 'json' or 'pickle', predefined JSON and pickle packers will be used. |
|
247 | 'json' or 'pickle', predefined JSON and pickle packers will be used. | |
243 | Otherwise, the entire importstring must be used. |
|
248 | Otherwise, the entire importstring must be used. | |
244 |
|
249 | |||
245 | The functions must accept at least valid JSON input, and output *bytes*. |
|
250 | The functions must accept at least valid JSON input, and output *bytes*. | |
246 |
|
251 | |||
247 | For example, to use msgpack: |
|
252 | For example, to use msgpack: | |
248 | packer = 'msgpack.packb', unpacker='msgpack.unpackb' |
|
253 | packer = 'msgpack.packb', unpacker='msgpack.unpackb' | |
249 | pack/unpack : callables |
|
254 | pack/unpack : callables | |
250 | You can also set the pack/unpack callables for serialization directly. |
|
255 | You can also set the pack/unpack callables for serialization directly. | |
251 | session : bytes |
|
256 | session : bytes | |
252 | the ID of this Session object. The default is to generate a new UUID. |
|
257 | the ID of this Session object. The default is to generate a new UUID. | |
253 | username : unicode |
|
258 | username : unicode | |
254 | username added to message headers. The default is to ask the OS. |
|
259 | username added to message headers. The default is to ask the OS. | |
255 | key : bytes |
|
260 | key : bytes | |
256 | The key used to initialize an HMAC signature. If unset, messages |
|
261 | The key used to initialize an HMAC signature. If unset, messages | |
257 | will not be signed or checked. |
|
262 | will not be signed or checked. | |
258 | keyfile : filepath |
|
263 | keyfile : filepath | |
259 | The file containing a key. If this is set, `key` will be initialized |
|
264 | The file containing a key. If this is set, `key` will be initialized | |
260 | to the contents of the file. |
|
265 | to the contents of the file. | |
261 |
|
266 | |||
262 | """ |
|
267 | """ | |
263 |
|
268 | |||
264 | debug=Bool(False, config=True, help="""Debug output in the Session""") |
|
269 | debug=Bool(False, config=True, help="""Debug output in the Session""") | |
265 |
|
270 | |||
266 | packer = DottedObjectName('json',config=True, |
|
271 | packer = DottedObjectName('json',config=True, | |
267 | help="""The name of the packer for serializing messages. |
|
272 | help="""The name of the packer for serializing messages. | |
268 | Should be one of 'json', 'pickle', or an import name |
|
273 | Should be one of 'json', 'pickle', or an import name | |
269 | for a custom callable serializer.""") |
|
274 | for a custom callable serializer.""") | |
270 | def _packer_changed(self, name, old, new): |
|
275 | def _packer_changed(self, name, old, new): | |
271 | if new.lower() == 'json': |
|
276 | if new.lower() == 'json': | |
272 | self.pack = json_packer |
|
277 | self.pack = json_packer | |
273 | self.unpack = json_unpacker |
|
278 | self.unpack = json_unpacker | |
274 | self.unpacker = new |
|
279 | self.unpacker = new | |
275 | elif new.lower() == 'pickle': |
|
280 | elif new.lower() == 'pickle': | |
276 | self.pack = pickle_packer |
|
281 | self.pack = pickle_packer | |
277 | self.unpack = pickle_unpacker |
|
282 | self.unpack = pickle_unpacker | |
278 | self.unpacker = new |
|
283 | self.unpacker = new | |
279 | else: |
|
284 | else: | |
280 | self.pack = import_item(str(new)) |
|
285 | self.pack = import_item(str(new)) | |
281 |
|
286 | |||
282 | unpacker = DottedObjectName('json', config=True, |
|
287 | unpacker = DottedObjectName('json', config=True, | |
283 | help="""The name of the unpacker for unserializing messages. |
|
288 | help="""The name of the unpacker for unserializing messages. | |
284 | Only used with custom functions for `packer`.""") |
|
289 | Only used with custom functions for `packer`.""") | |
285 | def _unpacker_changed(self, name, old, new): |
|
290 | def _unpacker_changed(self, name, old, new): | |
286 | if new.lower() == 'json': |
|
291 | if new.lower() == 'json': | |
287 | self.pack = json_packer |
|
292 | self.pack = json_packer | |
288 | self.unpack = json_unpacker |
|
293 | self.unpack = json_unpacker | |
289 | self.packer = new |
|
294 | self.packer = new | |
290 | elif new.lower() == 'pickle': |
|
295 | elif new.lower() == 'pickle': | |
291 | self.pack = pickle_packer |
|
296 | self.pack = pickle_packer | |
292 | self.unpack = pickle_unpacker |
|
297 | self.unpack = pickle_unpacker | |
293 | self.packer = new |
|
298 | self.packer = new | |
294 | else: |
|
299 | else: | |
295 | self.unpack = import_item(str(new)) |
|
300 | self.unpack = import_item(str(new)) | |
296 |
|
301 | |||
297 | session = CUnicode(u'', config=True, |
|
302 | session = CUnicode(u'', config=True, | |
298 | help="""The UUID identifying this session.""") |
|
303 | help="""The UUID identifying this session.""") | |
299 | def _session_default(self): |
|
304 | def _session_default(self): | |
300 | u = unicode_type(uuid.uuid4()) |
|
305 | u = unicode_type(uuid.uuid4()) | |
301 | self.bsession = u.encode('ascii') |
|
306 | self.bsession = u.encode('ascii') | |
302 | return u |
|
307 | return u | |
303 |
|
308 | |||
304 | def _session_changed(self, name, old, new): |
|
309 | def _session_changed(self, name, old, new): | |
305 | self.bsession = self.session.encode('ascii') |
|
310 | self.bsession = self.session.encode('ascii') | |
306 |
|
311 | |||
307 | # bsession is the session as bytes |
|
312 | # bsession is the session as bytes | |
308 | bsession = CBytes(b'') |
|
313 | bsession = CBytes(b'') | |
309 |
|
314 | |||
310 | username = Unicode(str_to_unicode(os.environ.get('USER', 'username')), |
|
315 | username = Unicode(str_to_unicode(os.environ.get('USER', 'username')), | |
311 | help="""Username for the Session. Default is your system username.""", |
|
316 | help="""Username for the Session. Default is your system username.""", | |
312 | config=True) |
|
317 | config=True) | |
313 |
|
318 | |||
314 | metadata = Dict({}, config=True, |
|
319 | metadata = Dict({}, config=True, | |
315 | help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""") |
|
320 | help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""") | |
316 |
|
321 | |||
317 | # if 0, no adapting to do. |
|
322 | # if 0, no adapting to do. | |
318 | adapt_version = Integer(0) |
|
323 | adapt_version = Integer(0) | |
319 |
|
324 | |||
320 | # message signature related traits: |
|
325 | # message signature related traits: | |
321 |
|
326 | |||
322 | key = CBytes(config=True, |
|
327 | key = CBytes(config=True, | |
323 | help="""execution key, for signing messages.""") |
|
328 | help="""execution key, for signing messages.""") | |
324 | def _key_default(self): |
|
329 | def _key_default(self): | |
325 | return str_to_bytes(str(uuid.uuid4())) |
|
330 | return str_to_bytes(str(uuid.uuid4())) | |
326 |
|
331 | |||
327 | def _key_changed(self): |
|
332 | def _key_changed(self): | |
328 | self._new_auth() |
|
333 | self._new_auth() | |
329 |
|
334 | |||
330 | signature_scheme = Unicode('hmac-sha256', config=True, |
|
335 | signature_scheme = Unicode('hmac-sha256', config=True, | |
331 | help="""The digest scheme used to construct the message signatures. |
|
336 | help="""The digest scheme used to construct the message signatures. | |
332 | Must have the form 'hmac-HASH'.""") |
|
337 | Must have the form 'hmac-HASH'.""") | |
333 | def _signature_scheme_changed(self, name, old, new): |
|
338 | def _signature_scheme_changed(self, name, old, new): | |
334 | if not new.startswith('hmac-'): |
|
339 | if not new.startswith('hmac-'): | |
335 | raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) |
|
340 | raise TraitError("signature_scheme must start with 'hmac-', got %r" % new) | |
336 | hash_name = new.split('-', 1)[1] |
|
341 | hash_name = new.split('-', 1)[1] | |
337 | try: |
|
342 | try: | |
338 | self.digest_mod = getattr(hashlib, hash_name) |
|
343 | self.digest_mod = getattr(hashlib, hash_name) | |
339 | except AttributeError: |
|
344 | except AttributeError: | |
340 | raise TraitError("hashlib has no such attribute: %s" % hash_name) |
|
345 | raise TraitError("hashlib has no such attribute: %s" % hash_name) | |
341 | self._new_auth() |
|
346 | self._new_auth() | |
342 |
|
347 | |||
343 | digest_mod = Any() |
|
348 | digest_mod = Any() | |
344 | def _digest_mod_default(self): |
|
349 | def _digest_mod_default(self): | |
345 | return hashlib.sha256 |
|
350 | return hashlib.sha256 | |
346 |
|
351 | |||
347 | auth = Instance(hmac.HMAC, allow_none=True) |
|
352 | auth = Instance(hmac.HMAC, allow_none=True) | |
348 |
|
353 | |||
349 | def _new_auth(self): |
|
354 | def _new_auth(self): | |
350 | if self.key: |
|
355 | if self.key: | |
351 | self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) |
|
356 | self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod) | |
352 | else: |
|
357 | else: | |
353 | self.auth = None |
|
358 | self.auth = None | |
354 |
|
359 | |||
355 | digest_history = Set() |
|
360 | digest_history = Set() | |
356 | digest_history_size = Integer(2**16, config=True, |
|
361 | digest_history_size = Integer(2**16, config=True, | |
357 | help="""The maximum number of digests to remember. |
|
362 | help="""The maximum number of digests to remember. | |
358 |
|
363 | |||
359 | The digest history will be culled when it exceeds this value. |
|
364 | The digest history will be culled when it exceeds this value. | |
360 | """ |
|
365 | """ | |
361 | ) |
|
366 | ) | |
362 |
|
367 | |||
363 | keyfile = Unicode('', config=True, |
|
368 | keyfile = Unicode('', config=True, | |
364 | help="""path to file containing execution key.""") |
|
369 | help="""path to file containing execution key.""") | |
365 | def _keyfile_changed(self, name, old, new): |
|
370 | def _keyfile_changed(self, name, old, new): | |
366 | with open(new, 'rb') as f: |
|
371 | with open(new, 'rb') as f: | |
367 | self.key = f.read().strip() |
|
372 | self.key = f.read().strip() | |
368 |
|
373 | |||
369 | # for protecting against sends from forks |
|
374 | # for protecting against sends from forks | |
370 | pid = Integer() |
|
375 | pid = Integer() | |
371 |
|
376 | |||
372 | # serialization traits: |
|
377 | # serialization traits: | |
373 |
|
378 | |||
374 | pack = Any(default_packer) # the actual packer function |
|
379 | pack = Any(default_packer) # the actual packer function | |
375 | def _pack_changed(self, name, old, new): |
|
380 | def _pack_changed(self, name, old, new): | |
376 | if not callable(new): |
|
381 | if not callable(new): | |
377 | raise TypeError("packer must be callable, not %s"%type(new)) |
|
382 | raise TypeError("packer must be callable, not %s"%type(new)) | |
378 |
|
383 | |||
379 | unpack = Any(default_unpacker) # the actual packer function |
|
384 | unpack = Any(default_unpacker) # the actual packer function | |
380 | def _unpack_changed(self, name, old, new): |
|
385 | def _unpack_changed(self, name, old, new): | |
381 | # unpacker is not checked - it is assumed to be |
|
386 | # unpacker is not checked - it is assumed to be | |
382 | if not callable(new): |
|
387 | if not callable(new): | |
383 | raise TypeError("unpacker must be callable, not %s"%type(new)) |
|
388 | raise TypeError("unpacker must be callable, not %s"%type(new)) | |
384 |
|
389 | |||
385 | # thresholds: |
|
390 | # thresholds: | |
386 | copy_threshold = Integer(2**16, config=True, |
|
391 | copy_threshold = Integer(2**16, config=True, | |
387 | help="Threshold (in bytes) beyond which a buffer should be sent without copying.") |
|
392 | help="Threshold (in bytes) beyond which a buffer should be sent without copying.") | |
388 | buffer_threshold = Integer(MAX_BYTES, config=True, |
|
393 | buffer_threshold = Integer(MAX_BYTES, config=True, | |
389 | help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") |
|
394 | help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.") | |
390 | item_threshold = Integer(MAX_ITEMS, config=True, |
|
395 | item_threshold = Integer(MAX_ITEMS, config=True, | |
391 | help="""The maximum number of items for a container to be introspected for custom serialization. |
|
396 | help="""The maximum number of items for a container to be introspected for custom serialization. | |
392 | Containers larger than this are pickled outright. |
|
397 | Containers larger than this are pickled outright. | |
393 | """ |
|
398 | """ | |
394 | ) |
|
399 | ) | |
395 |
|
400 | |||
396 |
|
401 | |||
397 | def __init__(self, **kwargs): |
|
402 | def __init__(self, **kwargs): | |
398 | """create a Session object |
|
403 | """create a Session object | |
399 |
|
404 | |||
400 | Parameters |
|
405 | Parameters | |
401 | ---------- |
|
406 | ---------- | |
402 |
|
407 | |||
403 | debug : bool |
|
408 | debug : bool | |
404 | whether to trigger extra debugging statements |
|
409 | whether to trigger extra debugging statements | |
405 | packer/unpacker : str : 'json', 'pickle' or import_string |
|
410 | packer/unpacker : str : 'json', 'pickle' or import_string | |
406 | importstrings for methods to serialize message parts. If just |
|
411 | importstrings for methods to serialize message parts. If just | |
407 | 'json' or 'pickle', predefined JSON and pickle packers will be used. |
|
412 | 'json' or 'pickle', predefined JSON and pickle packers will be used. | |
408 | Otherwise, the entire importstring must be used. |
|
413 | Otherwise, the entire importstring must be used. | |
409 |
|
414 | |||
410 | The functions must accept at least valid JSON input, and output |
|
415 | The functions must accept at least valid JSON input, and output | |
411 | *bytes*. |
|
416 | *bytes*. | |
412 |
|
417 | |||
413 | For example, to use msgpack: |
|
418 | For example, to use msgpack: | |
414 | packer = 'msgpack.packb', unpacker='msgpack.unpackb' |
|
419 | packer = 'msgpack.packb', unpacker='msgpack.unpackb' | |
415 | pack/unpack : callables |
|
420 | pack/unpack : callables | |
416 | You can also set the pack/unpack callables for serialization |
|
421 | You can also set the pack/unpack callables for serialization | |
417 | directly. |
|
422 | directly. | |
418 | session : unicode (must be ascii) |
|
423 | session : unicode (must be ascii) | |
419 | the ID of this Session object. The default is to generate a new |
|
424 | the ID of this Session object. The default is to generate a new | |
420 | UUID. |
|
425 | UUID. | |
421 | bsession : bytes |
|
426 | bsession : bytes | |
422 | The session as bytes |
|
427 | The session as bytes | |
423 | username : unicode |
|
428 | username : unicode | |
424 | username added to message headers. The default is to ask the OS. |
|
429 | username added to message headers. The default is to ask the OS. | |
425 | key : bytes |
|
430 | key : bytes | |
426 | The key used to initialize an HMAC signature. If unset, messages |
|
431 | The key used to initialize an HMAC signature. If unset, messages | |
427 | will not be signed or checked. |
|
432 | will not be signed or checked. | |
428 | signature_scheme : str |
|
433 | signature_scheme : str | |
429 | The message digest scheme. Currently must be of the form 'hmac-HASH', |
|
434 | The message digest scheme. Currently must be of the form 'hmac-HASH', | |
430 | where 'HASH' is a hashing function available in Python's hashlib. |
|
435 | where 'HASH' is a hashing function available in Python's hashlib. | |
431 | The default is 'hmac-sha256'. |
|
436 | The default is 'hmac-sha256'. | |
432 | This is ignored if 'key' is empty. |
|
437 | This is ignored if 'key' is empty. | |
433 | keyfile : filepath |
|
438 | keyfile : filepath | |
434 | The file containing a key. If this is set, `key` will be |
|
439 | The file containing a key. If this is set, `key` will be | |
435 | initialized to the contents of the file. |
|
440 | initialized to the contents of the file. | |
436 | """ |
|
441 | """ | |
437 | super(Session, self).__init__(**kwargs) |
|
442 | super(Session, self).__init__(**kwargs) | |
438 | self._check_packers() |
|
443 | self._check_packers() | |
439 | self.none = self.pack({}) |
|
444 | self.none = self.pack({}) | |
440 | # ensure self._session_default() if necessary, so bsession is defined: |
|
445 | # ensure self._session_default() if necessary, so bsession is defined: | |
441 | self.session |
|
446 | self.session | |
442 | self.pid = os.getpid() |
|
447 | self.pid = os.getpid() | |
443 | self._new_auth() |
|
448 | self._new_auth() | |
444 |
|
449 | |||
445 | @property |
|
450 | @property | |
446 | def msg_id(self): |
|
451 | def msg_id(self): | |
447 | """always return new uuid""" |
|
452 | """always return new uuid""" | |
448 | return str(uuid.uuid4()) |
|
453 | return str(uuid.uuid4()) | |
449 |
|
454 | |||
450 | def _check_packers(self): |
|
455 | def _check_packers(self): | |
451 | """check packers for datetime support.""" |
|
456 | """check packers for datetime support.""" | |
452 | pack = self.pack |
|
457 | pack = self.pack | |
453 | unpack = self.unpack |
|
458 | unpack = self.unpack | |
454 |
|
459 | |||
455 | # check simple serialization |
|
460 | # check simple serialization | |
456 | msg = dict(a=[1,'hi']) |
|
461 | msg = dict(a=[1,'hi']) | |
457 | try: |
|
462 | try: | |
458 | packed = pack(msg) |
|
463 | packed = pack(msg) | |
459 | except Exception as e: |
|
464 | except Exception as e: | |
460 | msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" |
|
465 | msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}" | |
461 | if self.packer == 'json': |
|
466 | if self.packer == 'json': | |
462 | jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod |
|
467 | jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod | |
463 | else: |
|
468 | else: | |
464 | jsonmsg = "" |
|
469 | jsonmsg = "" | |
465 | raise ValueError( |
|
470 | raise ValueError( | |
466 | msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) |
|
471 | msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg) | |
467 | ) |
|
472 | ) | |
468 |
|
473 | |||
469 | # ensure packed message is bytes |
|
474 | # ensure packed message is bytes | |
470 | if not isinstance(packed, bytes): |
|
475 | if not isinstance(packed, bytes): | |
471 | raise ValueError("message packed to %r, but bytes are required"%type(packed)) |
|
476 | raise ValueError("message packed to %r, but bytes are required"%type(packed)) | |
472 |
|
477 | |||
473 | # check that unpack is pack's inverse |
|
478 | # check that unpack is pack's inverse | |
474 | try: |
|
479 | try: | |
475 | unpacked = unpack(packed) |
|
480 | unpacked = unpack(packed) | |
476 | assert unpacked == msg |
|
481 | assert unpacked == msg | |
477 | except Exception as e: |
|
482 | except Exception as e: | |
478 | msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" |
|
483 | msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}" | |
479 | if self.packer == 'json': |
|
484 | if self.packer == 'json': | |
480 | jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod |
|
485 | jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod | |
481 | else: |
|
486 | else: | |
482 | jsonmsg = "" |
|
487 | jsonmsg = "" | |
483 | raise ValueError( |
|
488 | raise ValueError( | |
484 | msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) |
|
489 | msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg) | |
485 | ) |
|
490 | ) | |
486 |
|
491 | |||
487 | # check datetime support |
|
492 | # check datetime support | |
488 | msg = dict(t=datetime.now()) |
|
493 | msg = dict(t=datetime.now()) | |
489 | try: |
|
494 | try: | |
490 | unpacked = unpack(pack(msg)) |
|
495 | unpacked = unpack(pack(msg)) | |
491 | if isinstance(unpacked['t'], datetime): |
|
496 | if isinstance(unpacked['t'], datetime): | |
492 | raise ValueError("Shouldn't deserialize to datetime") |
|
497 | raise ValueError("Shouldn't deserialize to datetime") | |
493 | except Exception: |
|
498 | except Exception: | |
494 | self.pack = lambda o: pack(squash_dates(o)) |
|
499 | self.pack = lambda o: pack(squash_dates(o)) | |
495 | self.unpack = lambda s: unpack(s) |
|
500 | self.unpack = lambda s: unpack(s) | |
496 |
|
501 | |||
497 | def msg_header(self, msg_type): |
|
502 | def msg_header(self, msg_type): | |
498 | return msg_header(self.msg_id, msg_type, self.username, self.session) |
|
503 | return msg_header(self.msg_id, msg_type, self.username, self.session) | |
499 |
|
504 | |||
500 | def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): |
|
505 | def msg(self, msg_type, content=None, parent=None, header=None, metadata=None): | |
501 | """Return the nested message dict. |
|
506 | """Return the nested message dict. | |
502 |
|
507 | |||
503 | This format is different from what is sent over the wire. The |
|
508 | This format is different from what is sent over the wire. The | |
504 | serialize/deserialize methods converts this nested message dict to the wire |
|
509 | serialize/deserialize methods converts this nested message dict to the wire | |
505 | format, which is a list of message parts. |
|
510 | format, which is a list of message parts. | |
506 | """ |
|
511 | """ | |
507 | msg = {} |
|
512 | msg = {} | |
508 | header = self.msg_header(msg_type) if header is None else header |
|
513 | header = self.msg_header(msg_type) if header is None else header | |
509 | msg['header'] = header |
|
514 | msg['header'] = header | |
510 | msg['msg_id'] = header['msg_id'] |
|
515 | msg['msg_id'] = header['msg_id'] | |
511 | msg['msg_type'] = header['msg_type'] |
|
516 | msg['msg_type'] = header['msg_type'] | |
512 | msg['parent_header'] = {} if parent is None else extract_header(parent) |
|
517 | msg['parent_header'] = {} if parent is None else extract_header(parent) | |
513 | msg['content'] = {} if content is None else content |
|
518 | msg['content'] = {} if content is None else content | |
514 | msg['metadata'] = self.metadata.copy() |
|
519 | msg['metadata'] = self.metadata.copy() | |
515 | if metadata is not None: |
|
520 | if metadata is not None: | |
516 | msg['metadata'].update(metadata) |
|
521 | msg['metadata'].update(metadata) | |
517 | return msg |
|
522 | return msg | |
518 |
|
523 | |||
519 | def sign(self, msg_list): |
|
524 | def sign(self, msg_list): | |
520 | """Sign a message with HMAC digest. If no auth, return b''. |
|
525 | """Sign a message with HMAC digest. If no auth, return b''. | |
521 |
|
526 | |||
522 | Parameters |
|
527 | Parameters | |
523 | ---------- |
|
528 | ---------- | |
524 | msg_list : list |
|
529 | msg_list : list | |
525 | The [p_header,p_parent,p_content] part of the message list. |
|
530 | The [p_header,p_parent,p_content] part of the message list. | |
526 | """ |
|
531 | """ | |
527 | if self.auth is None: |
|
532 | if self.auth is None: | |
528 | return b'' |
|
533 | return b'' | |
529 | h = self.auth.copy() |
|
534 | h = self.auth.copy() | |
530 | for m in msg_list: |
|
535 | for m in msg_list: | |
531 | h.update(m) |
|
536 | h.update(m) | |
532 | return str_to_bytes(h.hexdigest()) |
|
537 | return str_to_bytes(h.hexdigest()) | |
533 |
|
538 | |||
534 | def serialize(self, msg, ident=None): |
|
539 | def serialize(self, msg, ident=None): | |
535 | """Serialize the message components to bytes. |
|
540 | """Serialize the message components to bytes. | |
536 |
|
541 | |||
537 | This is roughly the inverse of deserialize. The serialize/deserialize |
|
542 | This is roughly the inverse of deserialize. The serialize/deserialize | |
538 | methods work with full message lists, whereas pack/unpack work with |
|
543 | methods work with full message lists, whereas pack/unpack work with | |
539 | the individual message parts in the message list. |
|
544 | the individual message parts in the message list. | |
540 |
|
545 | |||
541 | Parameters |
|
546 | Parameters | |
542 | ---------- |
|
547 | ---------- | |
543 | msg : dict or Message |
|
548 | msg : dict or Message | |
544 | The next message dict as returned by the self.msg method. |
|
549 | The next message dict as returned by the self.msg method. | |
545 |
|
550 | |||
546 | Returns |
|
551 | Returns | |
547 | ------- |
|
552 | ------- | |
548 | msg_list : list |
|
553 | msg_list : list | |
549 | The list of bytes objects to be sent with the format:: |
|
554 | The list of bytes objects to be sent with the format:: | |
550 |
|
555 | |||
551 | [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, |
|
556 | [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent, | |
552 | p_metadata, p_content, buffer1, buffer2, ...] |
|
557 | p_metadata, p_content, buffer1, buffer2, ...] | |
553 |
|
558 | |||
554 | In this list, the ``p_*`` entities are the packed or serialized |
|
559 | In this list, the ``p_*`` entities are the packed or serialized | |
555 | versions, so if JSON is used, these are utf8 encoded JSON strings. |
|
560 | versions, so if JSON is used, these are utf8 encoded JSON strings. | |
556 | """ |
|
561 | """ | |
557 | content = msg.get('content', {}) |
|
562 | content = msg.get('content', {}) | |
558 | if content is None: |
|
563 | if content is None: | |
559 | content = self.none |
|
564 | content = self.none | |
560 | elif isinstance(content, dict): |
|
565 | elif isinstance(content, dict): | |
561 | content = self.pack(content) |
|
566 | content = self.pack(content) | |
562 | elif isinstance(content, bytes): |
|
567 | elif isinstance(content, bytes): | |
563 | # content is already packed, as in a relayed message |
|
568 | # content is already packed, as in a relayed message | |
564 | pass |
|
569 | pass | |
565 | elif isinstance(content, unicode_type): |
|
570 | elif isinstance(content, unicode_type): | |
566 | # should be bytes, but JSON often spits out unicode |
|
571 | # should be bytes, but JSON often spits out unicode | |
567 | content = content.encode('utf8') |
|
572 | content = content.encode('utf8') | |
568 | else: |
|
573 | else: | |
569 | raise TypeError("Content incorrect type: %s"%type(content)) |
|
574 | raise TypeError("Content incorrect type: %s"%type(content)) | |
570 |
|
575 | |||
571 | real_message = [self.pack(msg['header']), |
|
576 | real_message = [self.pack(msg['header']), | |
572 | self.pack(msg['parent_header']), |
|
577 | self.pack(msg['parent_header']), | |
573 | self.pack(msg['metadata']), |
|
578 | self.pack(msg['metadata']), | |
574 | content, |
|
579 | content, | |
575 | ] |
|
580 | ] | |
576 |
|
581 | |||
577 | to_send = [] |
|
582 | to_send = [] | |
578 |
|
583 | |||
579 | if isinstance(ident, list): |
|
584 | if isinstance(ident, list): | |
580 | # accept list of idents |
|
585 | # accept list of idents | |
581 | to_send.extend(ident) |
|
586 | to_send.extend(ident) | |
582 | elif ident is not None: |
|
587 | elif ident is not None: | |
583 | to_send.append(ident) |
|
588 | to_send.append(ident) | |
584 | to_send.append(DELIM) |
|
589 | to_send.append(DELIM) | |
585 |
|
590 | |||
586 | signature = self.sign(real_message) |
|
591 | signature = self.sign(real_message) | |
587 | to_send.append(signature) |
|
592 | to_send.append(signature) | |
588 |
|
593 | |||
589 | to_send.extend(real_message) |
|
594 | to_send.extend(real_message) | |
590 |
|
595 | |||
591 | return to_send |
|
596 | return to_send | |
592 |
|
597 | |||
593 | def send(self, stream, msg_or_type, content=None, parent=None, ident=None, |
|
598 | def send(self, stream, msg_or_type, content=None, parent=None, ident=None, | |
594 | buffers=None, track=False, header=None, metadata=None): |
|
599 | buffers=None, track=False, header=None, metadata=None): | |
595 | """Build and send a message via stream or socket. |
|
600 | """Build and send a message via stream or socket. | |
596 |
|
601 | |||
597 | The message format used by this function internally is as follows: |
|
602 | The message format used by this function internally is as follows: | |
598 |
|
603 | |||
599 | [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, |
|
604 | [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content, | |
600 | buffer1,buffer2,...] |
|
605 | buffer1,buffer2,...] | |
601 |
|
606 | |||
602 | The serialize/deserialize methods convert the nested message dict into this |
|
607 | The serialize/deserialize methods convert the nested message dict into this | |
603 | format. |
|
608 | format. | |
604 |
|
609 | |||
605 | Parameters |
|
610 | Parameters | |
606 | ---------- |
|
611 | ---------- | |
607 |
|
612 | |||
608 | stream : zmq.Socket or ZMQStream |
|
613 | stream : zmq.Socket or ZMQStream | |
609 | The socket-like object used to send the data. |
|
614 | The socket-like object used to send the data. | |
610 | msg_or_type : str or Message/dict |
|
615 | msg_or_type : str or Message/dict | |
611 | Normally, msg_or_type will be a msg_type unless a message is being |
|
616 | Normally, msg_or_type will be a msg_type unless a message is being | |
612 | sent more than once. If a header is supplied, this can be set to |
|
617 | sent more than once. If a header is supplied, this can be set to | |
613 | None and the msg_type will be pulled from the header. |
|
618 | None and the msg_type will be pulled from the header. | |
614 |
|
619 | |||
615 | content : dict or None |
|
620 | content : dict or None | |
616 | The content of the message (ignored if msg_or_type is a message). |
|
621 | The content of the message (ignored if msg_or_type is a message). | |
617 | header : dict or None |
|
622 | header : dict or None | |
618 | The header dict for the message (ignored if msg_to_type is a message). |
|
623 | The header dict for the message (ignored if msg_to_type is a message). | |
619 | parent : Message or dict or None |
|
624 | parent : Message or dict or None | |
620 | The parent or parent header describing the parent of this message |
|
625 | The parent or parent header describing the parent of this message | |
621 | (ignored if msg_or_type is a message). |
|
626 | (ignored if msg_or_type is a message). | |
622 | ident : bytes or list of bytes |
|
627 | ident : bytes or list of bytes | |
623 | The zmq.IDENTITY routing path. |
|
628 | The zmq.IDENTITY routing path. | |
624 | metadata : dict or None |
|
629 | metadata : dict or None | |
625 | The metadata describing the message |
|
630 | The metadata describing the message | |
626 | buffers : list or None |
|
631 | buffers : list or None | |
627 | The already-serialized buffers to be appended to the message. |
|
632 | The already-serialized buffers to be appended to the message. | |
628 | track : bool |
|
633 | track : bool | |
629 | Whether to track. Only for use with Sockets, because ZMQStream |
|
634 | Whether to track. Only for use with Sockets, because ZMQStream | |
630 | objects cannot track messages. |
|
635 | objects cannot track messages. | |
631 |
|
636 | |||
632 |
|
637 | |||
633 | Returns |
|
638 | Returns | |
634 | ------- |
|
639 | ------- | |
635 | msg : dict |
|
640 | msg : dict | |
636 | The constructed message. |
|
641 | The constructed message. | |
637 | """ |
|
642 | """ | |
638 | if not isinstance(stream, zmq.Socket): |
|
643 | if not isinstance(stream, zmq.Socket): | |
639 | # ZMQStreams and dummy sockets do not support tracking. |
|
644 | # ZMQStreams and dummy sockets do not support tracking. | |
640 | track = False |
|
645 | track = False | |
641 |
|
646 | |||
642 | if isinstance(msg_or_type, (Message, dict)): |
|
647 | if isinstance(msg_or_type, (Message, dict)): | |
643 | # We got a Message or message dict, not a msg_type so don't |
|
648 | # We got a Message or message dict, not a msg_type so don't | |
644 | # build a new Message. |
|
649 | # build a new Message. | |
645 | msg = msg_or_type |
|
650 | msg = msg_or_type | |
646 | buffers = buffers or msg.get('buffers', []) |
|
651 | buffers = buffers or msg.get('buffers', []) | |
647 | else: |
|
652 | else: | |
648 | msg = self.msg(msg_or_type, content=content, parent=parent, |
|
653 | msg = self.msg(msg_or_type, content=content, parent=parent, | |
649 | header=header, metadata=metadata) |
|
654 | header=header, metadata=metadata) | |
650 | if not os.getpid() == self.pid: |
|
655 | if not os.getpid() == self.pid: | |
651 | io.rprint("WARNING: attempted to send message from fork") |
|
656 | io.rprint("WARNING: attempted to send message from fork") | |
652 | io.rprint(msg) |
|
657 | io.rprint(msg) | |
653 | return |
|
658 | return | |
654 | buffers = [] if buffers is None else buffers |
|
659 | buffers = [] if buffers is None else buffers | |
655 | if self.adapt_version: |
|
660 | if self.adapt_version: | |
656 | msg = adapt(msg, self.adapt_version) |
|
661 | msg = adapt(msg, self.adapt_version) | |
657 | to_send = self.serialize(msg, ident) |
|
662 | to_send = self.serialize(msg, ident) | |
658 | to_send.extend(buffers) |
|
663 | to_send.extend(buffers) | |
659 | longest = max([ len(s) for s in to_send ]) |
|
664 | longest = max([ len(s) for s in to_send ]) | |
660 | copy = (longest < self.copy_threshold) |
|
665 | copy = (longest < self.copy_threshold) | |
661 |
|
666 | |||
662 | if buffers and track and not copy: |
|
667 | if buffers and track and not copy: | |
663 | # only really track when we are doing zero-copy buffers |
|
668 | # only really track when we are doing zero-copy buffers | |
664 | tracker = stream.send_multipart(to_send, copy=False, track=True) |
|
669 | tracker = stream.send_multipart(to_send, copy=False, track=True) | |
665 | else: |
|
670 | else: | |
666 | # use dummy tracker, which will be done immediately |
|
671 | # use dummy tracker, which will be done immediately | |
667 | tracker = DONE |
|
672 | tracker = DONE | |
668 | stream.send_multipart(to_send, copy=copy) |
|
673 | stream.send_multipart(to_send, copy=copy) | |
669 |
|
674 | |||
670 | if self.debug: |
|
675 | if self.debug: | |
671 | pprint.pprint(msg) |
|
676 | pprint.pprint(msg) | |
672 | pprint.pprint(to_send) |
|
677 | pprint.pprint(to_send) | |
673 | pprint.pprint(buffers) |
|
678 | pprint.pprint(buffers) | |
674 |
|
679 | |||
675 | msg['tracker'] = tracker |
|
680 | msg['tracker'] = tracker | |
676 |
|
681 | |||
677 | return msg |
|
682 | return msg | |
678 |
|
683 | |||
679 | def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): |
|
684 | def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None): | |
680 | """Send a raw message via ident path. |
|
685 | """Send a raw message via ident path. | |
681 |
|
686 | |||
682 | This method is used to send a already serialized message. |
|
687 | This method is used to send a already serialized message. | |
683 |
|
688 | |||
684 | Parameters |
|
689 | Parameters | |
685 | ---------- |
|
690 | ---------- | |
686 | stream : ZMQStream or Socket |
|
691 | stream : ZMQStream or Socket | |
687 | The ZMQ stream or socket to use for sending the message. |
|
692 | The ZMQ stream or socket to use for sending the message. | |
688 | msg_list : list |
|
693 | msg_list : list | |
689 | The serialized list of messages to send. This only includes the |
|
694 | The serialized list of messages to send. This only includes the | |
690 | [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of |
|
695 | [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of | |
691 | the message. |
|
696 | the message. | |
692 | ident : ident or list |
|
697 | ident : ident or list | |
693 | A single ident or a list of idents to use in sending. |
|
698 | A single ident or a list of idents to use in sending. | |
694 | """ |
|
699 | """ | |
695 | to_send = [] |
|
700 | to_send = [] | |
696 | if isinstance(ident, bytes): |
|
701 | if isinstance(ident, bytes): | |
697 | ident = [ident] |
|
702 | ident = [ident] | |
698 | if ident is not None: |
|
703 | if ident is not None: | |
699 | to_send.extend(ident) |
|
704 | to_send.extend(ident) | |
700 |
|
705 | |||
701 | to_send.append(DELIM) |
|
706 | to_send.append(DELIM) | |
702 | to_send.append(self.sign(msg_list)) |
|
707 | to_send.append(self.sign(msg_list)) | |
703 | to_send.extend(msg_list) |
|
708 | to_send.extend(msg_list) | |
704 | stream.send_multipart(to_send, flags, copy=copy) |
|
709 | stream.send_multipart(to_send, flags, copy=copy) | |
705 |
|
710 | |||
706 | def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): |
|
711 | def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True): | |
707 | """Receive and unpack a message. |
|
712 | """Receive and unpack a message. | |
708 |
|
713 | |||
709 | Parameters |
|
714 | Parameters | |
710 | ---------- |
|
715 | ---------- | |
711 | socket : ZMQStream or Socket |
|
716 | socket : ZMQStream or Socket | |
712 | The socket or stream to use in receiving. |
|
717 | The socket or stream to use in receiving. | |
713 |
|
718 | |||
714 | Returns |
|
719 | Returns | |
715 | ------- |
|
720 | ------- | |
716 | [idents], msg |
|
721 | [idents], msg | |
717 | [idents] is a list of idents and msg is a nested message dict of |
|
722 | [idents] is a list of idents and msg is a nested message dict of | |
718 | same format as self.msg returns. |
|
723 | same format as self.msg returns. | |
719 | """ |
|
724 | """ | |
720 | if isinstance(socket, ZMQStream): |
|
725 | if isinstance(socket, ZMQStream): | |
721 | socket = socket.socket |
|
726 | socket = socket.socket | |
722 | try: |
|
727 | try: | |
723 | msg_list = socket.recv_multipart(mode, copy=copy) |
|
728 | msg_list = socket.recv_multipart(mode, copy=copy) | |
724 | except zmq.ZMQError as e: |
|
729 | except zmq.ZMQError as e: | |
725 | if e.errno == zmq.EAGAIN: |
|
730 | if e.errno == zmq.EAGAIN: | |
726 | # We can convert EAGAIN to None as we know in this case |
|
731 | # We can convert EAGAIN to None as we know in this case | |
727 | # recv_multipart won't return None. |
|
732 | # recv_multipart won't return None. | |
728 | return None,None |
|
733 | return None,None | |
729 | else: |
|
734 | else: | |
730 | raise |
|
735 | raise | |
731 | # split multipart message into identity list and message dict |
|
736 | # split multipart message into identity list and message dict | |
732 | # invalid large messages can cause very expensive string comparisons |
|
737 | # invalid large messages can cause very expensive string comparisons | |
733 | idents, msg_list = self.feed_identities(msg_list, copy) |
|
738 | idents, msg_list = self.feed_identities(msg_list, copy) | |
734 | try: |
|
739 | try: | |
735 | return idents, self.deserialize(msg_list, content=content, copy=copy) |
|
740 | return idents, self.deserialize(msg_list, content=content, copy=copy) | |
736 | except Exception as e: |
|
741 | except Exception as e: | |
737 | # TODO: handle it |
|
742 | # TODO: handle it | |
738 | raise e |
|
743 | raise e | |
739 |
|
744 | |||
740 | def feed_identities(self, msg_list, copy=True): |
|
745 | def feed_identities(self, msg_list, copy=True): | |
741 | """Split the identities from the rest of the message. |
|
746 | """Split the identities from the rest of the message. | |
742 |
|
747 | |||
743 | Feed until DELIM is reached, then return the prefix as idents and |
|
748 | Feed until DELIM is reached, then return the prefix as idents and | |
744 | remainder as msg_list. This is easily broken by setting an IDENT to DELIM, |
|
749 | remainder as msg_list. This is easily broken by setting an IDENT to DELIM, | |
745 | but that would be silly. |
|
750 | but that would be silly. | |
746 |
|
751 | |||
747 | Parameters |
|
752 | Parameters | |
748 | ---------- |
|
753 | ---------- | |
749 | msg_list : a list of Message or bytes objects |
|
754 | msg_list : a list of Message or bytes objects | |
750 | The message to be split. |
|
755 | The message to be split. | |
751 | copy : bool |
|
756 | copy : bool | |
752 | flag determining whether the arguments are bytes or Messages |
|
757 | flag determining whether the arguments are bytes or Messages | |
753 |
|
758 | |||
754 | Returns |
|
759 | Returns | |
755 | ------- |
|
760 | ------- | |
756 | (idents, msg_list) : two lists |
|
761 | (idents, msg_list) : two lists | |
757 | idents will always be a list of bytes, each of which is a ZMQ |
|
762 | idents will always be a list of bytes, each of which is a ZMQ | |
758 | identity. msg_list will be a list of bytes or zmq.Messages of the |
|
763 | identity. msg_list will be a list of bytes or zmq.Messages of the | |
759 | form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and |
|
764 | form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and | |
760 | should be unpackable/unserializable via self.deserialize at this |
|
765 | should be unpackable/unserializable via self.deserialize at this | |
761 | point. |
|
766 | point. | |
762 | """ |
|
767 | """ | |
763 | if copy: |
|
768 | if copy: | |
764 | idx = msg_list.index(DELIM) |
|
769 | idx = msg_list.index(DELIM) | |
765 | return msg_list[:idx], msg_list[idx+1:] |
|
770 | return msg_list[:idx], msg_list[idx+1:] | |
766 | else: |
|
771 | else: | |
767 | failed = True |
|
772 | failed = True | |
768 | for idx,m in enumerate(msg_list): |
|
773 | for idx,m in enumerate(msg_list): | |
769 | if m.bytes == DELIM: |
|
774 | if m.bytes == DELIM: | |
770 | failed = False |
|
775 | failed = False | |
771 | break |
|
776 | break | |
772 | if failed: |
|
777 | if failed: | |
773 | raise ValueError("DELIM not in msg_list") |
|
778 | raise ValueError("DELIM not in msg_list") | |
774 | idents, msg_list = msg_list[:idx], msg_list[idx+1:] |
|
779 | idents, msg_list = msg_list[:idx], msg_list[idx+1:] | |
775 | return [m.bytes for m in idents], msg_list |
|
780 | return [m.bytes for m in idents], msg_list | |
776 |
|
781 | |||
777 | def _add_digest(self, signature): |
|
782 | def _add_digest(self, signature): | |
778 | """add a digest to history to protect against replay attacks""" |
|
783 | """add a digest to history to protect against replay attacks""" | |
779 | if self.digest_history_size == 0: |
|
784 | if self.digest_history_size == 0: | |
780 | # no history, never add digests |
|
785 | # no history, never add digests | |
781 | return |
|
786 | return | |
782 |
|
787 | |||
783 | self.digest_history.add(signature) |
|
788 | self.digest_history.add(signature) | |
784 | if len(self.digest_history) > self.digest_history_size: |
|
789 | if len(self.digest_history) > self.digest_history_size: | |
785 | # threshold reached, cull 10% |
|
790 | # threshold reached, cull 10% | |
786 | self._cull_digest_history() |
|
791 | self._cull_digest_history() | |
787 |
|
792 | |||
788 | def _cull_digest_history(self): |
|
793 | def _cull_digest_history(self): | |
789 | """cull the digest history |
|
794 | """cull the digest history | |
790 |
|
795 | |||
791 | Removes a randomly selected 10% of the digest history |
|
796 | Removes a randomly selected 10% of the digest history | |
792 | """ |
|
797 | """ | |
793 | current = len(self.digest_history) |
|
798 | current = len(self.digest_history) | |
794 | n_to_cull = max(int(current // 10), current - self.digest_history_size) |
|
799 | n_to_cull = max(int(current // 10), current - self.digest_history_size) | |
795 | if n_to_cull >= current: |
|
800 | if n_to_cull >= current: | |
796 | self.digest_history = set() |
|
801 | self.digest_history = set() | |
797 | return |
|
802 | return | |
798 | to_cull = random.sample(self.digest_history, n_to_cull) |
|
803 | to_cull = random.sample(self.digest_history, n_to_cull) | |
799 | self.digest_history.difference_update(to_cull) |
|
804 | self.digest_history.difference_update(to_cull) | |
800 |
|
805 | |||
801 | def deserialize(self, msg_list, content=True, copy=True): |
|
806 | def deserialize(self, msg_list, content=True, copy=True): | |
802 | """Unserialize a msg_list to a nested message dict. |
|
807 | """Unserialize a msg_list to a nested message dict. | |
803 |
|
808 | |||
804 | This is roughly the inverse of serialize. The serialize/deserialize |
|
809 | This is roughly the inverse of serialize. The serialize/deserialize | |
805 | methods work with full message lists, whereas pack/unpack work with |
|
810 | methods work with full message lists, whereas pack/unpack work with | |
806 | the individual message parts in the message list. |
|
811 | the individual message parts in the message list. | |
807 |
|
812 | |||
808 | Parameters |
|
813 | Parameters | |
809 | ---------- |
|
814 | ---------- | |
810 | msg_list : list of bytes or Message objects |
|
815 | msg_list : list of bytes or Message objects | |
811 | The list of message parts of the form [HMAC,p_header,p_parent, |
|
816 | The list of message parts of the form [HMAC,p_header,p_parent, | |
812 | p_metadata,p_content,buffer1,buffer2,...]. |
|
817 | p_metadata,p_content,buffer1,buffer2,...]. | |
813 | content : bool (True) |
|
818 | content : bool (True) | |
814 | Whether to unpack the content dict (True), or leave it packed |
|
819 | Whether to unpack the content dict (True), or leave it packed | |
815 | (False). |
|
820 | (False). | |
816 | copy : bool (True) |
|
821 | copy : bool (True) | |
817 | Whether msg_list contains bytes (True) or the non-copying Message |
|
822 | Whether msg_list contains bytes (True) or the non-copying Message | |
818 | objects in each place (False). |
|
823 | objects in each place (False). | |
819 |
|
824 | |||
820 | Returns |
|
825 | Returns | |
821 | ------- |
|
826 | ------- | |
822 | msg : dict |
|
827 | msg : dict | |
823 | The nested message dict with top-level keys [header, parent_header, |
|
828 | The nested message dict with top-level keys [header, parent_header, | |
824 | content, buffers]. The buffers are returned as memoryviews. |
|
829 | content, buffers]. The buffers are returned as memoryviews. | |
825 | """ |
|
830 | """ | |
826 | minlen = 5 |
|
831 | minlen = 5 | |
827 | message = {} |
|
832 | message = {} | |
828 | if not copy: |
|
833 | if not copy: | |
829 | # pyzmq didn't copy the first parts of the message, so we'll do it |
|
834 | # pyzmq didn't copy the first parts of the message, so we'll do it | |
830 | for i in range(minlen): |
|
835 | for i in range(minlen): | |
831 | msg_list[i] = msg_list[i].bytes |
|
836 | msg_list[i] = msg_list[i].bytes | |
832 | if self.auth is not None: |
|
837 | if self.auth is not None: | |
833 | signature = msg_list[0] |
|
838 | signature = msg_list[0] | |
834 | if not signature: |
|
839 | if not signature: | |
835 | raise ValueError("Unsigned Message") |
|
840 | raise ValueError("Unsigned Message") | |
836 | if signature in self.digest_history: |
|
841 | if signature in self.digest_history: | |
837 | raise ValueError("Duplicate Signature: %r" % signature) |
|
842 | raise ValueError("Duplicate Signature: %r" % signature) | |
838 | self._add_digest(signature) |
|
843 | self._add_digest(signature) | |
839 | check = self.sign(msg_list[1:5]) |
|
844 | check = self.sign(msg_list[1:5]) | |
840 | if not compare_digest(signature, check): |
|
845 | if not compare_digest(signature, check): | |
841 | raise ValueError("Invalid Signature: %r" % signature) |
|
846 | raise ValueError("Invalid Signature: %r" % signature) | |
842 | if not len(msg_list) >= minlen: |
|
847 | if not len(msg_list) >= minlen: | |
843 | raise TypeError("malformed message, must have at least %i elements"%minlen) |
|
848 | raise TypeError("malformed message, must have at least %i elements"%minlen) | |
844 | header = self.unpack(msg_list[1]) |
|
849 | header = self.unpack(msg_list[1]) | |
845 | message['header'] = extract_dates(header) |
|
850 | message['header'] = extract_dates(header) | |
846 | message['msg_id'] = header['msg_id'] |
|
851 | message['msg_id'] = header['msg_id'] | |
847 | message['msg_type'] = header['msg_type'] |
|
852 | message['msg_type'] = header['msg_type'] | |
848 | message['parent_header'] = extract_dates(self.unpack(msg_list[2])) |
|
853 | message['parent_header'] = extract_dates(self.unpack(msg_list[2])) | |
849 | message['metadata'] = self.unpack(msg_list[3]) |
|
854 | message['metadata'] = self.unpack(msg_list[3]) | |
850 | if content: |
|
855 | if content: | |
851 | message['content'] = self.unpack(msg_list[4]) |
|
856 | message['content'] = self.unpack(msg_list[4]) | |
852 | else: |
|
857 | else: | |
853 | message['content'] = msg_list[4] |
|
858 | message['content'] = msg_list[4] | |
854 | buffers = [memoryview(b) for b in msg_list[5:]] |
|
859 | buffers = [memoryview(b) for b in msg_list[5:]] | |
855 | if buffers and buffers[0].shape is None: |
|
860 | if buffers and buffers[0].shape is None: | |
856 | # force copy to workaround pyzmq #646 |
|
861 | # force copy to workaround pyzmq #646 | |
857 | buffers = [memoryview(b.bytes) for b in msg_list[5:]] |
|
862 | buffers = [memoryview(b.bytes) for b in msg_list[5:]] | |
858 | message['buffers'] = buffers |
|
863 | message['buffers'] = buffers | |
859 | # adapt to the current version |
|
864 | # adapt to the current version | |
860 | return adapt(message) |
|
865 | return adapt(message) | |
861 |
|
866 | |||
862 | def unserialize(self, *args, **kwargs): |
|
867 | def unserialize(self, *args, **kwargs): | |
863 | warnings.warn( |
|
868 | warnings.warn( | |
864 | "Session.unserialize is deprecated. Use Session.deserialize.", |
|
869 | "Session.unserialize is deprecated. Use Session.deserialize.", | |
865 | DeprecationWarning, |
|
870 | DeprecationWarning, | |
866 | ) |
|
871 | ) | |
867 | return self.deserialize(*args, **kwargs) |
|
872 | return self.deserialize(*args, **kwargs) | |
868 |
|
873 | |||
869 |
|
874 | |||
870 | def test_msg2obj(): |
|
875 | def test_msg2obj(): | |
871 | am = dict(x=1) |
|
876 | am = dict(x=1) | |
872 | ao = Message(am) |
|
877 | ao = Message(am) | |
873 | assert ao.x == am['x'] |
|
878 | assert ao.x == am['x'] | |
874 |
|
879 | |||
875 | am['y'] = dict(z=1) |
|
880 | am['y'] = dict(z=1) | |
876 | ao = Message(am) |
|
881 | ao = Message(am) | |
877 | assert ao.y.z == am['y']['z'] |
|
882 | assert ao.y.z == am['y']['z'] | |
878 |
|
883 | |||
879 | k1, k2 = 'y', 'z' |
|
884 | k1, k2 = 'y', 'z' | |
880 | assert ao[k1][k2] == am[k1][k2] |
|
885 | assert ao[k1][k2] == am[k1][k2] | |
881 |
|
886 | |||
882 | am2 = dict(ao) |
|
887 | am2 = dict(ao) | |
883 | assert am['x'] == am2['x'] |
|
888 | assert am['x'] == am2['x'] | |
884 | assert am['y']['z'] == am2['y']['z'] |
|
889 | assert am['y']['z'] == am2['y']['z'] |
General Comments 0
You need to be logged in to leave comments.
Login now