##// END OF EJS Templates
Copy file metadata in atomic save...
Thomas Kluyver -
Show More
@@ -1,319 +1,339 b''
1 1 # encoding: utf-8
2 2 """
3 3 IO related utilities.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12 from __future__ import print_function
13 13 from __future__ import absolute_import
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 import codecs
19 19 from contextlib import contextmanager
20 20 import io
21 21 import os
22 import shutil
23 import stat
22 24 import sys
23 25 import tempfile
24 26 from .capture import CapturedIO, capture_output
25 27 from .py3compat import string_types, input, PY3
26 28
27 29 #-----------------------------------------------------------------------------
28 30 # Code
29 31 #-----------------------------------------------------------------------------
30 32
31 33
32 34 class IOStream:
33 35
34 36 def __init__(self,stream, fallback=None):
35 37 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
36 38 if fallback is not None:
37 39 stream = fallback
38 40 else:
39 41 raise ValueError("fallback required, but not specified")
40 42 self.stream = stream
41 43 self._swrite = stream.write
42 44
43 45 # clone all methods not overridden:
44 46 def clone(meth):
45 47 return not hasattr(self, meth) and not meth.startswith('_')
46 48 for meth in filter(clone, dir(stream)):
47 49 setattr(self, meth, getattr(stream, meth))
48 50
49 51 def __repr__(self):
50 52 cls = self.__class__
51 53 tpl = '{mod}.{cls}({args})'
52 54 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
53 55
54 56 def write(self,data):
55 57 try:
56 58 self._swrite(data)
57 59 except:
58 60 try:
59 61 # print handles some unicode issues which may trip a plain
60 62 # write() call. Emulate write() by using an empty end
61 63 # argument.
62 64 print(data, end='', file=self.stream)
63 65 except:
64 66 # if we get here, something is seriously broken.
65 67 print('ERROR - failed to write data to stream:', self.stream,
66 68 file=sys.stderr)
67 69
68 70 def writelines(self, lines):
69 71 if isinstance(lines, string_types):
70 72 lines = [lines]
71 73 for line in lines:
72 74 self.write(line)
73 75
74 76 # This class used to have a writeln method, but regular files and streams
75 77 # in Python don't have this method. We need to keep this completely
76 78 # compatible so we removed it.
77 79
78 80 @property
79 81 def closed(self):
80 82 return self.stream.closed
81 83
82 84 def close(self):
83 85 pass
84 86
85 87 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
86 88 devnull = open(os.devnull, 'w')
87 89 stdin = IOStream(sys.stdin, fallback=devnull)
88 90 stdout = IOStream(sys.stdout, fallback=devnull)
89 91 stderr = IOStream(sys.stderr, fallback=devnull)
90 92
91 93 class IOTerm:
92 94 """ Term holds the file or file-like objects for handling I/O operations.
93 95
94 96 These are normally just sys.stdin, sys.stdout and sys.stderr but for
95 97 Windows they can can replaced to allow editing the strings before they are
96 98 displayed."""
97 99
98 100 # In the future, having IPython channel all its I/O operations through
99 101 # this class will make it easier to embed it into other environments which
100 102 # are not a normal terminal (such as a GUI-based shell)
101 103 def __init__(self, stdin=None, stdout=None, stderr=None):
102 104 mymodule = sys.modules[__name__]
103 105 self.stdin = IOStream(stdin, mymodule.stdin)
104 106 self.stdout = IOStream(stdout, mymodule.stdout)
105 107 self.stderr = IOStream(stderr, mymodule.stderr)
106 108
107 109
108 110 class Tee(object):
109 111 """A class to duplicate an output stream to stdout/err.
110 112
111 113 This works in a manner very similar to the Unix 'tee' command.
112 114
113 115 When the object is closed or deleted, it closes the original file given to
114 116 it for duplication.
115 117 """
116 118 # Inspired by:
117 119 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
118 120
119 121 def __init__(self, file_or_name, mode="w", channel='stdout'):
120 122 """Construct a new Tee object.
121 123
122 124 Parameters
123 125 ----------
124 126 file_or_name : filename or open filehandle (writable)
125 127 File that will be duplicated
126 128
127 129 mode : optional, valid mode for open().
128 130 If a filename was give, open with this mode.
129 131
130 132 channel : str, one of ['stdout', 'stderr']
131 133 """
132 134 if channel not in ['stdout', 'stderr']:
133 135 raise ValueError('Invalid channel spec %s' % channel)
134 136
135 137 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
136 138 self.file = file_or_name
137 139 else:
138 140 self.file = open(file_or_name, mode)
139 141 self.channel = channel
140 142 self.ostream = getattr(sys, channel)
141 143 setattr(sys, channel, self)
142 144 self._closed = False
143 145
144 146 def close(self):
145 147 """Close the file and restore the channel."""
146 148 self.flush()
147 149 setattr(sys, self.channel, self.ostream)
148 150 self.file.close()
149 151 self._closed = True
150 152
151 153 def write(self, data):
152 154 """Write data to both channels."""
153 155 self.file.write(data)
154 156 self.ostream.write(data)
155 157 self.ostream.flush()
156 158
157 159 def flush(self):
158 160 """Flush both channels."""
159 161 self.file.flush()
160 162 self.ostream.flush()
161 163
162 164 def __del__(self):
163 165 if not self._closed:
164 166 self.close()
165 167
166 168
167 169 def ask_yes_no(prompt, default=None, interrupt=None):
168 170 """Asks a question and returns a boolean (y/n) answer.
169 171
170 172 If default is given (one of 'y','n'), it is used if the user input is
171 173 empty. If interrupt is given (one of 'y','n'), it is used if the user
172 174 presses Ctrl-C. Otherwise the question is repeated until an answer is
173 175 given.
174 176
175 177 An EOF is treated as the default answer. If there is no default, an
176 178 exception is raised to prevent infinite loops.
177 179
178 180 Valid answers are: y/yes/n/no (match is not case sensitive)."""
179 181
180 182 answers = {'y':True,'n':False,'yes':True,'no':False}
181 183 ans = None
182 184 while ans not in answers.keys():
183 185 try:
184 186 ans = input(prompt+' ').lower()
185 187 if not ans: # response was an empty string
186 188 ans = default
187 189 except KeyboardInterrupt:
188 190 if interrupt:
189 191 ans = interrupt
190 192 except EOFError:
191 193 if default in answers.keys():
192 194 ans = default
193 195 print()
194 196 else:
195 197 raise
196 198
197 199 return answers[ans]
198 200
199 201
200 202 def temp_pyfile(src, ext='.py'):
201 203 """Make a temporary python file, return filename and filehandle.
202 204
203 205 Parameters
204 206 ----------
205 207 src : string or list of strings (no need for ending newlines if list)
206 208 Source code to be written to the file.
207 209
208 210 ext : optional, string
209 211 Extension for the generated file.
210 212
211 213 Returns
212 214 -------
213 215 (filename, open filehandle)
214 216 It is the caller's responsibility to close the open file and unlink it.
215 217 """
216 218 fname = tempfile.mkstemp(ext)[1]
217 219 f = open(fname,'w')
218 220 f.write(src)
219 221 f.flush()
220 222 return fname, f
221 223
224 def _copy_metadata(src, dst):
225 """Copy the set of metadata we want for atomic_writing.
226
227 Permission bits and flags. We'd like to copy file ownership as well, but we
228 can't do that.
229 """
230 shutil.copymode(src, dst)
231 st = os.stat(src)
232 if hasattr(os, 'chflags') and hasattr(st, 'st_flags'):
233 os.chflags(st.st_flags)
234
222 235 @contextmanager
223 236 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
224 237 """Context manager to write to a file only if the entire write is successful.
225 238
226 239 This works by creating a temporary file in the same directory, and renaming
227 240 it over the old file if the context is exited without an error. If the
228 241 target file is a symlink or a hardlink, this will not be preserved: it will
229 242 be replaced by a new regular file.
230 243
231 244 On Windows, there is a small chink in the atomicity: the target file is
232 245 deleted before renaming the temporary file over it. This appears to be
233 246 unavoidable.
234 247
235 248 Parameters
236 249 ----------
237 250 path : str
238 251 The target file to write to.
239 252
240 253 text : bool, optional
241 254 Whether to open the file in text mode (i.e. to write unicode). Default is
242 255 True.
243 256
244 257 encoding : str, optional
245 258 The encoding to use for files opened in text mode. Default is UTF-8.
246 259
247 260 **kwargs
248 261 Passed to :func:`io.open`.
249 262 """
250 263 dirname, basename = os.path.split(path)
251 264 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname, text=text)
252 265 if text:
253 266 fileobj = io.open(handle, 'w', encoding=encoding, **kwargs)
254 267 else:
255 268 fileobj = io.open(handle, 'wb', **kwargs)
256 269
257 270 try:
258 271 yield fileobj
259 272 except:
260 273 fileobj.close()
261 274 os.remove(tmp_path)
262 275 raise
263 276
264 277 # Flush to disk
265 278 fileobj.flush()
266 279 os.fsync(fileobj.fileno())
267 280
268 281 # Written successfully, now rename it
269 282 fileobj.close()
270 283
284 # Copy permission bits, access time, etc.
285 try:
286 _copy_metadata(path, tmp_path)
287 except OSError:
288 # e.g. the file didn't already exist. Ignore any failure to copy metadata
289 pass
290
271 291 if os.name == 'nt' and os.path.exists(path):
272 292 # Rename over existing file doesn't work on Windows
273 293 os.remove(path)
274 294
275 295 os.rename(tmp_path, path)
276 296
277 297
278 298 def raw_print(*args, **kw):
279 299 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
280 300
281 301 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
282 302 file=sys.__stdout__)
283 303 sys.__stdout__.flush()
284 304
285 305
286 306 def raw_print_err(*args, **kw):
287 307 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
288 308
289 309 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
290 310 file=sys.__stderr__)
291 311 sys.__stderr__.flush()
292 312
293 313
294 314 # Short aliases for quick debugging, do NOT use these in production code.
295 315 rprint = raw_print
296 316 rprinte = raw_print_err
297 317
298 318 def unicode_std_stream(stream='stdout'):
299 319 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
300 320
301 321 This ignores environment variables and default encodings, to reliably write
302 322 unicode to stdout or stderr.
303 323
304 324 ::
305 325
306 326 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
307 327 """
308 328 assert stream in ('stdout', 'stderr')
309 329 stream = getattr(sys, stream)
310 330 if PY3:
311 331 try:
312 332 stream_b = stream.buffer
313 333 except AttributeError:
314 334 # sys.stdout has been replaced - use it directly
315 335 return stream
316 336 else:
317 337 stream_b = stream
318 338
319 339 return codecs.getwriter('utf-8')(stream_b)
@@ -1,151 +1,160 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15 from __future__ import absolute_import
16 16
17 17 import io as stdlib_io
18 18 import os.path
19 import stat
19 20 import sys
20 21
21 22 from subprocess import Popen, PIPE
22 23 import unittest
23 24
24 25 import nose.tools as nt
25 26
26 27 from IPython.testing.decorators import skipif
27 28 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
28 29 atomic_writing,
29 30 )
30 31 from IPython.utils.py3compat import doctest_refactor_print, PY3
31 32 from IPython.utils.tempdir import TemporaryDirectory
32 33
33 34 if PY3:
34 35 from io import StringIO
35 36 else:
36 37 from StringIO import StringIO
37 38
38 39 #-----------------------------------------------------------------------------
39 40 # Tests
40 41 #-----------------------------------------------------------------------------
41 42
42 43
43 44 def test_tee_simple():
44 45 "Very simple check with stdout only"
45 46 chan = StringIO()
46 47 text = 'Hello'
47 48 tee = Tee(chan, channel='stdout')
48 49 print(text, file=chan)
49 50 nt.assert_equal(chan.getvalue(), text+"\n")
50 51
51 52
52 53 class TeeTestCase(unittest.TestCase):
53 54
54 55 def tchan(self, channel, check='close'):
55 56 trap = StringIO()
56 57 chan = StringIO()
57 58 text = 'Hello'
58 59
59 60 std_ori = getattr(sys, channel)
60 61 setattr(sys, channel, trap)
61 62
62 63 tee = Tee(chan, channel=channel)
63 64 print(text, end='', file=chan)
64 65 setattr(sys, channel, std_ori)
65 66 trap_val = trap.getvalue()
66 67 nt.assert_equal(chan.getvalue(), text)
67 68 if check=='close':
68 69 tee.close()
69 70 else:
70 71 del tee
71 72
72 73 def test(self):
73 74 for chan in ['stdout', 'stderr']:
74 75 for check in ['close', 'del']:
75 76 self.tchan(chan, check)
76 77
77 78 def test_io_init():
78 79 """Test that io.stdin/out/err exist at startup"""
79 80 for name in ('stdin', 'stdout', 'stderr'):
80 81 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
81 82 p = Popen([sys.executable, '-c', cmd],
82 83 stdout=PIPE)
83 84 p.wait()
84 85 classname = p.stdout.read().strip().decode('ascii')
85 86 # __class__ is a reference to the class object in Python 3, so we can't
86 87 # just test for string equality.
87 88 assert 'IPython.utils.io.IOStream' in classname, classname
88 89
89 90 def test_capture_output():
90 91 """capture_output() context works"""
91 92
92 93 with capture_output() as io:
93 94 print('hi, stdout')
94 95 print('hi, stderr', file=sys.stderr)
95 96
96 97 nt.assert_equal(io.stdout, 'hi, stdout\n')
97 98 nt.assert_equal(io.stderr, 'hi, stderr\n')
98 99
99 100 def test_UnicodeStdStream():
100 101 # Test wrapping a bytes-level stdout
101 102 if PY3:
102 103 stdoutb = stdlib_io.BytesIO()
103 104 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
104 105 else:
105 106 stdout = stdoutb = stdlib_io.BytesIO()
106 107
107 108 orig_stdout = sys.stdout
108 109 sys.stdout = stdout
109 110 try:
110 111 sample = u"@Ε‚e¢ŧ←"
111 112 unicode_std_stream().write(sample)
112 113
113 114 output = stdoutb.getvalue().decode('utf-8')
114 115 nt.assert_equal(output, sample)
115 116 assert not stdout.closed
116 117 finally:
117 118 sys.stdout = orig_stdout
118 119
119 120 @skipif(not PY3, "Not applicable on Python 2")
120 121 def test_UnicodeStdStream_nowrap():
121 122 # If we replace stdout with a StringIO, it shouldn't get wrapped.
122 123 orig_stdout = sys.stdout
123 124 sys.stdout = StringIO()
124 125 try:
125 126 nt.assert_is(unicode_std_stream(), sys.stdout)
126 127 assert not sys.stdout.closed
127 128 finally:
128 129 sys.stdout = orig_stdout
129 130
130 131 def test_atomic_writing():
131 132 class CustomExc(Exception): pass
132 133
133 134 with TemporaryDirectory() as td:
134 135 f1 = os.path.join(td, 'penguin')
135 136 with stdlib_io.open(f1, 'w') as f:
136 137 f.write(u'Before')
138
139 if os.name != 'nt':
140 os.chmod(f1, 0o701)
141 orig_mode = stat.S_IMODE(os.stat(f1).st_mode)
137 142
138 143 with nt.assert_raises(CustomExc):
139 144 with atomic_writing(f1) as f:
140 145 f.write(u'Failing write')
141 146 raise CustomExc
142 147
143 148 # Because of the exception, the file should not have been modified
144 149 with stdlib_io.open(f1, 'r') as f:
145 150 nt.assert_equal(f.read(), u'Before')
146 151
147 152 with atomic_writing(f1) as f:
148 153 f.write(u'Overwritten')
149 154
150 155 with stdlib_io.open(f1, 'r') as f:
151 156 nt.assert_equal(f.read(), u'Overwritten')
157
158 if os.name != 'nt':
159 mode = stat.S_IMODE(os.stat(f1).st_mode)
160 nt.assert_equal(mode, orig_mode)
General Comments 0
You need to be logged in to leave comments. Login now