##// END OF EJS Templates
update pickleutil imports
Min RK -
Show More
@@ -1,425 +1,426 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Pickle related utilities. Perhaps this should be called 'can'."""
2 """Pickle related utilities. Perhaps this should be called 'can'."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import copy
7 import copy
8 import logging
8 import logging
9 import sys
9 import sys
10 from types import FunctionType
10 from types import FunctionType
11
11
12 try:
12 try:
13 import cPickle as pickle
13 import cPickle as pickle
14 except ImportError:
14 except ImportError:
15 import pickle
15 import pickle
16
16
17 from IPython.utils import py3compat
18 from IPython.utils.importstring import import_item
19 from IPython.utils.py3compat import string_types, iteritems, buffer_to_bytes_py2
20
17 from . import codeutil # This registers a hook when it's imported
21 from . import codeutil # This registers a hook when it's imported
18 from . import py3compat
19 from .importstring import import_item
20 from .py3compat import string_types, iteritems, buffer_to_bytes_py2
21
22
22 from IPython.config import Application
23 from IPython.config import Application
23 from IPython.utils.log import get_logger
24 from IPython.utils.log import get_logger
24
25
25 if py3compat.PY3:
26 if py3compat.PY3:
26 buffer = memoryview
27 buffer = memoryview
27 class_type = type
28 class_type = type
28 else:
29 else:
29 from types import ClassType
30 from types import ClassType
30 class_type = (type, ClassType)
31 class_type = (type, ClassType)
31
32
32 try:
33 try:
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 except AttributeError:
35 except AttributeError:
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36
37
37 def _get_cell_type(a=None):
38 def _get_cell_type(a=None):
38 """the type of a closure cell doesn't seem to be importable,
39 """the type of a closure cell doesn't seem to be importable,
39 so just create one
40 so just create one
40 """
41 """
41 def inner():
42 def inner():
42 return a
43 return a
43 return type(py3compat.get_closure(inner)[0])
44 return type(py3compat.get_closure(inner)[0])
44
45
45 cell_type = _get_cell_type()
46 cell_type = _get_cell_type()
46
47
47 #-------------------------------------------------------------------------------
48 #-------------------------------------------------------------------------------
48 # Functions
49 # Functions
49 #-------------------------------------------------------------------------------
50 #-------------------------------------------------------------------------------
50
51
51
52
52 def use_dill():
53 def use_dill():
53 """use dill to expand serialization support
54 """use dill to expand serialization support
54
55
55 adds support for object methods and closures to serialization.
56 adds support for object methods and closures to serialization.
56 """
57 """
57 # import dill causes most of the magic
58 # import dill causes most of the magic
58 import dill
59 import dill
59
60
60 # dill doesn't work with cPickle,
61 # dill doesn't work with cPickle,
61 # tell the two relevant modules to use plain pickle
62 # tell the two relevant modules to use plain pickle
62
63
63 global pickle
64 global pickle
64 pickle = dill
65 pickle = dill
65
66
66 try:
67 try:
67 from IPython.kernel.zmq import serialize
68 from IPython.kernel.zmq import serialize
68 except ImportError:
69 except ImportError:
69 pass
70 pass
70 else:
71 else:
71 serialize.pickle = dill
72 serialize.pickle = dill
72
73
73 # disable special function handling, let dill take care of it
74 # disable special function handling, let dill take care of it
74 can_map.pop(FunctionType, None)
75 can_map.pop(FunctionType, None)
75
76
76 def use_cloudpickle():
77 def use_cloudpickle():
77 """use cloudpickle to expand serialization support
78 """use cloudpickle to expand serialization support
78
79
79 adds support for object methods and closures to serialization.
80 adds support for object methods and closures to serialization.
80 """
81 """
81 from cloud.serialization import cloudpickle
82 from cloud.serialization import cloudpickle
82
83
83 global pickle
84 global pickle
84 pickle = cloudpickle
85 pickle = cloudpickle
85
86
86 try:
87 try:
87 from IPython.kernel.zmq import serialize
88 from IPython.kernel.zmq import serialize
88 except ImportError:
89 except ImportError:
89 pass
90 pass
90 else:
91 else:
91 serialize.pickle = cloudpickle
92 serialize.pickle = cloudpickle
92
93
93 # disable special function handling, let cloudpickle take care of it
94 # disable special function handling, let cloudpickle take care of it
94 can_map.pop(FunctionType, None)
95 can_map.pop(FunctionType, None)
95
96
96
97
97 #-------------------------------------------------------------------------------
98 #-------------------------------------------------------------------------------
98 # Classes
99 # Classes
99 #-------------------------------------------------------------------------------
100 #-------------------------------------------------------------------------------
100
101
101
102
102 class CannedObject(object):
103 class CannedObject(object):
103 def __init__(self, obj, keys=[], hook=None):
104 def __init__(self, obj, keys=[], hook=None):
104 """can an object for safe pickling
105 """can an object for safe pickling
105
106
106 Parameters
107 Parameters
107 ==========
108 ==========
108
109
109 obj:
110 obj:
110 The object to be canned
111 The object to be canned
111 keys: list (optional)
112 keys: list (optional)
112 list of attribute names that will be explicitly canned / uncanned
113 list of attribute names that will be explicitly canned / uncanned
113 hook: callable (optional)
114 hook: callable (optional)
114 An optional extra callable,
115 An optional extra callable,
115 which can do additional processing of the uncanned object.
116 which can do additional processing of the uncanned object.
116
117
117 large data may be offloaded into the buffers list,
118 large data may be offloaded into the buffers list,
118 used for zero-copy transfers.
119 used for zero-copy transfers.
119 """
120 """
120 self.keys = keys
121 self.keys = keys
121 self.obj = copy.copy(obj)
122 self.obj = copy.copy(obj)
122 self.hook = can(hook)
123 self.hook = can(hook)
123 for key in keys:
124 for key in keys:
124 setattr(self.obj, key, can(getattr(obj, key)))
125 setattr(self.obj, key, can(getattr(obj, key)))
125
126
126 self.buffers = []
127 self.buffers = []
127
128
128 def get_object(self, g=None):
129 def get_object(self, g=None):
129 if g is None:
130 if g is None:
130 g = {}
131 g = {}
131 obj = self.obj
132 obj = self.obj
132 for key in self.keys:
133 for key in self.keys:
133 setattr(obj, key, uncan(getattr(obj, key), g))
134 setattr(obj, key, uncan(getattr(obj, key), g))
134
135
135 if self.hook:
136 if self.hook:
136 self.hook = uncan(self.hook, g)
137 self.hook = uncan(self.hook, g)
137 self.hook(obj, g)
138 self.hook(obj, g)
138 return self.obj
139 return self.obj
139
140
140
141
141 class Reference(CannedObject):
142 class Reference(CannedObject):
142 """object for wrapping a remote reference by name."""
143 """object for wrapping a remote reference by name."""
143 def __init__(self, name):
144 def __init__(self, name):
144 if not isinstance(name, string_types):
145 if not isinstance(name, string_types):
145 raise TypeError("illegal name: %r"%name)
146 raise TypeError("illegal name: %r"%name)
146 self.name = name
147 self.name = name
147 self.buffers = []
148 self.buffers = []
148
149
149 def __repr__(self):
150 def __repr__(self):
150 return "<Reference: %r>"%self.name
151 return "<Reference: %r>"%self.name
151
152
152 def get_object(self, g=None):
153 def get_object(self, g=None):
153 if g is None:
154 if g is None:
154 g = {}
155 g = {}
155
156
156 return eval(self.name, g)
157 return eval(self.name, g)
157
158
158
159
159 class CannedCell(CannedObject):
160 class CannedCell(CannedObject):
160 """Can a closure cell"""
161 """Can a closure cell"""
161 def __init__(self, cell):
162 def __init__(self, cell):
162 self.cell_contents = can(cell.cell_contents)
163 self.cell_contents = can(cell.cell_contents)
163
164
164 def get_object(self, g=None):
165 def get_object(self, g=None):
165 cell_contents = uncan(self.cell_contents, g)
166 cell_contents = uncan(self.cell_contents, g)
166 def inner():
167 def inner():
167 return cell_contents
168 return cell_contents
168 return py3compat.get_closure(inner)[0]
169 return py3compat.get_closure(inner)[0]
169
170
170
171
171 class CannedFunction(CannedObject):
172 class CannedFunction(CannedObject):
172
173
173 def __init__(self, f):
174 def __init__(self, f):
174 self._check_type(f)
175 self._check_type(f)
175 self.code = f.__code__
176 self.code = f.__code__
176 if f.__defaults__:
177 if f.__defaults__:
177 self.defaults = [ can(fd) for fd in f.__defaults__ ]
178 self.defaults = [ can(fd) for fd in f.__defaults__ ]
178 else:
179 else:
179 self.defaults = None
180 self.defaults = None
180
181
181 closure = py3compat.get_closure(f)
182 closure = py3compat.get_closure(f)
182 if closure:
183 if closure:
183 self.closure = tuple( can(cell) for cell in closure )
184 self.closure = tuple( can(cell) for cell in closure )
184 else:
185 else:
185 self.closure = None
186 self.closure = None
186
187
187 self.module = f.__module__ or '__main__'
188 self.module = f.__module__ or '__main__'
188 self.__name__ = f.__name__
189 self.__name__ = f.__name__
189 self.buffers = []
190 self.buffers = []
190
191
191 def _check_type(self, obj):
192 def _check_type(self, obj):
192 assert isinstance(obj, FunctionType), "Not a function type"
193 assert isinstance(obj, FunctionType), "Not a function type"
193
194
194 def get_object(self, g=None):
195 def get_object(self, g=None):
195 # try to load function back into its module:
196 # try to load function back into its module:
196 if not self.module.startswith('__'):
197 if not self.module.startswith('__'):
197 __import__(self.module)
198 __import__(self.module)
198 g = sys.modules[self.module].__dict__
199 g = sys.modules[self.module].__dict__
199
200
200 if g is None:
201 if g is None:
201 g = {}
202 g = {}
202 if self.defaults:
203 if self.defaults:
203 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
204 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
204 else:
205 else:
205 defaults = None
206 defaults = None
206 if self.closure:
207 if self.closure:
207 closure = tuple(uncan(cell, g) for cell in self.closure)
208 closure = tuple(uncan(cell, g) for cell in self.closure)
208 else:
209 else:
209 closure = None
210 closure = None
210 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
211 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
211 return newFunc
212 return newFunc
212
213
213 class CannedClass(CannedObject):
214 class CannedClass(CannedObject):
214
215
215 def __init__(self, cls):
216 def __init__(self, cls):
216 self._check_type(cls)
217 self._check_type(cls)
217 self.name = cls.__name__
218 self.name = cls.__name__
218 self.old_style = not isinstance(cls, type)
219 self.old_style = not isinstance(cls, type)
219 self._canned_dict = {}
220 self._canned_dict = {}
220 for k,v in cls.__dict__.items():
221 for k,v in cls.__dict__.items():
221 if k not in ('__weakref__', '__dict__'):
222 if k not in ('__weakref__', '__dict__'):
222 self._canned_dict[k] = can(v)
223 self._canned_dict[k] = can(v)
223 if self.old_style:
224 if self.old_style:
224 mro = []
225 mro = []
225 else:
226 else:
226 mro = cls.mro()
227 mro = cls.mro()
227
228
228 self.parents = [ can(c) for c in mro[1:] ]
229 self.parents = [ can(c) for c in mro[1:] ]
229 self.buffers = []
230 self.buffers = []
230
231
231 def _check_type(self, obj):
232 def _check_type(self, obj):
232 assert isinstance(obj, class_type), "Not a class type"
233 assert isinstance(obj, class_type), "Not a class type"
233
234
234 def get_object(self, g=None):
235 def get_object(self, g=None):
235 parents = tuple(uncan(p, g) for p in self.parents)
236 parents = tuple(uncan(p, g) for p in self.parents)
236 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
237 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
237
238
238 class CannedArray(CannedObject):
239 class CannedArray(CannedObject):
239 def __init__(self, obj):
240 def __init__(self, obj):
240 from numpy import ascontiguousarray
241 from numpy import ascontiguousarray
241 self.shape = obj.shape
242 self.shape = obj.shape
242 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
243 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
243 self.pickled = False
244 self.pickled = False
244 if sum(obj.shape) == 0:
245 if sum(obj.shape) == 0:
245 self.pickled = True
246 self.pickled = True
246 elif obj.dtype == 'O':
247 elif obj.dtype == 'O':
247 # can't handle object dtype with buffer approach
248 # can't handle object dtype with buffer approach
248 self.pickled = True
249 self.pickled = True
249 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
250 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
250 self.pickled = True
251 self.pickled = True
251 if self.pickled:
252 if self.pickled:
252 # just pickle it
253 # just pickle it
253 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
254 self.buffers = [pickle.dumps(obj, PICKLE_PROTOCOL)]
254 else:
255 else:
255 # ensure contiguous
256 # ensure contiguous
256 obj = ascontiguousarray(obj, dtype=None)
257 obj = ascontiguousarray(obj, dtype=None)
257 self.buffers = [buffer(obj)]
258 self.buffers = [buffer(obj)]
258
259
259 def get_object(self, g=None):
260 def get_object(self, g=None):
260 from numpy import frombuffer
261 from numpy import frombuffer
261 data = self.buffers[0]
262 data = self.buffers[0]
262 if self.pickled:
263 if self.pickled:
263 # we just pickled it
264 # we just pickled it
264 return pickle.loads(buffer_to_bytes_py2(data))
265 return pickle.loads(buffer_to_bytes_py2(data))
265 else:
266 else:
266 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
267 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
267
268
268
269
269 class CannedBytes(CannedObject):
270 class CannedBytes(CannedObject):
270 wrap = bytes
271 wrap = bytes
271 def __init__(self, obj):
272 def __init__(self, obj):
272 self.buffers = [obj]
273 self.buffers = [obj]
273
274
274 def get_object(self, g=None):
275 def get_object(self, g=None):
275 data = self.buffers[0]
276 data = self.buffers[0]
276 return self.wrap(data)
277 return self.wrap(data)
277
278
278 def CannedBuffer(CannedBytes):
279 def CannedBuffer(CannedBytes):
279 wrap = buffer
280 wrap = buffer
280
281
281 #-------------------------------------------------------------------------------
282 #-------------------------------------------------------------------------------
282 # Functions
283 # Functions
283 #-------------------------------------------------------------------------------
284 #-------------------------------------------------------------------------------
284
285
285 def _import_mapping(mapping, original=None):
286 def _import_mapping(mapping, original=None):
286 """import any string-keys in a type mapping
287 """import any string-keys in a type mapping
287
288
288 """
289 """
289 log = get_logger()
290 log = get_logger()
290 log.debug("Importing canning map")
291 log.debug("Importing canning map")
291 for key,value in list(mapping.items()):
292 for key,value in list(mapping.items()):
292 if isinstance(key, string_types):
293 if isinstance(key, string_types):
293 try:
294 try:
294 cls = import_item(key)
295 cls = import_item(key)
295 except Exception:
296 except Exception:
296 if original and key not in original:
297 if original and key not in original:
297 # only message on user-added classes
298 # only message on user-added classes
298 log.error("canning class not importable: %r", key, exc_info=True)
299 log.error("canning class not importable: %r", key, exc_info=True)
299 mapping.pop(key)
300 mapping.pop(key)
300 else:
301 else:
301 mapping[cls] = mapping.pop(key)
302 mapping[cls] = mapping.pop(key)
302
303
303 def istype(obj, check):
304 def istype(obj, check):
304 """like isinstance(obj, check), but strict
305 """like isinstance(obj, check), but strict
305
306
306 This won't catch subclasses.
307 This won't catch subclasses.
307 """
308 """
308 if isinstance(check, tuple):
309 if isinstance(check, tuple):
309 for cls in check:
310 for cls in check:
310 if type(obj) is cls:
311 if type(obj) is cls:
311 return True
312 return True
312 return False
313 return False
313 else:
314 else:
314 return type(obj) is check
315 return type(obj) is check
315
316
316 def can(obj):
317 def can(obj):
317 """prepare an object for pickling"""
318 """prepare an object for pickling"""
318
319
319 import_needed = False
320 import_needed = False
320
321
321 for cls,canner in iteritems(can_map):
322 for cls,canner in iteritems(can_map):
322 if isinstance(cls, string_types):
323 if isinstance(cls, string_types):
323 import_needed = True
324 import_needed = True
324 break
325 break
325 elif istype(obj, cls):
326 elif istype(obj, cls):
326 return canner(obj)
327 return canner(obj)
327
328
328 if import_needed:
329 if import_needed:
329 # perform can_map imports, then try again
330 # perform can_map imports, then try again
330 # this will usually only happen once
331 # this will usually only happen once
331 _import_mapping(can_map, _original_can_map)
332 _import_mapping(can_map, _original_can_map)
332 return can(obj)
333 return can(obj)
333
334
334 return obj
335 return obj
335
336
336 def can_class(obj):
337 def can_class(obj):
337 if isinstance(obj, class_type) and obj.__module__ == '__main__':
338 if isinstance(obj, class_type) and obj.__module__ == '__main__':
338 return CannedClass(obj)
339 return CannedClass(obj)
339 else:
340 else:
340 return obj
341 return obj
341
342
342 def can_dict(obj):
343 def can_dict(obj):
343 """can the *values* of a dict"""
344 """can the *values* of a dict"""
344 if istype(obj, dict):
345 if istype(obj, dict):
345 newobj = {}
346 newobj = {}
346 for k, v in iteritems(obj):
347 for k, v in iteritems(obj):
347 newobj[k] = can(v)
348 newobj[k] = can(v)
348 return newobj
349 return newobj
349 else:
350 else:
350 return obj
351 return obj
351
352
352 sequence_types = (list, tuple, set)
353 sequence_types = (list, tuple, set)
353
354
354 def can_sequence(obj):
355 def can_sequence(obj):
355 """can the elements of a sequence"""
356 """can the elements of a sequence"""
356 if istype(obj, sequence_types):
357 if istype(obj, sequence_types):
357 t = type(obj)
358 t = type(obj)
358 return t([can(i) for i in obj])
359 return t([can(i) for i in obj])
359 else:
360 else:
360 return obj
361 return obj
361
362
362 def uncan(obj, g=None):
363 def uncan(obj, g=None):
363 """invert canning"""
364 """invert canning"""
364
365
365 import_needed = False
366 import_needed = False
366 for cls,uncanner in iteritems(uncan_map):
367 for cls,uncanner in iteritems(uncan_map):
367 if isinstance(cls, string_types):
368 if isinstance(cls, string_types):
368 import_needed = True
369 import_needed = True
369 break
370 break
370 elif isinstance(obj, cls):
371 elif isinstance(obj, cls):
371 return uncanner(obj, g)
372 return uncanner(obj, g)
372
373
373 if import_needed:
374 if import_needed:
374 # perform uncan_map imports, then try again
375 # perform uncan_map imports, then try again
375 # this will usually only happen once
376 # this will usually only happen once
376 _import_mapping(uncan_map, _original_uncan_map)
377 _import_mapping(uncan_map, _original_uncan_map)
377 return uncan(obj, g)
378 return uncan(obj, g)
378
379
379 return obj
380 return obj
380
381
381 def uncan_dict(obj, g=None):
382 def uncan_dict(obj, g=None):
382 if istype(obj, dict):
383 if istype(obj, dict):
383 newobj = {}
384 newobj = {}
384 for k, v in iteritems(obj):
385 for k, v in iteritems(obj):
385 newobj[k] = uncan(v,g)
386 newobj[k] = uncan(v,g)
386 return newobj
387 return newobj
387 else:
388 else:
388 return obj
389 return obj
389
390
390 def uncan_sequence(obj, g=None):
391 def uncan_sequence(obj, g=None):
391 if istype(obj, sequence_types):
392 if istype(obj, sequence_types):
392 t = type(obj)
393 t = type(obj)
393 return t([uncan(i,g) for i in obj])
394 return t([uncan(i,g) for i in obj])
394 else:
395 else:
395 return obj
396 return obj
396
397
397 def _uncan_dependent_hook(dep, g=None):
398 def _uncan_dependent_hook(dep, g=None):
398 dep.check_dependency()
399 dep.check_dependency()
399
400
400 def can_dependent(obj):
401 def can_dependent(obj):
401 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
402 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
402
403
403 #-------------------------------------------------------------------------------
404 #-------------------------------------------------------------------------------
404 # API dictionaries
405 # API dictionaries
405 #-------------------------------------------------------------------------------
406 #-------------------------------------------------------------------------------
406
407
407 # These dicts can be extended for custom serialization of new objects
408 # These dicts can be extended for custom serialization of new objects
408
409
409 can_map = {
410 can_map = {
410 'IPython.parallel.dependent' : can_dependent,
411 'IPython.parallel.dependent' : can_dependent,
411 'numpy.ndarray' : CannedArray,
412 'numpy.ndarray' : CannedArray,
412 FunctionType : CannedFunction,
413 FunctionType : CannedFunction,
413 bytes : CannedBytes,
414 bytes : CannedBytes,
414 buffer : CannedBuffer,
415 buffer : CannedBuffer,
415 cell_type : CannedCell,
416 cell_type : CannedCell,
416 class_type : can_class,
417 class_type : can_class,
417 }
418 }
418
419
419 uncan_map = {
420 uncan_map = {
420 CannedObject : lambda obj, g: obj.get_object(g),
421 CannedObject : lambda obj, g: obj.get_object(g),
421 }
422 }
422
423
423 # for use in _import_mapping:
424 # for use in _import_mapping:
424 _original_can_map = can_map.copy()
425 _original_can_map = can_map.copy()
425 _original_uncan_map = uncan_map.copy()
426 _original_uncan_map = uncan_map.copy()
@@ -1,179 +1,179 b''
1 """serialization utilities for apply messages"""
1 """serialization utilities for apply messages"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 try:
6 try:
7 import cPickle
7 import cPickle
8 pickle = cPickle
8 pickle = cPickle
9 except:
9 except:
10 cPickle = None
10 cPickle = None
11 import pickle
11 import pickle
12
12
13 # IPython imports
13 # IPython imports
14 from IPython.utils.py3compat import PY3, buffer_to_bytes_py2
14 from IPython.utils.py3compat import PY3, buffer_to_bytes_py2
15 from IPython.utils.data import flatten
15 from IPython.utils.data import flatten
16 from IPython.utils.pickleutil import (
16 from ipython_kernel.pickleutil import (
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
18 istype, sequence_types, PICKLE_PROTOCOL,
18 istype, sequence_types, PICKLE_PROTOCOL,
19 )
19 )
20 from jupyter_client.session import MAX_ITEMS, MAX_BYTES
20 from jupyter_client.session import MAX_ITEMS, MAX_BYTES
21
21
22
22
23 if PY3:
23 if PY3:
24 buffer = memoryview
24 buffer = memoryview
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Serialization Functions
27 # Serialization Functions
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30
30
31 def _extract_buffers(obj, threshold=MAX_BYTES):
31 def _extract_buffers(obj, threshold=MAX_BYTES):
32 """extract buffers larger than a certain threshold"""
32 """extract buffers larger than a certain threshold"""
33 buffers = []
33 buffers = []
34 if isinstance(obj, CannedObject) and obj.buffers:
34 if isinstance(obj, CannedObject) and obj.buffers:
35 for i,buf in enumerate(obj.buffers):
35 for i,buf in enumerate(obj.buffers):
36 if len(buf) > threshold:
36 if len(buf) > threshold:
37 # buffer larger than threshold, prevent pickling
37 # buffer larger than threshold, prevent pickling
38 obj.buffers[i] = None
38 obj.buffers[i] = None
39 buffers.append(buf)
39 buffers.append(buf)
40 elif isinstance(buf, buffer):
40 elif isinstance(buf, buffer):
41 # buffer too small for separate send, coerce to bytes
41 # buffer too small for separate send, coerce to bytes
42 # because pickling buffer objects just results in broken pointers
42 # because pickling buffer objects just results in broken pointers
43 obj.buffers[i] = bytes(buf)
43 obj.buffers[i] = bytes(buf)
44 return buffers
44 return buffers
45
45
46 def _restore_buffers(obj, buffers):
46 def _restore_buffers(obj, buffers):
47 """restore buffers extracted by """
47 """restore buffers extracted by """
48 if isinstance(obj, CannedObject) and obj.buffers:
48 if isinstance(obj, CannedObject) and obj.buffers:
49 for i,buf in enumerate(obj.buffers):
49 for i,buf in enumerate(obj.buffers):
50 if buf is None:
50 if buf is None:
51 obj.buffers[i] = buffers.pop(0)
51 obj.buffers[i] = buffers.pop(0)
52
52
53 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
53 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
54 """Serialize an object into a list of sendable buffers.
54 """Serialize an object into a list of sendable buffers.
55
55
56 Parameters
56 Parameters
57 ----------
57 ----------
58
58
59 obj : object
59 obj : object
60 The object to be serialized
60 The object to be serialized
61 buffer_threshold : int
61 buffer_threshold : int
62 The threshold (in bytes) for pulling out data buffers
62 The threshold (in bytes) for pulling out data buffers
63 to avoid pickling them.
63 to avoid pickling them.
64 item_threshold : int
64 item_threshold : int
65 The maximum number of items over which canning will iterate.
65 The maximum number of items over which canning will iterate.
66 Containers (lists, dicts) larger than this will be pickled without
66 Containers (lists, dicts) larger than this will be pickled without
67 introspection.
67 introspection.
68
68
69 Returns
69 Returns
70 -------
70 -------
71 [bufs] : list of buffers representing the serialized object.
71 [bufs] : list of buffers representing the serialized object.
72 """
72 """
73 buffers = []
73 buffers = []
74 if istype(obj, sequence_types) and len(obj) < item_threshold:
74 if istype(obj, sequence_types) and len(obj) < item_threshold:
75 cobj = can_sequence(obj)
75 cobj = can_sequence(obj)
76 for c in cobj:
76 for c in cobj:
77 buffers.extend(_extract_buffers(c, buffer_threshold))
77 buffers.extend(_extract_buffers(c, buffer_threshold))
78 elif istype(obj, dict) and len(obj) < item_threshold:
78 elif istype(obj, dict) and len(obj) < item_threshold:
79 cobj = {}
79 cobj = {}
80 for k in sorted(obj):
80 for k in sorted(obj):
81 c = can(obj[k])
81 c = can(obj[k])
82 buffers.extend(_extract_buffers(c, buffer_threshold))
82 buffers.extend(_extract_buffers(c, buffer_threshold))
83 cobj[k] = c
83 cobj[k] = c
84 else:
84 else:
85 cobj = can(obj)
85 cobj = can(obj)
86 buffers.extend(_extract_buffers(cobj, buffer_threshold))
86 buffers.extend(_extract_buffers(cobj, buffer_threshold))
87
87
88 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
88 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
89 return buffers
89 return buffers
90
90
91 def deserialize_object(buffers, g=None):
91 def deserialize_object(buffers, g=None):
92 """reconstruct an object serialized by serialize_object from data buffers.
92 """reconstruct an object serialized by serialize_object from data buffers.
93
93
94 Parameters
94 Parameters
95 ----------
95 ----------
96
96
97 bufs : list of buffers/bytes
97 bufs : list of buffers/bytes
98
98
99 g : globals to be used when uncanning
99 g : globals to be used when uncanning
100
100
101 Returns
101 Returns
102 -------
102 -------
103
103
104 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
104 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
105 """
105 """
106 bufs = list(buffers)
106 bufs = list(buffers)
107 pobj = buffer_to_bytes_py2(bufs.pop(0))
107 pobj = buffer_to_bytes_py2(bufs.pop(0))
108 canned = pickle.loads(pobj)
108 canned = pickle.loads(pobj)
109 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
109 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
110 for c in canned:
110 for c in canned:
111 _restore_buffers(c, bufs)
111 _restore_buffers(c, bufs)
112 newobj = uncan_sequence(canned, g)
112 newobj = uncan_sequence(canned, g)
113 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
113 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
114 newobj = {}
114 newobj = {}
115 for k in sorted(canned):
115 for k in sorted(canned):
116 c = canned[k]
116 c = canned[k]
117 _restore_buffers(c, bufs)
117 _restore_buffers(c, bufs)
118 newobj[k] = uncan(c, g)
118 newobj[k] = uncan(c, g)
119 else:
119 else:
120 _restore_buffers(canned, bufs)
120 _restore_buffers(canned, bufs)
121 newobj = uncan(canned, g)
121 newobj = uncan(canned, g)
122
122
123 return newobj, bufs
123 return newobj, bufs
124
124
125 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
125 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
126 """pack up a function, args, and kwargs to be sent over the wire
126 """pack up a function, args, and kwargs to be sent over the wire
127
127
128 Each element of args/kwargs will be canned for special treatment,
128 Each element of args/kwargs will be canned for special treatment,
129 but inspection will not go any deeper than that.
129 but inspection will not go any deeper than that.
130
130
131 Any object whose data is larger than `threshold` will not have their data copied
131 Any object whose data is larger than `threshold` will not have their data copied
132 (only numpy arrays and bytes/buffers support zero-copy)
132 (only numpy arrays and bytes/buffers support zero-copy)
133
133
134 Message will be a list of bytes/buffers of the format:
134 Message will be a list of bytes/buffers of the format:
135
135
136 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
136 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
137
137
138 With length at least two + len(args) + len(kwargs)
138 With length at least two + len(args) + len(kwargs)
139 """
139 """
140
140
141 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
141 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
142
142
143 kw_keys = sorted(kwargs.keys())
143 kw_keys = sorted(kwargs.keys())
144 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
144 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
145
145
146 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
146 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
147
147
148 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
148 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
149 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
149 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
150 msg.extend(arg_bufs)
150 msg.extend(arg_bufs)
151 msg.extend(kwarg_bufs)
151 msg.extend(kwarg_bufs)
152
152
153 return msg
153 return msg
154
154
155 def unpack_apply_message(bufs, g=None, copy=True):
155 def unpack_apply_message(bufs, g=None, copy=True):
156 """unpack f,args,kwargs from buffers packed by pack_apply_message()
156 """unpack f,args,kwargs from buffers packed by pack_apply_message()
157 Returns: original f,args,kwargs"""
157 Returns: original f,args,kwargs"""
158 bufs = list(bufs) # allow us to pop
158 bufs = list(bufs) # allow us to pop
159 assert len(bufs) >= 2, "not enough buffers!"
159 assert len(bufs) >= 2, "not enough buffers!"
160 pf = buffer_to_bytes_py2(bufs.pop(0))
160 pf = buffer_to_bytes_py2(bufs.pop(0))
161 f = uncan(pickle.loads(pf), g)
161 f = uncan(pickle.loads(pf), g)
162 pinfo = buffer_to_bytes_py2(bufs.pop(0))
162 pinfo = buffer_to_bytes_py2(bufs.pop(0))
163 info = pickle.loads(pinfo)
163 info = pickle.loads(pinfo)
164 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
164 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
165
165
166 args = []
166 args = []
167 for i in range(info['nargs']):
167 for i in range(info['nargs']):
168 arg, arg_bufs = deserialize_object(arg_bufs, g)
168 arg, arg_bufs = deserialize_object(arg_bufs, g)
169 args.append(arg)
169 args.append(arg)
170 args = tuple(args)
170 args = tuple(args)
171 assert not arg_bufs, "Shouldn't be any arg bufs left over"
171 assert not arg_bufs, "Shouldn't be any arg bufs left over"
172
172
173 kwargs = {}
173 kwargs = {}
174 for key in info['kw_keys']:
174 for key in info['kw_keys']:
175 kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g)
175 kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g)
176 kwargs[key] = kwarg
176 kwargs[key] = kwarg
177 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
177 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
178
178
179 return f,args,kwargs
179 return f,args,kwargs
@@ -1,62 +1,62 b''
1
1
2 import pickle
2 import pickle
3
3
4 import nose.tools as nt
4 import nose.tools as nt
5 from IPython.utils import codeutil
5 from ipython_kernel import codeutil
6 from IPython.utils.pickleutil import can, uncan
6 from ipython_kernel.pickleutil import can, uncan
7
7
8 def interactive(f):
8 def interactive(f):
9 f.__module__ = '__main__'
9 f.__module__ = '__main__'
10 return f
10 return f
11
11
12 def dumps(obj):
12 def dumps(obj):
13 return pickle.dumps(can(obj))
13 return pickle.dumps(can(obj))
14
14
15 def loads(obj):
15 def loads(obj):
16 return uncan(pickle.loads(obj))
16 return uncan(pickle.loads(obj))
17
17
18 def test_no_closure():
18 def test_no_closure():
19 @interactive
19 @interactive
20 def foo():
20 def foo():
21 a = 5
21 a = 5
22 return a
22 return a
23
23
24 pfoo = dumps(foo)
24 pfoo = dumps(foo)
25 bar = loads(pfoo)
25 bar = loads(pfoo)
26 nt.assert_equal(foo(), bar())
26 nt.assert_equal(foo(), bar())
27
27
28 def test_generator_closure():
28 def test_generator_closure():
29 # this only creates a closure on Python 3
29 # this only creates a closure on Python 3
30 @interactive
30 @interactive
31 def foo():
31 def foo():
32 i = 'i'
32 i = 'i'
33 r = [ i for j in (1,2) ]
33 r = [ i for j in (1,2) ]
34 return r
34 return r
35
35
36 pfoo = dumps(foo)
36 pfoo = dumps(foo)
37 bar = loads(pfoo)
37 bar = loads(pfoo)
38 nt.assert_equal(foo(), bar())
38 nt.assert_equal(foo(), bar())
39
39
40 def test_nested_closure():
40 def test_nested_closure():
41 @interactive
41 @interactive
42 def foo():
42 def foo():
43 i = 'i'
43 i = 'i'
44 def g():
44 def g():
45 return i
45 return i
46 return g()
46 return g()
47
47
48 pfoo = dumps(foo)
48 pfoo = dumps(foo)
49 bar = loads(pfoo)
49 bar = loads(pfoo)
50 nt.assert_equal(foo(), bar())
50 nt.assert_equal(foo(), bar())
51
51
52 def test_closure():
52 def test_closure():
53 i = 'i'
53 i = 'i'
54 @interactive
54 @interactive
55 def foo():
55 def foo():
56 return i
56 return i
57
57
58 pfoo = dumps(foo)
58 pfoo = dumps(foo)
59 bar = loads(pfoo)
59 bar = loads(pfoo)
60 nt.assert_equal(foo(), bar())
60 nt.assert_equal(foo(), bar())
61
61
62 No newline at end of file
62
@@ -1,208 +1,208 b''
1 """test serialization tools"""
1 """test serialization tools"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import pickle
6 import pickle
7 from collections import namedtuple
7 from collections import namedtuple
8
8
9 import nose.tools as nt
9 import nose.tools as nt
10
10
11 # from unittest import TestCaes
11 # from unittest import TestCaes
12 from ipython_kernel.serialize import serialize_object, deserialize_object
12 from ipython_kernel.serialize import serialize_object, deserialize_object
13 from IPython.testing import decorators as dec
13 from IPython.testing import decorators as dec
14 from IPython.utils.pickleutil import CannedArray, CannedClass
14 from ipython_kernel.pickleutil import CannedArray, CannedClass
15 from IPython.utils.py3compat import iteritems
15 from IPython.utils.py3compat import iteritems
16 from IPython.parallel import interactive
16 from IPython.parallel import interactive
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Globals and Utilities
19 # Globals and Utilities
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 def roundtrip(obj):
22 def roundtrip(obj):
23 """roundtrip an object through serialization"""
23 """roundtrip an object through serialization"""
24 bufs = serialize_object(obj)
24 bufs = serialize_object(obj)
25 obj2, remainder = deserialize_object(bufs)
25 obj2, remainder = deserialize_object(bufs)
26 nt.assert_equals(remainder, [])
26 nt.assert_equals(remainder, [])
27 return obj2
27 return obj2
28
28
29 class C(object):
29 class C(object):
30 """dummy class for """
30 """dummy class for """
31
31
32 def __init__(self, **kwargs):
32 def __init__(self, **kwargs):
33 for key,value in iteritems(kwargs):
33 for key,value in iteritems(kwargs):
34 setattr(self, key, value)
34 setattr(self, key, value)
35
35
36 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
36 SHAPES = ((100,), (1024,10), (10,8,6,5), (), (0,))
37 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
37 DTYPES = ('uint8', 'float64', 'int32', [('g', 'float32')], '|S10')
38
38
39 #-------------------------------------------------------------------------------
39 #-------------------------------------------------------------------------------
40 # Tests
40 # Tests
41 #-------------------------------------------------------------------------------
41 #-------------------------------------------------------------------------------
42
42
43 def new_array(shape, dtype):
43 def new_array(shape, dtype):
44 import numpy
44 import numpy
45 return numpy.random.random(shape).astype(dtype)
45 return numpy.random.random(shape).astype(dtype)
46
46
47 def test_roundtrip_simple():
47 def test_roundtrip_simple():
48 for obj in [
48 for obj in [
49 'hello',
49 'hello',
50 dict(a='b', b=10),
50 dict(a='b', b=10),
51 [1,2,'hi'],
51 [1,2,'hi'],
52 (b'123', 'hello'),
52 (b'123', 'hello'),
53 ]:
53 ]:
54 obj2 = roundtrip(obj)
54 obj2 = roundtrip(obj)
55 nt.assert_equal(obj, obj2)
55 nt.assert_equal(obj, obj2)
56
56
57 def test_roundtrip_nested():
57 def test_roundtrip_nested():
58 for obj in [
58 for obj in [
59 dict(a=range(5), b={1:b'hello'}),
59 dict(a=range(5), b={1:b'hello'}),
60 [range(5),[range(3),(1,[b'whoda'])]],
60 [range(5),[range(3),(1,[b'whoda'])]],
61 ]:
61 ]:
62 obj2 = roundtrip(obj)
62 obj2 = roundtrip(obj)
63 nt.assert_equal(obj, obj2)
63 nt.assert_equal(obj, obj2)
64
64
65 def test_roundtrip_buffered():
65 def test_roundtrip_buffered():
66 for obj in [
66 for obj in [
67 dict(a=b"x"*1025),
67 dict(a=b"x"*1025),
68 b"hello"*500,
68 b"hello"*500,
69 [b"hello"*501, 1,2,3]
69 [b"hello"*501, 1,2,3]
70 ]:
70 ]:
71 bufs = serialize_object(obj)
71 bufs = serialize_object(obj)
72 nt.assert_equal(len(bufs), 2)
72 nt.assert_equal(len(bufs), 2)
73 obj2, remainder = deserialize_object(bufs)
73 obj2, remainder = deserialize_object(bufs)
74 nt.assert_equal(remainder, [])
74 nt.assert_equal(remainder, [])
75 nt.assert_equal(obj, obj2)
75 nt.assert_equal(obj, obj2)
76
76
77 @dec.skip_without('numpy')
77 @dec.skip_without('numpy')
78 def test_numpy():
78 def test_numpy():
79 import numpy
79 import numpy
80 from numpy.testing.utils import assert_array_equal
80 from numpy.testing.utils import assert_array_equal
81 for shape in SHAPES:
81 for shape in SHAPES:
82 for dtype in DTYPES:
82 for dtype in DTYPES:
83 A = new_array(shape, dtype=dtype)
83 A = new_array(shape, dtype=dtype)
84 bufs = serialize_object(A)
84 bufs = serialize_object(A)
85 B, r = deserialize_object(bufs)
85 B, r = deserialize_object(bufs)
86 nt.assert_equal(r, [])
86 nt.assert_equal(r, [])
87 nt.assert_equal(A.shape, B.shape)
87 nt.assert_equal(A.shape, B.shape)
88 nt.assert_equal(A.dtype, B.dtype)
88 nt.assert_equal(A.dtype, B.dtype)
89 assert_array_equal(A,B)
89 assert_array_equal(A,B)
90
90
91 @dec.skip_without('numpy')
91 @dec.skip_without('numpy')
92 def test_recarray():
92 def test_recarray():
93 import numpy
93 import numpy
94 from numpy.testing.utils import assert_array_equal
94 from numpy.testing.utils import assert_array_equal
95 for shape in SHAPES:
95 for shape in SHAPES:
96 for dtype in [
96 for dtype in [
97 [('f', float), ('s', '|S10')],
97 [('f', float), ('s', '|S10')],
98 [('n', int), ('s', '|S1'), ('u', 'uint32')],
98 [('n', int), ('s', '|S1'), ('u', 'uint32')],
99 ]:
99 ]:
100 A = new_array(shape, dtype=dtype)
100 A = new_array(shape, dtype=dtype)
101
101
102 bufs = serialize_object(A)
102 bufs = serialize_object(A)
103 B, r = deserialize_object(bufs)
103 B, r = deserialize_object(bufs)
104 nt.assert_equal(r, [])
104 nt.assert_equal(r, [])
105 nt.assert_equal(A.shape, B.shape)
105 nt.assert_equal(A.shape, B.shape)
106 nt.assert_equal(A.dtype, B.dtype)
106 nt.assert_equal(A.dtype, B.dtype)
107 assert_array_equal(A,B)
107 assert_array_equal(A,B)
108
108
109 @dec.skip_without('numpy')
109 @dec.skip_without('numpy')
110 def test_numpy_in_seq():
110 def test_numpy_in_seq():
111 import numpy
111 import numpy
112 from numpy.testing.utils import assert_array_equal
112 from numpy.testing.utils import assert_array_equal
113 for shape in SHAPES:
113 for shape in SHAPES:
114 for dtype in DTYPES:
114 for dtype in DTYPES:
115 A = new_array(shape, dtype=dtype)
115 A = new_array(shape, dtype=dtype)
116 bufs = serialize_object((A,1,2,b'hello'))
116 bufs = serialize_object((A,1,2,b'hello'))
117 canned = pickle.loads(bufs[0])
117 canned = pickle.loads(bufs[0])
118 nt.assert_is_instance(canned[0], CannedArray)
118 nt.assert_is_instance(canned[0], CannedArray)
119 tup, r = deserialize_object(bufs)
119 tup, r = deserialize_object(bufs)
120 B = tup[0]
120 B = tup[0]
121 nt.assert_equal(r, [])
121 nt.assert_equal(r, [])
122 nt.assert_equal(A.shape, B.shape)
122 nt.assert_equal(A.shape, B.shape)
123 nt.assert_equal(A.dtype, B.dtype)
123 nt.assert_equal(A.dtype, B.dtype)
124 assert_array_equal(A,B)
124 assert_array_equal(A,B)
125
125
126 @dec.skip_without('numpy')
126 @dec.skip_without('numpy')
127 def test_numpy_in_dict():
127 def test_numpy_in_dict():
128 import numpy
128 import numpy
129 from numpy.testing.utils import assert_array_equal
129 from numpy.testing.utils import assert_array_equal
130 for shape in SHAPES:
130 for shape in SHAPES:
131 for dtype in DTYPES:
131 for dtype in DTYPES:
132 A = new_array(shape, dtype=dtype)
132 A = new_array(shape, dtype=dtype)
133 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
133 bufs = serialize_object(dict(a=A,b=1,c=range(20)))
134 canned = pickle.loads(bufs[0])
134 canned = pickle.loads(bufs[0])
135 nt.assert_is_instance(canned['a'], CannedArray)
135 nt.assert_is_instance(canned['a'], CannedArray)
136 d, r = deserialize_object(bufs)
136 d, r = deserialize_object(bufs)
137 B = d['a']
137 B = d['a']
138 nt.assert_equal(r, [])
138 nt.assert_equal(r, [])
139 nt.assert_equal(A.shape, B.shape)
139 nt.assert_equal(A.shape, B.shape)
140 nt.assert_equal(A.dtype, B.dtype)
140 nt.assert_equal(A.dtype, B.dtype)
141 assert_array_equal(A,B)
141 assert_array_equal(A,B)
142
142
143 def test_class():
143 def test_class():
144 @interactive
144 @interactive
145 class C(object):
145 class C(object):
146 a=5
146 a=5
147 bufs = serialize_object(dict(C=C))
147 bufs = serialize_object(dict(C=C))
148 canned = pickle.loads(bufs[0])
148 canned = pickle.loads(bufs[0])
149 nt.assert_is_instance(canned['C'], CannedClass)
149 nt.assert_is_instance(canned['C'], CannedClass)
150 d, r = deserialize_object(bufs)
150 d, r = deserialize_object(bufs)
151 C2 = d['C']
151 C2 = d['C']
152 nt.assert_equal(C2.a, C.a)
152 nt.assert_equal(C2.a, C.a)
153
153
154 def test_class_oldstyle():
154 def test_class_oldstyle():
155 @interactive
155 @interactive
156 class C:
156 class C:
157 a=5
157 a=5
158
158
159 bufs = serialize_object(dict(C=C))
159 bufs = serialize_object(dict(C=C))
160 canned = pickle.loads(bufs[0])
160 canned = pickle.loads(bufs[0])
161 nt.assert_is_instance(canned['C'], CannedClass)
161 nt.assert_is_instance(canned['C'], CannedClass)
162 d, r = deserialize_object(bufs)
162 d, r = deserialize_object(bufs)
163 C2 = d['C']
163 C2 = d['C']
164 nt.assert_equal(C2.a, C.a)
164 nt.assert_equal(C2.a, C.a)
165
165
166 def test_tuple():
166 def test_tuple():
167 tup = (lambda x:x, 1)
167 tup = (lambda x:x, 1)
168 bufs = serialize_object(tup)
168 bufs = serialize_object(tup)
169 canned = pickle.loads(bufs[0])
169 canned = pickle.loads(bufs[0])
170 nt.assert_is_instance(canned, tuple)
170 nt.assert_is_instance(canned, tuple)
171 t2, r = deserialize_object(bufs)
171 t2, r = deserialize_object(bufs)
172 nt.assert_equal(t2[0](t2[1]), tup[0](tup[1]))
172 nt.assert_equal(t2[0](t2[1]), tup[0](tup[1]))
173
173
174 point = namedtuple('point', 'x y')
174 point = namedtuple('point', 'x y')
175
175
176 def test_namedtuple():
176 def test_namedtuple():
177 p = point(1,2)
177 p = point(1,2)
178 bufs = serialize_object(p)
178 bufs = serialize_object(p)
179 canned = pickle.loads(bufs[0])
179 canned = pickle.loads(bufs[0])
180 nt.assert_is_instance(canned, point)
180 nt.assert_is_instance(canned, point)
181 p2, r = deserialize_object(bufs, globals())
181 p2, r = deserialize_object(bufs, globals())
182 nt.assert_equal(p2.x, p.x)
182 nt.assert_equal(p2.x, p.x)
183 nt.assert_equal(p2.y, p.y)
183 nt.assert_equal(p2.y, p.y)
184
184
185 def test_list():
185 def test_list():
186 lis = [lambda x:x, 1]
186 lis = [lambda x:x, 1]
187 bufs = serialize_object(lis)
187 bufs = serialize_object(lis)
188 canned = pickle.loads(bufs[0])
188 canned = pickle.loads(bufs[0])
189 nt.assert_is_instance(canned, list)
189 nt.assert_is_instance(canned, list)
190 l2, r = deserialize_object(bufs)
190 l2, r = deserialize_object(bufs)
191 nt.assert_equal(l2[0](l2[1]), lis[0](lis[1]))
191 nt.assert_equal(l2[0](l2[1]), lis[0](lis[1]))
192
192
193 def test_class_inheritance():
193 def test_class_inheritance():
194 @interactive
194 @interactive
195 class C(object):
195 class C(object):
196 a=5
196 a=5
197
197
198 @interactive
198 @interactive
199 class D(C):
199 class D(C):
200 b=10
200 b=10
201
201
202 bufs = serialize_object(dict(D=D))
202 bufs = serialize_object(dict(D=D))
203 canned = pickle.loads(bufs[0])
203 canned = pickle.loads(bufs[0])
204 nt.assert_is_instance(canned['D'], CannedClass)
204 nt.assert_is_instance(canned['D'], CannedClass)
205 d, r = deserialize_object(bufs)
205 d, r = deserialize_object(bufs)
206 D2 = d['D']
206 D2 = d['D']
207 nt.assert_equal(D2.a, D.a)
207 nt.assert_equal(D2.a, D.a)
208 nt.assert_equal(D2.b, D.b)
208 nt.assert_equal(D2.b, D.b)
@@ -1,72 +1,72 b''
1 """The IPython ZMQ-based parallel computing interface.
1 """The IPython ZMQ-based parallel computing interface.
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import warnings
19 import warnings
20
20
21 import zmq
21 import zmq
22
22
23 from IPython.config.configurable import MultipleInstanceError
23 from IPython.config.configurable import MultipleInstanceError
24 from IPython.utils.zmqrelated import check_for_zmq
24 from IPython.utils.zmqrelated import check_for_zmq
25
25
26 min_pyzmq = '2.1.11'
26 min_pyzmq = '2.1.11'
27
27
28 check_for_zmq(min_pyzmq, 'ipython_parallel')
28 check_for_zmq(min_pyzmq, 'ipython_parallel')
29
29
30 from IPython.utils.pickleutil import Reference
30 from ipython_kernel.pickleutil import Reference
31
31
32 from .client.asyncresult import *
32 from .client.asyncresult import *
33 from .client.client import Client
33 from .client.client import Client
34 from .client.remotefunction import *
34 from .client.remotefunction import *
35 from .client.view import *
35 from .client.view import *
36 from .controller.dependency import *
36 from .controller.dependency import *
37 from .error import *
37 from .error import *
38 from .util import interactive
38 from .util import interactive
39
39
40 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
41 # Functions
41 # Functions
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 def bind_kernel(**kwargs):
45 def bind_kernel(**kwargs):
46 """Bind an Engine's Kernel to be used as a full IPython kernel.
46 """Bind an Engine's Kernel to be used as a full IPython kernel.
47
47
48 This allows a running Engine to be used simultaneously as a full IPython kernel
48 This allows a running Engine to be used simultaneously as a full IPython kernel
49 with the QtConsole or other frontends.
49 with the QtConsole or other frontends.
50
50
51 This function returns immediately.
51 This function returns immediately.
52 """
52 """
53 from IPython.kernel.zmq.kernelapp import IPKernelApp
53 from IPython.kernel.zmq.kernelapp import IPKernelApp
54 from ipython_parallel.apps.ipengineapp import IPEngineApp
54 from ipython_parallel.apps.ipengineapp import IPEngineApp
55
55
56 # first check for IPKernelApp, in which case this should be a no-op
56 # first check for IPKernelApp, in which case this should be a no-op
57 # because there is already a bound kernel
57 # because there is already a bound kernel
58 if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp):
58 if IPKernelApp.initialized() and isinstance(IPKernelApp._instance, IPKernelApp):
59 return
59 return
60
60
61 if IPEngineApp.initialized():
61 if IPEngineApp.initialized():
62 try:
62 try:
63 app = IPEngineApp.instance()
63 app = IPEngineApp.instance()
64 except MultipleInstanceError:
64 except MultipleInstanceError:
65 pass
65 pass
66 else:
66 else:
67 return app.bind_kernel(**kwargs)
67 return app.bind_kernel(**kwargs)
68
68
69 raise RuntimeError("bind_kernel be called from an IPEngineApp instance")
69 raise RuntimeError("bind_kernel be called from an IPEngineApp instance")
70
70
71
71
72
72
@@ -1,1125 +1,1125 b''
1 """Views of remote engines."""
1 """Views of remote engines."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 from __future__ import print_function
6 from __future__ import print_function
7
7
8 import imp
8 import imp
9 import sys
9 import sys
10 import warnings
10 import warnings
11 from contextlib import contextmanager
11 from contextlib import contextmanager
12 from types import ModuleType
12 from types import ModuleType
13
13
14 import zmq
14 import zmq
15
15
16 from IPython.testing.skipdoctest import skip_doctest
16 from IPython.testing.skipdoctest import skip_doctest
17 from IPython.utils import pickleutil
17 from IPython.utils import pickleutil
18 from IPython.utils.traitlets import (
18 from IPython.utils.traitlets import (
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
19 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
20 )
20 )
21 from decorator import decorator
21 from decorator import decorator
22
22
23 from ipython_parallel import util
23 from ipython_parallel import util
24 from ipython_parallel.controller.dependency import Dependency, dependent
24 from ipython_parallel.controller.dependency import Dependency, dependent
25 from IPython.utils.py3compat import string_types, iteritems, PY3
25 from IPython.utils.py3compat import string_types, iteritems, PY3
26
26
27 from . import map as Map
27 from . import map as Map
28 from .asyncresult import AsyncResult, AsyncMapResult
28 from .asyncresult import AsyncResult, AsyncMapResult
29 from .remotefunction import ParallelFunction, parallel, remote, getname
29 from .remotefunction import ParallelFunction, parallel, remote, getname
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Decorators
32 # Decorators
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 @decorator
35 @decorator
36 def save_ids(f, self, *args, **kwargs):
36 def save_ids(f, self, *args, **kwargs):
37 """Keep our history and outstanding attributes up to date after a method call."""
37 """Keep our history and outstanding attributes up to date after a method call."""
38 n_previous = len(self.client.history)
38 n_previous = len(self.client.history)
39 try:
39 try:
40 ret = f(self, *args, **kwargs)
40 ret = f(self, *args, **kwargs)
41 finally:
41 finally:
42 nmsgs = len(self.client.history) - n_previous
42 nmsgs = len(self.client.history) - n_previous
43 msg_ids = self.client.history[-nmsgs:]
43 msg_ids = self.client.history[-nmsgs:]
44 self.history.extend(msg_ids)
44 self.history.extend(msg_ids)
45 self.outstanding.update(msg_ids)
45 self.outstanding.update(msg_ids)
46 return ret
46 return ret
47
47
48 @decorator
48 @decorator
49 def sync_results(f, self, *args, **kwargs):
49 def sync_results(f, self, *args, **kwargs):
50 """sync relevant results from self.client to our results attribute."""
50 """sync relevant results from self.client to our results attribute."""
51 if self._in_sync_results:
51 if self._in_sync_results:
52 return f(self, *args, **kwargs)
52 return f(self, *args, **kwargs)
53 self._in_sync_results = True
53 self._in_sync_results = True
54 try:
54 try:
55 ret = f(self, *args, **kwargs)
55 ret = f(self, *args, **kwargs)
56 finally:
56 finally:
57 self._in_sync_results = False
57 self._in_sync_results = False
58 self._sync_results()
58 self._sync_results()
59 return ret
59 return ret
60
60
61 @decorator
61 @decorator
62 def spin_after(f, self, *args, **kwargs):
62 def spin_after(f, self, *args, **kwargs):
63 """call spin after the method."""
63 """call spin after the method."""
64 ret = f(self, *args, **kwargs)
64 ret = f(self, *args, **kwargs)
65 self.spin()
65 self.spin()
66 return ret
66 return ret
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72 @skip_doctest
72 @skip_doctest
73 class View(HasTraits):
73 class View(HasTraits):
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
74 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
75
75
76 Don't use this class, use subclasses.
76 Don't use this class, use subclasses.
77
77
78 Methods
78 Methods
79 -------
79 -------
80
80
81 spin
81 spin
82 flushes incoming results and registration state changes
82 flushes incoming results and registration state changes
83 control methods spin, and requesting `ids` also ensures up to date
83 control methods spin, and requesting `ids` also ensures up to date
84
84
85 wait
85 wait
86 wait on one or more msg_ids
86 wait on one or more msg_ids
87
87
88 execution methods
88 execution methods
89 apply
89 apply
90 legacy: execute, run
90 legacy: execute, run
91
91
92 data movement
92 data movement
93 push, pull, scatter, gather
93 push, pull, scatter, gather
94
94
95 query methods
95 query methods
96 get_result, queue_status, purge_results, result_status
96 get_result, queue_status, purge_results, result_status
97
97
98 control methods
98 control methods
99 abort, shutdown
99 abort, shutdown
100
100
101 """
101 """
102 # flags
102 # flags
103 block=Bool(False)
103 block=Bool(False)
104 track=Bool(True)
104 track=Bool(True)
105 targets = Any()
105 targets = Any()
106
106
107 history=List()
107 history=List()
108 outstanding = Set()
108 outstanding = Set()
109 results = Dict()
109 results = Dict()
110 client = Instance('ipython_parallel.Client', allow_none=True)
110 client = Instance('ipython_parallel.Client', allow_none=True)
111
111
112 _socket = Instance('zmq.Socket', allow_none=True)
112 _socket = Instance('zmq.Socket', allow_none=True)
113 _flag_names = List(['targets', 'block', 'track'])
113 _flag_names = List(['targets', 'block', 'track'])
114 _in_sync_results = Bool(False)
114 _in_sync_results = Bool(False)
115 _targets = Any()
115 _targets = Any()
116 _idents = Any()
116 _idents = Any()
117
117
118 def __init__(self, client=None, socket=None, **flags):
118 def __init__(self, client=None, socket=None, **flags):
119 super(View, self).__init__(client=client, _socket=socket)
119 super(View, self).__init__(client=client, _socket=socket)
120 self.results = client.results
120 self.results = client.results
121 self.block = client.block
121 self.block = client.block
122
122
123 self.set_flags(**flags)
123 self.set_flags(**flags)
124
124
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
125 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
126
126
127 def __repr__(self):
127 def __repr__(self):
128 strtargets = str(self.targets)
128 strtargets = str(self.targets)
129 if len(strtargets) > 16:
129 if len(strtargets) > 16:
130 strtargets = strtargets[:12]+'...]'
130 strtargets = strtargets[:12]+'...]'
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
131 return "<%s %s>"%(self.__class__.__name__, strtargets)
132
132
133 def __len__(self):
133 def __len__(self):
134 if isinstance(self.targets, list):
134 if isinstance(self.targets, list):
135 return len(self.targets)
135 return len(self.targets)
136 elif isinstance(self.targets, int):
136 elif isinstance(self.targets, int):
137 return 1
137 return 1
138 else:
138 else:
139 return len(self.client)
139 return len(self.client)
140
140
141 def set_flags(self, **kwargs):
141 def set_flags(self, **kwargs):
142 """set my attribute flags by keyword.
142 """set my attribute flags by keyword.
143
143
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
144 Views determine behavior with a few attributes (`block`, `track`, etc.).
145 These attributes can be set all at once by name with this method.
145 These attributes can be set all at once by name with this method.
146
146
147 Parameters
147 Parameters
148 ----------
148 ----------
149
149
150 block : bool
150 block : bool
151 whether to wait for results
151 whether to wait for results
152 track : bool
152 track : bool
153 whether to create a MessageTracker to allow the user to
153 whether to create a MessageTracker to allow the user to
154 safely edit after arrays and buffers during non-copying
154 safely edit after arrays and buffers during non-copying
155 sends.
155 sends.
156 """
156 """
157 for name, value in iteritems(kwargs):
157 for name, value in iteritems(kwargs):
158 if name not in self._flag_names:
158 if name not in self._flag_names:
159 raise KeyError("Invalid name: %r"%name)
159 raise KeyError("Invalid name: %r"%name)
160 else:
160 else:
161 setattr(self, name, value)
161 setattr(self, name, value)
162
162
163 @contextmanager
163 @contextmanager
164 def temp_flags(self, **kwargs):
164 def temp_flags(self, **kwargs):
165 """temporarily set flags, for use in `with` statements.
165 """temporarily set flags, for use in `with` statements.
166
166
167 See set_flags for permanent setting of flags
167 See set_flags for permanent setting of flags
168
168
169 Examples
169 Examples
170 --------
170 --------
171
171
172 >>> view.track=False
172 >>> view.track=False
173 ...
173 ...
174 >>> with view.temp_flags(track=True):
174 >>> with view.temp_flags(track=True):
175 ... ar = view.apply(dostuff, my_big_array)
175 ... ar = view.apply(dostuff, my_big_array)
176 ... ar.tracker.wait() # wait for send to finish
176 ... ar.tracker.wait() # wait for send to finish
177 >>> view.track
177 >>> view.track
178 False
178 False
179
179
180 """
180 """
181 # preflight: save flags, and set temporaries
181 # preflight: save flags, and set temporaries
182 saved_flags = {}
182 saved_flags = {}
183 for f in self._flag_names:
183 for f in self._flag_names:
184 saved_flags[f] = getattr(self, f)
184 saved_flags[f] = getattr(self, f)
185 self.set_flags(**kwargs)
185 self.set_flags(**kwargs)
186 # yield to the with-statement block
186 # yield to the with-statement block
187 try:
187 try:
188 yield
188 yield
189 finally:
189 finally:
190 # postflight: restore saved flags
190 # postflight: restore saved flags
191 self.set_flags(**saved_flags)
191 self.set_flags(**saved_flags)
192
192
193
193
194 #----------------------------------------------------------------
194 #----------------------------------------------------------------
195 # apply
195 # apply
196 #----------------------------------------------------------------
196 #----------------------------------------------------------------
197
197
198 def _sync_results(self):
198 def _sync_results(self):
199 """to be called by @sync_results decorator
199 """to be called by @sync_results decorator
200
200
201 after submitting any tasks.
201 after submitting any tasks.
202 """
202 """
203 delta = self.outstanding.difference(self.client.outstanding)
203 delta = self.outstanding.difference(self.client.outstanding)
204 completed = self.outstanding.intersection(delta)
204 completed = self.outstanding.intersection(delta)
205 self.outstanding = self.outstanding.difference(completed)
205 self.outstanding = self.outstanding.difference(completed)
206
206
207 @sync_results
207 @sync_results
208 @save_ids
208 @save_ids
209 def _really_apply(self, f, args, kwargs, block=None, **options):
209 def _really_apply(self, f, args, kwargs, block=None, **options):
210 """wrapper for client.send_apply_request"""
210 """wrapper for client.send_apply_request"""
211 raise NotImplementedError("Implement in subclasses")
211 raise NotImplementedError("Implement in subclasses")
212
212
213 def apply(self, f, *args, **kwargs):
213 def apply(self, f, *args, **kwargs):
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
214 """calls ``f(*args, **kwargs)`` on remote engines, returning the result.
215
215
216 This method sets all apply flags via this View's attributes.
216 This method sets all apply flags via this View's attributes.
217
217
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
218 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult`
219 instance if ``self.block`` is False, otherwise the return value of
219 instance if ``self.block`` is False, otherwise the return value of
220 ``f(*args, **kwargs)``.
220 ``f(*args, **kwargs)``.
221 """
221 """
222 return self._really_apply(f, args, kwargs)
222 return self._really_apply(f, args, kwargs)
223
223
224 def apply_async(self, f, *args, **kwargs):
224 def apply_async(self, f, *args, **kwargs):
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
225 """calls ``f(*args, **kwargs)`` on remote engines in a nonblocking manner.
226
226
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
227 Returns :class:`~IPython.parallel.client.asyncresult.AsyncResult` instance.
228 """
228 """
229 return self._really_apply(f, args, kwargs, block=False)
229 return self._really_apply(f, args, kwargs, block=False)
230
230
231 @spin_after
231 @spin_after
232 def apply_sync(self, f, *args, **kwargs):
232 def apply_sync(self, f, *args, **kwargs):
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
233 """calls ``f(*args, **kwargs)`` on remote engines in a blocking manner,
234 returning the result.
234 returning the result.
235 """
235 """
236 return self._really_apply(f, args, kwargs, block=True)
236 return self._really_apply(f, args, kwargs, block=True)
237
237
238 #----------------------------------------------------------------
238 #----------------------------------------------------------------
239 # wrappers for client and control methods
239 # wrappers for client and control methods
240 #----------------------------------------------------------------
240 #----------------------------------------------------------------
241 @sync_results
241 @sync_results
242 def spin(self):
242 def spin(self):
243 """spin the client, and sync"""
243 """spin the client, and sync"""
244 self.client.spin()
244 self.client.spin()
245
245
246 @sync_results
246 @sync_results
247 def wait(self, jobs=None, timeout=-1):
247 def wait(self, jobs=None, timeout=-1):
248 """waits on one or more `jobs`, for up to `timeout` seconds.
248 """waits on one or more `jobs`, for up to `timeout` seconds.
249
249
250 Parameters
250 Parameters
251 ----------
251 ----------
252
252
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
253 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
254 ints are indices to self.history
254 ints are indices to self.history
255 strs are msg_ids
255 strs are msg_ids
256 default: wait on all outstanding messages
256 default: wait on all outstanding messages
257 timeout : float
257 timeout : float
258 a time in seconds, after which to give up.
258 a time in seconds, after which to give up.
259 default is -1, which means no timeout
259 default is -1, which means no timeout
260
260
261 Returns
261 Returns
262 -------
262 -------
263
263
264 True : when all msg_ids are done
264 True : when all msg_ids are done
265 False : timeout reached, some msg_ids still outstanding
265 False : timeout reached, some msg_ids still outstanding
266 """
266 """
267 if jobs is None:
267 if jobs is None:
268 jobs = self.history
268 jobs = self.history
269 return self.client.wait(jobs, timeout)
269 return self.client.wait(jobs, timeout)
270
270
271 def abort(self, jobs=None, targets=None, block=None):
271 def abort(self, jobs=None, targets=None, block=None):
272 """Abort jobs on my engines.
272 """Abort jobs on my engines.
273
273
274 Parameters
274 Parameters
275 ----------
275 ----------
276
276
277 jobs : None, str, list of strs, optional
277 jobs : None, str, list of strs, optional
278 if None: abort all jobs.
278 if None: abort all jobs.
279 else: abort specific msg_id(s).
279 else: abort specific msg_id(s).
280 """
280 """
281 block = block if block is not None else self.block
281 block = block if block is not None else self.block
282 targets = targets if targets is not None else self.targets
282 targets = targets if targets is not None else self.targets
283 jobs = jobs if jobs is not None else list(self.outstanding)
283 jobs = jobs if jobs is not None else list(self.outstanding)
284
284
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
285 return self.client.abort(jobs=jobs, targets=targets, block=block)
286
286
287 def queue_status(self, targets=None, verbose=False):
287 def queue_status(self, targets=None, verbose=False):
288 """Fetch the Queue status of my engines"""
288 """Fetch the Queue status of my engines"""
289 targets = targets if targets is not None else self.targets
289 targets = targets if targets is not None else self.targets
290 return self.client.queue_status(targets=targets, verbose=verbose)
290 return self.client.queue_status(targets=targets, verbose=verbose)
291
291
292 def purge_results(self, jobs=[], targets=[]):
292 def purge_results(self, jobs=[], targets=[]):
293 """Instruct the controller to forget specific results."""
293 """Instruct the controller to forget specific results."""
294 if targets is None or targets == 'all':
294 if targets is None or targets == 'all':
295 targets = self.targets
295 targets = self.targets
296 return self.client.purge_results(jobs=jobs, targets=targets)
296 return self.client.purge_results(jobs=jobs, targets=targets)
297
297
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
298 def shutdown(self, targets=None, restart=False, hub=False, block=None):
299 """Terminates one or more engine processes, optionally including the hub.
299 """Terminates one or more engine processes, optionally including the hub.
300 """
300 """
301 block = self.block if block is None else block
301 block = self.block if block is None else block
302 if targets is None or targets == 'all':
302 if targets is None or targets == 'all':
303 targets = self.targets
303 targets = self.targets
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
304 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
305
305
306 @spin_after
306 @spin_after
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
307 def get_result(self, indices_or_msg_ids=None, block=None, owner=True):
308 """return one or more results, specified by history index or msg_id.
308 """return one or more results, specified by history index or msg_id.
309
309
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
310 See :meth:`IPython.parallel.client.client.Client.get_result` for details.
311 """
311 """
312
312
313 if indices_or_msg_ids is None:
313 if indices_or_msg_ids is None:
314 indices_or_msg_ids = -1
314 indices_or_msg_ids = -1
315 if isinstance(indices_or_msg_ids, int):
315 if isinstance(indices_or_msg_ids, int):
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
316 indices_or_msg_ids = self.history[indices_or_msg_ids]
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
317 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
318 indices_or_msg_ids = list(indices_or_msg_ids)
318 indices_or_msg_ids = list(indices_or_msg_ids)
319 for i,index in enumerate(indices_or_msg_ids):
319 for i,index in enumerate(indices_or_msg_ids):
320 if isinstance(index, int):
320 if isinstance(index, int):
321 indices_or_msg_ids[i] = self.history[index]
321 indices_or_msg_ids[i] = self.history[index]
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
322 return self.client.get_result(indices_or_msg_ids, block=block, owner=owner)
323
323
324 #-------------------------------------------------------------------
324 #-------------------------------------------------------------------
325 # Map
325 # Map
326 #-------------------------------------------------------------------
326 #-------------------------------------------------------------------
327
327
328 @sync_results
328 @sync_results
329 def map(self, f, *sequences, **kwargs):
329 def map(self, f, *sequences, **kwargs):
330 """override in subclasses"""
330 """override in subclasses"""
331 raise NotImplementedError
331 raise NotImplementedError
332
332
333 def map_async(self, f, *sequences, **kwargs):
333 def map_async(self, f, *sequences, **kwargs):
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
334 """Parallel version of builtin :func:`python:map`, using this view's engines.
335
335
336 This is equivalent to ``map(...block=False)``.
336 This is equivalent to ``map(...block=False)``.
337
337
338 See `self.map` for details.
338 See `self.map` for details.
339 """
339 """
340 if 'block' in kwargs:
340 if 'block' in kwargs:
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
341 raise TypeError("map_async doesn't take a `block` keyword argument.")
342 kwargs['block'] = False
342 kwargs['block'] = False
343 return self.map(f,*sequences,**kwargs)
343 return self.map(f,*sequences,**kwargs)
344
344
345 def map_sync(self, f, *sequences, **kwargs):
345 def map_sync(self, f, *sequences, **kwargs):
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
346 """Parallel version of builtin :func:`python:map`, using this view's engines.
347
347
348 This is equivalent to ``map(...block=True)``.
348 This is equivalent to ``map(...block=True)``.
349
349
350 See `self.map` for details.
350 See `self.map` for details.
351 """
351 """
352 if 'block' in kwargs:
352 if 'block' in kwargs:
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
353 raise TypeError("map_sync doesn't take a `block` keyword argument.")
354 kwargs['block'] = True
354 kwargs['block'] = True
355 return self.map(f,*sequences,**kwargs)
355 return self.map(f,*sequences,**kwargs)
356
356
357 def imap(self, f, *sequences, **kwargs):
357 def imap(self, f, *sequences, **kwargs):
358 """Parallel version of :func:`itertools.imap`.
358 """Parallel version of :func:`itertools.imap`.
359
359
360 See `self.map` for details.
360 See `self.map` for details.
361
361
362 """
362 """
363
363
364 return iter(self.map_async(f,*sequences, **kwargs))
364 return iter(self.map_async(f,*sequences, **kwargs))
365
365
366 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
367 # Decorators
367 # Decorators
368 #-------------------------------------------------------------------
368 #-------------------------------------------------------------------
369
369
370 def remote(self, block=None, **flags):
370 def remote(self, block=None, **flags):
371 """Decorator for making a RemoteFunction"""
371 """Decorator for making a RemoteFunction"""
372 block = self.block if block is None else block
372 block = self.block if block is None else block
373 return remote(self, block=block, **flags)
373 return remote(self, block=block, **flags)
374
374
375 def parallel(self, dist='b', block=None, **flags):
375 def parallel(self, dist='b', block=None, **flags):
376 """Decorator for making a ParallelFunction"""
376 """Decorator for making a ParallelFunction"""
377 block = self.block if block is None else block
377 block = self.block if block is None else block
378 return parallel(self, dist=dist, block=block, **flags)
378 return parallel(self, dist=dist, block=block, **flags)
379
379
380 @skip_doctest
380 @skip_doctest
381 class DirectView(View):
381 class DirectView(View):
382 """Direct Multiplexer View of one or more engines.
382 """Direct Multiplexer View of one or more engines.
383
383
384 These are created via indexed access to a client:
384 These are created via indexed access to a client:
385
385
386 >>> dv_1 = client[1]
386 >>> dv_1 = client[1]
387 >>> dv_all = client[:]
387 >>> dv_all = client[:]
388 >>> dv_even = client[::2]
388 >>> dv_even = client[::2]
389 >>> dv_some = client[1:3]
389 >>> dv_some = client[1:3]
390
390
391 This object provides dictionary access to engine namespaces:
391 This object provides dictionary access to engine namespaces:
392
392
393 # push a=5:
393 # push a=5:
394 >>> dv['a'] = 5
394 >>> dv['a'] = 5
395 # pull 'foo':
395 # pull 'foo':
396 >>> dv['foo']
396 >>> dv['foo']
397
397
398 """
398 """
399
399
400 def __init__(self, client=None, socket=None, targets=None):
400 def __init__(self, client=None, socket=None, targets=None):
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
401 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
402
402
403 @property
403 @property
404 def importer(self):
404 def importer(self):
405 """sync_imports(local=True) as a property.
405 """sync_imports(local=True) as a property.
406
406
407 See sync_imports for details.
407 See sync_imports for details.
408
408
409 """
409 """
410 return self.sync_imports(True)
410 return self.sync_imports(True)
411
411
412 @contextmanager
412 @contextmanager
413 def sync_imports(self, local=True, quiet=False):
413 def sync_imports(self, local=True, quiet=False):
414 """Context Manager for performing simultaneous local and remote imports.
414 """Context Manager for performing simultaneous local and remote imports.
415
415
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
416 'import x as y' will *not* work. The 'as y' part will simply be ignored.
417
417
418 If `local=True`, then the package will also be imported locally.
418 If `local=True`, then the package will also be imported locally.
419
419
420 If `quiet=True`, no output will be produced when attempting remote
420 If `quiet=True`, no output will be produced when attempting remote
421 imports.
421 imports.
422
422
423 Note that remote-only (`local=False`) imports have not been implemented.
423 Note that remote-only (`local=False`) imports have not been implemented.
424
424
425 >>> with view.sync_imports():
425 >>> with view.sync_imports():
426 ... from numpy import recarray
426 ... from numpy import recarray
427 importing recarray from numpy on engine(s)
427 importing recarray from numpy on engine(s)
428
428
429 """
429 """
430 from IPython.utils.py3compat import builtin_mod
430 from IPython.utils.py3compat import builtin_mod
431 local_import = builtin_mod.__import__
431 local_import = builtin_mod.__import__
432 modules = set()
432 modules = set()
433 results = []
433 results = []
434 @util.interactive
434 @util.interactive
435 def remote_import(name, fromlist, level):
435 def remote_import(name, fromlist, level):
436 """the function to be passed to apply, that actually performs the import
436 """the function to be passed to apply, that actually performs the import
437 on the engine, and loads up the user namespace.
437 on the engine, and loads up the user namespace.
438 """
438 """
439 import sys
439 import sys
440 user_ns = globals()
440 user_ns = globals()
441 mod = __import__(name, fromlist=fromlist, level=level)
441 mod = __import__(name, fromlist=fromlist, level=level)
442 if fromlist:
442 if fromlist:
443 for key in fromlist:
443 for key in fromlist:
444 user_ns[key] = getattr(mod, key)
444 user_ns[key] = getattr(mod, key)
445 else:
445 else:
446 user_ns[name] = sys.modules[name]
446 user_ns[name] = sys.modules[name]
447
447
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
448 def view_import(name, globals={}, locals={}, fromlist=[], level=0):
449 """the drop-in replacement for __import__, that optionally imports
449 """the drop-in replacement for __import__, that optionally imports
450 locally as well.
450 locally as well.
451 """
451 """
452 # don't override nested imports
452 # don't override nested imports
453 save_import = builtin_mod.__import__
453 save_import = builtin_mod.__import__
454 builtin_mod.__import__ = local_import
454 builtin_mod.__import__ = local_import
455
455
456 if imp.lock_held():
456 if imp.lock_held():
457 # this is a side-effect import, don't do it remotely, or even
457 # this is a side-effect import, don't do it remotely, or even
458 # ignore the local effects
458 # ignore the local effects
459 return local_import(name, globals, locals, fromlist, level)
459 return local_import(name, globals, locals, fromlist, level)
460
460
461 imp.acquire_lock()
461 imp.acquire_lock()
462 if local:
462 if local:
463 mod = local_import(name, globals, locals, fromlist, level)
463 mod = local_import(name, globals, locals, fromlist, level)
464 else:
464 else:
465 raise NotImplementedError("remote-only imports not yet implemented")
465 raise NotImplementedError("remote-only imports not yet implemented")
466 imp.release_lock()
466 imp.release_lock()
467
467
468 key = name+':'+','.join(fromlist or [])
468 key = name+':'+','.join(fromlist or [])
469 if level <= 0 and key not in modules:
469 if level <= 0 and key not in modules:
470 modules.add(key)
470 modules.add(key)
471 if not quiet:
471 if not quiet:
472 if fromlist:
472 if fromlist:
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
473 print("importing %s from %s on engine(s)"%(','.join(fromlist), name))
474 else:
474 else:
475 print("importing %s on engine(s)"%name)
475 print("importing %s on engine(s)"%name)
476 results.append(self.apply_async(remote_import, name, fromlist, level))
476 results.append(self.apply_async(remote_import, name, fromlist, level))
477 # restore override
477 # restore override
478 builtin_mod.__import__ = save_import
478 builtin_mod.__import__ = save_import
479
479
480 return mod
480 return mod
481
481
482 # override __import__
482 # override __import__
483 builtin_mod.__import__ = view_import
483 builtin_mod.__import__ = view_import
484 try:
484 try:
485 # enter the block
485 # enter the block
486 yield
486 yield
487 except ImportError:
487 except ImportError:
488 if local:
488 if local:
489 raise
489 raise
490 else:
490 else:
491 # ignore import errors if not doing local imports
491 # ignore import errors if not doing local imports
492 pass
492 pass
493 finally:
493 finally:
494 # always restore __import__
494 # always restore __import__
495 builtin_mod.__import__ = local_import
495 builtin_mod.__import__ = local_import
496
496
497 for r in results:
497 for r in results:
498 # raise possible remote ImportErrors here
498 # raise possible remote ImportErrors here
499 r.get()
499 r.get()
500
500
501 def use_dill(self):
501 def use_dill(self):
502 """Expand serialization support with dill
502 """Expand serialization support with dill
503
503
504 adds support for closures, etc.
504 adds support for closures, etc.
505
505
506 This calls IPython.utils.pickleutil.use_dill() here and on each engine.
506 This calls ipython_kernel.pickleutil.use_dill() here and on each engine.
507 """
507 """
508 pickleutil.use_dill()
508 pickleutil.use_dill()
509 return self.apply(pickleutil.use_dill)
509 return self.apply(pickleutil.use_dill)
510
510
511 def use_cloudpickle(self):
511 def use_cloudpickle(self):
512 """Expand serialization support with cloudpickle.
512 """Expand serialization support with cloudpickle.
513 """
513 """
514 pickleutil.use_cloudpickle()
514 pickleutil.use_cloudpickle()
515 return self.apply(pickleutil.use_cloudpickle)
515 return self.apply(pickleutil.use_cloudpickle)
516
516
517
517
518 @sync_results
518 @sync_results
519 @save_ids
519 @save_ids
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
520 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
521 """calls f(*args, **kwargs) on remote engines, returning the result.
521 """calls f(*args, **kwargs) on remote engines, returning the result.
522
522
523 This method sets all of `apply`'s flags via this View's attributes.
523 This method sets all of `apply`'s flags via this View's attributes.
524
524
525 Parameters
525 Parameters
526 ----------
526 ----------
527
527
528 f : callable
528 f : callable
529
529
530 args : list [default: empty]
530 args : list [default: empty]
531
531
532 kwargs : dict [default: empty]
532 kwargs : dict [default: empty]
533
533
534 targets : target list [default: self.targets]
534 targets : target list [default: self.targets]
535 where to run
535 where to run
536 block : bool [default: self.block]
536 block : bool [default: self.block]
537 whether to block
537 whether to block
538 track : bool [default: self.track]
538 track : bool [default: self.track]
539 whether to ask zmq to track the message, for safe non-copying sends
539 whether to ask zmq to track the message, for safe non-copying sends
540
540
541 Returns
541 Returns
542 -------
542 -------
543
543
544 if self.block is False:
544 if self.block is False:
545 returns AsyncResult
545 returns AsyncResult
546 else:
546 else:
547 returns actual result of f(*args, **kwargs) on the engine(s)
547 returns actual result of f(*args, **kwargs) on the engine(s)
548 This will be a list of self.targets is also a list (even length 1), or
548 This will be a list of self.targets is also a list (even length 1), or
549 the single result if self.targets is an integer engine id
549 the single result if self.targets is an integer engine id
550 """
550 """
551 args = [] if args is None else args
551 args = [] if args is None else args
552 kwargs = {} if kwargs is None else kwargs
552 kwargs = {} if kwargs is None else kwargs
553 block = self.block if block is None else block
553 block = self.block if block is None else block
554 track = self.track if track is None else track
554 track = self.track if track is None else track
555 targets = self.targets if targets is None else targets
555 targets = self.targets if targets is None else targets
556
556
557 _idents, _targets = self.client._build_targets(targets)
557 _idents, _targets = self.client._build_targets(targets)
558 msg_ids = []
558 msg_ids = []
559 trackers = []
559 trackers = []
560 for ident in _idents:
560 for ident in _idents:
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
561 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
562 ident=ident)
562 ident=ident)
563 if track:
563 if track:
564 trackers.append(msg['tracker'])
564 trackers.append(msg['tracker'])
565 msg_ids.append(msg['header']['msg_id'])
565 msg_ids.append(msg['header']['msg_id'])
566 if isinstance(targets, int):
566 if isinstance(targets, int):
567 msg_ids = msg_ids[0]
567 msg_ids = msg_ids[0]
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
568 tracker = None if track is False else zmq.MessageTracker(*trackers)
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
569 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=_targets,
570 tracker=tracker, owner=True,
570 tracker=tracker, owner=True,
571 )
571 )
572 if block:
572 if block:
573 try:
573 try:
574 return ar.get()
574 return ar.get()
575 except KeyboardInterrupt:
575 except KeyboardInterrupt:
576 pass
576 pass
577 return ar
577 return ar
578
578
579
579
580 @sync_results
580 @sync_results
581 def map(self, f, *sequences, **kwargs):
581 def map(self, f, *sequences, **kwargs):
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
582 """``view.map(f, *sequences, block=self.block)`` => list|AsyncMapResult
583
583
584 Parallel version of builtin `map`, using this View's `targets`.
584 Parallel version of builtin `map`, using this View's `targets`.
585
585
586 There will be one task per target, so work will be chunked
586 There will be one task per target, so work will be chunked
587 if the sequences are longer than `targets`.
587 if the sequences are longer than `targets`.
588
588
589 Results can be iterated as they are ready, but will become available in chunks.
589 Results can be iterated as they are ready, but will become available in chunks.
590
590
591 Parameters
591 Parameters
592 ----------
592 ----------
593
593
594 f : callable
594 f : callable
595 function to be mapped
595 function to be mapped
596 *sequences: one or more sequences of matching length
596 *sequences: one or more sequences of matching length
597 the sequences to be distributed and passed to `f`
597 the sequences to be distributed and passed to `f`
598 block : bool
598 block : bool
599 whether to wait for the result or not [default self.block]
599 whether to wait for the result or not [default self.block]
600
600
601 Returns
601 Returns
602 -------
602 -------
603
603
604
604
605 If block=False
605 If block=False
606 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
606 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
607 An object like AsyncResult, but which reassembles the sequence of results
607 An object like AsyncResult, but which reassembles the sequence of results
608 into a single list. AsyncMapResults can be iterated through before all
608 into a single list. AsyncMapResults can be iterated through before all
609 results are complete.
609 results are complete.
610 else
610 else
611 A list, the result of ``map(f,*sequences)``
611 A list, the result of ``map(f,*sequences)``
612 """
612 """
613
613
614 block = kwargs.pop('block', self.block)
614 block = kwargs.pop('block', self.block)
615 for k in kwargs.keys():
615 for k in kwargs.keys():
616 if k not in ['block', 'track']:
616 if k not in ['block', 'track']:
617 raise TypeError("invalid keyword arg, %r"%k)
617 raise TypeError("invalid keyword arg, %r"%k)
618
618
619 assert len(sequences) > 0, "must have some sequences to map onto!"
619 assert len(sequences) > 0, "must have some sequences to map onto!"
620 pf = ParallelFunction(self, f, block=block, **kwargs)
620 pf = ParallelFunction(self, f, block=block, **kwargs)
621 return pf.map(*sequences)
621 return pf.map(*sequences)
622
622
623 @sync_results
623 @sync_results
624 @save_ids
624 @save_ids
625 def execute(self, code, silent=True, targets=None, block=None):
625 def execute(self, code, silent=True, targets=None, block=None):
626 """Executes `code` on `targets` in blocking or nonblocking manner.
626 """Executes `code` on `targets` in blocking or nonblocking manner.
627
627
628 ``execute`` is always `bound` (affects engine namespace)
628 ``execute`` is always `bound` (affects engine namespace)
629
629
630 Parameters
630 Parameters
631 ----------
631 ----------
632
632
633 code : str
633 code : str
634 the code string to be executed
634 the code string to be executed
635 block : bool
635 block : bool
636 whether or not to wait until done to return
636 whether or not to wait until done to return
637 default: self.block
637 default: self.block
638 """
638 """
639 block = self.block if block is None else block
639 block = self.block if block is None else block
640 targets = self.targets if targets is None else targets
640 targets = self.targets if targets is None else targets
641
641
642 _idents, _targets = self.client._build_targets(targets)
642 _idents, _targets = self.client._build_targets(targets)
643 msg_ids = []
643 msg_ids = []
644 trackers = []
644 trackers = []
645 for ident in _idents:
645 for ident in _idents:
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
646 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
647 msg_ids.append(msg['header']['msg_id'])
647 msg_ids.append(msg['header']['msg_id'])
648 if isinstance(targets, int):
648 if isinstance(targets, int):
649 msg_ids = msg_ids[0]
649 msg_ids = msg_ids[0]
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
650 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=_targets, owner=True)
651 if block:
651 if block:
652 try:
652 try:
653 ar.get()
653 ar.get()
654 except KeyboardInterrupt:
654 except KeyboardInterrupt:
655 pass
655 pass
656 return ar
656 return ar
657
657
658 def run(self, filename, targets=None, block=None):
658 def run(self, filename, targets=None, block=None):
659 """Execute contents of `filename` on my engine(s).
659 """Execute contents of `filename` on my engine(s).
660
660
661 This simply reads the contents of the file and calls `execute`.
661 This simply reads the contents of the file and calls `execute`.
662
662
663 Parameters
663 Parameters
664 ----------
664 ----------
665
665
666 filename : str
666 filename : str
667 The path to the file
667 The path to the file
668 targets : int/str/list of ints/strs
668 targets : int/str/list of ints/strs
669 the engines on which to execute
669 the engines on which to execute
670 default : all
670 default : all
671 block : bool
671 block : bool
672 whether or not to wait until done
672 whether or not to wait until done
673 default: self.block
673 default: self.block
674
674
675 """
675 """
676 with open(filename, 'r') as f:
676 with open(filename, 'r') as f:
677 # add newline in case of trailing indented whitespace
677 # add newline in case of trailing indented whitespace
678 # which will cause SyntaxError
678 # which will cause SyntaxError
679 code = f.read()+'\n'
679 code = f.read()+'\n'
680 return self.execute(code, block=block, targets=targets)
680 return self.execute(code, block=block, targets=targets)
681
681
682 def update(self, ns):
682 def update(self, ns):
683 """update remote namespace with dict `ns`
683 """update remote namespace with dict `ns`
684
684
685 See `push` for details.
685 See `push` for details.
686 """
686 """
687 return self.push(ns, block=self.block, track=self.track)
687 return self.push(ns, block=self.block, track=self.track)
688
688
689 def push(self, ns, targets=None, block=None, track=None):
689 def push(self, ns, targets=None, block=None, track=None):
690 """update remote namespace with dict `ns`
690 """update remote namespace with dict `ns`
691
691
692 Parameters
692 Parameters
693 ----------
693 ----------
694
694
695 ns : dict
695 ns : dict
696 dict of keys with which to update engine namespace(s)
696 dict of keys with which to update engine namespace(s)
697 block : bool [default : self.block]
697 block : bool [default : self.block]
698 whether to wait to be notified of engine receipt
698 whether to wait to be notified of engine receipt
699
699
700 """
700 """
701
701
702 block = block if block is not None else self.block
702 block = block if block is not None else self.block
703 track = track if track is not None else self.track
703 track = track if track is not None else self.track
704 targets = targets if targets is not None else self.targets
704 targets = targets if targets is not None else self.targets
705 # applier = self.apply_sync if block else self.apply_async
705 # applier = self.apply_sync if block else self.apply_async
706 if not isinstance(ns, dict):
706 if not isinstance(ns, dict):
707 raise TypeError("Must be a dict, not %s"%type(ns))
707 raise TypeError("Must be a dict, not %s"%type(ns))
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
708 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
709
709
710 def get(self, key_s):
710 def get(self, key_s):
711 """get object(s) by `key_s` from remote namespace
711 """get object(s) by `key_s` from remote namespace
712
712
713 see `pull` for details.
713 see `pull` for details.
714 """
714 """
715 # block = block if block is not None else self.block
715 # block = block if block is not None else self.block
716 return self.pull(key_s, block=True)
716 return self.pull(key_s, block=True)
717
717
718 def pull(self, names, targets=None, block=None):
718 def pull(self, names, targets=None, block=None):
719 """get object(s) by `name` from remote namespace
719 """get object(s) by `name` from remote namespace
720
720
721 will return one object if it is a key.
721 will return one object if it is a key.
722 can also take a list of keys, in which case it will return a list of objects.
722 can also take a list of keys, in which case it will return a list of objects.
723 """
723 """
724 block = block if block is not None else self.block
724 block = block if block is not None else self.block
725 targets = targets if targets is not None else self.targets
725 targets = targets if targets is not None else self.targets
726 applier = self.apply_sync if block else self.apply_async
726 applier = self.apply_sync if block else self.apply_async
727 if isinstance(names, string_types):
727 if isinstance(names, string_types):
728 pass
728 pass
729 elif isinstance(names, (list,tuple,set)):
729 elif isinstance(names, (list,tuple,set)):
730 for key in names:
730 for key in names:
731 if not isinstance(key, string_types):
731 if not isinstance(key, string_types):
732 raise TypeError("keys must be str, not type %r"%type(key))
732 raise TypeError("keys must be str, not type %r"%type(key))
733 else:
733 else:
734 raise TypeError("names must be strs, not %r"%names)
734 raise TypeError("names must be strs, not %r"%names)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
735 return self._really_apply(util._pull, (names,), block=block, targets=targets)
736
736
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
737 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
738 """
738 """
739 Partition a Python sequence and send the partitions to a set of engines.
739 Partition a Python sequence and send the partitions to a set of engines.
740 """
740 """
741 block = block if block is not None else self.block
741 block = block if block is not None else self.block
742 track = track if track is not None else self.track
742 track = track if track is not None else self.track
743 targets = targets if targets is not None else self.targets
743 targets = targets if targets is not None else self.targets
744
744
745 # construct integer ID list:
745 # construct integer ID list:
746 targets = self.client._build_targets(targets)[1]
746 targets = self.client._build_targets(targets)[1]
747
747
748 mapObject = Map.dists[dist]()
748 mapObject = Map.dists[dist]()
749 nparts = len(targets)
749 nparts = len(targets)
750 msg_ids = []
750 msg_ids = []
751 trackers = []
751 trackers = []
752 for index, engineid in enumerate(targets):
752 for index, engineid in enumerate(targets):
753 partition = mapObject.getPartition(seq, index, nparts)
753 partition = mapObject.getPartition(seq, index, nparts)
754 if flatten and len(partition) == 1:
754 if flatten and len(partition) == 1:
755 ns = {key: partition[0]}
755 ns = {key: partition[0]}
756 else:
756 else:
757 ns = {key: partition}
757 ns = {key: partition}
758 r = self.push(ns, block=False, track=track, targets=engineid)
758 r = self.push(ns, block=False, track=track, targets=engineid)
759 msg_ids.extend(r.msg_ids)
759 msg_ids.extend(r.msg_ids)
760 if track:
760 if track:
761 trackers.append(r._tracker)
761 trackers.append(r._tracker)
762
762
763 if track:
763 if track:
764 tracker = zmq.MessageTracker(*trackers)
764 tracker = zmq.MessageTracker(*trackers)
765 else:
765 else:
766 tracker = None
766 tracker = None
767
767
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
768 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets,
769 tracker=tracker, owner=True,
769 tracker=tracker, owner=True,
770 )
770 )
771 if block:
771 if block:
772 r.wait()
772 r.wait()
773 else:
773 else:
774 return r
774 return r
775
775
776 @sync_results
776 @sync_results
777 @save_ids
777 @save_ids
778 def gather(self, key, dist='b', targets=None, block=None):
778 def gather(self, key, dist='b', targets=None, block=None):
779 """
779 """
780 Gather a partitioned sequence on a set of engines as a single local seq.
780 Gather a partitioned sequence on a set of engines as a single local seq.
781 """
781 """
782 block = block if block is not None else self.block
782 block = block if block is not None else self.block
783 targets = targets if targets is not None else self.targets
783 targets = targets if targets is not None else self.targets
784 mapObject = Map.dists[dist]()
784 mapObject = Map.dists[dist]()
785 msg_ids = []
785 msg_ids = []
786
786
787 # construct integer ID list:
787 # construct integer ID list:
788 targets = self.client._build_targets(targets)[1]
788 targets = self.client._build_targets(targets)[1]
789
789
790 for index, engineid in enumerate(targets):
790 for index, engineid in enumerate(targets):
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
791 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
792
792
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
793 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
794
794
795 if block:
795 if block:
796 try:
796 try:
797 return r.get()
797 return r.get()
798 except KeyboardInterrupt:
798 except KeyboardInterrupt:
799 pass
799 pass
800 return r
800 return r
801
801
802 def __getitem__(self, key):
802 def __getitem__(self, key):
803 return self.get(key)
803 return self.get(key)
804
804
805 def __setitem__(self,key, value):
805 def __setitem__(self,key, value):
806 self.update({key:value})
806 self.update({key:value})
807
807
808 def clear(self, targets=None, block=None):
808 def clear(self, targets=None, block=None):
809 """Clear the remote namespaces on my engines."""
809 """Clear the remote namespaces on my engines."""
810 block = block if block is not None else self.block
810 block = block if block is not None else self.block
811 targets = targets if targets is not None else self.targets
811 targets = targets if targets is not None else self.targets
812 return self.client.clear(targets=targets, block=block)
812 return self.client.clear(targets=targets, block=block)
813
813
814 #----------------------------------------
814 #----------------------------------------
815 # activate for %px, %autopx, etc. magics
815 # activate for %px, %autopx, etc. magics
816 #----------------------------------------
816 #----------------------------------------
817
817
818 def activate(self, suffix=''):
818 def activate(self, suffix=''):
819 """Activate IPython magics associated with this View
819 """Activate IPython magics associated with this View
820
820
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
821 Defines the magics `%px, %autopx, %pxresult, %%px, %pxconfig`
822
822
823 Parameters
823 Parameters
824 ----------
824 ----------
825
825
826 suffix: str [default: '']
826 suffix: str [default: '']
827 The suffix, if any, for the magics. This allows you to have
827 The suffix, if any, for the magics. This allows you to have
828 multiple views associated with parallel magics at the same time.
828 multiple views associated with parallel magics at the same time.
829
829
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
830 e.g. ``rc[::2].activate(suffix='_even')`` will give you
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
831 the magics ``%px_even``, ``%pxresult_even``, etc. for running magics
832 on the even engines.
832 on the even engines.
833 """
833 """
834
834
835 from IPython.parallel.client.magics import ParallelMagics
835 from IPython.parallel.client.magics import ParallelMagics
836
836
837 try:
837 try:
838 # This is injected into __builtins__.
838 # This is injected into __builtins__.
839 ip = get_ipython()
839 ip = get_ipython()
840 except NameError:
840 except NameError:
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
841 print("The IPython parallel magics (%px, etc.) only work within IPython.")
842 return
842 return
843
843
844 M = ParallelMagics(ip, self, suffix)
844 M = ParallelMagics(ip, self, suffix)
845 ip.magics_manager.register(M)
845 ip.magics_manager.register(M)
846
846
847
847
848 @skip_doctest
848 @skip_doctest
849 class LoadBalancedView(View):
849 class LoadBalancedView(View):
850 """An load-balancing View that only executes via the Task scheduler.
850 """An load-balancing View that only executes via the Task scheduler.
851
851
852 Load-balanced views can be created with the client's `view` method:
852 Load-balanced views can be created with the client's `view` method:
853
853
854 >>> v = client.load_balanced_view()
854 >>> v = client.load_balanced_view()
855
855
856 or targets can be specified, to restrict the potential destinations:
856 or targets can be specified, to restrict the potential destinations:
857
857
858 >>> v = client.load_balanced_view([1,3])
858 >>> v = client.load_balanced_view([1,3])
859
859
860 which would restrict loadbalancing to between engines 1 and 3.
860 which would restrict loadbalancing to between engines 1 and 3.
861
861
862 """
862 """
863
863
864 follow=Any()
864 follow=Any()
865 after=Any()
865 after=Any()
866 timeout=CFloat()
866 timeout=CFloat()
867 retries = Integer(0)
867 retries = Integer(0)
868
868
869 _task_scheme = Any()
869 _task_scheme = Any()
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
870 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
871
871
872 def __init__(self, client=None, socket=None, **flags):
872 def __init__(self, client=None, socket=None, **flags):
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
873 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
874 self._task_scheme=client._task_scheme
874 self._task_scheme=client._task_scheme
875
875
876 def _validate_dependency(self, dep):
876 def _validate_dependency(self, dep):
877 """validate a dependency.
877 """validate a dependency.
878
878
879 For use in `set_flags`.
879 For use in `set_flags`.
880 """
880 """
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
881 if dep is None or isinstance(dep, string_types + (AsyncResult, Dependency)):
882 return True
882 return True
883 elif isinstance(dep, (list,set, tuple)):
883 elif isinstance(dep, (list,set, tuple)):
884 for d in dep:
884 for d in dep:
885 if not isinstance(d, string_types + (AsyncResult,)):
885 if not isinstance(d, string_types + (AsyncResult,)):
886 return False
886 return False
887 elif isinstance(dep, dict):
887 elif isinstance(dep, dict):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
888 if set(dep.keys()) != set(Dependency().as_dict().keys()):
889 return False
889 return False
890 if not isinstance(dep['msg_ids'], list):
890 if not isinstance(dep['msg_ids'], list):
891 return False
891 return False
892 for d in dep['msg_ids']:
892 for d in dep['msg_ids']:
893 if not isinstance(d, string_types):
893 if not isinstance(d, string_types):
894 return False
894 return False
895 else:
895 else:
896 return False
896 return False
897
897
898 return True
898 return True
899
899
900 def _render_dependency(self, dep):
900 def _render_dependency(self, dep):
901 """helper for building jsonable dependencies from various input forms."""
901 """helper for building jsonable dependencies from various input forms."""
902 if isinstance(dep, Dependency):
902 if isinstance(dep, Dependency):
903 return dep.as_dict()
903 return dep.as_dict()
904 elif isinstance(dep, AsyncResult):
904 elif isinstance(dep, AsyncResult):
905 return dep.msg_ids
905 return dep.msg_ids
906 elif dep is None:
906 elif dep is None:
907 return []
907 return []
908 else:
908 else:
909 # pass to Dependency constructor
909 # pass to Dependency constructor
910 return list(Dependency(dep))
910 return list(Dependency(dep))
911
911
912 def set_flags(self, **kwargs):
912 def set_flags(self, **kwargs):
913 """set my attribute flags by keyword.
913 """set my attribute flags by keyword.
914
914
915 A View is a wrapper for the Client's apply method, but with attributes
915 A View is a wrapper for the Client's apply method, but with attributes
916 that specify keyword arguments, those attributes can be set by keyword
916 that specify keyword arguments, those attributes can be set by keyword
917 argument with this method.
917 argument with this method.
918
918
919 Parameters
919 Parameters
920 ----------
920 ----------
921
921
922 block : bool
922 block : bool
923 whether to wait for results
923 whether to wait for results
924 track : bool
924 track : bool
925 whether to create a MessageTracker to allow the user to
925 whether to create a MessageTracker to allow the user to
926 safely edit after arrays and buffers during non-copying
926 safely edit after arrays and buffers during non-copying
927 sends.
927 sends.
928
928
929 after : Dependency or collection of msg_ids
929 after : Dependency or collection of msg_ids
930 Only for load-balanced execution (targets=None)
930 Only for load-balanced execution (targets=None)
931 Specify a list of msg_ids as a time-based dependency.
931 Specify a list of msg_ids as a time-based dependency.
932 This job will only be run *after* the dependencies
932 This job will only be run *after* the dependencies
933 have been met.
933 have been met.
934
934
935 follow : Dependency or collection of msg_ids
935 follow : Dependency or collection of msg_ids
936 Only for load-balanced execution (targets=None)
936 Only for load-balanced execution (targets=None)
937 Specify a list of msg_ids as a location-based dependency.
937 Specify a list of msg_ids as a location-based dependency.
938 This job will only be run on an engine where this dependency
938 This job will only be run on an engine where this dependency
939 is met.
939 is met.
940
940
941 timeout : float/int or None
941 timeout : float/int or None
942 Only for load-balanced execution (targets=None)
942 Only for load-balanced execution (targets=None)
943 Specify an amount of time (in seconds) for the scheduler to
943 Specify an amount of time (in seconds) for the scheduler to
944 wait for dependencies to be met before failing with a
944 wait for dependencies to be met before failing with a
945 DependencyTimeout.
945 DependencyTimeout.
946
946
947 retries : int
947 retries : int
948 Number of times a task will be retried on failure.
948 Number of times a task will be retried on failure.
949 """
949 """
950
950
951 super(LoadBalancedView, self).set_flags(**kwargs)
951 super(LoadBalancedView, self).set_flags(**kwargs)
952 for name in ('follow', 'after'):
952 for name in ('follow', 'after'):
953 if name in kwargs:
953 if name in kwargs:
954 value = kwargs[name]
954 value = kwargs[name]
955 if self._validate_dependency(value):
955 if self._validate_dependency(value):
956 setattr(self, name, value)
956 setattr(self, name, value)
957 else:
957 else:
958 raise ValueError("Invalid dependency: %r"%value)
958 raise ValueError("Invalid dependency: %r"%value)
959 if 'timeout' in kwargs:
959 if 'timeout' in kwargs:
960 t = kwargs['timeout']
960 t = kwargs['timeout']
961 if not isinstance(t, (int, float, type(None))):
961 if not isinstance(t, (int, float, type(None))):
962 if (not PY3) and (not isinstance(t, long)):
962 if (not PY3) and (not isinstance(t, long)):
963 raise TypeError("Invalid type for timeout: %r"%type(t))
963 raise TypeError("Invalid type for timeout: %r"%type(t))
964 if t is not None:
964 if t is not None:
965 if t < 0:
965 if t < 0:
966 raise ValueError("Invalid timeout: %s"%t)
966 raise ValueError("Invalid timeout: %s"%t)
967 self.timeout = t
967 self.timeout = t
968
968
969 @sync_results
969 @sync_results
970 @save_ids
970 @save_ids
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
971 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
972 after=None, follow=None, timeout=None,
972 after=None, follow=None, timeout=None,
973 targets=None, retries=None):
973 targets=None, retries=None):
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
974 """calls f(*args, **kwargs) on a remote engine, returning the result.
975
975
976 This method temporarily sets all of `apply`'s flags for a single call.
976 This method temporarily sets all of `apply`'s flags for a single call.
977
977
978 Parameters
978 Parameters
979 ----------
979 ----------
980
980
981 f : callable
981 f : callable
982
982
983 args : list [default: empty]
983 args : list [default: empty]
984
984
985 kwargs : dict [default: empty]
985 kwargs : dict [default: empty]
986
986
987 block : bool [default: self.block]
987 block : bool [default: self.block]
988 whether to block
988 whether to block
989 track : bool [default: self.track]
989 track : bool [default: self.track]
990 whether to ask zmq to track the message, for safe non-copying sends
990 whether to ask zmq to track the message, for safe non-copying sends
991
991
992 !!!!!! TODO: THE REST HERE !!!!
992 !!!!!! TODO: THE REST HERE !!!!
993
993
994 Returns
994 Returns
995 -------
995 -------
996
996
997 if self.block is False:
997 if self.block is False:
998 returns AsyncResult
998 returns AsyncResult
999 else:
999 else:
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1000 returns actual result of f(*args, **kwargs) on the engine(s)
1001 This will be a list of self.targets is also a list (even length 1), or
1001 This will be a list of self.targets is also a list (even length 1), or
1002 the single result if self.targets is an integer engine id
1002 the single result if self.targets is an integer engine id
1003 """
1003 """
1004
1004
1005 # validate whether we can run
1005 # validate whether we can run
1006 if self._socket.closed:
1006 if self._socket.closed:
1007 msg = "Task farming is disabled"
1007 msg = "Task farming is disabled"
1008 if self._task_scheme == 'pure':
1008 if self._task_scheme == 'pure':
1009 msg += " because the pure ZMQ scheduler cannot handle"
1009 msg += " because the pure ZMQ scheduler cannot handle"
1010 msg += " disappearing engines."
1010 msg += " disappearing engines."
1011 raise RuntimeError(msg)
1011 raise RuntimeError(msg)
1012
1012
1013 if self._task_scheme == 'pure':
1013 if self._task_scheme == 'pure':
1014 # pure zmq scheme doesn't support extra features
1014 # pure zmq scheme doesn't support extra features
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1015 msg = "Pure ZMQ scheduler doesn't support the following flags:"
1016 "follow, after, retries, targets, timeout"
1016 "follow, after, retries, targets, timeout"
1017 if (follow or after or retries or targets or timeout):
1017 if (follow or after or retries or targets or timeout):
1018 # hard fail on Scheduler flags
1018 # hard fail on Scheduler flags
1019 raise RuntimeError(msg)
1019 raise RuntimeError(msg)
1020 if isinstance(f, dependent):
1020 if isinstance(f, dependent):
1021 # soft warn on functional dependencies
1021 # soft warn on functional dependencies
1022 warnings.warn(msg, RuntimeWarning)
1022 warnings.warn(msg, RuntimeWarning)
1023
1023
1024 # build args
1024 # build args
1025 args = [] if args is None else args
1025 args = [] if args is None else args
1026 kwargs = {} if kwargs is None else kwargs
1026 kwargs = {} if kwargs is None else kwargs
1027 block = self.block if block is None else block
1027 block = self.block if block is None else block
1028 track = self.track if track is None else track
1028 track = self.track if track is None else track
1029 after = self.after if after is None else after
1029 after = self.after if after is None else after
1030 retries = self.retries if retries is None else retries
1030 retries = self.retries if retries is None else retries
1031 follow = self.follow if follow is None else follow
1031 follow = self.follow if follow is None else follow
1032 timeout = self.timeout if timeout is None else timeout
1032 timeout = self.timeout if timeout is None else timeout
1033 targets = self.targets if targets is None else targets
1033 targets = self.targets if targets is None else targets
1034
1034
1035 if not isinstance(retries, int):
1035 if not isinstance(retries, int):
1036 raise TypeError('retries must be int, not %r'%type(retries))
1036 raise TypeError('retries must be int, not %r'%type(retries))
1037
1037
1038 if targets is None:
1038 if targets is None:
1039 idents = []
1039 idents = []
1040 else:
1040 else:
1041 idents = self.client._build_targets(targets)[0]
1041 idents = self.client._build_targets(targets)[0]
1042 # ensure *not* bytes
1042 # ensure *not* bytes
1043 idents = [ ident.decode() for ident in idents ]
1043 idents = [ ident.decode() for ident in idents ]
1044
1044
1045 after = self._render_dependency(after)
1045 after = self._render_dependency(after)
1046 follow = self._render_dependency(follow)
1046 follow = self._render_dependency(follow)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1047 metadata = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1048
1048
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1049 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1050 metadata=metadata)
1050 metadata=metadata)
1051 tracker = None if track is False else msg['tracker']
1051 tracker = None if track is False else msg['tracker']
1052
1052
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1053 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f),
1054 targets=None, tracker=tracker, owner=True,
1054 targets=None, tracker=tracker, owner=True,
1055 )
1055 )
1056 if block:
1056 if block:
1057 try:
1057 try:
1058 return ar.get()
1058 return ar.get()
1059 except KeyboardInterrupt:
1059 except KeyboardInterrupt:
1060 pass
1060 pass
1061 return ar
1061 return ar
1062
1062
1063 @sync_results
1063 @sync_results
1064 @save_ids
1064 @save_ids
1065 def map(self, f, *sequences, **kwargs):
1065 def map(self, f, *sequences, **kwargs):
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1066 """``view.map(f, *sequences, block=self.block, chunksize=1, ordered=True)`` => list|AsyncMapResult
1067
1067
1068 Parallel version of builtin `map`, load-balanced by this View.
1068 Parallel version of builtin `map`, load-balanced by this View.
1069
1069
1070 `block`, and `chunksize` can be specified by keyword only.
1070 `block`, and `chunksize` can be specified by keyword only.
1071
1071
1072 Each `chunksize` elements will be a separate task, and will be
1072 Each `chunksize` elements will be a separate task, and will be
1073 load-balanced. This lets individual elements be available for iteration
1073 load-balanced. This lets individual elements be available for iteration
1074 as soon as they arrive.
1074 as soon as they arrive.
1075
1075
1076 Parameters
1076 Parameters
1077 ----------
1077 ----------
1078
1078
1079 f : callable
1079 f : callable
1080 function to be mapped
1080 function to be mapped
1081 *sequences: one or more sequences of matching length
1081 *sequences: one or more sequences of matching length
1082 the sequences to be distributed and passed to `f`
1082 the sequences to be distributed and passed to `f`
1083 block : bool [default self.block]
1083 block : bool [default self.block]
1084 whether to wait for the result or not
1084 whether to wait for the result or not
1085 track : bool
1085 track : bool
1086 whether to create a MessageTracker to allow the user to
1086 whether to create a MessageTracker to allow the user to
1087 safely edit after arrays and buffers during non-copying
1087 safely edit after arrays and buffers during non-copying
1088 sends.
1088 sends.
1089 chunksize : int [default 1]
1089 chunksize : int [default 1]
1090 how many elements should be in each task.
1090 how many elements should be in each task.
1091 ordered : bool [default True]
1091 ordered : bool [default True]
1092 Whether the results should be gathered as they arrive, or enforce
1092 Whether the results should be gathered as they arrive, or enforce
1093 the order of submission.
1093 the order of submission.
1094
1094
1095 Only applies when iterating through AsyncMapResult as results arrive.
1095 Only applies when iterating through AsyncMapResult as results arrive.
1096 Has no effect when block=True.
1096 Has no effect when block=True.
1097
1097
1098 Returns
1098 Returns
1099 -------
1099 -------
1100
1100
1101 if block=False
1101 if block=False
1102 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1102 An :class:`~ipython_parallel.client.asyncresult.AsyncMapResult` instance.
1103 An object like AsyncResult, but which reassembles the sequence of results
1103 An object like AsyncResult, but which reassembles the sequence of results
1104 into a single list. AsyncMapResults can be iterated through before all
1104 into a single list. AsyncMapResults can be iterated through before all
1105 results are complete.
1105 results are complete.
1106 else
1106 else
1107 A list, the result of ``map(f,*sequences)``
1107 A list, the result of ``map(f,*sequences)``
1108 """
1108 """
1109
1109
1110 # default
1110 # default
1111 block = kwargs.get('block', self.block)
1111 block = kwargs.get('block', self.block)
1112 chunksize = kwargs.get('chunksize', 1)
1112 chunksize = kwargs.get('chunksize', 1)
1113 ordered = kwargs.get('ordered', True)
1113 ordered = kwargs.get('ordered', True)
1114
1114
1115 keyset = set(kwargs.keys())
1115 keyset = set(kwargs.keys())
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1116 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1117 if extra_keys:
1117 if extra_keys:
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1118 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1119
1119
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1120 assert len(sequences) > 0, "must have some sequences to map onto!"
1121
1121
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1122 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1123 return pf.map(*sequences)
1123 return pf.map(*sequences)
1124
1124
1125 __all__ = ['LoadBalancedView', 'DirectView']
1125 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,229 +1,229 b''
1 """Dependency utilities
1 """Dependency utilities
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2013 The IPython Development Team
8 # Copyright (C) 2013 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 from types import ModuleType
14 from types import ModuleType
15
15
16 from ipython_parallel.client.asyncresult import AsyncResult
16 from ipython_parallel.client.asyncresult import AsyncResult
17 from ipython_parallel.error import UnmetDependency
17 from ipython_parallel.error import UnmetDependency
18 from ipython_parallel.util import interactive
18 from ipython_parallel.util import interactive
19 from IPython.utils import py3compat
19 from IPython.utils import py3compat
20 from IPython.utils.py3compat import string_types
20 from IPython.utils.py3compat import string_types
21 from IPython.utils.pickleutil import can, uncan
21 from ipython_kernel.pickleutil import can, uncan
22
22
23 class depend(object):
23 class depend(object):
24 """Dependency decorator, for use with tasks.
24 """Dependency decorator, for use with tasks.
25
25
26 `@depend` lets you define a function for engine dependencies
26 `@depend` lets you define a function for engine dependencies
27 just like you use `apply` for tasks.
27 just like you use `apply` for tasks.
28
28
29
29
30 Examples
30 Examples
31 --------
31 --------
32 ::
32 ::
33
33
34 @depend(df, a,b, c=5)
34 @depend(df, a,b, c=5)
35 def f(m,n,p)
35 def f(m,n,p)
36
36
37 view.apply(f, 1,2,3)
37 view.apply(f, 1,2,3)
38
38
39 will call df(a,b,c=5) on the engine, and if it returns False or
39 will call df(a,b,c=5) on the engine, and if it returns False or
40 raises an UnmetDependency error, then the task will not be run
40 raises an UnmetDependency error, then the task will not be run
41 and another engine will be tried.
41 and another engine will be tried.
42 """
42 """
43 def __init__(self, _wrapped_f, *args, **kwargs):
43 def __init__(self, _wrapped_f, *args, **kwargs):
44 self.f = _wrapped_f
44 self.f = _wrapped_f
45 self.args = args
45 self.args = args
46 self.kwargs = kwargs
46 self.kwargs = kwargs
47
47
48 def __call__(self, f):
48 def __call__(self, f):
49 return dependent(f, self.f, *self.args, **self.kwargs)
49 return dependent(f, self.f, *self.args, **self.kwargs)
50
50
51 class dependent(object):
51 class dependent(object):
52 """A function that depends on another function.
52 """A function that depends on another function.
53 This is an object to prevent the closure used
53 This is an object to prevent the closure used
54 in traditional decorators, which are not picklable.
54 in traditional decorators, which are not picklable.
55 """
55 """
56
56
57 def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs):
57 def __init__(self, _wrapped_f, _wrapped_df, *dargs, **dkwargs):
58 self.f = _wrapped_f
58 self.f = _wrapped_f
59 name = getattr(_wrapped_f, '__name__', 'f')
59 name = getattr(_wrapped_f, '__name__', 'f')
60 if py3compat.PY3:
60 if py3compat.PY3:
61 self.__name__ = name
61 self.__name__ = name
62 else:
62 else:
63 self.func_name = name
63 self.func_name = name
64 self.df = _wrapped_df
64 self.df = _wrapped_df
65 self.dargs = dargs
65 self.dargs = dargs
66 self.dkwargs = dkwargs
66 self.dkwargs = dkwargs
67
67
68 def check_dependency(self):
68 def check_dependency(self):
69 if self.df(*self.dargs, **self.dkwargs) is False:
69 if self.df(*self.dargs, **self.dkwargs) is False:
70 raise UnmetDependency()
70 raise UnmetDependency()
71
71
72 def __call__(self, *args, **kwargs):
72 def __call__(self, *args, **kwargs):
73 return self.f(*args, **kwargs)
73 return self.f(*args, **kwargs)
74
74
75 if not py3compat.PY3:
75 if not py3compat.PY3:
76 @property
76 @property
77 def __name__(self):
77 def __name__(self):
78 return self.func_name
78 return self.func_name
79
79
80 @interactive
80 @interactive
81 def _require(*modules, **mapping):
81 def _require(*modules, **mapping):
82 """Helper for @require decorator."""
82 """Helper for @require decorator."""
83 from ipython_parallel.error import UnmetDependency
83 from ipython_parallel.error import UnmetDependency
84 from IPython.utils.pickleutil import uncan
84 from ipython_kernel.pickleutil import uncan
85 user_ns = globals()
85 user_ns = globals()
86 for name in modules:
86 for name in modules:
87 try:
87 try:
88 exec('import %s' % name, user_ns)
88 exec('import %s' % name, user_ns)
89 except ImportError:
89 except ImportError:
90 raise UnmetDependency(name)
90 raise UnmetDependency(name)
91
91
92 for name, cobj in mapping.items():
92 for name, cobj in mapping.items():
93 user_ns[name] = uncan(cobj, user_ns)
93 user_ns[name] = uncan(cobj, user_ns)
94 return True
94 return True
95
95
96 def require(*objects, **mapping):
96 def require(*objects, **mapping):
97 """Simple decorator for requiring local objects and modules to be available
97 """Simple decorator for requiring local objects and modules to be available
98 when the decorated function is called on the engine.
98 when the decorated function is called on the engine.
99
99
100 Modules specified by name or passed directly will be imported
100 Modules specified by name or passed directly will be imported
101 prior to calling the decorated function.
101 prior to calling the decorated function.
102
102
103 Objects other than modules will be pushed as a part of the task.
103 Objects other than modules will be pushed as a part of the task.
104 Functions can be passed positionally,
104 Functions can be passed positionally,
105 and will be pushed to the engine with their __name__.
105 and will be pushed to the engine with their __name__.
106 Other objects can be passed by keyword arg.
106 Other objects can be passed by keyword arg.
107
107
108 Examples::
108 Examples::
109
109
110 In [1]: @require('numpy')
110 In [1]: @require('numpy')
111 ...: def norm(a):
111 ...: def norm(a):
112 ...: return numpy.linalg.norm(a,2)
112 ...: return numpy.linalg.norm(a,2)
113
113
114 In [2]: foo = lambda x: x*x
114 In [2]: foo = lambda x: x*x
115 In [3]: @require(foo)
115 In [3]: @require(foo)
116 ...: def bar(a):
116 ...: def bar(a):
117 ...: return foo(1-a)
117 ...: return foo(1-a)
118 """
118 """
119 names = []
119 names = []
120 for obj in objects:
120 for obj in objects:
121 if isinstance(obj, ModuleType):
121 if isinstance(obj, ModuleType):
122 obj = obj.__name__
122 obj = obj.__name__
123
123
124 if isinstance(obj, string_types):
124 if isinstance(obj, string_types):
125 names.append(obj)
125 names.append(obj)
126 elif hasattr(obj, '__name__'):
126 elif hasattr(obj, '__name__'):
127 mapping[obj.__name__] = obj
127 mapping[obj.__name__] = obj
128 else:
128 else:
129 raise TypeError("Objects other than modules and functions "
129 raise TypeError("Objects other than modules and functions "
130 "must be passed by kwarg, but got: %s" % type(obj)
130 "must be passed by kwarg, but got: %s" % type(obj)
131 )
131 )
132
132
133 for name, obj in mapping.items():
133 for name, obj in mapping.items():
134 mapping[name] = can(obj)
134 mapping[name] = can(obj)
135 return depend(_require, *names, **mapping)
135 return depend(_require, *names, **mapping)
136
136
137 class Dependency(set):
137 class Dependency(set):
138 """An object for representing a set of msg_id dependencies.
138 """An object for representing a set of msg_id dependencies.
139
139
140 Subclassed from set().
140 Subclassed from set().
141
141
142 Parameters
142 Parameters
143 ----------
143 ----------
144 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
144 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
145 The msg_ids to depend on
145 The msg_ids to depend on
146 all : bool [default True]
146 all : bool [default True]
147 Whether the dependency should be considered met when *all* depending tasks have completed
147 Whether the dependency should be considered met when *all* depending tasks have completed
148 or only when *any* have been completed.
148 or only when *any* have been completed.
149 success : bool [default True]
149 success : bool [default True]
150 Whether to consider successes as fulfilling dependencies.
150 Whether to consider successes as fulfilling dependencies.
151 failure : bool [default False]
151 failure : bool [default False]
152 Whether to consider failures as fulfilling dependencies.
152 Whether to consider failures as fulfilling dependencies.
153
153
154 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
154 If `all=success=True` and `failure=False`, then the task will fail with an ImpossibleDependency
155 as soon as the first depended-upon task fails.
155 as soon as the first depended-upon task fails.
156 """
156 """
157
157
158 all=True
158 all=True
159 success=True
159 success=True
160 failure=True
160 failure=True
161
161
162 def __init__(self, dependencies=[], all=True, success=True, failure=False):
162 def __init__(self, dependencies=[], all=True, success=True, failure=False):
163 if isinstance(dependencies, dict):
163 if isinstance(dependencies, dict):
164 # load from dict
164 # load from dict
165 all = dependencies.get('all', True)
165 all = dependencies.get('all', True)
166 success = dependencies.get('success', success)
166 success = dependencies.get('success', success)
167 failure = dependencies.get('failure', failure)
167 failure = dependencies.get('failure', failure)
168 dependencies = dependencies.get('dependencies', [])
168 dependencies = dependencies.get('dependencies', [])
169 ids = []
169 ids = []
170
170
171 # extract ids from various sources:
171 # extract ids from various sources:
172 if isinstance(dependencies, string_types + (AsyncResult,)):
172 if isinstance(dependencies, string_types + (AsyncResult,)):
173 dependencies = [dependencies]
173 dependencies = [dependencies]
174 for d in dependencies:
174 for d in dependencies:
175 if isinstance(d, string_types):
175 if isinstance(d, string_types):
176 ids.append(d)
176 ids.append(d)
177 elif isinstance(d, AsyncResult):
177 elif isinstance(d, AsyncResult):
178 ids.extend(d.msg_ids)
178 ids.extend(d.msg_ids)
179 else:
179 else:
180 raise TypeError("invalid dependency type: %r"%type(d))
180 raise TypeError("invalid dependency type: %r"%type(d))
181
181
182 set.__init__(self, ids)
182 set.__init__(self, ids)
183 self.all = all
183 self.all = all
184 if not (success or failure):
184 if not (success or failure):
185 raise ValueError("Must depend on at least one of successes or failures!")
185 raise ValueError("Must depend on at least one of successes or failures!")
186 self.success=success
186 self.success=success
187 self.failure = failure
187 self.failure = failure
188
188
189 def check(self, completed, failed=None):
189 def check(self, completed, failed=None):
190 """check whether our dependencies have been met."""
190 """check whether our dependencies have been met."""
191 if len(self) == 0:
191 if len(self) == 0:
192 return True
192 return True
193 against = set()
193 against = set()
194 if self.success:
194 if self.success:
195 against = completed
195 against = completed
196 if failed is not None and self.failure:
196 if failed is not None and self.failure:
197 against = against.union(failed)
197 against = against.union(failed)
198 if self.all:
198 if self.all:
199 return self.issubset(against)
199 return self.issubset(against)
200 else:
200 else:
201 return not self.isdisjoint(against)
201 return not self.isdisjoint(against)
202
202
203 def unreachable(self, completed, failed=None):
203 def unreachable(self, completed, failed=None):
204 """return whether this dependency has become impossible."""
204 """return whether this dependency has become impossible."""
205 if len(self) == 0:
205 if len(self) == 0:
206 return False
206 return False
207 against = set()
207 against = set()
208 if not self.success:
208 if not self.success:
209 against = completed
209 against = completed
210 if failed is not None and not self.failure:
210 if failed is not None and not self.failure:
211 against = against.union(failed)
211 against = against.union(failed)
212 if self.all:
212 if self.all:
213 return not self.isdisjoint(against)
213 return not self.isdisjoint(against)
214 else:
214 else:
215 return self.issubset(against)
215 return self.issubset(against)
216
216
217
217
218 def as_dict(self):
218 def as_dict(self):
219 """Represent this dependency as a dict. For json compatibility."""
219 """Represent this dependency as a dict. For json compatibility."""
220 return dict(
220 return dict(
221 dependencies=list(self),
221 dependencies=list(self),
222 all=self.all,
222 all=self.all,
223 success=self.success,
223 success=self.success,
224 failure=self.failure
224 failure=self.failure
225 )
225 )
226
226
227
227
228 __all__ = ['depend', 'require', 'dependent', 'Dependency']
228 __all__ = ['depend', 'require', 'dependent', 'Dependency']
229
229
@@ -1,136 +1,136 b''
1 """Tests for dependency.py
1 """Tests for dependency.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 __docformat__ = "restructuredtext en"
8 __docformat__ = "restructuredtext en"
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16
16
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-------------------------------------------------------------------------------
19 #-------------------------------------------------------------------------------
20
20
21 # import
21 # import
22 import os
22 import os
23
23
24 from IPython.utils.pickleutil import can, uncan
24 from ipython_kernel.pickleutil import can, uncan
25
25
26 import ipython_parallel as pmod
26 import ipython_parallel as pmod
27 from ipython_parallel.util import interactive
27 from ipython_parallel.util import interactive
28
28
29 from ipython_parallel.tests import add_engines
29 from ipython_parallel.tests import add_engines
30 from .clienttest import ClusterTestCase
30 from .clienttest import ClusterTestCase
31
31
32 def setup():
32 def setup():
33 add_engines(1, total=True)
33 add_engines(1, total=True)
34
34
35 @pmod.require('time')
35 @pmod.require('time')
36 def wait(n):
36 def wait(n):
37 time.sleep(n)
37 time.sleep(n)
38 return n
38 return n
39
39
40 @pmod.interactive
40 @pmod.interactive
41 def func(x):
41 def func(x):
42 return x*x
42 return x*x
43
43
44 mixed = list(map(str, range(10)))
44 mixed = list(map(str, range(10)))
45 completed = list(map(str, range(0,10,2)))
45 completed = list(map(str, range(0,10,2)))
46 failed = list(map(str, range(1,10,2)))
46 failed = list(map(str, range(1,10,2)))
47
47
48 class DependencyTest(ClusterTestCase):
48 class DependencyTest(ClusterTestCase):
49
49
50 def setUp(self):
50 def setUp(self):
51 ClusterTestCase.setUp(self)
51 ClusterTestCase.setUp(self)
52 self.user_ns = {'__builtins__' : __builtins__}
52 self.user_ns = {'__builtins__' : __builtins__}
53 self.view = self.client.load_balanced_view()
53 self.view = self.client.load_balanced_view()
54 self.dview = self.client[-1]
54 self.dview = self.client[-1]
55 self.succeeded = set(map(str, range(0,25,2)))
55 self.succeeded = set(map(str, range(0,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
56 self.failed = set(map(str, range(1,25,2)))
57
57
58 def assertMet(self, dep):
58 def assertMet(self, dep):
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
59 self.assertTrue(dep.check(self.succeeded, self.failed), "Dependency should be met")
60
60
61 def assertUnmet(self, dep):
61 def assertUnmet(self, dep):
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
62 self.assertFalse(dep.check(self.succeeded, self.failed), "Dependency should not be met")
63
63
64 def assertUnreachable(self, dep):
64 def assertUnreachable(self, dep):
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
65 self.assertTrue(dep.unreachable(self.succeeded, self.failed), "Dependency should be unreachable")
66
66
67 def assertReachable(self, dep):
67 def assertReachable(self, dep):
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
68 self.assertFalse(dep.unreachable(self.succeeded, self.failed), "Dependency should be reachable")
69
69
70 def cancan(self, f):
70 def cancan(self, f):
71 """decorator to pass through canning into self.user_ns"""
71 """decorator to pass through canning into self.user_ns"""
72 return uncan(can(f), self.user_ns)
72 return uncan(can(f), self.user_ns)
73
73
74 def test_require_imports(self):
74 def test_require_imports(self):
75 """test that @require imports names"""
75 """test that @require imports names"""
76 @self.cancan
76 @self.cancan
77 @pmod.require('base64')
77 @pmod.require('base64')
78 @interactive
78 @interactive
79 def encode(arg):
79 def encode(arg):
80 return base64.b64encode(arg)
80 return base64.b64encode(arg)
81 # must pass through canning to properly connect namespaces
81 # must pass through canning to properly connect namespaces
82 self.assertEqual(encode(b'foo'), b'Zm9v')
82 self.assertEqual(encode(b'foo'), b'Zm9v')
83
83
84 def test_success_only(self):
84 def test_success_only(self):
85 dep = pmod.Dependency(mixed, success=True, failure=False)
85 dep = pmod.Dependency(mixed, success=True, failure=False)
86 self.assertUnmet(dep)
86 self.assertUnmet(dep)
87 self.assertUnreachable(dep)
87 self.assertUnreachable(dep)
88 dep.all=False
88 dep.all=False
89 self.assertMet(dep)
89 self.assertMet(dep)
90 self.assertReachable(dep)
90 self.assertReachable(dep)
91 dep = pmod.Dependency(completed, success=True, failure=False)
91 dep = pmod.Dependency(completed, success=True, failure=False)
92 self.assertMet(dep)
92 self.assertMet(dep)
93 self.assertReachable(dep)
93 self.assertReachable(dep)
94 dep.all=False
94 dep.all=False
95 self.assertMet(dep)
95 self.assertMet(dep)
96 self.assertReachable(dep)
96 self.assertReachable(dep)
97
97
98 def test_failure_only(self):
98 def test_failure_only(self):
99 dep = pmod.Dependency(mixed, success=False, failure=True)
99 dep = pmod.Dependency(mixed, success=False, failure=True)
100 self.assertUnmet(dep)
100 self.assertUnmet(dep)
101 self.assertUnreachable(dep)
101 self.assertUnreachable(dep)
102 dep.all=False
102 dep.all=False
103 self.assertMet(dep)
103 self.assertMet(dep)
104 self.assertReachable(dep)
104 self.assertReachable(dep)
105 dep = pmod.Dependency(completed, success=False, failure=True)
105 dep = pmod.Dependency(completed, success=False, failure=True)
106 self.assertUnmet(dep)
106 self.assertUnmet(dep)
107 self.assertUnreachable(dep)
107 self.assertUnreachable(dep)
108 dep.all=False
108 dep.all=False
109 self.assertUnmet(dep)
109 self.assertUnmet(dep)
110 self.assertUnreachable(dep)
110 self.assertUnreachable(dep)
111
111
112 def test_require_function(self):
112 def test_require_function(self):
113
113
114 @pmod.interactive
114 @pmod.interactive
115 def bar(a):
115 def bar(a):
116 return func(a)
116 return func(a)
117
117
118 @pmod.require(func)
118 @pmod.require(func)
119 @pmod.interactive
119 @pmod.interactive
120 def bar2(a):
120 def bar2(a):
121 return func(a)
121 return func(a)
122
122
123 self.client[:].clear()
123 self.client[:].clear()
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
124 self.assertRaisesRemote(NameError, self.view.apply_sync, bar, 5)
125 ar = self.view.apply_async(bar2, 5)
125 ar = self.view.apply_async(bar2, 5)
126 self.assertEqual(ar.get(5), func(5))
126 self.assertEqual(ar.get(5), func(5))
127
127
128 def test_require_object(self):
128 def test_require_object(self):
129
129
130 @pmod.require(foo=func)
130 @pmod.require(foo=func)
131 @pmod.interactive
131 @pmod.interactive
132 def bar(a):
132 def bar(a):
133 return foo(a)
133 return foo(a)
134
134
135 ar = self.view.apply_async(bar, 5)
135 ar = self.view.apply_async(bar, 5)
136 self.assertEqual(ar.get(5), func(5))
136 self.assertEqual(ar.get(5), func(5))
@@ -1,884 +1,889 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # py3
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 except AttributeError:
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36
37 try:
32 # We are using compare_digest to limit the surface of timing attacks
38 # We are using compare_digest to limit the surface of timing attacks
33 from hmac import compare_digest
39 from hmac import compare_digest
34 except ImportError:
40 except ImportError:
35 # Python < 2.7.7: When digests don't match no feedback is provided,
41 # Python < 2.7.7: When digests don't match no feedback is provided,
36 # limiting the surface of attack
42 # limiting the surface of attack
37 def compare_digest(a,b): return a == b
43 def compare_digest(a,b): return a == b
38
44
39 import zmq
45 import zmq
40 from zmq.utils import jsonapi
46 from zmq.utils import jsonapi
41 from zmq.eventloop.ioloop import IOLoop
47 from zmq.eventloop.ioloop import IOLoop
42 from zmq.eventloop.zmqstream import ZMQStream
48 from zmq.eventloop.zmqstream import ZMQStream
43
49
44 from IPython.core.release import kernel_protocol_version
50 from IPython.core.release import kernel_protocol_version
45 from IPython.config.configurable import Configurable, LoggingConfigurable
51 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.utils import io
52 from IPython.utils import io
47 from IPython.utils.importstring import import_item
53 from IPython.utils.importstring import import_item
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
54 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
55 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
50 iteritems)
56 iteritems)
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
57 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode, Dict, Integer,
58 DottedObjectName, CUnicode, Dict, Integer,
53 TraitError,
59 TraitError,
54 )
60 )
55 from IPython.utils.pickleutil import PICKLE_PROTOCOL
56 from jupyter_client.adapter import adapt
61 from jupyter_client.adapter import adapt
57
62
58 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
59 # utility functions
64 # utility functions
60 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
61
66
62 def squash_unicode(obj):
67 def squash_unicode(obj):
63 """coerce unicode back to bytestrings."""
68 """coerce unicode back to bytestrings."""
64 if isinstance(obj,dict):
69 if isinstance(obj,dict):
65 for key in obj.keys():
70 for key in obj.keys():
66 obj[key] = squash_unicode(obj[key])
71 obj[key] = squash_unicode(obj[key])
67 if isinstance(key, unicode_type):
72 if isinstance(key, unicode_type):
68 obj[squash_unicode(key)] = obj.pop(key)
73 obj[squash_unicode(key)] = obj.pop(key)
69 elif isinstance(obj, list):
74 elif isinstance(obj, list):
70 for i,v in enumerate(obj):
75 for i,v in enumerate(obj):
71 obj[i] = squash_unicode(v)
76 obj[i] = squash_unicode(v)
72 elif isinstance(obj, unicode_type):
77 elif isinstance(obj, unicode_type):
73 obj = obj.encode('utf8')
78 obj = obj.encode('utf8')
74 return obj
79 return obj
75
80
76 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
77 # globals and defaults
82 # globals and defaults
78 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
79
84
80 # default values for the thresholds:
85 # default values for the thresholds:
81 MAX_ITEMS = 64
86 MAX_ITEMS = 64
82 MAX_BYTES = 1024
87 MAX_BYTES = 1024
83
88
84 # ISO8601-ify datetime objects
89 # ISO8601-ify datetime objects
85 # allow unicode
90 # allow unicode
86 # disallow nan, because it's not actually valid JSON
91 # disallow nan, because it's not actually valid JSON
87 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
92 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
88 ensure_ascii=False, allow_nan=False,
93 ensure_ascii=False, allow_nan=False,
89 )
94 )
90 json_unpacker = lambda s: jsonapi.loads(s)
95 json_unpacker = lambda s: jsonapi.loads(s)
91
96
92 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
97 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
93 pickle_unpacker = pickle.loads
98 pickle_unpacker = pickle.loads
94
99
95 default_packer = json_packer
100 default_packer = json_packer
96 default_unpacker = json_unpacker
101 default_unpacker = json_unpacker
97
102
98 DELIM = b"<IDS|MSG>"
103 DELIM = b"<IDS|MSG>"
99 # singleton dummy tracker, which will always report as done
104 # singleton dummy tracker, which will always report as done
100 DONE = zmq.MessageTracker()
105 DONE = zmq.MessageTracker()
101
106
102 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
103 # Mixin tools for apps that use Sessions
108 # Mixin tools for apps that use Sessions
104 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
105
110
106 session_aliases = dict(
111 session_aliases = dict(
107 ident = 'Session.session',
112 ident = 'Session.session',
108 user = 'Session.username',
113 user = 'Session.username',
109 keyfile = 'Session.keyfile',
114 keyfile = 'Session.keyfile',
110 )
115 )
111
116
112 session_flags = {
117 session_flags = {
113 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
118 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
114 'keyfile' : '' }},
119 'keyfile' : '' }},
115 """Use HMAC digests for authentication of messages.
120 """Use HMAC digests for authentication of messages.
116 Setting this flag will generate a new UUID to use as the HMAC key.
121 Setting this flag will generate a new UUID to use as the HMAC key.
117 """),
122 """),
118 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
123 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
119 """Don't authenticate messages."""),
124 """Don't authenticate messages."""),
120 }
125 }
121
126
122 def default_secure(cfg):
127 def default_secure(cfg):
123 """Set the default behavior for a config environment to be secure.
128 """Set the default behavior for a config environment to be secure.
124
129
125 If Session.key/keyfile have not been set, set Session.key to
130 If Session.key/keyfile have not been set, set Session.key to
126 a new random UUID.
131 a new random UUID.
127 """
132 """
128 warnings.warn("default_secure is deprecated", DeprecationWarning)
133 warnings.warn("default_secure is deprecated", DeprecationWarning)
129 if 'Session' in cfg:
134 if 'Session' in cfg:
130 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
135 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
131 return
136 return
132 # key/keyfile not specified, generate new UUID:
137 # key/keyfile not specified, generate new UUID:
133 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
138 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
134
139
135
140
136 #-----------------------------------------------------------------------------
141 #-----------------------------------------------------------------------------
137 # Classes
142 # Classes
138 #-----------------------------------------------------------------------------
143 #-----------------------------------------------------------------------------
139
144
140 class SessionFactory(LoggingConfigurable):
145 class SessionFactory(LoggingConfigurable):
141 """The Base class for configurables that have a Session, Context, logger,
146 """The Base class for configurables that have a Session, Context, logger,
142 and IOLoop.
147 and IOLoop.
143 """
148 """
144
149
145 logname = Unicode('')
150 logname = Unicode('')
146 def _logname_changed(self, name, old, new):
151 def _logname_changed(self, name, old, new):
147 self.log = logging.getLogger(new)
152 self.log = logging.getLogger(new)
148
153
149 # not configurable:
154 # not configurable:
150 context = Instance('zmq.Context')
155 context = Instance('zmq.Context')
151 def _context_default(self):
156 def _context_default(self):
152 return zmq.Context.instance()
157 return zmq.Context.instance()
153
158
154 session = Instance('jupyter_client.session.Session',
159 session = Instance('jupyter_client.session.Session',
155 allow_none=True)
160 allow_none=True)
156
161
157 loop = Instance('zmq.eventloop.ioloop.IOLoop')
162 loop = Instance('zmq.eventloop.ioloop.IOLoop')
158 def _loop_default(self):
163 def _loop_default(self):
159 return IOLoop.instance()
164 return IOLoop.instance()
160
165
161 def __init__(self, **kwargs):
166 def __init__(self, **kwargs):
162 super(SessionFactory, self).__init__(**kwargs)
167 super(SessionFactory, self).__init__(**kwargs)
163
168
164 if self.session is None:
169 if self.session is None:
165 # construct the session
170 # construct the session
166 self.session = Session(**kwargs)
171 self.session = Session(**kwargs)
167
172
168
173
169 class Message(object):
174 class Message(object):
170 """A simple message object that maps dict keys to attributes.
175 """A simple message object that maps dict keys to attributes.
171
176
172 A Message can be created from a dict and a dict from a Message instance
177 A Message can be created from a dict and a dict from a Message instance
173 simply by calling dict(msg_obj)."""
178 simply by calling dict(msg_obj)."""
174
179
175 def __init__(self, msg_dict):
180 def __init__(self, msg_dict):
176 dct = self.__dict__
181 dct = self.__dict__
177 for k, v in iteritems(dict(msg_dict)):
182 for k, v in iteritems(dict(msg_dict)):
178 if isinstance(v, dict):
183 if isinstance(v, dict):
179 v = Message(v)
184 v = Message(v)
180 dct[k] = v
185 dct[k] = v
181
186
182 # Having this iterator lets dict(msg_obj) work out of the box.
187 # Having this iterator lets dict(msg_obj) work out of the box.
183 def __iter__(self):
188 def __iter__(self):
184 return iter(iteritems(self.__dict__))
189 return iter(iteritems(self.__dict__))
185
190
186 def __repr__(self):
191 def __repr__(self):
187 return repr(self.__dict__)
192 return repr(self.__dict__)
188
193
189 def __str__(self):
194 def __str__(self):
190 return pprint.pformat(self.__dict__)
195 return pprint.pformat(self.__dict__)
191
196
192 def __contains__(self, k):
197 def __contains__(self, k):
193 return k in self.__dict__
198 return k in self.__dict__
194
199
195 def __getitem__(self, k):
200 def __getitem__(self, k):
196 return self.__dict__[k]
201 return self.__dict__[k]
197
202
198
203
199 def msg_header(msg_id, msg_type, username, session):
204 def msg_header(msg_id, msg_type, username, session):
200 date = datetime.now()
205 date = datetime.now()
201 version = kernel_protocol_version
206 version = kernel_protocol_version
202 return locals()
207 return locals()
203
208
204 def extract_header(msg_or_header):
209 def extract_header(msg_or_header):
205 """Given a message or header, return the header."""
210 """Given a message or header, return the header."""
206 if not msg_or_header:
211 if not msg_or_header:
207 return {}
212 return {}
208 try:
213 try:
209 # See if msg_or_header is the entire message.
214 # See if msg_or_header is the entire message.
210 h = msg_or_header['header']
215 h = msg_or_header['header']
211 except KeyError:
216 except KeyError:
212 try:
217 try:
213 # See if msg_or_header is just the header
218 # See if msg_or_header is just the header
214 h = msg_or_header['msg_id']
219 h = msg_or_header['msg_id']
215 except KeyError:
220 except KeyError:
216 raise
221 raise
217 else:
222 else:
218 h = msg_or_header
223 h = msg_or_header
219 if not isinstance(h, dict):
224 if not isinstance(h, dict):
220 h = dict(h)
225 h = dict(h)
221 return h
226 return h
222
227
223 class Session(Configurable):
228 class Session(Configurable):
224 """Object for handling serialization and sending of messages.
229 """Object for handling serialization and sending of messages.
225
230
226 The Session object handles building messages and sending them
231 The Session object handles building messages and sending them
227 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
232 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
228 other over the network via Session objects, and only need to work with the
233 other over the network via Session objects, and only need to work with the
229 dict-based IPython message spec. The Session will handle
234 dict-based IPython message spec. The Session will handle
230 serialization/deserialization, security, and metadata.
235 serialization/deserialization, security, and metadata.
231
236
232 Sessions support configurable serialization via packer/unpacker traits,
237 Sessions support configurable serialization via packer/unpacker traits,
233 and signing with HMAC digests via the key/keyfile traits.
238 and signing with HMAC digests via the key/keyfile traits.
234
239
235 Parameters
240 Parameters
236 ----------
241 ----------
237
242
238 debug : bool
243 debug : bool
239 whether to trigger extra debugging statements
244 whether to trigger extra debugging statements
240 packer/unpacker : str : 'json', 'pickle' or import_string
245 packer/unpacker : str : 'json', 'pickle' or import_string
241 importstrings for methods to serialize message parts. If just
246 importstrings for methods to serialize message parts. If just
242 'json' or 'pickle', predefined JSON and pickle packers will be used.
247 'json' or 'pickle', predefined JSON and pickle packers will be used.
243 Otherwise, the entire importstring must be used.
248 Otherwise, the entire importstring must be used.
244
249
245 The functions must accept at least valid JSON input, and output *bytes*.
250 The functions must accept at least valid JSON input, and output *bytes*.
246
251
247 For example, to use msgpack:
252 For example, to use msgpack:
248 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
253 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
249 pack/unpack : callables
254 pack/unpack : callables
250 You can also set the pack/unpack callables for serialization directly.
255 You can also set the pack/unpack callables for serialization directly.
251 session : bytes
256 session : bytes
252 the ID of this Session object. The default is to generate a new UUID.
257 the ID of this Session object. The default is to generate a new UUID.
253 username : unicode
258 username : unicode
254 username added to message headers. The default is to ask the OS.
259 username added to message headers. The default is to ask the OS.
255 key : bytes
260 key : bytes
256 The key used to initialize an HMAC signature. If unset, messages
261 The key used to initialize an HMAC signature. If unset, messages
257 will not be signed or checked.
262 will not be signed or checked.
258 keyfile : filepath
263 keyfile : filepath
259 The file containing a key. If this is set, `key` will be initialized
264 The file containing a key. If this is set, `key` will be initialized
260 to the contents of the file.
265 to the contents of the file.
261
266
262 """
267 """
263
268
264 debug=Bool(False, config=True, help="""Debug output in the Session""")
269 debug=Bool(False, config=True, help="""Debug output in the Session""")
265
270
266 packer = DottedObjectName('json',config=True,
271 packer = DottedObjectName('json',config=True,
267 help="""The name of the packer for serializing messages.
272 help="""The name of the packer for serializing messages.
268 Should be one of 'json', 'pickle', or an import name
273 Should be one of 'json', 'pickle', or an import name
269 for a custom callable serializer.""")
274 for a custom callable serializer.""")
270 def _packer_changed(self, name, old, new):
275 def _packer_changed(self, name, old, new):
271 if new.lower() == 'json':
276 if new.lower() == 'json':
272 self.pack = json_packer
277 self.pack = json_packer
273 self.unpack = json_unpacker
278 self.unpack = json_unpacker
274 self.unpacker = new
279 self.unpacker = new
275 elif new.lower() == 'pickle':
280 elif new.lower() == 'pickle':
276 self.pack = pickle_packer
281 self.pack = pickle_packer
277 self.unpack = pickle_unpacker
282 self.unpack = pickle_unpacker
278 self.unpacker = new
283 self.unpacker = new
279 else:
284 else:
280 self.pack = import_item(str(new))
285 self.pack = import_item(str(new))
281
286
282 unpacker = DottedObjectName('json', config=True,
287 unpacker = DottedObjectName('json', config=True,
283 help="""The name of the unpacker for unserializing messages.
288 help="""The name of the unpacker for unserializing messages.
284 Only used with custom functions for `packer`.""")
289 Only used with custom functions for `packer`.""")
285 def _unpacker_changed(self, name, old, new):
290 def _unpacker_changed(self, name, old, new):
286 if new.lower() == 'json':
291 if new.lower() == 'json':
287 self.pack = json_packer
292 self.pack = json_packer
288 self.unpack = json_unpacker
293 self.unpack = json_unpacker
289 self.packer = new
294 self.packer = new
290 elif new.lower() == 'pickle':
295 elif new.lower() == 'pickle':
291 self.pack = pickle_packer
296 self.pack = pickle_packer
292 self.unpack = pickle_unpacker
297 self.unpack = pickle_unpacker
293 self.packer = new
298 self.packer = new
294 else:
299 else:
295 self.unpack = import_item(str(new))
300 self.unpack = import_item(str(new))
296
301
297 session = CUnicode(u'', config=True,
302 session = CUnicode(u'', config=True,
298 help="""The UUID identifying this session.""")
303 help="""The UUID identifying this session.""")
299 def _session_default(self):
304 def _session_default(self):
300 u = unicode_type(uuid.uuid4())
305 u = unicode_type(uuid.uuid4())
301 self.bsession = u.encode('ascii')
306 self.bsession = u.encode('ascii')
302 return u
307 return u
303
308
304 def _session_changed(self, name, old, new):
309 def _session_changed(self, name, old, new):
305 self.bsession = self.session.encode('ascii')
310 self.bsession = self.session.encode('ascii')
306
311
307 # bsession is the session as bytes
312 # bsession is the session as bytes
308 bsession = CBytes(b'')
313 bsession = CBytes(b'')
309
314
310 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
315 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
311 help="""Username for the Session. Default is your system username.""",
316 help="""Username for the Session. Default is your system username.""",
312 config=True)
317 config=True)
313
318
314 metadata = Dict({}, config=True,
319 metadata = Dict({}, config=True,
315 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
320 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
316
321
317 # if 0, no adapting to do.
322 # if 0, no adapting to do.
318 adapt_version = Integer(0)
323 adapt_version = Integer(0)
319
324
320 # message signature related traits:
325 # message signature related traits:
321
326
322 key = CBytes(config=True,
327 key = CBytes(config=True,
323 help="""execution key, for signing messages.""")
328 help="""execution key, for signing messages.""")
324 def _key_default(self):
329 def _key_default(self):
325 return str_to_bytes(str(uuid.uuid4()))
330 return str_to_bytes(str(uuid.uuid4()))
326
331
327 def _key_changed(self):
332 def _key_changed(self):
328 self._new_auth()
333 self._new_auth()
329
334
330 signature_scheme = Unicode('hmac-sha256', config=True,
335 signature_scheme = Unicode('hmac-sha256', config=True,
331 help="""The digest scheme used to construct the message signatures.
336 help="""The digest scheme used to construct the message signatures.
332 Must have the form 'hmac-HASH'.""")
337 Must have the form 'hmac-HASH'.""")
333 def _signature_scheme_changed(self, name, old, new):
338 def _signature_scheme_changed(self, name, old, new):
334 if not new.startswith('hmac-'):
339 if not new.startswith('hmac-'):
335 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
340 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
336 hash_name = new.split('-', 1)[1]
341 hash_name = new.split('-', 1)[1]
337 try:
342 try:
338 self.digest_mod = getattr(hashlib, hash_name)
343 self.digest_mod = getattr(hashlib, hash_name)
339 except AttributeError:
344 except AttributeError:
340 raise TraitError("hashlib has no such attribute: %s" % hash_name)
345 raise TraitError("hashlib has no such attribute: %s" % hash_name)
341 self._new_auth()
346 self._new_auth()
342
347
343 digest_mod = Any()
348 digest_mod = Any()
344 def _digest_mod_default(self):
349 def _digest_mod_default(self):
345 return hashlib.sha256
350 return hashlib.sha256
346
351
347 auth = Instance(hmac.HMAC, allow_none=True)
352 auth = Instance(hmac.HMAC, allow_none=True)
348
353
349 def _new_auth(self):
354 def _new_auth(self):
350 if self.key:
355 if self.key:
351 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
356 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
352 else:
357 else:
353 self.auth = None
358 self.auth = None
354
359
355 digest_history = Set()
360 digest_history = Set()
356 digest_history_size = Integer(2**16, config=True,
361 digest_history_size = Integer(2**16, config=True,
357 help="""The maximum number of digests to remember.
362 help="""The maximum number of digests to remember.
358
363
359 The digest history will be culled when it exceeds this value.
364 The digest history will be culled when it exceeds this value.
360 """
365 """
361 )
366 )
362
367
363 keyfile = Unicode('', config=True,
368 keyfile = Unicode('', config=True,
364 help="""path to file containing execution key.""")
369 help="""path to file containing execution key.""")
365 def _keyfile_changed(self, name, old, new):
370 def _keyfile_changed(self, name, old, new):
366 with open(new, 'rb') as f:
371 with open(new, 'rb') as f:
367 self.key = f.read().strip()
372 self.key = f.read().strip()
368
373
369 # for protecting against sends from forks
374 # for protecting against sends from forks
370 pid = Integer()
375 pid = Integer()
371
376
372 # serialization traits:
377 # serialization traits:
373
378
374 pack = Any(default_packer) # the actual packer function
379 pack = Any(default_packer) # the actual packer function
375 def _pack_changed(self, name, old, new):
380 def _pack_changed(self, name, old, new):
376 if not callable(new):
381 if not callable(new):
377 raise TypeError("packer must be callable, not %s"%type(new))
382 raise TypeError("packer must be callable, not %s"%type(new))
378
383
379 unpack = Any(default_unpacker) # the actual packer function
384 unpack = Any(default_unpacker) # the actual packer function
380 def _unpack_changed(self, name, old, new):
385 def _unpack_changed(self, name, old, new):
381 # unpacker is not checked - it is assumed to be
386 # unpacker is not checked - it is assumed to be
382 if not callable(new):
387 if not callable(new):
383 raise TypeError("unpacker must be callable, not %s"%type(new))
388 raise TypeError("unpacker must be callable, not %s"%type(new))
384
389
385 # thresholds:
390 # thresholds:
386 copy_threshold = Integer(2**16, config=True,
391 copy_threshold = Integer(2**16, config=True,
387 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
392 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
388 buffer_threshold = Integer(MAX_BYTES, config=True,
393 buffer_threshold = Integer(MAX_BYTES, config=True,
389 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
394 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
390 item_threshold = Integer(MAX_ITEMS, config=True,
395 item_threshold = Integer(MAX_ITEMS, config=True,
391 help="""The maximum number of items for a container to be introspected for custom serialization.
396 help="""The maximum number of items for a container to be introspected for custom serialization.
392 Containers larger than this are pickled outright.
397 Containers larger than this are pickled outright.
393 """
398 """
394 )
399 )
395
400
396
401
397 def __init__(self, **kwargs):
402 def __init__(self, **kwargs):
398 """create a Session object
403 """create a Session object
399
404
400 Parameters
405 Parameters
401 ----------
406 ----------
402
407
403 debug : bool
408 debug : bool
404 whether to trigger extra debugging statements
409 whether to trigger extra debugging statements
405 packer/unpacker : str : 'json', 'pickle' or import_string
410 packer/unpacker : str : 'json', 'pickle' or import_string
406 importstrings for methods to serialize message parts. If just
411 importstrings for methods to serialize message parts. If just
407 'json' or 'pickle', predefined JSON and pickle packers will be used.
412 'json' or 'pickle', predefined JSON and pickle packers will be used.
408 Otherwise, the entire importstring must be used.
413 Otherwise, the entire importstring must be used.
409
414
410 The functions must accept at least valid JSON input, and output
415 The functions must accept at least valid JSON input, and output
411 *bytes*.
416 *bytes*.
412
417
413 For example, to use msgpack:
418 For example, to use msgpack:
414 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
419 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
415 pack/unpack : callables
420 pack/unpack : callables
416 You can also set the pack/unpack callables for serialization
421 You can also set the pack/unpack callables for serialization
417 directly.
422 directly.
418 session : unicode (must be ascii)
423 session : unicode (must be ascii)
419 the ID of this Session object. The default is to generate a new
424 the ID of this Session object. The default is to generate a new
420 UUID.
425 UUID.
421 bsession : bytes
426 bsession : bytes
422 The session as bytes
427 The session as bytes
423 username : unicode
428 username : unicode
424 username added to message headers. The default is to ask the OS.
429 username added to message headers. The default is to ask the OS.
425 key : bytes
430 key : bytes
426 The key used to initialize an HMAC signature. If unset, messages
431 The key used to initialize an HMAC signature. If unset, messages
427 will not be signed or checked.
432 will not be signed or checked.
428 signature_scheme : str
433 signature_scheme : str
429 The message digest scheme. Currently must be of the form 'hmac-HASH',
434 The message digest scheme. Currently must be of the form 'hmac-HASH',
430 where 'HASH' is a hashing function available in Python's hashlib.
435 where 'HASH' is a hashing function available in Python's hashlib.
431 The default is 'hmac-sha256'.
436 The default is 'hmac-sha256'.
432 This is ignored if 'key' is empty.
437 This is ignored if 'key' is empty.
433 keyfile : filepath
438 keyfile : filepath
434 The file containing a key. If this is set, `key` will be
439 The file containing a key. If this is set, `key` will be
435 initialized to the contents of the file.
440 initialized to the contents of the file.
436 """
441 """
437 super(Session, self).__init__(**kwargs)
442 super(Session, self).__init__(**kwargs)
438 self._check_packers()
443 self._check_packers()
439 self.none = self.pack({})
444 self.none = self.pack({})
440 # ensure self._session_default() if necessary, so bsession is defined:
445 # ensure self._session_default() if necessary, so bsession is defined:
441 self.session
446 self.session
442 self.pid = os.getpid()
447 self.pid = os.getpid()
443 self._new_auth()
448 self._new_auth()
444
449
445 @property
450 @property
446 def msg_id(self):
451 def msg_id(self):
447 """always return new uuid"""
452 """always return new uuid"""
448 return str(uuid.uuid4())
453 return str(uuid.uuid4())
449
454
450 def _check_packers(self):
455 def _check_packers(self):
451 """check packers for datetime support."""
456 """check packers for datetime support."""
452 pack = self.pack
457 pack = self.pack
453 unpack = self.unpack
458 unpack = self.unpack
454
459
455 # check simple serialization
460 # check simple serialization
456 msg = dict(a=[1,'hi'])
461 msg = dict(a=[1,'hi'])
457 try:
462 try:
458 packed = pack(msg)
463 packed = pack(msg)
459 except Exception as e:
464 except Exception as e:
460 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
465 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
461 if self.packer == 'json':
466 if self.packer == 'json':
462 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
467 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
463 else:
468 else:
464 jsonmsg = ""
469 jsonmsg = ""
465 raise ValueError(
470 raise ValueError(
466 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
471 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
467 )
472 )
468
473
469 # ensure packed message is bytes
474 # ensure packed message is bytes
470 if not isinstance(packed, bytes):
475 if not isinstance(packed, bytes):
471 raise ValueError("message packed to %r, but bytes are required"%type(packed))
476 raise ValueError("message packed to %r, but bytes are required"%type(packed))
472
477
473 # check that unpack is pack's inverse
478 # check that unpack is pack's inverse
474 try:
479 try:
475 unpacked = unpack(packed)
480 unpacked = unpack(packed)
476 assert unpacked == msg
481 assert unpacked == msg
477 except Exception as e:
482 except Exception as e:
478 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
483 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
479 if self.packer == 'json':
484 if self.packer == 'json':
480 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
485 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
481 else:
486 else:
482 jsonmsg = ""
487 jsonmsg = ""
483 raise ValueError(
488 raise ValueError(
484 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
489 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
485 )
490 )
486
491
487 # check datetime support
492 # check datetime support
488 msg = dict(t=datetime.now())
493 msg = dict(t=datetime.now())
489 try:
494 try:
490 unpacked = unpack(pack(msg))
495 unpacked = unpack(pack(msg))
491 if isinstance(unpacked['t'], datetime):
496 if isinstance(unpacked['t'], datetime):
492 raise ValueError("Shouldn't deserialize to datetime")
497 raise ValueError("Shouldn't deserialize to datetime")
493 except Exception:
498 except Exception:
494 self.pack = lambda o: pack(squash_dates(o))
499 self.pack = lambda o: pack(squash_dates(o))
495 self.unpack = lambda s: unpack(s)
500 self.unpack = lambda s: unpack(s)
496
501
497 def msg_header(self, msg_type):
502 def msg_header(self, msg_type):
498 return msg_header(self.msg_id, msg_type, self.username, self.session)
503 return msg_header(self.msg_id, msg_type, self.username, self.session)
499
504
500 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
505 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
501 """Return the nested message dict.
506 """Return the nested message dict.
502
507
503 This format is different from what is sent over the wire. The
508 This format is different from what is sent over the wire. The
504 serialize/deserialize methods converts this nested message dict to the wire
509 serialize/deserialize methods converts this nested message dict to the wire
505 format, which is a list of message parts.
510 format, which is a list of message parts.
506 """
511 """
507 msg = {}
512 msg = {}
508 header = self.msg_header(msg_type) if header is None else header
513 header = self.msg_header(msg_type) if header is None else header
509 msg['header'] = header
514 msg['header'] = header
510 msg['msg_id'] = header['msg_id']
515 msg['msg_id'] = header['msg_id']
511 msg['msg_type'] = header['msg_type']
516 msg['msg_type'] = header['msg_type']
512 msg['parent_header'] = {} if parent is None else extract_header(parent)
517 msg['parent_header'] = {} if parent is None else extract_header(parent)
513 msg['content'] = {} if content is None else content
518 msg['content'] = {} if content is None else content
514 msg['metadata'] = self.metadata.copy()
519 msg['metadata'] = self.metadata.copy()
515 if metadata is not None:
520 if metadata is not None:
516 msg['metadata'].update(metadata)
521 msg['metadata'].update(metadata)
517 return msg
522 return msg
518
523
519 def sign(self, msg_list):
524 def sign(self, msg_list):
520 """Sign a message with HMAC digest. If no auth, return b''.
525 """Sign a message with HMAC digest. If no auth, return b''.
521
526
522 Parameters
527 Parameters
523 ----------
528 ----------
524 msg_list : list
529 msg_list : list
525 The [p_header,p_parent,p_content] part of the message list.
530 The [p_header,p_parent,p_content] part of the message list.
526 """
531 """
527 if self.auth is None:
532 if self.auth is None:
528 return b''
533 return b''
529 h = self.auth.copy()
534 h = self.auth.copy()
530 for m in msg_list:
535 for m in msg_list:
531 h.update(m)
536 h.update(m)
532 return str_to_bytes(h.hexdigest())
537 return str_to_bytes(h.hexdigest())
533
538
534 def serialize(self, msg, ident=None):
539 def serialize(self, msg, ident=None):
535 """Serialize the message components to bytes.
540 """Serialize the message components to bytes.
536
541
537 This is roughly the inverse of deserialize. The serialize/deserialize
542 This is roughly the inverse of deserialize. The serialize/deserialize
538 methods work with full message lists, whereas pack/unpack work with
543 methods work with full message lists, whereas pack/unpack work with
539 the individual message parts in the message list.
544 the individual message parts in the message list.
540
545
541 Parameters
546 Parameters
542 ----------
547 ----------
543 msg : dict or Message
548 msg : dict or Message
544 The next message dict as returned by the self.msg method.
549 The next message dict as returned by the self.msg method.
545
550
546 Returns
551 Returns
547 -------
552 -------
548 msg_list : list
553 msg_list : list
549 The list of bytes objects to be sent with the format::
554 The list of bytes objects to be sent with the format::
550
555
551 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
556 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
552 p_metadata, p_content, buffer1, buffer2, ...]
557 p_metadata, p_content, buffer1, buffer2, ...]
553
558
554 In this list, the ``p_*`` entities are the packed or serialized
559 In this list, the ``p_*`` entities are the packed or serialized
555 versions, so if JSON is used, these are utf8 encoded JSON strings.
560 versions, so if JSON is used, these are utf8 encoded JSON strings.
556 """
561 """
557 content = msg.get('content', {})
562 content = msg.get('content', {})
558 if content is None:
563 if content is None:
559 content = self.none
564 content = self.none
560 elif isinstance(content, dict):
565 elif isinstance(content, dict):
561 content = self.pack(content)
566 content = self.pack(content)
562 elif isinstance(content, bytes):
567 elif isinstance(content, bytes):
563 # content is already packed, as in a relayed message
568 # content is already packed, as in a relayed message
564 pass
569 pass
565 elif isinstance(content, unicode_type):
570 elif isinstance(content, unicode_type):
566 # should be bytes, but JSON often spits out unicode
571 # should be bytes, but JSON often spits out unicode
567 content = content.encode('utf8')
572 content = content.encode('utf8')
568 else:
573 else:
569 raise TypeError("Content incorrect type: %s"%type(content))
574 raise TypeError("Content incorrect type: %s"%type(content))
570
575
571 real_message = [self.pack(msg['header']),
576 real_message = [self.pack(msg['header']),
572 self.pack(msg['parent_header']),
577 self.pack(msg['parent_header']),
573 self.pack(msg['metadata']),
578 self.pack(msg['metadata']),
574 content,
579 content,
575 ]
580 ]
576
581
577 to_send = []
582 to_send = []
578
583
579 if isinstance(ident, list):
584 if isinstance(ident, list):
580 # accept list of idents
585 # accept list of idents
581 to_send.extend(ident)
586 to_send.extend(ident)
582 elif ident is not None:
587 elif ident is not None:
583 to_send.append(ident)
588 to_send.append(ident)
584 to_send.append(DELIM)
589 to_send.append(DELIM)
585
590
586 signature = self.sign(real_message)
591 signature = self.sign(real_message)
587 to_send.append(signature)
592 to_send.append(signature)
588
593
589 to_send.extend(real_message)
594 to_send.extend(real_message)
590
595
591 return to_send
596 return to_send
592
597
593 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
598 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
594 buffers=None, track=False, header=None, metadata=None):
599 buffers=None, track=False, header=None, metadata=None):
595 """Build and send a message via stream or socket.
600 """Build and send a message via stream or socket.
596
601
597 The message format used by this function internally is as follows:
602 The message format used by this function internally is as follows:
598
603
599 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
604 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
600 buffer1,buffer2,...]
605 buffer1,buffer2,...]
601
606
602 The serialize/deserialize methods convert the nested message dict into this
607 The serialize/deserialize methods convert the nested message dict into this
603 format.
608 format.
604
609
605 Parameters
610 Parameters
606 ----------
611 ----------
607
612
608 stream : zmq.Socket or ZMQStream
613 stream : zmq.Socket or ZMQStream
609 The socket-like object used to send the data.
614 The socket-like object used to send the data.
610 msg_or_type : str or Message/dict
615 msg_or_type : str or Message/dict
611 Normally, msg_or_type will be a msg_type unless a message is being
616 Normally, msg_or_type will be a msg_type unless a message is being
612 sent more than once. If a header is supplied, this can be set to
617 sent more than once. If a header is supplied, this can be set to
613 None and the msg_type will be pulled from the header.
618 None and the msg_type will be pulled from the header.
614
619
615 content : dict or None
620 content : dict or None
616 The content of the message (ignored if msg_or_type is a message).
621 The content of the message (ignored if msg_or_type is a message).
617 header : dict or None
622 header : dict or None
618 The header dict for the message (ignored if msg_to_type is a message).
623 The header dict for the message (ignored if msg_to_type is a message).
619 parent : Message or dict or None
624 parent : Message or dict or None
620 The parent or parent header describing the parent of this message
625 The parent or parent header describing the parent of this message
621 (ignored if msg_or_type is a message).
626 (ignored if msg_or_type is a message).
622 ident : bytes or list of bytes
627 ident : bytes or list of bytes
623 The zmq.IDENTITY routing path.
628 The zmq.IDENTITY routing path.
624 metadata : dict or None
629 metadata : dict or None
625 The metadata describing the message
630 The metadata describing the message
626 buffers : list or None
631 buffers : list or None
627 The already-serialized buffers to be appended to the message.
632 The already-serialized buffers to be appended to the message.
628 track : bool
633 track : bool
629 Whether to track. Only for use with Sockets, because ZMQStream
634 Whether to track. Only for use with Sockets, because ZMQStream
630 objects cannot track messages.
635 objects cannot track messages.
631
636
632
637
633 Returns
638 Returns
634 -------
639 -------
635 msg : dict
640 msg : dict
636 The constructed message.
641 The constructed message.
637 """
642 """
638 if not isinstance(stream, zmq.Socket):
643 if not isinstance(stream, zmq.Socket):
639 # ZMQStreams and dummy sockets do not support tracking.
644 # ZMQStreams and dummy sockets do not support tracking.
640 track = False
645 track = False
641
646
642 if isinstance(msg_or_type, (Message, dict)):
647 if isinstance(msg_or_type, (Message, dict)):
643 # We got a Message or message dict, not a msg_type so don't
648 # We got a Message or message dict, not a msg_type so don't
644 # build a new Message.
649 # build a new Message.
645 msg = msg_or_type
650 msg = msg_or_type
646 buffers = buffers or msg.get('buffers', [])
651 buffers = buffers or msg.get('buffers', [])
647 else:
652 else:
648 msg = self.msg(msg_or_type, content=content, parent=parent,
653 msg = self.msg(msg_or_type, content=content, parent=parent,
649 header=header, metadata=metadata)
654 header=header, metadata=metadata)
650 if not os.getpid() == self.pid:
655 if not os.getpid() == self.pid:
651 io.rprint("WARNING: attempted to send message from fork")
656 io.rprint("WARNING: attempted to send message from fork")
652 io.rprint(msg)
657 io.rprint(msg)
653 return
658 return
654 buffers = [] if buffers is None else buffers
659 buffers = [] if buffers is None else buffers
655 if self.adapt_version:
660 if self.adapt_version:
656 msg = adapt(msg, self.adapt_version)
661 msg = adapt(msg, self.adapt_version)
657 to_send = self.serialize(msg, ident)
662 to_send = self.serialize(msg, ident)
658 to_send.extend(buffers)
663 to_send.extend(buffers)
659 longest = max([ len(s) for s in to_send ])
664 longest = max([ len(s) for s in to_send ])
660 copy = (longest < self.copy_threshold)
665 copy = (longest < self.copy_threshold)
661
666
662 if buffers and track and not copy:
667 if buffers and track and not copy:
663 # only really track when we are doing zero-copy buffers
668 # only really track when we are doing zero-copy buffers
664 tracker = stream.send_multipart(to_send, copy=False, track=True)
669 tracker = stream.send_multipart(to_send, copy=False, track=True)
665 else:
670 else:
666 # use dummy tracker, which will be done immediately
671 # use dummy tracker, which will be done immediately
667 tracker = DONE
672 tracker = DONE
668 stream.send_multipart(to_send, copy=copy)
673 stream.send_multipart(to_send, copy=copy)
669
674
670 if self.debug:
675 if self.debug:
671 pprint.pprint(msg)
676 pprint.pprint(msg)
672 pprint.pprint(to_send)
677 pprint.pprint(to_send)
673 pprint.pprint(buffers)
678 pprint.pprint(buffers)
674
679
675 msg['tracker'] = tracker
680 msg['tracker'] = tracker
676
681
677 return msg
682 return msg
678
683
679 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
684 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
680 """Send a raw message via ident path.
685 """Send a raw message via ident path.
681
686
682 This method is used to send a already serialized message.
687 This method is used to send a already serialized message.
683
688
684 Parameters
689 Parameters
685 ----------
690 ----------
686 stream : ZMQStream or Socket
691 stream : ZMQStream or Socket
687 The ZMQ stream or socket to use for sending the message.
692 The ZMQ stream or socket to use for sending the message.
688 msg_list : list
693 msg_list : list
689 The serialized list of messages to send. This only includes the
694 The serialized list of messages to send. This only includes the
690 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
695 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
691 the message.
696 the message.
692 ident : ident or list
697 ident : ident or list
693 A single ident or a list of idents to use in sending.
698 A single ident or a list of idents to use in sending.
694 """
699 """
695 to_send = []
700 to_send = []
696 if isinstance(ident, bytes):
701 if isinstance(ident, bytes):
697 ident = [ident]
702 ident = [ident]
698 if ident is not None:
703 if ident is not None:
699 to_send.extend(ident)
704 to_send.extend(ident)
700
705
701 to_send.append(DELIM)
706 to_send.append(DELIM)
702 to_send.append(self.sign(msg_list))
707 to_send.append(self.sign(msg_list))
703 to_send.extend(msg_list)
708 to_send.extend(msg_list)
704 stream.send_multipart(to_send, flags, copy=copy)
709 stream.send_multipart(to_send, flags, copy=copy)
705
710
706 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
711 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
707 """Receive and unpack a message.
712 """Receive and unpack a message.
708
713
709 Parameters
714 Parameters
710 ----------
715 ----------
711 socket : ZMQStream or Socket
716 socket : ZMQStream or Socket
712 The socket or stream to use in receiving.
717 The socket or stream to use in receiving.
713
718
714 Returns
719 Returns
715 -------
720 -------
716 [idents], msg
721 [idents], msg
717 [idents] is a list of idents and msg is a nested message dict of
722 [idents] is a list of idents and msg is a nested message dict of
718 same format as self.msg returns.
723 same format as self.msg returns.
719 """
724 """
720 if isinstance(socket, ZMQStream):
725 if isinstance(socket, ZMQStream):
721 socket = socket.socket
726 socket = socket.socket
722 try:
727 try:
723 msg_list = socket.recv_multipart(mode, copy=copy)
728 msg_list = socket.recv_multipart(mode, copy=copy)
724 except zmq.ZMQError as e:
729 except zmq.ZMQError as e:
725 if e.errno == zmq.EAGAIN:
730 if e.errno == zmq.EAGAIN:
726 # We can convert EAGAIN to None as we know in this case
731 # We can convert EAGAIN to None as we know in this case
727 # recv_multipart won't return None.
732 # recv_multipart won't return None.
728 return None,None
733 return None,None
729 else:
734 else:
730 raise
735 raise
731 # split multipart message into identity list and message dict
736 # split multipart message into identity list and message dict
732 # invalid large messages can cause very expensive string comparisons
737 # invalid large messages can cause very expensive string comparisons
733 idents, msg_list = self.feed_identities(msg_list, copy)
738 idents, msg_list = self.feed_identities(msg_list, copy)
734 try:
739 try:
735 return idents, self.deserialize(msg_list, content=content, copy=copy)
740 return idents, self.deserialize(msg_list, content=content, copy=copy)
736 except Exception as e:
741 except Exception as e:
737 # TODO: handle it
742 # TODO: handle it
738 raise e
743 raise e
739
744
740 def feed_identities(self, msg_list, copy=True):
745 def feed_identities(self, msg_list, copy=True):
741 """Split the identities from the rest of the message.
746 """Split the identities from the rest of the message.
742
747
743 Feed until DELIM is reached, then return the prefix as idents and
748 Feed until DELIM is reached, then return the prefix as idents and
744 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
749 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
745 but that would be silly.
750 but that would be silly.
746
751
747 Parameters
752 Parameters
748 ----------
753 ----------
749 msg_list : a list of Message or bytes objects
754 msg_list : a list of Message or bytes objects
750 The message to be split.
755 The message to be split.
751 copy : bool
756 copy : bool
752 flag determining whether the arguments are bytes or Messages
757 flag determining whether the arguments are bytes or Messages
753
758
754 Returns
759 Returns
755 -------
760 -------
756 (idents, msg_list) : two lists
761 (idents, msg_list) : two lists
757 idents will always be a list of bytes, each of which is a ZMQ
762 idents will always be a list of bytes, each of which is a ZMQ
758 identity. msg_list will be a list of bytes or zmq.Messages of the
763 identity. msg_list will be a list of bytes or zmq.Messages of the
759 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
764 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
760 should be unpackable/unserializable via self.deserialize at this
765 should be unpackable/unserializable via self.deserialize at this
761 point.
766 point.
762 """
767 """
763 if copy:
768 if copy:
764 idx = msg_list.index(DELIM)
769 idx = msg_list.index(DELIM)
765 return msg_list[:idx], msg_list[idx+1:]
770 return msg_list[:idx], msg_list[idx+1:]
766 else:
771 else:
767 failed = True
772 failed = True
768 for idx,m in enumerate(msg_list):
773 for idx,m in enumerate(msg_list):
769 if m.bytes == DELIM:
774 if m.bytes == DELIM:
770 failed = False
775 failed = False
771 break
776 break
772 if failed:
777 if failed:
773 raise ValueError("DELIM not in msg_list")
778 raise ValueError("DELIM not in msg_list")
774 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
779 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
775 return [m.bytes for m in idents], msg_list
780 return [m.bytes for m in idents], msg_list
776
781
777 def _add_digest(self, signature):
782 def _add_digest(self, signature):
778 """add a digest to history to protect against replay attacks"""
783 """add a digest to history to protect against replay attacks"""
779 if self.digest_history_size == 0:
784 if self.digest_history_size == 0:
780 # no history, never add digests
785 # no history, never add digests
781 return
786 return
782
787
783 self.digest_history.add(signature)
788 self.digest_history.add(signature)
784 if len(self.digest_history) > self.digest_history_size:
789 if len(self.digest_history) > self.digest_history_size:
785 # threshold reached, cull 10%
790 # threshold reached, cull 10%
786 self._cull_digest_history()
791 self._cull_digest_history()
787
792
788 def _cull_digest_history(self):
793 def _cull_digest_history(self):
789 """cull the digest history
794 """cull the digest history
790
795
791 Removes a randomly selected 10% of the digest history
796 Removes a randomly selected 10% of the digest history
792 """
797 """
793 current = len(self.digest_history)
798 current = len(self.digest_history)
794 n_to_cull = max(int(current // 10), current - self.digest_history_size)
799 n_to_cull = max(int(current // 10), current - self.digest_history_size)
795 if n_to_cull >= current:
800 if n_to_cull >= current:
796 self.digest_history = set()
801 self.digest_history = set()
797 return
802 return
798 to_cull = random.sample(self.digest_history, n_to_cull)
803 to_cull = random.sample(self.digest_history, n_to_cull)
799 self.digest_history.difference_update(to_cull)
804 self.digest_history.difference_update(to_cull)
800
805
801 def deserialize(self, msg_list, content=True, copy=True):
806 def deserialize(self, msg_list, content=True, copy=True):
802 """Unserialize a msg_list to a nested message dict.
807 """Unserialize a msg_list to a nested message dict.
803
808
804 This is roughly the inverse of serialize. The serialize/deserialize
809 This is roughly the inverse of serialize. The serialize/deserialize
805 methods work with full message lists, whereas pack/unpack work with
810 methods work with full message lists, whereas pack/unpack work with
806 the individual message parts in the message list.
811 the individual message parts in the message list.
807
812
808 Parameters
813 Parameters
809 ----------
814 ----------
810 msg_list : list of bytes or Message objects
815 msg_list : list of bytes or Message objects
811 The list of message parts of the form [HMAC,p_header,p_parent,
816 The list of message parts of the form [HMAC,p_header,p_parent,
812 p_metadata,p_content,buffer1,buffer2,...].
817 p_metadata,p_content,buffer1,buffer2,...].
813 content : bool (True)
818 content : bool (True)
814 Whether to unpack the content dict (True), or leave it packed
819 Whether to unpack the content dict (True), or leave it packed
815 (False).
820 (False).
816 copy : bool (True)
821 copy : bool (True)
817 Whether msg_list contains bytes (True) or the non-copying Message
822 Whether msg_list contains bytes (True) or the non-copying Message
818 objects in each place (False).
823 objects in each place (False).
819
824
820 Returns
825 Returns
821 -------
826 -------
822 msg : dict
827 msg : dict
823 The nested message dict with top-level keys [header, parent_header,
828 The nested message dict with top-level keys [header, parent_header,
824 content, buffers]. The buffers are returned as memoryviews.
829 content, buffers]. The buffers are returned as memoryviews.
825 """
830 """
826 minlen = 5
831 minlen = 5
827 message = {}
832 message = {}
828 if not copy:
833 if not copy:
829 # pyzmq didn't copy the first parts of the message, so we'll do it
834 # pyzmq didn't copy the first parts of the message, so we'll do it
830 for i in range(minlen):
835 for i in range(minlen):
831 msg_list[i] = msg_list[i].bytes
836 msg_list[i] = msg_list[i].bytes
832 if self.auth is not None:
837 if self.auth is not None:
833 signature = msg_list[0]
838 signature = msg_list[0]
834 if not signature:
839 if not signature:
835 raise ValueError("Unsigned Message")
840 raise ValueError("Unsigned Message")
836 if signature in self.digest_history:
841 if signature in self.digest_history:
837 raise ValueError("Duplicate Signature: %r" % signature)
842 raise ValueError("Duplicate Signature: %r" % signature)
838 self._add_digest(signature)
843 self._add_digest(signature)
839 check = self.sign(msg_list[1:5])
844 check = self.sign(msg_list[1:5])
840 if not compare_digest(signature, check):
845 if not compare_digest(signature, check):
841 raise ValueError("Invalid Signature: %r" % signature)
846 raise ValueError("Invalid Signature: %r" % signature)
842 if not len(msg_list) >= minlen:
847 if not len(msg_list) >= minlen:
843 raise TypeError("malformed message, must have at least %i elements"%minlen)
848 raise TypeError("malformed message, must have at least %i elements"%minlen)
844 header = self.unpack(msg_list[1])
849 header = self.unpack(msg_list[1])
845 message['header'] = extract_dates(header)
850 message['header'] = extract_dates(header)
846 message['msg_id'] = header['msg_id']
851 message['msg_id'] = header['msg_id']
847 message['msg_type'] = header['msg_type']
852 message['msg_type'] = header['msg_type']
848 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
853 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
849 message['metadata'] = self.unpack(msg_list[3])
854 message['metadata'] = self.unpack(msg_list[3])
850 if content:
855 if content:
851 message['content'] = self.unpack(msg_list[4])
856 message['content'] = self.unpack(msg_list[4])
852 else:
857 else:
853 message['content'] = msg_list[4]
858 message['content'] = msg_list[4]
854 buffers = [memoryview(b) for b in msg_list[5:]]
859 buffers = [memoryview(b) for b in msg_list[5:]]
855 if buffers and buffers[0].shape is None:
860 if buffers and buffers[0].shape is None:
856 # force copy to workaround pyzmq #646
861 # force copy to workaround pyzmq #646
857 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
862 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
858 message['buffers'] = buffers
863 message['buffers'] = buffers
859 # adapt to the current version
864 # adapt to the current version
860 return adapt(message)
865 return adapt(message)
861
866
862 def unserialize(self, *args, **kwargs):
867 def unserialize(self, *args, **kwargs):
863 warnings.warn(
868 warnings.warn(
864 "Session.unserialize is deprecated. Use Session.deserialize.",
869 "Session.unserialize is deprecated. Use Session.deserialize.",
865 DeprecationWarning,
870 DeprecationWarning,
866 )
871 )
867 return self.deserialize(*args, **kwargs)
872 return self.deserialize(*args, **kwargs)
868
873
869
874
870 def test_msg2obj():
875 def test_msg2obj():
871 am = dict(x=1)
876 am = dict(x=1)
872 ao = Message(am)
877 ao = Message(am)
873 assert ao.x == am['x']
878 assert ao.x == am['x']
874
879
875 am['y'] = dict(z=1)
880 am['y'] = dict(z=1)
876 ao = Message(am)
881 ao = Message(am)
877 assert ao.y.z == am['y']['z']
882 assert ao.y.z == am['y']['z']
878
883
879 k1, k2 = 'y', 'z'
884 k1, k2 = 'y', 'z'
880 assert ao[k1][k2] == am[k1][k2]
885 assert ao[k1][k2] == am[k1][k2]
881
886
882 am2 = dict(ao)
887 am2 = dict(ao)
883 assert am['x'] == am2['x']
888 assert am['x'] == am2['x']
884 assert am['y']['z'] == am2['y']['z']
889 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now