##// END OF EJS Templates
move capture_output util from parallel tests to utils.io
MinRK -
Show More
@@ -1,213 +1,183 b''
1 1 """base class for parallel client tests
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 import sys
16 16 import tempfile
17 17 import time
18 18 from StringIO import StringIO
19 19
20 20 from nose import SkipTest
21 21
22 22 import zmq
23 23 from zmq.tests import BaseZMQTestCase
24 24
25 25 from IPython.external.decorator import decorator
26 26
27 27 from IPython.parallel import error
28 28 from IPython.parallel import Client
29 29
30 30 from IPython.parallel.tests import launchers, add_engines
31 31
32 32 # simple tasks for use in apply tests
33 33
34 34 def segfault():
35 35 """this will segfault"""
36 36 import ctypes
37 37 ctypes.memset(-1,0,1)
38 38
39 39 def crash():
40 40 """from stdlib crashers in the test suite"""
41 41 import types
42 42 if sys.platform.startswith('win'):
43 43 import ctypes
44 44 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 46 if sys.version_info[0] >= 3:
47 47 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 48 args.insert(1, 0)
49 49
50 50 co = types.CodeType(*args)
51 51 exec(co)
52 52
53 53 def wait(n):
54 54 """sleep for a time"""
55 55 import time
56 56 time.sleep(n)
57 57 return n
58 58
59 59 def raiser(eclass):
60 60 """raise an exception"""
61 61 raise eclass()
62 62
63 63 def generate_output():
64 64 """function for testing output
65 65
66 66 publishes two outputs of each type, and returns
67 67 a rich displayable object.
68 68 """
69 69
70 70 import sys
71 71 from IPython.core.display import display, HTML, Math
72 72
73 73 print "stdout"
74 74 print >> sys.stderr, "stderr"
75 75
76 76 display(HTML("<b>HTML</b>"))
77 77
78 78 print "stdout2"
79 79 print >> sys.stderr, "stderr2"
80 80
81 81 display(Math(r"\alpha=\beta"))
82 82
83 83 return Math("42")
84 84
85 85 # test decorator for skipping tests when libraries are unavailable
86 86 def skip_without(*names):
87 87 """skip a test if some names are not importable"""
88 88 @decorator
89 89 def skip_without_names(f, *args, **kwargs):
90 90 """decorator to skip tests in the absence of numpy."""
91 91 for name in names:
92 92 try:
93 93 __import__(name)
94 94 except ImportError:
95 95 raise SkipTest
96 96 return f(*args, **kwargs)
97 97 return skip_without_names
98 98
99 99 #-------------------------------------------------------------------------------
100 100 # Classes
101 101 #-------------------------------------------------------------------------------
102 102
103 class CapturedIO(object):
104 """Simple object for containing captured stdout/err StringIO objects"""
105
106 def __init__(self, stdout, stderr):
107 self.stdout_io = stdout
108 self.stderr_io = stderr
109
110 @property
111 def stdout(self):
112 return self.stdout_io.getvalue()
113
114 @property
115 def stderr(self):
116 return self.stderr_io.getvalue()
117
118
119 class capture_output(object):
120 """context manager for capturing stdout/err"""
121
122 def __enter__(self):
123 self.sys_stdout = sys.stdout
124 self.sys_stderr = sys.stderr
125 stdout = sys.stdout = StringIO()
126 stderr = sys.stderr = StringIO()
127 return CapturedIO(stdout, stderr)
128
129 def __exit__(self, exc_type, exc_value, traceback):
130 sys.stdout = self.sys_stdout
131 sys.stderr = self.sys_stderr
132
133 103
134 104 class ClusterTestCase(BaseZMQTestCase):
135 105
136 106 def add_engines(self, n=1, block=True):
137 107 """add multiple engines to our cluster"""
138 108 self.engines.extend(add_engines(n))
139 109 if block:
140 110 self.wait_on_engines()
141 111
142 112 def minimum_engines(self, n=1, block=True):
143 113 """add engines until there are at least n connected"""
144 114 self.engines.extend(add_engines(n, total=True))
145 115 if block:
146 116 self.wait_on_engines()
147 117
148 118
149 119 def wait_on_engines(self, timeout=5):
150 120 """wait for our engines to connect."""
151 121 n = len(self.engines)+self.base_engine_count
152 122 tic = time.time()
153 123 while time.time()-tic < timeout and len(self.client.ids) < n:
154 124 time.sleep(0.1)
155 125
156 126 assert not len(self.client.ids) < n, "waiting for engines timed out"
157 127
158 128 def connect_client(self):
159 129 """connect a client with my Context, and track its sockets for cleanup"""
160 130 c = Client(profile='iptest', context=self.context)
161 131 for name in filter(lambda n:n.endswith('socket'), dir(c)):
162 132 s = getattr(c, name)
163 133 s.setsockopt(zmq.LINGER, 0)
164 134 self.sockets.append(s)
165 135 return c
166 136
167 137 def assertRaisesRemote(self, etype, f, *args, **kwargs):
168 138 try:
169 139 try:
170 140 f(*args, **kwargs)
171 141 except error.CompositeError as e:
172 142 e.raise_exception()
173 143 except error.RemoteError as e:
174 144 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
175 145 else:
176 146 self.fail("should have raised a RemoteError")
177 147
178 148 def _wait_for(self, f, timeout=10):
179 149 """wait for a condition"""
180 150 tic = time.time()
181 151 while time.time() <= tic + timeout:
182 152 if f():
183 153 return
184 154 time.sleep(0.1)
185 155 self.client.spin()
186 156 if not f():
187 157 print "Warning: Awaited condition never arrived"
188 158
189 159 def setUp(self):
190 160 BaseZMQTestCase.setUp(self)
191 161 self.client = self.connect_client()
192 162 # start every test with clean engine namespaces:
193 163 self.client.clear(block=True)
194 164 self.base_engine_count=len(self.client.ids)
195 165 self.engines=[]
196 166
197 167 def tearDown(self):
198 168 # self.client.clear(block=True)
199 169 # close fds:
200 170 for e in filter(lambda e: e.poll() is not None, launchers):
201 171 launchers.remove(e)
202 172
203 173 # allow flushing of incoming messages to prevent crash on socket close
204 174 self.client.wait(timeout=2)
205 175 # time.sleep(2)
206 176 self.client.spin()
207 177 self.client.close()
208 178 BaseZMQTestCase.tearDown(self)
209 179 # this will be redundant when pyzmq merges PR #88
210 180 # self.context.term()
211 181 # print tempfile.TemporaryFile().fileno(),
212 182 # sys.stdout.flush()
213 183 No newline at end of file
@@ -1,266 +1,267 b''
1 1 """Tests for asyncresult.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import time
20 20
21 from IPython.parallel.error import TimeoutError
21 from IPython.utils.io import capture_output
22 22
23 from IPython.parallel.error import TimeoutError
23 24 from IPython.parallel import error, Client
24 25 from IPython.parallel.tests import add_engines
25 from .clienttest import ClusterTestCase, capture_output
26 from .clienttest import ClusterTestCase
26 27
27 28 def setup():
28 29 add_engines(2, total=True)
29 30
30 31 def wait(n):
31 32 import time
32 33 time.sleep(n)
33 34 return n
34 35
35 36 class AsyncResultTest(ClusterTestCase):
36 37
37 38 def test_single_result_view(self):
38 39 """various one-target views get the right value for single_result"""
39 40 eid = self.client.ids[-1]
40 41 ar = self.client[eid].apply_async(lambda : 42)
41 42 self.assertEquals(ar.get(), 42)
42 43 ar = self.client[[eid]].apply_async(lambda : 42)
43 44 self.assertEquals(ar.get(), [42])
44 45 ar = self.client[-1:].apply_async(lambda : 42)
45 46 self.assertEquals(ar.get(), [42])
46 47
47 48 def test_get_after_done(self):
48 49 ar = self.client[-1].apply_async(lambda : 42)
49 50 ar.wait()
50 51 self.assertTrue(ar.ready())
51 52 self.assertEquals(ar.get(), 42)
52 53 self.assertEquals(ar.get(), 42)
53 54
54 55 def test_get_before_done(self):
55 56 ar = self.client[-1].apply_async(wait, 0.1)
56 57 self.assertRaises(TimeoutError, ar.get, 0)
57 58 ar.wait(0)
58 59 self.assertFalse(ar.ready())
59 60 self.assertEquals(ar.get(), 0.1)
60 61
61 62 def test_get_after_error(self):
62 63 ar = self.client[-1].apply_async(lambda : 1/0)
63 64 ar.wait(10)
64 65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 66 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 67 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
67 68
68 69 def test_get_dict(self):
69 70 n = len(self.client)
70 71 ar = self.client[:].apply_async(lambda : 5)
71 72 self.assertEquals(ar.get(), [5]*n)
72 73 d = ar.get_dict()
73 74 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 75 for eid,r in d.iteritems():
75 76 self.assertEquals(r, 5)
76 77
77 78 def test_list_amr(self):
78 79 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
79 80 rlist = list(ar)
80 81
81 82 def test_getattr(self):
82 83 ar = self.client[:].apply_async(wait, 0.5)
83 84 self.assertRaises(AttributeError, lambda : ar._foo)
84 85 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
85 86 self.assertRaises(AttributeError, lambda : ar.foo)
86 87 self.assertRaises(AttributeError, lambda : ar.engine_id)
87 88 self.assertFalse(hasattr(ar, '__length_hint__'))
88 89 self.assertFalse(hasattr(ar, 'foo'))
89 90 self.assertFalse(hasattr(ar, 'engine_id'))
90 91 ar.get(5)
91 92 self.assertRaises(AttributeError, lambda : ar._foo)
92 93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 94 self.assertRaises(AttributeError, lambda : ar.foo)
94 95 self.assertTrue(isinstance(ar.engine_id, list))
95 96 self.assertEquals(ar.engine_id, ar['engine_id'])
96 97 self.assertFalse(hasattr(ar, '__length_hint__'))
97 98 self.assertFalse(hasattr(ar, 'foo'))
98 99 self.assertTrue(hasattr(ar, 'engine_id'))
99 100
100 101 def test_getitem(self):
101 102 ar = self.client[:].apply_async(wait, 0.5)
102 103 self.assertRaises(TimeoutError, lambda : ar['foo'])
103 104 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
104 105 ar.get(5)
105 106 self.assertRaises(KeyError, lambda : ar['foo'])
106 107 self.assertTrue(isinstance(ar['engine_id'], list))
107 108 self.assertEquals(ar.engine_id, ar['engine_id'])
108 109
109 110 def test_single_result(self):
110 111 ar = self.client[-1].apply_async(wait, 0.5)
111 112 self.assertRaises(TimeoutError, lambda : ar['foo'])
112 113 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
113 114 self.assertTrue(ar.get(5) == 0.5)
114 115 self.assertTrue(isinstance(ar['engine_id'], int))
115 116 self.assertTrue(isinstance(ar.engine_id, int))
116 117 self.assertEquals(ar.engine_id, ar['engine_id'])
117 118
118 119 def test_abort(self):
119 120 e = self.client[-1]
120 121 ar = e.execute('import time; time.sleep(1)', block=False)
121 122 ar2 = e.apply_async(lambda : 2)
122 123 ar2.abort()
123 124 self.assertRaises(error.TaskAborted, ar2.get)
124 125 ar.get()
125 126
126 127 def test_len(self):
127 128 v = self.client.load_balanced_view()
128 129 ar = v.map_async(lambda x: x, range(10))
129 130 self.assertEquals(len(ar), 10)
130 131 ar = v.apply_async(lambda x: x, range(10))
131 132 self.assertEquals(len(ar), 1)
132 133 ar = self.client[:].apply_async(lambda x: x, range(10))
133 134 self.assertEquals(len(ar), len(self.client.ids))
134 135
135 136 def test_wall_time_single(self):
136 137 v = self.client.load_balanced_view()
137 138 ar = v.apply_async(time.sleep, 0.25)
138 139 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
139 140 ar.get(2)
140 141 self.assertTrue(ar.wall_time < 1.)
141 142 self.assertTrue(ar.wall_time > 0.2)
142 143
143 144 def test_wall_time_multi(self):
144 145 self.minimum_engines(4)
145 146 v = self.client[:]
146 147 ar = v.apply_async(time.sleep, 0.25)
147 148 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
148 149 ar.get(2)
149 150 self.assertTrue(ar.wall_time < 1.)
150 151 self.assertTrue(ar.wall_time > 0.2)
151 152
152 153 def test_serial_time_single(self):
153 154 v = self.client.load_balanced_view()
154 155 ar = v.apply_async(time.sleep, 0.25)
155 156 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
156 157 ar.get(2)
157 158 self.assertTrue(ar.serial_time < 1.)
158 159 self.assertTrue(ar.serial_time > 0.2)
159 160
160 161 def test_serial_time_multi(self):
161 162 self.minimum_engines(4)
162 163 v = self.client[:]
163 164 ar = v.apply_async(time.sleep, 0.25)
164 165 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
165 166 ar.get(2)
166 167 self.assertTrue(ar.serial_time < 2.)
167 168 self.assertTrue(ar.serial_time > 0.8)
168 169
169 170 def test_elapsed_single(self):
170 171 v = self.client.load_balanced_view()
171 172 ar = v.apply_async(time.sleep, 0.25)
172 173 while not ar.ready():
173 174 time.sleep(0.01)
174 175 self.assertTrue(ar.elapsed < 1)
175 176 self.assertTrue(ar.elapsed < 1)
176 177 ar.get(2)
177 178
178 179 def test_elapsed_multi(self):
179 180 v = self.client[:]
180 181 ar = v.apply_async(time.sleep, 0.25)
181 182 while not ar.ready():
182 183 time.sleep(0.01)
183 184 self.assertTrue(ar.elapsed < 1)
184 185 self.assertTrue(ar.elapsed < 1)
185 186 ar.get(2)
186 187
187 188 def test_hubresult_timestamps(self):
188 189 self.minimum_engines(4)
189 190 v = self.client[:]
190 191 ar = v.apply_async(time.sleep, 0.25)
191 192 ar.get(2)
192 193 rc2 = Client(profile='iptest')
193 194 # must have try/finally to close second Client, otherwise
194 195 # will have dangling sockets causing problems
195 196 try:
196 197 time.sleep(0.25)
197 198 hr = rc2.get_result(ar.msg_ids)
198 199 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
199 200 hr.get(1)
200 201 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
201 202 self.assertEquals(hr.serial_time, ar.serial_time)
202 203 finally:
203 204 rc2.close()
204 205
205 206 def test_display_empty_streams_single(self):
206 207 """empty stdout/err are not displayed (single result)"""
207 208 self.minimum_engines(1)
208 209
209 210 v = self.client[-1]
210 211 ar = v.execute("print (5555)")
211 212 ar.get(5)
212 213 with capture_output() as io:
213 214 ar.display_outputs()
214 215 self.assertEquals(io.stderr, '')
215 216 self.assertEquals('5555\n', io.stdout)
216 217
217 218 ar = v.execute("a=5")
218 219 ar.get(5)
219 220 with capture_output() as io:
220 221 ar.display_outputs()
221 222 self.assertEquals(io.stderr, '')
222 223 self.assertEquals(io.stdout, '')
223 224
224 225 def test_display_empty_streams_type(self):
225 226 """empty stdout/err are not displayed (groupby type)"""
226 227 self.minimum_engines(1)
227 228
228 229 v = self.client[:]
229 230 ar = v.execute("print (5555)")
230 231 ar.get(5)
231 232 with capture_output() as io:
232 233 ar.display_outputs()
233 234 self.assertEquals(io.stderr, '')
234 235 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
235 236 self.assertFalse('\n\n' in io.stdout, io.stdout)
236 237 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
237 238
238 239 ar = v.execute("a=5")
239 240 ar.get(5)
240 241 with capture_output() as io:
241 242 ar.display_outputs()
242 243 self.assertEquals(io.stderr, '')
243 244 self.assertEquals(io.stdout, '')
244 245
245 246 def test_display_empty_streams_engine(self):
246 247 """empty stdout/err are not displayed (groupby engine)"""
247 248 self.minimum_engines(1)
248 249
249 250 v = self.client[:]
250 251 ar = v.execute("print (5555)")
251 252 ar.get(5)
252 253 with capture_output() as io:
253 254 ar.display_outputs('engine')
254 255 self.assertEquals(io.stderr, '')
255 256 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
256 257 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 258 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
258 259
259 260 ar = v.execute("a=5")
260 261 ar.get(5)
261 262 with capture_output() as io:
262 263 ar.display_outputs('engine')
263 264 self.assertEquals(io.stderr, '')
264 265 self.assertEquals(io.stdout, '')
265 266
266 267
@@ -1,339 +1,340 b''
1 1 # -*- coding: utf-8 -*-
2 2 """Test Parallel magics
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import re
20 20 import sys
21 21 import time
22 22
23 23 import zmq
24 24 from nose import SkipTest
25 25
26 26 from IPython.testing import decorators as dec
27 27 from IPython.testing.ipunittest import ParametricTestCase
28 from IPython.utils.io import capture_output
28 29
29 30 from IPython import parallel as pmod
30 31 from IPython.parallel import error
31 32 from IPython.parallel import AsyncResult
32 33 from IPython.parallel.util import interactive
33 34
34 35 from IPython.parallel.tests import add_engines
35 36
36 from .clienttest import ClusterTestCase, capture_output, generate_output
37 from .clienttest import ClusterTestCase, generate_output
37 38
38 39 def setup():
39 40 add_engines(3, total=True)
40 41
41 42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
42 43
43 44 def test_px_blocking(self):
44 45 ip = get_ipython()
45 46 v = self.client[-1:]
46 47 v.activate()
47 48 v.block=True
48 49
49 50 ip.magic('px a=5')
50 51 self.assertEquals(v['a'], [5])
51 52 ip.magic('px a=10')
52 53 self.assertEquals(v['a'], [10])
53 54 # just 'print a' works ~99% of the time, but this ensures that
54 55 # the stdout message has arrived when the result is finished:
55 56 with capture_output() as io:
56 57 ip.magic(
57 58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
58 59 )
59 60 out = io.stdout
60 61 self.assertTrue('[stdout:' in out, out)
61 62 self.assertFalse('\n\n' in out)
62 63 self.assertTrue(out.rstrip().endswith('10'))
63 64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
64 65
65 66 def _check_generated_stderr(self, stderr, n):
66 67 expected = [
67 68 r'\[stderr:\d+\]',
68 69 '^stderr$',
69 70 '^stderr2$',
70 71 ] * n
71 72
72 73 self.assertFalse('\n\n' in stderr, stderr)
73 74 lines = stderr.splitlines()
74 75 self.assertEquals(len(lines), len(expected), stderr)
75 76 for line,expect in zip(lines, expected):
76 77 if isinstance(expect, str):
77 78 expect = [expect]
78 79 for ex in expect:
79 80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
80 81
81 82 def test_cellpx_block_args(self):
82 83 """%%px --[no]block flags work"""
83 84 ip = get_ipython()
84 85 v = self.client[-1:]
85 86 v.activate()
86 87 v.block=False
87 88
88 89 for block in (True, False):
89 90 v.block = block
90 91
91 92 with capture_output() as io:
92 93 ip.run_cell_magic("px", "", "1")
93 94 if block:
94 95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
95 96 else:
96 97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
97 98
98 99 with capture_output() as io:
99 100 ip.run_cell_magic("px", "--block", "1")
100 101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
101 102
102 103 with capture_output() as io:
103 104 ip.run_cell_magic("px", "--noblock", "1")
104 105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
105 106
106 107 def test_cellpx_groupby_engine(self):
107 108 """%%px --group-outputs=engine"""
108 109 ip = get_ipython()
109 110 v = self.client[:]
110 111 v.block = True
111 112 v.activate()
112 113
113 114 v['generate_output'] = generate_output
114 115
115 116 with capture_output() as io:
116 117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
117 118
118 119 self.assertFalse('\n\n' in io.stdout)
119 120 lines = io.stdout.splitlines()[1:]
120 121 expected = [
121 122 r'\[stdout:\d+\]',
122 123 'stdout',
123 124 'stdout2',
124 125 r'\[output:\d+\]',
125 126 r'IPython\.core\.display\.HTML',
126 127 r'IPython\.core\.display\.Math',
127 128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
128 129 ] * len(v)
129 130
130 131 self.assertEquals(len(lines), len(expected), io.stdout)
131 132 for line,expect in zip(lines, expected):
132 133 if isinstance(expect, str):
133 134 expect = [expect]
134 135 for ex in expect:
135 136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
136 137
137 138 self._check_generated_stderr(io.stderr, len(v))
138 139
139 140
140 141 def test_cellpx_groupby_order(self):
141 142 """%%px --group-outputs=order"""
142 143 ip = get_ipython()
143 144 v = self.client[:]
144 145 v.block = True
145 146 v.activate()
146 147
147 148 v['generate_output'] = generate_output
148 149
149 150 with capture_output() as io:
150 151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
151 152
152 153 self.assertFalse('\n\n' in io.stdout)
153 154 lines = io.stdout.splitlines()[1:]
154 155 expected = []
155 156 expected.extend([
156 157 r'\[stdout:\d+\]',
157 158 'stdout',
158 159 'stdout2',
159 160 ] * len(v))
160 161 expected.extend([
161 162 r'\[output:\d+\]',
162 163 'IPython.core.display.HTML',
163 164 ] * len(v))
164 165 expected.extend([
165 166 r'\[output:\d+\]',
166 167 'IPython.core.display.Math',
167 168 ] * len(v))
168 169 expected.extend([
169 170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
170 171 ] * len(v))
171 172
172 173 self.assertEquals(len(lines), len(expected), io.stdout)
173 174 for line,expect in zip(lines, expected):
174 175 if isinstance(expect, str):
175 176 expect = [expect]
176 177 for ex in expect:
177 178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
178 179
179 180 self._check_generated_stderr(io.stderr, len(v))
180 181
181 182 def test_cellpx_groupby_type(self):
182 183 """%%px --group-outputs=type"""
183 184 ip = get_ipython()
184 185 v = self.client[:]
185 186 v.block = True
186 187 v.activate()
187 188
188 189 v['generate_output'] = generate_output
189 190
190 191 with capture_output() as io:
191 192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
192 193
193 194 self.assertFalse('\n\n' in io.stdout)
194 195 lines = io.stdout.splitlines()[1:]
195 196
196 197 expected = []
197 198 expected.extend([
198 199 r'\[stdout:\d+\]',
199 200 'stdout',
200 201 'stdout2',
201 202 ] * len(v))
202 203 expected.extend([
203 204 r'\[output:\d+\]',
204 205 r'IPython\.core\.display\.HTML',
205 206 r'IPython\.core\.display\.Math',
206 207 ] * len(v))
207 208 expected.extend([
208 209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
209 210 ] * len(v))
210 211
211 212 self.assertEquals(len(lines), len(expected), io.stdout)
212 213 for line,expect in zip(lines, expected):
213 214 if isinstance(expect, str):
214 215 expect = [expect]
215 216 for ex in expect:
216 217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
217 218
218 219 self._check_generated_stderr(io.stderr, len(v))
219 220
220 221
221 222 def test_px_nonblocking(self):
222 223 ip = get_ipython()
223 224 v = self.client[-1:]
224 225 v.activate()
225 226 v.block=False
226 227
227 228 ip.magic('px a=5')
228 229 self.assertEquals(v['a'], [5])
229 230 ip.magic('px a=10')
230 231 self.assertEquals(v['a'], [10])
231 232 with capture_output() as io:
232 233 ar = ip.magic('px print (a)')
233 234 self.assertTrue(isinstance(ar, AsyncResult))
234 235 self.assertTrue('Async' in io.stdout)
235 236 self.assertFalse('[stdout:' in io.stdout)
236 237 self.assertFalse('\n\n' in io.stdout)
237 238
238 239 ar = ip.magic('px 1/0')
239 240 self.assertRaisesRemote(ZeroDivisionError, ar.get)
240 241
241 242 def test_autopx_blocking(self):
242 243 ip = get_ipython()
243 244 v = self.client[-1]
244 245 v.activate()
245 246 v.block=True
246 247
247 248 with capture_output() as io:
248 249 ip.magic('autopx')
249 250 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
250 251 ip.run_cell('b*=2')
251 252 ip.run_cell('print (b)')
252 253 ip.run_cell('b')
253 254 ip.run_cell("b/c")
254 255 ip.magic('autopx')
255 256
256 257 output = io.stdout
257 258
258 259 self.assertTrue(output.startswith('%autopx enabled'), output)
259 260 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
260 261 self.assertTrue('RemoteError: ZeroDivisionError' in output, output)
261 262 self.assertTrue('\nOut[' in output, output)
262 263 self.assertTrue(': 24690' in output, output)
263 264 ar = v.get_result(-1)
264 265 self.assertEquals(v['a'], 5)
265 266 self.assertEquals(v['b'], 24690)
266 267 self.assertRaisesRemote(ZeroDivisionError, ar.get)
267 268
268 269 def test_autopx_nonblocking(self):
269 270 ip = get_ipython()
270 271 v = self.client[-1]
271 272 v.activate()
272 273 v.block=False
273 274
274 275 with capture_output() as io:
275 276 ip.magic('autopx')
276 277 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
277 278 ip.run_cell('print (b)')
278 279 ip.run_cell('import time; time.sleep(0.1)')
279 280 ip.run_cell("b/c")
280 281 ip.run_cell('b*=2')
281 282 ip.magic('autopx')
282 283
283 284 output = io.stdout.rstrip()
284 285
285 286 self.assertTrue(output.startswith('%autopx enabled'))
286 287 self.assertTrue(output.endswith('%autopx disabled'))
287 288 self.assertFalse('ZeroDivisionError' in output)
288 289 ar = v.get_result(-2)
289 290 self.assertRaisesRemote(ZeroDivisionError, ar.get)
290 291 # prevent TaskAborted on pulls, due to ZeroDivisionError
291 292 time.sleep(0.5)
292 293 self.assertEquals(v['a'], 5)
293 294 # b*=2 will not fire, due to abort
294 295 self.assertEquals(v['b'], 10)
295 296
296 297 def test_result(self):
297 298 ip = get_ipython()
298 299 v = self.client[-1]
299 300 v.activate()
300 301 data = dict(a=111,b=222)
301 302 v.push(data, block=True)
302 303
303 304 ip.magic('px a')
304 305 ip.magic('px b')
305 306 for idx, name in [
306 307 ('', 'b'),
307 308 ('-1', 'b'),
308 309 ('2', 'b'),
309 310 ('1', 'a'),
310 311 ('-2', 'a'),
311 312 ]:
312 313 with capture_output() as io:
313 314 ip.magic('result ' + idx)
314 315 output = io.stdout
315 316 msg = "expected %s output to include %s, but got: %s" % \
316 317 ('%result '+idx, str(data[name]), output)
317 318 self.assertTrue(str(data[name]) in output, msg)
318 319
319 320 @dec.skipif_not_matplotlib
320 321 def test_px_pylab(self):
321 322 """%pylab works on engines"""
322 323 ip = get_ipython()
323 324 v = self.client[-1]
324 325 v.block = True
325 326 v.activate()
326 327
327 328 with capture_output() as io:
328 329 ip.magic("px %pylab inline")
329 330
330 331 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
331 332 self.assertTrue("backend_inline" in io.stdout, io.stdout)
332 333
333 334 with capture_output() as io:
334 335 ip.magic("px plot(rand(100))")
335 336
336 337 self.assertTrue('Out[' in io.stdout, io.stdout)
337 338 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
338 339
339 340
@@ -1,323 +1,357 b''
1 1 # encoding: utf-8
2 2 """
3 3 IO related utilities.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12 from __future__ import print_function
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17 import os
18 18 import sys
19 19 import tempfile
20 from StringIO import StringIO
20 21
21 22 #-----------------------------------------------------------------------------
22 23 # Code
23 24 #-----------------------------------------------------------------------------
24 25
25 26
26 27 class IOStream:
27 28
28 29 def __init__(self,stream, fallback=None):
29 30 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
30 31 if fallback is not None:
31 32 stream = fallback
32 33 else:
33 34 raise ValueError("fallback required, but not specified")
34 35 self.stream = stream
35 36 self._swrite = stream.write
36 37
37 38 # clone all methods not overridden:
38 39 def clone(meth):
39 40 return not hasattr(self, meth) and not meth.startswith('_')
40 41 for meth in filter(clone, dir(stream)):
41 42 setattr(self, meth, getattr(stream, meth))
42 43
43 44 def write(self,data):
44 45 try:
45 46 self._swrite(data)
46 47 except:
47 48 try:
48 49 # print handles some unicode issues which may trip a plain
49 50 # write() call. Emulate write() by using an empty end
50 51 # argument.
51 52 print(data, end='', file=self.stream)
52 53 except:
53 54 # if we get here, something is seriously broken.
54 55 print('ERROR - failed to write data to stream:', self.stream,
55 56 file=sys.stderr)
56 57
57 58 def writelines(self, lines):
58 59 if isinstance(lines, basestring):
59 60 lines = [lines]
60 61 for line in lines:
61 62 self.write(line)
62 63
63 64 # This class used to have a writeln method, but regular files and streams
64 65 # in Python don't have this method. We need to keep this completely
65 66 # compatible so we removed it.
66 67
67 68 @property
68 69 def closed(self):
69 70 return self.stream.closed
70 71
71 72 def close(self):
72 73 pass
73 74
74 75 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
75 76 devnull = open(os.devnull, 'a')
76 77 stdin = IOStream(sys.stdin, fallback=devnull)
77 78 stdout = IOStream(sys.stdout, fallback=devnull)
78 79 stderr = IOStream(sys.stderr, fallback=devnull)
79 80
80 81 class IOTerm:
81 82 """ Term holds the file or file-like objects for handling I/O operations.
82 83
83 84 These are normally just sys.stdin, sys.stdout and sys.stderr but for
84 85 Windows they can can replaced to allow editing the strings before they are
85 86 displayed."""
86 87
87 88 # In the future, having IPython channel all its I/O operations through
88 89 # this class will make it easier to embed it into other environments which
89 90 # are not a normal terminal (such as a GUI-based shell)
90 91 def __init__(self, stdin=None, stdout=None, stderr=None):
91 92 mymodule = sys.modules[__name__]
92 93 self.stdin = IOStream(stdin, mymodule.stdin)
93 94 self.stdout = IOStream(stdout, mymodule.stdout)
94 95 self.stderr = IOStream(stderr, mymodule.stderr)
95 96
96 97
97 98 class Tee(object):
98 99 """A class to duplicate an output stream to stdout/err.
99 100
100 101 This works in a manner very similar to the Unix 'tee' command.
101 102
102 103 When the object is closed or deleted, it closes the original file given to
103 104 it for duplication.
104 105 """
105 106 # Inspired by:
106 107 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
107 108
108 109 def __init__(self, file_or_name, mode="w", channel='stdout'):
109 110 """Construct a new Tee object.
110 111
111 112 Parameters
112 113 ----------
113 114 file_or_name : filename or open filehandle (writable)
114 115 File that will be duplicated
115 116
116 117 mode : optional, valid mode for open().
117 118 If a filename was give, open with this mode.
118 119
119 120 channel : str, one of ['stdout', 'stderr']
120 121 """
121 122 if channel not in ['stdout', 'stderr']:
122 123 raise ValueError('Invalid channel spec %s' % channel)
123 124
124 125 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
125 126 self.file = file_or_name
126 127 else:
127 128 self.file = open(file_or_name, mode)
128 129 self.channel = channel
129 130 self.ostream = getattr(sys, channel)
130 131 setattr(sys, channel, self)
131 132 self._closed = False
132 133
133 134 def close(self):
134 135 """Close the file and restore the channel."""
135 136 self.flush()
136 137 setattr(sys, self.channel, self.ostream)
137 138 self.file.close()
138 139 self._closed = True
139 140
140 141 def write(self, data):
141 142 """Write data to both channels."""
142 143 self.file.write(data)
143 144 self.ostream.write(data)
144 145 self.ostream.flush()
145 146
146 147 def flush(self):
147 148 """Flush both channels."""
148 149 self.file.flush()
149 150 self.ostream.flush()
150 151
151 152 def __del__(self):
152 153 if not self._closed:
153 154 self.close()
154 155
155 156
156 157 def file_read(filename):
157 158 """Read a file and close it. Returns the file source."""
158 159 fobj = open(filename,'r');
159 160 source = fobj.read();
160 161 fobj.close()
161 162 return source
162 163
163 164
164 165 def file_readlines(filename):
165 166 """Read a file and close it. Returns the file source using readlines()."""
166 167 fobj = open(filename,'r');
167 168 lines = fobj.readlines();
168 169 fobj.close()
169 170 return lines
170 171
171 172
172 173 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
173 174 """Take multiple lines of input.
174 175
175 176 A list with each line of input as a separate element is returned when a
176 177 termination string is entered (defaults to a single '.'). Input can also
177 178 terminate via EOF (^D in Unix, ^Z-RET in Windows).
178 179
179 180 Lines of input which end in \\ are joined into single entries (and a
180 181 secondary continuation prompt is issued as long as the user terminates
181 182 lines with \\). This allows entering very long strings which are still
182 183 meant to be treated as single entities.
183 184 """
184 185
185 186 try:
186 187 if header:
187 188 header += '\n'
188 189 lines = [raw_input(header + ps1)]
189 190 except EOFError:
190 191 return []
191 192 terminate = [terminate_str]
192 193 try:
193 194 while lines[-1:] != terminate:
194 195 new_line = raw_input(ps1)
195 196 while new_line.endswith('\\'):
196 197 new_line = new_line[:-1] + raw_input(ps2)
197 198 lines.append(new_line)
198 199
199 200 return lines[:-1] # don't return the termination command
200 201 except EOFError:
201 202 print()
202 203 return lines
203 204
204 205
205 206 def raw_input_ext(prompt='', ps2='... '):
206 207 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
207 208
208 209 line = raw_input(prompt)
209 210 while line.endswith('\\'):
210 211 line = line[:-1] + raw_input(ps2)
211 212 return line
212 213
213 214
214 215 def ask_yes_no(prompt,default=None):
215 216 """Asks a question and returns a boolean (y/n) answer.
216 217
217 218 If default is given (one of 'y','n'), it is used if the user input is
218 219 empty. Otherwise the question is repeated until an answer is given.
219 220
220 221 An EOF is treated as the default answer. If there is no default, an
221 222 exception is raised to prevent infinite loops.
222 223
223 224 Valid answers are: y/yes/n/no (match is not case sensitive)."""
224 225
225 226 answers = {'y':True,'n':False,'yes':True,'no':False}
226 227 ans = None
227 228 while ans not in answers.keys():
228 229 try:
229 230 ans = raw_input(prompt+' ').lower()
230 231 if not ans: # response was an empty string
231 232 ans = default
232 233 except KeyboardInterrupt:
233 234 pass
234 235 except EOFError:
235 236 if default in answers.keys():
236 237 ans = default
237 238 print()
238 239 else:
239 240 raise
240 241
241 242 return answers[ans]
242 243
243 244
244 245 class NLprinter:
245 246 """Print an arbitrarily nested list, indicating index numbers.
246 247
247 248 An instance of this class called nlprint is available and callable as a
248 249 function.
249 250
250 251 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
251 252 and using 'sep' to separate the index from the value. """
252 253
253 254 def __init__(self):
254 255 self.depth = 0
255 256
256 257 def __call__(self,lst,pos='',**kw):
257 258 """Prints the nested list numbering levels."""
258 259 kw.setdefault('indent',' ')
259 260 kw.setdefault('sep',': ')
260 261 kw.setdefault('start',0)
261 262 kw.setdefault('stop',len(lst))
262 263 # we need to remove start and stop from kw so they don't propagate
263 264 # into a recursive call for a nested list.
264 265 start = kw['start']; del kw['start']
265 266 stop = kw['stop']; del kw['stop']
266 267 if self.depth == 0 and 'header' in kw.keys():
267 268 print(kw['header'])
268 269
269 270 for idx in range(start,stop):
270 271 elem = lst[idx]
271 272 newpos = pos + str(idx)
272 273 if type(elem)==type([]):
273 274 self.depth += 1
274 275 self.__call__(elem, newpos+",", **kw)
275 276 self.depth -= 1
276 277 else:
277 278 print(kw['indent']*self.depth + newpos + kw["sep"] + repr(elem))
278 279
279 280 nlprint = NLprinter()
280 281
281 282
282 283 def temp_pyfile(src, ext='.py'):
283 284 """Make a temporary python file, return filename and filehandle.
284 285
285 286 Parameters
286 287 ----------
287 288 src : string or list of strings (no need for ending newlines if list)
288 289 Source code to be written to the file.
289 290
290 291 ext : optional, string
291 292 Extension for the generated file.
292 293
293 294 Returns
294 295 -------
295 296 (filename, open filehandle)
296 297 It is the caller's responsibility to close the open file and unlink it.
297 298 """
298 299 fname = tempfile.mkstemp(ext)[1]
299 300 f = open(fname,'w')
300 301 f.write(src)
301 302 f.flush()
302 303 return fname, f
303 304
304 305
305 306 def raw_print(*args, **kw):
306 307 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
307 308
308 309 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
309 310 file=sys.__stdout__)
310 311 sys.__stdout__.flush()
311 312
312 313
313 314 def raw_print_err(*args, **kw):
314 315 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
315 316
316 317 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
317 318 file=sys.__stderr__)
318 319 sys.__stderr__.flush()
319 320
320 321
321 322 # Short aliases for quick debugging, do NOT use these in production code.
322 323 rprint = raw_print
323 324 rprinte = raw_print_err
325
326
327 class CapturedIO(object):
328 """Simple object for containing captured stdout/err StringIO objects"""
329
330 def __init__(self, stdout, stderr):
331 self.stdout_io = stdout
332 self.stderr_io = stderr
333
334 @property
335 def stdout(self):
336 return self.stdout_io.getvalue()
337
338 @property
339 def stderr(self):
340 return self.stderr_io.getvalue()
341
342
343 class capture_output(object):
344 """context manager for capturing stdout/err"""
345
346 def __enter__(self):
347 self.sys_stdout = sys.stdout
348 self.sys_stderr = sys.stderr
349 stdout = sys.stdout = StringIO()
350 stderr = sys.stderr = StringIO()
351 return CapturedIO(stdout, stderr)
352
353 def __exit__(self, exc_type, exc_value, traceback):
354 sys.stdout = self.sys_stdout
355 sys.stderr = self.sys_stderr
356
357
@@ -1,75 +1,85 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14
15 15 import sys
16 16
17 17 from StringIO import StringIO
18 18 from subprocess import Popen, PIPE
19 19
20 20 import nose.tools as nt
21 21
22 22 from IPython.testing import decorators as dec
23 from IPython.utils.io import Tee
23 from IPython.utils.io import Tee, capture_output
24 24 from IPython.utils.py3compat import doctest_refactor_print
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # Tests
28 28 #-----------------------------------------------------------------------------
29 29
30 30
31 31 def test_tee_simple():
32 32 "Very simple check with stdout only"
33 33 chan = StringIO()
34 34 text = 'Hello'
35 35 tee = Tee(chan, channel='stdout')
36 36 print >> chan, text
37 37 nt.assert_equal(chan.getvalue(), text+"\n")
38 38
39 39
40 40 class TeeTestCase(dec.ParametricTestCase):
41 41
42 42 def tchan(self, channel, check='close'):
43 43 trap = StringIO()
44 44 chan = StringIO()
45 45 text = 'Hello'
46 46
47 47 std_ori = getattr(sys, channel)
48 48 setattr(sys, channel, trap)
49 49
50 50 tee = Tee(chan, channel=channel)
51 51 print >> chan, text,
52 52 setattr(sys, channel, std_ori)
53 53 trap_val = trap.getvalue()
54 54 nt.assert_equals(chan.getvalue(), text)
55 55 if check=='close':
56 56 tee.close()
57 57 else:
58 58 del tee
59 59
60 60 def test(self):
61 61 for chan in ['stdout', 'stderr']:
62 62 for check in ['close', 'del']:
63 63 yield self.tchan(chan, check)
64 64
65 65 def test_io_init():
66 66 """Test that io.stdin/out/err exist at startup"""
67 67 for name in ('stdin', 'stdout', 'stderr'):
68 68 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
69 69 p = Popen([sys.executable, '-c', cmd],
70 70 stdout=PIPE)
71 71 p.wait()
72 72 classname = p.stdout.read().strip().decode('ascii')
73 73 # __class__ is a reference to the class object in Python 3, so we can't
74 74 # just test for string equality.
75 75 assert 'IPython.utils.io.IOStream' in classname, classname
76
77 def test_capture_output():
78 """capture_output() context works"""
79
80 with capture_output() as io:
81 print 'hi, stdout'
82 print >> sys.stderr, 'hi, stderr'
83
84 nt.assert_equals(io.stdout, 'hi, stdout\n')
85 nt.assert_equals(io.stderr, 'hi, stderr\n')
General Comments 0
You need to be logged in to leave comments. Login now