##// END OF EJS Templates
Backport PR #6911: don't use text mode in mkstemp...
Min RK -
Show More
@@ -1,345 +1,345 b''
1 1 # encoding: utf-8
2 2 """
3 3 IO related utilities.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12 from __future__ import print_function
13 13 from __future__ import absolute_import
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 import codecs
19 19 from contextlib import contextmanager
20 20 import io
21 21 import os
22 22 import shutil
23 23 import stat
24 24 import sys
25 25 import tempfile
26 26 from .capture import CapturedIO, capture_output
27 27 from .py3compat import string_types, input, PY3
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Code
31 31 #-----------------------------------------------------------------------------
32 32
33 33
34 34 class IOStream:
35 35
36 36 def __init__(self,stream, fallback=None):
37 37 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
38 38 if fallback is not None:
39 39 stream = fallback
40 40 else:
41 41 raise ValueError("fallback required, but not specified")
42 42 self.stream = stream
43 43 self._swrite = stream.write
44 44
45 45 # clone all methods not overridden:
46 46 def clone(meth):
47 47 return not hasattr(self, meth) and not meth.startswith('_')
48 48 for meth in filter(clone, dir(stream)):
49 49 setattr(self, meth, getattr(stream, meth))
50 50
51 51 def __repr__(self):
52 52 cls = self.__class__
53 53 tpl = '{mod}.{cls}({args})'
54 54 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
55 55
56 56 def write(self,data):
57 57 try:
58 58 self._swrite(data)
59 59 except:
60 60 try:
61 61 # print handles some unicode issues which may trip a plain
62 62 # write() call. Emulate write() by using an empty end
63 63 # argument.
64 64 print(data, end='', file=self.stream)
65 65 except:
66 66 # if we get here, something is seriously broken.
67 67 print('ERROR - failed to write data to stream:', self.stream,
68 68 file=sys.stderr)
69 69
70 70 def writelines(self, lines):
71 71 if isinstance(lines, string_types):
72 72 lines = [lines]
73 73 for line in lines:
74 74 self.write(line)
75 75
76 76 # This class used to have a writeln method, but regular files and streams
77 77 # in Python don't have this method. We need to keep this completely
78 78 # compatible so we removed it.
79 79
80 80 @property
81 81 def closed(self):
82 82 return self.stream.closed
83 83
84 84 def close(self):
85 85 pass
86 86
87 87 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
88 88 devnull = open(os.devnull, 'w')
89 89 stdin = IOStream(sys.stdin, fallback=devnull)
90 90 stdout = IOStream(sys.stdout, fallback=devnull)
91 91 stderr = IOStream(sys.stderr, fallback=devnull)
92 92
93 93 class IOTerm:
94 94 """ Term holds the file or file-like objects for handling I/O operations.
95 95
96 96 These are normally just sys.stdin, sys.stdout and sys.stderr but for
97 97 Windows they can can replaced to allow editing the strings before they are
98 98 displayed."""
99 99
100 100 # In the future, having IPython channel all its I/O operations through
101 101 # this class will make it easier to embed it into other environments which
102 102 # are not a normal terminal (such as a GUI-based shell)
103 103 def __init__(self, stdin=None, stdout=None, stderr=None):
104 104 mymodule = sys.modules[__name__]
105 105 self.stdin = IOStream(stdin, mymodule.stdin)
106 106 self.stdout = IOStream(stdout, mymodule.stdout)
107 107 self.stderr = IOStream(stderr, mymodule.stderr)
108 108
109 109
110 110 class Tee(object):
111 111 """A class to duplicate an output stream to stdout/err.
112 112
113 113 This works in a manner very similar to the Unix 'tee' command.
114 114
115 115 When the object is closed or deleted, it closes the original file given to
116 116 it for duplication.
117 117 """
118 118 # Inspired by:
119 119 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
120 120
121 121 def __init__(self, file_or_name, mode="w", channel='stdout'):
122 122 """Construct a new Tee object.
123 123
124 124 Parameters
125 125 ----------
126 126 file_or_name : filename or open filehandle (writable)
127 127 File that will be duplicated
128 128
129 129 mode : optional, valid mode for open().
130 130 If a filename was give, open with this mode.
131 131
132 132 channel : str, one of ['stdout', 'stderr']
133 133 """
134 134 if channel not in ['stdout', 'stderr']:
135 135 raise ValueError('Invalid channel spec %s' % channel)
136 136
137 137 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
138 138 self.file = file_or_name
139 139 else:
140 140 self.file = open(file_or_name, mode)
141 141 self.channel = channel
142 142 self.ostream = getattr(sys, channel)
143 143 setattr(sys, channel, self)
144 144 self._closed = False
145 145
146 146 def close(self):
147 147 """Close the file and restore the channel."""
148 148 self.flush()
149 149 setattr(sys, self.channel, self.ostream)
150 150 self.file.close()
151 151 self._closed = True
152 152
153 153 def write(self, data):
154 154 """Write data to both channels."""
155 155 self.file.write(data)
156 156 self.ostream.write(data)
157 157 self.ostream.flush()
158 158
159 159 def flush(self):
160 160 """Flush both channels."""
161 161 self.file.flush()
162 162 self.ostream.flush()
163 163
164 164 def __del__(self):
165 165 if not self._closed:
166 166 self.close()
167 167
168 168
169 169 def ask_yes_no(prompt, default=None, interrupt=None):
170 170 """Asks a question and returns a boolean (y/n) answer.
171 171
172 172 If default is given (one of 'y','n'), it is used if the user input is
173 173 empty. If interrupt is given (one of 'y','n'), it is used if the user
174 174 presses Ctrl-C. Otherwise the question is repeated until an answer is
175 175 given.
176 176
177 177 An EOF is treated as the default answer. If there is no default, an
178 178 exception is raised to prevent infinite loops.
179 179
180 180 Valid answers are: y/yes/n/no (match is not case sensitive)."""
181 181
182 182 answers = {'y':True,'n':False,'yes':True,'no':False}
183 183 ans = None
184 184 while ans not in answers.keys():
185 185 try:
186 186 ans = input(prompt+' ').lower()
187 187 if not ans: # response was an empty string
188 188 ans = default
189 189 except KeyboardInterrupt:
190 190 if interrupt:
191 191 ans = interrupt
192 192 except EOFError:
193 193 if default in answers.keys():
194 194 ans = default
195 195 print()
196 196 else:
197 197 raise
198 198
199 199 return answers[ans]
200 200
201 201
202 202 def temp_pyfile(src, ext='.py'):
203 203 """Make a temporary python file, return filename and filehandle.
204 204
205 205 Parameters
206 206 ----------
207 207 src : string or list of strings (no need for ending newlines if list)
208 208 Source code to be written to the file.
209 209
210 210 ext : optional, string
211 211 Extension for the generated file.
212 212
213 213 Returns
214 214 -------
215 215 (filename, open filehandle)
216 216 It is the caller's responsibility to close the open file and unlink it.
217 217 """
218 218 fname = tempfile.mkstemp(ext)[1]
219 219 f = open(fname,'w')
220 220 f.write(src)
221 221 f.flush()
222 222 return fname, f
223 223
224 224 def _copy_metadata(src, dst):
225 225 """Copy the set of metadata we want for atomic_writing.
226 226
227 227 Permission bits and flags. We'd like to copy file ownership as well, but we
228 228 can't do that.
229 229 """
230 230 shutil.copymode(src, dst)
231 231 st = os.stat(src)
232 232 if hasattr(os, 'chflags') and hasattr(st, 'st_flags'):
233 233 os.chflags(dst, st.st_flags)
234 234
235 235 @contextmanager
236 236 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
237 237 """Context manager to write to a file only if the entire write is successful.
238 238
239 239 This works by creating a temporary file in the same directory, and renaming
240 240 it over the old file if the context is exited without an error. If the
241 241 target file is a symlink or a hardlink, this will not be preserved: it will
242 242 be replaced by a new regular file.
243 243
244 244 On Windows, there is a small chink in the atomicity: the target file is
245 245 deleted before renaming the temporary file over it. This appears to be
246 246 unavoidable.
247 247
248 248 Parameters
249 249 ----------
250 250 path : str
251 251 The target file to write to.
252 252
253 253 text : bool, optional
254 254 Whether to open the file in text mode (i.e. to write unicode). Default is
255 255 True.
256 256
257 257 encoding : str, optional
258 258 The encoding to use for files opened in text mode. Default is UTF-8.
259 259
260 260 **kwargs
261 261 Passed to :func:`io.open`.
262 262 """
263 263 # realpath doesn't work on Windows: http://bugs.python.org/issue9949
264 264 # Luckily, we only need to resolve the file itself being a symlink, not
265 265 # any of its directories, so this will suffice:
266 266 if os.path.islink(path):
267 267 path = os.path.join(os.path.dirname(path), os.readlink(path))
268 268
269 269 dirname, basename = os.path.split(path)
270 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname, text=text)
270 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname)
271 271 if text:
272 272 fileobj = io.open(handle, 'w', encoding=encoding, **kwargs)
273 273 else:
274 274 fileobj = io.open(handle, 'wb', **kwargs)
275 275
276 276 try:
277 277 yield fileobj
278 278 except:
279 279 fileobj.close()
280 280 os.remove(tmp_path)
281 281 raise
282 282
283 283 # Flush to disk
284 284 fileobj.flush()
285 285 os.fsync(fileobj.fileno())
286 286
287 287 # Written successfully, now rename it
288 288 fileobj.close()
289 289
290 290 # Copy permission bits, access time, etc.
291 291 try:
292 292 _copy_metadata(path, tmp_path)
293 293 except OSError:
294 294 # e.g. the file didn't already exist. Ignore any failure to copy metadata
295 295 pass
296 296
297 297 if os.name == 'nt' and os.path.exists(path):
298 298 # Rename over existing file doesn't work on Windows
299 299 os.remove(path)
300 300
301 301 os.rename(tmp_path, path)
302 302
303 303
304 304 def raw_print(*args, **kw):
305 305 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
306 306
307 307 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
308 308 file=sys.__stdout__)
309 309 sys.__stdout__.flush()
310 310
311 311
312 312 def raw_print_err(*args, **kw):
313 313 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
314 314
315 315 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
316 316 file=sys.__stderr__)
317 317 sys.__stderr__.flush()
318 318
319 319
320 320 # Short aliases for quick debugging, do NOT use these in production code.
321 321 rprint = raw_print
322 322 rprinte = raw_print_err
323 323
324 324 def unicode_std_stream(stream='stdout'):
325 325 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
326 326
327 327 This ignores environment variables and default encodings, to reliably write
328 328 unicode to stdout or stderr.
329 329
330 330 ::
331 331
332 332 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
333 333 """
334 334 assert stream in ('stdout', 'stderr')
335 335 stream = getattr(sys, stream)
336 336 if PY3:
337 337 try:
338 338 stream_b = stream.buffer
339 339 except AttributeError:
340 340 # sys.stdout has been replaced - use it directly
341 341 return stream
342 342 else:
343 343 stream_b = stream
344 344
345 345 return codecs.getwriter('utf-8')(stream_b)
@@ -1,178 +1,215 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15 from __future__ import absolute_import
16 16
17 17 import io as stdlib_io
18 18 import os.path
19 19 import stat
20 20 import sys
21 21
22 22 from subprocess import Popen, PIPE
23 23 import unittest
24 24
25 25 import nose.tools as nt
26 26
27 27 from IPython.testing.decorators import skipif
28 28 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
29 29 atomic_writing,
30 30 )
31 31 from IPython.utils.py3compat import doctest_refactor_print, PY3
32 32 from IPython.utils.tempdir import TemporaryDirectory
33 33
34 34 if PY3:
35 35 from io import StringIO
36 36 else:
37 37 from StringIO import StringIO
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Tests
41 41 #-----------------------------------------------------------------------------
42 42
43 43
44 44 def test_tee_simple():
45 45 "Very simple check with stdout only"
46 46 chan = StringIO()
47 47 text = 'Hello'
48 48 tee = Tee(chan, channel='stdout')
49 49 print(text, file=chan)
50 50 nt.assert_equal(chan.getvalue(), text+"\n")
51 51
52 52
53 53 class TeeTestCase(unittest.TestCase):
54 54
55 55 def tchan(self, channel, check='close'):
56 56 trap = StringIO()
57 57 chan = StringIO()
58 58 text = 'Hello'
59 59
60 60 std_ori = getattr(sys, channel)
61 61 setattr(sys, channel, trap)
62 62
63 63 tee = Tee(chan, channel=channel)
64 64 print(text, end='', file=chan)
65 65 setattr(sys, channel, std_ori)
66 66 trap_val = trap.getvalue()
67 67 nt.assert_equal(chan.getvalue(), text)
68 68 if check=='close':
69 69 tee.close()
70 70 else:
71 71 del tee
72 72
73 73 def test(self):
74 74 for chan in ['stdout', 'stderr']:
75 75 for check in ['close', 'del']:
76 76 self.tchan(chan, check)
77 77
78 78 def test_io_init():
79 79 """Test that io.stdin/out/err exist at startup"""
80 80 for name in ('stdin', 'stdout', 'stderr'):
81 81 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
82 82 p = Popen([sys.executable, '-c', cmd],
83 83 stdout=PIPE)
84 84 p.wait()
85 85 classname = p.stdout.read().strip().decode('ascii')
86 86 # __class__ is a reference to the class object in Python 3, so we can't
87 87 # just test for string equality.
88 88 assert 'IPython.utils.io.IOStream' in classname, classname
89 89
90 90 def test_capture_output():
91 91 """capture_output() context works"""
92 92
93 93 with capture_output() as io:
94 94 print('hi, stdout')
95 95 print('hi, stderr', file=sys.stderr)
96 96
97 97 nt.assert_equal(io.stdout, 'hi, stdout\n')
98 98 nt.assert_equal(io.stderr, 'hi, stderr\n')
99 99
100 100 def test_UnicodeStdStream():
101 101 # Test wrapping a bytes-level stdout
102 102 if PY3:
103 103 stdoutb = stdlib_io.BytesIO()
104 104 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
105 105 else:
106 106 stdout = stdoutb = stdlib_io.BytesIO()
107 107
108 108 orig_stdout = sys.stdout
109 109 sys.stdout = stdout
110 110 try:
111 111 sample = u"@Ε‚e¢ŧ←"
112 112 unicode_std_stream().write(sample)
113 113
114 114 output = stdoutb.getvalue().decode('utf-8')
115 115 nt.assert_equal(output, sample)
116 116 assert not stdout.closed
117 117 finally:
118 118 sys.stdout = orig_stdout
119 119
120 120 @skipif(not PY3, "Not applicable on Python 2")
121 121 def test_UnicodeStdStream_nowrap():
122 122 # If we replace stdout with a StringIO, it shouldn't get wrapped.
123 123 orig_stdout = sys.stdout
124 124 sys.stdout = StringIO()
125 125 try:
126 126 nt.assert_is(unicode_std_stream(), sys.stdout)
127 127 assert not sys.stdout.closed
128 128 finally:
129 129 sys.stdout = orig_stdout
130 130
131 131 def test_atomic_writing():
132 132 class CustomExc(Exception): pass
133 133
134 134 with TemporaryDirectory() as td:
135 135 f1 = os.path.join(td, 'penguin')
136 136 with stdlib_io.open(f1, 'w') as f:
137 137 f.write(u'Before')
138 138
139 139 if os.name != 'nt':
140 140 os.chmod(f1, 0o701)
141 141 orig_mode = stat.S_IMODE(os.stat(f1).st_mode)
142 142
143 143 f2 = os.path.join(td, 'flamingo')
144 144 try:
145 145 os.symlink(f1, f2)
146 146 have_symlink = True
147 147 except (AttributeError, NotImplementedError, OSError):
148 148 # AttributeError: Python doesn't support it
149 149 # NotImplementedError: The system doesn't support it
150 150 # OSError: The user lacks the privilege (Windows)
151 151 have_symlink = False
152 152
153 153 with nt.assert_raises(CustomExc):
154 154 with atomic_writing(f1) as f:
155 155 f.write(u'Failing write')
156 156 raise CustomExc
157 157
158 158 # Because of the exception, the file should not have been modified
159 159 with stdlib_io.open(f1, 'r') as f:
160 160 nt.assert_equal(f.read(), u'Before')
161 161
162 162 with atomic_writing(f1) as f:
163 163 f.write(u'Overwritten')
164 164
165 165 with stdlib_io.open(f1, 'r') as f:
166 166 nt.assert_equal(f.read(), u'Overwritten')
167 167
168 168 if os.name != 'nt':
169 169 mode = stat.S_IMODE(os.stat(f1).st_mode)
170 170 nt.assert_equal(mode, orig_mode)
171 171
172 172 if have_symlink:
173 173 # Check that writing over a file preserves a symlink
174 174 with atomic_writing(f2) as f:
175 175 f.write(u'written from symlink')
176 176
177 177 with stdlib_io.open(f1, 'r') as f:
178 nt.assert_equal(f.read(), u'written from symlink') No newline at end of file
178 nt.assert_equal(f.read(), u'written from symlink')
179
180 def test_atomic_writing_newlines():
181 with TemporaryDirectory() as td:
182 path = os.path.join(td, 'testfile')
183
184 lf = u'a\nb\nc\n'
185 plat = lf.replace(u'\n', os.linesep)
186 crlf = lf.replace(u'\n', u'\r\n')
187
188 # test default
189 with stdlib_io.open(path, 'w') as f:
190 f.write(lf)
191 with stdlib_io.open(path, 'r', newline='') as f:
192 read = f.read()
193 nt.assert_equal(read, plat)
194
195 # test newline=LF
196 with stdlib_io.open(path, 'w', newline='\n') as f:
197 f.write(lf)
198 with stdlib_io.open(path, 'r', newline='') as f:
199 read = f.read()
200 nt.assert_equal(read, lf)
201
202 # test newline=CRLF
203 with atomic_writing(path, newline='\r\n') as f:
204 f.write(lf)
205 with stdlib_io.open(path, 'r', newline='') as f:
206 read = f.read()
207 nt.assert_equal(read, crlf)
208
209 # test newline=no convert
210 text = u'crlf\r\ncr\rlf\n'
211 with atomic_writing(path, newline='') as f:
212 f.write(text)
213 with stdlib_io.open(path, 'r', newline='') as f:
214 read = f.read()
215 nt.assert_equal(read, text)
General Comments 0
You need to be logged in to leave comments. Login now