##// END OF EJS Templates
preserve umask in atomic_writing...
Min RK -
Show More
@@ -1,345 +1,347 b''
1 1 # encoding: utf-8
2 2 """
3 3 IO related utilities.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12 from __future__ import print_function
13 13 from __future__ import absolute_import
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 import codecs
19 19 from contextlib import contextmanager
20 20 import io
21 21 import os
22 22 import shutil
23 23 import stat
24 24 import sys
25 25 import tempfile
26 26 from .capture import CapturedIO, capture_output
27 27 from .py3compat import string_types, input, PY3
28 28
29 29 #-----------------------------------------------------------------------------
30 30 # Code
31 31 #-----------------------------------------------------------------------------
32 32
33 33
34 34 class IOStream:
35 35
36 36 def __init__(self,stream, fallback=None):
37 37 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
38 38 if fallback is not None:
39 39 stream = fallback
40 40 else:
41 41 raise ValueError("fallback required, but not specified")
42 42 self.stream = stream
43 43 self._swrite = stream.write
44 44
45 45 # clone all methods not overridden:
46 46 def clone(meth):
47 47 return not hasattr(self, meth) and not meth.startswith('_')
48 48 for meth in filter(clone, dir(stream)):
49 49 setattr(self, meth, getattr(stream, meth))
50 50
51 51 def __repr__(self):
52 52 cls = self.__class__
53 53 tpl = '{mod}.{cls}({args})'
54 54 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
55 55
56 56 def write(self,data):
57 57 try:
58 58 self._swrite(data)
59 59 except:
60 60 try:
61 61 # print handles some unicode issues which may trip a plain
62 62 # write() call. Emulate write() by using an empty end
63 63 # argument.
64 64 print(data, end='', file=self.stream)
65 65 except:
66 66 # if we get here, something is seriously broken.
67 67 print('ERROR - failed to write data to stream:', self.stream,
68 68 file=sys.stderr)
69 69
70 70 def writelines(self, lines):
71 71 if isinstance(lines, string_types):
72 72 lines = [lines]
73 73 for line in lines:
74 74 self.write(line)
75 75
76 76 # This class used to have a writeln method, but regular files and streams
77 77 # in Python don't have this method. We need to keep this completely
78 78 # compatible so we removed it.
79 79
80 80 @property
81 81 def closed(self):
82 82 return self.stream.closed
83 83
84 84 def close(self):
85 85 pass
86 86
87 87 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
88 88 devnull = open(os.devnull, 'w')
89 89 stdin = IOStream(sys.stdin, fallback=devnull)
90 90 stdout = IOStream(sys.stdout, fallback=devnull)
91 91 stderr = IOStream(sys.stderr, fallback=devnull)
92 92
93 93 class IOTerm:
94 94 """ Term holds the file or file-like objects for handling I/O operations.
95 95
96 96 These are normally just sys.stdin, sys.stdout and sys.stderr but for
97 97 Windows they can can replaced to allow editing the strings before they are
98 98 displayed."""
99 99
100 100 # In the future, having IPython channel all its I/O operations through
101 101 # this class will make it easier to embed it into other environments which
102 102 # are not a normal terminal (such as a GUI-based shell)
103 103 def __init__(self, stdin=None, stdout=None, stderr=None):
104 104 mymodule = sys.modules[__name__]
105 105 self.stdin = IOStream(stdin, mymodule.stdin)
106 106 self.stdout = IOStream(stdout, mymodule.stdout)
107 107 self.stderr = IOStream(stderr, mymodule.stderr)
108 108
109 109
110 110 class Tee(object):
111 111 """A class to duplicate an output stream to stdout/err.
112 112
113 113 This works in a manner very similar to the Unix 'tee' command.
114 114
115 115 When the object is closed or deleted, it closes the original file given to
116 116 it for duplication.
117 117 """
118 118 # Inspired by:
119 119 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
120 120
121 121 def __init__(self, file_or_name, mode="w", channel='stdout'):
122 122 """Construct a new Tee object.
123 123
124 124 Parameters
125 125 ----------
126 126 file_or_name : filename or open filehandle (writable)
127 127 File that will be duplicated
128 128
129 129 mode : optional, valid mode for open().
130 130 If a filename was give, open with this mode.
131 131
132 132 channel : str, one of ['stdout', 'stderr']
133 133 """
134 134 if channel not in ['stdout', 'stderr']:
135 135 raise ValueError('Invalid channel spec %s' % channel)
136 136
137 137 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
138 138 self.file = file_or_name
139 139 else:
140 140 self.file = open(file_or_name, mode)
141 141 self.channel = channel
142 142 self.ostream = getattr(sys, channel)
143 143 setattr(sys, channel, self)
144 144 self._closed = False
145 145
146 146 def close(self):
147 147 """Close the file and restore the channel."""
148 148 self.flush()
149 149 setattr(sys, self.channel, self.ostream)
150 150 self.file.close()
151 151 self._closed = True
152 152
153 153 def write(self, data):
154 154 """Write data to both channels."""
155 155 self.file.write(data)
156 156 self.ostream.write(data)
157 157 self.ostream.flush()
158 158
159 159 def flush(self):
160 160 """Flush both channels."""
161 161 self.file.flush()
162 162 self.ostream.flush()
163 163
164 164 def __del__(self):
165 165 if not self._closed:
166 166 self.close()
167 167
168 168
169 169 def ask_yes_no(prompt, default=None, interrupt=None):
170 170 """Asks a question and returns a boolean (y/n) answer.
171 171
172 172 If default is given (one of 'y','n'), it is used if the user input is
173 173 empty. If interrupt is given (one of 'y','n'), it is used if the user
174 174 presses Ctrl-C. Otherwise the question is repeated until an answer is
175 175 given.
176 176
177 177 An EOF is treated as the default answer. If there is no default, an
178 178 exception is raised to prevent infinite loops.
179 179
180 180 Valid answers are: y/yes/n/no (match is not case sensitive)."""
181 181
182 182 answers = {'y':True,'n':False,'yes':True,'no':False}
183 183 ans = None
184 184 while ans not in answers.keys():
185 185 try:
186 186 ans = input(prompt+' ').lower()
187 187 if not ans: # response was an empty string
188 188 ans = default
189 189 except KeyboardInterrupt:
190 190 if interrupt:
191 191 ans = interrupt
192 192 except EOFError:
193 193 if default in answers.keys():
194 194 ans = default
195 195 print()
196 196 else:
197 197 raise
198 198
199 199 return answers[ans]
200 200
201 201
202 202 def temp_pyfile(src, ext='.py'):
203 203 """Make a temporary python file, return filename and filehandle.
204 204
205 205 Parameters
206 206 ----------
207 207 src : string or list of strings (no need for ending newlines if list)
208 208 Source code to be written to the file.
209 209
210 210 ext : optional, string
211 211 Extension for the generated file.
212 212
213 213 Returns
214 214 -------
215 215 (filename, open filehandle)
216 216 It is the caller's responsibility to close the open file and unlink it.
217 217 """
218 218 fname = tempfile.mkstemp(ext)[1]
219 219 f = open(fname,'w')
220 220 f.write(src)
221 221 f.flush()
222 222 return fname, f
223 223
224 224 def _copy_metadata(src, dst):
225 225 """Copy the set of metadata we want for atomic_writing.
226 226
227 227 Permission bits and flags. We'd like to copy file ownership as well, but we
228 228 can't do that.
229 229 """
230 230 shutil.copymode(src, dst)
231 231 st = os.stat(src)
232 232 if hasattr(os, 'chflags') and hasattr(st, 'st_flags'):
233 233 os.chflags(dst, st.st_flags)
234 234
235 235 @contextmanager
236 236 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
237 237 """Context manager to write to a file only if the entire write is successful.
238 238
239 239 This works by creating a temporary file in the same directory, and renaming
240 240 it over the old file if the context is exited without an error. If other
241 241 file names are hard linked to the target file, this relationship will not be
242 242 preserved.
243 243
244 244 On Windows, there is a small chink in the atomicity: the target file is
245 245 deleted before renaming the temporary file over it. This appears to be
246 246 unavoidable.
247 247
248 248 Parameters
249 249 ----------
250 250 path : str
251 251 The target file to write to.
252 252
253 253 text : bool, optional
254 254 Whether to open the file in text mode (i.e. to write unicode). Default is
255 255 True.
256 256
257 257 encoding : str, optional
258 258 The encoding to use for files opened in text mode. Default is UTF-8.
259 259
260 260 **kwargs
261 261 Passed to :func:`io.open`.
262 262 """
263 263 # realpath doesn't work on Windows: http://bugs.python.org/issue9949
264 264 # Luckily, we only need to resolve the file itself being a symlink, not
265 265 # any of its directories, so this will suffice:
266 266 if os.path.islink(path):
267 267 path = os.path.join(os.path.dirname(path), os.readlink(path))
268 268
269 269 dirname, basename = os.path.split(path)
270 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname)
270 tmp_dir = tempfile.mkdtemp(prefix=basename, dir=dirname)
271 tmp_path = os.path.join(tmp_dir, basename)
271 272 if text:
272 fileobj = io.open(handle, 'w', encoding=encoding, **kwargs)
273 fileobj = io.open(tmp_path, 'w', encoding=encoding, **kwargs)
273 274 else:
274 fileobj = io.open(handle, 'wb', **kwargs)
275 fileobj = io.open(tmp_path, 'wb', **kwargs)
275 276
276 277 try:
277 278 yield fileobj
278 279 except:
279 280 fileobj.close()
280 os.remove(tmp_path)
281 shutil.rmtree(tmp_dir)
281 282 raise
282 283
283 284 # Flush to disk
284 285 fileobj.flush()
285 286 os.fsync(fileobj.fileno())
286 287
287 288 # Written successfully, now rename it
288 289 fileobj.close()
289 290
290 291 # Copy permission bits, access time, etc.
291 292 try:
292 293 _copy_metadata(path, tmp_path)
293 294 except OSError:
294 295 # e.g. the file didn't already exist. Ignore any failure to copy metadata
295 296 pass
296 297
297 298 if os.name == 'nt' and os.path.exists(path):
298 299 # Rename over existing file doesn't work on Windows
299 300 os.remove(path)
300 301
301 302 os.rename(tmp_path, path)
303 shutil.rmtree(tmp_dir)
302 304
303 305
304 306 def raw_print(*args, **kw):
305 307 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
306 308
307 309 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
308 310 file=sys.__stdout__)
309 311 sys.__stdout__.flush()
310 312
311 313
312 314 def raw_print_err(*args, **kw):
313 315 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
314 316
315 317 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
316 318 file=sys.__stderr__)
317 319 sys.__stderr__.flush()
318 320
319 321
320 322 # Short aliases for quick debugging, do NOT use these in production code.
321 323 rprint = raw_print
322 324 rprinte = raw_print_err
323 325
324 326 def unicode_std_stream(stream='stdout'):
325 327 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
326 328
327 329 This ignores environment variables and default encodings, to reliably write
328 330 unicode to stdout or stderr.
329 331
330 332 ::
331 333
332 334 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
333 335 """
334 336 assert stream in ('stdout', 'stderr')
335 337 stream = getattr(sys, stream)
336 338 if PY3:
337 339 try:
338 340 stream_b = stream.buffer
339 341 except AttributeError:
340 342 # sys.stdout has been replaced - use it directly
341 343 return stream
342 344 else:
343 345 stream_b = stream
344 346
345 347 return codecs.getwriter('utf-8')(stream_b)
@@ -1,215 +1,231 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
10
11 #-----------------------------------------------------------------------------
12 # Imports
13 #-----------------------------------------------------------------------------
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
6
14 7 from __future__ import print_function
15 8 from __future__ import absolute_import
16 9
17 10 import io as stdlib_io
18 11 import os.path
19 12 import stat
20 13 import sys
21 14
22 15 from subprocess import Popen, PIPE
23 16 import unittest
24 17
25 18 import nose.tools as nt
26 19
27 from IPython.testing.decorators import skipif
20 from IPython.testing.decorators import skipif, skip_win32
28 21 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
29 22 atomic_writing,
30 23 )
31 24 from IPython.utils.py3compat import doctest_refactor_print, PY3
32 25 from IPython.utils.tempdir import TemporaryDirectory
33 26
34 27 if PY3:
35 28 from io import StringIO
36 29 else:
37 30 from StringIO import StringIO
38 31
39 #-----------------------------------------------------------------------------
40 # Tests
41 #-----------------------------------------------------------------------------
42
43 32
44 33 def test_tee_simple():
45 34 "Very simple check with stdout only"
46 35 chan = StringIO()
47 36 text = 'Hello'
48 37 tee = Tee(chan, channel='stdout')
49 38 print(text, file=chan)
50 39 nt.assert_equal(chan.getvalue(), text+"\n")
51 40
52 41
53 42 class TeeTestCase(unittest.TestCase):
54 43
55 44 def tchan(self, channel, check='close'):
56 45 trap = StringIO()
57 46 chan = StringIO()
58 47 text = 'Hello'
59 48
60 49 std_ori = getattr(sys, channel)
61 50 setattr(sys, channel, trap)
62 51
63 52 tee = Tee(chan, channel=channel)
64 53 print(text, end='', file=chan)
65 54 setattr(sys, channel, std_ori)
66 55 trap_val = trap.getvalue()
67 56 nt.assert_equal(chan.getvalue(), text)
68 57 if check=='close':
69 58 tee.close()
70 59 else:
71 60 del tee
72 61
73 62 def test(self):
74 63 for chan in ['stdout', 'stderr']:
75 64 for check in ['close', 'del']:
76 65 self.tchan(chan, check)
77 66
78 67 def test_io_init():
79 68 """Test that io.stdin/out/err exist at startup"""
80 69 for name in ('stdin', 'stdout', 'stderr'):
81 70 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
82 71 p = Popen([sys.executable, '-c', cmd],
83 72 stdout=PIPE)
84 73 p.wait()
85 74 classname = p.stdout.read().strip().decode('ascii')
86 75 # __class__ is a reference to the class object in Python 3, so we can't
87 76 # just test for string equality.
88 77 assert 'IPython.utils.io.IOStream' in classname, classname
89 78
90 79 def test_capture_output():
91 80 """capture_output() context works"""
92 81
93 82 with capture_output() as io:
94 83 print('hi, stdout')
95 84 print('hi, stderr', file=sys.stderr)
96 85
97 86 nt.assert_equal(io.stdout, 'hi, stdout\n')
98 87 nt.assert_equal(io.stderr, 'hi, stderr\n')
99 88
100 89 def test_UnicodeStdStream():
101 90 # Test wrapping a bytes-level stdout
102 91 if PY3:
103 92 stdoutb = stdlib_io.BytesIO()
104 93 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
105 94 else:
106 95 stdout = stdoutb = stdlib_io.BytesIO()
107 96
108 97 orig_stdout = sys.stdout
109 98 sys.stdout = stdout
110 99 try:
111 100 sample = u"@Ε‚e¢ŧ←"
112 101 unicode_std_stream().write(sample)
113 102
114 103 output = stdoutb.getvalue().decode('utf-8')
115 104 nt.assert_equal(output, sample)
116 105 assert not stdout.closed
117 106 finally:
118 107 sys.stdout = orig_stdout
119 108
120 109 @skipif(not PY3, "Not applicable on Python 2")
121 110 def test_UnicodeStdStream_nowrap():
122 111 # If we replace stdout with a StringIO, it shouldn't get wrapped.
123 112 orig_stdout = sys.stdout
124 113 sys.stdout = StringIO()
125 114 try:
126 115 nt.assert_is(unicode_std_stream(), sys.stdout)
127 116 assert not sys.stdout.closed
128 117 finally:
129 118 sys.stdout = orig_stdout
130 119
131 120 def test_atomic_writing():
132 121 class CustomExc(Exception): pass
133 122
134 123 with TemporaryDirectory() as td:
135 124 f1 = os.path.join(td, 'penguin')
136 125 with stdlib_io.open(f1, 'w') as f:
137 126 f.write(u'Before')
138 127
139 128 if os.name != 'nt':
140 129 os.chmod(f1, 0o701)
141 130 orig_mode = stat.S_IMODE(os.stat(f1).st_mode)
142 131
143 132 f2 = os.path.join(td, 'flamingo')
144 133 try:
145 134 os.symlink(f1, f2)
146 135 have_symlink = True
147 136 except (AttributeError, NotImplementedError, OSError):
148 137 # AttributeError: Python doesn't support it
149 138 # NotImplementedError: The system doesn't support it
150 139 # OSError: The user lacks the privilege (Windows)
151 140 have_symlink = False
152 141
153 142 with nt.assert_raises(CustomExc):
154 143 with atomic_writing(f1) as f:
155 144 f.write(u'Failing write')
156 145 raise CustomExc
157 146
158 147 # Because of the exception, the file should not have been modified
159 148 with stdlib_io.open(f1, 'r') as f:
160 149 nt.assert_equal(f.read(), u'Before')
161 150
162 151 with atomic_writing(f1) as f:
163 152 f.write(u'Overwritten')
164 153
165 154 with stdlib_io.open(f1, 'r') as f:
166 155 nt.assert_equal(f.read(), u'Overwritten')
167 156
168 157 if os.name != 'nt':
169 158 mode = stat.S_IMODE(os.stat(f1).st_mode)
170 159 nt.assert_equal(mode, orig_mode)
171 160
172 161 if have_symlink:
173 162 # Check that writing over a file preserves a symlink
174 163 with atomic_writing(f2) as f:
175 164 f.write(u'written from symlink')
176 165
177 166 with stdlib_io.open(f1, 'r') as f:
178 167 nt.assert_equal(f.read(), u'written from symlink')
179 168
169 def _save_umask():
170 global umask
171 umask = os.umask(0)
172 os.umask(umask)
173
174 def _restore_umask():
175 os.umask(umask)
176
177 @skip_win32
178 @nt.with_setup(_save_umask, _restore_umask)
179 def test_atomic_writing_umask():
180 with TemporaryDirectory() as td:
181 os.umask(0o022)
182 f1 = os.path.join(td, '1')
183 with atomic_writing(f1) as f:
184 f.write(u'1')
185 mode = stat.S_IMODE(os.stat(f1).st_mode)
186 nt.assert_equal(mode, 0o644, '{:o} != 644'.format(mode))
187
188 os.umask(0o057)
189 f2 = os.path.join(td, '2')
190 with atomic_writing(f2) as f:
191 f.write(u'2')
192 mode = stat.S_IMODE(os.stat(f2).st_mode)
193 nt.assert_equal(mode, 0o620, '{:o} != 620'.format(mode))
194
195
180 196 def test_atomic_writing_newlines():
181 197 with TemporaryDirectory() as td:
182 198 path = os.path.join(td, 'testfile')
183 199
184 200 lf = u'a\nb\nc\n'
185 201 plat = lf.replace(u'\n', os.linesep)
186 202 crlf = lf.replace(u'\n', u'\r\n')
187 203
188 204 # test default
189 205 with stdlib_io.open(path, 'w') as f:
190 206 f.write(lf)
191 207 with stdlib_io.open(path, 'r', newline='') as f:
192 208 read = f.read()
193 209 nt.assert_equal(read, plat)
194 210
195 211 # test newline=LF
196 212 with stdlib_io.open(path, 'w', newline='\n') as f:
197 213 f.write(lf)
198 214 with stdlib_io.open(path, 'r', newline='') as f:
199 215 read = f.read()
200 216 nt.assert_equal(read, lf)
201 217
202 218 # test newline=CRLF
203 219 with atomic_writing(path, newline='\r\n') as f:
204 220 f.write(lf)
205 221 with stdlib_io.open(path, 'r', newline='') as f:
206 222 read = f.read()
207 223 nt.assert_equal(read, crlf)
208 224
209 225 # test newline=no convert
210 226 text = u'crlf\r\ncr\rlf\n'
211 227 with atomic_writing(path, newline='') as f:
212 228 f.write(text)
213 229 with stdlib_io.open(path, 'r', newline='') as f:
214 230 read = f.read()
215 231 nt.assert_equal(read, text)
General Comments 0
You need to be logged in to leave comments. Login now