##// END OF EJS Templates
Rework atomic_writing with tests & docstring
Thomas Kluyver -
Show More
@@ -1,532 +1,532 b''
1 1 """A contents manager that uses the local file system for storage."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import base64
7 7 import io
8 8 import os
9 9 import glob
10 10 import shutil
11 11
12 12 from tornado import web
13 13
14 14 from .manager import ContentsManager
15 15 from IPython.nbformat import current
16 16 from IPython.utils.io import atomic_writing
17 17 from IPython.utils.path import ensure_dir_exists
18 18 from IPython.utils.traitlets import Unicode, Bool, TraitError
19 19 from IPython.utils.py3compat import getcwd
20 20 from IPython.utils import tz
21 21 from IPython.html.utils import is_hidden, to_os_path, url_path_join
22 22
23 23
24 24 class FileContentsManager(ContentsManager):
25 25
26 26 root_dir = Unicode(getcwd(), config=True)
27 27
28 28 save_script = Bool(False, config=True, help='DEPRECATED, IGNORED')
29 29 def _save_script_changed(self):
30 30 self.log.warn("""
31 31 Automatically saving notebooks as scripts has been removed.
32 32 Use `ipython nbconvert --to python [notebook]` instead.
33 33 """)
34 34
35 35 def _root_dir_changed(self, name, old, new):
36 36 """Do a bit of validation of the root_dir."""
37 37 if not os.path.isabs(new):
38 38 # If we receive a non-absolute path, make it absolute.
39 39 self.root_dir = os.path.abspath(new)
40 40 return
41 41 if not os.path.isdir(new):
42 42 raise TraitError("%r is not a directory" % new)
43 43
44 44 checkpoint_dir = Unicode('.ipynb_checkpoints', config=True,
45 45 help="""The directory name in which to keep file checkpoints
46 46
47 47 This is a path relative to the file's own directory.
48 48
49 49 By default, it is .ipynb_checkpoints
50 50 """
51 51 )
52 52
53 53 def _copy(self, src, dest):
54 54 """copy src to dest
55 55
56 56 like shutil.copy2, but log errors in copystat
57 57 """
58 58 shutil.copyfile(src, dest)
59 59 try:
60 60 shutil.copystat(src, dest)
61 61 except OSError as e:
62 62 self.log.debug("copystat on %s failed", dest, exc_info=True)
63 63
64 64 def _get_os_path(self, name=None, path=''):
65 65 """Given a filename and API path, return its file system
66 66 path.
67 67
68 68 Parameters
69 69 ----------
70 70 name : string
71 71 A filename
72 72 path : string
73 73 The relative API path to the named file.
74 74
75 75 Returns
76 76 -------
77 77 path : string
78 78 API path to be evaluated relative to root_dir.
79 79 """
80 80 if name is not None:
81 81 path = url_path_join(path, name)
82 82 return to_os_path(path, self.root_dir)
83 83
84 84 def path_exists(self, path):
85 85 """Does the API-style path refer to an extant directory?
86 86
87 87 API-style wrapper for os.path.isdir
88 88
89 89 Parameters
90 90 ----------
91 91 path : string
92 92 The path to check. This is an API path (`/` separated,
93 93 relative to root_dir).
94 94
95 95 Returns
96 96 -------
97 97 exists : bool
98 98 Whether the path is indeed a directory.
99 99 """
100 100 path = path.strip('/')
101 101 os_path = self._get_os_path(path=path)
102 102 return os.path.isdir(os_path)
103 103
104 104 def is_hidden(self, path):
105 105 """Does the API style path correspond to a hidden directory or file?
106 106
107 107 Parameters
108 108 ----------
109 109 path : string
110 110 The path to check. This is an API path (`/` separated,
111 111 relative to root_dir).
112 112
113 113 Returns
114 114 -------
115 115 exists : bool
116 116 Whether the path is hidden.
117 117
118 118 """
119 119 path = path.strip('/')
120 120 os_path = self._get_os_path(path=path)
121 121 return is_hidden(os_path, self.root_dir)
122 122
123 123 def file_exists(self, name, path=''):
124 124 """Returns True if the file exists, else returns False.
125 125
126 126 API-style wrapper for os.path.isfile
127 127
128 128 Parameters
129 129 ----------
130 130 name : string
131 131 The name of the file you are checking.
132 132 path : string
133 133 The relative path to the file's directory (with '/' as separator)
134 134
135 135 Returns
136 136 -------
137 137 exists : bool
138 138 Whether the file exists.
139 139 """
140 140 path = path.strip('/')
141 141 nbpath = self._get_os_path(name, path=path)
142 142 return os.path.isfile(nbpath)
143 143
144 144 def exists(self, name=None, path=''):
145 145 """Returns True if the path [and name] exists, else returns False.
146 146
147 147 API-style wrapper for os.path.exists
148 148
149 149 Parameters
150 150 ----------
151 151 name : string
152 152 The name of the file you are checking.
153 153 path : string
154 154 The relative path to the file's directory (with '/' as separator)
155 155
156 156 Returns
157 157 -------
158 158 exists : bool
159 159 Whether the target exists.
160 160 """
161 161 path = path.strip('/')
162 162 os_path = self._get_os_path(name, path=path)
163 163 return os.path.exists(os_path)
164 164
165 165 def _base_model(self, name, path=''):
166 166 """Build the common base of a contents model"""
167 167 os_path = self._get_os_path(name, path)
168 168 info = os.stat(os_path)
169 169 last_modified = tz.utcfromtimestamp(info.st_mtime)
170 170 created = tz.utcfromtimestamp(info.st_ctime)
171 171 # Create the base model.
172 172 model = {}
173 173 model['name'] = name
174 174 model['path'] = path
175 175 model['last_modified'] = last_modified
176 176 model['created'] = created
177 177 model['content'] = None
178 178 model['format'] = None
179 179 return model
180 180
181 181 def _dir_model(self, name, path='', content=True):
182 182 """Build a model for a directory
183 183
184 184 if content is requested, will include a listing of the directory
185 185 """
186 186 os_path = self._get_os_path(name, path)
187 187
188 188 four_o_four = u'directory does not exist: %r' % os_path
189 189
190 190 if not os.path.isdir(os_path):
191 191 raise web.HTTPError(404, four_o_four)
192 192 elif is_hidden(os_path, self.root_dir):
193 193 self.log.info("Refusing to serve hidden directory %r, via 404 Error",
194 194 os_path
195 195 )
196 196 raise web.HTTPError(404, four_o_four)
197 197
198 198 if name is None:
199 199 if '/' in path:
200 200 path, name = path.rsplit('/', 1)
201 201 else:
202 202 name = ''
203 203 model = self._base_model(name, path)
204 204 model['type'] = 'directory'
205 205 dir_path = u'{}/{}'.format(path, name)
206 206 if content:
207 207 model['content'] = contents = []
208 208 for os_path in glob.glob(self._get_os_path('*', dir_path)):
209 209 name = os.path.basename(os_path)
210 210 if self.should_list(name) and not is_hidden(os_path, self.root_dir):
211 211 contents.append(self.get_model(name=name, path=dir_path, content=False))
212 212
213 213 model['format'] = 'json'
214 214
215 215 return model
216 216
217 217 def _file_model(self, name, path='', content=True):
218 218 """Build a model for a file
219 219
220 220 if content is requested, include the file contents.
221 221 UTF-8 text files will be unicode, binary files will be base64-encoded.
222 222 """
223 223 model = self._base_model(name, path)
224 224 model['type'] = 'file'
225 225 if content:
226 226 os_path = self._get_os_path(name, path)
227 227 with io.open(os_path, 'rb') as f:
228 228 bcontent = f.read()
229 229 try:
230 230 model['content'] = bcontent.decode('utf8')
231 231 except UnicodeError as e:
232 232 model['content'] = base64.encodestring(bcontent).decode('ascii')
233 233 model['format'] = 'base64'
234 234 else:
235 235 model['format'] = 'text'
236 236 return model
237 237
238 238
239 239 def _notebook_model(self, name, path='', content=True):
240 240 """Build a notebook model
241 241
242 242 if content is requested, the notebook content will be populated
243 243 as a JSON structure (not double-serialized)
244 244 """
245 245 model = self._base_model(name, path)
246 246 model['type'] = 'notebook'
247 247 if content:
248 248 os_path = self._get_os_path(name, path)
249 249 with io.open(os_path, 'r', encoding='utf-8') as f:
250 250 try:
251 251 nb = current.read(f, u'json')
252 252 except Exception as e:
253 253 raise web.HTTPError(400, u"Unreadable Notebook: %s %s" % (os_path, e))
254 254 self.mark_trusted_cells(nb, name, path)
255 255 model['content'] = nb
256 256 model['format'] = 'json'
257 257 return model
258 258
259 259 def get_model(self, name, path='', content=True):
260 260 """ Takes a path and name for an entity and returns its model
261 261
262 262 Parameters
263 263 ----------
264 264 name : str
265 265 the name of the target
266 266 path : str
267 267 the API path that describes the relative path for the target
268 268
269 269 Returns
270 270 -------
271 271 model : dict
272 272 the contents model. If content=True, returns the contents
273 273 of the file or directory as well.
274 274 """
275 275 path = path.strip('/')
276 276
277 277 if not self.exists(name=name, path=path):
278 278 raise web.HTTPError(404, u'No such file or directory: %s/%s' % (path, name))
279 279
280 280 os_path = self._get_os_path(name, path)
281 281 if os.path.isdir(os_path):
282 282 model = self._dir_model(name, path, content)
283 283 elif name.endswith('.ipynb'):
284 284 model = self._notebook_model(name, path, content)
285 285 else:
286 286 model = self._file_model(name, path, content)
287 287 return model
288 288
289 289 def _save_notebook(self, os_path, model, name='', path=''):
290 290 """save a notebook file"""
291 291 # Save the notebook file
292 292 nb = current.to_notebook_json(model['content'])
293 293
294 294 self.check_and_sign(nb, name, path)
295 295
296 296 if 'name' in nb['metadata']:
297 297 nb['metadata']['name'] = u''
298 298
299 299 with atomic_writing(os_path, encoding='utf-8') as f:
300 300 current.write(nb, f, u'json')
301 301
302 302 def _save_file(self, os_path, model, name='', path=''):
303 303 """save a non-notebook file"""
304 304 fmt = model.get('format', None)
305 305 if fmt not in {'text', 'base64'}:
306 306 raise web.HTTPError(400, "Must specify format of file contents as 'text' or 'base64'")
307 307 try:
308 308 content = model['content']
309 309 if fmt == 'text':
310 310 bcontent = content.encode('utf8')
311 311 else:
312 312 b64_bytes = content.encode('ascii')
313 313 bcontent = base64.decodestring(b64_bytes)
314 314 except Exception as e:
315 315 raise web.HTTPError(400, u'Encoding error saving %s: %s' % (os_path, e))
316 with atomic_writing(os_path, 'wb') as f:
316 with atomic_writing(os_path, text=False) as f:
317 317 f.write(bcontent)
318 318
319 319 def _save_directory(self, os_path, model, name='', path=''):
320 320 """create a directory"""
321 321 if is_hidden(os_path, self.root_dir):
322 322 raise web.HTTPError(400, u'Cannot create hidden directory %r' % os_path)
323 323 if not os.path.exists(os_path):
324 324 os.mkdir(os_path)
325 325 elif not os.path.isdir(os_path):
326 326 raise web.HTTPError(400, u'Not a directory: %s' % (os_path))
327 327 else:
328 328 self.log.debug("Directory %r already exists", os_path)
329 329
330 330 def save(self, model, name='', path=''):
331 331 """Save the file model and return the model with no content."""
332 332 path = path.strip('/')
333 333
334 334 if 'type' not in model:
335 335 raise web.HTTPError(400, u'No file type provided')
336 336 if 'content' not in model and model['type'] != 'directory':
337 337 raise web.HTTPError(400, u'No file content provided')
338 338
339 339 # One checkpoint should always exist
340 340 if self.file_exists(name, path) and not self.list_checkpoints(name, path):
341 341 self.create_checkpoint(name, path)
342 342
343 343 new_path = model.get('path', path).strip('/')
344 344 new_name = model.get('name', name)
345 345
346 346 if path != new_path or name != new_name:
347 347 self.rename(name, path, new_name, new_path)
348 348
349 349 os_path = self._get_os_path(new_name, new_path)
350 350 self.log.debug("Saving %s", os_path)
351 351 try:
352 352 if model['type'] == 'notebook':
353 353 self._save_notebook(os_path, model, new_name, new_path)
354 354 elif model['type'] == 'file':
355 355 self._save_file(os_path, model, new_name, new_path)
356 356 elif model['type'] == 'directory':
357 357 self._save_directory(os_path, model, new_name, new_path)
358 358 else:
359 359 raise web.HTTPError(400, "Unhandled contents type: %s" % model['type'])
360 360 except web.HTTPError:
361 361 raise
362 362 except Exception as e:
363 363 raise web.HTTPError(400, u'Unexpected error while saving file: %s %s' % (os_path, e))
364 364
365 365 model = self.get_model(new_name, new_path, content=False)
366 366 return model
367 367
368 368 def update(self, model, name, path=''):
369 369 """Update the file's path and/or name
370 370
371 371 For use in PATCH requests, to enable renaming a file without
372 372 re-uploading its contents. Only used for renaming at the moment.
373 373 """
374 374 path = path.strip('/')
375 375 new_name = model.get('name', name)
376 376 new_path = model.get('path', path).strip('/')
377 377 if path != new_path or name != new_name:
378 378 self.rename(name, path, new_name, new_path)
379 379 model = self.get_model(new_name, new_path, content=False)
380 380 return model
381 381
382 382 def delete(self, name, path=''):
383 383 """Delete file by name and path."""
384 384 path = path.strip('/')
385 385 os_path = self._get_os_path(name, path)
386 386 rm = os.unlink
387 387 if os.path.isdir(os_path):
388 388 listing = os.listdir(os_path)
389 389 # don't delete non-empty directories (checkpoints dir doesn't count)
390 390 if listing and listing != [self.checkpoint_dir]:
391 391 raise web.HTTPError(400, u'Directory %s not empty' % os_path)
392 392 elif not os.path.isfile(os_path):
393 393 raise web.HTTPError(404, u'File does not exist: %s' % os_path)
394 394
395 395 # clear checkpoints
396 396 for checkpoint in self.list_checkpoints(name, path):
397 397 checkpoint_id = checkpoint['id']
398 398 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
399 399 if os.path.isfile(cp_path):
400 400 self.log.debug("Unlinking checkpoint %s", cp_path)
401 401 os.unlink(cp_path)
402 402
403 403 if os.path.isdir(os_path):
404 404 self.log.debug("Removing directory %s", os_path)
405 405 shutil.rmtree(os_path)
406 406 else:
407 407 self.log.debug("Unlinking file %s", os_path)
408 408 rm(os_path)
409 409
410 410 def rename(self, old_name, old_path, new_name, new_path):
411 411 """Rename a file."""
412 412 old_path = old_path.strip('/')
413 413 new_path = new_path.strip('/')
414 414 if new_name == old_name and new_path == old_path:
415 415 return
416 416
417 417 new_os_path = self._get_os_path(new_name, new_path)
418 418 old_os_path = self._get_os_path(old_name, old_path)
419 419
420 420 # Should we proceed with the move?
421 421 if os.path.isfile(new_os_path):
422 422 raise web.HTTPError(409, u'File with name already exists: %s' % new_os_path)
423 423
424 424 # Move the file
425 425 try:
426 426 shutil.move(old_os_path, new_os_path)
427 427 except Exception as e:
428 428 raise web.HTTPError(500, u'Unknown error renaming file: %s %s' % (old_os_path, e))
429 429
430 430 # Move the checkpoints
431 431 old_checkpoints = self.list_checkpoints(old_name, old_path)
432 432 for cp in old_checkpoints:
433 433 checkpoint_id = cp['id']
434 434 old_cp_path = self.get_checkpoint_path(checkpoint_id, old_name, old_path)
435 435 new_cp_path = self.get_checkpoint_path(checkpoint_id, new_name, new_path)
436 436 if os.path.isfile(old_cp_path):
437 437 self.log.debug("Renaming checkpoint %s -> %s", old_cp_path, new_cp_path)
438 438 shutil.move(old_cp_path, new_cp_path)
439 439
440 440 # Checkpoint-related utilities
441 441
442 442 def get_checkpoint_path(self, checkpoint_id, name, path=''):
443 443 """find the path to a checkpoint"""
444 444 path = path.strip('/')
445 445 basename, ext = os.path.splitext(name)
446 446 filename = u"{name}-{checkpoint_id}{ext}".format(
447 447 name=basename,
448 448 checkpoint_id=checkpoint_id,
449 449 ext=ext,
450 450 )
451 451 os_path = self._get_os_path(path=path)
452 452 cp_dir = os.path.join(os_path, self.checkpoint_dir)
453 453 ensure_dir_exists(cp_dir)
454 454 cp_path = os.path.join(cp_dir, filename)
455 455 return cp_path
456 456
457 457 def get_checkpoint_model(self, checkpoint_id, name, path=''):
458 458 """construct the info dict for a given checkpoint"""
459 459 path = path.strip('/')
460 460 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
461 461 stats = os.stat(cp_path)
462 462 last_modified = tz.utcfromtimestamp(stats.st_mtime)
463 463 info = dict(
464 464 id = checkpoint_id,
465 465 last_modified = last_modified,
466 466 )
467 467 return info
468 468
469 469 # public checkpoint API
470 470
471 471 def create_checkpoint(self, name, path=''):
472 472 """Create a checkpoint from the current state of a file"""
473 473 path = path.strip('/')
474 474 src_path = self._get_os_path(name, path)
475 475 # only the one checkpoint ID:
476 476 checkpoint_id = u"checkpoint"
477 477 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
478 478 self.log.debug("creating checkpoint for %s", name)
479 479 self._copy(src_path, cp_path)
480 480
481 481 # return the checkpoint info
482 482 return self.get_checkpoint_model(checkpoint_id, name, path)
483 483
484 484 def list_checkpoints(self, name, path=''):
485 485 """list the checkpoints for a given file
486 486
487 487 This contents manager currently only supports one checkpoint per file.
488 488 """
489 489 path = path.strip('/')
490 490 checkpoint_id = "checkpoint"
491 491 os_path = self.get_checkpoint_path(checkpoint_id, name, path)
492 492 if not os.path.exists(os_path):
493 493 return []
494 494 else:
495 495 return [self.get_checkpoint_model(checkpoint_id, name, path)]
496 496
497 497
498 498 def restore_checkpoint(self, checkpoint_id, name, path=''):
499 499 """restore a file to a checkpointed state"""
500 500 path = path.strip('/')
501 501 self.log.info("restoring %s from checkpoint %s", name, checkpoint_id)
502 502 nb_path = self._get_os_path(name, path)
503 503 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
504 504 if not os.path.isfile(cp_path):
505 505 self.log.debug("checkpoint file does not exist: %s", cp_path)
506 506 raise web.HTTPError(404,
507 507 u'checkpoint does not exist: %s-%s' % (name, checkpoint_id)
508 508 )
509 509 # ensure notebook is readable (never restore from an unreadable notebook)
510 510 if cp_path.endswith('.ipynb'):
511 511 with io.open(cp_path, 'r', encoding='utf-8') as f:
512 512 current.read(f, u'json')
513 513 self._copy(cp_path, nb_path)
514 514 self.log.debug("copying %s -> %s", cp_path, nb_path)
515 515
516 516 def delete_checkpoint(self, checkpoint_id, name, path=''):
517 517 """delete a file's checkpoint"""
518 518 path = path.strip('/')
519 519 cp_path = self.get_checkpoint_path(checkpoint_id, name, path)
520 520 if not os.path.isfile(cp_path):
521 521 raise web.HTTPError(404,
522 522 u'Checkpoint does not exist: %s%s-%s' % (path, name, checkpoint_id)
523 523 )
524 524 self.log.debug("unlinking %s", cp_path)
525 525 os.unlink(cp_path)
526 526
527 527 def info_string(self):
528 528 return "Serving notebooks from local directory: %s" % self.root_dir
529 529
530 530 def get_kernel_path(self, name, path='', model=None):
531 531 """Return the initial working dir a kernel associated with a given notebook"""
532 532 return os.path.join(self.root_dir, path)
@@ -1,280 +1,315 b''
1 1 # encoding: utf-8
2 2 """
3 3 IO related utilities.
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2008-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12 from __future__ import print_function
13 13 from __future__ import absolute_import
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18 import codecs
19 19 from contextlib import contextmanager
20 import io
20 21 import os
21 22 import sys
22 23 import tempfile
23 24 from .capture import CapturedIO, capture_output
24 25 from .py3compat import string_types, input, PY3
25 26
26 27 #-----------------------------------------------------------------------------
27 28 # Code
28 29 #-----------------------------------------------------------------------------
29 30
30 31
31 32 class IOStream:
32 33
33 34 def __init__(self,stream, fallback=None):
34 35 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
35 36 if fallback is not None:
36 37 stream = fallback
37 38 else:
38 39 raise ValueError("fallback required, but not specified")
39 40 self.stream = stream
40 41 self._swrite = stream.write
41 42
42 43 # clone all methods not overridden:
43 44 def clone(meth):
44 45 return not hasattr(self, meth) and not meth.startswith('_')
45 46 for meth in filter(clone, dir(stream)):
46 47 setattr(self, meth, getattr(stream, meth))
47 48
48 49 def __repr__(self):
49 50 cls = self.__class__
50 51 tpl = '{mod}.{cls}({args})'
51 52 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
52 53
53 54 def write(self,data):
54 55 try:
55 56 self._swrite(data)
56 57 except:
57 58 try:
58 59 # print handles some unicode issues which may trip a plain
59 60 # write() call. Emulate write() by using an empty end
60 61 # argument.
61 62 print(data, end='', file=self.stream)
62 63 except:
63 64 # if we get here, something is seriously broken.
64 65 print('ERROR - failed to write data to stream:', self.stream,
65 66 file=sys.stderr)
66 67
67 68 def writelines(self, lines):
68 69 if isinstance(lines, string_types):
69 70 lines = [lines]
70 71 for line in lines:
71 72 self.write(line)
72 73
73 74 # This class used to have a writeln method, but regular files and streams
74 75 # in Python don't have this method. We need to keep this completely
75 76 # compatible so we removed it.
76 77
77 78 @property
78 79 def closed(self):
79 80 return self.stream.closed
80 81
81 82 def close(self):
82 83 pass
83 84
84 85 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
85 86 devnull = open(os.devnull, 'w')
86 87 stdin = IOStream(sys.stdin, fallback=devnull)
87 88 stdout = IOStream(sys.stdout, fallback=devnull)
88 89 stderr = IOStream(sys.stderr, fallback=devnull)
89 90
90 91 class IOTerm:
91 92 """ Term holds the file or file-like objects for handling I/O operations.
92 93
93 94 These are normally just sys.stdin, sys.stdout and sys.stderr but for
94 95 Windows they can can replaced to allow editing the strings before they are
95 96 displayed."""
96 97
97 98 # In the future, having IPython channel all its I/O operations through
98 99 # this class will make it easier to embed it into other environments which
99 100 # are not a normal terminal (such as a GUI-based shell)
100 101 def __init__(self, stdin=None, stdout=None, stderr=None):
101 102 mymodule = sys.modules[__name__]
102 103 self.stdin = IOStream(stdin, mymodule.stdin)
103 104 self.stdout = IOStream(stdout, mymodule.stdout)
104 105 self.stderr = IOStream(stderr, mymodule.stderr)
105 106
106 107
107 108 class Tee(object):
108 109 """A class to duplicate an output stream to stdout/err.
109 110
110 111 This works in a manner very similar to the Unix 'tee' command.
111 112
112 113 When the object is closed or deleted, it closes the original file given to
113 114 it for duplication.
114 115 """
115 116 # Inspired by:
116 117 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
117 118
118 119 def __init__(self, file_or_name, mode="w", channel='stdout'):
119 120 """Construct a new Tee object.
120 121
121 122 Parameters
122 123 ----------
123 124 file_or_name : filename or open filehandle (writable)
124 125 File that will be duplicated
125 126
126 127 mode : optional, valid mode for open().
127 128 If a filename was give, open with this mode.
128 129
129 130 channel : str, one of ['stdout', 'stderr']
130 131 """
131 132 if channel not in ['stdout', 'stderr']:
132 133 raise ValueError('Invalid channel spec %s' % channel)
133 134
134 135 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
135 136 self.file = file_or_name
136 137 else:
137 138 self.file = open(file_or_name, mode)
138 139 self.channel = channel
139 140 self.ostream = getattr(sys, channel)
140 141 setattr(sys, channel, self)
141 142 self._closed = False
142 143
143 144 def close(self):
144 145 """Close the file and restore the channel."""
145 146 self.flush()
146 147 setattr(sys, self.channel, self.ostream)
147 148 self.file.close()
148 149 self._closed = True
149 150
150 151 def write(self, data):
151 152 """Write data to both channels."""
152 153 self.file.write(data)
153 154 self.ostream.write(data)
154 155 self.ostream.flush()
155 156
156 157 def flush(self):
157 158 """Flush both channels."""
158 159 self.file.flush()
159 160 self.ostream.flush()
160 161
161 162 def __del__(self):
162 163 if not self._closed:
163 164 self.close()
164 165
165 166
166 167 def ask_yes_no(prompt, default=None, interrupt=None):
167 168 """Asks a question and returns a boolean (y/n) answer.
168 169
169 170 If default is given (one of 'y','n'), it is used if the user input is
170 171 empty. If interrupt is given (one of 'y','n'), it is used if the user
171 172 presses Ctrl-C. Otherwise the question is repeated until an answer is
172 173 given.
173 174
174 175 An EOF is treated as the default answer. If there is no default, an
175 176 exception is raised to prevent infinite loops.
176 177
177 178 Valid answers are: y/yes/n/no (match is not case sensitive)."""
178 179
179 180 answers = {'y':True,'n':False,'yes':True,'no':False}
180 181 ans = None
181 182 while ans not in answers.keys():
182 183 try:
183 184 ans = input(prompt+' ').lower()
184 185 if not ans: # response was an empty string
185 186 ans = default
186 187 except KeyboardInterrupt:
187 188 if interrupt:
188 189 ans = interrupt
189 190 except EOFError:
190 191 if default in answers.keys():
191 192 ans = default
192 193 print()
193 194 else:
194 195 raise
195 196
196 197 return answers[ans]
197 198
198 199
199 200 def temp_pyfile(src, ext='.py'):
200 201 """Make a temporary python file, return filename and filehandle.
201 202
202 203 Parameters
203 204 ----------
204 205 src : string or list of strings (no need for ending newlines if list)
205 206 Source code to be written to the file.
206 207
207 208 ext : optional, string
208 209 Extension for the generated file.
209 210
210 211 Returns
211 212 -------
212 213 (filename, open filehandle)
213 214 It is the caller's responsibility to close the open file and unlink it.
214 215 """
215 216 fname = tempfile.mkstemp(ext)[1]
216 217 f = open(fname,'w')
217 218 f.write(src)
218 219 f.flush()
219 220 return fname, f
220 221
221 222 @contextmanager
222 def atomic_writing(path, mode='w', encoding='utf-8', **kwargs):
223 tmp_file = path + '.tmp-write'
224 if 'b' in mode:
225 encoding = None
223 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
224 """Context manager to write to a file only if the entire write is successful.
226 225
227 with open(tmp_file, mode, encoding=encoding, **kwargs) as f:
228 yield f
226 This works by creating a temporary file in the same directory, and renaming
227 it over the old file if the context is exited without an error. If the
228 target file is a symlink or a hardlink, this will not be preserved: it will
229 be replaced by a new regular file.
229 230
231 On Windows, there is a small chink in the atomicity: the target file is
232 deleted before renaming the temporary file over it. This appears to be
233 unavoidable.
234
235 Parameters
236 ----------
237 path : str
238 The target file to write to.
239
240 text : bool, optional
241 Whether to open the file in text mode (i.e. to write unicode). Default is
242 True.
243
244 encoding : str, optional
245 The encoding to use for files opened in text mode. Default is UTF-8.
246
247 **kwargs
248 Passed to :func:`io.open`.
249 """
250 dirname, basename = os.path.split(path)
251 handle, tmp_path = tempfile.mkstemp(prefix=basename, dir=dirname, text=text)
252 if text:
253 fileobj = io.open(handle, 'w', encoding=encoding, **kwargs)
254 else:
255 fileobj = io.open(handle, 'wb', **kwargs)
256
257 try:
258 yield fileobj
259 except:
260 fileobj.close()
261 os.remove(tmp_path)
262 raise
263 else:
230 264 # Written successfully, now rename it
265 fileobj.close()
231 266
232 267 if os.name == 'nt' and os.path.exists(path):
233 268 # Rename over existing file doesn't work on Windows
234 269 os.remove(path)
235 270
236 os.rename(tmp_file, path)
271 os.rename(tmp_path, path)
237 272
238 273
239 274 def raw_print(*args, **kw):
240 275 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
241 276
242 277 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
243 278 file=sys.__stdout__)
244 279 sys.__stdout__.flush()
245 280
246 281
247 282 def raw_print_err(*args, **kw):
248 283 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
249 284
250 285 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
251 286 file=sys.__stderr__)
252 287 sys.__stderr__.flush()
253 288
254 289
255 290 # Short aliases for quick debugging, do NOT use these in production code.
256 291 rprint = raw_print
257 292 rprinte = raw_print_err
258 293
259 294 def unicode_std_stream(stream='stdout'):
260 295 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
261 296
262 297 This ignores environment variables and default encodings, to reliably write
263 298 unicode to stdout or stderr.
264 299
265 300 ::
266 301
267 302 unicode_std_stream().write(u'ł@e¶ŧ←')
268 303 """
269 304 assert stream in ('stdout', 'stderr')
270 305 stream = getattr(sys, stream)
271 306 if PY3:
272 307 try:
273 308 stream_b = stream.buffer
274 309 except AttributeError:
275 310 # sys.stdout has been replaced - use it directly
276 311 return stream
277 312 else:
278 313 stream_b = stream
279 314
280 315 return codecs.getwriter('utf-8')(stream_b)
@@ -1,124 +1,151 b''
1 1 # encoding: utf-8
2 2 """Tests for io.py"""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Copyright (C) 2008-2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-----------------------------------------------------------------------------
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Imports
13 13 #-----------------------------------------------------------------------------
14 14 from __future__ import print_function
15 15 from __future__ import absolute_import
16 16
17 17 import io as stdlib_io
18 import os.path
18 19 import sys
19 20
20 21 from subprocess import Popen, PIPE
21 22 import unittest
22 23
23 24 import nose.tools as nt
24 25
25 26 from IPython.testing.decorators import skipif
26 from IPython.utils.io import Tee, capture_output, unicode_std_stream
27 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
28 atomic_writing,
29 )
27 30 from IPython.utils.py3compat import doctest_refactor_print, PY3
31 from IPython.utils.tempdir import TemporaryDirectory
28 32
29 33 if PY3:
30 34 from io import StringIO
31 35 else:
32 36 from StringIO import StringIO
33 37
34 38 #-----------------------------------------------------------------------------
35 39 # Tests
36 40 #-----------------------------------------------------------------------------
37 41
38 42
39 43 def test_tee_simple():
40 44 "Very simple check with stdout only"
41 45 chan = StringIO()
42 46 text = 'Hello'
43 47 tee = Tee(chan, channel='stdout')
44 48 print(text, file=chan)
45 49 nt.assert_equal(chan.getvalue(), text+"\n")
46 50
47 51
48 52 class TeeTestCase(unittest.TestCase):
49 53
50 54 def tchan(self, channel, check='close'):
51 55 trap = StringIO()
52 56 chan = StringIO()
53 57 text = 'Hello'
54 58
55 59 std_ori = getattr(sys, channel)
56 60 setattr(sys, channel, trap)
57 61
58 62 tee = Tee(chan, channel=channel)
59 63 print(text, end='', file=chan)
60 64 setattr(sys, channel, std_ori)
61 65 trap_val = trap.getvalue()
62 66 nt.assert_equal(chan.getvalue(), text)
63 67 if check=='close':
64 68 tee.close()
65 69 else:
66 70 del tee
67 71
68 72 def test(self):
69 73 for chan in ['stdout', 'stderr']:
70 74 for check in ['close', 'del']:
71 75 self.tchan(chan, check)
72 76
73 77 def test_io_init():
74 78 """Test that io.stdin/out/err exist at startup"""
75 79 for name in ('stdin', 'stdout', 'stderr'):
76 80 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
77 81 p = Popen([sys.executable, '-c', cmd],
78 82 stdout=PIPE)
79 83 p.wait()
80 84 classname = p.stdout.read().strip().decode('ascii')
81 85 # __class__ is a reference to the class object in Python 3, so we can't
82 86 # just test for string equality.
83 87 assert 'IPython.utils.io.IOStream' in classname, classname
84 88
85 89 def test_capture_output():
86 90 """capture_output() context works"""
87 91
88 92 with capture_output() as io:
89 93 print('hi, stdout')
90 94 print('hi, stderr', file=sys.stderr)
91 95
92 96 nt.assert_equal(io.stdout, 'hi, stdout\n')
93 97 nt.assert_equal(io.stderr, 'hi, stderr\n')
94 98
95 99 def test_UnicodeStdStream():
96 100 # Test wrapping a bytes-level stdout
97 101 if PY3:
98 102 stdoutb = stdlib_io.BytesIO()
99 103 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
100 104 else:
101 105 stdout = stdoutb = stdlib_io.BytesIO()
102 106
103 107 orig_stdout = sys.stdout
104 108 sys.stdout = stdout
105 109 try:
106 110 sample = u"@łe¶ŧ←"
107 111 unicode_std_stream().write(sample)
108 112
109 113 output = stdoutb.getvalue().decode('utf-8')
110 114 nt.assert_equal(output, sample)
111 115 assert not stdout.closed
112 116 finally:
113 117 sys.stdout = orig_stdout
114 118
115 119 @skipif(not PY3, "Not applicable on Python 2")
116 120 def test_UnicodeStdStream_nowrap():
117 121 # If we replace stdout with a StringIO, it shouldn't get wrapped.
118 122 orig_stdout = sys.stdout
119 123 sys.stdout = StringIO()
120 124 try:
121 125 nt.assert_is(unicode_std_stream(), sys.stdout)
122 126 assert not sys.stdout.closed
123 127 finally:
124 128 sys.stdout = orig_stdout
129
130 def test_atomic_writing():
131 class CustomExc(Exception): pass
132
133 with TemporaryDirectory() as td:
134 f1 = os.path.join(td, 'penguin')
135 with stdlib_io.open(f1, 'w') as f:
136 f.write(u'Before')
137
138 with nt.assert_raises(CustomExc):
139 with atomic_writing(f1) as f:
140 f.write(u'Failing write')
141 raise CustomExc
142
143 # Because of the exception, the file should not have been modified
144 with stdlib_io.open(f1, 'r') as f:
145 nt.assert_equal(f.read(), u'Before')
146
147 with atomic_writing(f1) as f:
148 f.write(u'Overwritten')
149
150 with stdlib_io.open(f1, 'r') as f:
151 nt.assert_equal(f.read(), u'Overwritten')
General Comments 0
You need to be logged in to leave comments. Login now