##// END OF EJS Templates
Rework atomic_writing with tests & docstring
Thomas Kluyver -
Show More
@@ -1,532 +1,532 b''
1 """A contents manager that uses the local file system for storage."""
1 """A contents manager that uses the local file system for storage."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import base64
6 import base64
7 import io
7 import io
8 import os
8 import os
9 import glob
9 import glob
10 import shutil
10 import shutil
11
11
12 from tornado import web
12 from tornado import web
13
13
14 from .manager import ContentsManager
14 from .manager import ContentsManager
15 from IPython.nbformat import current
15 from IPython.nbformat import current
16 from IPython.utils.io import atomic_writing
16 from IPython.utils.io import atomic_writing
17 from IPython.utils.path import ensure_dir_exists
17 from IPython.utils.path import ensure_dir_exists
18 from IPython.utils.traitlets import Unicode, Bool, TraitError
18 from IPython.utils.traitlets import Unicode, Bool, TraitError
19 from IPython.utils.py3compat import getcwd
19 from IPython.utils.py3compat import getcwd
20 from IPython.utils import tz
20 from IPython.utils import tz
21 from IPython.html.utils import is_hidden, to_os_path, url_path_join
21 from IPython.html.utils import is_hidden, to_os_path, url_path_join
22
22
23
23
24 class FileContentsManager(ContentsManager):
24 class FileContentsManager(ContentsManager):
25
25
26 root_dir = Unicode(getcwd(), config=True)
26 root_dir = Unicode(getcwd(), config=True)
27
27
28 save_script = Bool(False, config=True, help='DEPRECATED, IGNORED')
28 save_script = Bool(False, config=True, help='DEPRECATED, IGNORED')
29 def _save_script_changed(self):
29 def _save_script_changed(self):
30 self.log.warn("""
30 self.log.warn("""
31 Automatically saving notebooks as scripts has been removed.
31 Automatically saving notebooks as scripts has been removed.
32 Use `ipython nbconvert --to python [notebook]` instead.
32 Use `ipython nbconvert --to python [notebook]` instead.
33 """)
33 """)
34
34
35 def _root_dir_changed(self, name, old, new):
35 def _root_dir_changed(self, name, old, new):
36 """Do a bit of validation of the root_dir."""
36 """Do a bit of validation of the root_dir."""
37 if not os.path.isabs(new):
37 if not os.path.isabs(new):
38 # If we receive a non-absolute path, make it absolute.
38 # If we receive a non-absolute path, make it absolute.
39 self.root_dir = os.path.abspath(new)
39 self.root_dir = os.path.abspath(new)
40 return
40 return
41 if not os.path.isdir(new):
41 if not os.path.isdir(new):
42 raise TraitError("%r is not a directory" % new)
42 raise TraitError("%r is not a directory" % new)
43
43
44 checkpoint_dir = Unicode('.ipynb_checkpoints', config=True,
44 checkpoint_dir = Unicode('.ipynb_checkpoints', config=True,
45 help="""The directory name in which to keep file checkpoints
45 help="""The directory name in which to keep file checkpoints
46
46
47 This is a path relative to the file's own directory.
47 This is a path relative to the file's own directory.
48
48
49 By default, it is .ipynb_checkpoints
49 By default, it is .ipynb_checkpoints
50 """
50 """
51 )
51 )
52
52
53 def _copy(self, src, dest):
53 def _copy(self, src, dest):
54 """copy src to dest
54 """copy src to dest
55
55
56 like shutil.copy2, but log errors in copystat
56 like shutil.copy2, but log errors in copystat
57 """
57 """
58 shutil.copyfile(src, dest)
58 shutil.copyfile(src, dest)
59 try:
59 try:
60 shutil.copystat(src, dest)
60 shutil.copystat(src, dest)
61 except OSError as e:
61 except OSError as e:
62 self.log.debug("copystat on %s failed", dest, exc_info=True)
62 self.log.debug("copystat on %s failed", dest, exc_info=True)
63
63
64 def _get_os_path(self, name=None, path=''):
64 def _get_os_path(self, name=None, path=''):
65 """Given a filename and API path, return its file system
65 """Given a filename and API path, return its file system
66 path.
66 path.
67
67
68 Parameters
68 Parameters
69 ----------
69 ----------
70 name : string
70 name : string
71 A filename
71 A filename
72 path : string
72 path : string
73 The relative API path to the named file.
73 The relative API path to the named file.
74
74
75 Returns
75 Returns
76 -------
76 -------
77 path : string
77 path : string
78 API path to be evaluated relative to root_dir.
78 API path to be evaluated relative to root_dir.
79 """
79 """
80 if name is not None:
80 if name is not None:
81 path = url_path_join(path, name)
81 path = url_path_join(path, name)
82 return to_os_path(path, self.root_dir)
82 return to_os_path(path, self.root_dir)
83
83
84 def path_exists(self, path):
84 def path_exists(self, path):
85 """Does the API-style path refer to an extant directory?
85 """Does the API-style path refer to an extant directory?
86
86
87 API-style wrapper for os.path.isdir
87 API-style wrapper for os.path.isdir
88
88
89 Parameters
89 Parameters
90 ----------
90 ----------
91 path : string
91 path : string
92 The path to check. This is an API path (`/` separated,
92 The path to check. This is an API path (`/` separated,
93 relative to root_dir).
93 relative to root_dir).
94
94
95 Returns
95 Returns
96 -------
96 -------
97 exists : bool
97 exists : bool
98 Whether the path is indeed a directory.
98 Whether the path is indeed a directory.
99 """
99 """
100 path = path.strip('/')
100 path = path.strip('/')
101 os_path = self._get_os_path(path=path)
101 os_path = self._get_os_path(path=path)
102 return os.path.isdir(os_path)
102 return os.path.isdir(os_path)
103
103
104 def is_hidden(self, path):
104 def is_hidden(self, path):
105 """Does the API style path correspond to a hidden directory or file?
105 """Does the API style path correspond to a hidden directory or file?
106
106
107 Parameters
107 Parameters
108 ----------
108 ----------
109 path : string
109 path : string
110 The path to check. This is an API path (`/` separated,
110 The path to check. This is an API path (`/` separated,
111 relative to root_dir).
111 relative to root_dir).
112
112
113 Returns
113 Returns
114 -------
114 -------
115 exists : bool
115 exists : bool
116 Whether the path is hidden.
116 Whether the path is hidden.
117
117
118 """
118 """
119 path = path.strip('/')
119 path = path.strip('/')
120 os_path = self._get_os_path(path=path)
120 os_path = self._get_os_path(path=path)
121 return is_hidden(os_path, self.root_dir)
121 return is_hidden(os_path, self.root_dir)
122
122
123 def file_exists(self, name, path=''):
123 def file_exists(self, name, path=''):
124 """Returns True if the file exists, else returns False.
124 """Returns True if the file exists, else returns False.
125
125
126 API-style wrapper for os.path.isfile
126 API-style wrapper for os.path.isfile
127
127
128 Parameters
128 Parameters
129 ----------
129 ----------
130 name : string
130 name : string
131 The name of the file you are checking.
131 The name of the file you are checking.
132 path : string
132 path : string
133 The relative path to the file's directory (with '/' as separator)
133 The relative path to the file's directory (with '/' as separator)
134
134
135 Returns
135 Returns
136 -------
136 -------
137 exists : bool
137 exists : bool
138 Whether the file exists.
138 Whether the file exists.
139 """
139 """
140 path = path.strip('/')
140 path = path.strip('/')
141 nbpath = self._get_os_path(name, path=path)
141 nbpath = self._get_os_path(name, path=path)
142 return os.path.isfile(nbpath)
142 return os.path.isfile(nbpath)
143
143
144 def exists(self, name=None, path=''):
144 def exists(self, name=None, path=''):
145 """Returns True if the path [and name] exists, else returns False.
145 """Returns True if the path [and name] exists, else returns False.
146
146
147 API-style wrapper for os.path.exists
147 API-style wrapper for os.path.exists
148
148
149 Parameters
149 Parameters
150 ----------
150 ----------
151 name : string
151 name : string
152 The name of the file you are checking.
152 The name of the file you are checking.
153 path : string
153 path : string
154 The relative path to the file's directory (with '/' as separator)
154 The relative path to the file's directory (with '/' as separator)
155
155
156 Returns
156 Returns
157 -------
157 -------
158 exists : bool
158 exists : bool
159 Whether the target exists.
159 Whether the target exists.
160 """
160 """
161 path = path.strip('/')
161 path = path.strip('/')
162 os_path = self._get_os_path(name, path=path)
162 os_path = self._get_os_path(name, path=path)
163 return os.path.exists(os_path)
163 return os.path.exists(os_path)
164
164
165 def _base_model(self, name, path=''):
165 def _base_model(self, name, path=''):
166 """Build the common base of a contents model"""
166 """Build the common base of a contents model"""
167 os_path = self._get_os_path(name, path)
167 os_path = self._get_os_path(name, path)
168 info = os.stat(os_path)
168 info = os.stat(os_path)
169 last_modified = tz.utcfromtimestamp(info.st_mtime)
169 last_modified = tz.utcfromtimestamp(info.st_mtime)
170 created = tz.utcfromtimestamp(info.st_ctime)
170 created = tz.utcfromtimestamp(info.st_ctime)
171 # Create the base model.
171 # Create the base model.
172 model = {}
172 model = {}
173 model['name'] = name
173 model['name'] = name
174 model['path'] = path
174 model['path'] = path
175 model['last_modified'] = last_modified
175 model['last_modified'] = last_modified
176 model['created'] = created
176 model['created'] = created
177 model['content'] = None
177 model['content'] = None
178 model['format'] = None
178 model['format'] = None
179 return model
179 return model
180
180
181 def _dir_model(self, name, path='', content=True):
181 def _dir_model(self, name, path='', content=True):
182 """Build a model for a directory
182 """Build a model for a directory
183
183
184 if content is requested, will include a listing of the directory
184 if content is requested, will include a listing of the directory
185 """
185 """
186 os_path = self._get_os_path(name, path)
186 os_path = self._get_os_path(name, path)
187
187
188 four_o_four = u'directory does not exist: %r' % os_path
188 four_o_four = u'directory does not exist: %r' % os_path
189
189
190 if not os.path.isdir(os_path):
190 if not os.path.isdir(os_path):
191 raise web.HTTPError(404, four_o_four)
191 raise web.HTTPError(404, four_o_four)
192 elif is_hidden(os_path, self.root_dir):
192 elif is_hidden(os_path, self.root_dir):
193 self.log.info("Refusing to serve hidden directory %r, via 404 Error",
193 self.log.info("Refusing to serve hidden directory %r, via 404 Error",
194 os_path
194 os_path
195 )
195 )
196 raise web.HTTPError(404, four_o_four)
196 raise web.HTTPError(404, four_o_four)
197
197
198 if name is None:
198 if name is None:
199 if '/' in path:
199 if '/' in path:
200 path, name = path.rsplit('/', 1)
200 path, name = path.rsplit('/', 1)
201 else:
201 else:
202 name = ''
202 name = ''
203 model = self._base_model(name, path)
203 model = self._base_model(name, path)
204 model['type'] = 'directory'
204 model['type'] = 'directory'
205 dir_path = u'{}/{}'.format(path, name)
205 dir_path = u'{}/{}'.format(path, name)
206 if content:
206 if content:
207 model['content'] = contents = []
207 model['content'] = contents = []
208 for os_path in glob.glob(self._get_os_path('*', dir_path)):
208 for os_path in glob.glob(self._get_os_path('*', dir_path)):
209 name = os.path.basename(os_path)
209 name = os.path.basename(os_path)
210 if self.should_list(name) and not is_hidden(os_path, self.root_dir):
210 if self.should_list(name) and not is_hidden(os_path, self.root_dir):
211 contents.append(self.get_model(name=name, path=dir_path, content=False))
211 contents.append(self.get_model(name=name, path=dir_path, content=False))
212
212
213 model['format'] = 'json'
213 model['format'] = 'json'
214
214
215 return model
215 return model
216
216
217 def _file_model(self, name, path='', content=True):
217 def _file_model(self, name, path='', content=True):
218 """Build a model for a file
218 """Build a model for a file
219
219
220 if content is requested, include the file contents.
220 if content is requested, include the file contents.
221 UTF-8 text files will be unicode, binary files will be base64-encoded.
221 UTF-8 text files will be unicode, binary files will be base64-encoded.
222 """
222 """
223 model = self._base_model(name, path)
223 model = self._base_model(name, path)
224 model['type'] = 'file'
224 model['type'] = 'file'
225 if content:
225 if content:
226 os_path = self._get_os_path(name, path)
226 os_path = self._get_os_path(name, path)
227 with io.open(os_path, 'rb') as f:
227 with io.open(os_path, 'rb') as f:
228 bcontent = f.read()
228 bcontent = f.read()
229 try:
229 try:
230 model['content'] = bcontent.decode('utf8')
230 model['content'] = bcontent.decode('utf8')
231 except UnicodeError as e:
231 except UnicodeError as e:
232 model['content'] = base64.encodestring(bcontent).decode('ascii')
232 model['content'] = base64.encodestring(bcontent).decode('ascii')
233 model['format'] = 'base64'
233 model['format'] = 'base64'
234 else:
234 else:
235 model['format'] = 'text'
235 model['format'] = 'text'
236 return model
236 return model
237
237
238
238
239 def _notebook_model(self, name, path='', content=True):
239 def _notebook_model(self, name, path='', content=True):
240 """Build a notebook model
240 """Build a notebook model
241
241
242 if content is requested, the notebook content will be populated
242 if content is requested, the notebook content will be populated
243 as a JSON structure (not double-serialized)
243 as a JSON structure (not double-serialized)
244 """
244 """
245 model = self._base_model(name, path)
245 model = self._base_model(name, path)
246 model['type'] = 'notebook'
246 model['type'] = 'notebook'
247 if content:
247 if content:
248 os_path = self._get_os_path(name, path)
248 os_path = self._get_os_path(name, path)
249 with io.open(os_path, 'r', encoding='utf-8') as f:
249 with io.open(os_path, 'r', encoding='utf-8') as f:
250 try:
250 try:
251 nb = current.read(f, u'json')
251 nb = current.read(f, u'json')
252 except Exception as e:
252 except Exception as e:
253 raise web.HTTPError(400, u"Unreadable Notebook: %s %s" % (os_path, e))
253 raise web.HTTPError(400, u"Unreadable Notebook: %s %s" % (os_path, e))
254 self.mark_trusted_cells(nb, name, path)
254 self.mark_trusted_cells(nb, name, path)
255 model['content'] = nb
255 model['content'] = nb
256 model['format'] = 'json'
256 model['format'] = 'json'
257 return model
257 return model
258
258
259 def get_model(self, name, path='', content=True):
259 def get_model(self, name, path='', content=True):
260 """ Takes a path and name for an entity and returns its model
260 """ Takes a path and name for an entity and returns its model
261
261
262 Parameters
262 Parameters
263 ----------
263 ----------
264 name : str
264 name : str
265 the name of the target
265 the name of the target
266 path : str
266 path : str
267 the API path that describes the relative path for the target
267 the API path that describes the relative path for the target
268
268
269 Returns
269 Returns
270 -------
270 -------
271 model : dict
271 model : dict
272 the contents model. If content=True, returns the contents
272 the contents model. If content=True, returns the contents
273 of the file or directory as well.
273 of the file or directory as well.
274 """
274 """
275 path = path.strip('/')
275 path = path.strip('/')
276
276
277 if not self.exists(name=name, path=path):
277 if not self.exists(name=name, path=path):
278 raise web.HTTPError(404, u'No such file or directory: %s/%s' % (path, name))
278 raise web.HTTPError(404, u'No such file or directory: %s/%s' % (path, name))
279
279
280 os_path = self._get_os_path(name, path)
280 os_path = self._get_os_path(name, path)
281 if os.path.isdir(os_path):
281 if os.path.isdir(os_path):
282 model = self._dir_model(name, path, content)
282 model = self._dir_model(name, path, content)
283 elif name.endswith('.ipynb'):
283 elif name.endswith('.ipynb'):
284 model = self._notebook_model(name, path, content)
284 model = self._notebook_model(name, path, content)
285 else:
285 else:
286 model = self._file_model(name, path, content)
286 model = self._file_model(name, path, content)
287 return model
287 return model
288
288
289 def _save_notebook(self, os_path, model, name='', path=''):
289 def _save_notebook(self, os_path, model, name='', path=''):
290 """save a notebook file"""
290 """save a notebook file"""
291 # Save the notebook file
291 # Save the notebook file
292 nb = current.to_notebook_json(model['content'])
292 nb = current.to_notebook_json(model['content'])
293
293
294 self.check_and_sign(nb, name, path)
294 self.check_and_sign(nb, name, path)
295
295
296 if 'name' in nb['metadata']:
296 if 'name' in nb['metadata']:
297 nb['metadata']['name'] = u''
297 nb['metadata']['name'] = u''
298
298
299 with atomic_writing(os_path, encoding='utf-8') as f:
299 with atomic_writing(os_path, encoding='utf-8') as f:
300 current.write(nb, f, u'json')
300 current.write(nb, f, u'json')
301
301
302 def _save_file(self, os_path, model, name='', path=''):
302 def _save_file(self, os_path, model, name='', path=''):
303 """save a non-notebook file"""
303 """save a non-notebook file"""
304 fmt = model.get('format', None)
304 fmt = model.get('format', None)
305 if fmt not in {'text', 'base64'}:
305 if fmt not in {'text', 'base64'}:
306 raise web.HTTPError(400, "Must specify format of file contents as 'text' or 'base64'")
306 raise web.HTTPError(400, "Must specify format of file contents as 'text' or 'base64'")
307 try:
307 try:
308 content = model['content']
308 content = model['content']
309 if fmt == 'text':
309 if fmt == 'text':
310 bcontent = content.encode('utf8')
310 bcontent = content.encode('utf8')
311 else:
311 else:
312 b64_bytes = content.encode('ascii')
312 b64_bytes = content.encode('ascii')
313 bcontent = base64.decodestring(b64_bytes)
313 bcontent = base64.decodestring(b64_bytes)
314 except Exception as e:
314 except Exception as e:
315 raise web.HTTPError(400, u'Encoding error saving %s: %s' % (os_path, e))
315 raise web.HTTPError(400, u'Encoding error saving %s: %s' % (os_path, e))
316 with atomic_writing(os_path, 'wb') as f:
316 with atomic_writing(os_path, text=False) as f:
317 f.write(bcontent)
317 f.write(bcontent)
318
318
319 def _save_directory(self, os_path, model, name='', path=''):
319 def _save_directory(self, os_path, model, name='', path=''):
320 """create a directory"""
320 """create a directory"""
321 if is_hidden(os_path, self.root_dir):
321 if is_hidden(os_path, self.root_dir):
322 raise web.HTTPError(400, u'Cannot create hidden directory %r' % os_path)
322 raise web.HTTPError(400, u'Cannot create hidden directory %r' % os_path)
323 if not os.path.exists(os_path):
323 if not os.path.exists(os_path):
324 os.mkdir(os_path)
324 os.mkdir(os_path)
325 elif not os.path.isdir(os_path):
325 elif not os.path.isdir(os_path):
326 raise web.HTTPError(400, u'Not a directory: %s' % (os_path))
326 raise web.HTTPError(400, u'Not a directory: %s' % (os_path))
327 else:
327 else:
328 self.log.debug("Directory %r already exists", os_path)
328 self.log.debug("Directory %r already exists", os_path)
329
329
330 def save(self, model, name='', path=''):
330 def save(self, model, name='', path=''):
331 """Save the file model and return the model with no content."""
331 """Save the file model and return the model with no content."""
332 path = path.strip('/')
332 path = path.strip('/')
333
333
334 if 'type' not in model:
334 if 'type' not in model:
335 raise web.HTTPError(400, u'No file type provided')
335 raise web.HTTPError(400, u'No file type provided')
336 if 'content' not in model and model['type'] != 'directory':
336 if 'content' not in model and model['type'] != 'directory':
337 raise web.HTTPError(400, u'No file content provided')
337 raise web.HTTPError(400, u'No file content provided')
338
338
339 # One checkpoint should always exist
339 # One checkpoint should always exist
340 if self.file_exists(name, path) and not self.list_checkpoints(name, path):
340 if self.file_exists(name, path) and not self.list_checkpoints(name, path):
341 self.create_checkpoint(name, path)
341 self.create_checkpoint(name, path)
342
342
343 new_path = model.get('path', path).strip('/')
343 new_path = model.get('path', path).strip('/')
344 new_name = model.get('name', name)
344 new_name = model.get('name', name)
345
345
346 if path != new_path or name != new_name:
346 if path != new_path or name != new_name:
347 self.rename(name, path, new_name, new_path)
347 self.rename(name, path, new_name, new_path)
348
348
349 os_path = self._get_os_path(new_name, new_path)
349 os_path = self._get_os_path(new_name, new_path)
350 self.log.debug("Saving %s", os_path)
350 self.log.debug("Saving %s", os_path)
351 try:
351 try:
352 if model['type'] == 'notebook':
352 if model['type'] == 'notebook':
353 self._save_notebook(os_path, model, new_name, new_path)
353 self._save_notebook(os_path, model, new_name, new_path)
354 elif model['type'] == 'file':
354 elif model['type'] == 'file':
355 self._save_file(os_path, model, new_name, new_path)
355 self._save_file(os_path, model, new_name, new_path)
356 elif model['type'] == 'directory':
356 elif model['type'] == 'directory':
357 self._save_directory(os_path, model, new_name, new_path)
357 self._save_directory(os_path, model, new_name, new_path)
358 else:
358 else:
359 raise web.HTTPError(400, "Unhandled contents type: %s" % model['type'])
359 raise web.HTTPError(400, "Unhandled contents type: %s" % model['type'])
360 except web.HTTPError:
360 except web.HTTPError:
361 raise
361 raise
362 except Exception as e:
362 except Exception as e:
363 raise web.HTTPError(400, u'Unexpected error while saving file: %s %s' % (os_path, e))
363 raise web.HTTPError(400, u'Unexpected error while saving file: %s %s' % (os_path, e))
364
364
365 model = self.get_model(new_name, new_path, content=False)
365 model = self.get_model(new_name, new_path, content=False)
366 return model
366 return model
367
367
368 def update(self, model, name, path=''):
368 def update(self, model, name, path=''):
369 """Update the file's path and/or name
369 """Update the file's path and/or name
370
370
371 For use in PATCH requests, to enable renaming a file without
371 For use in PATCH requests, to enable renaming a file without
372 re-uploading its contents. Only used for renaming at the moment.
372 re-uploading its contents. Only used for renaming at the moment.
373 """
373 """
374 path = path.strip('/')
374 path = path.strip('/')
375 new_name = model.get('name', name)
375 new_name = model.get('name', name)
376 new_path = model.get('path', path).strip('/')
376 new_path = model.get('path', path).strip('/')
377 if path != new_path or name != new_name:
377 if path != new_path or name != new_name:
378 self.rename(name, path, new_name, new_path)
378 self.rename(name, path, new_name, new_path)
379 model = self.get_model(new_name, new_path, content=False)
379 model = self.get_model(new_name, new_path, content=False)
380 return model
380 return model
381
381
382 def delete(self, name, path=''):
382 def delete(self, name, path=''):
383 """Delete file by name and path."""
383 """Delete file by name and path."""
384 path = path.strip('/')
384 path = path.strip('/')
385 os_path = self._get_os_path(name, path)
385 os_path = self._get_os_path(name, path)
386 rm = os.unlink
386 rm = os.unlink
387 if os.path.isdir(os_path):
387 if os.path.isdir(os_path):
388 listing = os.listdir(os_path)
388 listing = os.listdir(os_path)
389 # don't delete non-empty directories (checkpoints dir doesn't count)
389 # don't delete non-empty directories (checkpoints dir doesn't count)
390 if listing and listing != [self.checkpoint_dir]:
390 if listing and listing != [self.checkpoint_dir]:
391 raise web.HTTPError(400, u'Directory %s not empty' % os_path)
391 raise web.HTTPError(400, u'Directory %s not empty' % os_path)
392 elif not os.path.isfile(os_path):
392 elif not os.path.isfile(os_path):
393 raise web.HTTPError(404, u'File does not exist: %s' % os_path)
393 raise web.HTTPError(404, u'File does not exist: %s' % os_path)
394
394
395 # clear checkpoints
395 # clear checkpoints
396 for checkpoint in self.list_checkpoints(name, path):
396 for checkpoint in self.list_checkpoints(name, path):
397 checkpoint_id = checkpoint['id']
397 checkpoint_id = checkpoint['id']
398 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
398 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
399 if os.path.isfile(cp_path):
399 if os.path.isfile(cp_path):
400 self.log.debug("Unlinking checkpoint %s", cp_path)
400 self.log.debug("Unlinking checkpoint %s", cp_path)
401 os.unlink(cp_path)
401 os.unlink(cp_path)
402
402
403 if os.path.isdir(os_path):
403 if os.path.isdir(os_path):
404 self.log.debug("Removing directory %s", os_path)
404 self.log.debug("Removing directory %s", os_path)
405 shutil.rmtree(os_path)
405 shutil.rmtree(os_path)
406 else:
406 else:
407 self.log.debug("Unlinking file %s", os_path)
407 self.log.debug("Unlinking file %s", os_path)
408 rm(os_path)
408 rm(os_path)
409
409
410 def rename(self, old_name, old_path, new_name, new_path):
410 def rename(self, old_name, old_path, new_name, new_path):
411 """Rename a file."""
411 """Rename a file."""
412 old_path = old_path.strip('/')
412 old_path = old_path.strip('/')
413 new_path = new_path.strip('/')
413 new_path = new_path.strip('/')
414 if new_name == old_name and new_path == old_path:
414 if new_name == old_name and new_path == old_path:
415 return
415 return
416
416
417 new_os_path = self._get_os_path(new_name, new_path)
417 new_os_path = self._get_os_path(new_name, new_path)
418 old_os_path = self._get_os_path(old_name, old_path)
418 old_os_path = self._get_os_path(old_name, old_path)
419
419
420 # Should we proceed with the move?
420 # Should we proceed with the move?
421 if os.path.isfile(new_os_path):
421 if os.path.isfile(new_os_path):
422 raise web.HTTPError(409, u'File with name already exists: %s' % new_os_path)
422 raise web.HTTPError(409, u'File with name already exists: %s' % new_os_path)
423
423
424 # Move the file
424 # Move the file
425 try:
425 try:
426 shutil.move(old_os_path, new_os_path)
426 shutil.move(old_os_path, new_os_path)
427 except Exception as e:
427 except Exception as e:
428 raise web.HTTPError(500, u'Unknown error renaming file: %s %s' % (old_os_path, e))
428 raise web.HTTPError(500, u'Unknown error renaming file: %s %s' % (old_os_path, e))
429
429
430 # Move the checkpoints
430 # Move the checkpoints
431 old_checkpoints = self.list_checkpoints(old_name, old_path)
431 old_checkpoints = self.list_checkpoints(old_name, old_path)
432 for cp in old_checkpoints:
432 for cp in old_checkpoints:
433 checkpoint_id = cp['id']
433 checkpoint_id = cp['id']
434 old_cp_path = self.get_checkpoint_path(checkpoint_id, old_name, old_path)
434 old_cp_path = self.get_checkpoint_path(checkpoint_id, old_name, old_path)
435 new_cp_path = self.get_checkpoint_path(checkpoint_id, new_name, new_path)
435 new_cp_path = self.get_checkpoint_path(checkpoint_id, new_name, new_path)
436 if os.path.isfile(old_cp_path):
436 if os.path.isfile(old_cp_path):
437 self.log.debug("Renaming checkpoint %s -> %s", old_cp_path, new_cp_path)
437 self.log.debug("Renaming checkpoint %s -> %s", old_cp_path, new_cp_path)
438 shutil.move(old_cp_path, new_cp_path)
438 shutil.move(old_cp_path, new_cp_path)
439
439
440 # Checkpoint-related utilities
440 # Checkpoint-related utilities
441
441
442 def get_checkpoint_path(self, checkpoint_id, name, path=''):
442 def get_checkpoint_path(self, checkpoint_id, name, path=''):
443 """find the path to a checkpoint"""
443 """find the path to a checkpoint"""
444 path = path.strip('/')
444 path = path.strip('/')
445 basename, ext = os.path.splitext(name)
445 basename, ext = os.path.splitext(name)
446 filename = u"{name}-{checkpoint_id}{ext}".format(
446 filename = u"{name}-{checkpoint_id}{ext}".format(
447 name=basename,
447 name=basename,
448 checkpoint_id=checkpoint_id,
448 checkpoint_id=checkpoint_id,
449 ext=ext,
449 ext=ext,
450 )
450 )
451 os_path = self._get_os_path(path=path)
451 os_path = self._get_os_path(path=path)
452 cp_dir = os.path.join(os_path, self.checkpoint_dir)
452 cp_dir = os.path.join(os_path, self.checkpoint_dir)
453 ensure_dir_exists(cp_dir)
453 ensure_dir_exists(cp_dir)
454 cp_path = os.path.join(cp_dir, filename)
454 cp_path = os.path.join(cp_dir, filename)
455 return cp_path
455 return cp_path
456
456
457 def get_checkpoint_model(self, checkpoint_id, name, path=''):
457 def get_checkpoint_model(self, checkpoint_id, name, path=''):
458 """construct the info dict for a given checkpoint"""
458 """construct the info dict for a given checkpoint"""
459 path = path.strip('/')
459 path = path.strip('/')
460 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
460 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
461 stats = os.stat(cp_path)
461 stats = os.stat(cp_path)
462 last_modified = tz.utcfromtimestamp(stats.st_mtime)
462 last_modified = tz.utcfromtimestamp(stats.st_mtime)
463 info = dict(
463 info = dict(
464 id = checkpoint_id,
464 id = checkpoint_id,
465 last_modified = last_modified,
465 last_modified = last_modified,
466 )
466 )
467 return info
467 return info
468
468
469 # public checkpoint API
469 # public checkpoint API
470
470
471 def create_checkpoint(self, name, path=''):
471 def create_checkpoint(self, name, path=''):
472 """Create a checkpoint from the current state of a file"""
472 """Create a checkpoint from the current state of a file"""
473 path = path.strip('/')
473 path = path.strip('/')
474 src_path = self._get_os_path(name, path)
474 src_path = self._get_os_path(name, path)
475 # only the one checkpoint ID:
475 # only the one checkpoint ID:
476 checkpoint_id = u"checkpoint"
476 checkpoint_id = u"checkpoint"
477 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
477 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
478 self.log.debug("creating checkpoint for %s", name)
478 self.log.debug("creating checkpoint for %s", name)
479 self._copy(src_path, cp_path)
479 self._copy(src_path, cp_path)
480
480
481 # return the checkpoint info
481 # return the checkpoint info
482 return self.get_checkpoint_model(checkpoint_id, name, path)
482 return self.get_checkpoint_model(checkpoint_id, name, path)
483
483
484 def list_checkpoints(self, name, path=''):
484 def list_checkpoints(self, name, path=''):
485 """list the checkpoints for a given file
485 """list the checkpoints for a given file
486
486
487 This contents manager currently only supports one checkpoint per file.
487 This contents manager currently only supports one checkpoint per file.
488 """
488 """
489 path = path.strip('/')
489 path = path.strip('/')
490 checkpoint_id = "checkpoint"
490 checkpoint_id = "checkpoint"
491 os_path = self.get_checkpoint_path(checkpoint_id, name, path)
491 os_path = self.get_checkpoint_path(checkpoint_id, name, path)
492 if not os.path.exists(os_path):
492 if not os.path.exists(os_path):
493 return []
493 return []
494 else:
494 else:
495 return [self.get_checkpoint_model(checkpoint_id, name, path)]
495 return [self.get_checkpoint_model(checkpoint_id, name, path)]
496
496
497
497
498 def restore_checkpoint(self, checkpoint_id, name, path=''):
498 def restore_checkpoint(self, checkpoint_id, name, path=''):
499 """restore a file to a checkpointed state"""
499 """restore a file to a checkpointed state"""
500 path = path.strip('/')
500 path = path.strip('/')
501 self.log.info("restoring %s from checkpoint %s", name, checkpoint_id)
501 self.log.info("restoring %s from checkpoint %s", name, checkpoint_id)
502 nb_path = self._get_os_path(name, path)
502 nb_path = self._get_os_path(name, path)
503 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
503 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
504 if not os.path.isfile(cp_path):
504 if not os.path.isfile(cp_path):
505 self.log.debug("checkpoint file does not exist: %s", cp_path)
505 self.log.debug("checkpoint file does not exist: %s", cp_path)
506 raise web.HTTPError(404,
506 raise web.HTTPError(404,
507 u'checkpoint does not exist: %s-%s' % (name, checkpoint_id)
507 u'checkpoint does not exist: %s-%s' % (name, checkpoint_id)
508 )
508 )
509 # ensure notebook is readable (never restore from an unreadable notebook)
509 # ensure notebook is readable (never restore from an unreadable notebook)
510 if cp_path.endswith('.ipynb'):
510 if cp_path.endswith('.ipynb'):
511 with io.open(cp_path, 'r', encoding='utf-8') as f:
511 with io.open(cp_path, 'r', encoding='utf-8') as f:
512 current.read(f, u'json')
512 current.read(f, u'json')
513 self._copy(cp_path, nb_path)
513 self._copy(cp_path, nb_path)
514 self.log.debug("copying %s -> %s", cp_path, nb_path)
514 self.log.debug("copying %s -> %s", cp_path, nb_path)
515
515
516 def delete_checkpoint(self, checkpoint_id, name, path=''):
516 def delete_checkpoint(self, checkpoint_id, name, path=''):
517 """delete a file's checkpoint"""
517 """delete a file's checkpoint"""
518 path = path.strip('/')
518 path = path.strip('/')
519 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
519 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
520 if not os.path.isfile(cp_path):
520 if not os.path.isfile(cp_path):
521 raise web.HTTPError(404,
521 raise web.HTTPError(404,
522 u'Checkpoint does not exist: %s%s-%s' % (path, name, checkpoint_id)
522 u'Checkpoint does not exist: %s%s-%s' % (path, name, checkpoint_id)
523 )
523 )
524 self.log.debug("unlinking %s", cp_path)
524 self.log.debug("unlinking %s", cp_path)
525 os.unlink(cp_path)
525 os.unlink(cp_path)
526
526
527 def info_string(self):
527 def info_string(self):
528 return "Serving notebooks from local directory: %s" % self.root_dir
528 return "Serving notebooks from local directory: %s" % self.root_dir
529
529
530 def get_kernel_path(self, name, path='', model=None):
530 def get_kernel_path(self, name, path='', model=None):
531 """Return the initial working dir a kernel associated with a given notebook"""
531 """Return the initial working dir a kernel associated with a given notebook"""
532 return os.path.join(self.root_dir, path)
532 return os.path.join(self.root_dir, path)
@@ -1,280 +1,315 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 IO related utilities.
3 IO related utilities.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 from __future__ import print_function
12 from __future__ import print_function
13 from __future__ import absolute_import
13 from __future__ import absolute_import
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 import codecs
18 import codecs
19 from contextlib import contextmanager
19 from contextlib import contextmanager
20 import io
20 import os
21 import os
21 import sys
22 import sys
22 import tempfile
23 import tempfile
23 from .capture import CapturedIO, capture_output
24 from .capture import CapturedIO, capture_output
24 from .py3compat import string_types, input, PY3
25 from .py3compat import string_types, input, PY3
25
26
26 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
27 # Code
28 # Code
28 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
29
30
30
31
31 class IOStream:
32 class IOStream:
32
33
33 def __init__(self,stream, fallback=None):
34 def __init__(self,stream, fallback=None):
34 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
35 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
35 if fallback is not None:
36 if fallback is not None:
36 stream = fallback
37 stream = fallback
37 else:
38 else:
38 raise ValueError("fallback required, but not specified")
39 raise ValueError("fallback required, but not specified")
39 self.stream = stream
40 self.stream = stream
40 self._swrite = stream.write
41 self._swrite = stream.write
41
42
42 # clone all methods not overridden:
43 # clone all methods not overridden:
43 def clone(meth):
44 def clone(meth):
44 return not hasattr(self, meth) and not meth.startswith('_')
45 return not hasattr(self, meth) and not meth.startswith('_')
45 for meth in filter(clone, dir(stream)):
46 for meth in filter(clone, dir(stream)):
46 setattr(self, meth, getattr(stream, meth))
47 setattr(self, meth, getattr(stream, meth))
47
48
48 def __repr__(self):
49 def __repr__(self):
49 cls = self.__class__
50 cls = self.__class__
50 tpl = '{mod}.{cls}({args})'
51 tpl = '{mod}.{cls}({args})'
51 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
52 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
52
53
53 def write(self,data):
54 def write(self,data):
54 try:
55 try:
55 self._swrite(data)
56 self._swrite(data)
56 except:
57 except:
57 try:
58 try:
58 # print handles some unicode issues which may trip a plain
59 # print handles some unicode issues which may trip a plain
59 # write() call. Emulate write() by using an empty end
60 # write() call. Emulate write() by using an empty end
60 # argument.
61 # argument.
61 print(data, end='', file=self.stream)
62 print(data, end='', file=self.stream)
62 except:
63 except:
63 # if we get here, something is seriously broken.
64 # if we get here, something is seriously broken.
64 print('ERROR - failed to write data to stream:', self.stream,
65 print('ERROR - failed to write data to stream:', self.stream,
65 file=sys.stderr)
66 file=sys.stderr)
66
67
67 def writelines(self, lines):
68 def writelines(self, lines):
68 if isinstance(lines, string_types):
69 if isinstance(lines, string_types):
69 lines = [lines]
70 lines = [lines]
70 for line in lines:
71 for line in lines:
71 self.write(line)
72 self.write(line)
72
73
73 # This class used to have a writeln method, but regular files and streams
74 # This class used to have a writeln method, but regular files and streams
74 # in Python don't have this method. We need to keep this completely
75 # in Python don't have this method. We need to keep this completely
75 # compatible so we removed it.
76 # compatible so we removed it.
76
77
77 @property
78 @property
78 def closed(self):
79 def closed(self):
79 return self.stream.closed
80 return self.stream.closed
80
81
81 def close(self):
82 def close(self):
82 pass
83 pass
83
84
84 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
85 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
85 devnull = open(os.devnull, 'w')
86 devnull = open(os.devnull, 'w')
86 stdin = IOStream(sys.stdin, fallback=devnull)
87 stdin = IOStream(sys.stdin, fallback=devnull)
87 stdout = IOStream(sys.stdout, fallback=devnull)
88 stdout = IOStream(sys.stdout, fallback=devnull)
88 stderr = IOStream(sys.stderr, fallback=devnull)
89 stderr = IOStream(sys.stderr, fallback=devnull)
89
90
90 class IOTerm:
91 class IOTerm:
91 """ Term holds the file or file-like objects for handling I/O operations.
92 """ Term holds the file or file-like objects for handling I/O operations.
92
93
93 These are normally just sys.stdin, sys.stdout and sys.stderr but for
94 These are normally just sys.stdin, sys.stdout and sys.stderr but for
94 Windows they can can replaced to allow editing the strings before they are
95 Windows they can can replaced to allow editing the strings before they are
95 displayed."""
96 displayed."""
96
97
97 # In the future, having IPython channel all its I/O operations through
98 # In the future, having IPython channel all its I/O operations through
98 # this class will make it easier to embed it into other environments which
99 # this class will make it easier to embed it into other environments which
99 # are not a normal terminal (such as a GUI-based shell)
100 # are not a normal terminal (such as a GUI-based shell)
100 def __init__(self, stdin=None, stdout=None, stderr=None):
101 def __init__(self, stdin=None, stdout=None, stderr=None):
101 mymodule = sys.modules[__name__]
102 mymodule = sys.modules[__name__]
102 self.stdin = IOStream(stdin, mymodule.stdin)
103 self.stdin = IOStream(stdin, mymodule.stdin)
103 self.stdout = IOStream(stdout, mymodule.stdout)
104 self.stdout = IOStream(stdout, mymodule.stdout)
104 self.stderr = IOStream(stderr, mymodule.stderr)
105 self.stderr = IOStream(stderr, mymodule.stderr)
105
106
106
107
107 class Tee(object):
108 class Tee(object):
108 """A class to duplicate an output stream to stdout/err.
109 """A class to duplicate an output stream to stdout/err.
109
110
110 This works in a manner very similar to the Unix 'tee' command.
111 This works in a manner very similar to the Unix 'tee' command.
111
112
112 When the object is closed or deleted, it closes the original file given to
113 When the object is closed or deleted, it closes the original file given to
113 it for duplication.
114 it for duplication.
114 """
115 """
115 # Inspired by:
116 # Inspired by:
116 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
117 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
117
118
118 def __init__(self, file_or_name, mode="w", channel='stdout'):
119 def __init__(self, file_or_name, mode="w", channel='stdout'):
119 """Construct a new Tee object.
120 """Construct a new Tee object.
120
121
121 Parameters
122 Parameters
122 ----------
123 ----------
123 file_or_name : filename or open filehandle (writable)
124 file_or_name : filename or open filehandle (writable)
124 File that will be duplicated
125 File that will be duplicated
125
126
126 mode : optional, valid mode for open().
127 mode : optional, valid mode for open().
127 If a filename was give, open with this mode.
128 If a filename was give, open with this mode.
128
129
129 channel : str, one of ['stdout', 'stderr']
130 channel : str, one of ['stdout', 'stderr']
130 """
131 """
131 if channel not in ['stdout', 'stderr']:
132 if channel not in ['stdout', 'stderr']:
132 raise ValueError('Invalid channel spec %s' % channel)
133 raise ValueError('Invalid channel spec %s' % channel)
133
134
134 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
135 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
135 self.file = file_or_name
136 self.file = file_or_name
136 else:
137 else:
137 self.file = open(file_or_name, mode)
138 self.file = open(file_or_name, mode)
138 self.channel = channel
139 self.channel = channel
139 self.ostream = getattr(sys, channel)
140 self.ostream = getattr(sys, channel)
140 setattr(sys, channel, self)
141 setattr(sys, channel, self)
141 self._closed = False
142 self._closed = False
142
143
143 def close(self):
144 def close(self):
144 """Close the file and restore the channel."""
145 """Close the file and restore the channel."""
145 self.flush()
146 self.flush()
146 setattr(sys, self.channel, self.ostream)
147 setattr(sys, self.channel, self.ostream)
147 self.file.close()
148 self.file.close()
148 self._closed = True
149 self._closed = True
149
150
150 def write(self, data):
151 def write(self, data):
151 """Write data to both channels."""
152 """Write data to both channels."""
152 self.file.write(data)
153 self.file.write(data)
153 self.ostream.write(data)
154 self.ostream.write(data)
154 self.ostream.flush()
155 self.ostream.flush()
155
156
156 def flush(self):
157 def flush(self):
157 """Flush both channels."""
158 """Flush both channels."""
158 self.file.flush()
159 self.file.flush()
159 self.ostream.flush()
160 self.ostream.flush()
160
161
161 def __del__(self):
162 def __del__(self):
162 if not self._closed:
163 if not self._closed:
163 self.close()
164 self.close()
164
165
165
166
166 def ask_yes_no(prompt, default=None, interrupt=None):
167 def ask_yes_no(prompt, default=None, interrupt=None):
167 """Asks a question and returns a boolean (y/n) answer.
168 """Asks a question and returns a boolean (y/n) answer.
168
169
169 If default is given (one of 'y','n'), it is used if the user input is
170 If default is given (one of 'y','n'), it is used if the user input is
170 empty. If interrupt is given (one of 'y','n'), it is used if the user
171 empty. If interrupt is given (one of 'y','n'), it is used if the user
171 presses Ctrl-C. Otherwise the question is repeated until an answer is
172 presses Ctrl-C. Otherwise the question is repeated until an answer is
172 given.
173 given.
173
174
174 An EOF is treated as the default answer. If there is no default, an
175 An EOF is treated as the default answer. If there is no default, an
175 exception is raised to prevent infinite loops.
176 exception is raised to prevent infinite loops.
176
177
177 Valid answers are: y/yes/n/no (match is not case sensitive)."""
178 Valid answers are: y/yes/n/no (match is not case sensitive)."""
178
179
179 answers = {'y':True,'n':False,'yes':True,'no':False}
180 answers = {'y':True,'n':False,'yes':True,'no':False}
180 ans = None
181 ans = None
181 while ans not in answers.keys():
182 while ans not in answers.keys():
182 try:
183 try:
183 ans = input(prompt+' ').lower()
184 ans = input(prompt+' ').lower()
184 if not ans: # response was an empty string
185 if not ans: # response was an empty string
185 ans = default
186 ans = default
186 except KeyboardInterrupt:
187 except KeyboardInterrupt:
187 if interrupt:
188 if interrupt:
188 ans = interrupt
189 ans = interrupt
189 except EOFError:
190 except EOFError:
190 if default in answers.keys():
191 if default in answers.keys():
191 ans = default
192 ans = default
192 print()
193 print()
193 else:
194 else:
194 raise
195 raise
195
196
196 return answers[ans]
197 return answers[ans]
197
198
198
199
199 def temp_pyfile(src, ext='.py'):
200 def temp_pyfile(src, ext='.py'):
200 """Make a temporary python file, return filename and filehandle.
201 """Make a temporary python file, return filename and filehandle.
201
202
202 Parameters
203 Parameters
203 ----------
204 ----------
204 src : string or list of strings (no need for ending newlines if list)
205 src : string or list of strings (no need for ending newlines if list)
205 Source code to be written to the file.
206 Source code to be written to the file.
206
207
207 ext : optional, string
208 ext : optional, string
208 Extension for the generated file.
209 Extension for the generated file.
209
210
210 Returns
211 Returns
211 -------
212 -------
212 (filename, open filehandle)
213 (filename, open filehandle)
213 It is the caller's responsibility to close the open file and unlink it.
214 It is the caller's responsibility to close the open file and unlink it.
214 """
215 """
215 fname = tempfile.mkstemp(ext)[1]
216 fname = tempfile.mkstemp(ext)[1]
216 f = open(fname,'w')
217 f = open(fname,'w')
217 f.write(src)
218 f.write(src)
218 f.flush()
219 f.flush()
219 return fname, f
220 return fname, f
220
221
221 @contextmanager
222 @contextmanager
222 def atomic_writing(path, mode='w', encoding='utf-8', **kwargs):
223 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
223 tmp_file = path + '.tmp-write'
224 """Context manager to write to a file only if the entire write is successful.
224 if 'b' in mode:
225
225 encoding = None
226 This works by creating a temporary file in the same directory, and renaming
226
227 it over the old file if the context is exited without an error. If the
227 with open(tmp_file, mode, encoding=encoding, **kwargs) as f:
228 target file is a symlink or a hardlink, this will not be preserved: it will
228 yield f
229 be replaced by a new regular file.
229
230
230 # Written successfully, now rename it
231 On Windows, there is a small chink in the atomicity: the target file is
232 deleted before renaming the temporary file over it. This appears to be
233 unavoidable.
234
235 Parameters
236 ----------
237 path : str
238 The target file to write to.
239
240 text : bool, optional
241 Whether to open the file in text mode (i.e. to write unicode). Default is
242 True.
243
244 encoding : str, optional
245 The encoding to use for files opened in text mode. Default is UTF-8.
246
247 **kwargs
248 Passed to :func:`io.open`.
249 """
250 dirname, basename = os.path.split(path)
251 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname, text=text)
252 if text:
253 fileobj = io.open(handle, 'w', encoding=encoding, **kwargs)
254 else:
255 fileobj = io.open(handle, 'wb', **kwargs)
256
257 try:
258 yield fileobj
259 except:
260 fileobj.close()
261 os.remove(tmp_path)
262 raise
263 else:
264 # Written successfully, now rename it
265 fileobj.close()
231
266
232 if os.name == 'nt' and os.path.exists(path):
267 if os.name == 'nt' and os.path.exists(path):
233 # Rename over existing file doesn't work on Windows
268 # Rename over existing file doesn't work on Windows
234 os.remove(path)
269 os.remove(path)
235
270
236 os.rename(tmp_file, path)
271 os.rename(tmp_path, path)
237
272
238
273
239 def raw_print(*args, **kw):
274 def raw_print(*args, **kw):
240 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
275 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
241
276
242 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
277 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
243 file=sys.__stdout__)
278 file=sys.__stdout__)
244 sys.__stdout__.flush()
279 sys.__stdout__.flush()
245
280
246
281
247 def raw_print_err(*args, **kw):
282 def raw_print_err(*args, **kw):
248 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
283 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
249
284
250 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
285 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
251 file=sys.__stderr__)
286 file=sys.__stderr__)
252 sys.__stderr__.flush()
287 sys.__stderr__.flush()
253
288
254
289
255 # Short aliases for quick debugging, do NOT use these in production code.
290 # Short aliases for quick debugging, do NOT use these in production code.
256 rprint = raw_print
291 rprint = raw_print
257 rprinte = raw_print_err
292 rprinte = raw_print_err
258
293
259 def unicode_std_stream(stream='stdout'):
294 def unicode_std_stream(stream='stdout'):
260 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
295 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
261
296
262 This ignores environment variables and default encodings, to reliably write
297 This ignores environment variables and default encodings, to reliably write
263 unicode to stdout or stderr.
298 unicode to stdout or stderr.
264
299
265 ::
300 ::
266
301
267 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
302 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
268 """
303 """
269 assert stream in ('stdout', 'stderr')
304 assert stream in ('stdout', 'stderr')
270 stream = getattr(sys, stream)
305 stream = getattr(sys, stream)
271 if PY3:
306 if PY3:
272 try:
307 try:
273 stream_b = stream.buffer
308 stream_b = stream.buffer
274 except AttributeError:
309 except AttributeError:
275 # sys.stdout has been replaced - use it directly
310 # sys.stdout has been replaced - use it directly
276 return stream
311 return stream
277 else:
312 else:
278 stream_b = stream
313 stream_b = stream
279
314
280 return codecs.getwriter('utf-8')(stream_b)
315 return codecs.getwriter('utf-8')(stream_b)
@@ -1,124 +1,151 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for io.py"""
2 """Tests for io.py"""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2008-2011 The IPython Development Team
5 # Copyright (C) 2008-2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 from __future__ import print_function
14 from __future__ import print_function
15 from __future__ import absolute_import
15 from __future__ import absolute_import
16
16
17 import io as stdlib_io
17 import io as stdlib_io
18 import os.path
18 import sys
19 import sys
19
20
20 from subprocess import Popen, PIPE
21 from subprocess import Popen, PIPE
21 import unittest
22 import unittest
22
23
23 import nose.tools as nt
24 import nose.tools as nt
24
25
25 from IPython.testing.decorators import skipif
26 from IPython.testing.decorators import skipif
26 from IPython.utils.io import Tee, capture_output, unicode_std_stream
27 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
28 atomic_writing,
29 )
27 from IPython.utils.py3compat import doctest_refactor_print, PY3
30 from IPython.utils.py3compat import doctest_refactor_print, PY3
31 from IPython.utils.tempdir import TemporaryDirectory
28
32
29 if PY3:
33 if PY3:
30 from io import StringIO
34 from io import StringIO
31 else:
35 else:
32 from StringIO import StringIO
36 from StringIO import StringIO
33
37
34 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
35 # Tests
39 # Tests
36 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
37
41
38
42
39 def test_tee_simple():
43 def test_tee_simple():
40 "Very simple check with stdout only"
44 "Very simple check with stdout only"
41 chan = StringIO()
45 chan = StringIO()
42 text = 'Hello'
46 text = 'Hello'
43 tee = Tee(chan, channel='stdout')
47 tee = Tee(chan, channel='stdout')
44 print(text, file=chan)
48 print(text, file=chan)
45 nt.assert_equal(chan.getvalue(), text+"\n")
49 nt.assert_equal(chan.getvalue(), text+"\n")
46
50
47
51
48 class TeeTestCase(unittest.TestCase):
52 class TeeTestCase(unittest.TestCase):
49
53
50 def tchan(self, channel, check='close'):
54 def tchan(self, channel, check='close'):
51 trap = StringIO()
55 trap = StringIO()
52 chan = StringIO()
56 chan = StringIO()
53 text = 'Hello'
57 text = 'Hello'
54
58
55 std_ori = getattr(sys, channel)
59 std_ori = getattr(sys, channel)
56 setattr(sys, channel, trap)
60 setattr(sys, channel, trap)
57
61
58 tee = Tee(chan, channel=channel)
62 tee = Tee(chan, channel=channel)
59 print(text, end='', file=chan)
63 print(text, end='', file=chan)
60 setattr(sys, channel, std_ori)
64 setattr(sys, channel, std_ori)
61 trap_val = trap.getvalue()
65 trap_val = trap.getvalue()
62 nt.assert_equal(chan.getvalue(), text)
66 nt.assert_equal(chan.getvalue(), text)
63 if check=='close':
67 if check=='close':
64 tee.close()
68 tee.close()
65 else:
69 else:
66 del tee
70 del tee
67
71
68 def test(self):
72 def test(self):
69 for chan in ['stdout', 'stderr']:
73 for chan in ['stdout', 'stderr']:
70 for check in ['close', 'del']:
74 for check in ['close', 'del']:
71 self.tchan(chan, check)
75 self.tchan(chan, check)
72
76
73 def test_io_init():
77 def test_io_init():
74 """Test that io.stdin/out/err exist at startup"""
78 """Test that io.stdin/out/err exist at startup"""
75 for name in ('stdin', 'stdout', 'stderr'):
79 for name in ('stdin', 'stdout', 'stderr'):
76 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
80 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
77 p = Popen([sys.executable, '-c', cmd],
81 p = Popen([sys.executable, '-c', cmd],
78 stdout=PIPE)
82 stdout=PIPE)
79 p.wait()
83 p.wait()
80 classname = p.stdout.read().strip().decode('ascii')
84 classname = p.stdout.read().strip().decode('ascii')
81 # __class__ is a reference to the class object in Python 3, so we can't
85 # __class__ is a reference to the class object in Python 3, so we can't
82 # just test for string equality.
86 # just test for string equality.
83 assert 'IPython.utils.io.IOStream' in classname, classname
87 assert 'IPython.utils.io.IOStream' in classname, classname
84
88
85 def test_capture_output():
89 def test_capture_output():
86 """capture_output() context works"""
90 """capture_output() context works"""
87
91
88 with capture_output() as io:
92 with capture_output() as io:
89 print('hi, stdout')
93 print('hi, stdout')
90 print('hi, stderr', file=sys.stderr)
94 print('hi, stderr', file=sys.stderr)
91
95
92 nt.assert_equal(io.stdout, 'hi, stdout\n')
96 nt.assert_equal(io.stdout, 'hi, stdout\n')
93 nt.assert_equal(io.stderr, 'hi, stderr\n')
97 nt.assert_equal(io.stderr, 'hi, stderr\n')
94
98
95 def test_UnicodeStdStream():
99 def test_UnicodeStdStream():
96 # Test wrapping a bytes-level stdout
100 # Test wrapping a bytes-level stdout
97 if PY3:
101 if PY3:
98 stdoutb = stdlib_io.BytesIO()
102 stdoutb = stdlib_io.BytesIO()
99 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
103 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
100 else:
104 else:
101 stdout = stdoutb = stdlib_io.BytesIO()
105 stdout = stdoutb = stdlib_io.BytesIO()
102
106
103 orig_stdout = sys.stdout
107 orig_stdout = sys.stdout
104 sys.stdout = stdout
108 sys.stdout = stdout
105 try:
109 try:
106 sample = u"@Ε‚e¢ŧ←"
110 sample = u"@Ε‚e¢ŧ←"
107 unicode_std_stream().write(sample)
111 unicode_std_stream().write(sample)
108
112
109 output = stdoutb.getvalue().decode('utf-8')
113 output = stdoutb.getvalue().decode('utf-8')
110 nt.assert_equal(output, sample)
114 nt.assert_equal(output, sample)
111 assert not stdout.closed
115 assert not stdout.closed
112 finally:
116 finally:
113 sys.stdout = orig_stdout
117 sys.stdout = orig_stdout
114
118
115 @skipif(not PY3, "Not applicable on Python 2")
119 @skipif(not PY3, "Not applicable on Python 2")
116 def test_UnicodeStdStream_nowrap():
120 def test_UnicodeStdStream_nowrap():
117 # If we replace stdout with a StringIO, it shouldn't get wrapped.
121 # If we replace stdout with a StringIO, it shouldn't get wrapped.
118 orig_stdout = sys.stdout
122 orig_stdout = sys.stdout
119 sys.stdout = StringIO()
123 sys.stdout = StringIO()
120 try:
124 try:
121 nt.assert_is(unicode_std_stream(), sys.stdout)
125 nt.assert_is(unicode_std_stream(), sys.stdout)
122 assert not sys.stdout.closed
126 assert not sys.stdout.closed
123 finally:
127 finally:
124 sys.stdout = orig_stdout No newline at end of file
128 sys.stdout = orig_stdout
129
130 def test_atomic_writing():
131 class CustomExc(Exception): pass
132
133 with TemporaryDirectory() as td:
134 f1 = os.path.join(td, 'penguin')
135 with stdlib_io.open(f1, 'w') as f:
136 f.write(u'Before')
137
138 with nt.assert_raises(CustomExc):
139 with atomic_writing(f1) as f:
140 f.write(u'Failing write')
141 raise CustomExc
142
143 # Because of the exception, the file should not have been modified
144 with stdlib_io.open(f1, 'r') as f:
145 nt.assert_equal(f.read(), u'Before')
146
147 with atomic_writing(f1) as f:
148 f.write(u'Overwritten')
149
150 with stdlib_io.open(f1, 'r') as f:
151 nt.assert_equal(f.read(), u'Overwritten')
General Comments 0
You need to be logged in to leave comments. Login now