##// END OF EJS Templates
Merge pull request #10865 from hugovk/rm-python2...
Thomas Kluyver -
r24018:5b2b7dd0 merge
parent child Browse files
Show More
@@ -0,0 +1,19
1 # Top-most EditorConfig file
2 root = true
3
4 [*]
5 # Unix-style newlines with a newline ending every file
6 end_of_line = lf
7 insert_final_newline = true
8 charset = utf-8
9
10 # Four-space indentation
11 indent_size = 4
12 indent_style = space
13
14 trim_trailing_whitespace = false
15
16 [*.yml]
17 # Two-space indentation
18 indent_size = 2
19 indent_style = space
@@ -1,25 +1,26
1 MANIFEST
1 MANIFEST
2 build
2 build
3 dist
3 dist
4 _build
4 _build
5 docs/man/*.gz
5 docs/man/*.gz
6 docs/source/api/generated
6 docs/source/api/generated
7 docs/source/config/options
7 docs/source/config/options
8 docs/source/config/shortcuts/*.csv
8 docs/source/config/shortcuts/*.csv
9 docs/source/interactive/magics-generated.txt
9 docs/source/interactive/magics-generated.txt
10 docs/source/config/shortcuts/*.csv
10 docs/source/config/shortcuts/*.csv
11 docs/gh-pages
11 docs/gh-pages
12 jupyter_notebook/notebook/static/mathjax
12 jupyter_notebook/notebook/static/mathjax
13 jupyter_notebook/static/style/*.map
13 jupyter_notebook/static/style/*.map
14 *.py[co]
14 *.py[co]
15 __pycache__
15 __pycache__
16 *.egg-info
16 *.egg-info
17 *~
17 *~
18 *.bak
18 *.bak
19 .ipynb_checkpoints
19 .ipynb_checkpoints
20 .tox
20 .tox
21 .DS_Store
21 .DS_Store
22 \#*#
22 \#*#
23 .#*
23 .#*
24 .cache
24 .coverage
25 .coverage
25 *.swp
26 *.swp
@@ -1,279 +1,279
1 """Magic functions for running cells in various scripts."""
1 """Magic functions for running cells in various scripts."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import errno
6 import errno
7 import os
7 import os
8 import sys
8 import sys
9 import signal
9 import signal
10 import time
10 import time
11 from subprocess import Popen, PIPE
11 from subprocess import Popen, PIPE
12 import atexit
12 import atexit
13
13
14 from IPython.core import magic_arguments
14 from IPython.core import magic_arguments
15 from IPython.core.magic import (
15 from IPython.core.magic import (
16 Magics, magics_class, line_magic, cell_magic
16 Magics, magics_class, line_magic, cell_magic
17 )
17 )
18 from IPython.lib.backgroundjobs import BackgroundJobManager
18 from IPython.lib.backgroundjobs import BackgroundJobManager
19 from IPython.utils import py3compat
19 from IPython.utils import py3compat
20 from IPython.utils.process import arg_split
20 from IPython.utils.process import arg_split
21 from traitlets import List, Dict, default
21 from traitlets import List, Dict, default
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Magic implementation classes
24 # Magic implementation classes
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 def script_args(f):
27 def script_args(f):
28 """single decorator for adding script args"""
28 """single decorator for adding script args"""
29 args = [
29 args = [
30 magic_arguments.argument(
30 magic_arguments.argument(
31 '--out', type=str,
31 '--out', type=str,
32 help="""The variable in which to store stdout from the script.
32 help="""The variable in which to store stdout from the script.
33 If the script is backgrounded, this will be the stdout *pipe*,
33 If the script is backgrounded, this will be the stdout *pipe*,
34 instead of the stderr text itself.
34 instead of the stderr text itself.
35 """
35 """
36 ),
36 ),
37 magic_arguments.argument(
37 magic_arguments.argument(
38 '--err', type=str,
38 '--err', type=str,
39 help="""The variable in which to store stderr from the script.
39 help="""The variable in which to store stderr from the script.
40 If the script is backgrounded, this will be the stderr *pipe*,
40 If the script is backgrounded, this will be the stderr *pipe*,
41 instead of the stderr text itself.
41 instead of the stderr text itself.
42 """
42 """
43 ),
43 ),
44 magic_arguments.argument(
44 magic_arguments.argument(
45 '--bg', action="store_true",
45 '--bg', action="store_true",
46 help="""Whether to run the script in the background.
46 help="""Whether to run the script in the background.
47 If given, the only way to see the output of the command is
47 If given, the only way to see the output of the command is
48 with --out/err.
48 with --out/err.
49 """
49 """
50 ),
50 ),
51 magic_arguments.argument(
51 magic_arguments.argument(
52 '--proc', type=str,
52 '--proc', type=str,
53 help="""The variable in which to store Popen instance.
53 help="""The variable in which to store Popen instance.
54 This is used only when --bg option is given.
54 This is used only when --bg option is given.
55 """
55 """
56 ),
56 ),
57 ]
57 ]
58 for arg in args:
58 for arg in args:
59 f = arg(f)
59 f = arg(f)
60 return f
60 return f
61
61
62 @magics_class
62 @magics_class
63 class ScriptMagics(Magics):
63 class ScriptMagics(Magics):
64 """Magics for talking to scripts
64 """Magics for talking to scripts
65
65
66 This defines a base `%%script` cell magic for running a cell
66 This defines a base `%%script` cell magic for running a cell
67 with a program in a subprocess, and registers a few top-level
67 with a program in a subprocess, and registers a few top-level
68 magics that call %%script with common interpreters.
68 magics that call %%script with common interpreters.
69 """
69 """
70 script_magics = List(
70 script_magics = List(
71 help="""Extra script cell magics to define
71 help="""Extra script cell magics to define
72
72
73 This generates simple wrappers of `%%script foo` as `%%foo`.
73 This generates simple wrappers of `%%script foo` as `%%foo`.
74
74
75 If you want to add script magics that aren't on your path,
75 If you want to add script magics that aren't on your path,
76 specify them in script_paths
76 specify them in script_paths
77 """,
77 """,
78 ).tag(config=True)
78 ).tag(config=True)
79 @default('script_magics')
79 @default('script_magics')
80 def _script_magics_default(self):
80 def _script_magics_default(self):
81 """default to a common list of programs"""
81 """default to a common list of programs"""
82
82
83 defaults = [
83 defaults = [
84 'sh',
84 'sh',
85 'bash',
85 'bash',
86 'perl',
86 'perl',
87 'ruby',
87 'ruby',
88 'python',
88 'python',
89 'python2',
89 'python2',
90 'python3',
90 'python3',
91 'pypy',
91 'pypy',
92 ]
92 ]
93 if os.name == 'nt':
93 if os.name == 'nt':
94 defaults.extend([
94 defaults.extend([
95 'cmd',
95 'cmd',
96 ])
96 ])
97
97
98 return defaults
98 return defaults
99
99
100 script_paths = Dict(
100 script_paths = Dict(
101 help="""Dict mapping short 'ruby' names to full paths, such as '/opt/secret/bin/ruby'
101 help="""Dict mapping short 'ruby' names to full paths, such as '/opt/secret/bin/ruby'
102
102
103 Only necessary for items in script_magics where the default path will not
103 Only necessary for items in script_magics where the default path will not
104 find the right interpreter.
104 find the right interpreter.
105 """
105 """
106 ).tag(config=True)
106 ).tag(config=True)
107
107
108 def __init__(self, shell=None):
108 def __init__(self, shell=None):
109 super(ScriptMagics, self).__init__(shell=shell)
109 super(ScriptMagics, self).__init__(shell=shell)
110 self._generate_script_magics()
110 self._generate_script_magics()
111 self.job_manager = BackgroundJobManager()
111 self.job_manager = BackgroundJobManager()
112 self.bg_processes = []
112 self.bg_processes = []
113 atexit.register(self.kill_bg_processes)
113 atexit.register(self.kill_bg_processes)
114
114
115 def __del__(self):
115 def __del__(self):
116 self.kill_bg_processes()
116 self.kill_bg_processes()
117
117
118 def _generate_script_magics(self):
118 def _generate_script_magics(self):
119 cell_magics = self.magics['cell']
119 cell_magics = self.magics['cell']
120 for name in self.script_magics:
120 for name in self.script_magics:
121 cell_magics[name] = self._make_script_magic(name)
121 cell_magics[name] = self._make_script_magic(name)
122
122
123 def _make_script_magic(self, name):
123 def _make_script_magic(self, name):
124 """make a named magic, that calls %%script with a particular program"""
124 """make a named magic, that calls %%script with a particular program"""
125 # expand to explicit path if necessary:
125 # expand to explicit path if necessary:
126 script = self.script_paths.get(name, name)
126 script = self.script_paths.get(name, name)
127
127
128 @magic_arguments.magic_arguments()
128 @magic_arguments.magic_arguments()
129 @script_args
129 @script_args
130 def named_script_magic(line, cell):
130 def named_script_magic(line, cell):
131 # if line, add it as cl-flags
131 # if line, add it as cl-flags
132 if line:
132 if line:
133 line = "%s %s" % (script, line)
133 line = "%s %s" % (script, line)
134 else:
134 else:
135 line = script
135 line = script
136 return self.shebang(line, cell)
136 return self.shebang(line, cell)
137
137
138 # write a basic docstring:
138 # write a basic docstring:
139 named_script_magic.__doc__ = \
139 named_script_magic.__doc__ = \
140 """%%{name} script magic
140 """%%{name} script magic
141
141
142 Run cells with {script} in a subprocess.
142 Run cells with {script} in a subprocess.
143
143
144 This is a shortcut for `%%script {script}`
144 This is a shortcut for `%%script {script}`
145 """.format(**locals())
145 """.format(**locals())
146
146
147 return named_script_magic
147 return named_script_magic
148
148
149 @magic_arguments.magic_arguments()
149 @magic_arguments.magic_arguments()
150 @script_args
150 @script_args
151 @cell_magic("script")
151 @cell_magic("script")
152 def shebang(self, line, cell):
152 def shebang(self, line, cell):
153 """Run a cell via a shell command
153 """Run a cell via a shell command
154
154
155 The `%%script` line is like the #! line of script,
155 The `%%script` line is like the #! line of script,
156 specifying a program (bash, perl, ruby, etc.) with which to run.
156 specifying a program (bash, perl, ruby, etc.) with which to run.
157
157
158 The rest of the cell is run by that program.
158 The rest of the cell is run by that program.
159
159
160 Examples
160 Examples
161 --------
161 --------
162 ::
162 ::
163
163
164 In [1]: %%script bash
164 In [1]: %%script bash
165 ...: for i in 1 2 3; do
165 ...: for i in 1 2 3; do
166 ...: echo $i
166 ...: echo $i
167 ...: done
167 ...: done
168 1
168 1
169 2
169 2
170 3
170 3
171 """
171 """
172 argv = arg_split(line, posix = not sys.platform.startswith('win'))
172 argv = arg_split(line, posix = not sys.platform.startswith('win'))
173 args, cmd = self.shebang.parser.parse_known_args(argv)
173 args, cmd = self.shebang.parser.parse_known_args(argv)
174
174
175 try:
175 try:
176 p = Popen(cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE)
176 p = Popen(cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE)
177 except OSError as e:
177 except OSError as e:
178 if e.errno == errno.ENOENT:
178 if e.errno == errno.ENOENT:
179 print("Couldn't find program: %r" % cmd[0])
179 print("Couldn't find program: %r" % cmd[0])
180 return
180 return
181 else:
181 else:
182 raise
182 raise
183
183
184 if not cell.endswith('\n'):
184 if not cell.endswith('\n'):
185 cell += '\n'
185 cell += '\n'
186 cell = cell.encode('utf8', 'replace')
186 cell = cell.encode('utf8', 'replace')
187 if args.bg:
187 if args.bg:
188 self.bg_processes.append(p)
188 self.bg_processes.append(p)
189 self._gc_bg_processes()
189 self._gc_bg_processes()
190 if args.out:
190 if args.out:
191 self.shell.user_ns[args.out] = p.stdout
191 self.shell.user_ns[args.out] = p.stdout
192 if args.err:
192 if args.err:
193 self.shell.user_ns[args.err] = p.stderr
193 self.shell.user_ns[args.err] = p.stderr
194 self.job_manager.new(self._run_script, p, cell, daemon=True)
194 self.job_manager.new(self._run_script, p, cell, daemon=True)
195 if args.proc:
195 if args.proc:
196 self.shell.user_ns[args.proc] = p
196 self.shell.user_ns[args.proc] = p
197 return
197 return
198
198
199 try:
199 try:
200 out, err = p.communicate(cell)
200 out, err = p.communicate(cell)
201 except KeyboardInterrupt:
201 except KeyboardInterrupt:
202 try:
202 try:
203 p.send_signal(signal.SIGINT)
203 p.send_signal(signal.SIGINT)
204 time.sleep(0.1)
204 time.sleep(0.1)
205 if p.poll() is not None:
205 if p.poll() is not None:
206 print("Process is interrupted.")
206 print("Process is interrupted.")
207 return
207 return
208 p.terminate()
208 p.terminate()
209 time.sleep(0.1)
209 time.sleep(0.1)
210 if p.poll() is not None:
210 if p.poll() is not None:
211 print("Process is terminated.")
211 print("Process is terminated.")
212 return
212 return
213 p.kill()
213 p.kill()
214 print("Process is killed.")
214 print("Process is killed.")
215 except OSError:
215 except OSError:
216 pass
216 pass
217 except Exception as e:
217 except Exception as e:
218 print("Error while terminating subprocess (pid=%i): %s" \
218 print("Error while terminating subprocess (pid=%i): %s" \
219 % (p.pid, e))
219 % (p.pid, e))
220 return
220 return
221 out = py3compat.bytes_to_str(out)
221 out = py3compat.decode(out)
222 err = py3compat.bytes_to_str(err)
222 err = py3compat.decode(err)
223 if args.out:
223 if args.out:
224 self.shell.user_ns[args.out] = out
224 self.shell.user_ns[args.out] = out
225 else:
225 else:
226 sys.stdout.write(out)
226 sys.stdout.write(out)
227 sys.stdout.flush()
227 sys.stdout.flush()
228 if args.err:
228 if args.err:
229 self.shell.user_ns[args.err] = err
229 self.shell.user_ns[args.err] = err
230 else:
230 else:
231 sys.stderr.write(err)
231 sys.stderr.write(err)
232 sys.stderr.flush()
232 sys.stderr.flush()
233
233
234 def _run_script(self, p, cell):
234 def _run_script(self, p, cell):
235 """callback for running the script in the background"""
235 """callback for running the script in the background"""
236 p.stdin.write(cell)
236 p.stdin.write(cell)
237 p.stdin.close()
237 p.stdin.close()
238 p.wait()
238 p.wait()
239
239
240 @line_magic("killbgscripts")
240 @line_magic("killbgscripts")
241 def killbgscripts(self, _nouse_=''):
241 def killbgscripts(self, _nouse_=''):
242 """Kill all BG processes started by %%script and its family."""
242 """Kill all BG processes started by %%script and its family."""
243 self.kill_bg_processes()
243 self.kill_bg_processes()
244 print("All background processes were killed.")
244 print("All background processes were killed.")
245
245
246 def kill_bg_processes(self):
246 def kill_bg_processes(self):
247 """Kill all BG processes which are still running."""
247 """Kill all BG processes which are still running."""
248 if not self.bg_processes:
248 if not self.bg_processes:
249 return
249 return
250 for p in self.bg_processes:
250 for p in self.bg_processes:
251 if p.poll() is None:
251 if p.poll() is None:
252 try:
252 try:
253 p.send_signal(signal.SIGINT)
253 p.send_signal(signal.SIGINT)
254 except:
254 except:
255 pass
255 pass
256 time.sleep(0.1)
256 time.sleep(0.1)
257 self._gc_bg_processes()
257 self._gc_bg_processes()
258 if not self.bg_processes:
258 if not self.bg_processes:
259 return
259 return
260 for p in self.bg_processes:
260 for p in self.bg_processes:
261 if p.poll() is None:
261 if p.poll() is None:
262 try:
262 try:
263 p.terminate()
263 p.terminate()
264 except:
264 except:
265 pass
265 pass
266 time.sleep(0.1)
266 time.sleep(0.1)
267 self._gc_bg_processes()
267 self._gc_bg_processes()
268 if not self.bg_processes:
268 if not self.bg_processes:
269 return
269 return
270 for p in self.bg_processes:
270 for p in self.bg_processes:
271 if p.poll() is None:
271 if p.poll() is None:
272 try:
272 try:
273 p.kill()
273 p.kill()
274 except:
274 except:
275 pass
275 pass
276 self._gc_bg_processes()
276 self._gc_bg_processes()
277
277
278 def _gc_bg_processes(self):
278 def _gc_bg_processes(self):
279 self.bg_processes = [p for p in self.bg_processes if p.poll() is None]
279 self.bg_processes = [p for p in self.bg_processes if p.poll() is None]
@@ -1,74 +1,73
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for IPython.core.application"""
2 """Tests for IPython.core.application"""
3
3
4 import os
4 import os
5 import tempfile
5 import tempfile
6
6
7 import nose.tools as nt
7 import nose.tools as nt
8
8
9 from traitlets import Unicode
9 from traitlets import Unicode
10
10
11 from IPython.core.application import BaseIPythonApplication
11 from IPython.core.application import BaseIPythonApplication
12 from IPython.testing import decorators as dec
12 from IPython.testing import decorators as dec
13 from IPython.utils import py3compat
14 from IPython.utils.tempdir import TemporaryDirectory
13 from IPython.utils.tempdir import TemporaryDirectory
15
14
16
15
17 @dec.onlyif_unicode_paths
16 @dec.onlyif_unicode_paths
18 def test_unicode_cwd():
17 def test_unicode_cwd():
19 """Check that IPython starts with non-ascii characters in the path."""
18 """Check that IPython starts with non-ascii characters in the path."""
20 wd = tempfile.mkdtemp(suffix=u"€")
19 wd = tempfile.mkdtemp(suffix=u"€")
21
20
22 old_wd = os.getcwd()
21 old_wd = os.getcwd()
23 os.chdir(wd)
22 os.chdir(wd)
24 #raise Exception(repr(os.getcwd()))
23 #raise Exception(repr(os.getcwd()))
25 try:
24 try:
26 app = BaseIPythonApplication()
25 app = BaseIPythonApplication()
27 # The lines below are copied from Application.initialize()
26 # The lines below are copied from Application.initialize()
28 app.init_profile_dir()
27 app.init_profile_dir()
29 app.init_config_files()
28 app.init_config_files()
30 app.load_config_file(suppress_errors=False)
29 app.load_config_file(suppress_errors=False)
31 finally:
30 finally:
32 os.chdir(old_wd)
31 os.chdir(old_wd)
33
32
34 @dec.onlyif_unicode_paths
33 @dec.onlyif_unicode_paths
35 def test_unicode_ipdir():
34 def test_unicode_ipdir():
36 """Check that IPython starts with non-ascii characters in the IP dir."""
35 """Check that IPython starts with non-ascii characters in the IP dir."""
37 ipdir = tempfile.mkdtemp(suffix=u"€")
36 ipdir = tempfile.mkdtemp(suffix=u"€")
38
37
39 # Create the config file, so it tries to load it.
38 # Create the config file, so it tries to load it.
40 with open(os.path.join(ipdir, 'ipython_config.py'), "w") as f:
39 with open(os.path.join(ipdir, 'ipython_config.py'), "w") as f:
41 pass
40 pass
42
41
43 old_ipdir1 = os.environ.pop("IPYTHONDIR", None)
42 old_ipdir1 = os.environ.pop("IPYTHONDIR", None)
44 old_ipdir2 = os.environ.pop("IPYTHON_DIR", None)
43 old_ipdir2 = os.environ.pop("IPYTHON_DIR", None)
45 os.environ["IPYTHONDIR"] = ipdir
44 os.environ["IPYTHONDIR"] = ipdir
46 try:
45 try:
47 app = BaseIPythonApplication()
46 app = BaseIPythonApplication()
48 # The lines below are copied from Application.initialize()
47 # The lines below are copied from Application.initialize()
49 app.init_profile_dir()
48 app.init_profile_dir()
50 app.init_config_files()
49 app.init_config_files()
51 app.load_config_file(suppress_errors=False)
50 app.load_config_file(suppress_errors=False)
52 finally:
51 finally:
53 if old_ipdir1:
52 if old_ipdir1:
54 os.environ["IPYTHONDIR"] = old_ipdir1
53 os.environ["IPYTHONDIR"] = old_ipdir1
55 if old_ipdir2:
54 if old_ipdir2:
56 os.environ["IPYTHONDIR"] = old_ipdir2
55 os.environ["IPYTHONDIR"] = old_ipdir2
57
56
58 def test_cli_priority():
57 def test_cli_priority():
59 with TemporaryDirectory() as td:
58 with TemporaryDirectory() as td:
60
59
61 class TestApp(BaseIPythonApplication):
60 class TestApp(BaseIPythonApplication):
62 test = Unicode().tag(config=True)
61 test = Unicode().tag(config=True)
63
62
64 # Create the config file, so it tries to load it.
63 # Create the config file, so it tries to load it.
65 with open(os.path.join(td, 'ipython_config.py'), "w") as f:
64 with open(os.path.join(td, 'ipython_config.py'), "w") as f:
66 f.write("c.TestApp.test = 'config file'")
65 f.write("c.TestApp.test = 'config file'")
67
66
68 app = TestApp()
67 app = TestApp()
69 app.initialize(['--profile-dir', td])
68 app.initialize(['--profile-dir', td])
70 nt.assert_equal(app.test, 'config file')
69 nt.assert_equal(app.test, 'config file')
71 app = TestApp()
70 app = TestApp()
72 app.initialize(['--profile-dir', td, '--TestApp.test=cli'])
71 app.initialize(['--profile-dir', td, '--TestApp.test=cli'])
73 nt.assert_equal(app.test, 'cli')
72 nt.assert_equal(app.test, 'cli')
74
73
@@ -1,74 +1,73
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for the compilerop module.
2 """Tests for the compilerop module.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team.
5 # Copyright (C) 2010-2011 The IPython Development Team.
6 #
6 #
7 # Distributed under the terms of the BSD License.
7 # Distributed under the terms of the BSD License.
8 #
8 #
9 # The full license is in the file COPYING.txt, distributed with this software.
9 # The full license is in the file COPYING.txt, distributed with this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 # Stdlib imports
16 # Stdlib imports
17 import linecache
17 import linecache
18 import sys
18 import sys
19
19
20 # Third-party imports
20 # Third-party imports
21 import nose.tools as nt
21 import nose.tools as nt
22
22
23 # Our own imports
23 # Our own imports
24 from IPython.core import compilerop
24 from IPython.core import compilerop
25 from IPython.utils import py3compat
26
25
27 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
28 # Test functions
27 # Test functions
29 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
30
29
31 def test_code_name():
30 def test_code_name():
32 code = 'x=1'
31 code = 'x=1'
33 name = compilerop.code_name(code)
32 name = compilerop.code_name(code)
34 nt.assert_true(name.startswith('<ipython-input-0'))
33 nt.assert_true(name.startswith('<ipython-input-0'))
35
34
36
35
37 def test_code_name2():
36 def test_code_name2():
38 code = 'x=1'
37 code = 'x=1'
39 name = compilerop.code_name(code, 9)
38 name = compilerop.code_name(code, 9)
40 nt.assert_true(name.startswith('<ipython-input-9'))
39 nt.assert_true(name.startswith('<ipython-input-9'))
41
40
42
41
43 def test_cache():
42 def test_cache():
44 """Test the compiler correctly compiles and caches inputs
43 """Test the compiler correctly compiles and caches inputs
45 """
44 """
46 cp = compilerop.CachingCompiler()
45 cp = compilerop.CachingCompiler()
47 ncache = len(linecache.cache)
46 ncache = len(linecache.cache)
48 cp.cache('x=1')
47 cp.cache('x=1')
49 nt.assert_true(len(linecache.cache) > ncache)
48 nt.assert_true(len(linecache.cache) > ncache)
50
49
51 def setUp():
50 def setUp():
52 # Check we're in a proper Python 2 environment (some imports, such
51 # Check we're in a proper Python 2 environment (some imports, such
53 # as GTK, can change the default encoding, which can hide bugs.)
52 # as GTK, can change the default encoding, which can hide bugs.)
54 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
53 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
55
54
56 def test_cache_unicode():
55 def test_cache_unicode():
57 cp = compilerop.CachingCompiler()
56 cp = compilerop.CachingCompiler()
58 ncache = len(linecache.cache)
57 ncache = len(linecache.cache)
59 cp.cache(u"t = 'žćčšđ'")
58 cp.cache(u"t = 'žćčšđ'")
60 nt.assert_true(len(linecache.cache) > ncache)
59 nt.assert_true(len(linecache.cache) > ncache)
61
60
62 def test_compiler_check_cache():
61 def test_compiler_check_cache():
63 """Test the compiler properly manages the cache.
62 """Test the compiler properly manages the cache.
64 """
63 """
65 # Rather simple-minded tests that just exercise the API
64 # Rather simple-minded tests that just exercise the API
66 cp = compilerop.CachingCompiler()
65 cp = compilerop.CachingCompiler()
67 cp.cache('x=1', 99)
66 cp.cache('x=1', 99)
68 # Ensure now that after clearing the cache, our entries survive
67 # Ensure now that after clearing the cache, our entries survive
69 linecache.checkcache()
68 linecache.checkcache()
70 for k in linecache.cache:
69 for k in linecache.cache:
71 if k.startswith('<ipython-input-99'):
70 if k.startswith('<ipython-input-99'):
72 break
71 break
73 else:
72 else:
74 raise AssertionError('Entry for input-99 missing from linecache')
73 raise AssertionError('Entry for input-99 missing from linecache')
@@ -1,162 +1,161
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for completerlib.
2 """Tests for completerlib.
3
3
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Imports
7 # Imports
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 import os
10 import os
11 import shutil
11 import shutil
12 import sys
12 import sys
13 import tempfile
13 import tempfile
14 import unittest
14 import unittest
15 from os.path import join
15 from os.path import join
16
16
17 import nose.tools as nt
17 import nose.tools as nt
18
18
19 from IPython.core.completerlib import magic_run_completer, module_completion
19 from IPython.core.completerlib import magic_run_completer, module_completion
20 from IPython.utils import py3compat
21 from IPython.utils.tempdir import TemporaryDirectory
20 from IPython.utils.tempdir import TemporaryDirectory
22 from IPython.testing.decorators import onlyif_unicode_paths
21 from IPython.testing.decorators import onlyif_unicode_paths
23
22
24
23
25 class MockEvent(object):
24 class MockEvent(object):
26 def __init__(self, line):
25 def __init__(self, line):
27 self.line = line
26 self.line = line
28
27
29 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
30 # Test functions begin
29 # Test functions begin
31 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
32 class Test_magic_run_completer(unittest.TestCase):
31 class Test_magic_run_completer(unittest.TestCase):
33 files = [u"aao.py", u"a.py", u"b.py", u"aao.txt"]
32 files = [u"aao.py", u"a.py", u"b.py", u"aao.txt"]
34 dirs = [u"adir/", "bdir/"]
33 dirs = [u"adir/", "bdir/"]
35
34
36 def setUp(self):
35 def setUp(self):
37 self.BASETESTDIR = tempfile.mkdtemp()
36 self.BASETESTDIR = tempfile.mkdtemp()
38 for fil in self.files:
37 for fil in self.files:
39 with open(join(self.BASETESTDIR, fil), "w") as sfile:
38 with open(join(self.BASETESTDIR, fil), "w") as sfile:
40 sfile.write("pass\n")
39 sfile.write("pass\n")
41 for d in self.dirs:
40 for d in self.dirs:
42 os.mkdir(join(self.BASETESTDIR, d))
41 os.mkdir(join(self.BASETESTDIR, d))
43
42
44 self.oldpath = os.getcwd()
43 self.oldpath = os.getcwd()
45 os.chdir(self.BASETESTDIR)
44 os.chdir(self.BASETESTDIR)
46
45
47 def tearDown(self):
46 def tearDown(self):
48 os.chdir(self.oldpath)
47 os.chdir(self.oldpath)
49 shutil.rmtree(self.BASETESTDIR)
48 shutil.rmtree(self.BASETESTDIR)
50
49
51 def test_1(self):
50 def test_1(self):
52 """Test magic_run_completer, should match two alterntives
51 """Test magic_run_completer, should match two alterntives
53 """
52 """
54 event = MockEvent(u"%run a")
53 event = MockEvent(u"%run a")
55 mockself = None
54 mockself = None
56 match = set(magic_run_completer(mockself, event))
55 match = set(magic_run_completer(mockself, event))
57 self.assertEqual(match, {u"a.py", u"aao.py", u"adir/"})
56 self.assertEqual(match, {u"a.py", u"aao.py", u"adir/"})
58
57
59 def test_2(self):
58 def test_2(self):
60 """Test magic_run_completer, should match one alterntive
59 """Test magic_run_completer, should match one alterntive
61 """
60 """
62 event = MockEvent(u"%run aa")
61 event = MockEvent(u"%run aa")
63 mockself = None
62 mockself = None
64 match = set(magic_run_completer(mockself, event))
63 match = set(magic_run_completer(mockself, event))
65 self.assertEqual(match, {u"aao.py"})
64 self.assertEqual(match, {u"aao.py"})
66
65
67 def test_3(self):
66 def test_3(self):
68 """Test magic_run_completer with unterminated " """
67 """Test magic_run_completer with unterminated " """
69 event = MockEvent(u'%run "a')
68 event = MockEvent(u'%run "a')
70 mockself = None
69 mockself = None
71 match = set(magic_run_completer(mockself, event))
70 match = set(magic_run_completer(mockself, event))
72 self.assertEqual(match, {u"a.py", u"aao.py", u"adir/"})
71 self.assertEqual(match, {u"a.py", u"aao.py", u"adir/"})
73
72
74 def test_completion_more_args(self):
73 def test_completion_more_args(self):
75 event = MockEvent(u'%run a.py ')
74 event = MockEvent(u'%run a.py ')
76 match = set(magic_run_completer(None, event))
75 match = set(magic_run_completer(None, event))
77 self.assertEqual(match, set(self.files + self.dirs))
76 self.assertEqual(match, set(self.files + self.dirs))
78
77
79 def test_completion_in_dir(self):
78 def test_completion_in_dir(self):
80 # Github issue #3459
79 # Github issue #3459
81 event = MockEvent(u'%run a.py {}'.format(join(self.BASETESTDIR, 'a')))
80 event = MockEvent(u'%run a.py {}'.format(join(self.BASETESTDIR, 'a')))
82 print(repr(event.line))
81 print(repr(event.line))
83 match = set(magic_run_completer(None, event))
82 match = set(magic_run_completer(None, event))
84 # We specifically use replace here rather than normpath, because
83 # We specifically use replace here rather than normpath, because
85 # at one point there were duplicates 'adir' and 'adir/', and normpath
84 # at one point there were duplicates 'adir' and 'adir/', and normpath
86 # would hide the failure for that.
85 # would hide the failure for that.
87 self.assertEqual(match, {join(self.BASETESTDIR, f).replace('\\','/')
86 self.assertEqual(match, {join(self.BASETESTDIR, f).replace('\\','/')
88 for f in (u'a.py', u'aao.py', u'aao.txt', u'adir/')})
87 for f in (u'a.py', u'aao.py', u'aao.txt', u'adir/')})
89
88
90 class Test_magic_run_completer_nonascii(unittest.TestCase):
89 class Test_magic_run_completer_nonascii(unittest.TestCase):
91 @onlyif_unicode_paths
90 @onlyif_unicode_paths
92 def setUp(self):
91 def setUp(self):
93 self.BASETESTDIR = tempfile.mkdtemp()
92 self.BASETESTDIR = tempfile.mkdtemp()
94 for fil in [u"aaø.py", u"a.py", u"b.py"]:
93 for fil in [u"aaø.py", u"a.py", u"b.py"]:
95 with open(join(self.BASETESTDIR, fil), "w") as sfile:
94 with open(join(self.BASETESTDIR, fil), "w") as sfile:
96 sfile.write("pass\n")
95 sfile.write("pass\n")
97 self.oldpath = os.getcwd()
96 self.oldpath = os.getcwd()
98 os.chdir(self.BASETESTDIR)
97 os.chdir(self.BASETESTDIR)
99
98
100 def tearDown(self):
99 def tearDown(self):
101 os.chdir(self.oldpath)
100 os.chdir(self.oldpath)
102 shutil.rmtree(self.BASETESTDIR)
101 shutil.rmtree(self.BASETESTDIR)
103
102
104 @onlyif_unicode_paths
103 @onlyif_unicode_paths
105 def test_1(self):
104 def test_1(self):
106 """Test magic_run_completer, should match two alterntives
105 """Test magic_run_completer, should match two alterntives
107 """
106 """
108 event = MockEvent(u"%run a")
107 event = MockEvent(u"%run a")
109 mockself = None
108 mockself = None
110 match = set(magic_run_completer(mockself, event))
109 match = set(magic_run_completer(mockself, event))
111 self.assertEqual(match, {u"a.py", u"aaø.py"})
110 self.assertEqual(match, {u"a.py", u"aaø.py"})
112
111
113 @onlyif_unicode_paths
112 @onlyif_unicode_paths
114 def test_2(self):
113 def test_2(self):
115 """Test magic_run_completer, should match one alterntive
114 """Test magic_run_completer, should match one alterntive
116 """
115 """
117 event = MockEvent(u"%run aa")
116 event = MockEvent(u"%run aa")
118 mockself = None
117 mockself = None
119 match = set(magic_run_completer(mockself, event))
118 match = set(magic_run_completer(mockself, event))
120 self.assertEqual(match, {u"aaø.py"})
119 self.assertEqual(match, {u"aaø.py"})
121
120
122 @onlyif_unicode_paths
121 @onlyif_unicode_paths
123 def test_3(self):
122 def test_3(self):
124 """Test magic_run_completer with unterminated " """
123 """Test magic_run_completer with unterminated " """
125 event = MockEvent(u'%run "a')
124 event = MockEvent(u'%run "a')
126 mockself = None
125 mockself = None
127 match = set(magic_run_completer(mockself, event))
126 match = set(magic_run_completer(mockself, event))
128 self.assertEqual(match, {u"a.py", u"aaø.py"})
127 self.assertEqual(match, {u"a.py", u"aaø.py"})
129
128
130 # module_completer:
129 # module_completer:
131
130
132 def test_import_invalid_module():
131 def test_import_invalid_module():
133 """Testing of issue https://github.com/ipython/ipython/issues/1107"""
132 """Testing of issue https://github.com/ipython/ipython/issues/1107"""
134 invalid_module_names = {'foo-bar', 'foo:bar', '10foo'}
133 invalid_module_names = {'foo-bar', 'foo:bar', '10foo'}
135 valid_module_names = {'foobar'}
134 valid_module_names = {'foobar'}
136 with TemporaryDirectory() as tmpdir:
135 with TemporaryDirectory() as tmpdir:
137 sys.path.insert( 0, tmpdir )
136 sys.path.insert( 0, tmpdir )
138 for name in invalid_module_names | valid_module_names:
137 for name in invalid_module_names | valid_module_names:
139 filename = os.path.join(tmpdir, name + '.py')
138 filename = os.path.join(tmpdir, name + '.py')
140 open(filename, 'w').close()
139 open(filename, 'w').close()
141
140
142 s = set( module_completion('import foo') )
141 s = set( module_completion('import foo') )
143 intersection = s.intersection(invalid_module_names)
142 intersection = s.intersection(invalid_module_names)
144 nt.assert_equal(intersection, set())
143 nt.assert_equal(intersection, set())
145
144
146 assert valid_module_names.issubset(s), valid_module_names.intersection(s)
145 assert valid_module_names.issubset(s), valid_module_names.intersection(s)
147
146
148
147
149 def test_bad_module_all():
148 def test_bad_module_all():
150 """Test module with invalid __all__
149 """Test module with invalid __all__
151
150
152 https://github.com/ipython/ipython/issues/9678
151 https://github.com/ipython/ipython/issues/9678
153 """
152 """
154 testsdir = os.path.dirname(__file__)
153 testsdir = os.path.dirname(__file__)
155 sys.path.insert(0, testsdir)
154 sys.path.insert(0, testsdir)
156 try:
155 try:
157 results = module_completion('from bad_all import ')
156 results = module_completion('from bad_all import ')
158 nt.assert_in('puppies', results)
157 nt.assert_in('puppies', results)
159 for r in results:
158 for r in results:
160 nt.assert_is_instance(r, str)
159 nt.assert_is_instance(r, str)
161 finally:
160 finally:
162 sys.path.remove(testsdir)
161 sys.path.remove(testsdir)
@@ -1,211 +1,210
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for the IPython tab-completion machinery.
2 """Tests for the IPython tab-completion machinery.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Module imports
5 # Module imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7
7
8 # stdlib
8 # stdlib
9 import io
9 import io
10 import os
10 import os
11 import sys
11 import sys
12 import tempfile
12 import tempfile
13 from datetime import datetime
13 from datetime import datetime
14
14
15 # third party
15 # third party
16 import nose.tools as nt
16 import nose.tools as nt
17
17
18 # our own packages
18 # our own packages
19 from traitlets.config.loader import Config
19 from traitlets.config.loader import Config
20 from IPython.utils.tempdir import TemporaryDirectory
20 from IPython.utils.tempdir import TemporaryDirectory
21 from IPython.core.history import HistoryManager, extract_hist_ranges
21 from IPython.core.history import HistoryManager, extract_hist_ranges
22 from IPython.utils import py3compat
23
22
24 def setUp():
23 def setUp():
25 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
24 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
26
25
27 def test_history():
26 def test_history():
28 ip = get_ipython()
27 ip = get_ipython()
29 with TemporaryDirectory() as tmpdir:
28 with TemporaryDirectory() as tmpdir:
30 hist_manager_ori = ip.history_manager
29 hist_manager_ori = ip.history_manager
31 hist_file = os.path.join(tmpdir, 'history.sqlite')
30 hist_file = os.path.join(tmpdir, 'history.sqlite')
32 try:
31 try:
33 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
32 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
34 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
33 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
35 for i, h in enumerate(hist, start=1):
34 for i, h in enumerate(hist, start=1):
36 ip.history_manager.store_inputs(i, h)
35 ip.history_manager.store_inputs(i, h)
37
36
38 ip.history_manager.db_log_output = True
37 ip.history_manager.db_log_output = True
39 # Doesn't match the input, but we'll just check it's stored.
38 # Doesn't match the input, but we'll just check it's stored.
40 ip.history_manager.output_hist_reprs[3] = "spam"
39 ip.history_manager.output_hist_reprs[3] = "spam"
41 ip.history_manager.store_output(3)
40 ip.history_manager.store_output(3)
42
41
43 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
42 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
44
43
45 # Detailed tests for _get_range_session
44 # Detailed tests for _get_range_session
46 grs = ip.history_manager._get_range_session
45 grs = ip.history_manager._get_range_session
47 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
46 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
48 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
47 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
49 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
48 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
50
49
51 # Check whether specifying a range beyond the end of the current
50 # Check whether specifying a range beyond the end of the current
52 # session results in an error (gh-804)
51 # session results in an error (gh-804)
53 ip.magic('%hist 2-500')
52 ip.magic('%hist 2-500')
54
53
55 # Check that we can write non-ascii characters to a file
54 # Check that we can write non-ascii characters to a file
56 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
55 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
57 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
56 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
58 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
57 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
59 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
58 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
60
59
61 # New session
60 # New session
62 ip.history_manager.reset()
61 ip.history_manager.reset()
63 newcmds = [u"z=5",
62 newcmds = [u"z=5",
64 u"class X(object):\n pass",
63 u"class X(object):\n pass",
65 u"k='p'",
64 u"k='p'",
66 u"z=5"]
65 u"z=5"]
67 for i, cmd in enumerate(newcmds, start=1):
66 for i, cmd in enumerate(newcmds, start=1):
68 ip.history_manager.store_inputs(i, cmd)
67 ip.history_manager.store_inputs(i, cmd)
69 gothist = ip.history_manager.get_range(start=1, stop=4)
68 gothist = ip.history_manager.get_range(start=1, stop=4)
70 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
69 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
71 # Previous session:
70 # Previous session:
72 gothist = ip.history_manager.get_range(-1, 1, 4)
71 gothist = ip.history_manager.get_range(-1, 1, 4)
73 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
72 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
74
73
75 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
74 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
76
75
77 # Check get_hist_tail
76 # Check get_hist_tail
78 gothist = ip.history_manager.get_tail(5, output=True,
77 gothist = ip.history_manager.get_tail(5, output=True,
79 include_latest=True)
78 include_latest=True)
80 expected = [(1, 3, (hist[-1], "spam"))] \
79 expected = [(1, 3, (hist[-1], "spam"))] \
81 + [(s, n, (c, None)) for (s, n, c) in newhist]
80 + [(s, n, (c, None)) for (s, n, c) in newhist]
82 nt.assert_equal(list(gothist), expected)
81 nt.assert_equal(list(gothist), expected)
83
82
84 gothist = ip.history_manager.get_tail(2)
83 gothist = ip.history_manager.get_tail(2)
85 expected = newhist[-3:-1]
84 expected = newhist[-3:-1]
86 nt.assert_equal(list(gothist), expected)
85 nt.assert_equal(list(gothist), expected)
87
86
88 # Check get_hist_search
87 # Check get_hist_search
89 gothist = ip.history_manager.search("*test*")
88 gothist = ip.history_manager.search("*test*")
90 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
89 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
91
90
92 gothist = ip.history_manager.search("*=*")
91 gothist = ip.history_manager.search("*=*")
93 nt.assert_equal(list(gothist),
92 nt.assert_equal(list(gothist),
94 [(1, 1, hist[0]),
93 [(1, 1, hist[0]),
95 (1, 2, hist[1]),
94 (1, 2, hist[1]),
96 (1, 3, hist[2]),
95 (1, 3, hist[2]),
97 newhist[0],
96 newhist[0],
98 newhist[2],
97 newhist[2],
99 newhist[3]])
98 newhist[3]])
100
99
101 gothist = ip.history_manager.search("*=*", n=4)
100 gothist = ip.history_manager.search("*=*", n=4)
102 nt.assert_equal(list(gothist),
101 nt.assert_equal(list(gothist),
103 [(1, 3, hist[2]),
102 [(1, 3, hist[2]),
104 newhist[0],
103 newhist[0],
105 newhist[2],
104 newhist[2],
106 newhist[3]])
105 newhist[3]])
107
106
108 gothist = ip.history_manager.search("*=*", unique=True)
107 gothist = ip.history_manager.search("*=*", unique=True)
109 nt.assert_equal(list(gothist),
108 nt.assert_equal(list(gothist),
110 [(1, 1, hist[0]),
109 [(1, 1, hist[0]),
111 (1, 2, hist[1]),
110 (1, 2, hist[1]),
112 (1, 3, hist[2]),
111 (1, 3, hist[2]),
113 newhist[2],
112 newhist[2],
114 newhist[3]])
113 newhist[3]])
115
114
116 gothist = ip.history_manager.search("*=*", unique=True, n=3)
115 gothist = ip.history_manager.search("*=*", unique=True, n=3)
117 nt.assert_equal(list(gothist),
116 nt.assert_equal(list(gothist),
118 [(1, 3, hist[2]),
117 [(1, 3, hist[2]),
119 newhist[2],
118 newhist[2],
120 newhist[3]])
119 newhist[3]])
121
120
122 gothist = ip.history_manager.search("b*", output=True)
121 gothist = ip.history_manager.search("b*", output=True)
123 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
122 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
124
123
125 # Cross testing: check that magic %save can get previous session.
124 # Cross testing: check that magic %save can get previous session.
126 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
125 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
127 ip.magic("save " + testfilename + " ~1/1-3")
126 ip.magic("save " + testfilename + " ~1/1-3")
128 with io.open(testfilename, encoding='utf-8') as testfile:
127 with io.open(testfilename, encoding='utf-8') as testfile:
129 nt.assert_equal(testfile.read(),
128 nt.assert_equal(testfile.read(),
130 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
129 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
131
130
132 # Duplicate line numbers - check that it doesn't crash, and
131 # Duplicate line numbers - check that it doesn't crash, and
133 # gets a new session
132 # gets a new session
134 ip.history_manager.store_inputs(1, "rogue")
133 ip.history_manager.store_inputs(1, "rogue")
135 ip.history_manager.writeout_cache()
134 ip.history_manager.writeout_cache()
136 nt.assert_equal(ip.history_manager.session_number, 3)
135 nt.assert_equal(ip.history_manager.session_number, 3)
137 finally:
136 finally:
138 # Ensure saving thread is shut down before we try to clean up the files
137 # Ensure saving thread is shut down before we try to clean up the files
139 ip.history_manager.save_thread.stop()
138 ip.history_manager.save_thread.stop()
140 # Forcibly close database rather than relying on garbage collection
139 # Forcibly close database rather than relying on garbage collection
141 ip.history_manager.db.close()
140 ip.history_manager.db.close()
142 # Restore history manager
141 # Restore history manager
143 ip.history_manager = hist_manager_ori
142 ip.history_manager = hist_manager_ori
144
143
145
144
146 def test_extract_hist_ranges():
145 def test_extract_hist_ranges():
147 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
146 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
148 expected = [(0, 1, 2), # 0 == current session
147 expected = [(0, 1, 2), # 0 == current session
149 (2, 3, 4),
148 (2, 3, 4),
150 (-4, 5, 7),
149 (-4, 5, 7),
151 (-4, 7, 10),
150 (-4, 7, 10),
152 (-9, 2, None), # None == to end
151 (-9, 2, None), # None == to end
153 (-8, 1, None),
152 (-8, 1, None),
154 (-7, 1, 6),
153 (-7, 1, 6),
155 (-10, 1, None)]
154 (-10, 1, None)]
156 actual = list(extract_hist_ranges(instr))
155 actual = list(extract_hist_ranges(instr))
157 nt.assert_equal(actual, expected)
156 nt.assert_equal(actual, expected)
158
157
159 def test_magic_rerun():
158 def test_magic_rerun():
160 """Simple test for %rerun (no args -> rerun last line)"""
159 """Simple test for %rerun (no args -> rerun last line)"""
161 ip = get_ipython()
160 ip = get_ipython()
162 ip.run_cell("a = 10", store_history=True)
161 ip.run_cell("a = 10", store_history=True)
163 ip.run_cell("a += 1", store_history=True)
162 ip.run_cell("a += 1", store_history=True)
164 nt.assert_equal(ip.user_ns["a"], 11)
163 nt.assert_equal(ip.user_ns["a"], 11)
165 ip.run_cell("%rerun", store_history=True)
164 ip.run_cell("%rerun", store_history=True)
166 nt.assert_equal(ip.user_ns["a"], 12)
165 nt.assert_equal(ip.user_ns["a"], 12)
167
166
168 def test_timestamp_type():
167 def test_timestamp_type():
169 ip = get_ipython()
168 ip = get_ipython()
170 info = ip.history_manager.get_session_info()
169 info = ip.history_manager.get_session_info()
171 nt.assert_true(isinstance(info[1], datetime))
170 nt.assert_true(isinstance(info[1], datetime))
172
171
173 def test_hist_file_config():
172 def test_hist_file_config():
174 cfg = Config()
173 cfg = Config()
175 tfile = tempfile.NamedTemporaryFile(delete=False)
174 tfile = tempfile.NamedTemporaryFile(delete=False)
176 cfg.HistoryManager.hist_file = tfile.name
175 cfg.HistoryManager.hist_file = tfile.name
177 try:
176 try:
178 hm = HistoryManager(shell=get_ipython(), config=cfg)
177 hm = HistoryManager(shell=get_ipython(), config=cfg)
179 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
178 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
180 finally:
179 finally:
181 try:
180 try:
182 os.remove(tfile.name)
181 os.remove(tfile.name)
183 except OSError:
182 except OSError:
184 # same catch as in testing.tools.TempFileMixin
183 # same catch as in testing.tools.TempFileMixin
185 # On Windows, even though we close the file, we still can't
184 # On Windows, even though we close the file, we still can't
186 # delete it. I have no clue why
185 # delete it. I have no clue why
187 pass
186 pass
188
187
189 def test_histmanager_disabled():
188 def test_histmanager_disabled():
190 """Ensure that disabling the history manager doesn't create a database."""
189 """Ensure that disabling the history manager doesn't create a database."""
191 cfg = Config()
190 cfg = Config()
192 cfg.HistoryAccessor.enabled = False
191 cfg.HistoryAccessor.enabled = False
193
192
194 ip = get_ipython()
193 ip = get_ipython()
195 with TemporaryDirectory() as tmpdir:
194 with TemporaryDirectory() as tmpdir:
196 hist_manager_ori = ip.history_manager
195 hist_manager_ori = ip.history_manager
197 hist_file = os.path.join(tmpdir, 'history.sqlite')
196 hist_file = os.path.join(tmpdir, 'history.sqlite')
198 cfg.HistoryManager.hist_file = hist_file
197 cfg.HistoryManager.hist_file = hist_file
199 try:
198 try:
200 ip.history_manager = HistoryManager(shell=ip, config=cfg)
199 ip.history_manager = HistoryManager(shell=ip, config=cfg)
201 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
200 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
202 for i, h in enumerate(hist, start=1):
201 for i, h in enumerate(hist, start=1):
203 ip.history_manager.store_inputs(i, h)
202 ip.history_manager.store_inputs(i, h)
204 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
203 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
205 ip.history_manager.reset()
204 ip.history_manager.reset()
206 ip.history_manager.end_session()
205 ip.history_manager.end_session()
207 finally:
206 finally:
208 ip.history_manager = hist_manager_ori
207 ip.history_manager = hist_manager_ori
209
208
210 # hist_file should not be created
209 # hist_file should not be created
211 nt.assert_false(os.path.exists(hist_file))
210 nt.assert_false(os.path.exists(hist_file))
@@ -1,926 +1,924
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import ast
12 import ast
13 import os
13 import os
14 import signal
14 import signal
15 import shutil
15 import shutil
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import unittest
18 import unittest
19 from unittest import mock
19 from unittest import mock
20 from io import StringIO
21
20
22 from os.path import join
21 from os.path import join
23
22
24 import nose.tools as nt
23 import nose.tools as nt
25
24
26 from IPython.core.error import InputRejected
25 from IPython.core.error import InputRejected
27 from IPython.core.inputtransformer import InputTransformer
26 from IPython.core.inputtransformer import InputTransformer
28 from IPython.testing.decorators import (
27 from IPython.testing.decorators import (
29 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
28 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
30 )
29 )
31 from IPython.testing import tools as tt
30 from IPython.testing import tools as tt
32 from IPython.utils.process import find_cmd
31 from IPython.utils.process import find_cmd
33 from IPython.utils import py3compat
34
32
35 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
36 # Globals
34 # Globals
37 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
38 # This is used by every single test, no point repeating it ad nauseam
36 # This is used by every single test, no point repeating it ad nauseam
39 ip = get_ipython()
37 ip = get_ipython()
40
38
41 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
42 # Tests
40 # Tests
43 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
44
42
45 class DerivedInterrupt(KeyboardInterrupt):
43 class DerivedInterrupt(KeyboardInterrupt):
46 pass
44 pass
47
45
48 class InteractiveShellTestCase(unittest.TestCase):
46 class InteractiveShellTestCase(unittest.TestCase):
49 def test_naked_string_cells(self):
47 def test_naked_string_cells(self):
50 """Test that cells with only naked strings are fully executed"""
48 """Test that cells with only naked strings are fully executed"""
51 # First, single-line inputs
49 # First, single-line inputs
52 ip.run_cell('"a"\n')
50 ip.run_cell('"a"\n')
53 self.assertEqual(ip.user_ns['_'], 'a')
51 self.assertEqual(ip.user_ns['_'], 'a')
54 # And also multi-line cells
52 # And also multi-line cells
55 ip.run_cell('"""a\nb"""\n')
53 ip.run_cell('"""a\nb"""\n')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
54 self.assertEqual(ip.user_ns['_'], 'a\nb')
57
55
58 def test_run_empty_cell(self):
56 def test_run_empty_cell(self):
59 """Just make sure we don't get a horrible error with a blank
57 """Just make sure we don't get a horrible error with a blank
60 cell of input. Yes, I did overlook that."""
58 cell of input. Yes, I did overlook that."""
61 old_xc = ip.execution_count
59 old_xc = ip.execution_count
62 res = ip.run_cell('')
60 res = ip.run_cell('')
63 self.assertEqual(ip.execution_count, old_xc)
61 self.assertEqual(ip.execution_count, old_xc)
64 self.assertEqual(res.execution_count, None)
62 self.assertEqual(res.execution_count, None)
65
63
66 def test_run_cell_multiline(self):
64 def test_run_cell_multiline(self):
67 """Multi-block, multi-line cells must execute correctly.
65 """Multi-block, multi-line cells must execute correctly.
68 """
66 """
69 src = '\n'.join(["x=1",
67 src = '\n'.join(["x=1",
70 "y=2",
68 "y=2",
71 "if 1:",
69 "if 1:",
72 " x += 1",
70 " x += 1",
73 " y += 1",])
71 " y += 1",])
74 res = ip.run_cell(src)
72 res = ip.run_cell(src)
75 self.assertEqual(ip.user_ns['x'], 2)
73 self.assertEqual(ip.user_ns['x'], 2)
76 self.assertEqual(ip.user_ns['y'], 3)
74 self.assertEqual(ip.user_ns['y'], 3)
77 self.assertEqual(res.success, True)
75 self.assertEqual(res.success, True)
78 self.assertEqual(res.result, None)
76 self.assertEqual(res.result, None)
79
77
80 def test_multiline_string_cells(self):
78 def test_multiline_string_cells(self):
81 "Code sprinkled with multiline strings should execute (GH-306)"
79 "Code sprinkled with multiline strings should execute (GH-306)"
82 ip.run_cell('tmp=0')
80 ip.run_cell('tmp=0')
83 self.assertEqual(ip.user_ns['tmp'], 0)
81 self.assertEqual(ip.user_ns['tmp'], 0)
84 res = ip.run_cell('tmp=1;"""a\nb"""\n')
82 res = ip.run_cell('tmp=1;"""a\nb"""\n')
85 self.assertEqual(ip.user_ns['tmp'], 1)
83 self.assertEqual(ip.user_ns['tmp'], 1)
86 self.assertEqual(res.success, True)
84 self.assertEqual(res.success, True)
87 self.assertEqual(res.result, "a\nb")
85 self.assertEqual(res.result, "a\nb")
88
86
89 def test_dont_cache_with_semicolon(self):
87 def test_dont_cache_with_semicolon(self):
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
88 "Ending a line with semicolon should not cache the returned object (GH-307)"
91 oldlen = len(ip.user_ns['Out'])
89 oldlen = len(ip.user_ns['Out'])
92 for cell in ['1;', '1;1;']:
90 for cell in ['1;', '1;1;']:
93 res = ip.run_cell(cell, store_history=True)
91 res = ip.run_cell(cell, store_history=True)
94 newlen = len(ip.user_ns['Out'])
92 newlen = len(ip.user_ns['Out'])
95 self.assertEqual(oldlen, newlen)
93 self.assertEqual(oldlen, newlen)
96 self.assertIsNone(res.result)
94 self.assertIsNone(res.result)
97 i = 0
95 i = 0
98 #also test the default caching behavior
96 #also test the default caching behavior
99 for cell in ['1', '1;1']:
97 for cell in ['1', '1;1']:
100 ip.run_cell(cell, store_history=True)
98 ip.run_cell(cell, store_history=True)
101 newlen = len(ip.user_ns['Out'])
99 newlen = len(ip.user_ns['Out'])
102 i += 1
100 i += 1
103 self.assertEqual(oldlen+i, newlen)
101 self.assertEqual(oldlen+i, newlen)
104
102
105 def test_syntax_error(self):
103 def test_syntax_error(self):
106 res = ip.run_cell("raise = 3")
104 res = ip.run_cell("raise = 3")
107 self.assertIsInstance(res.error_before_exec, SyntaxError)
105 self.assertIsInstance(res.error_before_exec, SyntaxError)
108
106
109 def test_In_variable(self):
107 def test_In_variable(self):
110 "Verify that In variable grows with user input (GH-284)"
108 "Verify that In variable grows with user input (GH-284)"
111 oldlen = len(ip.user_ns['In'])
109 oldlen = len(ip.user_ns['In'])
112 ip.run_cell('1;', store_history=True)
110 ip.run_cell('1;', store_history=True)
113 newlen = len(ip.user_ns['In'])
111 newlen = len(ip.user_ns['In'])
114 self.assertEqual(oldlen+1, newlen)
112 self.assertEqual(oldlen+1, newlen)
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
113 self.assertEqual(ip.user_ns['In'][-1],'1;')
116
114
117 def test_magic_names_in_string(self):
115 def test_magic_names_in_string(self):
118 ip.run_cell('a = """\n%exit\n"""')
116 ip.run_cell('a = """\n%exit\n"""')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
117 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
120
118
121 def test_trailing_newline(self):
119 def test_trailing_newline(self):
122 """test that running !(command) does not raise a SyntaxError"""
120 """test that running !(command) does not raise a SyntaxError"""
123 ip.run_cell('!(true)\n', False)
121 ip.run_cell('!(true)\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
122 ip.run_cell('!(true)\n\n\n', False)
125
123
126 def test_gh_597(self):
124 def test_gh_597(self):
127 """Pretty-printing lists of objects with non-ascii reprs may cause
125 """Pretty-printing lists of objects with non-ascii reprs may cause
128 problems."""
126 problems."""
129 class Spam(object):
127 class Spam(object):
130 def __repr__(self):
128 def __repr__(self):
131 return "\xe9"*50
129 return "\xe9"*50
132 import IPython.core.formatters
130 import IPython.core.formatters
133 f = IPython.core.formatters.PlainTextFormatter()
131 f = IPython.core.formatters.PlainTextFormatter()
134 f([Spam(),Spam()])
132 f([Spam(),Spam()])
135
133
136
134
137 def test_future_flags(self):
135 def test_future_flags(self):
138 """Check that future flags are used for parsing code (gh-777)"""
136 """Check that future flags are used for parsing code (gh-777)"""
139 ip.run_cell('from __future__ import barry_as_FLUFL')
137 ip.run_cell('from __future__ import barry_as_FLUFL')
140 try:
138 try:
141 ip.run_cell('prfunc_return_val = 1 <> 2')
139 ip.run_cell('prfunc_return_val = 1 <> 2')
142 assert 'prfunc_return_val' in ip.user_ns
140 assert 'prfunc_return_val' in ip.user_ns
143 finally:
141 finally:
144 # Reset compiler flags so we don't mess up other tests.
142 # Reset compiler flags so we don't mess up other tests.
145 ip.compile.reset_compiler_flags()
143 ip.compile.reset_compiler_flags()
146
144
147 def test_can_pickle(self):
145 def test_can_pickle(self):
148 "Can we pickle objects defined interactively (GH-29)"
146 "Can we pickle objects defined interactively (GH-29)"
149 ip = get_ipython()
147 ip = get_ipython()
150 ip.reset()
148 ip.reset()
151 ip.run_cell(("class Mylist(list):\n"
149 ip.run_cell(("class Mylist(list):\n"
152 " def __init__(self,x=[]):\n"
150 " def __init__(self,x=[]):\n"
153 " list.__init__(self,x)"))
151 " list.__init__(self,x)"))
154 ip.run_cell("w=Mylist([1,2,3])")
152 ip.run_cell("w=Mylist([1,2,3])")
155
153
156 from pickle import dumps
154 from pickle import dumps
157
155
158 # We need to swap in our main module - this is only necessary
156 # We need to swap in our main module - this is only necessary
159 # inside the test framework, because IPython puts the interactive module
157 # inside the test framework, because IPython puts the interactive module
160 # in place (but the test framework undoes this).
158 # in place (but the test framework undoes this).
161 _main = sys.modules['__main__']
159 _main = sys.modules['__main__']
162 sys.modules['__main__'] = ip.user_module
160 sys.modules['__main__'] = ip.user_module
163 try:
161 try:
164 res = dumps(ip.user_ns["w"])
162 res = dumps(ip.user_ns["w"])
165 finally:
163 finally:
166 sys.modules['__main__'] = _main
164 sys.modules['__main__'] = _main
167 self.assertTrue(isinstance(res, bytes))
165 self.assertTrue(isinstance(res, bytes))
168
166
169 def test_global_ns(self):
167 def test_global_ns(self):
170 "Code in functions must be able to access variables outside them."
168 "Code in functions must be able to access variables outside them."
171 ip = get_ipython()
169 ip = get_ipython()
172 ip.run_cell("a = 10")
170 ip.run_cell("a = 10")
173 ip.run_cell(("def f(x):\n"
171 ip.run_cell(("def f(x):\n"
174 " return x + a"))
172 " return x + a"))
175 ip.run_cell("b = f(12)")
173 ip.run_cell("b = f(12)")
176 self.assertEqual(ip.user_ns["b"], 22)
174 self.assertEqual(ip.user_ns["b"], 22)
177
175
178 def test_bad_custom_tb(self):
176 def test_bad_custom_tb(self):
179 """Check that InteractiveShell is protected from bad custom exception handlers"""
177 """Check that InteractiveShell is protected from bad custom exception handlers"""
180 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
178 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
181 self.assertEqual(ip.custom_exceptions, (IOError,))
179 self.assertEqual(ip.custom_exceptions, (IOError,))
182 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
180 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
183 ip.run_cell(u'raise IOError("foo")')
181 ip.run_cell(u'raise IOError("foo")')
184 self.assertEqual(ip.custom_exceptions, ())
182 self.assertEqual(ip.custom_exceptions, ())
185
183
186 def test_bad_custom_tb_return(self):
184 def test_bad_custom_tb_return(self):
187 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
185 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
188 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
186 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
189 self.assertEqual(ip.custom_exceptions, (NameError,))
187 self.assertEqual(ip.custom_exceptions, (NameError,))
190 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
188 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
191 ip.run_cell(u'a=abracadabra')
189 ip.run_cell(u'a=abracadabra')
192 self.assertEqual(ip.custom_exceptions, ())
190 self.assertEqual(ip.custom_exceptions, ())
193
191
194 def test_drop_by_id(self):
192 def test_drop_by_id(self):
195 myvars = {"a":object(), "b":object(), "c": object()}
193 myvars = {"a":object(), "b":object(), "c": object()}
196 ip.push(myvars, interactive=False)
194 ip.push(myvars, interactive=False)
197 for name in myvars:
195 for name in myvars:
198 assert name in ip.user_ns, name
196 assert name in ip.user_ns, name
199 assert name in ip.user_ns_hidden, name
197 assert name in ip.user_ns_hidden, name
200 ip.user_ns['b'] = 12
198 ip.user_ns['b'] = 12
201 ip.drop_by_id(myvars)
199 ip.drop_by_id(myvars)
202 for name in ["a", "c"]:
200 for name in ["a", "c"]:
203 assert name not in ip.user_ns, name
201 assert name not in ip.user_ns, name
204 assert name not in ip.user_ns_hidden, name
202 assert name not in ip.user_ns_hidden, name
205 assert ip.user_ns['b'] == 12
203 assert ip.user_ns['b'] == 12
206 ip.reset()
204 ip.reset()
207
205
208 def test_var_expand(self):
206 def test_var_expand(self):
209 ip.user_ns['f'] = u'Ca\xf1o'
207 ip.user_ns['f'] = u'Ca\xf1o'
210 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
208 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
211 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
209 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
212 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
210 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
213 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
211 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
214
212
215 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
213 self.assertEqual(ip.var_expand(u"grep x | awk '{print $1}'"), u"grep x | awk '{print $1}'")
216
214
217 ip.user_ns['f'] = b'Ca\xc3\xb1o'
215 ip.user_ns['f'] = b'Ca\xc3\xb1o'
218 # This should not raise any exception:
216 # This should not raise any exception:
219 ip.var_expand(u'echo $f')
217 ip.var_expand(u'echo $f')
220
218
221 def test_var_expand_local(self):
219 def test_var_expand_local(self):
222 """Test local variable expansion in !system and %magic calls"""
220 """Test local variable expansion in !system and %magic calls"""
223 # !system
221 # !system
224 ip.run_cell('def test():\n'
222 ip.run_cell('def test():\n'
225 ' lvar = "ttt"\n'
223 ' lvar = "ttt"\n'
226 ' ret = !echo {lvar}\n'
224 ' ret = !echo {lvar}\n'
227 ' return ret[0]\n')
225 ' return ret[0]\n')
228 res = ip.user_ns['test']()
226 res = ip.user_ns['test']()
229 nt.assert_in('ttt', res)
227 nt.assert_in('ttt', res)
230
228
231 # %magic
229 # %magic
232 ip.run_cell('def makemacro():\n'
230 ip.run_cell('def makemacro():\n'
233 ' macroname = "macro_var_expand_locals"\n'
231 ' macroname = "macro_var_expand_locals"\n'
234 ' %macro {macroname} codestr\n')
232 ' %macro {macroname} codestr\n')
235 ip.user_ns['codestr'] = "str(12)"
233 ip.user_ns['codestr'] = "str(12)"
236 ip.run_cell('makemacro()')
234 ip.run_cell('makemacro()')
237 nt.assert_in('macro_var_expand_locals', ip.user_ns)
235 nt.assert_in('macro_var_expand_locals', ip.user_ns)
238
236
239 def test_var_expand_self(self):
237 def test_var_expand_self(self):
240 """Test variable expansion with the name 'self', which was failing.
238 """Test variable expansion with the name 'self', which was failing.
241
239
242 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
240 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
243 """
241 """
244 ip.run_cell('class cTest:\n'
242 ip.run_cell('class cTest:\n'
245 ' classvar="see me"\n'
243 ' classvar="see me"\n'
246 ' def test(self):\n'
244 ' def test(self):\n'
247 ' res = !echo Variable: {self.classvar}\n'
245 ' res = !echo Variable: {self.classvar}\n'
248 ' return res[0]\n')
246 ' return res[0]\n')
249 nt.assert_in('see me', ip.user_ns['cTest']().test())
247 nt.assert_in('see me', ip.user_ns['cTest']().test())
250
248
251 def test_bad_var_expand(self):
249 def test_bad_var_expand(self):
252 """var_expand on invalid formats shouldn't raise"""
250 """var_expand on invalid formats shouldn't raise"""
253 # SyntaxError
251 # SyntaxError
254 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
252 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
255 # NameError
253 # NameError
256 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
254 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
257 # ZeroDivisionError
255 # ZeroDivisionError
258 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
256 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
259
257
260 def test_silent_postexec(self):
258 def test_silent_postexec(self):
261 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
259 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
262 pre_explicit = mock.Mock()
260 pre_explicit = mock.Mock()
263 pre_always = mock.Mock()
261 pre_always = mock.Mock()
264 post_explicit = mock.Mock()
262 post_explicit = mock.Mock()
265 post_always = mock.Mock()
263 post_always = mock.Mock()
266 all_mocks = [pre_explicit, pre_always, post_explicit, post_always]
264 all_mocks = [pre_explicit, pre_always, post_explicit, post_always]
267
265
268 ip.events.register('pre_run_cell', pre_explicit)
266 ip.events.register('pre_run_cell', pre_explicit)
269 ip.events.register('pre_execute', pre_always)
267 ip.events.register('pre_execute', pre_always)
270 ip.events.register('post_run_cell', post_explicit)
268 ip.events.register('post_run_cell', post_explicit)
271 ip.events.register('post_execute', post_always)
269 ip.events.register('post_execute', post_always)
272
270
273 try:
271 try:
274 ip.run_cell("1", silent=True)
272 ip.run_cell("1", silent=True)
275 assert pre_always.called
273 assert pre_always.called
276 assert not pre_explicit.called
274 assert not pre_explicit.called
277 assert post_always.called
275 assert post_always.called
278 assert not post_explicit.called
276 assert not post_explicit.called
279 # double-check that non-silent exec did what we expected
277 # double-check that non-silent exec did what we expected
280 # silent to avoid
278 # silent to avoid
281 ip.run_cell("1")
279 ip.run_cell("1")
282 assert pre_explicit.called
280 assert pre_explicit.called
283 assert post_explicit.called
281 assert post_explicit.called
284 info, = pre_explicit.call_args[0]
282 info, = pre_explicit.call_args[0]
285 result, = post_explicit.call_args[0]
283 result, = post_explicit.call_args[0]
286 self.assertEqual(info, result.info)
284 self.assertEqual(info, result.info)
287 # check that post hooks are always called
285 # check that post hooks are always called
288 [m.reset_mock() for m in all_mocks]
286 [m.reset_mock() for m in all_mocks]
289 ip.run_cell("syntax error")
287 ip.run_cell("syntax error")
290 assert pre_always.called
288 assert pre_always.called
291 assert pre_explicit.called
289 assert pre_explicit.called
292 assert post_always.called
290 assert post_always.called
293 assert post_explicit.called
291 assert post_explicit.called
294 info, = pre_explicit.call_args[0]
292 info, = pre_explicit.call_args[0]
295 result, = post_explicit.call_args[0]
293 result, = post_explicit.call_args[0]
296 self.assertEqual(info, result.info)
294 self.assertEqual(info, result.info)
297 finally:
295 finally:
298 # remove post-exec
296 # remove post-exec
299 ip.events.unregister('pre_run_cell', pre_explicit)
297 ip.events.unregister('pre_run_cell', pre_explicit)
300 ip.events.unregister('pre_execute', pre_always)
298 ip.events.unregister('pre_execute', pre_always)
301 ip.events.unregister('post_run_cell', post_explicit)
299 ip.events.unregister('post_run_cell', post_explicit)
302 ip.events.unregister('post_execute', post_always)
300 ip.events.unregister('post_execute', post_always)
303
301
304 def test_silent_noadvance(self):
302 def test_silent_noadvance(self):
305 """run_cell(silent=True) doesn't advance execution_count"""
303 """run_cell(silent=True) doesn't advance execution_count"""
306 ec = ip.execution_count
304 ec = ip.execution_count
307 # silent should force store_history=False
305 # silent should force store_history=False
308 ip.run_cell("1", store_history=True, silent=True)
306 ip.run_cell("1", store_history=True, silent=True)
309
307
310 self.assertEqual(ec, ip.execution_count)
308 self.assertEqual(ec, ip.execution_count)
311 # double-check that non-silent exec did what we expected
309 # double-check that non-silent exec did what we expected
312 # silent to avoid
310 # silent to avoid
313 ip.run_cell("1", store_history=True)
311 ip.run_cell("1", store_history=True)
314 self.assertEqual(ec+1, ip.execution_count)
312 self.assertEqual(ec+1, ip.execution_count)
315
313
316 def test_silent_nodisplayhook(self):
314 def test_silent_nodisplayhook(self):
317 """run_cell(silent=True) doesn't trigger displayhook"""
315 """run_cell(silent=True) doesn't trigger displayhook"""
318 d = dict(called=False)
316 d = dict(called=False)
319
317
320 trap = ip.display_trap
318 trap = ip.display_trap
321 save_hook = trap.hook
319 save_hook = trap.hook
322
320
323 def failing_hook(*args, **kwargs):
321 def failing_hook(*args, **kwargs):
324 d['called'] = True
322 d['called'] = True
325
323
326 try:
324 try:
327 trap.hook = failing_hook
325 trap.hook = failing_hook
328 res = ip.run_cell("1", silent=True)
326 res = ip.run_cell("1", silent=True)
329 self.assertFalse(d['called'])
327 self.assertFalse(d['called'])
330 self.assertIsNone(res.result)
328 self.assertIsNone(res.result)
331 # double-check that non-silent exec did what we expected
329 # double-check that non-silent exec did what we expected
332 # silent to avoid
330 # silent to avoid
333 ip.run_cell("1")
331 ip.run_cell("1")
334 self.assertTrue(d['called'])
332 self.assertTrue(d['called'])
335 finally:
333 finally:
336 trap.hook = save_hook
334 trap.hook = save_hook
337
335
338 def test_ofind_line_magic(self):
336 def test_ofind_line_magic(self):
339 from IPython.core.magic import register_line_magic
337 from IPython.core.magic import register_line_magic
340
338
341 @register_line_magic
339 @register_line_magic
342 def lmagic(line):
340 def lmagic(line):
343 "A line magic"
341 "A line magic"
344
342
345 # Get info on line magic
343 # Get info on line magic
346 lfind = ip._ofind('lmagic')
344 lfind = ip._ofind('lmagic')
347 info = dict(found=True, isalias=False, ismagic=True,
345 info = dict(found=True, isalias=False, ismagic=True,
348 namespace = 'IPython internal', obj= lmagic.__wrapped__,
346 namespace = 'IPython internal', obj= lmagic.__wrapped__,
349 parent = None)
347 parent = None)
350 nt.assert_equal(lfind, info)
348 nt.assert_equal(lfind, info)
351
349
352 def test_ofind_cell_magic(self):
350 def test_ofind_cell_magic(self):
353 from IPython.core.magic import register_cell_magic
351 from IPython.core.magic import register_cell_magic
354
352
355 @register_cell_magic
353 @register_cell_magic
356 def cmagic(line, cell):
354 def cmagic(line, cell):
357 "A cell magic"
355 "A cell magic"
358
356
359 # Get info on cell magic
357 # Get info on cell magic
360 find = ip._ofind('cmagic')
358 find = ip._ofind('cmagic')
361 info = dict(found=True, isalias=False, ismagic=True,
359 info = dict(found=True, isalias=False, ismagic=True,
362 namespace = 'IPython internal', obj= cmagic.__wrapped__,
360 namespace = 'IPython internal', obj= cmagic.__wrapped__,
363 parent = None)
361 parent = None)
364 nt.assert_equal(find, info)
362 nt.assert_equal(find, info)
365
363
366 def test_ofind_property_with_error(self):
364 def test_ofind_property_with_error(self):
367 class A(object):
365 class A(object):
368 @property
366 @property
369 def foo(self):
367 def foo(self):
370 raise NotImplementedError()
368 raise NotImplementedError()
371 a = A()
369 a = A()
372
370
373 found = ip._ofind('a.foo', [('locals', locals())])
371 found = ip._ofind('a.foo', [('locals', locals())])
374 info = dict(found=True, isalias=False, ismagic=False,
372 info = dict(found=True, isalias=False, ismagic=False,
375 namespace='locals', obj=A.foo, parent=a)
373 namespace='locals', obj=A.foo, parent=a)
376 nt.assert_equal(found, info)
374 nt.assert_equal(found, info)
377
375
378 def test_ofind_multiple_attribute_lookups(self):
376 def test_ofind_multiple_attribute_lookups(self):
379 class A(object):
377 class A(object):
380 @property
378 @property
381 def foo(self):
379 def foo(self):
382 raise NotImplementedError()
380 raise NotImplementedError()
383
381
384 a = A()
382 a = A()
385 a.a = A()
383 a.a = A()
386 a.a.a = A()
384 a.a.a = A()
387
385
388 found = ip._ofind('a.a.a.foo', [('locals', locals())])
386 found = ip._ofind('a.a.a.foo', [('locals', locals())])
389 info = dict(found=True, isalias=False, ismagic=False,
387 info = dict(found=True, isalias=False, ismagic=False,
390 namespace='locals', obj=A.foo, parent=a.a.a)
388 namespace='locals', obj=A.foo, parent=a.a.a)
391 nt.assert_equal(found, info)
389 nt.assert_equal(found, info)
392
390
393 def test_ofind_slotted_attributes(self):
391 def test_ofind_slotted_attributes(self):
394 class A(object):
392 class A(object):
395 __slots__ = ['foo']
393 __slots__ = ['foo']
396 def __init__(self):
394 def __init__(self):
397 self.foo = 'bar'
395 self.foo = 'bar'
398
396
399 a = A()
397 a = A()
400 found = ip._ofind('a.foo', [('locals', locals())])
398 found = ip._ofind('a.foo', [('locals', locals())])
401 info = dict(found=True, isalias=False, ismagic=False,
399 info = dict(found=True, isalias=False, ismagic=False,
402 namespace='locals', obj=a.foo, parent=a)
400 namespace='locals', obj=a.foo, parent=a)
403 nt.assert_equal(found, info)
401 nt.assert_equal(found, info)
404
402
405 found = ip._ofind('a.bar', [('locals', locals())])
403 found = ip._ofind('a.bar', [('locals', locals())])
406 info = dict(found=False, isalias=False, ismagic=False,
404 info = dict(found=False, isalias=False, ismagic=False,
407 namespace=None, obj=None, parent=a)
405 namespace=None, obj=None, parent=a)
408 nt.assert_equal(found, info)
406 nt.assert_equal(found, info)
409
407
410 def test_ofind_prefers_property_to_instance_level_attribute(self):
408 def test_ofind_prefers_property_to_instance_level_attribute(self):
411 class A(object):
409 class A(object):
412 @property
410 @property
413 def foo(self):
411 def foo(self):
414 return 'bar'
412 return 'bar'
415 a = A()
413 a = A()
416 a.__dict__['foo'] = 'baz'
414 a.__dict__['foo'] = 'baz'
417 nt.assert_equal(a.foo, 'bar')
415 nt.assert_equal(a.foo, 'bar')
418 found = ip._ofind('a.foo', [('locals', locals())])
416 found = ip._ofind('a.foo', [('locals', locals())])
419 nt.assert_is(found['obj'], A.foo)
417 nt.assert_is(found['obj'], A.foo)
420
418
421 def test_custom_syntaxerror_exception(self):
419 def test_custom_syntaxerror_exception(self):
422 called = []
420 called = []
423 def my_handler(shell, etype, value, tb, tb_offset=None):
421 def my_handler(shell, etype, value, tb, tb_offset=None):
424 called.append(etype)
422 called.append(etype)
425 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
423 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
426
424
427 ip.set_custom_exc((SyntaxError,), my_handler)
425 ip.set_custom_exc((SyntaxError,), my_handler)
428 try:
426 try:
429 ip.run_cell("1f")
427 ip.run_cell("1f")
430 # Check that this was called, and only once.
428 # Check that this was called, and only once.
431 self.assertEqual(called, [SyntaxError])
429 self.assertEqual(called, [SyntaxError])
432 finally:
430 finally:
433 # Reset the custom exception hook
431 # Reset the custom exception hook
434 ip.set_custom_exc((), None)
432 ip.set_custom_exc((), None)
435
433
436 def test_custom_exception(self):
434 def test_custom_exception(self):
437 called = []
435 called = []
438 def my_handler(shell, etype, value, tb, tb_offset=None):
436 def my_handler(shell, etype, value, tb, tb_offset=None):
439 called.append(etype)
437 called.append(etype)
440 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
438 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
441
439
442 ip.set_custom_exc((ValueError,), my_handler)
440 ip.set_custom_exc((ValueError,), my_handler)
443 try:
441 try:
444 res = ip.run_cell("raise ValueError('test')")
442 res = ip.run_cell("raise ValueError('test')")
445 # Check that this was called, and only once.
443 # Check that this was called, and only once.
446 self.assertEqual(called, [ValueError])
444 self.assertEqual(called, [ValueError])
447 # Check that the error is on the result object
445 # Check that the error is on the result object
448 self.assertIsInstance(res.error_in_exec, ValueError)
446 self.assertIsInstance(res.error_in_exec, ValueError)
449 finally:
447 finally:
450 # Reset the custom exception hook
448 # Reset the custom exception hook
451 ip.set_custom_exc((), None)
449 ip.set_custom_exc((), None)
452
450
453 def test_mktempfile(self):
451 def test_mktempfile(self):
454 filename = ip.mktempfile()
452 filename = ip.mktempfile()
455 # Check that we can open the file again on Windows
453 # Check that we can open the file again on Windows
456 with open(filename, 'w') as f:
454 with open(filename, 'w') as f:
457 f.write('abc')
455 f.write('abc')
458
456
459 filename = ip.mktempfile(data='blah')
457 filename = ip.mktempfile(data='blah')
460 with open(filename, 'r') as f:
458 with open(filename, 'r') as f:
461 self.assertEqual(f.read(), 'blah')
459 self.assertEqual(f.read(), 'blah')
462
460
463 def test_new_main_mod(self):
461 def test_new_main_mod(self):
464 # Smoketest to check that this accepts a unicode module name
462 # Smoketest to check that this accepts a unicode module name
465 name = u'jiefmw'
463 name = u'jiefmw'
466 mod = ip.new_main_mod(u'%s.py' % name, name)
464 mod = ip.new_main_mod(u'%s.py' % name, name)
467 self.assertEqual(mod.__name__, name)
465 self.assertEqual(mod.__name__, name)
468
466
469 def test_get_exception_only(self):
467 def test_get_exception_only(self):
470 try:
468 try:
471 raise KeyboardInterrupt
469 raise KeyboardInterrupt
472 except KeyboardInterrupt:
470 except KeyboardInterrupt:
473 msg = ip.get_exception_only()
471 msg = ip.get_exception_only()
474 self.assertEqual(msg, 'KeyboardInterrupt\n')
472 self.assertEqual(msg, 'KeyboardInterrupt\n')
475
473
476 try:
474 try:
477 raise DerivedInterrupt("foo")
475 raise DerivedInterrupt("foo")
478 except KeyboardInterrupt:
476 except KeyboardInterrupt:
479 msg = ip.get_exception_only()
477 msg = ip.get_exception_only()
480 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
478 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
481
479
482 def test_inspect_text(self):
480 def test_inspect_text(self):
483 ip.run_cell('a = 5')
481 ip.run_cell('a = 5')
484 text = ip.object_inspect_text('a')
482 text = ip.object_inspect_text('a')
485 self.assertIsInstance(text, str)
483 self.assertIsInstance(text, str)
486
484
487 def test_last_execution_result(self):
485 def test_last_execution_result(self):
488 """ Check that last execution result gets set correctly (GH-10702) """
486 """ Check that last execution result gets set correctly (GH-10702) """
489 result = ip.run_cell('a = 5; a')
487 result = ip.run_cell('a = 5; a')
490 self.assertTrue(ip.last_execution_succeeded)
488 self.assertTrue(ip.last_execution_succeeded)
491 self.assertEqual(ip.last_execution_result.result, 5)
489 self.assertEqual(ip.last_execution_result.result, 5)
492
490
493 result = ip.run_cell('a = x_invalid_id_x')
491 result = ip.run_cell('a = x_invalid_id_x')
494 self.assertFalse(ip.last_execution_succeeded)
492 self.assertFalse(ip.last_execution_succeeded)
495 self.assertFalse(ip.last_execution_result.success)
493 self.assertFalse(ip.last_execution_result.success)
496 self.assertIsInstance(ip.last_execution_result.error_in_exec, NameError)
494 self.assertIsInstance(ip.last_execution_result.error_in_exec, NameError)
497
495
498
496
499 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
497 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
500
498
501 @onlyif_unicode_paths
499 @onlyif_unicode_paths
502 def setUp(self):
500 def setUp(self):
503 self.BASETESTDIR = tempfile.mkdtemp()
501 self.BASETESTDIR = tempfile.mkdtemp()
504 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
502 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
505 os.mkdir(self.TESTDIR)
503 os.mkdir(self.TESTDIR)
506 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
504 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
507 sfile.write("pass\n")
505 sfile.write("pass\n")
508 self.oldpath = os.getcwd()
506 self.oldpath = os.getcwd()
509 os.chdir(self.TESTDIR)
507 os.chdir(self.TESTDIR)
510 self.fname = u"åäötestscript.py"
508 self.fname = u"åäötestscript.py"
511
509
512 def tearDown(self):
510 def tearDown(self):
513 os.chdir(self.oldpath)
511 os.chdir(self.oldpath)
514 shutil.rmtree(self.BASETESTDIR)
512 shutil.rmtree(self.BASETESTDIR)
515
513
516 @onlyif_unicode_paths
514 @onlyif_unicode_paths
517 def test_1(self):
515 def test_1(self):
518 """Test safe_execfile with non-ascii path
516 """Test safe_execfile with non-ascii path
519 """
517 """
520 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
518 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
521
519
522 class ExitCodeChecks(tt.TempFileMixin):
520 class ExitCodeChecks(tt.TempFileMixin):
523 def test_exit_code_ok(self):
521 def test_exit_code_ok(self):
524 self.system('exit 0')
522 self.system('exit 0')
525 self.assertEqual(ip.user_ns['_exit_code'], 0)
523 self.assertEqual(ip.user_ns['_exit_code'], 0)
526
524
527 def test_exit_code_error(self):
525 def test_exit_code_error(self):
528 self.system('exit 1')
526 self.system('exit 1')
529 self.assertEqual(ip.user_ns['_exit_code'], 1)
527 self.assertEqual(ip.user_ns['_exit_code'], 1)
530
528
531 @skipif(not hasattr(signal, 'SIGALRM'))
529 @skipif(not hasattr(signal, 'SIGALRM'))
532 def test_exit_code_signal(self):
530 def test_exit_code_signal(self):
533 self.mktmp("import signal, time\n"
531 self.mktmp("import signal, time\n"
534 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
532 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
535 "time.sleep(1)\n")
533 "time.sleep(1)\n")
536 self.system("%s %s" % (sys.executable, self.fname))
534 self.system("%s %s" % (sys.executable, self.fname))
537 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
535 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
538
536
539 @onlyif_cmds_exist("csh")
537 @onlyif_cmds_exist("csh")
540 def test_exit_code_signal_csh(self):
538 def test_exit_code_signal_csh(self):
541 SHELL = os.environ.get('SHELL', None)
539 SHELL = os.environ.get('SHELL', None)
542 os.environ['SHELL'] = find_cmd("csh")
540 os.environ['SHELL'] = find_cmd("csh")
543 try:
541 try:
544 self.test_exit_code_signal()
542 self.test_exit_code_signal()
545 finally:
543 finally:
546 if SHELL is not None:
544 if SHELL is not None:
547 os.environ['SHELL'] = SHELL
545 os.environ['SHELL'] = SHELL
548 else:
546 else:
549 del os.environ['SHELL']
547 del os.environ['SHELL']
550
548
551 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
549 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
552 system = ip.system_raw
550 system = ip.system_raw
553
551
554 @onlyif_unicode_paths
552 @onlyif_unicode_paths
555 def test_1(self):
553 def test_1(self):
556 """Test system_raw with non-ascii cmd
554 """Test system_raw with non-ascii cmd
557 """
555 """
558 cmd = u'''python -c "'åäö'" '''
556 cmd = u'''python -c "'åäö'" '''
559 ip.system_raw(cmd)
557 ip.system_raw(cmd)
560
558
561 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
559 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
562 @mock.patch('os.system', side_effect=KeyboardInterrupt)
560 @mock.patch('os.system', side_effect=KeyboardInterrupt)
563 def test_control_c(self, *mocks):
561 def test_control_c(self, *mocks):
564 try:
562 try:
565 self.system("sleep 1 # wont happen")
563 self.system("sleep 1 # wont happen")
566 except KeyboardInterrupt:
564 except KeyboardInterrupt:
567 self.fail("system call should intercept "
565 self.fail("system call should intercept "
568 "keyboard interrupt from subprocess.call")
566 "keyboard interrupt from subprocess.call")
569 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
567 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
570
568
571 # TODO: Exit codes are currently ignored on Windows.
569 # TODO: Exit codes are currently ignored on Windows.
572 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
570 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
573 system = ip.system_piped
571 system = ip.system_piped
574
572
575 @skip_win32
573 @skip_win32
576 def test_exit_code_ok(self):
574 def test_exit_code_ok(self):
577 ExitCodeChecks.test_exit_code_ok(self)
575 ExitCodeChecks.test_exit_code_ok(self)
578
576
579 @skip_win32
577 @skip_win32
580 def test_exit_code_error(self):
578 def test_exit_code_error(self):
581 ExitCodeChecks.test_exit_code_error(self)
579 ExitCodeChecks.test_exit_code_error(self)
582
580
583 @skip_win32
581 @skip_win32
584 def test_exit_code_signal(self):
582 def test_exit_code_signal(self):
585 ExitCodeChecks.test_exit_code_signal(self)
583 ExitCodeChecks.test_exit_code_signal(self)
586
584
587 class TestModules(unittest.TestCase, tt.TempFileMixin):
585 class TestModules(unittest.TestCase, tt.TempFileMixin):
588 def test_extraneous_loads(self):
586 def test_extraneous_loads(self):
589 """Test we're not loading modules on startup that we shouldn't.
587 """Test we're not loading modules on startup that we shouldn't.
590 """
588 """
591 self.mktmp("import sys\n"
589 self.mktmp("import sys\n"
592 "print('numpy' in sys.modules)\n"
590 "print('numpy' in sys.modules)\n"
593 "print('ipyparallel' in sys.modules)\n"
591 "print('ipyparallel' in sys.modules)\n"
594 "print('ipykernel' in sys.modules)\n"
592 "print('ipykernel' in sys.modules)\n"
595 )
593 )
596 out = "False\nFalse\nFalse\n"
594 out = "False\nFalse\nFalse\n"
597 tt.ipexec_validate(self.fname, out)
595 tt.ipexec_validate(self.fname, out)
598
596
599 class Negator(ast.NodeTransformer):
597 class Negator(ast.NodeTransformer):
600 """Negates all number literals in an AST."""
598 """Negates all number literals in an AST."""
601 def visit_Num(self, node):
599 def visit_Num(self, node):
602 node.n = -node.n
600 node.n = -node.n
603 return node
601 return node
604
602
605 class TestAstTransform(unittest.TestCase):
603 class TestAstTransform(unittest.TestCase):
606 def setUp(self):
604 def setUp(self):
607 self.negator = Negator()
605 self.negator = Negator()
608 ip.ast_transformers.append(self.negator)
606 ip.ast_transformers.append(self.negator)
609
607
610 def tearDown(self):
608 def tearDown(self):
611 ip.ast_transformers.remove(self.negator)
609 ip.ast_transformers.remove(self.negator)
612
610
613 def test_run_cell(self):
611 def test_run_cell(self):
614 with tt.AssertPrints('-34'):
612 with tt.AssertPrints('-34'):
615 ip.run_cell('print (12 + 22)')
613 ip.run_cell('print (12 + 22)')
616
614
617 # A named reference to a number shouldn't be transformed.
615 # A named reference to a number shouldn't be transformed.
618 ip.user_ns['n'] = 55
616 ip.user_ns['n'] = 55
619 with tt.AssertNotPrints('-55'):
617 with tt.AssertNotPrints('-55'):
620 ip.run_cell('print (n)')
618 ip.run_cell('print (n)')
621
619
622 def test_timeit(self):
620 def test_timeit(self):
623 called = set()
621 called = set()
624 def f(x):
622 def f(x):
625 called.add(x)
623 called.add(x)
626 ip.push({'f':f})
624 ip.push({'f':f})
627
625
628 with tt.AssertPrints("std. dev. of"):
626 with tt.AssertPrints("std. dev. of"):
629 ip.run_line_magic("timeit", "-n1 f(1)")
627 ip.run_line_magic("timeit", "-n1 f(1)")
630 self.assertEqual(called, {-1})
628 self.assertEqual(called, {-1})
631 called.clear()
629 called.clear()
632
630
633 with tt.AssertPrints("std. dev. of"):
631 with tt.AssertPrints("std. dev. of"):
634 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
632 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
635 self.assertEqual(called, {-2, -3})
633 self.assertEqual(called, {-2, -3})
636
634
637 def test_time(self):
635 def test_time(self):
638 called = []
636 called = []
639 def f(x):
637 def f(x):
640 called.append(x)
638 called.append(x)
641 ip.push({'f':f})
639 ip.push({'f':f})
642
640
643 # Test with an expression
641 # Test with an expression
644 with tt.AssertPrints("Wall time: "):
642 with tt.AssertPrints("Wall time: "):
645 ip.run_line_magic("time", "f(5+9)")
643 ip.run_line_magic("time", "f(5+9)")
646 self.assertEqual(called, [-14])
644 self.assertEqual(called, [-14])
647 called[:] = []
645 called[:] = []
648
646
649 # Test with a statement (different code path)
647 # Test with a statement (different code path)
650 with tt.AssertPrints("Wall time: "):
648 with tt.AssertPrints("Wall time: "):
651 ip.run_line_magic("time", "a = f(-3 + -2)")
649 ip.run_line_magic("time", "a = f(-3 + -2)")
652 self.assertEqual(called, [5])
650 self.assertEqual(called, [5])
653
651
654 def test_macro(self):
652 def test_macro(self):
655 ip.push({'a':10})
653 ip.push({'a':10})
656 # The AST transformation makes this do a+=-1
654 # The AST transformation makes this do a+=-1
657 ip.define_macro("amacro", "a+=1\nprint(a)")
655 ip.define_macro("amacro", "a+=1\nprint(a)")
658
656
659 with tt.AssertPrints("9"):
657 with tt.AssertPrints("9"):
660 ip.run_cell("amacro")
658 ip.run_cell("amacro")
661 with tt.AssertPrints("8"):
659 with tt.AssertPrints("8"):
662 ip.run_cell("amacro")
660 ip.run_cell("amacro")
663
661
664 class IntegerWrapper(ast.NodeTransformer):
662 class IntegerWrapper(ast.NodeTransformer):
665 """Wraps all integers in a call to Integer()"""
663 """Wraps all integers in a call to Integer()"""
666 def visit_Num(self, node):
664 def visit_Num(self, node):
667 if isinstance(node.n, int):
665 if isinstance(node.n, int):
668 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
666 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
669 args=[node], keywords=[])
667 args=[node], keywords=[])
670 return node
668 return node
671
669
672 class TestAstTransform2(unittest.TestCase):
670 class TestAstTransform2(unittest.TestCase):
673 def setUp(self):
671 def setUp(self):
674 self.intwrapper = IntegerWrapper()
672 self.intwrapper = IntegerWrapper()
675 ip.ast_transformers.append(self.intwrapper)
673 ip.ast_transformers.append(self.intwrapper)
676
674
677 self.calls = []
675 self.calls = []
678 def Integer(*args):
676 def Integer(*args):
679 self.calls.append(args)
677 self.calls.append(args)
680 return args
678 return args
681 ip.push({"Integer": Integer})
679 ip.push({"Integer": Integer})
682
680
683 def tearDown(self):
681 def tearDown(self):
684 ip.ast_transformers.remove(self.intwrapper)
682 ip.ast_transformers.remove(self.intwrapper)
685 del ip.user_ns['Integer']
683 del ip.user_ns['Integer']
686
684
687 def test_run_cell(self):
685 def test_run_cell(self):
688 ip.run_cell("n = 2")
686 ip.run_cell("n = 2")
689 self.assertEqual(self.calls, [(2,)])
687 self.assertEqual(self.calls, [(2,)])
690
688
691 # This shouldn't throw an error
689 # This shouldn't throw an error
692 ip.run_cell("o = 2.0")
690 ip.run_cell("o = 2.0")
693 self.assertEqual(ip.user_ns['o'], 2.0)
691 self.assertEqual(ip.user_ns['o'], 2.0)
694
692
695 def test_timeit(self):
693 def test_timeit(self):
696 called = set()
694 called = set()
697 def f(x):
695 def f(x):
698 called.add(x)
696 called.add(x)
699 ip.push({'f':f})
697 ip.push({'f':f})
700
698
701 with tt.AssertPrints("std. dev. of"):
699 with tt.AssertPrints("std. dev. of"):
702 ip.run_line_magic("timeit", "-n1 f(1)")
700 ip.run_line_magic("timeit", "-n1 f(1)")
703 self.assertEqual(called, {(1,)})
701 self.assertEqual(called, {(1,)})
704 called.clear()
702 called.clear()
705
703
706 with tt.AssertPrints("std. dev. of"):
704 with tt.AssertPrints("std. dev. of"):
707 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
705 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
708 self.assertEqual(called, {(2,), (3,)})
706 self.assertEqual(called, {(2,), (3,)})
709
707
710 class ErrorTransformer(ast.NodeTransformer):
708 class ErrorTransformer(ast.NodeTransformer):
711 """Throws an error when it sees a number."""
709 """Throws an error when it sees a number."""
712 def visit_Num(self, node):
710 def visit_Num(self, node):
713 raise ValueError("test")
711 raise ValueError("test")
714
712
715 class TestAstTransformError(unittest.TestCase):
713 class TestAstTransformError(unittest.TestCase):
716 def test_unregistering(self):
714 def test_unregistering(self):
717 err_transformer = ErrorTransformer()
715 err_transformer = ErrorTransformer()
718 ip.ast_transformers.append(err_transformer)
716 ip.ast_transformers.append(err_transformer)
719
717
720 with tt.AssertPrints("unregister", channel='stderr'):
718 with tt.AssertPrints("unregister", channel='stderr'):
721 ip.run_cell("1 + 2")
719 ip.run_cell("1 + 2")
722
720
723 # This should have been removed.
721 # This should have been removed.
724 nt.assert_not_in(err_transformer, ip.ast_transformers)
722 nt.assert_not_in(err_transformer, ip.ast_transformers)
725
723
726
724
727 class StringRejector(ast.NodeTransformer):
725 class StringRejector(ast.NodeTransformer):
728 """Throws an InputRejected when it sees a string literal.
726 """Throws an InputRejected when it sees a string literal.
729
727
730 Used to verify that NodeTransformers can signal that a piece of code should
728 Used to verify that NodeTransformers can signal that a piece of code should
731 not be executed by throwing an InputRejected.
729 not be executed by throwing an InputRejected.
732 """
730 """
733
731
734 def visit_Str(self, node):
732 def visit_Str(self, node):
735 raise InputRejected("test")
733 raise InputRejected("test")
736
734
737
735
738 class TestAstTransformInputRejection(unittest.TestCase):
736 class TestAstTransformInputRejection(unittest.TestCase):
739
737
740 def setUp(self):
738 def setUp(self):
741 self.transformer = StringRejector()
739 self.transformer = StringRejector()
742 ip.ast_transformers.append(self.transformer)
740 ip.ast_transformers.append(self.transformer)
743
741
744 def tearDown(self):
742 def tearDown(self):
745 ip.ast_transformers.remove(self.transformer)
743 ip.ast_transformers.remove(self.transformer)
746
744
747 def test_input_rejection(self):
745 def test_input_rejection(self):
748 """Check that NodeTransformers can reject input."""
746 """Check that NodeTransformers can reject input."""
749
747
750 expect_exception_tb = tt.AssertPrints("InputRejected: test")
748 expect_exception_tb = tt.AssertPrints("InputRejected: test")
751 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
749 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
752
750
753 # Run the same check twice to verify that the transformer is not
751 # Run the same check twice to verify that the transformer is not
754 # disabled after raising.
752 # disabled after raising.
755 with expect_exception_tb, expect_no_cell_output:
753 with expect_exception_tb, expect_no_cell_output:
756 ip.run_cell("'unsafe'")
754 ip.run_cell("'unsafe'")
757
755
758 with expect_exception_tb, expect_no_cell_output:
756 with expect_exception_tb, expect_no_cell_output:
759 res = ip.run_cell("'unsafe'")
757 res = ip.run_cell("'unsafe'")
760
758
761 self.assertIsInstance(res.error_before_exec, InputRejected)
759 self.assertIsInstance(res.error_before_exec, InputRejected)
762
760
763 def test__IPYTHON__():
761 def test__IPYTHON__():
764 # This shouldn't raise a NameError, that's all
762 # This shouldn't raise a NameError, that's all
765 __IPYTHON__
763 __IPYTHON__
766
764
767
765
768 class DummyRepr(object):
766 class DummyRepr(object):
769 def __repr__(self):
767 def __repr__(self):
770 return "DummyRepr"
768 return "DummyRepr"
771
769
772 def _repr_html_(self):
770 def _repr_html_(self):
773 return "<b>dummy</b>"
771 return "<b>dummy</b>"
774
772
775 def _repr_javascript_(self):
773 def _repr_javascript_(self):
776 return "console.log('hi');", {'key': 'value'}
774 return "console.log('hi');", {'key': 'value'}
777
775
778
776
779 def test_user_variables():
777 def test_user_variables():
780 # enable all formatters
778 # enable all formatters
781 ip.display_formatter.active_types = ip.display_formatter.format_types
779 ip.display_formatter.active_types = ip.display_formatter.format_types
782
780
783 ip.user_ns['dummy'] = d = DummyRepr()
781 ip.user_ns['dummy'] = d = DummyRepr()
784 keys = {'dummy', 'doesnotexist'}
782 keys = {'dummy', 'doesnotexist'}
785 r = ip.user_expressions({ key:key for key in keys})
783 r = ip.user_expressions({ key:key for key in keys})
786
784
787 nt.assert_equal(keys, set(r.keys()))
785 nt.assert_equal(keys, set(r.keys()))
788 dummy = r['dummy']
786 dummy = r['dummy']
789 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
787 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
790 nt.assert_equal(dummy['status'], 'ok')
788 nt.assert_equal(dummy['status'], 'ok')
791 data = dummy['data']
789 data = dummy['data']
792 metadata = dummy['metadata']
790 metadata = dummy['metadata']
793 nt.assert_equal(data.get('text/html'), d._repr_html_())
791 nt.assert_equal(data.get('text/html'), d._repr_html_())
794 js, jsmd = d._repr_javascript_()
792 js, jsmd = d._repr_javascript_()
795 nt.assert_equal(data.get('application/javascript'), js)
793 nt.assert_equal(data.get('application/javascript'), js)
796 nt.assert_equal(metadata.get('application/javascript'), jsmd)
794 nt.assert_equal(metadata.get('application/javascript'), jsmd)
797
795
798 dne = r['doesnotexist']
796 dne = r['doesnotexist']
799 nt.assert_equal(dne['status'], 'error')
797 nt.assert_equal(dne['status'], 'error')
800 nt.assert_equal(dne['ename'], 'NameError')
798 nt.assert_equal(dne['ename'], 'NameError')
801
799
802 # back to text only
800 # back to text only
803 ip.display_formatter.active_types = ['text/plain']
801 ip.display_formatter.active_types = ['text/plain']
804
802
805 def test_user_expression():
803 def test_user_expression():
806 # enable all formatters
804 # enable all formatters
807 ip.display_formatter.active_types = ip.display_formatter.format_types
805 ip.display_formatter.active_types = ip.display_formatter.format_types
808 query = {
806 query = {
809 'a' : '1 + 2',
807 'a' : '1 + 2',
810 'b' : '1/0',
808 'b' : '1/0',
811 }
809 }
812 r = ip.user_expressions(query)
810 r = ip.user_expressions(query)
813 import pprint
811 import pprint
814 pprint.pprint(r)
812 pprint.pprint(r)
815 nt.assert_equal(set(r.keys()), set(query.keys()))
813 nt.assert_equal(set(r.keys()), set(query.keys()))
816 a = r['a']
814 a = r['a']
817 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
815 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
818 nt.assert_equal(a['status'], 'ok')
816 nt.assert_equal(a['status'], 'ok')
819 data = a['data']
817 data = a['data']
820 metadata = a['metadata']
818 metadata = a['metadata']
821 nt.assert_equal(data.get('text/plain'), '3')
819 nt.assert_equal(data.get('text/plain'), '3')
822
820
823 b = r['b']
821 b = r['b']
824 nt.assert_equal(b['status'], 'error')
822 nt.assert_equal(b['status'], 'error')
825 nt.assert_equal(b['ename'], 'ZeroDivisionError')
823 nt.assert_equal(b['ename'], 'ZeroDivisionError')
826
824
827 # back to text only
825 # back to text only
828 ip.display_formatter.active_types = ['text/plain']
826 ip.display_formatter.active_types = ['text/plain']
829
827
830
828
831
829
832
830
833
831
834 class TestSyntaxErrorTransformer(unittest.TestCase):
832 class TestSyntaxErrorTransformer(unittest.TestCase):
835 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
833 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
836
834
837 class SyntaxErrorTransformer(InputTransformer):
835 class SyntaxErrorTransformer(InputTransformer):
838
836
839 def push(self, line):
837 def push(self, line):
840 pos = line.find('syntaxerror')
838 pos = line.find('syntaxerror')
841 if pos >= 0:
839 if pos >= 0:
842 e = SyntaxError('input contains "syntaxerror"')
840 e = SyntaxError('input contains "syntaxerror"')
843 e.text = line
841 e.text = line
844 e.offset = pos + 1
842 e.offset = pos + 1
845 raise e
843 raise e
846 return line
844 return line
847
845
848 def reset(self):
846 def reset(self):
849 pass
847 pass
850
848
851 def setUp(self):
849 def setUp(self):
852 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
850 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
853 ip.input_splitter.python_line_transforms.append(self.transformer)
851 ip.input_splitter.python_line_transforms.append(self.transformer)
854 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
852 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
855
853
856 def tearDown(self):
854 def tearDown(self):
857 ip.input_splitter.python_line_transforms.remove(self.transformer)
855 ip.input_splitter.python_line_transforms.remove(self.transformer)
858 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
856 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
859
857
860 def test_syntaxerror_input_transformer(self):
858 def test_syntaxerror_input_transformer(self):
861 with tt.AssertPrints('1234'):
859 with tt.AssertPrints('1234'):
862 ip.run_cell('1234')
860 ip.run_cell('1234')
863 with tt.AssertPrints('SyntaxError: invalid syntax'):
861 with tt.AssertPrints('SyntaxError: invalid syntax'):
864 ip.run_cell('1 2 3') # plain python syntax error
862 ip.run_cell('1 2 3') # plain python syntax error
865 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
863 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
866 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
864 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
867 with tt.AssertPrints('3456'):
865 with tt.AssertPrints('3456'):
868 ip.run_cell('3456')
866 ip.run_cell('3456')
869
867
870
868
871
869
872 def test_warning_suppression():
870 def test_warning_suppression():
873 ip.run_cell("import warnings")
871 ip.run_cell("import warnings")
874 try:
872 try:
875 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
873 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
876 ip.run_cell("warnings.warn('asdf')")
874 ip.run_cell("warnings.warn('asdf')")
877 # Here's the real test -- if we run that again, we should get the
875 # Here's the real test -- if we run that again, we should get the
878 # warning again. Traditionally, each warning was only issued once per
876 # warning again. Traditionally, each warning was only issued once per
879 # IPython session (approximately), even if the user typed in new and
877 # IPython session (approximately), even if the user typed in new and
880 # different code that should have also triggered the warning, leading
878 # different code that should have also triggered the warning, leading
881 # to much confusion.
879 # to much confusion.
882 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
880 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
883 ip.run_cell("warnings.warn('asdf')")
881 ip.run_cell("warnings.warn('asdf')")
884 finally:
882 finally:
885 ip.run_cell("del warnings")
883 ip.run_cell("del warnings")
886
884
887
885
888 def test_deprecation_warning():
886 def test_deprecation_warning():
889 ip.run_cell("""
887 ip.run_cell("""
890 import warnings
888 import warnings
891 def wrn():
889 def wrn():
892 warnings.warn(
890 warnings.warn(
893 "I AM A WARNING",
891 "I AM A WARNING",
894 DeprecationWarning
892 DeprecationWarning
895 )
893 )
896 """)
894 """)
897 try:
895 try:
898 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
896 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
899 ip.run_cell("wrn()")
897 ip.run_cell("wrn()")
900 finally:
898 finally:
901 ip.run_cell("del warnings")
899 ip.run_cell("del warnings")
902 ip.run_cell("del wrn")
900 ip.run_cell("del wrn")
903
901
904
902
905 class TestImportNoDeprecate(tt.TempFileMixin):
903 class TestImportNoDeprecate(tt.TempFileMixin):
906
904
907 def setup(self):
905 def setup(self):
908 """Make a valid python temp file."""
906 """Make a valid python temp file."""
909 self.mktmp("""
907 self.mktmp("""
910 import warnings
908 import warnings
911 def wrn():
909 def wrn():
912 warnings.warn(
910 warnings.warn(
913 "I AM A WARNING",
911 "I AM A WARNING",
914 DeprecationWarning
912 DeprecationWarning
915 )
913 )
916 """)
914 """)
917
915
918 def test_no_dep(self):
916 def test_no_dep(self):
919 """
917 """
920 No deprecation warning should be raised from imported functions
918 No deprecation warning should be raised from imported functions
921 """
919 """
922 ip.run_cell("from {} import wrn".format(self.fname))
920 ip.run_cell("from {} import wrn".format(self.fname))
923
921
924 with tt.AssertNotPrints("I AM A WARNING"):
922 with tt.AssertNotPrints("I AM A WARNING"):
925 ip.run_cell("wrn()")
923 ip.run_cell("wrn()")
926 ip.run_cell("del wrn")
924 ip.run_cell("del wrn")
@@ -1,434 +1,432
1 """Tests for the object inspection functionality.
1 """Tests for the object inspection functionality.
2 """
2 """
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7
7
8 from inspect import Signature, Parameter
8 from inspect import Signature, Parameter
9 import os
9 import os
10 import re
10 import re
11 import sys
11 import sys
12
12
13 import nose.tools as nt
13 import nose.tools as nt
14
14
15 from .. import oinspect
15 from .. import oinspect
16 from IPython.core.magic import (Magics, magics_class, line_magic,
16 from IPython.core.magic import (Magics, magics_class, line_magic,
17 cell_magic, line_cell_magic,
17 cell_magic, line_cell_magic,
18 register_line_magic, register_cell_magic,
18 register_line_magic, register_cell_magic,
19 register_line_cell_magic)
19 register_line_cell_magic)
20 from decorator import decorator
20 from decorator import decorator
21 from IPython import get_ipython
21 from IPython import get_ipython
22 from IPython.testing.decorators import skipif
23 from IPython.testing.tools import AssertPrints, AssertNotPrints
22 from IPython.testing.tools import AssertPrints, AssertNotPrints
24 from IPython.utils.path import compress_user
23 from IPython.utils.path import compress_user
25 from IPython.utils import py3compat
26
24
27
25
28 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
29 # Globals and constants
27 # Globals and constants
30 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
31
29
32 inspector = oinspect.Inspector()
30 inspector = oinspect.Inspector()
33 ip = get_ipython()
31 ip = get_ipython()
34
32
35 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
36 # Local utilities
34 # Local utilities
37 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
38
36
39 # WARNING: since this test checks the line number where a function is
37 # WARNING: since this test checks the line number where a function is
40 # defined, if any code is inserted above, the following line will need to be
38 # defined, if any code is inserted above, the following line will need to be
41 # updated. Do NOT insert any whitespace between the next line and the function
39 # updated. Do NOT insert any whitespace between the next line and the function
42 # definition below.
40 # definition below.
43 THIS_LINE_NUMBER = 43 # Put here the actual number of this line
41 THIS_LINE_NUMBER = 41 # Put here the actual number of this line
44
42
45 from unittest import TestCase
43 from unittest import TestCase
46
44
47 class Test(TestCase):
45 class Test(TestCase):
48
46
49 def test_find_source_lines(self):
47 def test_find_source_lines(self):
50 self.assertEqual(oinspect.find_source_lines(Test.test_find_source_lines),
48 self.assertEqual(oinspect.find_source_lines(Test.test_find_source_lines),
51 THIS_LINE_NUMBER+6)
49 THIS_LINE_NUMBER+6)
52
50
53
51
54 # A couple of utilities to ensure these tests work the same from a source or a
52 # A couple of utilities to ensure these tests work the same from a source or a
55 # binary install
53 # binary install
56 def pyfile(fname):
54 def pyfile(fname):
57 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
55 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
58
56
59
57
60 def match_pyfiles(f1, f2):
58 def match_pyfiles(f1, f2):
61 nt.assert_equal(pyfile(f1), pyfile(f2))
59 nt.assert_equal(pyfile(f1), pyfile(f2))
62
60
63
61
64 def test_find_file():
62 def test_find_file():
65 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
63 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
66
64
67
65
68 def test_find_file_decorated1():
66 def test_find_file_decorated1():
69
67
70 @decorator
68 @decorator
71 def noop1(f):
69 def noop1(f):
72 def wrapper(*a, **kw):
70 def wrapper(*a, **kw):
73 return f(*a, **kw)
71 return f(*a, **kw)
74 return wrapper
72 return wrapper
75
73
76 @noop1
74 @noop1
77 def f(x):
75 def f(x):
78 "My docstring"
76 "My docstring"
79
77
80 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
78 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
81 nt.assert_equal(f.__doc__, "My docstring")
79 nt.assert_equal(f.__doc__, "My docstring")
82
80
83
81
84 def test_find_file_decorated2():
82 def test_find_file_decorated2():
85
83
86 @decorator
84 @decorator
87 def noop2(f, *a, **kw):
85 def noop2(f, *a, **kw):
88 return f(*a, **kw)
86 return f(*a, **kw)
89
87
90 @noop2
88 @noop2
91 @noop2
89 @noop2
92 @noop2
90 @noop2
93 def f(x):
91 def f(x):
94 "My docstring 2"
92 "My docstring 2"
95
93
96 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
94 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
97 nt.assert_equal(f.__doc__, "My docstring 2")
95 nt.assert_equal(f.__doc__, "My docstring 2")
98
96
99
97
100 def test_find_file_magic():
98 def test_find_file_magic():
101 run = ip.find_line_magic('run')
99 run = ip.find_line_magic('run')
102 nt.assert_not_equal(oinspect.find_file(run), None)
100 nt.assert_not_equal(oinspect.find_file(run), None)
103
101
104
102
105 # A few generic objects we can then inspect in the tests below
103 # A few generic objects we can then inspect in the tests below
106
104
107 class Call(object):
105 class Call(object):
108 """This is the class docstring."""
106 """This is the class docstring."""
109
107
110 def __init__(self, x, y=1):
108 def __init__(self, x, y=1):
111 """This is the constructor docstring."""
109 """This is the constructor docstring."""
112
110
113 def __call__(self, *a, **kw):
111 def __call__(self, *a, **kw):
114 """This is the call docstring."""
112 """This is the call docstring."""
115
113
116 def method(self, x, z=2):
114 def method(self, x, z=2):
117 """Some method's docstring"""
115 """Some method's docstring"""
118
116
119 class HasSignature(object):
117 class HasSignature(object):
120 """This is the class docstring."""
118 """This is the class docstring."""
121 __signature__ = Signature([Parameter('test', Parameter.POSITIONAL_OR_KEYWORD)])
119 __signature__ = Signature([Parameter('test', Parameter.POSITIONAL_OR_KEYWORD)])
122
120
123 def __init__(self, *args):
121 def __init__(self, *args):
124 """This is the init docstring"""
122 """This is the init docstring"""
125
123
126
124
127 class SimpleClass(object):
125 class SimpleClass(object):
128 def method(self, x, z=2):
126 def method(self, x, z=2):
129 """Some method's docstring"""
127 """Some method's docstring"""
130
128
131
129
132 class OldStyle:
130 class OldStyle:
133 """An old-style class for testing."""
131 """An old-style class for testing."""
134 pass
132 pass
135
133
136
134
137 def f(x, y=2, *a, **kw):
135 def f(x, y=2, *a, **kw):
138 """A simple function."""
136 """A simple function."""
139
137
140
138
141 def g(y, z=3, *a, **kw):
139 def g(y, z=3, *a, **kw):
142 pass # no docstring
140 pass # no docstring
143
141
144
142
145 @register_line_magic
143 @register_line_magic
146 def lmagic(line):
144 def lmagic(line):
147 "A line magic"
145 "A line magic"
148
146
149
147
150 @register_cell_magic
148 @register_cell_magic
151 def cmagic(line, cell):
149 def cmagic(line, cell):
152 "A cell magic"
150 "A cell magic"
153
151
154
152
155 @register_line_cell_magic
153 @register_line_cell_magic
156 def lcmagic(line, cell=None):
154 def lcmagic(line, cell=None):
157 "A line/cell magic"
155 "A line/cell magic"
158
156
159
157
160 @magics_class
158 @magics_class
161 class SimpleMagics(Magics):
159 class SimpleMagics(Magics):
162 @line_magic
160 @line_magic
163 def Clmagic(self, cline):
161 def Clmagic(self, cline):
164 "A class-based line magic"
162 "A class-based line magic"
165
163
166 @cell_magic
164 @cell_magic
167 def Ccmagic(self, cline, ccell):
165 def Ccmagic(self, cline, ccell):
168 "A class-based cell magic"
166 "A class-based cell magic"
169
167
170 @line_cell_magic
168 @line_cell_magic
171 def Clcmagic(self, cline, ccell=None):
169 def Clcmagic(self, cline, ccell=None):
172 "A class-based line/cell magic"
170 "A class-based line/cell magic"
173
171
174
172
175 class Awkward(object):
173 class Awkward(object):
176 def __getattr__(self, name):
174 def __getattr__(self, name):
177 raise Exception(name)
175 raise Exception(name)
178
176
179 class NoBoolCall:
177 class NoBoolCall:
180 """
178 """
181 callable with `__bool__` raising should still be inspect-able.
179 callable with `__bool__` raising should still be inspect-able.
182 """
180 """
183
181
184 def __call__(self):
182 def __call__(self):
185 """does nothing"""
183 """does nothing"""
186 pass
184 pass
187
185
188 def __bool__(self):
186 def __bool__(self):
189 """just raise NotImplemented"""
187 """just raise NotImplemented"""
190 raise NotImplementedError('Must be implemented')
188 raise NotImplementedError('Must be implemented')
191
189
192
190
193 class SerialLiar(object):
191 class SerialLiar(object):
194 """Attribute accesses always get another copy of the same class.
192 """Attribute accesses always get another copy of the same class.
195
193
196 unittest.mock.call does something similar, but it's not ideal for testing
194 unittest.mock.call does something similar, but it's not ideal for testing
197 as the failure mode is to eat all your RAM. This gives up after 10k levels.
195 as the failure mode is to eat all your RAM. This gives up after 10k levels.
198 """
196 """
199 def __init__(self, max_fibbing_twig, lies_told=0):
197 def __init__(self, max_fibbing_twig, lies_told=0):
200 if lies_told > 10000:
198 if lies_told > 10000:
201 raise RuntimeError('Nose too long, honesty is the best policy')
199 raise RuntimeError('Nose too long, honesty is the best policy')
202 self.max_fibbing_twig = max_fibbing_twig
200 self.max_fibbing_twig = max_fibbing_twig
203 self.lies_told = lies_told
201 self.lies_told = lies_told
204 max_fibbing_twig[0] = max(max_fibbing_twig[0], lies_told)
202 max_fibbing_twig[0] = max(max_fibbing_twig[0], lies_told)
205
203
206 def __getattr__(self, item):
204 def __getattr__(self, item):
207 return SerialLiar(self.max_fibbing_twig, self.lies_told + 1)
205 return SerialLiar(self.max_fibbing_twig, self.lies_told + 1)
208
206
209 #-----------------------------------------------------------------------------
207 #-----------------------------------------------------------------------------
210 # Tests
208 # Tests
211 #-----------------------------------------------------------------------------
209 #-----------------------------------------------------------------------------
212
210
213 def test_info():
211 def test_info():
214 "Check that Inspector.info fills out various fields as expected."
212 "Check that Inspector.info fills out various fields as expected."
215 i = inspector.info(Call, oname='Call')
213 i = inspector.info(Call, oname='Call')
216 nt.assert_equal(i['type_name'], 'type')
214 nt.assert_equal(i['type_name'], 'type')
217 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
215 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
218 nt.assert_equal(i['base_class'], expted_class)
216 nt.assert_equal(i['base_class'], expted_class)
219 nt.assert_regex(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'( at 0x[0-9a-f]{1,9})?>")
217 nt.assert_regex(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'( at 0x[0-9a-f]{1,9})?>")
220 fname = __file__
218 fname = __file__
221 if fname.endswith(".pyc"):
219 if fname.endswith(".pyc"):
222 fname = fname[:-1]
220 fname = fname[:-1]
223 # case-insensitive comparison needed on some filesystems
221 # case-insensitive comparison needed on some filesystems
224 # e.g. Windows:
222 # e.g. Windows:
225 nt.assert_equal(i['file'].lower(), compress_user(fname).lower())
223 nt.assert_equal(i['file'].lower(), compress_user(fname).lower())
226 nt.assert_equal(i['definition'], None)
224 nt.assert_equal(i['definition'], None)
227 nt.assert_equal(i['docstring'], Call.__doc__)
225 nt.assert_equal(i['docstring'], Call.__doc__)
228 nt.assert_equal(i['source'], None)
226 nt.assert_equal(i['source'], None)
229 nt.assert_true(i['isclass'])
227 nt.assert_true(i['isclass'])
230 nt.assert_equal(i['init_definition'], "Call(x, y=1)")
228 nt.assert_equal(i['init_definition'], "Call(x, y=1)")
231 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
229 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
232
230
233 i = inspector.info(Call, detail_level=1)
231 i = inspector.info(Call, detail_level=1)
234 nt.assert_not_equal(i['source'], None)
232 nt.assert_not_equal(i['source'], None)
235 nt.assert_equal(i['docstring'], None)
233 nt.assert_equal(i['docstring'], None)
236
234
237 c = Call(1)
235 c = Call(1)
238 c.__doc__ = "Modified instance docstring"
236 c.__doc__ = "Modified instance docstring"
239 i = inspector.info(c)
237 i = inspector.info(c)
240 nt.assert_equal(i['type_name'], 'Call')
238 nt.assert_equal(i['type_name'], 'Call')
241 nt.assert_equal(i['docstring'], "Modified instance docstring")
239 nt.assert_equal(i['docstring'], "Modified instance docstring")
242 nt.assert_equal(i['class_docstring'], Call.__doc__)
240 nt.assert_equal(i['class_docstring'], Call.__doc__)
243 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
241 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
244 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
242 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
245
243
246 def test_class_signature():
244 def test_class_signature():
247 info = inspector.info(HasSignature, 'HasSignature')
245 info = inspector.info(HasSignature, 'HasSignature')
248 nt.assert_equal(info['init_definition'], "HasSignature(test)")
246 nt.assert_equal(info['init_definition'], "HasSignature(test)")
249 nt.assert_equal(info['init_docstring'], HasSignature.__init__.__doc__)
247 nt.assert_equal(info['init_docstring'], HasSignature.__init__.__doc__)
250
248
251 def test_info_awkward():
249 def test_info_awkward():
252 # Just test that this doesn't throw an error.
250 # Just test that this doesn't throw an error.
253 inspector.info(Awkward())
251 inspector.info(Awkward())
254
252
255 def test_bool_raise():
253 def test_bool_raise():
256 inspector.info(NoBoolCall())
254 inspector.info(NoBoolCall())
257
255
258 def test_info_serialliar():
256 def test_info_serialliar():
259 fib_tracker = [0]
257 fib_tracker = [0]
260 inspector.info(SerialLiar(fib_tracker))
258 inspector.info(SerialLiar(fib_tracker))
261
259
262 # Nested attribute access should be cut off at 100 levels deep to avoid
260 # Nested attribute access should be cut off at 100 levels deep to avoid
263 # infinite loops: https://github.com/ipython/ipython/issues/9122
261 # infinite loops: https://github.com/ipython/ipython/issues/9122
264 nt.assert_less(fib_tracker[0], 9000)
262 nt.assert_less(fib_tracker[0], 9000)
265
263
266 def test_calldef_none():
264 def test_calldef_none():
267 # We should ignore __call__ for all of these.
265 # We should ignore __call__ for all of these.
268 for obj in [f, SimpleClass().method, any, str.upper]:
266 for obj in [f, SimpleClass().method, any, str.upper]:
269 print(obj)
267 print(obj)
270 i = inspector.info(obj)
268 i = inspector.info(obj)
271 nt.assert_is(i['call_def'], None)
269 nt.assert_is(i['call_def'], None)
272
270
273 def f_kwarg(pos, *, kwonly):
271 def f_kwarg(pos, *, kwonly):
274 pass
272 pass
275
273
276 def test_definition_kwonlyargs():
274 def test_definition_kwonlyargs():
277 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
275 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
278 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)")
276 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)")
279
277
280 def test_getdoc():
278 def test_getdoc():
281 class A(object):
279 class A(object):
282 """standard docstring"""
280 """standard docstring"""
283 pass
281 pass
284
282
285 class B(object):
283 class B(object):
286 """standard docstring"""
284 """standard docstring"""
287 def getdoc(self):
285 def getdoc(self):
288 return "custom docstring"
286 return "custom docstring"
289
287
290 class C(object):
288 class C(object):
291 """standard docstring"""
289 """standard docstring"""
292 def getdoc(self):
290 def getdoc(self):
293 return None
291 return None
294
292
295 a = A()
293 a = A()
296 b = B()
294 b = B()
297 c = C()
295 c = C()
298
296
299 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
297 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
300 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
298 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
301 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
299 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
302
300
303
301
304 def test_empty_property_has_no_source():
302 def test_empty_property_has_no_source():
305 i = inspector.info(property(), detail_level=1)
303 i = inspector.info(property(), detail_level=1)
306 nt.assert_is(i['source'], None)
304 nt.assert_is(i['source'], None)
307
305
308
306
309 def test_property_sources():
307 def test_property_sources():
310 import zlib
308 import zlib
311
309
312 class A(object):
310 class A(object):
313 @property
311 @property
314 def foo(self):
312 def foo(self):
315 return 'bar'
313 return 'bar'
316
314
317 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
315 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
318
316
319 id = property(id)
317 id = property(id)
320 compress = property(zlib.compress)
318 compress = property(zlib.compress)
321
319
322 i = inspector.info(A.foo, detail_level=1)
320 i = inspector.info(A.foo, detail_level=1)
323 nt.assert_in('def foo(self):', i['source'])
321 nt.assert_in('def foo(self):', i['source'])
324 nt.assert_in('lambda self, v:', i['source'])
322 nt.assert_in('lambda self, v:', i['source'])
325
323
326 i = inspector.info(A.id, detail_level=1)
324 i = inspector.info(A.id, detail_level=1)
327 nt.assert_in('fget = <function id>', i['source'])
325 nt.assert_in('fget = <function id>', i['source'])
328
326
329 i = inspector.info(A.compress, detail_level=1)
327 i = inspector.info(A.compress, detail_level=1)
330 nt.assert_in('fget = <function zlib.compress>', i['source'])
328 nt.assert_in('fget = <function zlib.compress>', i['source'])
331
329
332
330
333 def test_property_docstring_is_in_info_for_detail_level_0():
331 def test_property_docstring_is_in_info_for_detail_level_0():
334 class A(object):
332 class A(object):
335 @property
333 @property
336 def foobar(self):
334 def foobar(self):
337 """This is `foobar` property."""
335 """This is `foobar` property."""
338 pass
336 pass
339
337
340 ip.user_ns['a_obj'] = A()
338 ip.user_ns['a_obj'] = A()
341 nt.assert_equal(
339 nt.assert_equal(
342 'This is `foobar` property.',
340 'This is `foobar` property.',
343 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
341 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
344
342
345 ip.user_ns['a_cls'] = A
343 ip.user_ns['a_cls'] = A
346 nt.assert_equal(
344 nt.assert_equal(
347 'This is `foobar` property.',
345 'This is `foobar` property.',
348 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
346 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
349
347
350
348
351 def test_pdef():
349 def test_pdef():
352 # See gh-1914
350 # See gh-1914
353 def foo(): pass
351 def foo(): pass
354 inspector.pdef(foo, 'foo')
352 inspector.pdef(foo, 'foo')
355
353
356
354
357 def test_pinfo_nonascii():
355 def test_pinfo_nonascii():
358 # See gh-1177
356 # See gh-1177
359 from . import nonascii2
357 from . import nonascii2
360 ip.user_ns['nonascii2'] = nonascii2
358 ip.user_ns['nonascii2'] = nonascii2
361 ip._inspect('pinfo', 'nonascii2', detail_level=1)
359 ip._inspect('pinfo', 'nonascii2', detail_level=1)
362
360
363
361
364 def test_pinfo_docstring_no_source():
362 def test_pinfo_docstring_no_source():
365 """Docstring should be included with detail_level=1 if there is no source"""
363 """Docstring should be included with detail_level=1 if there is no source"""
366 with AssertPrints('Docstring:'):
364 with AssertPrints('Docstring:'):
367 ip._inspect('pinfo', 'str.format', detail_level=0)
365 ip._inspect('pinfo', 'str.format', detail_level=0)
368 with AssertPrints('Docstring:'):
366 with AssertPrints('Docstring:'):
369 ip._inspect('pinfo', 'str.format', detail_level=1)
367 ip._inspect('pinfo', 'str.format', detail_level=1)
370
368
371
369
372 def test_pinfo_no_docstring_if_source():
370 def test_pinfo_no_docstring_if_source():
373 """Docstring should not be included with detail_level=1 if source is found"""
371 """Docstring should not be included with detail_level=1 if source is found"""
374 def foo():
372 def foo():
375 """foo has a docstring"""
373 """foo has a docstring"""
376
374
377 ip.user_ns['foo'] = foo
375 ip.user_ns['foo'] = foo
378
376
379 with AssertPrints('Docstring:'):
377 with AssertPrints('Docstring:'):
380 ip._inspect('pinfo', 'foo', detail_level=0)
378 ip._inspect('pinfo', 'foo', detail_level=0)
381 with AssertPrints('Source:'):
379 with AssertPrints('Source:'):
382 ip._inspect('pinfo', 'foo', detail_level=1)
380 ip._inspect('pinfo', 'foo', detail_level=1)
383 with AssertNotPrints('Docstring:'):
381 with AssertNotPrints('Docstring:'):
384 ip._inspect('pinfo', 'foo', detail_level=1)
382 ip._inspect('pinfo', 'foo', detail_level=1)
385
383
386
384
387 def test_pinfo_docstring_if_detail_and_no_source():
385 def test_pinfo_docstring_if_detail_and_no_source():
388 """ Docstring should be displayed if source info not available """
386 """ Docstring should be displayed if source info not available """
389 obj_def = '''class Foo(object):
387 obj_def = '''class Foo(object):
390 """ This is a docstring for Foo """
388 """ This is a docstring for Foo """
391 def bar(self):
389 def bar(self):
392 """ This is a docstring for Foo.bar """
390 """ This is a docstring for Foo.bar """
393 pass
391 pass
394 '''
392 '''
395
393
396 ip.run_cell(obj_def)
394 ip.run_cell(obj_def)
397 ip.run_cell('foo = Foo()')
395 ip.run_cell('foo = Foo()')
398
396
399 with AssertNotPrints("Source:"):
397 with AssertNotPrints("Source:"):
400 with AssertPrints('Docstring:'):
398 with AssertPrints('Docstring:'):
401 ip._inspect('pinfo', 'foo', detail_level=0)
399 ip._inspect('pinfo', 'foo', detail_level=0)
402 with AssertPrints('Docstring:'):
400 with AssertPrints('Docstring:'):
403 ip._inspect('pinfo', 'foo', detail_level=1)
401 ip._inspect('pinfo', 'foo', detail_level=1)
404 with AssertPrints('Docstring:'):
402 with AssertPrints('Docstring:'):
405 ip._inspect('pinfo', 'foo.bar', detail_level=0)
403 ip._inspect('pinfo', 'foo.bar', detail_level=0)
406
404
407 with AssertNotPrints('Docstring:'):
405 with AssertNotPrints('Docstring:'):
408 with AssertPrints('Source:'):
406 with AssertPrints('Source:'):
409 ip._inspect('pinfo', 'foo.bar', detail_level=1)
407 ip._inspect('pinfo', 'foo.bar', detail_level=1)
410
408
411
409
412 def test_pinfo_magic():
410 def test_pinfo_magic():
413 with AssertPrints('Docstring:'):
411 with AssertPrints('Docstring:'):
414 ip._inspect('pinfo', 'lsmagic', detail_level=0)
412 ip._inspect('pinfo', 'lsmagic', detail_level=0)
415
413
416 with AssertPrints('Source:'):
414 with AssertPrints('Source:'):
417 ip._inspect('pinfo', 'lsmagic', detail_level=1)
415 ip._inspect('pinfo', 'lsmagic', detail_level=1)
418
416
419
417
420 def test_init_colors():
418 def test_init_colors():
421 # ensure colors are not present in signature info
419 # ensure colors are not present in signature info
422 info = inspector.info(HasSignature)
420 info = inspector.info(HasSignature)
423 init_def = info['init_definition']
421 init_def = info['init_definition']
424 nt.assert_not_in('[0m', init_def)
422 nt.assert_not_in('[0m', init_def)
425
423
426
424
427 def test_builtin_init():
425 def test_builtin_init():
428 info = inspector.info(list)
426 info = inspector.info(list)
429 init_def = info['init_definition']
427 init_def = info['init_definition']
430 # Python < 3.4 can't get init definition from builtins,
428 # Python < 3.4 can't get init definition from builtins,
431 # but still exercise the inspection in case of error-raising bugs.
429 # but still exercise the inspection in case of error-raising bugs.
432 if sys.version_info >= (3,4):
430 if sys.version_info >= (3,4):
433 nt.assert_is_not_none(init_def)
431 nt.assert_is_not_none(init_def)
434
432
@@ -1,39 +1,38
1 # coding: utf-8
1 # coding: utf-8
2 import nose.tools as nt
2 import nose.tools as nt
3
3
4 from IPython.core.splitinput import split_user_input, LineInfo
4 from IPython.core.splitinput import split_user_input, LineInfo
5 from IPython.testing import tools as tt
5 from IPython.testing import tools as tt
6 from IPython.utils import py3compat
7
6
8 tests = [
7 tests = [
9 ('x=1', ('', '', 'x', '=1')),
8 ('x=1', ('', '', 'x', '=1')),
10 ('?', ('', '?', '', '')),
9 ('?', ('', '?', '', '')),
11 ('??', ('', '??', '', '')),
10 ('??', ('', '??', '', '')),
12 (' ?', (' ', '?', '', '')),
11 (' ?', (' ', '?', '', '')),
13 (' ??', (' ', '??', '', '')),
12 (' ??', (' ', '??', '', '')),
14 ('??x', ('', '??', 'x', '')),
13 ('??x', ('', '??', 'x', '')),
15 ('?x=1', ('', '?', 'x', '=1')),
14 ('?x=1', ('', '?', 'x', '=1')),
16 ('!ls', ('', '!', 'ls', '')),
15 ('!ls', ('', '!', 'ls', '')),
17 (' !ls', (' ', '!', 'ls', '')),
16 (' !ls', (' ', '!', 'ls', '')),
18 ('!!ls', ('', '!!', 'ls', '')),
17 ('!!ls', ('', '!!', 'ls', '')),
19 (' !!ls', (' ', '!!', 'ls', '')),
18 (' !!ls', (' ', '!!', 'ls', '')),
20 (',ls', ('', ',', 'ls', '')),
19 (',ls', ('', ',', 'ls', '')),
21 (';ls', ('', ';', 'ls', '')),
20 (';ls', ('', ';', 'ls', '')),
22 (' ;ls', (' ', ';', 'ls', '')),
21 (' ;ls', (' ', ';', 'ls', '')),
23 ('f.g(x)', ('', '', 'f.g', '(x)')),
22 ('f.g(x)', ('', '', 'f.g', '(x)')),
24 ('f.g (x)', ('', '', 'f.g', '(x)')),
23 ('f.g (x)', ('', '', 'f.g', '(x)')),
25 ('?%hist1', ('', '?', '%hist1', '')),
24 ('?%hist1', ('', '?', '%hist1', '')),
26 ('?%%hist2', ('', '?', '%%hist2', '')),
25 ('?%%hist2', ('', '?', '%%hist2', '')),
27 ('??%hist3', ('', '??', '%hist3', '')),
26 ('??%hist3', ('', '??', '%hist3', '')),
28 ('??%%hist4', ('', '??', '%%hist4', '')),
27 ('??%%hist4', ('', '??', '%%hist4', '')),
29 ('?x*', ('', '?', 'x*', '')),
28 ('?x*', ('', '?', 'x*', '')),
30 ]
29 ]
31 tests.append((u"Pérez Fernando", (u'', u'', u'Pérez', u'Fernando')))
30 tests.append((u"Pérez Fernando", (u'', u'', u'Pérez', u'Fernando')))
32
31
33 def test_split_user_input():
32 def test_split_user_input():
34 return tt.check_pairs(split_user_input, tests)
33 return tt.check_pairs(split_user_input, tests)
35
34
36 def test_LineInfo():
35 def test_LineInfo():
37 """Simple test for LineInfo construction and str()"""
36 """Simple test for LineInfo construction and str()"""
38 linfo = LineInfo(' %cd /home')
37 linfo = LineInfo(' %cd /home')
39 nt.assert_equal(str(linfo), 'LineInfo [ |%|cd|/home]')
38 nt.assert_equal(str(linfo), 'LineInfo [ |%|cd|/home]')
@@ -1,1464 +1,1461
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Verbose and colourful traceback formatting.
3 Verbose and colourful traceback formatting.
4
4
5 **ColorTB**
5 **ColorTB**
6
6
7 I've always found it a bit hard to visually parse tracebacks in Python. The
7 I've always found it a bit hard to visually parse tracebacks in Python. The
8 ColorTB class is a solution to that problem. It colors the different parts of a
8 ColorTB class is a solution to that problem. It colors the different parts of a
9 traceback in a manner similar to what you would expect from a syntax-highlighting
9 traceback in a manner similar to what you would expect from a syntax-highlighting
10 text editor.
10 text editor.
11
11
12 Installation instructions for ColorTB::
12 Installation instructions for ColorTB::
13
13
14 import sys,ultratb
14 import sys,ultratb
15 sys.excepthook = ultratb.ColorTB()
15 sys.excepthook = ultratb.ColorTB()
16
16
17 **VerboseTB**
17 **VerboseTB**
18
18
19 I've also included a port of Ka-Ping Yee's "cgitb.py" that produces all kinds
19 I've also included a port of Ka-Ping Yee's "cgitb.py" that produces all kinds
20 of useful info when a traceback occurs. Ping originally had it spit out HTML
20 of useful info when a traceback occurs. Ping originally had it spit out HTML
21 and intended it for CGI programmers, but why should they have all the fun? I
21 and intended it for CGI programmers, but why should they have all the fun? I
22 altered it to spit out colored text to the terminal. It's a bit overwhelming,
22 altered it to spit out colored text to the terminal. It's a bit overwhelming,
23 but kind of neat, and maybe useful for long-running programs that you believe
23 but kind of neat, and maybe useful for long-running programs that you believe
24 are bug-free. If a crash *does* occur in that type of program you want details.
24 are bug-free. If a crash *does* occur in that type of program you want details.
25 Give it a shot--you'll love it or you'll hate it.
25 Give it a shot--you'll love it or you'll hate it.
26
26
27 .. note::
27 .. note::
28
28
29 The Verbose mode prints the variables currently visible where the exception
29 The Verbose mode prints the variables currently visible where the exception
30 happened (shortening their strings if too long). This can potentially be
30 happened (shortening their strings if too long). This can potentially be
31 very slow, if you happen to have a huge data structure whose string
31 very slow, if you happen to have a huge data structure whose string
32 representation is complex to compute. Your computer may appear to freeze for
32 representation is complex to compute. Your computer may appear to freeze for
33 a while with cpu usage at 100%. If this occurs, you can cancel the traceback
33 a while with cpu usage at 100%. If this occurs, you can cancel the traceback
34 with Ctrl-C (maybe hitting it more than once).
34 with Ctrl-C (maybe hitting it more than once).
35
35
36 If you encounter this kind of situation often, you may want to use the
36 If you encounter this kind of situation often, you may want to use the
37 Verbose_novars mode instead of the regular Verbose, which avoids formatting
37 Verbose_novars mode instead of the regular Verbose, which avoids formatting
38 variables (but otherwise includes the information and context given by
38 variables (but otherwise includes the information and context given by
39 Verbose).
39 Verbose).
40
40
41 .. note::
41 .. note::
42
42
43 The verbose mode print all variables in the stack, which means it can
43 The verbose mode print all variables in the stack, which means it can
44 potentially leak sensitive information like access keys, or unencryted
44 potentially leak sensitive information like access keys, or unencryted
45 password.
45 password.
46
46
47 Installation instructions for VerboseTB::
47 Installation instructions for VerboseTB::
48
48
49 import sys,ultratb
49 import sys,ultratb
50 sys.excepthook = ultratb.VerboseTB()
50 sys.excepthook = ultratb.VerboseTB()
51
51
52 Note: Much of the code in this module was lifted verbatim from the standard
52 Note: Much of the code in this module was lifted verbatim from the standard
53 library module 'traceback.py' and Ka-Ping Yee's 'cgitb.py'.
53 library module 'traceback.py' and Ka-Ping Yee's 'cgitb.py'.
54
54
55 Color schemes
55 Color schemes
56 -------------
56 -------------
57
57
58 The colors are defined in the class TBTools through the use of the
58 The colors are defined in the class TBTools through the use of the
59 ColorSchemeTable class. Currently the following exist:
59 ColorSchemeTable class. Currently the following exist:
60
60
61 - NoColor: allows all of this module to be used in any terminal (the color
61 - NoColor: allows all of this module to be used in any terminal (the color
62 escapes are just dummy blank strings).
62 escapes are just dummy blank strings).
63
63
64 - Linux: is meant to look good in a terminal like the Linux console (black
64 - Linux: is meant to look good in a terminal like the Linux console (black
65 or very dark background).
65 or very dark background).
66
66
67 - LightBG: similar to Linux but swaps dark/light colors to be more readable
67 - LightBG: similar to Linux but swaps dark/light colors to be more readable
68 in light background terminals.
68 in light background terminals.
69
69
70 - Neutral: a neutral color scheme that should be readable on both light and
70 - Neutral: a neutral color scheme that should be readable on both light and
71 dark background
71 dark background
72
72
73 You can implement other color schemes easily, the syntax is fairly
73 You can implement other color schemes easily, the syntax is fairly
74 self-explanatory. Please send back new schemes you develop to the author for
74 self-explanatory. Please send back new schemes you develop to the author for
75 possible inclusion in future releases.
75 possible inclusion in future releases.
76
76
77 Inheritance diagram:
77 Inheritance diagram:
78
78
79 .. inheritance-diagram:: IPython.core.ultratb
79 .. inheritance-diagram:: IPython.core.ultratb
80 :parts: 3
80 :parts: 3
81 """
81 """
82
82
83 #*****************************************************************************
83 #*****************************************************************************
84 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
84 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
85 # Copyright (C) 2001-2004 Fernando Perez <fperez@colorado.edu>
85 # Copyright (C) 2001-2004 Fernando Perez <fperez@colorado.edu>
86 #
86 #
87 # Distributed under the terms of the BSD License. The full license is in
87 # Distributed under the terms of the BSD License. The full license is in
88 # the file COPYING, distributed as part of this software.
88 # the file COPYING, distributed as part of this software.
89 #*****************************************************************************
89 #*****************************************************************************
90
90
91
91
92 import dis
92 import dis
93 import inspect
93 import inspect
94 import keyword
94 import keyword
95 import linecache
95 import linecache
96 import os
96 import os
97 import pydoc
97 import pydoc
98 import re
98 import re
99 import sys
99 import sys
100 import time
100 import time
101 import tokenize
101 import tokenize
102 import traceback
102 import traceback
103
103
104 try: # Python 2
104 try: # Python 2
105 generate_tokens = tokenize.generate_tokens
105 generate_tokens = tokenize.generate_tokens
106 except AttributeError: # Python 3
106 except AttributeError: # Python 3
107 generate_tokens = tokenize.tokenize
107 generate_tokens = tokenize.tokenize
108
108
109 # For purposes of monkeypatching inspect to fix a bug in it.
109 # For purposes of monkeypatching inspect to fix a bug in it.
110 from inspect import getsourcefile, getfile, getmodule, \
110 from inspect import getsourcefile, getfile, getmodule, \
111 ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode
111 ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode
112
112
113 # IPython's own modules
113 # IPython's own modules
114 from IPython import get_ipython
114 from IPython import get_ipython
115 from IPython.core import debugger
115 from IPython.core import debugger
116 from IPython.core.display_trap import DisplayTrap
116 from IPython.core.display_trap import DisplayTrap
117 from IPython.core.excolors import exception_colors
117 from IPython.core.excolors import exception_colors
118 from IPython.utils import PyColorize
118 from IPython.utils import PyColorize
119 from IPython.utils import openpy
119 from IPython.utils import openpy
120 from IPython.utils import path as util_path
120 from IPython.utils import path as util_path
121 from IPython.utils import py3compat
121 from IPython.utils import py3compat
122 from IPython.utils.data import uniq_stable
122 from IPython.utils.data import uniq_stable
123 from IPython.utils.terminal import get_terminal_size
123 from IPython.utils.terminal import get_terminal_size
124 from logging import info, error, debug
124 from logging import info, error, debug
125
125
126 import IPython.utils.colorable as colorable
126 import IPython.utils.colorable as colorable
127
127
128 # Globals
128 # Globals
129 # amount of space to put line numbers before verbose tracebacks
129 # amount of space to put line numbers before verbose tracebacks
130 INDENT_SIZE = 8
130 INDENT_SIZE = 8
131
131
132 # Default color scheme. This is used, for example, by the traceback
132 # Default color scheme. This is used, for example, by the traceback
133 # formatter. When running in an actual IPython instance, the user's rc.colors
133 # formatter. When running in an actual IPython instance, the user's rc.colors
134 # value is used, but having a module global makes this functionality available
134 # value is used, but having a module global makes this functionality available
135 # to users of ultratb who are NOT running inside ipython.
135 # to users of ultratb who are NOT running inside ipython.
136 DEFAULT_SCHEME = 'NoColor'
136 DEFAULT_SCHEME = 'NoColor'
137
137
138 # ---------------------------------------------------------------------------
138 # ---------------------------------------------------------------------------
139 # Code begins
139 # Code begins
140
140
141 # Utility functions
141 # Utility functions
142 def inspect_error():
142 def inspect_error():
143 """Print a message about internal inspect errors.
143 """Print a message about internal inspect errors.
144
144
145 These are unfortunately quite common."""
145 These are unfortunately quite common."""
146
146
147 error('Internal Python error in the inspect module.\n'
147 error('Internal Python error in the inspect module.\n'
148 'Below is the traceback from this internal error.\n')
148 'Below is the traceback from this internal error.\n')
149
149
150
150
151 # This function is a monkeypatch we apply to the Python inspect module. We have
151 # This function is a monkeypatch we apply to the Python inspect module. We have
152 # now found when it's needed (see discussion on issue gh-1456), and we have a
152 # now found when it's needed (see discussion on issue gh-1456), and we have a
153 # test case (IPython.core.tests.test_ultratb.ChangedPyFileTest) that fails if
153 # test case (IPython.core.tests.test_ultratb.ChangedPyFileTest) that fails if
154 # the monkeypatch is not applied. TK, Aug 2012.
154 # the monkeypatch is not applied. TK, Aug 2012.
155 def findsource(object):
155 def findsource(object):
156 """Return the entire source file and starting line number for an object.
156 """Return the entire source file and starting line number for an object.
157
157
158 The argument may be a module, class, method, function, traceback, frame,
158 The argument may be a module, class, method, function, traceback, frame,
159 or code object. The source code is returned as a list of all the lines
159 or code object. The source code is returned as a list of all the lines
160 in the file and the line number indexes a line in that list. An IOError
160 in the file and the line number indexes a line in that list. An IOError
161 is raised if the source code cannot be retrieved.
161 is raised if the source code cannot be retrieved.
162
162
163 FIXED version with which we monkeypatch the stdlib to work around a bug."""
163 FIXED version with which we monkeypatch the stdlib to work around a bug."""
164
164
165 file = getsourcefile(object) or getfile(object)
165 file = getsourcefile(object) or getfile(object)
166 # If the object is a frame, then trying to get the globals dict from its
166 # If the object is a frame, then trying to get the globals dict from its
167 # module won't work. Instead, the frame object itself has the globals
167 # module won't work. Instead, the frame object itself has the globals
168 # dictionary.
168 # dictionary.
169 globals_dict = None
169 globals_dict = None
170 if inspect.isframe(object):
170 if inspect.isframe(object):
171 # XXX: can this ever be false?
171 # XXX: can this ever be false?
172 globals_dict = object.f_globals
172 globals_dict = object.f_globals
173 else:
173 else:
174 module = getmodule(object, file)
174 module = getmodule(object, file)
175 if module:
175 if module:
176 globals_dict = module.__dict__
176 globals_dict = module.__dict__
177 lines = linecache.getlines(file, globals_dict)
177 lines = linecache.getlines(file, globals_dict)
178 if not lines:
178 if not lines:
179 raise IOError('could not get source code')
179 raise IOError('could not get source code')
180
180
181 if ismodule(object):
181 if ismodule(object):
182 return lines, 0
182 return lines, 0
183
183
184 if isclass(object):
184 if isclass(object):
185 name = object.__name__
185 name = object.__name__
186 pat = re.compile(r'^(\s*)class\s*' + name + r'\b')
186 pat = re.compile(r'^(\s*)class\s*' + name + r'\b')
187 # make some effort to find the best matching class definition:
187 # make some effort to find the best matching class definition:
188 # use the one with the least indentation, which is the one
188 # use the one with the least indentation, which is the one
189 # that's most probably not inside a function definition.
189 # that's most probably not inside a function definition.
190 candidates = []
190 candidates = []
191 for i, line in enumerate(lines):
191 for i, line in enumerate(lines):
192 match = pat.match(line)
192 match = pat.match(line)
193 if match:
193 if match:
194 # if it's at toplevel, it's already the best one
194 # if it's at toplevel, it's already the best one
195 if line[0] == 'c':
195 if line[0] == 'c':
196 return lines, i
196 return lines, i
197 # else add whitespace to candidate list
197 # else add whitespace to candidate list
198 candidates.append((match.group(1), i))
198 candidates.append((match.group(1), i))
199 if candidates:
199 if candidates:
200 # this will sort by whitespace, and by line number,
200 # this will sort by whitespace, and by line number,
201 # less whitespace first
201 # less whitespace first
202 candidates.sort()
202 candidates.sort()
203 return lines, candidates[0][1]
203 return lines, candidates[0][1]
204 else:
204 else:
205 raise IOError('could not find class definition')
205 raise IOError('could not find class definition')
206
206
207 if ismethod(object):
207 if ismethod(object):
208 object = object.__func__
208 object = object.__func__
209 if isfunction(object):
209 if isfunction(object):
210 object = object.__code__
210 object = object.__code__
211 if istraceback(object):
211 if istraceback(object):
212 object = object.tb_frame
212 object = object.tb_frame
213 if isframe(object):
213 if isframe(object):
214 object = object.f_code
214 object = object.f_code
215 if iscode(object):
215 if iscode(object):
216 if not hasattr(object, 'co_firstlineno'):
216 if not hasattr(object, 'co_firstlineno'):
217 raise IOError('could not find function definition')
217 raise IOError('could not find function definition')
218 pat = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)')
218 pat = re.compile(r'^(\s*def\s)|(.*(?<!\w)lambda(:|\s))|^(\s*@)')
219 pmatch = pat.match
219 pmatch = pat.match
220 # fperez - fix: sometimes, co_firstlineno can give a number larger than
220 # fperez - fix: sometimes, co_firstlineno can give a number larger than
221 # the length of lines, which causes an error. Safeguard against that.
221 # the length of lines, which causes an error. Safeguard against that.
222 lnum = min(object.co_firstlineno, len(lines)) - 1
222 lnum = min(object.co_firstlineno, len(lines)) - 1
223 while lnum > 0:
223 while lnum > 0:
224 if pmatch(lines[lnum]):
224 if pmatch(lines[lnum]):
225 break
225 break
226 lnum -= 1
226 lnum -= 1
227
227
228 return lines, lnum
228 return lines, lnum
229 raise IOError('could not find code object')
229 raise IOError('could not find code object')
230
230
231
231
232 # This is a patched version of inspect.getargs that applies the (unmerged)
232 # This is a patched version of inspect.getargs that applies the (unmerged)
233 # patch for http://bugs.python.org/issue14611 by Stefano Taschini. This fixes
233 # patch for http://bugs.python.org/issue14611 by Stefano Taschini. This fixes
234 # https://github.com/ipython/ipython/issues/8205 and
234 # https://github.com/ipython/ipython/issues/8205 and
235 # https://github.com/ipython/ipython/issues/8293
235 # https://github.com/ipython/ipython/issues/8293
236 def getargs(co):
236 def getargs(co):
237 """Get information about the arguments accepted by a code object.
237 """Get information about the arguments accepted by a code object.
238
238
239 Three things are returned: (args, varargs, varkw), where 'args' is
239 Three things are returned: (args, varargs, varkw), where 'args' is
240 a list of argument names (possibly containing nested lists), and
240 a list of argument names (possibly containing nested lists), and
241 'varargs' and 'varkw' are the names of the * and ** arguments or None."""
241 'varargs' and 'varkw' are the names of the * and ** arguments or None."""
242 if not iscode(co):
242 if not iscode(co):
243 raise TypeError('{!r} is not a code object'.format(co))
243 raise TypeError('{!r} is not a code object'.format(co))
244
244
245 nargs = co.co_argcount
245 nargs = co.co_argcount
246 names = co.co_varnames
246 names = co.co_varnames
247 args = list(names[:nargs])
247 args = list(names[:nargs])
248 step = 0
248 step = 0
249
249
250 # The following acrobatics are for anonymous (tuple) arguments.
250 # The following acrobatics are for anonymous (tuple) arguments.
251 for i in range(nargs):
251 for i in range(nargs):
252 if args[i][:1] in ('', '.'):
252 if args[i][:1] in ('', '.'):
253 stack, remain, count = [], [], []
253 stack, remain, count = [], [], []
254 while step < len(co.co_code):
254 while step < len(co.co_code):
255 op = ord(co.co_code[step])
255 op = ord(co.co_code[step])
256 step = step + 1
256 step = step + 1
257 if op >= dis.HAVE_ARGUMENT:
257 if op >= dis.HAVE_ARGUMENT:
258 opname = dis.opname[op]
258 opname = dis.opname[op]
259 value = ord(co.co_code[step]) + ord(co.co_code[step+1])*256
259 value = ord(co.co_code[step]) + ord(co.co_code[step+1])*256
260 step = step + 2
260 step = step + 2
261 if opname in ('UNPACK_TUPLE', 'UNPACK_SEQUENCE'):
261 if opname in ('UNPACK_TUPLE', 'UNPACK_SEQUENCE'):
262 remain.append(value)
262 remain.append(value)
263 count.append(value)
263 count.append(value)
264 elif opname in ('STORE_FAST', 'STORE_DEREF'):
264 elif opname in ('STORE_FAST', 'STORE_DEREF'):
265 if op in dis.haslocal:
265 if op in dis.haslocal:
266 stack.append(co.co_varnames[value])
266 stack.append(co.co_varnames[value])
267 elif op in dis.hasfree:
267 elif op in dis.hasfree:
268 stack.append((co.co_cellvars + co.co_freevars)[value])
268 stack.append((co.co_cellvars + co.co_freevars)[value])
269 # Special case for sublists of length 1: def foo((bar))
269 # Special case for sublists of length 1: def foo((bar))
270 # doesn't generate the UNPACK_TUPLE bytecode, so if
270 # doesn't generate the UNPACK_TUPLE bytecode, so if
271 # `remain` is empty here, we have such a sublist.
271 # `remain` is empty here, we have such a sublist.
272 if not remain:
272 if not remain:
273 stack[0] = [stack[0]]
273 stack[0] = [stack[0]]
274 break
274 break
275 else:
275 else:
276 remain[-1] = remain[-1] - 1
276 remain[-1] = remain[-1] - 1
277 while remain[-1] == 0:
277 while remain[-1] == 0:
278 remain.pop()
278 remain.pop()
279 size = count.pop()
279 size = count.pop()
280 stack[-size:] = [stack[-size:]]
280 stack[-size:] = [stack[-size:]]
281 if not remain:
281 if not remain:
282 break
282 break
283 remain[-1] = remain[-1] - 1
283 remain[-1] = remain[-1] - 1
284 if not remain:
284 if not remain:
285 break
285 break
286 args[i] = stack[0]
286 args[i] = stack[0]
287
287
288 varargs = None
288 varargs = None
289 if co.co_flags & inspect.CO_VARARGS:
289 if co.co_flags & inspect.CO_VARARGS:
290 varargs = co.co_varnames[nargs]
290 varargs = co.co_varnames[nargs]
291 nargs = nargs + 1
291 nargs = nargs + 1
292 varkw = None
292 varkw = None
293 if co.co_flags & inspect.CO_VARKEYWORDS:
293 if co.co_flags & inspect.CO_VARKEYWORDS:
294 varkw = co.co_varnames[nargs]
294 varkw = co.co_varnames[nargs]
295 return inspect.Arguments(args, varargs, varkw)
295 return inspect.Arguments(args, varargs, varkw)
296
296
297
297
298 # Monkeypatch inspect to apply our bugfix.
298 # Monkeypatch inspect to apply our bugfix.
299 def with_patch_inspect(f):
299 def with_patch_inspect(f):
300 """
300 """
301 Deprecated since IPython 6.0
301 Deprecated since IPython 6.0
302 decorator for monkeypatching inspect.findsource
302 decorator for monkeypatching inspect.findsource
303 """
303 """
304
304
305 def wrapped(*args, **kwargs):
305 def wrapped(*args, **kwargs):
306 save_findsource = inspect.findsource
306 save_findsource = inspect.findsource
307 save_getargs = inspect.getargs
307 save_getargs = inspect.getargs
308 inspect.findsource = findsource
308 inspect.findsource = findsource
309 inspect.getargs = getargs
309 inspect.getargs = getargs
310 try:
310 try:
311 return f(*args, **kwargs)
311 return f(*args, **kwargs)
312 finally:
312 finally:
313 inspect.findsource = save_findsource
313 inspect.findsource = save_findsource
314 inspect.getargs = save_getargs
314 inspect.getargs = save_getargs
315
315
316 return wrapped
316 return wrapped
317
317
318
318
319 def fix_frame_records_filenames(records):
319 def fix_frame_records_filenames(records):
320 """Try to fix the filenames in each record from inspect.getinnerframes().
320 """Try to fix the filenames in each record from inspect.getinnerframes().
321
321
322 Particularly, modules loaded from within zip files have useless filenames
322 Particularly, modules loaded from within zip files have useless filenames
323 attached to their code object, and inspect.getinnerframes() just uses it.
323 attached to their code object, and inspect.getinnerframes() just uses it.
324 """
324 """
325 fixed_records = []
325 fixed_records = []
326 for frame, filename, line_no, func_name, lines, index in records:
326 for frame, filename, line_no, func_name, lines, index in records:
327 # Look inside the frame's globals dictionary for __file__,
327 # Look inside the frame's globals dictionary for __file__,
328 # which should be better. However, keep Cython filenames since
328 # which should be better. However, keep Cython filenames since
329 # we prefer the source filenames over the compiled .so file.
329 # we prefer the source filenames over the compiled .so file.
330 if not filename.endswith(('.pyx', '.pxd', '.pxi')):
330 if not filename.endswith(('.pyx', '.pxd', '.pxi')):
331 better_fn = frame.f_globals.get('__file__', None)
331 better_fn = frame.f_globals.get('__file__', None)
332 if isinstance(better_fn, str):
332 if isinstance(better_fn, str):
333 # Check the type just in case someone did something weird with
333 # Check the type just in case someone did something weird with
334 # __file__. It might also be None if the error occurred during
334 # __file__. It might also be None if the error occurred during
335 # import.
335 # import.
336 filename = better_fn
336 filename = better_fn
337 fixed_records.append((frame, filename, line_no, func_name, lines, index))
337 fixed_records.append((frame, filename, line_no, func_name, lines, index))
338 return fixed_records
338 return fixed_records
339
339
340
340
341 @with_patch_inspect
341 @with_patch_inspect
342 def _fixed_getinnerframes(etb, context=1, tb_offset=0):
342 def _fixed_getinnerframes(etb, context=1, tb_offset=0):
343 LNUM_POS, LINES_POS, INDEX_POS = 2, 4, 5
343 LNUM_POS, LINES_POS, INDEX_POS = 2, 4, 5
344
344
345 records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
345 records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
346 # If the error is at the console, don't build any context, since it would
346 # If the error is at the console, don't build any context, since it would
347 # otherwise produce 5 blank lines printed out (there is no file at the
347 # otherwise produce 5 blank lines printed out (there is no file at the
348 # console)
348 # console)
349 rec_check = records[tb_offset:]
349 rec_check = records[tb_offset:]
350 try:
350 try:
351 rname = rec_check[0][1]
351 rname = rec_check[0][1]
352 if rname == '<ipython console>' or rname.endswith('<string>'):
352 if rname == '<ipython console>' or rname.endswith('<string>'):
353 return rec_check
353 return rec_check
354 except IndexError:
354 except IndexError:
355 pass
355 pass
356
356
357 aux = traceback.extract_tb(etb)
357 aux = traceback.extract_tb(etb)
358 assert len(records) == len(aux)
358 assert len(records) == len(aux)
359 for i, (file, lnum, _, _) in enumerate(aux):
359 for i, (file, lnum, _, _) in enumerate(aux):
360 maybeStart = lnum - 1 - context // 2
360 maybeStart = lnum - 1 - context // 2
361 start = max(maybeStart, 0)
361 start = max(maybeStart, 0)
362 end = start + context
362 end = start + context
363 lines = linecache.getlines(file)[start:end]
363 lines = linecache.getlines(file)[start:end]
364 buf = list(records[i])
364 buf = list(records[i])
365 buf[LNUM_POS] = lnum
365 buf[LNUM_POS] = lnum
366 buf[INDEX_POS] = lnum - 1 - start
366 buf[INDEX_POS] = lnum - 1 - start
367 buf[LINES_POS] = lines
367 buf[LINES_POS] = lines
368 records[i] = tuple(buf)
368 records[i] = tuple(buf)
369 return records[tb_offset:]
369 return records[tb_offset:]
370
370
371 # Helper function -- largely belongs to VerboseTB, but we need the same
371 # Helper function -- largely belongs to VerboseTB, but we need the same
372 # functionality to produce a pseudo verbose TB for SyntaxErrors, so that they
372 # functionality to produce a pseudo verbose TB for SyntaxErrors, so that they
373 # can be recognized properly by ipython.el's py-traceback-line-re
373 # can be recognized properly by ipython.el's py-traceback-line-re
374 # (SyntaxErrors have to be treated specially because they have no traceback)
374 # (SyntaxErrors have to be treated specially because they have no traceback)
375
375
376
376
377 def _format_traceback_lines(lnum, index, lines, Colors, lvals=None, _line_format=(lambda x,_:x,None)):
377 def _format_traceback_lines(lnum, index, lines, Colors, lvals=None, _line_format=(lambda x,_:x,None)):
378 numbers_width = INDENT_SIZE - 1
378 numbers_width = INDENT_SIZE - 1
379 res = []
379 res = []
380 i = lnum - index
380 i = lnum - index
381
381
382 for line in lines:
382 for line in lines:
383 line = py3compat.cast_unicode(line)
383 line = py3compat.cast_unicode(line)
384
384
385 new_line, err = _line_format(line, 'str')
385 new_line, err = _line_format(line, 'str')
386 if not err: line = new_line
386 if not err: line = new_line
387
387
388 if i == lnum:
388 if i == lnum:
389 # This is the line with the error
389 # This is the line with the error
390 pad = numbers_width - len(str(i))
390 pad = numbers_width - len(str(i))
391 num = '%s%s' % (debugger.make_arrow(pad), str(lnum))
391 num = '%s%s' % (debugger.make_arrow(pad), str(lnum))
392 line = '%s%s%s %s%s' % (Colors.linenoEm, num,
392 line = '%s%s%s %s%s' % (Colors.linenoEm, num,
393 Colors.line, line, Colors.Normal)
393 Colors.line, line, Colors.Normal)
394 else:
394 else:
395 num = '%*s' % (numbers_width, i)
395 num = '%*s' % (numbers_width, i)
396 line = '%s%s%s %s' % (Colors.lineno, num,
396 line = '%s%s%s %s' % (Colors.lineno, num,
397 Colors.Normal, line)
397 Colors.Normal, line)
398
398
399 res.append(line)
399 res.append(line)
400 if lvals and i == lnum:
400 if lvals and i == lnum:
401 res.append(lvals + '\n')
401 res.append(lvals + '\n')
402 i = i + 1
402 i = i + 1
403 return res
403 return res
404
404
405 def is_recursion_error(etype, value, records):
405 def is_recursion_error(etype, value, records):
406 try:
406 try:
407 # RecursionError is new in Python 3.5
407 # RecursionError is new in Python 3.5
408 recursion_error_type = RecursionError
408 recursion_error_type = RecursionError
409 except NameError:
409 except NameError:
410 recursion_error_type = RuntimeError
410 recursion_error_type = RuntimeError
411
411
412 # The default recursion limit is 1000, but some of that will be taken up
412 # The default recursion limit is 1000, but some of that will be taken up
413 # by stack frames in IPython itself. >500 frames probably indicates
413 # by stack frames in IPython itself. >500 frames probably indicates
414 # a recursion error.
414 # a recursion error.
415 return (etype is recursion_error_type) \
415 return (etype is recursion_error_type) \
416 and "recursion" in str(value).lower() \
416 and "recursion" in str(value).lower() \
417 and len(records) > 500
417 and len(records) > 500
418
418
419 def find_recursion(etype, value, records):
419 def find_recursion(etype, value, records):
420 """Identify the repeating stack frames from a RecursionError traceback
420 """Identify the repeating stack frames from a RecursionError traceback
421
421
422 'records' is a list as returned by VerboseTB.get_records()
422 'records' is a list as returned by VerboseTB.get_records()
423
423
424 Returns (last_unique, repeat_length)
424 Returns (last_unique, repeat_length)
425 """
425 """
426 # This involves a bit of guesswork - we want to show enough of the traceback
426 # This involves a bit of guesswork - we want to show enough of the traceback
427 # to indicate where the recursion is occurring. We guess that the innermost
427 # to indicate where the recursion is occurring. We guess that the innermost
428 # quarter of the traceback (250 frames by default) is repeats, and find the
428 # quarter of the traceback (250 frames by default) is repeats, and find the
429 # first frame (from in to out) that looks different.
429 # first frame (from in to out) that looks different.
430 if not is_recursion_error(etype, value, records):
430 if not is_recursion_error(etype, value, records):
431 return len(records), 0
431 return len(records), 0
432
432
433 # Select filename, lineno, func_name to track frames with
433 # Select filename, lineno, func_name to track frames with
434 records = [r[1:4] for r in records]
434 records = [r[1:4] for r in records]
435 inner_frames = records[-(len(records)//4):]
435 inner_frames = records[-(len(records)//4):]
436 frames_repeated = set(inner_frames)
436 frames_repeated = set(inner_frames)
437
437
438 last_seen_at = {}
438 last_seen_at = {}
439 longest_repeat = 0
439 longest_repeat = 0
440 i = len(records)
440 i = len(records)
441 for frame in reversed(records):
441 for frame in reversed(records):
442 i -= 1
442 i -= 1
443 if frame not in frames_repeated:
443 if frame not in frames_repeated:
444 last_unique = i
444 last_unique = i
445 break
445 break
446
446
447 if frame in last_seen_at:
447 if frame in last_seen_at:
448 distance = last_seen_at[frame] - i
448 distance = last_seen_at[frame] - i
449 longest_repeat = max(longest_repeat, distance)
449 longest_repeat = max(longest_repeat, distance)
450
450
451 last_seen_at[frame] = i
451 last_seen_at[frame] = i
452 else:
452 else:
453 last_unique = 0 # The whole traceback was recursion
453 last_unique = 0 # The whole traceback was recursion
454
454
455 return last_unique, longest_repeat
455 return last_unique, longest_repeat
456
456
457 #---------------------------------------------------------------------------
457 #---------------------------------------------------------------------------
458 # Module classes
458 # Module classes
459 class TBTools(colorable.Colorable):
459 class TBTools(colorable.Colorable):
460 """Basic tools used by all traceback printer classes."""
460 """Basic tools used by all traceback printer classes."""
461
461
462 # Number of frames to skip when reporting tracebacks
462 # Number of frames to skip when reporting tracebacks
463 tb_offset = 0
463 tb_offset = 0
464
464
465 def __init__(self, color_scheme='NoColor', call_pdb=False, ostream=None, parent=None, config=None):
465 def __init__(self, color_scheme='NoColor', call_pdb=False, ostream=None, parent=None, config=None):
466 # Whether to call the interactive pdb debugger after printing
466 # Whether to call the interactive pdb debugger after printing
467 # tracebacks or not
467 # tracebacks or not
468 super(TBTools, self).__init__(parent=parent, config=config)
468 super(TBTools, self).__init__(parent=parent, config=config)
469 self.call_pdb = call_pdb
469 self.call_pdb = call_pdb
470
470
471 # Output stream to write to. Note that we store the original value in
471 # Output stream to write to. Note that we store the original value in
472 # a private attribute and then make the public ostream a property, so
472 # a private attribute and then make the public ostream a property, so
473 # that we can delay accessing sys.stdout until runtime. The way
473 # that we can delay accessing sys.stdout until runtime. The way
474 # things are written now, the sys.stdout object is dynamically managed
474 # things are written now, the sys.stdout object is dynamically managed
475 # so a reference to it should NEVER be stored statically. This
475 # so a reference to it should NEVER be stored statically. This
476 # property approach confines this detail to a single location, and all
476 # property approach confines this detail to a single location, and all
477 # subclasses can simply access self.ostream for writing.
477 # subclasses can simply access self.ostream for writing.
478 self._ostream = ostream
478 self._ostream = ostream
479
479
480 # Create color table
480 # Create color table
481 self.color_scheme_table = exception_colors()
481 self.color_scheme_table = exception_colors()
482
482
483 self.set_colors(color_scheme)
483 self.set_colors(color_scheme)
484 self.old_scheme = color_scheme # save initial value for toggles
484 self.old_scheme = color_scheme # save initial value for toggles
485
485
486 if call_pdb:
486 if call_pdb:
487 self.pdb = debugger.Pdb()
487 self.pdb = debugger.Pdb()
488 else:
488 else:
489 self.pdb = None
489 self.pdb = None
490
490
491 def _get_ostream(self):
491 def _get_ostream(self):
492 """Output stream that exceptions are written to.
492 """Output stream that exceptions are written to.
493
493
494 Valid values are:
494 Valid values are:
495
495
496 - None: the default, which means that IPython will dynamically resolve
496 - None: the default, which means that IPython will dynamically resolve
497 to sys.stdout. This ensures compatibility with most tools, including
497 to sys.stdout. This ensures compatibility with most tools, including
498 Windows (where plain stdout doesn't recognize ANSI escapes).
498 Windows (where plain stdout doesn't recognize ANSI escapes).
499
499
500 - Any object with 'write' and 'flush' attributes.
500 - Any object with 'write' and 'flush' attributes.
501 """
501 """
502 return sys.stdout if self._ostream is None else self._ostream
502 return sys.stdout if self._ostream is None else self._ostream
503
503
504 def _set_ostream(self, val):
504 def _set_ostream(self, val):
505 assert val is None or (hasattr(val, 'write') and hasattr(val, 'flush'))
505 assert val is None or (hasattr(val, 'write') and hasattr(val, 'flush'))
506 self._ostream = val
506 self._ostream = val
507
507
508 ostream = property(_get_ostream, _set_ostream)
508 ostream = property(_get_ostream, _set_ostream)
509
509
510 def set_colors(self, *args, **kw):
510 def set_colors(self, *args, **kw):
511 """Shorthand access to the color table scheme selector method."""
511 """Shorthand access to the color table scheme selector method."""
512
512
513 # Set own color table
513 # Set own color table
514 self.color_scheme_table.set_active_scheme(*args, **kw)
514 self.color_scheme_table.set_active_scheme(*args, **kw)
515 # for convenience, set Colors to the active scheme
515 # for convenience, set Colors to the active scheme
516 self.Colors = self.color_scheme_table.active_colors
516 self.Colors = self.color_scheme_table.active_colors
517 # Also set colors of debugger
517 # Also set colors of debugger
518 if hasattr(self, 'pdb') and self.pdb is not None:
518 if hasattr(self, 'pdb') and self.pdb is not None:
519 self.pdb.set_colors(*args, **kw)
519 self.pdb.set_colors(*args, **kw)
520
520
521 def color_toggle(self):
521 def color_toggle(self):
522 """Toggle between the currently active color scheme and NoColor."""
522 """Toggle between the currently active color scheme and NoColor."""
523
523
524 if self.color_scheme_table.active_scheme_name == 'NoColor':
524 if self.color_scheme_table.active_scheme_name == 'NoColor':
525 self.color_scheme_table.set_active_scheme(self.old_scheme)
525 self.color_scheme_table.set_active_scheme(self.old_scheme)
526 self.Colors = self.color_scheme_table.active_colors
526 self.Colors = self.color_scheme_table.active_colors
527 else:
527 else:
528 self.old_scheme = self.color_scheme_table.active_scheme_name
528 self.old_scheme = self.color_scheme_table.active_scheme_name
529 self.color_scheme_table.set_active_scheme('NoColor')
529 self.color_scheme_table.set_active_scheme('NoColor')
530 self.Colors = self.color_scheme_table.active_colors
530 self.Colors = self.color_scheme_table.active_colors
531
531
532 def stb2text(self, stb):
532 def stb2text(self, stb):
533 """Convert a structured traceback (a list) to a string."""
533 """Convert a structured traceback (a list) to a string."""
534 return '\n'.join(stb)
534 return '\n'.join(stb)
535
535
536 def text(self, etype, value, tb, tb_offset=None, context=5):
536 def text(self, etype, value, tb, tb_offset=None, context=5):
537 """Return formatted traceback.
537 """Return formatted traceback.
538
538
539 Subclasses may override this if they add extra arguments.
539 Subclasses may override this if they add extra arguments.
540 """
540 """
541 tb_list = self.structured_traceback(etype, value, tb,
541 tb_list = self.structured_traceback(etype, value, tb,
542 tb_offset, context)
542 tb_offset, context)
543 return self.stb2text(tb_list)
543 return self.stb2text(tb_list)
544
544
545 def structured_traceback(self, etype, evalue, tb, tb_offset=None,
545 def structured_traceback(self, etype, evalue, tb, tb_offset=None,
546 context=5, mode=None):
546 context=5, mode=None):
547 """Return a list of traceback frames.
547 """Return a list of traceback frames.
548
548
549 Must be implemented by each class.
549 Must be implemented by each class.
550 """
550 """
551 raise NotImplementedError()
551 raise NotImplementedError()
552
552
553
553
554 #---------------------------------------------------------------------------
554 #---------------------------------------------------------------------------
555 class ListTB(TBTools):
555 class ListTB(TBTools):
556 """Print traceback information from a traceback list, with optional color.
556 """Print traceback information from a traceback list, with optional color.
557
557
558 Calling requires 3 arguments: (etype, evalue, elist)
558 Calling requires 3 arguments: (etype, evalue, elist)
559 as would be obtained by::
559 as would be obtained by::
560
560
561 etype, evalue, tb = sys.exc_info()
561 etype, evalue, tb = sys.exc_info()
562 if tb:
562 if tb:
563 elist = traceback.extract_tb(tb)
563 elist = traceback.extract_tb(tb)
564 else:
564 else:
565 elist = None
565 elist = None
566
566
567 It can thus be used by programs which need to process the traceback before
567 It can thus be used by programs which need to process the traceback before
568 printing (such as console replacements based on the code module from the
568 printing (such as console replacements based on the code module from the
569 standard library).
569 standard library).
570
570
571 Because they are meant to be called without a full traceback (only a
571 Because they are meant to be called without a full traceback (only a
572 list), instances of this class can't call the interactive pdb debugger."""
572 list), instances of this class can't call the interactive pdb debugger."""
573
573
574 def __init__(self, color_scheme='NoColor', call_pdb=False, ostream=None, parent=None, config=None):
574 def __init__(self, color_scheme='NoColor', call_pdb=False, ostream=None, parent=None, config=None):
575 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
575 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
576 ostream=ostream, parent=parent,config=config)
576 ostream=ostream, parent=parent,config=config)
577
577
578 def __call__(self, etype, value, elist):
578 def __call__(self, etype, value, elist):
579 self.ostream.flush()
579 self.ostream.flush()
580 self.ostream.write(self.text(etype, value, elist))
580 self.ostream.write(self.text(etype, value, elist))
581 self.ostream.write('\n')
581 self.ostream.write('\n')
582
582
583 def structured_traceback(self, etype, value, elist, tb_offset=None,
583 def structured_traceback(self, etype, value, elist, tb_offset=None,
584 context=5):
584 context=5):
585 """Return a color formatted string with the traceback info.
585 """Return a color formatted string with the traceback info.
586
586
587 Parameters
587 Parameters
588 ----------
588 ----------
589 etype : exception type
589 etype : exception type
590 Type of the exception raised.
590 Type of the exception raised.
591
591
592 value : object
592 value : object
593 Data stored in the exception
593 Data stored in the exception
594
594
595 elist : list
595 elist : list
596 List of frames, see class docstring for details.
596 List of frames, see class docstring for details.
597
597
598 tb_offset : int, optional
598 tb_offset : int, optional
599 Number of frames in the traceback to skip. If not given, the
599 Number of frames in the traceback to skip. If not given, the
600 instance value is used (set in constructor).
600 instance value is used (set in constructor).
601
601
602 context : int, optional
602 context : int, optional
603 Number of lines of context information to print.
603 Number of lines of context information to print.
604
604
605 Returns
605 Returns
606 -------
606 -------
607 String with formatted exception.
607 String with formatted exception.
608 """
608 """
609 tb_offset = self.tb_offset if tb_offset is None else tb_offset
609 tb_offset = self.tb_offset if tb_offset is None else tb_offset
610 Colors = self.Colors
610 Colors = self.Colors
611 out_list = []
611 out_list = []
612 if elist:
612 if elist:
613
613
614 if tb_offset and len(elist) > tb_offset:
614 if tb_offset and len(elist) > tb_offset:
615 elist = elist[tb_offset:]
615 elist = elist[tb_offset:]
616
616
617 out_list.append('Traceback %s(most recent call last)%s:' %
617 out_list.append('Traceback %s(most recent call last)%s:' %
618 (Colors.normalEm, Colors.Normal) + '\n')
618 (Colors.normalEm, Colors.Normal) + '\n')
619 out_list.extend(self._format_list(elist))
619 out_list.extend(self._format_list(elist))
620 # The exception info should be a single entry in the list.
620 # The exception info should be a single entry in the list.
621 lines = ''.join(self._format_exception_only(etype, value))
621 lines = ''.join(self._format_exception_only(etype, value))
622 out_list.append(lines)
622 out_list.append(lines)
623
623
624 # Note: this code originally read:
624 # Note: this code originally read:
625
625
626 ## for line in lines[:-1]:
626 ## for line in lines[:-1]:
627 ## out_list.append(" "+line)
627 ## out_list.append(" "+line)
628 ## out_list.append(lines[-1])
628 ## out_list.append(lines[-1])
629
629
630 # This means it was indenting everything but the last line by a little
630 # This means it was indenting everything but the last line by a little
631 # bit. I've disabled this for now, but if we see ugliness somewhere we
631 # bit. I've disabled this for now, but if we see ugliness somewhere we
632 # can restore it.
632 # can restore it.
633
633
634 return out_list
634 return out_list
635
635
636 def _format_list(self, extracted_list):
636 def _format_list(self, extracted_list):
637 """Format a list of traceback entry tuples for printing.
637 """Format a list of traceback entry tuples for printing.
638
638
639 Given a list of tuples as returned by extract_tb() or
639 Given a list of tuples as returned by extract_tb() or
640 extract_stack(), return a list of strings ready for printing.
640 extract_stack(), return a list of strings ready for printing.
641 Each string in the resulting list corresponds to the item with the
641 Each string in the resulting list corresponds to the item with the
642 same index in the argument list. Each string ends in a newline;
642 same index in the argument list. Each string ends in a newline;
643 the strings may contain internal newlines as well, for those items
643 the strings may contain internal newlines as well, for those items
644 whose source text line is not None.
644 whose source text line is not None.
645
645
646 Lifted almost verbatim from traceback.py
646 Lifted almost verbatim from traceback.py
647 """
647 """
648
648
649 Colors = self.Colors
649 Colors = self.Colors
650 list = []
650 list = []
651 for filename, lineno, name, line in extracted_list[:-1]:
651 for filename, lineno, name, line in extracted_list[:-1]:
652 item = ' File %s"%s"%s, line %s%d%s, in %s%s%s\n' % \
652 item = ' File %s"%s"%s, line %s%d%s, in %s%s%s\n' % \
653 (Colors.filename, filename, Colors.Normal,
653 (Colors.filename, filename, Colors.Normal,
654 Colors.lineno, lineno, Colors.Normal,
654 Colors.lineno, lineno, Colors.Normal,
655 Colors.name, name, Colors.Normal)
655 Colors.name, name, Colors.Normal)
656 if line:
656 if line:
657 item += ' %s\n' % line.strip()
657 item += ' %s\n' % line.strip()
658 list.append(item)
658 list.append(item)
659 # Emphasize the last entry
659 # Emphasize the last entry
660 filename, lineno, name, line = extracted_list[-1]
660 filename, lineno, name, line = extracted_list[-1]
661 item = '%s File %s"%s"%s, line %s%d%s, in %s%s%s%s\n' % \
661 item = '%s File %s"%s"%s, line %s%d%s, in %s%s%s%s\n' % \
662 (Colors.normalEm,
662 (Colors.normalEm,
663 Colors.filenameEm, filename, Colors.normalEm,
663 Colors.filenameEm, filename, Colors.normalEm,
664 Colors.linenoEm, lineno, Colors.normalEm,
664 Colors.linenoEm, lineno, Colors.normalEm,
665 Colors.nameEm, name, Colors.normalEm,
665 Colors.nameEm, name, Colors.normalEm,
666 Colors.Normal)
666 Colors.Normal)
667 if line:
667 if line:
668 item += '%s %s%s\n' % (Colors.line, line.strip(),
668 item += '%s %s%s\n' % (Colors.line, line.strip(),
669 Colors.Normal)
669 Colors.Normal)
670 list.append(item)
670 list.append(item)
671 return list
671 return list
672
672
673 def _format_exception_only(self, etype, value):
673 def _format_exception_only(self, etype, value):
674 """Format the exception part of a traceback.
674 """Format the exception part of a traceback.
675
675
676 The arguments are the exception type and value such as given by
676 The arguments are the exception type and value such as given by
677 sys.exc_info()[:2]. The return value is a list of strings, each ending
677 sys.exc_info()[:2]. The return value is a list of strings, each ending
678 in a newline. Normally, the list contains a single string; however,
678 in a newline. Normally, the list contains a single string; however,
679 for SyntaxError exceptions, it contains several lines that (when
679 for SyntaxError exceptions, it contains several lines that (when
680 printed) display detailed information about where the syntax error
680 printed) display detailed information about where the syntax error
681 occurred. The message indicating which exception occurred is the
681 occurred. The message indicating which exception occurred is the
682 always last string in the list.
682 always last string in the list.
683
683
684 Also lifted nearly verbatim from traceback.py
684 Also lifted nearly verbatim from traceback.py
685 """
685 """
686 have_filedata = False
686 have_filedata = False
687 Colors = self.Colors
687 Colors = self.Colors
688 list = []
688 list = []
689 stype = py3compat.cast_unicode(Colors.excName + etype.__name__ + Colors.Normal)
689 stype = py3compat.cast_unicode(Colors.excName + etype.__name__ + Colors.Normal)
690 if value is None:
690 if value is None:
691 # Not sure if this can still happen in Python 2.6 and above
691 # Not sure if this can still happen in Python 2.6 and above
692 list.append(stype + '\n')
692 list.append(stype + '\n')
693 else:
693 else:
694 if issubclass(etype, SyntaxError):
694 if issubclass(etype, SyntaxError):
695 have_filedata = True
695 have_filedata = True
696 if not value.filename: value.filename = "<string>"
696 if not value.filename: value.filename = "<string>"
697 if value.lineno:
697 if value.lineno:
698 lineno = value.lineno
698 lineno = value.lineno
699 textline = linecache.getline(value.filename, value.lineno)
699 textline = linecache.getline(value.filename, value.lineno)
700 else:
700 else:
701 lineno = 'unknown'
701 lineno = 'unknown'
702 textline = ''
702 textline = ''
703 list.append('%s File %s"%s"%s, line %s%s%s\n' % \
703 list.append('%s File %s"%s"%s, line %s%s%s\n' % \
704 (Colors.normalEm,
704 (Colors.normalEm,
705 Colors.filenameEm, py3compat.cast_unicode(value.filename), Colors.normalEm,
705 Colors.filenameEm, py3compat.cast_unicode(value.filename), Colors.normalEm,
706 Colors.linenoEm, lineno, Colors.Normal ))
706 Colors.linenoEm, lineno, Colors.Normal ))
707 if textline == '':
707 if textline == '':
708 textline = py3compat.cast_unicode(value.text, "utf-8")
708 textline = py3compat.cast_unicode(value.text, "utf-8")
709
709
710 if textline is not None:
710 if textline is not None:
711 i = 0
711 i = 0
712 while i < len(textline) and textline[i].isspace():
712 while i < len(textline) and textline[i].isspace():
713 i += 1
713 i += 1
714 list.append('%s %s%s\n' % (Colors.line,
714 list.append('%s %s%s\n' % (Colors.line,
715 textline.strip(),
715 textline.strip(),
716 Colors.Normal))
716 Colors.Normal))
717 if value.offset is not None:
717 if value.offset is not None:
718 s = ' '
718 s = ' '
719 for c in textline[i:value.offset - 1]:
719 for c in textline[i:value.offset - 1]:
720 if c.isspace():
720 if c.isspace():
721 s += c
721 s += c
722 else:
722 else:
723 s += ' '
723 s += ' '
724 list.append('%s%s^%s\n' % (Colors.caret, s,
724 list.append('%s%s^%s\n' % (Colors.caret, s,
725 Colors.Normal))
725 Colors.Normal))
726
726
727 try:
727 try:
728 s = value.msg
728 s = value.msg
729 except Exception:
729 except Exception:
730 s = self._some_str(value)
730 s = self._some_str(value)
731 if s:
731 if s:
732 list.append('%s%s:%s %s\n' % (stype, Colors.excName,
732 list.append('%s%s:%s %s\n' % (stype, Colors.excName,
733 Colors.Normal, s))
733 Colors.Normal, s))
734 else:
734 else:
735 list.append('%s\n' % stype)
735 list.append('%s\n' % stype)
736
736
737 # sync with user hooks
737 # sync with user hooks
738 if have_filedata:
738 if have_filedata:
739 ipinst = get_ipython()
739 ipinst = get_ipython()
740 if ipinst is not None:
740 if ipinst is not None:
741 ipinst.hooks.synchronize_with_editor(value.filename, value.lineno, 0)
741 ipinst.hooks.synchronize_with_editor(value.filename, value.lineno, 0)
742
742
743 return list
743 return list
744
744
745 def get_exception_only(self, etype, value):
745 def get_exception_only(self, etype, value):
746 """Only print the exception type and message, without a traceback.
746 """Only print the exception type and message, without a traceback.
747
747
748 Parameters
748 Parameters
749 ----------
749 ----------
750 etype : exception type
750 etype : exception type
751 value : exception value
751 value : exception value
752 """
752 """
753 return ListTB.structured_traceback(self, etype, value, [])
753 return ListTB.structured_traceback(self, etype, value, [])
754
754
755 def show_exception_only(self, etype, evalue):
755 def show_exception_only(self, etype, evalue):
756 """Only print the exception type and message, without a traceback.
756 """Only print the exception type and message, without a traceback.
757
757
758 Parameters
758 Parameters
759 ----------
759 ----------
760 etype : exception type
760 etype : exception type
761 value : exception value
761 value : exception value
762 """
762 """
763 # This method needs to use __call__ from *this* class, not the one from
763 # This method needs to use __call__ from *this* class, not the one from
764 # a subclass whose signature or behavior may be different
764 # a subclass whose signature or behavior may be different
765 ostream = self.ostream
765 ostream = self.ostream
766 ostream.flush()
766 ostream.flush()
767 ostream.write('\n'.join(self.get_exception_only(etype, evalue)))
767 ostream.write('\n'.join(self.get_exception_only(etype, evalue)))
768 ostream.flush()
768 ostream.flush()
769
769
770 def _some_str(self, value):
770 def _some_str(self, value):
771 # Lifted from traceback.py
771 # Lifted from traceback.py
772 try:
772 try:
773 return py3compat.cast_unicode(str(value))
773 return py3compat.cast_unicode(str(value))
774 except:
774 except:
775 return u'<unprintable %s object>' % type(value).__name__
775 return u'<unprintable %s object>' % type(value).__name__
776
776
777
777
778 #----------------------------------------------------------------------------
778 #----------------------------------------------------------------------------
779 class VerboseTB(TBTools):
779 class VerboseTB(TBTools):
780 """A port of Ka-Ping Yee's cgitb.py module that outputs color text instead
780 """A port of Ka-Ping Yee's cgitb.py module that outputs color text instead
781 of HTML. Requires inspect and pydoc. Crazy, man.
781 of HTML. Requires inspect and pydoc. Crazy, man.
782
782
783 Modified version which optionally strips the topmost entries from the
783 Modified version which optionally strips the topmost entries from the
784 traceback, to be used with alternate interpreters (because their own code
784 traceback, to be used with alternate interpreters (because their own code
785 would appear in the traceback)."""
785 would appear in the traceback)."""
786
786
787 def __init__(self, color_scheme='Linux', call_pdb=False, ostream=None,
787 def __init__(self, color_scheme='Linux', call_pdb=False, ostream=None,
788 tb_offset=0, long_header=False, include_vars=True,
788 tb_offset=0, long_header=False, include_vars=True,
789 check_cache=None, debugger_cls = None,
789 check_cache=None, debugger_cls = None,
790 parent=None, config=None):
790 parent=None, config=None):
791 """Specify traceback offset, headers and color scheme.
791 """Specify traceback offset, headers and color scheme.
792
792
793 Define how many frames to drop from the tracebacks. Calling it with
793 Define how many frames to drop from the tracebacks. Calling it with
794 tb_offset=1 allows use of this handler in interpreters which will have
794 tb_offset=1 allows use of this handler in interpreters which will have
795 their own code at the top of the traceback (VerboseTB will first
795 their own code at the top of the traceback (VerboseTB will first
796 remove that frame before printing the traceback info)."""
796 remove that frame before printing the traceback info)."""
797 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
797 TBTools.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
798 ostream=ostream, parent=parent, config=config)
798 ostream=ostream, parent=parent, config=config)
799 self.tb_offset = tb_offset
799 self.tb_offset = tb_offset
800 self.long_header = long_header
800 self.long_header = long_header
801 self.include_vars = include_vars
801 self.include_vars = include_vars
802 # By default we use linecache.checkcache, but the user can provide a
802 # By default we use linecache.checkcache, but the user can provide a
803 # different check_cache implementation. This is used by the IPython
803 # different check_cache implementation. This is used by the IPython
804 # kernel to provide tracebacks for interactive code that is cached,
804 # kernel to provide tracebacks for interactive code that is cached,
805 # by a compiler instance that flushes the linecache but preserves its
805 # by a compiler instance that flushes the linecache but preserves its
806 # own code cache.
806 # own code cache.
807 if check_cache is None:
807 if check_cache is None:
808 check_cache = linecache.checkcache
808 check_cache = linecache.checkcache
809 self.check_cache = check_cache
809 self.check_cache = check_cache
810
810
811 self.debugger_cls = debugger_cls or debugger.Pdb
811 self.debugger_cls = debugger_cls or debugger.Pdb
812
812
813 def format_records(self, records, last_unique, recursion_repeat):
813 def format_records(self, records, last_unique, recursion_repeat):
814 """Format the stack frames of the traceback"""
814 """Format the stack frames of the traceback"""
815 frames = []
815 frames = []
816 for r in records[:last_unique+recursion_repeat+1]:
816 for r in records[:last_unique+recursion_repeat+1]:
817 #print '*** record:',file,lnum,func,lines,index # dbg
817 #print '*** record:',file,lnum,func,lines,index # dbg
818 frames.append(self.format_record(*r))
818 frames.append(self.format_record(*r))
819
819
820 if recursion_repeat:
820 if recursion_repeat:
821 frames.append('... last %d frames repeated, from the frame below ...\n' % recursion_repeat)
821 frames.append('... last %d frames repeated, from the frame below ...\n' % recursion_repeat)
822 frames.append(self.format_record(*records[last_unique+recursion_repeat+1]))
822 frames.append(self.format_record(*records[last_unique+recursion_repeat+1]))
823
823
824 return frames
824 return frames
825
825
826 def format_record(self, frame, file, lnum, func, lines, index):
826 def format_record(self, frame, file, lnum, func, lines, index):
827 """Format a single stack frame"""
827 """Format a single stack frame"""
828 Colors = self.Colors # just a shorthand + quicker name lookup
828 Colors = self.Colors # just a shorthand + quicker name lookup
829 ColorsNormal = Colors.Normal # used a lot
829 ColorsNormal = Colors.Normal # used a lot
830 col_scheme = self.color_scheme_table.active_scheme_name
830 col_scheme = self.color_scheme_table.active_scheme_name
831 indent = ' ' * INDENT_SIZE
831 indent = ' ' * INDENT_SIZE
832 em_normal = '%s\n%s%s' % (Colors.valEm, indent, ColorsNormal)
832 em_normal = '%s\n%s%s' % (Colors.valEm, indent, ColorsNormal)
833 undefined = '%sundefined%s' % (Colors.em, ColorsNormal)
833 undefined = '%sundefined%s' % (Colors.em, ColorsNormal)
834 tpl_link = '%s%%s%s' % (Colors.filenameEm, ColorsNormal)
834 tpl_link = '%s%%s%s' % (Colors.filenameEm, ColorsNormal)
835 tpl_call = 'in %s%%s%s%%s%s' % (Colors.vName, Colors.valEm,
835 tpl_call = 'in %s%%s%s%%s%s' % (Colors.vName, Colors.valEm,
836 ColorsNormal)
836 ColorsNormal)
837 tpl_call_fail = 'in %s%%s%s(***failed resolving arguments***)%s' % \
837 tpl_call_fail = 'in %s%%s%s(***failed resolving arguments***)%s' % \
838 (Colors.vName, Colors.valEm, ColorsNormal)
838 (Colors.vName, Colors.valEm, ColorsNormal)
839 tpl_local_var = '%s%%s%s' % (Colors.vName, ColorsNormal)
839 tpl_local_var = '%s%%s%s' % (Colors.vName, ColorsNormal)
840 tpl_global_var = '%sglobal%s %s%%s%s' % (Colors.em, ColorsNormal,
840 tpl_global_var = '%sglobal%s %s%%s%s' % (Colors.em, ColorsNormal,
841 Colors.vName, ColorsNormal)
841 Colors.vName, ColorsNormal)
842 tpl_name_val = '%%s %s= %%s%s' % (Colors.valEm, ColorsNormal)
842 tpl_name_val = '%%s %s= %%s%s' % (Colors.valEm, ColorsNormal)
843
843
844 tpl_line = '%s%%s%s %%s' % (Colors.lineno, ColorsNormal)
844 tpl_line = '%s%%s%s %%s' % (Colors.lineno, ColorsNormal)
845 tpl_line_em = '%s%%s%s %%s%s' % (Colors.linenoEm, Colors.line,
845 tpl_line_em = '%s%%s%s %%s%s' % (Colors.linenoEm, Colors.line,
846 ColorsNormal)
846 ColorsNormal)
847
847
848 abspath = os.path.abspath
848 abspath = os.path.abspath
849
849
850
850
851 if not file:
851 if not file:
852 file = '?'
852 file = '?'
853 elif file.startswith(str("<")) and file.endswith(str(">")):
853 elif file.startswith(str("<")) and file.endswith(str(">")):
854 # Not a real filename, no problem...
854 # Not a real filename, no problem...
855 pass
855 pass
856 elif not os.path.isabs(file):
856 elif not os.path.isabs(file):
857 # Try to make the filename absolute by trying all
857 # Try to make the filename absolute by trying all
858 # sys.path entries (which is also what linecache does)
858 # sys.path entries (which is also what linecache does)
859 for dirname in sys.path:
859 for dirname in sys.path:
860 try:
860 try:
861 fullname = os.path.join(dirname, file)
861 fullname = os.path.join(dirname, file)
862 if os.path.isfile(fullname):
862 if os.path.isfile(fullname):
863 file = os.path.abspath(fullname)
863 file = os.path.abspath(fullname)
864 break
864 break
865 except Exception:
865 except Exception:
866 # Just in case that sys.path contains very
866 # Just in case that sys.path contains very
867 # strange entries...
867 # strange entries...
868 pass
868 pass
869
869
870 file = py3compat.cast_unicode(file, util_path.fs_encoding)
870 file = py3compat.cast_unicode(file, util_path.fs_encoding)
871 link = tpl_link % util_path.compress_user(file)
871 link = tpl_link % util_path.compress_user(file)
872 args, varargs, varkw, locals = inspect.getargvalues(frame)
872 args, varargs, varkw, locals = inspect.getargvalues(frame)
873
873
874 if func == '?':
874 if func == '?':
875 call = ''
875 call = ''
876 else:
876 else:
877 # Decide whether to include variable details or not
877 # Decide whether to include variable details or not
878 var_repr = self.include_vars and eqrepr or nullrepr
878 var_repr = self.include_vars and eqrepr or nullrepr
879 try:
879 try:
880 call = tpl_call % (func, inspect.formatargvalues(args,
880 call = tpl_call % (func, inspect.formatargvalues(args,
881 varargs, varkw,
881 varargs, varkw,
882 locals, formatvalue=var_repr))
882 locals, formatvalue=var_repr))
883 except KeyError:
883 except KeyError:
884 # This happens in situations like errors inside generator
884 # This happens in situations like errors inside generator
885 # expressions, where local variables are listed in the
885 # expressions, where local variables are listed in the
886 # line, but can't be extracted from the frame. I'm not
886 # line, but can't be extracted from the frame. I'm not
887 # 100% sure this isn't actually a bug in inspect itself,
887 # 100% sure this isn't actually a bug in inspect itself,
888 # but since there's no info for us to compute with, the
888 # but since there's no info for us to compute with, the
889 # best we can do is report the failure and move on. Here
889 # best we can do is report the failure and move on. Here
890 # we must *not* call any traceback construction again,
890 # we must *not* call any traceback construction again,
891 # because that would mess up use of %debug later on. So we
891 # because that would mess up use of %debug later on. So we
892 # simply report the failure and move on. The only
892 # simply report the failure and move on. The only
893 # limitation will be that this frame won't have locals
893 # limitation will be that this frame won't have locals
894 # listed in the call signature. Quite subtle problem...
894 # listed in the call signature. Quite subtle problem...
895 # I can't think of a good way to validate this in a unit
895 # I can't think of a good way to validate this in a unit
896 # test, but running a script consisting of:
896 # test, but running a script consisting of:
897 # dict( (k,v.strip()) for (k,v) in range(10) )
897 # dict( (k,v.strip()) for (k,v) in range(10) )
898 # will illustrate the error, if this exception catch is
898 # will illustrate the error, if this exception catch is
899 # disabled.
899 # disabled.
900 call = tpl_call_fail % func
900 call = tpl_call_fail % func
901
901
902 # Don't attempt to tokenize binary files.
902 # Don't attempt to tokenize binary files.
903 if file.endswith(('.so', '.pyd', '.dll')):
903 if file.endswith(('.so', '.pyd', '.dll')):
904 return '%s %s\n' % (link, call)
904 return '%s %s\n' % (link, call)
905
905
906 elif file.endswith(('.pyc', '.pyo')):
906 elif file.endswith(('.pyc', '.pyo')):
907 # Look up the corresponding source file.
907 # Look up the corresponding source file.
908 try:
908 try:
909 file = openpy.source_from_cache(file)
909 file = openpy.source_from_cache(file)
910 except ValueError:
910 except ValueError:
911 # Failed to get the source file for some reason
911 # Failed to get the source file for some reason
912 # E.g. https://github.com/ipython/ipython/issues/9486
912 # E.g. https://github.com/ipython/ipython/issues/9486
913 return '%s %s\n' % (link, call)
913 return '%s %s\n' % (link, call)
914
914
915 def linereader(file=file, lnum=[lnum], getline=linecache.getline):
915 def linereader(file=file, lnum=[lnum], getline=linecache.getline):
916 line = getline(file, lnum[0])
916 line = getline(file, lnum[0])
917 lnum[0] += 1
917 lnum[0] += 1
918 return line
918 return line
919
919
920 # Build the list of names on this line of code where the exception
920 # Build the list of names on this line of code where the exception
921 # occurred.
921 # occurred.
922 try:
922 try:
923 names = []
923 names = []
924 name_cont = False
924 name_cont = False
925
925
926 for token_type, token, start, end, line in generate_tokens(linereader):
926 for token_type, token, start, end, line in generate_tokens(linereader):
927 # build composite names
927 # build composite names
928 if token_type == tokenize.NAME and token not in keyword.kwlist:
928 if token_type == tokenize.NAME and token not in keyword.kwlist:
929 if name_cont:
929 if name_cont:
930 # Continuation of a dotted name
930 # Continuation of a dotted name
931 try:
931 try:
932 names[-1].append(token)
932 names[-1].append(token)
933 except IndexError:
933 except IndexError:
934 names.append([token])
934 names.append([token])
935 name_cont = False
935 name_cont = False
936 else:
936 else:
937 # Regular new names. We append everything, the caller
937 # Regular new names. We append everything, the caller
938 # will be responsible for pruning the list later. It's
938 # will be responsible for pruning the list later. It's
939 # very tricky to try to prune as we go, b/c composite
939 # very tricky to try to prune as we go, b/c composite
940 # names can fool us. The pruning at the end is easy
940 # names can fool us. The pruning at the end is easy
941 # to do (or the caller can print a list with repeated
941 # to do (or the caller can print a list with repeated
942 # names if so desired.
942 # names if so desired.
943 names.append([token])
943 names.append([token])
944 elif token == '.':
944 elif token == '.':
945 name_cont = True
945 name_cont = True
946 elif token_type == tokenize.NEWLINE:
946 elif token_type == tokenize.NEWLINE:
947 break
947 break
948
948
949 except (IndexError, UnicodeDecodeError, SyntaxError):
949 except (IndexError, UnicodeDecodeError, SyntaxError):
950 # signals exit of tokenizer
950 # signals exit of tokenizer
951 # SyntaxError can occur if the file is not actually Python
951 # SyntaxError can occur if the file is not actually Python
952 # - see gh-6300
952 # - see gh-6300
953 pass
953 pass
954 except tokenize.TokenError as msg:
954 except tokenize.TokenError as msg:
955 # Tokenizing may fail for various reasons, many of which are
955 # Tokenizing may fail for various reasons, many of which are
956 # harmless. (A good example is when the line in question is the
956 # harmless. (A good example is when the line in question is the
957 # close of a triple-quoted string, cf gh-6864). We don't want to
957 # close of a triple-quoted string, cf gh-6864). We don't want to
958 # show this to users, but want make it available for debugging
958 # show this to users, but want make it available for debugging
959 # purposes.
959 # purposes.
960 _m = ("An unexpected error occurred while tokenizing input\n"
960 _m = ("An unexpected error occurred while tokenizing input\n"
961 "The following traceback may be corrupted or invalid\n"
961 "The following traceback may be corrupted or invalid\n"
962 "The error message is: %s\n" % msg)
962 "The error message is: %s\n" % msg)
963 debug(_m)
963 debug(_m)
964
964
965 # Join composite names (e.g. "dict.fromkeys")
965 # Join composite names (e.g. "dict.fromkeys")
966 names = ['.'.join(n) for n in names]
966 names = ['.'.join(n) for n in names]
967 # prune names list of duplicates, but keep the right order
967 # prune names list of duplicates, but keep the right order
968 unique_names = uniq_stable(names)
968 unique_names = uniq_stable(names)
969
969
970 # Start loop over vars
970 # Start loop over vars
971 lvals = []
971 lvals = []
972 if self.include_vars:
972 if self.include_vars:
973 for name_full in unique_names:
973 for name_full in unique_names:
974 name_base = name_full.split('.', 1)[0]
974 name_base = name_full.split('.', 1)[0]
975 if name_base in frame.f_code.co_varnames:
975 if name_base in frame.f_code.co_varnames:
976 if name_base in locals:
976 if name_base in locals:
977 try:
977 try:
978 value = repr(eval(name_full, locals))
978 value = repr(eval(name_full, locals))
979 except:
979 except:
980 value = undefined
980 value = undefined
981 else:
981 else:
982 value = undefined
982 value = undefined
983 name = tpl_local_var % name_full
983 name = tpl_local_var % name_full
984 else:
984 else:
985 if name_base in frame.f_globals:
985 if name_base in frame.f_globals:
986 try:
986 try:
987 value = repr(eval(name_full, frame.f_globals))
987 value = repr(eval(name_full, frame.f_globals))
988 except:
988 except:
989 value = undefined
989 value = undefined
990 else:
990 else:
991 value = undefined
991 value = undefined
992 name = tpl_global_var % name_full
992 name = tpl_global_var % name_full
993 lvals.append(tpl_name_val % (name, value))
993 lvals.append(tpl_name_val % (name, value))
994 if lvals:
994 if lvals:
995 lvals = '%s%s' % (indent, em_normal.join(lvals))
995 lvals = '%s%s' % (indent, em_normal.join(lvals))
996 else:
996 else:
997 lvals = ''
997 lvals = ''
998
998
999 level = '%s %s\n' % (link, call)
999 level = '%s %s\n' % (link, call)
1000
1000
1001 if index is None:
1001 if index is None:
1002 return level
1002 return level
1003 else:
1003 else:
1004 _line_format = PyColorize.Parser(style=col_scheme, parent=self).format2
1004 _line_format = PyColorize.Parser(style=col_scheme, parent=self).format2
1005 return '%s%s' % (level, ''.join(
1005 return '%s%s' % (level, ''.join(
1006 _format_traceback_lines(lnum, index, lines, Colors, lvals,
1006 _format_traceback_lines(lnum, index, lines, Colors, lvals,
1007 _line_format)))
1007 _line_format)))
1008
1008
1009 def prepare_chained_exception_message(self, cause):
1009 def prepare_chained_exception_message(self, cause):
1010 direct_cause = "\nThe above exception was the direct cause of the following exception:\n"
1010 direct_cause = "\nThe above exception was the direct cause of the following exception:\n"
1011 exception_during_handling = "\nDuring handling of the above exception, another exception occurred:\n"
1011 exception_during_handling = "\nDuring handling of the above exception, another exception occurred:\n"
1012
1012
1013 if cause:
1013 if cause:
1014 message = [[direct_cause]]
1014 message = [[direct_cause]]
1015 else:
1015 else:
1016 message = [[exception_during_handling]]
1016 message = [[exception_during_handling]]
1017 return message
1017 return message
1018
1018
1019 def prepare_header(self, etype, long_version=False):
1019 def prepare_header(self, etype, long_version=False):
1020 colors = self.Colors # just a shorthand + quicker name lookup
1020 colors = self.Colors # just a shorthand + quicker name lookup
1021 colorsnormal = colors.Normal # used a lot
1021 colorsnormal = colors.Normal # used a lot
1022 exc = '%s%s%s' % (colors.excName, etype, colorsnormal)
1022 exc = '%s%s%s' % (colors.excName, etype, colorsnormal)
1023 width = min(75, get_terminal_size()[0])
1023 width = min(75, get_terminal_size()[0])
1024 if long_version:
1024 if long_version:
1025 # Header with the exception type, python version, and date
1025 # Header with the exception type, python version, and date
1026 pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
1026 pyver = 'Python ' + sys.version.split()[0] + ': ' + sys.executable
1027 date = time.ctime(time.time())
1027 date = time.ctime(time.time())
1028
1028
1029 head = '%s%s%s\n%s%s%s\n%s' % (colors.topline, '-' * width, colorsnormal,
1029 head = '%s%s%s\n%s%s%s\n%s' % (colors.topline, '-' * width, colorsnormal,
1030 exc, ' ' * (width - len(str(etype)) - len(pyver)),
1030 exc, ' ' * (width - len(str(etype)) - len(pyver)),
1031 pyver, date.rjust(width) )
1031 pyver, date.rjust(width) )
1032 head += "\nA problem occurred executing Python code. Here is the sequence of function" \
1032 head += "\nA problem occurred executing Python code. Here is the sequence of function" \
1033 "\ncalls leading up to the error, with the most recent (innermost) call last."
1033 "\ncalls leading up to the error, with the most recent (innermost) call last."
1034 else:
1034 else:
1035 # Simplified header
1035 # Simplified header
1036 head = '%s%s' % (exc, 'Traceback (most recent call last)'. \
1036 head = '%s%s' % (exc, 'Traceback (most recent call last)'. \
1037 rjust(width - len(str(etype))) )
1037 rjust(width - len(str(etype))) )
1038
1038
1039 return head
1039 return head
1040
1040
1041 def format_exception(self, etype, evalue):
1041 def format_exception(self, etype, evalue):
1042 colors = self.Colors # just a shorthand + quicker name lookup
1042 colors = self.Colors # just a shorthand + quicker name lookup
1043 colorsnormal = colors.Normal # used a lot
1043 colorsnormal = colors.Normal # used a lot
1044 indent = ' ' * INDENT_SIZE
1044 indent = ' ' * INDENT_SIZE
1045 # Get (safely) a string form of the exception info
1045 # Get (safely) a string form of the exception info
1046 try:
1046 try:
1047 etype_str, evalue_str = map(str, (etype, evalue))
1047 etype_str, evalue_str = map(str, (etype, evalue))
1048 except:
1048 except:
1049 # User exception is improperly defined.
1049 # User exception is improperly defined.
1050 etype, evalue = str, sys.exc_info()[:2]
1050 etype, evalue = str, sys.exc_info()[:2]
1051 etype_str, evalue_str = map(str, (etype, evalue))
1051 etype_str, evalue_str = map(str, (etype, evalue))
1052 # ... and format it
1052 # ... and format it
1053 return ['%s%s%s: %s' % (colors.excName, etype_str,
1053 return ['%s%s%s: %s' % (colors.excName, etype_str,
1054 colorsnormal, py3compat.cast_unicode(evalue_str))]
1054 colorsnormal, py3compat.cast_unicode(evalue_str))]
1055
1055
1056 def format_exception_as_a_whole(self, etype, evalue, etb, number_of_lines_of_context, tb_offset):
1056 def format_exception_as_a_whole(self, etype, evalue, etb, number_of_lines_of_context, tb_offset):
1057 """Formats the header, traceback and exception message for a single exception.
1057 """Formats the header, traceback and exception message for a single exception.
1058
1058
1059 This may be called multiple times by Python 3 exception chaining
1059 This may be called multiple times by Python 3 exception chaining
1060 (PEP 3134).
1060 (PEP 3134).
1061 """
1061 """
1062 # some locals
1062 # some locals
1063 orig_etype = etype
1063 orig_etype = etype
1064 try:
1064 try:
1065 etype = etype.__name__
1065 etype = etype.__name__
1066 except AttributeError:
1066 except AttributeError:
1067 pass
1067 pass
1068
1068
1069 tb_offset = self.tb_offset if tb_offset is None else tb_offset
1069 tb_offset = self.tb_offset if tb_offset is None else tb_offset
1070 head = self.prepare_header(etype, self.long_header)
1070 head = self.prepare_header(etype, self.long_header)
1071 records = self.get_records(etb, number_of_lines_of_context, tb_offset)
1071 records = self.get_records(etb, number_of_lines_of_context, tb_offset)
1072
1072
1073 if records is None:
1073 if records is None:
1074 return ""
1074 return ""
1075
1075
1076 last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
1076 last_unique, recursion_repeat = find_recursion(orig_etype, evalue, records)
1077
1077
1078 frames = self.format_records(records, last_unique, recursion_repeat)
1078 frames = self.format_records(records, last_unique, recursion_repeat)
1079
1079
1080 formatted_exception = self.format_exception(etype, evalue)
1080 formatted_exception = self.format_exception(etype, evalue)
1081 if records:
1081 if records:
1082 filepath, lnum = records[-1][1:3]
1082 filepath, lnum = records[-1][1:3]
1083 filepath = os.path.abspath(filepath)
1083 filepath = os.path.abspath(filepath)
1084 ipinst = get_ipython()
1084 ipinst = get_ipython()
1085 if ipinst is not None:
1085 if ipinst is not None:
1086 ipinst.hooks.synchronize_with_editor(filepath, lnum, 0)
1086 ipinst.hooks.synchronize_with_editor(filepath, lnum, 0)
1087
1087
1088 return [[head] + frames + [''.join(formatted_exception[0])]]
1088 return [[head] + frames + [''.join(formatted_exception[0])]]
1089
1089
1090 def get_records(self, etb, number_of_lines_of_context, tb_offset):
1090 def get_records(self, etb, number_of_lines_of_context, tb_offset):
1091 try:
1091 try:
1092 # Try the default getinnerframes and Alex's: Alex's fixes some
1092 # Try the default getinnerframes and Alex's: Alex's fixes some
1093 # problems, but it generates empty tracebacks for console errors
1093 # problems, but it generates empty tracebacks for console errors
1094 # (5 blanks lines) where none should be returned.
1094 # (5 blanks lines) where none should be returned.
1095 return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
1095 return _fixed_getinnerframes(etb, number_of_lines_of_context, tb_offset)
1096 except UnicodeDecodeError:
1096 except UnicodeDecodeError:
1097 # This can occur if a file's encoding magic comment is wrong.
1097 # This can occur if a file's encoding magic comment is wrong.
1098 # I can't see a way to recover without duplicating a bunch of code
1098 # I can't see a way to recover without duplicating a bunch of code
1099 # from the stdlib traceback module. --TK
1099 # from the stdlib traceback module. --TK
1100 error('\nUnicodeDecodeError while processing traceback.\n')
1100 error('\nUnicodeDecodeError while processing traceback.\n')
1101 return None
1101 return None
1102 except:
1102 except:
1103 # FIXME: I've been getting many crash reports from python 2.3
1103 # FIXME: I've been getting many crash reports from python 2.3
1104 # users, traceable to inspect.py. If I can find a small test-case
1104 # users, traceable to inspect.py. If I can find a small test-case
1105 # to reproduce this, I should either write a better workaround or
1105 # to reproduce this, I should either write a better workaround or
1106 # file a bug report against inspect (if that's the real problem).
1106 # file a bug report against inspect (if that's the real problem).
1107 # So far, I haven't been able to find an isolated example to
1107 # So far, I haven't been able to find an isolated example to
1108 # reproduce the problem.
1108 # reproduce the problem.
1109 inspect_error()
1109 inspect_error()
1110 traceback.print_exc(file=self.ostream)
1110 traceback.print_exc(file=self.ostream)
1111 info('\nUnfortunately, your original traceback can not be constructed.\n')
1111 info('\nUnfortunately, your original traceback can not be constructed.\n')
1112 return None
1112 return None
1113
1113
1114 def get_parts_of_chained_exception(self, evalue):
1114 def get_parts_of_chained_exception(self, evalue):
1115 def get_chained_exception(exception_value):
1115 def get_chained_exception(exception_value):
1116 cause = getattr(exception_value, '__cause__', None)
1116 cause = getattr(exception_value, '__cause__', None)
1117 if cause:
1117 if cause:
1118 return cause
1118 return cause
1119 if getattr(exception_value, '__suppress_context__', False):
1119 if getattr(exception_value, '__suppress_context__', False):
1120 return None
1120 return None
1121 return getattr(exception_value, '__context__', None)
1121 return getattr(exception_value, '__context__', None)
1122
1122
1123 chained_evalue = get_chained_exception(evalue)
1123 chained_evalue = get_chained_exception(evalue)
1124
1124
1125 if chained_evalue:
1125 if chained_evalue:
1126 return chained_evalue.__class__, chained_evalue, chained_evalue.__traceback__
1126 return chained_evalue.__class__, chained_evalue, chained_evalue.__traceback__
1127
1127
1128 def structured_traceback(self, etype, evalue, etb, tb_offset=None,
1128 def structured_traceback(self, etype, evalue, etb, tb_offset=None,
1129 number_of_lines_of_context=5):
1129 number_of_lines_of_context=5):
1130 """Return a nice text document describing the traceback."""
1130 """Return a nice text document describing the traceback."""
1131
1131
1132 formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
1132 formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
1133 tb_offset)
1133 tb_offset)
1134
1134
1135 colors = self.Colors # just a shorthand + quicker name lookup
1135 colors = self.Colors # just a shorthand + quicker name lookup
1136 colorsnormal = colors.Normal # used a lot
1136 colorsnormal = colors.Normal # used a lot
1137 head = '%s%s%s' % (colors.topline, '-' * min(75, get_terminal_size()[0]), colorsnormal)
1137 head = '%s%s%s' % (colors.topline, '-' * min(75, get_terminal_size()[0]), colorsnormal)
1138 structured_traceback_parts = [head]
1138 structured_traceback_parts = [head]
1139 if py3compat.PY3:
1139 chained_exceptions_tb_offset = 0
1140 chained_exceptions_tb_offset = 0
1140 lines_of_context = 3
1141 lines_of_context = 3
1141 formatted_exceptions = formatted_exception
1142 formatted_exceptions = formatted_exception
1142 exception = self.get_parts_of_chained_exception(evalue)
1143 if exception:
1144 formatted_exceptions += self.prepare_chained_exception_message(evalue.__cause__)
1145 etype, evalue, etb = exception
1146 else:
1147 evalue = None
1148 chained_exc_ids = set()
1149 while evalue:
1150 formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
1151 chained_exceptions_tb_offset)
1143 exception = self.get_parts_of_chained_exception(evalue)
1152 exception = self.get_parts_of_chained_exception(evalue)
1144 if exception:
1153
1154 if exception and not id(exception[1]) in chained_exc_ids:
1155 chained_exc_ids.add(id(exception[1])) # trace exception to avoid infinite 'cause' loop
1145 formatted_exceptions += self.prepare_chained_exception_message(evalue.__cause__)
1156 formatted_exceptions += self.prepare_chained_exception_message(evalue.__cause__)
1146 etype, evalue, etb = exception
1157 etype, evalue, etb = exception
1147 else:
1158 else:
1148 evalue = None
1159 evalue = None
1149 chained_exc_ids = set()
1150 while evalue:
1151 formatted_exceptions += self.format_exception_as_a_whole(etype, evalue, etb, lines_of_context,
1152 chained_exceptions_tb_offset)
1153 exception = self.get_parts_of_chained_exception(evalue)
1154
1155 if exception and not id(exception[1]) in chained_exc_ids:
1156 chained_exc_ids.add(id(exception[1])) # trace exception to avoid infinite 'cause' loop
1157 formatted_exceptions += self.prepare_chained_exception_message(evalue.__cause__)
1158 etype, evalue, etb = exception
1159 else:
1160 evalue = None
1161
1160
1162 # we want to see exceptions in a reversed order:
1161 # we want to see exceptions in a reversed order:
1163 # the first exception should be on top
1162 # the first exception should be on top
1164 for formatted_exception in reversed(formatted_exceptions):
1163 for formatted_exception in reversed(formatted_exceptions):
1165 structured_traceback_parts += formatted_exception
1164 structured_traceback_parts += formatted_exception
1166 else:
1167 structured_traceback_parts += formatted_exception[0]
1168
1165
1169 return structured_traceback_parts
1166 return structured_traceback_parts
1170
1167
1171 def debugger(self, force=False):
1168 def debugger(self, force=False):
1172 """Call up the pdb debugger if desired, always clean up the tb
1169 """Call up the pdb debugger if desired, always clean up the tb
1173 reference.
1170 reference.
1174
1171
1175 Keywords:
1172 Keywords:
1176
1173
1177 - force(False): by default, this routine checks the instance call_pdb
1174 - force(False): by default, this routine checks the instance call_pdb
1178 flag and does not actually invoke the debugger if the flag is false.
1175 flag and does not actually invoke the debugger if the flag is false.
1179 The 'force' option forces the debugger to activate even if the flag
1176 The 'force' option forces the debugger to activate even if the flag
1180 is false.
1177 is false.
1181
1178
1182 If the call_pdb flag is set, the pdb interactive debugger is
1179 If the call_pdb flag is set, the pdb interactive debugger is
1183 invoked. In all cases, the self.tb reference to the current traceback
1180 invoked. In all cases, the self.tb reference to the current traceback
1184 is deleted to prevent lingering references which hamper memory
1181 is deleted to prevent lingering references which hamper memory
1185 management.
1182 management.
1186
1183
1187 Note that each call to pdb() does an 'import readline', so if your app
1184 Note that each call to pdb() does an 'import readline', so if your app
1188 requires a special setup for the readline completers, you'll have to
1185 requires a special setup for the readline completers, you'll have to
1189 fix that by hand after invoking the exception handler."""
1186 fix that by hand after invoking the exception handler."""
1190
1187
1191 if force or self.call_pdb:
1188 if force or self.call_pdb:
1192 if self.pdb is None:
1189 if self.pdb is None:
1193 self.pdb = self.debugger_cls()
1190 self.pdb = self.debugger_cls()
1194 # the system displayhook may have changed, restore the original
1191 # the system displayhook may have changed, restore the original
1195 # for pdb
1192 # for pdb
1196 display_trap = DisplayTrap(hook=sys.__displayhook__)
1193 display_trap = DisplayTrap(hook=sys.__displayhook__)
1197 with display_trap:
1194 with display_trap:
1198 self.pdb.reset()
1195 self.pdb.reset()
1199 # Find the right frame so we don't pop up inside ipython itself
1196 # Find the right frame so we don't pop up inside ipython itself
1200 if hasattr(self, 'tb') and self.tb is not None:
1197 if hasattr(self, 'tb') and self.tb is not None:
1201 etb = self.tb
1198 etb = self.tb
1202 else:
1199 else:
1203 etb = self.tb = sys.last_traceback
1200 etb = self.tb = sys.last_traceback
1204 while self.tb is not None and self.tb.tb_next is not None:
1201 while self.tb is not None and self.tb.tb_next is not None:
1205 self.tb = self.tb.tb_next
1202 self.tb = self.tb.tb_next
1206 if etb and etb.tb_next:
1203 if etb and etb.tb_next:
1207 etb = etb.tb_next
1204 etb = etb.tb_next
1208 self.pdb.botframe = etb.tb_frame
1205 self.pdb.botframe = etb.tb_frame
1209 self.pdb.interaction(self.tb.tb_frame, self.tb)
1206 self.pdb.interaction(self.tb.tb_frame, self.tb)
1210
1207
1211 if hasattr(self, 'tb'):
1208 if hasattr(self, 'tb'):
1212 del self.tb
1209 del self.tb
1213
1210
1214 def handler(self, info=None):
1211 def handler(self, info=None):
1215 (etype, evalue, etb) = info or sys.exc_info()
1212 (etype, evalue, etb) = info or sys.exc_info()
1216 self.tb = etb
1213 self.tb = etb
1217 ostream = self.ostream
1214 ostream = self.ostream
1218 ostream.flush()
1215 ostream.flush()
1219 ostream.write(self.text(etype, evalue, etb))
1216 ostream.write(self.text(etype, evalue, etb))
1220 ostream.write('\n')
1217 ostream.write('\n')
1221 ostream.flush()
1218 ostream.flush()
1222
1219
1223 # Changed so an instance can just be called as VerboseTB_inst() and print
1220 # Changed so an instance can just be called as VerboseTB_inst() and print
1224 # out the right info on its own.
1221 # out the right info on its own.
1225 def __call__(self, etype=None, evalue=None, etb=None):
1222 def __call__(self, etype=None, evalue=None, etb=None):
1226 """This hook can replace sys.excepthook (for Python 2.1 or higher)."""
1223 """This hook can replace sys.excepthook (for Python 2.1 or higher)."""
1227 if etb is None:
1224 if etb is None:
1228 self.handler()
1225 self.handler()
1229 else:
1226 else:
1230 self.handler((etype, evalue, etb))
1227 self.handler((etype, evalue, etb))
1231 try:
1228 try:
1232 self.debugger()
1229 self.debugger()
1233 except KeyboardInterrupt:
1230 except KeyboardInterrupt:
1234 print("\nKeyboardInterrupt")
1231 print("\nKeyboardInterrupt")
1235
1232
1236
1233
1237 #----------------------------------------------------------------------------
1234 #----------------------------------------------------------------------------
1238 class FormattedTB(VerboseTB, ListTB):
1235 class FormattedTB(VerboseTB, ListTB):
1239 """Subclass ListTB but allow calling with a traceback.
1236 """Subclass ListTB but allow calling with a traceback.
1240
1237
1241 It can thus be used as a sys.excepthook for Python > 2.1.
1238 It can thus be used as a sys.excepthook for Python > 2.1.
1242
1239
1243 Also adds 'Context' and 'Verbose' modes, not available in ListTB.
1240 Also adds 'Context' and 'Verbose' modes, not available in ListTB.
1244
1241
1245 Allows a tb_offset to be specified. This is useful for situations where
1242 Allows a tb_offset to be specified. This is useful for situations where
1246 one needs to remove a number of topmost frames from the traceback (such as
1243 one needs to remove a number of topmost frames from the traceback (such as
1247 occurs with python programs that themselves execute other python code,
1244 occurs with python programs that themselves execute other python code,
1248 like Python shells). """
1245 like Python shells). """
1249
1246
1250 def __init__(self, mode='Plain', color_scheme='Linux', call_pdb=False,
1247 def __init__(self, mode='Plain', color_scheme='Linux', call_pdb=False,
1251 ostream=None,
1248 ostream=None,
1252 tb_offset=0, long_header=False, include_vars=False,
1249 tb_offset=0, long_header=False, include_vars=False,
1253 check_cache=None, debugger_cls=None,
1250 check_cache=None, debugger_cls=None,
1254 parent=None, config=None):
1251 parent=None, config=None):
1255
1252
1256 # NEVER change the order of this list. Put new modes at the end:
1253 # NEVER change the order of this list. Put new modes at the end:
1257 self.valid_modes = ['Plain', 'Context', 'Verbose']
1254 self.valid_modes = ['Plain', 'Context', 'Verbose']
1258 self.verbose_modes = self.valid_modes[1:3]
1255 self.verbose_modes = self.valid_modes[1:3]
1259
1256
1260 VerboseTB.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
1257 VerboseTB.__init__(self, color_scheme=color_scheme, call_pdb=call_pdb,
1261 ostream=ostream, tb_offset=tb_offset,
1258 ostream=ostream, tb_offset=tb_offset,
1262 long_header=long_header, include_vars=include_vars,
1259 long_header=long_header, include_vars=include_vars,
1263 check_cache=check_cache, debugger_cls=debugger_cls,
1260 check_cache=check_cache, debugger_cls=debugger_cls,
1264 parent=parent, config=config)
1261 parent=parent, config=config)
1265
1262
1266 # Different types of tracebacks are joined with different separators to
1263 # Different types of tracebacks are joined with different separators to
1267 # form a single string. They are taken from this dict
1264 # form a single string. They are taken from this dict
1268 self._join_chars = dict(Plain='', Context='\n', Verbose='\n')
1265 self._join_chars = dict(Plain='', Context='\n', Verbose='\n')
1269 # set_mode also sets the tb_join_char attribute
1266 # set_mode also sets the tb_join_char attribute
1270 self.set_mode(mode)
1267 self.set_mode(mode)
1271
1268
1272 def _extract_tb(self, tb):
1269 def _extract_tb(self, tb):
1273 if tb:
1270 if tb:
1274 return traceback.extract_tb(tb)
1271 return traceback.extract_tb(tb)
1275 else:
1272 else:
1276 return None
1273 return None
1277
1274
1278 def structured_traceback(self, etype, value, tb, tb_offset=None, number_of_lines_of_context=5):
1275 def structured_traceback(self, etype, value, tb, tb_offset=None, number_of_lines_of_context=5):
1279 tb_offset = self.tb_offset if tb_offset is None else tb_offset
1276 tb_offset = self.tb_offset if tb_offset is None else tb_offset
1280 mode = self.mode
1277 mode = self.mode
1281 if mode in self.verbose_modes:
1278 if mode in self.verbose_modes:
1282 # Verbose modes need a full traceback
1279 # Verbose modes need a full traceback
1283 return VerboseTB.structured_traceback(
1280 return VerboseTB.structured_traceback(
1284 self, etype, value, tb, tb_offset, number_of_lines_of_context
1281 self, etype, value, tb, tb_offset, number_of_lines_of_context
1285 )
1282 )
1286 else:
1283 else:
1287 # We must check the source cache because otherwise we can print
1284 # We must check the source cache because otherwise we can print
1288 # out-of-date source code.
1285 # out-of-date source code.
1289 self.check_cache()
1286 self.check_cache()
1290 # Now we can extract and format the exception
1287 # Now we can extract and format the exception
1291 elist = self._extract_tb(tb)
1288 elist = self._extract_tb(tb)
1292 return ListTB.structured_traceback(
1289 return ListTB.structured_traceback(
1293 self, etype, value, elist, tb_offset, number_of_lines_of_context
1290 self, etype, value, elist, tb_offset, number_of_lines_of_context
1294 )
1291 )
1295
1292
1296 def stb2text(self, stb):
1293 def stb2text(self, stb):
1297 """Convert a structured traceback (a list) to a string."""
1294 """Convert a structured traceback (a list) to a string."""
1298 return self.tb_join_char.join(stb)
1295 return self.tb_join_char.join(stb)
1299
1296
1300
1297
1301 def set_mode(self, mode=None):
1298 def set_mode(self, mode=None):
1302 """Switch to the desired mode.
1299 """Switch to the desired mode.
1303
1300
1304 If mode is not specified, cycles through the available modes."""
1301 If mode is not specified, cycles through the available modes."""
1305
1302
1306 if not mode:
1303 if not mode:
1307 new_idx = (self.valid_modes.index(self.mode) + 1 ) % \
1304 new_idx = (self.valid_modes.index(self.mode) + 1 ) % \
1308 len(self.valid_modes)
1305 len(self.valid_modes)
1309 self.mode = self.valid_modes[new_idx]
1306 self.mode = self.valid_modes[new_idx]
1310 elif mode not in self.valid_modes:
1307 elif mode not in self.valid_modes:
1311 raise ValueError('Unrecognized mode in FormattedTB: <' + mode + '>\n'
1308 raise ValueError('Unrecognized mode in FormattedTB: <' + mode + '>\n'
1312 'Valid modes: ' + str(self.valid_modes))
1309 'Valid modes: ' + str(self.valid_modes))
1313 else:
1310 else:
1314 self.mode = mode
1311 self.mode = mode
1315 # include variable details only in 'Verbose' mode
1312 # include variable details only in 'Verbose' mode
1316 self.include_vars = (self.mode == self.valid_modes[2])
1313 self.include_vars = (self.mode == self.valid_modes[2])
1317 # Set the join character for generating text tracebacks
1314 # Set the join character for generating text tracebacks
1318 self.tb_join_char = self._join_chars[self.mode]
1315 self.tb_join_char = self._join_chars[self.mode]
1319
1316
1320 # some convenient shortcuts
1317 # some convenient shortcuts
1321 def plain(self):
1318 def plain(self):
1322 self.set_mode(self.valid_modes[0])
1319 self.set_mode(self.valid_modes[0])
1323
1320
1324 def context(self):
1321 def context(self):
1325 self.set_mode(self.valid_modes[1])
1322 self.set_mode(self.valid_modes[1])
1326
1323
1327 def verbose(self):
1324 def verbose(self):
1328 self.set_mode(self.valid_modes[2])
1325 self.set_mode(self.valid_modes[2])
1329
1326
1330
1327
1331 #----------------------------------------------------------------------------
1328 #----------------------------------------------------------------------------
1332 class AutoFormattedTB(FormattedTB):
1329 class AutoFormattedTB(FormattedTB):
1333 """A traceback printer which can be called on the fly.
1330 """A traceback printer which can be called on the fly.
1334
1331
1335 It will find out about exceptions by itself.
1332 It will find out about exceptions by itself.
1336
1333
1337 A brief example::
1334 A brief example::
1338
1335
1339 AutoTB = AutoFormattedTB(mode = 'Verbose',color_scheme='Linux')
1336 AutoTB = AutoFormattedTB(mode = 'Verbose',color_scheme='Linux')
1340 try:
1337 try:
1341 ...
1338 ...
1342 except:
1339 except:
1343 AutoTB() # or AutoTB(out=logfile) where logfile is an open file object
1340 AutoTB() # or AutoTB(out=logfile) where logfile is an open file object
1344 """
1341 """
1345
1342
1346 def __call__(self, etype=None, evalue=None, etb=None,
1343 def __call__(self, etype=None, evalue=None, etb=None,
1347 out=None, tb_offset=None):
1344 out=None, tb_offset=None):
1348 """Print out a formatted exception traceback.
1345 """Print out a formatted exception traceback.
1349
1346
1350 Optional arguments:
1347 Optional arguments:
1351 - out: an open file-like object to direct output to.
1348 - out: an open file-like object to direct output to.
1352
1349
1353 - tb_offset: the number of frames to skip over in the stack, on a
1350 - tb_offset: the number of frames to skip over in the stack, on a
1354 per-call basis (this overrides temporarily the instance's tb_offset
1351 per-call basis (this overrides temporarily the instance's tb_offset
1355 given at initialization time. """
1352 given at initialization time. """
1356
1353
1357 if out is None:
1354 if out is None:
1358 out = self.ostream
1355 out = self.ostream
1359 out.flush()
1356 out.flush()
1360 out.write(self.text(etype, evalue, etb, tb_offset))
1357 out.write(self.text(etype, evalue, etb, tb_offset))
1361 out.write('\n')
1358 out.write('\n')
1362 out.flush()
1359 out.flush()
1363 # FIXME: we should remove the auto pdb behavior from here and leave
1360 # FIXME: we should remove the auto pdb behavior from here and leave
1364 # that to the clients.
1361 # that to the clients.
1365 try:
1362 try:
1366 self.debugger()
1363 self.debugger()
1367 except KeyboardInterrupt:
1364 except KeyboardInterrupt:
1368 print("\nKeyboardInterrupt")
1365 print("\nKeyboardInterrupt")
1369
1366
1370 def structured_traceback(self, etype=None, value=None, tb=None,
1367 def structured_traceback(self, etype=None, value=None, tb=None,
1371 tb_offset=None, number_of_lines_of_context=5):
1368 tb_offset=None, number_of_lines_of_context=5):
1372 if etype is None:
1369 if etype is None:
1373 etype, value, tb = sys.exc_info()
1370 etype, value, tb = sys.exc_info()
1374 self.tb = tb
1371 self.tb = tb
1375 return FormattedTB.structured_traceback(
1372 return FormattedTB.structured_traceback(
1376 self, etype, value, tb, tb_offset, number_of_lines_of_context)
1373 self, etype, value, tb, tb_offset, number_of_lines_of_context)
1377
1374
1378
1375
1379 #---------------------------------------------------------------------------
1376 #---------------------------------------------------------------------------
1380
1377
1381 # A simple class to preserve Nathan's original functionality.
1378 # A simple class to preserve Nathan's original functionality.
1382 class ColorTB(FormattedTB):
1379 class ColorTB(FormattedTB):
1383 """Shorthand to initialize a FormattedTB in Linux colors mode."""
1380 """Shorthand to initialize a FormattedTB in Linux colors mode."""
1384
1381
1385 def __init__(self, color_scheme='Linux', call_pdb=0, **kwargs):
1382 def __init__(self, color_scheme='Linux', call_pdb=0, **kwargs):
1386 FormattedTB.__init__(self, color_scheme=color_scheme,
1383 FormattedTB.__init__(self, color_scheme=color_scheme,
1387 call_pdb=call_pdb, **kwargs)
1384 call_pdb=call_pdb, **kwargs)
1388
1385
1389
1386
1390 class SyntaxTB(ListTB):
1387 class SyntaxTB(ListTB):
1391 """Extension which holds some state: the last exception value"""
1388 """Extension which holds some state: the last exception value"""
1392
1389
1393 def __init__(self, color_scheme='NoColor', parent=None, config=None):
1390 def __init__(self, color_scheme='NoColor', parent=None, config=None):
1394 ListTB.__init__(self, color_scheme, parent=parent, config=config)
1391 ListTB.__init__(self, color_scheme, parent=parent, config=config)
1395 self.last_syntax_error = None
1392 self.last_syntax_error = None
1396
1393
1397 def __call__(self, etype, value, elist):
1394 def __call__(self, etype, value, elist):
1398 self.last_syntax_error = value
1395 self.last_syntax_error = value
1399
1396
1400 ListTB.__call__(self, etype, value, elist)
1397 ListTB.__call__(self, etype, value, elist)
1401
1398
1402 def structured_traceback(self, etype, value, elist, tb_offset=None,
1399 def structured_traceback(self, etype, value, elist, tb_offset=None,
1403 context=5):
1400 context=5):
1404 # If the source file has been edited, the line in the syntax error can
1401 # If the source file has been edited, the line in the syntax error can
1405 # be wrong (retrieved from an outdated cache). This replaces it with
1402 # be wrong (retrieved from an outdated cache). This replaces it with
1406 # the current value.
1403 # the current value.
1407 if isinstance(value, SyntaxError) \
1404 if isinstance(value, SyntaxError) \
1408 and isinstance(value.filename, str) \
1405 and isinstance(value.filename, str) \
1409 and isinstance(value.lineno, int):
1406 and isinstance(value.lineno, int):
1410 linecache.checkcache(value.filename)
1407 linecache.checkcache(value.filename)
1411 newtext = linecache.getline(value.filename, value.lineno)
1408 newtext = linecache.getline(value.filename, value.lineno)
1412 if newtext:
1409 if newtext:
1413 value.text = newtext
1410 value.text = newtext
1414 self.last_syntax_error = value
1411 self.last_syntax_error = value
1415 return super(SyntaxTB, self).structured_traceback(etype, value, elist,
1412 return super(SyntaxTB, self).structured_traceback(etype, value, elist,
1416 tb_offset=tb_offset, context=context)
1413 tb_offset=tb_offset, context=context)
1417
1414
1418 def clear_err_state(self):
1415 def clear_err_state(self):
1419 """Return the current error state and clear it"""
1416 """Return the current error state and clear it"""
1420 e = self.last_syntax_error
1417 e = self.last_syntax_error
1421 self.last_syntax_error = None
1418 self.last_syntax_error = None
1422 return e
1419 return e
1423
1420
1424 def stb2text(self, stb):
1421 def stb2text(self, stb):
1425 """Convert a structured traceback (a list) to a string."""
1422 """Convert a structured traceback (a list) to a string."""
1426 return ''.join(stb)
1423 return ''.join(stb)
1427
1424
1428
1425
1429 # some internal-use functions
1426 # some internal-use functions
1430 def text_repr(value):
1427 def text_repr(value):
1431 """Hopefully pretty robust repr equivalent."""
1428 """Hopefully pretty robust repr equivalent."""
1432 # this is pretty horrible but should always return *something*
1429 # this is pretty horrible but should always return *something*
1433 try:
1430 try:
1434 return pydoc.text.repr(value)
1431 return pydoc.text.repr(value)
1435 except KeyboardInterrupt:
1432 except KeyboardInterrupt:
1436 raise
1433 raise
1437 except:
1434 except:
1438 try:
1435 try:
1439 return repr(value)
1436 return repr(value)
1440 except KeyboardInterrupt:
1437 except KeyboardInterrupt:
1441 raise
1438 raise
1442 except:
1439 except:
1443 try:
1440 try:
1444 # all still in an except block so we catch
1441 # all still in an except block so we catch
1445 # getattr raising
1442 # getattr raising
1446 name = getattr(value, '__name__', None)
1443 name = getattr(value, '__name__', None)
1447 if name:
1444 if name:
1448 # ick, recursion
1445 # ick, recursion
1449 return text_repr(name)
1446 return text_repr(name)
1450 klass = getattr(value, '__class__', None)
1447 klass = getattr(value, '__class__', None)
1451 if klass:
1448 if klass:
1452 return '%s instance' % text_repr(klass)
1449 return '%s instance' % text_repr(klass)
1453 except KeyboardInterrupt:
1450 except KeyboardInterrupt:
1454 raise
1451 raise
1455 except:
1452 except:
1456 return 'UNRECOVERABLE REPR FAILURE'
1453 return 'UNRECOVERABLE REPR FAILURE'
1457
1454
1458
1455
1459 def eqrepr(value, repr=text_repr):
1456 def eqrepr(value, repr=text_repr):
1460 return '=%s' % repr(value)
1457 return '=%s' % repr(value)
1461
1458
1462
1459
1463 def nullrepr(value, repr=text_repr):
1460 def nullrepr(value, repr=text_repr):
1464 return ''
1461 return ''
@@ -1,525 +1,524
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 from importlib import import_module
118 from importlib import import_module
119 from IPython.utils.py3compat import PY3
120 from imp import reload
119 from imp import reload
121
120
122 from IPython.utils import openpy
121 from IPython.utils import openpy
123
122
124 #------------------------------------------------------------------------------
123 #------------------------------------------------------------------------------
125 # Autoreload functionality
124 # Autoreload functionality
126 #------------------------------------------------------------------------------
125 #------------------------------------------------------------------------------
127
126
128 class ModuleReloader(object):
127 class ModuleReloader(object):
129 enabled = False
128 enabled = False
130 """Whether this reloader is enabled"""
129 """Whether this reloader is enabled"""
131
130
132 check_all = True
131 check_all = True
133 """Autoreload all modules, not just those listed in 'modules'"""
132 """Autoreload all modules, not just those listed in 'modules'"""
134
133
135 def __init__(self):
134 def __init__(self):
136 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
135 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
137 self.failed = {}
136 self.failed = {}
138 # Modules specially marked as autoreloadable.
137 # Modules specially marked as autoreloadable.
139 self.modules = {}
138 self.modules = {}
140 # Modules specially marked as not autoreloadable.
139 # Modules specially marked as not autoreloadable.
141 self.skip_modules = {}
140 self.skip_modules = {}
142 # (module-name, name) -> weakref, for replacing old code objects
141 # (module-name, name) -> weakref, for replacing old code objects
143 self.old_objects = {}
142 self.old_objects = {}
144 # Module modification timestamps
143 # Module modification timestamps
145 self.modules_mtimes = {}
144 self.modules_mtimes = {}
146
145
147 # Cache module modification times
146 # Cache module modification times
148 self.check(check_all=True, do_reload=False)
147 self.check(check_all=True, do_reload=False)
149
148
150 def mark_module_skipped(self, module_name):
149 def mark_module_skipped(self, module_name):
151 """Skip reloading the named module in the future"""
150 """Skip reloading the named module in the future"""
152 try:
151 try:
153 del self.modules[module_name]
152 del self.modules[module_name]
154 except KeyError:
153 except KeyError:
155 pass
154 pass
156 self.skip_modules[module_name] = True
155 self.skip_modules[module_name] = True
157
156
158 def mark_module_reloadable(self, module_name):
157 def mark_module_reloadable(self, module_name):
159 """Reload the named module in the future (if it is imported)"""
158 """Reload the named module in the future (if it is imported)"""
160 try:
159 try:
161 del self.skip_modules[module_name]
160 del self.skip_modules[module_name]
162 except KeyError:
161 except KeyError:
163 pass
162 pass
164 self.modules[module_name] = True
163 self.modules[module_name] = True
165
164
166 def aimport_module(self, module_name):
165 def aimport_module(self, module_name):
167 """Import a module, and mark it reloadable
166 """Import a module, and mark it reloadable
168
167
169 Returns
168 Returns
170 -------
169 -------
171 top_module : module
170 top_module : module
172 The imported module if it is top-level, or the top-level
171 The imported module if it is top-level, or the top-level
173 top_name : module
172 top_name : module
174 Name of top_module
173 Name of top_module
175
174
176 """
175 """
177 self.mark_module_reloadable(module_name)
176 self.mark_module_reloadable(module_name)
178
177
179 import_module(module_name)
178 import_module(module_name)
180 top_name = module_name.split('.')[0]
179 top_name = module_name.split('.')[0]
181 top_module = sys.modules[top_name]
180 top_module = sys.modules[top_name]
182 return top_module, top_name
181 return top_module, top_name
183
182
184 def filename_and_mtime(self, module):
183 def filename_and_mtime(self, module):
185 if not hasattr(module, '__file__') or module.__file__ is None:
184 if not hasattr(module, '__file__') or module.__file__ is None:
186 return None, None
185 return None, None
187
186
188 if getattr(module, '__name__', None) in ['__mp_main__', '__main__']:
187 if getattr(module, '__name__', None) in ['__mp_main__', '__main__']:
189 # we cannot reload(__main__) or reload(__mp_main__)
188 # we cannot reload(__main__) or reload(__mp_main__)
190 return None, None
189 return None, None
191
190
192 filename = module.__file__
191 filename = module.__file__
193 path, ext = os.path.splitext(filename)
192 path, ext = os.path.splitext(filename)
194
193
195 if ext.lower() == '.py':
194 if ext.lower() == '.py':
196 py_filename = filename
195 py_filename = filename
197 else:
196 else:
198 try:
197 try:
199 py_filename = openpy.source_from_cache(filename)
198 py_filename = openpy.source_from_cache(filename)
200 except ValueError:
199 except ValueError:
201 return None, None
200 return None, None
202
201
203 try:
202 try:
204 pymtime = os.stat(py_filename).st_mtime
203 pymtime = os.stat(py_filename).st_mtime
205 except OSError:
204 except OSError:
206 return None, None
205 return None, None
207
206
208 return py_filename, pymtime
207 return py_filename, pymtime
209
208
210 def check(self, check_all=False, do_reload=True):
209 def check(self, check_all=False, do_reload=True):
211 """Check whether some modules need to be reloaded."""
210 """Check whether some modules need to be reloaded."""
212
211
213 if not self.enabled and not check_all:
212 if not self.enabled and not check_all:
214 return
213 return
215
214
216 if check_all or self.check_all:
215 if check_all or self.check_all:
217 modules = list(sys.modules.keys())
216 modules = list(sys.modules.keys())
218 else:
217 else:
219 modules = list(self.modules.keys())
218 modules = list(self.modules.keys())
220
219
221 for modname in modules:
220 for modname in modules:
222 m = sys.modules.get(modname, None)
221 m = sys.modules.get(modname, None)
223
222
224 if modname in self.skip_modules:
223 if modname in self.skip_modules:
225 continue
224 continue
226
225
227 py_filename, pymtime = self.filename_and_mtime(m)
226 py_filename, pymtime = self.filename_and_mtime(m)
228 if py_filename is None:
227 if py_filename is None:
229 continue
228 continue
230
229
231 try:
230 try:
232 if pymtime <= self.modules_mtimes[modname]:
231 if pymtime <= self.modules_mtimes[modname]:
233 continue
232 continue
234 except KeyError:
233 except KeyError:
235 self.modules_mtimes[modname] = pymtime
234 self.modules_mtimes[modname] = pymtime
236 continue
235 continue
237 else:
236 else:
238 if self.failed.get(py_filename, None) == pymtime:
237 if self.failed.get(py_filename, None) == pymtime:
239 continue
238 continue
240
239
241 self.modules_mtimes[modname] = pymtime
240 self.modules_mtimes[modname] = pymtime
242
241
243 # If we've reached this point, we should try to reload the module
242 # If we've reached this point, we should try to reload the module
244 if do_reload:
243 if do_reload:
245 try:
244 try:
246 superreload(m, reload, self.old_objects)
245 superreload(m, reload, self.old_objects)
247 if py_filename in self.failed:
246 if py_filename in self.failed:
248 del self.failed[py_filename]
247 del self.failed[py_filename]
249 except:
248 except:
250 print("[autoreload of %s failed: %s]" % (
249 print("[autoreload of %s failed: %s]" % (
251 modname, traceback.format_exc(10)), file=sys.stderr)
250 modname, traceback.format_exc(10)), file=sys.stderr)
252 self.failed[py_filename] = pymtime
251 self.failed[py_filename] = pymtime
253
252
254 #------------------------------------------------------------------------------
253 #------------------------------------------------------------------------------
255 # superreload
254 # superreload
256 #------------------------------------------------------------------------------
255 #------------------------------------------------------------------------------
257
256
258
257
259 func_attrs = ['__code__', '__defaults__', '__doc__',
258 func_attrs = ['__code__', '__defaults__', '__doc__',
260 '__closure__', '__globals__', '__dict__']
259 '__closure__', '__globals__', '__dict__']
261
260
262
261
263 def update_function(old, new):
262 def update_function(old, new):
264 """Upgrade the code object of a function"""
263 """Upgrade the code object of a function"""
265 for name in func_attrs:
264 for name in func_attrs:
266 try:
265 try:
267 setattr(old, name, getattr(new, name))
266 setattr(old, name, getattr(new, name))
268 except (AttributeError, TypeError):
267 except (AttributeError, TypeError):
269 pass
268 pass
270
269
271
270
272 def update_class(old, new):
271 def update_class(old, new):
273 """Replace stuff in the __dict__ of a class, and upgrade
272 """Replace stuff in the __dict__ of a class, and upgrade
274 method code objects"""
273 method code objects"""
275 for key in list(old.__dict__.keys()):
274 for key in list(old.__dict__.keys()):
276 old_obj = getattr(old, key)
275 old_obj = getattr(old, key)
277 try:
276 try:
278 new_obj = getattr(new, key)
277 new_obj = getattr(new, key)
279 if old_obj == new_obj:
278 if old_obj == new_obj:
280 continue
279 continue
281 except AttributeError:
280 except AttributeError:
282 # obsolete attribute: remove it
281 # obsolete attribute: remove it
283 try:
282 try:
284 delattr(old, key)
283 delattr(old, key)
285 except (AttributeError, TypeError):
284 except (AttributeError, TypeError):
286 pass
285 pass
287 continue
286 continue
288
287
289 if update_generic(old_obj, new_obj): continue
288 if update_generic(old_obj, new_obj): continue
290
289
291 try:
290 try:
292 setattr(old, key, getattr(new, key))
291 setattr(old, key, getattr(new, key))
293 except (AttributeError, TypeError):
292 except (AttributeError, TypeError):
294 pass # skip non-writable attributes
293 pass # skip non-writable attributes
295
294
296
295
297 def update_property(old, new):
296 def update_property(old, new):
298 """Replace get/set/del functions of a property"""
297 """Replace get/set/del functions of a property"""
299 update_generic(old.fdel, new.fdel)
298 update_generic(old.fdel, new.fdel)
300 update_generic(old.fget, new.fget)
299 update_generic(old.fget, new.fget)
301 update_generic(old.fset, new.fset)
300 update_generic(old.fset, new.fset)
302
301
303
302
304 def isinstance2(a, b, typ):
303 def isinstance2(a, b, typ):
305 return isinstance(a, typ) and isinstance(b, typ)
304 return isinstance(a, typ) and isinstance(b, typ)
306
305
307
306
308 UPDATE_RULES = [
307 UPDATE_RULES = [
309 (lambda a, b: isinstance2(a, b, type),
308 (lambda a, b: isinstance2(a, b, type),
310 update_class),
309 update_class),
311 (lambda a, b: isinstance2(a, b, types.FunctionType),
310 (lambda a, b: isinstance2(a, b, types.FunctionType),
312 update_function),
311 update_function),
313 (lambda a, b: isinstance2(a, b, property),
312 (lambda a, b: isinstance2(a, b, property),
314 update_property),
313 update_property),
315 ]
314 ]
316 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
315 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
317 lambda a, b: update_function(a.__func__, b.__func__)),
316 lambda a, b: update_function(a.__func__, b.__func__)),
318 ])
317 ])
319
318
320
319
321 def update_generic(a, b):
320 def update_generic(a, b):
322 for type_check, update in UPDATE_RULES:
321 for type_check, update in UPDATE_RULES:
323 if type_check(a, b):
322 if type_check(a, b):
324 update(a, b)
323 update(a, b)
325 return True
324 return True
326 return False
325 return False
327
326
328
327
329 class StrongRef(object):
328 class StrongRef(object):
330 def __init__(self, obj):
329 def __init__(self, obj):
331 self.obj = obj
330 self.obj = obj
332 def __call__(self):
331 def __call__(self):
333 return self.obj
332 return self.obj
334
333
335
334
336 def superreload(module, reload=reload, old_objects={}):
335 def superreload(module, reload=reload, old_objects={}):
337 """Enhanced version of the builtin reload function.
336 """Enhanced version of the builtin reload function.
338
337
339 superreload remembers objects previously in the module, and
338 superreload remembers objects previously in the module, and
340
339
341 - upgrades the class dictionary of every old class in the module
340 - upgrades the class dictionary of every old class in the module
342 - upgrades the code object of every old function and method
341 - upgrades the code object of every old function and method
343 - clears the module's namespace before reloading
342 - clears the module's namespace before reloading
344
343
345 """
344 """
346
345
347 # collect old objects in the module
346 # collect old objects in the module
348 for name, obj in list(module.__dict__.items()):
347 for name, obj in list(module.__dict__.items()):
349 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
348 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
350 continue
349 continue
351 key = (module.__name__, name)
350 key = (module.__name__, name)
352 try:
351 try:
353 old_objects.setdefault(key, []).append(weakref.ref(obj))
352 old_objects.setdefault(key, []).append(weakref.ref(obj))
354 except TypeError:
353 except TypeError:
355 pass
354 pass
356
355
357 # reload module
356 # reload module
358 try:
357 try:
359 # clear namespace first from old cruft
358 # clear namespace first from old cruft
360 old_dict = module.__dict__.copy()
359 old_dict = module.__dict__.copy()
361 old_name = module.__name__
360 old_name = module.__name__
362 module.__dict__.clear()
361 module.__dict__.clear()
363 module.__dict__['__name__'] = old_name
362 module.__dict__['__name__'] = old_name
364 module.__dict__['__loader__'] = old_dict['__loader__']
363 module.__dict__['__loader__'] = old_dict['__loader__']
365 except (TypeError, AttributeError, KeyError):
364 except (TypeError, AttributeError, KeyError):
366 pass
365 pass
367
366
368 try:
367 try:
369 module = reload(module)
368 module = reload(module)
370 except:
369 except:
371 # restore module dictionary on failed reload
370 # restore module dictionary on failed reload
372 module.__dict__.update(old_dict)
371 module.__dict__.update(old_dict)
373 raise
372 raise
374
373
375 # iterate over all objects and update functions & classes
374 # iterate over all objects and update functions & classes
376 for name, new_obj in list(module.__dict__.items()):
375 for name, new_obj in list(module.__dict__.items()):
377 key = (module.__name__, name)
376 key = (module.__name__, name)
378 if key not in old_objects: continue
377 if key not in old_objects: continue
379
378
380 new_refs = []
379 new_refs = []
381 for old_ref in old_objects[key]:
380 for old_ref in old_objects[key]:
382 old_obj = old_ref()
381 old_obj = old_ref()
383 if old_obj is None: continue
382 if old_obj is None: continue
384 new_refs.append(old_ref)
383 new_refs.append(old_ref)
385 update_generic(old_obj, new_obj)
384 update_generic(old_obj, new_obj)
386
385
387 if new_refs:
386 if new_refs:
388 old_objects[key] = new_refs
387 old_objects[key] = new_refs
389 else:
388 else:
390 del old_objects[key]
389 del old_objects[key]
391
390
392 return module
391 return module
393
392
394 #------------------------------------------------------------------------------
393 #------------------------------------------------------------------------------
395 # IPython connectivity
394 # IPython connectivity
396 #------------------------------------------------------------------------------
395 #------------------------------------------------------------------------------
397
396
398 from IPython.core.magic import Magics, magics_class, line_magic
397 from IPython.core.magic import Magics, magics_class, line_magic
399
398
400 @magics_class
399 @magics_class
401 class AutoreloadMagics(Magics):
400 class AutoreloadMagics(Magics):
402 def __init__(self, *a, **kw):
401 def __init__(self, *a, **kw):
403 super(AutoreloadMagics, self).__init__(*a, **kw)
402 super(AutoreloadMagics, self).__init__(*a, **kw)
404 self._reloader = ModuleReloader()
403 self._reloader = ModuleReloader()
405 self._reloader.check_all = False
404 self._reloader.check_all = False
406 self.loaded_modules = set(sys.modules)
405 self.loaded_modules = set(sys.modules)
407
406
408 @line_magic
407 @line_magic
409 def autoreload(self, parameter_s=''):
408 def autoreload(self, parameter_s=''):
410 r"""%autoreload => Reload modules automatically
409 r"""%autoreload => Reload modules automatically
411
410
412 %autoreload
411 %autoreload
413 Reload all modules (except those excluded by %aimport) automatically
412 Reload all modules (except those excluded by %aimport) automatically
414 now.
413 now.
415
414
416 %autoreload 0
415 %autoreload 0
417 Disable automatic reloading.
416 Disable automatic reloading.
418
417
419 %autoreload 1
418 %autoreload 1
420 Reload all modules imported with %aimport every time before executing
419 Reload all modules imported with %aimport every time before executing
421 the Python code typed.
420 the Python code typed.
422
421
423 %autoreload 2
422 %autoreload 2
424 Reload all modules (except those excluded by %aimport) every time
423 Reload all modules (except those excluded by %aimport) every time
425 before executing the Python code typed.
424 before executing the Python code typed.
426
425
427 Reloading Python modules in a reliable way is in general
426 Reloading Python modules in a reliable way is in general
428 difficult, and unexpected things may occur. %autoreload tries to
427 difficult, and unexpected things may occur. %autoreload tries to
429 work around common pitfalls by replacing function code objects and
428 work around common pitfalls by replacing function code objects and
430 parts of classes previously in the module with new versions. This
429 parts of classes previously in the module with new versions. This
431 makes the following things to work:
430 makes the following things to work:
432
431
433 - Functions and classes imported via 'from xxx import foo' are upgraded
432 - Functions and classes imported via 'from xxx import foo' are upgraded
434 to new versions when 'xxx' is reloaded.
433 to new versions when 'xxx' is reloaded.
435
434
436 - Methods and properties of classes are upgraded on reload, so that
435 - Methods and properties of classes are upgraded on reload, so that
437 calling 'c.foo()' on an object 'c' created before the reload causes
436 calling 'c.foo()' on an object 'c' created before the reload causes
438 the new code for 'foo' to be executed.
437 the new code for 'foo' to be executed.
439
438
440 Some of the known remaining caveats are:
439 Some of the known remaining caveats are:
441
440
442 - Replacing code objects does not always succeed: changing a @property
441 - Replacing code objects does not always succeed: changing a @property
443 in a class to an ordinary method or a method to a member variable
442 in a class to an ordinary method or a method to a member variable
444 can cause problems (but in old objects only).
443 can cause problems (but in old objects only).
445
444
446 - Functions that are removed (eg. via monkey-patching) from a module
445 - Functions that are removed (eg. via monkey-patching) from a module
447 before it is reloaded are not upgraded.
446 before it is reloaded are not upgraded.
448
447
449 - C extension modules cannot be reloaded, and so cannot be
448 - C extension modules cannot be reloaded, and so cannot be
450 autoreloaded.
449 autoreloaded.
451
450
452 """
451 """
453 if parameter_s == '':
452 if parameter_s == '':
454 self._reloader.check(True)
453 self._reloader.check(True)
455 elif parameter_s == '0':
454 elif parameter_s == '0':
456 self._reloader.enabled = False
455 self._reloader.enabled = False
457 elif parameter_s == '1':
456 elif parameter_s == '1':
458 self._reloader.check_all = False
457 self._reloader.check_all = False
459 self._reloader.enabled = True
458 self._reloader.enabled = True
460 elif parameter_s == '2':
459 elif parameter_s == '2':
461 self._reloader.check_all = True
460 self._reloader.check_all = True
462 self._reloader.enabled = True
461 self._reloader.enabled = True
463
462
464 @line_magic
463 @line_magic
465 def aimport(self, parameter_s='', stream=None):
464 def aimport(self, parameter_s='', stream=None):
466 """%aimport => Import modules for automatic reloading.
465 """%aimport => Import modules for automatic reloading.
467
466
468 %aimport
467 %aimport
469 List modules to automatically import and not to import.
468 List modules to automatically import and not to import.
470
469
471 %aimport foo
470 %aimport foo
472 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
471 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
473
472
474 %aimport foo, bar
473 %aimport foo, bar
475 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
474 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
476
475
477 %aimport -foo
476 %aimport -foo
478 Mark module 'foo' to not be autoreloaded for %autoreload 1
477 Mark module 'foo' to not be autoreloaded for %autoreload 1
479 """
478 """
480 modname = parameter_s
479 modname = parameter_s
481 if not modname:
480 if not modname:
482 to_reload = sorted(self._reloader.modules.keys())
481 to_reload = sorted(self._reloader.modules.keys())
483 to_skip = sorted(self._reloader.skip_modules.keys())
482 to_skip = sorted(self._reloader.skip_modules.keys())
484 if stream is None:
483 if stream is None:
485 stream = sys.stdout
484 stream = sys.stdout
486 if self._reloader.check_all:
485 if self._reloader.check_all:
487 stream.write("Modules to reload:\nall-except-skipped\n")
486 stream.write("Modules to reload:\nall-except-skipped\n")
488 else:
487 else:
489 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
488 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
490 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
489 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
491 elif modname.startswith('-'):
490 elif modname.startswith('-'):
492 modname = modname[1:]
491 modname = modname[1:]
493 self._reloader.mark_module_skipped(modname)
492 self._reloader.mark_module_skipped(modname)
494 else:
493 else:
495 for _module in ([_.strip() for _ in modname.split(',')]):
494 for _module in ([_.strip() for _ in modname.split(',')]):
496 top_module, top_name = self._reloader.aimport_module(_module)
495 top_module, top_name = self._reloader.aimport_module(_module)
497
496
498 # Inject module to user namespace
497 # Inject module to user namespace
499 self.shell.push({top_name: top_module})
498 self.shell.push({top_name: top_module})
500
499
501 def pre_run_cell(self):
500 def pre_run_cell(self):
502 if self._reloader.enabled:
501 if self._reloader.enabled:
503 try:
502 try:
504 self._reloader.check()
503 self._reloader.check()
505 except:
504 except:
506 pass
505 pass
507
506
508 def post_execute_hook(self):
507 def post_execute_hook(self):
509 """Cache the modification times of any modules imported in this execution
508 """Cache the modification times of any modules imported in this execution
510 """
509 """
511 newly_loaded_modules = set(sys.modules) - self.loaded_modules
510 newly_loaded_modules = set(sys.modules) - self.loaded_modules
512 for modname in newly_loaded_modules:
511 for modname in newly_loaded_modules:
513 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
512 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
514 if pymtime is not None:
513 if pymtime is not None:
515 self._reloader.modules_mtimes[modname] = pymtime
514 self._reloader.modules_mtimes[modname] = pymtime
516
515
517 self.loaded_modules.update(newly_loaded_modules)
516 self.loaded_modules.update(newly_loaded_modules)
518
517
519
518
520 def load_ipython_extension(ip):
519 def load_ipython_extension(ip):
521 """Load the extension in IPython."""
520 """Load the extension in IPython."""
522 auto_reload = AutoreloadMagics(ip)
521 auto_reload = AutoreloadMagics(ip)
523 ip.register_magics(auto_reload)
522 ip.register_magics(auto_reload)
524 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
523 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
525 ip.events.register('post_execute', auto_reload.post_execute_hook)
524 ip.events.register('post_execute', auto_reload.post_execute_hook)
@@ -1,114 +1,114
1 """
1 """
2 Password generation for the IPython notebook.
2 Password generation for the IPython notebook.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Imports
5 # Imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Stdlib
7 # Stdlib
8 import getpass
8 import getpass
9 import hashlib
9 import hashlib
10 import random
10 import random
11
11
12 # Our own
12 # Our own
13 from IPython.core.error import UsageError
13 from IPython.core.error import UsageError
14 from IPython.utils.py3compat import cast_bytes, str_to_bytes
14 from IPython.utils.py3compat import encode
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Globals
17 # Globals
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 # Length of the salt in nr of hex chars, which implies salt_len * 4
20 # Length of the salt in nr of hex chars, which implies salt_len * 4
21 # bits of randomness.
21 # bits of randomness.
22 salt_len = 12
22 salt_len = 12
23
23
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25 # Functions
25 # Functions
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27
27
28 def passwd(passphrase=None, algorithm='sha1'):
28 def passwd(passphrase=None, algorithm='sha1'):
29 """Generate hashed password and salt for use in notebook configuration.
29 """Generate hashed password and salt for use in notebook configuration.
30
30
31 In the notebook configuration, set `c.NotebookApp.password` to
31 In the notebook configuration, set `c.NotebookApp.password` to
32 the generated string.
32 the generated string.
33
33
34 Parameters
34 Parameters
35 ----------
35 ----------
36 passphrase : str
36 passphrase : str
37 Password to hash. If unspecified, the user is asked to input
37 Password to hash. If unspecified, the user is asked to input
38 and verify a password.
38 and verify a password.
39 algorithm : str
39 algorithm : str
40 Hashing algorithm to use (e.g, 'sha1' or any argument supported
40 Hashing algorithm to use (e.g, 'sha1' or any argument supported
41 by :func:`hashlib.new`).
41 by :func:`hashlib.new`).
42
42
43 Returns
43 Returns
44 -------
44 -------
45 hashed_passphrase : str
45 hashed_passphrase : str
46 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
46 Hashed password, in the format 'hash_algorithm:salt:passphrase_hash'.
47
47
48 Examples
48 Examples
49 --------
49 --------
50 >>> passwd('mypassword')
50 >>> passwd('mypassword')
51 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
51 'sha1:7cf3:b7d6da294ea9592a9480c8f52e63cd42cfb9dd12'
52
52
53 """
53 """
54 if passphrase is None:
54 if passphrase is None:
55 for i in range(3):
55 for i in range(3):
56 p0 = getpass.getpass('Enter password: ')
56 p0 = getpass.getpass('Enter password: ')
57 p1 = getpass.getpass('Verify password: ')
57 p1 = getpass.getpass('Verify password: ')
58 if p0 == p1:
58 if p0 == p1:
59 passphrase = p0
59 passphrase = p0
60 break
60 break
61 else:
61 else:
62 print('Passwords do not match.')
62 print('Passwords do not match.')
63 else:
63 else:
64 raise UsageError('No matching passwords found. Giving up.')
64 raise UsageError('No matching passwords found. Giving up.')
65
65
66 h = hashlib.new(algorithm)
66 h = hashlib.new(algorithm)
67 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
67 salt = ('%0' + str(salt_len) + 'x') % random.getrandbits(4 * salt_len)
68 h.update(cast_bytes(passphrase, 'utf-8') + str_to_bytes(salt, 'ascii'))
68 h.update(encode(passphrase, 'utf-8') + encode(salt, 'ascii'))
69
69
70 return ':'.join((algorithm, salt, h.hexdigest()))
70 return ':'.join((algorithm, salt, h.hexdigest()))
71
71
72
72
73 def passwd_check(hashed_passphrase, passphrase):
73 def passwd_check(hashed_passphrase, passphrase):
74 """Verify that a given passphrase matches its hashed version.
74 """Verify that a given passphrase matches its hashed version.
75
75
76 Parameters
76 Parameters
77 ----------
77 ----------
78 hashed_passphrase : str
78 hashed_passphrase : str
79 Hashed password, in the format returned by `passwd`.
79 Hashed password, in the format returned by `passwd`.
80 passphrase : str
80 passphrase : str
81 Passphrase to validate.
81 Passphrase to validate.
82
82
83 Returns
83 Returns
84 -------
84 -------
85 valid : bool
85 valid : bool
86 True if the passphrase matches the hash.
86 True if the passphrase matches the hash.
87
87
88 Examples
88 Examples
89 --------
89 --------
90 >>> from IPython.lib.security import passwd_check
90 >>> from IPython.lib.security import passwd_check
91 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
91 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
92 ... 'mypassword')
92 ... 'mypassword')
93 True
93 True
94
94
95 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
95 >>> passwd_check('sha1:0e112c3ddfce:a68df677475c2b47b6e86d0467eec97ac5f4b85a',
96 ... 'anotherpassword')
96 ... 'anotherpassword')
97 False
97 False
98 """
98 """
99 try:
99 try:
100 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
100 algorithm, salt, pw_digest = hashed_passphrase.split(':', 2)
101 except (ValueError, TypeError):
101 except (ValueError, TypeError):
102 return False
102 return False
103
103
104 try:
104 try:
105 h = hashlib.new(algorithm)
105 h = hashlib.new(algorithm)
106 except ValueError:
106 except ValueError:
107 return False
107 return False
108
108
109 if len(pw_digest) == 0:
109 if len(pw_digest) == 0:
110 return False
110 return False
111
111
112 h.update(cast_bytes(passphrase, 'utf-8') + cast_bytes(salt, 'ascii'))
112 h.update(encode(passphrase, 'utf-8') + encode(salt, 'ascii'))
113
113
114 return h.hexdigest() == pw_digest
114 return h.hexdigest() == pw_digest
@@ -1,540 +1,540
1 """IPython terminal interface using prompt_toolkit"""
1 """IPython terminal interface using prompt_toolkit"""
2
2
3 import os
3 import os
4 import sys
4 import sys
5 import warnings
5 import warnings
6 from warnings import warn
6 from warnings import warn
7
7
8 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
8 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
9 from IPython.utils import io
9 from IPython.utils import io
10 from IPython.utils.py3compat import input, cast_unicode_py2
10 from IPython.utils.py3compat import input
11 from IPython.utils.terminal import toggle_set_term_title, set_term_title
11 from IPython.utils.terminal import toggle_set_term_title, set_term_title
12 from IPython.utils.process import abbrev_cwd
12 from IPython.utils.process import abbrev_cwd
13 from traitlets import (
13 from traitlets import (
14 Bool, Unicode, Dict, Integer, observe, Instance, Type, default, Enum, Union,
14 Bool, Unicode, Dict, Integer, observe, Instance, Type, default, Enum, Union,
15 Any,
15 Any,
16 )
16 )
17
17
18 from prompt_toolkit.document import Document
18 from prompt_toolkit.document import Document
19 from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
19 from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
20 from prompt_toolkit.filters import (HasFocus, Condition, IsDone)
20 from prompt_toolkit.filters import (HasFocus, Condition, IsDone)
21 from prompt_toolkit.history import InMemoryHistory
21 from prompt_toolkit.history import InMemoryHistory
22 from prompt_toolkit.shortcuts import create_prompt_application, create_eventloop, create_prompt_layout, create_output
22 from prompt_toolkit.shortcuts import create_prompt_application, create_eventloop, create_prompt_layout, create_output
23 from prompt_toolkit.interface import CommandLineInterface
23 from prompt_toolkit.interface import CommandLineInterface
24 from prompt_toolkit.key_binding.manager import KeyBindingManager
24 from prompt_toolkit.key_binding.manager import KeyBindingManager
25 from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
25 from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
26 from prompt_toolkit.styles import PygmentsStyle, DynamicStyle
26 from prompt_toolkit.styles import PygmentsStyle, DynamicStyle
27
27
28 from pygments.styles import get_style_by_name
28 from pygments.styles import get_style_by_name
29 from pygments.style import Style
29 from pygments.style import Style
30 from pygments.token import Token
30 from pygments.token import Token
31
31
32 from .debugger import TerminalPdb, Pdb
32 from .debugger import TerminalPdb, Pdb
33 from .magics import TerminalMagics
33 from .magics import TerminalMagics
34 from .pt_inputhooks import get_inputhook_name_and_func
34 from .pt_inputhooks import get_inputhook_name_and_func
35 from .prompts import Prompts, ClassicPrompts, RichPromptDisplayHook
35 from .prompts import Prompts, ClassicPrompts, RichPromptDisplayHook
36 from .ptutils import IPythonPTCompleter, IPythonPTLexer
36 from .ptutils import IPythonPTCompleter, IPythonPTLexer
37 from .shortcuts import register_ipython_shortcuts
37 from .shortcuts import register_ipython_shortcuts
38
38
39 DISPLAY_BANNER_DEPRECATED = object()
39 DISPLAY_BANNER_DEPRECATED = object()
40
40
41
41
42 class _NoStyle(Style): pass
42 class _NoStyle(Style): pass
43
43
44
44
45
45
46 _style_overrides_light_bg = {
46 _style_overrides_light_bg = {
47 Token.Prompt: '#0000ff',
47 Token.Prompt: '#0000ff',
48 Token.PromptNum: '#0000ee bold',
48 Token.PromptNum: '#0000ee bold',
49 Token.OutPrompt: '#cc0000',
49 Token.OutPrompt: '#cc0000',
50 Token.OutPromptNum: '#bb0000 bold',
50 Token.OutPromptNum: '#bb0000 bold',
51 }
51 }
52
52
53 _style_overrides_linux = {
53 _style_overrides_linux = {
54 Token.Prompt: '#00cc00',
54 Token.Prompt: '#00cc00',
55 Token.PromptNum: '#00bb00 bold',
55 Token.PromptNum: '#00bb00 bold',
56 Token.OutPrompt: '#cc0000',
56 Token.OutPrompt: '#cc0000',
57 Token.OutPromptNum: '#bb0000 bold',
57 Token.OutPromptNum: '#bb0000 bold',
58 }
58 }
59
59
60
60
61
61
62 def get_default_editor():
62 def get_default_editor():
63 try:
63 try:
64 return os.environ['EDITOR']
64 return os.environ['EDITOR']
65 except KeyError:
65 except KeyError:
66 pass
66 pass
67 except UnicodeError:
67 except UnicodeError:
68 warn("$EDITOR environment variable is not pure ASCII. Using platform "
68 warn("$EDITOR environment variable is not pure ASCII. Using platform "
69 "default editor.")
69 "default editor.")
70
70
71 if os.name == 'posix':
71 if os.name == 'posix':
72 return 'vi' # the only one guaranteed to be there!
72 return 'vi' # the only one guaranteed to be there!
73 else:
73 else:
74 return 'notepad' # same in Windows!
74 return 'notepad' # same in Windows!
75
75
76 # conservatively check for tty
76 # conservatively check for tty
77 # overridden streams can result in things like:
77 # overridden streams can result in things like:
78 # - sys.stdin = None
78 # - sys.stdin = None
79 # - no isatty method
79 # - no isatty method
80 for _name in ('stdin', 'stdout', 'stderr'):
80 for _name in ('stdin', 'stdout', 'stderr'):
81 _stream = getattr(sys, _name)
81 _stream = getattr(sys, _name)
82 if not _stream or not hasattr(_stream, 'isatty') or not _stream.isatty():
82 if not _stream or not hasattr(_stream, 'isatty') or not _stream.isatty():
83 _is_tty = False
83 _is_tty = False
84 break
84 break
85 else:
85 else:
86 _is_tty = True
86 _is_tty = True
87
87
88
88
89 _use_simple_prompt = ('IPY_TEST_SIMPLE_PROMPT' in os.environ) or (not _is_tty)
89 _use_simple_prompt = ('IPY_TEST_SIMPLE_PROMPT' in os.environ) or (not _is_tty)
90
90
91 class TerminalInteractiveShell(InteractiveShell):
91 class TerminalInteractiveShell(InteractiveShell):
92 space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
92 space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
93 'to reserve for the completion menu'
93 'to reserve for the completion menu'
94 ).tag(config=True)
94 ).tag(config=True)
95
95
96 def _space_for_menu_changed(self, old, new):
96 def _space_for_menu_changed(self, old, new):
97 self._update_layout()
97 self._update_layout()
98
98
99 pt_cli = None
99 pt_cli = None
100 debugger_history = None
100 debugger_history = None
101 _pt_app = None
101 _pt_app = None
102
102
103 simple_prompt = Bool(_use_simple_prompt,
103 simple_prompt = Bool(_use_simple_prompt,
104 help="""Use `raw_input` for the REPL, without completion and prompt colors.
104 help="""Use `raw_input` for the REPL, without completion and prompt colors.
105
105
106 Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
106 Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
107 IPython own testing machinery, and emacs inferior-shell integration through elpy.
107 IPython own testing machinery, and emacs inferior-shell integration through elpy.
108
108
109 This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
109 This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
110 environment variable is set, or the current terminal is not a tty."""
110 environment variable is set, or the current terminal is not a tty."""
111 ).tag(config=True)
111 ).tag(config=True)
112
112
113 @property
113 @property
114 def debugger_cls(self):
114 def debugger_cls(self):
115 return Pdb if self.simple_prompt else TerminalPdb
115 return Pdb if self.simple_prompt else TerminalPdb
116
116
117 confirm_exit = Bool(True,
117 confirm_exit = Bool(True,
118 help="""
118 help="""
119 Set to confirm when you try to exit IPython with an EOF (Control-D
119 Set to confirm when you try to exit IPython with an EOF (Control-D
120 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
120 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
121 you can force a direct exit without any confirmation.""",
121 you can force a direct exit without any confirmation.""",
122 ).tag(config=True)
122 ).tag(config=True)
123
123
124 editing_mode = Unicode('emacs',
124 editing_mode = Unicode('emacs',
125 help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
125 help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
126 ).tag(config=True)
126 ).tag(config=True)
127
127
128 mouse_support = Bool(False,
128 mouse_support = Bool(False,
129 help="Enable mouse support in the prompt\n(Note: prevents selecting text with the mouse)"
129 help="Enable mouse support in the prompt\n(Note: prevents selecting text with the mouse)"
130 ).tag(config=True)
130 ).tag(config=True)
131
131
132 # We don't load the list of styles for the help string, because loading
132 # We don't load the list of styles for the help string, because loading
133 # Pygments plugins takes time and can cause unexpected errors.
133 # Pygments plugins takes time and can cause unexpected errors.
134 highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
134 highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
135 help="""The name or class of a Pygments style to use for syntax
135 help="""The name or class of a Pygments style to use for syntax
136 highlighting. To see available styles, run `pygmentize -L styles`."""
136 highlighting. To see available styles, run `pygmentize -L styles`."""
137 ).tag(config=True)
137 ).tag(config=True)
138
138
139
139
140 @observe('highlighting_style')
140 @observe('highlighting_style')
141 @observe('colors')
141 @observe('colors')
142 def _highlighting_style_changed(self, change):
142 def _highlighting_style_changed(self, change):
143 self.refresh_style()
143 self.refresh_style()
144
144
145 def refresh_style(self):
145 def refresh_style(self):
146 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
146 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
147
147
148
148
149 highlighting_style_overrides = Dict(
149 highlighting_style_overrides = Dict(
150 help="Override highlighting format for specific tokens"
150 help="Override highlighting format for specific tokens"
151 ).tag(config=True)
151 ).tag(config=True)
152
152
153 true_color = Bool(False,
153 true_color = Bool(False,
154 help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
154 help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
155 "If your terminal supports true color, the following command "
155 "If your terminal supports true color, the following command "
156 "should print 'TRUECOLOR' in orange: "
156 "should print 'TRUECOLOR' in orange: "
157 "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
157 "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
158 ).tag(config=True)
158 ).tag(config=True)
159
159
160 editor = Unicode(get_default_editor(),
160 editor = Unicode(get_default_editor(),
161 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
161 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
162 ).tag(config=True)
162 ).tag(config=True)
163
163
164 prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)
164 prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)
165
165
166 prompts = Instance(Prompts)
166 prompts = Instance(Prompts)
167
167
168 @default('prompts')
168 @default('prompts')
169 def _prompts_default(self):
169 def _prompts_default(self):
170 return self.prompts_class(self)
170 return self.prompts_class(self)
171
171
172 @observe('prompts')
172 @observe('prompts')
173 def _(self, change):
173 def _(self, change):
174 self._update_layout()
174 self._update_layout()
175
175
176 @default('displayhook_class')
176 @default('displayhook_class')
177 def _displayhook_class_default(self):
177 def _displayhook_class_default(self):
178 return RichPromptDisplayHook
178 return RichPromptDisplayHook
179
179
180 term_title = Bool(True,
180 term_title = Bool(True,
181 help="Automatically set the terminal title"
181 help="Automatically set the terminal title"
182 ).tag(config=True)
182 ).tag(config=True)
183
183
184 term_title_format = Unicode("IPython: {cwd}",
184 term_title_format = Unicode("IPython: {cwd}",
185 help="Customize the terminal title format. This is a python format string. " +
185 help="Customize the terminal title format. This is a python format string. " +
186 "Available substitutions are: {cwd}."
186 "Available substitutions are: {cwd}."
187 ).tag(config=True)
187 ).tag(config=True)
188
188
189 display_completions = Enum(('column', 'multicolumn','readlinelike'),
189 display_completions = Enum(('column', 'multicolumn','readlinelike'),
190 help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
190 help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
191 "'readlinelike'. These options are for `prompt_toolkit`, see "
191 "'readlinelike'. These options are for `prompt_toolkit`, see "
192 "`prompt_toolkit` documentation for more information."
192 "`prompt_toolkit` documentation for more information."
193 ),
193 ),
194 default_value='multicolumn').tag(config=True)
194 default_value='multicolumn').tag(config=True)
195
195
196 highlight_matching_brackets = Bool(True,
196 highlight_matching_brackets = Bool(True,
197 help="Highlight matching brackets.",
197 help="Highlight matching brackets.",
198 ).tag(config=True)
198 ).tag(config=True)
199
199
200 extra_open_editor_shortcuts = Bool(False,
200 extra_open_editor_shortcuts = Bool(False,
201 help="Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
201 help="Enable vi (v) or Emacs (C-X C-E) shortcuts to open an external editor. "
202 "This is in addition to the F2 binding, which is always enabled."
202 "This is in addition to the F2 binding, which is always enabled."
203 ).tag(config=True)
203 ).tag(config=True)
204
204
205 handle_return = Any(None,
205 handle_return = Any(None,
206 help="Provide an alternative handler to be called when the user presses "
206 help="Provide an alternative handler to be called when the user presses "
207 "Return. This is an advanced option intended for debugging, which "
207 "Return. This is an advanced option intended for debugging, which "
208 "may be changed or removed in later releases."
208 "may be changed or removed in later releases."
209 ).tag(config=True)
209 ).tag(config=True)
210
210
211 @observe('term_title')
211 @observe('term_title')
212 def init_term_title(self, change=None):
212 def init_term_title(self, change=None):
213 # Enable or disable the terminal title.
213 # Enable or disable the terminal title.
214 if self.term_title:
214 if self.term_title:
215 toggle_set_term_title(True)
215 toggle_set_term_title(True)
216 set_term_title(self.term_title_format.format(cwd=abbrev_cwd()))
216 set_term_title(self.term_title_format.format(cwd=abbrev_cwd()))
217 else:
217 else:
218 toggle_set_term_title(False)
218 toggle_set_term_title(False)
219
219
220 def init_display_formatter(self):
220 def init_display_formatter(self):
221 super(TerminalInteractiveShell, self).init_display_formatter()
221 super(TerminalInteractiveShell, self).init_display_formatter()
222 # terminal only supports plain text
222 # terminal only supports plain text
223 self.display_formatter.active_types = ['text/plain']
223 self.display_formatter.active_types = ['text/plain']
224 # disable `_ipython_display_`
224 # disable `_ipython_display_`
225 self.display_formatter.ipython_display_formatter.enabled = False
225 self.display_formatter.ipython_display_formatter.enabled = False
226
226
227 def init_prompt_toolkit_cli(self):
227 def init_prompt_toolkit_cli(self):
228 if self.simple_prompt:
228 if self.simple_prompt:
229 # Fall back to plain non-interactive output for tests.
229 # Fall back to plain non-interactive output for tests.
230 # This is very limited, and only accepts a single line.
230 # This is very limited, and only accepts a single line.
231 def prompt():
231 def prompt():
232 isp = self.input_splitter
232 isp = self.input_splitter
233 prompt_text = "".join(x[1] for x in self.prompts.in_prompt_tokens())
233 prompt_text = "".join(x[1] for x in self.prompts.in_prompt_tokens())
234 prompt_continuation = "".join(x[1] for x in self.prompts.continuation_prompt_tokens())
234 prompt_continuation = "".join(x[1] for x in self.prompts.continuation_prompt_tokens())
235 while isp.push_accepts_more():
235 while isp.push_accepts_more():
236 line = cast_unicode_py2(input(prompt_text))
236 line = input(prompt_text)
237 isp.push(line)
237 isp.push(line)
238 prompt_text = prompt_continuation
238 prompt_text = prompt_continuation
239 return isp.source_reset()
239 return isp.source_reset()
240 self.prompt_for_code = prompt
240 self.prompt_for_code = prompt
241 return
241 return
242
242
243 # Set up keyboard shortcuts
243 # Set up keyboard shortcuts
244 kbmanager = KeyBindingManager.for_prompt(
244 kbmanager = KeyBindingManager.for_prompt(
245 enable_open_in_editor=self.extra_open_editor_shortcuts,
245 enable_open_in_editor=self.extra_open_editor_shortcuts,
246 )
246 )
247 register_ipython_shortcuts(kbmanager.registry, self)
247 register_ipython_shortcuts(kbmanager.registry, self)
248
248
249 # Pre-populate history from IPython's history database
249 # Pre-populate history from IPython's history database
250 history = InMemoryHistory()
250 history = InMemoryHistory()
251 last_cell = u""
251 last_cell = u""
252 for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
252 for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
253 include_latest=True):
253 include_latest=True):
254 # Ignore blank lines and consecutive duplicates
254 # Ignore blank lines and consecutive duplicates
255 cell = cell.rstrip()
255 cell = cell.rstrip()
256 if cell and (cell != last_cell):
256 if cell and (cell != last_cell):
257 history.append(cell)
257 history.append(cell)
258 last_cell = cell
258 last_cell = cell
259
259
260 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
260 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
261 self.style = DynamicStyle(lambda: self._style)
261 self.style = DynamicStyle(lambda: self._style)
262
262
263 editing_mode = getattr(EditingMode, self.editing_mode.upper())
263 editing_mode = getattr(EditingMode, self.editing_mode.upper())
264
264
265 def patch_stdout(**kwargs):
265 def patch_stdout(**kwargs):
266 return self.pt_cli.patch_stdout_context(**kwargs)
266 return self.pt_cli.patch_stdout_context(**kwargs)
267
267
268 self._pt_app = create_prompt_application(
268 self._pt_app = create_prompt_application(
269 editing_mode=editing_mode,
269 editing_mode=editing_mode,
270 key_bindings_registry=kbmanager.registry,
270 key_bindings_registry=kbmanager.registry,
271 history=history,
271 history=history,
272 completer=IPythonPTCompleter(shell=self,
272 completer=IPythonPTCompleter(shell=self,
273 patch_stdout=patch_stdout),
273 patch_stdout=patch_stdout),
274 enable_history_search=True,
274 enable_history_search=True,
275 style=self.style,
275 style=self.style,
276 mouse_support=self.mouse_support,
276 mouse_support=self.mouse_support,
277 **self._layout_options()
277 **self._layout_options()
278 )
278 )
279 self._eventloop = create_eventloop(self.inputhook)
279 self._eventloop = create_eventloop(self.inputhook)
280 self.pt_cli = CommandLineInterface(
280 self.pt_cli = CommandLineInterface(
281 self._pt_app, eventloop=self._eventloop,
281 self._pt_app, eventloop=self._eventloop,
282 output=create_output(true_color=self.true_color))
282 output=create_output(true_color=self.true_color))
283
283
284 def _make_style_from_name_or_cls(self, name_or_cls):
284 def _make_style_from_name_or_cls(self, name_or_cls):
285 """
285 """
286 Small wrapper that make an IPython compatible style from a style name
286 Small wrapper that make an IPython compatible style from a style name
287
287
288 We need that to add style for prompt ... etc.
288 We need that to add style for prompt ... etc.
289 """
289 """
290 style_overrides = {}
290 style_overrides = {}
291 if name_or_cls == 'legacy':
291 if name_or_cls == 'legacy':
292 legacy = self.colors.lower()
292 legacy = self.colors.lower()
293 if legacy == 'linux':
293 if legacy == 'linux':
294 style_cls = get_style_by_name('monokai')
294 style_cls = get_style_by_name('monokai')
295 style_overrides = _style_overrides_linux
295 style_overrides = _style_overrides_linux
296 elif legacy == 'lightbg':
296 elif legacy == 'lightbg':
297 style_overrides = _style_overrides_light_bg
297 style_overrides = _style_overrides_light_bg
298 style_cls = get_style_by_name('pastie')
298 style_cls = get_style_by_name('pastie')
299 elif legacy == 'neutral':
299 elif legacy == 'neutral':
300 # The default theme needs to be visible on both a dark background
300 # The default theme needs to be visible on both a dark background
301 # and a light background, because we can't tell what the terminal
301 # and a light background, because we can't tell what the terminal
302 # looks like. These tweaks to the default theme help with that.
302 # looks like. These tweaks to the default theme help with that.
303 style_cls = get_style_by_name('default')
303 style_cls = get_style_by_name('default')
304 style_overrides.update({
304 style_overrides.update({
305 Token.Number: '#007700',
305 Token.Number: '#007700',
306 Token.Operator: 'noinherit',
306 Token.Operator: 'noinherit',
307 Token.String: '#BB6622',
307 Token.String: '#BB6622',
308 Token.Name.Function: '#2080D0',
308 Token.Name.Function: '#2080D0',
309 Token.Name.Class: 'bold #2080D0',
309 Token.Name.Class: 'bold #2080D0',
310 Token.Name.Namespace: 'bold #2080D0',
310 Token.Name.Namespace: 'bold #2080D0',
311 Token.Prompt: '#009900',
311 Token.Prompt: '#009900',
312 Token.PromptNum: '#00ff00 bold',
312 Token.PromptNum: '#00ff00 bold',
313 Token.OutPrompt: '#990000',
313 Token.OutPrompt: '#990000',
314 Token.OutPromptNum: '#ff0000 bold',
314 Token.OutPromptNum: '#ff0000 bold',
315 })
315 })
316
316
317 # Hack: Due to limited color support on the Windows console
317 # Hack: Due to limited color support on the Windows console
318 # the prompt colors will be wrong without this
318 # the prompt colors will be wrong without this
319 if os.name == 'nt':
319 if os.name == 'nt':
320 style_overrides.update({
320 style_overrides.update({
321 Token.Prompt: '#ansidarkgreen',
321 Token.Prompt: '#ansidarkgreen',
322 Token.PromptNum: '#ansigreen bold',
322 Token.PromptNum: '#ansigreen bold',
323 Token.OutPrompt: '#ansidarkred',
323 Token.OutPrompt: '#ansidarkred',
324 Token.OutPromptNum: '#ansired bold',
324 Token.OutPromptNum: '#ansired bold',
325 })
325 })
326 elif legacy =='nocolor':
326 elif legacy =='nocolor':
327 style_cls=_NoStyle
327 style_cls=_NoStyle
328 style_overrides = {}
328 style_overrides = {}
329 else :
329 else :
330 raise ValueError('Got unknown colors: ', legacy)
330 raise ValueError('Got unknown colors: ', legacy)
331 else :
331 else :
332 if isinstance(name_or_cls, str):
332 if isinstance(name_or_cls, str):
333 style_cls = get_style_by_name(name_or_cls)
333 style_cls = get_style_by_name(name_or_cls)
334 else:
334 else:
335 style_cls = name_or_cls
335 style_cls = name_or_cls
336 style_overrides = {
336 style_overrides = {
337 Token.Prompt: '#009900',
337 Token.Prompt: '#009900',
338 Token.PromptNum: '#00ff00 bold',
338 Token.PromptNum: '#00ff00 bold',
339 Token.OutPrompt: '#990000',
339 Token.OutPrompt: '#990000',
340 Token.OutPromptNum: '#ff0000 bold',
340 Token.OutPromptNum: '#ff0000 bold',
341 }
341 }
342 style_overrides.update(self.highlighting_style_overrides)
342 style_overrides.update(self.highlighting_style_overrides)
343 style = PygmentsStyle.from_defaults(pygments_style_cls=style_cls,
343 style = PygmentsStyle.from_defaults(pygments_style_cls=style_cls,
344 style_dict=style_overrides)
344 style_dict=style_overrides)
345
345
346 return style
346 return style
347
347
348 def _layout_options(self):
348 def _layout_options(self):
349 """
349 """
350 Return the current layout option for the current Terminal InteractiveShell
350 Return the current layout option for the current Terminal InteractiveShell
351 """
351 """
352 return {
352 return {
353 'lexer':IPythonPTLexer(),
353 'lexer':IPythonPTLexer(),
354 'reserve_space_for_menu':self.space_for_menu,
354 'reserve_space_for_menu':self.space_for_menu,
355 'get_prompt_tokens':self.prompts.in_prompt_tokens,
355 'get_prompt_tokens':self.prompts.in_prompt_tokens,
356 'get_continuation_tokens':self.prompts.continuation_prompt_tokens,
356 'get_continuation_tokens':self.prompts.continuation_prompt_tokens,
357 'multiline':True,
357 'multiline':True,
358 'display_completions_in_columns': (self.display_completions == 'multicolumn'),
358 'display_completions_in_columns': (self.display_completions == 'multicolumn'),
359
359
360 # Highlight matching brackets, but only when this setting is
360 # Highlight matching brackets, but only when this setting is
361 # enabled, and only when the DEFAULT_BUFFER has the focus.
361 # enabled, and only when the DEFAULT_BUFFER has the focus.
362 'extra_input_processors': [ConditionalProcessor(
362 'extra_input_processors': [ConditionalProcessor(
363 processor=HighlightMatchingBracketProcessor(chars='[](){}'),
363 processor=HighlightMatchingBracketProcessor(chars='[](){}'),
364 filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
364 filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
365 Condition(lambda cli: self.highlight_matching_brackets))],
365 Condition(lambda cli: self.highlight_matching_brackets))],
366 }
366 }
367
367
368 def _update_layout(self):
368 def _update_layout(self):
369 """
369 """
370 Ask for a re computation of the application layout, if for example ,
370 Ask for a re computation of the application layout, if for example ,
371 some configuration options have changed.
371 some configuration options have changed.
372 """
372 """
373 if self._pt_app:
373 if self._pt_app:
374 self._pt_app.layout = create_prompt_layout(**self._layout_options())
374 self._pt_app.layout = create_prompt_layout(**self._layout_options())
375
375
376 def prompt_for_code(self):
376 def prompt_for_code(self):
377 with self.pt_cli.patch_stdout_context(raw=True):
377 with self.pt_cli.patch_stdout_context(raw=True):
378 document = self.pt_cli.run(
378 document = self.pt_cli.run(
379 pre_run=self.pre_prompt, reset_current_buffer=True)
379 pre_run=self.pre_prompt, reset_current_buffer=True)
380 return document.text
380 return document.text
381
381
382 def enable_win_unicode_console(self):
382 def enable_win_unicode_console(self):
383 if sys.version_info >= (3, 6):
383 if sys.version_info >= (3, 6):
384 # Since PEP 528, Python uses the unicode APIs for the Windows
384 # Since PEP 528, Python uses the unicode APIs for the Windows
385 # console by default, so WUC shouldn't be needed.
385 # console by default, so WUC shouldn't be needed.
386 return
386 return
387
387
388 import win_unicode_console
388 import win_unicode_console
389 win_unicode_console.enable()
389 win_unicode_console.enable()
390
390
391 def init_io(self):
391 def init_io(self):
392 if sys.platform not in {'win32', 'cli'}:
392 if sys.platform not in {'win32', 'cli'}:
393 return
393 return
394
394
395 self.enable_win_unicode_console()
395 self.enable_win_unicode_console()
396
396
397 import colorama
397 import colorama
398 colorama.init()
398 colorama.init()
399
399
400 # For some reason we make these wrappers around stdout/stderr.
400 # For some reason we make these wrappers around stdout/stderr.
401 # For now, we need to reset them so all output gets coloured.
401 # For now, we need to reset them so all output gets coloured.
402 # https://github.com/ipython/ipython/issues/8669
402 # https://github.com/ipython/ipython/issues/8669
403 # io.std* are deprecated, but don't show our own deprecation warnings
403 # io.std* are deprecated, but don't show our own deprecation warnings
404 # during initialization of the deprecated API.
404 # during initialization of the deprecated API.
405 with warnings.catch_warnings():
405 with warnings.catch_warnings():
406 warnings.simplefilter('ignore', DeprecationWarning)
406 warnings.simplefilter('ignore', DeprecationWarning)
407 io.stdout = io.IOStream(sys.stdout)
407 io.stdout = io.IOStream(sys.stdout)
408 io.stderr = io.IOStream(sys.stderr)
408 io.stderr = io.IOStream(sys.stderr)
409
409
410 def init_magics(self):
410 def init_magics(self):
411 super(TerminalInteractiveShell, self).init_magics()
411 super(TerminalInteractiveShell, self).init_magics()
412 self.register_magics(TerminalMagics)
412 self.register_magics(TerminalMagics)
413
413
414 def init_alias(self):
414 def init_alias(self):
415 # The parent class defines aliases that can be safely used with any
415 # The parent class defines aliases that can be safely used with any
416 # frontend.
416 # frontend.
417 super(TerminalInteractiveShell, self).init_alias()
417 super(TerminalInteractiveShell, self).init_alias()
418
418
419 # Now define aliases that only make sense on the terminal, because they
419 # Now define aliases that only make sense on the terminal, because they
420 # need direct access to the console in a way that we can't emulate in
420 # need direct access to the console in a way that we can't emulate in
421 # GUI or web frontend
421 # GUI or web frontend
422 if os.name == 'posix':
422 if os.name == 'posix':
423 for cmd in ['clear', 'more', 'less', 'man']:
423 for cmd in ['clear', 'more', 'less', 'man']:
424 self.alias_manager.soft_define_alias(cmd, cmd)
424 self.alias_manager.soft_define_alias(cmd, cmd)
425
425
426
426
427 def __init__(self, *args, **kwargs):
427 def __init__(self, *args, **kwargs):
428 super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
428 super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
429 self.init_prompt_toolkit_cli()
429 self.init_prompt_toolkit_cli()
430 self.init_term_title()
430 self.init_term_title()
431 self.keep_running = True
431 self.keep_running = True
432
432
433 self.debugger_history = InMemoryHistory()
433 self.debugger_history = InMemoryHistory()
434
434
435 def ask_exit(self):
435 def ask_exit(self):
436 self.keep_running = False
436 self.keep_running = False
437
437
438 rl_next_input = None
438 rl_next_input = None
439
439
440 def pre_prompt(self):
440 def pre_prompt(self):
441 if self.rl_next_input:
441 if self.rl_next_input:
442 # We can't set the buffer here, because it will be reset just after
442 # We can't set the buffer here, because it will be reset just after
443 # this. Adding a callable to pre_run_callables does what we need
443 # this. Adding a callable to pre_run_callables does what we need
444 # after the buffer is reset.
444 # after the buffer is reset.
445 s = self.rl_next_input
445 s = self.rl_next_input
446 def set_doc():
446 def set_doc():
447 self.pt_cli.application.buffer.document = Document(s)
447 self.pt_cli.application.buffer.document = Document(s)
448 if hasattr(self.pt_cli, 'pre_run_callables'):
448 if hasattr(self.pt_cli, 'pre_run_callables'):
449 self.pt_cli.pre_run_callables.append(set_doc)
449 self.pt_cli.pre_run_callables.append(set_doc)
450 else:
450 else:
451 # Older version of prompt_toolkit; it's OK to set the document
451 # Older version of prompt_toolkit; it's OK to set the document
452 # directly here.
452 # directly here.
453 set_doc()
453 set_doc()
454 self.rl_next_input = None
454 self.rl_next_input = None
455
455
456 def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):
456 def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):
457
457
458 if display_banner is not DISPLAY_BANNER_DEPRECATED:
458 if display_banner is not DISPLAY_BANNER_DEPRECATED:
459 warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
459 warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
460
460
461 self.keep_running = True
461 self.keep_running = True
462 while self.keep_running:
462 while self.keep_running:
463 print(self.separate_in, end='')
463 print(self.separate_in, end='')
464
464
465 try:
465 try:
466 code = self.prompt_for_code()
466 code = self.prompt_for_code()
467 except EOFError:
467 except EOFError:
468 if (not self.confirm_exit) \
468 if (not self.confirm_exit) \
469 or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
469 or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
470 self.ask_exit()
470 self.ask_exit()
471
471
472 else:
472 else:
473 if code:
473 if code:
474 self.run_cell(code, store_history=True)
474 self.run_cell(code, store_history=True)
475
475
476 def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
476 def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
477 # An extra layer of protection in case someone mashing Ctrl-C breaks
477 # An extra layer of protection in case someone mashing Ctrl-C breaks
478 # out of our internal code.
478 # out of our internal code.
479 if display_banner is not DISPLAY_BANNER_DEPRECATED:
479 if display_banner is not DISPLAY_BANNER_DEPRECATED:
480 warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
480 warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
481 while True:
481 while True:
482 try:
482 try:
483 self.interact()
483 self.interact()
484 break
484 break
485 except KeyboardInterrupt as e:
485 except KeyboardInterrupt as e:
486 print("\n%s escaped interact()\n" % type(e).__name__)
486 print("\n%s escaped interact()\n" % type(e).__name__)
487 finally:
487 finally:
488 # An interrupt during the eventloop will mess up the
488 # An interrupt during the eventloop will mess up the
489 # internal state of the prompt_toolkit library.
489 # internal state of the prompt_toolkit library.
490 # Stopping the eventloop fixes this, see
490 # Stopping the eventloop fixes this, see
491 # https://github.com/ipython/ipython/pull/9867
491 # https://github.com/ipython/ipython/pull/9867
492 if hasattr(self, '_eventloop'):
492 if hasattr(self, '_eventloop'):
493 self._eventloop.stop()
493 self._eventloop.stop()
494
494
495 _inputhook = None
495 _inputhook = None
496 def inputhook(self, context):
496 def inputhook(self, context):
497 if self._inputhook is not None:
497 if self._inputhook is not None:
498 self._inputhook(context)
498 self._inputhook(context)
499
499
500 active_eventloop = None
500 active_eventloop = None
501 def enable_gui(self, gui=None):
501 def enable_gui(self, gui=None):
502 if gui:
502 if gui:
503 self.active_eventloop, self._inputhook =\
503 self.active_eventloop, self._inputhook =\
504 get_inputhook_name_and_func(gui)
504 get_inputhook_name_and_func(gui)
505 else:
505 else:
506 self.active_eventloop = self._inputhook = None
506 self.active_eventloop = self._inputhook = None
507
507
508 # Run !system commands directly, not through pipes, so terminal programs
508 # Run !system commands directly, not through pipes, so terminal programs
509 # work correctly.
509 # work correctly.
510 system = InteractiveShell.system_raw
510 system = InteractiveShell.system_raw
511
511
512 def auto_rewrite_input(self, cmd):
512 def auto_rewrite_input(self, cmd):
513 """Overridden from the parent class to use fancy rewriting prompt"""
513 """Overridden from the parent class to use fancy rewriting prompt"""
514 if not self.show_rewritten_input:
514 if not self.show_rewritten_input:
515 return
515 return
516
516
517 tokens = self.prompts.rewrite_prompt_tokens()
517 tokens = self.prompts.rewrite_prompt_tokens()
518 if self.pt_cli:
518 if self.pt_cli:
519 self.pt_cli.print_tokens(tokens)
519 self.pt_cli.print_tokens(tokens)
520 print(cmd)
520 print(cmd)
521 else:
521 else:
522 prompt = ''.join(s for t, s in tokens)
522 prompt = ''.join(s for t, s in tokens)
523 print(prompt, cmd, sep='')
523 print(prompt, cmd, sep='')
524
524
525 _prompts_before = None
525 _prompts_before = None
526 def switch_doctest_mode(self, mode):
526 def switch_doctest_mode(self, mode):
527 """Switch prompts to classic for %doctest_mode"""
527 """Switch prompts to classic for %doctest_mode"""
528 if mode:
528 if mode:
529 self._prompts_before = self.prompts
529 self._prompts_before = self.prompts
530 self.prompts = ClassicPrompts(self)
530 self.prompts = ClassicPrompts(self)
531 elif self._prompts_before:
531 elif self._prompts_before:
532 self.prompts = self._prompts_before
532 self.prompts = self._prompts_before
533 self._prompts_before = None
533 self._prompts_before = None
534 self._update_layout()
534 self._update_layout()
535
535
536
536
537 InteractiveShellABC.register(TerminalInteractiveShell)
537 InteractiveShellABC.register(TerminalInteractiveShell)
538
538
539 if __name__ == '__main__':
539 if __name__ == '__main__':
540 TerminalInteractiveShell.instance().interact()
540 TerminalInteractiveShell.instance().interact()
@@ -1,378 +1,376
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Decorators for labeling test objects.
2 """Decorators for labeling test objects.
3
3
4 Decorators that merely return a modified version of the original function
4 Decorators that merely return a modified version of the original function
5 object are straightforward. Decorators that return a new function object need
5 object are straightforward. Decorators that return a new function object need
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
7 decorator, in order to preserve metadata such as function name, setup and
7 decorator, in order to preserve metadata such as function name, setup and
8 teardown functions and so on - see nose.tools for more information.
8 teardown functions and so on - see nose.tools for more information.
9
9
10 This module provides a set of useful decorators meant to be ready to use in
10 This module provides a set of useful decorators meant to be ready to use in
11 your own tests. See the bottom of the file for the ready-made ones, and if you
11 your own tests. See the bottom of the file for the ready-made ones, and if you
12 find yourself writing a new one that may be of generic use, add it here.
12 find yourself writing a new one that may be of generic use, add it here.
13
13
14 Included decorators:
14 Included decorators:
15
15
16
16
17 Lightweight testing that remains unittest-compatible.
17 Lightweight testing that remains unittest-compatible.
18
18
19 - An @as_unittest decorator can be used to tag any normal parameter-less
19 - An @as_unittest decorator can be used to tag any normal parameter-less
20 function as a unittest TestCase. Then, both nose and normal unittest will
20 function as a unittest TestCase. Then, both nose and normal unittest will
21 recognize it as such. This will make it easier to migrate away from Nose if
21 recognize it as such. This will make it easier to migrate away from Nose if
22 we ever need/want to while maintaining very lightweight tests.
22 we ever need/want to while maintaining very lightweight tests.
23
23
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
26 available, OR use equivalent code in IPython.external._decorators, which
26 available, OR use equivalent code in IPython.external._decorators, which
27 we've copied verbatim from numpy.
27 we've copied verbatim from numpy.
28
28
29 """
29 """
30
30
31 # Copyright (c) IPython Development Team.
31 # Copyright (c) IPython Development Team.
32 # Distributed under the terms of the Modified BSD License.
32 # Distributed under the terms of the Modified BSD License.
33
33
34 import sys
35 import os
34 import os
35 import shutil
36 import sys
36 import tempfile
37 import tempfile
37 import unittest
38 import unittest
38 import warnings
39 import warnings
39 from importlib import import_module
40 from importlib import import_module
40
41
41 from decorator import decorator
42 from decorator import decorator
42
43
43 # Expose the unittest-driven decorators
44 # Expose the unittest-driven decorators
44 from .ipunittest import ipdoctest, ipdocstring
45 from .ipunittest import ipdoctest, ipdocstring
45
46
46 # Grab the numpy-specific decorators which we keep in a file that we
47 # Grab the numpy-specific decorators which we keep in a file that we
47 # occasionally update from upstream: decorators.py is a copy of
48 # occasionally update from upstream: decorators.py is a copy of
48 # numpy.testing.decorators, we expose all of it here.
49 # numpy.testing.decorators, we expose all of it here.
49 from IPython.external.decorators import *
50 from IPython.external.decorators import *
50
51
51 # For onlyif_cmd_exists decorator
52 from IPython.utils.py3compat import which
53
54 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
55 # Classes and functions
53 # Classes and functions
56 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
57
55
58 # Simple example of the basic idea
56 # Simple example of the basic idea
59 def as_unittest(func):
57 def as_unittest(func):
60 """Decorator to make a simple function into a normal test via unittest."""
58 """Decorator to make a simple function into a normal test via unittest."""
61 class Tester(unittest.TestCase):
59 class Tester(unittest.TestCase):
62 def test(self):
60 def test(self):
63 func()
61 func()
64
62
65 Tester.__name__ = func.__name__
63 Tester.__name__ = func.__name__
66
64
67 return Tester
65 return Tester
68
66
69 # Utility functions
67 # Utility functions
70
68
71 def apply_wrapper(wrapper, func):
69 def apply_wrapper(wrapper, func):
72 """Apply a wrapper to a function for decoration.
70 """Apply a wrapper to a function for decoration.
73
71
74 This mixes Michele Simionato's decorator tool with nose's make_decorator,
72 This mixes Michele Simionato's decorator tool with nose's make_decorator,
75 to apply a wrapper in a decorator so that all nose attributes, as well as
73 to apply a wrapper in a decorator so that all nose attributes, as well as
76 function signature and other properties, survive the decoration cleanly.
74 function signature and other properties, survive the decoration cleanly.
77 This will ensure that wrapped functions can still be well introspected via
75 This will ensure that wrapped functions can still be well introspected via
78 IPython, for example.
76 IPython, for example.
79 """
77 """
80 warnings.warn("The function `apply_wrapper` is deprecated since IPython 4.0",
78 warnings.warn("The function `apply_wrapper` is deprecated since IPython 4.0",
81 DeprecationWarning, stacklevel=2)
79 DeprecationWarning, stacklevel=2)
82 import nose.tools
80 import nose.tools
83
81
84 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
82 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
85
83
86
84
87 def make_label_dec(label, ds=None):
85 def make_label_dec(label, ds=None):
88 """Factory function to create a decorator that applies one or more labels.
86 """Factory function to create a decorator that applies one or more labels.
89
87
90 Parameters
88 Parameters
91 ----------
89 ----------
92 label : string or sequence
90 label : string or sequence
93 One or more labels that will be applied by the decorator to the functions
91 One or more labels that will be applied by the decorator to the functions
94 it decorates. Labels are attributes of the decorated function with their
92 it decorates. Labels are attributes of the decorated function with their
95 value set to True.
93 value set to True.
96
94
97 ds : string
95 ds : string
98 An optional docstring for the resulting decorator. If not given, a
96 An optional docstring for the resulting decorator. If not given, a
99 default docstring is auto-generated.
97 default docstring is auto-generated.
100
98
101 Returns
99 Returns
102 -------
100 -------
103 A decorator.
101 A decorator.
104
102
105 Examples
103 Examples
106 --------
104 --------
107
105
108 A simple labeling decorator:
106 A simple labeling decorator:
109
107
110 >>> slow = make_label_dec('slow')
108 >>> slow = make_label_dec('slow')
111 >>> slow.__doc__
109 >>> slow.__doc__
112 "Labels a test as 'slow'."
110 "Labels a test as 'slow'."
113
111
114 And one that uses multiple labels and a custom docstring:
112 And one that uses multiple labels and a custom docstring:
115
113
116 >>> rare = make_label_dec(['slow','hard'],
114 >>> rare = make_label_dec(['slow','hard'],
117 ... "Mix labels 'slow' and 'hard' for rare tests.")
115 ... "Mix labels 'slow' and 'hard' for rare tests.")
118 >>> rare.__doc__
116 >>> rare.__doc__
119 "Mix labels 'slow' and 'hard' for rare tests."
117 "Mix labels 'slow' and 'hard' for rare tests."
120
118
121 Now, let's test using this one:
119 Now, let's test using this one:
122 >>> @rare
120 >>> @rare
123 ... def f(): pass
121 ... def f(): pass
124 ...
122 ...
125 >>>
123 >>>
126 >>> f.slow
124 >>> f.slow
127 True
125 True
128 >>> f.hard
126 >>> f.hard
129 True
127 True
130 """
128 """
131
129
132 warnings.warn("The function `make_label_dec` is deprecated since IPython 4.0",
130 warnings.warn("The function `make_label_dec` is deprecated since IPython 4.0",
133 DeprecationWarning, stacklevel=2)
131 DeprecationWarning, stacklevel=2)
134 if isinstance(label, str):
132 if isinstance(label, str):
135 labels = [label]
133 labels = [label]
136 else:
134 else:
137 labels = label
135 labels = label
138
136
139 # Validate that the given label(s) are OK for use in setattr() by doing a
137 # Validate that the given label(s) are OK for use in setattr() by doing a
140 # dry run on a dummy function.
138 # dry run on a dummy function.
141 tmp = lambda : None
139 tmp = lambda : None
142 for label in labels:
140 for label in labels:
143 setattr(tmp,label,True)
141 setattr(tmp,label,True)
144
142
145 # This is the actual decorator we'll return
143 # This is the actual decorator we'll return
146 def decor(f):
144 def decor(f):
147 for label in labels:
145 for label in labels:
148 setattr(f,label,True)
146 setattr(f,label,True)
149 return f
147 return f
150
148
151 # Apply the user's docstring, or autogenerate a basic one
149 # Apply the user's docstring, or autogenerate a basic one
152 if ds is None:
150 if ds is None:
153 ds = "Labels a test as %r." % label
151 ds = "Labels a test as %r." % label
154 decor.__doc__ = ds
152 decor.__doc__ = ds
155
153
156 return decor
154 return decor
157
155
158
156
159 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
157 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
160 # preserve function metadata better and allows the skip condition to be a
158 # preserve function metadata better and allows the skip condition to be a
161 # callable.
159 # callable.
162 def skipif(skip_condition, msg=None):
160 def skipif(skip_condition, msg=None):
163 ''' Make function raise SkipTest exception if skip_condition is true
161 ''' Make function raise SkipTest exception if skip_condition is true
164
162
165 Parameters
163 Parameters
166 ----------
164 ----------
167
165
168 skip_condition : bool or callable
166 skip_condition : bool or callable
169 Flag to determine whether to skip test. If the condition is a
167 Flag to determine whether to skip test. If the condition is a
170 callable, it is used at runtime to dynamically make the decision. This
168 callable, it is used at runtime to dynamically make the decision. This
171 is useful for tests that may require costly imports, to delay the cost
169 is useful for tests that may require costly imports, to delay the cost
172 until the test suite is actually executed.
170 until the test suite is actually executed.
173 msg : string
171 msg : string
174 Message to give on raising a SkipTest exception.
172 Message to give on raising a SkipTest exception.
175
173
176 Returns
174 Returns
177 -------
175 -------
178 decorator : function
176 decorator : function
179 Decorator, which, when applied to a function, causes SkipTest
177 Decorator, which, when applied to a function, causes SkipTest
180 to be raised when the skip_condition was True, and the function
178 to be raised when the skip_condition was True, and the function
181 to be called normally otherwise.
179 to be called normally otherwise.
182
180
183 Notes
181 Notes
184 -----
182 -----
185 You will see from the code that we had to further decorate the
183 You will see from the code that we had to further decorate the
186 decorator with the nose.tools.make_decorator function in order to
184 decorator with the nose.tools.make_decorator function in order to
187 transmit function name, and various other metadata.
185 transmit function name, and various other metadata.
188 '''
186 '''
189
187
190 def skip_decorator(f):
188 def skip_decorator(f):
191 # Local import to avoid a hard nose dependency and only incur the
189 # Local import to avoid a hard nose dependency and only incur the
192 # import time overhead at actual test-time.
190 # import time overhead at actual test-time.
193 import nose
191 import nose
194
192
195 # Allow for both boolean or callable skip conditions.
193 # Allow for both boolean or callable skip conditions.
196 if callable(skip_condition):
194 if callable(skip_condition):
197 skip_val = skip_condition
195 skip_val = skip_condition
198 else:
196 else:
199 skip_val = lambda : skip_condition
197 skip_val = lambda : skip_condition
200
198
201 def get_msg(func,msg=None):
199 def get_msg(func,msg=None):
202 """Skip message with information about function being skipped."""
200 """Skip message with information about function being skipped."""
203 if msg is None: out = 'Test skipped due to test condition.'
201 if msg is None: out = 'Test skipped due to test condition.'
204 else: out = msg
202 else: out = msg
205 return "Skipping test: %s. %s" % (func.__name__,out)
203 return "Skipping test: %s. %s" % (func.__name__,out)
206
204
207 # We need to define *two* skippers because Python doesn't allow both
205 # We need to define *two* skippers because Python doesn't allow both
208 # return with value and yield inside the same function.
206 # return with value and yield inside the same function.
209 def skipper_func(*args, **kwargs):
207 def skipper_func(*args, **kwargs):
210 """Skipper for normal test functions."""
208 """Skipper for normal test functions."""
211 if skip_val():
209 if skip_val():
212 raise nose.SkipTest(get_msg(f,msg))
210 raise nose.SkipTest(get_msg(f,msg))
213 else:
211 else:
214 return f(*args, **kwargs)
212 return f(*args, **kwargs)
215
213
216 def skipper_gen(*args, **kwargs):
214 def skipper_gen(*args, **kwargs):
217 """Skipper for test generators."""
215 """Skipper for test generators."""
218 if skip_val():
216 if skip_val():
219 raise nose.SkipTest(get_msg(f,msg))
217 raise nose.SkipTest(get_msg(f,msg))
220 else:
218 else:
221 for x in f(*args, **kwargs):
219 for x in f(*args, **kwargs):
222 yield x
220 yield x
223
221
224 # Choose the right skipper to use when building the actual generator.
222 # Choose the right skipper to use when building the actual generator.
225 if nose.util.isgenerator(f):
223 if nose.util.isgenerator(f):
226 skipper = skipper_gen
224 skipper = skipper_gen
227 else:
225 else:
228 skipper = skipper_func
226 skipper = skipper_func
229
227
230 return nose.tools.make_decorator(f)(skipper)
228 return nose.tools.make_decorator(f)(skipper)
231
229
232 return skip_decorator
230 return skip_decorator
233
231
234 # A version with the condition set to true, common case just to attach a message
232 # A version with the condition set to true, common case just to attach a message
235 # to a skip decorator
233 # to a skip decorator
236 def skip(msg=None):
234 def skip(msg=None):
237 """Decorator factory - mark a test function for skipping from test suite.
235 """Decorator factory - mark a test function for skipping from test suite.
238
236
239 Parameters
237 Parameters
240 ----------
238 ----------
241 msg : string
239 msg : string
242 Optional message to be added.
240 Optional message to be added.
243
241
244 Returns
242 Returns
245 -------
243 -------
246 decorator : function
244 decorator : function
247 Decorator, which, when applied to a function, causes SkipTest
245 Decorator, which, when applied to a function, causes SkipTest
248 to be raised, with the optional message added.
246 to be raised, with the optional message added.
249 """
247 """
250
248
251 return skipif(True,msg)
249 return skipif(True,msg)
252
250
253
251
254 def onlyif(condition, msg):
252 def onlyif(condition, msg):
255 """The reverse from skipif, see skipif for details."""
253 """The reverse from skipif, see skipif for details."""
256
254
257 if callable(condition):
255 if callable(condition):
258 skip_condition = lambda : not condition()
256 skip_condition = lambda : not condition()
259 else:
257 else:
260 skip_condition = lambda : not condition
258 skip_condition = lambda : not condition
261
259
262 return skipif(skip_condition, msg)
260 return skipif(skip_condition, msg)
263
261
264 #-----------------------------------------------------------------------------
262 #-----------------------------------------------------------------------------
265 # Utility functions for decorators
263 # Utility functions for decorators
266 def module_not_available(module):
264 def module_not_available(module):
267 """Can module be imported? Returns true if module does NOT import.
265 """Can module be imported? Returns true if module does NOT import.
268
266
269 This is used to make a decorator to skip tests that require module to be
267 This is used to make a decorator to skip tests that require module to be
270 available, but delay the 'import numpy' to test execution time.
268 available, but delay the 'import numpy' to test execution time.
271 """
269 """
272 try:
270 try:
273 mod = import_module(module)
271 mod = import_module(module)
274 mod_not_avail = False
272 mod_not_avail = False
275 except ImportError:
273 except ImportError:
276 mod_not_avail = True
274 mod_not_avail = True
277
275
278 return mod_not_avail
276 return mod_not_avail
279
277
280
278
281 def decorated_dummy(dec, name):
279 def decorated_dummy(dec, name):
282 """Return a dummy function decorated with dec, with the given name.
280 """Return a dummy function decorated with dec, with the given name.
283
281
284 Examples
282 Examples
285 --------
283 --------
286 import IPython.testing.decorators as dec
284 import IPython.testing.decorators as dec
287 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
285 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
288 """
286 """
289 warnings.warn("The function `decorated_dummy` is deprecated since IPython 4.0",
287 warnings.warn("The function `decorated_dummy` is deprecated since IPython 4.0",
290 DeprecationWarning, stacklevel=2)
288 DeprecationWarning, stacklevel=2)
291 dummy = lambda: None
289 dummy = lambda: None
292 dummy.__name__ = name
290 dummy.__name__ = name
293 return dec(dummy)
291 return dec(dummy)
294
292
295 #-----------------------------------------------------------------------------
293 #-----------------------------------------------------------------------------
296 # Decorators for public use
294 # Decorators for public use
297
295
298 # Decorators to skip certain tests on specific platforms.
296 # Decorators to skip certain tests on specific platforms.
299 skip_win32 = skipif(sys.platform == 'win32',
297 skip_win32 = skipif(sys.platform == 'win32',
300 "This test does not run under Windows")
298 "This test does not run under Windows")
301 skip_linux = skipif(sys.platform.startswith('linux'),
299 skip_linux = skipif(sys.platform.startswith('linux'),
302 "This test does not run under Linux")
300 "This test does not run under Linux")
303 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
301 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
304
302
305
303
306 # Decorators to skip tests if not on specific platforms.
304 # Decorators to skip tests if not on specific platforms.
307 skip_if_not_win32 = skipif(sys.platform != 'win32',
305 skip_if_not_win32 = skipif(sys.platform != 'win32',
308 "This test only runs under Windows")
306 "This test only runs under Windows")
309 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
307 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
310 "This test only runs under Linux")
308 "This test only runs under Linux")
311 skip_if_not_osx = skipif(sys.platform != 'darwin',
309 skip_if_not_osx = skipif(sys.platform != 'darwin',
312 "This test only runs under OSX")
310 "This test only runs under OSX")
313
311
314
312
315 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
313 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
316 os.environ.get('DISPLAY', '') == '')
314 os.environ.get('DISPLAY', '') == '')
317 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
315 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
318
316
319 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
317 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
320
318
321 # not a decorator itself, returns a dummy function to be used as setup
319 # not a decorator itself, returns a dummy function to be used as setup
322 def skip_file_no_x11(name):
320 def skip_file_no_x11(name):
323 warnings.warn("The function `skip_file_no_x11` is deprecated since IPython 4.0",
321 warnings.warn("The function `skip_file_no_x11` is deprecated since IPython 4.0",
324 DeprecationWarning, stacklevel=2)
322 DeprecationWarning, stacklevel=2)
325 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
323 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
326
324
327 # Other skip decorators
325 # Other skip decorators
328
326
329 # generic skip without module
327 # generic skip without module
330 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
328 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
331
329
332 skipif_not_numpy = skip_without('numpy')
330 skipif_not_numpy = skip_without('numpy')
333
331
334 skipif_not_matplotlib = skip_without('matplotlib')
332 skipif_not_matplotlib = skip_without('matplotlib')
335
333
336 skipif_not_sympy = skip_without('sympy')
334 skipif_not_sympy = skip_without('sympy')
337
335
338 skip_known_failure = knownfailureif(True,'This test is known to fail')
336 skip_known_failure = knownfailureif(True,'This test is known to fail')
339
337
340 # A null 'decorator', useful to make more readable code that needs to pick
338 # A null 'decorator', useful to make more readable code that needs to pick
341 # between different decorators based on OS or other conditions
339 # between different decorators based on OS or other conditions
342 null_deco = lambda f: f
340 null_deco = lambda f: f
343
341
344 # Some tests only run where we can use unicode paths. Note that we can't just
342 # Some tests only run where we can use unicode paths. Note that we can't just
345 # check os.path.supports_unicode_filenames, which is always False on Linux.
343 # check os.path.supports_unicode_filenames, which is always False on Linux.
346 try:
344 try:
347 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
345 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
348 except UnicodeEncodeError:
346 except UnicodeEncodeError:
349 unicode_paths = False
347 unicode_paths = False
350 else:
348 else:
351 unicode_paths = True
349 unicode_paths = True
352 f.close()
350 f.close()
353
351
354 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
352 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
355 "where we can use unicode in filenames."))
353 "where we can use unicode in filenames."))
356
354
357
355
358 def onlyif_cmds_exist(*commands):
356 def onlyif_cmds_exist(*commands):
359 """
357 """
360 Decorator to skip test when at least one of `commands` is not found.
358 Decorator to skip test when at least one of `commands` is not found.
361 """
359 """
362 for cmd in commands:
360 for cmd in commands:
363 if not which(cmd):
361 if not shutil.which(cmd):
364 return skip("This test runs only if command '{0}' "
362 return skip("This test runs only if command '{0}' "
365 "is installed".format(cmd))
363 "is installed".format(cmd))
366 return null_deco
364 return null_deco
367
365
368 def onlyif_any_cmd_exists(*commands):
366 def onlyif_any_cmd_exists(*commands):
369 """
367 """
370 Decorator to skip test unless at least one of `commands` is found.
368 Decorator to skip test unless at least one of `commands` is found.
371 """
369 """
372 warnings.warn("The function `onlyif_any_cmd_exists` is deprecated since IPython 4.0",
370 warnings.warn("The function `onlyif_any_cmd_exists` is deprecated since IPython 4.0",
373 DeprecationWarning, stacklevel=2)
371 DeprecationWarning, stacklevel=2)
374 for cmd in commands:
372 for cmd in commands:
375 if which(cmd):
373 if shutil.which(cmd):
376 return null_deco
374 return null_deco
377 return skip("This test runs only if one of the commands {0} "
375 return skip("This test runs only if one of the commands {0} "
378 "is installed".format(commands))
376 "is installed".format(commands))
@@ -1,136 +1,136
1 """Global IPython app to support test running.
1 """Global IPython app to support test running.
2
2
3 We must start our own ipython object and heavily muck with it so that all the
3 We must start our own ipython object and heavily muck with it so that all the
4 modifications IPython makes to system behavior don't send the doctest machinery
4 modifications IPython makes to system behavior don't send the doctest machinery
5 into a fit. This code should be considered a gross hack, but it gets the job
5 into a fit. This code should be considered a gross hack, but it gets the job
6 done.
6 done.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import builtins as builtin_mod
12 import builtins as builtin_mod
13 import sys
13 import sys
14 import types
14 import warnings
15 import warnings
15
16
16 from . import tools
17 from . import tools
17
18
18 from IPython.core import page
19 from IPython.core import page
19 from IPython.utils import io
20 from IPython.utils import io
20 from IPython.utils import py3compat
21 from IPython.terminal.interactiveshell import TerminalInteractiveShell
21 from IPython.terminal.interactiveshell import TerminalInteractiveShell
22
22
23
23
24 class StreamProxy(io.IOStream):
24 class StreamProxy(io.IOStream):
25 """Proxy for sys.stdout/err. This will request the stream *at call time*
25 """Proxy for sys.stdout/err. This will request the stream *at call time*
26 allowing for nose's Capture plugin's redirection of sys.stdout/err.
26 allowing for nose's Capture plugin's redirection of sys.stdout/err.
27
27
28 Parameters
28 Parameters
29 ----------
29 ----------
30 name : str
30 name : str
31 The name of the stream. This will be requested anew at every call
31 The name of the stream. This will be requested anew at every call
32 """
32 """
33
33
34 def __init__(self, name):
34 def __init__(self, name):
35 warnings.warn("StreamProxy is deprecated and unused as of IPython 5", DeprecationWarning,
35 warnings.warn("StreamProxy is deprecated and unused as of IPython 5", DeprecationWarning,
36 stacklevel=2,
36 stacklevel=2,
37 )
37 )
38 self.name=name
38 self.name=name
39
39
40 @property
40 @property
41 def stream(self):
41 def stream(self):
42 return getattr(sys, self.name)
42 return getattr(sys, self.name)
43
43
44 def flush(self):
44 def flush(self):
45 self.stream.flush()
45 self.stream.flush()
46
46
47
47
48 def get_ipython():
48 def get_ipython():
49 # This will get replaced by the real thing once we start IPython below
49 # This will get replaced by the real thing once we start IPython below
50 return start_ipython()
50 return start_ipython()
51
51
52
52
53 # A couple of methods to override those in the running IPython to interact
53 # A couple of methods to override those in the running IPython to interact
54 # better with doctest (doctest captures on raw stdout, so we need to direct
54 # better with doctest (doctest captures on raw stdout, so we need to direct
55 # various types of output there otherwise it will miss them).
55 # various types of output there otherwise it will miss them).
56
56
57 def xsys(self, cmd):
57 def xsys(self, cmd):
58 """Replace the default system call with a capturing one for doctest.
58 """Replace the default system call with a capturing one for doctest.
59 """
59 """
60 # We use getoutput, but we need to strip it because pexpect captures
60 # We use getoutput, but we need to strip it because pexpect captures
61 # the trailing newline differently from commands.getoutput
61 # the trailing newline differently from commands.getoutput
62 print(self.getoutput(cmd, split=False, depth=1).rstrip(), end='', file=sys.stdout)
62 print(self.getoutput(cmd, split=False, depth=1).rstrip(), end='', file=sys.stdout)
63 sys.stdout.flush()
63 sys.stdout.flush()
64
64
65
65
66 def _showtraceback(self, etype, evalue, stb):
66 def _showtraceback(self, etype, evalue, stb):
67 """Print the traceback purely on stdout for doctest to capture it.
67 """Print the traceback purely on stdout for doctest to capture it.
68 """
68 """
69 print(self.InteractiveTB.stb2text(stb), file=sys.stdout)
69 print(self.InteractiveTB.stb2text(stb), file=sys.stdout)
70
70
71
71
72 def start_ipython():
72 def start_ipython():
73 """Start a global IPython shell, which we need for IPython-specific syntax.
73 """Start a global IPython shell, which we need for IPython-specific syntax.
74 """
74 """
75 global get_ipython
75 global get_ipython
76
76
77 # This function should only ever run once!
77 # This function should only ever run once!
78 if hasattr(start_ipython, 'already_called'):
78 if hasattr(start_ipython, 'already_called'):
79 return
79 return
80 start_ipython.already_called = True
80 start_ipython.already_called = True
81
81
82 # Store certain global objects that IPython modifies
82 # Store certain global objects that IPython modifies
83 _displayhook = sys.displayhook
83 _displayhook = sys.displayhook
84 _excepthook = sys.excepthook
84 _excepthook = sys.excepthook
85 _main = sys.modules.get('__main__')
85 _main = sys.modules.get('__main__')
86
86
87 # Create custom argv and namespaces for our IPython to be test-friendly
87 # Create custom argv and namespaces for our IPython to be test-friendly
88 config = tools.default_config()
88 config = tools.default_config()
89 config.TerminalInteractiveShell.simple_prompt = True
89 config.TerminalInteractiveShell.simple_prompt = True
90
90
91 # Create and initialize our test-friendly IPython instance.
91 # Create and initialize our test-friendly IPython instance.
92 shell = TerminalInteractiveShell.instance(config=config,
92 shell = TerminalInteractiveShell.instance(config=config,
93 )
93 )
94
94
95 # A few more tweaks needed for playing nicely with doctests...
95 # A few more tweaks needed for playing nicely with doctests...
96
96
97 # remove history file
97 # remove history file
98 shell.tempfiles.append(config.HistoryManager.hist_file)
98 shell.tempfiles.append(config.HistoryManager.hist_file)
99
99
100 # These traps are normally only active for interactive use, set them
100 # These traps are normally only active for interactive use, set them
101 # permanently since we'll be mocking interactive sessions.
101 # permanently since we'll be mocking interactive sessions.
102 shell.builtin_trap.activate()
102 shell.builtin_trap.activate()
103
103
104 # Modify the IPython system call with one that uses getoutput, so that we
104 # Modify the IPython system call with one that uses getoutput, so that we
105 # can capture subcommands and print them to Python's stdout, otherwise the
105 # can capture subcommands and print them to Python's stdout, otherwise the
106 # doctest machinery would miss them.
106 # doctest machinery would miss them.
107 shell.system = py3compat.MethodType(xsys, shell)
107 shell.system = types.MethodType(xsys, shell)
108
108
109 shell._showtraceback = py3compat.MethodType(_showtraceback, shell)
109 shell._showtraceback = types.MethodType(_showtraceback, shell)
110
110
111 # IPython is ready, now clean up some global state...
111 # IPython is ready, now clean up some global state...
112
112
113 # Deactivate the various python system hooks added by ipython for
113 # Deactivate the various python system hooks added by ipython for
114 # interactive convenience so we don't confuse the doctest system
114 # interactive convenience so we don't confuse the doctest system
115 sys.modules['__main__'] = _main
115 sys.modules['__main__'] = _main
116 sys.displayhook = _displayhook
116 sys.displayhook = _displayhook
117 sys.excepthook = _excepthook
117 sys.excepthook = _excepthook
118
118
119 # So that ipython magics and aliases can be doctested (they work by making
119 # So that ipython magics and aliases can be doctested (they work by making
120 # a call into a global _ip object). Also make the top-level get_ipython
120 # a call into a global _ip object). Also make the top-level get_ipython
121 # now return this without recursively calling here again.
121 # now return this without recursively calling here again.
122 _ip = shell
122 _ip = shell
123 get_ipython = _ip.get_ipython
123 get_ipython = _ip.get_ipython
124 builtin_mod._ip = _ip
124 builtin_mod._ip = _ip
125 builtin_mod.get_ipython = get_ipython
125 builtin_mod.get_ipython = get_ipython
126
126
127 # Override paging, so we don't require user interaction during the tests.
127 # Override paging, so we don't require user interaction during the tests.
128 def nopage(strng, start=0, screen_lines=0, pager_cmd=None):
128 def nopage(strng, start=0, screen_lines=0, pager_cmd=None):
129 if isinstance(strng, dict):
129 if isinstance(strng, dict):
130 strng = strng.get('text/plain', '')
130 strng = strng.get('text/plain', '')
131 print(strng)
131 print(strng)
132
132
133 page.orig_page = page.pager_page
133 page.orig_page = page.pager_page
134 page.pager_page = nopage
134 page.pager_page = nopage
135
135
136 return _ip
136 return _ip
@@ -1,454 +1,454
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """IPython Test Suite Runner.
2 """IPython Test Suite Runner.
3
3
4 This module provides a main entry point to a user script to test IPython
4 This module provides a main entry point to a user script to test IPython
5 itself from the command line. There are two ways of running this script:
5 itself from the command line. There are two ways of running this script:
6
6
7 1. With the syntax `iptest all`. This runs our entire test suite by
7 1. With the syntax `iptest all`. This runs our entire test suite by
8 calling this script (with different arguments) recursively. This
8 calling this script (with different arguments) recursively. This
9 causes modules and package to be tested in different processes, using nose
9 causes modules and package to be tested in different processes, using nose
10 or trial where appropriate.
10 or trial where appropriate.
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
12 the script simply calls nose, but with special command line flags and
12 the script simply calls nose, but with special command line flags and
13 plugins loaded.
13 plugins loaded.
14
14
15 """
15 """
16
16
17 # Copyright (c) IPython Development Team.
17 # Copyright (c) IPython Development Team.
18 # Distributed under the terms of the Modified BSD License.
18 # Distributed under the terms of the Modified BSD License.
19
19
20
20
21 import glob
21 import glob
22 from io import BytesIO
22 from io import BytesIO
23 import os
23 import os
24 import os.path as path
24 import os.path as path
25 import sys
25 import sys
26 from threading import Thread, Lock, Event
26 from threading import Thread, Lock, Event
27 import warnings
27 import warnings
28
28
29 import nose.plugins.builtin
29 import nose.plugins.builtin
30 from nose.plugins.xunit import Xunit
30 from nose.plugins.xunit import Xunit
31 from nose import SkipTest
31 from nose import SkipTest
32 from nose.core import TestProgram
32 from nose.core import TestProgram
33 from nose.plugins import Plugin
33 from nose.plugins import Plugin
34 from nose.util import safe_str
34 from nose.util import safe_str
35
35
36 from IPython import version_info
36 from IPython import version_info
37 from IPython.utils.py3compat import bytes_to_str
37 from IPython.utils.py3compat import decode
38 from IPython.utils.importstring import import_item
38 from IPython.utils.importstring import import_item
39 from IPython.testing.plugin.ipdoctest import IPythonDoctest
39 from IPython.testing.plugin.ipdoctest import IPythonDoctest
40 from IPython.external.decorators import KnownFailure, knownfailureif
40 from IPython.external.decorators import KnownFailure, knownfailureif
41
41
42 pjoin = path.join
42 pjoin = path.join
43
43
44
44
45 # Enable printing all warnings raise by IPython's modules
45 # Enable printing all warnings raise by IPython's modules
46 warnings.filterwarnings('ignore', message='.*Matplotlib is building the font cache.*', category=UserWarning, module='.*')
46 warnings.filterwarnings('ignore', message='.*Matplotlib is building the font cache.*', category=UserWarning, module='.*')
47 warnings.filterwarnings('error', message='.*', category=ResourceWarning, module='.*')
47 warnings.filterwarnings('error', message='.*', category=ResourceWarning, module='.*')
48 warnings.filterwarnings('error', message=".*{'config': True}.*", category=DeprecationWarning, module='IPy.*')
48 warnings.filterwarnings('error', message=".*{'config': True}.*", category=DeprecationWarning, module='IPy.*')
49 warnings.filterwarnings('default', message='.*', category=Warning, module='IPy.*')
49 warnings.filterwarnings('default', message='.*', category=Warning, module='IPy.*')
50
50
51 warnings.filterwarnings('error', message='.*apply_wrapper.*', category=DeprecationWarning, module='.*')
51 warnings.filterwarnings('error', message='.*apply_wrapper.*', category=DeprecationWarning, module='.*')
52 warnings.filterwarnings('error', message='.*make_label_dec', category=DeprecationWarning, module='.*')
52 warnings.filterwarnings('error', message='.*make_label_dec', category=DeprecationWarning, module='.*')
53 warnings.filterwarnings('error', message='.*decorated_dummy.*', category=DeprecationWarning, module='.*')
53 warnings.filterwarnings('error', message='.*decorated_dummy.*', category=DeprecationWarning, module='.*')
54 warnings.filterwarnings('error', message='.*skip_file_no_x11.*', category=DeprecationWarning, module='.*')
54 warnings.filterwarnings('error', message='.*skip_file_no_x11.*', category=DeprecationWarning, module='.*')
55 warnings.filterwarnings('error', message='.*onlyif_any_cmd_exists.*', category=DeprecationWarning, module='.*')
55 warnings.filterwarnings('error', message='.*onlyif_any_cmd_exists.*', category=DeprecationWarning, module='.*')
56
56
57 warnings.filterwarnings('error', message='.*disable_gui.*', category=DeprecationWarning, module='.*')
57 warnings.filterwarnings('error', message='.*disable_gui.*', category=DeprecationWarning, module='.*')
58
58
59 warnings.filterwarnings('error', message='.*ExceptionColors global is deprecated.*', category=DeprecationWarning, module='.*')
59 warnings.filterwarnings('error', message='.*ExceptionColors global is deprecated.*', category=DeprecationWarning, module='.*')
60
60
61 # Jedi older versions
61 # Jedi older versions
62 warnings.filterwarnings(
62 warnings.filterwarnings(
63 'error', message='.*elementwise != comparison failed and.*', category=FutureWarning, module='.*')
63 'error', message='.*elementwise != comparison failed and.*', category=FutureWarning, module='.*')
64
64
65 if version_info < (6,):
65 if version_info < (6,):
66 # nose.tools renames all things from `camelCase` to `snake_case` which raise an
66 # nose.tools renames all things from `camelCase` to `snake_case` which raise an
67 # warning with the runner they also import from standard import library. (as of Dec 2015)
67 # warning with the runner they also import from standard import library. (as of Dec 2015)
68 # Ignore, let's revisit that in a couple of years for IPython 6.
68 # Ignore, let's revisit that in a couple of years for IPython 6.
69 warnings.filterwarnings(
69 warnings.filterwarnings(
70 'ignore', message='.*Please use assertEqual instead', category=Warning, module='IPython.*')
70 'ignore', message='.*Please use assertEqual instead', category=Warning, module='IPython.*')
71
71
72 if version_info < (7,):
72 if version_info < (7,):
73 warnings.filterwarnings('ignore', message='.*Completer.complete.*',
73 warnings.filterwarnings('ignore', message='.*Completer.complete.*',
74 category=PendingDeprecationWarning, module='.*')
74 category=PendingDeprecationWarning, module='.*')
75 else:
75 else:
76 warnings.warn(
76 warnings.warn(
77 'Completer.complete was pending deprecation and should be changed to Deprecated', FutureWarning)
77 'Completer.complete was pending deprecation and should be changed to Deprecated', FutureWarning)
78
78
79
79
80
80
81 # ------------------------------------------------------------------------------
81 # ------------------------------------------------------------------------------
82 # Monkeypatch Xunit to count known failures as skipped.
82 # Monkeypatch Xunit to count known failures as skipped.
83 # ------------------------------------------------------------------------------
83 # ------------------------------------------------------------------------------
84 def monkeypatch_xunit():
84 def monkeypatch_xunit():
85 try:
85 try:
86 knownfailureif(True)(lambda: None)()
86 knownfailureif(True)(lambda: None)()
87 except Exception as e:
87 except Exception as e:
88 KnownFailureTest = type(e)
88 KnownFailureTest = type(e)
89
89
90 def addError(self, test, err, capt=None):
90 def addError(self, test, err, capt=None):
91 if issubclass(err[0], KnownFailureTest):
91 if issubclass(err[0], KnownFailureTest):
92 err = (SkipTest,) + err[1:]
92 err = (SkipTest,) + err[1:]
93 return self.orig_addError(test, err, capt)
93 return self.orig_addError(test, err, capt)
94
94
95 Xunit.orig_addError = Xunit.addError
95 Xunit.orig_addError = Xunit.addError
96 Xunit.addError = addError
96 Xunit.addError = addError
97
97
98 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
99 # Check which dependencies are installed and greater than minimum version.
99 # Check which dependencies are installed and greater than minimum version.
100 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
101 def extract_version(mod):
101 def extract_version(mod):
102 return mod.__version__
102 return mod.__version__
103
103
104 def test_for(item, min_version=None, callback=extract_version):
104 def test_for(item, min_version=None, callback=extract_version):
105 """Test to see if item is importable, and optionally check against a minimum
105 """Test to see if item is importable, and optionally check against a minimum
106 version.
106 version.
107
107
108 If min_version is given, the default behavior is to check against the
108 If min_version is given, the default behavior is to check against the
109 `__version__` attribute of the item, but specifying `callback` allows you to
109 `__version__` attribute of the item, but specifying `callback` allows you to
110 extract the value you are interested in. e.g::
110 extract the value you are interested in. e.g::
111
111
112 In [1]: import sys
112 In [1]: import sys
113
113
114 In [2]: from IPython.testing.iptest import test_for
114 In [2]: from IPython.testing.iptest import test_for
115
115
116 In [3]: test_for('sys', (2,6), callback=lambda sys: sys.version_info)
116 In [3]: test_for('sys', (2,6), callback=lambda sys: sys.version_info)
117 Out[3]: True
117 Out[3]: True
118
118
119 """
119 """
120 try:
120 try:
121 check = import_item(item)
121 check = import_item(item)
122 except (ImportError, RuntimeError):
122 except (ImportError, RuntimeError):
123 # GTK reports Runtime error if it can't be initialized even if it's
123 # GTK reports Runtime error if it can't be initialized even if it's
124 # importable.
124 # importable.
125 return False
125 return False
126 else:
126 else:
127 if min_version:
127 if min_version:
128 if callback:
128 if callback:
129 # extra processing step to get version to compare
129 # extra processing step to get version to compare
130 check = callback(check)
130 check = callback(check)
131
131
132 return check >= min_version
132 return check >= min_version
133 else:
133 else:
134 return True
134 return True
135
135
136 # Global dict where we can store information on what we have and what we don't
136 # Global dict where we can store information on what we have and what we don't
137 # have available at test run time
137 # have available at test run time
138 have = {'matplotlib': test_for('matplotlib'),
138 have = {'matplotlib': test_for('matplotlib'),
139 'pygments': test_for('pygments'),
139 'pygments': test_for('pygments'),
140 'sqlite3': test_for('sqlite3')}
140 'sqlite3': test_for('sqlite3')}
141
141
142 #-----------------------------------------------------------------------------
142 #-----------------------------------------------------------------------------
143 # Test suite definitions
143 # Test suite definitions
144 #-----------------------------------------------------------------------------
144 #-----------------------------------------------------------------------------
145
145
146 test_group_names = ['core',
146 test_group_names = ['core',
147 'extensions', 'lib', 'terminal', 'testing', 'utils',
147 'extensions', 'lib', 'terminal', 'testing', 'utils',
148 ]
148 ]
149
149
150 class TestSection(object):
150 class TestSection(object):
151 def __init__(self, name, includes):
151 def __init__(self, name, includes):
152 self.name = name
152 self.name = name
153 self.includes = includes
153 self.includes = includes
154 self.excludes = []
154 self.excludes = []
155 self.dependencies = []
155 self.dependencies = []
156 self.enabled = True
156 self.enabled = True
157
157
158 def exclude(self, module):
158 def exclude(self, module):
159 if not module.startswith('IPython'):
159 if not module.startswith('IPython'):
160 module = self.includes[0] + "." + module
160 module = self.includes[0] + "." + module
161 self.excludes.append(module.replace('.', os.sep))
161 self.excludes.append(module.replace('.', os.sep))
162
162
163 def requires(self, *packages):
163 def requires(self, *packages):
164 self.dependencies.extend(packages)
164 self.dependencies.extend(packages)
165
165
166 @property
166 @property
167 def will_run(self):
167 def will_run(self):
168 return self.enabled and all(have[p] for p in self.dependencies)
168 return self.enabled and all(have[p] for p in self.dependencies)
169
169
170 # Name -> (include, exclude, dependencies_met)
170 # Name -> (include, exclude, dependencies_met)
171 test_sections = {n:TestSection(n, ['IPython.%s' % n]) for n in test_group_names}
171 test_sections = {n:TestSection(n, ['IPython.%s' % n]) for n in test_group_names}
172
172
173
173
174 # Exclusions and dependencies
174 # Exclusions and dependencies
175 # ---------------------------
175 # ---------------------------
176
176
177 # core:
177 # core:
178 sec = test_sections['core']
178 sec = test_sections['core']
179 if not have['sqlite3']:
179 if not have['sqlite3']:
180 sec.exclude('tests.test_history')
180 sec.exclude('tests.test_history')
181 sec.exclude('history')
181 sec.exclude('history')
182 if not have['matplotlib']:
182 if not have['matplotlib']:
183 sec.exclude('pylabtools'),
183 sec.exclude('pylabtools'),
184 sec.exclude('tests.test_pylabtools')
184 sec.exclude('tests.test_pylabtools')
185
185
186 # lib:
186 # lib:
187 sec = test_sections['lib']
187 sec = test_sections['lib']
188 sec.exclude('kernel')
188 sec.exclude('kernel')
189 if not have['pygments']:
189 if not have['pygments']:
190 sec.exclude('tests.test_lexers')
190 sec.exclude('tests.test_lexers')
191 # We do this unconditionally, so that the test suite doesn't import
191 # We do this unconditionally, so that the test suite doesn't import
192 # gtk, changing the default encoding and masking some unicode bugs.
192 # gtk, changing the default encoding and masking some unicode bugs.
193 sec.exclude('inputhookgtk')
193 sec.exclude('inputhookgtk')
194 # We also do this unconditionally, because wx can interfere with Unix signals.
194 # We also do this unconditionally, because wx can interfere with Unix signals.
195 # There are currently no tests for it anyway.
195 # There are currently no tests for it anyway.
196 sec.exclude('inputhookwx')
196 sec.exclude('inputhookwx')
197 # Testing inputhook will need a lot of thought, to figure out
197 # Testing inputhook will need a lot of thought, to figure out
198 # how to have tests that don't lock up with the gui event
198 # how to have tests that don't lock up with the gui event
199 # loops in the picture
199 # loops in the picture
200 sec.exclude('inputhook')
200 sec.exclude('inputhook')
201
201
202 # testing:
202 # testing:
203 sec = test_sections['testing']
203 sec = test_sections['testing']
204 # These have to be skipped on win32 because they use echo, rm, cd, etc.
204 # These have to be skipped on win32 because they use echo, rm, cd, etc.
205 # See ticket https://github.com/ipython/ipython/issues/87
205 # See ticket https://github.com/ipython/ipython/issues/87
206 if sys.platform == 'win32':
206 if sys.platform == 'win32':
207 sec.exclude('plugin.test_exampleip')
207 sec.exclude('plugin.test_exampleip')
208 sec.exclude('plugin.dtexample')
208 sec.exclude('plugin.dtexample')
209
209
210 # don't run jupyter_console tests found via shim
210 # don't run jupyter_console tests found via shim
211 test_sections['terminal'].exclude('console')
211 test_sections['terminal'].exclude('console')
212
212
213 # extensions:
213 # extensions:
214 sec = test_sections['extensions']
214 sec = test_sections['extensions']
215 # This is deprecated in favour of rpy2
215 # This is deprecated in favour of rpy2
216 sec.exclude('rmagic')
216 sec.exclude('rmagic')
217 # autoreload does some strange stuff, so move it to its own test section
217 # autoreload does some strange stuff, so move it to its own test section
218 sec.exclude('autoreload')
218 sec.exclude('autoreload')
219 sec.exclude('tests.test_autoreload')
219 sec.exclude('tests.test_autoreload')
220 test_sections['autoreload'] = TestSection('autoreload',
220 test_sections['autoreload'] = TestSection('autoreload',
221 ['IPython.extensions.autoreload', 'IPython.extensions.tests.test_autoreload'])
221 ['IPython.extensions.autoreload', 'IPython.extensions.tests.test_autoreload'])
222 test_group_names.append('autoreload')
222 test_group_names.append('autoreload')
223
223
224
224
225 #-----------------------------------------------------------------------------
225 #-----------------------------------------------------------------------------
226 # Functions and classes
226 # Functions and classes
227 #-----------------------------------------------------------------------------
227 #-----------------------------------------------------------------------------
228
228
229 def check_exclusions_exist():
229 def check_exclusions_exist():
230 from IPython.paths import get_ipython_package_dir
230 from IPython.paths import get_ipython_package_dir
231 from warnings import warn
231 from warnings import warn
232 parent = os.path.dirname(get_ipython_package_dir())
232 parent = os.path.dirname(get_ipython_package_dir())
233 for sec in test_sections:
233 for sec in test_sections:
234 for pattern in sec.exclusions:
234 for pattern in sec.exclusions:
235 fullpath = pjoin(parent, pattern)
235 fullpath = pjoin(parent, pattern)
236 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
236 if not os.path.exists(fullpath) and not glob.glob(fullpath + '.*'):
237 warn("Excluding nonexistent file: %r" % pattern)
237 warn("Excluding nonexistent file: %r" % pattern)
238
238
239
239
240 class ExclusionPlugin(Plugin):
240 class ExclusionPlugin(Plugin):
241 """A nose plugin to effect our exclusions of files and directories.
241 """A nose plugin to effect our exclusions of files and directories.
242 """
242 """
243 name = 'exclusions'
243 name = 'exclusions'
244 score = 3000 # Should come before any other plugins
244 score = 3000 # Should come before any other plugins
245
245
246 def __init__(self, exclude_patterns=None):
246 def __init__(self, exclude_patterns=None):
247 """
247 """
248 Parameters
248 Parameters
249 ----------
249 ----------
250
250
251 exclude_patterns : sequence of strings, optional
251 exclude_patterns : sequence of strings, optional
252 Filenames containing these patterns (as raw strings, not as regular
252 Filenames containing these patterns (as raw strings, not as regular
253 expressions) are excluded from the tests.
253 expressions) are excluded from the tests.
254 """
254 """
255 self.exclude_patterns = exclude_patterns or []
255 self.exclude_patterns = exclude_patterns or []
256 super(ExclusionPlugin, self).__init__()
256 super(ExclusionPlugin, self).__init__()
257
257
258 def options(self, parser, env=os.environ):
258 def options(self, parser, env=os.environ):
259 Plugin.options(self, parser, env)
259 Plugin.options(self, parser, env)
260
260
261 def configure(self, options, config):
261 def configure(self, options, config):
262 Plugin.configure(self, options, config)
262 Plugin.configure(self, options, config)
263 # Override nose trying to disable plugin.
263 # Override nose trying to disable plugin.
264 self.enabled = True
264 self.enabled = True
265
265
266 def wantFile(self, filename):
266 def wantFile(self, filename):
267 """Return whether the given filename should be scanned for tests.
267 """Return whether the given filename should be scanned for tests.
268 """
268 """
269 if any(pat in filename for pat in self.exclude_patterns):
269 if any(pat in filename for pat in self.exclude_patterns):
270 return False
270 return False
271 return None
271 return None
272
272
273 def wantDirectory(self, directory):
273 def wantDirectory(self, directory):
274 """Return whether the given directory should be scanned for tests.
274 """Return whether the given directory should be scanned for tests.
275 """
275 """
276 if any(pat in directory for pat in self.exclude_patterns):
276 if any(pat in directory for pat in self.exclude_patterns):
277 return False
277 return False
278 return None
278 return None
279
279
280
280
281 class StreamCapturer(Thread):
281 class StreamCapturer(Thread):
282 daemon = True # Don't hang if main thread crashes
282 daemon = True # Don't hang if main thread crashes
283 started = False
283 started = False
284 def __init__(self, echo=False):
284 def __init__(self, echo=False):
285 super(StreamCapturer, self).__init__()
285 super(StreamCapturer, self).__init__()
286 self.echo = echo
286 self.echo = echo
287 self.streams = []
287 self.streams = []
288 self.buffer = BytesIO()
288 self.buffer = BytesIO()
289 self.readfd, self.writefd = os.pipe()
289 self.readfd, self.writefd = os.pipe()
290 self.buffer_lock = Lock()
290 self.buffer_lock = Lock()
291 self.stop = Event()
291 self.stop = Event()
292
292
293 def run(self):
293 def run(self):
294 self.started = True
294 self.started = True
295
295
296 while not self.stop.is_set():
296 while not self.stop.is_set():
297 chunk = os.read(self.readfd, 1024)
297 chunk = os.read(self.readfd, 1024)
298
298
299 with self.buffer_lock:
299 with self.buffer_lock:
300 self.buffer.write(chunk)
300 self.buffer.write(chunk)
301 if self.echo:
301 if self.echo:
302 sys.stdout.write(bytes_to_str(chunk))
302 sys.stdout.write(decode(chunk))
303
303
304 os.close(self.readfd)
304 os.close(self.readfd)
305 os.close(self.writefd)
305 os.close(self.writefd)
306
306
307 def reset_buffer(self):
307 def reset_buffer(self):
308 with self.buffer_lock:
308 with self.buffer_lock:
309 self.buffer.truncate(0)
309 self.buffer.truncate(0)
310 self.buffer.seek(0)
310 self.buffer.seek(0)
311
311
312 def get_buffer(self):
312 def get_buffer(self):
313 with self.buffer_lock:
313 with self.buffer_lock:
314 return self.buffer.getvalue()
314 return self.buffer.getvalue()
315
315
316 def ensure_started(self):
316 def ensure_started(self):
317 if not self.started:
317 if not self.started:
318 self.start()
318 self.start()
319
319
320 def halt(self):
320 def halt(self):
321 """Safely stop the thread."""
321 """Safely stop the thread."""
322 if not self.started:
322 if not self.started:
323 return
323 return
324
324
325 self.stop.set()
325 self.stop.set()
326 os.write(self.writefd, b'\0') # Ensure we're not locked in a read()
326 os.write(self.writefd, b'\0') # Ensure we're not locked in a read()
327 self.join()
327 self.join()
328
328
329 class SubprocessStreamCapturePlugin(Plugin):
329 class SubprocessStreamCapturePlugin(Plugin):
330 name='subprocstreams'
330 name='subprocstreams'
331 def __init__(self):
331 def __init__(self):
332 Plugin.__init__(self)
332 Plugin.__init__(self)
333 self.stream_capturer = StreamCapturer()
333 self.stream_capturer = StreamCapturer()
334 self.destination = os.environ.get('IPTEST_SUBPROC_STREAMS', 'capture')
334 self.destination = os.environ.get('IPTEST_SUBPROC_STREAMS', 'capture')
335 # This is ugly, but distant parts of the test machinery need to be able
335 # This is ugly, but distant parts of the test machinery need to be able
336 # to redirect streams, so we make the object globally accessible.
336 # to redirect streams, so we make the object globally accessible.
337 nose.iptest_stdstreams_fileno = self.get_write_fileno
337 nose.iptest_stdstreams_fileno = self.get_write_fileno
338
338
339 def get_write_fileno(self):
339 def get_write_fileno(self):
340 if self.destination == 'capture':
340 if self.destination == 'capture':
341 self.stream_capturer.ensure_started()
341 self.stream_capturer.ensure_started()
342 return self.stream_capturer.writefd
342 return self.stream_capturer.writefd
343 elif self.destination == 'discard':
343 elif self.destination == 'discard':
344 return os.open(os.devnull, os.O_WRONLY)
344 return os.open(os.devnull, os.O_WRONLY)
345 else:
345 else:
346 return sys.__stdout__.fileno()
346 return sys.__stdout__.fileno()
347
347
348 def configure(self, options, config):
348 def configure(self, options, config):
349 Plugin.configure(self, options, config)
349 Plugin.configure(self, options, config)
350 # Override nose trying to disable plugin.
350 # Override nose trying to disable plugin.
351 if self.destination == 'capture':
351 if self.destination == 'capture':
352 self.enabled = True
352 self.enabled = True
353
353
354 def startTest(self, test):
354 def startTest(self, test):
355 # Reset log capture
355 # Reset log capture
356 self.stream_capturer.reset_buffer()
356 self.stream_capturer.reset_buffer()
357
357
358 def formatFailure(self, test, err):
358 def formatFailure(self, test, err):
359 # Show output
359 # Show output
360 ec, ev, tb = err
360 ec, ev, tb = err
361 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
361 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
362 if captured.strip():
362 if captured.strip():
363 ev = safe_str(ev)
363 ev = safe_str(ev)
364 out = [ev, '>> begin captured subprocess output <<',
364 out = [ev, '>> begin captured subprocess output <<',
365 captured,
365 captured,
366 '>> end captured subprocess output <<']
366 '>> end captured subprocess output <<']
367 return ec, '\n'.join(out), tb
367 return ec, '\n'.join(out), tb
368
368
369 return err
369 return err
370
370
371 formatError = formatFailure
371 formatError = formatFailure
372
372
373 def finalize(self, result):
373 def finalize(self, result):
374 self.stream_capturer.halt()
374 self.stream_capturer.halt()
375
375
376
376
377 def run_iptest():
377 def run_iptest():
378 """Run the IPython test suite using nose.
378 """Run the IPython test suite using nose.
379
379
380 This function is called when this script is **not** called with the form
380 This function is called when this script is **not** called with the form
381 `iptest all`. It simply calls nose with appropriate command line flags
381 `iptest all`. It simply calls nose with appropriate command line flags
382 and accepts all of the standard nose arguments.
382 and accepts all of the standard nose arguments.
383 """
383 """
384 # Apply our monkeypatch to Xunit
384 # Apply our monkeypatch to Xunit
385 if '--with-xunit' in sys.argv and not hasattr(Xunit, 'orig_addError'):
385 if '--with-xunit' in sys.argv and not hasattr(Xunit, 'orig_addError'):
386 monkeypatch_xunit()
386 monkeypatch_xunit()
387
387
388 arg1 = sys.argv[1]
388 arg1 = sys.argv[1]
389 if arg1 in test_sections:
389 if arg1 in test_sections:
390 section = test_sections[arg1]
390 section = test_sections[arg1]
391 sys.argv[1:2] = section.includes
391 sys.argv[1:2] = section.includes
392 elif arg1.startswith('IPython.') and arg1[8:] in test_sections:
392 elif arg1.startswith('IPython.') and arg1[8:] in test_sections:
393 section = test_sections[arg1[8:]]
393 section = test_sections[arg1[8:]]
394 sys.argv[1:2] = section.includes
394 sys.argv[1:2] = section.includes
395 else:
395 else:
396 section = TestSection(arg1, includes=[arg1])
396 section = TestSection(arg1, includes=[arg1])
397
397
398
398
399 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
399 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
400 # We add --exe because of setuptools' imbecility (it
400 # We add --exe because of setuptools' imbecility (it
401 # blindly does chmod +x on ALL files). Nose does the
401 # blindly does chmod +x on ALL files). Nose does the
402 # right thing and it tries to avoid executables,
402 # right thing and it tries to avoid executables,
403 # setuptools unfortunately forces our hand here. This
403 # setuptools unfortunately forces our hand here. This
404 # has been discussed on the distutils list and the
404 # has been discussed on the distutils list and the
405 # setuptools devs refuse to fix this problem!
405 # setuptools devs refuse to fix this problem!
406 '--exe',
406 '--exe',
407 ]
407 ]
408 if '-a' not in argv and '-A' not in argv:
408 if '-a' not in argv and '-A' not in argv:
409 argv = argv + ['-a', '!crash']
409 argv = argv + ['-a', '!crash']
410
410
411 if nose.__version__ >= '0.11':
411 if nose.__version__ >= '0.11':
412 # I don't fully understand why we need this one, but depending on what
412 # I don't fully understand why we need this one, but depending on what
413 # directory the test suite is run from, if we don't give it, 0 tests
413 # directory the test suite is run from, if we don't give it, 0 tests
414 # get run. Specifically, if the test suite is run from the source dir
414 # get run. Specifically, if the test suite is run from the source dir
415 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
415 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
416 # even if the same call done in this directory works fine). It appears
416 # even if the same call done in this directory works fine). It appears
417 # that if the requested package is in the current dir, nose bails early
417 # that if the requested package is in the current dir, nose bails early
418 # by default. Since it's otherwise harmless, leave it in by default
418 # by default. Since it's otherwise harmless, leave it in by default
419 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
419 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
420 argv.append('--traverse-namespace')
420 argv.append('--traverse-namespace')
421
421
422 plugins = [ ExclusionPlugin(section.excludes), KnownFailure(),
422 plugins = [ ExclusionPlugin(section.excludes), KnownFailure(),
423 SubprocessStreamCapturePlugin() ]
423 SubprocessStreamCapturePlugin() ]
424
424
425 # we still have some vestigial doctests in core
425 # we still have some vestigial doctests in core
426 if (section.name.startswith(('core', 'IPython.core', 'IPython.utils'))):
426 if (section.name.startswith(('core', 'IPython.core', 'IPython.utils'))):
427 plugins.append(IPythonDoctest())
427 plugins.append(IPythonDoctest())
428 argv.extend([
428 argv.extend([
429 '--with-ipdoctest',
429 '--with-ipdoctest',
430 '--ipdoctest-tests',
430 '--ipdoctest-tests',
431 '--ipdoctest-extension=txt',
431 '--ipdoctest-extension=txt',
432 ])
432 ])
433
433
434
434
435 # Use working directory set by parent process (see iptestcontroller)
435 # Use working directory set by parent process (see iptestcontroller)
436 if 'IPTEST_WORKING_DIR' in os.environ:
436 if 'IPTEST_WORKING_DIR' in os.environ:
437 os.chdir(os.environ['IPTEST_WORKING_DIR'])
437 os.chdir(os.environ['IPTEST_WORKING_DIR'])
438
438
439 # We need a global ipython running in this process, but the special
439 # We need a global ipython running in this process, but the special
440 # in-process group spawns its own IPython kernels, so for *that* group we
440 # in-process group spawns its own IPython kernels, so for *that* group we
441 # must avoid also opening the global one (otherwise there's a conflict of
441 # must avoid also opening the global one (otherwise there's a conflict of
442 # singletons). Ultimately the solution to this problem is to refactor our
442 # singletons). Ultimately the solution to this problem is to refactor our
443 # assumptions about what needs to be a singleton and what doesn't (app
443 # assumptions about what needs to be a singleton and what doesn't (app
444 # objects should, individual shells shouldn't). But for now, this
444 # objects should, individual shells shouldn't). But for now, this
445 # workaround allows the test suite for the inprocess module to complete.
445 # workaround allows the test suite for the inprocess module to complete.
446 if 'kernel.inprocess' not in section.name:
446 if 'kernel.inprocess' not in section.name:
447 from IPython.testing import globalipapp
447 from IPython.testing import globalipapp
448 globalipapp.start_ipython()
448 globalipapp.start_ipython()
449
449
450 # Now nose can run
450 # Now nose can run
451 TestProgram(argv=argv, addplugins=plugins)
451 TestProgram(argv=argv, addplugins=plugins)
452
452
453 if __name__ == '__main__':
453 if __name__ == '__main__':
454 run_iptest()
454 run_iptest()
@@ -1,510 +1,510
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """IPython Test Process Controller
2 """IPython Test Process Controller
3
3
4 This module runs one or more subprocesses which will actually run the IPython
4 This module runs one or more subprocesses which will actually run the IPython
5 test suite.
5 test suite.
6
6
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12
12
13 import argparse
13 import argparse
14 import multiprocessing.pool
14 import multiprocessing.pool
15 import os
15 import os
16 import stat
16 import stat
17 import shutil
17 import shutil
18 import signal
18 import signal
19 import sys
19 import sys
20 import subprocess
20 import subprocess
21 import time
21 import time
22
22
23 from .iptest import (
23 from .iptest import (
24 have, test_group_names as py_test_group_names, test_sections, StreamCapturer,
24 have, test_group_names as py_test_group_names, test_sections, StreamCapturer,
25 )
25 )
26 from IPython.utils.path import compress_user
26 from IPython.utils.path import compress_user
27 from IPython.utils.py3compat import bytes_to_str
27 from IPython.utils.py3compat import decode
28 from IPython.utils.sysinfo import get_sys_info
28 from IPython.utils.sysinfo import get_sys_info
29 from IPython.utils.tempdir import TemporaryDirectory
29 from IPython.utils.tempdir import TemporaryDirectory
30
30
31 def popen_wait(p, timeout):
31 def popen_wait(p, timeout):
32 return p.wait(timeout)
32 return p.wait(timeout)
33
33
34 class TestController(object):
34 class TestController(object):
35 """Run tests in a subprocess
35 """Run tests in a subprocess
36 """
36 """
37 #: str, IPython test suite to be executed.
37 #: str, IPython test suite to be executed.
38 section = None
38 section = None
39 #: list, command line arguments to be executed
39 #: list, command line arguments to be executed
40 cmd = None
40 cmd = None
41 #: dict, extra environment variables to set for the subprocess
41 #: dict, extra environment variables to set for the subprocess
42 env = None
42 env = None
43 #: list, TemporaryDirectory instances to clear up when the process finishes
43 #: list, TemporaryDirectory instances to clear up when the process finishes
44 dirs = None
44 dirs = None
45 #: subprocess.Popen instance
45 #: subprocess.Popen instance
46 process = None
46 process = None
47 #: str, process stdout+stderr
47 #: str, process stdout+stderr
48 stdout = None
48 stdout = None
49
49
50 def __init__(self):
50 def __init__(self):
51 self.cmd = []
51 self.cmd = []
52 self.env = {}
52 self.env = {}
53 self.dirs = []
53 self.dirs = []
54
54
55 def setup(self):
55 def setup(self):
56 """Create temporary directories etc.
56 """Create temporary directories etc.
57
57
58 This is only called when we know the test group will be run. Things
58 This is only called when we know the test group will be run. Things
59 created here may be cleaned up by self.cleanup().
59 created here may be cleaned up by self.cleanup().
60 """
60 """
61 pass
61 pass
62
62
63 def launch(self, buffer_output=False, capture_output=False):
63 def launch(self, buffer_output=False, capture_output=False):
64 # print('*** ENV:', self.env) # dbg
64 # print('*** ENV:', self.env) # dbg
65 # print('*** CMD:', self.cmd) # dbg
65 # print('*** CMD:', self.cmd) # dbg
66 env = os.environ.copy()
66 env = os.environ.copy()
67 env.update(self.env)
67 env.update(self.env)
68 if buffer_output:
68 if buffer_output:
69 capture_output = True
69 capture_output = True
70 self.stdout_capturer = c = StreamCapturer(echo=not buffer_output)
70 self.stdout_capturer = c = StreamCapturer(echo=not buffer_output)
71 c.start()
71 c.start()
72 stdout = c.writefd if capture_output else None
72 stdout = c.writefd if capture_output else None
73 stderr = subprocess.STDOUT if capture_output else None
73 stderr = subprocess.STDOUT if capture_output else None
74 self.process = subprocess.Popen(self.cmd, stdout=stdout,
74 self.process = subprocess.Popen(self.cmd, stdout=stdout,
75 stderr=stderr, env=env)
75 stderr=stderr, env=env)
76
76
77 def wait(self):
77 def wait(self):
78 self.process.wait()
78 self.process.wait()
79 self.stdout_capturer.halt()
79 self.stdout_capturer.halt()
80 self.stdout = self.stdout_capturer.get_buffer()
80 self.stdout = self.stdout_capturer.get_buffer()
81 return self.process.returncode
81 return self.process.returncode
82
82
83 def print_extra_info(self):
83 def print_extra_info(self):
84 """Print extra information about this test run.
84 """Print extra information about this test run.
85
85
86 If we're running in parallel and showing the concise view, this is only
86 If we're running in parallel and showing the concise view, this is only
87 called if the test group fails. Otherwise, it's called before the test
87 called if the test group fails. Otherwise, it's called before the test
88 group is started.
88 group is started.
89
89
90 The base implementation does nothing, but it can be overridden by
90 The base implementation does nothing, but it can be overridden by
91 subclasses.
91 subclasses.
92 """
92 """
93 return
93 return
94
94
95 def cleanup_process(self):
95 def cleanup_process(self):
96 """Cleanup on exit by killing any leftover processes."""
96 """Cleanup on exit by killing any leftover processes."""
97 subp = self.process
97 subp = self.process
98 if subp is None or (subp.poll() is not None):
98 if subp is None or (subp.poll() is not None):
99 return # Process doesn't exist, or is already dead.
99 return # Process doesn't exist, or is already dead.
100
100
101 try:
101 try:
102 print('Cleaning up stale PID: %d' % subp.pid)
102 print('Cleaning up stale PID: %d' % subp.pid)
103 subp.kill()
103 subp.kill()
104 except: # (OSError, WindowsError) ?
104 except: # (OSError, WindowsError) ?
105 # This is just a best effort, if we fail or the process was
105 # This is just a best effort, if we fail or the process was
106 # really gone, ignore it.
106 # really gone, ignore it.
107 pass
107 pass
108 else:
108 else:
109 for i in range(10):
109 for i in range(10):
110 if subp.poll() is None:
110 if subp.poll() is None:
111 time.sleep(0.1)
111 time.sleep(0.1)
112 else:
112 else:
113 break
113 break
114
114
115 if subp.poll() is None:
115 if subp.poll() is None:
116 # The process did not die...
116 # The process did not die...
117 print('... failed. Manual cleanup may be required.')
117 print('... failed. Manual cleanup may be required.')
118
118
119 def cleanup(self):
119 def cleanup(self):
120 "Kill process if it's still alive, and clean up temporary directories"
120 "Kill process if it's still alive, and clean up temporary directories"
121 self.cleanup_process()
121 self.cleanup_process()
122 for td in self.dirs:
122 for td in self.dirs:
123 td.cleanup()
123 td.cleanup()
124
124
125 __del__ = cleanup
125 __del__ = cleanup
126
126
127
127
128 class PyTestController(TestController):
128 class PyTestController(TestController):
129 """Run Python tests using IPython.testing.iptest"""
129 """Run Python tests using IPython.testing.iptest"""
130 #: str, Python command to execute in subprocess
130 #: str, Python command to execute in subprocess
131 pycmd = None
131 pycmd = None
132
132
133 def __init__(self, section, options):
133 def __init__(self, section, options):
134 """Create new test runner."""
134 """Create new test runner."""
135 TestController.__init__(self)
135 TestController.__init__(self)
136 self.section = section
136 self.section = section
137 # pycmd is put into cmd[2] in PyTestController.launch()
137 # pycmd is put into cmd[2] in PyTestController.launch()
138 self.cmd = [sys.executable, '-c', None, section]
138 self.cmd = [sys.executable, '-c', None, section]
139 self.pycmd = "from IPython.testing.iptest import run_iptest; run_iptest()"
139 self.pycmd = "from IPython.testing.iptest import run_iptest; run_iptest()"
140 self.options = options
140 self.options = options
141
141
142 def setup(self):
142 def setup(self):
143 ipydir = TemporaryDirectory()
143 ipydir = TemporaryDirectory()
144 self.dirs.append(ipydir)
144 self.dirs.append(ipydir)
145 self.env['IPYTHONDIR'] = ipydir.name
145 self.env['IPYTHONDIR'] = ipydir.name
146 self.workingdir = workingdir = TemporaryDirectory()
146 self.workingdir = workingdir = TemporaryDirectory()
147 self.dirs.append(workingdir)
147 self.dirs.append(workingdir)
148 self.env['IPTEST_WORKING_DIR'] = workingdir.name
148 self.env['IPTEST_WORKING_DIR'] = workingdir.name
149 # This means we won't get odd effects from our own matplotlib config
149 # This means we won't get odd effects from our own matplotlib config
150 self.env['MPLCONFIGDIR'] = workingdir.name
150 self.env['MPLCONFIGDIR'] = workingdir.name
151 # For security reasons (http://bugs.python.org/issue16202), use
151 # For security reasons (http://bugs.python.org/issue16202), use
152 # a temporary directory to which other users have no access.
152 # a temporary directory to which other users have no access.
153 self.env['TMPDIR'] = workingdir.name
153 self.env['TMPDIR'] = workingdir.name
154
154
155 # Add a non-accessible directory to PATH (see gh-7053)
155 # Add a non-accessible directory to PATH (see gh-7053)
156 noaccess = os.path.join(self.workingdir.name, "_no_access_")
156 noaccess = os.path.join(self.workingdir.name, "_no_access_")
157 self.noaccess = noaccess
157 self.noaccess = noaccess
158 os.mkdir(noaccess, 0)
158 os.mkdir(noaccess, 0)
159
159
160 PATH = os.environ.get('PATH', '')
160 PATH = os.environ.get('PATH', '')
161 if PATH:
161 if PATH:
162 PATH = noaccess + os.pathsep + PATH
162 PATH = noaccess + os.pathsep + PATH
163 else:
163 else:
164 PATH = noaccess
164 PATH = noaccess
165 self.env['PATH'] = PATH
165 self.env['PATH'] = PATH
166
166
167 # From options:
167 # From options:
168 if self.options.xunit:
168 if self.options.xunit:
169 self.add_xunit()
169 self.add_xunit()
170 if self.options.coverage:
170 if self.options.coverage:
171 self.add_coverage()
171 self.add_coverage()
172 self.env['IPTEST_SUBPROC_STREAMS'] = self.options.subproc_streams
172 self.env['IPTEST_SUBPROC_STREAMS'] = self.options.subproc_streams
173 self.cmd.extend(self.options.extra_args)
173 self.cmd.extend(self.options.extra_args)
174
174
175 def cleanup(self):
175 def cleanup(self):
176 """
176 """
177 Make the non-accessible directory created in setup() accessible
177 Make the non-accessible directory created in setup() accessible
178 again, otherwise deleting the workingdir will fail.
178 again, otherwise deleting the workingdir will fail.
179 """
179 """
180 os.chmod(self.noaccess, stat.S_IRWXU)
180 os.chmod(self.noaccess, stat.S_IRWXU)
181 TestController.cleanup(self)
181 TestController.cleanup(self)
182
182
183 @property
183 @property
184 def will_run(self):
184 def will_run(self):
185 try:
185 try:
186 return test_sections[self.section].will_run
186 return test_sections[self.section].will_run
187 except KeyError:
187 except KeyError:
188 return True
188 return True
189
189
190 def add_xunit(self):
190 def add_xunit(self):
191 xunit_file = os.path.abspath(self.section + '.xunit.xml')
191 xunit_file = os.path.abspath(self.section + '.xunit.xml')
192 self.cmd.extend(['--with-xunit', '--xunit-file', xunit_file])
192 self.cmd.extend(['--with-xunit', '--xunit-file', xunit_file])
193
193
194 def add_coverage(self):
194 def add_coverage(self):
195 try:
195 try:
196 sources = test_sections[self.section].includes
196 sources = test_sections[self.section].includes
197 except KeyError:
197 except KeyError:
198 sources = ['IPython']
198 sources = ['IPython']
199
199
200 coverage_rc = ("[run]\n"
200 coverage_rc = ("[run]\n"
201 "data_file = {data_file}\n"
201 "data_file = {data_file}\n"
202 "source =\n"
202 "source =\n"
203 " {source}\n"
203 " {source}\n"
204 ).format(data_file=os.path.abspath('.coverage.'+self.section),
204 ).format(data_file=os.path.abspath('.coverage.'+self.section),
205 source="\n ".join(sources))
205 source="\n ".join(sources))
206 config_file = os.path.join(self.workingdir.name, '.coveragerc')
206 config_file = os.path.join(self.workingdir.name, '.coveragerc')
207 with open(config_file, 'w') as f:
207 with open(config_file, 'w') as f:
208 f.write(coverage_rc)
208 f.write(coverage_rc)
209
209
210 self.env['COVERAGE_PROCESS_START'] = config_file
210 self.env['COVERAGE_PROCESS_START'] = config_file
211 self.pycmd = "import coverage; coverage.process_startup(); " + self.pycmd
211 self.pycmd = "import coverage; coverage.process_startup(); " + self.pycmd
212
212
213 def launch(self, buffer_output=False):
213 def launch(self, buffer_output=False):
214 self.cmd[2] = self.pycmd
214 self.cmd[2] = self.pycmd
215 super(PyTestController, self).launch(buffer_output=buffer_output)
215 super(PyTestController, self).launch(buffer_output=buffer_output)
216
216
217
217
218 def prepare_controllers(options):
218 def prepare_controllers(options):
219 """Returns two lists of TestController instances, those to run, and those
219 """Returns two lists of TestController instances, those to run, and those
220 not to run."""
220 not to run."""
221 testgroups = options.testgroups
221 testgroups = options.testgroups
222 if not testgroups:
222 if not testgroups:
223 testgroups = py_test_group_names
223 testgroups = py_test_group_names
224
224
225 controllers = [PyTestController(name, options) for name in testgroups]
225 controllers = [PyTestController(name, options) for name in testgroups]
226
226
227 to_run = [c for c in controllers if c.will_run]
227 to_run = [c for c in controllers if c.will_run]
228 not_run = [c for c in controllers if not c.will_run]
228 not_run = [c for c in controllers if not c.will_run]
229 return to_run, not_run
229 return to_run, not_run
230
230
231 def do_run(controller, buffer_output=True):
231 def do_run(controller, buffer_output=True):
232 """Setup and run a test controller.
232 """Setup and run a test controller.
233
233
234 If buffer_output is True, no output is displayed, to avoid it appearing
234 If buffer_output is True, no output is displayed, to avoid it appearing
235 interleaved. In this case, the caller is responsible for displaying test
235 interleaved. In this case, the caller is responsible for displaying test
236 output on failure.
236 output on failure.
237
237
238 Returns
238 Returns
239 -------
239 -------
240 controller : TestController
240 controller : TestController
241 The same controller as passed in, as a convenience for using map() type
241 The same controller as passed in, as a convenience for using map() type
242 APIs.
242 APIs.
243 exitcode : int
243 exitcode : int
244 The exit code of the test subprocess. Non-zero indicates failure.
244 The exit code of the test subprocess. Non-zero indicates failure.
245 """
245 """
246 try:
246 try:
247 try:
247 try:
248 controller.setup()
248 controller.setup()
249 if not buffer_output:
249 if not buffer_output:
250 controller.print_extra_info()
250 controller.print_extra_info()
251 controller.launch(buffer_output=buffer_output)
251 controller.launch(buffer_output=buffer_output)
252 except Exception:
252 except Exception:
253 import traceback
253 import traceback
254 traceback.print_exc()
254 traceback.print_exc()
255 return controller, 1 # signal failure
255 return controller, 1 # signal failure
256
256
257 exitcode = controller.wait()
257 exitcode = controller.wait()
258 return controller, exitcode
258 return controller, exitcode
259
259
260 except KeyboardInterrupt:
260 except KeyboardInterrupt:
261 return controller, -signal.SIGINT
261 return controller, -signal.SIGINT
262 finally:
262 finally:
263 controller.cleanup()
263 controller.cleanup()
264
264
265 def report():
265 def report():
266 """Return a string with a summary report of test-related variables."""
266 """Return a string with a summary report of test-related variables."""
267 inf = get_sys_info()
267 inf = get_sys_info()
268 out = []
268 out = []
269 def _add(name, value):
269 def _add(name, value):
270 out.append((name, value))
270 out.append((name, value))
271
271
272 _add('IPython version', inf['ipython_version'])
272 _add('IPython version', inf['ipython_version'])
273 _add('IPython commit', "{} ({})".format(inf['commit_hash'], inf['commit_source']))
273 _add('IPython commit', "{} ({})".format(inf['commit_hash'], inf['commit_source']))
274 _add('IPython package', compress_user(inf['ipython_path']))
274 _add('IPython package', compress_user(inf['ipython_path']))
275 _add('Python version', inf['sys_version'].replace('\n',''))
275 _add('Python version', inf['sys_version'].replace('\n',''))
276 _add('sys.executable', compress_user(inf['sys_executable']))
276 _add('sys.executable', compress_user(inf['sys_executable']))
277 _add('Platform', inf['platform'])
277 _add('Platform', inf['platform'])
278
278
279 width = max(len(n) for (n,v) in out)
279 width = max(len(n) for (n,v) in out)
280 out = ["{:<{width}}: {}\n".format(n, v, width=width) for (n,v) in out]
280 out = ["{:<{width}}: {}\n".format(n, v, width=width) for (n,v) in out]
281
281
282 avail = []
282 avail = []
283 not_avail = []
283 not_avail = []
284
284
285 for k, is_avail in have.items():
285 for k, is_avail in have.items():
286 if is_avail:
286 if is_avail:
287 avail.append(k)
287 avail.append(k)
288 else:
288 else:
289 not_avail.append(k)
289 not_avail.append(k)
290
290
291 if avail:
291 if avail:
292 out.append('\nTools and libraries available at test time:\n')
292 out.append('\nTools and libraries available at test time:\n')
293 avail.sort()
293 avail.sort()
294 out.append(' ' + ' '.join(avail)+'\n')
294 out.append(' ' + ' '.join(avail)+'\n')
295
295
296 if not_avail:
296 if not_avail:
297 out.append('\nTools and libraries NOT available at test time:\n')
297 out.append('\nTools and libraries NOT available at test time:\n')
298 not_avail.sort()
298 not_avail.sort()
299 out.append(' ' + ' '.join(not_avail)+'\n')
299 out.append(' ' + ' '.join(not_avail)+'\n')
300
300
301 return ''.join(out)
301 return ''.join(out)
302
302
303 def run_iptestall(options):
303 def run_iptestall(options):
304 """Run the entire IPython test suite by calling nose and trial.
304 """Run the entire IPython test suite by calling nose and trial.
305
305
306 This function constructs :class:`IPTester` instances for all IPython
306 This function constructs :class:`IPTester` instances for all IPython
307 modules and package and then runs each of them. This causes the modules
307 modules and package and then runs each of them. This causes the modules
308 and packages of IPython to be tested each in their own subprocess using
308 and packages of IPython to be tested each in their own subprocess using
309 nose.
309 nose.
310
310
311 Parameters
311 Parameters
312 ----------
312 ----------
313
313
314 All parameters are passed as attributes of the options object.
314 All parameters are passed as attributes of the options object.
315
315
316 testgroups : list of str
316 testgroups : list of str
317 Run only these sections of the test suite. If empty, run all the available
317 Run only these sections of the test suite. If empty, run all the available
318 sections.
318 sections.
319
319
320 fast : int or None
320 fast : int or None
321 Run the test suite in parallel, using n simultaneous processes. If None
321 Run the test suite in parallel, using n simultaneous processes. If None
322 is passed, one process is used per CPU core. Default 1 (i.e. sequential)
322 is passed, one process is used per CPU core. Default 1 (i.e. sequential)
323
323
324 inc_slow : bool
324 inc_slow : bool
325 Include slow tests. By default, these tests aren't run.
325 Include slow tests. By default, these tests aren't run.
326
326
327 url : unicode
327 url : unicode
328 Address:port to use when running the JS tests.
328 Address:port to use when running the JS tests.
329
329
330 xunit : bool
330 xunit : bool
331 Produce Xunit XML output. This is written to multiple foo.xunit.xml files.
331 Produce Xunit XML output. This is written to multiple foo.xunit.xml files.
332
332
333 coverage : bool or str
333 coverage : bool or str
334 Measure code coverage from tests. True will store the raw coverage data,
334 Measure code coverage from tests. True will store the raw coverage data,
335 or pass 'html' or 'xml' to get reports.
335 or pass 'html' or 'xml' to get reports.
336
336
337 extra_args : list
337 extra_args : list
338 Extra arguments to pass to the test subprocesses, e.g. '-v'
338 Extra arguments to pass to the test subprocesses, e.g. '-v'
339 """
339 """
340 to_run, not_run = prepare_controllers(options)
340 to_run, not_run = prepare_controllers(options)
341
341
342 def justify(ltext, rtext, width=70, fill='-'):
342 def justify(ltext, rtext, width=70, fill='-'):
343 ltext += ' '
343 ltext += ' '
344 rtext = (' ' + rtext).rjust(width - len(ltext), fill)
344 rtext = (' ' + rtext).rjust(width - len(ltext), fill)
345 return ltext + rtext
345 return ltext + rtext
346
346
347 # Run all test runners, tracking execution time
347 # Run all test runners, tracking execution time
348 failed = []
348 failed = []
349 t_start = time.time()
349 t_start = time.time()
350
350
351 print()
351 print()
352 if options.fast == 1:
352 if options.fast == 1:
353 # This actually means sequential, i.e. with 1 job
353 # This actually means sequential, i.e. with 1 job
354 for controller in to_run:
354 for controller in to_run:
355 print('Test group:', controller.section)
355 print('Test group:', controller.section)
356 sys.stdout.flush() # Show in correct order when output is piped
356 sys.stdout.flush() # Show in correct order when output is piped
357 controller, res = do_run(controller, buffer_output=False)
357 controller, res = do_run(controller, buffer_output=False)
358 if res:
358 if res:
359 failed.append(controller)
359 failed.append(controller)
360 if res == -signal.SIGINT:
360 if res == -signal.SIGINT:
361 print("Interrupted")
361 print("Interrupted")
362 break
362 break
363 print()
363 print()
364
364
365 else:
365 else:
366 # Run tests concurrently
366 # Run tests concurrently
367 try:
367 try:
368 pool = multiprocessing.pool.ThreadPool(options.fast)
368 pool = multiprocessing.pool.ThreadPool(options.fast)
369 for (controller, res) in pool.imap_unordered(do_run, to_run):
369 for (controller, res) in pool.imap_unordered(do_run, to_run):
370 res_string = 'OK' if res == 0 else 'FAILED'
370 res_string = 'OK' if res == 0 else 'FAILED'
371 print(justify('Test group: ' + controller.section, res_string))
371 print(justify('Test group: ' + controller.section, res_string))
372 if res:
372 if res:
373 controller.print_extra_info()
373 controller.print_extra_info()
374 print(bytes_to_str(controller.stdout))
374 print(decode(controller.stdout))
375 failed.append(controller)
375 failed.append(controller)
376 if res == -signal.SIGINT:
376 if res == -signal.SIGINT:
377 print("Interrupted")
377 print("Interrupted")
378 break
378 break
379 except KeyboardInterrupt:
379 except KeyboardInterrupt:
380 return
380 return
381
381
382 for controller in not_run:
382 for controller in not_run:
383 print(justify('Test group: ' + controller.section, 'NOT RUN'))
383 print(justify('Test group: ' + controller.section, 'NOT RUN'))
384
384
385 t_end = time.time()
385 t_end = time.time()
386 t_tests = t_end - t_start
386 t_tests = t_end - t_start
387 nrunners = len(to_run)
387 nrunners = len(to_run)
388 nfail = len(failed)
388 nfail = len(failed)
389 # summarize results
389 # summarize results
390 print('_'*70)
390 print('_'*70)
391 print('Test suite completed for system with the following information:')
391 print('Test suite completed for system with the following information:')
392 print(report())
392 print(report())
393 took = "Took %.3fs." % t_tests
393 took = "Took %.3fs." % t_tests
394 print('Status: ', end='')
394 print('Status: ', end='')
395 if not failed:
395 if not failed:
396 print('OK (%d test groups).' % nrunners, took)
396 print('OK (%d test groups).' % nrunners, took)
397 else:
397 else:
398 # If anything went wrong, point out what command to rerun manually to
398 # If anything went wrong, point out what command to rerun manually to
399 # see the actual errors and individual summary
399 # see the actual errors and individual summary
400 failed_sections = [c.section for c in failed]
400 failed_sections = [c.section for c in failed]
401 print('ERROR - {} out of {} test groups failed ({}).'.format(nfail,
401 print('ERROR - {} out of {} test groups failed ({}).'.format(nfail,
402 nrunners, ', '.join(failed_sections)), took)
402 nrunners, ', '.join(failed_sections)), took)
403 print()
403 print()
404 print('You may wish to rerun these, with:')
404 print('You may wish to rerun these, with:')
405 print(' iptest', *failed_sections)
405 print(' iptest', *failed_sections)
406 print()
406 print()
407
407
408 if options.coverage:
408 if options.coverage:
409 from coverage import coverage, CoverageException
409 from coverage import coverage, CoverageException
410 cov = coverage(data_file='.coverage')
410 cov = coverage(data_file='.coverage')
411 cov.combine()
411 cov.combine()
412 cov.save()
412 cov.save()
413
413
414 # Coverage HTML report
414 # Coverage HTML report
415 if options.coverage == 'html':
415 if options.coverage == 'html':
416 html_dir = 'ipy_htmlcov'
416 html_dir = 'ipy_htmlcov'
417 shutil.rmtree(html_dir, ignore_errors=True)
417 shutil.rmtree(html_dir, ignore_errors=True)
418 print("Writing HTML coverage report to %s/ ... " % html_dir, end="")
418 print("Writing HTML coverage report to %s/ ... " % html_dir, end="")
419 sys.stdout.flush()
419 sys.stdout.flush()
420
420
421 # Custom HTML reporter to clean up module names.
421 # Custom HTML reporter to clean up module names.
422 from coverage.html import HtmlReporter
422 from coverage.html import HtmlReporter
423 class CustomHtmlReporter(HtmlReporter):
423 class CustomHtmlReporter(HtmlReporter):
424 def find_code_units(self, morfs):
424 def find_code_units(self, morfs):
425 super(CustomHtmlReporter, self).find_code_units(morfs)
425 super(CustomHtmlReporter, self).find_code_units(morfs)
426 for cu in self.code_units:
426 for cu in self.code_units:
427 nameparts = cu.name.split(os.sep)
427 nameparts = cu.name.split(os.sep)
428 if 'IPython' not in nameparts:
428 if 'IPython' not in nameparts:
429 continue
429 continue
430 ix = nameparts.index('IPython')
430 ix = nameparts.index('IPython')
431 cu.name = '.'.join(nameparts[ix:])
431 cu.name = '.'.join(nameparts[ix:])
432
432
433 # Reimplement the html_report method with our custom reporter
433 # Reimplement the html_report method with our custom reporter
434 cov.get_data()
434 cov.get_data()
435 cov.config.from_args(omit='*{0}tests{0}*'.format(os.sep), html_dir=html_dir,
435 cov.config.from_args(omit='*{0}tests{0}*'.format(os.sep), html_dir=html_dir,
436 html_title='IPython test coverage',
436 html_title='IPython test coverage',
437 )
437 )
438 reporter = CustomHtmlReporter(cov, cov.config)
438 reporter = CustomHtmlReporter(cov, cov.config)
439 reporter.report(None)
439 reporter.report(None)
440 print('done.')
440 print('done.')
441
441
442 # Coverage XML report
442 # Coverage XML report
443 elif options.coverage == 'xml':
443 elif options.coverage == 'xml':
444 try:
444 try:
445 cov.xml_report(outfile='ipy_coverage.xml')
445 cov.xml_report(outfile='ipy_coverage.xml')
446 except CoverageException as e:
446 except CoverageException as e:
447 print('Generating coverage report failed. Are you running javascript tests only?')
447 print('Generating coverage report failed. Are you running javascript tests only?')
448 import traceback
448 import traceback
449 traceback.print_exc()
449 traceback.print_exc()
450
450
451 if failed:
451 if failed:
452 # Ensure that our exit code indicates failure
452 # Ensure that our exit code indicates failure
453 sys.exit(1)
453 sys.exit(1)
454
454
455 argparser = argparse.ArgumentParser(description='Run IPython test suite')
455 argparser = argparse.ArgumentParser(description='Run IPython test suite')
456 argparser.add_argument('testgroups', nargs='*',
456 argparser.add_argument('testgroups', nargs='*',
457 help='Run specified groups of tests. If omitted, run '
457 help='Run specified groups of tests. If omitted, run '
458 'all tests.')
458 'all tests.')
459 argparser.add_argument('--all', action='store_true',
459 argparser.add_argument('--all', action='store_true',
460 help='Include slow tests not run by default.')
460 help='Include slow tests not run by default.')
461 argparser.add_argument('--url', help="URL to use for the JS tests.")
461 argparser.add_argument('--url', help="URL to use for the JS tests.")
462 argparser.add_argument('-j', '--fast', nargs='?', const=None, default=1, type=int,
462 argparser.add_argument('-j', '--fast', nargs='?', const=None, default=1, type=int,
463 help='Run test sections in parallel. This starts as many '
463 help='Run test sections in parallel. This starts as many '
464 'processes as you have cores, or you can specify a number.')
464 'processes as you have cores, or you can specify a number.')
465 argparser.add_argument('--xunit', action='store_true',
465 argparser.add_argument('--xunit', action='store_true',
466 help='Produce Xunit XML results')
466 help='Produce Xunit XML results')
467 argparser.add_argument('--coverage', nargs='?', const=True, default=False,
467 argparser.add_argument('--coverage', nargs='?', const=True, default=False,
468 help="Measure test coverage. Specify 'html' or "
468 help="Measure test coverage. Specify 'html' or "
469 "'xml' to get reports.")
469 "'xml' to get reports.")
470 argparser.add_argument('--subproc-streams', default='capture',
470 argparser.add_argument('--subproc-streams', default='capture',
471 help="What to do with stdout/stderr from subprocesses. "
471 help="What to do with stdout/stderr from subprocesses. "
472 "'capture' (default), 'show' and 'discard' are the options.")
472 "'capture' (default), 'show' and 'discard' are the options.")
473
473
474 def default_options():
474 def default_options():
475 """Get an argparse Namespace object with the default arguments, to pass to
475 """Get an argparse Namespace object with the default arguments, to pass to
476 :func:`run_iptestall`.
476 :func:`run_iptestall`.
477 """
477 """
478 options = argparser.parse_args([])
478 options = argparser.parse_args([])
479 options.extra_args = []
479 options.extra_args = []
480 return options
480 return options
481
481
482 def main():
482 def main():
483 # iptest doesn't work correctly if the working directory is the
483 # iptest doesn't work correctly if the working directory is the
484 # root of the IPython source tree. Tell the user to avoid
484 # root of the IPython source tree. Tell the user to avoid
485 # frustration.
485 # frustration.
486 if os.path.exists(os.path.join(os.getcwd(),
486 if os.path.exists(os.path.join(os.getcwd(),
487 'IPython', 'testing', '__main__.py')):
487 'IPython', 'testing', '__main__.py')):
488 print("Don't run iptest from the IPython source directory",
488 print("Don't run iptest from the IPython source directory",
489 file=sys.stderr)
489 file=sys.stderr)
490 sys.exit(1)
490 sys.exit(1)
491 # Arguments after -- should be passed through to nose. Argparse treats
491 # Arguments after -- should be passed through to nose. Argparse treats
492 # everything after -- as regular positional arguments, so we separate them
492 # everything after -- as regular positional arguments, so we separate them
493 # first.
493 # first.
494 try:
494 try:
495 ix = sys.argv.index('--')
495 ix = sys.argv.index('--')
496 except ValueError:
496 except ValueError:
497 to_parse = sys.argv[1:]
497 to_parse = sys.argv[1:]
498 extra_args = []
498 extra_args = []
499 else:
499 else:
500 to_parse = sys.argv[1:ix]
500 to_parse = sys.argv[1:ix]
501 extra_args = sys.argv[ix+1:]
501 extra_args = sys.argv[ix+1:]
502
502
503 options = argparser.parse_args(to_parse)
503 options = argparser.parse_args(to_parse)
504 options.extra_args = extra_args
504 options.extra_args = extra_args
505
505
506 run_iptestall(options)
506 run_iptestall(options)
507
507
508
508
509 if __name__ == '__main__':
509 if __name__ == '__main__':
510 main()
510 main()
@@ -1,468 +1,467
1 """Generic testing tools.
1 """Generic testing tools.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 """
6 """
7
7
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import os
12 import os
13 import re
13 import re
14 import sys
14 import sys
15 import tempfile
15 import tempfile
16
16
17 from contextlib import contextmanager
17 from contextlib import contextmanager
18 from io import StringIO
18 from io import StringIO
19 from subprocess import Popen, PIPE
19 from subprocess import Popen, PIPE
20 from unittest.mock import patch
20 from unittest.mock import patch
21
21
22 try:
22 try:
23 # These tools are used by parts of the runtime, so we make the nose
23 # These tools are used by parts of the runtime, so we make the nose
24 # dependency optional at this point. Nose is a hard dependency to run the
24 # dependency optional at this point. Nose is a hard dependency to run the
25 # test suite, but NOT to use ipython itself.
25 # test suite, but NOT to use ipython itself.
26 import nose.tools as nt
26 import nose.tools as nt
27 has_nose = True
27 has_nose = True
28 except ImportError:
28 except ImportError:
29 has_nose = False
29 has_nose = False
30
30
31 from traitlets.config.loader import Config
31 from traitlets.config.loader import Config
32 from IPython.utils.process import get_output_error_code
32 from IPython.utils.process import get_output_error_code
33 from IPython.utils.text import list_strings
33 from IPython.utils.text import list_strings
34 from IPython.utils.io import temp_pyfile, Tee
34 from IPython.utils.io import temp_pyfile, Tee
35 from IPython.utils import py3compat
35 from IPython.utils import py3compat
36 from IPython.utils.encoding import DEFAULT_ENCODING
37
36
38 from . import decorators as dec
37 from . import decorators as dec
39 from . import skipdoctest
38 from . import skipdoctest
40
39
41
40
42 # The docstring for full_path doctests differently on win32 (different path
41 # The docstring for full_path doctests differently on win32 (different path
43 # separator) so just skip the doctest there. The example remains informative.
42 # separator) so just skip the doctest there. The example remains informative.
44 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
43 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45
44
46 @doctest_deco
45 @doctest_deco
47 def full_path(startPath,files):
46 def full_path(startPath,files):
48 """Make full paths for all the listed files, based on startPath.
47 """Make full paths for all the listed files, based on startPath.
49
48
50 Only the base part of startPath is kept, since this routine is typically
49 Only the base part of startPath is kept, since this routine is typically
51 used with a script's ``__file__`` variable as startPath. The base of startPath
50 used with a script's ``__file__`` variable as startPath. The base of startPath
52 is then prepended to all the listed files, forming the output list.
51 is then prepended to all the listed files, forming the output list.
53
52
54 Parameters
53 Parameters
55 ----------
54 ----------
56 startPath : string
55 startPath : string
57 Initial path to use as the base for the results. This path is split
56 Initial path to use as the base for the results. This path is split
58 using os.path.split() and only its first component is kept.
57 using os.path.split() and only its first component is kept.
59
58
60 files : string or list
59 files : string or list
61 One or more files.
60 One or more files.
62
61
63 Examples
62 Examples
64 --------
63 --------
65
64
66 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
65 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 ['/foo/a.txt', '/foo/b.txt']
66 ['/foo/a.txt', '/foo/b.txt']
68
67
69 >>> full_path('/foo',['a.txt','b.txt'])
68 >>> full_path('/foo',['a.txt','b.txt'])
70 ['/a.txt', '/b.txt']
69 ['/a.txt', '/b.txt']
71
70
72 If a single file is given, the output is still a list::
71 If a single file is given, the output is still a list::
73
72
74 >>> full_path('/foo','a.txt')
73 >>> full_path('/foo','a.txt')
75 ['/a.txt']
74 ['/a.txt']
76 """
75 """
77
76
78 files = list_strings(files)
77 files = list_strings(files)
79 base = os.path.split(startPath)[0]
78 base = os.path.split(startPath)[0]
80 return [ os.path.join(base,f) for f in files ]
79 return [ os.path.join(base,f) for f in files ]
81
80
82
81
83 def parse_test_output(txt):
82 def parse_test_output(txt):
84 """Parse the output of a test run and return errors, failures.
83 """Parse the output of a test run and return errors, failures.
85
84
86 Parameters
85 Parameters
87 ----------
86 ----------
88 txt : str
87 txt : str
89 Text output of a test run, assumed to contain a line of one of the
88 Text output of a test run, assumed to contain a line of one of the
90 following forms::
89 following forms::
91
90
92 'FAILED (errors=1)'
91 'FAILED (errors=1)'
93 'FAILED (failures=1)'
92 'FAILED (failures=1)'
94 'FAILED (errors=1, failures=1)'
93 'FAILED (errors=1, failures=1)'
95
94
96 Returns
95 Returns
97 -------
96 -------
98 nerr, nfail
97 nerr, nfail
99 number of errors and failures.
98 number of errors and failures.
100 """
99 """
101
100
102 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
101 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 if err_m:
102 if err_m:
104 nerr = int(err_m.group(1))
103 nerr = int(err_m.group(1))
105 nfail = 0
104 nfail = 0
106 return nerr, nfail
105 return nerr, nfail
107
106
108 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
107 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 if fail_m:
108 if fail_m:
110 nerr = 0
109 nerr = 0
111 nfail = int(fail_m.group(1))
110 nfail = int(fail_m.group(1))
112 return nerr, nfail
111 return nerr, nfail
113
112
114 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
113 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 re.MULTILINE)
114 re.MULTILINE)
116 if both_m:
115 if both_m:
117 nerr = int(both_m.group(1))
116 nerr = int(both_m.group(1))
118 nfail = int(both_m.group(2))
117 nfail = int(both_m.group(2))
119 return nerr, nfail
118 return nerr, nfail
120
119
121 # If the input didn't match any of these forms, assume no error/failures
120 # If the input didn't match any of these forms, assume no error/failures
122 return 0, 0
121 return 0, 0
123
122
124
123
125 # So nose doesn't think this is a test
124 # So nose doesn't think this is a test
126 parse_test_output.__test__ = False
125 parse_test_output.__test__ = False
127
126
128
127
129 def default_argv():
128 def default_argv():
130 """Return a valid default argv for creating testing instances of ipython"""
129 """Return a valid default argv for creating testing instances of ipython"""
131
130
132 return ['--quick', # so no config file is loaded
131 return ['--quick', # so no config file is loaded
133 # Other defaults to minimize side effects on stdout
132 # Other defaults to minimize side effects on stdout
134 '--colors=NoColor', '--no-term-title','--no-banner',
133 '--colors=NoColor', '--no-term-title','--no-banner',
135 '--autocall=0']
134 '--autocall=0']
136
135
137
136
138 def default_config():
137 def default_config():
139 """Return a config object with good defaults for testing."""
138 """Return a config object with good defaults for testing."""
140 config = Config()
139 config = Config()
141 config.TerminalInteractiveShell.colors = 'NoColor'
140 config.TerminalInteractiveShell.colors = 'NoColor'
142 config.TerminalTerminalInteractiveShell.term_title = False,
141 config.TerminalTerminalInteractiveShell.term_title = False,
143 config.TerminalInteractiveShell.autocall = 0
142 config.TerminalInteractiveShell.autocall = 0
144 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
143 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 config.HistoryManager.hist_file = f.name
144 config.HistoryManager.hist_file = f.name
146 f.close()
145 f.close()
147 config.HistoryManager.db_cache_size = 10000
146 config.HistoryManager.db_cache_size = 10000
148 return config
147 return config
149
148
150
149
151 def get_ipython_cmd(as_string=False):
150 def get_ipython_cmd(as_string=False):
152 """
151 """
153 Return appropriate IPython command line name. By default, this will return
152 Return appropriate IPython command line name. By default, this will return
154 a list that can be used with subprocess.Popen, for example, but passing
153 a list that can be used with subprocess.Popen, for example, but passing
155 `as_string=True` allows for returning the IPython command as a string.
154 `as_string=True` allows for returning the IPython command as a string.
156
155
157 Parameters
156 Parameters
158 ----------
157 ----------
159 as_string: bool
158 as_string: bool
160 Flag to allow to return the command as a string.
159 Flag to allow to return the command as a string.
161 """
160 """
162 ipython_cmd = [sys.executable, "-m", "IPython"]
161 ipython_cmd = [sys.executable, "-m", "IPython"]
163
162
164 if as_string:
163 if as_string:
165 ipython_cmd = " ".join(ipython_cmd)
164 ipython_cmd = " ".join(ipython_cmd)
166
165
167 return ipython_cmd
166 return ipython_cmd
168
167
169 def ipexec(fname, options=None, commands=()):
168 def ipexec(fname, options=None, commands=()):
170 """Utility to call 'ipython filename'.
169 """Utility to call 'ipython filename'.
171
170
172 Starts IPython with a minimal and safe configuration to make startup as fast
171 Starts IPython with a minimal and safe configuration to make startup as fast
173 as possible.
172 as possible.
174
173
175 Note that this starts IPython in a subprocess!
174 Note that this starts IPython in a subprocess!
176
175
177 Parameters
176 Parameters
178 ----------
177 ----------
179 fname : str
178 fname : str
180 Name of file to be executed (should have .py or .ipy extension).
179 Name of file to be executed (should have .py or .ipy extension).
181
180
182 options : optional, list
181 options : optional, list
183 Extra command-line flags to be passed to IPython.
182 Extra command-line flags to be passed to IPython.
184
183
185 commands : optional, list
184 commands : optional, list
186 Commands to send in on stdin
185 Commands to send in on stdin
187
186
188 Returns
187 Returns
189 -------
188 -------
190 (stdout, stderr) of ipython subprocess.
189 (stdout, stderr) of ipython subprocess.
191 """
190 """
192 if options is None: options = []
191 if options is None: options = []
193
192
194 cmdargs = default_argv() + options
193 cmdargs = default_argv() + options
195
194
196 test_dir = os.path.dirname(__file__)
195 test_dir = os.path.dirname(__file__)
197
196
198 ipython_cmd = get_ipython_cmd()
197 ipython_cmd = get_ipython_cmd()
199 # Absolute path for filename
198 # Absolute path for filename
200 full_fname = os.path.join(test_dir, fname)
199 full_fname = os.path.join(test_dir, fname)
201 full_cmd = ipython_cmd + cmdargs + [full_fname]
200 full_cmd = ipython_cmd + cmdargs + [full_fname]
202 env = os.environ.copy()
201 env = os.environ.copy()
203 # FIXME: ignore all warnings in ipexec while we have shims
202 # FIXME: ignore all warnings in ipexec while we have shims
204 # should we keep suppressing warnings here, even after removing shims?
203 # should we keep suppressing warnings here, even after removing shims?
205 env['PYTHONWARNINGS'] = 'ignore'
204 env['PYTHONWARNINGS'] = 'ignore'
206 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
205 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 for k, v in env.items():
206 for k, v in env.items():
208 # Debug a bizarre failure we've seen on Windows:
207 # Debug a bizarre failure we've seen on Windows:
209 # TypeError: environment can only contain strings
208 # TypeError: environment can only contain strings
210 if not isinstance(v, str):
209 if not isinstance(v, str):
211 print(k, v)
210 print(k, v)
212 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
211 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 out, err = p.communicate(input=py3compat.str_to_bytes('\n'.join(commands)) or None)
212 out, err = p.communicate(input=py3compat.encode('\n'.join(commands)) or None)
214 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
213 out, err = py3compat.decode(out), py3compat.decode(err)
215 # `import readline` causes 'ESC[?1034h' to be output sometimes,
214 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 # so strip that out before doing comparisons
215 # so strip that out before doing comparisons
217 if out:
216 if out:
218 out = re.sub(r'\x1b\[[^h]+h', '', out)
217 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 return out, err
218 return out, err
220
219
221
220
222 def ipexec_validate(fname, expected_out, expected_err='',
221 def ipexec_validate(fname, expected_out, expected_err='',
223 options=None, commands=()):
222 options=None, commands=()):
224 """Utility to call 'ipython filename' and validate output/error.
223 """Utility to call 'ipython filename' and validate output/error.
225
224
226 This function raises an AssertionError if the validation fails.
225 This function raises an AssertionError if the validation fails.
227
226
228 Note that this starts IPython in a subprocess!
227 Note that this starts IPython in a subprocess!
229
228
230 Parameters
229 Parameters
231 ----------
230 ----------
232 fname : str
231 fname : str
233 Name of the file to be executed (should have .py or .ipy extension).
232 Name of the file to be executed (should have .py or .ipy extension).
234
233
235 expected_out : str
234 expected_out : str
236 Expected stdout of the process.
235 Expected stdout of the process.
237
236
238 expected_err : optional, str
237 expected_err : optional, str
239 Expected stderr of the process.
238 Expected stderr of the process.
240
239
241 options : optional, list
240 options : optional, list
242 Extra command-line flags to be passed to IPython.
241 Extra command-line flags to be passed to IPython.
243
242
244 Returns
243 Returns
245 -------
244 -------
246 None
245 None
247 """
246 """
248
247
249 import nose.tools as nt
248 import nose.tools as nt
250
249
251 out, err = ipexec(fname, options, commands)
250 out, err = ipexec(fname, options, commands)
252 #print 'OUT', out # dbg
251 #print 'OUT', out # dbg
253 #print 'ERR', err # dbg
252 #print 'ERR', err # dbg
254 # If there are any errors, we must check those befor stdout, as they may be
253 # If there are any errors, we must check those befor stdout, as they may be
255 # more informative than simply having an empty stdout.
254 # more informative than simply having an empty stdout.
256 if err:
255 if err:
257 if expected_err:
256 if expected_err:
258 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
257 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 else:
258 else:
260 raise ValueError('Running file %r produced error: %r' %
259 raise ValueError('Running file %r produced error: %r' %
261 (fname, err))
260 (fname, err))
262 # If no errors or output on stderr was expected, match stdout
261 # If no errors or output on stderr was expected, match stdout
263 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
262 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264
263
265
264
266 class TempFileMixin(object):
265 class TempFileMixin(object):
267 """Utility class to create temporary Python/IPython files.
266 """Utility class to create temporary Python/IPython files.
268
267
269 Meant as a mixin class for test cases."""
268 Meant as a mixin class for test cases."""
270
269
271 def mktmp(self, src, ext='.py'):
270 def mktmp(self, src, ext='.py'):
272 """Make a valid python temp file."""
271 """Make a valid python temp file."""
273 fname, f = temp_pyfile(src, ext)
272 fname, f = temp_pyfile(src, ext)
274 self.tmpfile = f
273 self.tmpfile = f
275 self.fname = fname
274 self.fname = fname
276
275
277 def tearDown(self):
276 def tearDown(self):
278 if hasattr(self, 'tmpfile'):
277 if hasattr(self, 'tmpfile'):
279 # If the tmpfile wasn't made because of skipped tests, like in
278 # If the tmpfile wasn't made because of skipped tests, like in
280 # win32, there's nothing to cleanup.
279 # win32, there's nothing to cleanup.
281 self.tmpfile.close()
280 self.tmpfile.close()
282 try:
281 try:
283 os.unlink(self.fname)
282 os.unlink(self.fname)
284 except:
283 except:
285 # On Windows, even though we close the file, we still can't
284 # On Windows, even though we close the file, we still can't
286 # delete it. I have no clue why
285 # delete it. I have no clue why
287 pass
286 pass
288
287
289 def __enter__(self):
288 def __enter__(self):
290 return self
289 return self
291
290
292 def __exit__(self, exc_type, exc_value, traceback):
291 def __exit__(self, exc_type, exc_value, traceback):
293 self.tearDown()
292 self.tearDown()
294
293
295
294
296 pair_fail_msg = ("Testing {0}\n\n"
295 pair_fail_msg = ("Testing {0}\n\n"
297 "In:\n"
296 "In:\n"
298 " {1!r}\n"
297 " {1!r}\n"
299 "Expected:\n"
298 "Expected:\n"
300 " {2!r}\n"
299 " {2!r}\n"
301 "Got:\n"
300 "Got:\n"
302 " {3!r}\n")
301 " {3!r}\n")
303 def check_pairs(func, pairs):
302 def check_pairs(func, pairs):
304 """Utility function for the common case of checking a function with a
303 """Utility function for the common case of checking a function with a
305 sequence of input/output pairs.
304 sequence of input/output pairs.
306
305
307 Parameters
306 Parameters
308 ----------
307 ----------
309 func : callable
308 func : callable
310 The function to be tested. Should accept a single argument.
309 The function to be tested. Should accept a single argument.
311 pairs : iterable
310 pairs : iterable
312 A list of (input, expected_output) tuples.
311 A list of (input, expected_output) tuples.
313
312
314 Returns
313 Returns
315 -------
314 -------
316 None. Raises an AssertionError if any output does not match the expected
315 None. Raises an AssertionError if any output does not match the expected
317 value.
316 value.
318 """
317 """
319 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
318 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
320 for inp, expected in pairs:
319 for inp, expected in pairs:
321 out = func(inp)
320 out = func(inp)
322 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
321 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
323
322
324
323
325 MyStringIO = StringIO
324 MyStringIO = StringIO
326
325
327 _re_type = type(re.compile(r''))
326 _re_type = type(re.compile(r''))
328
327
329 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
328 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
330 -------
329 -------
331 {2!s}
330 {2!s}
332 -------
331 -------
333 """
332 """
334
333
335 class AssertPrints(object):
334 class AssertPrints(object):
336 """Context manager for testing that code prints certain text.
335 """Context manager for testing that code prints certain text.
337
336
338 Examples
337 Examples
339 --------
338 --------
340 >>> with AssertPrints("abc", suppress=False):
339 >>> with AssertPrints("abc", suppress=False):
341 ... print("abcd")
340 ... print("abcd")
342 ... print("def")
341 ... print("def")
343 ...
342 ...
344 abcd
343 abcd
345 def
344 def
346 """
345 """
347 def __init__(self, s, channel='stdout', suppress=True):
346 def __init__(self, s, channel='stdout', suppress=True):
348 self.s = s
347 self.s = s
349 if isinstance(self.s, (str, _re_type)):
348 if isinstance(self.s, (str, _re_type)):
350 self.s = [self.s]
349 self.s = [self.s]
351 self.channel = channel
350 self.channel = channel
352 self.suppress = suppress
351 self.suppress = suppress
353
352
354 def __enter__(self):
353 def __enter__(self):
355 self.orig_stream = getattr(sys, self.channel)
354 self.orig_stream = getattr(sys, self.channel)
356 self.buffer = MyStringIO()
355 self.buffer = MyStringIO()
357 self.tee = Tee(self.buffer, channel=self.channel)
356 self.tee = Tee(self.buffer, channel=self.channel)
358 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
357 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
359
358
360 def __exit__(self, etype, value, traceback):
359 def __exit__(self, etype, value, traceback):
361 try:
360 try:
362 if value is not None:
361 if value is not None:
363 # If an error was raised, don't check anything else
362 # If an error was raised, don't check anything else
364 return False
363 return False
365 self.tee.flush()
364 self.tee.flush()
366 setattr(sys, self.channel, self.orig_stream)
365 setattr(sys, self.channel, self.orig_stream)
367 printed = self.buffer.getvalue()
366 printed = self.buffer.getvalue()
368 for s in self.s:
367 for s in self.s:
369 if isinstance(s, _re_type):
368 if isinstance(s, _re_type):
370 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
369 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
371 else:
370 else:
372 assert s in printed, notprinted_msg.format(s, self.channel, printed)
371 assert s in printed, notprinted_msg.format(s, self.channel, printed)
373 return False
372 return False
374 finally:
373 finally:
375 self.tee.close()
374 self.tee.close()
376
375
377 printed_msg = """Found {0!r} in printed output (on {1}):
376 printed_msg = """Found {0!r} in printed output (on {1}):
378 -------
377 -------
379 {2!s}
378 {2!s}
380 -------
379 -------
381 """
380 """
382
381
383 class AssertNotPrints(AssertPrints):
382 class AssertNotPrints(AssertPrints):
384 """Context manager for checking that certain output *isn't* produced.
383 """Context manager for checking that certain output *isn't* produced.
385
384
386 Counterpart of AssertPrints"""
385 Counterpart of AssertPrints"""
387 def __exit__(self, etype, value, traceback):
386 def __exit__(self, etype, value, traceback):
388 try:
387 try:
389 if value is not None:
388 if value is not None:
390 # If an error was raised, don't check anything else
389 # If an error was raised, don't check anything else
391 self.tee.close()
390 self.tee.close()
392 return False
391 return False
393 self.tee.flush()
392 self.tee.flush()
394 setattr(sys, self.channel, self.orig_stream)
393 setattr(sys, self.channel, self.orig_stream)
395 printed = self.buffer.getvalue()
394 printed = self.buffer.getvalue()
396 for s in self.s:
395 for s in self.s:
397 if isinstance(s, _re_type):
396 if isinstance(s, _re_type):
398 assert not s.search(printed),printed_msg.format(
397 assert not s.search(printed),printed_msg.format(
399 s.pattern, self.channel, printed)
398 s.pattern, self.channel, printed)
400 else:
399 else:
401 assert s not in printed, printed_msg.format(
400 assert s not in printed, printed_msg.format(
402 s, self.channel, printed)
401 s, self.channel, printed)
403 return False
402 return False
404 finally:
403 finally:
405 self.tee.close()
404 self.tee.close()
406
405
407 @contextmanager
406 @contextmanager
408 def mute_warn():
407 def mute_warn():
409 from IPython.utils import warn
408 from IPython.utils import warn
410 save_warn = warn.warn
409 save_warn = warn.warn
411 warn.warn = lambda *a, **kw: None
410 warn.warn = lambda *a, **kw: None
412 try:
411 try:
413 yield
412 yield
414 finally:
413 finally:
415 warn.warn = save_warn
414 warn.warn = save_warn
416
415
417 @contextmanager
416 @contextmanager
418 def make_tempfile(name):
417 def make_tempfile(name):
419 """ Create an empty, named, temporary file for the duration of the context.
418 """ Create an empty, named, temporary file for the duration of the context.
420 """
419 """
421 f = open(name, 'w')
420 f = open(name, 'w')
422 f.close()
421 f.close()
423 try:
422 try:
424 yield
423 yield
425 finally:
424 finally:
426 os.unlink(name)
425 os.unlink(name)
427
426
428 def fake_input(inputs):
427 def fake_input(inputs):
429 """Temporarily replace the input() function to return the given values
428 """Temporarily replace the input() function to return the given values
430
429
431 Use as a context manager:
430 Use as a context manager:
432
431
433 with fake_input(['result1', 'result2']):
432 with fake_input(['result1', 'result2']):
434 ...
433 ...
435
434
436 Values are returned in order. If input() is called again after the last value
435 Values are returned in order. If input() is called again after the last value
437 was used, EOFError is raised.
436 was used, EOFError is raised.
438 """
437 """
439 it = iter(inputs)
438 it = iter(inputs)
440 def mock_input(prompt=''):
439 def mock_input(prompt=''):
441 try:
440 try:
442 return next(it)
441 return next(it)
443 except StopIteration:
442 except StopIteration:
444 raise EOFError('No more inputs given')
443 raise EOFError('No more inputs given')
445
444
446 return patch('builtins.input', mock_input)
445 return patch('builtins.input', mock_input)
447
446
448 def help_output_test(subcommand=''):
447 def help_output_test(subcommand=''):
449 """test that `ipython [subcommand] -h` works"""
448 """test that `ipython [subcommand] -h` works"""
450 cmd = get_ipython_cmd() + [subcommand, '-h']
449 cmd = get_ipython_cmd() + [subcommand, '-h']
451 out, err, rc = get_output_error_code(cmd)
450 out, err, rc = get_output_error_code(cmd)
452 nt.assert_equal(rc, 0, err)
451 nt.assert_equal(rc, 0, err)
453 nt.assert_not_in("Traceback", err)
452 nt.assert_not_in("Traceback", err)
454 nt.assert_in("Options", out)
453 nt.assert_in("Options", out)
455 nt.assert_in("--help-all", out)
454 nt.assert_in("--help-all", out)
456 return out, err
455 return out, err
457
456
458
457
459 def help_all_output_test(subcommand=''):
458 def help_all_output_test(subcommand=''):
460 """test that `ipython [subcommand] --help-all` works"""
459 """test that `ipython [subcommand] --help-all` works"""
461 cmd = get_ipython_cmd() + [subcommand, '--help-all']
460 cmd = get_ipython_cmd() + [subcommand, '--help-all']
462 out, err, rc = get_output_error_code(cmd)
461 out, err, rc = get_output_error_code(cmd)
463 nt.assert_equal(rc, 0, err)
462 nt.assert_equal(rc, 0, err)
464 nt.assert_not_in("Traceback", err)
463 nt.assert_not_in("Traceback", err)
465 nt.assert_in("Options", out)
464 nt.assert_in("Options", out)
466 nt.assert_in("Class", out)
465 nt.assert_in("Class", out)
467 return out, err
466 return out, err
468
467
@@ -1,78 +1,78
1 """cli-specific implementation of process utilities.
1 """cli-specific implementation of process utilities.
2
2
3 cli - Common Language Infrastructure for IronPython. Code
3 cli - Common Language Infrastructure for IronPython. Code
4 can run on any operating system. Check os.name for os-
4 can run on any operating system. Check os.name for os-
5 specific settings.
5 specific settings.
6
6
7 This file is only meant to be imported by process.py, not by end-users.
7 This file is only meant to be imported by process.py, not by end-users.
8
8
9 This file is largely untested. To become a full drop-in process
9 This file is largely untested. To become a full drop-in process
10 interface for IronPython will probably require you to help fill
10 interface for IronPython will probably require you to help fill
11 in the details.
11 in the details.
12 """
12 """
13
13
14 # Import cli libraries:
14 # Import cli libraries:
15 import clr
15 import clr
16 import System
16 import System
17
17
18 # Import Python libraries:
18 # Import Python libraries:
19 import os
19 import os
20
20
21 # Import IPython libraries:
21 # Import IPython libraries:
22 from IPython.utils import py3compat
22 from IPython.utils import py3compat
23 from ._process_common import arg_split
23 from ._process_common import arg_split
24
24
25 def _find_cmd(cmd):
25 def _find_cmd(cmd):
26 """Find the full path to a command using which."""
26 """Find the full path to a command using which."""
27 paths = System.Environment.GetEnvironmentVariable("PATH").Split(os.pathsep)
27 paths = System.Environment.GetEnvironmentVariable("PATH").Split(os.pathsep)
28 for path in paths:
28 for path in paths:
29 filename = os.path.join(path, cmd)
29 filename = os.path.join(path, cmd)
30 if System.IO.File.Exists(filename):
30 if System.IO.File.Exists(filename):
31 return py3compat.bytes_to_str(filename)
31 return py3compat.decode(filename)
32 raise OSError("command %r not found" % cmd)
32 raise OSError("command %r not found" % cmd)
33
33
34 def system(cmd):
34 def system(cmd):
35 """
35 """
36 system(cmd) should work in a cli environment on Mac OSX, Linux,
36 system(cmd) should work in a cli environment on Mac OSX, Linux,
37 and Windows
37 and Windows
38 """
38 """
39 psi = System.Diagnostics.ProcessStartInfo(cmd)
39 psi = System.Diagnostics.ProcessStartInfo(cmd)
40 psi.RedirectStandardOutput = True
40 psi.RedirectStandardOutput = True
41 psi.RedirectStandardError = True
41 psi.RedirectStandardError = True
42 psi.WindowStyle = System.Diagnostics.ProcessWindowStyle.Normal
42 psi.WindowStyle = System.Diagnostics.ProcessWindowStyle.Normal
43 psi.UseShellExecute = False
43 psi.UseShellExecute = False
44 # Start up process:
44 # Start up process:
45 reg = System.Diagnostics.Process.Start(psi)
45 reg = System.Diagnostics.Process.Start(psi)
46
46
47 def getoutput(cmd):
47 def getoutput(cmd):
48 """
48 """
49 getoutput(cmd) should work in a cli environment on Mac OSX, Linux,
49 getoutput(cmd) should work in a cli environment on Mac OSX, Linux,
50 and Windows
50 and Windows
51 """
51 """
52 psi = System.Diagnostics.ProcessStartInfo(cmd)
52 psi = System.Diagnostics.ProcessStartInfo(cmd)
53 psi.RedirectStandardOutput = True
53 psi.RedirectStandardOutput = True
54 psi.RedirectStandardError = True
54 psi.RedirectStandardError = True
55 psi.WindowStyle = System.Diagnostics.ProcessWindowStyle.Normal
55 psi.WindowStyle = System.Diagnostics.ProcessWindowStyle.Normal
56 psi.UseShellExecute = False
56 psi.UseShellExecute = False
57 # Start up process:
57 # Start up process:
58 reg = System.Diagnostics.Process.Start(psi)
58 reg = System.Diagnostics.Process.Start(psi)
59 myOutput = reg.StandardOutput
59 myOutput = reg.StandardOutput
60 output = myOutput.ReadToEnd()
60 output = myOutput.ReadToEnd()
61 myError = reg.StandardError
61 myError = reg.StandardError
62 error = myError.ReadToEnd()
62 error = myError.ReadToEnd()
63 return output
63 return output
64
64
65 def check_pid(pid):
65 def check_pid(pid):
66 """
66 """
67 Check if a process with the given PID (pid) exists
67 Check if a process with the given PID (pid) exists
68 """
68 """
69 try:
69 try:
70 System.Diagnostics.Process.GetProcessById(pid)
70 System.Diagnostics.Process.GetProcessById(pid)
71 # process with given pid is running
71 # process with given pid is running
72 return True
72 return True
73 except System.InvalidOperationException:
73 except System.InvalidOperationException:
74 # process wasn't started by this object (but is running)
74 # process wasn't started by this object (but is running)
75 return True
75 return True
76 except System.ArgumentException:
76 except System.ArgumentException:
77 # process with given pid isn't running
77 # process with given pid isn't running
78 return False
78 return False
@@ -1,212 +1,212
1 """Common utilities for the various process_* implementations.
1 """Common utilities for the various process_* implementations.
2
2
3 This file is only meant to be imported by the platform-specific implementations
3 This file is only meant to be imported by the platform-specific implementations
4 of subprocess utilities, and it contains tools that are common to all of them.
4 of subprocess utilities, and it contains tools that are common to all of them.
5 """
5 """
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 import subprocess
17 import subprocess
18 import shlex
18 import shlex
19 import sys
19 import sys
20 import os
20 import os
21
21
22 from IPython.utils import py3compat
22 from IPython.utils import py3compat
23
23
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25 # Function definitions
25 # Function definitions
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27
27
28 def read_no_interrupt(p):
28 def read_no_interrupt(p):
29 """Read from a pipe ignoring EINTR errors.
29 """Read from a pipe ignoring EINTR errors.
30
30
31 This is necessary because when reading from pipes with GUI event loops
31 This is necessary because when reading from pipes with GUI event loops
32 running in the background, often interrupts are raised that stop the
32 running in the background, often interrupts are raised that stop the
33 command from completing."""
33 command from completing."""
34 import errno
34 import errno
35
35
36 try:
36 try:
37 return p.read()
37 return p.read()
38 except IOError as err:
38 except IOError as err:
39 if err.errno != errno.EINTR:
39 if err.errno != errno.EINTR:
40 raise
40 raise
41
41
42
42
43 def process_handler(cmd, callback, stderr=subprocess.PIPE):
43 def process_handler(cmd, callback, stderr=subprocess.PIPE):
44 """Open a command in a shell subprocess and execute a callback.
44 """Open a command in a shell subprocess and execute a callback.
45
45
46 This function provides common scaffolding for creating subprocess.Popen()
46 This function provides common scaffolding for creating subprocess.Popen()
47 calls. It creates a Popen object and then calls the callback with it.
47 calls. It creates a Popen object and then calls the callback with it.
48
48
49 Parameters
49 Parameters
50 ----------
50 ----------
51 cmd : str or list
51 cmd : str or list
52 A command to be executed by the system, using :class:`subprocess.Popen`.
52 A command to be executed by the system, using :class:`subprocess.Popen`.
53 If a string is passed, it will be run in the system shell. If a list is
53 If a string is passed, it will be run in the system shell. If a list is
54 passed, it will be used directly as arguments.
54 passed, it will be used directly as arguments.
55
55
56 callback : callable
56 callback : callable
57 A one-argument function that will be called with the Popen object.
57 A one-argument function that will be called with the Popen object.
58
58
59 stderr : file descriptor number, optional
59 stderr : file descriptor number, optional
60 By default this is set to ``subprocess.PIPE``, but you can also pass the
60 By default this is set to ``subprocess.PIPE``, but you can also pass the
61 value ``subprocess.STDOUT`` to force the subprocess' stderr to go into
61 value ``subprocess.STDOUT`` to force the subprocess' stderr to go into
62 the same file descriptor as its stdout. This is useful to read stdout
62 the same file descriptor as its stdout. This is useful to read stdout
63 and stderr combined in the order they are generated.
63 and stderr combined in the order they are generated.
64
64
65 Returns
65 Returns
66 -------
66 -------
67 The return value of the provided callback is returned.
67 The return value of the provided callback is returned.
68 """
68 """
69 sys.stdout.flush()
69 sys.stdout.flush()
70 sys.stderr.flush()
70 sys.stderr.flush()
71 # On win32, close_fds can't be true when using pipes for stdin/out/err
71 # On win32, close_fds can't be true when using pipes for stdin/out/err
72 close_fds = sys.platform != 'win32'
72 close_fds = sys.platform != 'win32'
73 # Determine if cmd should be run with system shell.
73 # Determine if cmd should be run with system shell.
74 shell = isinstance(cmd, str)
74 shell = isinstance(cmd, str)
75 # On POSIX systems run shell commands with user-preferred shell.
75 # On POSIX systems run shell commands with user-preferred shell.
76 executable = None
76 executable = None
77 if shell and os.name == 'posix' and 'SHELL' in os.environ:
77 if shell and os.name == 'posix' and 'SHELL' in os.environ:
78 executable = os.environ['SHELL']
78 executable = os.environ['SHELL']
79 p = subprocess.Popen(cmd, shell=shell,
79 p = subprocess.Popen(cmd, shell=shell,
80 executable=executable,
80 executable=executable,
81 stdin=subprocess.PIPE,
81 stdin=subprocess.PIPE,
82 stdout=subprocess.PIPE,
82 stdout=subprocess.PIPE,
83 stderr=stderr,
83 stderr=stderr,
84 close_fds=close_fds)
84 close_fds=close_fds)
85
85
86 try:
86 try:
87 out = callback(p)
87 out = callback(p)
88 except KeyboardInterrupt:
88 except KeyboardInterrupt:
89 print('^C')
89 print('^C')
90 sys.stdout.flush()
90 sys.stdout.flush()
91 sys.stderr.flush()
91 sys.stderr.flush()
92 out = None
92 out = None
93 finally:
93 finally:
94 # Make really sure that we don't leave processes behind, in case the
94 # Make really sure that we don't leave processes behind, in case the
95 # call above raises an exception
95 # call above raises an exception
96 # We start by assuming the subprocess finished (to avoid NameErrors
96 # We start by assuming the subprocess finished (to avoid NameErrors
97 # later depending on the path taken)
97 # later depending on the path taken)
98 if p.returncode is None:
98 if p.returncode is None:
99 try:
99 try:
100 p.terminate()
100 p.terminate()
101 p.poll()
101 p.poll()
102 except OSError:
102 except OSError:
103 pass
103 pass
104 # One last try on our way out
104 # One last try on our way out
105 if p.returncode is None:
105 if p.returncode is None:
106 try:
106 try:
107 p.kill()
107 p.kill()
108 except OSError:
108 except OSError:
109 pass
109 pass
110
110
111 return out
111 return out
112
112
113
113
114 def getoutput(cmd):
114 def getoutput(cmd):
115 """Run a command and return its stdout/stderr as a string.
115 """Run a command and return its stdout/stderr as a string.
116
116
117 Parameters
117 Parameters
118 ----------
118 ----------
119 cmd : str or list
119 cmd : str or list
120 A command to be executed in the system shell.
120 A command to be executed in the system shell.
121
121
122 Returns
122 Returns
123 -------
123 -------
124 output : str
124 output : str
125 A string containing the combination of stdout and stderr from the
125 A string containing the combination of stdout and stderr from the
126 subprocess, in whatever order the subprocess originally wrote to its
126 subprocess, in whatever order the subprocess originally wrote to its
127 file descriptors (so the order of the information in this string is the
127 file descriptors (so the order of the information in this string is the
128 correct order as would be seen if running the command in a terminal).
128 correct order as would be seen if running the command in a terminal).
129 """
129 """
130 out = process_handler(cmd, lambda p: p.communicate()[0], subprocess.STDOUT)
130 out = process_handler(cmd, lambda p: p.communicate()[0], subprocess.STDOUT)
131 if out is None:
131 if out is None:
132 return ''
132 return ''
133 return py3compat.bytes_to_str(out)
133 return py3compat.decode(out)
134
134
135
135
136 def getoutputerror(cmd):
136 def getoutputerror(cmd):
137 """Return (standard output, standard error) of executing cmd in a shell.
137 """Return (standard output, standard error) of executing cmd in a shell.
138
138
139 Accepts the same arguments as os.system().
139 Accepts the same arguments as os.system().
140
140
141 Parameters
141 Parameters
142 ----------
142 ----------
143 cmd : str or list
143 cmd : str or list
144 A command to be executed in the system shell.
144 A command to be executed in the system shell.
145
145
146 Returns
146 Returns
147 -------
147 -------
148 stdout : str
148 stdout : str
149 stderr : str
149 stderr : str
150 """
150 """
151 return get_output_error_code(cmd)[:2]
151 return get_output_error_code(cmd)[:2]
152
152
153 def get_output_error_code(cmd):
153 def get_output_error_code(cmd):
154 """Return (standard output, standard error, return code) of executing cmd
154 """Return (standard output, standard error, return code) of executing cmd
155 in a shell.
155 in a shell.
156
156
157 Accepts the same arguments as os.system().
157 Accepts the same arguments as os.system().
158
158
159 Parameters
159 Parameters
160 ----------
160 ----------
161 cmd : str or list
161 cmd : str or list
162 A command to be executed in the system shell.
162 A command to be executed in the system shell.
163
163
164 Returns
164 Returns
165 -------
165 -------
166 stdout : str
166 stdout : str
167 stderr : str
167 stderr : str
168 returncode: int
168 returncode: int
169 """
169 """
170
170
171 out_err, p = process_handler(cmd, lambda p: (p.communicate(), p))
171 out_err, p = process_handler(cmd, lambda p: (p.communicate(), p))
172 if out_err is None:
172 if out_err is None:
173 return '', '', p.returncode
173 return '', '', p.returncode
174 out, err = out_err
174 out, err = out_err
175 return py3compat.bytes_to_str(out), py3compat.bytes_to_str(err), p.returncode
175 return py3compat.decode(out), py3compat.decode(err), p.returncode
176
176
177 def arg_split(s, posix=False, strict=True):
177 def arg_split(s, posix=False, strict=True):
178 """Split a command line's arguments in a shell-like manner.
178 """Split a command line's arguments in a shell-like manner.
179
179
180 This is a modified version of the standard library's shlex.split()
180 This is a modified version of the standard library's shlex.split()
181 function, but with a default of posix=False for splitting, so that quotes
181 function, but with a default of posix=False for splitting, so that quotes
182 in inputs are respected.
182 in inputs are respected.
183
183
184 if strict=False, then any errors shlex.split would raise will result in the
184 if strict=False, then any errors shlex.split would raise will result in the
185 unparsed remainder being the last element of the list, rather than raising.
185 unparsed remainder being the last element of the list, rather than raising.
186 This is because we sometimes use arg_split to parse things other than
186 This is because we sometimes use arg_split to parse things other than
187 command-line args.
187 command-line args.
188 """
188 """
189
189
190 lex = shlex.shlex(s, posix=posix)
190 lex = shlex.shlex(s, posix=posix)
191 lex.whitespace_split = True
191 lex.whitespace_split = True
192 # Extract tokens, ensuring that things like leaving open quotes
192 # Extract tokens, ensuring that things like leaving open quotes
193 # does not cause this to raise. This is important, because we
193 # does not cause this to raise. This is important, because we
194 # sometimes pass Python source through this (e.g. %timeit f(" ")),
194 # sometimes pass Python source through this (e.g. %timeit f(" ")),
195 # and it shouldn't raise an exception.
195 # and it shouldn't raise an exception.
196 # It may be a bad idea to parse things that are not command-line args
196 # It may be a bad idea to parse things that are not command-line args
197 # through this function, but we do, so let's be safe about it.
197 # through this function, but we do, so let's be safe about it.
198 lex.commenters='' #fix for GH-1269
198 lex.commenters='' #fix for GH-1269
199 tokens = []
199 tokens = []
200 while True:
200 while True:
201 try:
201 try:
202 tokens.append(next(lex))
202 tokens.append(next(lex))
203 except StopIteration:
203 except StopIteration:
204 break
204 break
205 except ValueError:
205 except ValueError:
206 if strict:
206 if strict:
207 raise
207 raise
208 # couldn't parse, get remaining blob as last token
208 # couldn't parse, get remaining blob as last token
209 tokens.append(lex.token)
209 tokens.append(lex.token)
210 break
210 break
211
211
212 return tokens
212 return tokens
@@ -1,224 +1,224
1 """Posix-specific implementation of process utilities.
1 """Posix-specific implementation of process utilities.
2
2
3 This file is only meant to be imported by process.py, not by end-users.
3 This file is only meant to be imported by process.py, not by end-users.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 # Stdlib
17 # Stdlib
18 import errno
18 import errno
19 import os
19 import os
20 import subprocess as sp
20 import subprocess as sp
21 import sys
21 import sys
22
22
23 import pexpect
23 import pexpect
24
24
25 # Our own
25 # Our own
26 from ._process_common import getoutput, arg_split
26 from ._process_common import getoutput, arg_split
27 from IPython.utils import py3compat
27 from IPython.utils import py3compat
28 from IPython.utils.encoding import DEFAULT_ENCODING
28 from IPython.utils.encoding import DEFAULT_ENCODING
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Function definitions
31 # Function definitions
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 def _find_cmd(cmd):
34 def _find_cmd(cmd):
35 """Find the full path to a command using which."""
35 """Find the full path to a command using which."""
36
36
37 path = sp.Popen(['/usr/bin/env', 'which', cmd],
37 path = sp.Popen(['/usr/bin/env', 'which', cmd],
38 stdout=sp.PIPE, stderr=sp.PIPE).communicate()[0]
38 stdout=sp.PIPE, stderr=sp.PIPE).communicate()[0]
39 return py3compat.bytes_to_str(path)
39 return py3compat.decode(path)
40
40
41
41
42 class ProcessHandler(object):
42 class ProcessHandler(object):
43 """Execute subprocesses under the control of pexpect.
43 """Execute subprocesses under the control of pexpect.
44 """
44 """
45 # Timeout in seconds to wait on each reading of the subprocess' output.
45 # Timeout in seconds to wait on each reading of the subprocess' output.
46 # This should not be set too low to avoid cpu overusage from our side,
46 # This should not be set too low to avoid cpu overusage from our side,
47 # since we read in a loop whose period is controlled by this timeout.
47 # since we read in a loop whose period is controlled by this timeout.
48 read_timeout = 0.05
48 read_timeout = 0.05
49
49
50 # Timeout to give a process if we receive SIGINT, between sending the
50 # Timeout to give a process if we receive SIGINT, between sending the
51 # SIGINT to the process and forcefully terminating it.
51 # SIGINT to the process and forcefully terminating it.
52 terminate_timeout = 0.2
52 terminate_timeout = 0.2
53
53
54 # File object where stdout and stderr of the subprocess will be written
54 # File object where stdout and stderr of the subprocess will be written
55 logfile = None
55 logfile = None
56
56
57 # Shell to call for subprocesses to execute
57 # Shell to call for subprocesses to execute
58 _sh = None
58 _sh = None
59
59
60 @property
60 @property
61 def sh(self):
61 def sh(self):
62 if self._sh is None:
62 if self._sh is None:
63 self._sh = pexpect.which('sh')
63 self._sh = pexpect.which('sh')
64 if self._sh is None:
64 if self._sh is None:
65 raise OSError('"sh" shell not found')
65 raise OSError('"sh" shell not found')
66
66
67 return self._sh
67 return self._sh
68
68
69 def __init__(self, logfile=None, read_timeout=None, terminate_timeout=None):
69 def __init__(self, logfile=None, read_timeout=None, terminate_timeout=None):
70 """Arguments are used for pexpect calls."""
70 """Arguments are used for pexpect calls."""
71 self.read_timeout = (ProcessHandler.read_timeout if read_timeout is
71 self.read_timeout = (ProcessHandler.read_timeout if read_timeout is
72 None else read_timeout)
72 None else read_timeout)
73 self.terminate_timeout = (ProcessHandler.terminate_timeout if
73 self.terminate_timeout = (ProcessHandler.terminate_timeout if
74 terminate_timeout is None else
74 terminate_timeout is None else
75 terminate_timeout)
75 terminate_timeout)
76 self.logfile = sys.stdout if logfile is None else logfile
76 self.logfile = sys.stdout if logfile is None else logfile
77
77
78 def getoutput(self, cmd):
78 def getoutput(self, cmd):
79 """Run a command and return its stdout/stderr as a string.
79 """Run a command and return its stdout/stderr as a string.
80
80
81 Parameters
81 Parameters
82 ----------
82 ----------
83 cmd : str
83 cmd : str
84 A command to be executed in the system shell.
84 A command to be executed in the system shell.
85
85
86 Returns
86 Returns
87 -------
87 -------
88 output : str
88 output : str
89 A string containing the combination of stdout and stderr from the
89 A string containing the combination of stdout and stderr from the
90 subprocess, in whatever order the subprocess originally wrote to its
90 subprocess, in whatever order the subprocess originally wrote to its
91 file descriptors (so the order of the information in this string is the
91 file descriptors (so the order of the information in this string is the
92 correct order as would be seen if running the command in a terminal).
92 correct order as would be seen if running the command in a terminal).
93 """
93 """
94 try:
94 try:
95 return pexpect.run(self.sh, args=['-c', cmd]).replace('\r\n', '\n')
95 return pexpect.run(self.sh, args=['-c', cmd]).replace('\r\n', '\n')
96 except KeyboardInterrupt:
96 except KeyboardInterrupt:
97 print('^C', file=sys.stderr, end='')
97 print('^C', file=sys.stderr, end='')
98
98
99 def getoutput_pexpect(self, cmd):
99 def getoutput_pexpect(self, cmd):
100 """Run a command and return its stdout/stderr as a string.
100 """Run a command and return its stdout/stderr as a string.
101
101
102 Parameters
102 Parameters
103 ----------
103 ----------
104 cmd : str
104 cmd : str
105 A command to be executed in the system shell.
105 A command to be executed in the system shell.
106
106
107 Returns
107 Returns
108 -------
108 -------
109 output : str
109 output : str
110 A string containing the combination of stdout and stderr from the
110 A string containing the combination of stdout and stderr from the
111 subprocess, in whatever order the subprocess originally wrote to its
111 subprocess, in whatever order the subprocess originally wrote to its
112 file descriptors (so the order of the information in this string is the
112 file descriptors (so the order of the information in this string is the
113 correct order as would be seen if running the command in a terminal).
113 correct order as would be seen if running the command in a terminal).
114 """
114 """
115 try:
115 try:
116 return pexpect.run(self.sh, args=['-c', cmd]).replace('\r\n', '\n')
116 return pexpect.run(self.sh, args=['-c', cmd]).replace('\r\n', '\n')
117 except KeyboardInterrupt:
117 except KeyboardInterrupt:
118 print('^C', file=sys.stderr, end='')
118 print('^C', file=sys.stderr, end='')
119
119
120 def system(self, cmd):
120 def system(self, cmd):
121 """Execute a command in a subshell.
121 """Execute a command in a subshell.
122
122
123 Parameters
123 Parameters
124 ----------
124 ----------
125 cmd : str
125 cmd : str
126 A command to be executed in the system shell.
126 A command to be executed in the system shell.
127
127
128 Returns
128 Returns
129 -------
129 -------
130 int : child's exitstatus
130 int : child's exitstatus
131 """
131 """
132 # Get likely encoding for the output.
132 # Get likely encoding for the output.
133 enc = DEFAULT_ENCODING
133 enc = DEFAULT_ENCODING
134
134
135 # Patterns to match on the output, for pexpect. We read input and
135 # Patterns to match on the output, for pexpect. We read input and
136 # allow either a short timeout or EOF
136 # allow either a short timeout or EOF
137 patterns = [pexpect.TIMEOUT, pexpect.EOF]
137 patterns = [pexpect.TIMEOUT, pexpect.EOF]
138 # the index of the EOF pattern in the list.
138 # the index of the EOF pattern in the list.
139 # even though we know it's 1, this call means we don't have to worry if
139 # even though we know it's 1, this call means we don't have to worry if
140 # we change the above list, and forget to change this value:
140 # we change the above list, and forget to change this value:
141 EOF_index = patterns.index(pexpect.EOF)
141 EOF_index = patterns.index(pexpect.EOF)
142 # The size of the output stored so far in the process output buffer.
142 # The size of the output stored so far in the process output buffer.
143 # Since pexpect only appends to this buffer, each time we print we
143 # Since pexpect only appends to this buffer, each time we print we
144 # record how far we've printed, so that next time we only print *new*
144 # record how far we've printed, so that next time we only print *new*
145 # content from the buffer.
145 # content from the buffer.
146 out_size = 0
146 out_size = 0
147 try:
147 try:
148 # Since we're not really searching the buffer for text patterns, we
148 # Since we're not really searching the buffer for text patterns, we
149 # can set pexpect's search window to be tiny and it won't matter.
149 # can set pexpect's search window to be tiny and it won't matter.
150 # We only search for the 'patterns' timeout or EOF, which aren't in
150 # We only search for the 'patterns' timeout or EOF, which aren't in
151 # the text itself.
151 # the text itself.
152 #child = pexpect.spawn(pcmd, searchwindowsize=1)
152 #child = pexpect.spawn(pcmd, searchwindowsize=1)
153 if hasattr(pexpect, 'spawnb'):
153 if hasattr(pexpect, 'spawnb'):
154 child = pexpect.spawnb(self.sh, args=['-c', cmd]) # Pexpect-U
154 child = pexpect.spawnb(self.sh, args=['-c', cmd]) # Pexpect-U
155 else:
155 else:
156 child = pexpect.spawn(self.sh, args=['-c', cmd]) # Vanilla Pexpect
156 child = pexpect.spawn(self.sh, args=['-c', cmd]) # Vanilla Pexpect
157 flush = sys.stdout.flush
157 flush = sys.stdout.flush
158 while True:
158 while True:
159 # res is the index of the pattern that caused the match, so we
159 # res is the index of the pattern that caused the match, so we
160 # know whether we've finished (if we matched EOF) or not
160 # know whether we've finished (if we matched EOF) or not
161 res_idx = child.expect_list(patterns, self.read_timeout)
161 res_idx = child.expect_list(patterns, self.read_timeout)
162 print(child.before[out_size:].decode(enc, 'replace'), end='')
162 print(child.before[out_size:].decode(enc, 'replace'), end='')
163 flush()
163 flush()
164 if res_idx==EOF_index:
164 if res_idx==EOF_index:
165 break
165 break
166 # Update the pointer to what we've already printed
166 # Update the pointer to what we've already printed
167 out_size = len(child.before)
167 out_size = len(child.before)
168 except KeyboardInterrupt:
168 except KeyboardInterrupt:
169 # We need to send ^C to the process. The ascii code for '^C' is 3
169 # We need to send ^C to the process. The ascii code for '^C' is 3
170 # (the character is known as ETX for 'End of Text', see
170 # (the character is known as ETX for 'End of Text', see
171 # curses.ascii.ETX).
171 # curses.ascii.ETX).
172 child.sendline(chr(3))
172 child.sendline(chr(3))
173 # Read and print any more output the program might produce on its
173 # Read and print any more output the program might produce on its
174 # way out.
174 # way out.
175 try:
175 try:
176 out_size = len(child.before)
176 out_size = len(child.before)
177 child.expect_list(patterns, self.terminate_timeout)
177 child.expect_list(patterns, self.terminate_timeout)
178 print(child.before[out_size:].decode(enc, 'replace'), end='')
178 print(child.before[out_size:].decode(enc, 'replace'), end='')
179 sys.stdout.flush()
179 sys.stdout.flush()
180 except KeyboardInterrupt:
180 except KeyboardInterrupt:
181 # Impatient users tend to type it multiple times
181 # Impatient users tend to type it multiple times
182 pass
182 pass
183 finally:
183 finally:
184 # Ensure the subprocess really is terminated
184 # Ensure the subprocess really is terminated
185 child.terminate(force=True)
185 child.terminate(force=True)
186 # add isalive check, to ensure exitstatus is set:
186 # add isalive check, to ensure exitstatus is set:
187 child.isalive()
187 child.isalive()
188
188
189 # We follow the subprocess pattern, returning either the exit status
189 # We follow the subprocess pattern, returning either the exit status
190 # as a positive number, or the terminating signal as a negative
190 # as a positive number, or the terminating signal as a negative
191 # number.
191 # number.
192 # on Linux, sh returns 128+n for signals terminating child processes on Linux
192 # on Linux, sh returns 128+n for signals terminating child processes on Linux
193 # on BSD (OS X), the signal code is set instead
193 # on BSD (OS X), the signal code is set instead
194 if child.exitstatus is None:
194 if child.exitstatus is None:
195 # on WIFSIGNALED, pexpect sets signalstatus, leaving exitstatus=None
195 # on WIFSIGNALED, pexpect sets signalstatus, leaving exitstatus=None
196 if child.signalstatus is None:
196 if child.signalstatus is None:
197 # this condition may never occur,
197 # this condition may never occur,
198 # but let's be certain we always return an integer.
198 # but let's be certain we always return an integer.
199 return 0
199 return 0
200 return -child.signalstatus
200 return -child.signalstatus
201 if child.exitstatus > 128:
201 if child.exitstatus > 128:
202 return -(child.exitstatus - 128)
202 return -(child.exitstatus - 128)
203 return child.exitstatus
203 return child.exitstatus
204
204
205
205
206 # Make system() with a functional interface for outside use. Note that we use
206 # Make system() with a functional interface for outside use. Note that we use
207 # getoutput() from the _common utils, which is built on top of popen(). Using
207 # getoutput() from the _common utils, which is built on top of popen(). Using
208 # pexpect to get subprocess output produces difficult to parse output, since
208 # pexpect to get subprocess output produces difficult to parse output, since
209 # programs think they are talking to a tty and produce highly formatted output
209 # programs think they are talking to a tty and produce highly formatted output
210 # (ls is a good example) that makes them hard.
210 # (ls is a good example) that makes them hard.
211 system = ProcessHandler().system
211 system = ProcessHandler().system
212
212
213 def check_pid(pid):
213 def check_pid(pid):
214 try:
214 try:
215 os.kill(pid, 0)
215 os.kill(pid, 0)
216 except OSError as err:
216 except OSError as err:
217 if err.errno == errno.ESRCH:
217 if err.errno == errno.ESRCH:
218 return False
218 return False
219 elif err.errno == errno.EPERM:
219 elif err.errno == errno.EPERM:
220 # Don't have permission to signal the process - probably means it exists
220 # Don't have permission to signal the process - probably means it exists
221 return True
221 return True
222 raise
222 raise
223 else:
223 else:
224 return True
224 return True
@@ -1,191 +1,191
1 """Windows-specific implementation of process utilities.
1 """Windows-specific implementation of process utilities.
2
2
3 This file is only meant to be imported by process.py, not by end-users.
3 This file is only meant to be imported by process.py, not by end-users.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 # stdlib
17 # stdlib
18 import os
18 import os
19 import sys
19 import sys
20 import ctypes
20 import ctypes
21
21
22 from ctypes import c_int, POINTER
22 from ctypes import c_int, POINTER
23 from ctypes.wintypes import LPCWSTR, HLOCAL
23 from ctypes.wintypes import LPCWSTR, HLOCAL
24 from subprocess import STDOUT
24 from subprocess import STDOUT
25
25
26 # our own imports
26 # our own imports
27 from ._process_common import read_no_interrupt, process_handler, arg_split as py_arg_split
27 from ._process_common import read_no_interrupt, process_handler, arg_split as py_arg_split
28 from . import py3compat
28 from . import py3compat
29 from .encoding import DEFAULT_ENCODING
29 from .encoding import DEFAULT_ENCODING
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Function definitions
32 # Function definitions
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 class AvoidUNCPath(object):
35 class AvoidUNCPath(object):
36 """A context manager to protect command execution from UNC paths.
36 """A context manager to protect command execution from UNC paths.
37
37
38 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
38 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
39 This context manager temporarily changes directory to the 'C:' drive on
39 This context manager temporarily changes directory to the 'C:' drive on
40 entering, and restores the original working directory on exit.
40 entering, and restores the original working directory on exit.
41
41
42 The context manager returns the starting working directory *if* it made a
42 The context manager returns the starting working directory *if* it made a
43 change and None otherwise, so that users can apply the necessary adjustment
43 change and None otherwise, so that users can apply the necessary adjustment
44 to their system calls in the event of a change.
44 to their system calls in the event of a change.
45
45
46 Examples
46 Examples
47 --------
47 --------
48 ::
48 ::
49 cmd = 'dir'
49 cmd = 'dir'
50 with AvoidUNCPath() as path:
50 with AvoidUNCPath() as path:
51 if path is not None:
51 if path is not None:
52 cmd = '"pushd %s &&"%s' % (path, cmd)
52 cmd = '"pushd %s &&"%s' % (path, cmd)
53 os.system(cmd)
53 os.system(cmd)
54 """
54 """
55 def __enter__(self):
55 def __enter__(self):
56 self.path = os.getcwd()
56 self.path = os.getcwd()
57 self.is_unc_path = self.path.startswith(r"\\")
57 self.is_unc_path = self.path.startswith(r"\\")
58 if self.is_unc_path:
58 if self.is_unc_path:
59 # change to c drive (as cmd.exe cannot handle UNC addresses)
59 # change to c drive (as cmd.exe cannot handle UNC addresses)
60 os.chdir("C:")
60 os.chdir("C:")
61 return self.path
61 return self.path
62 else:
62 else:
63 # We return None to signal that there was no change in the working
63 # We return None to signal that there was no change in the working
64 # directory
64 # directory
65 return None
65 return None
66
66
67 def __exit__(self, exc_type, exc_value, traceback):
67 def __exit__(self, exc_type, exc_value, traceback):
68 if self.is_unc_path:
68 if self.is_unc_path:
69 os.chdir(self.path)
69 os.chdir(self.path)
70
70
71
71
72 def _find_cmd(cmd):
72 def _find_cmd(cmd):
73 """Find the full path to a .bat or .exe using the win32api module."""
73 """Find the full path to a .bat or .exe using the win32api module."""
74 try:
74 try:
75 from win32api import SearchPath
75 from win32api import SearchPath
76 except ImportError:
76 except ImportError:
77 raise ImportError('you need to have pywin32 installed for this to work')
77 raise ImportError('you need to have pywin32 installed for this to work')
78 else:
78 else:
79 PATH = os.environ['PATH']
79 PATH = os.environ['PATH']
80 extensions = ['.exe', '.com', '.bat', '.py']
80 extensions = ['.exe', '.com', '.bat', '.py']
81 path = None
81 path = None
82 for ext in extensions:
82 for ext in extensions:
83 try:
83 try:
84 path = SearchPath(PATH, cmd, ext)[0]
84 path = SearchPath(PATH, cmd, ext)[0]
85 except:
85 except:
86 pass
86 pass
87 if path is None:
87 if path is None:
88 raise OSError("command %r not found" % cmd)
88 raise OSError("command %r not found" % cmd)
89 else:
89 else:
90 return path
90 return path
91
91
92
92
93 def _system_body(p):
93 def _system_body(p):
94 """Callback for _system."""
94 """Callback for _system."""
95 enc = DEFAULT_ENCODING
95 enc = DEFAULT_ENCODING
96 for line in read_no_interrupt(p.stdout).splitlines():
96 for line in read_no_interrupt(p.stdout).splitlines():
97 line = line.decode(enc, 'replace')
97 line = line.decode(enc, 'replace')
98 print(line, file=sys.stdout)
98 print(line, file=sys.stdout)
99 for line in read_no_interrupt(p.stderr).splitlines():
99 for line in read_no_interrupt(p.stderr).splitlines():
100 line = line.decode(enc, 'replace')
100 line = line.decode(enc, 'replace')
101 print(line, file=sys.stderr)
101 print(line, file=sys.stderr)
102
102
103 # Wait to finish for returncode
103 # Wait to finish for returncode
104 return p.wait()
104 return p.wait()
105
105
106
106
107 def system(cmd):
107 def system(cmd):
108 """Win32 version of os.system() that works with network shares.
108 """Win32 version of os.system() that works with network shares.
109
109
110 Note that this implementation returns None, as meant for use in IPython.
110 Note that this implementation returns None, as meant for use in IPython.
111
111
112 Parameters
112 Parameters
113 ----------
113 ----------
114 cmd : str or list
114 cmd : str or list
115 A command to be executed in the system shell.
115 A command to be executed in the system shell.
116
116
117 Returns
117 Returns
118 -------
118 -------
119 None : we explicitly do NOT return the subprocess status code, as this
119 None : we explicitly do NOT return the subprocess status code, as this
120 utility is meant to be used extensively in IPython, where any return value
120 utility is meant to be used extensively in IPython, where any return value
121 would trigger :func:`sys.displayhook` calls.
121 would trigger :func:`sys.displayhook` calls.
122 """
122 """
123 # The controller provides interactivity with both
123 # The controller provides interactivity with both
124 # stdin and stdout
124 # stdin and stdout
125 #import _process_win32_controller
125 #import _process_win32_controller
126 #_process_win32_controller.system(cmd)
126 #_process_win32_controller.system(cmd)
127
127
128 with AvoidUNCPath() as path:
128 with AvoidUNCPath() as path:
129 if path is not None:
129 if path is not None:
130 cmd = '"pushd %s &&"%s' % (path, cmd)
130 cmd = '"pushd %s &&"%s' % (path, cmd)
131 return process_handler(cmd, _system_body)
131 return process_handler(cmd, _system_body)
132
132
133 def getoutput(cmd):
133 def getoutput(cmd):
134 """Return standard output of executing cmd in a shell.
134 """Return standard output of executing cmd in a shell.
135
135
136 Accepts the same arguments as os.system().
136 Accepts the same arguments as os.system().
137
137
138 Parameters
138 Parameters
139 ----------
139 ----------
140 cmd : str or list
140 cmd : str or list
141 A command to be executed in the system shell.
141 A command to be executed in the system shell.
142
142
143 Returns
143 Returns
144 -------
144 -------
145 stdout : str
145 stdout : str
146 """
146 """
147
147
148 with AvoidUNCPath() as path:
148 with AvoidUNCPath() as path:
149 if path is not None:
149 if path is not None:
150 cmd = '"pushd %s &&"%s' % (path, cmd)
150 cmd = '"pushd %s &&"%s' % (path, cmd)
151 out = process_handler(cmd, lambda p: p.communicate()[0], STDOUT)
151 out = process_handler(cmd, lambda p: p.communicate()[0], STDOUT)
152
152
153 if out is None:
153 if out is None:
154 out = b''
154 out = b''
155 return py3compat.bytes_to_str(out)
155 return py3compat.decode(out)
156
156
157 try:
157 try:
158 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
158 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
159 CommandLineToArgvW.arg_types = [LPCWSTR, POINTER(c_int)]
159 CommandLineToArgvW.arg_types = [LPCWSTR, POINTER(c_int)]
160 CommandLineToArgvW.restype = POINTER(LPCWSTR)
160 CommandLineToArgvW.restype = POINTER(LPCWSTR)
161 LocalFree = ctypes.windll.kernel32.LocalFree
161 LocalFree = ctypes.windll.kernel32.LocalFree
162 LocalFree.res_type = HLOCAL
162 LocalFree.res_type = HLOCAL
163 LocalFree.arg_types = [HLOCAL]
163 LocalFree.arg_types = [HLOCAL]
164
164
165 def arg_split(commandline, posix=False, strict=True):
165 def arg_split(commandline, posix=False, strict=True):
166 """Split a command line's arguments in a shell-like manner.
166 """Split a command line's arguments in a shell-like manner.
167
167
168 This is a special version for windows that use a ctypes call to CommandLineToArgvW
168 This is a special version for windows that use a ctypes call to CommandLineToArgvW
169 to do the argv splitting. The posix paramter is ignored.
169 to do the argv splitting. The posix paramter is ignored.
170
170
171 If strict=False, process_common.arg_split(...strict=False) is used instead.
171 If strict=False, process_common.arg_split(...strict=False) is used instead.
172 """
172 """
173 #CommandLineToArgvW returns path to executable if called with empty string.
173 #CommandLineToArgvW returns path to executable if called with empty string.
174 if commandline.strip() == "":
174 if commandline.strip() == "":
175 return []
175 return []
176 if not strict:
176 if not strict:
177 # not really a cl-arg, fallback on _process_common
177 # not really a cl-arg, fallback on _process_common
178 return py_arg_split(commandline, posix=posix, strict=strict)
178 return py_arg_split(commandline, posix=posix, strict=strict)
179 argvn = c_int()
179 argvn = c_int()
180 result_pointer = CommandLineToArgvW(py3compat.cast_unicode(commandline.lstrip()), ctypes.byref(argvn))
180 result_pointer = CommandLineToArgvW(py3compat.cast_unicode(commandline.lstrip()), ctypes.byref(argvn))
181 result_array_type = LPCWSTR * argvn.value
181 result_array_type = LPCWSTR * argvn.value
182 result = [arg for arg in result_array_type.from_address(ctypes.addressof(result_pointer.contents))]
182 result = [arg for arg in result_array_type.from_address(ctypes.addressof(result_pointer.contents))]
183 retval = LocalFree(result_pointer)
183 retval = LocalFree(result_pointer)
184 return result
184 return result
185 except AttributeError:
185 except AttributeError:
186 arg_split = py_arg_split
186 arg_split = py_arg_split
187
187
188 def check_pid(pid):
188 def check_pid(pid):
189 # OpenProcess returns 0 if no such process (of ours) exists
189 # OpenProcess returns 0 if no such process (of ours) exists
190 # positive int otherwise
190 # positive int otherwise
191 return bool(ctypes.windll.kernel32.OpenProcess(1,0,pid))
191 return bool(ctypes.windll.kernel32.OpenProcess(1,0,pid))
@@ -1,576 +1,573
1 """Windows-specific implementation of process utilities with direct WinAPI.
1 """Windows-specific implementation of process utilities with direct WinAPI.
2
2
3 This file is meant to be used by process.py
3 This file is meant to be used by process.py
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13
13
14 # stdlib
14 # stdlib
15 import os, sys, threading
15 import os, sys, threading
16 import ctypes, msvcrt
16 import ctypes, msvcrt
17
17
18 # local imports
19 from . import py3compat
20
21 # Win32 API types needed for the API calls
18 # Win32 API types needed for the API calls
22 from ctypes import POINTER
19 from ctypes import POINTER
23 from ctypes.wintypes import HANDLE, HLOCAL, LPVOID, WORD, DWORD, BOOL, \
20 from ctypes.wintypes import HANDLE, HLOCAL, LPVOID, WORD, DWORD, BOOL, \
24 ULONG, LPCWSTR
21 ULONG, LPCWSTR
25 LPDWORD = POINTER(DWORD)
22 LPDWORD = POINTER(DWORD)
26 LPHANDLE = POINTER(HANDLE)
23 LPHANDLE = POINTER(HANDLE)
27 ULONG_PTR = POINTER(ULONG)
24 ULONG_PTR = POINTER(ULONG)
28 class SECURITY_ATTRIBUTES(ctypes.Structure):
25 class SECURITY_ATTRIBUTES(ctypes.Structure):
29 _fields_ = [("nLength", DWORD),
26 _fields_ = [("nLength", DWORD),
30 ("lpSecurityDescriptor", LPVOID),
27 ("lpSecurityDescriptor", LPVOID),
31 ("bInheritHandle", BOOL)]
28 ("bInheritHandle", BOOL)]
32 LPSECURITY_ATTRIBUTES = POINTER(SECURITY_ATTRIBUTES)
29 LPSECURITY_ATTRIBUTES = POINTER(SECURITY_ATTRIBUTES)
33 class STARTUPINFO(ctypes.Structure):
30 class STARTUPINFO(ctypes.Structure):
34 _fields_ = [("cb", DWORD),
31 _fields_ = [("cb", DWORD),
35 ("lpReserved", LPCWSTR),
32 ("lpReserved", LPCWSTR),
36 ("lpDesktop", LPCWSTR),
33 ("lpDesktop", LPCWSTR),
37 ("lpTitle", LPCWSTR),
34 ("lpTitle", LPCWSTR),
38 ("dwX", DWORD),
35 ("dwX", DWORD),
39 ("dwY", DWORD),
36 ("dwY", DWORD),
40 ("dwXSize", DWORD),
37 ("dwXSize", DWORD),
41 ("dwYSize", DWORD),
38 ("dwYSize", DWORD),
42 ("dwXCountChars", DWORD),
39 ("dwXCountChars", DWORD),
43 ("dwYCountChars", DWORD),
40 ("dwYCountChars", DWORD),
44 ("dwFillAttribute", DWORD),
41 ("dwFillAttribute", DWORD),
45 ("dwFlags", DWORD),
42 ("dwFlags", DWORD),
46 ("wShowWindow", WORD),
43 ("wShowWindow", WORD),
47 ("cbReserved2", WORD),
44 ("cbReserved2", WORD),
48 ("lpReserved2", LPVOID),
45 ("lpReserved2", LPVOID),
49 ("hStdInput", HANDLE),
46 ("hStdInput", HANDLE),
50 ("hStdOutput", HANDLE),
47 ("hStdOutput", HANDLE),
51 ("hStdError", HANDLE)]
48 ("hStdError", HANDLE)]
52 LPSTARTUPINFO = POINTER(STARTUPINFO)
49 LPSTARTUPINFO = POINTER(STARTUPINFO)
53 class PROCESS_INFORMATION(ctypes.Structure):
50 class PROCESS_INFORMATION(ctypes.Structure):
54 _fields_ = [("hProcess", HANDLE),
51 _fields_ = [("hProcess", HANDLE),
55 ("hThread", HANDLE),
52 ("hThread", HANDLE),
56 ("dwProcessId", DWORD),
53 ("dwProcessId", DWORD),
57 ("dwThreadId", DWORD)]
54 ("dwThreadId", DWORD)]
58 LPPROCESS_INFORMATION = POINTER(PROCESS_INFORMATION)
55 LPPROCESS_INFORMATION = POINTER(PROCESS_INFORMATION)
59
56
60 # Win32 API constants needed
57 # Win32 API constants needed
61 ERROR_HANDLE_EOF = 38
58 ERROR_HANDLE_EOF = 38
62 ERROR_BROKEN_PIPE = 109
59 ERROR_BROKEN_PIPE = 109
63 ERROR_NO_DATA = 232
60 ERROR_NO_DATA = 232
64 HANDLE_FLAG_INHERIT = 0x0001
61 HANDLE_FLAG_INHERIT = 0x0001
65 STARTF_USESTDHANDLES = 0x0100
62 STARTF_USESTDHANDLES = 0x0100
66 CREATE_SUSPENDED = 0x0004
63 CREATE_SUSPENDED = 0x0004
67 CREATE_NEW_CONSOLE = 0x0010
64 CREATE_NEW_CONSOLE = 0x0010
68 CREATE_NO_WINDOW = 0x08000000
65 CREATE_NO_WINDOW = 0x08000000
69 STILL_ACTIVE = 259
66 STILL_ACTIVE = 259
70 WAIT_TIMEOUT = 0x0102
67 WAIT_TIMEOUT = 0x0102
71 WAIT_FAILED = 0xFFFFFFFF
68 WAIT_FAILED = 0xFFFFFFFF
72 INFINITE = 0xFFFFFFFF
69 INFINITE = 0xFFFFFFFF
73 DUPLICATE_SAME_ACCESS = 0x00000002
70 DUPLICATE_SAME_ACCESS = 0x00000002
74 ENABLE_ECHO_INPUT = 0x0004
71 ENABLE_ECHO_INPUT = 0x0004
75 ENABLE_LINE_INPUT = 0x0002
72 ENABLE_LINE_INPUT = 0x0002
76 ENABLE_PROCESSED_INPUT = 0x0001
73 ENABLE_PROCESSED_INPUT = 0x0001
77
74
78 # Win32 API functions needed
75 # Win32 API functions needed
79 GetLastError = ctypes.windll.kernel32.GetLastError
76 GetLastError = ctypes.windll.kernel32.GetLastError
80 GetLastError.argtypes = []
77 GetLastError.argtypes = []
81 GetLastError.restype = DWORD
78 GetLastError.restype = DWORD
82
79
83 CreateFile = ctypes.windll.kernel32.CreateFileW
80 CreateFile = ctypes.windll.kernel32.CreateFileW
84 CreateFile.argtypes = [LPCWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE]
81 CreateFile.argtypes = [LPCWSTR, DWORD, DWORD, LPVOID, DWORD, DWORD, HANDLE]
85 CreateFile.restype = HANDLE
82 CreateFile.restype = HANDLE
86
83
87 CreatePipe = ctypes.windll.kernel32.CreatePipe
84 CreatePipe = ctypes.windll.kernel32.CreatePipe
88 CreatePipe.argtypes = [POINTER(HANDLE), POINTER(HANDLE),
85 CreatePipe.argtypes = [POINTER(HANDLE), POINTER(HANDLE),
89 LPSECURITY_ATTRIBUTES, DWORD]
86 LPSECURITY_ATTRIBUTES, DWORD]
90 CreatePipe.restype = BOOL
87 CreatePipe.restype = BOOL
91
88
92 CreateProcess = ctypes.windll.kernel32.CreateProcessW
89 CreateProcess = ctypes.windll.kernel32.CreateProcessW
93 CreateProcess.argtypes = [LPCWSTR, LPCWSTR, LPSECURITY_ATTRIBUTES,
90 CreateProcess.argtypes = [LPCWSTR, LPCWSTR, LPSECURITY_ATTRIBUTES,
94 LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID, LPCWSTR, LPSTARTUPINFO,
91 LPSECURITY_ATTRIBUTES, BOOL, DWORD, LPVOID, LPCWSTR, LPSTARTUPINFO,
95 LPPROCESS_INFORMATION]
92 LPPROCESS_INFORMATION]
96 CreateProcess.restype = BOOL
93 CreateProcess.restype = BOOL
97
94
98 GetExitCodeProcess = ctypes.windll.kernel32.GetExitCodeProcess
95 GetExitCodeProcess = ctypes.windll.kernel32.GetExitCodeProcess
99 GetExitCodeProcess.argtypes = [HANDLE, LPDWORD]
96 GetExitCodeProcess.argtypes = [HANDLE, LPDWORD]
100 GetExitCodeProcess.restype = BOOL
97 GetExitCodeProcess.restype = BOOL
101
98
102 GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
99 GetCurrentProcess = ctypes.windll.kernel32.GetCurrentProcess
103 GetCurrentProcess.argtypes = []
100 GetCurrentProcess.argtypes = []
104 GetCurrentProcess.restype = HANDLE
101 GetCurrentProcess.restype = HANDLE
105
102
106 ResumeThread = ctypes.windll.kernel32.ResumeThread
103 ResumeThread = ctypes.windll.kernel32.ResumeThread
107 ResumeThread.argtypes = [HANDLE]
104 ResumeThread.argtypes = [HANDLE]
108 ResumeThread.restype = DWORD
105 ResumeThread.restype = DWORD
109
106
110 ReadFile = ctypes.windll.kernel32.ReadFile
107 ReadFile = ctypes.windll.kernel32.ReadFile
111 ReadFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
108 ReadFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
112 ReadFile.restype = BOOL
109 ReadFile.restype = BOOL
113
110
114 WriteFile = ctypes.windll.kernel32.WriteFile
111 WriteFile = ctypes.windll.kernel32.WriteFile
115 WriteFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
112 WriteFile.argtypes = [HANDLE, LPVOID, DWORD, LPDWORD, LPVOID]
116 WriteFile.restype = BOOL
113 WriteFile.restype = BOOL
117
114
118 GetConsoleMode = ctypes.windll.kernel32.GetConsoleMode
115 GetConsoleMode = ctypes.windll.kernel32.GetConsoleMode
119 GetConsoleMode.argtypes = [HANDLE, LPDWORD]
116 GetConsoleMode.argtypes = [HANDLE, LPDWORD]
120 GetConsoleMode.restype = BOOL
117 GetConsoleMode.restype = BOOL
121
118
122 SetConsoleMode = ctypes.windll.kernel32.SetConsoleMode
119 SetConsoleMode = ctypes.windll.kernel32.SetConsoleMode
123 SetConsoleMode.argtypes = [HANDLE, DWORD]
120 SetConsoleMode.argtypes = [HANDLE, DWORD]
124 SetConsoleMode.restype = BOOL
121 SetConsoleMode.restype = BOOL
125
122
126 FlushConsoleInputBuffer = ctypes.windll.kernel32.FlushConsoleInputBuffer
123 FlushConsoleInputBuffer = ctypes.windll.kernel32.FlushConsoleInputBuffer
127 FlushConsoleInputBuffer.argtypes = [HANDLE]
124 FlushConsoleInputBuffer.argtypes = [HANDLE]
128 FlushConsoleInputBuffer.restype = BOOL
125 FlushConsoleInputBuffer.restype = BOOL
129
126
130 WaitForSingleObject = ctypes.windll.kernel32.WaitForSingleObject
127 WaitForSingleObject = ctypes.windll.kernel32.WaitForSingleObject
131 WaitForSingleObject.argtypes = [HANDLE, DWORD]
128 WaitForSingleObject.argtypes = [HANDLE, DWORD]
132 WaitForSingleObject.restype = DWORD
129 WaitForSingleObject.restype = DWORD
133
130
134 DuplicateHandle = ctypes.windll.kernel32.DuplicateHandle
131 DuplicateHandle = ctypes.windll.kernel32.DuplicateHandle
135 DuplicateHandle.argtypes = [HANDLE, HANDLE, HANDLE, LPHANDLE,
132 DuplicateHandle.argtypes = [HANDLE, HANDLE, HANDLE, LPHANDLE,
136 DWORD, BOOL, DWORD]
133 DWORD, BOOL, DWORD]
137 DuplicateHandle.restype = BOOL
134 DuplicateHandle.restype = BOOL
138
135
139 SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
136 SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
140 SetHandleInformation.argtypes = [HANDLE, DWORD, DWORD]
137 SetHandleInformation.argtypes = [HANDLE, DWORD, DWORD]
141 SetHandleInformation.restype = BOOL
138 SetHandleInformation.restype = BOOL
142
139
143 CloseHandle = ctypes.windll.kernel32.CloseHandle
140 CloseHandle = ctypes.windll.kernel32.CloseHandle
144 CloseHandle.argtypes = [HANDLE]
141 CloseHandle.argtypes = [HANDLE]
145 CloseHandle.restype = BOOL
142 CloseHandle.restype = BOOL
146
143
147 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
144 CommandLineToArgvW = ctypes.windll.shell32.CommandLineToArgvW
148 CommandLineToArgvW.argtypes = [LPCWSTR, POINTER(ctypes.c_int)]
145 CommandLineToArgvW.argtypes = [LPCWSTR, POINTER(ctypes.c_int)]
149 CommandLineToArgvW.restype = POINTER(LPCWSTR)
146 CommandLineToArgvW.restype = POINTER(LPCWSTR)
150
147
151 LocalFree = ctypes.windll.kernel32.LocalFree
148 LocalFree = ctypes.windll.kernel32.LocalFree
152 LocalFree.argtypes = [HLOCAL]
149 LocalFree.argtypes = [HLOCAL]
153 LocalFree.restype = HLOCAL
150 LocalFree.restype = HLOCAL
154
151
155 class AvoidUNCPath(object):
152 class AvoidUNCPath(object):
156 """A context manager to protect command execution from UNC paths.
153 """A context manager to protect command execution from UNC paths.
157
154
158 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
155 In the Win32 API, commands can't be invoked with the cwd being a UNC path.
159 This context manager temporarily changes directory to the 'C:' drive on
156 This context manager temporarily changes directory to the 'C:' drive on
160 entering, and restores the original working directory on exit.
157 entering, and restores the original working directory on exit.
161
158
162 The context manager returns the starting working directory *if* it made a
159 The context manager returns the starting working directory *if* it made a
163 change and None otherwise, so that users can apply the necessary adjustment
160 change and None otherwise, so that users can apply the necessary adjustment
164 to their system calls in the event of a change.
161 to their system calls in the event of a change.
165
162
166 Examples
163 Examples
167 --------
164 --------
168 ::
165 ::
169 cmd = 'dir'
166 cmd = 'dir'
170 with AvoidUNCPath() as path:
167 with AvoidUNCPath() as path:
171 if path is not None:
168 if path is not None:
172 cmd = '"pushd %s &&"%s' % (path, cmd)
169 cmd = '"pushd %s &&"%s' % (path, cmd)
173 os.system(cmd)
170 os.system(cmd)
174 """
171 """
175 def __enter__(self):
172 def __enter__(self):
176 self.path = os.getcwd()
173 self.path = os.getcwd()
177 self.is_unc_path = self.path.startswith(r"\\")
174 self.is_unc_path = self.path.startswith(r"\\")
178 if self.is_unc_path:
175 if self.is_unc_path:
179 # change to c drive (as cmd.exe cannot handle UNC addresses)
176 # change to c drive (as cmd.exe cannot handle UNC addresses)
180 os.chdir("C:")
177 os.chdir("C:")
181 return self.path
178 return self.path
182 else:
179 else:
183 # We return None to signal that there was no change in the working
180 # We return None to signal that there was no change in the working
184 # directory
181 # directory
185 return None
182 return None
186
183
187 def __exit__(self, exc_type, exc_value, traceback):
184 def __exit__(self, exc_type, exc_value, traceback):
188 if self.is_unc_path:
185 if self.is_unc_path:
189 os.chdir(self.path)
186 os.chdir(self.path)
190
187
191
188
192 class Win32ShellCommandController(object):
189 class Win32ShellCommandController(object):
193 """Runs a shell command in a 'with' context.
190 """Runs a shell command in a 'with' context.
194
191
195 This implementation is Win32-specific.
192 This implementation is Win32-specific.
196
193
197 Example:
194 Example:
198 # Runs the command interactively with default console stdin/stdout
195 # Runs the command interactively with default console stdin/stdout
199 with ShellCommandController('python -i') as scc:
196 with ShellCommandController('python -i') as scc:
200 scc.run()
197 scc.run()
201
198
202 # Runs the command using the provided functions for stdin/stdout
199 # Runs the command using the provided functions for stdin/stdout
203 def my_stdout_func(s):
200 def my_stdout_func(s):
204 # print or save the string 's'
201 # print or save the string 's'
205 write_to_stdout(s)
202 write_to_stdout(s)
206 def my_stdin_func():
203 def my_stdin_func():
207 # If input is available, return it as a string.
204 # If input is available, return it as a string.
208 if input_available():
205 if input_available():
209 return get_input()
206 return get_input()
210 # If no input available, return None after a short delay to
207 # If no input available, return None after a short delay to
211 # keep from blocking.
208 # keep from blocking.
212 else:
209 else:
213 time.sleep(0.01)
210 time.sleep(0.01)
214 return None
211 return None
215
212
216 with ShellCommandController('python -i') as scc:
213 with ShellCommandController('python -i') as scc:
217 scc.run(my_stdout_func, my_stdin_func)
214 scc.run(my_stdout_func, my_stdin_func)
218 """
215 """
219
216
220 def __init__(self, cmd, mergeout = True):
217 def __init__(self, cmd, mergeout = True):
221 """Initializes the shell command controller.
218 """Initializes the shell command controller.
222
219
223 The cmd is the program to execute, and mergeout is
220 The cmd is the program to execute, and mergeout is
224 whether to blend stdout and stderr into one output
221 whether to blend stdout and stderr into one output
225 in stdout. Merging them together in this fashion more
222 in stdout. Merging them together in this fashion more
226 reliably keeps stdout and stderr in the correct order
223 reliably keeps stdout and stderr in the correct order
227 especially for interactive shell usage.
224 especially for interactive shell usage.
228 """
225 """
229 self.cmd = cmd
226 self.cmd = cmd
230 self.mergeout = mergeout
227 self.mergeout = mergeout
231
228
232 def __enter__(self):
229 def __enter__(self):
233 cmd = self.cmd
230 cmd = self.cmd
234 mergeout = self.mergeout
231 mergeout = self.mergeout
235
232
236 self.hstdout, self.hstdin, self.hstderr = None, None, None
233 self.hstdout, self.hstdin, self.hstderr = None, None, None
237 self.piProcInfo = None
234 self.piProcInfo = None
238 try:
235 try:
239 p_hstdout, c_hstdout, p_hstderr, \
236 p_hstdout, c_hstdout, p_hstderr, \
240 c_hstderr, p_hstdin, c_hstdin = [None]*6
237 c_hstderr, p_hstdin, c_hstdin = [None]*6
241
238
242 # SECURITY_ATTRIBUTES with inherit handle set to True
239 # SECURITY_ATTRIBUTES with inherit handle set to True
243 saAttr = SECURITY_ATTRIBUTES()
240 saAttr = SECURITY_ATTRIBUTES()
244 saAttr.nLength = ctypes.sizeof(saAttr)
241 saAttr.nLength = ctypes.sizeof(saAttr)
245 saAttr.bInheritHandle = True
242 saAttr.bInheritHandle = True
246 saAttr.lpSecurityDescriptor = None
243 saAttr.lpSecurityDescriptor = None
247
244
248 def create_pipe(uninherit):
245 def create_pipe(uninherit):
249 """Creates a Windows pipe, which consists of two handles.
246 """Creates a Windows pipe, which consists of two handles.
250
247
251 The 'uninherit' parameter controls which handle is not
248 The 'uninherit' parameter controls which handle is not
252 inherited by the child process.
249 inherited by the child process.
253 """
250 """
254 handles = HANDLE(), HANDLE()
251 handles = HANDLE(), HANDLE()
255 if not CreatePipe(ctypes.byref(handles[0]),
252 if not CreatePipe(ctypes.byref(handles[0]),
256 ctypes.byref(handles[1]), ctypes.byref(saAttr), 0):
253 ctypes.byref(handles[1]), ctypes.byref(saAttr), 0):
257 raise ctypes.WinError()
254 raise ctypes.WinError()
258 if not SetHandleInformation(handles[uninherit],
255 if not SetHandleInformation(handles[uninherit],
259 HANDLE_FLAG_INHERIT, 0):
256 HANDLE_FLAG_INHERIT, 0):
260 raise ctypes.WinError()
257 raise ctypes.WinError()
261 return handles[0].value, handles[1].value
258 return handles[0].value, handles[1].value
262
259
263 p_hstdout, c_hstdout = create_pipe(uninherit=0)
260 p_hstdout, c_hstdout = create_pipe(uninherit=0)
264 # 'mergeout' signals that stdout and stderr should be merged.
261 # 'mergeout' signals that stdout and stderr should be merged.
265 # We do that by using one pipe for both of them.
262 # We do that by using one pipe for both of them.
266 if mergeout:
263 if mergeout:
267 c_hstderr = HANDLE()
264 c_hstderr = HANDLE()
268 if not DuplicateHandle(GetCurrentProcess(), c_hstdout,
265 if not DuplicateHandle(GetCurrentProcess(), c_hstdout,
269 GetCurrentProcess(), ctypes.byref(c_hstderr),
266 GetCurrentProcess(), ctypes.byref(c_hstderr),
270 0, True, DUPLICATE_SAME_ACCESS):
267 0, True, DUPLICATE_SAME_ACCESS):
271 raise ctypes.WinError()
268 raise ctypes.WinError()
272 else:
269 else:
273 p_hstderr, c_hstderr = create_pipe(uninherit=0)
270 p_hstderr, c_hstderr = create_pipe(uninherit=0)
274 c_hstdin, p_hstdin = create_pipe(uninherit=1)
271 c_hstdin, p_hstdin = create_pipe(uninherit=1)
275
272
276 # Create the process object
273 # Create the process object
277 piProcInfo = PROCESS_INFORMATION()
274 piProcInfo = PROCESS_INFORMATION()
278 siStartInfo = STARTUPINFO()
275 siStartInfo = STARTUPINFO()
279 siStartInfo.cb = ctypes.sizeof(siStartInfo)
276 siStartInfo.cb = ctypes.sizeof(siStartInfo)
280 siStartInfo.hStdInput = c_hstdin
277 siStartInfo.hStdInput = c_hstdin
281 siStartInfo.hStdOutput = c_hstdout
278 siStartInfo.hStdOutput = c_hstdout
282 siStartInfo.hStdError = c_hstderr
279 siStartInfo.hStdError = c_hstderr
283 siStartInfo.dwFlags = STARTF_USESTDHANDLES
280 siStartInfo.dwFlags = STARTF_USESTDHANDLES
284 dwCreationFlags = CREATE_SUSPENDED | CREATE_NO_WINDOW # | CREATE_NEW_CONSOLE
281 dwCreationFlags = CREATE_SUSPENDED | CREATE_NO_WINDOW # | CREATE_NEW_CONSOLE
285
282
286 if not CreateProcess(None,
283 if not CreateProcess(None,
287 u"cmd.exe /c " + cmd,
284 u"cmd.exe /c " + cmd,
288 None, None, True, dwCreationFlags,
285 None, None, True, dwCreationFlags,
289 None, None, ctypes.byref(siStartInfo),
286 None, None, ctypes.byref(siStartInfo),
290 ctypes.byref(piProcInfo)):
287 ctypes.byref(piProcInfo)):
291 raise ctypes.WinError()
288 raise ctypes.WinError()
292
289
293 # Close this process's versions of the child handles
290 # Close this process's versions of the child handles
294 CloseHandle(c_hstdin)
291 CloseHandle(c_hstdin)
295 c_hstdin = None
292 c_hstdin = None
296 CloseHandle(c_hstdout)
293 CloseHandle(c_hstdout)
297 c_hstdout = None
294 c_hstdout = None
298 if c_hstderr is not None:
295 if c_hstderr is not None:
299 CloseHandle(c_hstderr)
296 CloseHandle(c_hstderr)
300 c_hstderr = None
297 c_hstderr = None
301
298
302 # Transfer ownership of the parent handles to the object
299 # Transfer ownership of the parent handles to the object
303 self.hstdin = p_hstdin
300 self.hstdin = p_hstdin
304 p_hstdin = None
301 p_hstdin = None
305 self.hstdout = p_hstdout
302 self.hstdout = p_hstdout
306 p_hstdout = None
303 p_hstdout = None
307 if not mergeout:
304 if not mergeout:
308 self.hstderr = p_hstderr
305 self.hstderr = p_hstderr
309 p_hstderr = None
306 p_hstderr = None
310 self.piProcInfo = piProcInfo
307 self.piProcInfo = piProcInfo
311
308
312 finally:
309 finally:
313 if p_hstdin:
310 if p_hstdin:
314 CloseHandle(p_hstdin)
311 CloseHandle(p_hstdin)
315 if c_hstdin:
312 if c_hstdin:
316 CloseHandle(c_hstdin)
313 CloseHandle(c_hstdin)
317 if p_hstdout:
314 if p_hstdout:
318 CloseHandle(p_hstdout)
315 CloseHandle(p_hstdout)
319 if c_hstdout:
316 if c_hstdout:
320 CloseHandle(c_hstdout)
317 CloseHandle(c_hstdout)
321 if p_hstderr:
318 if p_hstderr:
322 CloseHandle(p_hstderr)
319 CloseHandle(p_hstderr)
323 if c_hstderr:
320 if c_hstderr:
324 CloseHandle(c_hstderr)
321 CloseHandle(c_hstderr)
325
322
326 return self
323 return self
327
324
328 def _stdin_thread(self, handle, hprocess, func, stdout_func):
325 def _stdin_thread(self, handle, hprocess, func, stdout_func):
329 exitCode = DWORD()
326 exitCode = DWORD()
330 bytesWritten = DWORD(0)
327 bytesWritten = DWORD(0)
331 while True:
328 while True:
332 #print("stdin thread loop start")
329 #print("stdin thread loop start")
333 # Get the input string (may be bytes or unicode)
330 # Get the input string (may be bytes or unicode)
334 data = func()
331 data = func()
335
332
336 # None signals to poll whether the process has exited
333 # None signals to poll whether the process has exited
337 if data is None:
334 if data is None:
338 #print("checking for process completion")
335 #print("checking for process completion")
339 if not GetExitCodeProcess(hprocess, ctypes.byref(exitCode)):
336 if not GetExitCodeProcess(hprocess, ctypes.byref(exitCode)):
340 raise ctypes.WinError()
337 raise ctypes.WinError()
341 if exitCode.value != STILL_ACTIVE:
338 if exitCode.value != STILL_ACTIVE:
342 return
339 return
343 # TESTING: Does zero-sized writefile help?
340 # TESTING: Does zero-sized writefile help?
344 if not WriteFile(handle, "", 0,
341 if not WriteFile(handle, "", 0,
345 ctypes.byref(bytesWritten), None):
342 ctypes.byref(bytesWritten), None):
346 raise ctypes.WinError()
343 raise ctypes.WinError()
347 continue
344 continue
348 #print("\nGot str %s\n" % repr(data), file=sys.stderr)
345 #print("\nGot str %s\n" % repr(data), file=sys.stderr)
349
346
350 # Encode the string to the console encoding
347 # Encode the string to the console encoding
351 if isinstance(data, unicode): #FIXME: Python3
348 if isinstance(data, unicode): #FIXME: Python3
352 data = data.encode('utf_8')
349 data = data.encode('utf_8')
353
350
354 # What we have now must be a string of bytes
351 # What we have now must be a string of bytes
355 if not isinstance(data, str): #FIXME: Python3
352 if not isinstance(data, str): #FIXME: Python3
356 raise RuntimeError("internal stdin function string error")
353 raise RuntimeError("internal stdin function string error")
357
354
358 # An empty string signals EOF
355 # An empty string signals EOF
359 if len(data) == 0:
356 if len(data) == 0:
360 return
357 return
361
358
362 # In a windows console, sometimes the input is echoed,
359 # In a windows console, sometimes the input is echoed,
363 # but sometimes not. How do we determine when to do this?
360 # but sometimes not. How do we determine when to do this?
364 stdout_func(data)
361 stdout_func(data)
365 # WriteFile may not accept all the data at once.
362 # WriteFile may not accept all the data at once.
366 # Loop until everything is processed
363 # Loop until everything is processed
367 while len(data) != 0:
364 while len(data) != 0:
368 #print("Calling writefile")
365 #print("Calling writefile")
369 if not WriteFile(handle, data, len(data),
366 if not WriteFile(handle, data, len(data),
370 ctypes.byref(bytesWritten), None):
367 ctypes.byref(bytesWritten), None):
371 # This occurs at exit
368 # This occurs at exit
372 if GetLastError() == ERROR_NO_DATA:
369 if GetLastError() == ERROR_NO_DATA:
373 return
370 return
374 raise ctypes.WinError()
371 raise ctypes.WinError()
375 #print("Called writefile")
372 #print("Called writefile")
376 data = data[bytesWritten.value:]
373 data = data[bytesWritten.value:]
377
374
378 def _stdout_thread(self, handle, func):
375 def _stdout_thread(self, handle, func):
379 # Allocate the output buffer
376 # Allocate the output buffer
380 data = ctypes.create_string_buffer(4096)
377 data = ctypes.create_string_buffer(4096)
381 while True:
378 while True:
382 bytesRead = DWORD(0)
379 bytesRead = DWORD(0)
383 if not ReadFile(handle, data, 4096,
380 if not ReadFile(handle, data, 4096,
384 ctypes.byref(bytesRead), None):
381 ctypes.byref(bytesRead), None):
385 le = GetLastError()
382 le = GetLastError()
386 if le == ERROR_BROKEN_PIPE:
383 if le == ERROR_BROKEN_PIPE:
387 return
384 return
388 else:
385 else:
389 raise ctypes.WinError()
386 raise ctypes.WinError()
390 # FIXME: Python3
387 # FIXME: Python3
391 s = data.value[0:bytesRead.value]
388 s = data.value[0:bytesRead.value]
392 #print("\nv: %s" % repr(s), file=sys.stderr)
389 #print("\nv: %s" % repr(s), file=sys.stderr)
393 func(s.decode('utf_8', 'replace'))
390 func(s.decode('utf_8', 'replace'))
394
391
395 def run(self, stdout_func = None, stdin_func = None, stderr_func = None):
392 def run(self, stdout_func = None, stdin_func = None, stderr_func = None):
396 """Runs the process, using the provided functions for I/O.
393 """Runs the process, using the provided functions for I/O.
397
394
398 The function stdin_func should return strings whenever a
395 The function stdin_func should return strings whenever a
399 character or characters become available.
396 character or characters become available.
400 The functions stdout_func and stderr_func are called whenever
397 The functions stdout_func and stderr_func are called whenever
401 something is printed to stdout or stderr, respectively.
398 something is printed to stdout or stderr, respectively.
402 These functions are called from different threads (but not
399 These functions are called from different threads (but not
403 concurrently, because of the GIL).
400 concurrently, because of the GIL).
404 """
401 """
405 if stdout_func is None and stdin_func is None and stderr_func is None:
402 if stdout_func is None and stdin_func is None and stderr_func is None:
406 return self._run_stdio()
403 return self._run_stdio()
407
404
408 if stderr_func is not None and self.mergeout:
405 if stderr_func is not None and self.mergeout:
409 raise RuntimeError("Shell command was initiated with "
406 raise RuntimeError("Shell command was initiated with "
410 "merged stdin/stdout, but a separate stderr_func "
407 "merged stdin/stdout, but a separate stderr_func "
411 "was provided to the run() method")
408 "was provided to the run() method")
412
409
413 # Create a thread for each input/output handle
410 # Create a thread for each input/output handle
414 stdin_thread = None
411 stdin_thread = None
415 threads = []
412 threads = []
416 if stdin_func:
413 if stdin_func:
417 stdin_thread = threading.Thread(target=self._stdin_thread,
414 stdin_thread = threading.Thread(target=self._stdin_thread,
418 args=(self.hstdin, self.piProcInfo.hProcess,
415 args=(self.hstdin, self.piProcInfo.hProcess,
419 stdin_func, stdout_func))
416 stdin_func, stdout_func))
420 threads.append(threading.Thread(target=self._stdout_thread,
417 threads.append(threading.Thread(target=self._stdout_thread,
421 args=(self.hstdout, stdout_func)))
418 args=(self.hstdout, stdout_func)))
422 if not self.mergeout:
419 if not self.mergeout:
423 if stderr_func is None:
420 if stderr_func is None:
424 stderr_func = stdout_func
421 stderr_func = stdout_func
425 threads.append(threading.Thread(target=self._stdout_thread,
422 threads.append(threading.Thread(target=self._stdout_thread,
426 args=(self.hstderr, stderr_func)))
423 args=(self.hstderr, stderr_func)))
427 # Start the I/O threads and the process
424 # Start the I/O threads and the process
428 if ResumeThread(self.piProcInfo.hThread) == 0xFFFFFFFF:
425 if ResumeThread(self.piProcInfo.hThread) == 0xFFFFFFFF:
429 raise ctypes.WinError()
426 raise ctypes.WinError()
430 if stdin_thread is not None:
427 if stdin_thread is not None:
431 stdin_thread.start()
428 stdin_thread.start()
432 for thread in threads:
429 for thread in threads:
433 thread.start()
430 thread.start()
434 # Wait for the process to complete
431 # Wait for the process to complete
435 if WaitForSingleObject(self.piProcInfo.hProcess, INFINITE) == \
432 if WaitForSingleObject(self.piProcInfo.hProcess, INFINITE) == \
436 WAIT_FAILED:
433 WAIT_FAILED:
437 raise ctypes.WinError()
434 raise ctypes.WinError()
438 # Wait for the I/O threads to complete
435 # Wait for the I/O threads to complete
439 for thread in threads:
436 for thread in threads:
440 thread.join()
437 thread.join()
441
438
442 # Wait for the stdin thread to complete
439 # Wait for the stdin thread to complete
443 if stdin_thread is not None:
440 if stdin_thread is not None:
444 stdin_thread.join()
441 stdin_thread.join()
445
442
446 def _stdin_raw_nonblock(self):
443 def _stdin_raw_nonblock(self):
447 """Use the raw Win32 handle of sys.stdin to do non-blocking reads"""
444 """Use the raw Win32 handle of sys.stdin to do non-blocking reads"""
448 # WARNING: This is experimental, and produces inconsistent results.
445 # WARNING: This is experimental, and produces inconsistent results.
449 # It's possible for the handle not to be appropriate for use
446 # It's possible for the handle not to be appropriate for use
450 # with WaitForSingleObject, among other things.
447 # with WaitForSingleObject, among other things.
451 handle = msvcrt.get_osfhandle(sys.stdin.fileno())
448 handle = msvcrt.get_osfhandle(sys.stdin.fileno())
452 result = WaitForSingleObject(handle, 100)
449 result = WaitForSingleObject(handle, 100)
453 if result == WAIT_FAILED:
450 if result == WAIT_FAILED:
454 raise ctypes.WinError()
451 raise ctypes.WinError()
455 elif result == WAIT_TIMEOUT:
452 elif result == WAIT_TIMEOUT:
456 print(".", end='')
453 print(".", end='')
457 return None
454 return None
458 else:
455 else:
459 data = ctypes.create_string_buffer(256)
456 data = ctypes.create_string_buffer(256)
460 bytesRead = DWORD(0)
457 bytesRead = DWORD(0)
461 print('?', end='')
458 print('?', end='')
462
459
463 if not ReadFile(handle, data, 256,
460 if not ReadFile(handle, data, 256,
464 ctypes.byref(bytesRead), None):
461 ctypes.byref(bytesRead), None):
465 raise ctypes.WinError()
462 raise ctypes.WinError()
466 # This ensures the non-blocking works with an actual console
463 # This ensures the non-blocking works with an actual console
467 # Not checking the error, so the processing will still work with
464 # Not checking the error, so the processing will still work with
468 # other handle types
465 # other handle types
469 FlushConsoleInputBuffer(handle)
466 FlushConsoleInputBuffer(handle)
470
467
471 data = data.value
468 data = data.value
472 data = data.replace('\r\n', '\n')
469 data = data.replace('\r\n', '\n')
473 data = data.replace('\r', '\n')
470 data = data.replace('\r', '\n')
474 print(repr(data) + " ", end='')
471 print(repr(data) + " ", end='')
475 return data
472 return data
476
473
477 def _stdin_raw_block(self):
474 def _stdin_raw_block(self):
478 """Use a blocking stdin read"""
475 """Use a blocking stdin read"""
479 # The big problem with the blocking read is that it doesn't
476 # The big problem with the blocking read is that it doesn't
480 # exit when it's supposed to in all contexts. An extra
477 # exit when it's supposed to in all contexts. An extra
481 # key-press may be required to trigger the exit.
478 # key-press may be required to trigger the exit.
482 try:
479 try:
483 data = sys.stdin.read(1)
480 data = sys.stdin.read(1)
484 data = data.replace('\r', '\n')
481 data = data.replace('\r', '\n')
485 return data
482 return data
486 except WindowsError as we:
483 except WindowsError as we:
487 if we.winerror == ERROR_NO_DATA:
484 if we.winerror == ERROR_NO_DATA:
488 # This error occurs when the pipe is closed
485 # This error occurs when the pipe is closed
489 return None
486 return None
490 else:
487 else:
491 # Otherwise let the error propagate
488 # Otherwise let the error propagate
492 raise we
489 raise we
493
490
494 def _stdout_raw(self, s):
491 def _stdout_raw(self, s):
495 """Writes the string to stdout"""
492 """Writes the string to stdout"""
496 print(s, end='', file=sys.stdout)
493 print(s, end='', file=sys.stdout)
497 sys.stdout.flush()
494 sys.stdout.flush()
498
495
499 def _stderr_raw(self, s):
496 def _stderr_raw(self, s):
500 """Writes the string to stdout"""
497 """Writes the string to stdout"""
501 print(s, end='', file=sys.stderr)
498 print(s, end='', file=sys.stderr)
502 sys.stderr.flush()
499 sys.stderr.flush()
503
500
504 def _run_stdio(self):
501 def _run_stdio(self):
505 """Runs the process using the system standard I/O.
502 """Runs the process using the system standard I/O.
506
503
507 IMPORTANT: stdin needs to be asynchronous, so the Python
504 IMPORTANT: stdin needs to be asynchronous, so the Python
508 sys.stdin object is not used. Instead,
505 sys.stdin object is not used. Instead,
509 msvcrt.kbhit/getwch are used asynchronously.
506 msvcrt.kbhit/getwch are used asynchronously.
510 """
507 """
511 # Disable Line and Echo mode
508 # Disable Line and Echo mode
512 #lpMode = DWORD()
509 #lpMode = DWORD()
513 #handle = msvcrt.get_osfhandle(sys.stdin.fileno())
510 #handle = msvcrt.get_osfhandle(sys.stdin.fileno())
514 #if GetConsoleMode(handle, ctypes.byref(lpMode)):
511 #if GetConsoleMode(handle, ctypes.byref(lpMode)):
515 # set_console_mode = True
512 # set_console_mode = True
516 # if not SetConsoleMode(handle, lpMode.value &
513 # if not SetConsoleMode(handle, lpMode.value &
517 # ~(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)):
514 # ~(ENABLE_ECHO_INPUT | ENABLE_LINE_INPUT | ENABLE_PROCESSED_INPUT)):
518 # raise ctypes.WinError()
515 # raise ctypes.WinError()
519
516
520 if self.mergeout:
517 if self.mergeout:
521 return self.run(stdout_func = self._stdout_raw,
518 return self.run(stdout_func = self._stdout_raw,
522 stdin_func = self._stdin_raw_block)
519 stdin_func = self._stdin_raw_block)
523 else:
520 else:
524 return self.run(stdout_func = self._stdout_raw,
521 return self.run(stdout_func = self._stdout_raw,
525 stdin_func = self._stdin_raw_block,
522 stdin_func = self._stdin_raw_block,
526 stderr_func = self._stderr_raw)
523 stderr_func = self._stderr_raw)
527
524
528 # Restore the previous console mode
525 # Restore the previous console mode
529 #if set_console_mode:
526 #if set_console_mode:
530 # if not SetConsoleMode(handle, lpMode.value):
527 # if not SetConsoleMode(handle, lpMode.value):
531 # raise ctypes.WinError()
528 # raise ctypes.WinError()
532
529
533 def __exit__(self, exc_type, exc_value, traceback):
530 def __exit__(self, exc_type, exc_value, traceback):
534 if self.hstdin:
531 if self.hstdin:
535 CloseHandle(self.hstdin)
532 CloseHandle(self.hstdin)
536 self.hstdin = None
533 self.hstdin = None
537 if self.hstdout:
534 if self.hstdout:
538 CloseHandle(self.hstdout)
535 CloseHandle(self.hstdout)
539 self.hstdout = None
536 self.hstdout = None
540 if self.hstderr:
537 if self.hstderr:
541 CloseHandle(self.hstderr)
538 CloseHandle(self.hstderr)
542 self.hstderr = None
539 self.hstderr = None
543 if self.piProcInfo is not None:
540 if self.piProcInfo is not None:
544 CloseHandle(self.piProcInfo.hProcess)
541 CloseHandle(self.piProcInfo.hProcess)
545 CloseHandle(self.piProcInfo.hThread)
542 CloseHandle(self.piProcInfo.hThread)
546 self.piProcInfo = None
543 self.piProcInfo = None
547
544
548
545
549 def system(cmd):
546 def system(cmd):
550 """Win32 version of os.system() that works with network shares.
547 """Win32 version of os.system() that works with network shares.
551
548
552 Note that this implementation returns None, as meant for use in IPython.
549 Note that this implementation returns None, as meant for use in IPython.
553
550
554 Parameters
551 Parameters
555 ----------
552 ----------
556 cmd : str
553 cmd : str
557 A command to be executed in the system shell.
554 A command to be executed in the system shell.
558
555
559 Returns
556 Returns
560 -------
557 -------
561 None : we explicitly do NOT return the subprocess status code, as this
558 None : we explicitly do NOT return the subprocess status code, as this
562 utility is meant to be used extensively in IPython, where any return value
559 utility is meant to be used extensively in IPython, where any return value
563 would trigger :func:`sys.displayhook` calls.
560 would trigger :func:`sys.displayhook` calls.
564 """
561 """
565 with AvoidUNCPath() as path:
562 with AvoidUNCPath() as path:
566 if path is not None:
563 if path is not None:
567 cmd = '"pushd %s &&"%s' % (path, cmd)
564 cmd = '"pushd %s &&"%s' % (path, cmd)
568 with Win32ShellCommandController(cmd) as scc:
565 with Win32ShellCommandController(cmd) as scc:
569 scc.run()
566 scc.run()
570
567
571
568
572 if __name__ == "__main__":
569 if __name__ == "__main__":
573 print("Test starting!")
570 print("Test starting!")
574 #system("cmd")
571 #system("cmd")
575 system("python -i")
572 system("python -i")
576 print("Test finished!")
573 print("Test finished!")
@@ -1,69 +1,69
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with external processes.
3 Utilities for working with external processes.
4 """
4 """
5
5
6 # Copyright (c) IPython Development Team.
6 # Copyright (c) IPython Development Team.
7 # Distributed under the terms of the Modified BSD License.
7 # Distributed under the terms of the Modified BSD License.
8
8
9
9
10 import os
10 import os
11 import shutil
11 import sys
12 import sys
12
13
13 if sys.platform == 'win32':
14 if sys.platform == 'win32':
14 from ._process_win32 import system, getoutput, arg_split, check_pid
15 from ._process_win32 import system, getoutput, arg_split, check_pid
15 elif sys.platform == 'cli':
16 elif sys.platform == 'cli':
16 from ._process_cli import system, getoutput, arg_split, check_pid
17 from ._process_cli import system, getoutput, arg_split, check_pid
17 else:
18 else:
18 from ._process_posix import system, getoutput, arg_split, check_pid
19 from ._process_posix import system, getoutput, arg_split, check_pid
19
20
20 from ._process_common import getoutputerror, get_output_error_code, process_handler
21 from ._process_common import getoutputerror, get_output_error_code, process_handler
21 from . import py3compat
22
22
23
23
24 class FindCmdError(Exception):
24 class FindCmdError(Exception):
25 pass
25 pass
26
26
27
27
28 def find_cmd(cmd):
28 def find_cmd(cmd):
29 """Find absolute path to executable cmd in a cross platform manner.
29 """Find absolute path to executable cmd in a cross platform manner.
30
30
31 This function tries to determine the full path to a command line program
31 This function tries to determine the full path to a command line program
32 using `which` on Unix/Linux/OS X and `win32api` on Windows. Most of the
32 using `which` on Unix/Linux/OS X and `win32api` on Windows. Most of the
33 time it will use the version that is first on the users `PATH`.
33 time it will use the version that is first on the users `PATH`.
34
34
35 Warning, don't use this to find IPython command line programs as there
35 Warning, don't use this to find IPython command line programs as there
36 is a risk you will find the wrong one. Instead find those using the
36 is a risk you will find the wrong one. Instead find those using the
37 following code and looking for the application itself::
37 following code and looking for the application itself::
38
38
39 import sys
39 import sys
40 argv = [sys.executable, '-m', 'IPython']
40 argv = [sys.executable, '-m', 'IPython']
41
41
42 Parameters
42 Parameters
43 ----------
43 ----------
44 cmd : str
44 cmd : str
45 The command line program to look for.
45 The command line program to look for.
46 """
46 """
47 path = py3compat.which(cmd)
47 path = shutil.which(cmd)
48 if path is None:
48 if path is None:
49 raise FindCmdError('command could not be found: %s' % cmd)
49 raise FindCmdError('command could not be found: %s' % cmd)
50 return path
50 return path
51
51
52
52
53 def abbrev_cwd():
53 def abbrev_cwd():
54 """ Return abbreviated version of cwd, e.g. d:mydir """
54 """ Return abbreviated version of cwd, e.g. d:mydir """
55 cwd = os.getcwd().replace('\\','/')
55 cwd = os.getcwd().replace('\\','/')
56 drivepart = ''
56 drivepart = ''
57 tail = cwd
57 tail = cwd
58 if sys.platform == 'win32':
58 if sys.platform == 'win32':
59 if len(cwd) < 4:
59 if len(cwd) < 4:
60 return cwd
60 return cwd
61 drivepart,tail = os.path.splitdrive(cwd)
61 drivepart,tail = os.path.splitdrive(cwd)
62
62
63
63
64 parts = tail.split('/')
64 parts = tail.split('/')
65 if len(parts) > 2:
65 if len(parts) > 2:
66 tail = '/'.join(parts[-2:])
66 tail = '/'.join(parts[-2:])
67
67
68 return (drivepart + (
68 return (drivepart + (
69 cwd == '/' and '/' or tail))
69 cwd == '/' and '/' or tail))
@@ -1,336 +1,258
1 # coding: utf-8
1 # coding: utf-8
2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
2 """Compatibility tricks for Python 3. Mainly to do with unicode.
3
4 This file is deprecated and will be removed in a future version.
5 """
3 import functools
6 import functools
4 import os
7 import os
5 import sys
8 import sys
6 import re
9 import re
7 import shutil
10 import shutil
8 import types
11 import types
9 import platform
12 import platform
10
13
11 from .encoding import DEFAULT_ENCODING
14 from .encoding import DEFAULT_ENCODING
12
15
13 def no_code(x, encoding=None):
16 def no_code(x, encoding=None):
14 return x
17 return x
15
18
16 def decode(s, encoding=None):
19 def decode(s, encoding=None):
17 encoding = encoding or DEFAULT_ENCODING
20 encoding = encoding or DEFAULT_ENCODING
18 return s.decode(encoding, "replace")
21 return s.decode(encoding, "replace")
19
22
20 def encode(u, encoding=None):
23 def encode(u, encoding=None):
21 encoding = encoding or DEFAULT_ENCODING
24 encoding = encoding or DEFAULT_ENCODING
22 return u.encode(encoding, "replace")
25 return u.encode(encoding, "replace")
23
26
24
27
25 def cast_unicode(s, encoding=None):
28 def cast_unicode(s, encoding=None):
26 if isinstance(s, bytes):
29 if isinstance(s, bytes):
27 return decode(s, encoding)
30 return decode(s, encoding)
28 return s
31 return s
29
32
30 def cast_bytes(s, encoding=None):
33 def cast_bytes(s, encoding=None):
31 if not isinstance(s, bytes):
34 if not isinstance(s, bytes):
32 return encode(s, encoding)
35 return encode(s, encoding)
33 return s
36 return s
34
37
35 def buffer_to_bytes(buf):
38 def buffer_to_bytes(buf):
36 """Cast a buffer object to bytes"""
39 """Cast a buffer object to bytes"""
37 if not isinstance(buf, bytes):
40 if not isinstance(buf, bytes):
38 buf = bytes(buf)
41 buf = bytes(buf)
39 return buf
42 return buf
40
43
41 def _modify_str_or_docstring(str_change_func):
44 def _modify_str_or_docstring(str_change_func):
42 @functools.wraps(str_change_func)
45 @functools.wraps(str_change_func)
43 def wrapper(func_or_str):
46 def wrapper(func_or_str):
44 if isinstance(func_or_str, string_types):
47 if isinstance(func_or_str, string_types):
45 func = None
48 func = None
46 doc = func_or_str
49 doc = func_or_str
47 else:
50 else:
48 func = func_or_str
51 func = func_or_str
49 doc = func.__doc__
52 doc = func.__doc__
50
53
51 # PYTHONOPTIMIZE=2 strips docstrings, so they can disappear unexpectedly
54 # PYTHONOPTIMIZE=2 strips docstrings, so they can disappear unexpectedly
52 if doc is not None:
55 if doc is not None:
53 doc = str_change_func(doc)
56 doc = str_change_func(doc)
54
57
55 if func:
58 if func:
56 func.__doc__ = doc
59 func.__doc__ = doc
57 return func
60 return func
58 return doc
61 return doc
59 return wrapper
62 return wrapper
60
63
61 def safe_unicode(e):
64 def safe_unicode(e):
62 """unicode(e) with various fallbacks. Used for exceptions, which may not be
65 """unicode(e) with various fallbacks. Used for exceptions, which may not be
63 safe to call unicode() on.
66 safe to call unicode() on.
64 """
67 """
65 try:
68 try:
66 return unicode_type(e)
69 return unicode_type(e)
67 except UnicodeError:
70 except UnicodeError:
68 pass
71 pass
69
72
70 try:
73 try:
71 return str_to_unicode(str(e))
74 return str_to_unicode(str(e))
72 except UnicodeError:
75 except UnicodeError:
73 pass
76 pass
74
77
75 try:
78 try:
76 return str_to_unicode(repr(e))
79 return str_to_unicode(repr(e))
77 except UnicodeError:
80 except UnicodeError:
78 pass
81 pass
79
82
80 return u'Unrecoverably corrupt evalue'
83 return u'Unrecoverably corrupt evalue'
81
84
82 # shutil.which from Python 3.4
85 # shutil.which from Python 3.4
83 def _shutil_which(cmd, mode=os.F_OK | os.X_OK, path=None):
86 def _shutil_which(cmd, mode=os.F_OK | os.X_OK, path=None):
84 """Given a command, mode, and a PATH string, return the path which
87 """Given a command, mode, and a PATH string, return the path which
85 conforms to the given mode on the PATH, or None if there is no such
88 conforms to the given mode on the PATH, or None if there is no such
86 file.
89 file.
87
90
88 `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
91 `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result
89 of os.environ.get("PATH"), or can be overridden with a custom search
92 of os.environ.get("PATH"), or can be overridden with a custom search
90 path.
93 path.
91
94
92 This is a backport of shutil.which from Python 3.4
95 This is a backport of shutil.which from Python 3.4
93 """
96 """
94 # Check that a given file can be accessed with the correct mode.
97 # Check that a given file can be accessed with the correct mode.
95 # Additionally check that `file` is not a directory, as on Windows
98 # Additionally check that `file` is not a directory, as on Windows
96 # directories pass the os.access check.
99 # directories pass the os.access check.
97 def _access_check(fn, mode):
100 def _access_check(fn, mode):
98 return (os.path.exists(fn) and os.access(fn, mode)
101 return (os.path.exists(fn) and os.access(fn, mode)
99 and not os.path.isdir(fn))
102 and not os.path.isdir(fn))
100
103
101 # If we're given a path with a directory part, look it up directly rather
104 # If we're given a path with a directory part, look it up directly rather
102 # than referring to PATH directories. This includes checking relative to the
105 # than referring to PATH directories. This includes checking relative to the
103 # current directory, e.g. ./script
106 # current directory, e.g. ./script
104 if os.path.dirname(cmd):
107 if os.path.dirname(cmd):
105 if _access_check(cmd, mode):
108 if _access_check(cmd, mode):
106 return cmd
109 return cmd
107 return None
110 return None
108
111
109 if path is None:
112 if path is None:
110 path = os.environ.get("PATH", os.defpath)
113 path = os.environ.get("PATH", os.defpath)
111 if not path:
114 if not path:
112 return None
115 return None
113 path = path.split(os.pathsep)
116 path = path.split(os.pathsep)
114
117
115 if sys.platform == "win32":
118 if sys.platform == "win32":
116 # The current directory takes precedence on Windows.
119 # The current directory takes precedence on Windows.
117 if not os.curdir in path:
120 if not os.curdir in path:
118 path.insert(0, os.curdir)
121 path.insert(0, os.curdir)
119
122
120 # PATHEXT is necessary to check on Windows.
123 # PATHEXT is necessary to check on Windows.
121 pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
124 pathext = os.environ.get("PATHEXT", "").split(os.pathsep)
122 # See if the given file matches any of the expected path extensions.
125 # See if the given file matches any of the expected path extensions.
123 # This will allow us to short circuit when given "python.exe".
126 # This will allow us to short circuit when given "python.exe".
124 # If it does match, only test that one, otherwise we have to try
127 # If it does match, only test that one, otherwise we have to try
125 # others.
128 # others.
126 if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
129 if any(cmd.lower().endswith(ext.lower()) for ext in pathext):
127 files = [cmd]
130 files = [cmd]
128 else:
131 else:
129 files = [cmd + ext for ext in pathext]
132 files = [cmd + ext for ext in pathext]
130 else:
133 else:
131 # On other platforms you don't have things like PATHEXT to tell you
134 # On other platforms you don't have things like PATHEXT to tell you
132 # what file suffixes are executable, so just pass on cmd as-is.
135 # what file suffixes are executable, so just pass on cmd as-is.
133 files = [cmd]
136 files = [cmd]
134
137
135 seen = set()
138 seen = set()
136 for dir in path:
139 for dir in path:
137 normdir = os.path.normcase(dir)
140 normdir = os.path.normcase(dir)
138 if not normdir in seen:
141 if not normdir in seen:
139 seen.add(normdir)
142 seen.add(normdir)
140 for thefile in files:
143 for thefile in files:
141 name = os.path.join(dir, thefile)
144 name = os.path.join(dir, thefile)
142 if _access_check(name, mode):
145 if _access_check(name, mode):
143 return name
146 return name
144 return None
147 return None
145
148
146 if sys.version_info[0] >= 3:
149 PY3 = True
147 PY3 = True
148
149 # keep reference to builtin_mod because the kernel overrides that value
150 # to forward requests to a frontend.
151 def input(prompt=''):
152 return builtin_mod.input(prompt)
153
154 builtin_mod_name = "builtins"
155 import builtins as builtin_mod
156
157 str_to_unicode = no_code
158 unicode_to_str = no_code
159 str_to_bytes = encode
160 bytes_to_str = decode
161 cast_bytes_py2 = no_code
162 cast_unicode_py2 = no_code
163 buffer_to_bytes_py2 = no_code
164
165 string_types = (str,)
166 unicode_type = str
167
168 which = shutil.which
169
170 def isidentifier(s, dotted=False):
171 if dotted:
172 return all(isidentifier(a) for a in s.split("."))
173 return s.isidentifier()
174
175 xrange = range
176 def iteritems(d): return iter(d.items())
177 def itervalues(d): return iter(d.values())
178 getcwd = os.getcwd
179
180 MethodType = types.MethodType
181
182 def execfile(fname, glob, loc=None, compiler=None):
183 loc = loc if (loc is not None) else glob
184 with open(fname, 'rb') as f:
185 compiler = compiler or compile
186 exec(compiler(f.read(), fname, 'exec'), glob, loc)
187
188 # Refactor print statements in doctests.
189 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
190 def _print_statement_sub(match):
191 expr = match.groups('expr')
192 return "print(%s)" % expr
193
194 @_modify_str_or_docstring
195 def doctest_refactor_print(doc):
196 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
197 unfortunately doesn't pick up on our doctests.
198
199 Can accept a string or a function, so it can be used as a decorator."""
200 return _print_statement_re.sub(_print_statement_sub, doc)
201
202 # Abstract u'abc' syntax:
203 @_modify_str_or_docstring
204 def u_format(s):
205 """"{u}'abc'" --> "'abc'" (Python 3)
206
207 Accepts a string or a function, so it can be used as a decorator."""
208 return s.format(u='')
209
210 def get_closure(f):
211 """Get a function's closure attribute"""
212 return f.__closure__
213
214 else:
215 PY3 = False
216
217 # keep reference to builtin_mod because the kernel overrides that value
218 # to forward requests to a frontend.
219 def input(prompt=''):
220 return builtin_mod.raw_input(prompt)
221
222 builtin_mod_name = "__builtin__"
223 import __builtin__ as builtin_mod
224
225 str_to_unicode = decode
226 unicode_to_str = encode
227 str_to_bytes = no_code
228 bytes_to_str = no_code
229 cast_bytes_py2 = cast_bytes
230 cast_unicode_py2 = cast_unicode
231 buffer_to_bytes_py2 = buffer_to_bytes
232
233 string_types = (str, unicode)
234 unicode_type = unicode
235
236 import re
237 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
238 def isidentifier(s, dotted=False):
239 if dotted:
240 return all(isidentifier(a) for a in s.split("."))
241 return bool(_name_re.match(s))
242
243 xrange = xrange
244 def iteritems(d): return d.iteritems()
245 def itervalues(d): return d.itervalues()
246 getcwd = os.getcwdu
247
248 def MethodType(func, instance):
249 return types.MethodType(func, instance, type(instance))
250
251 def doctest_refactor_print(func_or_str):
252 return func_or_str
253
254 def get_closure(f):
255 """Get a function's closure attribute"""
256 return f.func_closure
257
258 which = _shutil_which
259
260 # Abstract u'abc' syntax:
261 @_modify_str_or_docstring
262 def u_format(s):
263 """"{u}'abc'" --> "u'abc'" (Python 2)
264
265 Accepts a string or a function, so it can be used as a decorator."""
266 return s.format(u='u')
267
268 if sys.platform == 'win32':
269 def execfile(fname, glob=None, loc=None, compiler=None):
270 loc = loc if (loc is not None) else glob
271 scripttext = builtin_mod.open(fname).read()+ '\n'
272 # compile converts unicode filename to str assuming
273 # ascii. Let's do the conversion before calling compile
274 if isinstance(fname, unicode):
275 filename = unicode_to_str(fname)
276 else:
277 filename = fname
278 compiler = compiler or compile
279 exec(compiler(scripttext, filename, 'exec'), glob, loc)
280
150
281 else:
151 # keep reference to builtin_mod because the kernel overrides that value
282 def execfile(fname, glob=None, loc=None, compiler=None):
152 # to forward requests to a frontend.
283 if isinstance(fname, unicode):
153 def input(prompt=''):
284 filename = fname.encode(sys.getfilesystemencoding())
154 return builtin_mod.input(prompt)
285 else:
155
286 filename = fname
156 builtin_mod_name = "builtins"
287 where = [ns for ns in [glob, loc] if ns is not None]
157 import builtins as builtin_mod
288 if compiler is None:
158
289 builtin_mod.execfile(filename, *where)
159 str_to_unicode = no_code
290 else:
160 unicode_to_str = no_code
291 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
161 str_to_bytes = encode
292 exec(compiler(scripttext, filename, 'exec'), glob, loc)
162 bytes_to_str = decode
163 cast_bytes_py2 = no_code
164 cast_unicode_py2 = no_code
165 buffer_to_bytes_py2 = no_code
166
167 string_types = (str,)
168 unicode_type = str
169
170 which = shutil.which
171
172 def isidentifier(s, dotted=False):
173 if dotted:
174 return all(isidentifier(a) for a in s.split("."))
175 return s.isidentifier()
176
177 xrange = range
178 def iteritems(d): return iter(d.items())
179 def itervalues(d): return iter(d.values())
180 getcwd = os.getcwd
181
182 MethodType = types.MethodType
183
184 def execfile(fname, glob, loc=None, compiler=None):
185 loc = loc if (loc is not None) else glob
186 with open(fname, 'rb') as f:
187 compiler = compiler or compile
188 exec(compiler(f.read(), fname, 'exec'), glob, loc)
189
190 # Refactor print statements in doctests.
191 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
192 def _print_statement_sub(match):
193 expr = match.groups('expr')
194 return "print(%s)" % expr
195
196 @_modify_str_or_docstring
197 def doctest_refactor_print(doc):
198 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
199 unfortunately doesn't pick up on our doctests.
200
201 Can accept a string or a function, so it can be used as a decorator."""
202 return _print_statement_re.sub(_print_statement_sub, doc)
203
204 # Abstract u'abc' syntax:
205 @_modify_str_or_docstring
206 def u_format(s):
207 """"{u}'abc'" --> "'abc'" (Python 3)
208
209 Accepts a string or a function, so it can be used as a decorator."""
210 return s.format(u='')
211
212 def get_closure(f):
213 """Get a function's closure attribute"""
214 return f.__closure__
293
215
294
216
295 PY2 = not PY3
217 PY2 = not PY3
296 PYPY = platform.python_implementation() == "PyPy"
218 PYPY = platform.python_implementation() == "PyPy"
297
219
298
220
299 def annotate(**kwargs):
221 def annotate(**kwargs):
300 """Python 3 compatible function annotation for Python 2."""
222 """Python 3 compatible function annotation for Python 2."""
301 if not kwargs:
223 if not kwargs:
302 raise ValueError('annotations must be provided as keyword arguments')
224 raise ValueError('annotations must be provided as keyword arguments')
303 def dec(f):
225 def dec(f):
304 if hasattr(f, '__annotations__'):
226 if hasattr(f, '__annotations__'):
305 for k, v in kwargs.items():
227 for k, v in kwargs.items():
306 f.__annotations__[k] = v
228 f.__annotations__[k] = v
307 else:
229 else:
308 f.__annotations__ = kwargs
230 f.__annotations__ = kwargs
309 return f
231 return f
310 return dec
232 return dec
311
233
312
234
313 # Parts below taken from six:
235 # Parts below taken from six:
314 # Copyright (c) 2010-2013 Benjamin Peterson
236 # Copyright (c) 2010-2013 Benjamin Peterson
315 #
237 #
316 # Permission is hereby granted, free of charge, to any person obtaining a copy
238 # Permission is hereby granted, free of charge, to any person obtaining a copy
317 # of this software and associated documentation files (the "Software"), to deal
239 # of this software and associated documentation files (the "Software"), to deal
318 # in the Software without restriction, including without limitation the rights
240 # in the Software without restriction, including without limitation the rights
319 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
241 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
320 # copies of the Software, and to permit persons to whom the Software is
242 # copies of the Software, and to permit persons to whom the Software is
321 # furnished to do so, subject to the following conditions:
243 # furnished to do so, subject to the following conditions:
322 #
244 #
323 # The above copyright notice and this permission notice shall be included in all
245 # The above copyright notice and this permission notice shall be included in all
324 # copies or substantial portions of the Software.
246 # copies or substantial portions of the Software.
325 #
247 #
326 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
248 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
327 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
249 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
328 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
250 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
329 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
251 # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
330 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
252 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
331 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
253 # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
332 # SOFTWARE.
254 # SOFTWARE.
333
255
334 def with_metaclass(meta, *bases):
256 def with_metaclass(meta, *bases):
335 """Create a base class with a metaclass."""
257 """Create a base class with a metaclass."""
336 return meta("_NewBase", bases, {})
258 return meta("_NewBase", bases, {})
@@ -1,595 +1,589
1 """Patched version of standard library tokenize, to deal with various bugs.
1 """Patched version of standard library tokenize, to deal with various bugs.
2
2
3 Based on Python 3.2 code.
3 Based on Python 3.2 code.
4
4
5 Patches:
5 Patches:
6
6
7 - Gareth Rees' patch for Python issue #12691 (untokenizing)
7 - Gareth Rees' patch for Python issue #12691 (untokenizing)
8 - Except we don't encode the output of untokenize
8 - Except we don't encode the output of untokenize
9 - Python 2 compatible syntax, so that it can be byte-compiled at installation
9 - Python 2 compatible syntax, so that it can be byte-compiled at installation
10 - Newlines in comments and blank lines should be either NL or NEWLINE, depending
10 - Newlines in comments and blank lines should be either NL or NEWLINE, depending
11 on whether they are in a multi-line statement. Filed as Python issue #17061.
11 on whether they are in a multi-line statement. Filed as Python issue #17061.
12 - Export generate_tokens & TokenError
12 - Export generate_tokens & TokenError
13 - u and rb literals are allowed under Python 3.3 and above.
13 - u and rb literals are allowed under Python 3.3 and above.
14
14
15 ------------------------------------------------------------------------------
15 ------------------------------------------------------------------------------
16 Tokenization help for Python programs.
16 Tokenization help for Python programs.
17
17
18 tokenize(readline) is a generator that breaks a stream of bytes into
18 tokenize(readline) is a generator that breaks a stream of bytes into
19 Python tokens. It decodes the bytes according to PEP-0263 for
19 Python tokens. It decodes the bytes according to PEP-0263 for
20 determining source file encoding.
20 determining source file encoding.
21
21
22 It accepts a readline-like method which is called repeatedly to get the
22 It accepts a readline-like method which is called repeatedly to get the
23 next line of input (or b"" for EOF). It generates 5-tuples with these
23 next line of input (or b"" for EOF). It generates 5-tuples with these
24 members:
24 members:
25
25
26 the token type (see token.py)
26 the token type (see token.py)
27 the token (a string)
27 the token (a string)
28 the starting (row, column) indices of the token (a 2-tuple of ints)
28 the starting (row, column) indices of the token (a 2-tuple of ints)
29 the ending (row, column) indices of the token (a 2-tuple of ints)
29 the ending (row, column) indices of the token (a 2-tuple of ints)
30 the original line (string)
30 the original line (string)
31
31
32 It is designed to match the working of the Python tokenizer exactly, except
32 It is designed to match the working of the Python tokenizer exactly, except
33 that it produces COMMENT tokens for comments and gives type OP for all
33 that it produces COMMENT tokens for comments and gives type OP for all
34 operators. Additionally, all token lists start with an ENCODING token
34 operators. Additionally, all token lists start with an ENCODING token
35 which tells you which encoding was used to decode the bytes stream.
35 which tells you which encoding was used to decode the bytes stream.
36 """
36 """
37
37
38 __author__ = 'Ka-Ping Yee <ping@lfw.org>'
38 __author__ = 'Ka-Ping Yee <ping@lfw.org>'
39 __credits__ = ('GvR, ESR, Tim Peters, Thomas Wouters, Fred Drake, '
39 __credits__ = ('GvR, ESR, Tim Peters, Thomas Wouters, Fred Drake, '
40 'Skip Montanaro, Raymond Hettinger, Trent Nelson, '
40 'Skip Montanaro, Raymond Hettinger, Trent Nelson, '
41 'Michael Foord')
41 'Michael Foord')
42 import builtins
42 import builtins
43 import re
43 import re
44 import sys
44 import sys
45 from token import *
45 from token import *
46 from codecs import lookup, BOM_UTF8
46 from codecs import lookup, BOM_UTF8
47 import collections
47 import collections
48 from io import TextIOWrapper
48 from io import TextIOWrapper
49 cookie_re = re.compile("coding[:=]\s*([-\w.]+)")
49 cookie_re = re.compile("coding[:=]\s*([-\w.]+)")
50
50
51 import token
51 import token
52 __all__ = token.__all__ + ["COMMENT", "tokenize", "detect_encoding",
52 __all__ = token.__all__ + ["COMMENT", "tokenize", "detect_encoding",
53 "NL", "untokenize", "ENCODING", "TokenInfo"]
53 "NL", "untokenize", "ENCODING", "TokenInfo"]
54 del token
54 del token
55
55
56 __all__ += ["generate_tokens", "TokenError"]
56 __all__ += ["generate_tokens", "TokenError"]
57
57
58 COMMENT = N_TOKENS
58 COMMENT = N_TOKENS
59 tok_name[COMMENT] = 'COMMENT'
59 tok_name[COMMENT] = 'COMMENT'
60 NL = N_TOKENS + 1
60 NL = N_TOKENS + 1
61 tok_name[NL] = 'NL'
61 tok_name[NL] = 'NL'
62 ENCODING = N_TOKENS + 2
62 ENCODING = N_TOKENS + 2
63 tok_name[ENCODING] = 'ENCODING'
63 tok_name[ENCODING] = 'ENCODING'
64 N_TOKENS += 3
64 N_TOKENS += 3
65
65
66 class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')):
66 class TokenInfo(collections.namedtuple('TokenInfo', 'type string start end line')):
67 def __repr__(self):
67 def __repr__(self):
68 annotated_type = '%d (%s)' % (self.type, tok_name[self.type])
68 annotated_type = '%d (%s)' % (self.type, tok_name[self.type])
69 return ('TokenInfo(type=%s, string=%r, start=%r, end=%r, line=%r)' %
69 return ('TokenInfo(type=%s, string=%r, start=%r, end=%r, line=%r)' %
70 self._replace(type=annotated_type))
70 self._replace(type=annotated_type))
71
71
72 def group(*choices): return '(' + '|'.join(choices) + ')'
72 def group(*choices): return '(' + '|'.join(choices) + ')'
73 def any(*choices): return group(*choices) + '*'
73 def any(*choices): return group(*choices) + '*'
74 def maybe(*choices): return group(*choices) + '?'
74 def maybe(*choices): return group(*choices) + '?'
75
75
76 # Note: we use unicode matching for names ("\w") but ascii matching for
76 # Note: we use unicode matching for names ("\w") but ascii matching for
77 # number literals.
77 # number literals.
78 Whitespace = r'[ \f\t]*'
78 Whitespace = r'[ \f\t]*'
79 Comment = r'#[^\r\n]*'
79 Comment = r'#[^\r\n]*'
80 Ignore = Whitespace + any(r'\\\r?\n' + Whitespace) + maybe(Comment)
80 Ignore = Whitespace + any(r'\\\r?\n' + Whitespace) + maybe(Comment)
81 Name = r'\w+'
81 Name = r'\w+'
82
82
83 Hexnumber = r'0[xX][0-9a-fA-F]+'
83 Hexnumber = r'0[xX][0-9a-fA-F]+'
84 Binnumber = r'0[bB][01]+'
84 Binnumber = r'0[bB][01]+'
85 Octnumber = r'0[oO][0-7]+'
85 Octnumber = r'0[oO][0-7]+'
86 Decnumber = r'(?:0+|[1-9][0-9]*)'
86 Decnumber = r'(?:0+|[1-9][0-9]*)'
87 Intnumber = group(Hexnumber, Binnumber, Octnumber, Decnumber)
87 Intnumber = group(Hexnumber, Binnumber, Octnumber, Decnumber)
88 Exponent = r'[eE][-+]?[0-9]+'
88 Exponent = r'[eE][-+]?[0-9]+'
89 Pointfloat = group(r'[0-9]+\.[0-9]*', r'\.[0-9]+') + maybe(Exponent)
89 Pointfloat = group(r'[0-9]+\.[0-9]*', r'\.[0-9]+') + maybe(Exponent)
90 Expfloat = r'[0-9]+' + Exponent
90 Expfloat = r'[0-9]+' + Exponent
91 Floatnumber = group(Pointfloat, Expfloat)
91 Floatnumber = group(Pointfloat, Expfloat)
92 Imagnumber = group(r'[0-9]+[jJ]', Floatnumber + r'[jJ]')
92 Imagnumber = group(r'[0-9]+[jJ]', Floatnumber + r'[jJ]')
93 Number = group(Imagnumber, Floatnumber, Intnumber)
93 Number = group(Imagnumber, Floatnumber, Intnumber)
94
94 StringPrefix = r'(?:[bB][rR]?|[rR][bB]?|[uU])?'
95 if sys.version_info.minor >= 3:
96 StringPrefix = r'(?:[bB][rR]?|[rR][bB]?|[uU])?'
97 else:
98 StringPrefix = r'(?:[bB]?[rR]?)?'
99
95
100 # Tail end of ' string.
96 # Tail end of ' string.
101 Single = r"[^'\\]*(?:\\.[^'\\]*)*'"
97 Single = r"[^'\\]*(?:\\.[^'\\]*)*'"
102 # Tail end of " string.
98 # Tail end of " string.
103 Double = r'[^"\\]*(?:\\.[^"\\]*)*"'
99 Double = r'[^"\\]*(?:\\.[^"\\]*)*"'
104 # Tail end of ''' string.
100 # Tail end of ''' string.
105 Single3 = r"[^'\\]*(?:(?:\\.|'(?!''))[^'\\]*)*'''"
101 Single3 = r"[^'\\]*(?:(?:\\.|'(?!''))[^'\\]*)*'''"
106 # Tail end of """ string.
102 # Tail end of """ string.
107 Double3 = r'[^"\\]*(?:(?:\\.|"(?!""))[^"\\]*)*"""'
103 Double3 = r'[^"\\]*(?:(?:\\.|"(?!""))[^"\\]*)*"""'
108 Triple = group(StringPrefix + "'''", StringPrefix + '"""')
104 Triple = group(StringPrefix + "'''", StringPrefix + '"""')
109 # Single-line ' or " string.
105 # Single-line ' or " string.
110 String = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'",
106 String = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*'",
111 StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*"')
107 StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*"')
112
108
113 # Because of leftmost-then-longest match semantics, be sure to put the
109 # Because of leftmost-then-longest match semantics, be sure to put the
114 # longest operators first (e.g., if = came before ==, == would get
110 # longest operators first (e.g., if = came before ==, == would get
115 # recognized as two instances of =).
111 # recognized as two instances of =).
116 Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=",
112 Operator = group(r"\*\*=?", r">>=?", r"<<=?", r"!=",
117 r"//=?", r"->",
113 r"//=?", r"->",
118 r"[+\-*/%&|^=<>]=?",
114 r"[+\-*/%&|^=<>]=?",
119 r"~")
115 r"~")
120
116
121 Bracket = '[][(){}]'
117 Bracket = '[][(){}]'
122 Special = group(r'\r?\n', r'\.\.\.', r'[:;.,@]')
118 Special = group(r'\r?\n', r'\.\.\.', r'[:;.,@]')
123 Funny = group(Operator, Bracket, Special)
119 Funny = group(Operator, Bracket, Special)
124
120
125 PlainToken = group(Number, Funny, String, Name)
121 PlainToken = group(Number, Funny, String, Name)
126 Token = Ignore + PlainToken
122 Token = Ignore + PlainToken
127
123
128 # First (or only) line of ' or " string.
124 # First (or only) line of ' or " string.
129 ContStr = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*" +
125 ContStr = group(StringPrefix + r"'[^\n'\\]*(?:\\.[^\n'\\]*)*" +
130 group("'", r'\\\r?\n'),
126 group("'", r'\\\r?\n'),
131 StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*' +
127 StringPrefix + r'"[^\n"\\]*(?:\\.[^\n"\\]*)*' +
132 group('"', r'\\\r?\n'))
128 group('"', r'\\\r?\n'))
133 PseudoExtras = group(r'\\\r?\n', Comment, Triple)
129 PseudoExtras = group(r'\\\r?\n', Comment, Triple)
134 PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
130 PseudoToken = Whitespace + group(PseudoExtras, Number, Funny, ContStr, Name)
135
131
136 def _compile(expr):
132 def _compile(expr):
137 return re.compile(expr, re.UNICODE)
133 return re.compile(expr, re.UNICODE)
138
134
139 tokenprog, pseudoprog, single3prog, double3prog = map(
135 tokenprog, pseudoprog, single3prog, double3prog = map(
140 _compile, (Token, PseudoToken, Single3, Double3))
136 _compile, (Token, PseudoToken, Single3, Double3))
141 endprogs = {"'": _compile(Single), '"': _compile(Double),
137 endprogs = {"'": _compile(Single), '"': _compile(Double),
142 "'''": single3prog, '"""': double3prog,
138 "'''": single3prog, '"""': double3prog,
143 "r'''": single3prog, 'r"""': double3prog,
139 "r'''": single3prog, 'r"""': double3prog,
144 "b'''": single3prog, 'b"""': double3prog,
140 "b'''": single3prog, 'b"""': double3prog,
145 "R'''": single3prog, 'R"""': double3prog,
141 "R'''": single3prog, 'R"""': double3prog,
146 "B'''": single3prog, 'B"""': double3prog,
142 "B'''": single3prog, 'B"""': double3prog,
147 "br'''": single3prog, 'br"""': double3prog,
143 "br'''": single3prog, 'br"""': double3prog,
148 "bR'''": single3prog, 'bR"""': double3prog,
144 "bR'''": single3prog, 'bR"""': double3prog,
149 "Br'''": single3prog, 'Br"""': double3prog,
145 "Br'''": single3prog, 'Br"""': double3prog,
150 "BR'''": single3prog, 'BR"""': double3prog,
146 "BR'''": single3prog, 'BR"""': double3prog,
151 'r': None, 'R': None, 'b': None, 'B': None}
147 'r': None, 'R': None, 'b': None, 'B': None}
152
148
153 triple_quoted = {}
149 triple_quoted = {}
154 for t in ("'''", '"""',
150 for t in ("'''", '"""',
155 "r'''", 'r"""', "R'''", 'R"""',
151 "r'''", 'r"""', "R'''", 'R"""',
156 "b'''", 'b"""', "B'''", 'B"""',
152 "b'''", 'b"""', "B'''", 'B"""',
157 "br'''", 'br"""', "Br'''", 'Br"""',
153 "br'''", 'br"""', "Br'''", 'Br"""',
158 "bR'''", 'bR"""', "BR'''", 'BR"""'):
154 "bR'''", 'bR"""', "BR'''", 'BR"""'):
159 triple_quoted[t] = t
155 triple_quoted[t] = t
160 single_quoted = {}
156 single_quoted = {}
161 for t in ("'", '"',
157 for t in ("'", '"',
162 "r'", 'r"', "R'", 'R"',
158 "r'", 'r"', "R'", 'R"',
163 "b'", 'b"', "B'", 'B"',
159 "b'", 'b"', "B'", 'B"',
164 "br'", 'br"', "Br'", 'Br"',
160 "br'", 'br"', "Br'", 'Br"',
165 "bR'", 'bR"', "BR'", 'BR"' ):
161 "bR'", 'bR"', "BR'", 'BR"' ):
166 single_quoted[t] = t
162 single_quoted[t] = t
167
163
168 if sys.version_info.minor >= 3:
164 for _prefix in ['rb', 'rB', 'Rb', 'RB', 'u', 'U']:
169 # Python 3.3
165 _t2 = _prefix+'"""'
170 for _prefix in ['rb', 'rB', 'Rb', 'RB', 'u', 'U']:
166 endprogs[_t2] = double3prog
171 _t2 = _prefix+'"""'
167 triple_quoted[_t2] = _t2
172 endprogs[_t2] = double3prog
168 _t1 = _prefix + "'''"
173 triple_quoted[_t2] = _t2
169 endprogs[_t1] = single3prog
174 _t1 = _prefix + "'''"
170 triple_quoted[_t1] = _t1
175 endprogs[_t1] = single3prog
171 single_quoted[_prefix+'"'] = _prefix+'"'
176 triple_quoted[_t1] = _t1
172 single_quoted[_prefix+"'"] = _prefix+"'"
177 single_quoted[_prefix+'"'] = _prefix+'"'
173 del _prefix, _t2, _t1
178 single_quoted[_prefix+"'"] = _prefix+"'"
174 endprogs['u'] = None
179 del _prefix, _t2, _t1
175 endprogs['U'] = None
180 endprogs['u'] = None
181 endprogs['U'] = None
182
176
183 del _compile
177 del _compile
184
178
185 tabsize = 8
179 tabsize = 8
186
180
187 class TokenError(Exception): pass
181 class TokenError(Exception): pass
188
182
189 class StopTokenizing(Exception): pass
183 class StopTokenizing(Exception): pass
190
184
191
185
192 class Untokenizer:
186 class Untokenizer:
193
187
194 def __init__(self):
188 def __init__(self):
195 self.tokens = []
189 self.tokens = []
196 self.prev_row = 1
190 self.prev_row = 1
197 self.prev_col = 0
191 self.prev_col = 0
198 self.encoding = 'utf-8'
192 self.encoding = 'utf-8'
199
193
200 def add_whitespace(self, tok_type, start):
194 def add_whitespace(self, tok_type, start):
201 row, col = start
195 row, col = start
202 assert row >= self.prev_row
196 assert row >= self.prev_row
203 col_offset = col - self.prev_col
197 col_offset = col - self.prev_col
204 if col_offset > 0:
198 if col_offset > 0:
205 self.tokens.append(" " * col_offset)
199 self.tokens.append(" " * col_offset)
206 elif row > self.prev_row and tok_type not in (NEWLINE, NL, ENDMARKER):
200 elif row > self.prev_row and tok_type not in (NEWLINE, NL, ENDMARKER):
207 # Line was backslash-continued.
201 # Line was backslash-continued.
208 self.tokens.append(" ")
202 self.tokens.append(" ")
209
203
210 def untokenize(self, tokens):
204 def untokenize(self, tokens):
211 iterable = iter(tokens)
205 iterable = iter(tokens)
212 for t in iterable:
206 for t in iterable:
213 if len(t) == 2:
207 if len(t) == 2:
214 self.compat(t, iterable)
208 self.compat(t, iterable)
215 break
209 break
216 tok_type, token, start, end = t[:4]
210 tok_type, token, start, end = t[:4]
217 if tok_type == ENCODING:
211 if tok_type == ENCODING:
218 self.encoding = token
212 self.encoding = token
219 continue
213 continue
220 self.add_whitespace(tok_type, start)
214 self.add_whitespace(tok_type, start)
221 self.tokens.append(token)
215 self.tokens.append(token)
222 self.prev_row, self.prev_col = end
216 self.prev_row, self.prev_col = end
223 if tok_type in (NEWLINE, NL):
217 if tok_type in (NEWLINE, NL):
224 self.prev_row += 1
218 self.prev_row += 1
225 self.prev_col = 0
219 self.prev_col = 0
226 return "".join(self.tokens)
220 return "".join(self.tokens)
227
221
228 def compat(self, token, iterable):
222 def compat(self, token, iterable):
229 # This import is here to avoid problems when the itertools
223 # This import is here to avoid problems when the itertools
230 # module is not built yet and tokenize is imported.
224 # module is not built yet and tokenize is imported.
231 from itertools import chain
225 from itertools import chain
232 startline = False
226 startline = False
233 prevstring = False
227 prevstring = False
234 indents = []
228 indents = []
235 toks_append = self.tokens.append
229 toks_append = self.tokens.append
236
230
237 for tok in chain([token], iterable):
231 for tok in chain([token], iterable):
238 toknum, tokval = tok[:2]
232 toknum, tokval = tok[:2]
239 if toknum == ENCODING:
233 if toknum == ENCODING:
240 self.encoding = tokval
234 self.encoding = tokval
241 continue
235 continue
242
236
243 if toknum in (NAME, NUMBER):
237 if toknum in (NAME, NUMBER):
244 tokval += ' '
238 tokval += ' '
245
239
246 # Insert a space between two consecutive strings
240 # Insert a space between two consecutive strings
247 if toknum == STRING:
241 if toknum == STRING:
248 if prevstring:
242 if prevstring:
249 tokval = ' ' + tokval
243 tokval = ' ' + tokval
250 prevstring = True
244 prevstring = True
251 else:
245 else:
252 prevstring = False
246 prevstring = False
253
247
254 if toknum == INDENT:
248 if toknum == INDENT:
255 indents.append(tokval)
249 indents.append(tokval)
256 continue
250 continue
257 elif toknum == DEDENT:
251 elif toknum == DEDENT:
258 indents.pop()
252 indents.pop()
259 continue
253 continue
260 elif toknum in (NEWLINE, NL):
254 elif toknum in (NEWLINE, NL):
261 startline = True
255 startline = True
262 elif startline and indents:
256 elif startline and indents:
263 toks_append(indents[-1])
257 toks_append(indents[-1])
264 startline = False
258 startline = False
265 toks_append(tokval)
259 toks_append(tokval)
266
260
267
261
268 def untokenize(tokens):
262 def untokenize(tokens):
269 """
263 """
270 Convert ``tokens`` (an iterable) back into Python source code. Return
264 Convert ``tokens`` (an iterable) back into Python source code. Return
271 a bytes object, encoded using the encoding specified by the last
265 a bytes object, encoded using the encoding specified by the last
272 ENCODING token in ``tokens``, or UTF-8 if no ENCODING token is found.
266 ENCODING token in ``tokens``, or UTF-8 if no ENCODING token is found.
273
267
274 The result is guaranteed to tokenize back to match the input so that
268 The result is guaranteed to tokenize back to match the input so that
275 the conversion is lossless and round-trips are assured. The
269 the conversion is lossless and round-trips are assured. The
276 guarantee applies only to the token type and token string as the
270 guarantee applies only to the token type and token string as the
277 spacing between tokens (column positions) may change.
271 spacing between tokens (column positions) may change.
278
272
279 :func:`untokenize` has two modes. If the input tokens are sequences
273 :func:`untokenize` has two modes. If the input tokens are sequences
280 of length 2 (``type``, ``string``) then spaces are added as necessary to
274 of length 2 (``type``, ``string``) then spaces are added as necessary to
281 preserve the round-trip property.
275 preserve the round-trip property.
282
276
283 If the input tokens are sequences of length 4 or more (``type``,
277 If the input tokens are sequences of length 4 or more (``type``,
284 ``string``, ``start``, ``end``), as returned by :func:`tokenize`, then
278 ``string``, ``start``, ``end``), as returned by :func:`tokenize`, then
285 spaces are added so that each token appears in the result at the
279 spaces are added so that each token appears in the result at the
286 position indicated by ``start`` and ``end``, if possible.
280 position indicated by ``start`` and ``end``, if possible.
287 """
281 """
288 return Untokenizer().untokenize(tokens)
282 return Untokenizer().untokenize(tokens)
289
283
290
284
291 def _get_normal_name(orig_enc):
285 def _get_normal_name(orig_enc):
292 """Imitates get_normal_name in tokenizer.c."""
286 """Imitates get_normal_name in tokenizer.c."""
293 # Only care about the first 12 characters.
287 # Only care about the first 12 characters.
294 enc = orig_enc[:12].lower().replace("_", "-")
288 enc = orig_enc[:12].lower().replace("_", "-")
295 if enc == "utf-8" or enc.startswith("utf-8-"):
289 if enc == "utf-8" or enc.startswith("utf-8-"):
296 return "utf-8"
290 return "utf-8"
297 if enc in ("latin-1", "iso-8859-1", "iso-latin-1") or \
291 if enc in ("latin-1", "iso-8859-1", "iso-latin-1") or \
298 enc.startswith(("latin-1-", "iso-8859-1-", "iso-latin-1-")):
292 enc.startswith(("latin-1-", "iso-8859-1-", "iso-latin-1-")):
299 return "iso-8859-1"
293 return "iso-8859-1"
300 return orig_enc
294 return orig_enc
301
295
302 def detect_encoding(readline):
296 def detect_encoding(readline):
303 """
297 """
304 The detect_encoding() function is used to detect the encoding that should
298 The detect_encoding() function is used to detect the encoding that should
305 be used to decode a Python source file. It requires one argment, readline,
299 be used to decode a Python source file. It requires one argment, readline,
306 in the same way as the tokenize() generator.
300 in the same way as the tokenize() generator.
307
301
308 It will call readline a maximum of twice, and return the encoding used
302 It will call readline a maximum of twice, and return the encoding used
309 (as a string) and a list of any lines (left as bytes) it has read in.
303 (as a string) and a list of any lines (left as bytes) it has read in.
310
304
311 It detects the encoding from the presence of a utf-8 bom or an encoding
305 It detects the encoding from the presence of a utf-8 bom or an encoding
312 cookie as specified in pep-0263. If both a bom and a cookie are present,
306 cookie as specified in pep-0263. If both a bom and a cookie are present,
313 but disagree, a SyntaxError will be raised. If the encoding cookie is an
307 but disagree, a SyntaxError will be raised. If the encoding cookie is an
314 invalid charset, raise a SyntaxError. Note that if a utf-8 bom is found,
308 invalid charset, raise a SyntaxError. Note that if a utf-8 bom is found,
315 'utf-8-sig' is returned.
309 'utf-8-sig' is returned.
316
310
317 If no encoding is specified, then the default of 'utf-8' will be returned.
311 If no encoding is specified, then the default of 'utf-8' will be returned.
318 """
312 """
319 bom_found = False
313 bom_found = False
320 encoding = None
314 encoding = None
321 default = 'utf-8'
315 default = 'utf-8'
322 def read_or_stop():
316 def read_or_stop():
323 try:
317 try:
324 return readline()
318 return readline()
325 except StopIteration:
319 except StopIteration:
326 return b''
320 return b''
327
321
328 def find_cookie(line):
322 def find_cookie(line):
329 try:
323 try:
330 # Decode as UTF-8. Either the line is an encoding declaration,
324 # Decode as UTF-8. Either the line is an encoding declaration,
331 # in which case it should be pure ASCII, or it must be UTF-8
325 # in which case it should be pure ASCII, or it must be UTF-8
332 # per default encoding.
326 # per default encoding.
333 line_string = line.decode('utf-8')
327 line_string = line.decode('utf-8')
334 except UnicodeDecodeError:
328 except UnicodeDecodeError:
335 raise SyntaxError("invalid or missing encoding declaration")
329 raise SyntaxError("invalid or missing encoding declaration")
336
330
337 matches = cookie_re.findall(line_string)
331 matches = cookie_re.findall(line_string)
338 if not matches:
332 if not matches:
339 return None
333 return None
340 encoding = _get_normal_name(matches[0])
334 encoding = _get_normal_name(matches[0])
341 try:
335 try:
342 codec = lookup(encoding)
336 codec = lookup(encoding)
343 except LookupError:
337 except LookupError:
344 # This behaviour mimics the Python interpreter
338 # This behaviour mimics the Python interpreter
345 raise SyntaxError("unknown encoding: " + encoding)
339 raise SyntaxError("unknown encoding: " + encoding)
346
340
347 if bom_found:
341 if bom_found:
348 if encoding != 'utf-8':
342 if encoding != 'utf-8':
349 # This behaviour mimics the Python interpreter
343 # This behaviour mimics the Python interpreter
350 raise SyntaxError('encoding problem: utf-8')
344 raise SyntaxError('encoding problem: utf-8')
351 encoding += '-sig'
345 encoding += '-sig'
352 return encoding
346 return encoding
353
347
354 first = read_or_stop()
348 first = read_or_stop()
355 if first.startswith(BOM_UTF8):
349 if first.startswith(BOM_UTF8):
356 bom_found = True
350 bom_found = True
357 first = first[3:]
351 first = first[3:]
358 default = 'utf-8-sig'
352 default = 'utf-8-sig'
359 if not first:
353 if not first:
360 return default, []
354 return default, []
361
355
362 encoding = find_cookie(first)
356 encoding = find_cookie(first)
363 if encoding:
357 if encoding:
364 return encoding, [first]
358 return encoding, [first]
365
359
366 second = read_or_stop()
360 second = read_or_stop()
367 if not second:
361 if not second:
368 return default, [first]
362 return default, [first]
369
363
370 encoding = find_cookie(second)
364 encoding = find_cookie(second)
371 if encoding:
365 if encoding:
372 return encoding, [first, second]
366 return encoding, [first, second]
373
367
374 return default, [first, second]
368 return default, [first, second]
375
369
376
370
377 def open(filename):
371 def open(filename):
378 """Open a file in read only mode using the encoding detected by
372 """Open a file in read only mode using the encoding detected by
379 detect_encoding().
373 detect_encoding().
380 """
374 """
381 buffer = builtins.open(filename, 'rb')
375 buffer = builtins.open(filename, 'rb')
382 encoding, lines = detect_encoding(buffer.readline)
376 encoding, lines = detect_encoding(buffer.readline)
383 buffer.seek(0)
377 buffer.seek(0)
384 text = TextIOWrapper(buffer, encoding, line_buffering=True)
378 text = TextIOWrapper(buffer, encoding, line_buffering=True)
385 text.mode = 'r'
379 text.mode = 'r'
386 return text
380 return text
387
381
388
382
389 def tokenize(readline):
383 def tokenize(readline):
390 """
384 """
391 The tokenize() generator requires one argment, readline, which
385 The tokenize() generator requires one argment, readline, which
392 must be a callable object which provides the same interface as the
386 must be a callable object which provides the same interface as the
393 readline() method of built-in file objects. Each call to the function
387 readline() method of built-in file objects. Each call to the function
394 should return one line of input as bytes. Alternately, readline
388 should return one line of input as bytes. Alternately, readline
395 can be a callable function terminating with :class:`StopIteration`::
389 can be a callable function terminating with :class:`StopIteration`::
396
390
397 readline = open(myfile, 'rb').__next__ # Example of alternate readline
391 readline = open(myfile, 'rb').__next__ # Example of alternate readline
398
392
399 The generator produces 5-tuples with these members: the token type; the
393 The generator produces 5-tuples with these members: the token type; the
400 token string; a 2-tuple (srow, scol) of ints specifying the row and
394 token string; a 2-tuple (srow, scol) of ints specifying the row and
401 column where the token begins in the source; a 2-tuple (erow, ecol) of
395 column where the token begins in the source; a 2-tuple (erow, ecol) of
402 ints specifying the row and column where the token ends in the source;
396 ints specifying the row and column where the token ends in the source;
403 and the line on which the token was found. The line passed is the
397 and the line on which the token was found. The line passed is the
404 logical line; continuation lines are included.
398 logical line; continuation lines are included.
405
399
406 The first token sequence will always be an ENCODING token
400 The first token sequence will always be an ENCODING token
407 which tells you which encoding was used to decode the bytes stream.
401 which tells you which encoding was used to decode the bytes stream.
408 """
402 """
409 # This import is here to avoid problems when the itertools module is not
403 # This import is here to avoid problems when the itertools module is not
410 # built yet and tokenize is imported.
404 # built yet and tokenize is imported.
411 from itertools import chain, repeat
405 from itertools import chain, repeat
412 encoding, consumed = detect_encoding(readline)
406 encoding, consumed = detect_encoding(readline)
413 rl_gen = iter(readline, b"")
407 rl_gen = iter(readline, b"")
414 empty = repeat(b"")
408 empty = repeat(b"")
415 return _tokenize(chain(consumed, rl_gen, empty).__next__, encoding)
409 return _tokenize(chain(consumed, rl_gen, empty).__next__, encoding)
416
410
417
411
418 def _tokenize(readline, encoding):
412 def _tokenize(readline, encoding):
419 lnum = parenlev = continued = 0
413 lnum = parenlev = continued = 0
420 numchars = '0123456789'
414 numchars = '0123456789'
421 contstr, needcont = '', 0
415 contstr, needcont = '', 0
422 contline = None
416 contline = None
423 indents = [0]
417 indents = [0]
424
418
425 if encoding is not None:
419 if encoding is not None:
426 if encoding == "utf-8-sig":
420 if encoding == "utf-8-sig":
427 # BOM will already have been stripped.
421 # BOM will already have been stripped.
428 encoding = "utf-8"
422 encoding = "utf-8"
429 yield TokenInfo(ENCODING, encoding, (0, 0), (0, 0), '')
423 yield TokenInfo(ENCODING, encoding, (0, 0), (0, 0), '')
430 while True: # loop over lines in stream
424 while True: # loop over lines in stream
431 try:
425 try:
432 line = readline()
426 line = readline()
433 except StopIteration:
427 except StopIteration:
434 line = b''
428 line = b''
435
429
436 if encoding is not None:
430 if encoding is not None:
437 line = line.decode(encoding)
431 line = line.decode(encoding)
438 lnum += 1
432 lnum += 1
439 pos, max = 0, len(line)
433 pos, max = 0, len(line)
440
434
441 if contstr: # continued string
435 if contstr: # continued string
442 if not line:
436 if not line:
443 raise TokenError("EOF in multi-line string", strstart)
437 raise TokenError("EOF in multi-line string", strstart)
444 endmatch = endprog.match(line)
438 endmatch = endprog.match(line)
445 if endmatch:
439 if endmatch:
446 pos = end = endmatch.end(0)
440 pos = end = endmatch.end(0)
447 yield TokenInfo(STRING, contstr + line[:end],
441 yield TokenInfo(STRING, contstr + line[:end],
448 strstart, (lnum, end), contline + line)
442 strstart, (lnum, end), contline + line)
449 contstr, needcont = '', 0
443 contstr, needcont = '', 0
450 contline = None
444 contline = None
451 elif needcont and line[-2:] != '\\\n' and line[-3:] != '\\\r\n':
445 elif needcont and line[-2:] != '\\\n' and line[-3:] != '\\\r\n':
452 yield TokenInfo(ERRORTOKEN, contstr + line,
446 yield TokenInfo(ERRORTOKEN, contstr + line,
453 strstart, (lnum, len(line)), contline)
447 strstart, (lnum, len(line)), contline)
454 contstr = ''
448 contstr = ''
455 contline = None
449 contline = None
456 continue
450 continue
457 else:
451 else:
458 contstr = contstr + line
452 contstr = contstr + line
459 contline = contline + line
453 contline = contline + line
460 continue
454 continue
461
455
462 elif parenlev == 0 and not continued: # new statement
456 elif parenlev == 0 and not continued: # new statement
463 if not line: break
457 if not line: break
464 column = 0
458 column = 0
465 while pos < max: # measure leading whitespace
459 while pos < max: # measure leading whitespace
466 if line[pos] == ' ':
460 if line[pos] == ' ':
467 column += 1
461 column += 1
468 elif line[pos] == '\t':
462 elif line[pos] == '\t':
469 column = (column//tabsize + 1)*tabsize
463 column = (column//tabsize + 1)*tabsize
470 elif line[pos] == '\f':
464 elif line[pos] == '\f':
471 column = 0
465 column = 0
472 else:
466 else:
473 break
467 break
474 pos += 1
468 pos += 1
475 if pos == max:
469 if pos == max:
476 break
470 break
477
471
478 if line[pos] in '#\r\n': # skip comments or blank lines
472 if line[pos] in '#\r\n': # skip comments or blank lines
479 if line[pos] == '#':
473 if line[pos] == '#':
480 comment_token = line[pos:].rstrip('\r\n')
474 comment_token = line[pos:].rstrip('\r\n')
481 nl_pos = pos + len(comment_token)
475 nl_pos = pos + len(comment_token)
482 yield TokenInfo(COMMENT, comment_token,
476 yield TokenInfo(COMMENT, comment_token,
483 (lnum, pos), (lnum, pos + len(comment_token)), line)
477 (lnum, pos), (lnum, pos + len(comment_token)), line)
484 yield TokenInfo(NEWLINE, line[nl_pos:],
478 yield TokenInfo(NEWLINE, line[nl_pos:],
485 (lnum, nl_pos), (lnum, len(line)), line)
479 (lnum, nl_pos), (lnum, len(line)), line)
486 else:
480 else:
487 yield TokenInfo(NEWLINE, line[pos:],
481 yield TokenInfo(NEWLINE, line[pos:],
488 (lnum, pos), (lnum, len(line)), line)
482 (lnum, pos), (lnum, len(line)), line)
489 continue
483 continue
490
484
491 if column > indents[-1]: # count indents or dedents
485 if column > indents[-1]: # count indents or dedents
492 indents.append(column)
486 indents.append(column)
493 yield TokenInfo(INDENT, line[:pos], (lnum, 0), (lnum, pos), line)
487 yield TokenInfo(INDENT, line[:pos], (lnum, 0), (lnum, pos), line)
494 while column < indents[-1]:
488 while column < indents[-1]:
495 if column not in indents:
489 if column not in indents:
496 raise IndentationError(
490 raise IndentationError(
497 "unindent does not match any outer indentation level",
491 "unindent does not match any outer indentation level",
498 ("<tokenize>", lnum, pos, line))
492 ("<tokenize>", lnum, pos, line))
499 indents = indents[:-1]
493 indents = indents[:-1]
500 yield TokenInfo(DEDENT, '', (lnum, pos), (lnum, pos), line)
494 yield TokenInfo(DEDENT, '', (lnum, pos), (lnum, pos), line)
501
495
502 else: # continued statement
496 else: # continued statement
503 if not line:
497 if not line:
504 raise TokenError("EOF in multi-line statement", (lnum, 0))
498 raise TokenError("EOF in multi-line statement", (lnum, 0))
505 continued = 0
499 continued = 0
506
500
507 while pos < max:
501 while pos < max:
508 pseudomatch = pseudoprog.match(line, pos)
502 pseudomatch = pseudoprog.match(line, pos)
509 if pseudomatch: # scan for tokens
503 if pseudomatch: # scan for tokens
510 start, end = pseudomatch.span(1)
504 start, end = pseudomatch.span(1)
511 spos, epos, pos = (lnum, start), (lnum, end), end
505 spos, epos, pos = (lnum, start), (lnum, end), end
512 token, initial = line[start:end], line[start]
506 token, initial = line[start:end], line[start]
513
507
514 if (initial in numchars or # ordinary number
508 if (initial in numchars or # ordinary number
515 (initial == '.' and token != '.' and token != '...')):
509 (initial == '.' and token != '.' and token != '...')):
516 yield TokenInfo(NUMBER, token, spos, epos, line)
510 yield TokenInfo(NUMBER, token, spos, epos, line)
517 elif initial in '\r\n':
511 elif initial in '\r\n':
518 yield TokenInfo(NL if parenlev > 0 else NEWLINE,
512 yield TokenInfo(NL if parenlev > 0 else NEWLINE,
519 token, spos, epos, line)
513 token, spos, epos, line)
520 elif initial == '#':
514 elif initial == '#':
521 assert not token.endswith("\n")
515 assert not token.endswith("\n")
522 yield TokenInfo(COMMENT, token, spos, epos, line)
516 yield TokenInfo(COMMENT, token, spos, epos, line)
523 elif token in triple_quoted:
517 elif token in triple_quoted:
524 endprog = endprogs[token]
518 endprog = endprogs[token]
525 endmatch = endprog.match(line, pos)
519 endmatch = endprog.match(line, pos)
526 if endmatch: # all on one line
520 if endmatch: # all on one line
527 pos = endmatch.end(0)
521 pos = endmatch.end(0)
528 token = line[start:pos]
522 token = line[start:pos]
529 yield TokenInfo(STRING, token, spos, (lnum, pos), line)
523 yield TokenInfo(STRING, token, spos, (lnum, pos), line)
530 else:
524 else:
531 strstart = (lnum, start) # multiple lines
525 strstart = (lnum, start) # multiple lines
532 contstr = line[start:]
526 contstr = line[start:]
533 contline = line
527 contline = line
534 break
528 break
535 elif initial in single_quoted or \
529 elif initial in single_quoted or \
536 token[:2] in single_quoted or \
530 token[:2] in single_quoted or \
537 token[:3] in single_quoted:
531 token[:3] in single_quoted:
538 if token[-1] == '\n': # continued string
532 if token[-1] == '\n': # continued string
539 strstart = (lnum, start)
533 strstart = (lnum, start)
540 endprog = (endprogs[initial] or endprogs[token[1]] or
534 endprog = (endprogs[initial] or endprogs[token[1]] or
541 endprogs[token[2]])
535 endprogs[token[2]])
542 contstr, needcont = line[start:], 1
536 contstr, needcont = line[start:], 1
543 contline = line
537 contline = line
544 break
538 break
545 else: # ordinary string
539 else: # ordinary string
546 yield TokenInfo(STRING, token, spos, epos, line)
540 yield TokenInfo(STRING, token, spos, epos, line)
547 elif initial.isidentifier(): # ordinary name
541 elif initial.isidentifier(): # ordinary name
548 yield TokenInfo(NAME, token, spos, epos, line)
542 yield TokenInfo(NAME, token, spos, epos, line)
549 elif initial == '\\': # continued stmt
543 elif initial == '\\': # continued stmt
550 continued = 1
544 continued = 1
551 else:
545 else:
552 if initial in '([{':
546 if initial in '([{':
553 parenlev += 1
547 parenlev += 1
554 elif initial in ')]}':
548 elif initial in ')]}':
555 parenlev -= 1
549 parenlev -= 1
556 yield TokenInfo(OP, token, spos, epos, line)
550 yield TokenInfo(OP, token, spos, epos, line)
557 else:
551 else:
558 yield TokenInfo(ERRORTOKEN, line[pos],
552 yield TokenInfo(ERRORTOKEN, line[pos],
559 (lnum, pos), (lnum, pos+1), line)
553 (lnum, pos), (lnum, pos+1), line)
560 pos += 1
554 pos += 1
561
555
562 for indent in indents[1:]: # pop remaining indent levels
556 for indent in indents[1:]: # pop remaining indent levels
563 yield TokenInfo(DEDENT, '', (lnum, 0), (lnum, 0), '')
557 yield TokenInfo(DEDENT, '', (lnum, 0), (lnum, 0), '')
564 yield TokenInfo(ENDMARKER, '', (lnum, 0), (lnum, 0), '')
558 yield TokenInfo(ENDMARKER, '', (lnum, 0), (lnum, 0), '')
565
559
566
560
567 # An undocumented, backwards compatible, API for all the places in the standard
561 # An undocumented, backwards compatible, API for all the places in the standard
568 # library that expect to be able to use tokenize with strings
562 # library that expect to be able to use tokenize with strings
569 def generate_tokens(readline):
563 def generate_tokens(readline):
570 return _tokenize(readline, None)
564 return _tokenize(readline, None)
571
565
572 if __name__ == "__main__":
566 if __name__ == "__main__":
573 # Quick sanity check
567 # Quick sanity check
574 s = b'''def parseline(self, line):
568 s = b'''def parseline(self, line):
575 """Parse the line into a command name and a string containing
569 """Parse the line into a command name and a string containing
576 the arguments. Returns a tuple containing (command, args, line).
570 the arguments. Returns a tuple containing (command, args, line).
577 'command' and 'args' may be None if the line couldn't be parsed.
571 'command' and 'args' may be None if the line couldn't be parsed.
578 """
572 """
579 line = line.strip()
573 line = line.strip()
580 if not line:
574 if not line:
581 return None, None, line
575 return None, None, line
582 elif line[0] == '?':
576 elif line[0] == '?':
583 line = 'help ' + line[1:]
577 line = 'help ' + line[1:]
584 elif line[0] == '!':
578 elif line[0] == '!':
585 if hasattr(self, 'do_shell'):
579 if hasattr(self, 'do_shell'):
586 line = 'shell ' + line[1:]
580 line = 'shell ' + line[1:]
587 else:
581 else:
588 return None, None, line
582 return None, None, line
589 i, n = 0, len(line)
583 i, n = 0, len(line)
590 while i < n and line[i] in self.identchars: i = i+1
584 while i < n and line[i] in self.identchars: i = i+1
591 cmd, arg = line[:i], line[i:].strip()
585 cmd, arg = line[:i], line[i:].strip()
592 return cmd, arg, line
586 return cmd, arg, line
593 '''
587 '''
594 for tok in tokenize(iter(s.splitlines()).__next__):
588 for tok in tokenize(iter(s.splitlines()).__next__):
595 print(tok)
589 print(tok)
@@ -1,31 +1,32
1 :orphan:
1 :orphan:
2
2
3 Writing code for Python 2 and 3
3 Writing code for Python 2 and 3
4 ===============================
4 ===============================
5
5
6 .. module:: IPython.utils.py3compat
6 .. module:: IPython.utils.py3compat
7 :synopsis: Python 2 & 3 compatibility helpers
7 :synopsis: Python 2 & 3 compatibility helpers
8
8
9
9
10 IPython 6 requires Python 3, so our compatibility module
10 IPython 6 requires Python 3, so our compatibility module
11 ``IPython.utils.py3compat`` is deprecated. In most cases, we recommend you use
11 ``IPython.utils.py3compat`` is deprecated and will be removed in a future
12 the `six module <https://pythonhosted.org/six/>`__ to support compatible code.
12 version. In most cases, we recommend you use the `six module
13 This is widely used by other projects, so it is familiar to many developers and
13 <https://pythonhosted.org/six/>`__ to support compatible code. This is widely
14 thoroughly battle-tested.
14 used by other projects, so it is familiar to many developers and thoroughly
15 battle-tested.
15
16
16 Our ``py3compat`` module provided some more specific unicode conversions than
17 Our ``py3compat`` module provided some more specific unicode conversions than
17 those offered by ``six``. If you want to use these, copy them into your own code
18 those offered by ``six``. If you want to use these, copy them into your own code
18 from IPython 5.x. Do not rely on importing them from IPython, as the module may
19 from IPython 5.x. Do not rely on importing them from IPython, as the module may
19 be removed in the future.
20 be removed in the future.
20
21
21 .. seealso::
22 .. seealso::
22
23
23 `Porting Python 2 code to Python 3 <https://docs.python.org/3/howto/pyporting.html>`_
24 `Porting Python 2 code to Python 3 <https://docs.python.org/3/howto/pyporting.html>`_
24 Official information in the Python docs.
25 Official information in the Python docs.
25
26
26 `Python-Modernize <http://python-modernize.readthedocs.io/en/latest/>`_
27 `Python-Modernize <http://python-modernize.readthedocs.io/en/latest/>`_
27 A tool which helps make code compatible with Python 3.
28 A tool which helps make code compatible with Python 3.
28
29
29 `Python-Future <http://python-future.org/>`_
30 `Python-Future <http://python-future.org/>`_
30 Another compatibility tool, which focuses on writing code for Python 3 and
31 Another compatibility tool, which focuses on writing code for Python 3 and
31 making it work on Python 2.
32 making it work on Python 2.
@@ -1,468 +1,462
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 This module defines the things that are used in setup.py for building IPython
3 This module defines the things that are used in setup.py for building IPython
4
4
5 This includes:
5 This includes:
6
6
7 * The basic arguments to setup
7 * The basic arguments to setup
8 * Functions for finding things like packages, package data, etc.
8 * Functions for finding things like packages, package data, etc.
9 * A function for checking dependencies.
9 * A function for checking dependencies.
10 """
10 """
11
11
12 # Copyright (c) IPython Development Team.
12 # Copyright (c) IPython Development Team.
13 # Distributed under the terms of the Modified BSD License.
13 # Distributed under the terms of the Modified BSD License.
14
14
15
15
16 import re
16 import re
17 import os
17 import os
18 import sys
18 import sys
19
19
20 from distutils import log
20 from distutils import log
21 from distutils.command.build_py import build_py
21 from distutils.command.build_py import build_py
22 from distutils.command.build_scripts import build_scripts
22 from distutils.command.build_scripts import build_scripts
23 from distutils.command.install import install
23 from distutils.command.install import install
24 from distutils.command.install_scripts import install_scripts
24 from distutils.command.install_scripts import install_scripts
25 from distutils.cmd import Command
25 from distutils.cmd import Command
26 from glob import glob
26 from glob import glob
27
27
28 from setupext import install_data_ext
28 from setupext import install_data_ext
29
29
30 #-------------------------------------------------------------------------------
30 #-------------------------------------------------------------------------------
31 # Useful globals and utility functions
31 # Useful globals and utility functions
32 #-------------------------------------------------------------------------------
32 #-------------------------------------------------------------------------------
33
33
34 # A few handy globals
34 # A few handy globals
35 isfile = os.path.isfile
35 isfile = os.path.isfile
36 pjoin = os.path.join
36 pjoin = os.path.join
37 repo_root = os.path.dirname(os.path.abspath(__file__))
37 repo_root = os.path.dirname(os.path.abspath(__file__))
38
38
39 def oscmd(s):
39 def oscmd(s):
40 print(">", s)
40 print(">", s)
41 os.system(s)
41 os.system(s)
42
42
43 # Py3 compatibility hacks, without assuming IPython itself is installed with
43 def execfile(fname, globs, locs=None):
44 # the full py3compat machinery.
44 locs = locs or globs
45
45 exec(compile(open(fname).read(), fname, "exec"), globs, locs)
46 try:
47 execfile
48 except NameError:
49 def execfile(fname, globs, locs=None):
50 locs = locs or globs
51 exec(compile(open(fname).read(), fname, "exec"), globs, locs)
52
46
53 # A little utility we'll need below, since glob() does NOT allow you to do
47 # A little utility we'll need below, since glob() does NOT allow you to do
54 # exclusion on multiple endings!
48 # exclusion on multiple endings!
55 def file_doesnt_endwith(test,endings):
49 def file_doesnt_endwith(test,endings):
56 """Return true if test is a file and its name does NOT end with any
50 """Return true if test is a file and its name does NOT end with any
57 of the strings listed in endings."""
51 of the strings listed in endings."""
58 if not isfile(test):
52 if not isfile(test):
59 return False
53 return False
60 for e in endings:
54 for e in endings:
61 if test.endswith(e):
55 if test.endswith(e):
62 return False
56 return False
63 return True
57 return True
64
58
65 #---------------------------------------------------------------------------
59 #---------------------------------------------------------------------------
66 # Basic project information
60 # Basic project information
67 #---------------------------------------------------------------------------
61 #---------------------------------------------------------------------------
68
62
69 # release.py contains version, authors, license, url, keywords, etc.
63 # release.py contains version, authors, license, url, keywords, etc.
70 execfile(pjoin(repo_root, 'IPython','core','release.py'), globals())
64 execfile(pjoin(repo_root, 'IPython','core','release.py'), globals())
71
65
72 # Create a dict with the basic information
66 # Create a dict with the basic information
73 # This dict is eventually passed to setup after additional keys are added.
67 # This dict is eventually passed to setup after additional keys are added.
74 setup_args = dict(
68 setup_args = dict(
75 name = name,
69 name = name,
76 version = version,
70 version = version,
77 description = description,
71 description = description,
78 long_description = long_description,
72 long_description = long_description,
79 author = author,
73 author = author,
80 author_email = author_email,
74 author_email = author_email,
81 url = url,
75 url = url,
82 license = license,
76 license = license,
83 platforms = platforms,
77 platforms = platforms,
84 keywords = keywords,
78 keywords = keywords,
85 classifiers = classifiers,
79 classifiers = classifiers,
86 cmdclass = {'install_data': install_data_ext},
80 cmdclass = {'install_data': install_data_ext},
87 )
81 )
88
82
89
83
90 #---------------------------------------------------------------------------
84 #---------------------------------------------------------------------------
91 # Find packages
85 # Find packages
92 #---------------------------------------------------------------------------
86 #---------------------------------------------------------------------------
93
87
94 def find_packages():
88 def find_packages():
95 """
89 """
96 Find all of IPython's packages.
90 Find all of IPython's packages.
97 """
91 """
98 excludes = ['deathrow', 'quarantine']
92 excludes = ['deathrow', 'quarantine']
99 packages = []
93 packages = []
100 for dir,subdirs,files in os.walk('IPython'):
94 for dir,subdirs,files in os.walk('IPython'):
101 package = dir.replace(os.path.sep, '.')
95 package = dir.replace(os.path.sep, '.')
102 if any(package.startswith('IPython.'+exc) for exc in excludes):
96 if any(package.startswith('IPython.'+exc) for exc in excludes):
103 # package is to be excluded (e.g. deathrow)
97 # package is to be excluded (e.g. deathrow)
104 continue
98 continue
105 if '__init__.py' not in files:
99 if '__init__.py' not in files:
106 # not a package
100 # not a package
107 continue
101 continue
108 packages.append(package)
102 packages.append(package)
109 return packages
103 return packages
110
104
111 #---------------------------------------------------------------------------
105 #---------------------------------------------------------------------------
112 # Find package data
106 # Find package data
113 #---------------------------------------------------------------------------
107 #---------------------------------------------------------------------------
114
108
115 def find_package_data():
109 def find_package_data():
116 """
110 """
117 Find IPython's package_data.
111 Find IPython's package_data.
118 """
112 """
119 # This is not enough for these things to appear in an sdist.
113 # This is not enough for these things to appear in an sdist.
120 # We need to muck with the MANIFEST to get this to work
114 # We need to muck with the MANIFEST to get this to work
121
115
122 package_data = {
116 package_data = {
123 'IPython.core' : ['profile/README*'],
117 'IPython.core' : ['profile/README*'],
124 'IPython.core.tests' : ['*.png', '*.jpg', 'daft_extension/*.py'],
118 'IPython.core.tests' : ['*.png', '*.jpg', 'daft_extension/*.py'],
125 'IPython.lib.tests' : ['*.wav'],
119 'IPython.lib.tests' : ['*.wav'],
126 'IPython.testing.plugin' : ['*.txt'],
120 'IPython.testing.plugin' : ['*.txt'],
127 }
121 }
128
122
129 return package_data
123 return package_data
130
124
131
125
132 def check_package_data(package_data):
126 def check_package_data(package_data):
133 """verify that package_data globs make sense"""
127 """verify that package_data globs make sense"""
134 print("checking package data")
128 print("checking package data")
135 for pkg, data in package_data.items():
129 for pkg, data in package_data.items():
136 pkg_root = pjoin(*pkg.split('.'))
130 pkg_root = pjoin(*pkg.split('.'))
137 for d in data:
131 for d in data:
138 path = pjoin(pkg_root, d)
132 path = pjoin(pkg_root, d)
139 if '*' in path:
133 if '*' in path:
140 assert len(glob(path)) > 0, "No files match pattern %s" % path
134 assert len(glob(path)) > 0, "No files match pattern %s" % path
141 else:
135 else:
142 assert os.path.exists(path), "Missing package data: %s" % path
136 assert os.path.exists(path), "Missing package data: %s" % path
143
137
144
138
145 def check_package_data_first(command):
139 def check_package_data_first(command):
146 """decorator for checking package_data before running a given command
140 """decorator for checking package_data before running a given command
147
141
148 Probably only needs to wrap build_py
142 Probably only needs to wrap build_py
149 """
143 """
150 class DecoratedCommand(command):
144 class DecoratedCommand(command):
151 def run(self):
145 def run(self):
152 check_package_data(self.package_data)
146 check_package_data(self.package_data)
153 command.run(self)
147 command.run(self)
154 return DecoratedCommand
148 return DecoratedCommand
155
149
156
150
157 #---------------------------------------------------------------------------
151 #---------------------------------------------------------------------------
158 # Find data files
152 # Find data files
159 #---------------------------------------------------------------------------
153 #---------------------------------------------------------------------------
160
154
161 def make_dir_struct(tag,base,out_base):
155 def make_dir_struct(tag,base,out_base):
162 """Make the directory structure of all files below a starting dir.
156 """Make the directory structure of all files below a starting dir.
163
157
164 This is just a convenience routine to help build a nested directory
158 This is just a convenience routine to help build a nested directory
165 hierarchy because distutils is too stupid to do this by itself.
159 hierarchy because distutils is too stupid to do this by itself.
166
160
167 XXX - this needs a proper docstring!
161 XXX - this needs a proper docstring!
168 """
162 """
169
163
170 # we'll use these a lot below
164 # we'll use these a lot below
171 lbase = len(base)
165 lbase = len(base)
172 pathsep = os.path.sep
166 pathsep = os.path.sep
173 lpathsep = len(pathsep)
167 lpathsep = len(pathsep)
174
168
175 out = []
169 out = []
176 for (dirpath,dirnames,filenames) in os.walk(base):
170 for (dirpath,dirnames,filenames) in os.walk(base):
177 # we need to strip out the dirpath from the base to map it to the
171 # we need to strip out the dirpath from the base to map it to the
178 # output (installation) path. This requires possibly stripping the
172 # output (installation) path. This requires possibly stripping the
179 # path separator, because otherwise pjoin will not work correctly
173 # path separator, because otherwise pjoin will not work correctly
180 # (pjoin('foo/','/bar') returns '/bar').
174 # (pjoin('foo/','/bar') returns '/bar').
181
175
182 dp_eff = dirpath[lbase:]
176 dp_eff = dirpath[lbase:]
183 if dp_eff.startswith(pathsep):
177 if dp_eff.startswith(pathsep):
184 dp_eff = dp_eff[lpathsep:]
178 dp_eff = dp_eff[lpathsep:]
185 # The output path must be anchored at the out_base marker
179 # The output path must be anchored at the out_base marker
186 out_path = pjoin(out_base,dp_eff)
180 out_path = pjoin(out_base,dp_eff)
187 # Now we can generate the final filenames. Since os.walk only produces
181 # Now we can generate the final filenames. Since os.walk only produces
188 # filenames, we must join back with the dirpath to get full valid file
182 # filenames, we must join back with the dirpath to get full valid file
189 # paths:
183 # paths:
190 pfiles = [pjoin(dirpath,f) for f in filenames]
184 pfiles = [pjoin(dirpath,f) for f in filenames]
191 # Finally, generate the entry we need, which is a pari of (output
185 # Finally, generate the entry we need, which is a pari of (output
192 # path, files) for use as a data_files parameter in install_data.
186 # path, files) for use as a data_files parameter in install_data.
193 out.append((out_path, pfiles))
187 out.append((out_path, pfiles))
194
188
195 return out
189 return out
196
190
197
191
198 def find_data_files():
192 def find_data_files():
199 """
193 """
200 Find IPython's data_files.
194 Find IPython's data_files.
201
195
202 Just man pages at this point.
196 Just man pages at this point.
203 """
197 """
204
198
205 manpagebase = pjoin('share', 'man', 'man1')
199 manpagebase = pjoin('share', 'man', 'man1')
206
200
207 # Simple file lists can be made by hand
201 # Simple file lists can be made by hand
208 manpages = [f for f in glob(pjoin('docs','man','*.1.gz')) if isfile(f)]
202 manpages = [f for f in glob(pjoin('docs','man','*.1.gz')) if isfile(f)]
209 if not manpages:
203 if not manpages:
210 # When running from a source tree, the manpages aren't gzipped
204 # When running from a source tree, the manpages aren't gzipped
211 manpages = [f for f in glob(pjoin('docs','man','*.1')) if isfile(f)]
205 manpages = [f for f in glob(pjoin('docs','man','*.1')) if isfile(f)]
212
206
213 # And assemble the entire output list
207 # And assemble the entire output list
214 data_files = [ (manpagebase, manpages) ]
208 data_files = [ (manpagebase, manpages) ]
215
209
216 return data_files
210 return data_files
217
211
218
212
219 def make_man_update_target(manpage):
213 def make_man_update_target(manpage):
220 """Return a target_update-compliant tuple for the given manpage.
214 """Return a target_update-compliant tuple for the given manpage.
221
215
222 Parameters
216 Parameters
223 ----------
217 ----------
224 manpage : string
218 manpage : string
225 Name of the manpage, must include the section number (trailing number).
219 Name of the manpage, must include the section number (trailing number).
226
220
227 Example
221 Example
228 -------
222 -------
229
223
230 >>> make_man_update_target('ipython.1') #doctest: +NORMALIZE_WHITESPACE
224 >>> make_man_update_target('ipython.1') #doctest: +NORMALIZE_WHITESPACE
231 ('docs/man/ipython.1.gz',
225 ('docs/man/ipython.1.gz',
232 ['docs/man/ipython.1'],
226 ['docs/man/ipython.1'],
233 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz')
227 'cd docs/man && gzip -9c ipython.1 > ipython.1.gz')
234 """
228 """
235 man_dir = pjoin('docs', 'man')
229 man_dir = pjoin('docs', 'man')
236 manpage_gz = manpage + '.gz'
230 manpage_gz = manpage + '.gz'
237 manpath = pjoin(man_dir, manpage)
231 manpath = pjoin(man_dir, manpage)
238 manpath_gz = pjoin(man_dir, manpage_gz)
232 manpath_gz = pjoin(man_dir, manpage_gz)
239 gz_cmd = ( "cd %(man_dir)s && gzip -9c %(manpage)s > %(manpage_gz)s" %
233 gz_cmd = ( "cd %(man_dir)s && gzip -9c %(manpage)s > %(manpage_gz)s" %
240 locals() )
234 locals() )
241 return (manpath_gz, [manpath], gz_cmd)
235 return (manpath_gz, [manpath], gz_cmd)
242
236
243 # The two functions below are copied from IPython.utils.path, so we don't need
237 # The two functions below are copied from IPython.utils.path, so we don't need
244 # to import IPython during setup, which fails on Python 3.
238 # to import IPython during setup, which fails on Python 3.
245
239
246 def target_outdated(target,deps):
240 def target_outdated(target,deps):
247 """Determine whether a target is out of date.
241 """Determine whether a target is out of date.
248
242
249 target_outdated(target,deps) -> 1/0
243 target_outdated(target,deps) -> 1/0
250
244
251 deps: list of filenames which MUST exist.
245 deps: list of filenames which MUST exist.
252 target: single filename which may or may not exist.
246 target: single filename which may or may not exist.
253
247
254 If target doesn't exist or is older than any file listed in deps, return
248 If target doesn't exist or is older than any file listed in deps, return
255 true, otherwise return false.
249 true, otherwise return false.
256 """
250 """
257 try:
251 try:
258 target_time = os.path.getmtime(target)
252 target_time = os.path.getmtime(target)
259 except os.error:
253 except os.error:
260 return 1
254 return 1
261 for dep in deps:
255 for dep in deps:
262 dep_time = os.path.getmtime(dep)
256 dep_time = os.path.getmtime(dep)
263 if dep_time > target_time:
257 if dep_time > target_time:
264 #print "For target",target,"Dep failed:",dep # dbg
258 #print "For target",target,"Dep failed:",dep # dbg
265 #print "times (dep,tar):",dep_time,target_time # dbg
259 #print "times (dep,tar):",dep_time,target_time # dbg
266 return 1
260 return 1
267 return 0
261 return 0
268
262
269
263
270 def target_update(target,deps,cmd):
264 def target_update(target,deps,cmd):
271 """Update a target with a given command given a list of dependencies.
265 """Update a target with a given command given a list of dependencies.
272
266
273 target_update(target,deps,cmd) -> runs cmd if target is outdated.
267 target_update(target,deps,cmd) -> runs cmd if target is outdated.
274
268
275 This is just a wrapper around target_outdated() which calls the given
269 This is just a wrapper around target_outdated() which calls the given
276 command if target is outdated."""
270 command if target is outdated."""
277
271
278 if target_outdated(target,deps):
272 if target_outdated(target,deps):
279 os.system(cmd)
273 os.system(cmd)
280
274
281 #---------------------------------------------------------------------------
275 #---------------------------------------------------------------------------
282 # Find scripts
276 # Find scripts
283 #---------------------------------------------------------------------------
277 #---------------------------------------------------------------------------
284
278
285 def find_entry_points():
279 def find_entry_points():
286 """Defines the command line entry points for IPython
280 """Defines the command line entry points for IPython
287
281
288 This always uses setuptools-style entry points. When setuptools is not in
282 This always uses setuptools-style entry points. When setuptools is not in
289 use, our own build_scripts_entrypt class below parses these and builds
283 use, our own build_scripts_entrypt class below parses these and builds
290 command line scripts.
284 command line scripts.
291
285
292 Each of our entry points gets both a plain name, e.g. ipython, and one
286 Each of our entry points gets both a plain name, e.g. ipython, and one
293 suffixed with the Python major version number, e.g. ipython3.
287 suffixed with the Python major version number, e.g. ipython3.
294 """
288 """
295 ep = [
289 ep = [
296 'ipython%s = IPython:start_ipython',
290 'ipython%s = IPython:start_ipython',
297 'iptest%s = IPython.testing.iptestcontroller:main',
291 'iptest%s = IPython.testing.iptestcontroller:main',
298 ]
292 ]
299 suffix = str(sys.version_info[0])
293 suffix = str(sys.version_info[0])
300 return [e % '' for e in ep] + [e % suffix for e in ep]
294 return [e % '' for e in ep] + [e % suffix for e in ep]
301
295
302 script_src = """#!{executable}
296 script_src = """#!{executable}
303 # This script was automatically generated by setup.py
297 # This script was automatically generated by setup.py
304 if __name__ == '__main__':
298 if __name__ == '__main__':
305 from {mod} import {func}
299 from {mod} import {func}
306 {func}()
300 {func}()
307 """
301 """
308
302
309 class build_scripts_entrypt(build_scripts):
303 class build_scripts_entrypt(build_scripts):
310 """Build the command line scripts
304 """Build the command line scripts
311
305
312 Parse setuptools style entry points and write simple scripts to run the
306 Parse setuptools style entry points and write simple scripts to run the
313 target functions.
307 target functions.
314
308
315 On Windows, this also creates .cmd wrappers for the scripts so that you can
309 On Windows, this also creates .cmd wrappers for the scripts so that you can
316 easily launch them from a command line.
310 easily launch them from a command line.
317 """
311 """
318 def run(self):
312 def run(self):
319 self.mkpath(self.build_dir)
313 self.mkpath(self.build_dir)
320 outfiles = []
314 outfiles = []
321 for script in find_entry_points():
315 for script in find_entry_points():
322 name, entrypt = script.split('=')
316 name, entrypt = script.split('=')
323 name = name.strip()
317 name = name.strip()
324 entrypt = entrypt.strip()
318 entrypt = entrypt.strip()
325 outfile = os.path.join(self.build_dir, name)
319 outfile = os.path.join(self.build_dir, name)
326 outfiles.append(outfile)
320 outfiles.append(outfile)
327 print('Writing script to', outfile)
321 print('Writing script to', outfile)
328
322
329 mod, func = entrypt.split(':')
323 mod, func = entrypt.split(':')
330 with open(outfile, 'w') as f:
324 with open(outfile, 'w') as f:
331 f.write(script_src.format(executable=sys.executable,
325 f.write(script_src.format(executable=sys.executable,
332 mod=mod, func=func))
326 mod=mod, func=func))
333
327
334 if sys.platform == 'win32':
328 if sys.platform == 'win32':
335 # Write .cmd wrappers for Windows so 'ipython' etc. work at the
329 # Write .cmd wrappers for Windows so 'ipython' etc. work at the
336 # command line
330 # command line
337 cmd_file = os.path.join(self.build_dir, name + '.cmd')
331 cmd_file = os.path.join(self.build_dir, name + '.cmd')
338 cmd = '@"{python}" "%~dp0\{script}" %*\r\n'.format(
332 cmd = '@"{python}" "%~dp0\{script}" %*\r\n'.format(
339 python=sys.executable, script=name)
333 python=sys.executable, script=name)
340 log.info("Writing %s wrapper script" % cmd_file)
334 log.info("Writing %s wrapper script" % cmd_file)
341 with open(cmd_file, 'w') as f:
335 with open(cmd_file, 'w') as f:
342 f.write(cmd)
336 f.write(cmd)
343
337
344 return outfiles, outfiles
338 return outfiles, outfiles
345
339
346 class install_lib_symlink(Command):
340 class install_lib_symlink(Command):
347 user_options = [
341 user_options = [
348 ('install-dir=', 'd', "directory to install to"),
342 ('install-dir=', 'd', "directory to install to"),
349 ]
343 ]
350
344
351 def initialize_options(self):
345 def initialize_options(self):
352 self.install_dir = None
346 self.install_dir = None
353
347
354 def finalize_options(self):
348 def finalize_options(self):
355 self.set_undefined_options('symlink',
349 self.set_undefined_options('symlink',
356 ('install_lib', 'install_dir'),
350 ('install_lib', 'install_dir'),
357 )
351 )
358
352
359 def run(self):
353 def run(self):
360 if sys.platform == 'win32':
354 if sys.platform == 'win32':
361 raise Exception("This doesn't work on Windows.")
355 raise Exception("This doesn't work on Windows.")
362 pkg = os.path.join(os.getcwd(), 'IPython')
356 pkg = os.path.join(os.getcwd(), 'IPython')
363 dest = os.path.join(self.install_dir, 'IPython')
357 dest = os.path.join(self.install_dir, 'IPython')
364 if os.path.islink(dest):
358 if os.path.islink(dest):
365 print('removing existing symlink at %s' % dest)
359 print('removing existing symlink at %s' % dest)
366 os.unlink(dest)
360 os.unlink(dest)
367 print('symlinking %s -> %s' % (pkg, dest))
361 print('symlinking %s -> %s' % (pkg, dest))
368 os.symlink(pkg, dest)
362 os.symlink(pkg, dest)
369
363
370 class unsymlink(install):
364 class unsymlink(install):
371 def run(self):
365 def run(self):
372 dest = os.path.join(self.install_lib, 'IPython')
366 dest = os.path.join(self.install_lib, 'IPython')
373 if os.path.islink(dest):
367 if os.path.islink(dest):
374 print('removing symlink at %s' % dest)
368 print('removing symlink at %s' % dest)
375 os.unlink(dest)
369 os.unlink(dest)
376 else:
370 else:
377 print('No symlink exists at %s' % dest)
371 print('No symlink exists at %s' % dest)
378
372
379 class install_symlinked(install):
373 class install_symlinked(install):
380 def run(self):
374 def run(self):
381 if sys.platform == 'win32':
375 if sys.platform == 'win32':
382 raise Exception("This doesn't work on Windows.")
376 raise Exception("This doesn't work on Windows.")
383
377
384 # Run all sub-commands (at least those that need to be run)
378 # Run all sub-commands (at least those that need to be run)
385 for cmd_name in self.get_sub_commands():
379 for cmd_name in self.get_sub_commands():
386 self.run_command(cmd_name)
380 self.run_command(cmd_name)
387
381
388 # 'sub_commands': a list of commands this command might have to run to
382 # 'sub_commands': a list of commands this command might have to run to
389 # get its work done. See cmd.py for more info.
383 # get its work done. See cmd.py for more info.
390 sub_commands = [('install_lib_symlink', lambda self:True),
384 sub_commands = [('install_lib_symlink', lambda self:True),
391 ('install_scripts_sym', lambda self:True),
385 ('install_scripts_sym', lambda self:True),
392 ]
386 ]
393
387
394 class install_scripts_for_symlink(install_scripts):
388 class install_scripts_for_symlink(install_scripts):
395 """Redefined to get options from 'symlink' instead of 'install'.
389 """Redefined to get options from 'symlink' instead of 'install'.
396
390
397 I love distutils almost as much as I love setuptools.
391 I love distutils almost as much as I love setuptools.
398 """
392 """
399 def finalize_options(self):
393 def finalize_options(self):
400 self.set_undefined_options('build', ('build_scripts', 'build_dir'))
394 self.set_undefined_options('build', ('build_scripts', 'build_dir'))
401 self.set_undefined_options('symlink',
395 self.set_undefined_options('symlink',
402 ('install_scripts', 'install_dir'),
396 ('install_scripts', 'install_dir'),
403 ('force', 'force'),
397 ('force', 'force'),
404 ('skip_build', 'skip_build'),
398 ('skip_build', 'skip_build'),
405 )
399 )
406
400
407
401
408 #---------------------------------------------------------------------------
402 #---------------------------------------------------------------------------
409 # VCS related
403 # VCS related
410 #---------------------------------------------------------------------------
404 #---------------------------------------------------------------------------
411
405
412
406
413 def git_prebuild(pkg_dir, build_cmd=build_py):
407 def git_prebuild(pkg_dir, build_cmd=build_py):
414 """Return extended build or sdist command class for recording commit
408 """Return extended build or sdist command class for recording commit
415
409
416 records git commit in IPython.utils._sysinfo.commit
410 records git commit in IPython.utils._sysinfo.commit
417
411
418 for use in IPython.utils.sysinfo.sys_info() calls after installation.
412 for use in IPython.utils.sysinfo.sys_info() calls after installation.
419 """
413 """
420
414
421 class MyBuildPy(build_cmd):
415 class MyBuildPy(build_cmd):
422 ''' Subclass to write commit data into installation tree '''
416 ''' Subclass to write commit data into installation tree '''
423 def run(self):
417 def run(self):
424 # loose as `.dev` is suppose to be invalid
418 # loose as `.dev` is suppose to be invalid
425 print("check version number")
419 print("check version number")
426 loose_pep440re = re.compile('^(\d+)\.(\d+)\.(\d+((a|b|rc)\d+)?)(\.post\d+)?(\.dev\d*)?$')
420 loose_pep440re = re.compile('^(\d+)\.(\d+)\.(\d+((a|b|rc)\d+)?)(\.post\d+)?(\.dev\d*)?$')
427 if not loose_pep440re.match(version):
421 if not loose_pep440re.match(version):
428 raise ValueError("Version number '%s' is not valid (should match [N!]N(.N)*[{a|b|rc}N][.postN][.devN])" % version)
422 raise ValueError("Version number '%s' is not valid (should match [N!]N(.N)*[{a|b|rc}N][.postN][.devN])" % version)
429
423
430
424
431 build_cmd.run(self)
425 build_cmd.run(self)
432 # this one will only fire for build commands
426 # this one will only fire for build commands
433 if hasattr(self, 'build_lib'):
427 if hasattr(self, 'build_lib'):
434 self._record_commit(self.build_lib)
428 self._record_commit(self.build_lib)
435
429
436 def make_release_tree(self, base_dir, files):
430 def make_release_tree(self, base_dir, files):
437 # this one will fire for sdist
431 # this one will fire for sdist
438 build_cmd.make_release_tree(self, base_dir, files)
432 build_cmd.make_release_tree(self, base_dir, files)
439 self._record_commit(base_dir)
433 self._record_commit(base_dir)
440
434
441 def _record_commit(self, base_dir):
435 def _record_commit(self, base_dir):
442 import subprocess
436 import subprocess
443 proc = subprocess.Popen('git rev-parse --short HEAD',
437 proc = subprocess.Popen('git rev-parse --short HEAD',
444 stdout=subprocess.PIPE,
438 stdout=subprocess.PIPE,
445 stderr=subprocess.PIPE,
439 stderr=subprocess.PIPE,
446 shell=True)
440 shell=True)
447 repo_commit, _ = proc.communicate()
441 repo_commit, _ = proc.communicate()
448 repo_commit = repo_commit.strip().decode("ascii")
442 repo_commit = repo_commit.strip().decode("ascii")
449
443
450 out_pth = pjoin(base_dir, pkg_dir, 'utils', '_sysinfo.py')
444 out_pth = pjoin(base_dir, pkg_dir, 'utils', '_sysinfo.py')
451 if os.path.isfile(out_pth) and not repo_commit:
445 if os.path.isfile(out_pth) and not repo_commit:
452 # nothing to write, don't clobber
446 # nothing to write, don't clobber
453 return
447 return
454
448
455 print("writing git commit '%s' to %s" % (repo_commit, out_pth))
449 print("writing git commit '%s' to %s" % (repo_commit, out_pth))
456
450
457 # remove to avoid overwriting original via hard link
451 # remove to avoid overwriting original via hard link
458 try:
452 try:
459 os.remove(out_pth)
453 os.remove(out_pth)
460 except (IOError, OSError):
454 except (IOError, OSError):
461 pass
455 pass
462 with open(out_pth, 'w') as out_file:
456 with open(out_pth, 'w') as out_file:
463 out_file.writelines([
457 out_file.writelines([
464 '# GENERATED BY setup.py\n',
458 '# GENERATED BY setup.py\n',
465 'commit = u"%s"\n' % repo_commit,
459 'commit = u"%s"\n' % repo_commit,
466 ])
460 ])
467 return MyBuildPy
461 return MyBuildPy
468
462
1 NO CONTENT: file was removed
NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now