##// END OF EJS Templates
support iterators in view.map...
MinRK -
Show More
@@ -1,165 +1,171 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7 7
8 8
9 9 Authors:
10 10
11 11 * Brian Granger
12 12 * MinRK
13 13
14 14 """
15 15
16 16 #-------------------------------------------------------------------------------
17 17 # Copyright (C) 2008-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-------------------------------------------------------------------------------
22 22
23 23 #-------------------------------------------------------------------------------
24 24 # Imports
25 25 #-------------------------------------------------------------------------------
26 26
27 27 from __future__ import division
28 28
29 29 import types
30 from itertools import islice
30 31
31 32 from IPython.utils.data import flatten as utils_flatten
32 33
33 34 #-------------------------------------------------------------------------------
34 35 # Figure out which array packages are present and their array types
35 36 #-------------------------------------------------------------------------------
36 37
37 38 arrayModules = []
38 39 try:
39 40 import Numeric
40 41 except ImportError:
41 42 pass
42 43 else:
43 44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 45 try:
45 46 import numpy
46 47 except ImportError:
47 48 pass
48 49 else:
49 50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 51 try:
51 52 import numarray
52 53 except ImportError:
53 54 pass
54 55 else:
55 56 arrayModules.append({'module':numarray,
56 57 'type':numarray.numarraycore.NumArray})
57 58
58 59 class Map:
59 60 """A class for partitioning a sequence using a map."""
60 61
61 62 def getPartition(self, seq, p, q):
62 63 """Returns the pth partition of q partitions of seq."""
63 64
64 65 # Test for error conditions here
65 66 if p<0 or p>=q:
66 67 print "No partition exists."
67 68 return
68 69
69 70 remainder = len(seq)%q
70 71 basesize = len(seq)//q
71 72 hi = []
72 73 lo = []
73 74 for n in range(q):
74 75 if n < remainder:
75 76 lo.append(n * (basesize + 1))
76 77 hi.append(lo[-1] + basesize + 1)
77 78 else:
78 79 lo.append(n*basesize + remainder)
79 80 hi.append(lo[-1] + basesize)
80
81 81
82 result = seq[lo[p]:hi[p]]
82 try:
83 result = seq[lo[p]:hi[p]]
84 except TypeError:
85 # some objects (iterators) can't be sliced,
86 # use islice:
87 result = list(islice(seq, lo[p], hi[p]))
88
83 89 return result
84 90
85 91 def joinPartitions(self, listOfPartitions):
86 92 return self.concatenate(listOfPartitions)
87 93
88 94 def concatenate(self, listOfPartitions):
89 95 testObject = listOfPartitions[0]
90 96 # First see if we have a known array type
91 97 for m in arrayModules:
92 98 #print m
93 99 if isinstance(testObject, m['type']):
94 100 return m['module'].concatenate(listOfPartitions)
95 101 # Next try for Python sequence types
96 102 if isinstance(testObject, (types.ListType, types.TupleType)):
97 103 return utils_flatten(listOfPartitions)
98 104 # If we have scalars, just return listOfPartitions
99 105 return listOfPartitions
100 106
101 107 class RoundRobinMap(Map):
102 108 """Partitions a sequence in a roun robin fashion.
103 109
104 110 This currently does not work!
105 111 """
106 112
107 113 def getPartition(self, seq, p, q):
108 114 # if not isinstance(seq,(list,tuple)):
109 115 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
110 116 return seq[p:len(seq):q]
111 117 #result = []
112 118 #for i in range(p,len(seq),q):
113 119 # result.append(seq[i])
114 120 #return result
115 121
116 122 def joinPartitions(self, listOfPartitions):
117 123 testObject = listOfPartitions[0]
118 124 # First see if we have a known array type
119 125 for m in arrayModules:
120 126 #print m
121 127 if isinstance(testObject, m['type']):
122 128 return self.flatten_array(m['type'], listOfPartitions)
123 129 if isinstance(testObject, (types.ListType, types.TupleType)):
124 130 return self.flatten_list(listOfPartitions)
125 131 return listOfPartitions
126 132
127 133 def flatten_array(self, klass, listOfPartitions):
128 134 test = listOfPartitions[0]
129 135 shape = list(test.shape)
130 136 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
131 137 A = klass(shape)
132 138 N = shape[0]
133 139 q = len(listOfPartitions)
134 140 for p,part in enumerate(listOfPartitions):
135 141 A[p:N:q] = part
136 142 return A
137 143
138 144 def flatten_list(self, listOfPartitions):
139 145 flat = []
140 146 for i in range(len(listOfPartitions[0])):
141 147 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
142 148 return flat
143 149 #lengths = [len(x) for x in listOfPartitions]
144 150 #maxPartitionLength = len(listOfPartitions[0])
145 151 #numberOfPartitions = len(listOfPartitions)
146 152 #concat = self.concatenate(listOfPartitions)
147 153 #totalLength = len(concat)
148 154 #result = []
149 155 #for i in range(maxPartitionLength):
150 156 # result.append(concat[i:totalLength:maxPartitionLength])
151 157 # return self.concatenate(listOfPartitions)
152 158
153 159 def mappable(obj):
154 160 """return whether an object is mappable or not."""
155 161 if isinstance(obj, (tuple,list)):
156 162 return True
157 163 for m in arrayModules:
158 164 if isinstance(obj,m['type']):
159 165 return True
160 166 return False
161 167
162 168 dists = {'b':Map,'r':RoundRobinMap}
163 169
164 170
165 171
@@ -1,167 +1,178 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test LoadBalancedView objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from nose import SkipTest
24 24
25 25 from IPython import parallel as pmod
26 26 from IPython.parallel import error
27 27
28 28 from IPython.parallel.tests import add_engines
29 29
30 30 from .clienttest import ClusterTestCase, crash, wait, skip_without
31 31
32 32 def setup():
33 33 add_engines(3)
34 34
35 35 class TestLoadBalancedView(ClusterTestCase):
36 36
37 37 def setUp(self):
38 38 ClusterTestCase.setUp(self)
39 39 self.view = self.client.load_balanced_view()
40 40
41 41 def test_z_crash_task(self):
42 42 """test graceful handling of engine death (balanced)"""
43 43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
44 44 # self.add_engines(1)
45 45 ar = self.view.apply_async(crash)
46 46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 47 eid = ar.engine_id
48 48 tic = time.time()
49 49 while eid in self.client.ids and time.time()-tic < 5:
50 50 time.sleep(.01)
51 51 self.client.spin()
52 52 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 53
54 54 def test_map(self):
55 55 def f(x):
56 56 return x**2
57 57 data = range(16)
58 58 r = self.view.map_sync(f, data)
59 59 self.assertEquals(r, map(f, data))
60 60
61 61 def test_map_unordered(self):
62 62 def f(x):
63 63 return x**2
64 64 def slow_f(x):
65 65 import time
66 66 time.sleep(0.05*x)
67 67 return x**2
68 68 data = range(16,0,-1)
69 69 reference = map(f, data)
70 70
71 71 amr = self.view.map_async(slow_f, data, ordered=False)
72 72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
73 73 # check individual elements, retrieved as they come
74 74 # list comprehension uses __iter__
75 75 astheycame = [ r for r in amr ]
76 76 # Ensure that at least one result came out of order:
77 77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
78 78 self.assertEquals(sorted(astheycame, reverse=True), reference, "result corrupted")
79 79
80 80 def test_map_ordered(self):
81 81 def f(x):
82 82 return x**2
83 83 def slow_f(x):
84 84 import time
85 85 time.sleep(0.05*x)
86 86 return x**2
87 87 data = range(16,0,-1)
88 88 reference = map(f, data)
89 89
90 90 amr = self.view.map_async(slow_f, data)
91 91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
92 92 # check individual elements, retrieved as they come
93 93 # list(amr) uses __iter__
94 94 astheycame = list(amr)
95 95 # Ensure that results came in order
96 96 self.assertEquals(astheycame, reference)
97 97 self.assertEquals(amr.result, reference)
98
99 def test_map_iterable(self):
100 """test map on iterables (balanced)"""
101 view = self.view
102 # 101 is prime, so it won't be evenly distributed
103 arr = range(101)
104 # so that it will be an iterator, even in Python 3
105 it = iter(arr)
106 r = view.map_sync(lambda x:x, arr)
107 self.assertEquals(r, list(arr))
108
98 109
99 110 def test_abort(self):
100 111 view = self.view
101 112 ar = self.client[:].apply_async(time.sleep, .5)
102 113 ar = self.client[:].apply_async(time.sleep, .5)
103 114 time.sleep(0.2)
104 115 ar2 = view.apply_async(lambda : 2)
105 116 ar3 = view.apply_async(lambda : 3)
106 117 view.abort(ar2)
107 118 view.abort(ar3.msg_ids)
108 119 self.assertRaises(error.TaskAborted, ar2.get)
109 120 self.assertRaises(error.TaskAborted, ar3.get)
110 121
111 122 def test_retries(self):
112 123 add_engines(3)
113 124 view = self.view
114 125 view.timeout = 1 # prevent hang if this doesn't behave
115 126 def fail():
116 127 assert False
117 128 for r in range(len(self.client)-1):
118 129 with view.temp_flags(retries=r):
119 130 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
120 131
121 132 with view.temp_flags(retries=len(self.client), timeout=0.25):
122 133 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
123 134
124 135 def test_invalid_dependency(self):
125 136 view = self.view
126 137 with view.temp_flags(after='12345'):
127 138 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
128 139
129 140 def test_impossible_dependency(self):
130 141 if len(self.client) < 2:
131 142 add_engines(2)
132 143 view = self.client.load_balanced_view()
133 144 ar1 = view.apply_async(lambda : 1)
134 145 ar1.get()
135 146 e1 = ar1.engine_id
136 147 e2 = e1
137 148 while e2 == e1:
138 149 ar2 = view.apply_async(lambda : 1)
139 150 ar2.get()
140 151 e2 = ar2.engine_id
141 152
142 153 with view.temp_flags(follow=[ar1, ar2]):
143 154 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
144 155
145 156
146 157 def test_follow(self):
147 158 ar = self.view.apply_async(lambda : 1)
148 159 ar.get()
149 160 ars = []
150 161 first_id = ar.engine_id
151 162
152 163 self.view.follow = ar
153 164 for i in range(5):
154 165 ars.append(self.view.apply_async(lambda : 1))
155 166 self.view.wait(ars)
156 167 for ar in ars:
157 168 self.assertEquals(ar.engine_id, first_id)
158 169
159 170 def test_after(self):
160 171 view = self.view
161 172 ar = view.apply_async(time.sleep, 0.5)
162 173 with view.temp_flags(after=ar):
163 174 ar2 = view.apply_async(lambda : 1)
164 175
165 176 ar.wait()
166 177 ar2.wait()
167 178 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,451 +1,460 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 27 from IPython import parallel as pmod
28 28 from IPython.parallel import error
29 29 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
30 30 from IPython.parallel import DirectView
31 31 from IPython.parallel.util import interactive
32 32
33 33 from IPython.parallel.tests import add_engines
34 34
35 35 from .clienttest import ClusterTestCase, crash, wait, skip_without
36 36
37 37 def setup():
38 38 add_engines(3)
39 39
40 40 class TestView(ClusterTestCase):
41 41
42 42 def test_z_crash_mux(self):
43 43 """test graceful handling of engine death (direct)"""
44 44 raise SkipTest("crash tests disabled, due to undesirable crash reports")
45 45 # self.add_engines(1)
46 46 eid = self.client.ids[-1]
47 47 ar = self.client[eid].apply_async(crash)
48 48 self.assertRaisesRemote(error.EngineError, ar.get, 10)
49 49 eid = ar.engine_id
50 50 tic = time.time()
51 51 while eid in self.client.ids and time.time()-tic < 5:
52 52 time.sleep(.01)
53 53 self.client.spin()
54 54 self.assertFalse(eid in self.client.ids, "Engine should have died")
55 55
56 56 def test_push_pull(self):
57 57 """test pushing and pulling"""
58 58 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
59 59 t = self.client.ids[-1]
60 60 v = self.client[t]
61 61 push = v.push
62 62 pull = v.pull
63 63 v.block=True
64 64 nengines = len(self.client)
65 65 push({'data':data})
66 66 d = pull('data')
67 67 self.assertEquals(d, data)
68 68 self.client[:].push({'data':data})
69 69 d = self.client[:].pull('data', block=True)
70 70 self.assertEquals(d, nengines*[data])
71 71 ar = push({'data':data}, block=False)
72 72 self.assertTrue(isinstance(ar, AsyncResult))
73 73 r = ar.get()
74 74 ar = self.client[:].pull('data', block=False)
75 75 self.assertTrue(isinstance(ar, AsyncResult))
76 76 r = ar.get()
77 77 self.assertEquals(r, nengines*[data])
78 78 self.client[:].push(dict(a=10,b=20))
79 79 r = self.client[:].pull(('a','b'), block=True)
80 80 self.assertEquals(r, nengines*[[10,20]])
81 81
82 82 def test_push_pull_function(self):
83 83 "test pushing and pulling functions"
84 84 def testf(x):
85 85 return 2.0*x
86 86
87 87 t = self.client.ids[-1]
88 88 v = self.client[t]
89 89 v.block=True
90 90 push = v.push
91 91 pull = v.pull
92 92 execute = v.execute
93 93 push({'testf':testf})
94 94 r = pull('testf')
95 95 self.assertEqual(r(1.0), testf(1.0))
96 96 execute('r = testf(10)')
97 97 r = pull('r')
98 98 self.assertEquals(r, testf(10))
99 99 ar = self.client[:].push({'testf':testf}, block=False)
100 100 ar.get()
101 101 ar = self.client[:].pull('testf', block=False)
102 102 rlist = ar.get()
103 103 for r in rlist:
104 104 self.assertEqual(r(1.0), testf(1.0))
105 105 execute("def g(x): return x*x")
106 106 r = pull(('testf','g'))
107 107 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
108 108
109 109 def test_push_function_globals(self):
110 110 """test that pushed functions have access to globals"""
111 111 @interactive
112 112 def geta():
113 113 return a
114 114 # self.add_engines(1)
115 115 v = self.client[-1]
116 116 v.block=True
117 117 v['f'] = geta
118 118 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
119 119 v.execute('a=5')
120 120 v.execute('b=f()')
121 121 self.assertEquals(v['b'], 5)
122 122
123 123 def test_push_function_defaults(self):
124 124 """test that pushed functions preserve default args"""
125 125 def echo(a=10):
126 126 return a
127 127 v = self.client[-1]
128 128 v.block=True
129 129 v['f'] = echo
130 130 v.execute('b=f()')
131 131 self.assertEquals(v['b'], 10)
132 132
133 133 def test_get_result(self):
134 134 """test getting results from the Hub."""
135 135 c = pmod.Client(profile='iptest')
136 136 # self.add_engines(1)
137 137 t = c.ids[-1]
138 138 v = c[t]
139 139 v2 = self.client[t]
140 140 ar = v.apply_async(wait, 1)
141 141 # give the monitor time to notice the message
142 142 time.sleep(.25)
143 143 ahr = v2.get_result(ar.msg_ids)
144 144 self.assertTrue(isinstance(ahr, AsyncHubResult))
145 145 self.assertEquals(ahr.get(), ar.get())
146 146 ar2 = v2.get_result(ar.msg_ids)
147 147 self.assertFalse(isinstance(ar2, AsyncHubResult))
148 148 c.spin()
149 149 c.close()
150 150
151 151 def test_run_newline(self):
152 152 """test that run appends newline to files"""
153 153 tmpfile = mktemp()
154 154 with open(tmpfile, 'w') as f:
155 155 f.write("""def g():
156 156 return 5
157 157 """)
158 158 v = self.client[-1]
159 159 v.run(tmpfile, block=True)
160 160 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
161 161
162 162 def test_apply_tracked(self):
163 163 """test tracking for apply"""
164 164 # self.add_engines(1)
165 165 t = self.client.ids[-1]
166 166 v = self.client[t]
167 167 v.block=False
168 168 def echo(n=1024*1024, **kwargs):
169 169 with v.temp_flags(**kwargs):
170 170 return v.apply(lambda x: x, 'x'*n)
171 171 ar = echo(1, track=False)
172 172 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 173 self.assertTrue(ar.sent)
174 174 ar = echo(track=True)
175 175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
176 176 self.assertEquals(ar.sent, ar._tracker.done)
177 177 ar._tracker.wait()
178 178 self.assertTrue(ar.sent)
179 179
180 180 def test_push_tracked(self):
181 181 t = self.client.ids[-1]
182 182 ns = dict(x='x'*1024*1024)
183 183 v = self.client[t]
184 184 ar = v.push(ns, block=False, track=False)
185 185 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 186 self.assertTrue(ar.sent)
187 187
188 188 ar = v.push(ns, block=False, track=True)
189 189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 190 ar._tracker.wait()
191 191 self.assertEquals(ar.sent, ar._tracker.done)
192 192 self.assertTrue(ar.sent)
193 193 ar.get()
194 194
195 195 def test_scatter_tracked(self):
196 196 t = self.client.ids
197 197 x='x'*1024*1024
198 198 ar = self.client[t].scatter('x', x, block=False, track=False)
199 199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 200 self.assertTrue(ar.sent)
201 201
202 202 ar = self.client[t].scatter('x', x, block=False, track=True)
203 203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 204 self.assertEquals(ar.sent, ar._tracker.done)
205 205 ar._tracker.wait()
206 206 self.assertTrue(ar.sent)
207 207 ar.get()
208 208
209 209 def test_remote_reference(self):
210 210 v = self.client[-1]
211 211 v['a'] = 123
212 212 ra = pmod.Reference('a')
213 213 b = v.apply_sync(lambda x: x, ra)
214 214 self.assertEquals(b, 123)
215 215
216 216
217 217 def test_scatter_gather(self):
218 218 view = self.client[:]
219 219 seq1 = range(16)
220 220 view.scatter('a', seq1)
221 221 seq2 = view.gather('a', block=True)
222 222 self.assertEquals(seq2, seq1)
223 223 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
224 224
225 225 @skip_without('numpy')
226 226 def test_scatter_gather_numpy(self):
227 227 import numpy
228 228 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
229 229 view = self.client[:]
230 230 a = numpy.arange(64)
231 231 view.scatter('a', a)
232 232 b = view.gather('a', block=True)
233 233 assert_array_equal(b, a)
234 234
235 235 def test_map(self):
236 236 view = self.client[:]
237 237 def f(x):
238 238 return x**2
239 239 data = range(16)
240 240 r = view.map_sync(f, data)
241 241 self.assertEquals(r, map(f, data))
242 242
243 def test_map_iterable(self):
244 """test map on iterables (direct)"""
245 view = self.client[:]
246 # 101 is prime, so it won't be evenly distributed
247 arr = range(101)
248 # ensure it will be an iterator, even in Python 3
249 it = iter(arr)
250 r = view.map_sync(lambda x:x, arr)
251 self.assertEquals(r, list(arr))
252
243 253 def test_scatterGatherNonblocking(self):
244 254 data = range(16)
245 255 view = self.client[:]
246 256 view.scatter('a', data, block=False)
247 257 ar = view.gather('a', block=False)
248 258 self.assertEquals(ar.get(), data)
249 259
250 260 @skip_without('numpy')
251 261 def test_scatter_gather_numpy_nonblocking(self):
252 262 import numpy
253 263 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
254 264 a = numpy.arange(64)
255 265 view = self.client[:]
256 266 ar = view.scatter('a', a, block=False)
257 267 self.assertTrue(isinstance(ar, AsyncResult))
258 268 amr = view.gather('a', block=False)
259 269 self.assertTrue(isinstance(amr, AsyncMapResult))
260 270 assert_array_equal(amr.get(), a)
261 271
262 272 def test_execute(self):
263 273 view = self.client[:]
264 274 # self.client.debug=True
265 275 execute = view.execute
266 276 ar = execute('c=30', block=False)
267 277 self.assertTrue(isinstance(ar, AsyncResult))
268 278 ar = execute('d=[0,1,2]', block=False)
269 279 self.client.wait(ar, 1)
270 280 self.assertEquals(len(ar.get()), len(self.client))
271 281 for c in view['c']:
272 282 self.assertEquals(c, 30)
273 283
274 284 def test_abort(self):
275 285 view = self.client[-1]
276 286 ar = view.execute('import time; time.sleep(1)', block=False)
277 287 ar2 = view.apply_async(lambda : 2)
278 288 ar3 = view.apply_async(lambda : 3)
279 289 view.abort(ar2)
280 290 view.abort(ar3.msg_ids)
281 291 self.assertRaises(error.TaskAborted, ar2.get)
282 292 self.assertRaises(error.TaskAborted, ar3.get)
283 293
284 294 def test_temp_flags(self):
285 295 view = self.client[-1]
286 296 view.block=True
287 297 with view.temp_flags(block=False):
288 298 self.assertFalse(view.block)
289 299 self.assertTrue(view.block)
290 300
291 301 def test_importer(self):
292 302 view = self.client[-1]
293 303 view.clear(block=True)
294 304 with view.importer:
295 305 import re
296 306
297 307 @interactive
298 308 def findall(pat, s):
299 309 # this globals() step isn't necessary in real code
300 310 # only to prevent a closure in the test
301 311 re = globals()['re']
302 312 return re.findall(pat, s)
303 313
304 314 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
305 315
306 316 # parallel magic tests
307 317
308 318 def test_magic_px_blocking(self):
309 319 ip = get_ipython()
310 320 v = self.client[-1]
311 321 v.activate()
312 322 v.block=True
313 323
314 324 ip.magic_px('a=5')
315 325 self.assertEquals(v['a'], 5)
316 326 ip.magic_px('a=10')
317 327 self.assertEquals(v['a'], 10)
318 328 sio = StringIO()
319 329 savestdout = sys.stdout
320 330 sys.stdout = sio
321 331 # just 'print a' worst ~99% of the time, but this ensures that
322 332 # the stdout message has arrived when the result is finished:
323 333 ip.magic_px('import sys,time;print a; sys.stdout.flush();time.sleep(0.2)')
324 334 sys.stdout = savestdout
325 335 buf = sio.getvalue()
326 336 self.assertTrue('[stdout:' in buf, buf)
327 337 self.assertTrue(buf.rstrip().endswith('10'))
328 338 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
329 339
330 340 def test_magic_px_nonblocking(self):
331 341 ip = get_ipython()
332 342 v = self.client[-1]
333 343 v.activate()
334 344 v.block=False
335 345
336 346 ip.magic_px('a=5')
337 347 self.assertEquals(v['a'], 5)
338 348 ip.magic_px('a=10')
339 349 self.assertEquals(v['a'], 10)
340 350 sio = StringIO()
341 351 savestdout = sys.stdout
342 352 sys.stdout = sio
343 353 ip.magic_px('print a')
344 354 sys.stdout = savestdout
345 355 buf = sio.getvalue()
346 356 self.assertFalse('[stdout:%i]'%v.targets in buf)
347 357 ip.magic_px('1/0')
348 358 ar = v.get_result(-1)
349 359 self.assertRaisesRemote(ZeroDivisionError, ar.get)
350 360
351 361 def test_magic_autopx_blocking(self):
352 362 ip = get_ipython()
353 363 v = self.client[-1]
354 364 v.activate()
355 365 v.block=True
356 366
357 367 sio = StringIO()
358 368 savestdout = sys.stdout
359 369 sys.stdout = sio
360 370 ip.magic_autopx()
361 371 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
362 372 ip.run_cell('print b')
363 373 ip.run_cell("b/c")
364 374 ip.run_code(compile('b*=2', '', 'single'))
365 375 ip.magic_autopx()
366 376 sys.stdout = savestdout
367 377 output = sio.getvalue().strip()
368 378 self.assertTrue(output.startswith('%autopx enabled'))
369 379 self.assertTrue(output.endswith('%autopx disabled'))
370 380 self.assertTrue('RemoteError: ZeroDivisionError' in output)
371 381 ar = v.get_result(-2)
372 382 self.assertEquals(v['a'], 5)
373 383 self.assertEquals(v['b'], 20)
374 384 self.assertRaisesRemote(ZeroDivisionError, ar.get)
375 385
376 386 def test_magic_autopx_nonblocking(self):
377 387 ip = get_ipython()
378 388 v = self.client[-1]
379 389 v.activate()
380 390 v.block=False
381 391
382 392 sio = StringIO()
383 393 savestdout = sys.stdout
384 394 sys.stdout = sio
385 395 ip.magic_autopx()
386 396 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
387 397 ip.run_cell('print b')
388 398 ip.run_cell("b/c")
389 399 ip.run_code(compile('b*=2', '', 'single'))
390 400 ip.magic_autopx()
391 401 sys.stdout = savestdout
392 402 output = sio.getvalue().strip()
393 403 self.assertTrue(output.startswith('%autopx enabled'))
394 404 self.assertTrue(output.endswith('%autopx disabled'))
395 405 self.assertFalse('ZeroDivisionError' in output)
396 406 ar = v.get_result(-2)
397 407 self.assertEquals(v['a'], 5)
398 408 self.assertEquals(v['b'], 20)
399 409 self.assertRaisesRemote(ZeroDivisionError, ar.get)
400 410
401 411 def test_magic_result(self):
402 412 ip = get_ipython()
403 413 v = self.client[-1]
404 414 v.activate()
405 415 v['a'] = 111
406 416 ra = v['a']
407 417
408 418 ar = ip.magic_result()
409 419 self.assertEquals(ar.msg_ids, [v.history[-1]])
410 420 self.assertEquals(ar.get(), 111)
411 421 ar = ip.magic_result('-2')
412 422 self.assertEquals(ar.msg_ids, [v.history[-2]])
413 423
414 424 def test_unicode_execute(self):
415 425 """test executing unicode strings"""
416 426 v = self.client[-1]
417 427 v.block=True
418 428 if sys.version_info[0] >= 3:
419 429 code="a='é'"
420 430 else:
421 431 code=u"a=u'é'"
422 432 v.execute(code)
423 433 self.assertEquals(v['a'], u'é')
424 434
425 435 def test_unicode_apply_result(self):
426 436 """test unicode apply results"""
427 437 v = self.client[-1]
428 438 r = v.apply_sync(lambda : u'é')
429 439 self.assertEquals(r, u'é')
430 440
431 441 def test_unicode_apply_arg(self):
432 442 """test passing unicode arguments to apply"""
433 443 v = self.client[-1]
434 444
435 445 @interactive
436 446 def check_unicode(a, check):
437 447 assert isinstance(a, unicode), "%r is not unicode"%a
438 448 assert isinstance(check, bytes), "%r is not bytes"%check
439 449 assert a.encode('utf8') == check, "%s != %s"%(a,check)
440 450
441 451 for s in [ u'é', u'ßø®∫',u'asdf' ]:
442 452 try:
443 453 v.apply_sync(check_unicode, s, s.encode('utf8'))
444 454 except error.RemoteError as e:
445 455 if e.ename == 'AssertionError':
446 456 self.fail(e.evalue)
447 457 else:
448 458 raise e
449
450
459
451 460
General Comments 0
You need to be logged in to leave comments. Login now