##// END OF EJS Templates
Merge pull request #8240 from minrk/split-io...
Thomas Kluyver -
r21119:a229520a merge
parent child Browse files
Show More
@@ -0,0 +1,131 b''
1 # encoding: utf-8
2 """Tests for file IO"""
3
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
6
7 import io as stdlib_io
8 import os.path
9 import stat
10
11 import nose.tools as nt
12
13 from IPython.testing.decorators import skip_win32
14 from ..fileio import atomic_writing
15
16 from IPython.utils.tempdir import TemporaryDirectory
17
18 umask = 0
19
20 def test_atomic_writing():
21 class CustomExc(Exception): pass
22
23 with TemporaryDirectory() as td:
24 f1 = os.path.join(td, 'penguin')
25 with stdlib_io.open(f1, 'w') as f:
26 f.write(u'Before')
27
28 if os.name != 'nt':
29 os.chmod(f1, 0o701)
30 orig_mode = stat.S_IMODE(os.stat(f1).st_mode)
31
32 f2 = os.path.join(td, 'flamingo')
33 try:
34 os.symlink(f1, f2)
35 have_symlink = True
36 except (AttributeError, NotImplementedError, OSError):
37 # AttributeError: Python doesn't support it
38 # NotImplementedError: The system doesn't support it
39 # OSError: The user lacks the privilege (Windows)
40 have_symlink = False
41
42 with nt.assert_raises(CustomExc):
43 with atomic_writing(f1) as f:
44 f.write(u'Failing write')
45 raise CustomExc
46
47 # Because of the exception, the file should not have been modified
48 with stdlib_io.open(f1, 'r') as f:
49 nt.assert_equal(f.read(), u'Before')
50
51 with atomic_writing(f1) as f:
52 f.write(u'Overwritten')
53
54 with stdlib_io.open(f1, 'r') as f:
55 nt.assert_equal(f.read(), u'Overwritten')
56
57 if os.name != 'nt':
58 mode = stat.S_IMODE(os.stat(f1).st_mode)
59 nt.assert_equal(mode, orig_mode)
60
61 if have_symlink:
62 # Check that writing over a file preserves a symlink
63 with atomic_writing(f2) as f:
64 f.write(u'written from symlink')
65
66 with stdlib_io.open(f1, 'r') as f:
67 nt.assert_equal(f.read(), u'written from symlink')
68
69 def _save_umask():
70 global umask
71 umask = os.umask(0)
72 os.umask(umask)
73
74 def _restore_umask():
75 os.umask(umask)
76
77 @skip_win32
78 @nt.with_setup(_save_umask, _restore_umask)
79 def test_atomic_writing_umask():
80 with TemporaryDirectory() as td:
81 os.umask(0o022)
82 f1 = os.path.join(td, '1')
83 with atomic_writing(f1) as f:
84 f.write(u'1')
85 mode = stat.S_IMODE(os.stat(f1).st_mode)
86 nt.assert_equal(mode, 0o644, '{:o} != 644'.format(mode))
87
88 os.umask(0o057)
89 f2 = os.path.join(td, '2')
90 with atomic_writing(f2) as f:
91 f.write(u'2')
92 mode = stat.S_IMODE(os.stat(f2).st_mode)
93 nt.assert_equal(mode, 0o620, '{:o} != 620'.format(mode))
94
95
96 def test_atomic_writing_newlines():
97 with TemporaryDirectory() as td:
98 path = os.path.join(td, 'testfile')
99
100 lf = u'a\nb\nc\n'
101 plat = lf.replace(u'\n', os.linesep)
102 crlf = lf.replace(u'\n', u'\r\n')
103
104 # test default
105 with stdlib_io.open(path, 'w') as f:
106 f.write(lf)
107 with stdlib_io.open(path, 'r', newline='') as f:
108 read = f.read()
109 nt.assert_equal(read, plat)
110
111 # test newline=LF
112 with stdlib_io.open(path, 'w', newline='\n') as f:
113 f.write(lf)
114 with stdlib_io.open(path, 'r', newline='') as f:
115 read = f.read()
116 nt.assert_equal(read, lf)
117
118 # test newline=CRLF
119 with atomic_writing(path, newline='\r\n') as f:
120 f.write(lf)
121 with stdlib_io.open(path, 'r', newline='') as f:
122 read = f.read()
123 nt.assert_equal(read, crlf)
124
125 # test newline=no convert
126 text = u'crlf\r\ncr\rlf\n'
127 with atomic_writing(path, newline='') as f:
128 f.write(text)
129 with stdlib_io.open(path, 'r', newline='') as f:
130 read = f.read()
131 nt.assert_equal(read, text)
@@ -0,0 +1,33 b''
1 # coding: utf-8
2 """io-related utilities"""
3
4 # Copyright (c) Jupyter Development Team.
5 # Distributed under the terms of the Modified BSD License.
6
7 import codecs
8 import sys
9 from IPython.utils.py3compat import PY3
10
11
12 def unicode_std_stream(stream='stdout'):
13 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
14
15 This ignores environment variables and default encodings, to reliably write
16 unicode to stdout or stderr.
17
18 ::
19
20 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
21 """
22 assert stream in ('stdout', 'stderr')
23 stream = getattr(sys, stream)
24 if PY3:
25 try:
26 stream_b = stream.buffer
27 except AttributeError:
28 # sys.stdout has been replaced - use it directly
29 return stream
30 else:
31 stream_b = stream
32
33 return codecs.getwriter('utf-8')(stream_b)
@@ -0,0 +1,50 b''
1 # encoding: utf-8
2 """Tests for utils.io"""
3
4 # Copyright (c) Jupyter Development Team.
5 # Distributed under the terms of the Modified BSD License.
6
7 import io as stdlib_io
8 import sys
9
10 import nose.tools as nt
11
12 from IPython.testing.decorators import skipif
13 from ..io import unicode_std_stream
14 from IPython.utils.py3compat import PY3
15
16 if PY3:
17 from io import StringIO
18 else:
19 from StringIO import StringIO
20
21 def test_UnicodeStdStream():
22 # Test wrapping a bytes-level stdout
23 if PY3:
24 stdoutb = stdlib_io.BytesIO()
25 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
26 else:
27 stdout = stdoutb = stdlib_io.BytesIO()
28
29 orig_stdout = sys.stdout
30 sys.stdout = stdout
31 try:
32 sample = u"@Ε‚e¢ŧ←"
33 unicode_std_stream().write(sample)
34
35 output = stdoutb.getvalue().decode('utf-8')
36 nt.assert_equal(output, sample)
37 assert not stdout.closed
38 finally:
39 sys.stdout = orig_stdout
40
41 @skipif(not PY3, "Not applicable on Python 2")
42 def test_UnicodeStdStream_nowrap():
43 # If we replace stdout with a StringIO, it shouldn't get wrapped.
44 orig_stdout = sys.stdout
45 sys.stdout = StringIO()
46 try:
47 nt.assert_is(unicode_std_stream(), sys.stdout)
48 assert not sys.stdout.closed
49 finally:
50 sys.stdout = orig_stdout
@@ -1,174 +1,256 b''
1 """
1 """
2 Utilities for file-based Contents/Checkpoints managers.
2 Utilities for file-based Contents/Checkpoints managers.
3 """
3 """
4
4
5 # Copyright (c) IPython Development Team.
5 # Copyright (c) IPython Development Team.
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7
7
8 import base64
8 import base64
9 from contextlib import contextmanager
9 from contextlib import contextmanager
10 import errno
10 import errno
11 import io
11 import io
12 import os
12 import os
13 import shutil
13 import shutil
14 import tempfile
14
15
15 from tornado.web import HTTPError
16 from tornado.web import HTTPError
16
17
17 from IPython.html.utils import (
18 from IPython.html.utils import (
18 to_api_path,
19 to_api_path,
19 to_os_path,
20 to_os_path,
20 )
21 )
21 from IPython import nbformat
22 from IPython import nbformat
22 from IPython.utils.io import atomic_writing
23 from IPython.utils.py3compat import str_to_unicode
23 from IPython.utils.py3compat import str_to_unicode
24
24
25
25
26 def _copy_metadata(src, dst):
27 """Copy the set of metadata we want for atomic_writing.
28
29 Permission bits and flags. We'd like to copy file ownership as well, but we
30 can't do that.
31 """
32 shutil.copymode(src, dst)
33 st = os.stat(src)
34 if hasattr(os, 'chflags') and hasattr(st, 'st_flags'):
35 os.chflags(dst, st.st_flags)
36
37 @contextmanager
38 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
39 """Context manager to write to a file only if the entire write is successful.
40
41 This works by creating a temporary file in the same directory, and renaming
42 it over the old file if the context is exited without an error. If other
43 file names are hard linked to the target file, this relationship will not be
44 preserved.
45
46 On Windows, there is a small chink in the atomicity: the target file is
47 deleted before renaming the temporary file over it. This appears to be
48 unavoidable.
49
50 Parameters
51 ----------
52 path : str
53 The target file to write to.
54
55 text : bool, optional
56 Whether to open the file in text mode (i.e. to write unicode). Default is
57 True.
58
59 encoding : str, optional
60 The encoding to use for files opened in text mode. Default is UTF-8.
61
62 **kwargs
63 Passed to :func:`io.open`.
64 """
65 # realpath doesn't work on Windows: http://bugs.python.org/issue9949
66 # Luckily, we only need to resolve the file itself being a symlink, not
67 # any of its directories, so this will suffice:
68 if os.path.islink(path):
69 path = os.path.join(os.path.dirname(path), os.readlink(path))
70
71 dirname, basename = os.path.split(path)
72 tmp_dir = tempfile.mkdtemp(prefix=basename, dir=dirname)
73 tmp_path = os.path.join(tmp_dir, basename)
74 if text:
75 fileobj = io.open(tmp_path, 'w', encoding=encoding, **kwargs)
76 else:
77 fileobj = io.open(tmp_path, 'wb', **kwargs)
78
79 try:
80 yield fileobj
81 except:
82 fileobj.close()
83 shutil.rmtree(tmp_dir)
84 raise
85
86 # Flush to disk
87 fileobj.flush()
88 os.fsync(fileobj.fileno())
89
90 # Written successfully, now rename it
91 fileobj.close()
92
93 # Copy permission bits, access time, etc.
94 try:
95 _copy_metadata(path, tmp_path)
96 except OSError:
97 # e.g. the file didn't already exist. Ignore any failure to copy metadata
98 pass
99
100 if os.name == 'nt' and os.path.exists(path):
101 # Rename over existing file doesn't work on Windows
102 os.remove(path)
103
104 os.rename(tmp_path, path)
105 shutil.rmtree(tmp_dir)
106
107
26 class FileManagerMixin(object):
108 class FileManagerMixin(object):
27 """
109 """
28 Mixin for ContentsAPI classes that interact with the filesystem.
110 Mixin for ContentsAPI classes that interact with the filesystem.
29
111
30 Provides facilities for reading, writing, and copying both notebooks and
112 Provides facilities for reading, writing, and copying both notebooks and
31 generic files.
113 generic files.
32
114
33 Shared by FileContentsManager and FileCheckpoints.
115 Shared by FileContentsManager and FileCheckpoints.
34
116
35 Note
117 Note
36 ----
118 ----
37 Classes using this mixin must provide the following attributes:
119 Classes using this mixin must provide the following attributes:
38
120
39 root_dir : unicode
121 root_dir : unicode
40 A directory against against which API-style paths are to be resolved.
122 A directory against against which API-style paths are to be resolved.
41
123
42 log : logging.Logger
124 log : logging.Logger
43 """
125 """
44
126
45 @contextmanager
127 @contextmanager
46 def open(self, os_path, *args, **kwargs):
128 def open(self, os_path, *args, **kwargs):
47 """wrapper around io.open that turns permission errors into 403"""
129 """wrapper around io.open that turns permission errors into 403"""
48 with self.perm_to_403(os_path):
130 with self.perm_to_403(os_path):
49 with io.open(os_path, *args, **kwargs) as f:
131 with io.open(os_path, *args, **kwargs) as f:
50 yield f
132 yield f
51
133
52 @contextmanager
134 @contextmanager
53 def atomic_writing(self, os_path, *args, **kwargs):
135 def atomic_writing(self, os_path, *args, **kwargs):
54 """wrapper around atomic_writing that turns permission errors to 403"""
136 """wrapper around atomic_writing that turns permission errors to 403"""
55 with self.perm_to_403(os_path):
137 with self.perm_to_403(os_path):
56 with atomic_writing(os_path, *args, **kwargs) as f:
138 with atomic_writing(os_path, *args, **kwargs) as f:
57 yield f
139 yield f
58
140
59 @contextmanager
141 @contextmanager
60 def perm_to_403(self, os_path=''):
142 def perm_to_403(self, os_path=''):
61 """context manager for turning permission errors into 403."""
143 """context manager for turning permission errors into 403."""
62 try:
144 try:
63 yield
145 yield
64 except (OSError, IOError) as e:
146 except (OSError, IOError) as e:
65 if e.errno in {errno.EPERM, errno.EACCES}:
147 if e.errno in {errno.EPERM, errno.EACCES}:
66 # make 403 error message without root prefix
148 # make 403 error message without root prefix
67 # this may not work perfectly on unicode paths on Python 2,
149 # this may not work perfectly on unicode paths on Python 2,
68 # but nobody should be doing that anyway.
150 # but nobody should be doing that anyway.
69 if not os_path:
151 if not os_path:
70 os_path = str_to_unicode(e.filename or 'unknown file')
152 os_path = str_to_unicode(e.filename or 'unknown file')
71 path = to_api_path(os_path, root=self.root_dir)
153 path = to_api_path(os_path, root=self.root_dir)
72 raise HTTPError(403, u'Permission denied: %s' % path)
154 raise HTTPError(403, u'Permission denied: %s' % path)
73 else:
155 else:
74 raise
156 raise
75
157
76 def _copy(self, src, dest):
158 def _copy(self, src, dest):
77 """copy src to dest
159 """copy src to dest
78
160
79 like shutil.copy2, but log errors in copystat
161 like shutil.copy2, but log errors in copystat
80 """
162 """
81 shutil.copyfile(src, dest)
163 shutil.copyfile(src, dest)
82 try:
164 try:
83 shutil.copystat(src, dest)
165 shutil.copystat(src, dest)
84 except OSError:
166 except OSError:
85 self.log.debug("copystat on %s failed", dest, exc_info=True)
167 self.log.debug("copystat on %s failed", dest, exc_info=True)
86
168
87 def _get_os_path(self, path):
169 def _get_os_path(self, path):
88 """Given an API path, return its file system path.
170 """Given an API path, return its file system path.
89
171
90 Parameters
172 Parameters
91 ----------
173 ----------
92 path : string
174 path : string
93 The relative API path to the named file.
175 The relative API path to the named file.
94
176
95 Returns
177 Returns
96 -------
178 -------
97 path : string
179 path : string
98 Native, absolute OS path to for a file.
180 Native, absolute OS path to for a file.
99
181
100 Raises
182 Raises
101 ------
183 ------
102 404: if path is outside root
184 404: if path is outside root
103 """
185 """
104 root = os.path.abspath(self.root_dir)
186 root = os.path.abspath(self.root_dir)
105 os_path = to_os_path(path, root)
187 os_path = to_os_path(path, root)
106 if not (os.path.abspath(os_path) + os.path.sep).startswith(root):
188 if not (os.path.abspath(os_path) + os.path.sep).startswith(root):
107 raise HTTPError(404, "%s is outside root contents directory" % path)
189 raise HTTPError(404, "%s is outside root contents directory" % path)
108 return os_path
190 return os_path
109
191
110 def _read_notebook(self, os_path, as_version=4):
192 def _read_notebook(self, os_path, as_version=4):
111 """Read a notebook from an os path."""
193 """Read a notebook from an os path."""
112 with self.open(os_path, 'r', encoding='utf-8') as f:
194 with self.open(os_path, 'r', encoding='utf-8') as f:
113 try:
195 try:
114 return nbformat.read(f, as_version=as_version)
196 return nbformat.read(f, as_version=as_version)
115 except Exception as e:
197 except Exception as e:
116 raise HTTPError(
198 raise HTTPError(
117 400,
199 400,
118 u"Unreadable Notebook: %s %r" % (os_path, e),
200 u"Unreadable Notebook: %s %r" % (os_path, e),
119 )
201 )
120
202
121 def _save_notebook(self, os_path, nb):
203 def _save_notebook(self, os_path, nb):
122 """Save a notebook to an os_path."""
204 """Save a notebook to an os_path."""
123 with self.atomic_writing(os_path, encoding='utf-8') as f:
205 with self.atomic_writing(os_path, encoding='utf-8') as f:
124 nbformat.write(nb, f, version=nbformat.NO_CONVERT)
206 nbformat.write(nb, f, version=nbformat.NO_CONVERT)
125
207
126 def _read_file(self, os_path, format):
208 def _read_file(self, os_path, format):
127 """Read a non-notebook file.
209 """Read a non-notebook file.
128
210
129 os_path: The path to be read.
211 os_path: The path to be read.
130 format:
212 format:
131 If 'text', the contents will be decoded as UTF-8.
213 If 'text', the contents will be decoded as UTF-8.
132 If 'base64', the raw bytes contents will be encoded as base64.
214 If 'base64', the raw bytes contents will be encoded as base64.
133 If not specified, try to decode as UTF-8, and fall back to base64
215 If not specified, try to decode as UTF-8, and fall back to base64
134 """
216 """
135 if not os.path.isfile(os_path):
217 if not os.path.isfile(os_path):
136 raise HTTPError(400, "Cannot read non-file %s" % os_path)
218 raise HTTPError(400, "Cannot read non-file %s" % os_path)
137
219
138 with self.open(os_path, 'rb') as f:
220 with self.open(os_path, 'rb') as f:
139 bcontent = f.read()
221 bcontent = f.read()
140
222
141 if format is None or format == 'text':
223 if format is None or format == 'text':
142 # Try to interpret as unicode if format is unknown or if unicode
224 # Try to interpret as unicode if format is unknown or if unicode
143 # was explicitly requested.
225 # was explicitly requested.
144 try:
226 try:
145 return bcontent.decode('utf8'), 'text'
227 return bcontent.decode('utf8'), 'text'
146 except UnicodeError:
228 except UnicodeError:
147 if format == 'text':
229 if format == 'text':
148 raise HTTPError(
230 raise HTTPError(
149 400,
231 400,
150 "%s is not UTF-8 encoded" % os_path,
232 "%s is not UTF-8 encoded" % os_path,
151 reason='bad format',
233 reason='bad format',
152 )
234 )
153 return base64.encodestring(bcontent).decode('ascii'), 'base64'
235 return base64.encodestring(bcontent).decode('ascii'), 'base64'
154
236
155 def _save_file(self, os_path, content, format):
237 def _save_file(self, os_path, content, format):
156 """Save content of a generic file."""
238 """Save content of a generic file."""
157 if format not in {'text', 'base64'}:
239 if format not in {'text', 'base64'}:
158 raise HTTPError(
240 raise HTTPError(
159 400,
241 400,
160 "Must specify format of file contents as 'text' or 'base64'",
242 "Must specify format of file contents as 'text' or 'base64'",
161 )
243 )
162 try:
244 try:
163 if format == 'text':
245 if format == 'text':
164 bcontent = content.encode('utf8')
246 bcontent = content.encode('utf8')
165 else:
247 else:
166 b64_bytes = content.encode('ascii')
248 b64_bytes = content.encode('ascii')
167 bcontent = base64.decodestring(b64_bytes)
249 bcontent = base64.decodestring(b64_bytes)
168 except Exception as e:
250 except Exception as e:
169 raise HTTPError(
251 raise HTTPError(
170 400, u'Encoding error saving %s: %s' % (os_path, e)
252 400, u'Encoding error saving %s: %s' % (os_path, e)
171 )
253 )
172
254
173 with self.atomic_writing(os_path, text=False) as f:
255 with self.atomic_writing(os_path, text=False) as f:
174 f.write(bcontent)
256 f.write(bcontent)
@@ -1,347 +1,246 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 IO related utilities.
3 IO related utilities.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 # Copyright (c) IPython Development Team.
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Distributed under the terms of the Modified BSD License.
8 #
8
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
12 from __future__ import print_function
9 from __future__ import print_function
13 from __future__ import absolute_import
10 from __future__ import absolute_import
14
11
15 #-----------------------------------------------------------------------------
12
16 # Imports
17 #-----------------------------------------------------------------------------
18 import codecs
13 import codecs
19 from contextlib import contextmanager
14 from contextlib import contextmanager
20 import io
15 import io
21 import os
16 import os
22 import shutil
17 import shutil
23 import stat
24 import sys
18 import sys
25 import tempfile
19 import tempfile
20 import warnings
26 from .capture import CapturedIO, capture_output
21 from .capture import CapturedIO, capture_output
27 from .py3compat import string_types, input, PY3
22 from .py3compat import string_types, input, PY3
28
23
29 #-----------------------------------------------------------------------------
30 # Code
31 #-----------------------------------------------------------------------------
32
33
24
34 class IOStream:
25 class IOStream:
35
26
36 def __init__(self,stream, fallback=None):
27 def __init__(self,stream, fallback=None):
37 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
28 if not hasattr(stream,'write') or not hasattr(stream,'flush'):
38 if fallback is not None:
29 if fallback is not None:
39 stream = fallback
30 stream = fallback
40 else:
31 else:
41 raise ValueError("fallback required, but not specified")
32 raise ValueError("fallback required, but not specified")
42 self.stream = stream
33 self.stream = stream
43 self._swrite = stream.write
34 self._swrite = stream.write
44
35
45 # clone all methods not overridden:
36 # clone all methods not overridden:
46 def clone(meth):
37 def clone(meth):
47 return not hasattr(self, meth) and not meth.startswith('_')
38 return not hasattr(self, meth) and not meth.startswith('_')
48 for meth in filter(clone, dir(stream)):
39 for meth in filter(clone, dir(stream)):
49 setattr(self, meth, getattr(stream, meth))
40 setattr(self, meth, getattr(stream, meth))
50
41
51 def __repr__(self):
42 def __repr__(self):
52 cls = self.__class__
43 cls = self.__class__
53 tpl = '{mod}.{cls}({args})'
44 tpl = '{mod}.{cls}({args})'
54 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
45 return tpl.format(mod=cls.__module__, cls=cls.__name__, args=self.stream)
55
46
56 def write(self,data):
47 def write(self,data):
57 try:
48 try:
58 self._swrite(data)
49 self._swrite(data)
59 except:
50 except:
60 try:
51 try:
61 # print handles some unicode issues which may trip a plain
52 # print handles some unicode issues which may trip a plain
62 # write() call. Emulate write() by using an empty end
53 # write() call. Emulate write() by using an empty end
63 # argument.
54 # argument.
64 print(data, end='', file=self.stream)
55 print(data, end='', file=self.stream)
65 except:
56 except:
66 # if we get here, something is seriously broken.
57 # if we get here, something is seriously broken.
67 print('ERROR - failed to write data to stream:', self.stream,
58 print('ERROR - failed to write data to stream:', self.stream,
68 file=sys.stderr)
59 file=sys.stderr)
69
60
70 def writelines(self, lines):
61 def writelines(self, lines):
71 if isinstance(lines, string_types):
62 if isinstance(lines, string_types):
72 lines = [lines]
63 lines = [lines]
73 for line in lines:
64 for line in lines:
74 self.write(line)
65 self.write(line)
75
66
76 # This class used to have a writeln method, but regular files and streams
67 # This class used to have a writeln method, but regular files and streams
77 # in Python don't have this method. We need to keep this completely
68 # in Python don't have this method. We need to keep this completely
78 # compatible so we removed it.
69 # compatible so we removed it.
79
70
80 @property
71 @property
81 def closed(self):
72 def closed(self):
82 return self.stream.closed
73 return self.stream.closed
83
74
84 def close(self):
75 def close(self):
85 pass
76 pass
86
77
87 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
78 # setup stdin/stdout/stderr to sys.stdin/sys.stdout/sys.stderr
88 devnull = open(os.devnull, 'w')
79 devnull = open(os.devnull, 'w')
89 stdin = IOStream(sys.stdin, fallback=devnull)
80 stdin = IOStream(sys.stdin, fallback=devnull)
90 stdout = IOStream(sys.stdout, fallback=devnull)
81 stdout = IOStream(sys.stdout, fallback=devnull)
91 stderr = IOStream(sys.stderr, fallback=devnull)
82 stderr = IOStream(sys.stderr, fallback=devnull)
92
83
93 class IOTerm:
84 class IOTerm:
94 """ Term holds the file or file-like objects for handling I/O operations.
85 """ Term holds the file or file-like objects for handling I/O operations.
95
86
96 These are normally just sys.stdin, sys.stdout and sys.stderr but for
87 These are normally just sys.stdin, sys.stdout and sys.stderr but for
97 Windows they can can replaced to allow editing the strings before they are
88 Windows they can can replaced to allow editing the strings before they are
98 displayed."""
89 displayed."""
99
90
100 # In the future, having IPython channel all its I/O operations through
91 # In the future, having IPython channel all its I/O operations through
101 # this class will make it easier to embed it into other environments which
92 # this class will make it easier to embed it into other environments which
102 # are not a normal terminal (such as a GUI-based shell)
93 # are not a normal terminal (such as a GUI-based shell)
103 def __init__(self, stdin=None, stdout=None, stderr=None):
94 def __init__(self, stdin=None, stdout=None, stderr=None):
104 mymodule = sys.modules[__name__]
95 mymodule = sys.modules[__name__]
105 self.stdin = IOStream(stdin, mymodule.stdin)
96 self.stdin = IOStream(stdin, mymodule.stdin)
106 self.stdout = IOStream(stdout, mymodule.stdout)
97 self.stdout = IOStream(stdout, mymodule.stdout)
107 self.stderr = IOStream(stderr, mymodule.stderr)
98 self.stderr = IOStream(stderr, mymodule.stderr)
108
99
109
100
110 class Tee(object):
101 class Tee(object):
111 """A class to duplicate an output stream to stdout/err.
102 """A class to duplicate an output stream to stdout/err.
112
103
113 This works in a manner very similar to the Unix 'tee' command.
104 This works in a manner very similar to the Unix 'tee' command.
114
105
115 When the object is closed or deleted, it closes the original file given to
106 When the object is closed or deleted, it closes the original file given to
116 it for duplication.
107 it for duplication.
117 """
108 """
118 # Inspired by:
109 # Inspired by:
119 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
110 # http://mail.python.org/pipermail/python-list/2007-May/442737.html
120
111
121 def __init__(self, file_or_name, mode="w", channel='stdout'):
112 def __init__(self, file_or_name, mode="w", channel='stdout'):
122 """Construct a new Tee object.
113 """Construct a new Tee object.
123
114
124 Parameters
115 Parameters
125 ----------
116 ----------
126 file_or_name : filename or open filehandle (writable)
117 file_or_name : filename or open filehandle (writable)
127 File that will be duplicated
118 File that will be duplicated
128
119
129 mode : optional, valid mode for open().
120 mode : optional, valid mode for open().
130 If a filename was give, open with this mode.
121 If a filename was give, open with this mode.
131
122
132 channel : str, one of ['stdout', 'stderr']
123 channel : str, one of ['stdout', 'stderr']
133 """
124 """
134 if channel not in ['stdout', 'stderr']:
125 if channel not in ['stdout', 'stderr']:
135 raise ValueError('Invalid channel spec %s' % channel)
126 raise ValueError('Invalid channel spec %s' % channel)
136
127
137 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
128 if hasattr(file_or_name, 'write') and hasattr(file_or_name, 'seek'):
138 self.file = file_or_name
129 self.file = file_or_name
139 else:
130 else:
140 self.file = open(file_or_name, mode)
131 self.file = open(file_or_name, mode)
141 self.channel = channel
132 self.channel = channel
142 self.ostream = getattr(sys, channel)
133 self.ostream = getattr(sys, channel)
143 setattr(sys, channel, self)
134 setattr(sys, channel, self)
144 self._closed = False
135 self._closed = False
145
136
146 def close(self):
137 def close(self):
147 """Close the file and restore the channel."""
138 """Close the file and restore the channel."""
148 self.flush()
139 self.flush()
149 setattr(sys, self.channel, self.ostream)
140 setattr(sys, self.channel, self.ostream)
150 self.file.close()
141 self.file.close()
151 self._closed = True
142 self._closed = True
152
143
153 def write(self, data):
144 def write(self, data):
154 """Write data to both channels."""
145 """Write data to both channels."""
155 self.file.write(data)
146 self.file.write(data)
156 self.ostream.write(data)
147 self.ostream.write(data)
157 self.ostream.flush()
148 self.ostream.flush()
158
149
159 def flush(self):
150 def flush(self):
160 """Flush both channels."""
151 """Flush both channels."""
161 self.file.flush()
152 self.file.flush()
162 self.ostream.flush()
153 self.ostream.flush()
163
154
164 def __del__(self):
155 def __del__(self):
165 if not self._closed:
156 if not self._closed:
166 self.close()
157 self.close()
167
158
168
159
169 def ask_yes_no(prompt, default=None, interrupt=None):
160 def ask_yes_no(prompt, default=None, interrupt=None):
170 """Asks a question and returns a boolean (y/n) answer.
161 """Asks a question and returns a boolean (y/n) answer.
171
162
172 If default is given (one of 'y','n'), it is used if the user input is
163 If default is given (one of 'y','n'), it is used if the user input is
173 empty. If interrupt is given (one of 'y','n'), it is used if the user
164 empty. If interrupt is given (one of 'y','n'), it is used if the user
174 presses Ctrl-C. Otherwise the question is repeated until an answer is
165 presses Ctrl-C. Otherwise the question is repeated until an answer is
175 given.
166 given.
176
167
177 An EOF is treated as the default answer. If there is no default, an
168 An EOF is treated as the default answer. If there is no default, an
178 exception is raised to prevent infinite loops.
169 exception is raised to prevent infinite loops.
179
170
180 Valid answers are: y/yes/n/no (match is not case sensitive)."""
171 Valid answers are: y/yes/n/no (match is not case sensitive)."""
181
172
182 answers = {'y':True,'n':False,'yes':True,'no':False}
173 answers = {'y':True,'n':False,'yes':True,'no':False}
183 ans = None
174 ans = None
184 while ans not in answers.keys():
175 while ans not in answers.keys():
185 try:
176 try:
186 ans = input(prompt+' ').lower()
177 ans = input(prompt+' ').lower()
187 if not ans: # response was an empty string
178 if not ans: # response was an empty string
188 ans = default
179 ans = default
189 except KeyboardInterrupt:
180 except KeyboardInterrupt:
190 if interrupt:
181 if interrupt:
191 ans = interrupt
182 ans = interrupt
192 except EOFError:
183 except EOFError:
193 if default in answers.keys():
184 if default in answers.keys():
194 ans = default
185 ans = default
195 print()
186 print()
196 else:
187 else:
197 raise
188 raise
198
189
199 return answers[ans]
190 return answers[ans]
200
191
201
192
202 def temp_pyfile(src, ext='.py'):
193 def temp_pyfile(src, ext='.py'):
203 """Make a temporary python file, return filename and filehandle.
194 """Make a temporary python file, return filename and filehandle.
204
195
205 Parameters
196 Parameters
206 ----------
197 ----------
207 src : string or list of strings (no need for ending newlines if list)
198 src : string or list of strings (no need for ending newlines if list)
208 Source code to be written to the file.
199 Source code to be written to the file.
209
200
210 ext : optional, string
201 ext : optional, string
211 Extension for the generated file.
202 Extension for the generated file.
212
203
213 Returns
204 Returns
214 -------
205 -------
215 (filename, open filehandle)
206 (filename, open filehandle)
216 It is the caller's responsibility to close the open file and unlink it.
207 It is the caller's responsibility to close the open file and unlink it.
217 """
208 """
218 fname = tempfile.mkstemp(ext)[1]
209 fname = tempfile.mkstemp(ext)[1]
219 f = open(fname,'w')
210 f = open(fname,'w')
220 f.write(src)
211 f.write(src)
221 f.flush()
212 f.flush()
222 return fname, f
213 return fname, f
223
214
224 def _copy_metadata(src, dst):
215 def atomic_writing(*args, **kwargs):
225 """Copy the set of metadata we want for atomic_writing.
216 """DEPRECATED: moved to IPython.html.services.contents.fileio"""
226
217 warn("IPython.utils.io.atomic_writing has moved to IPython.html.services.contents.fileio")
227 Permission bits and flags. We'd like to copy file ownership as well, but we
218 from IPython.html.services.contents.fileio import atomic_writing
228 can't do that.
219 return atomic_writing(*args, **kwargs)
229 """
230 shutil.copymode(src, dst)
231 st = os.stat(src)
232 if hasattr(os, 'chflags') and hasattr(st, 'st_flags'):
233 os.chflags(dst, st.st_flags)
234
235 @contextmanager
236 def atomic_writing(path, text=True, encoding='utf-8', **kwargs):
237 """Context manager to write to a file only if the entire write is successful.
238
239 This works by creating a temporary file in the same directory, and renaming
240 it over the old file if the context is exited without an error. If other
241 file names are hard linked to the target file, this relationship will not be
242 preserved.
243
244 On Windows, there is a small chink in the atomicity: the target file is
245 deleted before renaming the temporary file over it. This appears to be
246 unavoidable.
247
248 Parameters
249 ----------
250 path : str
251 The target file to write to.
252
253 text : bool, optional
254 Whether to open the file in text mode (i.e. to write unicode). Default is
255 True.
256
257 encoding : str, optional
258 The encoding to use for files opened in text mode. Default is UTF-8.
259
260 **kwargs
261 Passed to :func:`io.open`.
262 """
263 # realpath doesn't work on Windows: http://bugs.python.org/issue9949
264 # Luckily, we only need to resolve the file itself being a symlink, not
265 # any of its directories, so this will suffice:
266 if os.path.islink(path):
267 path = os.path.join(os.path.dirname(path), os.readlink(path))
268
269 dirname, basename = os.path.split(path)
270 tmp_dir = tempfile.mkdtemp(prefix=basename, dir=dirname)
271 tmp_path = os.path.join(tmp_dir, basename)
272 if text:
273 fileobj = io.open(tmp_path, 'w', encoding=encoding, **kwargs)
274 else:
275 fileobj = io.open(tmp_path, 'wb', **kwargs)
276
277 try:
278 yield fileobj
279 except:
280 fileobj.close()
281 shutil.rmtree(tmp_dir)
282 raise
283
284 # Flush to disk
285 fileobj.flush()
286 os.fsync(fileobj.fileno())
287
288 # Written successfully, now rename it
289 fileobj.close()
290
291 # Copy permission bits, access time, etc.
292 try:
293 _copy_metadata(path, tmp_path)
294 except OSError:
295 # e.g. the file didn't already exist. Ignore any failure to copy metadata
296 pass
297
298 if os.name == 'nt' and os.path.exists(path):
299 # Rename over existing file doesn't work on Windows
300 os.remove(path)
301
302 os.rename(tmp_path, path)
303 shutil.rmtree(tmp_dir)
304
305
220
306 def raw_print(*args, **kw):
221 def raw_print(*args, **kw):
307 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
222 """Raw print to sys.__stdout__, otherwise identical interface to print()."""
308
223
309 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
224 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
310 file=sys.__stdout__)
225 file=sys.__stdout__)
311 sys.__stdout__.flush()
226 sys.__stdout__.flush()
312
227
313
228
314 def raw_print_err(*args, **kw):
229 def raw_print_err(*args, **kw):
315 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
230 """Raw print to sys.__stderr__, otherwise identical interface to print()."""
316
231
317 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
232 print(*args, sep=kw.get('sep', ' '), end=kw.get('end', '\n'),
318 file=sys.__stderr__)
233 file=sys.__stderr__)
319 sys.__stderr__.flush()
234 sys.__stderr__.flush()
320
235
321
236
322 # Short aliases for quick debugging, do NOT use these in production code.
237 # Short aliases for quick debugging, do NOT use these in production code.
323 rprint = raw_print
238 rprint = raw_print
324 rprinte = raw_print_err
239 rprinte = raw_print_err
325
240
326 def unicode_std_stream(stream='stdout'):
327 u"""Get a wrapper to write unicode to stdout/stderr as UTF-8.
328
329 This ignores environment variables and default encodings, to reliably write
330 unicode to stdout or stderr.
331
241
332 ::
242 def unicode_std_stream(stream='stdout'):
333
243 """DEPRECATED, moved to jupyter_nbconvert.utils.io"""
334 unicode_std_stream().write(u'Ε‚@e¢ŧ←')
244 warn("IPython.utils.io.unicode_std_stream has moved to jupyter_nbconvert.utils.io")
335 """
245 from jupyter_nbconvert.utils.io import unicode_std_stream
336 assert stream in ('stdout', 'stderr')
246 return unicode_std_stream(stream)
337 stream = getattr(sys, stream)
338 if PY3:
339 try:
340 stream_b = stream.buffer
341 except AttributeError:
342 # sys.stdout has been replaced - use it directly
343 return stream
344 else:
345 stream_b = stream
346
347 return codecs.getwriter('utf-8')(stream_b)
@@ -1,231 +1,87 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for io.py"""
2 """Tests for io.py"""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 from __future__ import print_function
7 from __future__ import print_function
8 from __future__ import absolute_import
8 from __future__ import absolute_import
9
9
10 import io as stdlib_io
10 import io as stdlib_io
11 import os.path
11 import os.path
12 import stat
12 import stat
13 import sys
13 import sys
14
14
15 from subprocess import Popen, PIPE
15 from subprocess import Popen, PIPE
16 import unittest
16 import unittest
17
17
18 import nose.tools as nt
18 import nose.tools as nt
19
19
20 from IPython.testing.decorators import skipif, skip_win32
20 from IPython.testing.decorators import skipif, skip_win32
21 from IPython.utils.io import (Tee, capture_output, unicode_std_stream,
21 from IPython.utils.io import Tee, capture_output
22 atomic_writing,
23 )
24 from IPython.utils.py3compat import doctest_refactor_print, PY3
22 from IPython.utils.py3compat import doctest_refactor_print, PY3
25 from IPython.utils.tempdir import TemporaryDirectory
23 from IPython.utils.tempdir import TemporaryDirectory
26
24
27 if PY3:
25 if PY3:
28 from io import StringIO
26 from io import StringIO
29 else:
27 else:
30 from StringIO import StringIO
28 from StringIO import StringIO
31
29
32
30
33 def test_tee_simple():
31 def test_tee_simple():
34 "Very simple check with stdout only"
32 "Very simple check with stdout only"
35 chan = StringIO()
33 chan = StringIO()
36 text = 'Hello'
34 text = 'Hello'
37 tee = Tee(chan, channel='stdout')
35 tee = Tee(chan, channel='stdout')
38 print(text, file=chan)
36 print(text, file=chan)
39 nt.assert_equal(chan.getvalue(), text+"\n")
37 nt.assert_equal(chan.getvalue(), text+"\n")
40
38
41
39
42 class TeeTestCase(unittest.TestCase):
40 class TeeTestCase(unittest.TestCase):
43
41
44 def tchan(self, channel, check='close'):
42 def tchan(self, channel, check='close'):
45 trap = StringIO()
43 trap = StringIO()
46 chan = StringIO()
44 chan = StringIO()
47 text = 'Hello'
45 text = 'Hello'
48
46
49 std_ori = getattr(sys, channel)
47 std_ori = getattr(sys, channel)
50 setattr(sys, channel, trap)
48 setattr(sys, channel, trap)
51
49
52 tee = Tee(chan, channel=channel)
50 tee = Tee(chan, channel=channel)
53 print(text, end='', file=chan)
51 print(text, end='', file=chan)
54 setattr(sys, channel, std_ori)
52 setattr(sys, channel, std_ori)
55 trap_val = trap.getvalue()
53 trap_val = trap.getvalue()
56 nt.assert_equal(chan.getvalue(), text)
54 nt.assert_equal(chan.getvalue(), text)
57 if check=='close':
55 if check=='close':
58 tee.close()
56 tee.close()
59 else:
57 else:
60 del tee
58 del tee
61
59
62 def test(self):
60 def test(self):
63 for chan in ['stdout', 'stderr']:
61 for chan in ['stdout', 'stderr']:
64 for check in ['close', 'del']:
62 for check in ['close', 'del']:
65 self.tchan(chan, check)
63 self.tchan(chan, check)
66
64
67 def test_io_init():
65 def test_io_init():
68 """Test that io.stdin/out/err exist at startup"""
66 """Test that io.stdin/out/err exist at startup"""
69 for name in ('stdin', 'stdout', 'stderr'):
67 for name in ('stdin', 'stdout', 'stderr'):
70 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
68 cmd = doctest_refactor_print("from IPython.utils import io;print io.%s.__class__"%name)
71 p = Popen([sys.executable, '-c', cmd],
69 p = Popen([sys.executable, '-c', cmd],
72 stdout=PIPE)
70 stdout=PIPE)
73 p.wait()
71 p.wait()
74 classname = p.stdout.read().strip().decode('ascii')
72 classname = p.stdout.read().strip().decode('ascii')
75 # __class__ is a reference to the class object in Python 3, so we can't
73 # __class__ is a reference to the class object in Python 3, so we can't
76 # just test for string equality.
74 # just test for string equality.
77 assert 'IPython.utils.io.IOStream' in classname, classname
75 assert 'IPython.utils.io.IOStream' in classname, classname
78
76
79 def test_capture_output():
77 def test_capture_output():
80 """capture_output() context works"""
78 """capture_output() context works"""
81
79
82 with capture_output() as io:
80 with capture_output() as io:
83 print('hi, stdout')
81 print('hi, stdout')
84 print('hi, stderr', file=sys.stderr)
82 print('hi, stderr', file=sys.stderr)
85
83
86 nt.assert_equal(io.stdout, 'hi, stdout\n')
84 nt.assert_equal(io.stdout, 'hi, stdout\n')
87 nt.assert_equal(io.stderr, 'hi, stderr\n')
85 nt.assert_equal(io.stderr, 'hi, stderr\n')
88
86
89 def test_UnicodeStdStream():
87
90 # Test wrapping a bytes-level stdout
91 if PY3:
92 stdoutb = stdlib_io.BytesIO()
93 stdout = stdlib_io.TextIOWrapper(stdoutb, encoding='ascii')
94 else:
95 stdout = stdoutb = stdlib_io.BytesIO()
96
97 orig_stdout = sys.stdout
98 sys.stdout = stdout
99 try:
100 sample = u"@Ε‚e¢ŧ←"
101 unicode_std_stream().write(sample)
102
103 output = stdoutb.getvalue().decode('utf-8')
104 nt.assert_equal(output, sample)
105 assert not stdout.closed
106 finally:
107 sys.stdout = orig_stdout
108
109 @skipif(not PY3, "Not applicable on Python 2")
110 def test_UnicodeStdStream_nowrap():
111 # If we replace stdout with a StringIO, it shouldn't get wrapped.
112 orig_stdout = sys.stdout
113 sys.stdout = StringIO()
114 try:
115 nt.assert_is(unicode_std_stream(), sys.stdout)
116 assert not sys.stdout.closed
117 finally:
118 sys.stdout = orig_stdout
119
120 def test_atomic_writing():
121 class CustomExc(Exception): pass
122
123 with TemporaryDirectory() as td:
124 f1 = os.path.join(td, 'penguin')
125 with stdlib_io.open(f1, 'w') as f:
126 f.write(u'Before')
127
128 if os.name != 'nt':
129 os.chmod(f1, 0o701)
130 orig_mode = stat.S_IMODE(os.stat(f1).st_mode)
131
132 f2 = os.path.join(td, 'flamingo')
133 try:
134 os.symlink(f1, f2)
135 have_symlink = True
136 except (AttributeError, NotImplementedError, OSError):
137 # AttributeError: Python doesn't support it
138 # NotImplementedError: The system doesn't support it
139 # OSError: The user lacks the privilege (Windows)
140 have_symlink = False
141
142 with nt.assert_raises(CustomExc):
143 with atomic_writing(f1) as f:
144 f.write(u'Failing write')
145 raise CustomExc
146
147 # Because of the exception, the file should not have been modified
148 with stdlib_io.open(f1, 'r') as f:
149 nt.assert_equal(f.read(), u'Before')
150
151 with atomic_writing(f1) as f:
152 f.write(u'Overwritten')
153
154 with stdlib_io.open(f1, 'r') as f:
155 nt.assert_equal(f.read(), u'Overwritten')
156
157 if os.name != 'nt':
158 mode = stat.S_IMODE(os.stat(f1).st_mode)
159 nt.assert_equal(mode, orig_mode)
160
161 if have_symlink:
162 # Check that writing over a file preserves a symlink
163 with atomic_writing(f2) as f:
164 f.write(u'written from symlink')
165
166 with stdlib_io.open(f1, 'r') as f:
167 nt.assert_equal(f.read(), u'written from symlink')
168
169 def _save_umask():
170 global umask
171 umask = os.umask(0)
172 os.umask(umask)
173
174 def _restore_umask():
175 os.umask(umask)
176
177 @skip_win32
178 @nt.with_setup(_save_umask, _restore_umask)
179 def test_atomic_writing_umask():
180 with TemporaryDirectory() as td:
181 os.umask(0o022)
182 f1 = os.path.join(td, '1')
183 with atomic_writing(f1) as f:
184 f.write(u'1')
185 mode = stat.S_IMODE(os.stat(f1).st_mode)
186 nt.assert_equal(mode, 0o644, '{:o} != 644'.format(mode))
187
188 os.umask(0o057)
189 f2 = os.path.join(td, '2')
190 with atomic_writing(f2) as f:
191 f.write(u'2')
192 mode = stat.S_IMODE(os.stat(f2).st_mode)
193 nt.assert_equal(mode, 0o620, '{:o} != 620'.format(mode))
194
195
196 def test_atomic_writing_newlines():
197 with TemporaryDirectory() as td:
198 path = os.path.join(td, 'testfile')
199
200 lf = u'a\nb\nc\n'
201 plat = lf.replace(u'\n', os.linesep)
202 crlf = lf.replace(u'\n', u'\r\n')
203
204 # test default
205 with stdlib_io.open(path, 'w') as f:
206 f.write(lf)
207 with stdlib_io.open(path, 'r', newline='') as f:
208 read = f.read()
209 nt.assert_equal(read, plat)
210
211 # test newline=LF
212 with stdlib_io.open(path, 'w', newline='\n') as f:
213 f.write(lf)
214 with stdlib_io.open(path, 'r', newline='') as f:
215 read = f.read()
216 nt.assert_equal(read, lf)
217
218 # test newline=CRLF
219 with atomic_writing(path, newline='\r\n') as f:
220 f.write(lf)
221 with stdlib_io.open(path, 'r', newline='') as f:
222 read = f.read()
223 nt.assert_equal(read, crlf)
224
225 # test newline=no convert
226 text = u'crlf\r\ncr\rlf\n'
227 with atomic_writing(path, newline='') as f:
228 f.write(text)
229 with stdlib_io.open(path, 'r', newline='') as f:
230 read = f.read()
231 nt.assert_equal(read, text)
@@ -1,889 +1,891 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9 """
9 """
10
10
11 # Copyright (c) IPython Development Team.
11 # Copyright (c) IPython Development Team.
12 # Distributed under the terms of the Modified BSD License.
12 # Distributed under the terms of the Modified BSD License.
13
13
14 import hashlib
14 import hashlib
15 import hmac
15 import hmac
16 import logging
16 import logging
17 import os
17 import os
18 import pprint
18 import pprint
19 import random
19 import random
20 import uuid
20 import uuid
21 import warnings
21 import warnings
22 from datetime import datetime
22 from datetime import datetime
23
23
24 try:
24 try:
25 import cPickle
25 import cPickle
26 pickle = cPickle
26 pickle = cPickle
27 except:
27 except:
28 cPickle = None
28 cPickle = None
29 import pickle
29 import pickle
30
30
31 try:
31 try:
32 # py3
32 # py3
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
33 PICKLE_PROTOCOL = pickle.DEFAULT_PROTOCOL
34 except AttributeError:
34 except AttributeError:
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
35 PICKLE_PROTOCOL = pickle.HIGHEST_PROTOCOL
36
36
37 try:
37 try:
38 # We are using compare_digest to limit the surface of timing attacks
38 # We are using compare_digest to limit the surface of timing attacks
39 from hmac import compare_digest
39 from hmac import compare_digest
40 except ImportError:
40 except ImportError:
41 # Python < 2.7.7: When digests don't match no feedback is provided,
41 # Python < 2.7.7: When digests don't match no feedback is provided,
42 # limiting the surface of attack
42 # limiting the surface of attack
43 def compare_digest(a,b): return a == b
43 def compare_digest(a,b): return a == b
44
44
45 import zmq
45 import zmq
46 from zmq.utils import jsonapi
46 from zmq.utils import jsonapi
47 from zmq.eventloop.ioloop import IOLoop
47 from zmq.eventloop.ioloop import IOLoop
48 from zmq.eventloop.zmqstream import ZMQStream
48 from zmq.eventloop.zmqstream import ZMQStream
49
49
50 from IPython.core.release import kernel_protocol_version
50 from IPython.core.release import kernel_protocol_version
51 from IPython.config.configurable import Configurable, LoggingConfigurable
51 from IPython.config.configurable import Configurable, LoggingConfigurable
52 from IPython.utils import io
53 from IPython.utils.importstring import import_item
52 from IPython.utils.importstring import import_item
54 from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
53 from jupyter_client.jsonutil import extract_dates, squash_dates, date_default
55 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
54 from IPython.utils.py3compat import (str_to_bytes, str_to_unicode, unicode_type,
56 iteritems)
55 iteritems)
57 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
56 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
58 DottedObjectName, CUnicode, Dict, Integer,
57 DottedObjectName, CUnicode, Dict, Integer,
59 TraitError,
58 TraitError,
60 )
59 )
61 from jupyter_client.adapter import adapt
60 from jupyter_client.adapter import adapt
61 from traitlets.log import get_logger
62
62
63
63 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
64 # utility functions
65 # utility functions
65 #-----------------------------------------------------------------------------
66 #-----------------------------------------------------------------------------
66
67
67 def squash_unicode(obj):
68 def squash_unicode(obj):
68 """coerce unicode back to bytestrings."""
69 """coerce unicode back to bytestrings."""
69 if isinstance(obj,dict):
70 if isinstance(obj,dict):
70 for key in obj.keys():
71 for key in obj.keys():
71 obj[key] = squash_unicode(obj[key])
72 obj[key] = squash_unicode(obj[key])
72 if isinstance(key, unicode_type):
73 if isinstance(key, unicode_type):
73 obj[squash_unicode(key)] = obj.pop(key)
74 obj[squash_unicode(key)] = obj.pop(key)
74 elif isinstance(obj, list):
75 elif isinstance(obj, list):
75 for i,v in enumerate(obj):
76 for i,v in enumerate(obj):
76 obj[i] = squash_unicode(v)
77 obj[i] = squash_unicode(v)
77 elif isinstance(obj, unicode_type):
78 elif isinstance(obj, unicode_type):
78 obj = obj.encode('utf8')
79 obj = obj.encode('utf8')
79 return obj
80 return obj
80
81
81 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
82 # globals and defaults
83 # globals and defaults
83 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
84
85
85 # default values for the thresholds:
86 # default values for the thresholds:
86 MAX_ITEMS = 64
87 MAX_ITEMS = 64
87 MAX_BYTES = 1024
88 MAX_BYTES = 1024
88
89
89 # ISO8601-ify datetime objects
90 # ISO8601-ify datetime objects
90 # allow unicode
91 # allow unicode
91 # disallow nan, because it's not actually valid JSON
92 # disallow nan, because it's not actually valid JSON
92 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
93 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,
93 ensure_ascii=False, allow_nan=False,
94 ensure_ascii=False, allow_nan=False,
94 )
95 )
95 json_unpacker = lambda s: jsonapi.loads(s)
96 json_unpacker = lambda s: jsonapi.loads(s)
96
97
97 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
98 pickle_packer = lambda o: pickle.dumps(squash_dates(o), PICKLE_PROTOCOL)
98 pickle_unpacker = pickle.loads
99 pickle_unpacker = pickle.loads
99
100
100 default_packer = json_packer
101 default_packer = json_packer
101 default_unpacker = json_unpacker
102 default_unpacker = json_unpacker
102
103
103 DELIM = b"<IDS|MSG>"
104 DELIM = b"<IDS|MSG>"
104 # singleton dummy tracker, which will always report as done
105 # singleton dummy tracker, which will always report as done
105 DONE = zmq.MessageTracker()
106 DONE = zmq.MessageTracker()
106
107
107 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
108 # Mixin tools for apps that use Sessions
109 # Mixin tools for apps that use Sessions
109 #-----------------------------------------------------------------------------
110 #-----------------------------------------------------------------------------
110
111
111 session_aliases = dict(
112 session_aliases = dict(
112 ident = 'Session.session',
113 ident = 'Session.session',
113 user = 'Session.username',
114 user = 'Session.username',
114 keyfile = 'Session.keyfile',
115 keyfile = 'Session.keyfile',
115 )
116 )
116
117
117 session_flags = {
118 session_flags = {
118 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
119 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
119 'keyfile' : '' }},
120 'keyfile' : '' }},
120 """Use HMAC digests for authentication of messages.
121 """Use HMAC digests for authentication of messages.
121 Setting this flag will generate a new UUID to use as the HMAC key.
122 Setting this flag will generate a new UUID to use as the HMAC key.
122 """),
123 """),
123 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
124 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
124 """Don't authenticate messages."""),
125 """Don't authenticate messages."""),
125 }
126 }
126
127
127 def default_secure(cfg):
128 def default_secure(cfg):
128 """Set the default behavior for a config environment to be secure.
129 """Set the default behavior for a config environment to be secure.
129
130
130 If Session.key/keyfile have not been set, set Session.key to
131 If Session.key/keyfile have not been set, set Session.key to
131 a new random UUID.
132 a new random UUID.
132 """
133 """
133 warnings.warn("default_secure is deprecated", DeprecationWarning)
134 warnings.warn("default_secure is deprecated", DeprecationWarning)
134 if 'Session' in cfg:
135 if 'Session' in cfg:
135 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
136 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
136 return
137 return
137 # key/keyfile not specified, generate new UUID:
138 # key/keyfile not specified, generate new UUID:
138 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
139 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
139
140
140
141
141 #-----------------------------------------------------------------------------
142 #-----------------------------------------------------------------------------
142 # Classes
143 # Classes
143 #-----------------------------------------------------------------------------
144 #-----------------------------------------------------------------------------
144
145
145 class SessionFactory(LoggingConfigurable):
146 class SessionFactory(LoggingConfigurable):
146 """The Base class for configurables that have a Session, Context, logger,
147 """The Base class for configurables that have a Session, Context, logger,
147 and IOLoop.
148 and IOLoop.
148 """
149 """
149
150
150 logname = Unicode('')
151 logname = Unicode('')
151 def _logname_changed(self, name, old, new):
152 def _logname_changed(self, name, old, new):
152 self.log = logging.getLogger(new)
153 self.log = logging.getLogger(new)
153
154
154 # not configurable:
155 # not configurable:
155 context = Instance('zmq.Context')
156 context = Instance('zmq.Context')
156 def _context_default(self):
157 def _context_default(self):
157 return zmq.Context.instance()
158 return zmq.Context.instance()
158
159
159 session = Instance('jupyter_client.session.Session',
160 session = Instance('jupyter_client.session.Session',
160 allow_none=True)
161 allow_none=True)
161
162
162 loop = Instance('zmq.eventloop.ioloop.IOLoop')
163 loop = Instance('zmq.eventloop.ioloop.IOLoop')
163 def _loop_default(self):
164 def _loop_default(self):
164 return IOLoop.instance()
165 return IOLoop.instance()
165
166
166 def __init__(self, **kwargs):
167 def __init__(self, **kwargs):
167 super(SessionFactory, self).__init__(**kwargs)
168 super(SessionFactory, self).__init__(**kwargs)
168
169
169 if self.session is None:
170 if self.session is None:
170 # construct the session
171 # construct the session
171 self.session = Session(**kwargs)
172 self.session = Session(**kwargs)
172
173
173
174
174 class Message(object):
175 class Message(object):
175 """A simple message object that maps dict keys to attributes.
176 """A simple message object that maps dict keys to attributes.
176
177
177 A Message can be created from a dict and a dict from a Message instance
178 A Message can be created from a dict and a dict from a Message instance
178 simply by calling dict(msg_obj)."""
179 simply by calling dict(msg_obj)."""
179
180
180 def __init__(self, msg_dict):
181 def __init__(self, msg_dict):
181 dct = self.__dict__
182 dct = self.__dict__
182 for k, v in iteritems(dict(msg_dict)):
183 for k, v in iteritems(dict(msg_dict)):
183 if isinstance(v, dict):
184 if isinstance(v, dict):
184 v = Message(v)
185 v = Message(v)
185 dct[k] = v
186 dct[k] = v
186
187
187 # Having this iterator lets dict(msg_obj) work out of the box.
188 # Having this iterator lets dict(msg_obj) work out of the box.
188 def __iter__(self):
189 def __iter__(self):
189 return iter(iteritems(self.__dict__))
190 return iter(iteritems(self.__dict__))
190
191
191 def __repr__(self):
192 def __repr__(self):
192 return repr(self.__dict__)
193 return repr(self.__dict__)
193
194
194 def __str__(self):
195 def __str__(self):
195 return pprint.pformat(self.__dict__)
196 return pprint.pformat(self.__dict__)
196
197
197 def __contains__(self, k):
198 def __contains__(self, k):
198 return k in self.__dict__
199 return k in self.__dict__
199
200
200 def __getitem__(self, k):
201 def __getitem__(self, k):
201 return self.__dict__[k]
202 return self.__dict__[k]
202
203
203
204
204 def msg_header(msg_id, msg_type, username, session):
205 def msg_header(msg_id, msg_type, username, session):
205 date = datetime.now()
206 date = datetime.now()
206 version = kernel_protocol_version
207 version = kernel_protocol_version
207 return locals()
208 return locals()
208
209
209 def extract_header(msg_or_header):
210 def extract_header(msg_or_header):
210 """Given a message or header, return the header."""
211 """Given a message or header, return the header."""
211 if not msg_or_header:
212 if not msg_or_header:
212 return {}
213 return {}
213 try:
214 try:
214 # See if msg_or_header is the entire message.
215 # See if msg_or_header is the entire message.
215 h = msg_or_header['header']
216 h = msg_or_header['header']
216 except KeyError:
217 except KeyError:
217 try:
218 try:
218 # See if msg_or_header is just the header
219 # See if msg_or_header is just the header
219 h = msg_or_header['msg_id']
220 h = msg_or_header['msg_id']
220 except KeyError:
221 except KeyError:
221 raise
222 raise
222 else:
223 else:
223 h = msg_or_header
224 h = msg_or_header
224 if not isinstance(h, dict):
225 if not isinstance(h, dict):
225 h = dict(h)
226 h = dict(h)
226 return h
227 return h
227
228
228 class Session(Configurable):
229 class Session(Configurable):
229 """Object for handling serialization and sending of messages.
230 """Object for handling serialization and sending of messages.
230
231
231 The Session object handles building messages and sending them
232 The Session object handles building messages and sending them
232 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
233 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
233 other over the network via Session objects, and only need to work with the
234 other over the network via Session objects, and only need to work with the
234 dict-based IPython message spec. The Session will handle
235 dict-based IPython message spec. The Session will handle
235 serialization/deserialization, security, and metadata.
236 serialization/deserialization, security, and metadata.
236
237
237 Sessions support configurable serialization via packer/unpacker traits,
238 Sessions support configurable serialization via packer/unpacker traits,
238 and signing with HMAC digests via the key/keyfile traits.
239 and signing with HMAC digests via the key/keyfile traits.
239
240
240 Parameters
241 Parameters
241 ----------
242 ----------
242
243
243 debug : bool
244 debug : bool
244 whether to trigger extra debugging statements
245 whether to trigger extra debugging statements
245 packer/unpacker : str : 'json', 'pickle' or import_string
246 packer/unpacker : str : 'json', 'pickle' or import_string
246 importstrings for methods to serialize message parts. If just
247 importstrings for methods to serialize message parts. If just
247 'json' or 'pickle', predefined JSON and pickle packers will be used.
248 'json' or 'pickle', predefined JSON and pickle packers will be used.
248 Otherwise, the entire importstring must be used.
249 Otherwise, the entire importstring must be used.
249
250
250 The functions must accept at least valid JSON input, and output *bytes*.
251 The functions must accept at least valid JSON input, and output *bytes*.
251
252
252 For example, to use msgpack:
253 For example, to use msgpack:
253 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
254 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
254 pack/unpack : callables
255 pack/unpack : callables
255 You can also set the pack/unpack callables for serialization directly.
256 You can also set the pack/unpack callables for serialization directly.
256 session : bytes
257 session : bytes
257 the ID of this Session object. The default is to generate a new UUID.
258 the ID of this Session object. The default is to generate a new UUID.
258 username : unicode
259 username : unicode
259 username added to message headers. The default is to ask the OS.
260 username added to message headers. The default is to ask the OS.
260 key : bytes
261 key : bytes
261 The key used to initialize an HMAC signature. If unset, messages
262 The key used to initialize an HMAC signature. If unset, messages
262 will not be signed or checked.
263 will not be signed or checked.
263 keyfile : filepath
264 keyfile : filepath
264 The file containing a key. If this is set, `key` will be initialized
265 The file containing a key. If this is set, `key` will be initialized
265 to the contents of the file.
266 to the contents of the file.
266
267
267 """
268 """
268
269
269 debug=Bool(False, config=True, help="""Debug output in the Session""")
270 debug=Bool(False, config=True, help="""Debug output in the Session""")
270
271
271 packer = DottedObjectName('json',config=True,
272 packer = DottedObjectName('json',config=True,
272 help="""The name of the packer for serializing messages.
273 help="""The name of the packer for serializing messages.
273 Should be one of 'json', 'pickle', or an import name
274 Should be one of 'json', 'pickle', or an import name
274 for a custom callable serializer.""")
275 for a custom callable serializer.""")
275 def _packer_changed(self, name, old, new):
276 def _packer_changed(self, name, old, new):
276 if new.lower() == 'json':
277 if new.lower() == 'json':
277 self.pack = json_packer
278 self.pack = json_packer
278 self.unpack = json_unpacker
279 self.unpack = json_unpacker
279 self.unpacker = new
280 self.unpacker = new
280 elif new.lower() == 'pickle':
281 elif new.lower() == 'pickle':
281 self.pack = pickle_packer
282 self.pack = pickle_packer
282 self.unpack = pickle_unpacker
283 self.unpack = pickle_unpacker
283 self.unpacker = new
284 self.unpacker = new
284 else:
285 else:
285 self.pack = import_item(str(new))
286 self.pack = import_item(str(new))
286
287
287 unpacker = DottedObjectName('json', config=True,
288 unpacker = DottedObjectName('json', config=True,
288 help="""The name of the unpacker for unserializing messages.
289 help="""The name of the unpacker for unserializing messages.
289 Only used with custom functions for `packer`.""")
290 Only used with custom functions for `packer`.""")
290 def _unpacker_changed(self, name, old, new):
291 def _unpacker_changed(self, name, old, new):
291 if new.lower() == 'json':
292 if new.lower() == 'json':
292 self.pack = json_packer
293 self.pack = json_packer
293 self.unpack = json_unpacker
294 self.unpack = json_unpacker
294 self.packer = new
295 self.packer = new
295 elif new.lower() == 'pickle':
296 elif new.lower() == 'pickle':
296 self.pack = pickle_packer
297 self.pack = pickle_packer
297 self.unpack = pickle_unpacker
298 self.unpack = pickle_unpacker
298 self.packer = new
299 self.packer = new
299 else:
300 else:
300 self.unpack = import_item(str(new))
301 self.unpack = import_item(str(new))
301
302
302 session = CUnicode(u'', config=True,
303 session = CUnicode(u'', config=True,
303 help="""The UUID identifying this session.""")
304 help="""The UUID identifying this session.""")
304 def _session_default(self):
305 def _session_default(self):
305 u = unicode_type(uuid.uuid4())
306 u = unicode_type(uuid.uuid4())
306 self.bsession = u.encode('ascii')
307 self.bsession = u.encode('ascii')
307 return u
308 return u
308
309
309 def _session_changed(self, name, old, new):
310 def _session_changed(self, name, old, new):
310 self.bsession = self.session.encode('ascii')
311 self.bsession = self.session.encode('ascii')
311
312
312 # bsession is the session as bytes
313 # bsession is the session as bytes
313 bsession = CBytes(b'')
314 bsession = CBytes(b'')
314
315
315 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
316 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
316 help="""Username for the Session. Default is your system username.""",
317 help="""Username for the Session. Default is your system username.""",
317 config=True)
318 config=True)
318
319
319 metadata = Dict({}, config=True,
320 metadata = Dict({}, config=True,
320 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
321 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
321
322
322 # if 0, no adapting to do.
323 # if 0, no adapting to do.
323 adapt_version = Integer(0)
324 adapt_version = Integer(0)
324
325
325 # message signature related traits:
326 # message signature related traits:
326
327
327 key = CBytes(config=True,
328 key = CBytes(config=True,
328 help="""execution key, for signing messages.""")
329 help="""execution key, for signing messages.""")
329 def _key_default(self):
330 def _key_default(self):
330 return str_to_bytes(str(uuid.uuid4()))
331 return str_to_bytes(str(uuid.uuid4()))
331
332
332 def _key_changed(self):
333 def _key_changed(self):
333 self._new_auth()
334 self._new_auth()
334
335
335 signature_scheme = Unicode('hmac-sha256', config=True,
336 signature_scheme = Unicode('hmac-sha256', config=True,
336 help="""The digest scheme used to construct the message signatures.
337 help="""The digest scheme used to construct the message signatures.
337 Must have the form 'hmac-HASH'.""")
338 Must have the form 'hmac-HASH'.""")
338 def _signature_scheme_changed(self, name, old, new):
339 def _signature_scheme_changed(self, name, old, new):
339 if not new.startswith('hmac-'):
340 if not new.startswith('hmac-'):
340 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
341 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
341 hash_name = new.split('-', 1)[1]
342 hash_name = new.split('-', 1)[1]
342 try:
343 try:
343 self.digest_mod = getattr(hashlib, hash_name)
344 self.digest_mod = getattr(hashlib, hash_name)
344 except AttributeError:
345 except AttributeError:
345 raise TraitError("hashlib has no such attribute: %s" % hash_name)
346 raise TraitError("hashlib has no such attribute: %s" % hash_name)
346 self._new_auth()
347 self._new_auth()
347
348
348 digest_mod = Any()
349 digest_mod = Any()
349 def _digest_mod_default(self):
350 def _digest_mod_default(self):
350 return hashlib.sha256
351 return hashlib.sha256
351
352
352 auth = Instance(hmac.HMAC, allow_none=True)
353 auth = Instance(hmac.HMAC, allow_none=True)
353
354
354 def _new_auth(self):
355 def _new_auth(self):
355 if self.key:
356 if self.key:
356 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
357 self.auth = hmac.HMAC(self.key, digestmod=self.digest_mod)
357 else:
358 else:
358 self.auth = None
359 self.auth = None
359
360
360 digest_history = Set()
361 digest_history = Set()
361 digest_history_size = Integer(2**16, config=True,
362 digest_history_size = Integer(2**16, config=True,
362 help="""The maximum number of digests to remember.
363 help="""The maximum number of digests to remember.
363
364
364 The digest history will be culled when it exceeds this value.
365 The digest history will be culled when it exceeds this value.
365 """
366 """
366 )
367 )
367
368
368 keyfile = Unicode('', config=True,
369 keyfile = Unicode('', config=True,
369 help="""path to file containing execution key.""")
370 help="""path to file containing execution key.""")
370 def _keyfile_changed(self, name, old, new):
371 def _keyfile_changed(self, name, old, new):
371 with open(new, 'rb') as f:
372 with open(new, 'rb') as f:
372 self.key = f.read().strip()
373 self.key = f.read().strip()
373
374
374 # for protecting against sends from forks
375 # for protecting against sends from forks
375 pid = Integer()
376 pid = Integer()
376
377
377 # serialization traits:
378 # serialization traits:
378
379
379 pack = Any(default_packer) # the actual packer function
380 pack = Any(default_packer) # the actual packer function
380 def _pack_changed(self, name, old, new):
381 def _pack_changed(self, name, old, new):
381 if not callable(new):
382 if not callable(new):
382 raise TypeError("packer must be callable, not %s"%type(new))
383 raise TypeError("packer must be callable, not %s"%type(new))
383
384
384 unpack = Any(default_unpacker) # the actual packer function
385 unpack = Any(default_unpacker) # the actual packer function
385 def _unpack_changed(self, name, old, new):
386 def _unpack_changed(self, name, old, new):
386 # unpacker is not checked - it is assumed to be
387 # unpacker is not checked - it is assumed to be
387 if not callable(new):
388 if not callable(new):
388 raise TypeError("unpacker must be callable, not %s"%type(new))
389 raise TypeError("unpacker must be callable, not %s"%type(new))
389
390
390 # thresholds:
391 # thresholds:
391 copy_threshold = Integer(2**16, config=True,
392 copy_threshold = Integer(2**16, config=True,
392 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
393 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
393 buffer_threshold = Integer(MAX_BYTES, config=True,
394 buffer_threshold = Integer(MAX_BYTES, config=True,
394 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
395 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
395 item_threshold = Integer(MAX_ITEMS, config=True,
396 item_threshold = Integer(MAX_ITEMS, config=True,
396 help="""The maximum number of items for a container to be introspected for custom serialization.
397 help="""The maximum number of items for a container to be introspected for custom serialization.
397 Containers larger than this are pickled outright.
398 Containers larger than this are pickled outright.
398 """
399 """
399 )
400 )
400
401
401
402
402 def __init__(self, **kwargs):
403 def __init__(self, **kwargs):
403 """create a Session object
404 """create a Session object
404
405
405 Parameters
406 Parameters
406 ----------
407 ----------
407
408
408 debug : bool
409 debug : bool
409 whether to trigger extra debugging statements
410 whether to trigger extra debugging statements
410 packer/unpacker : str : 'json', 'pickle' or import_string
411 packer/unpacker : str : 'json', 'pickle' or import_string
411 importstrings for methods to serialize message parts. If just
412 importstrings for methods to serialize message parts. If just
412 'json' or 'pickle', predefined JSON and pickle packers will be used.
413 'json' or 'pickle', predefined JSON and pickle packers will be used.
413 Otherwise, the entire importstring must be used.
414 Otherwise, the entire importstring must be used.
414
415
415 The functions must accept at least valid JSON input, and output
416 The functions must accept at least valid JSON input, and output
416 *bytes*.
417 *bytes*.
417
418
418 For example, to use msgpack:
419 For example, to use msgpack:
419 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
420 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
420 pack/unpack : callables
421 pack/unpack : callables
421 You can also set the pack/unpack callables for serialization
422 You can also set the pack/unpack callables for serialization
422 directly.
423 directly.
423 session : unicode (must be ascii)
424 session : unicode (must be ascii)
424 the ID of this Session object. The default is to generate a new
425 the ID of this Session object. The default is to generate a new
425 UUID.
426 UUID.
426 bsession : bytes
427 bsession : bytes
427 The session as bytes
428 The session as bytes
428 username : unicode
429 username : unicode
429 username added to message headers. The default is to ask the OS.
430 username added to message headers. The default is to ask the OS.
430 key : bytes
431 key : bytes
431 The key used to initialize an HMAC signature. If unset, messages
432 The key used to initialize an HMAC signature. If unset, messages
432 will not be signed or checked.
433 will not be signed or checked.
433 signature_scheme : str
434 signature_scheme : str
434 The message digest scheme. Currently must be of the form 'hmac-HASH',
435 The message digest scheme. Currently must be of the form 'hmac-HASH',
435 where 'HASH' is a hashing function available in Python's hashlib.
436 where 'HASH' is a hashing function available in Python's hashlib.
436 The default is 'hmac-sha256'.
437 The default is 'hmac-sha256'.
437 This is ignored if 'key' is empty.
438 This is ignored if 'key' is empty.
438 keyfile : filepath
439 keyfile : filepath
439 The file containing a key. If this is set, `key` will be
440 The file containing a key. If this is set, `key` will be
440 initialized to the contents of the file.
441 initialized to the contents of the file.
441 """
442 """
442 super(Session, self).__init__(**kwargs)
443 super(Session, self).__init__(**kwargs)
443 self._check_packers()
444 self._check_packers()
444 self.none = self.pack({})
445 self.none = self.pack({})
445 # ensure self._session_default() if necessary, so bsession is defined:
446 # ensure self._session_default() if necessary, so bsession is defined:
446 self.session
447 self.session
447 self.pid = os.getpid()
448 self.pid = os.getpid()
448 self._new_auth()
449 self._new_auth()
449
450
450 @property
451 @property
451 def msg_id(self):
452 def msg_id(self):
452 """always return new uuid"""
453 """always return new uuid"""
453 return str(uuid.uuid4())
454 return str(uuid.uuid4())
454
455
455 def _check_packers(self):
456 def _check_packers(self):
456 """check packers for datetime support."""
457 """check packers for datetime support."""
457 pack = self.pack
458 pack = self.pack
458 unpack = self.unpack
459 unpack = self.unpack
459
460
460 # check simple serialization
461 # check simple serialization
461 msg = dict(a=[1,'hi'])
462 msg = dict(a=[1,'hi'])
462 try:
463 try:
463 packed = pack(msg)
464 packed = pack(msg)
464 except Exception as e:
465 except Exception as e:
465 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
466 msg = "packer '{packer}' could not serialize a simple message: {e}{jsonmsg}"
466 if self.packer == 'json':
467 if self.packer == 'json':
467 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
468 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
468 else:
469 else:
469 jsonmsg = ""
470 jsonmsg = ""
470 raise ValueError(
471 raise ValueError(
471 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
472 msg.format(packer=self.packer, e=e, jsonmsg=jsonmsg)
472 )
473 )
473
474
474 # ensure packed message is bytes
475 # ensure packed message is bytes
475 if not isinstance(packed, bytes):
476 if not isinstance(packed, bytes):
476 raise ValueError("message packed to %r, but bytes are required"%type(packed))
477 raise ValueError("message packed to %r, but bytes are required"%type(packed))
477
478
478 # check that unpack is pack's inverse
479 # check that unpack is pack's inverse
479 try:
480 try:
480 unpacked = unpack(packed)
481 unpacked = unpack(packed)
481 assert unpacked == msg
482 assert unpacked == msg
482 except Exception as e:
483 except Exception as e:
483 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
484 msg = "unpacker '{unpacker}' could not handle output from packer '{packer}': {e}{jsonmsg}"
484 if self.packer == 'json':
485 if self.packer == 'json':
485 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
486 jsonmsg = "\nzmq.utils.jsonapi.jsonmod = %s" % jsonapi.jsonmod
486 else:
487 else:
487 jsonmsg = ""
488 jsonmsg = ""
488 raise ValueError(
489 raise ValueError(
489 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
490 msg.format(packer=self.packer, unpacker=self.unpacker, e=e, jsonmsg=jsonmsg)
490 )
491 )
491
492
492 # check datetime support
493 # check datetime support
493 msg = dict(t=datetime.now())
494 msg = dict(t=datetime.now())
494 try:
495 try:
495 unpacked = unpack(pack(msg))
496 unpacked = unpack(pack(msg))
496 if isinstance(unpacked['t'], datetime):
497 if isinstance(unpacked['t'], datetime):
497 raise ValueError("Shouldn't deserialize to datetime")
498 raise ValueError("Shouldn't deserialize to datetime")
498 except Exception:
499 except Exception:
499 self.pack = lambda o: pack(squash_dates(o))
500 self.pack = lambda o: pack(squash_dates(o))
500 self.unpack = lambda s: unpack(s)
501 self.unpack = lambda s: unpack(s)
501
502
502 def msg_header(self, msg_type):
503 def msg_header(self, msg_type):
503 return msg_header(self.msg_id, msg_type, self.username, self.session)
504 return msg_header(self.msg_id, msg_type, self.username, self.session)
504
505
505 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
506 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
506 """Return the nested message dict.
507 """Return the nested message dict.
507
508
508 This format is different from what is sent over the wire. The
509 This format is different from what is sent over the wire. The
509 serialize/deserialize methods converts this nested message dict to the wire
510 serialize/deserialize methods converts this nested message dict to the wire
510 format, which is a list of message parts.
511 format, which is a list of message parts.
511 """
512 """
512 msg = {}
513 msg = {}
513 header = self.msg_header(msg_type) if header is None else header
514 header = self.msg_header(msg_type) if header is None else header
514 msg['header'] = header
515 msg['header'] = header
515 msg['msg_id'] = header['msg_id']
516 msg['msg_id'] = header['msg_id']
516 msg['msg_type'] = header['msg_type']
517 msg['msg_type'] = header['msg_type']
517 msg['parent_header'] = {} if parent is None else extract_header(parent)
518 msg['parent_header'] = {} if parent is None else extract_header(parent)
518 msg['content'] = {} if content is None else content
519 msg['content'] = {} if content is None else content
519 msg['metadata'] = self.metadata.copy()
520 msg['metadata'] = self.metadata.copy()
520 if metadata is not None:
521 if metadata is not None:
521 msg['metadata'].update(metadata)
522 msg['metadata'].update(metadata)
522 return msg
523 return msg
523
524
524 def sign(self, msg_list):
525 def sign(self, msg_list):
525 """Sign a message with HMAC digest. If no auth, return b''.
526 """Sign a message with HMAC digest. If no auth, return b''.
526
527
527 Parameters
528 Parameters
528 ----------
529 ----------
529 msg_list : list
530 msg_list : list
530 The [p_header,p_parent,p_content] part of the message list.
531 The [p_header,p_parent,p_content] part of the message list.
531 """
532 """
532 if self.auth is None:
533 if self.auth is None:
533 return b''
534 return b''
534 h = self.auth.copy()
535 h = self.auth.copy()
535 for m in msg_list:
536 for m in msg_list:
536 h.update(m)
537 h.update(m)
537 return str_to_bytes(h.hexdigest())
538 return str_to_bytes(h.hexdigest())
538
539
539 def serialize(self, msg, ident=None):
540 def serialize(self, msg, ident=None):
540 """Serialize the message components to bytes.
541 """Serialize the message components to bytes.
541
542
542 This is roughly the inverse of deserialize. The serialize/deserialize
543 This is roughly the inverse of deserialize. The serialize/deserialize
543 methods work with full message lists, whereas pack/unpack work with
544 methods work with full message lists, whereas pack/unpack work with
544 the individual message parts in the message list.
545 the individual message parts in the message list.
545
546
546 Parameters
547 Parameters
547 ----------
548 ----------
548 msg : dict or Message
549 msg : dict or Message
549 The next message dict as returned by the self.msg method.
550 The next message dict as returned by the self.msg method.
550
551
551 Returns
552 Returns
552 -------
553 -------
553 msg_list : list
554 msg_list : list
554 The list of bytes objects to be sent with the format::
555 The list of bytes objects to be sent with the format::
555
556
556 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
557 [ident1, ident2, ..., DELIM, HMAC, p_header, p_parent,
557 p_metadata, p_content, buffer1, buffer2, ...]
558 p_metadata, p_content, buffer1, buffer2, ...]
558
559
559 In this list, the ``p_*`` entities are the packed or serialized
560 In this list, the ``p_*`` entities are the packed or serialized
560 versions, so if JSON is used, these are utf8 encoded JSON strings.
561 versions, so if JSON is used, these are utf8 encoded JSON strings.
561 """
562 """
562 content = msg.get('content', {})
563 content = msg.get('content', {})
563 if content is None:
564 if content is None:
564 content = self.none
565 content = self.none
565 elif isinstance(content, dict):
566 elif isinstance(content, dict):
566 content = self.pack(content)
567 content = self.pack(content)
567 elif isinstance(content, bytes):
568 elif isinstance(content, bytes):
568 # content is already packed, as in a relayed message
569 # content is already packed, as in a relayed message
569 pass
570 pass
570 elif isinstance(content, unicode_type):
571 elif isinstance(content, unicode_type):
571 # should be bytes, but JSON often spits out unicode
572 # should be bytes, but JSON often spits out unicode
572 content = content.encode('utf8')
573 content = content.encode('utf8')
573 else:
574 else:
574 raise TypeError("Content incorrect type: %s"%type(content))
575 raise TypeError("Content incorrect type: %s"%type(content))
575
576
576 real_message = [self.pack(msg['header']),
577 real_message = [self.pack(msg['header']),
577 self.pack(msg['parent_header']),
578 self.pack(msg['parent_header']),
578 self.pack(msg['metadata']),
579 self.pack(msg['metadata']),
579 content,
580 content,
580 ]
581 ]
581
582
582 to_send = []
583 to_send = []
583
584
584 if isinstance(ident, list):
585 if isinstance(ident, list):
585 # accept list of idents
586 # accept list of idents
586 to_send.extend(ident)
587 to_send.extend(ident)
587 elif ident is not None:
588 elif ident is not None:
588 to_send.append(ident)
589 to_send.append(ident)
589 to_send.append(DELIM)
590 to_send.append(DELIM)
590
591
591 signature = self.sign(real_message)
592 signature = self.sign(real_message)
592 to_send.append(signature)
593 to_send.append(signature)
593
594
594 to_send.extend(real_message)
595 to_send.extend(real_message)
595
596
596 return to_send
597 return to_send
597
598
598 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
599 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
599 buffers=None, track=False, header=None, metadata=None):
600 buffers=None, track=False, header=None, metadata=None):
600 """Build and send a message via stream or socket.
601 """Build and send a message via stream or socket.
601
602
602 The message format used by this function internally is as follows:
603 The message format used by this function internally is as follows:
603
604
604 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
605 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
605 buffer1,buffer2,...]
606 buffer1,buffer2,...]
606
607
607 The serialize/deserialize methods convert the nested message dict into this
608 The serialize/deserialize methods convert the nested message dict into this
608 format.
609 format.
609
610
610 Parameters
611 Parameters
611 ----------
612 ----------
612
613
613 stream : zmq.Socket or ZMQStream
614 stream : zmq.Socket or ZMQStream
614 The socket-like object used to send the data.
615 The socket-like object used to send the data.
615 msg_or_type : str or Message/dict
616 msg_or_type : str or Message/dict
616 Normally, msg_or_type will be a msg_type unless a message is being
617 Normally, msg_or_type will be a msg_type unless a message is being
617 sent more than once. If a header is supplied, this can be set to
618 sent more than once. If a header is supplied, this can be set to
618 None and the msg_type will be pulled from the header.
619 None and the msg_type will be pulled from the header.
619
620
620 content : dict or None
621 content : dict or None
621 The content of the message (ignored if msg_or_type is a message).
622 The content of the message (ignored if msg_or_type is a message).
622 header : dict or None
623 header : dict or None
623 The header dict for the message (ignored if msg_to_type is a message).
624 The header dict for the message (ignored if msg_to_type is a message).
624 parent : Message or dict or None
625 parent : Message or dict or None
625 The parent or parent header describing the parent of this message
626 The parent or parent header describing the parent of this message
626 (ignored if msg_or_type is a message).
627 (ignored if msg_or_type is a message).
627 ident : bytes or list of bytes
628 ident : bytes or list of bytes
628 The zmq.IDENTITY routing path.
629 The zmq.IDENTITY routing path.
629 metadata : dict or None
630 metadata : dict or None
630 The metadata describing the message
631 The metadata describing the message
631 buffers : list or None
632 buffers : list or None
632 The already-serialized buffers to be appended to the message.
633 The already-serialized buffers to be appended to the message.
633 track : bool
634 track : bool
634 Whether to track. Only for use with Sockets, because ZMQStream
635 Whether to track. Only for use with Sockets, because ZMQStream
635 objects cannot track messages.
636 objects cannot track messages.
636
637
637
638
638 Returns
639 Returns
639 -------
640 -------
640 msg : dict
641 msg : dict
641 The constructed message.
642 The constructed message.
642 """
643 """
643 if not isinstance(stream, zmq.Socket):
644 if not isinstance(stream, zmq.Socket):
644 # ZMQStreams and dummy sockets do not support tracking.
645 # ZMQStreams and dummy sockets do not support tracking.
645 track = False
646 track = False
646
647
647 if isinstance(msg_or_type, (Message, dict)):
648 if isinstance(msg_or_type, (Message, dict)):
648 # We got a Message or message dict, not a msg_type so don't
649 # We got a Message or message dict, not a msg_type so don't
649 # build a new Message.
650 # build a new Message.
650 msg = msg_or_type
651 msg = msg_or_type
651 buffers = buffers or msg.get('buffers', [])
652 buffers = buffers or msg.get('buffers', [])
652 else:
653 else:
653 msg = self.msg(msg_or_type, content=content, parent=parent,
654 msg = self.msg(msg_or_type, content=content, parent=parent,
654 header=header, metadata=metadata)
655 header=header, metadata=metadata)
655 if not os.getpid() == self.pid:
656 if not os.getpid() == self.pid:
656 io.rprint("WARNING: attempted to send message from fork")
657 get_logger().warn("WARNING: attempted to send message from fork\n%s",
657 io.rprint(msg)
658 msg
659 )
658 return
660 return
659 buffers = [] if buffers is None else buffers
661 buffers = [] if buffers is None else buffers
660 if self.adapt_version:
662 if self.adapt_version:
661 msg = adapt(msg, self.adapt_version)
663 msg = adapt(msg, self.adapt_version)
662 to_send = self.serialize(msg, ident)
664 to_send = self.serialize(msg, ident)
663 to_send.extend(buffers)
665 to_send.extend(buffers)
664 longest = max([ len(s) for s in to_send ])
666 longest = max([ len(s) for s in to_send ])
665 copy = (longest < self.copy_threshold)
667 copy = (longest < self.copy_threshold)
666
668
667 if buffers and track and not copy:
669 if buffers and track and not copy:
668 # only really track when we are doing zero-copy buffers
670 # only really track when we are doing zero-copy buffers
669 tracker = stream.send_multipart(to_send, copy=False, track=True)
671 tracker = stream.send_multipart(to_send, copy=False, track=True)
670 else:
672 else:
671 # use dummy tracker, which will be done immediately
673 # use dummy tracker, which will be done immediately
672 tracker = DONE
674 tracker = DONE
673 stream.send_multipart(to_send, copy=copy)
675 stream.send_multipart(to_send, copy=copy)
674
676
675 if self.debug:
677 if self.debug:
676 pprint.pprint(msg)
678 pprint.pprint(msg)
677 pprint.pprint(to_send)
679 pprint.pprint(to_send)
678 pprint.pprint(buffers)
680 pprint.pprint(buffers)
679
681
680 msg['tracker'] = tracker
682 msg['tracker'] = tracker
681
683
682 return msg
684 return msg
683
685
684 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
686 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
685 """Send a raw message via ident path.
687 """Send a raw message via ident path.
686
688
687 This method is used to send a already serialized message.
689 This method is used to send a already serialized message.
688
690
689 Parameters
691 Parameters
690 ----------
692 ----------
691 stream : ZMQStream or Socket
693 stream : ZMQStream or Socket
692 The ZMQ stream or socket to use for sending the message.
694 The ZMQ stream or socket to use for sending the message.
693 msg_list : list
695 msg_list : list
694 The serialized list of messages to send. This only includes the
696 The serialized list of messages to send. This only includes the
695 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
697 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
696 the message.
698 the message.
697 ident : ident or list
699 ident : ident or list
698 A single ident or a list of idents to use in sending.
700 A single ident or a list of idents to use in sending.
699 """
701 """
700 to_send = []
702 to_send = []
701 if isinstance(ident, bytes):
703 if isinstance(ident, bytes):
702 ident = [ident]
704 ident = [ident]
703 if ident is not None:
705 if ident is not None:
704 to_send.extend(ident)
706 to_send.extend(ident)
705
707
706 to_send.append(DELIM)
708 to_send.append(DELIM)
707 to_send.append(self.sign(msg_list))
709 to_send.append(self.sign(msg_list))
708 to_send.extend(msg_list)
710 to_send.extend(msg_list)
709 stream.send_multipart(to_send, flags, copy=copy)
711 stream.send_multipart(to_send, flags, copy=copy)
710
712
711 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
713 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
712 """Receive and unpack a message.
714 """Receive and unpack a message.
713
715
714 Parameters
716 Parameters
715 ----------
717 ----------
716 socket : ZMQStream or Socket
718 socket : ZMQStream or Socket
717 The socket or stream to use in receiving.
719 The socket or stream to use in receiving.
718
720
719 Returns
721 Returns
720 -------
722 -------
721 [idents], msg
723 [idents], msg
722 [idents] is a list of idents and msg is a nested message dict of
724 [idents] is a list of idents and msg is a nested message dict of
723 same format as self.msg returns.
725 same format as self.msg returns.
724 """
726 """
725 if isinstance(socket, ZMQStream):
727 if isinstance(socket, ZMQStream):
726 socket = socket.socket
728 socket = socket.socket
727 try:
729 try:
728 msg_list = socket.recv_multipart(mode, copy=copy)
730 msg_list = socket.recv_multipart(mode, copy=copy)
729 except zmq.ZMQError as e:
731 except zmq.ZMQError as e:
730 if e.errno == zmq.EAGAIN:
732 if e.errno == zmq.EAGAIN:
731 # We can convert EAGAIN to None as we know in this case
733 # We can convert EAGAIN to None as we know in this case
732 # recv_multipart won't return None.
734 # recv_multipart won't return None.
733 return None,None
735 return None,None
734 else:
736 else:
735 raise
737 raise
736 # split multipart message into identity list and message dict
738 # split multipart message into identity list and message dict
737 # invalid large messages can cause very expensive string comparisons
739 # invalid large messages can cause very expensive string comparisons
738 idents, msg_list = self.feed_identities(msg_list, copy)
740 idents, msg_list = self.feed_identities(msg_list, copy)
739 try:
741 try:
740 return idents, self.deserialize(msg_list, content=content, copy=copy)
742 return idents, self.deserialize(msg_list, content=content, copy=copy)
741 except Exception as e:
743 except Exception as e:
742 # TODO: handle it
744 # TODO: handle it
743 raise e
745 raise e
744
746
745 def feed_identities(self, msg_list, copy=True):
747 def feed_identities(self, msg_list, copy=True):
746 """Split the identities from the rest of the message.
748 """Split the identities from the rest of the message.
747
749
748 Feed until DELIM is reached, then return the prefix as idents and
750 Feed until DELIM is reached, then return the prefix as idents and
749 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
751 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
750 but that would be silly.
752 but that would be silly.
751
753
752 Parameters
754 Parameters
753 ----------
755 ----------
754 msg_list : a list of Message or bytes objects
756 msg_list : a list of Message or bytes objects
755 The message to be split.
757 The message to be split.
756 copy : bool
758 copy : bool
757 flag determining whether the arguments are bytes or Messages
759 flag determining whether the arguments are bytes or Messages
758
760
759 Returns
761 Returns
760 -------
762 -------
761 (idents, msg_list) : two lists
763 (idents, msg_list) : two lists
762 idents will always be a list of bytes, each of which is a ZMQ
764 idents will always be a list of bytes, each of which is a ZMQ
763 identity. msg_list will be a list of bytes or zmq.Messages of the
765 identity. msg_list will be a list of bytes or zmq.Messages of the
764 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
766 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
765 should be unpackable/unserializable via self.deserialize at this
767 should be unpackable/unserializable via self.deserialize at this
766 point.
768 point.
767 """
769 """
768 if copy:
770 if copy:
769 idx = msg_list.index(DELIM)
771 idx = msg_list.index(DELIM)
770 return msg_list[:idx], msg_list[idx+1:]
772 return msg_list[:idx], msg_list[idx+1:]
771 else:
773 else:
772 failed = True
774 failed = True
773 for idx,m in enumerate(msg_list):
775 for idx,m in enumerate(msg_list):
774 if m.bytes == DELIM:
776 if m.bytes == DELIM:
775 failed = False
777 failed = False
776 break
778 break
777 if failed:
779 if failed:
778 raise ValueError("DELIM not in msg_list")
780 raise ValueError("DELIM not in msg_list")
779 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
781 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
780 return [m.bytes for m in idents], msg_list
782 return [m.bytes for m in idents], msg_list
781
783
782 def _add_digest(self, signature):
784 def _add_digest(self, signature):
783 """add a digest to history to protect against replay attacks"""
785 """add a digest to history to protect against replay attacks"""
784 if self.digest_history_size == 0:
786 if self.digest_history_size == 0:
785 # no history, never add digests
787 # no history, never add digests
786 return
788 return
787
789
788 self.digest_history.add(signature)
790 self.digest_history.add(signature)
789 if len(self.digest_history) > self.digest_history_size:
791 if len(self.digest_history) > self.digest_history_size:
790 # threshold reached, cull 10%
792 # threshold reached, cull 10%
791 self._cull_digest_history()
793 self._cull_digest_history()
792
794
793 def _cull_digest_history(self):
795 def _cull_digest_history(self):
794 """cull the digest history
796 """cull the digest history
795
797
796 Removes a randomly selected 10% of the digest history
798 Removes a randomly selected 10% of the digest history
797 """
799 """
798 current = len(self.digest_history)
800 current = len(self.digest_history)
799 n_to_cull = max(int(current // 10), current - self.digest_history_size)
801 n_to_cull = max(int(current // 10), current - self.digest_history_size)
800 if n_to_cull >= current:
802 if n_to_cull >= current:
801 self.digest_history = set()
803 self.digest_history = set()
802 return
804 return
803 to_cull = random.sample(self.digest_history, n_to_cull)
805 to_cull = random.sample(self.digest_history, n_to_cull)
804 self.digest_history.difference_update(to_cull)
806 self.digest_history.difference_update(to_cull)
805
807
806 def deserialize(self, msg_list, content=True, copy=True):
808 def deserialize(self, msg_list, content=True, copy=True):
807 """Unserialize a msg_list to a nested message dict.
809 """Unserialize a msg_list to a nested message dict.
808
810
809 This is roughly the inverse of serialize. The serialize/deserialize
811 This is roughly the inverse of serialize. The serialize/deserialize
810 methods work with full message lists, whereas pack/unpack work with
812 methods work with full message lists, whereas pack/unpack work with
811 the individual message parts in the message list.
813 the individual message parts in the message list.
812
814
813 Parameters
815 Parameters
814 ----------
816 ----------
815 msg_list : list of bytes or Message objects
817 msg_list : list of bytes or Message objects
816 The list of message parts of the form [HMAC,p_header,p_parent,
818 The list of message parts of the form [HMAC,p_header,p_parent,
817 p_metadata,p_content,buffer1,buffer2,...].
819 p_metadata,p_content,buffer1,buffer2,...].
818 content : bool (True)
820 content : bool (True)
819 Whether to unpack the content dict (True), or leave it packed
821 Whether to unpack the content dict (True), or leave it packed
820 (False).
822 (False).
821 copy : bool (True)
823 copy : bool (True)
822 Whether msg_list contains bytes (True) or the non-copying Message
824 Whether msg_list contains bytes (True) or the non-copying Message
823 objects in each place (False).
825 objects in each place (False).
824
826
825 Returns
827 Returns
826 -------
828 -------
827 msg : dict
829 msg : dict
828 The nested message dict with top-level keys [header, parent_header,
830 The nested message dict with top-level keys [header, parent_header,
829 content, buffers]. The buffers are returned as memoryviews.
831 content, buffers]. The buffers are returned as memoryviews.
830 """
832 """
831 minlen = 5
833 minlen = 5
832 message = {}
834 message = {}
833 if not copy:
835 if not copy:
834 # pyzmq didn't copy the first parts of the message, so we'll do it
836 # pyzmq didn't copy the first parts of the message, so we'll do it
835 for i in range(minlen):
837 for i in range(minlen):
836 msg_list[i] = msg_list[i].bytes
838 msg_list[i] = msg_list[i].bytes
837 if self.auth is not None:
839 if self.auth is not None:
838 signature = msg_list[0]
840 signature = msg_list[0]
839 if not signature:
841 if not signature:
840 raise ValueError("Unsigned Message")
842 raise ValueError("Unsigned Message")
841 if signature in self.digest_history:
843 if signature in self.digest_history:
842 raise ValueError("Duplicate Signature: %r" % signature)
844 raise ValueError("Duplicate Signature: %r" % signature)
843 self._add_digest(signature)
845 self._add_digest(signature)
844 check = self.sign(msg_list[1:5])
846 check = self.sign(msg_list[1:5])
845 if not compare_digest(signature, check):
847 if not compare_digest(signature, check):
846 raise ValueError("Invalid Signature: %r" % signature)
848 raise ValueError("Invalid Signature: %r" % signature)
847 if not len(msg_list) >= minlen:
849 if not len(msg_list) >= minlen:
848 raise TypeError("malformed message, must have at least %i elements"%minlen)
850 raise TypeError("malformed message, must have at least %i elements"%minlen)
849 header = self.unpack(msg_list[1])
851 header = self.unpack(msg_list[1])
850 message['header'] = extract_dates(header)
852 message['header'] = extract_dates(header)
851 message['msg_id'] = header['msg_id']
853 message['msg_id'] = header['msg_id']
852 message['msg_type'] = header['msg_type']
854 message['msg_type'] = header['msg_type']
853 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
855 message['parent_header'] = extract_dates(self.unpack(msg_list[2]))
854 message['metadata'] = self.unpack(msg_list[3])
856 message['metadata'] = self.unpack(msg_list[3])
855 if content:
857 if content:
856 message['content'] = self.unpack(msg_list[4])
858 message['content'] = self.unpack(msg_list[4])
857 else:
859 else:
858 message['content'] = msg_list[4]
860 message['content'] = msg_list[4]
859 buffers = [memoryview(b) for b in msg_list[5:]]
861 buffers = [memoryview(b) for b in msg_list[5:]]
860 if buffers and buffers[0].shape is None:
862 if buffers and buffers[0].shape is None:
861 # force copy to workaround pyzmq #646
863 # force copy to workaround pyzmq #646
862 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
864 buffers = [memoryview(b.bytes) for b in msg_list[5:]]
863 message['buffers'] = buffers
865 message['buffers'] = buffers
864 # adapt to the current version
866 # adapt to the current version
865 return adapt(message)
867 return adapt(message)
866
868
867 def unserialize(self, *args, **kwargs):
869 def unserialize(self, *args, **kwargs):
868 warnings.warn(
870 warnings.warn(
869 "Session.unserialize is deprecated. Use Session.deserialize.",
871 "Session.unserialize is deprecated. Use Session.deserialize.",
870 DeprecationWarning,
872 DeprecationWarning,
871 )
873 )
872 return self.deserialize(*args, **kwargs)
874 return self.deserialize(*args, **kwargs)
873
875
874
876
875 def test_msg2obj():
877 def test_msg2obj():
876 am = dict(x=1)
878 am = dict(x=1)
877 ao = Message(am)
879 ao = Message(am)
878 assert ao.x == am['x']
880 assert ao.x == am['x']
879
881
880 am['y'] = dict(z=1)
882 am['y'] = dict(z=1)
881 ao = Message(am)
883 ao = Message(am)
882 assert ao.y.z == am['y']['z']
884 assert ao.y.z == am['y']['z']
883
885
884 k1, k2 = 'y', 'z'
886 k1, k2 = 'y', 'z'
885 assert ao[k1][k2] == am[k1][k2]
887 assert ao[k1][k2] == am[k1][k2]
886
888
887 am2 = dict(ao)
889 am2 = dict(ao)
888 assert am['x'] == am2['x']
890 assert am['x'] == am2['x']
889 assert am['y']['z'] == am2['y']['z']
891 assert am['y']['z'] == am2['y']['z']
@@ -1,141 +1,128 b''
1 """Utility for calling pandoc"""
1 """Utility for calling pandoc"""
2 #-----------------------------------------------------------------------------
2 # Copyright (c) IPython Development Team.
3 # Copyright (c) 2014 the IPython Development Team.
4 #
5 # Distributed under the terms of the Modified BSD License.
3 # Distributed under the terms of the Modified BSD License.
6 #
7 # The full license is in the file COPYING.txt, distributed with this software.
8 #-----------------------------------------------------------------------------
9
4
10 #-----------------------------------------------------------------------------
5 from __future__ import print_function, absolute_import
11 # Imports
12 #-----------------------------------------------------------------------------
13 from __future__ import print_function
14
6
15 # Stdlib imports
16 import subprocess
7 import subprocess
17 import warnings
8 import warnings
18 import re
9 import re
19 from io import TextIOWrapper, BytesIO
10 from io import TextIOWrapper, BytesIO
20
11
21 # IPython imports
22 from IPython.utils.py3compat import cast_bytes
12 from IPython.utils.py3compat import cast_bytes
23 from IPython.utils.version import check_version
13 from IPython.utils.version import check_version
24 from IPython.utils.process import is_cmd_found, FindCmdError
14 from IPython.utils.process import is_cmd_found, FindCmdError
25
15
26 from .exceptions import ConversionException
16 from .exceptions import ConversionException
27
17
28 #-----------------------------------------------------------------------------
29 # Classes and functions
30 #-----------------------------------------------------------------------------
31 _minimal_version = "1.12.1"
18 _minimal_version = "1.12.1"
32
19
33 def pandoc(source, fmt, to, extra_args=None, encoding='utf-8'):
20 def pandoc(source, fmt, to, extra_args=None, encoding='utf-8'):
34 """Convert an input string in format `from` to format `to` via pandoc.
21 """Convert an input string in format `from` to format `to` via pandoc.
35
22
36 Parameters
23 Parameters
37 ----------
24 ----------
38 source : string
25 source : string
39 Input string, assumed to be valid format `from`.
26 Input string, assumed to be valid format `from`.
40 fmt : string
27 fmt : string
41 The name of the input format (markdown, etc.)
28 The name of the input format (markdown, etc.)
42 to : string
29 to : string
43 The name of the output format (html, etc.)
30 The name of the output format (html, etc.)
44
31
45 Returns
32 Returns
46 -------
33 -------
47 out : unicode
34 out : unicode
48 Output as returned by pandoc.
35 Output as returned by pandoc.
49
36
50 Raises
37 Raises
51 ------
38 ------
52 PandocMissing
39 PandocMissing
53 If pandoc is not installed.
40 If pandoc is not installed.
54
41
55 Any error messages generated by pandoc are printed to stderr.
42 Any error messages generated by pandoc are printed to stderr.
56
43
57 """
44 """
58 cmd = ['pandoc', '-f', fmt, '-t', to]
45 cmd = ['pandoc', '-f', fmt, '-t', to]
59 if extra_args:
46 if extra_args:
60 cmd.extend(extra_args)
47 cmd.extend(extra_args)
61
48
62 # this will raise an exception that will pop us out of here
49 # this will raise an exception that will pop us out of here
63 check_pandoc_version()
50 check_pandoc_version()
64
51
65 # we can safely continue
52 # we can safely continue
66 p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
53 p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
67 out, _ = p.communicate(cast_bytes(source, encoding))
54 out, _ = p.communicate(cast_bytes(source, encoding))
68 out = TextIOWrapper(BytesIO(out), encoding, 'replace').read()
55 out = TextIOWrapper(BytesIO(out), encoding, 'replace').read()
69 return out.rstrip('\n')
56 return out.rstrip('\n')
70
57
71
58
72 def get_pandoc_version():
59 def get_pandoc_version():
73 """Gets the Pandoc version if Pandoc is installed.
60 """Gets the Pandoc version if Pandoc is installed.
74
61
75 If the minimal version is not met, it will probe Pandoc for its version, cache it and return that value.
62 If the minimal version is not met, it will probe Pandoc for its version, cache it and return that value.
76 If the minimal version is met, it will return the cached version and stop probing Pandoc
63 If the minimal version is met, it will return the cached version and stop probing Pandoc
77 (unless :func:`clean_cache()` is called).
64 (unless :func:`clean_cache()` is called).
78
65
79 Raises
66 Raises
80 ------
67 ------
81 PandocMissing
68 PandocMissing
82 If pandoc is unavailable.
69 If pandoc is unavailable.
83 """
70 """
84 global __version
71 global __version
85
72
86 if __version is None:
73 if __version is None:
87 if not is_cmd_found('pandoc'):
74 if not is_cmd_found('pandoc'):
88 raise PandocMissing()
75 raise PandocMissing()
89
76
90 out = subprocess.check_output(['pandoc', '-v'],
77 out = subprocess.check_output(['pandoc', '-v'],
91 universal_newlines=True)
78 universal_newlines=True)
92 out_lines = out.splitlines()
79 out_lines = out.splitlines()
93 version_pattern = re.compile(r"^\d+(\.\d+){1,}$")
80 version_pattern = re.compile(r"^\d+(\.\d+){1,}$")
94 for tok in out_lines[0].split():
81 for tok in out_lines[0].split():
95 if version_pattern.match(tok):
82 if version_pattern.match(tok):
96 __version = tok
83 __version = tok
97 break
84 break
98 return __version
85 return __version
99
86
100
87
101 def check_pandoc_version():
88 def check_pandoc_version():
102 """Returns True if minimal pandoc version is met.
89 """Returns True if minimal pandoc version is met.
103
90
104 Raises
91 Raises
105 ------
92 ------
106 PandocMissing
93 PandocMissing
107 If pandoc is unavailable.
94 If pandoc is unavailable.
108 """
95 """
109 v = get_pandoc_version()
96 v = get_pandoc_version()
110 if v is None:
97 if v is None:
111 warnings.warn("Sorry, we cannot determine the version of pandoc.\n"
98 warnings.warn("Sorry, we cannot determine the version of pandoc.\n"
112 "Please consider reporting this issue and include the"
99 "Please consider reporting this issue and include the"
113 "output of pandoc --version.\nContinuing...",
100 "output of pandoc --version.\nContinuing...",
114 RuntimeWarning, stacklevel=2)
101 RuntimeWarning, stacklevel=2)
115 return False
102 return False
116 ok = check_version(v , _minimal_version )
103 ok = check_version(v , _minimal_version )
117 if not ok:
104 if not ok:
118 warnings.warn( "You are using an old version of pandoc (%s)\n" % v +
105 warnings.warn( "You are using an old version of pandoc (%s)\n" % v +
119 "Recommended version is %s.\nTry updating." % _minimal_version +
106 "Recommended version is %s.\nTry updating." % _minimal_version +
120 "http://johnmacfarlane.net/pandoc/installing.html.\nContinuing with doubts...",
107 "http://johnmacfarlane.net/pandoc/installing.html.\nContinuing with doubts...",
121 RuntimeWarning, stacklevel=2)
108 RuntimeWarning, stacklevel=2)
122 return ok
109 return ok
123
110
124 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
125 # Exception handling
112 # Exception handling
126 #-----------------------------------------------------------------------------
113 #-----------------------------------------------------------------------------
127 class PandocMissing(ConversionException):
114 class PandocMissing(ConversionException):
128 """Exception raised when Pandoc is missing. """
115 """Exception raised when Pandoc is missing. """
129 def __init__(self, *args, **kwargs):
116 def __init__(self, *args, **kwargs):
130 super(PandocMissing, self).__init__( "Pandoc wasn't found.\n" +
117 super(PandocMissing, self).__init__( "Pandoc wasn't found.\n" +
131 "Please check that pandoc is installed:\n" +
118 "Please check that pandoc is installed:\n" +
132 "http://johnmacfarlane.net/pandoc/installing.html" )
119 "http://johnmacfarlane.net/pandoc/installing.html" )
133
120
134 #-----------------------------------------------------------------------------
121 #-----------------------------------------------------------------------------
135 # Internal state management
122 # Internal state management
136 #-----------------------------------------------------------------------------
123 #-----------------------------------------------------------------------------
137 def clean_cache():
124 def clean_cache():
138 global __version
125 global __version
139 __version = None
126 __version = None
140
127
141 __version = None
128 __version = None
@@ -1,34 +1,23 b''
1 """
1 """
2 Contains Stdout writer
2 Contains Stdout writer
3 """
3 """
4 #-----------------------------------------------------------------------------
5 #Copyright (c) 2013, the IPython Development Team.
6 #
7 #Distributed under the terms of the Modified BSD License.
8 #
9 #The full license is in the file COPYING.txt, distributed with this software.
10 #-----------------------------------------------------------------------------
11
4
12 #-----------------------------------------------------------------------------
5 # Copyright (c) Jupyter Development Team.
13 # Imports
6 # Distributed under the terms of the Modified BSD License.
14 #-----------------------------------------------------------------------------
15
7
16 from IPython.utils import io
8 from jupyter_nbconvert.utils import io
17 from .base import WriterBase
9 from .base import WriterBase
18
10
19 #-----------------------------------------------------------------------------
20 # Classes
21 #-----------------------------------------------------------------------------
22
11
23 class StdoutWriter(WriterBase):
12 class StdoutWriter(WriterBase):
24 """Consumes output from nbconvert export...() methods and writes to the
13 """Consumes output from nbconvert export...() methods and writes to the
25 stdout stream."""
14 stdout stream."""
26
15
27
16
28 def write(self, output, resources, **kw):
17 def write(self, output, resources, **kw):
29 """
18 """
30 Consume and write Jinja output.
19 Consume and write Jinja output.
31
20
32 See base for more...
21 See base for more...
33 """
22 """
34 io.unicode_std_stream().write(output)
23 io.unicode_std_stream().write(output)
General Comments 0
You need to be logged in to leave comments. Login now