##// END OF EJS Templates
remove use of utils.flatten...
Min RK -
Show More
@@ -1,179 +1,181 b''
1 """serialization utilities for apply messages"""
1 """serialization utilities for apply messages"""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 try:
6 try:
7 import cPickle
7 import cPickle
8 pickle = cPickle
8 pickle = cPickle
9 except:
9 except:
10 cPickle = None
10 cPickle = None
11 import pickle
11 import pickle
12
12
13 # IPython imports
13 from itertools import chain
14
14 from IPython.utils.py3compat import PY3, buffer_to_bytes_py2
15 from IPython.utils.py3compat import PY3, buffer_to_bytes_py2
15 from IPython.utils.data import flatten
16 from ipython_kernel.pickleutil import (
16 from ipython_kernel.pickleutil import (
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
17 can, uncan, can_sequence, uncan_sequence, CannedObject,
18 istype, sequence_types, PICKLE_PROTOCOL,
18 istype, sequence_types, PICKLE_PROTOCOL,
19 )
19 )
20 from jupyter_client.session import MAX_ITEMS, MAX_BYTES
20 from jupyter_client.session import MAX_ITEMS, MAX_BYTES
21
21
22
22
23 if PY3:
23 if PY3:
24 buffer = memoryview
24 buffer = memoryview
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Serialization Functions
27 # Serialization Functions
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30
30
31 def _extract_buffers(obj, threshold=MAX_BYTES):
31 def _extract_buffers(obj, threshold=MAX_BYTES):
32 """extract buffers larger than a certain threshold"""
32 """extract buffers larger than a certain threshold"""
33 buffers = []
33 buffers = []
34 if isinstance(obj, CannedObject) and obj.buffers:
34 if isinstance(obj, CannedObject) and obj.buffers:
35 for i,buf in enumerate(obj.buffers):
35 for i,buf in enumerate(obj.buffers):
36 if len(buf) > threshold:
36 if len(buf) > threshold:
37 # buffer larger than threshold, prevent pickling
37 # buffer larger than threshold, prevent pickling
38 obj.buffers[i] = None
38 obj.buffers[i] = None
39 buffers.append(buf)
39 buffers.append(buf)
40 elif isinstance(buf, buffer):
40 elif isinstance(buf, buffer):
41 # buffer too small for separate send, coerce to bytes
41 # buffer too small for separate send, coerce to bytes
42 # because pickling buffer objects just results in broken pointers
42 # because pickling buffer objects just results in broken pointers
43 obj.buffers[i] = bytes(buf)
43 obj.buffers[i] = bytes(buf)
44 return buffers
44 return buffers
45
45
46 def _restore_buffers(obj, buffers):
46 def _restore_buffers(obj, buffers):
47 """restore buffers extracted by """
47 """restore buffers extracted by """
48 if isinstance(obj, CannedObject) and obj.buffers:
48 if isinstance(obj, CannedObject) and obj.buffers:
49 for i,buf in enumerate(obj.buffers):
49 for i,buf in enumerate(obj.buffers):
50 if buf is None:
50 if buf is None:
51 obj.buffers[i] = buffers.pop(0)
51 obj.buffers[i] = buffers.pop(0)
52
52
53 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
53 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
54 """Serialize an object into a list of sendable buffers.
54 """Serialize an object into a list of sendable buffers.
55
55
56 Parameters
56 Parameters
57 ----------
57 ----------
58
58
59 obj : object
59 obj : object
60 The object to be serialized
60 The object to be serialized
61 buffer_threshold : int
61 buffer_threshold : int
62 The threshold (in bytes) for pulling out data buffers
62 The threshold (in bytes) for pulling out data buffers
63 to avoid pickling them.
63 to avoid pickling them.
64 item_threshold : int
64 item_threshold : int
65 The maximum number of items over which canning will iterate.
65 The maximum number of items over which canning will iterate.
66 Containers (lists, dicts) larger than this will be pickled without
66 Containers (lists, dicts) larger than this will be pickled without
67 introspection.
67 introspection.
68
68
69 Returns
69 Returns
70 -------
70 -------
71 [bufs] : list of buffers representing the serialized object.
71 [bufs] : list of buffers representing the serialized object.
72 """
72 """
73 buffers = []
73 buffers = []
74 if istype(obj, sequence_types) and len(obj) < item_threshold:
74 if istype(obj, sequence_types) and len(obj) < item_threshold:
75 cobj = can_sequence(obj)
75 cobj = can_sequence(obj)
76 for c in cobj:
76 for c in cobj:
77 buffers.extend(_extract_buffers(c, buffer_threshold))
77 buffers.extend(_extract_buffers(c, buffer_threshold))
78 elif istype(obj, dict) and len(obj) < item_threshold:
78 elif istype(obj, dict) and len(obj) < item_threshold:
79 cobj = {}
79 cobj = {}
80 for k in sorted(obj):
80 for k in sorted(obj):
81 c = can(obj[k])
81 c = can(obj[k])
82 buffers.extend(_extract_buffers(c, buffer_threshold))
82 buffers.extend(_extract_buffers(c, buffer_threshold))
83 cobj[k] = c
83 cobj[k] = c
84 else:
84 else:
85 cobj = can(obj)
85 cobj = can(obj)
86 buffers.extend(_extract_buffers(cobj, buffer_threshold))
86 buffers.extend(_extract_buffers(cobj, buffer_threshold))
87
87
88 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
88 buffers.insert(0, pickle.dumps(cobj, PICKLE_PROTOCOL))
89 return buffers
89 return buffers
90
90
91 def deserialize_object(buffers, g=None):
91 def deserialize_object(buffers, g=None):
92 """reconstruct an object serialized by serialize_object from data buffers.
92 """reconstruct an object serialized by serialize_object from data buffers.
93
93
94 Parameters
94 Parameters
95 ----------
95 ----------
96
96
97 bufs : list of buffers/bytes
97 bufs : list of buffers/bytes
98
98
99 g : globals to be used when uncanning
99 g : globals to be used when uncanning
100
100
101 Returns
101 Returns
102 -------
102 -------
103
103
104 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
104 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
105 """
105 """
106 bufs = list(buffers)
106 bufs = list(buffers)
107 pobj = buffer_to_bytes_py2(bufs.pop(0))
107 pobj = buffer_to_bytes_py2(bufs.pop(0))
108 canned = pickle.loads(pobj)
108 canned = pickle.loads(pobj)
109 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
109 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
110 for c in canned:
110 for c in canned:
111 _restore_buffers(c, bufs)
111 _restore_buffers(c, bufs)
112 newobj = uncan_sequence(canned, g)
112 newobj = uncan_sequence(canned, g)
113 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
113 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
114 newobj = {}
114 newobj = {}
115 for k in sorted(canned):
115 for k in sorted(canned):
116 c = canned[k]
116 c = canned[k]
117 _restore_buffers(c, bufs)
117 _restore_buffers(c, bufs)
118 newobj[k] = uncan(c, g)
118 newobj[k] = uncan(c, g)
119 else:
119 else:
120 _restore_buffers(canned, bufs)
120 _restore_buffers(canned, bufs)
121 newobj = uncan(canned, g)
121 newobj = uncan(canned, g)
122
122
123 return newobj, bufs
123 return newobj, bufs
124
124
125 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
125 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
126 """pack up a function, args, and kwargs to be sent over the wire
126 """pack up a function, args, and kwargs to be sent over the wire
127
127
128 Each element of args/kwargs will be canned for special treatment,
128 Each element of args/kwargs will be canned for special treatment,
129 but inspection will not go any deeper than that.
129 but inspection will not go any deeper than that.
130
130
131 Any object whose data is larger than `threshold` will not have their data copied
131 Any object whose data is larger than `threshold` will not have their data copied
132 (only numpy arrays and bytes/buffers support zero-copy)
132 (only numpy arrays and bytes/buffers support zero-copy)
133
133
134 Message will be a list of bytes/buffers of the format:
134 Message will be a list of bytes/buffers of the format:
135
135
136 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
136 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
137
137
138 With length at least two + len(args) + len(kwargs)
138 With length at least two + len(args) + len(kwargs)
139 """
139 """
140
140
141 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
141 arg_bufs = list(chain.from_iterable(
142 serialize_object(arg, buffer_threshold, item_threshold) for arg in args))
142
143
143 kw_keys = sorted(kwargs.keys())
144 kw_keys = sorted(kwargs.keys())
144 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
145 kwarg_bufs = list(chain.from_iterable(
146 serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys))
145
147
146 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
148 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
147
149
148 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
150 msg = [pickle.dumps(can(f), PICKLE_PROTOCOL)]
149 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
151 msg.append(pickle.dumps(info, PICKLE_PROTOCOL))
150 msg.extend(arg_bufs)
152 msg.extend(arg_bufs)
151 msg.extend(kwarg_bufs)
153 msg.extend(kwarg_bufs)
152
154
153 return msg
155 return msg
154
156
155 def unpack_apply_message(bufs, g=None, copy=True):
157 def unpack_apply_message(bufs, g=None, copy=True):
156 """unpack f,args,kwargs from buffers packed by pack_apply_message()
158 """unpack f,args,kwargs from buffers packed by pack_apply_message()
157 Returns: original f,args,kwargs"""
159 Returns: original f,args,kwargs"""
158 bufs = list(bufs) # allow us to pop
160 bufs = list(bufs) # allow us to pop
159 assert len(bufs) >= 2, "not enough buffers!"
161 assert len(bufs) >= 2, "not enough buffers!"
160 pf = buffer_to_bytes_py2(bufs.pop(0))
162 pf = buffer_to_bytes_py2(bufs.pop(0))
161 f = uncan(pickle.loads(pf), g)
163 f = uncan(pickle.loads(pf), g)
162 pinfo = buffer_to_bytes_py2(bufs.pop(0))
164 pinfo = buffer_to_bytes_py2(bufs.pop(0))
163 info = pickle.loads(pinfo)
165 info = pickle.loads(pinfo)
164 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
166 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
165
167
166 args = []
168 args = []
167 for i in range(info['nargs']):
169 for i in range(info['nargs']):
168 arg, arg_bufs = deserialize_object(arg_bufs, g)
170 arg, arg_bufs = deserialize_object(arg_bufs, g)
169 args.append(arg)
171 args.append(arg)
170 args = tuple(args)
172 args = tuple(args)
171 assert not arg_bufs, "Shouldn't be any arg bufs left over"
173 assert not arg_bufs, "Shouldn't be any arg bufs left over"
172
174
173 kwargs = {}
175 kwargs = {}
174 for key in info['kw_keys']:
176 for key in info['kw_keys']:
175 kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g)
177 kwarg, kwarg_bufs = deserialize_object(kwarg_bufs, g)
176 kwargs[key] = kwarg
178 kwargs[key] = kwarg
177 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
179 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
178
180
179 return f,args,kwargs
181 return f,args,kwargs
@@ -1,129 +1,124 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 from __future__ import division
12 from __future__ import division
13
13
14 import sys
14 import sys
15 from itertools import islice
15 from itertools import islice, chain
16
17 from IPython.utils.data import flatten as utils_flatten
18
19
16
20 numpy = None
17 numpy = None
21
18
22 def is_array(obj):
19 def is_array(obj):
23 """Is an object a numpy array?
20 """Is an object a numpy array?
24
21
25 Avoids importing numpy until it is requested
22 Avoids importing numpy until it is requested
26 """
23 """
27 global numpy
24 global numpy
28 if 'numpy' not in sys.modules:
25 if 'numpy' not in sys.modules:
29 return False
26 return False
30
27
31 if numpy is None:
28 if numpy is None:
32 import numpy
29 import numpy
33 return isinstance(obj, numpy.ndarray)
30 return isinstance(obj, numpy.ndarray)
34
31
35 class Map(object):
32 class Map(object):
36 """A class for partitioning a sequence using a map."""
33 """A class for partitioning a sequence using a map."""
37
34
38 def getPartition(self, seq, p, q, n=None):
35 def getPartition(self, seq, p, q, n=None):
39 """Returns the pth partition of q partitions of seq.
36 """Returns the pth partition of q partitions of seq.
40
37
41 The length can be specified as `n`,
38 The length can be specified as `n`,
42 otherwise it is the value of `len(seq)`
39 otherwise it is the value of `len(seq)`
43 """
40 """
44 n = len(seq) if n is None else n
41 n = len(seq) if n is None else n
45 # Test for error conditions here
42 # Test for error conditions here
46 if p<0 or p>=q:
43 if p<0 or p>=q:
47 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
44 raise ValueError("must have 0 <= p <= q, but have p=%s,q=%s" % (p, q))
48
45
49 remainder = n % q
46 remainder = n % q
50 basesize = n // q
47 basesize = n // q
51
48
52 if p < remainder:
49 if p < remainder:
53 low = p * (basesize + 1)
50 low = p * (basesize + 1)
54 high = low + basesize + 1
51 high = low + basesize + 1
55 else:
52 else:
56 low = p * basesize + remainder
53 low = p * basesize + remainder
57 high = low + basesize
54 high = low + basesize
58
55
59 try:
56 try:
60 result = seq[low:high]
57 result = seq[low:high]
61 except TypeError:
58 except TypeError:
62 # some objects (iterators) can't be sliced,
59 # some objects (iterators) can't be sliced,
63 # use islice:
60 # use islice:
64 result = list(islice(seq, low, high))
61 result = list(islice(seq, low, high))
65
62
66 return result
63 return result
67
64
68 def joinPartitions(self, listOfPartitions):
65 def joinPartitions(self, listOfPartitions):
69 return self.concatenate(listOfPartitions)
66 return self.concatenate(listOfPartitions)
70
67
71 def concatenate(self, listOfPartitions):
68 def concatenate(self, listOfPartitions):
72 testObject = listOfPartitions[0]
69 testObject = listOfPartitions[0]
73 # First see if we have a known array type
70 # First see if we have a known array type
74 if is_array(testObject):
71 if is_array(testObject):
75 return numpy.concatenate(listOfPartitions)
72 return numpy.concatenate(listOfPartitions)
76 # Next try for Python sequence types
73 # Next try for Python sequence types
77 if isinstance(testObject, (list, tuple)):
74 if isinstance(testObject, (list, tuple)):
78 return utils_flatten(listOfPartitions)
75 return list(chain.from_iterable(listOfPartitions))
79 # If we have scalars, just return listOfPartitions
76 # If we have scalars, just return listOfPartitions
80 return listOfPartitions
77 return listOfPartitions
81
78
82 class RoundRobinMap(Map):
79 class RoundRobinMap(Map):
83 """Partitions a sequence in a round robin fashion.
80 """Partitions a sequence in a round robin fashion.
84
81
85 This currently does not work!
82 This currently does not work!
86 """
83 """
87
84
88 def getPartition(self, seq, p, q, n=None):
85 def getPartition(self, seq, p, q, n=None):
89 n = len(seq) if n is None else n
86 n = len(seq) if n is None else n
90 return seq[p:n:q]
87 return seq[p:n:q]
91
88
92 def joinPartitions(self, listOfPartitions):
89 def joinPartitions(self, listOfPartitions):
93 testObject = listOfPartitions[0]
90 testObject = listOfPartitions[0]
94 # First see if we have a known array type
91 # First see if we have a known array type
95 if is_array(testObject):
92 if is_array(testObject):
96 return self.flatten_array(listOfPartitions)
93 return self.flatten_array(listOfPartitions)
97 if isinstance(testObject, (list, tuple)):
94 if isinstance(testObject, (list, tuple)):
98 return self.flatten_list(listOfPartitions)
95 return self.flatten_list(listOfPartitions)
99 return listOfPartitions
96 return listOfPartitions
100
97
101 def flatten_array(self, listOfPartitions):
98 def flatten_array(self, listOfPartitions):
102 test = listOfPartitions[0]
99 test = listOfPartitions[0]
103 shape = list(test.shape)
100 shape = list(test.shape)
104 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
101 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
105 A = numpy.ndarray(shape)
102 A = numpy.ndarray(shape)
106 N = shape[0]
103 N = shape[0]
107 q = len(listOfPartitions)
104 q = len(listOfPartitions)
108 for p,part in enumerate(listOfPartitions):
105 for p,part in enumerate(listOfPartitions):
109 A[p:N:q] = part
106 A[p:N:q] = part
110 return A
107 return A
111
108
112 def flatten_list(self, listOfPartitions):
109 def flatten_list(self, listOfPartitions):
113 flat = []
110 flat = []
114 for i in range(len(listOfPartitions[0])):
111 for i in range(len(listOfPartitions[0])):
115 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
112 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
116 return flat
113 return flat
117
114
118 def mappable(obj):
115 def mappable(obj):
119 """return whether an object is mappable or not."""
116 """return whether an object is mappable or not."""
120 if isinstance(obj, (tuple,list)):
117 if isinstance(obj, (tuple,list)):
121 return True
118 return True
122 if is_array(obj):
119 if is_array(obj):
123 return True
120 return True
124 return False
121 return False
125
122
126 dists = {'b':Map,'r':RoundRobinMap}
123 dists = {'b':Map,'r':RoundRobinMap}
127
124
128
129
General Comments 0
You need to be logged in to leave comments. Login now