##// END OF EJS Templates
move capture_output util from parallel tests to utils.io
MinRK -
Show More
@@ -1,213 +1,183 b''
1 """base class for parallel client tests
1 """base class for parallel client tests
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 import sys
15 import sys
16 import tempfile
16 import tempfile
17 import time
17 import time
18 from StringIO import StringIO
18 from StringIO import StringIO
19
19
20 from nose import SkipTest
20 from nose import SkipTest
21
21
22 import zmq
22 import zmq
23 from zmq.tests import BaseZMQTestCase
23 from zmq.tests import BaseZMQTestCase
24
24
25 from IPython.external.decorator import decorator
25 from IPython.external.decorator import decorator
26
26
27 from IPython.parallel import error
27 from IPython.parallel import error
28 from IPython.parallel import Client
28 from IPython.parallel import Client
29
29
30 from IPython.parallel.tests import launchers, add_engines
30 from IPython.parallel.tests import launchers, add_engines
31
31
32 # simple tasks for use in apply tests
32 # simple tasks for use in apply tests
33
33
34 def segfault():
34 def segfault():
35 """this will segfault"""
35 """this will segfault"""
36 import ctypes
36 import ctypes
37 ctypes.memset(-1,0,1)
37 ctypes.memset(-1,0,1)
38
38
39 def crash():
39 def crash():
40 """from stdlib crashers in the test suite"""
40 """from stdlib crashers in the test suite"""
41 import types
41 import types
42 if sys.platform.startswith('win'):
42 if sys.platform.startswith('win'):
43 import ctypes
43 import ctypes
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
44 ctypes.windll.kernel32.SetErrorMode(0x0002);
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
46 if sys.version_info[0] >= 3:
46 if sys.version_info[0] >= 3:
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
47 # Python3 adds 'kwonlyargcount' as the second argument to Code
48 args.insert(1, 0)
48 args.insert(1, 0)
49
49
50 co = types.CodeType(*args)
50 co = types.CodeType(*args)
51 exec(co)
51 exec(co)
52
52
53 def wait(n):
53 def wait(n):
54 """sleep for a time"""
54 """sleep for a time"""
55 import time
55 import time
56 time.sleep(n)
56 time.sleep(n)
57 return n
57 return n
58
58
59 def raiser(eclass):
59 def raiser(eclass):
60 """raise an exception"""
60 """raise an exception"""
61 raise eclass()
61 raise eclass()
62
62
63 def generate_output():
63 def generate_output():
64 """function for testing output
64 """function for testing output
65
65
66 publishes two outputs of each type, and returns
66 publishes two outputs of each type, and returns
67 a rich displayable object.
67 a rich displayable object.
68 """
68 """
69
69
70 import sys
70 import sys
71 from IPython.core.display import display, HTML, Math
71 from IPython.core.display import display, HTML, Math
72
72
73 print "stdout"
73 print "stdout"
74 print >> sys.stderr, "stderr"
74 print >> sys.stderr, "stderr"
75
75
76 display(HTML("<b>HTML</b>"))
76 display(HTML("<b>HTML</b>"))
77
77
78 print "stdout2"
78 print "stdout2"
79 print >> sys.stderr, "stderr2"
79 print >> sys.stderr, "stderr2"
80
80
81 display(Math(r"\alpha=\beta"))
81 display(Math(r"\alpha=\beta"))
82
82
83 return Math("42")
83 return Math("42")
84
84
85 # test decorator for skipping tests when libraries are unavailable
85 # test decorator for skipping tests when libraries are unavailable
86 def skip_without(*names):
86 def skip_without(*names):
87 """skip a test if some names are not importable"""
87 """skip a test if some names are not importable"""
88 @decorator
88 @decorator
89 def skip_without_names(f, *args, **kwargs):
89 def skip_without_names(f, *args, **kwargs):
90 """decorator to skip tests in the absence of numpy."""
90 """decorator to skip tests in the absence of numpy."""
91 for name in names:
91 for name in names:
92 try:
92 try:
93 __import__(name)
93 __import__(name)
94 except ImportError:
94 except ImportError:
95 raise SkipTest
95 raise SkipTest
96 return f(*args, **kwargs)
96 return f(*args, **kwargs)
97 return skip_without_names
97 return skip_without_names
98
98
99 #-------------------------------------------------------------------------------
99 #-------------------------------------------------------------------------------
100 # Classes
100 # Classes
101 #-------------------------------------------------------------------------------
101 #-------------------------------------------------------------------------------
102
102
103 class CapturedIO(object):
104 """Simple object for containing captured stdout/err StringIO objects"""
105
106 def __init__(self, stdout, stderr):
107 self.stdout_io = stdout
108 self.stderr_io = stderr
109
110 @property
111 def stdout(self):
112 return self.stdout_io.getvalue()
113
114 @property
115 def stderr(self):
116 return self.stderr_io.getvalue()
117
118
119 class capture_output(object):
120 """context manager for capturing stdout/err"""
121
122 def __enter__(self):
123 self.sys_stdout = sys.stdout
124 self.sys_stderr = sys.stderr
125 stdout = sys.stdout = StringIO()
126 stderr = sys.stderr = StringIO()
127 return CapturedIO(stdout, stderr)
128
129 def __exit__(self, exc_type, exc_value, traceback):
130 sys.stdout = self.sys_stdout
131 sys.stderr = self.sys_stderr
132
133
103
134 class ClusterTestCase(BaseZMQTestCase):
104 class ClusterTestCase(BaseZMQTestCase):
135
105
136 def add_engines(self, n=1, block=True):
106 def add_engines(self, n=1, block=True):
137 """add multiple engines to our cluster"""
107 """add multiple engines to our cluster"""
138 self.engines.extend(add_engines(n))
108 self.engines.extend(add_engines(n))
139 if block:
109 if block:
140 self.wait_on_engines()
110 self.wait_on_engines()
141
111
142 def minimum_engines(self, n=1, block=True):
112 def minimum_engines(self, n=1, block=True):
143 """add engines until there are at least n connected"""
113 """add engines until there are at least n connected"""
144 self.engines.extend(add_engines(n, total=True))
114 self.engines.extend(add_engines(n, total=True))
145 if block:
115 if block:
146 self.wait_on_engines()
116 self.wait_on_engines()
147
117
148
118
149 def wait_on_engines(self, timeout=5):
119 def wait_on_engines(self, timeout=5):
150 """wait for our engines to connect."""
120 """wait for our engines to connect."""
151 n = len(self.engines)+self.base_engine_count
121 n = len(self.engines)+self.base_engine_count
152 tic = time.time()
122 tic = time.time()
153 while time.time()-tic < timeout and len(self.client.ids) < n:
123 while time.time()-tic < timeout and len(self.client.ids) < n:
154 time.sleep(0.1)
124 time.sleep(0.1)
155
125
156 assert not len(self.client.ids) < n, "waiting for engines timed out"
126 assert not len(self.client.ids) < n, "waiting for engines timed out"
157
127
158 def connect_client(self):
128 def connect_client(self):
159 """connect a client with my Context, and track its sockets for cleanup"""
129 """connect a client with my Context, and track its sockets for cleanup"""
160 c = Client(profile='iptest', context=self.context)
130 c = Client(profile='iptest', context=self.context)
161 for name in filter(lambda n:n.endswith('socket'), dir(c)):
131 for name in filter(lambda n:n.endswith('socket'), dir(c)):
162 s = getattr(c, name)
132 s = getattr(c, name)
163 s.setsockopt(zmq.LINGER, 0)
133 s.setsockopt(zmq.LINGER, 0)
164 self.sockets.append(s)
134 self.sockets.append(s)
165 return c
135 return c
166
136
167 def assertRaisesRemote(self, etype, f, *args, **kwargs):
137 def assertRaisesRemote(self, etype, f, *args, **kwargs):
168 try:
138 try:
169 try:
139 try:
170 f(*args, **kwargs)
140 f(*args, **kwargs)
171 except error.CompositeError as e:
141 except error.CompositeError as e:
172 e.raise_exception()
142 e.raise_exception()
173 except error.RemoteError as e:
143 except error.RemoteError as e:
174 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
144 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
175 else:
145 else:
176 self.fail("should have raised a RemoteError")
146 self.fail("should have raised a RemoteError")
177
147
178 def _wait_for(self, f, timeout=10):
148 def _wait_for(self, f, timeout=10):
179 """wait for a condition"""
149 """wait for a condition"""
180 tic = time.time()
150 tic = time.time()
181 while time.time() <= tic + timeout:
151 while time.time() <= tic + timeout:
182 if f():
152 if f():
183 return
153 return
184 time.sleep(0.1)
154 time.sleep(0.1)
185 self.client.spin()
155 self.client.spin()
186 if not f():
156 if not f():
187 print "Warning: Awaited condition never arrived"
157 print "Warning: Awaited condition never arrived"
188
158
189 def setUp(self):
159 def setUp(self):
190 BaseZMQTestCase.setUp(self)
160 BaseZMQTestCase.setUp(self)
191 self.client = self.connect_client()
161 self.client = self.connect_client()
192 # start every test with clean engine namespaces:
162 # start every test with clean engine namespaces:
193 self.client.clear(block=True)
163 self.client.clear(block=True)
194 self.base_engine_count=len(self.client.ids)
164 self.base_engine_count=len(self.client.ids)
195 self.engines=[]
165 self.engines=[]
196
166
197 def tearDown(self):
167 def tearDown(self):
198 # self.client.clear(block=True)
168 # self.client.clear(block=True)
199 # close fds:
169 # close fds:
200 for e in filter(lambda e: e.poll() is not None, launchers):
170 for e in filter(lambda e: e.poll() is not None, launchers):
201 launchers.remove(e)
171 launchers.remove(e)
202
172
203 # allow flushing of incoming messages to prevent crash on socket close
173 # allow flushing of incoming messages to prevent crash on socket close
204 self.client.wait(timeout=2)
174 self.client.wait(timeout=2)
205 # time.sleep(2)
175 # time.sleep(2)
206 self.client.spin()
176 self.client.spin()
207 self.client.close()
177 self.client.close()
208 BaseZMQTestCase.tearDown(self)
178 BaseZMQTestCase.tearDown(self)
209 # this will be redundant when pyzmq merges PR #88
179 # this will be redundant when pyzmq merges PR #88
210 # self.context.term()
180 # self.context.term()
211 # print tempfile.TemporaryFile().fileno(),
181 # print tempfile.TemporaryFile().fileno(),
212 # sys.stdout.flush()
182 # sys.stdout.flush()
213 No newline at end of file
183
@@ -1,266 +1,267 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import time
19 import time
20
20
21 from IPython.parallel.error import TimeoutError
21 from IPython.utils.io import capture_output
22
22
23 from IPython.parallel.error import TimeoutError
23 from IPython.parallel import error, Client
24 from IPython.parallel import error, Client
24 from IPython.parallel.tests import add_engines
25 from IPython.parallel.tests import add_engines
25 from .clienttest import ClusterTestCase, capture_output
26 from .clienttest import ClusterTestCase
26
27
27 def setup():
28 def setup():
28 add_engines(2, total=True)
29 add_engines(2, total=True)
29
30
30 def wait(n):
31 def wait(n):
31 import time
32 import time
32 time.sleep(n)
33 time.sleep(n)
33 return n
34 return n
34
35
35 class AsyncResultTest(ClusterTestCase):
36 class AsyncResultTest(ClusterTestCase):
36
37
37 def test_single_result_view(self):
38 def test_single_result_view(self):
38 """various one-target views get the right value for single_result"""
39 """various one-target views get the right value for single_result"""
39 eid = self.client.ids[-1]
40 eid = self.client.ids[-1]
40 ar = self.client[eid].apply_async(lambda : 42)
41 ar = self.client[eid].apply_async(lambda : 42)
41 self.assertEquals(ar.get(), 42)
42 self.assertEquals(ar.get(), 42)
42 ar = self.client[[eid]].apply_async(lambda : 42)
43 ar = self.client[[eid]].apply_async(lambda : 42)
43 self.assertEquals(ar.get(), [42])
44 self.assertEquals(ar.get(), [42])
44 ar = self.client[-1:].apply_async(lambda : 42)
45 ar = self.client[-1:].apply_async(lambda : 42)
45 self.assertEquals(ar.get(), [42])
46 self.assertEquals(ar.get(), [42])
46
47
47 def test_get_after_done(self):
48 def test_get_after_done(self):
48 ar = self.client[-1].apply_async(lambda : 42)
49 ar = self.client[-1].apply_async(lambda : 42)
49 ar.wait()
50 ar.wait()
50 self.assertTrue(ar.ready())
51 self.assertTrue(ar.ready())
51 self.assertEquals(ar.get(), 42)
52 self.assertEquals(ar.get(), 42)
52 self.assertEquals(ar.get(), 42)
53 self.assertEquals(ar.get(), 42)
53
54
54 def test_get_before_done(self):
55 def test_get_before_done(self):
55 ar = self.client[-1].apply_async(wait, 0.1)
56 ar = self.client[-1].apply_async(wait, 0.1)
56 self.assertRaises(TimeoutError, ar.get, 0)
57 self.assertRaises(TimeoutError, ar.get, 0)
57 ar.wait(0)
58 ar.wait(0)
58 self.assertFalse(ar.ready())
59 self.assertFalse(ar.ready())
59 self.assertEquals(ar.get(), 0.1)
60 self.assertEquals(ar.get(), 0.1)
60
61
61 def test_get_after_error(self):
62 def test_get_after_error(self):
62 ar = self.client[-1].apply_async(lambda : 1/0)
63 ar = self.client[-1].apply_async(lambda : 1/0)
63 ar.wait(10)
64 ar.wait(10)
64 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
67 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
67
68
68 def test_get_dict(self):
69 def test_get_dict(self):
69 n = len(self.client)
70 n = len(self.client)
70 ar = self.client[:].apply_async(lambda : 5)
71 ar = self.client[:].apply_async(lambda : 5)
71 self.assertEquals(ar.get(), [5]*n)
72 self.assertEquals(ar.get(), [5]*n)
72 d = ar.get_dict()
73 d = ar.get_dict()
73 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 for eid,r in d.iteritems():
75 for eid,r in d.iteritems():
75 self.assertEquals(r, 5)
76 self.assertEquals(r, 5)
76
77
77 def test_list_amr(self):
78 def test_list_amr(self):
78 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
79 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
79 rlist = list(ar)
80 rlist = list(ar)
80
81
81 def test_getattr(self):
82 def test_getattr(self):
82 ar = self.client[:].apply_async(wait, 0.5)
83 ar = self.client[:].apply_async(wait, 0.5)
83 self.assertRaises(AttributeError, lambda : ar._foo)
84 self.assertRaises(AttributeError, lambda : ar._foo)
84 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
85 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
85 self.assertRaises(AttributeError, lambda : ar.foo)
86 self.assertRaises(AttributeError, lambda : ar.foo)
86 self.assertRaises(AttributeError, lambda : ar.engine_id)
87 self.assertRaises(AttributeError, lambda : ar.engine_id)
87 self.assertFalse(hasattr(ar, '__length_hint__'))
88 self.assertFalse(hasattr(ar, '__length_hint__'))
88 self.assertFalse(hasattr(ar, 'foo'))
89 self.assertFalse(hasattr(ar, 'foo'))
89 self.assertFalse(hasattr(ar, 'engine_id'))
90 self.assertFalse(hasattr(ar, 'engine_id'))
90 ar.get(5)
91 ar.get(5)
91 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertTrue(isinstance(ar.engine_id, list))
95 self.assertTrue(isinstance(ar.engine_id, list))
95 self.assertEquals(ar.engine_id, ar['engine_id'])
96 self.assertEquals(ar.engine_id, ar['engine_id'])
96 self.assertFalse(hasattr(ar, '__length_hint__'))
97 self.assertFalse(hasattr(ar, '__length_hint__'))
97 self.assertFalse(hasattr(ar, 'foo'))
98 self.assertFalse(hasattr(ar, 'foo'))
98 self.assertTrue(hasattr(ar, 'engine_id'))
99 self.assertTrue(hasattr(ar, 'engine_id'))
99
100
100 def test_getitem(self):
101 def test_getitem(self):
101 ar = self.client[:].apply_async(wait, 0.5)
102 ar = self.client[:].apply_async(wait, 0.5)
102 self.assertRaises(TimeoutError, lambda : ar['foo'])
103 self.assertRaises(TimeoutError, lambda : ar['foo'])
103 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
104 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
104 ar.get(5)
105 ar.get(5)
105 self.assertRaises(KeyError, lambda : ar['foo'])
106 self.assertRaises(KeyError, lambda : ar['foo'])
106 self.assertTrue(isinstance(ar['engine_id'], list))
107 self.assertTrue(isinstance(ar['engine_id'], list))
107 self.assertEquals(ar.engine_id, ar['engine_id'])
108 self.assertEquals(ar.engine_id, ar['engine_id'])
108
109
109 def test_single_result(self):
110 def test_single_result(self):
110 ar = self.client[-1].apply_async(wait, 0.5)
111 ar = self.client[-1].apply_async(wait, 0.5)
111 self.assertRaises(TimeoutError, lambda : ar['foo'])
112 self.assertRaises(TimeoutError, lambda : ar['foo'])
112 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
113 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
113 self.assertTrue(ar.get(5) == 0.5)
114 self.assertTrue(ar.get(5) == 0.5)
114 self.assertTrue(isinstance(ar['engine_id'], int))
115 self.assertTrue(isinstance(ar['engine_id'], int))
115 self.assertTrue(isinstance(ar.engine_id, int))
116 self.assertTrue(isinstance(ar.engine_id, int))
116 self.assertEquals(ar.engine_id, ar['engine_id'])
117 self.assertEquals(ar.engine_id, ar['engine_id'])
117
118
118 def test_abort(self):
119 def test_abort(self):
119 e = self.client[-1]
120 e = self.client[-1]
120 ar = e.execute('import time; time.sleep(1)', block=False)
121 ar = e.execute('import time; time.sleep(1)', block=False)
121 ar2 = e.apply_async(lambda : 2)
122 ar2 = e.apply_async(lambda : 2)
122 ar2.abort()
123 ar2.abort()
123 self.assertRaises(error.TaskAborted, ar2.get)
124 self.assertRaises(error.TaskAborted, ar2.get)
124 ar.get()
125 ar.get()
125
126
126 def test_len(self):
127 def test_len(self):
127 v = self.client.load_balanced_view()
128 v = self.client.load_balanced_view()
128 ar = v.map_async(lambda x: x, range(10))
129 ar = v.map_async(lambda x: x, range(10))
129 self.assertEquals(len(ar), 10)
130 self.assertEquals(len(ar), 10)
130 ar = v.apply_async(lambda x: x, range(10))
131 ar = v.apply_async(lambda x: x, range(10))
131 self.assertEquals(len(ar), 1)
132 self.assertEquals(len(ar), 1)
132 ar = self.client[:].apply_async(lambda x: x, range(10))
133 ar = self.client[:].apply_async(lambda x: x, range(10))
133 self.assertEquals(len(ar), len(self.client.ids))
134 self.assertEquals(len(ar), len(self.client.ids))
134
135
135 def test_wall_time_single(self):
136 def test_wall_time_single(self):
136 v = self.client.load_balanced_view()
137 v = self.client.load_balanced_view()
137 ar = v.apply_async(time.sleep, 0.25)
138 ar = v.apply_async(time.sleep, 0.25)
138 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
139 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
139 ar.get(2)
140 ar.get(2)
140 self.assertTrue(ar.wall_time < 1.)
141 self.assertTrue(ar.wall_time < 1.)
141 self.assertTrue(ar.wall_time > 0.2)
142 self.assertTrue(ar.wall_time > 0.2)
142
143
143 def test_wall_time_multi(self):
144 def test_wall_time_multi(self):
144 self.minimum_engines(4)
145 self.minimum_engines(4)
145 v = self.client[:]
146 v = self.client[:]
146 ar = v.apply_async(time.sleep, 0.25)
147 ar = v.apply_async(time.sleep, 0.25)
147 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
148 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
148 ar.get(2)
149 ar.get(2)
149 self.assertTrue(ar.wall_time < 1.)
150 self.assertTrue(ar.wall_time < 1.)
150 self.assertTrue(ar.wall_time > 0.2)
151 self.assertTrue(ar.wall_time > 0.2)
151
152
152 def test_serial_time_single(self):
153 def test_serial_time_single(self):
153 v = self.client.load_balanced_view()
154 v = self.client.load_balanced_view()
154 ar = v.apply_async(time.sleep, 0.25)
155 ar = v.apply_async(time.sleep, 0.25)
155 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
156 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
156 ar.get(2)
157 ar.get(2)
157 self.assertTrue(ar.serial_time < 1.)
158 self.assertTrue(ar.serial_time < 1.)
158 self.assertTrue(ar.serial_time > 0.2)
159 self.assertTrue(ar.serial_time > 0.2)
159
160
160 def test_serial_time_multi(self):
161 def test_serial_time_multi(self):
161 self.minimum_engines(4)
162 self.minimum_engines(4)
162 v = self.client[:]
163 v = self.client[:]
163 ar = v.apply_async(time.sleep, 0.25)
164 ar = v.apply_async(time.sleep, 0.25)
164 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
165 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
165 ar.get(2)
166 ar.get(2)
166 self.assertTrue(ar.serial_time < 2.)
167 self.assertTrue(ar.serial_time < 2.)
167 self.assertTrue(ar.serial_time > 0.8)
168 self.assertTrue(ar.serial_time > 0.8)
168
169
169 def test_elapsed_single(self):
170 def test_elapsed_single(self):
170 v = self.client.load_balanced_view()
171 v = self.client.load_balanced_view()
171 ar = v.apply_async(time.sleep, 0.25)
172 ar = v.apply_async(time.sleep, 0.25)
172 while not ar.ready():
173 while not ar.ready():
173 time.sleep(0.01)
174 time.sleep(0.01)
174 self.assertTrue(ar.elapsed < 1)
175 self.assertTrue(ar.elapsed < 1)
175 self.assertTrue(ar.elapsed < 1)
176 self.assertTrue(ar.elapsed < 1)
176 ar.get(2)
177 ar.get(2)
177
178
178 def test_elapsed_multi(self):
179 def test_elapsed_multi(self):
179 v = self.client[:]
180 v = self.client[:]
180 ar = v.apply_async(time.sleep, 0.25)
181 ar = v.apply_async(time.sleep, 0.25)
181 while not ar.ready():
182 while not ar.ready():
182 time.sleep(0.01)
183 time.sleep(0.01)
183 self.assertTrue(ar.elapsed < 1)
184 self.assertTrue(ar.elapsed < 1)
184 self.assertTrue(ar.elapsed < 1)
185 self.assertTrue(ar.elapsed < 1)
185 ar.get(2)
186 ar.get(2)
186
187
187 def test_hubresult_timestamps(self):
188 def test_hubresult_timestamps(self):
188 self.minimum_engines(4)
189 self.minimum_engines(4)
189 v = self.client[:]
190 v = self.client[:]
190 ar = v.apply_async(time.sleep, 0.25)
191 ar = v.apply_async(time.sleep, 0.25)
191 ar.get(2)
192 ar.get(2)
192 rc2 = Client(profile='iptest')
193 rc2 = Client(profile='iptest')
193 # must have try/finally to close second Client, otherwise
194 # must have try/finally to close second Client, otherwise
194 # will have dangling sockets causing problems
195 # will have dangling sockets causing problems
195 try:
196 try:
196 time.sleep(0.25)
197 time.sleep(0.25)
197 hr = rc2.get_result(ar.msg_ids)
198 hr = rc2.get_result(ar.msg_ids)
198 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
199 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
199 hr.get(1)
200 hr.get(1)
200 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
201 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
201 self.assertEquals(hr.serial_time, ar.serial_time)
202 self.assertEquals(hr.serial_time, ar.serial_time)
202 finally:
203 finally:
203 rc2.close()
204 rc2.close()
204
205
205 def test_display_empty_streams_single(self):
206 def test_display_empty_streams_single(self):
206 """empty stdout/err are not displayed (single result)"""
207 """empty stdout/err are not displayed (single result)"""
207 self.minimum_engines(1)
208 self.minimum_engines(1)
208
209
209 v = self.client[-1]
210 v = self.client[-1]
210 ar = v.execute("print (5555)")
211 ar = v.execute("print (5555)")
211 ar.get(5)
212 ar.get(5)
212 with capture_output() as io:
213 with capture_output() as io:
213 ar.display_outputs()
214 ar.display_outputs()
214 self.assertEquals(io.stderr, '')
215 self.assertEquals(io.stderr, '')
215 self.assertEquals('5555\n', io.stdout)
216 self.assertEquals('5555\n', io.stdout)
216
217
217 ar = v.execute("a=5")
218 ar = v.execute("a=5")
218 ar.get(5)
219 ar.get(5)
219 with capture_output() as io:
220 with capture_output() as io:
220 ar.display_outputs()
221 ar.display_outputs()
221 self.assertEquals(io.stderr, '')
222 self.assertEquals(io.stderr, '')
222 self.assertEquals(io.stdout, '')
223 self.assertEquals(io.stdout, '')
223
224
224 def test_display_empty_streams_type(self):
225 def test_display_empty_streams_type(self):
225 """empty stdout/err are not displayed (groupby type)"""
226 """empty stdout/err are not displayed (groupby type)"""
226 self.minimum_engines(1)
227 self.minimum_engines(1)
227
228
228 v = self.client[:]
229 v = self.client[:]
229 ar = v.execute("print (5555)")
230 ar = v.execute("print (5555)")
230 ar.get(5)
231 ar.get(5)
231 with capture_output() as io:
232 with capture_output() as io:
232 ar.display_outputs()
233 ar.display_outputs()
233 self.assertEquals(io.stderr, '')
234 self.assertEquals(io.stderr, '')
234 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
235 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
235 self.assertFalse('\n\n' in io.stdout, io.stdout)
236 self.assertFalse('\n\n' in io.stdout, io.stdout)
236 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
237 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
237
238
238 ar = v.execute("a=5")
239 ar = v.execute("a=5")
239 ar.get(5)
240 ar.get(5)
240 with capture_output() as io:
241 with capture_output() as io:
241 ar.display_outputs()
242 ar.display_outputs()
242 self.assertEquals(io.stderr, '')
243 self.assertEquals(io.stderr, '')
243 self.assertEquals(io.stdout, '')
244 self.assertEquals(io.stdout, '')
244
245
245 def test_display_empty_streams_engine(self):
246 def test_display_empty_streams_engine(self):
246 """empty stdout/err are not displayed (groupby engine)"""
247 """empty stdout/err are not displayed (groupby engine)"""
247 self.minimum_engines(1)
248 self.minimum_engines(1)
248
249
249 v = self.client[:]
250 v = self.client[:]
250 ar = v.execute("print (5555)")
251 ar = v.execute("print (5555)")
251 ar.get(5)
252 ar.get(5)
252 with capture_output() as io:
253 with capture_output() as io:
253 ar.display_outputs('engine')
254 ar.display_outputs('engine')
254 self.assertEquals(io.stderr, '')
255 self.assertEquals(io.stderr, '')
255 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
256 self.assertEquals(io.stdout.count('5555'), len(v), io.stdout)
256 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 self.assertFalse('\n\n' in io.stdout, io.stdout)
257 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
258 self.assertEquals(io.stdout.count('[stdout:'), len(v), io.stdout)
258
259
259 ar = v.execute("a=5")
260 ar = v.execute("a=5")
260 ar.get(5)
261 ar.get(5)
261 with capture_output() as io:
262 with capture_output() as io:
262 ar.display_outputs('engine')
263 ar.display_outputs('engine')
263 self.assertEquals(io.stderr, '')
264 self.assertEquals(io.stderr, '')
264 self.assertEquals(io.stdout, '')
265 self.assertEquals(io.stdout, '')
265
266
266
267
@@ -1,339 +1,340 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Test Parallel magics
2 """Test Parallel magics
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import re
19 import re
20 import sys
20 import sys
21 import time
21 import time
22
22
23 import zmq
23 import zmq
24 from nose import SkipTest
24 from nose import SkipTest
25
25
26 from IPython.testing import decorators as dec
26 from IPython.testing import decorators as dec
27 from IPython.testing.ipunittest import ParametricTestCase
27 from IPython.testing.ipunittest import ParametricTestCase
28 from IPython.utils.io import capture_output
28
29
29 from IPython import parallel as pmod
30 from IPython import parallel as pmod
30 from IPython.parallel import error
31 from IPython.parallel import error
31 from IPython.parallel import AsyncResult
32 from IPython.parallel import AsyncResult
32 from IPython.parallel.util import interactive
33 from IPython.parallel.util import interactive
33
34
34 from IPython.parallel.tests import add_engines
35 from IPython.parallel.tests import add_engines
35
36
36 from .clienttest import ClusterTestCase, capture_output, generate_output
37 from .clienttest import ClusterTestCase, generate_output
37
38
38 def setup():
39 def setup():
39 add_engines(3, total=True)
40 add_engines(3, total=True)
40
41
41 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
42 class TestParallelMagics(ClusterTestCase, ParametricTestCase):
42
43
43 def test_px_blocking(self):
44 def test_px_blocking(self):
44 ip = get_ipython()
45 ip = get_ipython()
45 v = self.client[-1:]
46 v = self.client[-1:]
46 v.activate()
47 v.activate()
47 v.block=True
48 v.block=True
48
49
49 ip.magic('px a=5')
50 ip.magic('px a=5')
50 self.assertEquals(v['a'], [5])
51 self.assertEquals(v['a'], [5])
51 ip.magic('px a=10')
52 ip.magic('px a=10')
52 self.assertEquals(v['a'], [10])
53 self.assertEquals(v['a'], [10])
53 # just 'print a' works ~99% of the time, but this ensures that
54 # just 'print a' works ~99% of the time, but this ensures that
54 # the stdout message has arrived when the result is finished:
55 # the stdout message has arrived when the result is finished:
55 with capture_output() as io:
56 with capture_output() as io:
56 ip.magic(
57 ip.magic(
57 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
58 'px import sys,time;print(a);sys.stdout.flush();time.sleep(0.2)'
58 )
59 )
59 out = io.stdout
60 out = io.stdout
60 self.assertTrue('[stdout:' in out, out)
61 self.assertTrue('[stdout:' in out, out)
61 self.assertFalse('\n\n' in out)
62 self.assertFalse('\n\n' in out)
62 self.assertTrue(out.rstrip().endswith('10'))
63 self.assertTrue(out.rstrip().endswith('10'))
63 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
64 self.assertRaisesRemote(ZeroDivisionError, ip.magic, 'px 1/0')
64
65
65 def _check_generated_stderr(self, stderr, n):
66 def _check_generated_stderr(self, stderr, n):
66 expected = [
67 expected = [
67 r'\[stderr:\d+\]',
68 r'\[stderr:\d+\]',
68 '^stderr$',
69 '^stderr$',
69 '^stderr2$',
70 '^stderr2$',
70 ] * n
71 ] * n
71
72
72 self.assertFalse('\n\n' in stderr, stderr)
73 self.assertFalse('\n\n' in stderr, stderr)
73 lines = stderr.splitlines()
74 lines = stderr.splitlines()
74 self.assertEquals(len(lines), len(expected), stderr)
75 self.assertEquals(len(lines), len(expected), stderr)
75 for line,expect in zip(lines, expected):
76 for line,expect in zip(lines, expected):
76 if isinstance(expect, str):
77 if isinstance(expect, str):
77 expect = [expect]
78 expect = [expect]
78 for ex in expect:
79 for ex in expect:
79 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
80 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
80
81
81 def test_cellpx_block_args(self):
82 def test_cellpx_block_args(self):
82 """%%px --[no]block flags work"""
83 """%%px --[no]block flags work"""
83 ip = get_ipython()
84 ip = get_ipython()
84 v = self.client[-1:]
85 v = self.client[-1:]
85 v.activate()
86 v.activate()
86 v.block=False
87 v.block=False
87
88
88 for block in (True, False):
89 for block in (True, False):
89 v.block = block
90 v.block = block
90
91
91 with capture_output() as io:
92 with capture_output() as io:
92 ip.run_cell_magic("px", "", "1")
93 ip.run_cell_magic("px", "", "1")
93 if block:
94 if block:
94 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
95 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
95 else:
96 else:
96 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
97 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
97
98
98 with capture_output() as io:
99 with capture_output() as io:
99 ip.run_cell_magic("px", "--block", "1")
100 ip.run_cell_magic("px", "--block", "1")
100 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
101 self.assertTrue(io.stdout.startswith("Parallel"), io.stdout)
101
102
102 with capture_output() as io:
103 with capture_output() as io:
103 ip.run_cell_magic("px", "--noblock", "1")
104 ip.run_cell_magic("px", "--noblock", "1")
104 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
105 self.assertTrue(io.stdout.startswith("Async"), io.stdout)
105
106
106 def test_cellpx_groupby_engine(self):
107 def test_cellpx_groupby_engine(self):
107 """%%px --group-outputs=engine"""
108 """%%px --group-outputs=engine"""
108 ip = get_ipython()
109 ip = get_ipython()
109 v = self.client[:]
110 v = self.client[:]
110 v.block = True
111 v.block = True
111 v.activate()
112 v.activate()
112
113
113 v['generate_output'] = generate_output
114 v['generate_output'] = generate_output
114
115
115 with capture_output() as io:
116 with capture_output() as io:
116 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
117 ip.run_cell_magic('px', '--group-outputs=engine', 'generate_output()')
117
118
118 self.assertFalse('\n\n' in io.stdout)
119 self.assertFalse('\n\n' in io.stdout)
119 lines = io.stdout.splitlines()[1:]
120 lines = io.stdout.splitlines()[1:]
120 expected = [
121 expected = [
121 r'\[stdout:\d+\]',
122 r'\[stdout:\d+\]',
122 'stdout',
123 'stdout',
123 'stdout2',
124 'stdout2',
124 r'\[output:\d+\]',
125 r'\[output:\d+\]',
125 r'IPython\.core\.display\.HTML',
126 r'IPython\.core\.display\.HTML',
126 r'IPython\.core\.display\.Math',
127 r'IPython\.core\.display\.Math',
127 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
128 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math',
128 ] * len(v)
129 ] * len(v)
129
130
130 self.assertEquals(len(lines), len(expected), io.stdout)
131 self.assertEquals(len(lines), len(expected), io.stdout)
131 for line,expect in zip(lines, expected):
132 for line,expect in zip(lines, expected):
132 if isinstance(expect, str):
133 if isinstance(expect, str):
133 expect = [expect]
134 expect = [expect]
134 for ex in expect:
135 for ex in expect:
135 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
136 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
136
137
137 self._check_generated_stderr(io.stderr, len(v))
138 self._check_generated_stderr(io.stderr, len(v))
138
139
139
140
140 def test_cellpx_groupby_order(self):
141 def test_cellpx_groupby_order(self):
141 """%%px --group-outputs=order"""
142 """%%px --group-outputs=order"""
142 ip = get_ipython()
143 ip = get_ipython()
143 v = self.client[:]
144 v = self.client[:]
144 v.block = True
145 v.block = True
145 v.activate()
146 v.activate()
146
147
147 v['generate_output'] = generate_output
148 v['generate_output'] = generate_output
148
149
149 with capture_output() as io:
150 with capture_output() as io:
150 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
151 ip.run_cell_magic('px', '--group-outputs=order', 'generate_output()')
151
152
152 self.assertFalse('\n\n' in io.stdout)
153 self.assertFalse('\n\n' in io.stdout)
153 lines = io.stdout.splitlines()[1:]
154 lines = io.stdout.splitlines()[1:]
154 expected = []
155 expected = []
155 expected.extend([
156 expected.extend([
156 r'\[stdout:\d+\]',
157 r'\[stdout:\d+\]',
157 'stdout',
158 'stdout',
158 'stdout2',
159 'stdout2',
159 ] * len(v))
160 ] * len(v))
160 expected.extend([
161 expected.extend([
161 r'\[output:\d+\]',
162 r'\[output:\d+\]',
162 'IPython.core.display.HTML',
163 'IPython.core.display.HTML',
163 ] * len(v))
164 ] * len(v))
164 expected.extend([
165 expected.extend([
165 r'\[output:\d+\]',
166 r'\[output:\d+\]',
166 'IPython.core.display.Math',
167 'IPython.core.display.Math',
167 ] * len(v))
168 ] * len(v))
168 expected.extend([
169 expected.extend([
169 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
170 r'Out\[\d+:\d+\]:.*IPython\.core\.display\.Math'
170 ] * len(v))
171 ] * len(v))
171
172
172 self.assertEquals(len(lines), len(expected), io.stdout)
173 self.assertEquals(len(lines), len(expected), io.stdout)
173 for line,expect in zip(lines, expected):
174 for line,expect in zip(lines, expected):
174 if isinstance(expect, str):
175 if isinstance(expect, str):
175 expect = [expect]
176 expect = [expect]
176 for ex in expect:
177 for ex in expect:
177 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
178 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
178
179
179 self._check_generated_stderr(io.stderr, len(v))
180 self._check_generated_stderr(io.stderr, len(v))
180
181
181 def test_cellpx_groupby_type(self):
182 def test_cellpx_groupby_type(self):
182 """%%px --group-outputs=type"""
183 """%%px --group-outputs=type"""
183 ip = get_ipython()
184 ip = get_ipython()
184 v = self.client[:]
185 v = self.client[:]
185 v.block = True
186 v.block = True
186 v.activate()
187 v.activate()
187
188
188 v['generate_output'] = generate_output
189 v['generate_output'] = generate_output
189
190
190 with capture_output() as io:
191 with capture_output() as io:
191 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
192 ip.run_cell_magic('px', '--group-outputs=type', 'generate_output()')
192
193
193 self.assertFalse('\n\n' in io.stdout)
194 self.assertFalse('\n\n' in io.stdout)
194 lines = io.stdout.splitlines()[1:]
195 lines = io.stdout.splitlines()[1:]
195
196
196 expected = []
197 expected = []
197 expected.extend([
198 expected.extend([
198 r'\[stdout:\d+\]',
199 r'\[stdout:\d+\]',
199 'stdout',
200 'stdout',
200 'stdout2',
201 'stdout2',
201 ] * len(v))
202 ] * len(v))
202 expected.extend([
203 expected.extend([
203 r'\[output:\d+\]',
204 r'\[output:\d+\]',
204 r'IPython\.core\.display\.HTML',
205 r'IPython\.core\.display\.HTML',
205 r'IPython\.core\.display\.Math',
206 r'IPython\.core\.display\.Math',
206 ] * len(v))
207 ] * len(v))
207 expected.extend([
208 expected.extend([
208 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
209 (r'Out\[\d+:\d+\]', r'IPython\.core\.display\.Math')
209 ] * len(v))
210 ] * len(v))
210
211
211 self.assertEquals(len(lines), len(expected), io.stdout)
212 self.assertEquals(len(lines), len(expected), io.stdout)
212 for line,expect in zip(lines, expected):
213 for line,expect in zip(lines, expected):
213 if isinstance(expect, str):
214 if isinstance(expect, str):
214 expect = [expect]
215 expect = [expect]
215 for ex in expect:
216 for ex in expect:
216 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
217 self.assertTrue(re.search(ex, line) is not None, "Expected %r in %r" % (ex, line))
217
218
218 self._check_generated_stderr(io.stderr, len(v))
219 self._check_generated_stderr(io.stderr, len(v))
219
220
220
221
221 def test_px_nonblocking(self):
222 def test_px_nonblocking(self):
222 ip = get_ipython()
223 ip = get_ipython()
223 v = self.client[-1:]
224 v = self.client[-1:]
224 v.activate()
225 v.activate()
225 v.block=False
226 v.block=False
226
227
227 ip.magic('px a=5')
228 ip.magic('px a=5')
228 self.assertEquals(v['a'], [5])
229 self.assertEquals(v['a'], [5])
229 ip.magic('px a=10')
230 ip.magic('px a=10')
230 self.assertEquals(v['a'], [10])
231 self.assertEquals(v['a'], [10])
231 with capture_output() as io:
232 with capture_output() as io:
232 ar = ip.magic('px print (a)')
233 ar = ip.magic('px print (a)')
233 self.assertTrue(isinstance(ar, AsyncResult))
234 self.assertTrue(isinstance(ar, AsyncResult))
234 self.assertTrue('Async' in io.stdout)
235 self.assertTrue('Async' in io.stdout)
235 self.assertFalse('[stdout:' in io.stdout)
236 self.assertFalse('[stdout:' in io.stdout)
236 self.assertFalse('\n\n' in io.stdout)
237 self.assertFalse('\n\n' in io.stdout)
237
238
238 ar = ip.magic('px 1/0')
239 ar = ip.magic('px 1/0')
239 self.assertRaisesRemote(ZeroDivisionError, ar.get)
240 self.assertRaisesRemote(ZeroDivisionError, ar.get)
240
241
241 def test_autopx_blocking(self):
242 def test_autopx_blocking(self):
242 ip = get_ipython()
243 ip = get_ipython()
243 v = self.client[-1]
244 v = self.client[-1]
244 v.activate()
245 v.activate()
245 v.block=True
246 v.block=True
246
247
247 with capture_output() as io:
248 with capture_output() as io:
248 ip.magic('autopx')
249 ip.magic('autopx')
249 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
250 ip.run_cell('\n'.join(('a=5','b=12345','c=0')))
250 ip.run_cell('b*=2')
251 ip.run_cell('b*=2')
251 ip.run_cell('print (b)')
252 ip.run_cell('print (b)')
252 ip.run_cell('b')
253 ip.run_cell('b')
253 ip.run_cell("b/c")
254 ip.run_cell("b/c")
254 ip.magic('autopx')
255 ip.magic('autopx')
255
256
256 output = io.stdout
257 output = io.stdout
257
258
258 self.assertTrue(output.startswith('%autopx enabled'), output)
259 self.assertTrue(output.startswith('%autopx enabled'), output)
259 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
260 self.assertTrue(output.rstrip().endswith('%autopx disabled'), output)
260 self.assertTrue('RemoteError: ZeroDivisionError' in output, output)
261 self.assertTrue('RemoteError: ZeroDivisionError' in output, output)
261 self.assertTrue('\nOut[' in output, output)
262 self.assertTrue('\nOut[' in output, output)
262 self.assertTrue(': 24690' in output, output)
263 self.assertTrue(': 24690' in output, output)
263 ar = v.get_result(-1)
264 ar = v.get_result(-1)
264 self.assertEquals(v['a'], 5)
265 self.assertEquals(v['a'], 5)
265 self.assertEquals(v['b'], 24690)
266 self.assertEquals(v['b'], 24690)
266 self.assertRaisesRemote(ZeroDivisionError, ar.get)
267 self.assertRaisesRemote(ZeroDivisionError, ar.get)
267
268
268 def test_autopx_nonblocking(self):
269 def test_autopx_nonblocking(self):
269 ip = get_ipython()
270 ip = get_ipython()
270 v = self.client[-1]
271 v = self.client[-1]
271 v.activate()
272 v.activate()
272 v.block=False
273 v.block=False
273
274
274 with capture_output() as io:
275 with capture_output() as io:
275 ip.magic('autopx')
276 ip.magic('autopx')
276 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
277 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
277 ip.run_cell('print (b)')
278 ip.run_cell('print (b)')
278 ip.run_cell('import time; time.sleep(0.1)')
279 ip.run_cell('import time; time.sleep(0.1)')
279 ip.run_cell("b/c")
280 ip.run_cell("b/c")
280 ip.run_cell('b*=2')
281 ip.run_cell('b*=2')
281 ip.magic('autopx')
282 ip.magic('autopx')
282
283
283 output = io.stdout.rstrip()
284 output = io.stdout.rstrip()
284
285
285 self.assertTrue(output.startswith('%autopx enabled'))
286 self.assertTrue(output.startswith('%autopx enabled'))
286 self.assertTrue(output.endswith('%autopx disabled'))
287 self.assertTrue(output.endswith('%autopx disabled'))
287 self.assertFalse('ZeroDivisionError' in output)
288 self.assertFalse('ZeroDivisionError' in output)
288 ar = v.get_result(-2)
289 ar = v.get_result(-2)
289 self.assertRaisesRemote(ZeroDivisionError, ar.get)
290 self.assertRaisesRemote(ZeroDivisionError, ar.get)
290 # prevent TaskAborted on pulls, due to ZeroDivisionError
291 # prevent TaskAborted on pulls, due to ZeroDivisionError
291 time.sleep(0.5)
292 time.sleep(0.5)
292 self.assertEquals(v['a'], 5)
293 self.assertEquals(v['a'], 5)
293 # b*=2 will not fire, due to abort
294 # b*=2 will not fire, due to abort
294 self.assertEquals(v['b'], 10)
295 self.assertEquals(v['b'], 10)
295
296
296 def test_result(self):
297 def test_result(self):
297 ip = get_ipython()
298 ip = get_ipython()
298 v = self.client[-1]
299 v = self.client[-1]
299 v.activate()
300 v.activate()
300 data = dict(a=111,b=222)
301 data = dict(a=111,b=222)
301 v.push(data, block=True)
302 v.push(data, block=True)
302
303
303 ip.magic('px a')
304 ip.magic('px a')
304 ip.magic('px b')
305 ip.magic('px b')
305 for idx, name in [
306 for idx, name in [
306 ('', 'b'),
307 ('', 'b'),
307 ('-1', 'b'),
308 ('-1', 'b'),
308 ('2', 'b'),
309 ('2', 'b'),
309 ('1', 'a'),
310 ('1', 'a'),
310 ('-2', 'a'),
311 ('-2', 'a'),
311 ]:
312 ]:
312 with capture_output() as io:
313 with capture_output() as io:
313 ip.magic('result ' + idx)
314 ip.magic('result ' + idx)
314 output = io.stdout
315 output = io.stdout
315 msg = "expected %s output to include %s, but got: %s" % \
316 msg = "expected %s output to include %s, but got: %s" % \
316 ('%result '+idx, str(data[name]), output)
317 ('%result '+idx, str(data[name]), output)
317 self.assertTrue(str(data[name]) in output, msg)
318 self.assertTrue(str(data[name]) in output, msg)
318
319
319 @dec.skipif_not_matplotlib
320 @dec.skipif_not_matplotlib
320 def test_px_pylab(self):
321 def test_px_pylab(self):
321 """%pylab works on engines"""
322 """%pylab works on engines"""
322 ip = get_ipython()
323 ip = get_ipython()
323 v = self.client[-1]
324 v = self.client[-1]
324 v.block = True
325 v.block = True
325 v.activate()
326 v.activate()
326
327
327 with capture_output() as io:
328 with capture_output() as io:
328 ip.magic("px %pylab inline")
329 ip.magic("px %pylab inline")
329
330
330 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
331 self.assertTrue("Welcome to pylab" in io.stdout, io.stdout)
331 self.assertTrue("backend_inline" in io.stdout, io.stdout)
332 self.assertTrue("backend_inline" in io.stdout, io.stdout)
332
333
333 with capture_output() as io:
334 with capture_output() as io:
334 ip.magic("px plot(rand(100))")
335 ip.magic("px plot(rand(100))")
335
336
336 self.assertTrue('Out[' in io.stdout, io.stdout)
337 self.assertTrue('Out[' in io.stdout, io.stdout)
337 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
338 self.assertTrue('matplotlib.lines' in io.stdout, io.stdout)
338
339
339
340
@@ -1,323 +1,357 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 IO related utilities.
3 IO related utilities.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 from __future__ import print_function
12 from __future__ import print_function
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 import os
17 import os
18 import sys
18 import sys
19 import tempfile
19 import tempfile
20 from StringIO import StringIO
20
21
21 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
22 # Code
23 # Code
23 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
24
25
25
26
26 class IOStream:
27 class IOStream:
27
28
28 def __init__(self,stream, fallback=None):
29 def __init__(self,stream, fallback=None):
29 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
30 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
30 if fallback is not None:
31 if fallback is not None:
31 stream = fallback
32 stream = fallback
32 else:
33 else:
33 raise ValueError("fallback required, but not specified")
34 raise ValueError("fallback required, but not specified")
34 self.stream = stream
35 self.stream = stream
35 self._swrite = stream.write
36 self._swrite = stream.write
36
37
37 # clone all methods not overridden:
38 # clone all methods not overridden:
38 def clone(meth):
39 def clone(meth):
39 return not hasattr(self, meth) and not meth.startswith('_')
40 return not hasattr(self, meth) and not meth.startswith('_')
40 for meth in filter(clone, dir(stream)):
41 for meth in filter(clone, dir(stream)):
41 setattr(self, meth, getattr(stream, meth))
42 setattr(self, meth, getattr(stream, meth))
42
43
43 def write(self,data):
44 def write(self,data):
44 try:
45 try:
45 self._swrite(data)
46 self._swrite(data)
46 except:
47 except:
47 try:
48 try:
48 # print handles some unicode issues which may trip a plain
49 # print handles some unicode issues which may trip a plain
49 # write() call. Emulate write() by using an empty end
50 # write() call. Emulate write() by using an empty end
50 # argument.
51 # argument.
51 print(data, end='', file=self.stream)
52 print(data, end='', file=self.stream)
52 except:
53 except:
53 # if we get here, something is seriously broken.
54 # if we get here, something is seriously broken.
54 print('ERROR - failed to write data to stream:', self.stream,
55 print('ERROR - failed to write data to stream:', self.stream,
55 file=sys.stderr)
56 file=sys.stderr)
56
57
57 def writelines(self, lines):
58 def writelines(self, lines):
58 if isinstance(lines, basestring):
59 if isinstance(lines, basestring):
59 lines = [lines]
60 lines = [lines]
60 for line in lines:
61 for line in lines:
61 self.write(line)
62 self.write(line)
62
63
63 # This class used to have a writeln method, but regular files and streams
64 # This class used to have a writeln method, but regular files and streams
64 # in Python don't have this method. We need to keep this completely
65 # in Python don't have this method. We need to keep this completely
65 # compatible so we removed it.
66 # compatible so we removed it.
66
67
67 @property
68 @property
68 def closed(self):
69 def closed(self):
69 return self.stream.closed
70 return self.stream.closed
70
71
71 def close(self):
72 def close(self):
72 pass
73 pass
73
74
74 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
75 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
75 devnull = open(os.devnull, 'a')
76 devnull = open(os.devnull, 'a')
76 stdin = IOStream(sys.stdin, fallback=devnull)
77 stdin = IOStream(sys.stdin, fallback=devnull)
77 stdout = IOStream(sys.stdout, fallback=devnull)
78 stdout = IOStream(sys.stdout, fallback=devnull)
78 stderr = IOStream(sys.stderr, fallback=devnull)
79 stderr = IOStream(sys.stderr, fallback=devnull)
79
80
80 class IOTerm:
81 class IOTerm:
81 """ Term holds the file or file-like objects for handling I/O operations.
82 """ Term holds the file or file-like objects for handling I/O operations.
82
83
83 These are normally just sys.stdin, sys.stdout and sys.stderr but for
84 These are normally just sys.stdin, sys.stdout and sys.stderr but for
84 Windows they can can replaced to allow editing the strings before they are
85 Windows they can can replaced to allow editing the strings before they are
85 displayed."""
86 displayed."""
86
87
87 # In the future, having IPython channel all its I/O operations through
88 # In the future, having IPython channel all its I/O operations through
88 # this class will make it easier to embed it into other environments which
89 # this class will make it easier to embed it into other environments which
89 # are not a normal terminal (such as a GUI-based shell)
90 # are not a normal terminal (such as a GUI-based shell)
90 def __init__(self, stdin=None, stdout=None, stderr=None):
91 def __init__(self, stdin=None, stdout=None, stderr=None):
91 mymodule = sys.modules[__name__]
92 mymodule = sys.modules[__name__]
92 self.stdin = IOStream(stdin, mymodule.stdin)
93 self.stdin = IOStream(stdin, mymodule.stdin)
93 self.stdout = IOStream(stdout, mymodule.stdout)
94 self.stdout = IOStream(stdout, mymodule.stdout)
94 self.stderr = IOStream(stderr, mymodule.stderr)
95 self.stderr = IOStream(stderr, mymodule.stderr)
95
96
96
97
97 class Tee(object):
98 class Tee(object):
98 """A class to duplicate an output stream to stdout/err.
99 """A class to duplicate an output stream to stdout/err.
99
100
100 This works in a manner very similar to the Unix 'tee' command.
101 This works in a manner very similar to the Unix 'tee' command.
101
102
102 When the object is closed or deleted, it closes the original file given to
103 When the object is closed or deleted, it closes the original file given to
103 it for duplication.
104 it for duplication.
104 """
105 """
105 # Inspired by:
106 # Inspired by:
106 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
107 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
107
108
108 def __init__(self, file_or_name, mode="w", channel='stdout'):
109 def __init__(self, file_or_name, mode="w", channel='stdout'):
109 """Construct a new Tee object.
110 """Construct a new Tee object.
110
111
111 Parameters
112 Parameters
112 ----------
113 ----------
113 file_or_name : filename or open filehandle (writable)
114 file_or_name : filename or open filehandle (writable)
114 File that will be duplicated
115 File that will be duplicated
115
116
116 mode : optional, valid mode for open().
117 mode : optional, valid mode for open().
117 If a filename was give, open with this mode.
118 If a filename was give, open with this mode.
118
119
119 channel : str, one of ['stdout', 'stderr']
120 channel : str, one of ['stdout', 'stderr']
120 """
121 """
121 if channel not in ['stdout', 'stderr']:
122 if channel not in ['stdout', 'stderr']:
122 raise ValueError('Invalid channel spec %s' % channel)
123 raise ValueError('Invalid channel spec %s' % channel)
123
124
124 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
125 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
125 self.file = file_or_name
126 self.file = file_or_name
126 else:
127 else:
127 self.file = open(file_or_name, mode)
128 self.file = open(file_or_name, mode)
128 self.channel = channel
129 self.channel = channel
129 self.ostream = getattr(sys, channel)
130 self.ostream = getattr(sys, channel)
130 setattr(sys, channel, self)
131 setattr(sys, channel, self)
131 self._closed = False
132 self._closed = False
132
133
133 def close(self):
134 def close(self):
134 """Close the file and restore the channel."""
135 """Close the file and restore the channel."""
135 self.flush()
136 self.flush()
136 setattr(sys, self.channel, self.ostream)
137 setattr(sys, self.channel, self.ostream)
137 self.file.close()
138 self.file.close()
138 self._closed = True
139 self._closed = True
139
140
140 def write(self, data):
141 def write(self, data):
141 """Write data to both channels."""
142 """Write data to both channels."""
142 self.file.write(data)
143 self.file.write(data)
143 self.ostream.write(data)
144 self.ostream.write(data)
144 self.ostream.flush()
145 self.ostream.flush()
145
146
146 def flush(self):
147 def flush(self):
147 """Flush both channels."""
148 """Flush both channels."""
148 self.file.flush()
149 self.file.flush()
149 self.ostream.flush()
150 self.ostream.flush()
150
151
151 def __del__(self):
152 def __del__(self):
152 if not self._closed:
153 if not self._closed:
153 self.close()
154 self.close()
154
155
155
156
156 def file_read(filename):
157 def file_read(filename):
157 """Read a file and close it. Returns the file source."""
158 """Read a file and close it. Returns the file source."""
158 fobj = open(filename,'r');
159 fobj = open(filename,'r');
159 source = fobj.read();
160 source = fobj.read();
160 fobj.close()
161 fobj.close()
161 return source
162 return source
162
163
163
164
164 def file_readlines(filename):
165 def file_readlines(filename):
165 """Read a file and close it. Returns the file source using readlines()."""
166 """Read a file and close it. Returns the file source using readlines()."""
166 fobj = open(filename,'r');
167 fobj = open(filename,'r');
167 lines = fobj.readlines();
168 lines = fobj.readlines();
168 fobj.close()
169 fobj.close()
169 return lines
170 return lines
170
171
171
172
172 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
173 def raw_input_multi(header='', ps1='==> ', ps2='..> ',terminate_str = '.'):
173 """Take multiple lines of input.
174 """Take multiple lines of input.
174
175
175 A list with each line of input as a separate element is returned when a
176 A list with each line of input as a separate element is returned when a
176 termination string is entered (defaults to a single '.'). Input can also
177 termination string is entered (defaults to a single '.'). Input can also
177 terminate via EOF (^D in Unix, ^Z-RET in Windows).
178 terminate via EOF (^D in Unix, ^Z-RET in Windows).
178
179
179 Lines of input which end in \\ are joined into single entries (and a
180 Lines of input which end in \\ are joined into single entries (and a
180 secondary continuation prompt is issued as long as the user terminates
181 secondary continuation prompt is issued as long as the user terminates
181 lines with \\). This allows entering very long strings which are still
182 lines with \\). This allows entering very long strings which are still
182 meant to be treated as single entities.
183 meant to be treated as single entities.
183 """
184 """
184
185
185 try:
186 try:
186 if header:
187 if header:
187 header += '\n'
188 header += '\n'
188 lines = [raw_input(header + ps1)]
189 lines = [raw_input(header + ps1)]
189 except EOFError:
190 except EOFError:
190 return []
191 return []
191 terminate = [terminate_str]
192 terminate = [terminate_str]
192 try:
193 try:
193 while lines[-1:] != terminate:
194 while lines[-1:] != terminate:
194 new_line = raw_input(ps1)
195 new_line = raw_input(ps1)
195 while new_line.endswith('\\'):
196 while new_line.endswith('\\'):
196 new_line = new_line[:-1] + raw_input(ps2)
197 new_line = new_line[:-1] + raw_input(ps2)
197 lines.append(new_line)
198 lines.append(new_line)
198
199
199 return lines[:-1] # don't return the termination command
200 return lines[:-1] # don't return the termination command
200 except EOFError:
201 except EOFError:
201 print()
202 print()
202 return lines
203 return lines
203
204
204
205
205 def raw_input_ext(prompt='', ps2='... '):
206 def raw_input_ext(prompt='', ps2='... '):
206 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
207 """Similar to raw_input(), but accepts extended lines if input ends with \\."""
207
208
208 line = raw_input(prompt)
209 line = raw_input(prompt)
209 while line.endswith('\\'):
210 while line.endswith('\\'):
210 line = line[:-1] + raw_input(ps2)
211 line = line[:-1] + raw_input(ps2)
211 return line
212 return line
212
213
213
214
214 def ask_yes_no(prompt,default=None):
215 def ask_yes_no(prompt,default=None):
215 """Asks a question and returns a boolean (y/n) answer.
216 """Asks a question and returns a boolean (y/n) answer.
216
217
217 If default is given (one of 'y','n'), it is used if the user input is
218 If default is given (one of 'y','n'), it is used if the user input is
218 empty. Otherwise the question is repeated until an answer is given.
219 empty. Otherwise the question is repeated until an answer is given.
219
220
220 An EOF is treated as the default answer. If there is no default, an
221 An EOF is treated as the default answer. If there is no default, an
221 exception is raised to prevent infinite loops.
222 exception is raised to prevent infinite loops.
222
223
223 Valid answers are: y/yes/n/no (match is not case sensitive)."""
224 Valid answers are: y/yes/n/no (match is not case sensitive)."""
224
225
225 answers = {'y':True,'n':False,'yes':True,'no':False}
226 answers = {'y':True,'n':False,'yes':True,'no':False}
226 ans = None
227 ans = None
227 while ans not in answers.keys():
228 while ans not in answers.keys():
228 try:
229 try:
229 ans = raw_input(prompt+' ').lower()
230 ans = raw_input(prompt+' ').lower()
230 if not ans: # response was an empty string
231 if not ans: # response was an empty string
231 ans = default
232 ans = default
232 except KeyboardInterrupt:
233 except KeyboardInterrupt:
233 pass
234 pass
234 except EOFError:
235 except EOFError:
235 if default in answers.keys():
236 if default in answers.keys():
236 ans = default
237 ans = default
237 print()
238 print()
238 else:
239 else:
239 raise
240 raise
240
241
241 return answers[ans]
242 return answers[ans]
242
243
243
244
244 class NLprinter:
245 class NLprinter:
245 """Print an arbitrarily nested list, indicating index numbers.
246 """Print an arbitrarily nested list, indicating index numbers.
246
247
247 An instance of this class called nlprint is available and callable as a
248 An instance of this class called nlprint is available and callable as a
248 function.
249 function.
249
250
250 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
251 nlprint(list,indent=' ',sep=': ') -> prints indenting each level by 'indent'
251 and using 'sep' to separate the index from the value. """
252 and using 'sep' to separate the index from the value. """
252
253
253 def __init__(self):
254 def __init__(self):
254 self.depth = 0
255 self.depth = 0
255
256
256 def __call__(self,lst,pos='',**kw):
257 def __call__(self,lst,pos='',**kw):
257 """Prints the nested list numbering levels."""
258 """Prints the nested list numbering levels."""
258 kw.setdefault('indent',' ')
259 kw.setdefault('indent',' ')
259 kw.setdefault('sep',': ')
260 kw.setdefault('sep',': ')
260 kw.setdefault('start',0)
261 kw.setdefault('start',0)
261 kw.setdefault('stop',len(lst))
262 kw.setdefault('stop',len(lst))
262 # we need to remove start and stop from kw so they don't propagate
263 # we need to remove start and stop from kw so they don't propagate
263 # into a recursive call for a nested list.
264 # into a recursive call for a nested list.
264 start = kw['start']; del kw['start']
265 start = kw['start']; del kw['start']
265 stop = kw['stop']; del kw['stop']
266 stop = kw['stop']; del kw['stop']
266 if self.depth == 0 and 'header' in kw.keys():
267 if self.depth == 0 and 'header' in kw.keys():
267 print(kw['header'])
268 print(kw['header'])
268
269
269 for idx in range(start,stop):
270 for idx in range(start,stop):
270 elem = lst[idx]
271 elem = lst[idx]
271 newpos = pos + str(idx)
272 newpos = pos + str(idx)
272 if type(elem)==type([]):
273 if type(elem)==type([]):
273 self.depth += 1
274 self.depth += 1
274 self.__call__(elem, newpos+",", **kw)
275 self.__call__(elem, newpos+",", **kw)
275 self.depth -= 1
276 self.depth -= 1
276 else:
277 else:
277 print(kw['indent']*self.depth + newpos + kw["sep"] + repr(elem))
278 print(kw['indent']*self.depth + newpos + kw["sep"] + repr(elem))
278
279
279 nlprint = NLprinter()
280 nlprint = NLprinter()
280
281
281
282
282 def temp_pyfile(src, ext='.py'):
283 def temp_pyfile(src, ext='.py'):
283 """Make a temporary python file, return filename and filehandle.
284 """Make a temporary python file, return filename and filehandle.
284
285
285 Parameters
286 Parameters
286 ----------
287 ----------
287 src : string or list of strings (no need for ending newlines if list)
288 src : string or list of strings (no need for ending newlines if list)
288 Source code to be written to the file.
289 Source code to be written to the file.
289
290
290 ext : optional, string
291 ext : optional, string
291 Extension for the generated file.
292 Extension for the generated file.
292
293
293 Returns
294 Returns
294 -------
295 -------
295 (filename, open filehandle)
296 (filename, open filehandle)
296 It is the caller's responsibility to close the open file and unlink it.
297 It is the caller's responsibility to close the open file and unlink it.
297 """
298 """
298 fname = tempfile.mkstemp(ext)[1]
299 fname = tempfile.mkstemp(ext)[1]
299 f = open(fname,'w')
300 f = open(fname,'w')
300 f.write(src)
301 f.write(src)
301 f.flush()
302 f.flush()
302 return fname, f
303 return fname, f
303
304
304
305
305 def raw_print(*args, **kw):
306 def raw_print(*args, **kw):
306 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
307 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
307
308
308 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
309 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
309 file=sys.__stdout__)
310 file=sys.__stdout__)
310 sys.__stdout__.flush()
311 sys.__stdout__.flush()
311
312
312
313
313 def raw_print_err(*args, **kw):
314 def raw_print_err(*args, **kw):
314 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
315 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
315
316
316 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
317 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
317 file=sys.__stderr__)
318 file=sys.__stderr__)
318 sys.__stderr__.flush()
319 sys.__stderr__.flush()
319
320
320
321
321 # Short aliases for quick debugging, do NOT use these in production code.
322 # Short aliases for quick debugging, do NOT use these in production code.
322 rprint = raw_print
323 rprint = raw_print
323 rprinte = raw_print_err
324 rprinte = raw_print_err
325
326
327 class CapturedIO(object):
328 """Simple object for containing captured stdout/err StringIO objects"""
329
330 def __init__(self, stdout, stderr):
331 self.stdout_io = stdout
332 self.stderr_io = stderr
333
334 @property
335 def stdout(self):
336 return self.stdout_io.getvalue()
337
338 @property
339 def stderr(self):
340 return self.stderr_io.getvalue()
341
342
343 class capture_output(object):
344 """context manager for capturing stdout/err"""
345
346 def __enter__(self):
347 self.sys_stdout = sys.stdout
348 self.sys_stderr = sys.stderr
349 stdout = sys.stdout = StringIO()
350 stderr = sys.stderr = StringIO()
351 return CapturedIO(stdout, stderr)
352
353 def __exit__(self, exc_type, exc_value, traceback):
354 sys.stdout = self.sys_stdout
355 sys.stderr = self.sys_stderr
356
357
@@ -1,75 +1,85 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for io.py"""
2 """Tests for io.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import sys
15 import sys
16
16
17 from StringIO import StringIO
17 from StringIO import StringIO
18 from subprocess import Popen, PIPE
18 from subprocess import Popen, PIPE
19
19
20 import nose.tools as nt
20 import nose.tools as nt
21
21
22 from IPython.testing import decorators as dec
22 from IPython.testing import decorators as dec
23 from IPython.utils.io import Tee
23 from IPython.utils.io import Tee, capture_output
24 from IPython.utils.py3compat import doctest_refactor_print
24 from IPython.utils.py3compat import doctest_refactor_print
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Tests
27 # Tests
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29
29
30
30
31 def test_tee_simple():
31 def test_tee_simple():
32 "Very simple check with stdout only"
32 "Very simple check with stdout only"
33 chan = StringIO()
33 chan = StringIO()
34 text = 'Hello'
34 text = 'Hello'
35 tee = Tee(chan, channel='stdout')
35 tee = Tee(chan, channel='stdout')
36 print >> chan, text
36 print >> chan, text
37 nt.assert_equal(chan.getvalue(), text+"\n")
37 nt.assert_equal(chan.getvalue(), text+"\n")
38
38
39
39
40 class TeeTestCase(dec.ParametricTestCase):
40 class TeeTestCase(dec.ParametricTestCase):
41
41
42 def tchan(self, channel, check='close'):
42 def tchan(self, channel, check='close'):
43 trap = StringIO()
43 trap = StringIO()
44 chan = StringIO()
44 chan = StringIO()
45 text = 'Hello'
45 text = 'Hello'
46
46
47 std_ori = getattr(sys, channel)
47 std_ori = getattr(sys, channel)
48 setattr(sys, channel, trap)
48 setattr(sys, channel, trap)
49
49
50 tee = Tee(chan, channel=channel)
50 tee = Tee(chan, channel=channel)
51 print >> chan, text,
51 print >> chan, text,
52 setattr(sys, channel, std_ori)
52 setattr(sys, channel, std_ori)
53 trap_val = trap.getvalue()
53 trap_val = trap.getvalue()
54 nt.assert_equals(chan.getvalue(), text)
54 nt.assert_equals(chan.getvalue(), text)
55 if check=='close':
55 if check=='close':
56 tee.close()
56 tee.close()
57 else:
57 else:
58 del tee
58 del tee
59
59
60 def test(self):
60 def test(self):
61 for chan in ['stdout', 'stderr']:
61 for chan in ['stdout', 'stderr']:
62 for check in ['close', 'del']:
62 for check in ['close', 'del']:
63 yield self.tchan(chan, check)
63 yield self.tchan(chan, check)
64
64
65 def test_io_init():
65 def test_io_init():
66 """Test that io.stdin/out/err exist at startup"""
66 """Test that io.stdin/out/err exist at startup"""
67 for name in ('stdin', 'stdout', 'stderr'):
67 for name in ('stdin', 'stdout', 'stderr'):
68 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
68 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
69 p = Popen([sys.executable, '-c', cmd],
69 p = Popen([sys.executable, '-c', cmd],
70 stdout=PIPE)
70 stdout=PIPE)
71 p.wait()
71 p.wait()
72 classname = p.stdout.read().strip().decode('ascii')
72 classname = p.stdout.read().strip().decode('ascii')
73 # __class__ is a reference to the class object in Python 3, so we can't
73 # __class__ is a reference to the class object in Python 3, so we can't
74 # just test for string equality.
74 # just test for string equality.
75 assert 'IPython.utils.io.IOStream' in classname, classname
75 assert 'IPython.utils.io.IOStream' in classname, classname
76
77 def test_capture_output():
78 """capture_output() context works"""
79
80 with capture_output() as io:
81 print 'hi, stdout'
82 print >> sys.stderr, 'hi, stderr'
83
84 nt.assert_equals(io.stdout, 'hi, stdout\n')
85 nt.assert_equals(io.stderr, 'hi, stderr\n')
General Comments 0
You need to be logged in to leave comments. Login now