##// END OF EJS Templates
support iterators in view.map...
MinRK -
Show More
@@ -1,165 +1,171 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes used in scattering and gathering sequences.
3 """Classes used in scattering and gathering sequences.
4
4
5 Scattering consists of partitioning a sequence and sending the various
5 Scattering consists of partitioning a sequence and sending the various
6 pieces to individual nodes in a cluster.
6 pieces to individual nodes in a cluster.
7
7
8
8
9 Authors:
9 Authors:
10
10
11 * Brian Granger
11 * Brian Granger
12 * MinRK
12 * MinRK
13
13
14 """
14 """
15
15
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17 # Copyright (C) 2008-2011 The IPython Development Team
17 # Copyright (C) 2008-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-------------------------------------------------------------------------------
21 #-------------------------------------------------------------------------------
22
22
23 #-------------------------------------------------------------------------------
23 #-------------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-------------------------------------------------------------------------------
25 #-------------------------------------------------------------------------------
26
26
27 from __future__ import division
27 from __future__ import division
28
28
29 import types
29 import types
30 from itertools import islice
30
31
31 from IPython.utils.data import flatten as utils_flatten
32 from IPython.utils.data import flatten as utils_flatten
32
33
33 #-------------------------------------------------------------------------------
34 #-------------------------------------------------------------------------------
34 # Figure out which array packages are present and their array types
35 # Figure out which array packages are present and their array types
35 #-------------------------------------------------------------------------------
36 #-------------------------------------------------------------------------------
36
37
37 arrayModules = []
38 arrayModules = []
38 try:
39 try:
39 import Numeric
40 import Numeric
40 except ImportError:
41 except ImportError:
41 pass
42 pass
42 else:
43 else:
43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 try:
45 try:
45 import numpy
46 import numpy
46 except ImportError:
47 except ImportError:
47 pass
48 pass
48 else:
49 else:
49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 try:
51 try:
51 import numarray
52 import numarray
52 except ImportError:
53 except ImportError:
53 pass
54 pass
54 else:
55 else:
55 arrayModules.append({'module':numarray,
56 arrayModules.append({'module':numarray,
56 'type':numarray.numarraycore.NumArray})
57 'type':numarray.numarraycore.NumArray})
57
58
58 class Map:
59 class Map:
59 """A class for partitioning a sequence using a map."""
60 """A class for partitioning a sequence using a map."""
60
61
61 def getPartition(self, seq, p, q):
62 def getPartition(self, seq, p, q):
62 """Returns the pth partition of q partitions of seq."""
63 """Returns the pth partition of q partitions of seq."""
63
64
64 # Test for error conditions here
65 # Test for error conditions here
65 if p<0 or p>=q:
66 if p<0 or p>=q:
66 print "No partition exists."
67 print "No partition exists."
67 return
68 return
68
69
69 remainder = len(seq)%q
70 remainder = len(seq)%q
70 basesize = len(seq)//q
71 basesize = len(seq)//q
71 hi = []
72 hi = []
72 lo = []
73 lo = []
73 for n in range(q):
74 for n in range(q):
74 if n < remainder:
75 if n < remainder:
75 lo.append(n * (basesize + 1))
76 lo.append(n * (basesize + 1))
76 hi.append(lo[-1] + basesize + 1)
77 hi.append(lo[-1] + basesize + 1)
77 else:
78 else:
78 lo.append(n*basesize + remainder)
79 lo.append(n*basesize + remainder)
79 hi.append(lo[-1] + basesize)
80 hi.append(lo[-1] + basesize)
80
81
81
82 result = seq[lo[p]:hi[p]]
82 try:
83 result = seq[lo[p]:hi[p]]
84 except TypeError:
85 # some objects (iterators) can't be sliced,
86 # use islice:
87 result = list(islice(seq, lo[p], hi[p]))
88
83 return result
89 return result
84
90
85 def joinPartitions(self, listOfPartitions):
91 def joinPartitions(self, listOfPartitions):
86 return self.concatenate(listOfPartitions)
92 return self.concatenate(listOfPartitions)
87
93
88 def concatenate(self, listOfPartitions):
94 def concatenate(self, listOfPartitions):
89 testObject = listOfPartitions[0]
95 testObject = listOfPartitions[0]
90 # First see if we have a known array type
96 # First see if we have a known array type
91 for m in arrayModules:
97 for m in arrayModules:
92 #print m
98 #print m
93 if isinstance(testObject, m['type']):
99 if isinstance(testObject, m['type']):
94 return m['module'].concatenate(listOfPartitions)
100 return m['module'].concatenate(listOfPartitions)
95 # Next try for Python sequence types
101 # Next try for Python sequence types
96 if isinstance(testObject, (types.ListType, types.TupleType)):
102 if isinstance(testObject, (types.ListType, types.TupleType)):
97 return utils_flatten(listOfPartitions)
103 return utils_flatten(listOfPartitions)
98 # If we have scalars, just return listOfPartitions
104 # If we have scalars, just return listOfPartitions
99 return listOfPartitions
105 return listOfPartitions
100
106
101 class RoundRobinMap(Map):
107 class RoundRobinMap(Map):
102 """Partitions a sequence in a roun robin fashion.
108 """Partitions a sequence in a roun robin fashion.
103
109
104 This currently does not work!
110 This currently does not work!
105 """
111 """
106
112
107 def getPartition(self, seq, p, q):
113 def getPartition(self, seq, p, q):
108 # if not isinstance(seq,(list,tuple)):
114 # if not isinstance(seq,(list,tuple)):
109 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
115 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
110 return seq[p:len(seq):q]
116 return seq[p:len(seq):q]
111 #result = []
117 #result = []
112 #for i in range(p,len(seq),q):
118 #for i in range(p,len(seq),q):
113 # result.append(seq[i])
119 # result.append(seq[i])
114 #return result
120 #return result
115
121
116 def joinPartitions(self, listOfPartitions):
122 def joinPartitions(self, listOfPartitions):
117 testObject = listOfPartitions[0]
123 testObject = listOfPartitions[0]
118 # First see if we have a known array type
124 # First see if we have a known array type
119 for m in arrayModules:
125 for m in arrayModules:
120 #print m
126 #print m
121 if isinstance(testObject, m['type']):
127 if isinstance(testObject, m['type']):
122 return self.flatten_array(m['type'], listOfPartitions)
128 return self.flatten_array(m['type'], listOfPartitions)
123 if isinstance(testObject, (types.ListType, types.TupleType)):
129 if isinstance(testObject, (types.ListType, types.TupleType)):
124 return self.flatten_list(listOfPartitions)
130 return self.flatten_list(listOfPartitions)
125 return listOfPartitions
131 return listOfPartitions
126
132
127 def flatten_array(self, klass, listOfPartitions):
133 def flatten_array(self, klass, listOfPartitions):
128 test = listOfPartitions[0]
134 test = listOfPartitions[0]
129 shape = list(test.shape)
135 shape = list(test.shape)
130 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
136 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
131 A = klass(shape)
137 A = klass(shape)
132 N = shape[0]
138 N = shape[0]
133 q = len(listOfPartitions)
139 q = len(listOfPartitions)
134 for p,part in enumerate(listOfPartitions):
140 for p,part in enumerate(listOfPartitions):
135 A[p:N:q] = part
141 A[p:N:q] = part
136 return A
142 return A
137
143
138 def flatten_list(self, listOfPartitions):
144 def flatten_list(self, listOfPartitions):
139 flat = []
145 flat = []
140 for i in range(len(listOfPartitions[0])):
146 for i in range(len(listOfPartitions[0])):
141 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
147 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
142 return flat
148 return flat
143 #lengths = [len(x) for x in listOfPartitions]
149 #lengths = [len(x) for x in listOfPartitions]
144 #maxPartitionLength = len(listOfPartitions[0])
150 #maxPartitionLength = len(listOfPartitions[0])
145 #numberOfPartitions = len(listOfPartitions)
151 #numberOfPartitions = len(listOfPartitions)
146 #concat = self.concatenate(listOfPartitions)
152 #concat = self.concatenate(listOfPartitions)
147 #totalLength = len(concat)
153 #totalLength = len(concat)
148 #result = []
154 #result = []
149 #for i in range(maxPartitionLength):
155 #for i in range(maxPartitionLength):
150 # result.append(concat[i:totalLength:maxPartitionLength])
156 # result.append(concat[i:totalLength:maxPartitionLength])
151 # return self.concatenate(listOfPartitions)
157 # return self.concatenate(listOfPartitions)
152
158
153 def mappable(obj):
159 def mappable(obj):
154 """return whether an object is mappable or not."""
160 """return whether an object is mappable or not."""
155 if isinstance(obj, (tuple,list)):
161 if isinstance(obj, (tuple,list)):
156 return True
162 return True
157 for m in arrayModules:
163 for m in arrayModules:
158 if isinstance(obj,m['type']):
164 if isinstance(obj,m['type']):
159 return True
165 return True
160 return False
166 return False
161
167
162 dists = {'b':Map,'r':RoundRobinMap}
168 dists = {'b':Map,'r':RoundRobinMap}
163
169
164
170
165
171
@@ -1,167 +1,178 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test LoadBalancedView objects
2 """test LoadBalancedView objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from nose import SkipTest
23 from nose import SkipTest
24
24
25 from IPython import parallel as pmod
25 from IPython import parallel as pmod
26 from IPython.parallel import error
26 from IPython.parallel import error
27
27
28 from IPython.parallel.tests import add_engines
28 from IPython.parallel.tests import add_engines
29
29
30 from .clienttest import ClusterTestCase, crash, wait, skip_without
30 from .clienttest import ClusterTestCase, crash, wait, skip_without
31
31
32 def setup():
32 def setup():
33 add_engines(3)
33 add_engines(3)
34
34
35 class TestLoadBalancedView(ClusterTestCase):
35 class TestLoadBalancedView(ClusterTestCase):
36
36
37 def setUp(self):
37 def setUp(self):
38 ClusterTestCase.setUp(self)
38 ClusterTestCase.setUp(self)
39 self.view = self.client.load_balanced_view()
39 self.view = self.client.load_balanced_view()
40
40
41 def test_z_crash_task(self):
41 def test_z_crash_task(self):
42 """test graceful handling of engine death (balanced)"""
42 """test graceful handling of engine death (balanced)"""
43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
43 raise SkipTest("crash tests disabled, due to undesirable crash reports")
44 # self.add_engines(1)
44 # self.add_engines(1)
45 ar = self.view.apply_async(crash)
45 ar = self.view.apply_async(crash)
46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 eid = ar.engine_id
47 eid = ar.engine_id
48 tic = time.time()
48 tic = time.time()
49 while eid in self.client.ids and time.time()-tic < 5:
49 while eid in self.client.ids and time.time()-tic < 5:
50 time.sleep(.01)
50 time.sleep(.01)
51 self.client.spin()
51 self.client.spin()
52 self.assertFalse(eid in self.client.ids, "Engine should have died")
52 self.assertFalse(eid in self.client.ids, "Engine should have died")
53
53
54 def test_map(self):
54 def test_map(self):
55 def f(x):
55 def f(x):
56 return x**2
56 return x**2
57 data = range(16)
57 data = range(16)
58 r = self.view.map_sync(f, data)
58 r = self.view.map_sync(f, data)
59 self.assertEquals(r, map(f, data))
59 self.assertEquals(r, map(f, data))
60
60
61 def test_map_unordered(self):
61 def test_map_unordered(self):
62 def f(x):
62 def f(x):
63 return x**2
63 return x**2
64 def slow_f(x):
64 def slow_f(x):
65 import time
65 import time
66 time.sleep(0.05*x)
66 time.sleep(0.05*x)
67 return x**2
67 return x**2
68 data = range(16,0,-1)
68 data = range(16,0,-1)
69 reference = map(f, data)
69 reference = map(f, data)
70
70
71 amr = self.view.map_async(slow_f, data, ordered=False)
71 amr = self.view.map_async(slow_f, data, ordered=False)
72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
72 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
73 # check individual elements, retrieved as they come
73 # check individual elements, retrieved as they come
74 # list comprehension uses __iter__
74 # list comprehension uses __iter__
75 astheycame = [ r for r in amr ]
75 astheycame = [ r for r in amr ]
76 # Ensure that at least one result came out of order:
76 # Ensure that at least one result came out of order:
77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
77 self.assertNotEquals(astheycame, reference, "should not have preserved order")
78 self.assertEquals(sorted(astheycame, reverse=True), reference, "result corrupted")
78 self.assertEquals(sorted(astheycame, reverse=True), reference, "result corrupted")
79
79
80 def test_map_ordered(self):
80 def test_map_ordered(self):
81 def f(x):
81 def f(x):
82 return x**2
82 return x**2
83 def slow_f(x):
83 def slow_f(x):
84 import time
84 import time
85 time.sleep(0.05*x)
85 time.sleep(0.05*x)
86 return x**2
86 return x**2
87 data = range(16,0,-1)
87 data = range(16,0,-1)
88 reference = map(f, data)
88 reference = map(f, data)
89
89
90 amr = self.view.map_async(slow_f, data)
90 amr = self.view.map_async(slow_f, data)
91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
91 self.assertTrue(isinstance(amr, pmod.AsyncMapResult))
92 # check individual elements, retrieved as they come
92 # check individual elements, retrieved as they come
93 # list(amr) uses __iter__
93 # list(amr) uses __iter__
94 astheycame = list(amr)
94 astheycame = list(amr)
95 # Ensure that results came in order
95 # Ensure that results came in order
96 self.assertEquals(astheycame, reference)
96 self.assertEquals(astheycame, reference)
97 self.assertEquals(amr.result, reference)
97 self.assertEquals(amr.result, reference)
98
99 def test_map_iterable(self):
100 """test map on iterables (balanced)"""
101 view = self.view
102 # 101 is prime, so it won't be evenly distributed
103 arr = range(101)
104 # so that it will be an iterator, even in Python 3
105 it = iter(arr)
106 r = view.map_sync(lambda x:x, arr)
107 self.assertEquals(r, list(arr))
108
98
109
99 def test_abort(self):
110 def test_abort(self):
100 view = self.view
111 view = self.view
101 ar = self.client[:].apply_async(time.sleep, .5)
112 ar = self.client[:].apply_async(time.sleep, .5)
102 ar = self.client[:].apply_async(time.sleep, .5)
113 ar = self.client[:].apply_async(time.sleep, .5)
103 time.sleep(0.2)
114 time.sleep(0.2)
104 ar2 = view.apply_async(lambda : 2)
115 ar2 = view.apply_async(lambda : 2)
105 ar3 = view.apply_async(lambda : 3)
116 ar3 = view.apply_async(lambda : 3)
106 view.abort(ar2)
117 view.abort(ar2)
107 view.abort(ar3.msg_ids)
118 view.abort(ar3.msg_ids)
108 self.assertRaises(error.TaskAborted, ar2.get)
119 self.assertRaises(error.TaskAborted, ar2.get)
109 self.assertRaises(error.TaskAborted, ar3.get)
120 self.assertRaises(error.TaskAborted, ar3.get)
110
121
111 def test_retries(self):
122 def test_retries(self):
112 add_engines(3)
123 add_engines(3)
113 view = self.view
124 view = self.view
114 view.timeout = 1 # prevent hang if this doesn't behave
125 view.timeout = 1 # prevent hang if this doesn't behave
115 def fail():
126 def fail():
116 assert False
127 assert False
117 for r in range(len(self.client)-1):
128 for r in range(len(self.client)-1):
118 with view.temp_flags(retries=r):
129 with view.temp_flags(retries=r):
119 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
130 self.assertRaisesRemote(AssertionError, view.apply_sync, fail)
120
131
121 with view.temp_flags(retries=len(self.client), timeout=0.25):
132 with view.temp_flags(retries=len(self.client), timeout=0.25):
122 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
133 self.assertRaisesRemote(error.TaskTimeout, view.apply_sync, fail)
123
134
124 def test_invalid_dependency(self):
135 def test_invalid_dependency(self):
125 view = self.view
136 view = self.view
126 with view.temp_flags(after='12345'):
137 with view.temp_flags(after='12345'):
127 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
138 self.assertRaisesRemote(error.InvalidDependency, view.apply_sync, lambda : 1)
128
139
129 def test_impossible_dependency(self):
140 def test_impossible_dependency(self):
130 if len(self.client) < 2:
141 if len(self.client) < 2:
131 add_engines(2)
142 add_engines(2)
132 view = self.client.load_balanced_view()
143 view = self.client.load_balanced_view()
133 ar1 = view.apply_async(lambda : 1)
144 ar1 = view.apply_async(lambda : 1)
134 ar1.get()
145 ar1.get()
135 e1 = ar1.engine_id
146 e1 = ar1.engine_id
136 e2 = e1
147 e2 = e1
137 while e2 == e1:
148 while e2 == e1:
138 ar2 = view.apply_async(lambda : 1)
149 ar2 = view.apply_async(lambda : 1)
139 ar2.get()
150 ar2.get()
140 e2 = ar2.engine_id
151 e2 = ar2.engine_id
141
152
142 with view.temp_flags(follow=[ar1, ar2]):
153 with view.temp_flags(follow=[ar1, ar2]):
143 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
154 self.assertRaisesRemote(error.ImpossibleDependency, view.apply_sync, lambda : 1)
144
155
145
156
146 def test_follow(self):
157 def test_follow(self):
147 ar = self.view.apply_async(lambda : 1)
158 ar = self.view.apply_async(lambda : 1)
148 ar.get()
159 ar.get()
149 ars = []
160 ars = []
150 first_id = ar.engine_id
161 first_id = ar.engine_id
151
162
152 self.view.follow = ar
163 self.view.follow = ar
153 for i in range(5):
164 for i in range(5):
154 ars.append(self.view.apply_async(lambda : 1))
165 ars.append(self.view.apply_async(lambda : 1))
155 self.view.wait(ars)
166 self.view.wait(ars)
156 for ar in ars:
167 for ar in ars:
157 self.assertEquals(ar.engine_id, first_id)
168 self.assertEquals(ar.engine_id, first_id)
158
169
159 def test_after(self):
170 def test_after(self):
160 view = self.view
171 view = self.view
161 ar = view.apply_async(time.sleep, 0.5)
172 ar = view.apply_async(time.sleep, 0.5)
162 with view.temp_flags(after=ar):
173 with view.temp_flags(after=ar):
163 ar2 = view.apply_async(lambda : 1)
174 ar2 = view.apply_async(lambda : 1)
164
175
165 ar.wait()
176 ar.wait()
166 ar2.wait()
177 ar2.wait()
167 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
178 self.assertTrue(ar2.started >= ar.completed, "%s not >= %s"%(ar.started, ar.completed))
@@ -1,451 +1,460 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython import parallel as pmod
27 from IPython import parallel as pmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
29 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
30 from IPython.parallel import DirectView
30 from IPython.parallel import DirectView
31 from IPython.parallel.util import interactive
31 from IPython.parallel.util import interactive
32
32
33 from IPython.parallel.tests import add_engines
33 from IPython.parallel.tests import add_engines
34
34
35 from .clienttest import ClusterTestCase, crash, wait, skip_without
35 from .clienttest import ClusterTestCase, crash, wait, skip_without
36
36
37 def setup():
37 def setup():
38 add_engines(3)
38 add_engines(3)
39
39
40 class TestView(ClusterTestCase):
40 class TestView(ClusterTestCase):
41
41
42 def test_z_crash_mux(self):
42 def test_z_crash_mux(self):
43 """test graceful handling of engine death (direct)"""
43 """test graceful handling of engine death (direct)"""
44 raise SkipTest("crash tests disabled, due to undesirable crash reports")
44 raise SkipTest("crash tests disabled, due to undesirable crash reports")
45 # self.add_engines(1)
45 # self.add_engines(1)
46 eid = self.client.ids[-1]
46 eid = self.client.ids[-1]
47 ar = self.client[eid].apply_async(crash)
47 ar = self.client[eid].apply_async(crash)
48 self.assertRaisesRemote(error.EngineError, ar.get, 10)
48 self.assertRaisesRemote(error.EngineError, ar.get, 10)
49 eid = ar.engine_id
49 eid = ar.engine_id
50 tic = time.time()
50 tic = time.time()
51 while eid in self.client.ids and time.time()-tic < 5:
51 while eid in self.client.ids and time.time()-tic < 5:
52 time.sleep(.01)
52 time.sleep(.01)
53 self.client.spin()
53 self.client.spin()
54 self.assertFalse(eid in self.client.ids, "Engine should have died")
54 self.assertFalse(eid in self.client.ids, "Engine should have died")
55
55
56 def test_push_pull(self):
56 def test_push_pull(self):
57 """test pushing and pulling"""
57 """test pushing and pulling"""
58 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
58 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
59 t = self.client.ids[-1]
59 t = self.client.ids[-1]
60 v = self.client[t]
60 v = self.client[t]
61 push = v.push
61 push = v.push
62 pull = v.pull
62 pull = v.pull
63 v.block=True
63 v.block=True
64 nengines = len(self.client)
64 nengines = len(self.client)
65 push({'data':data})
65 push({'data':data})
66 d = pull('data')
66 d = pull('data')
67 self.assertEquals(d, data)
67 self.assertEquals(d, data)
68 self.client[:].push({'data':data})
68 self.client[:].push({'data':data})
69 d = self.client[:].pull('data', block=True)
69 d = self.client[:].pull('data', block=True)
70 self.assertEquals(d, nengines*[data])
70 self.assertEquals(d, nengines*[data])
71 ar = push({'data':data}, block=False)
71 ar = push({'data':data}, block=False)
72 self.assertTrue(isinstance(ar, AsyncResult))
72 self.assertTrue(isinstance(ar, AsyncResult))
73 r = ar.get()
73 r = ar.get()
74 ar = self.client[:].pull('data', block=False)
74 ar = self.client[:].pull('data', block=False)
75 self.assertTrue(isinstance(ar, AsyncResult))
75 self.assertTrue(isinstance(ar, AsyncResult))
76 r = ar.get()
76 r = ar.get()
77 self.assertEquals(r, nengines*[data])
77 self.assertEquals(r, nengines*[data])
78 self.client[:].push(dict(a=10,b=20))
78 self.client[:].push(dict(a=10,b=20))
79 r = self.client[:].pull(('a','b'), block=True)
79 r = self.client[:].pull(('a','b'), block=True)
80 self.assertEquals(r, nengines*[[10,20]])
80 self.assertEquals(r, nengines*[[10,20]])
81
81
82 def test_push_pull_function(self):
82 def test_push_pull_function(self):
83 "test pushing and pulling functions"
83 "test pushing and pulling functions"
84 def testf(x):
84 def testf(x):
85 return 2.0*x
85 return 2.0*x
86
86
87 t = self.client.ids[-1]
87 t = self.client.ids[-1]
88 v = self.client[t]
88 v = self.client[t]
89 v.block=True
89 v.block=True
90 push = v.push
90 push = v.push
91 pull = v.pull
91 pull = v.pull
92 execute = v.execute
92 execute = v.execute
93 push({'testf':testf})
93 push({'testf':testf})
94 r = pull('testf')
94 r = pull('testf')
95 self.assertEqual(r(1.0), testf(1.0))
95 self.assertEqual(r(1.0), testf(1.0))
96 execute('r = testf(10)')
96 execute('r = testf(10)')
97 r = pull('r')
97 r = pull('r')
98 self.assertEquals(r, testf(10))
98 self.assertEquals(r, testf(10))
99 ar = self.client[:].push({'testf':testf}, block=False)
99 ar = self.client[:].push({'testf':testf}, block=False)
100 ar.get()
100 ar.get()
101 ar = self.client[:].pull('testf', block=False)
101 ar = self.client[:].pull('testf', block=False)
102 rlist = ar.get()
102 rlist = ar.get()
103 for r in rlist:
103 for r in rlist:
104 self.assertEqual(r(1.0), testf(1.0))
104 self.assertEqual(r(1.0), testf(1.0))
105 execute("def g(x): return x*x")
105 execute("def g(x): return x*x")
106 r = pull(('testf','g'))
106 r = pull(('testf','g'))
107 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
107 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
108
108
109 def test_push_function_globals(self):
109 def test_push_function_globals(self):
110 """test that pushed functions have access to globals"""
110 """test that pushed functions have access to globals"""
111 @interactive
111 @interactive
112 def geta():
112 def geta():
113 return a
113 return a
114 # self.add_engines(1)
114 # self.add_engines(1)
115 v = self.client[-1]
115 v = self.client[-1]
116 v.block=True
116 v.block=True
117 v['f'] = geta
117 v['f'] = geta
118 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
118 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
119 v.execute('a=5')
119 v.execute('a=5')
120 v.execute('b=f()')
120 v.execute('b=f()')
121 self.assertEquals(v['b'], 5)
121 self.assertEquals(v['b'], 5)
122
122
123 def test_push_function_defaults(self):
123 def test_push_function_defaults(self):
124 """test that pushed functions preserve default args"""
124 """test that pushed functions preserve default args"""
125 def echo(a=10):
125 def echo(a=10):
126 return a
126 return a
127 v = self.client[-1]
127 v = self.client[-1]
128 v.block=True
128 v.block=True
129 v['f'] = echo
129 v['f'] = echo
130 v.execute('b=f()')
130 v.execute('b=f()')
131 self.assertEquals(v['b'], 10)
131 self.assertEquals(v['b'], 10)
132
132
133 def test_get_result(self):
133 def test_get_result(self):
134 """test getting results from the Hub."""
134 """test getting results from the Hub."""
135 c = pmod.Client(profile='iptest')
135 c = pmod.Client(profile='iptest')
136 # self.add_engines(1)
136 # self.add_engines(1)
137 t = c.ids[-1]
137 t = c.ids[-1]
138 v = c[t]
138 v = c[t]
139 v2 = self.client[t]
139 v2 = self.client[t]
140 ar = v.apply_async(wait, 1)
140 ar = v.apply_async(wait, 1)
141 # give the monitor time to notice the message
141 # give the monitor time to notice the message
142 time.sleep(.25)
142 time.sleep(.25)
143 ahr = v2.get_result(ar.msg_ids)
143 ahr = v2.get_result(ar.msg_ids)
144 self.assertTrue(isinstance(ahr, AsyncHubResult))
144 self.assertTrue(isinstance(ahr, AsyncHubResult))
145 self.assertEquals(ahr.get(), ar.get())
145 self.assertEquals(ahr.get(), ar.get())
146 ar2 = v2.get_result(ar.msg_ids)
146 ar2 = v2.get_result(ar.msg_ids)
147 self.assertFalse(isinstance(ar2, AsyncHubResult))
147 self.assertFalse(isinstance(ar2, AsyncHubResult))
148 c.spin()
148 c.spin()
149 c.close()
149 c.close()
150
150
151 def test_run_newline(self):
151 def test_run_newline(self):
152 """test that run appends newline to files"""
152 """test that run appends newline to files"""
153 tmpfile = mktemp()
153 tmpfile = mktemp()
154 with open(tmpfile, 'w') as f:
154 with open(tmpfile, 'w') as f:
155 f.write("""def g():
155 f.write("""def g():
156 return 5
156 return 5
157 """)
157 """)
158 v = self.client[-1]
158 v = self.client[-1]
159 v.run(tmpfile, block=True)
159 v.run(tmpfile, block=True)
160 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
160 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
161
161
162 def test_apply_tracked(self):
162 def test_apply_tracked(self):
163 """test tracking for apply"""
163 """test tracking for apply"""
164 # self.add_engines(1)
164 # self.add_engines(1)
165 t = self.client.ids[-1]
165 t = self.client.ids[-1]
166 v = self.client[t]
166 v = self.client[t]
167 v.block=False
167 v.block=False
168 def echo(n=1024*1024, **kwargs):
168 def echo(n=1024*1024, **kwargs):
169 with v.temp_flags(**kwargs):
169 with v.temp_flags(**kwargs):
170 return v.apply(lambda x: x, 'x'*n)
170 return v.apply(lambda x: x, 'x'*n)
171 ar = echo(1, track=False)
171 ar = echo(1, track=False)
172 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
172 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 self.assertTrue(ar.sent)
173 self.assertTrue(ar.sent)
174 ar = echo(track=True)
174 ar = echo(track=True)
175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
176 self.assertEquals(ar.sent, ar._tracker.done)
176 self.assertEquals(ar.sent, ar._tracker.done)
177 ar._tracker.wait()
177 ar._tracker.wait()
178 self.assertTrue(ar.sent)
178 self.assertTrue(ar.sent)
179
179
180 def test_push_tracked(self):
180 def test_push_tracked(self):
181 t = self.client.ids[-1]
181 t = self.client.ids[-1]
182 ns = dict(x='x'*1024*1024)
182 ns = dict(x='x'*1024*1024)
183 v = self.client[t]
183 v = self.client[t]
184 ar = v.push(ns, block=False, track=False)
184 ar = v.push(ns, block=False, track=False)
185 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
185 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 self.assertTrue(ar.sent)
186 self.assertTrue(ar.sent)
187
187
188 ar = v.push(ns, block=False, track=True)
188 ar = v.push(ns, block=False, track=True)
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
189 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 ar._tracker.wait()
190 ar._tracker.wait()
191 self.assertEquals(ar.sent, ar._tracker.done)
191 self.assertEquals(ar.sent, ar._tracker.done)
192 self.assertTrue(ar.sent)
192 self.assertTrue(ar.sent)
193 ar.get()
193 ar.get()
194
194
195 def test_scatter_tracked(self):
195 def test_scatter_tracked(self):
196 t = self.client.ids
196 t = self.client.ids
197 x='x'*1024*1024
197 x='x'*1024*1024
198 ar = self.client[t].scatter('x', x, block=False, track=False)
198 ar = self.client[t].scatter('x', x, block=False, track=False)
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
199 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 self.assertTrue(ar.sent)
200 self.assertTrue(ar.sent)
201
201
202 ar = self.client[t].scatter('x', x, block=False, track=True)
202 ar = self.client[t].scatter('x', x, block=False, track=True)
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
203 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 self.assertEquals(ar.sent, ar._tracker.done)
204 self.assertEquals(ar.sent, ar._tracker.done)
205 ar._tracker.wait()
205 ar._tracker.wait()
206 self.assertTrue(ar.sent)
206 self.assertTrue(ar.sent)
207 ar.get()
207 ar.get()
208
208
209 def test_remote_reference(self):
209 def test_remote_reference(self):
210 v = self.client[-1]
210 v = self.client[-1]
211 v['a'] = 123
211 v['a'] = 123
212 ra = pmod.Reference('a')
212 ra = pmod.Reference('a')
213 b = v.apply_sync(lambda x: x, ra)
213 b = v.apply_sync(lambda x: x, ra)
214 self.assertEquals(b, 123)
214 self.assertEquals(b, 123)
215
215
216
216
217 def test_scatter_gather(self):
217 def test_scatter_gather(self):
218 view = self.client[:]
218 view = self.client[:]
219 seq1 = range(16)
219 seq1 = range(16)
220 view.scatter('a', seq1)
220 view.scatter('a', seq1)
221 seq2 = view.gather('a', block=True)
221 seq2 = view.gather('a', block=True)
222 self.assertEquals(seq2, seq1)
222 self.assertEquals(seq2, seq1)
223 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
223 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
224
224
225 @skip_without('numpy')
225 @skip_without('numpy')
226 def test_scatter_gather_numpy(self):
226 def test_scatter_gather_numpy(self):
227 import numpy
227 import numpy
228 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
228 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
229 view = self.client[:]
229 view = self.client[:]
230 a = numpy.arange(64)
230 a = numpy.arange(64)
231 view.scatter('a', a)
231 view.scatter('a', a)
232 b = view.gather('a', block=True)
232 b = view.gather('a', block=True)
233 assert_array_equal(b, a)
233 assert_array_equal(b, a)
234
234
235 def test_map(self):
235 def test_map(self):
236 view = self.client[:]
236 view = self.client[:]
237 def f(x):
237 def f(x):
238 return x**2
238 return x**2
239 data = range(16)
239 data = range(16)
240 r = view.map_sync(f, data)
240 r = view.map_sync(f, data)
241 self.assertEquals(r, map(f, data))
241 self.assertEquals(r, map(f, data))
242
242
243 def test_map_iterable(self):
244 """test map on iterables (direct)"""
245 view = self.client[:]
246 # 101 is prime, so it won't be evenly distributed
247 arr = range(101)
248 # ensure it will be an iterator, even in Python 3
249 it = iter(arr)
250 r = view.map_sync(lambda x:x, arr)
251 self.assertEquals(r, list(arr))
252
243 def test_scatterGatherNonblocking(self):
253 def test_scatterGatherNonblocking(self):
244 data = range(16)
254 data = range(16)
245 view = self.client[:]
255 view = self.client[:]
246 view.scatter('a', data, block=False)
256 view.scatter('a', data, block=False)
247 ar = view.gather('a', block=False)
257 ar = view.gather('a', block=False)
248 self.assertEquals(ar.get(), data)
258 self.assertEquals(ar.get(), data)
249
259
250 @skip_without('numpy')
260 @skip_without('numpy')
251 def test_scatter_gather_numpy_nonblocking(self):
261 def test_scatter_gather_numpy_nonblocking(self):
252 import numpy
262 import numpy
253 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
263 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
254 a = numpy.arange(64)
264 a = numpy.arange(64)
255 view = self.client[:]
265 view = self.client[:]
256 ar = view.scatter('a', a, block=False)
266 ar = view.scatter('a', a, block=False)
257 self.assertTrue(isinstance(ar, AsyncResult))
267 self.assertTrue(isinstance(ar, AsyncResult))
258 amr = view.gather('a', block=False)
268 amr = view.gather('a', block=False)
259 self.assertTrue(isinstance(amr, AsyncMapResult))
269 self.assertTrue(isinstance(amr, AsyncMapResult))
260 assert_array_equal(amr.get(), a)
270 assert_array_equal(amr.get(), a)
261
271
262 def test_execute(self):
272 def test_execute(self):
263 view = self.client[:]
273 view = self.client[:]
264 # self.client.debug=True
274 # self.client.debug=True
265 execute = view.execute
275 execute = view.execute
266 ar = execute('c=30', block=False)
276 ar = execute('c=30', block=False)
267 self.assertTrue(isinstance(ar, AsyncResult))
277 self.assertTrue(isinstance(ar, AsyncResult))
268 ar = execute('d=[0,1,2]', block=False)
278 ar = execute('d=[0,1,2]', block=False)
269 self.client.wait(ar, 1)
279 self.client.wait(ar, 1)
270 self.assertEquals(len(ar.get()), len(self.client))
280 self.assertEquals(len(ar.get()), len(self.client))
271 for c in view['c']:
281 for c in view['c']:
272 self.assertEquals(c, 30)
282 self.assertEquals(c, 30)
273
283
274 def test_abort(self):
284 def test_abort(self):
275 view = self.client[-1]
285 view = self.client[-1]
276 ar = view.execute('import time; time.sleep(1)', block=False)
286 ar = view.execute('import time; time.sleep(1)', block=False)
277 ar2 = view.apply_async(lambda : 2)
287 ar2 = view.apply_async(lambda : 2)
278 ar3 = view.apply_async(lambda : 3)
288 ar3 = view.apply_async(lambda : 3)
279 view.abort(ar2)
289 view.abort(ar2)
280 view.abort(ar3.msg_ids)
290 view.abort(ar3.msg_ids)
281 self.assertRaises(error.TaskAborted, ar2.get)
291 self.assertRaises(error.TaskAborted, ar2.get)
282 self.assertRaises(error.TaskAborted, ar3.get)
292 self.assertRaises(error.TaskAborted, ar3.get)
283
293
284 def test_temp_flags(self):
294 def test_temp_flags(self):
285 view = self.client[-1]
295 view = self.client[-1]
286 view.block=True
296 view.block=True
287 with view.temp_flags(block=False):
297 with view.temp_flags(block=False):
288 self.assertFalse(view.block)
298 self.assertFalse(view.block)
289 self.assertTrue(view.block)
299 self.assertTrue(view.block)
290
300
291 def test_importer(self):
301 def test_importer(self):
292 view = self.client[-1]
302 view = self.client[-1]
293 view.clear(block=True)
303 view.clear(block=True)
294 with view.importer:
304 with view.importer:
295 import re
305 import re
296
306
297 @interactive
307 @interactive
298 def findall(pat, s):
308 def findall(pat, s):
299 # this globals() step isn't necessary in real code
309 # this globals() step isn't necessary in real code
300 # only to prevent a closure in the test
310 # only to prevent a closure in the test
301 re = globals()['re']
311 re = globals()['re']
302 return re.findall(pat, s)
312 return re.findall(pat, s)
303
313
304 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
314 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
305
315
306 # parallel magic tests
316 # parallel magic tests
307
317
308 def test_magic_px_blocking(self):
318 def test_magic_px_blocking(self):
309 ip = get_ipython()
319 ip = get_ipython()
310 v = self.client[-1]
320 v = self.client[-1]
311 v.activate()
321 v.activate()
312 v.block=True
322 v.block=True
313
323
314 ip.magic_px('a=5')
324 ip.magic_px('a=5')
315 self.assertEquals(v['a'], 5)
325 self.assertEquals(v['a'], 5)
316 ip.magic_px('a=10')
326 ip.magic_px('a=10')
317 self.assertEquals(v['a'], 10)
327 self.assertEquals(v['a'], 10)
318 sio = StringIO()
328 sio = StringIO()
319 savestdout = sys.stdout
329 savestdout = sys.stdout
320 sys.stdout = sio
330 sys.stdout = sio
321 # just 'print a' worst ~99% of the time, but this ensures that
331 # just 'print a' worst ~99% of the time, but this ensures that
322 # the stdout message has arrived when the result is finished:
332 # the stdout message has arrived when the result is finished:
323 ip.magic_px('import sys,time;print a; sys.stdout.flush();time.sleep(0.2)')
333 ip.magic_px('import sys,time;print a; sys.stdout.flush();time.sleep(0.2)')
324 sys.stdout = savestdout
334 sys.stdout = savestdout
325 buf = sio.getvalue()
335 buf = sio.getvalue()
326 self.assertTrue('[stdout:' in buf, buf)
336 self.assertTrue('[stdout:' in buf, buf)
327 self.assertTrue(buf.rstrip().endswith('10'))
337 self.assertTrue(buf.rstrip().endswith('10'))
328 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
338 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
329
339
330 def test_magic_px_nonblocking(self):
340 def test_magic_px_nonblocking(self):
331 ip = get_ipython()
341 ip = get_ipython()
332 v = self.client[-1]
342 v = self.client[-1]
333 v.activate()
343 v.activate()
334 v.block=False
344 v.block=False
335
345
336 ip.magic_px('a=5')
346 ip.magic_px('a=5')
337 self.assertEquals(v['a'], 5)
347 self.assertEquals(v['a'], 5)
338 ip.magic_px('a=10')
348 ip.magic_px('a=10')
339 self.assertEquals(v['a'], 10)
349 self.assertEquals(v['a'], 10)
340 sio = StringIO()
350 sio = StringIO()
341 savestdout = sys.stdout
351 savestdout = sys.stdout
342 sys.stdout = sio
352 sys.stdout = sio
343 ip.magic_px('print a')
353 ip.magic_px('print a')
344 sys.stdout = savestdout
354 sys.stdout = savestdout
345 buf = sio.getvalue()
355 buf = sio.getvalue()
346 self.assertFalse('[stdout:%i]'%v.targets in buf)
356 self.assertFalse('[stdout:%i]'%v.targets in buf)
347 ip.magic_px('1/0')
357 ip.magic_px('1/0')
348 ar = v.get_result(-1)
358 ar = v.get_result(-1)
349 self.assertRaisesRemote(ZeroDivisionError, ar.get)
359 self.assertRaisesRemote(ZeroDivisionError, ar.get)
350
360
351 def test_magic_autopx_blocking(self):
361 def test_magic_autopx_blocking(self):
352 ip = get_ipython()
362 ip = get_ipython()
353 v = self.client[-1]
363 v = self.client[-1]
354 v.activate()
364 v.activate()
355 v.block=True
365 v.block=True
356
366
357 sio = StringIO()
367 sio = StringIO()
358 savestdout = sys.stdout
368 savestdout = sys.stdout
359 sys.stdout = sio
369 sys.stdout = sio
360 ip.magic_autopx()
370 ip.magic_autopx()
361 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
371 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
362 ip.run_cell('print b')
372 ip.run_cell('print b')
363 ip.run_cell("b/c")
373 ip.run_cell("b/c")
364 ip.run_code(compile('b*=2', '', 'single'))
374 ip.run_code(compile('b*=2', '', 'single'))
365 ip.magic_autopx()
375 ip.magic_autopx()
366 sys.stdout = savestdout
376 sys.stdout = savestdout
367 output = sio.getvalue().strip()
377 output = sio.getvalue().strip()
368 self.assertTrue(output.startswith('%autopx enabled'))
378 self.assertTrue(output.startswith('%autopx enabled'))
369 self.assertTrue(output.endswith('%autopx disabled'))
379 self.assertTrue(output.endswith('%autopx disabled'))
370 self.assertTrue('RemoteError: ZeroDivisionError' in output)
380 self.assertTrue('RemoteError: ZeroDivisionError' in output)
371 ar = v.get_result(-2)
381 ar = v.get_result(-2)
372 self.assertEquals(v['a'], 5)
382 self.assertEquals(v['a'], 5)
373 self.assertEquals(v['b'], 20)
383 self.assertEquals(v['b'], 20)
374 self.assertRaisesRemote(ZeroDivisionError, ar.get)
384 self.assertRaisesRemote(ZeroDivisionError, ar.get)
375
385
376 def test_magic_autopx_nonblocking(self):
386 def test_magic_autopx_nonblocking(self):
377 ip = get_ipython()
387 ip = get_ipython()
378 v = self.client[-1]
388 v = self.client[-1]
379 v.activate()
389 v.activate()
380 v.block=False
390 v.block=False
381
391
382 sio = StringIO()
392 sio = StringIO()
383 savestdout = sys.stdout
393 savestdout = sys.stdout
384 sys.stdout = sio
394 sys.stdout = sio
385 ip.magic_autopx()
395 ip.magic_autopx()
386 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
396 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
387 ip.run_cell('print b')
397 ip.run_cell('print b')
388 ip.run_cell("b/c")
398 ip.run_cell("b/c")
389 ip.run_code(compile('b*=2', '', 'single'))
399 ip.run_code(compile('b*=2', '', 'single'))
390 ip.magic_autopx()
400 ip.magic_autopx()
391 sys.stdout = savestdout
401 sys.stdout = savestdout
392 output = sio.getvalue().strip()
402 output = sio.getvalue().strip()
393 self.assertTrue(output.startswith('%autopx enabled'))
403 self.assertTrue(output.startswith('%autopx enabled'))
394 self.assertTrue(output.endswith('%autopx disabled'))
404 self.assertTrue(output.endswith('%autopx disabled'))
395 self.assertFalse('ZeroDivisionError' in output)
405 self.assertFalse('ZeroDivisionError' in output)
396 ar = v.get_result(-2)
406 ar = v.get_result(-2)
397 self.assertEquals(v['a'], 5)
407 self.assertEquals(v['a'], 5)
398 self.assertEquals(v['b'], 20)
408 self.assertEquals(v['b'], 20)
399 self.assertRaisesRemote(ZeroDivisionError, ar.get)
409 self.assertRaisesRemote(ZeroDivisionError, ar.get)
400
410
401 def test_magic_result(self):
411 def test_magic_result(self):
402 ip = get_ipython()
412 ip = get_ipython()
403 v = self.client[-1]
413 v = self.client[-1]
404 v.activate()
414 v.activate()
405 v['a'] = 111
415 v['a'] = 111
406 ra = v['a']
416 ra = v['a']
407
417
408 ar = ip.magic_result()
418 ar = ip.magic_result()
409 self.assertEquals(ar.msg_ids, [v.history[-1]])
419 self.assertEquals(ar.msg_ids, [v.history[-1]])
410 self.assertEquals(ar.get(), 111)
420 self.assertEquals(ar.get(), 111)
411 ar = ip.magic_result('-2')
421 ar = ip.magic_result('-2')
412 self.assertEquals(ar.msg_ids, [v.history[-2]])
422 self.assertEquals(ar.msg_ids, [v.history[-2]])
413
423
414 def test_unicode_execute(self):
424 def test_unicode_execute(self):
415 """test executing unicode strings"""
425 """test executing unicode strings"""
416 v = self.client[-1]
426 v = self.client[-1]
417 v.block=True
427 v.block=True
418 if sys.version_info[0] >= 3:
428 if sys.version_info[0] >= 3:
419 code="a='é'"
429 code="a='é'"
420 else:
430 else:
421 code=u"a=u'é'"
431 code=u"a=u'é'"
422 v.execute(code)
432 v.execute(code)
423 self.assertEquals(v['a'], u'é')
433 self.assertEquals(v['a'], u'é')
424
434
425 def test_unicode_apply_result(self):
435 def test_unicode_apply_result(self):
426 """test unicode apply results"""
436 """test unicode apply results"""
427 v = self.client[-1]
437 v = self.client[-1]
428 r = v.apply_sync(lambda : u'é')
438 r = v.apply_sync(lambda : u'é')
429 self.assertEquals(r, u'é')
439 self.assertEquals(r, u'é')
430
440
431 def test_unicode_apply_arg(self):
441 def test_unicode_apply_arg(self):
432 """test passing unicode arguments to apply"""
442 """test passing unicode arguments to apply"""
433 v = self.client[-1]
443 v = self.client[-1]
434
444
435 @interactive
445 @interactive
436 def check_unicode(a, check):
446 def check_unicode(a, check):
437 assert isinstance(a, unicode), "%r is not unicode"%a
447 assert isinstance(a, unicode), "%r is not unicode"%a
438 assert isinstance(check, bytes), "%r is not bytes"%check
448 assert isinstance(check, bytes), "%r is not bytes"%check
439 assert a.encode('utf8') == check, "%s != %s"%(a,check)
449 assert a.encode('utf8') == check, "%s != %s"%(a,check)
440
450
441 for s in [ u'é', u'ßø®∫',u'asdf' ]:
451 for s in [ u'é', u'ßø®∫',u'asdf' ]:
442 try:
452 try:
443 v.apply_sync(check_unicode, s, s.encode('utf8'))
453 v.apply_sync(check_unicode, s, s.encode('utf8'))
444 except error.RemoteError as e:
454 except error.RemoteError as e:
445 if e.ename == 'AssertionError':
455 if e.ename == 'AssertionError':
446 self.fail(e.evalue)
456 self.fail(e.evalue)
447 else:
457 else:
448 raise e
458 raise e
449
459
450
451
460
General Comments 0
You need to be logged in to leave comments. Login now