##// END OF EJS Templates
Merge pull request #10135 from srinivasreddy/dead_code...
Matthias Bussonnier -
r23101:eadc81ae merge
parent child Browse files
Show More
@@ -1,987 +1,980 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tools for inspecting Python objects.
2 """Tools for inspecting Python objects.
3
3
4 Uses syntax highlighting for presenting the various information elements.
4 Uses syntax highlighting for presenting the various information elements.
5
5
6 Similar in spirit to the inspect module, but all calls take a name argument to
6 Similar in spirit to the inspect module, but all calls take a name argument to
7 reference the name under which an object is being read.
7 reference the name under which an object is being read.
8 """
8 """
9
9
10 # Copyright (c) IPython Development Team.
10 # Copyright (c) IPython Development Team.
11 # Distributed under the terms of the Modified BSD License.
11 # Distributed under the terms of the Modified BSD License.
12
12
13 __all__ = ['Inspector','InspectColors']
13 __all__ = ['Inspector','InspectColors']
14
14
15 # stdlib modules
15 # stdlib modules
16 import inspect
16 import inspect
17 from inspect import signature
17 from inspect import signature
18 import linecache
18 import linecache
19 import warnings
19 import warnings
20 import os
20 import os
21 from textwrap import dedent
21 from textwrap import dedent
22 import types
22 import types
23 import io as stdlib_io
23 import io as stdlib_io
24 from itertools import zip_longest
24 from itertools import zip_longest
25
25
26 # IPython's own
26 # IPython's own
27 from IPython.core import page
27 from IPython.core import page
28 from IPython.lib.pretty import pretty
28 from IPython.lib.pretty import pretty
29 from IPython.testing.skipdoctest import skip_doctest
29 from IPython.testing.skipdoctest import skip_doctest
30 from IPython.utils import PyColorize
30 from IPython.utils import PyColorize
31 from IPython.utils import openpy
31 from IPython.utils import openpy
32 from IPython.utils import py3compat
32 from IPython.utils import py3compat
33 from IPython.utils.dir2 import safe_hasattr
33 from IPython.utils.dir2 import safe_hasattr
34 from IPython.utils.path import compress_user
34 from IPython.utils.path import compress_user
35 from IPython.utils.text import indent
35 from IPython.utils.text import indent
36 from IPython.utils.wildcard import list_namespace
36 from IPython.utils.wildcard import list_namespace
37 from IPython.utils.coloransi import TermColors, ColorScheme, ColorSchemeTable
37 from IPython.utils.coloransi import TermColors, ColorScheme, ColorSchemeTable
38 from IPython.utils.py3compat import cast_unicode, PY3
38 from IPython.utils.py3compat import cast_unicode
39 from IPython.utils.colorable import Colorable
39 from IPython.utils.colorable import Colorable
40 from IPython.utils.decorators import undoc
40 from IPython.utils.decorators import undoc
41
41
42 from pygments import highlight
42 from pygments import highlight
43 from pygments.lexers import PythonLexer
43 from pygments.lexers import PythonLexer
44 from pygments.formatters import HtmlFormatter
44 from pygments.formatters import HtmlFormatter
45
45
46 def pylight(code):
46 def pylight(code):
47 return highlight(code, PythonLexer(), HtmlFormatter(noclasses=True))
47 return highlight(code, PythonLexer(), HtmlFormatter(noclasses=True))
48
48
49 # builtin docstrings to ignore
49 # builtin docstrings to ignore
50 _func_call_docstring = types.FunctionType.__call__.__doc__
50 _func_call_docstring = types.FunctionType.__call__.__doc__
51 _object_init_docstring = object.__init__.__doc__
51 _object_init_docstring = object.__init__.__doc__
52 _builtin_type_docstrings = {
52 _builtin_type_docstrings = {
53 inspect.getdoc(t) for t in (types.ModuleType, types.MethodType,
53 inspect.getdoc(t) for t in (types.ModuleType, types.MethodType,
54 types.FunctionType, property)
54 types.FunctionType, property)
55 }
55 }
56
56
57 _builtin_func_type = type(all)
57 _builtin_func_type = type(all)
58 _builtin_meth_type = type(str.upper) # Bound methods have the same type as builtin functions
58 _builtin_meth_type = type(str.upper) # Bound methods have the same type as builtin functions
59 #****************************************************************************
59 #****************************************************************************
60 # Builtin color schemes
60 # Builtin color schemes
61
61
62 Colors = TermColors # just a shorthand
62 Colors = TermColors # just a shorthand
63
63
64 InspectColors = PyColorize.ANSICodeColors
64 InspectColors = PyColorize.ANSICodeColors
65
65
66 #****************************************************************************
66 #****************************************************************************
67 # Auxiliary functions and objects
67 # Auxiliary functions and objects
68
68
69 # See the messaging spec for the definition of all these fields. This list
69 # See the messaging spec for the definition of all these fields. This list
70 # effectively defines the order of display
70 # effectively defines the order of display
71 info_fields = ['type_name', 'base_class', 'string_form', 'namespace',
71 info_fields = ['type_name', 'base_class', 'string_form', 'namespace',
72 'length', 'file', 'definition', 'docstring', 'source',
72 'length', 'file', 'definition', 'docstring', 'source',
73 'init_definition', 'class_docstring', 'init_docstring',
73 'init_definition', 'class_docstring', 'init_docstring',
74 'call_def', 'call_docstring',
74 'call_def', 'call_docstring',
75 # These won't be printed but will be used to determine how to
75 # These won't be printed but will be used to determine how to
76 # format the object
76 # format the object
77 'ismagic', 'isalias', 'isclass', 'argspec', 'found', 'name'
77 'ismagic', 'isalias', 'isclass', 'argspec', 'found', 'name'
78 ]
78 ]
79
79
80
80
81 def object_info(**kw):
81 def object_info(**kw):
82 """Make an object info dict with all fields present."""
82 """Make an object info dict with all fields present."""
83 infodict = dict(zip_longest(info_fields, [None]))
83 infodict = dict(zip_longest(info_fields, [None]))
84 infodict.update(kw)
84 infodict.update(kw)
85 return infodict
85 return infodict
86
86
87
87
88 def get_encoding(obj):
88 def get_encoding(obj):
89 """Get encoding for python source file defining obj
89 """Get encoding for python source file defining obj
90
90
91 Returns None if obj is not defined in a sourcefile.
91 Returns None if obj is not defined in a sourcefile.
92 """
92 """
93 ofile = find_file(obj)
93 ofile = find_file(obj)
94 # run contents of file through pager starting at line where the object
94 # run contents of file through pager starting at line where the object
95 # is defined, as long as the file isn't binary and is actually on the
95 # is defined, as long as the file isn't binary and is actually on the
96 # filesystem.
96 # filesystem.
97 if ofile is None:
97 if ofile is None:
98 return None
98 return None
99 elif ofile.endswith(('.so', '.dll', '.pyd')):
99 elif ofile.endswith(('.so', '.dll', '.pyd')):
100 return None
100 return None
101 elif not os.path.isfile(ofile):
101 elif not os.path.isfile(ofile):
102 return None
102 return None
103 else:
103 else:
104 # Print only text files, not extension binaries. Note that
104 # Print only text files, not extension binaries. Note that
105 # getsourcelines returns lineno with 1-offset and page() uses
105 # getsourcelines returns lineno with 1-offset and page() uses
106 # 0-offset, so we must adjust.
106 # 0-offset, so we must adjust.
107 with stdlib_io.open(ofile, 'rb') as buffer: # Tweaked to use io.open for Python 2
107 with stdlib_io.open(ofile, 'rb') as buffer: # Tweaked to use io.open for Python 2
108 encoding, lines = openpy.detect_encoding(buffer.readline)
108 encoding, lines = openpy.detect_encoding(buffer.readline)
109 return encoding
109 return encoding
110
110
111 def getdoc(obj):
111 def getdoc(obj):
112 """Stable wrapper around inspect.getdoc.
112 """Stable wrapper around inspect.getdoc.
113
113
114 This can't crash because of attribute problems.
114 This can't crash because of attribute problems.
115
115
116 It also attempts to call a getdoc() method on the given object. This
116 It also attempts to call a getdoc() method on the given object. This
117 allows objects which provide their docstrings via non-standard mechanisms
117 allows objects which provide their docstrings via non-standard mechanisms
118 (like Pyro proxies) to still be inspected by ipython's ? system.
118 (like Pyro proxies) to still be inspected by ipython's ? system.
119 """
119 """
120 # Allow objects to offer customized documentation via a getdoc method:
120 # Allow objects to offer customized documentation via a getdoc method:
121 try:
121 try:
122 ds = obj.getdoc()
122 ds = obj.getdoc()
123 except Exception:
123 except Exception:
124 pass
124 pass
125 else:
125 else:
126 # if we get extra info, we add it to the normal docstring.
126 # if we get extra info, we add it to the normal docstring.
127 if isinstance(ds, str):
127 if isinstance(ds, str):
128 return inspect.cleandoc(ds)
128 return inspect.cleandoc(ds)
129 try:
129 try:
130 docstr = inspect.getdoc(obj)
130 docstr = inspect.getdoc(obj)
131 encoding = get_encoding(obj)
131 encoding = get_encoding(obj)
132 return py3compat.cast_unicode(docstr, encoding=encoding)
132 return py3compat.cast_unicode(docstr, encoding=encoding)
133 except Exception:
133 except Exception:
134 # Harden against an inspect failure, which can occur with
134 # Harden against an inspect failure, which can occur with
135 # extensions modules.
135 # extensions modules.
136 raise
136 raise
137 return None
137 return None
138
138
139
139
140 def getsource(obj, oname=''):
140 def getsource(obj, oname=''):
141 """Wrapper around inspect.getsource.
141 """Wrapper around inspect.getsource.
142
142
143 This can be modified by other projects to provide customized source
143 This can be modified by other projects to provide customized source
144 extraction.
144 extraction.
145
145
146 Parameters
146 Parameters
147 ----------
147 ----------
148 obj : object
148 obj : object
149 an object whose source code we will attempt to extract
149 an object whose source code we will attempt to extract
150 oname : str
150 oname : str
151 (optional) a name under which the object is known
151 (optional) a name under which the object is known
152
152
153 Returns
153 Returns
154 -------
154 -------
155 src : unicode or None
155 src : unicode or None
156
156
157 """
157 """
158
158
159 if isinstance(obj, property):
159 if isinstance(obj, property):
160 sources = []
160 sources = []
161 for attrname in ['fget', 'fset', 'fdel']:
161 for attrname in ['fget', 'fset', 'fdel']:
162 fn = getattr(obj, attrname)
162 fn = getattr(obj, attrname)
163 if fn is not None:
163 if fn is not None:
164 encoding = get_encoding(fn)
164 encoding = get_encoding(fn)
165 oname_prefix = ('%s.' % oname) if oname else ''
165 oname_prefix = ('%s.' % oname) if oname else ''
166 sources.append(cast_unicode(
166 sources.append(cast_unicode(
167 ''.join(('# ', oname_prefix, attrname)),
167 ''.join(('# ', oname_prefix, attrname)),
168 encoding=encoding))
168 encoding=encoding))
169 if inspect.isfunction(fn):
169 if inspect.isfunction(fn):
170 sources.append(dedent(getsource(fn)))
170 sources.append(dedent(getsource(fn)))
171 else:
171 else:
172 # Default str/repr only prints function name,
172 # Default str/repr only prints function name,
173 # pretty.pretty prints module name too.
173 # pretty.pretty prints module name too.
174 sources.append(cast_unicode(
174 sources.append(cast_unicode(
175 '%s%s = %s\n' % (
175 '%s%s = %s\n' % (
176 oname_prefix, attrname, pretty(fn)),
176 oname_prefix, attrname, pretty(fn)),
177 encoding=encoding))
177 encoding=encoding))
178 if sources:
178 if sources:
179 return '\n'.join(sources)
179 return '\n'.join(sources)
180 else:
180 else:
181 return None
181 return None
182
182
183 else:
183 else:
184 # Get source for non-property objects.
184 # Get source for non-property objects.
185
185
186 obj = _get_wrapped(obj)
186 obj = _get_wrapped(obj)
187
187
188 try:
188 try:
189 src = inspect.getsource(obj)
189 src = inspect.getsource(obj)
190 except TypeError:
190 except TypeError:
191 # The object itself provided no meaningful source, try looking for
191 # The object itself provided no meaningful source, try looking for
192 # its class definition instead.
192 # its class definition instead.
193 if hasattr(obj, '__class__'):
193 if hasattr(obj, '__class__'):
194 try:
194 try:
195 src = inspect.getsource(obj.__class__)
195 src = inspect.getsource(obj.__class__)
196 except TypeError:
196 except TypeError:
197 return None
197 return None
198
198
199 encoding = get_encoding(obj)
199 encoding = get_encoding(obj)
200 return cast_unicode(src, encoding=encoding)
200 return cast_unicode(src, encoding=encoding)
201
201
202
202
203 def is_simple_callable(obj):
203 def is_simple_callable(obj):
204 """True if obj is a function ()"""
204 """True if obj is a function ()"""
205 return (inspect.isfunction(obj) or inspect.ismethod(obj) or \
205 return (inspect.isfunction(obj) or inspect.ismethod(obj) or \
206 isinstance(obj, _builtin_func_type) or isinstance(obj, _builtin_meth_type))
206 isinstance(obj, _builtin_func_type) or isinstance(obj, _builtin_meth_type))
207
207
208
208
209 def getargspec(obj):
209 def getargspec(obj):
210 """Wrapper around :func:`inspect.getfullargspec` on Python 3, and
210 """Wrapper around :func:`inspect.getfullargspec` on Python 3, and
211 :func:inspect.getargspec` on Python 2.
211 :func:inspect.getargspec` on Python 2.
212
212
213 In addition to functions and methods, this can also handle objects with a
213 In addition to functions and methods, this can also handle objects with a
214 ``__call__`` attribute.
214 ``__call__`` attribute.
215 """
215 """
216 if safe_hasattr(obj, '__call__') and not is_simple_callable(obj):
216 if safe_hasattr(obj, '__call__') and not is_simple_callable(obj):
217 obj = obj.__call__
217 obj = obj.__call__
218
218
219 return inspect.getfullargspec(obj) if PY3 else inspect.getargspec(obj)
219 return inspect.getfullargspec(obj)
220
220
221
221
222 def format_argspec(argspec):
222 def format_argspec(argspec):
223 """Format argspect, convenience wrapper around inspect's.
223 """Format argspect, convenience wrapper around inspect's.
224
224
225 This takes a dict instead of ordered arguments and calls
225 This takes a dict instead of ordered arguments and calls
226 inspect.format_argspec with the arguments in the necessary order.
226 inspect.format_argspec with the arguments in the necessary order.
227 """
227 """
228 return inspect.formatargspec(argspec['args'], argspec['varargs'],
228 return inspect.formatargspec(argspec['args'], argspec['varargs'],
229 argspec['varkw'], argspec['defaults'])
229 argspec['varkw'], argspec['defaults'])
230
230
231 @undoc
231 @undoc
232 def call_tip(oinfo, format_call=True):
232 def call_tip(oinfo, format_call=True):
233 """DEPRECATED. Extract call tip data from an oinfo dict.
233 """DEPRECATED. Extract call tip data from an oinfo dict.
234 """
234 """
235 warnings.warn('`call_tip` function is deprecated as of IPython 6.0'
235 warnings.warn('`call_tip` function is deprecated as of IPython 6.0'
236 'and will be removed in future versions.', DeprecationWarning, stacklevel=2)
236 'and will be removed in future versions.', DeprecationWarning, stacklevel=2)
237 # Get call definition
237 # Get call definition
238 argspec = oinfo.get('argspec')
238 argspec = oinfo.get('argspec')
239 if argspec is None:
239 if argspec is None:
240 call_line = None
240 call_line = None
241 else:
241 else:
242 # Callable objects will have 'self' as their first argument, prune
242 # Callable objects will have 'self' as their first argument, prune
243 # it out if it's there for clarity (since users do *not* pass an
243 # it out if it's there for clarity (since users do *not* pass an
244 # extra first argument explicitly).
244 # extra first argument explicitly).
245 try:
245 try:
246 has_self = argspec['args'][0] == 'self'
246 has_self = argspec['args'][0] == 'self'
247 except (KeyError, IndexError):
247 except (KeyError, IndexError):
248 pass
248 pass
249 else:
249 else:
250 if has_self:
250 if has_self:
251 argspec['args'] = argspec['args'][1:]
251 argspec['args'] = argspec['args'][1:]
252
252
253 call_line = oinfo['name']+format_argspec(argspec)
253 call_line = oinfo['name']+format_argspec(argspec)
254
254
255 # Now get docstring.
255 # Now get docstring.
256 # The priority is: call docstring, constructor docstring, main one.
256 # The priority is: call docstring, constructor docstring, main one.
257 doc = oinfo.get('call_docstring')
257 doc = oinfo.get('call_docstring')
258 if doc is None:
258 if doc is None:
259 doc = oinfo.get('init_docstring')
259 doc = oinfo.get('init_docstring')
260 if doc is None:
260 if doc is None:
261 doc = oinfo.get('docstring','')
261 doc = oinfo.get('docstring','')
262
262
263 return call_line, doc
263 return call_line, doc
264
264
265
265
266 def _get_wrapped(obj):
266 def _get_wrapped(obj):
267 """Get the original object if wrapped in one or more @decorators
267 """Get the original object if wrapped in one or more @decorators
268
268
269 Some objects automatically construct similar objects on any unrecognised
269 Some objects automatically construct similar objects on any unrecognised
270 attribute access (e.g. unittest.mock.call). To protect against infinite loops,
270 attribute access (e.g. unittest.mock.call). To protect against infinite loops,
271 this will arbitrarily cut off after 100 levels of obj.__wrapped__
271 this will arbitrarily cut off after 100 levels of obj.__wrapped__
272 attribute access. --TK, Jan 2016
272 attribute access. --TK, Jan 2016
273 """
273 """
274 orig_obj = obj
274 orig_obj = obj
275 i = 0
275 i = 0
276 while safe_hasattr(obj, '__wrapped__'):
276 while safe_hasattr(obj, '__wrapped__'):
277 obj = obj.__wrapped__
277 obj = obj.__wrapped__
278 i += 1
278 i += 1
279 if i > 100:
279 if i > 100:
280 # __wrapped__ is probably a lie, so return the thing we started with
280 # __wrapped__ is probably a lie, so return the thing we started with
281 return orig_obj
281 return orig_obj
282 return obj
282 return obj
283
283
284 def find_file(obj):
284 def find_file(obj):
285 """Find the absolute path to the file where an object was defined.
285 """Find the absolute path to the file where an object was defined.
286
286
287 This is essentially a robust wrapper around `inspect.getabsfile`.
287 This is essentially a robust wrapper around `inspect.getabsfile`.
288
288
289 Returns None if no file can be found.
289 Returns None if no file can be found.
290
290
291 Parameters
291 Parameters
292 ----------
292 ----------
293 obj : any Python object
293 obj : any Python object
294
294
295 Returns
295 Returns
296 -------
296 -------
297 fname : str
297 fname : str
298 The absolute path to the file where the object was defined.
298 The absolute path to the file where the object was defined.
299 """
299 """
300 obj = _get_wrapped(obj)
300 obj = _get_wrapped(obj)
301
301
302 fname = None
302 fname = None
303 try:
303 try:
304 fname = inspect.getabsfile(obj)
304 fname = inspect.getabsfile(obj)
305 except TypeError:
305 except TypeError:
306 # For an instance, the file that matters is where its class was
306 # For an instance, the file that matters is where its class was
307 # declared.
307 # declared.
308 if hasattr(obj, '__class__'):
308 if hasattr(obj, '__class__'):
309 try:
309 try:
310 fname = inspect.getabsfile(obj.__class__)
310 fname = inspect.getabsfile(obj.__class__)
311 except TypeError:
311 except TypeError:
312 # Can happen for builtins
312 # Can happen for builtins
313 pass
313 pass
314 except:
314 except:
315 pass
315 pass
316 return cast_unicode(fname)
316 return cast_unicode(fname)
317
317
318
318
319 def find_source_lines(obj):
319 def find_source_lines(obj):
320 """Find the line number in a file where an object was defined.
320 """Find the line number in a file where an object was defined.
321
321
322 This is essentially a robust wrapper around `inspect.getsourcelines`.
322 This is essentially a robust wrapper around `inspect.getsourcelines`.
323
323
324 Returns None if no file can be found.
324 Returns None if no file can be found.
325
325
326 Parameters
326 Parameters
327 ----------
327 ----------
328 obj : any Python object
328 obj : any Python object
329
329
330 Returns
330 Returns
331 -------
331 -------
332 lineno : int
332 lineno : int
333 The line number where the object definition starts.
333 The line number where the object definition starts.
334 """
334 """
335 obj = _get_wrapped(obj)
335 obj = _get_wrapped(obj)
336
336
337 try:
337 try:
338 try:
338 try:
339 lineno = inspect.getsourcelines(obj)[1]
339 lineno = inspect.getsourcelines(obj)[1]
340 except TypeError:
340 except TypeError:
341 # For instances, try the class object like getsource() does
341 # For instances, try the class object like getsource() does
342 if hasattr(obj, '__class__'):
342 if hasattr(obj, '__class__'):
343 lineno = inspect.getsourcelines(obj.__class__)[1]
343 lineno = inspect.getsourcelines(obj.__class__)[1]
344 else:
344 else:
345 lineno = None
345 lineno = None
346 except:
346 except:
347 return None
347 return None
348
348
349 return lineno
349 return lineno
350
350
351 class Inspector(Colorable):
351 class Inspector(Colorable):
352
352
353 def __init__(self, color_table=InspectColors,
353 def __init__(self, color_table=InspectColors,
354 code_color_table=PyColorize.ANSICodeColors,
354 code_color_table=PyColorize.ANSICodeColors,
355 scheme='NoColor',
355 scheme='NoColor',
356 str_detail_level=0,
356 str_detail_level=0,
357 parent=None, config=None):
357 parent=None, config=None):
358 super(Inspector, self).__init__(parent=parent, config=config)
358 super(Inspector, self).__init__(parent=parent, config=config)
359 self.color_table = color_table
359 self.color_table = color_table
360 self.parser = PyColorize.Parser(out='str', parent=self, style=scheme)
360 self.parser = PyColorize.Parser(out='str', parent=self, style=scheme)
361 self.format = self.parser.format
361 self.format = self.parser.format
362 self.str_detail_level = str_detail_level
362 self.str_detail_level = str_detail_level
363 self.set_active_scheme(scheme)
363 self.set_active_scheme(scheme)
364
364
365 def _getdef(self,obj,oname=''):
365 def _getdef(self,obj,oname=''):
366 """Return the call signature for any callable object.
366 """Return the call signature for any callable object.
367
367
368 If any exception is generated, None is returned instead and the
368 If any exception is generated, None is returned instead and the
369 exception is suppressed."""
369 exception is suppressed."""
370 try:
370 try:
371 hdef = oname + str(signature(obj))
371 hdef = oname + str(signature(obj))
372 return cast_unicode(hdef)
372 return cast_unicode(hdef)
373 except:
373 except:
374 return None
374 return None
375
375
376 def __head(self,h):
376 def __head(self,h):
377 """Return a header string with proper colors."""
377 """Return a header string with proper colors."""
378 return '%s%s%s' % (self.color_table.active_colors.header,h,
378 return '%s%s%s' % (self.color_table.active_colors.header,h,
379 self.color_table.active_colors.normal)
379 self.color_table.active_colors.normal)
380
380
381 def set_active_scheme(self, scheme):
381 def set_active_scheme(self, scheme):
382 self.color_table.set_active_scheme(scheme)
382 self.color_table.set_active_scheme(scheme)
383 self.parser.color_table.set_active_scheme(scheme)
383 self.parser.color_table.set_active_scheme(scheme)
384
384
385 def noinfo(self, msg, oname):
385 def noinfo(self, msg, oname):
386 """Generic message when no information is found."""
386 """Generic message when no information is found."""
387 print('No %s found' % msg, end=' ')
387 print('No %s found' % msg, end=' ')
388 if oname:
388 if oname:
389 print('for %s' % oname)
389 print('for %s' % oname)
390 else:
390 else:
391 print()
391 print()
392
392
393 def pdef(self, obj, oname=''):
393 def pdef(self, obj, oname=''):
394 """Print the call signature for any callable object.
394 """Print the call signature for any callable object.
395
395
396 If the object is a class, print the constructor information."""
396 If the object is a class, print the constructor information."""
397
397
398 if not callable(obj):
398 if not callable(obj):
399 print('Object is not callable.')
399 print('Object is not callable.')
400 return
400 return
401
401
402 header = ''
402 header = ''
403
403
404 if inspect.isclass(obj):
404 if inspect.isclass(obj):
405 header = self.__head('Class constructor information:\n')
405 header = self.__head('Class constructor information:\n')
406 elif (not py3compat.PY3) and type(obj) is types.InstanceType:
406
407 obj = obj.__call__
408
407
409 output = self._getdef(obj,oname)
408 output = self._getdef(obj,oname)
410 if output is None:
409 if output is None:
411 self.noinfo('definition header',oname)
410 self.noinfo('definition header',oname)
412 else:
411 else:
413 print(header,self.format(output), end=' ')
412 print(header,self.format(output), end=' ')
414
413
415 # In Python 3, all classes are new-style, so they all have __init__.
414 # In Python 3, all classes are new-style, so they all have __init__.
416 @skip_doctest
415 @skip_doctest
417 def pdoc(self, obj, oname='', formatter=None):
416 def pdoc(self, obj, oname='', formatter=None):
418 """Print the docstring for any object.
417 """Print the docstring for any object.
419
418
420 Optional:
419 Optional:
421 -formatter: a function to run the docstring through for specially
420 -formatter: a function to run the docstring through for specially
422 formatted docstrings.
421 formatted docstrings.
423
422
424 Examples
423 Examples
425 --------
424 --------
426
425
427 In [1]: class NoInit:
426 In [1]: class NoInit:
428 ...: pass
427 ...: pass
429
428
430 In [2]: class NoDoc:
429 In [2]: class NoDoc:
431 ...: def __init__(self):
430 ...: def __init__(self):
432 ...: pass
431 ...: pass
433
432
434 In [3]: %pdoc NoDoc
433 In [3]: %pdoc NoDoc
435 No documentation found for NoDoc
434 No documentation found for NoDoc
436
435
437 In [4]: %pdoc NoInit
436 In [4]: %pdoc NoInit
438 No documentation found for NoInit
437 No documentation found for NoInit
439
438
440 In [5]: obj = NoInit()
439 In [5]: obj = NoInit()
441
440
442 In [6]: %pdoc obj
441 In [6]: %pdoc obj
443 No documentation found for obj
442 No documentation found for obj
444
443
445 In [5]: obj2 = NoDoc()
444 In [5]: obj2 = NoDoc()
446
445
447 In [6]: %pdoc obj2
446 In [6]: %pdoc obj2
448 No documentation found for obj2
447 No documentation found for obj2
449 """
448 """
450
449
451 head = self.__head # For convenience
450 head = self.__head # For convenience
452 lines = []
451 lines = []
453 ds = getdoc(obj)
452 ds = getdoc(obj)
454 if formatter:
453 if formatter:
455 ds = formatter(ds).get('plain/text', ds)
454 ds = formatter(ds).get('plain/text', ds)
456 if ds:
455 if ds:
457 lines.append(head("Class docstring:"))
456 lines.append(head("Class docstring:"))
458 lines.append(indent(ds))
457 lines.append(indent(ds))
459 if inspect.isclass(obj) and hasattr(obj, '__init__'):
458 if inspect.isclass(obj) and hasattr(obj, '__init__'):
460 init_ds = getdoc(obj.__init__)
459 init_ds = getdoc(obj.__init__)
461 if init_ds is not None:
460 if init_ds is not None:
462 lines.append(head("Init docstring:"))
461 lines.append(head("Init docstring:"))
463 lines.append(indent(init_ds))
462 lines.append(indent(init_ds))
464 elif hasattr(obj,'__call__'):
463 elif hasattr(obj,'__call__'):
465 call_ds = getdoc(obj.__call__)
464 call_ds = getdoc(obj.__call__)
466 if call_ds:
465 if call_ds:
467 lines.append(head("Call docstring:"))
466 lines.append(head("Call docstring:"))
468 lines.append(indent(call_ds))
467 lines.append(indent(call_ds))
469
468
470 if not lines:
469 if not lines:
471 self.noinfo('documentation',oname)
470 self.noinfo('documentation',oname)
472 else:
471 else:
473 page.page('\n'.join(lines))
472 page.page('\n'.join(lines))
474
473
475 def psource(self, obj, oname=''):
474 def psource(self, obj, oname=''):
476 """Print the source code for an object."""
475 """Print the source code for an object."""
477
476
478 # Flush the source cache because inspect can return out-of-date source
477 # Flush the source cache because inspect can return out-of-date source
479 linecache.checkcache()
478 linecache.checkcache()
480 try:
479 try:
481 src = getsource(obj, oname=oname)
480 src = getsource(obj, oname=oname)
482 except Exception:
481 except Exception:
483 src = None
482 src = None
484
483
485 if src is None:
484 if src is None:
486 self.noinfo('source', oname)
485 self.noinfo('source', oname)
487 else:
486 else:
488 page.page(self.format(src))
487 page.page(self.format(src))
489
488
490 def pfile(self, obj, oname=''):
489 def pfile(self, obj, oname=''):
491 """Show the whole file where an object was defined."""
490 """Show the whole file where an object was defined."""
492
491
493 lineno = find_source_lines(obj)
492 lineno = find_source_lines(obj)
494 if lineno is None:
493 if lineno is None:
495 self.noinfo('file', oname)
494 self.noinfo('file', oname)
496 return
495 return
497
496
498 ofile = find_file(obj)
497 ofile = find_file(obj)
499 # run contents of file through pager starting at line where the object
498 # run contents of file through pager starting at line where the object
500 # is defined, as long as the file isn't binary and is actually on the
499 # is defined, as long as the file isn't binary and is actually on the
501 # filesystem.
500 # filesystem.
502 if ofile.endswith(('.so', '.dll', '.pyd')):
501 if ofile.endswith(('.so', '.dll', '.pyd')):
503 print('File %r is binary, not printing.' % ofile)
502 print('File %r is binary, not printing.' % ofile)
504 elif not os.path.isfile(ofile):
503 elif not os.path.isfile(ofile):
505 print('File %r does not exist, not printing.' % ofile)
504 print('File %r does not exist, not printing.' % ofile)
506 else:
505 else:
507 # Print only text files, not extension binaries. Note that
506 # Print only text files, not extension binaries. Note that
508 # getsourcelines returns lineno with 1-offset and page() uses
507 # getsourcelines returns lineno with 1-offset and page() uses
509 # 0-offset, so we must adjust.
508 # 0-offset, so we must adjust.
510 page.page(self.format(openpy.read_py_file(ofile, skip_encoding_cookie=False)), lineno - 1)
509 page.page(self.format(openpy.read_py_file(ofile, skip_encoding_cookie=False)), lineno - 1)
511
510
512 def _format_fields(self, fields, title_width=0):
511 def _format_fields(self, fields, title_width=0):
513 """Formats a list of fields for display.
512 """Formats a list of fields for display.
514
513
515 Parameters
514 Parameters
516 ----------
515 ----------
517 fields : list
516 fields : list
518 A list of 2-tuples: (field_title, field_content)
517 A list of 2-tuples: (field_title, field_content)
519 title_width : int
518 title_width : int
520 How many characters to pad titles to. Default to longest title.
519 How many characters to pad titles to. Default to longest title.
521 """
520 """
522 out = []
521 out = []
523 header = self.__head
522 header = self.__head
524 if title_width == 0:
523 if title_width == 0:
525 title_width = max(len(title) + 2 for title, _ in fields)
524 title_width = max(len(title) + 2 for title, _ in fields)
526 for title, content in fields:
525 for title, content in fields:
527 if len(content.splitlines()) > 1:
526 if len(content.splitlines()) > 1:
528 title = header(title + ':') + '\n'
527 title = header(title + ':') + '\n'
529 else:
528 else:
530 title = header((title + ':').ljust(title_width))
529 title = header((title + ':').ljust(title_width))
531 out.append(cast_unicode(title) + cast_unicode(content))
530 out.append(cast_unicode(title) + cast_unicode(content))
532 return "\n".join(out)
531 return "\n".join(out)
533
532
534 def _mime_format(self, text, formatter=None):
533 def _mime_format(self, text, formatter=None):
535 """Return a mime bundle representation of the input text.
534 """Return a mime bundle representation of the input text.
536
535
537 - if `formatter` is None, the returned mime bundle has
536 - if `formatter` is None, the returned mime bundle has
538 a `text/plain` field, with the input text.
537 a `text/plain` field, with the input text.
539 a `text/html` field with a `<pre>` tag containing the input text.
538 a `text/html` field with a `<pre>` tag containing the input text.
540
539
541 - if `formatter` is not None, it must be a callable transforming the
540 - if `formatter` is not None, it must be a callable transforming the
542 input text into a mime bundle. Default values for `text/plain` and
541 input text into a mime bundle. Default values for `text/plain` and
543 `text/html` representations are the ones described above.
542 `text/html` representations are the ones described above.
544
543
545 Note:
544 Note:
546
545
547 Formatters returning strings are supported but this behavior is deprecated.
546 Formatters returning strings are supported but this behavior is deprecated.
548
547
549 """
548 """
550 text = cast_unicode(text)
549 text = cast_unicode(text)
551 defaults = {
550 defaults = {
552 'text/plain': text,
551 'text/plain': text,
553 'text/html': '<pre>' + text + '</pre>'
552 'text/html': '<pre>' + text + '</pre>'
554 }
553 }
555
554
556 if formatter is None:
555 if formatter is None:
557 return defaults
556 return defaults
558 else:
557 else:
559 formatted = formatter(text)
558 formatted = formatter(text)
560
559
561 if not isinstance(formatted, dict):
560 if not isinstance(formatted, dict):
562 # Handle the deprecated behavior of a formatter returning
561 # Handle the deprecated behavior of a formatter returning
563 # a string instead of a mime bundle.
562 # a string instead of a mime bundle.
564 return {
563 return {
565 'text/plain': formatted,
564 'text/plain': formatted,
566 'text/html': '<pre>' + formatted + '</pre>'
565 'text/html': '<pre>' + formatted + '</pre>'
567 }
566 }
568
567
569 else:
568 else:
570 return dict(defaults, **formatted)
569 return dict(defaults, **formatted)
571
570
572
571
573 def format_mime(self, bundle):
572 def format_mime(self, bundle):
574
573
575 text_plain = bundle['text/plain']
574 text_plain = bundle['text/plain']
576
575
577 text = ''
576 text = ''
578 heads, bodies = list(zip(*text_plain))
577 heads, bodies = list(zip(*text_plain))
579 _len = max(len(h) for h in heads)
578 _len = max(len(h) for h in heads)
580
579
581 for head, body in zip(heads, bodies):
580 for head, body in zip(heads, bodies):
582 body = body.strip('\n')
581 body = body.strip('\n')
583 delim = '\n' if '\n' in body else ' '
582 delim = '\n' if '\n' in body else ' '
584 text += self.__head(head+':') + (_len - len(head))*' ' +delim + body +'\n'
583 text += self.__head(head+':') + (_len - len(head))*' ' +delim + body +'\n'
585
584
586 bundle['text/plain'] = text
585 bundle['text/plain'] = text
587 return bundle
586 return bundle
588
587
589 def _get_info(self, obj, oname='', formatter=None, info=None, detail_level=0):
588 def _get_info(self, obj, oname='', formatter=None, info=None, detail_level=0):
590 """Retrieve an info dict and format it."""
589 """Retrieve an info dict and format it."""
591
590
592 info = self._info(obj, oname=oname, info=info, detail_level=detail_level)
591 info = self._info(obj, oname=oname, info=info, detail_level=detail_level)
593
592
594 _mime = {
593 _mime = {
595 'text/plain': [],
594 'text/plain': [],
596 'text/html': '',
595 'text/html': '',
597 }
596 }
598
597
599 def append_field(bundle, title, key, formatter=None):
598 def append_field(bundle, title, key, formatter=None):
600 field = info[key]
599 field = info[key]
601 if field is not None:
600 if field is not None:
602 formatted_field = self._mime_format(field, formatter)
601 formatted_field = self._mime_format(field, formatter)
603 bundle['text/plain'].append((title, formatted_field['text/plain']))
602 bundle['text/plain'].append((title, formatted_field['text/plain']))
604 bundle['text/html'] += '<h1>' + title + '</h1>\n' + formatted_field['text/html'] + '\n'
603 bundle['text/html'] += '<h1>' + title + '</h1>\n' + formatted_field['text/html'] + '\n'
605
604
606 def code_formatter(text):
605 def code_formatter(text):
607 return {
606 return {
608 'text/plain': self.format(text),
607 'text/plain': self.format(text),
609 'text/html': pylight(text)
608 'text/html': pylight(text)
610 }
609 }
611
610
612 if info['isalias']:
611 if info['isalias']:
613 append_field(_mime, 'Repr', 'string_form')
612 append_field(_mime, 'Repr', 'string_form')
614
613
615 elif info['ismagic']:
614 elif info['ismagic']:
616 if detail_level > 0:
615 if detail_level > 0:
617 append_field(_mime, 'Source', 'source', code_formatter)
616 append_field(_mime, 'Source', 'source', code_formatter)
618 else:
617 else:
619 append_field(_mime, 'Docstring', 'docstring', formatter)
618 append_field(_mime, 'Docstring', 'docstring', formatter)
620 append_field(_mime, 'File', 'file')
619 append_field(_mime, 'File', 'file')
621
620
622 elif info['isclass'] or is_simple_callable(obj):
621 elif info['isclass'] or is_simple_callable(obj):
623 # Functions, methods, classes
622 # Functions, methods, classes
624 append_field(_mime, 'Signature', 'definition', code_formatter)
623 append_field(_mime, 'Signature', 'definition', code_formatter)
625 append_field(_mime, 'Init signature', 'init_definition', code_formatter)
624 append_field(_mime, 'Init signature', 'init_definition', code_formatter)
626 if detail_level > 0 and info['source']:
625 if detail_level > 0 and info['source']:
627 append_field(_mime, 'Source', 'source', code_formatter)
626 append_field(_mime, 'Source', 'source', code_formatter)
628 else:
627 else:
629 append_field(_mime, 'Docstring', 'docstring', formatter)
628 append_field(_mime, 'Docstring', 'docstring', formatter)
630 append_field(_mime, 'Init docstring', 'init_docstring', formatter)
629 append_field(_mime, 'Init docstring', 'init_docstring', formatter)
631
630
632 append_field(_mime, 'File', 'file')
631 append_field(_mime, 'File', 'file')
633 append_field(_mime, 'Type', 'type_name')
632 append_field(_mime, 'Type', 'type_name')
634
633
635 else:
634 else:
636 # General Python objects
635 # General Python objects
637 append_field(_mime, 'Signature', 'definition', code_formatter)
636 append_field(_mime, 'Signature', 'definition', code_formatter)
638 append_field(_mime, 'Call signature', 'call_def', code_formatter)
637 append_field(_mime, 'Call signature', 'call_def', code_formatter)
639
640 append_field(_mime, 'Type', 'type_name')
638 append_field(_mime, 'Type', 'type_name')
641
642 # Base class for old-style instances
643 if (not py3compat.PY3) and isinstance(obj, types.InstanceType) and info['base_class']:
644 append_field(_mime, 'Base Class', 'base_class')
645
646 append_field(_mime, 'String form', 'string_form')
639 append_field(_mime, 'String form', 'string_form')
647
640
648 # Namespace
641 # Namespace
649 if info['namespace'] != 'Interactive':
642 if info['namespace'] != 'Interactive':
650 append_field(_mime, 'Namespace', 'namespace')
643 append_field(_mime, 'Namespace', 'namespace')
651
644
652 append_field(_mime, 'Length', 'length')
645 append_field(_mime, 'Length', 'length')
653 append_field(_mime, 'File', 'file')
646 append_field(_mime, 'File', 'file')
654
647
655 # Source or docstring, depending on detail level and whether
648 # Source or docstring, depending on detail level and whether
656 # source found.
649 # source found.
657 if detail_level > 0:
650 if detail_level > 0:
658 append_field(_mime, 'Source', 'source', code_formatter)
651 append_field(_mime, 'Source', 'source', code_formatter)
659 else:
652 else:
660 append_field(_mime, 'Docstring', 'docstring', formatter)
653 append_field(_mime, 'Docstring', 'docstring', formatter)
661
654
662 append_field(_mime, 'Class docstring', 'class_docstring', formatter)
655 append_field(_mime, 'Class docstring', 'class_docstring', formatter)
663 append_field(_mime, 'Init docstring', 'init_docstring', formatter)
656 append_field(_mime, 'Init docstring', 'init_docstring', formatter)
664 append_field(_mime, 'Call docstring', 'call_docstring', formatter)
657 append_field(_mime, 'Call docstring', 'call_docstring', formatter)
665
658
666
659
667 return self.format_mime(_mime)
660 return self.format_mime(_mime)
668
661
669 def pinfo(self, obj, oname='', formatter=None, info=None, detail_level=0, enable_html_pager=True):
662 def pinfo(self, obj, oname='', formatter=None, info=None, detail_level=0, enable_html_pager=True):
670 """Show detailed information about an object.
663 """Show detailed information about an object.
671
664
672 Optional arguments:
665 Optional arguments:
673
666
674 - oname: name of the variable pointing to the object.
667 - oname: name of the variable pointing to the object.
675
668
676 - formatter: callable (optional)
669 - formatter: callable (optional)
677 A special formatter for docstrings.
670 A special formatter for docstrings.
678
671
679 The formatter is a callable that takes a string as an input
672 The formatter is a callable that takes a string as an input
680 and returns either a formatted string or a mime type bundle
673 and returns either a formatted string or a mime type bundle
681 in the form of a dictionnary.
674 in the form of a dictionnary.
682
675
683 Although the support of custom formatter returning a string
676 Although the support of custom formatter returning a string
684 instead of a mime type bundle is deprecated.
677 instead of a mime type bundle is deprecated.
685
678
686 - info: a structure with some information fields which may have been
679 - info: a structure with some information fields which may have been
687 precomputed already.
680 precomputed already.
688
681
689 - detail_level: if set to 1, more information is given.
682 - detail_level: if set to 1, more information is given.
690 """
683 """
691 info = self._get_info(obj, oname, formatter, info, detail_level)
684 info = self._get_info(obj, oname, formatter, info, detail_level)
692 if not enable_html_pager:
685 if not enable_html_pager:
693 del info['text/html']
686 del info['text/html']
694 page.page(info)
687 page.page(info)
695
688
696 def info(self, obj, oname='', formatter=None, info=None, detail_level=0):
689 def info(self, obj, oname='', formatter=None, info=None, detail_level=0):
697 """DEPRECATED. Compute a dict with detailed information about an object.
690 """DEPRECATED. Compute a dict with detailed information about an object.
698 """
691 """
699 if formatter is not None:
692 if formatter is not None:
700 warnings.warn('The `formatter` keyword argument to `Inspector.info`'
693 warnings.warn('The `formatter` keyword argument to `Inspector.info`'
701 'is deprecated as of IPython 5.0 and will have no effects.',
694 'is deprecated as of IPython 5.0 and will have no effects.',
702 DeprecationWarning, stacklevel=2)
695 DeprecationWarning, stacklevel=2)
703 return self._info(obj, oname=oname, info=info, detail_level=detail_level)
696 return self._info(obj, oname=oname, info=info, detail_level=detail_level)
704
697
705 def _info(self, obj, oname='', info=None, detail_level=0):
698 def _info(self, obj, oname='', info=None, detail_level=0):
706 """Compute a dict with detailed information about an object.
699 """Compute a dict with detailed information about an object.
707
700
708 Optional arguments:
701 Optional arguments:
709
702
710 - oname: name of the variable pointing to the object.
703 - oname: name of the variable pointing to the object.
711
704
712 - info: a structure with some information fields which may have been
705 - info: a structure with some information fields which may have been
713 precomputed already.
706 precomputed already.
714
707
715 - detail_level: if set to 1, more information is given.
708 - detail_level: if set to 1, more information is given.
716 """
709 """
717
710
718 obj_type = type(obj)
711 obj_type = type(obj)
719
712
720 if info is None:
713 if info is None:
721 ismagic = 0
714 ismagic = 0
722 isalias = 0
715 isalias = 0
723 ospace = ''
716 ospace = ''
724 else:
717 else:
725 ismagic = info.ismagic
718 ismagic = info.ismagic
726 isalias = info.isalias
719 isalias = info.isalias
727 ospace = info.namespace
720 ospace = info.namespace
728
721
729 # Get docstring, special-casing aliases:
722 # Get docstring, special-casing aliases:
730 if isalias:
723 if isalias:
731 if not callable(obj):
724 if not callable(obj):
732 try:
725 try:
733 ds = "Alias to the system command:\n %s" % obj[1]
726 ds = "Alias to the system command:\n %s" % obj[1]
734 except:
727 except:
735 ds = "Alias: " + str(obj)
728 ds = "Alias: " + str(obj)
736 else:
729 else:
737 ds = "Alias to " + str(obj)
730 ds = "Alias to " + str(obj)
738 if obj.__doc__:
731 if obj.__doc__:
739 ds += "\nDocstring:\n" + obj.__doc__
732 ds += "\nDocstring:\n" + obj.__doc__
740 else:
733 else:
741 ds = getdoc(obj)
734 ds = getdoc(obj)
742 if ds is None:
735 if ds is None:
743 ds = '<no docstring>'
736 ds = '<no docstring>'
744
737
745 # store output in a dict, we initialize it here and fill it as we go
738 # store output in a dict, we initialize it here and fill it as we go
746 out = dict(name=oname, found=True, isalias=isalias, ismagic=ismagic)
739 out = dict(name=oname, found=True, isalias=isalias, ismagic=ismagic)
747
740
748 string_max = 200 # max size of strings to show (snipped if longer)
741 string_max = 200 # max size of strings to show (snipped if longer)
749 shalf = int((string_max - 5) / 2)
742 shalf = int((string_max - 5) / 2)
750
743
751 if ismagic:
744 if ismagic:
752 obj_type_name = 'Magic function'
745 obj_type_name = 'Magic function'
753 elif isalias:
746 elif isalias:
754 obj_type_name = 'System alias'
747 obj_type_name = 'System alias'
755 else:
748 else:
756 obj_type_name = obj_type.__name__
749 obj_type_name = obj_type.__name__
757 out['type_name'] = obj_type_name
750 out['type_name'] = obj_type_name
758
751
759 try:
752 try:
760 bclass = obj.__class__
753 bclass = obj.__class__
761 out['base_class'] = str(bclass)
754 out['base_class'] = str(bclass)
762 except: pass
755 except: pass
763
756
764 # String form, but snip if too long in ? form (full in ??)
757 # String form, but snip if too long in ? form (full in ??)
765 if detail_level >= self.str_detail_level:
758 if detail_level >= self.str_detail_level:
766 try:
759 try:
767 ostr = str(obj)
760 ostr = str(obj)
768 str_head = 'string_form'
761 str_head = 'string_form'
769 if not detail_level and len(ostr)>string_max:
762 if not detail_level and len(ostr)>string_max:
770 ostr = ostr[:shalf] + ' <...> ' + ostr[-shalf:]
763 ostr = ostr[:shalf] + ' <...> ' + ostr[-shalf:]
771 ostr = ("\n" + " " * len(str_head.expandtabs())).\
764 ostr = ("\n" + " " * len(str_head.expandtabs())).\
772 join(q.strip() for q in ostr.split("\n"))
765 join(q.strip() for q in ostr.split("\n"))
773 out[str_head] = ostr
766 out[str_head] = ostr
774 except:
767 except:
775 pass
768 pass
776
769
777 if ospace:
770 if ospace:
778 out['namespace'] = ospace
771 out['namespace'] = ospace
779
772
780 # Length (for strings and lists)
773 # Length (for strings and lists)
781 try:
774 try:
782 out['length'] = str(len(obj))
775 out['length'] = str(len(obj))
783 except: pass
776 except: pass
784
777
785 # Filename where object was defined
778 # Filename where object was defined
786 binary_file = False
779 binary_file = False
787 fname = find_file(obj)
780 fname = find_file(obj)
788 if fname is None:
781 if fname is None:
789 # if anything goes wrong, we don't want to show source, so it's as
782 # if anything goes wrong, we don't want to show source, so it's as
790 # if the file was binary
783 # if the file was binary
791 binary_file = True
784 binary_file = True
792 else:
785 else:
793 if fname.endswith(('.so', '.dll', '.pyd')):
786 if fname.endswith(('.so', '.dll', '.pyd')):
794 binary_file = True
787 binary_file = True
795 elif fname.endswith('<string>'):
788 elif fname.endswith('<string>'):
796 fname = 'Dynamically generated function. No source code available.'
789 fname = 'Dynamically generated function. No source code available.'
797 out['file'] = compress_user(fname)
790 out['file'] = compress_user(fname)
798
791
799 # Original source code for a callable, class or property.
792 # Original source code for a callable, class or property.
800 if detail_level:
793 if detail_level:
801 # Flush the source cache because inspect can return out-of-date
794 # Flush the source cache because inspect can return out-of-date
802 # source
795 # source
803 linecache.checkcache()
796 linecache.checkcache()
804 try:
797 try:
805 if isinstance(obj, property) or not binary_file:
798 if isinstance(obj, property) or not binary_file:
806 src = getsource(obj, oname)
799 src = getsource(obj, oname)
807 if src is not None:
800 if src is not None:
808 src = src.rstrip()
801 src = src.rstrip()
809 out['source'] = src
802 out['source'] = src
810
803
811 except Exception:
804 except Exception:
812 pass
805 pass
813
806
814 # Add docstring only if no source is to be shown (avoid repetitions).
807 # Add docstring only if no source is to be shown (avoid repetitions).
815 if ds and out.get('source', None) is None:
808 if ds and out.get('source', None) is None:
816 out['docstring'] = ds
809 out['docstring'] = ds
817
810
818 # Constructor docstring for classes
811 # Constructor docstring for classes
819 if inspect.isclass(obj):
812 if inspect.isclass(obj):
820 out['isclass'] = True
813 out['isclass'] = True
821
814
822 # get the init signature:
815 # get the init signature:
823 try:
816 try:
824 init_def = self._getdef(obj, oname)
817 init_def = self._getdef(obj, oname)
825 except AttributeError:
818 except AttributeError:
826 init_def = None
819 init_def = None
827
820
828 # get the __init__ docstring
821 # get the __init__ docstring
829 try:
822 try:
830 obj_init = obj.__init__
823 obj_init = obj.__init__
831 except AttributeError:
824 except AttributeError:
832 init_ds = None
825 init_ds = None
833 else:
826 else:
834 if init_def is None:
827 if init_def is None:
835 # Get signature from init if top-level sig failed.
828 # Get signature from init if top-level sig failed.
836 # Can happen for built-in types (list, etc.).
829 # Can happen for built-in types (list, etc.).
837 try:
830 try:
838 init_def = self._getdef(obj_init, oname)
831 init_def = self._getdef(obj_init, oname)
839 except AttributeError:
832 except AttributeError:
840 pass
833 pass
841 init_ds = getdoc(obj_init)
834 init_ds = getdoc(obj_init)
842 # Skip Python's auto-generated docstrings
835 # Skip Python's auto-generated docstrings
843 if init_ds == _object_init_docstring:
836 if init_ds == _object_init_docstring:
844 init_ds = None
837 init_ds = None
845
838
846 if init_def:
839 if init_def:
847 out['init_definition'] = init_def
840 out['init_definition'] = init_def
848
841
849 if init_ds:
842 if init_ds:
850 out['init_docstring'] = init_ds
843 out['init_docstring'] = init_ds
851
844
852 # and class docstring for instances:
845 # and class docstring for instances:
853 else:
846 else:
854 # reconstruct the function definition and print it:
847 # reconstruct the function definition and print it:
855 defln = self._getdef(obj, oname)
848 defln = self._getdef(obj, oname)
856 if defln:
849 if defln:
857 out['definition'] = defln
850 out['definition'] = defln
858
851
859 # First, check whether the instance docstring is identical to the
852 # First, check whether the instance docstring is identical to the
860 # class one, and print it separately if they don't coincide. In
853 # class one, and print it separately if they don't coincide. In
861 # most cases they will, but it's nice to print all the info for
854 # most cases they will, but it's nice to print all the info for
862 # objects which use instance-customized docstrings.
855 # objects which use instance-customized docstrings.
863 if ds:
856 if ds:
864 try:
857 try:
865 cls = getattr(obj,'__class__')
858 cls = getattr(obj,'__class__')
866 except:
859 except:
867 class_ds = None
860 class_ds = None
868 else:
861 else:
869 class_ds = getdoc(cls)
862 class_ds = getdoc(cls)
870 # Skip Python's auto-generated docstrings
863 # Skip Python's auto-generated docstrings
871 if class_ds in _builtin_type_docstrings:
864 if class_ds in _builtin_type_docstrings:
872 class_ds = None
865 class_ds = None
873 if class_ds and ds != class_ds:
866 if class_ds and ds != class_ds:
874 out['class_docstring'] = class_ds
867 out['class_docstring'] = class_ds
875
868
876 # Next, try to show constructor docstrings
869 # Next, try to show constructor docstrings
877 try:
870 try:
878 init_ds = getdoc(obj.__init__)
871 init_ds = getdoc(obj.__init__)
879 # Skip Python's auto-generated docstrings
872 # Skip Python's auto-generated docstrings
880 if init_ds == _object_init_docstring:
873 if init_ds == _object_init_docstring:
881 init_ds = None
874 init_ds = None
882 except AttributeError:
875 except AttributeError:
883 init_ds = None
876 init_ds = None
884 if init_ds:
877 if init_ds:
885 out['init_docstring'] = init_ds
878 out['init_docstring'] = init_ds
886
879
887 # Call form docstring for callable instances
880 # Call form docstring for callable instances
888 if safe_hasattr(obj, '__call__') and not is_simple_callable(obj):
881 if safe_hasattr(obj, '__call__') and not is_simple_callable(obj):
889 call_def = self._getdef(obj.__call__, oname)
882 call_def = self._getdef(obj.__call__, oname)
890 if call_def and (call_def != out.get('definition')):
883 if call_def and (call_def != out.get('definition')):
891 # it may never be the case that call def and definition differ,
884 # it may never be the case that call def and definition differ,
892 # but don't include the same signature twice
885 # but don't include the same signature twice
893 out['call_def'] = call_def
886 out['call_def'] = call_def
894 call_ds = getdoc(obj.__call__)
887 call_ds = getdoc(obj.__call__)
895 # Skip Python's auto-generated docstrings
888 # Skip Python's auto-generated docstrings
896 if call_ds == _func_call_docstring:
889 if call_ds == _func_call_docstring:
897 call_ds = None
890 call_ds = None
898 if call_ds:
891 if call_ds:
899 out['call_docstring'] = call_ds
892 out['call_docstring'] = call_ds
900
893
901 # Compute the object's argspec as a callable. The key is to decide
894 # Compute the object's argspec as a callable. The key is to decide
902 # whether to pull it from the object itself, from its __init__ or
895 # whether to pull it from the object itself, from its __init__ or
903 # from its __call__ method.
896 # from its __call__ method.
904
897
905 if inspect.isclass(obj):
898 if inspect.isclass(obj):
906 # Old-style classes need not have an __init__
899 # Old-style classes need not have an __init__
907 callable_obj = getattr(obj, "__init__", None)
900 callable_obj = getattr(obj, "__init__", None)
908 elif callable(obj):
901 elif callable(obj):
909 callable_obj = obj
902 callable_obj = obj
910 else:
903 else:
911 callable_obj = None
904 callable_obj = None
912
905
913 if callable_obj is not None:
906 if callable_obj is not None:
914 try:
907 try:
915 argspec = getargspec(callable_obj)
908 argspec = getargspec(callable_obj)
916 except (TypeError, AttributeError):
909 except (TypeError, AttributeError):
917 # For extensions/builtins we can't retrieve the argspec
910 # For extensions/builtins we can't retrieve the argspec
918 pass
911 pass
919 else:
912 else:
920 # named tuples' _asdict() method returns an OrderedDict, but we
913 # named tuples' _asdict() method returns an OrderedDict, but we
921 # we want a normal
914 # we want a normal
922 out['argspec'] = argspec_dict = dict(argspec._asdict())
915 out['argspec'] = argspec_dict = dict(argspec._asdict())
923 # We called this varkw before argspec became a named tuple.
916 # We called this varkw before argspec became a named tuple.
924 # With getfullargspec it's also called varkw.
917 # With getfullargspec it's also called varkw.
925 if 'varkw' not in argspec_dict:
918 if 'varkw' not in argspec_dict:
926 argspec_dict['varkw'] = argspec_dict.pop('keywords')
919 argspec_dict['varkw'] = argspec_dict.pop('keywords')
927
920
928 return object_info(**out)
921 return object_info(**out)
929
922
930 def psearch(self,pattern,ns_table,ns_search=[],
923 def psearch(self,pattern,ns_table,ns_search=[],
931 ignore_case=False,show_all=False):
924 ignore_case=False,show_all=False):
932 """Search namespaces with wildcards for objects.
925 """Search namespaces with wildcards for objects.
933
926
934 Arguments:
927 Arguments:
935
928
936 - pattern: string containing shell-like wildcards to use in namespace
929 - pattern: string containing shell-like wildcards to use in namespace
937 searches and optionally a type specification to narrow the search to
930 searches and optionally a type specification to narrow the search to
938 objects of that type.
931 objects of that type.
939
932
940 - ns_table: dict of name->namespaces for search.
933 - ns_table: dict of name->namespaces for search.
941
934
942 Optional arguments:
935 Optional arguments:
943
936
944 - ns_search: list of namespace names to include in search.
937 - ns_search: list of namespace names to include in search.
945
938
946 - ignore_case(False): make the search case-insensitive.
939 - ignore_case(False): make the search case-insensitive.
947
940
948 - show_all(False): show all names, including those starting with
941 - show_all(False): show all names, including those starting with
949 underscores.
942 underscores.
950 """
943 """
951 #print 'ps pattern:<%r>' % pattern # dbg
944 #print 'ps pattern:<%r>' % pattern # dbg
952
945
953 # defaults
946 # defaults
954 type_pattern = 'all'
947 type_pattern = 'all'
955 filter = ''
948 filter = ''
956
949
957 cmds = pattern.split()
950 cmds = pattern.split()
958 len_cmds = len(cmds)
951 len_cmds = len(cmds)
959 if len_cmds == 1:
952 if len_cmds == 1:
960 # Only filter pattern given
953 # Only filter pattern given
961 filter = cmds[0]
954 filter = cmds[0]
962 elif len_cmds == 2:
955 elif len_cmds == 2:
963 # Both filter and type specified
956 # Both filter and type specified
964 filter,type_pattern = cmds
957 filter,type_pattern = cmds
965 else:
958 else:
966 raise ValueError('invalid argument string for psearch: <%s>' %
959 raise ValueError('invalid argument string for psearch: <%s>' %
967 pattern)
960 pattern)
968
961
969 # filter search namespaces
962 # filter search namespaces
970 for name in ns_search:
963 for name in ns_search:
971 if name not in ns_table:
964 if name not in ns_table:
972 raise ValueError('invalid namespace <%s>. Valid names: %s' %
965 raise ValueError('invalid namespace <%s>. Valid names: %s' %
973 (name,ns_table.keys()))
966 (name,ns_table.keys()))
974
967
975 #print 'type_pattern:',type_pattern # dbg
968 #print 'type_pattern:',type_pattern # dbg
976 search_result, namespaces_seen = set(), set()
969 search_result, namespaces_seen = set(), set()
977 for ns_name in ns_search:
970 for ns_name in ns_search:
978 ns = ns_table[ns_name]
971 ns = ns_table[ns_name]
979 # Normally, locals and globals are the same, so we just check one.
972 # Normally, locals and globals are the same, so we just check one.
980 if id(ns) in namespaces_seen:
973 if id(ns) in namespaces_seen:
981 continue
974 continue
982 namespaces_seen.add(id(ns))
975 namespaces_seen.add(id(ns))
983 tmp_res = list_namespace(ns, type_pattern, filter,
976 tmp_res = list_namespace(ns, type_pattern, filter,
984 ignore_case=ignore_case, show_all=show_all)
977 ignore_case=ignore_case, show_all=show_all)
985 search_result.update(tmp_res)
978 search_result.update(tmp_res)
986
979
987 page.page('\n'.join(sorted(search_result)))
980 page.page('\n'.join(sorted(search_result)))
@@ -1,74 +1,74 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for the compilerop module.
2 """Tests for the compilerop module.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team.
5 # Copyright (C) 2010-2011 The IPython Development Team.
6 #
6 #
7 # Distributed under the terms of the BSD License.
7 # Distributed under the terms of the BSD License.
8 #
8 #
9 # The full license is in the file COPYING.txt, distributed with this software.
9 # The full license is in the file COPYING.txt, distributed with this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 # Stdlib imports
16 # Stdlib imports
17 import linecache
17 import linecache
18 import sys
18 import sys
19
19
20 # Third-party imports
20 # Third-party imports
21 import nose.tools as nt
21 import nose.tools as nt
22
22
23 # Our own imports
23 # Our own imports
24 from IPython.core import compilerop
24 from IPython.core import compilerop
25 from IPython.utils import py3compat
25 from IPython.utils import py3compat
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Test functions
28 # Test functions
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 def test_code_name():
31 def test_code_name():
32 code = 'x=1'
32 code = 'x=1'
33 name = compilerop.code_name(code)
33 name = compilerop.code_name(code)
34 nt.assert_true(name.startswith('<ipython-input-0'))
34 nt.assert_true(name.startswith('<ipython-input-0'))
35
35
36
36
37 def test_code_name2():
37 def test_code_name2():
38 code = 'x=1'
38 code = 'x=1'
39 name = compilerop.code_name(code, 9)
39 name = compilerop.code_name(code, 9)
40 nt.assert_true(name.startswith('<ipython-input-9'))
40 nt.assert_true(name.startswith('<ipython-input-9'))
41
41
42
42
43 def test_cache():
43 def test_cache():
44 """Test the compiler correctly compiles and caches inputs
44 """Test the compiler correctly compiles and caches inputs
45 """
45 """
46 cp = compilerop.CachingCompiler()
46 cp = compilerop.CachingCompiler()
47 ncache = len(linecache.cache)
47 ncache = len(linecache.cache)
48 cp.cache('x=1')
48 cp.cache('x=1')
49 nt.assert_true(len(linecache.cache) > ncache)
49 nt.assert_true(len(linecache.cache) > ncache)
50
50
51 def setUp():
51 def setUp():
52 # Check we're in a proper Python 2 environment (some imports, such
52 # Check we're in a proper Python 2 environment (some imports, such
53 # as GTK, can change the default encoding, which can hide bugs.)
53 # as GTK, can change the default encoding, which can hide bugs.)
54 nt.assert_equal(sys.getdefaultencoding(), "utf-8" if py3compat.PY3 else "ascii")
54 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
55
55
56 def test_cache_unicode():
56 def test_cache_unicode():
57 cp = compilerop.CachingCompiler()
57 cp = compilerop.CachingCompiler()
58 ncache = len(linecache.cache)
58 ncache = len(linecache.cache)
59 cp.cache(u"t = 'žćčőđ'")
59 cp.cache(u"t = 'žćčőđ'")
60 nt.assert_true(len(linecache.cache) > ncache)
60 nt.assert_true(len(linecache.cache) > ncache)
61
61
62 def test_compiler_check_cache():
62 def test_compiler_check_cache():
63 """Test the compiler properly manages the cache.
63 """Test the compiler properly manages the cache.
64 """
64 """
65 # Rather simple-minded tests that just exercise the API
65 # Rather simple-minded tests that just exercise the API
66 cp = compilerop.CachingCompiler()
66 cp = compilerop.CachingCompiler()
67 cp.cache('x=1', 99)
67 cp.cache('x=1', 99)
68 # Ensure now that after clearing the cache, our entries survive
68 # Ensure now that after clearing the cache, our entries survive
69 linecache.checkcache()
69 linecache.checkcache()
70 for k in linecache.cache:
70 for k in linecache.cache:
71 if k.startswith('<ipython-input-99'):
71 if k.startswith('<ipython-input-99'):
72 break
72 break
73 else:
73 else:
74 raise AssertionError('Entry for input-99 missing from linecache')
74 raise AssertionError('Entry for input-99 missing from linecache')
@@ -1,211 +1,211 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for the IPython tab-completion machinery.
2 """Tests for the IPython tab-completion machinery.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Module imports
5 # Module imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7
7
8 # stdlib
8 # stdlib
9 import io
9 import io
10 import os
10 import os
11 import sys
11 import sys
12 import tempfile
12 import tempfile
13 from datetime import datetime
13 from datetime import datetime
14
14
15 # third party
15 # third party
16 import nose.tools as nt
16 import nose.tools as nt
17
17
18 # our own packages
18 # our own packages
19 from traitlets.config.loader import Config
19 from traitlets.config.loader import Config
20 from IPython.utils.tempdir import TemporaryDirectory
20 from IPython.utils.tempdir import TemporaryDirectory
21 from IPython.core.history import HistoryManager, extract_hist_ranges
21 from IPython.core.history import HistoryManager, extract_hist_ranges
22 from IPython.utils import py3compat
22 from IPython.utils import py3compat
23
23
24 def setUp():
24 def setUp():
25 nt.assert_equal(sys.getdefaultencoding(), "utf-8" if py3compat.PY3 else "ascii")
25 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
26
26
27 def test_history():
27 def test_history():
28 ip = get_ipython()
28 ip = get_ipython()
29 with TemporaryDirectory() as tmpdir:
29 with TemporaryDirectory() as tmpdir:
30 hist_manager_ori = ip.history_manager
30 hist_manager_ori = ip.history_manager
31 hist_file = os.path.join(tmpdir, 'history.sqlite')
31 hist_file = os.path.join(tmpdir, 'history.sqlite')
32 try:
32 try:
33 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
33 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
34 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
34 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
35 for i, h in enumerate(hist, start=1):
35 for i, h in enumerate(hist, start=1):
36 ip.history_manager.store_inputs(i, h)
36 ip.history_manager.store_inputs(i, h)
37
37
38 ip.history_manager.db_log_output = True
38 ip.history_manager.db_log_output = True
39 # Doesn't match the input, but we'll just check it's stored.
39 # Doesn't match the input, but we'll just check it's stored.
40 ip.history_manager.output_hist_reprs[3] = "spam"
40 ip.history_manager.output_hist_reprs[3] = "spam"
41 ip.history_manager.store_output(3)
41 ip.history_manager.store_output(3)
42
42
43 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
43 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
44
44
45 # Detailed tests for _get_range_session
45 # Detailed tests for _get_range_session
46 grs = ip.history_manager._get_range_session
46 grs = ip.history_manager._get_range_session
47 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
47 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
48 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
48 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
49 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
49 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
50
50
51 # Check whether specifying a range beyond the end of the current
51 # Check whether specifying a range beyond the end of the current
52 # session results in an error (gh-804)
52 # session results in an error (gh-804)
53 ip.magic('%hist 2-500')
53 ip.magic('%hist 2-500')
54
54
55 # Check that we can write non-ascii characters to a file
55 # Check that we can write non-ascii characters to a file
56 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
56 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
57 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
57 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
58 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
58 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
59 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
59 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
60
60
61 # New session
61 # New session
62 ip.history_manager.reset()
62 ip.history_manager.reset()
63 newcmds = [u"z=5",
63 newcmds = [u"z=5",
64 u"class X(object):\n pass",
64 u"class X(object):\n pass",
65 u"k='p'",
65 u"k='p'",
66 u"z=5"]
66 u"z=5"]
67 for i, cmd in enumerate(newcmds, start=1):
67 for i, cmd in enumerate(newcmds, start=1):
68 ip.history_manager.store_inputs(i, cmd)
68 ip.history_manager.store_inputs(i, cmd)
69 gothist = ip.history_manager.get_range(start=1, stop=4)
69 gothist = ip.history_manager.get_range(start=1, stop=4)
70 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
70 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
71 # Previous session:
71 # Previous session:
72 gothist = ip.history_manager.get_range(-1, 1, 4)
72 gothist = ip.history_manager.get_range(-1, 1, 4)
73 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
73 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
74
74
75 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
75 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
76
76
77 # Check get_hist_tail
77 # Check get_hist_tail
78 gothist = ip.history_manager.get_tail(5, output=True,
78 gothist = ip.history_manager.get_tail(5, output=True,
79 include_latest=True)
79 include_latest=True)
80 expected = [(1, 3, (hist[-1], "spam"))] \
80 expected = [(1, 3, (hist[-1], "spam"))] \
81 + [(s, n, (c, None)) for (s, n, c) in newhist]
81 + [(s, n, (c, None)) for (s, n, c) in newhist]
82 nt.assert_equal(list(gothist), expected)
82 nt.assert_equal(list(gothist), expected)
83
83
84 gothist = ip.history_manager.get_tail(2)
84 gothist = ip.history_manager.get_tail(2)
85 expected = newhist[-3:-1]
85 expected = newhist[-3:-1]
86 nt.assert_equal(list(gothist), expected)
86 nt.assert_equal(list(gothist), expected)
87
87
88 # Check get_hist_search
88 # Check get_hist_search
89 gothist = ip.history_manager.search("*test*")
89 gothist = ip.history_manager.search("*test*")
90 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
90 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
91
91
92 gothist = ip.history_manager.search("*=*")
92 gothist = ip.history_manager.search("*=*")
93 nt.assert_equal(list(gothist),
93 nt.assert_equal(list(gothist),
94 [(1, 1, hist[0]),
94 [(1, 1, hist[0]),
95 (1, 2, hist[1]),
95 (1, 2, hist[1]),
96 (1, 3, hist[2]),
96 (1, 3, hist[2]),
97 newhist[0],
97 newhist[0],
98 newhist[2],
98 newhist[2],
99 newhist[3]])
99 newhist[3]])
100
100
101 gothist = ip.history_manager.search("*=*", n=4)
101 gothist = ip.history_manager.search("*=*", n=4)
102 nt.assert_equal(list(gothist),
102 nt.assert_equal(list(gothist),
103 [(1, 3, hist[2]),
103 [(1, 3, hist[2]),
104 newhist[0],
104 newhist[0],
105 newhist[2],
105 newhist[2],
106 newhist[3]])
106 newhist[3]])
107
107
108 gothist = ip.history_manager.search("*=*", unique=True)
108 gothist = ip.history_manager.search("*=*", unique=True)
109 nt.assert_equal(list(gothist),
109 nt.assert_equal(list(gothist),
110 [(1, 1, hist[0]),
110 [(1, 1, hist[0]),
111 (1, 2, hist[1]),
111 (1, 2, hist[1]),
112 (1, 3, hist[2]),
112 (1, 3, hist[2]),
113 newhist[2],
113 newhist[2],
114 newhist[3]])
114 newhist[3]])
115
115
116 gothist = ip.history_manager.search("*=*", unique=True, n=3)
116 gothist = ip.history_manager.search("*=*", unique=True, n=3)
117 nt.assert_equal(list(gothist),
117 nt.assert_equal(list(gothist),
118 [(1, 3, hist[2]),
118 [(1, 3, hist[2]),
119 newhist[2],
119 newhist[2],
120 newhist[3]])
120 newhist[3]])
121
121
122 gothist = ip.history_manager.search("b*", output=True)
122 gothist = ip.history_manager.search("b*", output=True)
123 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
123 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
124
124
125 # Cross testing: check that magic %save can get previous session.
125 # Cross testing: check that magic %save can get previous session.
126 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
126 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
127 ip.magic("save " + testfilename + " ~1/1-3")
127 ip.magic("save " + testfilename + " ~1/1-3")
128 with io.open(testfilename, encoding='utf-8') as testfile:
128 with io.open(testfilename, encoding='utf-8') as testfile:
129 nt.assert_equal(testfile.read(),
129 nt.assert_equal(testfile.read(),
130 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
130 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
131
131
132 # Duplicate line numbers - check that it doesn't crash, and
132 # Duplicate line numbers - check that it doesn't crash, and
133 # gets a new session
133 # gets a new session
134 ip.history_manager.store_inputs(1, "rogue")
134 ip.history_manager.store_inputs(1, "rogue")
135 ip.history_manager.writeout_cache()
135 ip.history_manager.writeout_cache()
136 nt.assert_equal(ip.history_manager.session_number, 3)
136 nt.assert_equal(ip.history_manager.session_number, 3)
137 finally:
137 finally:
138 # Ensure saving thread is shut down before we try to clean up the files
138 # Ensure saving thread is shut down before we try to clean up the files
139 ip.history_manager.save_thread.stop()
139 ip.history_manager.save_thread.stop()
140 # Forcibly close database rather than relying on garbage collection
140 # Forcibly close database rather than relying on garbage collection
141 ip.history_manager.db.close()
141 ip.history_manager.db.close()
142 # Restore history manager
142 # Restore history manager
143 ip.history_manager = hist_manager_ori
143 ip.history_manager = hist_manager_ori
144
144
145
145
146 def test_extract_hist_ranges():
146 def test_extract_hist_ranges():
147 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
147 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
148 expected = [(0, 1, 2), # 0 == current session
148 expected = [(0, 1, 2), # 0 == current session
149 (2, 3, 4),
149 (2, 3, 4),
150 (-4, 5, 7),
150 (-4, 5, 7),
151 (-4, 7, 10),
151 (-4, 7, 10),
152 (-9, 2, None), # None == to end
152 (-9, 2, None), # None == to end
153 (-8, 1, None),
153 (-8, 1, None),
154 (-7, 1, 6),
154 (-7, 1, 6),
155 (-10, 1, None)]
155 (-10, 1, None)]
156 actual = list(extract_hist_ranges(instr))
156 actual = list(extract_hist_ranges(instr))
157 nt.assert_equal(actual, expected)
157 nt.assert_equal(actual, expected)
158
158
159 def test_magic_rerun():
159 def test_magic_rerun():
160 """Simple test for %rerun (no args -> rerun last line)"""
160 """Simple test for %rerun (no args -> rerun last line)"""
161 ip = get_ipython()
161 ip = get_ipython()
162 ip.run_cell("a = 10", store_history=True)
162 ip.run_cell("a = 10", store_history=True)
163 ip.run_cell("a += 1", store_history=True)
163 ip.run_cell("a += 1", store_history=True)
164 nt.assert_equal(ip.user_ns["a"], 11)
164 nt.assert_equal(ip.user_ns["a"], 11)
165 ip.run_cell("%rerun", store_history=True)
165 ip.run_cell("%rerun", store_history=True)
166 nt.assert_equal(ip.user_ns["a"], 12)
166 nt.assert_equal(ip.user_ns["a"], 12)
167
167
168 def test_timestamp_type():
168 def test_timestamp_type():
169 ip = get_ipython()
169 ip = get_ipython()
170 info = ip.history_manager.get_session_info()
170 info = ip.history_manager.get_session_info()
171 nt.assert_true(isinstance(info[1], datetime))
171 nt.assert_true(isinstance(info[1], datetime))
172
172
173 def test_hist_file_config():
173 def test_hist_file_config():
174 cfg = Config()
174 cfg = Config()
175 tfile = tempfile.NamedTemporaryFile(delete=False)
175 tfile = tempfile.NamedTemporaryFile(delete=False)
176 cfg.HistoryManager.hist_file = tfile.name
176 cfg.HistoryManager.hist_file = tfile.name
177 try:
177 try:
178 hm = HistoryManager(shell=get_ipython(), config=cfg)
178 hm = HistoryManager(shell=get_ipython(), config=cfg)
179 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
179 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
180 finally:
180 finally:
181 try:
181 try:
182 os.remove(tfile.name)
182 os.remove(tfile.name)
183 except OSError:
183 except OSError:
184 # same catch as in testing.tools.TempFileMixin
184 # same catch as in testing.tools.TempFileMixin
185 # On Windows, even though we close the file, we still can't
185 # On Windows, even though we close the file, we still can't
186 # delete it. I have no clue why
186 # delete it. I have no clue why
187 pass
187 pass
188
188
189 def test_histmanager_disabled():
189 def test_histmanager_disabled():
190 """Ensure that disabling the history manager doesn't create a database."""
190 """Ensure that disabling the history manager doesn't create a database."""
191 cfg = Config()
191 cfg = Config()
192 cfg.HistoryAccessor.enabled = False
192 cfg.HistoryAccessor.enabled = False
193
193
194 ip = get_ipython()
194 ip = get_ipython()
195 with TemporaryDirectory() as tmpdir:
195 with TemporaryDirectory() as tmpdir:
196 hist_manager_ori = ip.history_manager
196 hist_manager_ori = ip.history_manager
197 hist_file = os.path.join(tmpdir, 'history.sqlite')
197 hist_file = os.path.join(tmpdir, 'history.sqlite')
198 cfg.HistoryManager.hist_file = hist_file
198 cfg.HistoryManager.hist_file = hist_file
199 try:
199 try:
200 ip.history_manager = HistoryManager(shell=ip, config=cfg)
200 ip.history_manager = HistoryManager(shell=ip, config=cfg)
201 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
201 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
202 for i, h in enumerate(hist, start=1):
202 for i, h in enumerate(hist, start=1):
203 ip.history_manager.store_inputs(i, h)
203 ip.history_manager.store_inputs(i, h)
204 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
204 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
205 ip.history_manager.reset()
205 ip.history_manager.reset()
206 ip.history_manager.end_session()
206 ip.history_manager.end_session()
207 finally:
207 finally:
208 ip.history_manager = hist_manager_ori
208 ip.history_manager = hist_manager_ori
209
209
210 # hist_file should not be created
210 # hist_file should not be created
211 nt.assert_false(os.path.exists(hist_file))
211 nt.assert_false(os.path.exists(hist_file))
@@ -1,904 +1,899 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import ast
12 import ast
13 import os
13 import os
14 import signal
14 import signal
15 import shutil
15 import shutil
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import unittest
18 import unittest
19 from unittest import mock
19 from unittest import mock
20 from io import StringIO
20
21
21 from os.path import join
22 from os.path import join
22
23
23 import nose.tools as nt
24 import nose.tools as nt
24
25
25 from IPython.core.error import InputRejected
26 from IPython.core.error import InputRejected
26 from IPython.core.inputtransformer import InputTransformer
27 from IPython.core.inputtransformer import InputTransformer
27 from IPython.testing.decorators import (
28 from IPython.testing.decorators import (
28 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
29 skipif, skip_win32, onlyif_unicode_paths, onlyif_cmds_exist,
29 )
30 )
30 from IPython.testing import tools as tt
31 from IPython.testing import tools as tt
31 from IPython.utils.process import find_cmd
32 from IPython.utils.process import find_cmd
32 from IPython.utils import py3compat
33 from IPython.utils import py3compat
33 from IPython.utils.py3compat import PY3
34
35 if PY3:
36 from io import StringIO
37 else:
38 from StringIO import StringIO
39
34
40 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
41 # Globals
36 # Globals
42 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
43 # This is used by every single test, no point repeating it ad nauseam
38 # This is used by every single test, no point repeating it ad nauseam
44 ip = get_ipython()
39 ip = get_ipython()
45
40
46 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
47 # Tests
42 # Tests
48 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
49
44
50 class DerivedInterrupt(KeyboardInterrupt):
45 class DerivedInterrupt(KeyboardInterrupt):
51 pass
46 pass
52
47
53 class InteractiveShellTestCase(unittest.TestCase):
48 class InteractiveShellTestCase(unittest.TestCase):
54 def test_naked_string_cells(self):
49 def test_naked_string_cells(self):
55 """Test that cells with only naked strings are fully executed"""
50 """Test that cells with only naked strings are fully executed"""
56 # First, single-line inputs
51 # First, single-line inputs
57 ip.run_cell('"a"\n')
52 ip.run_cell('"a"\n')
58 self.assertEqual(ip.user_ns['_'], 'a')
53 self.assertEqual(ip.user_ns['_'], 'a')
59 # And also multi-line cells
54 # And also multi-line cells
60 ip.run_cell('"""a\nb"""\n')
55 ip.run_cell('"""a\nb"""\n')
61 self.assertEqual(ip.user_ns['_'], 'a\nb')
56 self.assertEqual(ip.user_ns['_'], 'a\nb')
62
57
63 def test_run_empty_cell(self):
58 def test_run_empty_cell(self):
64 """Just make sure we don't get a horrible error with a blank
59 """Just make sure we don't get a horrible error with a blank
65 cell of input. Yes, I did overlook that."""
60 cell of input. Yes, I did overlook that."""
66 old_xc = ip.execution_count
61 old_xc = ip.execution_count
67 res = ip.run_cell('')
62 res = ip.run_cell('')
68 self.assertEqual(ip.execution_count, old_xc)
63 self.assertEqual(ip.execution_count, old_xc)
69 self.assertEqual(res.execution_count, None)
64 self.assertEqual(res.execution_count, None)
70
65
71 def test_run_cell_multiline(self):
66 def test_run_cell_multiline(self):
72 """Multi-block, multi-line cells must execute correctly.
67 """Multi-block, multi-line cells must execute correctly.
73 """
68 """
74 src = '\n'.join(["x=1",
69 src = '\n'.join(["x=1",
75 "y=2",
70 "y=2",
76 "if 1:",
71 "if 1:",
77 " x += 1",
72 " x += 1",
78 " y += 1",])
73 " y += 1",])
79 res = ip.run_cell(src)
74 res = ip.run_cell(src)
80 self.assertEqual(ip.user_ns['x'], 2)
75 self.assertEqual(ip.user_ns['x'], 2)
81 self.assertEqual(ip.user_ns['y'], 3)
76 self.assertEqual(ip.user_ns['y'], 3)
82 self.assertEqual(res.success, True)
77 self.assertEqual(res.success, True)
83 self.assertEqual(res.result, None)
78 self.assertEqual(res.result, None)
84
79
85 def test_multiline_string_cells(self):
80 def test_multiline_string_cells(self):
86 "Code sprinkled with multiline strings should execute (GH-306)"
81 "Code sprinkled with multiline strings should execute (GH-306)"
87 ip.run_cell('tmp=0')
82 ip.run_cell('tmp=0')
88 self.assertEqual(ip.user_ns['tmp'], 0)
83 self.assertEqual(ip.user_ns['tmp'], 0)
89 res = ip.run_cell('tmp=1;"""a\nb"""\n')
84 res = ip.run_cell('tmp=1;"""a\nb"""\n')
90 self.assertEqual(ip.user_ns['tmp'], 1)
85 self.assertEqual(ip.user_ns['tmp'], 1)
91 self.assertEqual(res.success, True)
86 self.assertEqual(res.success, True)
92 self.assertEqual(res.result, "a\nb")
87 self.assertEqual(res.result, "a\nb")
93
88
94 def test_dont_cache_with_semicolon(self):
89 def test_dont_cache_with_semicolon(self):
95 "Ending a line with semicolon should not cache the returned object (GH-307)"
90 "Ending a line with semicolon should not cache the returned object (GH-307)"
96 oldlen = len(ip.user_ns['Out'])
91 oldlen = len(ip.user_ns['Out'])
97 for cell in ['1;', '1;1;']:
92 for cell in ['1;', '1;1;']:
98 res = ip.run_cell(cell, store_history=True)
93 res = ip.run_cell(cell, store_history=True)
99 newlen = len(ip.user_ns['Out'])
94 newlen = len(ip.user_ns['Out'])
100 self.assertEqual(oldlen, newlen)
95 self.assertEqual(oldlen, newlen)
101 self.assertIsNone(res.result)
96 self.assertIsNone(res.result)
102 i = 0
97 i = 0
103 #also test the default caching behavior
98 #also test the default caching behavior
104 for cell in ['1', '1;1']:
99 for cell in ['1', '1;1']:
105 ip.run_cell(cell, store_history=True)
100 ip.run_cell(cell, store_history=True)
106 newlen = len(ip.user_ns['Out'])
101 newlen = len(ip.user_ns['Out'])
107 i += 1
102 i += 1
108 self.assertEqual(oldlen+i, newlen)
103 self.assertEqual(oldlen+i, newlen)
109
104
110 def test_syntax_error(self):
105 def test_syntax_error(self):
111 res = ip.run_cell("raise = 3")
106 res = ip.run_cell("raise = 3")
112 self.assertIsInstance(res.error_before_exec, SyntaxError)
107 self.assertIsInstance(res.error_before_exec, SyntaxError)
113
108
114 def test_In_variable(self):
109 def test_In_variable(self):
115 "Verify that In variable grows with user input (GH-284)"
110 "Verify that In variable grows with user input (GH-284)"
116 oldlen = len(ip.user_ns['In'])
111 oldlen = len(ip.user_ns['In'])
117 ip.run_cell('1;', store_history=True)
112 ip.run_cell('1;', store_history=True)
118 newlen = len(ip.user_ns['In'])
113 newlen = len(ip.user_ns['In'])
119 self.assertEqual(oldlen+1, newlen)
114 self.assertEqual(oldlen+1, newlen)
120 self.assertEqual(ip.user_ns['In'][-1],'1;')
115 self.assertEqual(ip.user_ns['In'][-1],'1;')
121
116
122 def test_magic_names_in_string(self):
117 def test_magic_names_in_string(self):
123 ip.run_cell('a = """\n%exit\n"""')
118 ip.run_cell('a = """\n%exit\n"""')
124 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
119 self.assertEqual(ip.user_ns['a'], '\n%exit\n')
125
120
126 def test_trailing_newline(self):
121 def test_trailing_newline(self):
127 """test that running !(command) does not raise a SyntaxError"""
122 """test that running !(command) does not raise a SyntaxError"""
128 ip.run_cell('!(true)\n', False)
123 ip.run_cell('!(true)\n', False)
129 ip.run_cell('!(true)\n\n\n', False)
124 ip.run_cell('!(true)\n\n\n', False)
130
125
131 def test_gh_597(self):
126 def test_gh_597(self):
132 """Pretty-printing lists of objects with non-ascii reprs may cause
127 """Pretty-printing lists of objects with non-ascii reprs may cause
133 problems."""
128 problems."""
134 class Spam(object):
129 class Spam(object):
135 def __repr__(self):
130 def __repr__(self):
136 return "\xe9"*50
131 return "\xe9"*50
137 import IPython.core.formatters
132 import IPython.core.formatters
138 f = IPython.core.formatters.PlainTextFormatter()
133 f = IPython.core.formatters.PlainTextFormatter()
139 f([Spam(),Spam()])
134 f([Spam(),Spam()])
140
135
141
136
142 def test_future_flags(self):
137 def test_future_flags(self):
143 """Check that future flags are used for parsing code (gh-777)"""
138 """Check that future flags are used for parsing code (gh-777)"""
144 ip.run_cell('from __future__ import barry_as_FLUFL')
139 ip.run_cell('from __future__ import barry_as_FLUFL')
145 try:
140 try:
146 ip.run_cell('prfunc_return_val = 1 <> 2')
141 ip.run_cell('prfunc_return_val = 1 <> 2')
147 assert 'prfunc_return_val' in ip.user_ns
142 assert 'prfunc_return_val' in ip.user_ns
148 finally:
143 finally:
149 # Reset compiler flags so we don't mess up other tests.
144 # Reset compiler flags so we don't mess up other tests.
150 ip.compile.reset_compiler_flags()
145 ip.compile.reset_compiler_flags()
151
146
152 def test_can_pickle(self):
147 def test_can_pickle(self):
153 "Can we pickle objects defined interactively (GH-29)"
148 "Can we pickle objects defined interactively (GH-29)"
154 ip = get_ipython()
149 ip = get_ipython()
155 ip.reset()
150 ip.reset()
156 ip.run_cell(("class Mylist(list):\n"
151 ip.run_cell(("class Mylist(list):\n"
157 " def __init__(self,x=[]):\n"
152 " def __init__(self,x=[]):\n"
158 " list.__init__(self,x)"))
153 " list.__init__(self,x)"))
159 ip.run_cell("w=Mylist([1,2,3])")
154 ip.run_cell("w=Mylist([1,2,3])")
160
155
161 from pickle import dumps
156 from pickle import dumps
162
157
163 # We need to swap in our main module - this is only necessary
158 # We need to swap in our main module - this is only necessary
164 # inside the test framework, because IPython puts the interactive module
159 # inside the test framework, because IPython puts the interactive module
165 # in place (but the test framework undoes this).
160 # in place (but the test framework undoes this).
166 _main = sys.modules['__main__']
161 _main = sys.modules['__main__']
167 sys.modules['__main__'] = ip.user_module
162 sys.modules['__main__'] = ip.user_module
168 try:
163 try:
169 res = dumps(ip.user_ns["w"])
164 res = dumps(ip.user_ns["w"])
170 finally:
165 finally:
171 sys.modules['__main__'] = _main
166 sys.modules['__main__'] = _main
172 self.assertTrue(isinstance(res, bytes))
167 self.assertTrue(isinstance(res, bytes))
173
168
174 def test_global_ns(self):
169 def test_global_ns(self):
175 "Code in functions must be able to access variables outside them."
170 "Code in functions must be able to access variables outside them."
176 ip = get_ipython()
171 ip = get_ipython()
177 ip.run_cell("a = 10")
172 ip.run_cell("a = 10")
178 ip.run_cell(("def f(x):\n"
173 ip.run_cell(("def f(x):\n"
179 " return x + a"))
174 " return x + a"))
180 ip.run_cell("b = f(12)")
175 ip.run_cell("b = f(12)")
181 self.assertEqual(ip.user_ns["b"], 22)
176 self.assertEqual(ip.user_ns["b"], 22)
182
177
183 def test_bad_custom_tb(self):
178 def test_bad_custom_tb(self):
184 """Check that InteractiveShell is protected from bad custom exception handlers"""
179 """Check that InteractiveShell is protected from bad custom exception handlers"""
185 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
180 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
186 self.assertEqual(ip.custom_exceptions, (IOError,))
181 self.assertEqual(ip.custom_exceptions, (IOError,))
187 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
182 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
188 ip.run_cell(u'raise IOError("foo")')
183 ip.run_cell(u'raise IOError("foo")')
189 self.assertEqual(ip.custom_exceptions, ())
184 self.assertEqual(ip.custom_exceptions, ())
190
185
191 def test_bad_custom_tb_return(self):
186 def test_bad_custom_tb_return(self):
192 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
187 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
193 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
188 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
194 self.assertEqual(ip.custom_exceptions, (NameError,))
189 self.assertEqual(ip.custom_exceptions, (NameError,))
195 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
190 with tt.AssertPrints("Custom TB Handler failed", channel='stderr'):
196 ip.run_cell(u'a=abracadabra')
191 ip.run_cell(u'a=abracadabra')
197 self.assertEqual(ip.custom_exceptions, ())
192 self.assertEqual(ip.custom_exceptions, ())
198
193
199 def test_drop_by_id(self):
194 def test_drop_by_id(self):
200 myvars = {"a":object(), "b":object(), "c": object()}
195 myvars = {"a":object(), "b":object(), "c": object()}
201 ip.push(myvars, interactive=False)
196 ip.push(myvars, interactive=False)
202 for name in myvars:
197 for name in myvars:
203 assert name in ip.user_ns, name
198 assert name in ip.user_ns, name
204 assert name in ip.user_ns_hidden, name
199 assert name in ip.user_ns_hidden, name
205 ip.user_ns['b'] = 12
200 ip.user_ns['b'] = 12
206 ip.drop_by_id(myvars)
201 ip.drop_by_id(myvars)
207 for name in ["a", "c"]:
202 for name in ["a", "c"]:
208 assert name not in ip.user_ns, name
203 assert name not in ip.user_ns, name
209 assert name not in ip.user_ns_hidden, name
204 assert name not in ip.user_ns_hidden, name
210 assert ip.user_ns['b'] == 12
205 assert ip.user_ns['b'] == 12
211 ip.reset()
206 ip.reset()
212
207
213 def test_var_expand(self):
208 def test_var_expand(self):
214 ip.user_ns['f'] = u'Ca\xf1o'
209 ip.user_ns['f'] = u'Ca\xf1o'
215 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
210 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
216 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
211 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
217 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
212 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
218 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
213 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
219
214
220 ip.user_ns['f'] = b'Ca\xc3\xb1o'
215 ip.user_ns['f'] = b'Ca\xc3\xb1o'
221 # This should not raise any exception:
216 # This should not raise any exception:
222 ip.var_expand(u'echo $f')
217 ip.var_expand(u'echo $f')
223
218
224 def test_var_expand_local(self):
219 def test_var_expand_local(self):
225 """Test local variable expansion in !system and %magic calls"""
220 """Test local variable expansion in !system and %magic calls"""
226 # !system
221 # !system
227 ip.run_cell('def test():\n'
222 ip.run_cell('def test():\n'
228 ' lvar = "ttt"\n'
223 ' lvar = "ttt"\n'
229 ' ret = !echo {lvar}\n'
224 ' ret = !echo {lvar}\n'
230 ' return ret[0]\n')
225 ' return ret[0]\n')
231 res = ip.user_ns['test']()
226 res = ip.user_ns['test']()
232 nt.assert_in('ttt', res)
227 nt.assert_in('ttt', res)
233
228
234 # %magic
229 # %magic
235 ip.run_cell('def makemacro():\n'
230 ip.run_cell('def makemacro():\n'
236 ' macroname = "macro_var_expand_locals"\n'
231 ' macroname = "macro_var_expand_locals"\n'
237 ' %macro {macroname} codestr\n')
232 ' %macro {macroname} codestr\n')
238 ip.user_ns['codestr'] = "str(12)"
233 ip.user_ns['codestr'] = "str(12)"
239 ip.run_cell('makemacro()')
234 ip.run_cell('makemacro()')
240 nt.assert_in('macro_var_expand_locals', ip.user_ns)
235 nt.assert_in('macro_var_expand_locals', ip.user_ns)
241
236
242 def test_var_expand_self(self):
237 def test_var_expand_self(self):
243 """Test variable expansion with the name 'self', which was failing.
238 """Test variable expansion with the name 'self', which was failing.
244
239
245 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
240 See https://github.com/ipython/ipython/issues/1878#issuecomment-7698218
246 """
241 """
247 ip.run_cell('class cTest:\n'
242 ip.run_cell('class cTest:\n'
248 ' classvar="see me"\n'
243 ' classvar="see me"\n'
249 ' def test(self):\n'
244 ' def test(self):\n'
250 ' res = !echo Variable: {self.classvar}\n'
245 ' res = !echo Variable: {self.classvar}\n'
251 ' return res[0]\n')
246 ' return res[0]\n')
252 nt.assert_in('see me', ip.user_ns['cTest']().test())
247 nt.assert_in('see me', ip.user_ns['cTest']().test())
253
248
254 def test_bad_var_expand(self):
249 def test_bad_var_expand(self):
255 """var_expand on invalid formats shouldn't raise"""
250 """var_expand on invalid formats shouldn't raise"""
256 # SyntaxError
251 # SyntaxError
257 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
252 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
258 # NameError
253 # NameError
259 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
254 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
260 # ZeroDivisionError
255 # ZeroDivisionError
261 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
256 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
262
257
263 def test_silent_postexec(self):
258 def test_silent_postexec(self):
264 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
259 """run_cell(silent=True) doesn't invoke pre/post_run_cell callbacks"""
265 pre_explicit = mock.Mock()
260 pre_explicit = mock.Mock()
266 pre_always = mock.Mock()
261 pre_always = mock.Mock()
267 post_explicit = mock.Mock()
262 post_explicit = mock.Mock()
268 post_always = mock.Mock()
263 post_always = mock.Mock()
269
264
270 ip.events.register('pre_run_cell', pre_explicit)
265 ip.events.register('pre_run_cell', pre_explicit)
271 ip.events.register('pre_execute', pre_always)
266 ip.events.register('pre_execute', pre_always)
272 ip.events.register('post_run_cell', post_explicit)
267 ip.events.register('post_run_cell', post_explicit)
273 ip.events.register('post_execute', post_always)
268 ip.events.register('post_execute', post_always)
274
269
275 try:
270 try:
276 ip.run_cell("1", silent=True)
271 ip.run_cell("1", silent=True)
277 assert pre_always.called
272 assert pre_always.called
278 assert not pre_explicit.called
273 assert not pre_explicit.called
279 assert post_always.called
274 assert post_always.called
280 assert not post_explicit.called
275 assert not post_explicit.called
281 # double-check that non-silent exec did what we expected
276 # double-check that non-silent exec did what we expected
282 # silent to avoid
277 # silent to avoid
283 ip.run_cell("1")
278 ip.run_cell("1")
284 assert pre_explicit.called
279 assert pre_explicit.called
285 assert post_explicit.called
280 assert post_explicit.called
286 finally:
281 finally:
287 # remove post-exec
282 # remove post-exec
288 ip.events.unregister('pre_run_cell', pre_explicit)
283 ip.events.unregister('pre_run_cell', pre_explicit)
289 ip.events.unregister('pre_execute', pre_always)
284 ip.events.unregister('pre_execute', pre_always)
290 ip.events.unregister('post_run_cell', post_explicit)
285 ip.events.unregister('post_run_cell', post_explicit)
291 ip.events.unregister('post_execute', post_always)
286 ip.events.unregister('post_execute', post_always)
292
287
293 def test_silent_noadvance(self):
288 def test_silent_noadvance(self):
294 """run_cell(silent=True) doesn't advance execution_count"""
289 """run_cell(silent=True) doesn't advance execution_count"""
295 ec = ip.execution_count
290 ec = ip.execution_count
296 # silent should force store_history=False
291 # silent should force store_history=False
297 ip.run_cell("1", store_history=True, silent=True)
292 ip.run_cell("1", store_history=True, silent=True)
298
293
299 self.assertEqual(ec, ip.execution_count)
294 self.assertEqual(ec, ip.execution_count)
300 # double-check that non-silent exec did what we expected
295 # double-check that non-silent exec did what we expected
301 # silent to avoid
296 # silent to avoid
302 ip.run_cell("1", store_history=True)
297 ip.run_cell("1", store_history=True)
303 self.assertEqual(ec+1, ip.execution_count)
298 self.assertEqual(ec+1, ip.execution_count)
304
299
305 def test_silent_nodisplayhook(self):
300 def test_silent_nodisplayhook(self):
306 """run_cell(silent=True) doesn't trigger displayhook"""
301 """run_cell(silent=True) doesn't trigger displayhook"""
307 d = dict(called=False)
302 d = dict(called=False)
308
303
309 trap = ip.display_trap
304 trap = ip.display_trap
310 save_hook = trap.hook
305 save_hook = trap.hook
311
306
312 def failing_hook(*args, **kwargs):
307 def failing_hook(*args, **kwargs):
313 d['called'] = True
308 d['called'] = True
314
309
315 try:
310 try:
316 trap.hook = failing_hook
311 trap.hook = failing_hook
317 res = ip.run_cell("1", silent=True)
312 res = ip.run_cell("1", silent=True)
318 self.assertFalse(d['called'])
313 self.assertFalse(d['called'])
319 self.assertIsNone(res.result)
314 self.assertIsNone(res.result)
320 # double-check that non-silent exec did what we expected
315 # double-check that non-silent exec did what we expected
321 # silent to avoid
316 # silent to avoid
322 ip.run_cell("1")
317 ip.run_cell("1")
323 self.assertTrue(d['called'])
318 self.assertTrue(d['called'])
324 finally:
319 finally:
325 trap.hook = save_hook
320 trap.hook = save_hook
326
321
327 def test_ofind_line_magic(self):
322 def test_ofind_line_magic(self):
328 from IPython.core.magic import register_line_magic
323 from IPython.core.magic import register_line_magic
329
324
330 @register_line_magic
325 @register_line_magic
331 def lmagic(line):
326 def lmagic(line):
332 "A line magic"
327 "A line magic"
333
328
334 # Get info on line magic
329 # Get info on line magic
335 lfind = ip._ofind('lmagic')
330 lfind = ip._ofind('lmagic')
336 info = dict(found=True, isalias=False, ismagic=True,
331 info = dict(found=True, isalias=False, ismagic=True,
337 namespace = 'IPython internal', obj= lmagic.__wrapped__,
332 namespace = 'IPython internal', obj= lmagic.__wrapped__,
338 parent = None)
333 parent = None)
339 nt.assert_equal(lfind, info)
334 nt.assert_equal(lfind, info)
340
335
341 def test_ofind_cell_magic(self):
336 def test_ofind_cell_magic(self):
342 from IPython.core.magic import register_cell_magic
337 from IPython.core.magic import register_cell_magic
343
338
344 @register_cell_magic
339 @register_cell_magic
345 def cmagic(line, cell):
340 def cmagic(line, cell):
346 "A cell magic"
341 "A cell magic"
347
342
348 # Get info on cell magic
343 # Get info on cell magic
349 find = ip._ofind('cmagic')
344 find = ip._ofind('cmagic')
350 info = dict(found=True, isalias=False, ismagic=True,
345 info = dict(found=True, isalias=False, ismagic=True,
351 namespace = 'IPython internal', obj= cmagic.__wrapped__,
346 namespace = 'IPython internal', obj= cmagic.__wrapped__,
352 parent = None)
347 parent = None)
353 nt.assert_equal(find, info)
348 nt.assert_equal(find, info)
354
349
355 def test_ofind_property_with_error(self):
350 def test_ofind_property_with_error(self):
356 class A(object):
351 class A(object):
357 @property
352 @property
358 def foo(self):
353 def foo(self):
359 raise NotImplementedError()
354 raise NotImplementedError()
360 a = A()
355 a = A()
361
356
362 found = ip._ofind('a.foo', [('locals', locals())])
357 found = ip._ofind('a.foo', [('locals', locals())])
363 info = dict(found=True, isalias=False, ismagic=False,
358 info = dict(found=True, isalias=False, ismagic=False,
364 namespace='locals', obj=A.foo, parent=a)
359 namespace='locals', obj=A.foo, parent=a)
365 nt.assert_equal(found, info)
360 nt.assert_equal(found, info)
366
361
367 def test_ofind_multiple_attribute_lookups(self):
362 def test_ofind_multiple_attribute_lookups(self):
368 class A(object):
363 class A(object):
369 @property
364 @property
370 def foo(self):
365 def foo(self):
371 raise NotImplementedError()
366 raise NotImplementedError()
372
367
373 a = A()
368 a = A()
374 a.a = A()
369 a.a = A()
375 a.a.a = A()
370 a.a.a = A()
376
371
377 found = ip._ofind('a.a.a.foo', [('locals', locals())])
372 found = ip._ofind('a.a.a.foo', [('locals', locals())])
378 info = dict(found=True, isalias=False, ismagic=False,
373 info = dict(found=True, isalias=False, ismagic=False,
379 namespace='locals', obj=A.foo, parent=a.a.a)
374 namespace='locals', obj=A.foo, parent=a.a.a)
380 nt.assert_equal(found, info)
375 nt.assert_equal(found, info)
381
376
382 def test_ofind_slotted_attributes(self):
377 def test_ofind_slotted_attributes(self):
383 class A(object):
378 class A(object):
384 __slots__ = ['foo']
379 __slots__ = ['foo']
385 def __init__(self):
380 def __init__(self):
386 self.foo = 'bar'
381 self.foo = 'bar'
387
382
388 a = A()
383 a = A()
389 found = ip._ofind('a.foo', [('locals', locals())])
384 found = ip._ofind('a.foo', [('locals', locals())])
390 info = dict(found=True, isalias=False, ismagic=False,
385 info = dict(found=True, isalias=False, ismagic=False,
391 namespace='locals', obj=a.foo, parent=a)
386 namespace='locals', obj=a.foo, parent=a)
392 nt.assert_equal(found, info)
387 nt.assert_equal(found, info)
393
388
394 found = ip._ofind('a.bar', [('locals', locals())])
389 found = ip._ofind('a.bar', [('locals', locals())])
395 info = dict(found=False, isalias=False, ismagic=False,
390 info = dict(found=False, isalias=False, ismagic=False,
396 namespace=None, obj=None, parent=a)
391 namespace=None, obj=None, parent=a)
397 nt.assert_equal(found, info)
392 nt.assert_equal(found, info)
398
393
399 def test_ofind_prefers_property_to_instance_level_attribute(self):
394 def test_ofind_prefers_property_to_instance_level_attribute(self):
400 class A(object):
395 class A(object):
401 @property
396 @property
402 def foo(self):
397 def foo(self):
403 return 'bar'
398 return 'bar'
404 a = A()
399 a = A()
405 a.__dict__['foo'] = 'baz'
400 a.__dict__['foo'] = 'baz'
406 nt.assert_equal(a.foo, 'bar')
401 nt.assert_equal(a.foo, 'bar')
407 found = ip._ofind('a.foo', [('locals', locals())])
402 found = ip._ofind('a.foo', [('locals', locals())])
408 nt.assert_is(found['obj'], A.foo)
403 nt.assert_is(found['obj'], A.foo)
409
404
410 def test_custom_syntaxerror_exception(self):
405 def test_custom_syntaxerror_exception(self):
411 called = []
406 called = []
412 def my_handler(shell, etype, value, tb, tb_offset=None):
407 def my_handler(shell, etype, value, tb, tb_offset=None):
413 called.append(etype)
408 called.append(etype)
414 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
409 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
415
410
416 ip.set_custom_exc((SyntaxError,), my_handler)
411 ip.set_custom_exc((SyntaxError,), my_handler)
417 try:
412 try:
418 ip.run_cell("1f")
413 ip.run_cell("1f")
419 # Check that this was called, and only once.
414 # Check that this was called, and only once.
420 self.assertEqual(called, [SyntaxError])
415 self.assertEqual(called, [SyntaxError])
421 finally:
416 finally:
422 # Reset the custom exception hook
417 # Reset the custom exception hook
423 ip.set_custom_exc((), None)
418 ip.set_custom_exc((), None)
424
419
425 def test_custom_exception(self):
420 def test_custom_exception(self):
426 called = []
421 called = []
427 def my_handler(shell, etype, value, tb, tb_offset=None):
422 def my_handler(shell, etype, value, tb, tb_offset=None):
428 called.append(etype)
423 called.append(etype)
429 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
424 shell.showtraceback((etype, value, tb), tb_offset=tb_offset)
430
425
431 ip.set_custom_exc((ValueError,), my_handler)
426 ip.set_custom_exc((ValueError,), my_handler)
432 try:
427 try:
433 res = ip.run_cell("raise ValueError('test')")
428 res = ip.run_cell("raise ValueError('test')")
434 # Check that this was called, and only once.
429 # Check that this was called, and only once.
435 self.assertEqual(called, [ValueError])
430 self.assertEqual(called, [ValueError])
436 # Check that the error is on the result object
431 # Check that the error is on the result object
437 self.assertIsInstance(res.error_in_exec, ValueError)
432 self.assertIsInstance(res.error_in_exec, ValueError)
438 finally:
433 finally:
439 # Reset the custom exception hook
434 # Reset the custom exception hook
440 ip.set_custom_exc((), None)
435 ip.set_custom_exc((), None)
441
436
442 def test_mktempfile(self):
437 def test_mktempfile(self):
443 filename = ip.mktempfile()
438 filename = ip.mktempfile()
444 # Check that we can open the file again on Windows
439 # Check that we can open the file again on Windows
445 with open(filename, 'w') as f:
440 with open(filename, 'w') as f:
446 f.write('abc')
441 f.write('abc')
447
442
448 filename = ip.mktempfile(data='blah')
443 filename = ip.mktempfile(data='blah')
449 with open(filename, 'r') as f:
444 with open(filename, 'r') as f:
450 self.assertEqual(f.read(), 'blah')
445 self.assertEqual(f.read(), 'blah')
451
446
452 def test_new_main_mod(self):
447 def test_new_main_mod(self):
453 # Smoketest to check that this accepts a unicode module name
448 # Smoketest to check that this accepts a unicode module name
454 name = u'jiefmw'
449 name = u'jiefmw'
455 mod = ip.new_main_mod(u'%s.py' % name, name)
450 mod = ip.new_main_mod(u'%s.py' % name, name)
456 self.assertEqual(mod.__name__, name)
451 self.assertEqual(mod.__name__, name)
457
452
458 def test_get_exception_only(self):
453 def test_get_exception_only(self):
459 try:
454 try:
460 raise KeyboardInterrupt
455 raise KeyboardInterrupt
461 except KeyboardInterrupt:
456 except KeyboardInterrupt:
462 msg = ip.get_exception_only()
457 msg = ip.get_exception_only()
463 self.assertEqual(msg, 'KeyboardInterrupt\n')
458 self.assertEqual(msg, 'KeyboardInterrupt\n')
464
459
465 try:
460 try:
466 raise DerivedInterrupt("foo")
461 raise DerivedInterrupt("foo")
467 except KeyboardInterrupt:
462 except KeyboardInterrupt:
468 msg = ip.get_exception_only()
463 msg = ip.get_exception_only()
469 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
464 self.assertEqual(msg, 'IPython.core.tests.test_interactiveshell.DerivedInterrupt: foo\n')
470
465
471 def test_inspect_text(self):
466 def test_inspect_text(self):
472 ip.run_cell('a = 5')
467 ip.run_cell('a = 5')
473 text = ip.object_inspect_text('a')
468 text = ip.object_inspect_text('a')
474 self.assertIsInstance(text, str)
469 self.assertIsInstance(text, str)
475
470
476
471
477 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
472 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
478
473
479 @onlyif_unicode_paths
474 @onlyif_unicode_paths
480 def setUp(self):
475 def setUp(self):
481 self.BASETESTDIR = tempfile.mkdtemp()
476 self.BASETESTDIR = tempfile.mkdtemp()
482 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
477 self.TESTDIR = join(self.BASETESTDIR, u"Γ₯Àâ")
483 os.mkdir(self.TESTDIR)
478 os.mkdir(self.TESTDIR)
484 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
479 with open(join(self.TESTDIR, u"Γ₯Àâtestscript.py"), "w") as sfile:
485 sfile.write("pass\n")
480 sfile.write("pass\n")
486 self.oldpath = os.getcwd()
481 self.oldpath = os.getcwd()
487 os.chdir(self.TESTDIR)
482 os.chdir(self.TESTDIR)
488 self.fname = u"Γ₯Àâtestscript.py"
483 self.fname = u"Γ₯Àâtestscript.py"
489
484
490 def tearDown(self):
485 def tearDown(self):
491 os.chdir(self.oldpath)
486 os.chdir(self.oldpath)
492 shutil.rmtree(self.BASETESTDIR)
487 shutil.rmtree(self.BASETESTDIR)
493
488
494 @onlyif_unicode_paths
489 @onlyif_unicode_paths
495 def test_1(self):
490 def test_1(self):
496 """Test safe_execfile with non-ascii path
491 """Test safe_execfile with non-ascii path
497 """
492 """
498 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
493 ip.safe_execfile(self.fname, {}, raise_exceptions=True)
499
494
500 class ExitCodeChecks(tt.TempFileMixin):
495 class ExitCodeChecks(tt.TempFileMixin):
501 def test_exit_code_ok(self):
496 def test_exit_code_ok(self):
502 self.system('exit 0')
497 self.system('exit 0')
503 self.assertEqual(ip.user_ns['_exit_code'], 0)
498 self.assertEqual(ip.user_ns['_exit_code'], 0)
504
499
505 def test_exit_code_error(self):
500 def test_exit_code_error(self):
506 self.system('exit 1')
501 self.system('exit 1')
507 self.assertEqual(ip.user_ns['_exit_code'], 1)
502 self.assertEqual(ip.user_ns['_exit_code'], 1)
508
503
509 @skipif(not hasattr(signal, 'SIGALRM'))
504 @skipif(not hasattr(signal, 'SIGALRM'))
510 def test_exit_code_signal(self):
505 def test_exit_code_signal(self):
511 self.mktmp("import signal, time\n"
506 self.mktmp("import signal, time\n"
512 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
507 "signal.setitimer(signal.ITIMER_REAL, 0.1)\n"
513 "time.sleep(1)\n")
508 "time.sleep(1)\n")
514 self.system("%s %s" % (sys.executable, self.fname))
509 self.system("%s %s" % (sys.executable, self.fname))
515 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
510 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGALRM)
516
511
517 @onlyif_cmds_exist("csh")
512 @onlyif_cmds_exist("csh")
518 def test_exit_code_signal_csh(self):
513 def test_exit_code_signal_csh(self):
519 SHELL = os.environ.get('SHELL', None)
514 SHELL = os.environ.get('SHELL', None)
520 os.environ['SHELL'] = find_cmd("csh")
515 os.environ['SHELL'] = find_cmd("csh")
521 try:
516 try:
522 self.test_exit_code_signal()
517 self.test_exit_code_signal()
523 finally:
518 finally:
524 if SHELL is not None:
519 if SHELL is not None:
525 os.environ['SHELL'] = SHELL
520 os.environ['SHELL'] = SHELL
526 else:
521 else:
527 del os.environ['SHELL']
522 del os.environ['SHELL']
528
523
529 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
524 class TestSystemRaw(unittest.TestCase, ExitCodeChecks):
530 system = ip.system_raw
525 system = ip.system_raw
531
526
532 @onlyif_unicode_paths
527 @onlyif_unicode_paths
533 def test_1(self):
528 def test_1(self):
534 """Test system_raw with non-ascii cmd
529 """Test system_raw with non-ascii cmd
535 """
530 """
536 cmd = u'''python -c "'Γ₯Àâ'" '''
531 cmd = u'''python -c "'Γ₯Àâ'" '''
537 ip.system_raw(cmd)
532 ip.system_raw(cmd)
538
533
539 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
534 @mock.patch('subprocess.call', side_effect=KeyboardInterrupt)
540 @mock.patch('os.system', side_effect=KeyboardInterrupt)
535 @mock.patch('os.system', side_effect=KeyboardInterrupt)
541 def test_control_c(self, *mocks):
536 def test_control_c(self, *mocks):
542 try:
537 try:
543 self.system("sleep 1 # wont happen")
538 self.system("sleep 1 # wont happen")
544 except KeyboardInterrupt:
539 except KeyboardInterrupt:
545 self.fail("system call should intercept "
540 self.fail("system call should intercept "
546 "keyboard interrupt from subprocess.call")
541 "keyboard interrupt from subprocess.call")
547 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
542 self.assertEqual(ip.user_ns['_exit_code'], -signal.SIGINT)
548
543
549 # TODO: Exit codes are currently ignored on Windows.
544 # TODO: Exit codes are currently ignored on Windows.
550 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
545 class TestSystemPipedExitCode(unittest.TestCase, ExitCodeChecks):
551 system = ip.system_piped
546 system = ip.system_piped
552
547
553 @skip_win32
548 @skip_win32
554 def test_exit_code_ok(self):
549 def test_exit_code_ok(self):
555 ExitCodeChecks.test_exit_code_ok(self)
550 ExitCodeChecks.test_exit_code_ok(self)
556
551
557 @skip_win32
552 @skip_win32
558 def test_exit_code_error(self):
553 def test_exit_code_error(self):
559 ExitCodeChecks.test_exit_code_error(self)
554 ExitCodeChecks.test_exit_code_error(self)
560
555
561 @skip_win32
556 @skip_win32
562 def test_exit_code_signal(self):
557 def test_exit_code_signal(self):
563 ExitCodeChecks.test_exit_code_signal(self)
558 ExitCodeChecks.test_exit_code_signal(self)
564
559
565 class TestModules(unittest.TestCase, tt.TempFileMixin):
560 class TestModules(unittest.TestCase, tt.TempFileMixin):
566 def test_extraneous_loads(self):
561 def test_extraneous_loads(self):
567 """Test we're not loading modules on startup that we shouldn't.
562 """Test we're not loading modules on startup that we shouldn't.
568 """
563 """
569 self.mktmp("import sys\n"
564 self.mktmp("import sys\n"
570 "print('numpy' in sys.modules)\n"
565 "print('numpy' in sys.modules)\n"
571 "print('ipyparallel' in sys.modules)\n"
566 "print('ipyparallel' in sys.modules)\n"
572 "print('ipykernel' in sys.modules)\n"
567 "print('ipykernel' in sys.modules)\n"
573 )
568 )
574 out = "False\nFalse\nFalse\n"
569 out = "False\nFalse\nFalse\n"
575 tt.ipexec_validate(self.fname, out)
570 tt.ipexec_validate(self.fname, out)
576
571
577 class Negator(ast.NodeTransformer):
572 class Negator(ast.NodeTransformer):
578 """Negates all number literals in an AST."""
573 """Negates all number literals in an AST."""
579 def visit_Num(self, node):
574 def visit_Num(self, node):
580 node.n = -node.n
575 node.n = -node.n
581 return node
576 return node
582
577
583 class TestAstTransform(unittest.TestCase):
578 class TestAstTransform(unittest.TestCase):
584 def setUp(self):
579 def setUp(self):
585 self.negator = Negator()
580 self.negator = Negator()
586 ip.ast_transformers.append(self.negator)
581 ip.ast_transformers.append(self.negator)
587
582
588 def tearDown(self):
583 def tearDown(self):
589 ip.ast_transformers.remove(self.negator)
584 ip.ast_transformers.remove(self.negator)
590
585
591 def test_run_cell(self):
586 def test_run_cell(self):
592 with tt.AssertPrints('-34'):
587 with tt.AssertPrints('-34'):
593 ip.run_cell('print (12 + 22)')
588 ip.run_cell('print (12 + 22)')
594
589
595 # A named reference to a number shouldn't be transformed.
590 # A named reference to a number shouldn't be transformed.
596 ip.user_ns['n'] = 55
591 ip.user_ns['n'] = 55
597 with tt.AssertNotPrints('-55'):
592 with tt.AssertNotPrints('-55'):
598 ip.run_cell('print (n)')
593 ip.run_cell('print (n)')
599
594
600 def test_timeit(self):
595 def test_timeit(self):
601 called = set()
596 called = set()
602 def f(x):
597 def f(x):
603 called.add(x)
598 called.add(x)
604 ip.push({'f':f})
599 ip.push({'f':f})
605
600
606 with tt.AssertPrints("average of "):
601 with tt.AssertPrints("average of "):
607 ip.run_line_magic("timeit", "-n1 f(1)")
602 ip.run_line_magic("timeit", "-n1 f(1)")
608 self.assertEqual(called, {-1})
603 self.assertEqual(called, {-1})
609 called.clear()
604 called.clear()
610
605
611 with tt.AssertPrints("average of "):
606 with tt.AssertPrints("average of "):
612 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
607 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
613 self.assertEqual(called, {-2, -3})
608 self.assertEqual(called, {-2, -3})
614
609
615 def test_time(self):
610 def test_time(self):
616 called = []
611 called = []
617 def f(x):
612 def f(x):
618 called.append(x)
613 called.append(x)
619 ip.push({'f':f})
614 ip.push({'f':f})
620
615
621 # Test with an expression
616 # Test with an expression
622 with tt.AssertPrints("Wall time: "):
617 with tt.AssertPrints("Wall time: "):
623 ip.run_line_magic("time", "f(5+9)")
618 ip.run_line_magic("time", "f(5+9)")
624 self.assertEqual(called, [-14])
619 self.assertEqual(called, [-14])
625 called[:] = []
620 called[:] = []
626
621
627 # Test with a statement (different code path)
622 # Test with a statement (different code path)
628 with tt.AssertPrints("Wall time: "):
623 with tt.AssertPrints("Wall time: "):
629 ip.run_line_magic("time", "a = f(-3 + -2)")
624 ip.run_line_magic("time", "a = f(-3 + -2)")
630 self.assertEqual(called, [5])
625 self.assertEqual(called, [5])
631
626
632 def test_macro(self):
627 def test_macro(self):
633 ip.push({'a':10})
628 ip.push({'a':10})
634 # The AST transformation makes this do a+=-1
629 # The AST transformation makes this do a+=-1
635 ip.define_macro("amacro", "a+=1\nprint(a)")
630 ip.define_macro("amacro", "a+=1\nprint(a)")
636
631
637 with tt.AssertPrints("9"):
632 with tt.AssertPrints("9"):
638 ip.run_cell("amacro")
633 ip.run_cell("amacro")
639 with tt.AssertPrints("8"):
634 with tt.AssertPrints("8"):
640 ip.run_cell("amacro")
635 ip.run_cell("amacro")
641
636
642 class IntegerWrapper(ast.NodeTransformer):
637 class IntegerWrapper(ast.NodeTransformer):
643 """Wraps all integers in a call to Integer()"""
638 """Wraps all integers in a call to Integer()"""
644 def visit_Num(self, node):
639 def visit_Num(self, node):
645 if isinstance(node.n, int):
640 if isinstance(node.n, int):
646 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
641 return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
647 args=[node], keywords=[])
642 args=[node], keywords=[])
648 return node
643 return node
649
644
650 class TestAstTransform2(unittest.TestCase):
645 class TestAstTransform2(unittest.TestCase):
651 def setUp(self):
646 def setUp(self):
652 self.intwrapper = IntegerWrapper()
647 self.intwrapper = IntegerWrapper()
653 ip.ast_transformers.append(self.intwrapper)
648 ip.ast_transformers.append(self.intwrapper)
654
649
655 self.calls = []
650 self.calls = []
656 def Integer(*args):
651 def Integer(*args):
657 self.calls.append(args)
652 self.calls.append(args)
658 return args
653 return args
659 ip.push({"Integer": Integer})
654 ip.push({"Integer": Integer})
660
655
661 def tearDown(self):
656 def tearDown(self):
662 ip.ast_transformers.remove(self.intwrapper)
657 ip.ast_transformers.remove(self.intwrapper)
663 del ip.user_ns['Integer']
658 del ip.user_ns['Integer']
664
659
665 def test_run_cell(self):
660 def test_run_cell(self):
666 ip.run_cell("n = 2")
661 ip.run_cell("n = 2")
667 self.assertEqual(self.calls, [(2,)])
662 self.assertEqual(self.calls, [(2,)])
668
663
669 # This shouldn't throw an error
664 # This shouldn't throw an error
670 ip.run_cell("o = 2.0")
665 ip.run_cell("o = 2.0")
671 self.assertEqual(ip.user_ns['o'], 2.0)
666 self.assertEqual(ip.user_ns['o'], 2.0)
672
667
673 def test_timeit(self):
668 def test_timeit(self):
674 called = set()
669 called = set()
675 def f(x):
670 def f(x):
676 called.add(x)
671 called.add(x)
677 ip.push({'f':f})
672 ip.push({'f':f})
678
673
679 with tt.AssertPrints("average of "):
674 with tt.AssertPrints("average of "):
680 ip.run_line_magic("timeit", "-n1 f(1)")
675 ip.run_line_magic("timeit", "-n1 f(1)")
681 self.assertEqual(called, {(1,)})
676 self.assertEqual(called, {(1,)})
682 called.clear()
677 called.clear()
683
678
684 with tt.AssertPrints("average of "):
679 with tt.AssertPrints("average of "):
685 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
680 ip.run_cell_magic("timeit", "-n1 f(2)", "f(3)")
686 self.assertEqual(called, {(2,), (3,)})
681 self.assertEqual(called, {(2,), (3,)})
687
682
688 class ErrorTransformer(ast.NodeTransformer):
683 class ErrorTransformer(ast.NodeTransformer):
689 """Throws an error when it sees a number."""
684 """Throws an error when it sees a number."""
690 def visit_Num(self, node):
685 def visit_Num(self, node):
691 raise ValueError("test")
686 raise ValueError("test")
692
687
693 class TestAstTransformError(unittest.TestCase):
688 class TestAstTransformError(unittest.TestCase):
694 def test_unregistering(self):
689 def test_unregistering(self):
695 err_transformer = ErrorTransformer()
690 err_transformer = ErrorTransformer()
696 ip.ast_transformers.append(err_transformer)
691 ip.ast_transformers.append(err_transformer)
697
692
698 with tt.AssertPrints("unregister", channel='stderr'):
693 with tt.AssertPrints("unregister", channel='stderr'):
699 ip.run_cell("1 + 2")
694 ip.run_cell("1 + 2")
700
695
701 # This should have been removed.
696 # This should have been removed.
702 nt.assert_not_in(err_transformer, ip.ast_transformers)
697 nt.assert_not_in(err_transformer, ip.ast_transformers)
703
698
704
699
705 class StringRejector(ast.NodeTransformer):
700 class StringRejector(ast.NodeTransformer):
706 """Throws an InputRejected when it sees a string literal.
701 """Throws an InputRejected when it sees a string literal.
707
702
708 Used to verify that NodeTransformers can signal that a piece of code should
703 Used to verify that NodeTransformers can signal that a piece of code should
709 not be executed by throwing an InputRejected.
704 not be executed by throwing an InputRejected.
710 """
705 """
711
706
712 def visit_Str(self, node):
707 def visit_Str(self, node):
713 raise InputRejected("test")
708 raise InputRejected("test")
714
709
715
710
716 class TestAstTransformInputRejection(unittest.TestCase):
711 class TestAstTransformInputRejection(unittest.TestCase):
717
712
718 def setUp(self):
713 def setUp(self):
719 self.transformer = StringRejector()
714 self.transformer = StringRejector()
720 ip.ast_transformers.append(self.transformer)
715 ip.ast_transformers.append(self.transformer)
721
716
722 def tearDown(self):
717 def tearDown(self):
723 ip.ast_transformers.remove(self.transformer)
718 ip.ast_transformers.remove(self.transformer)
724
719
725 def test_input_rejection(self):
720 def test_input_rejection(self):
726 """Check that NodeTransformers can reject input."""
721 """Check that NodeTransformers can reject input."""
727
722
728 expect_exception_tb = tt.AssertPrints("InputRejected: test")
723 expect_exception_tb = tt.AssertPrints("InputRejected: test")
729 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
724 expect_no_cell_output = tt.AssertNotPrints("'unsafe'", suppress=False)
730
725
731 # Run the same check twice to verify that the transformer is not
726 # Run the same check twice to verify that the transformer is not
732 # disabled after raising.
727 # disabled after raising.
733 with expect_exception_tb, expect_no_cell_output:
728 with expect_exception_tb, expect_no_cell_output:
734 ip.run_cell("'unsafe'")
729 ip.run_cell("'unsafe'")
735
730
736 with expect_exception_tb, expect_no_cell_output:
731 with expect_exception_tb, expect_no_cell_output:
737 res = ip.run_cell("'unsafe'")
732 res = ip.run_cell("'unsafe'")
738
733
739 self.assertIsInstance(res.error_before_exec, InputRejected)
734 self.assertIsInstance(res.error_before_exec, InputRejected)
740
735
741 def test__IPYTHON__():
736 def test__IPYTHON__():
742 # This shouldn't raise a NameError, that's all
737 # This shouldn't raise a NameError, that's all
743 __IPYTHON__
738 __IPYTHON__
744
739
745
740
746 class DummyRepr(object):
741 class DummyRepr(object):
747 def __repr__(self):
742 def __repr__(self):
748 return "DummyRepr"
743 return "DummyRepr"
749
744
750 def _repr_html_(self):
745 def _repr_html_(self):
751 return "<b>dummy</b>"
746 return "<b>dummy</b>"
752
747
753 def _repr_javascript_(self):
748 def _repr_javascript_(self):
754 return "console.log('hi');", {'key': 'value'}
749 return "console.log('hi');", {'key': 'value'}
755
750
756
751
757 def test_user_variables():
752 def test_user_variables():
758 # enable all formatters
753 # enable all formatters
759 ip.display_formatter.active_types = ip.display_formatter.format_types
754 ip.display_formatter.active_types = ip.display_formatter.format_types
760
755
761 ip.user_ns['dummy'] = d = DummyRepr()
756 ip.user_ns['dummy'] = d = DummyRepr()
762 keys = {'dummy', 'doesnotexist'}
757 keys = {'dummy', 'doesnotexist'}
763 r = ip.user_expressions({ key:key for key in keys})
758 r = ip.user_expressions({ key:key for key in keys})
764
759
765 nt.assert_equal(keys, set(r.keys()))
760 nt.assert_equal(keys, set(r.keys()))
766 dummy = r['dummy']
761 dummy = r['dummy']
767 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
762 nt.assert_equal({'status', 'data', 'metadata'}, set(dummy.keys()))
768 nt.assert_equal(dummy['status'], 'ok')
763 nt.assert_equal(dummy['status'], 'ok')
769 data = dummy['data']
764 data = dummy['data']
770 metadata = dummy['metadata']
765 metadata = dummy['metadata']
771 nt.assert_equal(data.get('text/html'), d._repr_html_())
766 nt.assert_equal(data.get('text/html'), d._repr_html_())
772 js, jsmd = d._repr_javascript_()
767 js, jsmd = d._repr_javascript_()
773 nt.assert_equal(data.get('application/javascript'), js)
768 nt.assert_equal(data.get('application/javascript'), js)
774 nt.assert_equal(metadata.get('application/javascript'), jsmd)
769 nt.assert_equal(metadata.get('application/javascript'), jsmd)
775
770
776 dne = r['doesnotexist']
771 dne = r['doesnotexist']
777 nt.assert_equal(dne['status'], 'error')
772 nt.assert_equal(dne['status'], 'error')
778 nt.assert_equal(dne['ename'], 'NameError')
773 nt.assert_equal(dne['ename'], 'NameError')
779
774
780 # back to text only
775 # back to text only
781 ip.display_formatter.active_types = ['text/plain']
776 ip.display_formatter.active_types = ['text/plain']
782
777
783 def test_user_expression():
778 def test_user_expression():
784 # enable all formatters
779 # enable all formatters
785 ip.display_formatter.active_types = ip.display_formatter.format_types
780 ip.display_formatter.active_types = ip.display_formatter.format_types
786 query = {
781 query = {
787 'a' : '1 + 2',
782 'a' : '1 + 2',
788 'b' : '1/0',
783 'b' : '1/0',
789 }
784 }
790 r = ip.user_expressions(query)
785 r = ip.user_expressions(query)
791 import pprint
786 import pprint
792 pprint.pprint(r)
787 pprint.pprint(r)
793 nt.assert_equal(set(r.keys()), set(query.keys()))
788 nt.assert_equal(set(r.keys()), set(query.keys()))
794 a = r['a']
789 a = r['a']
795 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
790 nt.assert_equal({'status', 'data', 'metadata'}, set(a.keys()))
796 nt.assert_equal(a['status'], 'ok')
791 nt.assert_equal(a['status'], 'ok')
797 data = a['data']
792 data = a['data']
798 metadata = a['metadata']
793 metadata = a['metadata']
799 nt.assert_equal(data.get('text/plain'), '3')
794 nt.assert_equal(data.get('text/plain'), '3')
800
795
801 b = r['b']
796 b = r['b']
802 nt.assert_equal(b['status'], 'error')
797 nt.assert_equal(b['status'], 'error')
803 nt.assert_equal(b['ename'], 'ZeroDivisionError')
798 nt.assert_equal(b['ename'], 'ZeroDivisionError')
804
799
805 # back to text only
800 # back to text only
806 ip.display_formatter.active_types = ['text/plain']
801 ip.display_formatter.active_types = ['text/plain']
807
802
808
803
809
804
810
805
811
806
812 class TestSyntaxErrorTransformer(unittest.TestCase):
807 class TestSyntaxErrorTransformer(unittest.TestCase):
813 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
808 """Check that SyntaxError raised by an input transformer is handled by run_cell()"""
814
809
815 class SyntaxErrorTransformer(InputTransformer):
810 class SyntaxErrorTransformer(InputTransformer):
816
811
817 def push(self, line):
812 def push(self, line):
818 pos = line.find('syntaxerror')
813 pos = line.find('syntaxerror')
819 if pos >= 0:
814 if pos >= 0:
820 e = SyntaxError('input contains "syntaxerror"')
815 e = SyntaxError('input contains "syntaxerror"')
821 e.text = line
816 e.text = line
822 e.offset = pos + 1
817 e.offset = pos + 1
823 raise e
818 raise e
824 return line
819 return line
825
820
826 def reset(self):
821 def reset(self):
827 pass
822 pass
828
823
829 def setUp(self):
824 def setUp(self):
830 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
825 self.transformer = TestSyntaxErrorTransformer.SyntaxErrorTransformer()
831 ip.input_splitter.python_line_transforms.append(self.transformer)
826 ip.input_splitter.python_line_transforms.append(self.transformer)
832 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
827 ip.input_transformer_manager.python_line_transforms.append(self.transformer)
833
828
834 def tearDown(self):
829 def tearDown(self):
835 ip.input_splitter.python_line_transforms.remove(self.transformer)
830 ip.input_splitter.python_line_transforms.remove(self.transformer)
836 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
831 ip.input_transformer_manager.python_line_transforms.remove(self.transformer)
837
832
838 def test_syntaxerror_input_transformer(self):
833 def test_syntaxerror_input_transformer(self):
839 with tt.AssertPrints('1234'):
834 with tt.AssertPrints('1234'):
840 ip.run_cell('1234')
835 ip.run_cell('1234')
841 with tt.AssertPrints('SyntaxError: invalid syntax'):
836 with tt.AssertPrints('SyntaxError: invalid syntax'):
842 ip.run_cell('1 2 3') # plain python syntax error
837 ip.run_cell('1 2 3') # plain python syntax error
843 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
838 with tt.AssertPrints('SyntaxError: input contains "syntaxerror"'):
844 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
839 ip.run_cell('2345 # syntaxerror') # input transformer syntax error
845 with tt.AssertPrints('3456'):
840 with tt.AssertPrints('3456'):
846 ip.run_cell('3456')
841 ip.run_cell('3456')
847
842
848
843
849
844
850 def test_warning_suppression():
845 def test_warning_suppression():
851 ip.run_cell("import warnings")
846 ip.run_cell("import warnings")
852 try:
847 try:
853 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
848 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
854 ip.run_cell("warnings.warn('asdf')")
849 ip.run_cell("warnings.warn('asdf')")
855 # Here's the real test -- if we run that again, we should get the
850 # Here's the real test -- if we run that again, we should get the
856 # warning again. Traditionally, each warning was only issued once per
851 # warning again. Traditionally, each warning was only issued once per
857 # IPython session (approximately), even if the user typed in new and
852 # IPython session (approximately), even if the user typed in new and
858 # different code that should have also triggered the warning, leading
853 # different code that should have also triggered the warning, leading
859 # to much confusion.
854 # to much confusion.
860 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
855 with tt.AssertPrints("UserWarning: asdf", channel="stderr"):
861 ip.run_cell("warnings.warn('asdf')")
856 ip.run_cell("warnings.warn('asdf')")
862 finally:
857 finally:
863 ip.run_cell("del warnings")
858 ip.run_cell("del warnings")
864
859
865
860
866 def test_deprecation_warning():
861 def test_deprecation_warning():
867 ip.run_cell("""
862 ip.run_cell("""
868 import warnings
863 import warnings
869 def wrn():
864 def wrn():
870 warnings.warn(
865 warnings.warn(
871 "I AM A WARNING",
866 "I AM A WARNING",
872 DeprecationWarning
867 DeprecationWarning
873 )
868 )
874 """)
869 """)
875 try:
870 try:
876 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
871 with tt.AssertPrints("I AM A WARNING", channel="stderr"):
877 ip.run_cell("wrn()")
872 ip.run_cell("wrn()")
878 finally:
873 finally:
879 ip.run_cell("del warnings")
874 ip.run_cell("del warnings")
880 ip.run_cell("del wrn")
875 ip.run_cell("del wrn")
881
876
882
877
883 class TestImportNoDeprecate(tt.TempFileMixin):
878 class TestImportNoDeprecate(tt.TempFileMixin):
884
879
885 def setup(self):
880 def setup(self):
886 """Make a valid python temp file."""
881 """Make a valid python temp file."""
887 self.mktmp("""
882 self.mktmp("""
888 import warnings
883 import warnings
889 def wrn():
884 def wrn():
890 warnings.warn(
885 warnings.warn(
891 "I AM A WARNING",
886 "I AM A WARNING",
892 DeprecationWarning
887 DeprecationWarning
893 )
888 )
894 """)
889 """)
895
890
896 def test_no_dep(self):
891 def test_no_dep(self):
897 """
892 """
898 No deprecation warning should be raised from imported functions
893 No deprecation warning should be raised from imported functions
899 """
894 """
900 ip.run_cell("from {} import wrn".format(self.fname))
895 ip.run_cell("from {} import wrn".format(self.fname))
901
896
902 with tt.AssertNotPrints("I AM A WARNING"):
897 with tt.AssertNotPrints("I AM A WARNING"):
903 ip.run_cell("wrn()")
898 ip.run_cell("wrn()")
904 ip.run_cell("del wrn")
899 ip.run_cell("del wrn")
@@ -1,987 +1,979 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for various magic functions.
2 """Tests for various magic functions.
3
3
4 Needs to be run by nose (to make ipython session available).
4 Needs to be run by nose (to make ipython session available).
5 """
5 """
6
6
7 import io
7 import io
8 import os
8 import os
9 import sys
9 import sys
10 import warnings
10 import warnings
11 from unittest import TestCase
11 from unittest import TestCase
12
12 from importlib import invalidate_caches
13 try:
13 from io import StringIO
14 from importlib import invalidate_caches # Required from Python 3.3
15 except ImportError:
16 def invalidate_caches():
17 pass
18
14
19 import nose.tools as nt
15 import nose.tools as nt
20
16
21 from IPython import get_ipython
17 from IPython import get_ipython
22 from IPython.core import magic
18 from IPython.core import magic
23 from IPython.core.error import UsageError
19 from IPython.core.error import UsageError
24 from IPython.core.magic import (Magics, magics_class, line_magic,
20 from IPython.core.magic import (Magics, magics_class, line_magic,
25 cell_magic,
21 cell_magic,
26 register_line_magic, register_cell_magic)
22 register_line_magic, register_cell_magic)
27 from IPython.core.magics import execution, script, code
23 from IPython.core.magics import execution, script, code
28 from IPython.testing import decorators as dec
24 from IPython.testing import decorators as dec
29 from IPython.testing import tools as tt
25 from IPython.testing import tools as tt
30 from IPython.utils import py3compat
26 from IPython.utils import py3compat
31 from IPython.utils.io import capture_output
27 from IPython.utils.io import capture_output
32 from IPython.utils.tempdir import TemporaryDirectory
28 from IPython.utils.tempdir import TemporaryDirectory
33 from IPython.utils.process import find_cmd
29 from IPython.utils.process import find_cmd
34
30
35 if py3compat.PY3:
36 from io import StringIO
37 else:
38 from StringIO import StringIO
39
31
40
32
41 _ip = get_ipython()
33 _ip = get_ipython()
42
34
43 @magic.magics_class
35 @magic.magics_class
44 class DummyMagics(magic.Magics): pass
36 class DummyMagics(magic.Magics): pass
45
37
46 def test_extract_code_ranges():
38 def test_extract_code_ranges():
47 instr = "1 3 5-6 7-9 10:15 17: :10 10- -13 :"
39 instr = "1 3 5-6 7-9 10:15 17: :10 10- -13 :"
48 expected = [(0, 1),
40 expected = [(0, 1),
49 (2, 3),
41 (2, 3),
50 (4, 6),
42 (4, 6),
51 (6, 9),
43 (6, 9),
52 (9, 14),
44 (9, 14),
53 (16, None),
45 (16, None),
54 (None, 9),
46 (None, 9),
55 (9, None),
47 (9, None),
56 (None, 13),
48 (None, 13),
57 (None, None)]
49 (None, None)]
58 actual = list(code.extract_code_ranges(instr))
50 actual = list(code.extract_code_ranges(instr))
59 nt.assert_equal(actual, expected)
51 nt.assert_equal(actual, expected)
60
52
61 def test_extract_symbols():
53 def test_extract_symbols():
62 source = """import foo\na = 10\ndef b():\n return 42\n\n\nclass A: pass\n\n\n"""
54 source = """import foo\na = 10\ndef b():\n return 42\n\n\nclass A: pass\n\n\n"""
63 symbols_args = ["a", "b", "A", "A,b", "A,a", "z"]
55 symbols_args = ["a", "b", "A", "A,b", "A,a", "z"]
64 expected = [([], ['a']),
56 expected = [([], ['a']),
65 (["def b():\n return 42\n"], []),
57 (["def b():\n return 42\n"], []),
66 (["class A: pass\n"], []),
58 (["class A: pass\n"], []),
67 (["class A: pass\n", "def b():\n return 42\n"], []),
59 (["class A: pass\n", "def b():\n return 42\n"], []),
68 (["class A: pass\n"], ['a']),
60 (["class A: pass\n"], ['a']),
69 ([], ['z'])]
61 ([], ['z'])]
70 for symbols, exp in zip(symbols_args, expected):
62 for symbols, exp in zip(symbols_args, expected):
71 nt.assert_equal(code.extract_symbols(source, symbols), exp)
63 nt.assert_equal(code.extract_symbols(source, symbols), exp)
72
64
73
65
74 def test_extract_symbols_raises_exception_with_non_python_code():
66 def test_extract_symbols_raises_exception_with_non_python_code():
75 source = ("=begin A Ruby program :)=end\n"
67 source = ("=begin A Ruby program :)=end\n"
76 "def hello\n"
68 "def hello\n"
77 "puts 'Hello world'\n"
69 "puts 'Hello world'\n"
78 "end")
70 "end")
79 with nt.assert_raises(SyntaxError):
71 with nt.assert_raises(SyntaxError):
80 code.extract_symbols(source, "hello")
72 code.extract_symbols(source, "hello")
81
73
82 def test_config():
74 def test_config():
83 """ test that config magic does not raise
75 """ test that config magic does not raise
84 can happen if Configurable init is moved too early into
76 can happen if Configurable init is moved too early into
85 Magics.__init__ as then a Config object will be registerd as a
77 Magics.__init__ as then a Config object will be registerd as a
86 magic.
78 magic.
87 """
79 """
88 ## should not raise.
80 ## should not raise.
89 _ip.magic('config')
81 _ip.magic('config')
90
82
91 def test_rehashx():
83 def test_rehashx():
92 # clear up everything
84 # clear up everything
93 _ip.alias_manager.clear_aliases()
85 _ip.alias_manager.clear_aliases()
94 del _ip.db['syscmdlist']
86 del _ip.db['syscmdlist']
95
87
96 _ip.magic('rehashx')
88 _ip.magic('rehashx')
97 # Practically ALL ipython development systems will have more than 10 aliases
89 # Practically ALL ipython development systems will have more than 10 aliases
98
90
99 nt.assert_true(len(_ip.alias_manager.aliases) > 10)
91 nt.assert_true(len(_ip.alias_manager.aliases) > 10)
100 for name, cmd in _ip.alias_manager.aliases:
92 for name, cmd in _ip.alias_manager.aliases:
101 # we must strip dots from alias names
93 # we must strip dots from alias names
102 nt.assert_not_in('.', name)
94 nt.assert_not_in('.', name)
103
95
104 # rehashx must fill up syscmdlist
96 # rehashx must fill up syscmdlist
105 scoms = _ip.db['syscmdlist']
97 scoms = _ip.db['syscmdlist']
106 nt.assert_true(len(scoms) > 10)
98 nt.assert_true(len(scoms) > 10)
107
99
108
100
109 def test_magic_parse_options():
101 def test_magic_parse_options():
110 """Test that we don't mangle paths when parsing magic options."""
102 """Test that we don't mangle paths when parsing magic options."""
111 ip = get_ipython()
103 ip = get_ipython()
112 path = 'c:\\x'
104 path = 'c:\\x'
113 m = DummyMagics(ip)
105 m = DummyMagics(ip)
114 opts = m.parse_options('-f %s' % path,'f:')[0]
106 opts = m.parse_options('-f %s' % path,'f:')[0]
115 # argv splitting is os-dependent
107 # argv splitting is os-dependent
116 if os.name == 'posix':
108 if os.name == 'posix':
117 expected = 'c:x'
109 expected = 'c:x'
118 else:
110 else:
119 expected = path
111 expected = path
120 nt.assert_equal(opts['f'], expected)
112 nt.assert_equal(opts['f'], expected)
121
113
122 def test_magic_parse_long_options():
114 def test_magic_parse_long_options():
123 """Magic.parse_options can handle --foo=bar long options"""
115 """Magic.parse_options can handle --foo=bar long options"""
124 ip = get_ipython()
116 ip = get_ipython()
125 m = DummyMagics(ip)
117 m = DummyMagics(ip)
126 opts, _ = m.parse_options('--foo --bar=bubble', 'a', 'foo', 'bar=')
118 opts, _ = m.parse_options('--foo --bar=bubble', 'a', 'foo', 'bar=')
127 nt.assert_in('foo', opts)
119 nt.assert_in('foo', opts)
128 nt.assert_in('bar', opts)
120 nt.assert_in('bar', opts)
129 nt.assert_equal(opts['bar'], "bubble")
121 nt.assert_equal(opts['bar'], "bubble")
130
122
131
123
132 @dec.skip_without('sqlite3')
124 @dec.skip_without('sqlite3')
133 def doctest_hist_f():
125 def doctest_hist_f():
134 """Test %hist -f with temporary filename.
126 """Test %hist -f with temporary filename.
135
127
136 In [9]: import tempfile
128 In [9]: import tempfile
137
129
138 In [10]: tfile = tempfile.mktemp('.py','tmp-ipython-')
130 In [10]: tfile = tempfile.mktemp('.py','tmp-ipython-')
139
131
140 In [11]: %hist -nl -f $tfile 3
132 In [11]: %hist -nl -f $tfile 3
141
133
142 In [13]: import os; os.unlink(tfile)
134 In [13]: import os; os.unlink(tfile)
143 """
135 """
144
136
145
137
146 @dec.skip_without('sqlite3')
138 @dec.skip_without('sqlite3')
147 def doctest_hist_r():
139 def doctest_hist_r():
148 """Test %hist -r
140 """Test %hist -r
149
141
150 XXX - This test is not recording the output correctly. For some reason, in
142 XXX - This test is not recording the output correctly. For some reason, in
151 testing mode the raw history isn't getting populated. No idea why.
143 testing mode the raw history isn't getting populated. No idea why.
152 Disabling the output checking for now, though at least we do run it.
144 Disabling the output checking for now, though at least we do run it.
153
145
154 In [1]: 'hist' in _ip.lsmagic()
146 In [1]: 'hist' in _ip.lsmagic()
155 Out[1]: True
147 Out[1]: True
156
148
157 In [2]: x=1
149 In [2]: x=1
158
150
159 In [3]: %hist -rl 2
151 In [3]: %hist -rl 2
160 x=1 # random
152 x=1 # random
161 %hist -r 2
153 %hist -r 2
162 """
154 """
163
155
164
156
165 @dec.skip_without('sqlite3')
157 @dec.skip_without('sqlite3')
166 def doctest_hist_op():
158 def doctest_hist_op():
167 """Test %hist -op
159 """Test %hist -op
168
160
169 In [1]: class b(float):
161 In [1]: class b(float):
170 ...: pass
162 ...: pass
171 ...:
163 ...:
172
164
173 In [2]: class s(object):
165 In [2]: class s(object):
174 ...: def __str__(self):
166 ...: def __str__(self):
175 ...: return 's'
167 ...: return 's'
176 ...:
168 ...:
177
169
178 In [3]:
170 In [3]:
179
171
180 In [4]: class r(b):
172 In [4]: class r(b):
181 ...: def __repr__(self):
173 ...: def __repr__(self):
182 ...: return 'r'
174 ...: return 'r'
183 ...:
175 ...:
184
176
185 In [5]: class sr(s,r): pass
177 In [5]: class sr(s,r): pass
186 ...:
178 ...:
187
179
188 In [6]:
180 In [6]:
189
181
190 In [7]: bb=b()
182 In [7]: bb=b()
191
183
192 In [8]: ss=s()
184 In [8]: ss=s()
193
185
194 In [9]: rr=r()
186 In [9]: rr=r()
195
187
196 In [10]: ssrr=sr()
188 In [10]: ssrr=sr()
197
189
198 In [11]: 4.5
190 In [11]: 4.5
199 Out[11]: 4.5
191 Out[11]: 4.5
200
192
201 In [12]: str(ss)
193 In [12]: str(ss)
202 Out[12]: 's'
194 Out[12]: 's'
203
195
204 In [13]:
196 In [13]:
205
197
206 In [14]: %hist -op
198 In [14]: %hist -op
207 >>> class b:
199 >>> class b:
208 ... pass
200 ... pass
209 ...
201 ...
210 >>> class s(b):
202 >>> class s(b):
211 ... def __str__(self):
203 ... def __str__(self):
212 ... return 's'
204 ... return 's'
213 ...
205 ...
214 >>>
206 >>>
215 >>> class r(b):
207 >>> class r(b):
216 ... def __repr__(self):
208 ... def __repr__(self):
217 ... return 'r'
209 ... return 'r'
218 ...
210 ...
219 >>> class sr(s,r): pass
211 >>> class sr(s,r): pass
220 >>>
212 >>>
221 >>> bb=b()
213 >>> bb=b()
222 >>> ss=s()
214 >>> ss=s()
223 >>> rr=r()
215 >>> rr=r()
224 >>> ssrr=sr()
216 >>> ssrr=sr()
225 >>> 4.5
217 >>> 4.5
226 4.5
218 4.5
227 >>> str(ss)
219 >>> str(ss)
228 's'
220 's'
229 >>>
221 >>>
230 """
222 """
231
223
232 def test_hist_pof():
224 def test_hist_pof():
233 ip = get_ipython()
225 ip = get_ipython()
234 ip.run_cell(u"1+2", store_history=True)
226 ip.run_cell(u"1+2", store_history=True)
235 #raise Exception(ip.history_manager.session_number)
227 #raise Exception(ip.history_manager.session_number)
236 #raise Exception(list(ip.history_manager._get_range_session()))
228 #raise Exception(list(ip.history_manager._get_range_session()))
237 with TemporaryDirectory() as td:
229 with TemporaryDirectory() as td:
238 tf = os.path.join(td, 'hist.py')
230 tf = os.path.join(td, 'hist.py')
239 ip.run_line_magic('history', '-pof %s' % tf)
231 ip.run_line_magic('history', '-pof %s' % tf)
240 assert os.path.isfile(tf)
232 assert os.path.isfile(tf)
241
233
242
234
243 @dec.skip_without('sqlite3')
235 @dec.skip_without('sqlite3')
244 def test_macro():
236 def test_macro():
245 ip = get_ipython()
237 ip = get_ipython()
246 ip.history_manager.reset() # Clear any existing history.
238 ip.history_manager.reset() # Clear any existing history.
247 cmds = ["a=1", "def b():\n return a**2", "print(a,b())"]
239 cmds = ["a=1", "def b():\n return a**2", "print(a,b())"]
248 for i, cmd in enumerate(cmds, start=1):
240 for i, cmd in enumerate(cmds, start=1):
249 ip.history_manager.store_inputs(i, cmd)
241 ip.history_manager.store_inputs(i, cmd)
250 ip.magic("macro test 1-3")
242 ip.magic("macro test 1-3")
251 nt.assert_equal(ip.user_ns["test"].value, "\n".join(cmds)+"\n")
243 nt.assert_equal(ip.user_ns["test"].value, "\n".join(cmds)+"\n")
252
244
253 # List macros
245 # List macros
254 nt.assert_in("test", ip.magic("macro"))
246 nt.assert_in("test", ip.magic("macro"))
255
247
256
248
257 @dec.skip_without('sqlite3')
249 @dec.skip_without('sqlite3')
258 def test_macro_run():
250 def test_macro_run():
259 """Test that we can run a multi-line macro successfully."""
251 """Test that we can run a multi-line macro successfully."""
260 ip = get_ipython()
252 ip = get_ipython()
261 ip.history_manager.reset()
253 ip.history_manager.reset()
262 cmds = ["a=10", "a+=1", py3compat.doctest_refactor_print("print a"),
254 cmds = ["a=10", "a+=1", py3compat.doctest_refactor_print("print a"),
263 "%macro test 2-3"]
255 "%macro test 2-3"]
264 for cmd in cmds:
256 for cmd in cmds:
265 ip.run_cell(cmd, store_history=True)
257 ip.run_cell(cmd, store_history=True)
266 nt.assert_equal(ip.user_ns["test"].value,
258 nt.assert_equal(ip.user_ns["test"].value,
267 py3compat.doctest_refactor_print("a+=1\nprint a\n"))
259 py3compat.doctest_refactor_print("a+=1\nprint a\n"))
268 with tt.AssertPrints("12"):
260 with tt.AssertPrints("12"):
269 ip.run_cell("test")
261 ip.run_cell("test")
270 with tt.AssertPrints("13"):
262 with tt.AssertPrints("13"):
271 ip.run_cell("test")
263 ip.run_cell("test")
272
264
273
265
274 def test_magic_magic():
266 def test_magic_magic():
275 """Test %magic"""
267 """Test %magic"""
276 ip = get_ipython()
268 ip = get_ipython()
277 with capture_output() as captured:
269 with capture_output() as captured:
278 ip.magic("magic")
270 ip.magic("magic")
279
271
280 stdout = captured.stdout
272 stdout = captured.stdout
281 nt.assert_in('%magic', stdout)
273 nt.assert_in('%magic', stdout)
282 nt.assert_in('IPython', stdout)
274 nt.assert_in('IPython', stdout)
283 nt.assert_in('Available', stdout)
275 nt.assert_in('Available', stdout)
284
276
285
277
286 @dec.skipif_not_numpy
278 @dec.skipif_not_numpy
287 def test_numpy_reset_array_undec():
279 def test_numpy_reset_array_undec():
288 "Test '%reset array' functionality"
280 "Test '%reset array' functionality"
289 _ip.ex('import numpy as np')
281 _ip.ex('import numpy as np')
290 _ip.ex('a = np.empty(2)')
282 _ip.ex('a = np.empty(2)')
291 nt.assert_in('a', _ip.user_ns)
283 nt.assert_in('a', _ip.user_ns)
292 _ip.magic('reset -f array')
284 _ip.magic('reset -f array')
293 nt.assert_not_in('a', _ip.user_ns)
285 nt.assert_not_in('a', _ip.user_ns)
294
286
295 def test_reset_out():
287 def test_reset_out():
296 "Test '%reset out' magic"
288 "Test '%reset out' magic"
297 _ip.run_cell("parrot = 'dead'", store_history=True)
289 _ip.run_cell("parrot = 'dead'", store_history=True)
298 # test '%reset -f out', make an Out prompt
290 # test '%reset -f out', make an Out prompt
299 _ip.run_cell("parrot", store_history=True)
291 _ip.run_cell("parrot", store_history=True)
300 nt.assert_true('dead' in [_ip.user_ns[x] for x in ('_','__','___')])
292 nt.assert_true('dead' in [_ip.user_ns[x] for x in ('_','__','___')])
301 _ip.magic('reset -f out')
293 _ip.magic('reset -f out')
302 nt.assert_false('dead' in [_ip.user_ns[x] for x in ('_','__','___')])
294 nt.assert_false('dead' in [_ip.user_ns[x] for x in ('_','__','___')])
303 nt.assert_equal(len(_ip.user_ns['Out']), 0)
295 nt.assert_equal(len(_ip.user_ns['Out']), 0)
304
296
305 def test_reset_in():
297 def test_reset_in():
306 "Test '%reset in' magic"
298 "Test '%reset in' magic"
307 # test '%reset -f in'
299 # test '%reset -f in'
308 _ip.run_cell("parrot", store_history=True)
300 _ip.run_cell("parrot", store_history=True)
309 nt.assert_true('parrot' in [_ip.user_ns[x] for x in ('_i','_ii','_iii')])
301 nt.assert_true('parrot' in [_ip.user_ns[x] for x in ('_i','_ii','_iii')])
310 _ip.magic('%reset -f in')
302 _ip.magic('%reset -f in')
311 nt.assert_false('parrot' in [_ip.user_ns[x] for x in ('_i','_ii','_iii')])
303 nt.assert_false('parrot' in [_ip.user_ns[x] for x in ('_i','_ii','_iii')])
312 nt.assert_equal(len(set(_ip.user_ns['In'])), 1)
304 nt.assert_equal(len(set(_ip.user_ns['In'])), 1)
313
305
314 def test_reset_dhist():
306 def test_reset_dhist():
315 "Test '%reset dhist' magic"
307 "Test '%reset dhist' magic"
316 _ip.run_cell("tmp = [d for d in _dh]") # copy before clearing
308 _ip.run_cell("tmp = [d for d in _dh]") # copy before clearing
317 _ip.magic('cd ' + os.path.dirname(nt.__file__))
309 _ip.magic('cd ' + os.path.dirname(nt.__file__))
318 _ip.magic('cd -')
310 _ip.magic('cd -')
319 nt.assert_true(len(_ip.user_ns['_dh']) > 0)
311 nt.assert_true(len(_ip.user_ns['_dh']) > 0)
320 _ip.magic('reset -f dhist')
312 _ip.magic('reset -f dhist')
321 nt.assert_equal(len(_ip.user_ns['_dh']), 0)
313 nt.assert_equal(len(_ip.user_ns['_dh']), 0)
322 _ip.run_cell("_dh = [d for d in tmp]") #restore
314 _ip.run_cell("_dh = [d for d in tmp]") #restore
323
315
324 def test_reset_in_length():
316 def test_reset_in_length():
325 "Test that '%reset in' preserves In[] length"
317 "Test that '%reset in' preserves In[] length"
326 _ip.run_cell("print 'foo'")
318 _ip.run_cell("print 'foo'")
327 _ip.run_cell("reset -f in")
319 _ip.run_cell("reset -f in")
328 nt.assert_equal(len(_ip.user_ns['In']), _ip.displayhook.prompt_count+1)
320 nt.assert_equal(len(_ip.user_ns['In']), _ip.displayhook.prompt_count+1)
329
321
330 def test_tb_syntaxerror():
322 def test_tb_syntaxerror():
331 """test %tb after a SyntaxError"""
323 """test %tb after a SyntaxError"""
332 ip = get_ipython()
324 ip = get_ipython()
333 ip.run_cell("for")
325 ip.run_cell("for")
334
326
335 # trap and validate stdout
327 # trap and validate stdout
336 save_stdout = sys.stdout
328 save_stdout = sys.stdout
337 try:
329 try:
338 sys.stdout = StringIO()
330 sys.stdout = StringIO()
339 ip.run_cell("%tb")
331 ip.run_cell("%tb")
340 out = sys.stdout.getvalue()
332 out = sys.stdout.getvalue()
341 finally:
333 finally:
342 sys.stdout = save_stdout
334 sys.stdout = save_stdout
343 # trim output, and only check the last line
335 # trim output, and only check the last line
344 last_line = out.rstrip().splitlines()[-1].strip()
336 last_line = out.rstrip().splitlines()[-1].strip()
345 nt.assert_equal(last_line, "SyntaxError: invalid syntax")
337 nt.assert_equal(last_line, "SyntaxError: invalid syntax")
346
338
347
339
348 def test_time():
340 def test_time():
349 ip = get_ipython()
341 ip = get_ipython()
350
342
351 with tt.AssertPrints("Wall time: "):
343 with tt.AssertPrints("Wall time: "):
352 ip.run_cell("%time None")
344 ip.run_cell("%time None")
353
345
354 ip.run_cell("def f(kmjy):\n"
346 ip.run_cell("def f(kmjy):\n"
355 " %time print (2*kmjy)")
347 " %time print (2*kmjy)")
356
348
357 with tt.AssertPrints("Wall time: "):
349 with tt.AssertPrints("Wall time: "):
358 with tt.AssertPrints("hihi", suppress=False):
350 with tt.AssertPrints("hihi", suppress=False):
359 ip.run_cell("f('hi')")
351 ip.run_cell("f('hi')")
360
352
361
353
362 @dec.skip_win32
354 @dec.skip_win32
363 def test_time2():
355 def test_time2():
364 ip = get_ipython()
356 ip = get_ipython()
365
357
366 with tt.AssertPrints("CPU times: user "):
358 with tt.AssertPrints("CPU times: user "):
367 ip.run_cell("%time None")
359 ip.run_cell("%time None")
368
360
369 def test_time3():
361 def test_time3():
370 """Erroneous magic function calls, issue gh-3334"""
362 """Erroneous magic function calls, issue gh-3334"""
371 ip = get_ipython()
363 ip = get_ipython()
372 ip.user_ns.pop('run', None)
364 ip.user_ns.pop('run', None)
373
365
374 with tt.AssertNotPrints("not found", channel='stderr'):
366 with tt.AssertNotPrints("not found", channel='stderr'):
375 ip.run_cell("%%time\n"
367 ip.run_cell("%%time\n"
376 "run = 0\n"
368 "run = 0\n"
377 "run += 1")
369 "run += 1")
378
370
379 def test_doctest_mode():
371 def test_doctest_mode():
380 "Toggle doctest_mode twice, it should be a no-op and run without error"
372 "Toggle doctest_mode twice, it should be a no-op and run without error"
381 _ip.magic('doctest_mode')
373 _ip.magic('doctest_mode')
382 _ip.magic('doctest_mode')
374 _ip.magic('doctest_mode')
383
375
384
376
385 def test_parse_options():
377 def test_parse_options():
386 """Tests for basic options parsing in magics."""
378 """Tests for basic options parsing in magics."""
387 # These are only the most minimal of tests, more should be added later. At
379 # These are only the most minimal of tests, more should be added later. At
388 # the very least we check that basic text/unicode calls work OK.
380 # the very least we check that basic text/unicode calls work OK.
389 m = DummyMagics(_ip)
381 m = DummyMagics(_ip)
390 nt.assert_equal(m.parse_options('foo', '')[1], 'foo')
382 nt.assert_equal(m.parse_options('foo', '')[1], 'foo')
391 nt.assert_equal(m.parse_options(u'foo', '')[1], u'foo')
383 nt.assert_equal(m.parse_options(u'foo', '')[1], u'foo')
392
384
393
385
394 def test_dirops():
386 def test_dirops():
395 """Test various directory handling operations."""
387 """Test various directory handling operations."""
396 # curpath = lambda :os.path.splitdrive(os.getcwd())[1].replace('\\','/')
388 # curpath = lambda :os.path.splitdrive(os.getcwd())[1].replace('\\','/')
397 curpath = os.getcwd
389 curpath = os.getcwd
398 startdir = os.getcwd()
390 startdir = os.getcwd()
399 ipdir = os.path.realpath(_ip.ipython_dir)
391 ipdir = os.path.realpath(_ip.ipython_dir)
400 try:
392 try:
401 _ip.magic('cd "%s"' % ipdir)
393 _ip.magic('cd "%s"' % ipdir)
402 nt.assert_equal(curpath(), ipdir)
394 nt.assert_equal(curpath(), ipdir)
403 _ip.magic('cd -')
395 _ip.magic('cd -')
404 nt.assert_equal(curpath(), startdir)
396 nt.assert_equal(curpath(), startdir)
405 _ip.magic('pushd "%s"' % ipdir)
397 _ip.magic('pushd "%s"' % ipdir)
406 nt.assert_equal(curpath(), ipdir)
398 nt.assert_equal(curpath(), ipdir)
407 _ip.magic('popd')
399 _ip.magic('popd')
408 nt.assert_equal(curpath(), startdir)
400 nt.assert_equal(curpath(), startdir)
409 finally:
401 finally:
410 os.chdir(startdir)
402 os.chdir(startdir)
411
403
412
404
413 def test_xmode():
405 def test_xmode():
414 # Calling xmode three times should be a no-op
406 # Calling xmode three times should be a no-op
415 xmode = _ip.InteractiveTB.mode
407 xmode = _ip.InteractiveTB.mode
416 for i in range(3):
408 for i in range(3):
417 _ip.magic("xmode")
409 _ip.magic("xmode")
418 nt.assert_equal(_ip.InteractiveTB.mode, xmode)
410 nt.assert_equal(_ip.InteractiveTB.mode, xmode)
419
411
420 def test_reset_hard():
412 def test_reset_hard():
421 monitor = []
413 monitor = []
422 class A(object):
414 class A(object):
423 def __del__(self):
415 def __del__(self):
424 monitor.append(1)
416 monitor.append(1)
425 def __repr__(self):
417 def __repr__(self):
426 return "<A instance>"
418 return "<A instance>"
427
419
428 _ip.user_ns["a"] = A()
420 _ip.user_ns["a"] = A()
429 _ip.run_cell("a")
421 _ip.run_cell("a")
430
422
431 nt.assert_equal(monitor, [])
423 nt.assert_equal(monitor, [])
432 _ip.magic("reset -f")
424 _ip.magic("reset -f")
433 nt.assert_equal(monitor, [1])
425 nt.assert_equal(monitor, [1])
434
426
435 class TestXdel(tt.TempFileMixin):
427 class TestXdel(tt.TempFileMixin):
436 def test_xdel(self):
428 def test_xdel(self):
437 """Test that references from %run are cleared by xdel."""
429 """Test that references from %run are cleared by xdel."""
438 src = ("class A(object):\n"
430 src = ("class A(object):\n"
439 " monitor = []\n"
431 " monitor = []\n"
440 " def __del__(self):\n"
432 " def __del__(self):\n"
441 " self.monitor.append(1)\n"
433 " self.monitor.append(1)\n"
442 "a = A()\n")
434 "a = A()\n")
443 self.mktmp(src)
435 self.mktmp(src)
444 # %run creates some hidden references...
436 # %run creates some hidden references...
445 _ip.magic("run %s" % self.fname)
437 _ip.magic("run %s" % self.fname)
446 # ... as does the displayhook.
438 # ... as does the displayhook.
447 _ip.run_cell("a")
439 _ip.run_cell("a")
448
440
449 monitor = _ip.user_ns["A"].monitor
441 monitor = _ip.user_ns["A"].monitor
450 nt.assert_equal(monitor, [])
442 nt.assert_equal(monitor, [])
451
443
452 _ip.magic("xdel a")
444 _ip.magic("xdel a")
453
445
454 # Check that a's __del__ method has been called.
446 # Check that a's __del__ method has been called.
455 nt.assert_equal(monitor, [1])
447 nt.assert_equal(monitor, [1])
456
448
457 def doctest_who():
449 def doctest_who():
458 """doctest for %who
450 """doctest for %who
459
451
460 In [1]: %reset -f
452 In [1]: %reset -f
461
453
462 In [2]: alpha = 123
454 In [2]: alpha = 123
463
455
464 In [3]: beta = 'beta'
456 In [3]: beta = 'beta'
465
457
466 In [4]: %who int
458 In [4]: %who int
467 alpha
459 alpha
468
460
469 In [5]: %who str
461 In [5]: %who str
470 beta
462 beta
471
463
472 In [6]: %whos
464 In [6]: %whos
473 Variable Type Data/Info
465 Variable Type Data/Info
474 ----------------------------
466 ----------------------------
475 alpha int 123
467 alpha int 123
476 beta str beta
468 beta str beta
477
469
478 In [7]: %who_ls
470 In [7]: %who_ls
479 Out[7]: ['alpha', 'beta']
471 Out[7]: ['alpha', 'beta']
480 """
472 """
481
473
482 def test_whos():
474 def test_whos():
483 """Check that whos is protected against objects where repr() fails."""
475 """Check that whos is protected against objects where repr() fails."""
484 class A(object):
476 class A(object):
485 def __repr__(self):
477 def __repr__(self):
486 raise Exception()
478 raise Exception()
487 _ip.user_ns['a'] = A()
479 _ip.user_ns['a'] = A()
488 _ip.magic("whos")
480 _ip.magic("whos")
489
481
490 @py3compat.u_format
482 @py3compat.u_format
491 def doctest_precision():
483 def doctest_precision():
492 """doctest for %precision
484 """doctest for %precision
493
485
494 In [1]: f = get_ipython().display_formatter.formatters['text/plain']
486 In [1]: f = get_ipython().display_formatter.formatters['text/plain']
495
487
496 In [2]: %precision 5
488 In [2]: %precision 5
497 Out[2]: {u}'%.5f'
489 Out[2]: {u}'%.5f'
498
490
499 In [3]: f.float_format
491 In [3]: f.float_format
500 Out[3]: {u}'%.5f'
492 Out[3]: {u}'%.5f'
501
493
502 In [4]: %precision %e
494 In [4]: %precision %e
503 Out[4]: {u}'%e'
495 Out[4]: {u}'%e'
504
496
505 In [5]: f(3.1415927)
497 In [5]: f(3.1415927)
506 Out[5]: {u}'3.141593e+00'
498 Out[5]: {u}'3.141593e+00'
507 """
499 """
508
500
509 def test_psearch():
501 def test_psearch():
510 with tt.AssertPrints("dict.fromkeys"):
502 with tt.AssertPrints("dict.fromkeys"):
511 _ip.run_cell("dict.fr*?")
503 _ip.run_cell("dict.fr*?")
512
504
513 def test_timeit_shlex():
505 def test_timeit_shlex():
514 """test shlex issues with timeit (#1109)"""
506 """test shlex issues with timeit (#1109)"""
515 _ip.ex("def f(*a,**kw): pass")
507 _ip.ex("def f(*a,**kw): pass")
516 _ip.magic('timeit -n1 "this is a bug".count(" ")')
508 _ip.magic('timeit -n1 "this is a bug".count(" ")')
517 _ip.magic('timeit -r1 -n1 f(" ", 1)')
509 _ip.magic('timeit -r1 -n1 f(" ", 1)')
518 _ip.magic('timeit -r1 -n1 f(" ", 1, " ", 2, " ")')
510 _ip.magic('timeit -r1 -n1 f(" ", 1, " ", 2, " ")')
519 _ip.magic('timeit -r1 -n1 ("a " + "b")')
511 _ip.magic('timeit -r1 -n1 ("a " + "b")')
520 _ip.magic('timeit -r1 -n1 f("a " + "b")')
512 _ip.magic('timeit -r1 -n1 f("a " + "b")')
521 _ip.magic('timeit -r1 -n1 f("a " + "b ")')
513 _ip.magic('timeit -r1 -n1 f("a " + "b ")')
522
514
523
515
524 def test_timeit_arguments():
516 def test_timeit_arguments():
525 "Test valid timeit arguments, should not cause SyntaxError (GH #1269)"
517 "Test valid timeit arguments, should not cause SyntaxError (GH #1269)"
526 _ip.magic("timeit ('#')")
518 _ip.magic("timeit ('#')")
527
519
528
520
529 def test_timeit_special_syntax():
521 def test_timeit_special_syntax():
530 "Test %%timeit with IPython special syntax"
522 "Test %%timeit with IPython special syntax"
531 @register_line_magic
523 @register_line_magic
532 def lmagic(line):
524 def lmagic(line):
533 ip = get_ipython()
525 ip = get_ipython()
534 ip.user_ns['lmagic_out'] = line
526 ip.user_ns['lmagic_out'] = line
535
527
536 # line mode test
528 # line mode test
537 _ip.run_line_magic('timeit', '-n1 -r1 %lmagic my line')
529 _ip.run_line_magic('timeit', '-n1 -r1 %lmagic my line')
538 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line')
530 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line')
539 # cell mode test
531 # cell mode test
540 _ip.run_cell_magic('timeit', '-n1 -r1', '%lmagic my line2')
532 _ip.run_cell_magic('timeit', '-n1 -r1', '%lmagic my line2')
541 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line2')
533 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line2')
542
534
543 def test_timeit_return():
535 def test_timeit_return():
544 """
536 """
545 test wether timeit -o return object
537 test wether timeit -o return object
546 """
538 """
547
539
548 res = _ip.run_line_magic('timeit','-n10 -r10 -o 1')
540 res = _ip.run_line_magic('timeit','-n10 -r10 -o 1')
549 assert(res is not None)
541 assert(res is not None)
550
542
551 def test_timeit_quiet():
543 def test_timeit_quiet():
552 """
544 """
553 test quiet option of timeit magic
545 test quiet option of timeit magic
554 """
546 """
555 with tt.AssertNotPrints("loops"):
547 with tt.AssertNotPrints("loops"):
556 _ip.run_cell("%timeit -n1 -r1 -q 1")
548 _ip.run_cell("%timeit -n1 -r1 -q 1")
557
549
558 def test_timeit_return_quiet():
550 def test_timeit_return_quiet():
559 with tt.AssertNotPrints("loops"):
551 with tt.AssertNotPrints("loops"):
560 res = _ip.run_line_magic('timeit', '-n1 -r1 -q -o 1')
552 res = _ip.run_line_magic('timeit', '-n1 -r1 -q -o 1')
561 assert (res is not None)
553 assert (res is not None)
562
554
563 @dec.skipif(execution.profile is None)
555 @dec.skipif(execution.profile is None)
564 def test_prun_special_syntax():
556 def test_prun_special_syntax():
565 "Test %%prun with IPython special syntax"
557 "Test %%prun with IPython special syntax"
566 @register_line_magic
558 @register_line_magic
567 def lmagic(line):
559 def lmagic(line):
568 ip = get_ipython()
560 ip = get_ipython()
569 ip.user_ns['lmagic_out'] = line
561 ip.user_ns['lmagic_out'] = line
570
562
571 # line mode test
563 # line mode test
572 _ip.run_line_magic('prun', '-q %lmagic my line')
564 _ip.run_line_magic('prun', '-q %lmagic my line')
573 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line')
565 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line')
574 # cell mode test
566 # cell mode test
575 _ip.run_cell_magic('prun', '-q', '%lmagic my line2')
567 _ip.run_cell_magic('prun', '-q', '%lmagic my line2')
576 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line2')
568 nt.assert_equal(_ip.user_ns['lmagic_out'], 'my line2')
577
569
578 @dec.skipif(execution.profile is None)
570 @dec.skipif(execution.profile is None)
579 def test_prun_quotes():
571 def test_prun_quotes():
580 "Test that prun does not clobber string escapes (GH #1302)"
572 "Test that prun does not clobber string escapes (GH #1302)"
581 _ip.magic(r"prun -q x = '\t'")
573 _ip.magic(r"prun -q x = '\t'")
582 nt.assert_equal(_ip.user_ns['x'], '\t')
574 nt.assert_equal(_ip.user_ns['x'], '\t')
583
575
584 def test_extension():
576 def test_extension():
585 # Debugging information for failures of this test
577 # Debugging information for failures of this test
586 print('sys.path:')
578 print('sys.path:')
587 for p in sys.path:
579 for p in sys.path:
588 print(' ', p)
580 print(' ', p)
589 print('CWD', os.getcwd())
581 print('CWD', os.getcwd())
590
582
591 nt.assert_raises(ImportError, _ip.magic, "load_ext daft_extension")
583 nt.assert_raises(ImportError, _ip.magic, "load_ext daft_extension")
592 daft_path = os.path.join(os.path.dirname(__file__), "daft_extension")
584 daft_path = os.path.join(os.path.dirname(__file__), "daft_extension")
593 sys.path.insert(0, daft_path)
585 sys.path.insert(0, daft_path)
594 try:
586 try:
595 _ip.user_ns.pop('arq', None)
587 _ip.user_ns.pop('arq', None)
596 invalidate_caches() # Clear import caches
588 invalidate_caches() # Clear import caches
597 _ip.magic("load_ext daft_extension")
589 _ip.magic("load_ext daft_extension")
598 nt.assert_equal(_ip.user_ns['arq'], 185)
590 nt.assert_equal(_ip.user_ns['arq'], 185)
599 _ip.magic("unload_ext daft_extension")
591 _ip.magic("unload_ext daft_extension")
600 assert 'arq' not in _ip.user_ns
592 assert 'arq' not in _ip.user_ns
601 finally:
593 finally:
602 sys.path.remove(daft_path)
594 sys.path.remove(daft_path)
603
595
604
596
605 def test_notebook_export_json():
597 def test_notebook_export_json():
606 _ip = get_ipython()
598 _ip = get_ipython()
607 _ip.history_manager.reset() # Clear any existing history.
599 _ip.history_manager.reset() # Clear any existing history.
608 cmds = [u"a=1", u"def b():\n return a**2", u"print('noΓ«l, Γ©tΓ©', b())"]
600 cmds = [u"a=1", u"def b():\n return a**2", u"print('noΓ«l, Γ©tΓ©', b())"]
609 for i, cmd in enumerate(cmds, start=1):
601 for i, cmd in enumerate(cmds, start=1):
610 _ip.history_manager.store_inputs(i, cmd)
602 _ip.history_manager.store_inputs(i, cmd)
611 with TemporaryDirectory() as td:
603 with TemporaryDirectory() as td:
612 outfile = os.path.join(td, "nb.ipynb")
604 outfile = os.path.join(td, "nb.ipynb")
613 _ip.magic("notebook -e %s" % outfile)
605 _ip.magic("notebook -e %s" % outfile)
614
606
615
607
616 class TestEnv(TestCase):
608 class TestEnv(TestCase):
617
609
618 def test_env(self):
610 def test_env(self):
619 env = _ip.magic("env")
611 env = _ip.magic("env")
620 self.assertTrue(isinstance(env, dict))
612 self.assertTrue(isinstance(env, dict))
621
613
622 def test_env_get_set_simple(self):
614 def test_env_get_set_simple(self):
623 env = _ip.magic("env var val1")
615 env = _ip.magic("env var val1")
624 self.assertEqual(env, None)
616 self.assertEqual(env, None)
625 self.assertEqual(os.environ['var'], 'val1')
617 self.assertEqual(os.environ['var'], 'val1')
626 self.assertEqual(_ip.magic("env var"), 'val1')
618 self.assertEqual(_ip.magic("env var"), 'val1')
627 env = _ip.magic("env var=val2")
619 env = _ip.magic("env var=val2")
628 self.assertEqual(env, None)
620 self.assertEqual(env, None)
629 self.assertEqual(os.environ['var'], 'val2')
621 self.assertEqual(os.environ['var'], 'val2')
630
622
631 def test_env_get_set_complex(self):
623 def test_env_get_set_complex(self):
632 env = _ip.magic("env var 'val1 '' 'val2")
624 env = _ip.magic("env var 'val1 '' 'val2")
633 self.assertEqual(env, None)
625 self.assertEqual(env, None)
634 self.assertEqual(os.environ['var'], "'val1 '' 'val2")
626 self.assertEqual(os.environ['var'], "'val1 '' 'val2")
635 self.assertEqual(_ip.magic("env var"), "'val1 '' 'val2")
627 self.assertEqual(_ip.magic("env var"), "'val1 '' 'val2")
636 env = _ip.magic('env var=val2 val3="val4')
628 env = _ip.magic('env var=val2 val3="val4')
637 self.assertEqual(env, None)
629 self.assertEqual(env, None)
638 self.assertEqual(os.environ['var'], 'val2 val3="val4')
630 self.assertEqual(os.environ['var'], 'val2 val3="val4')
639
631
640 def test_env_set_bad_input(self):
632 def test_env_set_bad_input(self):
641 self.assertRaises(UsageError, lambda: _ip.magic("set_env var"))
633 self.assertRaises(UsageError, lambda: _ip.magic("set_env var"))
642
634
643 def test_env_set_whitespace(self):
635 def test_env_set_whitespace(self):
644 self.assertRaises(UsageError, lambda: _ip.magic("env var A=B"))
636 self.assertRaises(UsageError, lambda: _ip.magic("env var A=B"))
645
637
646
638
647 class CellMagicTestCase(TestCase):
639 class CellMagicTestCase(TestCase):
648
640
649 def check_ident(self, magic):
641 def check_ident(self, magic):
650 # Manually called, we get the result
642 # Manually called, we get the result
651 out = _ip.run_cell_magic(magic, 'a', 'b')
643 out = _ip.run_cell_magic(magic, 'a', 'b')
652 nt.assert_equal(out, ('a','b'))
644 nt.assert_equal(out, ('a','b'))
653 # Via run_cell, it goes into the user's namespace via displayhook
645 # Via run_cell, it goes into the user's namespace via displayhook
654 _ip.run_cell('%%' + magic +' c\nd')
646 _ip.run_cell('%%' + magic +' c\nd')
655 nt.assert_equal(_ip.user_ns['_'], ('c','d'))
647 nt.assert_equal(_ip.user_ns['_'], ('c','d'))
656
648
657 def test_cell_magic_func_deco(self):
649 def test_cell_magic_func_deco(self):
658 "Cell magic using simple decorator"
650 "Cell magic using simple decorator"
659 @register_cell_magic
651 @register_cell_magic
660 def cellm(line, cell):
652 def cellm(line, cell):
661 return line, cell
653 return line, cell
662
654
663 self.check_ident('cellm')
655 self.check_ident('cellm')
664
656
665 def test_cell_magic_reg(self):
657 def test_cell_magic_reg(self):
666 "Cell magic manually registered"
658 "Cell magic manually registered"
667 def cellm(line, cell):
659 def cellm(line, cell):
668 return line, cell
660 return line, cell
669
661
670 _ip.register_magic_function(cellm, 'cell', 'cellm2')
662 _ip.register_magic_function(cellm, 'cell', 'cellm2')
671 self.check_ident('cellm2')
663 self.check_ident('cellm2')
672
664
673 def test_cell_magic_class(self):
665 def test_cell_magic_class(self):
674 "Cell magics declared via a class"
666 "Cell magics declared via a class"
675 @magics_class
667 @magics_class
676 class MyMagics(Magics):
668 class MyMagics(Magics):
677
669
678 @cell_magic
670 @cell_magic
679 def cellm3(self, line, cell):
671 def cellm3(self, line, cell):
680 return line, cell
672 return line, cell
681
673
682 _ip.register_magics(MyMagics)
674 _ip.register_magics(MyMagics)
683 self.check_ident('cellm3')
675 self.check_ident('cellm3')
684
676
685 def test_cell_magic_class2(self):
677 def test_cell_magic_class2(self):
686 "Cell magics declared via a class, #2"
678 "Cell magics declared via a class, #2"
687 @magics_class
679 @magics_class
688 class MyMagics2(Magics):
680 class MyMagics2(Magics):
689
681
690 @cell_magic('cellm4')
682 @cell_magic('cellm4')
691 def cellm33(self, line, cell):
683 def cellm33(self, line, cell):
692 return line, cell
684 return line, cell
693
685
694 _ip.register_magics(MyMagics2)
686 _ip.register_magics(MyMagics2)
695 self.check_ident('cellm4')
687 self.check_ident('cellm4')
696 # Check that nothing is registered as 'cellm33'
688 # Check that nothing is registered as 'cellm33'
697 c33 = _ip.find_cell_magic('cellm33')
689 c33 = _ip.find_cell_magic('cellm33')
698 nt.assert_equal(c33, None)
690 nt.assert_equal(c33, None)
699
691
700 def test_file():
692 def test_file():
701 """Basic %%file"""
693 """Basic %%file"""
702 ip = get_ipython()
694 ip = get_ipython()
703 with TemporaryDirectory() as td:
695 with TemporaryDirectory() as td:
704 fname = os.path.join(td, 'file1')
696 fname = os.path.join(td, 'file1')
705 ip.run_cell_magic("file", fname, u'\n'.join([
697 ip.run_cell_magic("file", fname, u'\n'.join([
706 'line1',
698 'line1',
707 'line2',
699 'line2',
708 ]))
700 ]))
709 with open(fname) as f:
701 with open(fname) as f:
710 s = f.read()
702 s = f.read()
711 nt.assert_in('line1\n', s)
703 nt.assert_in('line1\n', s)
712 nt.assert_in('line2', s)
704 nt.assert_in('line2', s)
713
705
714 def test_file_var_expand():
706 def test_file_var_expand():
715 """%%file $filename"""
707 """%%file $filename"""
716 ip = get_ipython()
708 ip = get_ipython()
717 with TemporaryDirectory() as td:
709 with TemporaryDirectory() as td:
718 fname = os.path.join(td, 'file1')
710 fname = os.path.join(td, 'file1')
719 ip.user_ns['filename'] = fname
711 ip.user_ns['filename'] = fname
720 ip.run_cell_magic("file", '$filename', u'\n'.join([
712 ip.run_cell_magic("file", '$filename', u'\n'.join([
721 'line1',
713 'line1',
722 'line2',
714 'line2',
723 ]))
715 ]))
724 with open(fname) as f:
716 with open(fname) as f:
725 s = f.read()
717 s = f.read()
726 nt.assert_in('line1\n', s)
718 nt.assert_in('line1\n', s)
727 nt.assert_in('line2', s)
719 nt.assert_in('line2', s)
728
720
729 def test_file_unicode():
721 def test_file_unicode():
730 """%%file with unicode cell"""
722 """%%file with unicode cell"""
731 ip = get_ipython()
723 ip = get_ipython()
732 with TemporaryDirectory() as td:
724 with TemporaryDirectory() as td:
733 fname = os.path.join(td, 'file1')
725 fname = os.path.join(td, 'file1')
734 ip.run_cell_magic("file", fname, u'\n'.join([
726 ip.run_cell_magic("file", fname, u'\n'.join([
735 u'linΓ©1',
727 u'linΓ©1',
736 u'linΓ©2',
728 u'linΓ©2',
737 ]))
729 ]))
738 with io.open(fname, encoding='utf-8') as f:
730 with io.open(fname, encoding='utf-8') as f:
739 s = f.read()
731 s = f.read()
740 nt.assert_in(u'linΓ©1\n', s)
732 nt.assert_in(u'linΓ©1\n', s)
741 nt.assert_in(u'linΓ©2', s)
733 nt.assert_in(u'linΓ©2', s)
742
734
743 def test_file_amend():
735 def test_file_amend():
744 """%%file -a amends files"""
736 """%%file -a amends files"""
745 ip = get_ipython()
737 ip = get_ipython()
746 with TemporaryDirectory() as td:
738 with TemporaryDirectory() as td:
747 fname = os.path.join(td, 'file2')
739 fname = os.path.join(td, 'file2')
748 ip.run_cell_magic("file", fname, u'\n'.join([
740 ip.run_cell_magic("file", fname, u'\n'.join([
749 'line1',
741 'line1',
750 'line2',
742 'line2',
751 ]))
743 ]))
752 ip.run_cell_magic("file", "-a %s" % fname, u'\n'.join([
744 ip.run_cell_magic("file", "-a %s" % fname, u'\n'.join([
753 'line3',
745 'line3',
754 'line4',
746 'line4',
755 ]))
747 ]))
756 with open(fname) as f:
748 with open(fname) as f:
757 s = f.read()
749 s = f.read()
758 nt.assert_in('line1\n', s)
750 nt.assert_in('line1\n', s)
759 nt.assert_in('line3\n', s)
751 nt.assert_in('line3\n', s)
760
752
761
753
762 def test_script_config():
754 def test_script_config():
763 ip = get_ipython()
755 ip = get_ipython()
764 ip.config.ScriptMagics.script_magics = ['whoda']
756 ip.config.ScriptMagics.script_magics = ['whoda']
765 sm = script.ScriptMagics(shell=ip)
757 sm = script.ScriptMagics(shell=ip)
766 nt.assert_in('whoda', sm.magics['cell'])
758 nt.assert_in('whoda', sm.magics['cell'])
767
759
768 @dec.skip_win32
760 @dec.skip_win32
769 def test_script_out():
761 def test_script_out():
770 ip = get_ipython()
762 ip = get_ipython()
771 ip.run_cell_magic("script", "--out output sh", "echo 'hi'")
763 ip.run_cell_magic("script", "--out output sh", "echo 'hi'")
772 nt.assert_equal(ip.user_ns['output'], 'hi\n')
764 nt.assert_equal(ip.user_ns['output'], 'hi\n')
773
765
774 @dec.skip_win32
766 @dec.skip_win32
775 def test_script_err():
767 def test_script_err():
776 ip = get_ipython()
768 ip = get_ipython()
777 ip.run_cell_magic("script", "--err error sh", "echo 'hello' >&2")
769 ip.run_cell_magic("script", "--err error sh", "echo 'hello' >&2")
778 nt.assert_equal(ip.user_ns['error'], 'hello\n')
770 nt.assert_equal(ip.user_ns['error'], 'hello\n')
779
771
780 @dec.skip_win32
772 @dec.skip_win32
781 def test_script_out_err():
773 def test_script_out_err():
782 ip = get_ipython()
774 ip = get_ipython()
783 ip.run_cell_magic("script", "--out output --err error sh", "echo 'hi'\necho 'hello' >&2")
775 ip.run_cell_magic("script", "--out output --err error sh", "echo 'hi'\necho 'hello' >&2")
784 nt.assert_equal(ip.user_ns['output'], 'hi\n')
776 nt.assert_equal(ip.user_ns['output'], 'hi\n')
785 nt.assert_equal(ip.user_ns['error'], 'hello\n')
777 nt.assert_equal(ip.user_ns['error'], 'hello\n')
786
778
787 @dec.skip_win32
779 @dec.skip_win32
788 def test_script_bg_out():
780 def test_script_bg_out():
789 ip = get_ipython()
781 ip = get_ipython()
790 ip.run_cell_magic("script", "--bg --out output sh", "echo 'hi'")
782 ip.run_cell_magic("script", "--bg --out output sh", "echo 'hi'")
791 nt.assert_equal(ip.user_ns['output'].read(), b'hi\n')
783 nt.assert_equal(ip.user_ns['output'].read(), b'hi\n')
792
784
793 @dec.skip_win32
785 @dec.skip_win32
794 def test_script_bg_err():
786 def test_script_bg_err():
795 ip = get_ipython()
787 ip = get_ipython()
796 ip.run_cell_magic("script", "--bg --err error sh", "echo 'hello' >&2")
788 ip.run_cell_magic("script", "--bg --err error sh", "echo 'hello' >&2")
797 nt.assert_equal(ip.user_ns['error'].read(), b'hello\n')
789 nt.assert_equal(ip.user_ns['error'].read(), b'hello\n')
798
790
799 @dec.skip_win32
791 @dec.skip_win32
800 def test_script_bg_out_err():
792 def test_script_bg_out_err():
801 ip = get_ipython()
793 ip = get_ipython()
802 ip.run_cell_magic("script", "--bg --out output --err error sh", "echo 'hi'\necho 'hello' >&2")
794 ip.run_cell_magic("script", "--bg --out output --err error sh", "echo 'hi'\necho 'hello' >&2")
803 nt.assert_equal(ip.user_ns['output'].read(), b'hi\n')
795 nt.assert_equal(ip.user_ns['output'].read(), b'hi\n')
804 nt.assert_equal(ip.user_ns['error'].read(), b'hello\n')
796 nt.assert_equal(ip.user_ns['error'].read(), b'hello\n')
805
797
806 def test_script_defaults():
798 def test_script_defaults():
807 ip = get_ipython()
799 ip = get_ipython()
808 for cmd in ['sh', 'bash', 'perl', 'ruby']:
800 for cmd in ['sh', 'bash', 'perl', 'ruby']:
809 try:
801 try:
810 find_cmd(cmd)
802 find_cmd(cmd)
811 except Exception:
803 except Exception:
812 pass
804 pass
813 else:
805 else:
814 nt.assert_in(cmd, ip.magics_manager.magics['cell'])
806 nt.assert_in(cmd, ip.magics_manager.magics['cell'])
815
807
816
808
817 @magics_class
809 @magics_class
818 class FooFoo(Magics):
810 class FooFoo(Magics):
819 """class with both %foo and %%foo magics"""
811 """class with both %foo and %%foo magics"""
820 @line_magic('foo')
812 @line_magic('foo')
821 def line_foo(self, line):
813 def line_foo(self, line):
822 "I am line foo"
814 "I am line foo"
823 pass
815 pass
824
816
825 @cell_magic("foo")
817 @cell_magic("foo")
826 def cell_foo(self, line, cell):
818 def cell_foo(self, line, cell):
827 "I am cell foo, not line foo"
819 "I am cell foo, not line foo"
828 pass
820 pass
829
821
830 def test_line_cell_info():
822 def test_line_cell_info():
831 """%%foo and %foo magics are distinguishable to inspect"""
823 """%%foo and %foo magics are distinguishable to inspect"""
832 ip = get_ipython()
824 ip = get_ipython()
833 ip.magics_manager.register(FooFoo)
825 ip.magics_manager.register(FooFoo)
834 oinfo = ip.object_inspect('foo')
826 oinfo = ip.object_inspect('foo')
835 nt.assert_true(oinfo['found'])
827 nt.assert_true(oinfo['found'])
836 nt.assert_true(oinfo['ismagic'])
828 nt.assert_true(oinfo['ismagic'])
837
829
838 oinfo = ip.object_inspect('%%foo')
830 oinfo = ip.object_inspect('%%foo')
839 nt.assert_true(oinfo['found'])
831 nt.assert_true(oinfo['found'])
840 nt.assert_true(oinfo['ismagic'])
832 nt.assert_true(oinfo['ismagic'])
841 nt.assert_equal(oinfo['docstring'], FooFoo.cell_foo.__doc__)
833 nt.assert_equal(oinfo['docstring'], FooFoo.cell_foo.__doc__)
842
834
843 oinfo = ip.object_inspect('%foo')
835 oinfo = ip.object_inspect('%foo')
844 nt.assert_true(oinfo['found'])
836 nt.assert_true(oinfo['found'])
845 nt.assert_true(oinfo['ismagic'])
837 nt.assert_true(oinfo['ismagic'])
846 nt.assert_equal(oinfo['docstring'], FooFoo.line_foo.__doc__)
838 nt.assert_equal(oinfo['docstring'], FooFoo.line_foo.__doc__)
847
839
848 def test_multiple_magics():
840 def test_multiple_magics():
849 ip = get_ipython()
841 ip = get_ipython()
850 foo1 = FooFoo(ip)
842 foo1 = FooFoo(ip)
851 foo2 = FooFoo(ip)
843 foo2 = FooFoo(ip)
852 mm = ip.magics_manager
844 mm = ip.magics_manager
853 mm.register(foo1)
845 mm.register(foo1)
854 nt.assert_true(mm.magics['line']['foo'].__self__ is foo1)
846 nt.assert_true(mm.magics['line']['foo'].__self__ is foo1)
855 mm.register(foo2)
847 mm.register(foo2)
856 nt.assert_true(mm.magics['line']['foo'].__self__ is foo2)
848 nt.assert_true(mm.magics['line']['foo'].__self__ is foo2)
857
849
858 def test_alias_magic():
850 def test_alias_magic():
859 """Test %alias_magic."""
851 """Test %alias_magic."""
860 ip = get_ipython()
852 ip = get_ipython()
861 mm = ip.magics_manager
853 mm = ip.magics_manager
862
854
863 # Basic operation: both cell and line magics are created, if possible.
855 # Basic operation: both cell and line magics are created, if possible.
864 ip.run_line_magic('alias_magic', 'timeit_alias timeit')
856 ip.run_line_magic('alias_magic', 'timeit_alias timeit')
865 nt.assert_in('timeit_alias', mm.magics['line'])
857 nt.assert_in('timeit_alias', mm.magics['line'])
866 nt.assert_in('timeit_alias', mm.magics['cell'])
858 nt.assert_in('timeit_alias', mm.magics['cell'])
867
859
868 # --cell is specified, line magic not created.
860 # --cell is specified, line magic not created.
869 ip.run_line_magic('alias_magic', '--cell timeit_cell_alias timeit')
861 ip.run_line_magic('alias_magic', '--cell timeit_cell_alias timeit')
870 nt.assert_not_in('timeit_cell_alias', mm.magics['line'])
862 nt.assert_not_in('timeit_cell_alias', mm.magics['line'])
871 nt.assert_in('timeit_cell_alias', mm.magics['cell'])
863 nt.assert_in('timeit_cell_alias', mm.magics['cell'])
872
864
873 # Test that line alias is created successfully.
865 # Test that line alias is created successfully.
874 ip.run_line_magic('alias_magic', '--line env_alias env')
866 ip.run_line_magic('alias_magic', '--line env_alias env')
875 nt.assert_equal(ip.run_line_magic('env', ''),
867 nt.assert_equal(ip.run_line_magic('env', ''),
876 ip.run_line_magic('env_alias', ''))
868 ip.run_line_magic('env_alias', ''))
877
869
878 def test_save():
870 def test_save():
879 """Test %save."""
871 """Test %save."""
880 ip = get_ipython()
872 ip = get_ipython()
881 ip.history_manager.reset() # Clear any existing history.
873 ip.history_manager.reset() # Clear any existing history.
882 cmds = [u"a=1", u"def b():\n return a**2", u"print(a, b())"]
874 cmds = [u"a=1", u"def b():\n return a**2", u"print(a, b())"]
883 for i, cmd in enumerate(cmds, start=1):
875 for i, cmd in enumerate(cmds, start=1):
884 ip.history_manager.store_inputs(i, cmd)
876 ip.history_manager.store_inputs(i, cmd)
885 with TemporaryDirectory() as tmpdir:
877 with TemporaryDirectory() as tmpdir:
886 file = os.path.join(tmpdir, "testsave.py")
878 file = os.path.join(tmpdir, "testsave.py")
887 ip.run_line_magic("save", "%s 1-10" % file)
879 ip.run_line_magic("save", "%s 1-10" % file)
888 with open(file) as f:
880 with open(file) as f:
889 content = f.read()
881 content = f.read()
890 nt.assert_equal(content.count(cmds[0]), 1)
882 nt.assert_equal(content.count(cmds[0]), 1)
891 nt.assert_in('coding: utf-8', content)
883 nt.assert_in('coding: utf-8', content)
892 ip.run_line_magic("save", "-a %s 1-10" % file)
884 ip.run_line_magic("save", "-a %s 1-10" % file)
893 with open(file) as f:
885 with open(file) as f:
894 content = f.read()
886 content = f.read()
895 nt.assert_equal(content.count(cmds[0]), 2)
887 nt.assert_equal(content.count(cmds[0]), 2)
896 nt.assert_in('coding: utf-8', content)
888 nt.assert_in('coding: utf-8', content)
897
889
898
890
899 def test_store():
891 def test_store():
900 """Test %store."""
892 """Test %store."""
901 ip = get_ipython()
893 ip = get_ipython()
902 ip.run_line_magic('load_ext', 'storemagic')
894 ip.run_line_magic('load_ext', 'storemagic')
903
895
904 # make sure the storage is empty
896 # make sure the storage is empty
905 ip.run_line_magic('store', '-z')
897 ip.run_line_magic('store', '-z')
906 ip.user_ns['var'] = 42
898 ip.user_ns['var'] = 42
907 ip.run_line_magic('store', 'var')
899 ip.run_line_magic('store', 'var')
908 ip.user_ns['var'] = 39
900 ip.user_ns['var'] = 39
909 ip.run_line_magic('store', '-r')
901 ip.run_line_magic('store', '-r')
910 nt.assert_equal(ip.user_ns['var'], 42)
902 nt.assert_equal(ip.user_ns['var'], 42)
911
903
912 ip.run_line_magic('store', '-d var')
904 ip.run_line_magic('store', '-d var')
913 ip.user_ns['var'] = 39
905 ip.user_ns['var'] = 39
914 ip.run_line_magic('store' , '-r')
906 ip.run_line_magic('store' , '-r')
915 nt.assert_equal(ip.user_ns['var'], 39)
907 nt.assert_equal(ip.user_ns['var'], 39)
916
908
917
909
918 def _run_edit_test(arg_s, exp_filename=None,
910 def _run_edit_test(arg_s, exp_filename=None,
919 exp_lineno=-1,
911 exp_lineno=-1,
920 exp_contents=None,
912 exp_contents=None,
921 exp_is_temp=None):
913 exp_is_temp=None):
922 ip = get_ipython()
914 ip = get_ipython()
923 M = code.CodeMagics(ip)
915 M = code.CodeMagics(ip)
924 last_call = ['','']
916 last_call = ['','']
925 opts,args = M.parse_options(arg_s,'prxn:')
917 opts,args = M.parse_options(arg_s,'prxn:')
926 filename, lineno, is_temp = M._find_edit_target(ip, args, opts, last_call)
918 filename, lineno, is_temp = M._find_edit_target(ip, args, opts, last_call)
927
919
928 if exp_filename is not None:
920 if exp_filename is not None:
929 nt.assert_equal(exp_filename, filename)
921 nt.assert_equal(exp_filename, filename)
930 if exp_contents is not None:
922 if exp_contents is not None:
931 with io.open(filename, 'r', encoding='utf-8') as f:
923 with io.open(filename, 'r', encoding='utf-8') as f:
932 contents = f.read()
924 contents = f.read()
933 nt.assert_equal(exp_contents, contents)
925 nt.assert_equal(exp_contents, contents)
934 if exp_lineno != -1:
926 if exp_lineno != -1:
935 nt.assert_equal(exp_lineno, lineno)
927 nt.assert_equal(exp_lineno, lineno)
936 if exp_is_temp is not None:
928 if exp_is_temp is not None:
937 nt.assert_equal(exp_is_temp, is_temp)
929 nt.assert_equal(exp_is_temp, is_temp)
938
930
939
931
940 def test_edit_interactive():
932 def test_edit_interactive():
941 """%edit on interactively defined objects"""
933 """%edit on interactively defined objects"""
942 ip = get_ipython()
934 ip = get_ipython()
943 n = ip.execution_count
935 n = ip.execution_count
944 ip.run_cell(u"def foo(): return 1", store_history=True)
936 ip.run_cell(u"def foo(): return 1", store_history=True)
945
937
946 try:
938 try:
947 _run_edit_test("foo")
939 _run_edit_test("foo")
948 except code.InteractivelyDefined as e:
940 except code.InteractivelyDefined as e:
949 nt.assert_equal(e.index, n)
941 nt.assert_equal(e.index, n)
950 else:
942 else:
951 raise AssertionError("Should have raised InteractivelyDefined")
943 raise AssertionError("Should have raised InteractivelyDefined")
952
944
953
945
954 def test_edit_cell():
946 def test_edit_cell():
955 """%edit [cell id]"""
947 """%edit [cell id]"""
956 ip = get_ipython()
948 ip = get_ipython()
957
949
958 ip.run_cell(u"def foo(): return 1", store_history=True)
950 ip.run_cell(u"def foo(): return 1", store_history=True)
959
951
960 # test
952 # test
961 _run_edit_test("1", exp_contents=ip.user_ns['In'][1], exp_is_temp=True)
953 _run_edit_test("1", exp_contents=ip.user_ns['In'][1], exp_is_temp=True)
962
954
963 def test_bookmark():
955 def test_bookmark():
964 ip = get_ipython()
956 ip = get_ipython()
965 ip.run_line_magic('bookmark', 'bmname')
957 ip.run_line_magic('bookmark', 'bmname')
966 with tt.AssertPrints('bmname'):
958 with tt.AssertPrints('bmname'):
967 ip.run_line_magic('bookmark', '-l')
959 ip.run_line_magic('bookmark', '-l')
968 ip.run_line_magic('bookmark', '-d bmname')
960 ip.run_line_magic('bookmark', '-d bmname')
969
961
970 def test_ls_magic():
962 def test_ls_magic():
971 ip = get_ipython()
963 ip = get_ipython()
972 json_formatter = ip.display_formatter.formatters['application/json']
964 json_formatter = ip.display_formatter.formatters['application/json']
973 json_formatter.enabled = True
965 json_formatter.enabled = True
974 lsmagic = ip.magic('lsmagic')
966 lsmagic = ip.magic('lsmagic')
975 with warnings.catch_warnings(record=True) as w:
967 with warnings.catch_warnings(record=True) as w:
976 j = json_formatter(lsmagic)
968 j = json_formatter(lsmagic)
977 nt.assert_equal(sorted(j), ['cell', 'line'])
969 nt.assert_equal(sorted(j), ['cell', 'line'])
978 nt.assert_equal(w, []) # no warnings
970 nt.assert_equal(w, []) # no warnings
979
971
980 def test_strip_initial_indent():
972 def test_strip_initial_indent():
981 def sii(s):
973 def sii(s):
982 lines = s.splitlines()
974 lines = s.splitlines()
983 return '\n'.join(code.strip_initial_indent(lines))
975 return '\n'.join(code.strip_initial_indent(lines))
984
976
985 nt.assert_equal(sii(" a = 1\nb = 2"), "a = 1\nb = 2")
977 nt.assert_equal(sii(" a = 1\nb = 2"), "a = 1\nb = 2")
986 nt.assert_equal(sii(" a\n b\nc"), "a\n b\nc")
978 nt.assert_equal(sii(" a\n b\nc"), "a\n b\nc")
987 nt.assert_equal(sii("a\n b"), "a\n b")
979 nt.assert_equal(sii("a\n b"), "a\n b")
@@ -1,202 +1,197 b''
1 """Tests for various magic functions specific to the terminal frontend.
1 """Tests for various magic functions specific to the terminal frontend.
2
2
3 Needs to be run by nose (to make ipython session available).
3 Needs to be run by nose (to make ipython session available).
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Imports
7 # Imports
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 import sys
10 import sys
11 from io import StringIO
11 from unittest import TestCase
12 from unittest import TestCase
12
13
13 import nose.tools as nt
14 import nose.tools as nt
14
15
15 from IPython.testing import tools as tt
16 from IPython.testing import tools as tt
16 from IPython.utils.py3compat import PY3
17
18 if PY3:
19 from io import StringIO
20 else:
21 from StringIO import StringIO
22
17
23 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
24 # Globals
19 # Globals
25 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
26 ip = get_ipython()
21 ip = get_ipython()
27
22
28 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
29 # Test functions begin
24 # Test functions begin
30 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
31
26
32 def check_cpaste(code, should_fail=False):
27 def check_cpaste(code, should_fail=False):
33 """Execute code via 'cpaste' and ensure it was executed, unless
28 """Execute code via 'cpaste' and ensure it was executed, unless
34 should_fail is set.
29 should_fail is set.
35 """
30 """
36 ip.user_ns['code_ran'] = False
31 ip.user_ns['code_ran'] = False
37
32
38 src = StringIO()
33 src = StringIO()
39 if not hasattr(src, 'encoding'):
34 if not hasattr(src, 'encoding'):
40 # IPython expects stdin to have an encoding attribute
35 # IPython expects stdin to have an encoding attribute
41 src.encoding = None
36 src.encoding = None
42 src.write(code)
37 src.write(code)
43 src.write('\n--\n')
38 src.write('\n--\n')
44 src.seek(0)
39 src.seek(0)
45
40
46 stdin_save = sys.stdin
41 stdin_save = sys.stdin
47 sys.stdin = src
42 sys.stdin = src
48
43
49 try:
44 try:
50 context = tt.AssertPrints if should_fail else tt.AssertNotPrints
45 context = tt.AssertPrints if should_fail else tt.AssertNotPrints
51 with context("Traceback (most recent call last)"):
46 with context("Traceback (most recent call last)"):
52 ip.magic('cpaste')
47 ip.magic('cpaste')
53
48
54 if not should_fail:
49 if not should_fail:
55 assert ip.user_ns['code_ran'], "%r failed" % code
50 assert ip.user_ns['code_ran'], "%r failed" % code
56 finally:
51 finally:
57 sys.stdin = stdin_save
52 sys.stdin = stdin_save
58
53
59 def test_cpaste():
54 def test_cpaste():
60 """Test cpaste magic"""
55 """Test cpaste magic"""
61
56
62 def runf():
57 def runf():
63 """Marker function: sets a flag when executed.
58 """Marker function: sets a flag when executed.
64 """
59 """
65 ip.user_ns['code_ran'] = True
60 ip.user_ns['code_ran'] = True
66 return 'runf' # return string so '+ runf()' doesn't result in success
61 return 'runf' # return string so '+ runf()' doesn't result in success
67
62
68 tests = {'pass': ["runf()",
63 tests = {'pass': ["runf()",
69 "In [1]: runf()",
64 "In [1]: runf()",
70 "In [1]: if 1:\n ...: runf()",
65 "In [1]: if 1:\n ...: runf()",
71 "> > > runf()",
66 "> > > runf()",
72 ">>> runf()",
67 ">>> runf()",
73 " >>> runf()",
68 " >>> runf()",
74 ],
69 ],
75
70
76 'fail': ["1 + runf()",
71 'fail': ["1 + runf()",
77 "++ runf()",
72 "++ runf()",
78 ]}
73 ]}
79
74
80 ip.user_ns['runf'] = runf
75 ip.user_ns['runf'] = runf
81
76
82 for code in tests['pass']:
77 for code in tests['pass']:
83 check_cpaste(code)
78 check_cpaste(code)
84
79
85 for code in tests['fail']:
80 for code in tests['fail']:
86 check_cpaste(code, should_fail=True)
81 check_cpaste(code, should_fail=True)
87
82
88
83
89 class PasteTestCase(TestCase):
84 class PasteTestCase(TestCase):
90 """Multiple tests for clipboard pasting"""
85 """Multiple tests for clipboard pasting"""
91
86
92 def paste(self, txt, flags='-q'):
87 def paste(self, txt, flags='-q'):
93 """Paste input text, by default in quiet mode"""
88 """Paste input text, by default in quiet mode"""
94 ip.hooks.clipboard_get = lambda : txt
89 ip.hooks.clipboard_get = lambda : txt
95 ip.magic('paste '+flags)
90 ip.magic('paste '+flags)
96
91
97 def setUp(self):
92 def setUp(self):
98 # Inject fake clipboard hook but save original so we can restore it later
93 # Inject fake clipboard hook but save original so we can restore it later
99 self.original_clip = ip.hooks.clipboard_get
94 self.original_clip = ip.hooks.clipboard_get
100
95
101 def tearDown(self):
96 def tearDown(self):
102 # Restore original hook
97 # Restore original hook
103 ip.hooks.clipboard_get = self.original_clip
98 ip.hooks.clipboard_get = self.original_clip
104
99
105 def test_paste(self):
100 def test_paste(self):
106 ip.user_ns.pop('x', None)
101 ip.user_ns.pop('x', None)
107 self.paste('x = 1')
102 self.paste('x = 1')
108 nt.assert_equal(ip.user_ns['x'], 1)
103 nt.assert_equal(ip.user_ns['x'], 1)
109 ip.user_ns.pop('x')
104 ip.user_ns.pop('x')
110
105
111 def test_paste_pyprompt(self):
106 def test_paste_pyprompt(self):
112 ip.user_ns.pop('x', None)
107 ip.user_ns.pop('x', None)
113 self.paste('>>> x=2')
108 self.paste('>>> x=2')
114 nt.assert_equal(ip.user_ns['x'], 2)
109 nt.assert_equal(ip.user_ns['x'], 2)
115 ip.user_ns.pop('x')
110 ip.user_ns.pop('x')
116
111
117 def test_paste_py_multi(self):
112 def test_paste_py_multi(self):
118 self.paste("""
113 self.paste("""
119 >>> x = [1,2,3]
114 >>> x = [1,2,3]
120 >>> y = []
115 >>> y = []
121 >>> for i in x:
116 >>> for i in x:
122 ... y.append(i**2)
117 ... y.append(i**2)
123 ...
118 ...
124 """)
119 """)
125 nt.assert_equal(ip.user_ns['x'], [1,2,3])
120 nt.assert_equal(ip.user_ns['x'], [1,2,3])
126 nt.assert_equal(ip.user_ns['y'], [1,4,9])
121 nt.assert_equal(ip.user_ns['y'], [1,4,9])
127
122
128 def test_paste_py_multi_r(self):
123 def test_paste_py_multi_r(self):
129 "Now, test that self.paste -r works"
124 "Now, test that self.paste -r works"
130 self.test_paste_py_multi()
125 self.test_paste_py_multi()
131 nt.assert_equal(ip.user_ns.pop('x'), [1,2,3])
126 nt.assert_equal(ip.user_ns.pop('x'), [1,2,3])
132 nt.assert_equal(ip.user_ns.pop('y'), [1,4,9])
127 nt.assert_equal(ip.user_ns.pop('y'), [1,4,9])
133 nt.assert_false('x' in ip.user_ns)
128 nt.assert_false('x' in ip.user_ns)
134 ip.magic('paste -r')
129 ip.magic('paste -r')
135 nt.assert_equal(ip.user_ns['x'], [1,2,3])
130 nt.assert_equal(ip.user_ns['x'], [1,2,3])
136 nt.assert_equal(ip.user_ns['y'], [1,4,9])
131 nt.assert_equal(ip.user_ns['y'], [1,4,9])
137
132
138 def test_paste_email(self):
133 def test_paste_email(self):
139 "Test pasting of email-quoted contents"
134 "Test pasting of email-quoted contents"
140 self.paste("""\
135 self.paste("""\
141 >> def foo(x):
136 >> def foo(x):
142 >> return x + 1
137 >> return x + 1
143 >> xx = foo(1.1)""")
138 >> xx = foo(1.1)""")
144 nt.assert_equal(ip.user_ns['xx'], 2.1)
139 nt.assert_equal(ip.user_ns['xx'], 2.1)
145
140
146 def test_paste_email2(self):
141 def test_paste_email2(self):
147 "Email again; some programs add a space also at each quoting level"
142 "Email again; some programs add a space also at each quoting level"
148 self.paste("""\
143 self.paste("""\
149 > > def foo(x):
144 > > def foo(x):
150 > > return x + 1
145 > > return x + 1
151 > > yy = foo(2.1) """)
146 > > yy = foo(2.1) """)
152 nt.assert_equal(ip.user_ns['yy'], 3.1)
147 nt.assert_equal(ip.user_ns['yy'], 3.1)
153
148
154 def test_paste_email_py(self):
149 def test_paste_email_py(self):
155 "Email quoting of interactive input"
150 "Email quoting of interactive input"
156 self.paste("""\
151 self.paste("""\
157 >> >>> def f(x):
152 >> >>> def f(x):
158 >> ... return x+1
153 >> ... return x+1
159 >> ...
154 >> ...
160 >> >>> zz = f(2.5) """)
155 >> >>> zz = f(2.5) """)
161 nt.assert_equal(ip.user_ns['zz'], 3.5)
156 nt.assert_equal(ip.user_ns['zz'], 3.5)
162
157
163 def test_paste_echo(self):
158 def test_paste_echo(self):
164 "Also test self.paste echoing, by temporarily faking the writer"
159 "Also test self.paste echoing, by temporarily faking the writer"
165 w = StringIO()
160 w = StringIO()
166 writer = ip.write
161 writer = ip.write
167 ip.write = w.write
162 ip.write = w.write
168 code = """
163 code = """
169 a = 100
164 a = 100
170 b = 200"""
165 b = 200"""
171 try:
166 try:
172 self.paste(code,'')
167 self.paste(code,'')
173 out = w.getvalue()
168 out = w.getvalue()
174 finally:
169 finally:
175 ip.write = writer
170 ip.write = writer
176 nt.assert_equal(ip.user_ns['a'], 100)
171 nt.assert_equal(ip.user_ns['a'], 100)
177 nt.assert_equal(ip.user_ns['b'], 200)
172 nt.assert_equal(ip.user_ns['b'], 200)
178 nt.assert_equal(out, code+"\n## -- End pasted text --\n")
173 nt.assert_equal(out, code+"\n## -- End pasted text --\n")
179
174
180 def test_paste_leading_commas(self):
175 def test_paste_leading_commas(self):
181 "Test multiline strings with leading commas"
176 "Test multiline strings with leading commas"
182 tm = ip.magics_manager.registry['TerminalMagics']
177 tm = ip.magics_manager.registry['TerminalMagics']
183 s = '''\
178 s = '''\
184 a = """
179 a = """
185 ,1,2,3
180 ,1,2,3
186 """'''
181 """'''
187 ip.user_ns.pop('foo', None)
182 ip.user_ns.pop('foo', None)
188 tm.store_or_execute(s, 'foo')
183 tm.store_or_execute(s, 'foo')
189 nt.assert_in('foo', ip.user_ns)
184 nt.assert_in('foo', ip.user_ns)
190
185
191
186
192 def test_paste_trailing_question(self):
187 def test_paste_trailing_question(self):
193 "Test pasting sources with trailing question marks"
188 "Test pasting sources with trailing question marks"
194 tm = ip.magics_manager.registry['TerminalMagics']
189 tm = ip.magics_manager.registry['TerminalMagics']
195 s = '''\
190 s = '''\
196 def funcfoo():
191 def funcfoo():
197 if True: #am i true?
192 if True: #am i true?
198 return 'fooresult'
193 return 'fooresult'
199 '''
194 '''
200 ip.user_ns.pop('funcfoo', None)
195 ip.user_ns.pop('funcfoo', None)
201 self.paste(s)
196 self.paste(s)
202 nt.assert_equal(ip.user_ns['funcfoo'](), 'fooresult')
197 nt.assert_equal(ip.user_ns['funcfoo'](), 'fooresult')
@@ -1,419 +1,409 b''
1 """Tests for the object inspection functionality.
1 """Tests for the object inspection functionality.
2 """
2 """
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7
7
8 from inspect import Signature, Parameter
8 from inspect import Signature, Parameter
9 import os
9 import os
10 import re
10 import re
11 import sys
11 import sys
12
12
13 import nose.tools as nt
13 import nose.tools as nt
14
14
15 from .. import oinspect
15 from .. import oinspect
16 from IPython.core.magic import (Magics, magics_class, line_magic,
16 from IPython.core.magic import (Magics, magics_class, line_magic,
17 cell_magic, line_cell_magic,
17 cell_magic, line_cell_magic,
18 register_line_magic, register_cell_magic,
18 register_line_magic, register_cell_magic,
19 register_line_cell_magic)
19 register_line_cell_magic)
20 from decorator import decorator
20 from decorator import decorator
21 from IPython import get_ipython
21 from IPython import get_ipython
22 from IPython.testing.decorators import skipif
22 from IPython.testing.decorators import skipif
23 from IPython.testing.tools import AssertPrints, AssertNotPrints
23 from IPython.testing.tools import AssertPrints, AssertNotPrints
24 from IPython.utils.path import compress_user
24 from IPython.utils.path import compress_user
25 from IPython.utils import py3compat
25 from IPython.utils import py3compat
26
26
27
27
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29 # Globals and constants
29 # Globals and constants
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31
31
32 inspector = oinspect.Inspector()
32 inspector = oinspect.Inspector()
33 ip = get_ipython()
33 ip = get_ipython()
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Local utilities
36 # Local utilities
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38
38
39 # WARNING: since this test checks the line number where a function is
39 # WARNING: since this test checks the line number where a function is
40 # defined, if any code is inserted above, the following line will need to be
40 # defined, if any code is inserted above, the following line will need to be
41 # updated. Do NOT insert any whitespace between the next line and the function
41 # updated. Do NOT insert any whitespace between the next line and the function
42 # definition below.
42 # definition below.
43 THIS_LINE_NUMBER = 43 # Put here the actual number of this line
43 THIS_LINE_NUMBER = 43 # Put here the actual number of this line
44
44
45 from unittest import TestCase
45 from unittest import TestCase
46
46
47 class Test(TestCase):
47 class Test(TestCase):
48
48
49 def test_find_source_lines(self):
49 def test_find_source_lines(self):
50 self.assertEqual(oinspect.find_source_lines(Test.test_find_source_lines),
50 self.assertEqual(oinspect.find_source_lines(Test.test_find_source_lines),
51 THIS_LINE_NUMBER+6)
51 THIS_LINE_NUMBER+6)
52
52
53
53
54 # A couple of utilities to ensure these tests work the same from a source or a
54 # A couple of utilities to ensure these tests work the same from a source or a
55 # binary install
55 # binary install
56 def pyfile(fname):
56 def pyfile(fname):
57 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
57 return os.path.normcase(re.sub('.py[co]$', '.py', fname))
58
58
59
59
60 def match_pyfiles(f1, f2):
60 def match_pyfiles(f1, f2):
61 nt.assert_equal(pyfile(f1), pyfile(f2))
61 nt.assert_equal(pyfile(f1), pyfile(f2))
62
62
63
63
64 def test_find_file():
64 def test_find_file():
65 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
65 match_pyfiles(oinspect.find_file(test_find_file), os.path.abspath(__file__))
66
66
67
67
68 def test_find_file_decorated1():
68 def test_find_file_decorated1():
69
69
70 @decorator
70 @decorator
71 def noop1(f):
71 def noop1(f):
72 def wrapper(*a, **kw):
72 def wrapper(*a, **kw):
73 return f(*a, **kw)
73 return f(*a, **kw)
74 return wrapper
74 return wrapper
75
75
76 @noop1
76 @noop1
77 def f(x):
77 def f(x):
78 "My docstring"
78 "My docstring"
79
79
80 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
80 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
81 nt.assert_equal(f.__doc__, "My docstring")
81 nt.assert_equal(f.__doc__, "My docstring")
82
82
83
83
84 def test_find_file_decorated2():
84 def test_find_file_decorated2():
85
85
86 @decorator
86 @decorator
87 def noop2(f, *a, **kw):
87 def noop2(f, *a, **kw):
88 return f(*a, **kw)
88 return f(*a, **kw)
89
89
90 @noop2
90 @noop2
91 @noop2
91 @noop2
92 @noop2
92 @noop2
93 def f(x):
93 def f(x):
94 "My docstring 2"
94 "My docstring 2"
95
95
96 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
96 match_pyfiles(oinspect.find_file(f), os.path.abspath(__file__))
97 nt.assert_equal(f.__doc__, "My docstring 2")
97 nt.assert_equal(f.__doc__, "My docstring 2")
98
98
99
99
100 def test_find_file_magic():
100 def test_find_file_magic():
101 run = ip.find_line_magic('run')
101 run = ip.find_line_magic('run')
102 nt.assert_not_equal(oinspect.find_file(run), None)
102 nt.assert_not_equal(oinspect.find_file(run), None)
103
103
104
104
105 # A few generic objects we can then inspect in the tests below
105 # A few generic objects we can then inspect in the tests below
106
106
107 class Call(object):
107 class Call(object):
108 """This is the class docstring."""
108 """This is the class docstring."""
109
109
110 def __init__(self, x, y=1):
110 def __init__(self, x, y=1):
111 """This is the constructor docstring."""
111 """This is the constructor docstring."""
112
112
113 def __call__(self, *a, **kw):
113 def __call__(self, *a, **kw):
114 """This is the call docstring."""
114 """This is the call docstring."""
115
115
116 def method(self, x, z=2):
116 def method(self, x, z=2):
117 """Some method's docstring"""
117 """Some method's docstring"""
118
118
119 class HasSignature(object):
119 class HasSignature(object):
120 """This is the class docstring."""
120 """This is the class docstring."""
121 __signature__ = Signature([Parameter('test', Parameter.POSITIONAL_OR_KEYWORD)])
121 __signature__ = Signature([Parameter('test', Parameter.POSITIONAL_OR_KEYWORD)])
122
122
123 def __init__(self, *args):
123 def __init__(self, *args):
124 """This is the init docstring"""
124 """This is the init docstring"""
125
125
126
126
127 class SimpleClass(object):
127 class SimpleClass(object):
128 def method(self, x, z=2):
128 def method(self, x, z=2):
129 """Some method's docstring"""
129 """Some method's docstring"""
130
130
131
131
132 class OldStyle:
132 class OldStyle:
133 """An old-style class for testing."""
133 """An old-style class for testing."""
134 pass
134 pass
135
135
136
136
137 def f(x, y=2, *a, **kw):
137 def f(x, y=2, *a, **kw):
138 """A simple function."""
138 """A simple function."""
139
139
140
140
141 def g(y, z=3, *a, **kw):
141 def g(y, z=3, *a, **kw):
142 pass # no docstring
142 pass # no docstring
143
143
144
144
145 @register_line_magic
145 @register_line_magic
146 def lmagic(line):
146 def lmagic(line):
147 "A line magic"
147 "A line magic"
148
148
149
149
150 @register_cell_magic
150 @register_cell_magic
151 def cmagic(line, cell):
151 def cmagic(line, cell):
152 "A cell magic"
152 "A cell magic"
153
153
154
154
155 @register_line_cell_magic
155 @register_line_cell_magic
156 def lcmagic(line, cell=None):
156 def lcmagic(line, cell=None):
157 "A line/cell magic"
157 "A line/cell magic"
158
158
159
159
160 @magics_class
160 @magics_class
161 class SimpleMagics(Magics):
161 class SimpleMagics(Magics):
162 @line_magic
162 @line_magic
163 def Clmagic(self, cline):
163 def Clmagic(self, cline):
164 "A class-based line magic"
164 "A class-based line magic"
165
165
166 @cell_magic
166 @cell_magic
167 def Ccmagic(self, cline, ccell):
167 def Ccmagic(self, cline, ccell):
168 "A class-based cell magic"
168 "A class-based cell magic"
169
169
170 @line_cell_magic
170 @line_cell_magic
171 def Clcmagic(self, cline, ccell=None):
171 def Clcmagic(self, cline, ccell=None):
172 "A class-based line/cell magic"
172 "A class-based line/cell magic"
173
173
174
174
175 class Awkward(object):
175 class Awkward(object):
176 def __getattr__(self, name):
176 def __getattr__(self, name):
177 raise Exception(name)
177 raise Exception(name)
178
178
179 class NoBoolCall:
179 class NoBoolCall:
180 """
180 """
181 callable with `__bool__` raising should still be inspect-able.
181 callable with `__bool__` raising should still be inspect-able.
182 """
182 """
183
183
184 def __call__(self):
184 def __call__(self):
185 """does nothing"""
185 """does nothing"""
186 pass
186 pass
187
187
188 def __bool__(self):
188 def __bool__(self):
189 """just raise NotImplemented"""
189 """just raise NotImplemented"""
190 raise NotImplementedError('Must be implemented')
190 raise NotImplementedError('Must be implemented')
191
191
192
192
193 class SerialLiar(object):
193 class SerialLiar(object):
194 """Attribute accesses always get another copy of the same class.
194 """Attribute accesses always get another copy of the same class.
195
195
196 unittest.mock.call does something similar, but it's not ideal for testing
196 unittest.mock.call does something similar, but it's not ideal for testing
197 as the failure mode is to eat all your RAM. This gives up after 10k levels.
197 as the failure mode is to eat all your RAM. This gives up after 10k levels.
198 """
198 """
199 def __init__(self, max_fibbing_twig, lies_told=0):
199 def __init__(self, max_fibbing_twig, lies_told=0):
200 if lies_told > 10000:
200 if lies_told > 10000:
201 raise RuntimeError('Nose too long, honesty is the best policy')
201 raise RuntimeError('Nose too long, honesty is the best policy')
202 self.max_fibbing_twig = max_fibbing_twig
202 self.max_fibbing_twig = max_fibbing_twig
203 self.lies_told = lies_told
203 self.lies_told = lies_told
204 max_fibbing_twig[0] = max(max_fibbing_twig[0], lies_told)
204 max_fibbing_twig[0] = max(max_fibbing_twig[0], lies_told)
205
205
206 def __getattr__(self, item):
206 def __getattr__(self, item):
207 return SerialLiar(self.max_fibbing_twig, self.lies_told + 1)
207 return SerialLiar(self.max_fibbing_twig, self.lies_told + 1)
208
208
209 #-----------------------------------------------------------------------------
209 #-----------------------------------------------------------------------------
210 # Tests
210 # Tests
211 #-----------------------------------------------------------------------------
211 #-----------------------------------------------------------------------------
212
212
213 def test_info():
213 def test_info():
214 "Check that Inspector.info fills out various fields as expected."
214 "Check that Inspector.info fills out various fields as expected."
215 i = inspector.info(Call, oname='Call')
215 i = inspector.info(Call, oname='Call')
216 nt.assert_equal(i['type_name'], 'type')
216 nt.assert_equal(i['type_name'], 'type')
217 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
217 expted_class = str(type(type)) # <class 'type'> (Python 3) or <type 'type'>
218 nt.assert_equal(i['base_class'], expted_class)
218 nt.assert_equal(i['base_class'], expted_class)
219 nt.assert_regex(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'( at 0x[0-9a-f]{1,9})?>")
219 nt.assert_regex(i['string_form'], "<class 'IPython.core.tests.test_oinspect.Call'( at 0x[0-9a-f]{1,9})?>")
220 fname = __file__
220 fname = __file__
221 if fname.endswith(".pyc"):
221 if fname.endswith(".pyc"):
222 fname = fname[:-1]
222 fname = fname[:-1]
223 # case-insensitive comparison needed on some filesystems
223 # case-insensitive comparison needed on some filesystems
224 # e.g. Windows:
224 # e.g. Windows:
225 nt.assert_equal(i['file'].lower(), compress_user(fname).lower())
225 nt.assert_equal(i['file'].lower(), compress_user(fname).lower())
226 nt.assert_equal(i['definition'], None)
226 nt.assert_equal(i['definition'], None)
227 nt.assert_equal(i['docstring'], Call.__doc__)
227 nt.assert_equal(i['docstring'], Call.__doc__)
228 nt.assert_equal(i['source'], None)
228 nt.assert_equal(i['source'], None)
229 nt.assert_true(i['isclass'])
229 nt.assert_true(i['isclass'])
230 _self_py2 = '' if py3compat.PY3 else 'self, '
230 nt.assert_equal(i['init_definition'], "Call(x, y=1)")
231 nt.assert_equal(i['init_definition'], "Call(%sx, y=1)" % _self_py2)
232 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
231 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
233
232
234 i = inspector.info(Call, detail_level=1)
233 i = inspector.info(Call, detail_level=1)
235 nt.assert_not_equal(i['source'], None)
234 nt.assert_not_equal(i['source'], None)
236 nt.assert_equal(i['docstring'], None)
235 nt.assert_equal(i['docstring'], None)
237
236
238 c = Call(1)
237 c = Call(1)
239 c.__doc__ = "Modified instance docstring"
238 c.__doc__ = "Modified instance docstring"
240 i = inspector.info(c)
239 i = inspector.info(c)
241 nt.assert_equal(i['type_name'], 'Call')
240 nt.assert_equal(i['type_name'], 'Call')
242 nt.assert_equal(i['docstring'], "Modified instance docstring")
241 nt.assert_equal(i['docstring'], "Modified instance docstring")
243 nt.assert_equal(i['class_docstring'], Call.__doc__)
242 nt.assert_equal(i['class_docstring'], Call.__doc__)
244 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
243 nt.assert_equal(i['init_docstring'], Call.__init__.__doc__)
245 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
244 nt.assert_equal(i['call_docstring'], Call.__call__.__doc__)
246
245
247 # Test old-style classes, which for example may not have an __init__ method.
248 if not py3compat.PY3:
249 i = inspector.info(OldStyle)
250 nt.assert_equal(i['type_name'], 'classobj')
251
252 i = inspector.info(OldStyle())
253 nt.assert_equal(i['type_name'], 'instance')
254 nt.assert_equal(i['docstring'], OldStyle.__doc__)
255
256 def test_class_signature():
246 def test_class_signature():
257 info = inspector.info(HasSignature, 'HasSignature')
247 info = inspector.info(HasSignature, 'HasSignature')
258 nt.assert_equal(info['init_definition'], "HasSignature(test)")
248 nt.assert_equal(info['init_definition'], "HasSignature(test)")
259 nt.assert_equal(info['init_docstring'], HasSignature.__init__.__doc__)
249 nt.assert_equal(info['init_docstring'], HasSignature.__init__.__doc__)
260
250
261 def test_info_awkward():
251 def test_info_awkward():
262 # Just test that this doesn't throw an error.
252 # Just test that this doesn't throw an error.
263 inspector.info(Awkward())
253 inspector.info(Awkward())
264
254
265 def test_bool_raise():
255 def test_bool_raise():
266 inspector.info(NoBoolCall())
256 inspector.info(NoBoolCall())
267
257
268 def test_info_serialliar():
258 def test_info_serialliar():
269 fib_tracker = [0]
259 fib_tracker = [0]
270 inspector.info(SerialLiar(fib_tracker))
260 inspector.info(SerialLiar(fib_tracker))
271
261
272 # Nested attribute access should be cut off at 100 levels deep to avoid
262 # Nested attribute access should be cut off at 100 levels deep to avoid
273 # infinite loops: https://github.com/ipython/ipython/issues/9122
263 # infinite loops: https://github.com/ipython/ipython/issues/9122
274 nt.assert_less(fib_tracker[0], 9000)
264 nt.assert_less(fib_tracker[0], 9000)
275
265
276 def test_calldef_none():
266 def test_calldef_none():
277 # We should ignore __call__ for all of these.
267 # We should ignore __call__ for all of these.
278 for obj in [f, SimpleClass().method, any, str.upper]:
268 for obj in [f, SimpleClass().method, any, str.upper]:
279 print(obj)
269 print(obj)
280 i = inspector.info(obj)
270 i = inspector.info(obj)
281 nt.assert_is(i['call_def'], None)
271 nt.assert_is(i['call_def'], None)
282
272
283 def f_kwarg(pos, *, kwonly):
273 def f_kwarg(pos, *, kwonly):
284 pass
274 pass
285
275
286 def test_definition_kwonlyargs():
276 def test_definition_kwonlyargs():
287 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
277 i = inspector.info(f_kwarg, oname='f_kwarg') # analysis:ignore
288 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)")
278 nt.assert_equal(i['definition'], "f_kwarg(pos, *, kwonly)")
289
279
290 def test_getdoc():
280 def test_getdoc():
291 class A(object):
281 class A(object):
292 """standard docstring"""
282 """standard docstring"""
293 pass
283 pass
294
284
295 class B(object):
285 class B(object):
296 """standard docstring"""
286 """standard docstring"""
297 def getdoc(self):
287 def getdoc(self):
298 return "custom docstring"
288 return "custom docstring"
299
289
300 class C(object):
290 class C(object):
301 """standard docstring"""
291 """standard docstring"""
302 def getdoc(self):
292 def getdoc(self):
303 return None
293 return None
304
294
305 a = A()
295 a = A()
306 b = B()
296 b = B()
307 c = C()
297 c = C()
308
298
309 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
299 nt.assert_equal(oinspect.getdoc(a), "standard docstring")
310 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
300 nt.assert_equal(oinspect.getdoc(b), "custom docstring")
311 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
301 nt.assert_equal(oinspect.getdoc(c), "standard docstring")
312
302
313
303
314 def test_empty_property_has_no_source():
304 def test_empty_property_has_no_source():
315 i = inspector.info(property(), detail_level=1)
305 i = inspector.info(property(), detail_level=1)
316 nt.assert_is(i['source'], None)
306 nt.assert_is(i['source'], None)
317
307
318
308
319 def test_property_sources():
309 def test_property_sources():
320 import zlib
310 import zlib
321
311
322 class A(object):
312 class A(object):
323 @property
313 @property
324 def foo(self):
314 def foo(self):
325 return 'bar'
315 return 'bar'
326
316
327 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
317 foo = foo.setter(lambda self, v: setattr(self, 'bar', v))
328
318
329 id = property(id)
319 id = property(id)
330 compress = property(zlib.compress)
320 compress = property(zlib.compress)
331
321
332 i = inspector.info(A.foo, detail_level=1)
322 i = inspector.info(A.foo, detail_level=1)
333 nt.assert_in('def foo(self):', i['source'])
323 nt.assert_in('def foo(self):', i['source'])
334 nt.assert_in('lambda self, v:', i['source'])
324 nt.assert_in('lambda self, v:', i['source'])
335
325
336 i = inspector.info(A.id, detail_level=1)
326 i = inspector.info(A.id, detail_level=1)
337 nt.assert_in('fget = <function id>', i['source'])
327 nt.assert_in('fget = <function id>', i['source'])
338
328
339 i = inspector.info(A.compress, detail_level=1)
329 i = inspector.info(A.compress, detail_level=1)
340 nt.assert_in('fget = <function zlib.compress>', i['source'])
330 nt.assert_in('fget = <function zlib.compress>', i['source'])
341
331
342
332
343 def test_property_docstring_is_in_info_for_detail_level_0():
333 def test_property_docstring_is_in_info_for_detail_level_0():
344 class A(object):
334 class A(object):
345 @property
335 @property
346 def foobar(self):
336 def foobar(self):
347 """This is `foobar` property."""
337 """This is `foobar` property."""
348 pass
338 pass
349
339
350 ip.user_ns['a_obj'] = A()
340 ip.user_ns['a_obj'] = A()
351 nt.assert_equal(
341 nt.assert_equal(
352 'This is `foobar` property.',
342 'This is `foobar` property.',
353 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
343 ip.object_inspect('a_obj.foobar', detail_level=0)['docstring'])
354
344
355 ip.user_ns['a_cls'] = A
345 ip.user_ns['a_cls'] = A
356 nt.assert_equal(
346 nt.assert_equal(
357 'This is `foobar` property.',
347 'This is `foobar` property.',
358 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
348 ip.object_inspect('a_cls.foobar', detail_level=0)['docstring'])
359
349
360
350
361 def test_pdef():
351 def test_pdef():
362 # See gh-1914
352 # See gh-1914
363 def foo(): pass
353 def foo(): pass
364 inspector.pdef(foo, 'foo')
354 inspector.pdef(foo, 'foo')
365
355
366
356
367 def test_pinfo_nonascii():
357 def test_pinfo_nonascii():
368 # See gh-1177
358 # See gh-1177
369 from . import nonascii2
359 from . import nonascii2
370 ip.user_ns['nonascii2'] = nonascii2
360 ip.user_ns['nonascii2'] = nonascii2
371 ip._inspect('pinfo', 'nonascii2', detail_level=1)
361 ip._inspect('pinfo', 'nonascii2', detail_level=1)
372
362
373
363
374 def test_pinfo_docstring_no_source():
364 def test_pinfo_docstring_no_source():
375 """Docstring should be included with detail_level=1 if there is no source"""
365 """Docstring should be included with detail_level=1 if there is no source"""
376 with AssertPrints('Docstring:'):
366 with AssertPrints('Docstring:'):
377 ip._inspect('pinfo', 'str.format', detail_level=0)
367 ip._inspect('pinfo', 'str.format', detail_level=0)
378 with AssertPrints('Docstring:'):
368 with AssertPrints('Docstring:'):
379 ip._inspect('pinfo', 'str.format', detail_level=1)
369 ip._inspect('pinfo', 'str.format', detail_level=1)
380
370
381
371
382 def test_pinfo_no_docstring_if_source():
372 def test_pinfo_no_docstring_if_source():
383 """Docstring should not be included with detail_level=1 if source is found"""
373 """Docstring should not be included with detail_level=1 if source is found"""
384 def foo():
374 def foo():
385 """foo has a docstring"""
375 """foo has a docstring"""
386
376
387 ip.user_ns['foo'] = foo
377 ip.user_ns['foo'] = foo
388
378
389 with AssertPrints('Docstring:'):
379 with AssertPrints('Docstring:'):
390 ip._inspect('pinfo', 'foo', detail_level=0)
380 ip._inspect('pinfo', 'foo', detail_level=0)
391 with AssertPrints('Source:'):
381 with AssertPrints('Source:'):
392 ip._inspect('pinfo', 'foo', detail_level=1)
382 ip._inspect('pinfo', 'foo', detail_level=1)
393 with AssertNotPrints('Docstring:'):
383 with AssertNotPrints('Docstring:'):
394 ip._inspect('pinfo', 'foo', detail_level=1)
384 ip._inspect('pinfo', 'foo', detail_level=1)
395
385
396
386
397 def test_pinfo_magic():
387 def test_pinfo_magic():
398 with AssertPrints('Docstring:'):
388 with AssertPrints('Docstring:'):
399 ip._inspect('pinfo', 'lsmagic', detail_level=0)
389 ip._inspect('pinfo', 'lsmagic', detail_level=0)
400
390
401 with AssertPrints('Source:'):
391 with AssertPrints('Source:'):
402 ip._inspect('pinfo', 'lsmagic', detail_level=1)
392 ip._inspect('pinfo', 'lsmagic', detail_level=1)
403
393
404
394
405 def test_init_colors():
395 def test_init_colors():
406 # ensure colors are not present in signature info
396 # ensure colors are not present in signature info
407 info = inspector.info(HasSignature)
397 info = inspector.info(HasSignature)
408 init_def = info['init_definition']
398 init_def = info['init_definition']
409 nt.assert_not_in('[0m', init_def)
399 nt.assert_not_in('[0m', init_def)
410
400
411
401
412 def test_builtin_init():
402 def test_builtin_init():
413 info = inspector.info(list)
403 info = inspector.info(list)
414 init_def = info['init_definition']
404 init_def = info['init_definition']
415 # Python < 3.4 can't get init definition from builtins,
405 # Python < 3.4 can't get init definition from builtins,
416 # but still exercise the inspection in case of error-raising bugs.
406 # but still exercise the inspection in case of error-raising bugs.
417 if sys.version_info >= (3,4):
407 if sys.version_info >= (3,4):
418 nt.assert_is_not_none(init_def)
408 nt.assert_is_not_none(init_def)
419
409
@@ -1,351 +1,347 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.core.ultratb
2 """Tests for IPython.core.ultratb
3 """
3 """
4 import io
4 import io
5 import sys
5 import sys
6 import os.path
6 import os.path
7 from textwrap import dedent
7 from textwrap import dedent
8 import traceback
8 import traceback
9 import unittest
9 import unittest
10 from unittest import mock
10 from unittest import mock
11
11
12 from ..ultratb import ColorTB, VerboseTB, find_recursion
12 from ..ultratb import ColorTB, VerboseTB, find_recursion
13
13
14
14
15 from IPython.testing import tools as tt
15 from IPython.testing import tools as tt
16 from IPython.testing.decorators import onlyif_unicode_paths
16 from IPython.testing.decorators import onlyif_unicode_paths
17 from IPython.utils.syspathcontext import prepended_to_syspath
17 from IPython.utils.syspathcontext import prepended_to_syspath
18 from IPython.utils.tempdir import TemporaryDirectory
18 from IPython.utils.tempdir import TemporaryDirectory
19 from IPython.utils.py3compat import PY3
20
19
21 ip = get_ipython()
20 ip = get_ipython()
22
21
23 file_1 = """1
22 file_1 = """1
24 2
23 2
25 3
24 3
26 def f():
25 def f():
27 1/0
26 1/0
28 """
27 """
29
28
30 file_2 = """def f():
29 file_2 = """def f():
31 1/0
30 1/0
32 """
31 """
33
32
34 class ChangedPyFileTest(unittest.TestCase):
33 class ChangedPyFileTest(unittest.TestCase):
35 def test_changing_py_file(self):
34 def test_changing_py_file(self):
36 """Traceback produced if the line where the error occurred is missing?
35 """Traceback produced if the line where the error occurred is missing?
37
36
38 https://github.com/ipython/ipython/issues/1456
37 https://github.com/ipython/ipython/issues/1456
39 """
38 """
40 with TemporaryDirectory() as td:
39 with TemporaryDirectory() as td:
41 fname = os.path.join(td, "foo.py")
40 fname = os.path.join(td, "foo.py")
42 with open(fname, "w") as f:
41 with open(fname, "w") as f:
43 f.write(file_1)
42 f.write(file_1)
44
43
45 with prepended_to_syspath(td):
44 with prepended_to_syspath(td):
46 ip.run_cell("import foo")
45 ip.run_cell("import foo")
47
46
48 with tt.AssertPrints("ZeroDivisionError"):
47 with tt.AssertPrints("ZeroDivisionError"):
49 ip.run_cell("foo.f()")
48 ip.run_cell("foo.f()")
50
49
51 # Make the file shorter, so the line of the error is missing.
50 # Make the file shorter, so the line of the error is missing.
52 with open(fname, "w") as f:
51 with open(fname, "w") as f:
53 f.write(file_2)
52 f.write(file_2)
54
53
55 # For some reason, this was failing on the *second* call after
54 # For some reason, this was failing on the *second* call after
56 # changing the file, so we call f() twice.
55 # changing the file, so we call f() twice.
57 with tt.AssertNotPrints("Internal Python error", channel='stderr'):
56 with tt.AssertNotPrints("Internal Python error", channel='stderr'):
58 with tt.AssertPrints("ZeroDivisionError"):
57 with tt.AssertPrints("ZeroDivisionError"):
59 ip.run_cell("foo.f()")
58 ip.run_cell("foo.f()")
60 with tt.AssertPrints("ZeroDivisionError"):
59 with tt.AssertPrints("ZeroDivisionError"):
61 ip.run_cell("foo.f()")
60 ip.run_cell("foo.f()")
62
61
63 iso_8859_5_file = u'''# coding: iso-8859-5
62 iso_8859_5_file = u'''# coding: iso-8859-5
64
63
65 def fail():
64 def fail():
66 """Π΄Π±Π˜Π–"""
65 """Π΄Π±Π˜Π–"""
67 1/0 # Π΄Π±Π˜Π–
66 1/0 # Π΄Π±Π˜Π–
68 '''
67 '''
69
68
70 class NonAsciiTest(unittest.TestCase):
69 class NonAsciiTest(unittest.TestCase):
71 @onlyif_unicode_paths
70 @onlyif_unicode_paths
72 def test_nonascii_path(self):
71 def test_nonascii_path(self):
73 # Non-ascii directory name as well.
72 # Non-ascii directory name as well.
74 with TemporaryDirectory(suffix=u'Γ©') as td:
73 with TemporaryDirectory(suffix=u'Γ©') as td:
75 fname = os.path.join(td, u"fooΓ©.py")
74 fname = os.path.join(td, u"fooΓ©.py")
76 with open(fname, "w") as f:
75 with open(fname, "w") as f:
77 f.write(file_1)
76 f.write(file_1)
78
77
79 with prepended_to_syspath(td):
78 with prepended_to_syspath(td):
80 ip.run_cell("import foo")
79 ip.run_cell("import foo")
81
80
82 with tt.AssertPrints("ZeroDivisionError"):
81 with tt.AssertPrints("ZeroDivisionError"):
83 ip.run_cell("foo.f()")
82 ip.run_cell("foo.f()")
84
83
85 def test_iso8859_5(self):
84 def test_iso8859_5(self):
86 with TemporaryDirectory() as td:
85 with TemporaryDirectory() as td:
87 fname = os.path.join(td, 'dfghjkl.py')
86 fname = os.path.join(td, 'dfghjkl.py')
88
87
89 with io.open(fname, 'w', encoding='iso-8859-5') as f:
88 with io.open(fname, 'w', encoding='iso-8859-5') as f:
90 f.write(iso_8859_5_file)
89 f.write(iso_8859_5_file)
91
90
92 with prepended_to_syspath(td):
91 with prepended_to_syspath(td):
93 ip.run_cell("from dfghjkl import fail")
92 ip.run_cell("from dfghjkl import fail")
94
93
95 with tt.AssertPrints("ZeroDivisionError"):
94 with tt.AssertPrints("ZeroDivisionError"):
96 with tt.AssertPrints(u'Π΄Π±Π˜Π–', suppress=False):
95 with tt.AssertPrints(u'Π΄Π±Π˜Π–', suppress=False):
97 ip.run_cell('fail()')
96 ip.run_cell('fail()')
98
97
99 def test_nonascii_msg(self):
98 def test_nonascii_msg(self):
100 cell = u"raise Exception('Γ©')"
99 cell = u"raise Exception('Γ©')"
101 expected = u"Exception('Γ©')"
100 expected = u"Exception('Γ©')"
102 ip.run_cell("%xmode plain")
101 ip.run_cell("%xmode plain")
103 with tt.AssertPrints(expected):
102 with tt.AssertPrints(expected):
104 ip.run_cell(cell)
103 ip.run_cell(cell)
105
104
106 ip.run_cell("%xmode verbose")
105 ip.run_cell("%xmode verbose")
107 with tt.AssertPrints(expected):
106 with tt.AssertPrints(expected):
108 ip.run_cell(cell)
107 ip.run_cell(cell)
109
108
110 ip.run_cell("%xmode context")
109 ip.run_cell("%xmode context")
111 with tt.AssertPrints(expected):
110 with tt.AssertPrints(expected):
112 ip.run_cell(cell)
111 ip.run_cell(cell)
113
112
114
113
115 class NestedGenExprTestCase(unittest.TestCase):
114 class NestedGenExprTestCase(unittest.TestCase):
116 """
115 """
117 Regression test for the following issues:
116 Regression test for the following issues:
118 https://github.com/ipython/ipython/issues/8293
117 https://github.com/ipython/ipython/issues/8293
119 https://github.com/ipython/ipython/issues/8205
118 https://github.com/ipython/ipython/issues/8205
120 """
119 """
121 def test_nested_genexpr(self):
120 def test_nested_genexpr(self):
122 code = dedent(
121 code = dedent(
123 """\
122 """\
124 class SpecificException(Exception):
123 class SpecificException(Exception):
125 pass
124 pass
126
125
127 def foo(x):
126 def foo(x):
128 raise SpecificException("Success!")
127 raise SpecificException("Success!")
129
128
130 sum(sum(foo(x) for _ in [0]) for x in [0])
129 sum(sum(foo(x) for _ in [0]) for x in [0])
131 """
130 """
132 )
131 )
133 with tt.AssertPrints('SpecificException: Success!', suppress=False):
132 with tt.AssertPrints('SpecificException: Success!', suppress=False):
134 ip.run_cell(code)
133 ip.run_cell(code)
135
134
136
135
137 indentationerror_file = """if True:
136 indentationerror_file = """if True:
138 zoon()
137 zoon()
139 """
138 """
140
139
141 class IndentationErrorTest(unittest.TestCase):
140 class IndentationErrorTest(unittest.TestCase):
142 def test_indentationerror_shows_line(self):
141 def test_indentationerror_shows_line(self):
143 # See issue gh-2398
142 # See issue gh-2398
144 with tt.AssertPrints("IndentationError"):
143 with tt.AssertPrints("IndentationError"):
145 with tt.AssertPrints("zoon()", suppress=False):
144 with tt.AssertPrints("zoon()", suppress=False):
146 ip.run_cell(indentationerror_file)
145 ip.run_cell(indentationerror_file)
147
146
148 with TemporaryDirectory() as td:
147 with TemporaryDirectory() as td:
149 fname = os.path.join(td, "foo.py")
148 fname = os.path.join(td, "foo.py")
150 with open(fname, "w") as f:
149 with open(fname, "w") as f:
151 f.write(indentationerror_file)
150 f.write(indentationerror_file)
152
151
153 with tt.AssertPrints("IndentationError"):
152 with tt.AssertPrints("IndentationError"):
154 with tt.AssertPrints("zoon()", suppress=False):
153 with tt.AssertPrints("zoon()", suppress=False):
155 ip.magic('run %s' % fname)
154 ip.magic('run %s' % fname)
156
155
157 se_file_1 = """1
156 se_file_1 = """1
158 2
157 2
159 7/
158 7/
160 """
159 """
161
160
162 se_file_2 = """7/
161 se_file_2 = """7/
163 """
162 """
164
163
165 class SyntaxErrorTest(unittest.TestCase):
164 class SyntaxErrorTest(unittest.TestCase):
166 def test_syntaxerror_without_lineno(self):
165 def test_syntaxerror_without_lineno(self):
167 with tt.AssertNotPrints("TypeError"):
166 with tt.AssertNotPrints("TypeError"):
168 with tt.AssertPrints("line unknown"):
167 with tt.AssertPrints("line unknown"):
169 ip.run_cell("raise SyntaxError()")
168 ip.run_cell("raise SyntaxError()")
170
169
171 def test_changing_py_file(self):
170 def test_changing_py_file(self):
172 with TemporaryDirectory() as td:
171 with TemporaryDirectory() as td:
173 fname = os.path.join(td, "foo.py")
172 fname = os.path.join(td, "foo.py")
174 with open(fname, 'w') as f:
173 with open(fname, 'w') as f:
175 f.write(se_file_1)
174 f.write(se_file_1)
176
175
177 with tt.AssertPrints(["7/", "SyntaxError"]):
176 with tt.AssertPrints(["7/", "SyntaxError"]):
178 ip.magic("run " + fname)
177 ip.magic("run " + fname)
179
178
180 # Modify the file
179 # Modify the file
181 with open(fname, 'w') as f:
180 with open(fname, 'w') as f:
182 f.write(se_file_2)
181 f.write(se_file_2)
183
182
184 # The SyntaxError should point to the correct line
183 # The SyntaxError should point to the correct line
185 with tt.AssertPrints(["7/", "SyntaxError"]):
184 with tt.AssertPrints(["7/", "SyntaxError"]):
186 ip.magic("run " + fname)
185 ip.magic("run " + fname)
187
186
188 def test_non_syntaxerror(self):
187 def test_non_syntaxerror(self):
189 # SyntaxTB may be called with an error other than a SyntaxError
188 # SyntaxTB may be called with an error other than a SyntaxError
190 # See e.g. gh-4361
189 # See e.g. gh-4361
191 try:
190 try:
192 raise ValueError('QWERTY')
191 raise ValueError('QWERTY')
193 except ValueError:
192 except ValueError:
194 with tt.AssertPrints('QWERTY'):
193 with tt.AssertPrints('QWERTY'):
195 ip.showsyntaxerror()
194 ip.showsyntaxerror()
196
195
197
196
198 class Python3ChainedExceptionsTest(unittest.TestCase):
197 class Python3ChainedExceptionsTest(unittest.TestCase):
199 DIRECT_CAUSE_ERROR_CODE = """
198 DIRECT_CAUSE_ERROR_CODE = """
200 try:
199 try:
201 x = 1 + 2
200 x = 1 + 2
202 print(not_defined_here)
201 print(not_defined_here)
203 except Exception as e:
202 except Exception as e:
204 x += 55
203 x += 55
205 x - 1
204 x - 1
206 y = {}
205 y = {}
207 raise KeyError('uh') from e
206 raise KeyError('uh') from e
208 """
207 """
209
208
210 EXCEPTION_DURING_HANDLING_CODE = """
209 EXCEPTION_DURING_HANDLING_CODE = """
211 try:
210 try:
212 x = 1 + 2
211 x = 1 + 2
213 print(not_defined_here)
212 print(not_defined_here)
214 except Exception as e:
213 except Exception as e:
215 x += 55
214 x += 55
216 x - 1
215 x - 1
217 y = {}
216 y = {}
218 raise KeyError('uh')
217 raise KeyError('uh')
219 """
218 """
220
219
221 SUPPRESS_CHAINING_CODE = """
220 SUPPRESS_CHAINING_CODE = """
222 try:
221 try:
223 1/0
222 1/0
224 except Exception:
223 except Exception:
225 raise ValueError("Yikes") from None
224 raise ValueError("Yikes") from None
226 """
225 """
227
226
228 def test_direct_cause_error(self):
227 def test_direct_cause_error(self):
229 if PY3:
228 with tt.AssertPrints(["KeyError", "NameError", "direct cause"]):
230 with tt.AssertPrints(["KeyError", "NameError", "direct cause"]):
229 ip.run_cell(self.DIRECT_CAUSE_ERROR_CODE)
231 ip.run_cell(self.DIRECT_CAUSE_ERROR_CODE)
232
230
233 def test_exception_during_handling_error(self):
231 def test_exception_during_handling_error(self):
234 if PY3:
232 with tt.AssertPrints(["KeyError", "NameError", "During handling"]):
235 with tt.AssertPrints(["KeyError", "NameError", "During handling"]):
233 ip.run_cell(self.EXCEPTION_DURING_HANDLING_CODE)
236 ip.run_cell(self.EXCEPTION_DURING_HANDLING_CODE)
237
234
238 def test_suppress_exception_chaining(self):
235 def test_suppress_exception_chaining(self):
239 if PY3:
236 with tt.AssertNotPrints("ZeroDivisionError"), \
240 with tt.AssertNotPrints("ZeroDivisionError"), \
237 tt.AssertPrints("ValueError", suppress=False):
241 tt.AssertPrints("ValueError", suppress=False):
238 ip.run_cell(self.SUPPRESS_CHAINING_CODE)
242 ip.run_cell(self.SUPPRESS_CHAINING_CODE)
243
239
244
240
245 class RecursionTest(unittest.TestCase):
241 class RecursionTest(unittest.TestCase):
246 DEFINITIONS = """
242 DEFINITIONS = """
247 def non_recurs():
243 def non_recurs():
248 1/0
244 1/0
249
245
250 def r1():
246 def r1():
251 r1()
247 r1()
252
248
253 def r3a():
249 def r3a():
254 r3b()
250 r3b()
255
251
256 def r3b():
252 def r3b():
257 r3c()
253 r3c()
258
254
259 def r3c():
255 def r3c():
260 r3a()
256 r3a()
261
257
262 def r3o1():
258 def r3o1():
263 r3a()
259 r3a()
264
260
265 def r3o2():
261 def r3o2():
266 r3o1()
262 r3o1()
267 """
263 """
268 def setUp(self):
264 def setUp(self):
269 ip.run_cell(self.DEFINITIONS)
265 ip.run_cell(self.DEFINITIONS)
270
266
271 def test_no_recursion(self):
267 def test_no_recursion(self):
272 with tt.AssertNotPrints("frames repeated"):
268 with tt.AssertNotPrints("frames repeated"):
273 ip.run_cell("non_recurs()")
269 ip.run_cell("non_recurs()")
274
270
275 def test_recursion_one_frame(self):
271 def test_recursion_one_frame(self):
276 with tt.AssertPrints("1 frames repeated"):
272 with tt.AssertPrints("1 frames repeated"):
277 ip.run_cell("r1()")
273 ip.run_cell("r1()")
278
274
279 def test_recursion_three_frames(self):
275 def test_recursion_three_frames(self):
280 with tt.AssertPrints("3 frames repeated"):
276 with tt.AssertPrints("3 frames repeated"):
281 ip.run_cell("r3o2()")
277 ip.run_cell("r3o2()")
282
278
283 def test_find_recursion(self):
279 def test_find_recursion(self):
284 captured = []
280 captured = []
285 def capture_exc(*args, **kwargs):
281 def capture_exc(*args, **kwargs):
286 captured.append(sys.exc_info())
282 captured.append(sys.exc_info())
287 with mock.patch.object(ip, 'showtraceback', capture_exc):
283 with mock.patch.object(ip, 'showtraceback', capture_exc):
288 ip.run_cell("r3o2()")
284 ip.run_cell("r3o2()")
289
285
290 self.assertEqual(len(captured), 1)
286 self.assertEqual(len(captured), 1)
291 etype, evalue, tb = captured[0]
287 etype, evalue, tb = captured[0]
292 self.assertIn("recursion", str(evalue))
288 self.assertIn("recursion", str(evalue))
293
289
294 records = ip.InteractiveTB.get_records(tb, 3, ip.InteractiveTB.tb_offset)
290 records = ip.InteractiveTB.get_records(tb, 3, ip.InteractiveTB.tb_offset)
295 for r in records[:10]:
291 for r in records[:10]:
296 print(r[1:4])
292 print(r[1:4])
297
293
298 # The outermost frames should be:
294 # The outermost frames should be:
299 # 0: the 'cell' that was running when the exception came up
295 # 0: the 'cell' that was running when the exception came up
300 # 1: r3o2()
296 # 1: r3o2()
301 # 2: r3o1()
297 # 2: r3o1()
302 # 3: r3a()
298 # 3: r3a()
303 # Then repeating r3b, r3c, r3a
299 # Then repeating r3b, r3c, r3a
304 last_unique, repeat_length = find_recursion(etype, evalue, records)
300 last_unique, repeat_length = find_recursion(etype, evalue, records)
305 self.assertEqual(last_unique, 2)
301 self.assertEqual(last_unique, 2)
306 self.assertEqual(repeat_length, 3)
302 self.assertEqual(repeat_length, 3)
307
303
308
304
309 #----------------------------------------------------------------------------
305 #----------------------------------------------------------------------------
310
306
311 # module testing (minimal)
307 # module testing (minimal)
312 def test_handlers():
308 def test_handlers():
313 def spam(c, d_e):
309 def spam(c, d_e):
314 (d, e) = d_e
310 (d, e) = d_e
315 x = c + d
311 x = c + d
316 y = c * d
312 y = c * d
317 foo(x, y)
313 foo(x, y)
318
314
319 def foo(a, b, bar=1):
315 def foo(a, b, bar=1):
320 eggs(a, b + bar)
316 eggs(a, b + bar)
321
317
322 def eggs(f, g, z=globals()):
318 def eggs(f, g, z=globals()):
323 h = f + g
319 h = f + g
324 i = f - g
320 i = f - g
325 return h / i
321 return h / i
326
322
327 buff = io.StringIO()
323 buff = io.StringIO()
328
324
329 buff.write('')
325 buff.write('')
330 buff.write('*** Before ***')
326 buff.write('*** Before ***')
331 try:
327 try:
332 buff.write(spam(1, (2, 3)))
328 buff.write(spam(1, (2, 3)))
333 except:
329 except:
334 traceback.print_exc(file=buff)
330 traceback.print_exc(file=buff)
335
331
336 handler = ColorTB(ostream=buff)
332 handler = ColorTB(ostream=buff)
337 buff.write('*** ColorTB ***')
333 buff.write('*** ColorTB ***')
338 try:
334 try:
339 buff.write(spam(1, (2, 3)))
335 buff.write(spam(1, (2, 3)))
340 except:
336 except:
341 handler(*sys.exc_info())
337 handler(*sys.exc_info())
342 buff.write('')
338 buff.write('')
343
339
344 handler = VerboseTB(ostream=buff)
340 handler = VerboseTB(ostream=buff)
345 buff.write('*** VerboseTB ***')
341 buff.write('*** VerboseTB ***')
346 try:
342 try:
347 buff.write(spam(1, (2, 3)))
343 buff.write(spam(1, (2, 3)))
348 except:
344 except:
349 handler(*sys.exc_info())
345 handler(*sys.exc_info())
350 buff.write('')
346 buff.write('')
351
347
@@ -1,544 +1,532 b''
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport foo, bar``
59 ``%aimport foo, bar``
60
60
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
61 Import modules 'foo', 'bar' and mark them to be autoreloaded for ``%autoreload 1``
62
62
63 ``%aimport -foo``
63 ``%aimport -foo``
64
64
65 Mark module 'foo' to not be autoreloaded.
65 Mark module 'foo' to not be autoreloaded.
66
66
67 Caveats
67 Caveats
68 =======
68 =======
69
69
70 Reloading Python modules in a reliable way is in general difficult,
70 Reloading Python modules in a reliable way is in general difficult,
71 and unexpected things may occur. ``%autoreload`` tries to work around
71 and unexpected things may occur. ``%autoreload`` tries to work around
72 common pitfalls by replacing function code objects and parts of
72 common pitfalls by replacing function code objects and parts of
73 classes previously in the module with new versions. This makes the
73 classes previously in the module with new versions. This makes the
74 following things to work:
74 following things to work:
75
75
76 - Functions and classes imported via 'from xxx import foo' are upgraded
76 - Functions and classes imported via 'from xxx import foo' are upgraded
77 to new versions when 'xxx' is reloaded.
77 to new versions when 'xxx' is reloaded.
78
78
79 - Methods and properties of classes are upgraded on reload, so that
79 - Methods and properties of classes are upgraded on reload, so that
80 calling 'c.foo()' on an object 'c' created before the reload causes
80 calling 'c.foo()' on an object 'c' created before the reload causes
81 the new code for 'foo' to be executed.
81 the new code for 'foo' to be executed.
82
82
83 Some of the known remaining caveats are:
83 Some of the known remaining caveats are:
84
84
85 - Replacing code objects does not always succeed: changing a @property
85 - Replacing code objects does not always succeed: changing a @property
86 in a class to an ordinary method or a method to a member variable
86 in a class to an ordinary method or a method to a member variable
87 can cause problems (but in old objects only).
87 can cause problems (but in old objects only).
88
88
89 - Functions that are removed (eg. via monkey-patching) from a module
89 - Functions that are removed (eg. via monkey-patching) from a module
90 before it is reloaded are not upgraded.
90 before it is reloaded are not upgraded.
91
91
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
92 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
93 """
93 """
94
94
95 skip_doctest = True
95 skip_doctest = True
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Copyright (C) 2000 Thomas Heller
98 # Copyright (C) 2000 Thomas Heller
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
99 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
100 # Copyright (C) 2012 The IPython Development Team
100 # Copyright (C) 2012 The IPython Development Team
101 #
101 #
102 # Distributed under the terms of the BSD License. The full license is in
102 # Distributed under the terms of the BSD License. The full license is in
103 # the file COPYING, distributed as part of this software.
103 # the file COPYING, distributed as part of this software.
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 #
105 #
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
106 # This IPython module is written by Pauli Virtanen, based on the autoreload
107 # code by Thomas Heller.
107 # code by Thomas Heller.
108
108
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110 # Imports
110 # Imports
111 #-----------------------------------------------------------------------------
111 #-----------------------------------------------------------------------------
112
112
113 import os
113 import os
114 import sys
114 import sys
115 import traceback
115 import traceback
116 import types
116 import types
117 import weakref
117 import weakref
118 from importlib import import_module
118 from importlib import import_module
119
119
120 try:
120 try:
121 # Reload is not defined by default in Python3.
121 # Reload is not defined by default in Python3.
122 reload
122 reload
123 except NameError:
123 except NameError:
124 from imp import reload
124 from imp import reload
125
125
126 from IPython.utils import openpy
126 from IPython.utils import openpy
127 from IPython.utils.py3compat import PY3
127 from IPython.utils.py3compat import PY3
128
128
129 #------------------------------------------------------------------------------
129 #------------------------------------------------------------------------------
130 # Autoreload functionality
130 # Autoreload functionality
131 #------------------------------------------------------------------------------
131 #------------------------------------------------------------------------------
132
132
133 class ModuleReloader(object):
133 class ModuleReloader(object):
134 enabled = False
134 enabled = False
135 """Whether this reloader is enabled"""
135 """Whether this reloader is enabled"""
136
136
137 check_all = True
137 check_all = True
138 """Autoreload all modules, not just those listed in 'modules'"""
138 """Autoreload all modules, not just those listed in 'modules'"""
139
139
140 def __init__(self):
140 def __init__(self):
141 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
141 # Modules that failed to reload: {module: mtime-on-failed-reload, ...}
142 self.failed = {}
142 self.failed = {}
143 # Modules specially marked as autoreloadable.
143 # Modules specially marked as autoreloadable.
144 self.modules = {}
144 self.modules = {}
145 # Modules specially marked as not autoreloadable.
145 # Modules specially marked as not autoreloadable.
146 self.skip_modules = {}
146 self.skip_modules = {}
147 # (module-name, name) -> weakref, for replacing old code objects
147 # (module-name, name) -> weakref, for replacing old code objects
148 self.old_objects = {}
148 self.old_objects = {}
149 # Module modification timestamps
149 # Module modification timestamps
150 self.modules_mtimes = {}
150 self.modules_mtimes = {}
151
151
152 # Cache module modification times
152 # Cache module modification times
153 self.check(check_all=True, do_reload=False)
153 self.check(check_all=True, do_reload=False)
154
154
155 def mark_module_skipped(self, module_name):
155 def mark_module_skipped(self, module_name):
156 """Skip reloading the named module in the future"""
156 """Skip reloading the named module in the future"""
157 try:
157 try:
158 del self.modules[module_name]
158 del self.modules[module_name]
159 except KeyError:
159 except KeyError:
160 pass
160 pass
161 self.skip_modules[module_name] = True
161 self.skip_modules[module_name] = True
162
162
163 def mark_module_reloadable(self, module_name):
163 def mark_module_reloadable(self, module_name):
164 """Reload the named module in the future (if it is imported)"""
164 """Reload the named module in the future (if it is imported)"""
165 try:
165 try:
166 del self.skip_modules[module_name]
166 del self.skip_modules[module_name]
167 except KeyError:
167 except KeyError:
168 pass
168 pass
169 self.modules[module_name] = True
169 self.modules[module_name] = True
170
170
171 def aimport_module(self, module_name):
171 def aimport_module(self, module_name):
172 """Import a module, and mark it reloadable
172 """Import a module, and mark it reloadable
173
173
174 Returns
174 Returns
175 -------
175 -------
176 top_module : module
176 top_module : module
177 The imported module if it is top-level, or the top-level
177 The imported module if it is top-level, or the top-level
178 top_name : module
178 top_name : module
179 Name of top_module
179 Name of top_module
180
180
181 """
181 """
182 self.mark_module_reloadable(module_name)
182 self.mark_module_reloadable(module_name)
183
183
184 import_module(module_name)
184 import_module(module_name)
185 top_name = module_name.split('.')[0]
185 top_name = module_name.split('.')[0]
186 top_module = sys.modules[top_name]
186 top_module = sys.modules[top_name]
187 return top_module, top_name
187 return top_module, top_name
188
188
189 def filename_and_mtime(self, module):
189 def filename_and_mtime(self, module):
190 if not hasattr(module, '__file__') or module.__file__ is None:
190 if not hasattr(module, '__file__') or module.__file__ is None:
191 return None, None
191 return None, None
192
192
193 if getattr(module, '__name__', None) == '__main__':
193 if getattr(module, '__name__', None) == '__main__':
194 # we cannot reload(__main__)
194 # we cannot reload(__main__)
195 return None, None
195 return None, None
196
196
197 filename = module.__file__
197 filename = module.__file__
198 path, ext = os.path.splitext(filename)
198 path, ext = os.path.splitext(filename)
199
199
200 if ext.lower() == '.py':
200 if ext.lower() == '.py':
201 py_filename = filename
201 py_filename = filename
202 else:
202 else:
203 try:
203 try:
204 py_filename = openpy.source_from_cache(filename)
204 py_filename = openpy.source_from_cache(filename)
205 except ValueError:
205 except ValueError:
206 return None, None
206 return None, None
207
207
208 try:
208 try:
209 pymtime = os.stat(py_filename).st_mtime
209 pymtime = os.stat(py_filename).st_mtime
210 except OSError:
210 except OSError:
211 return None, None
211 return None, None
212
212
213 return py_filename, pymtime
213 return py_filename, pymtime
214
214
215 def check(self, check_all=False, do_reload=True):
215 def check(self, check_all=False, do_reload=True):
216 """Check whether some modules need to be reloaded."""
216 """Check whether some modules need to be reloaded."""
217
217
218 if not self.enabled and not check_all:
218 if not self.enabled and not check_all:
219 return
219 return
220
220
221 if check_all or self.check_all:
221 if check_all or self.check_all:
222 modules = list(sys.modules.keys())
222 modules = list(sys.modules.keys())
223 else:
223 else:
224 modules = list(self.modules.keys())
224 modules = list(self.modules.keys())
225
225
226 for modname in modules:
226 for modname in modules:
227 m = sys.modules.get(modname, None)
227 m = sys.modules.get(modname, None)
228
228
229 if modname in self.skip_modules:
229 if modname in self.skip_modules:
230 continue
230 continue
231
231
232 py_filename, pymtime = self.filename_and_mtime(m)
232 py_filename, pymtime = self.filename_and_mtime(m)
233 if py_filename is None:
233 if py_filename is None:
234 continue
234 continue
235
235
236 try:
236 try:
237 if pymtime <= self.modules_mtimes[modname]:
237 if pymtime <= self.modules_mtimes[modname]:
238 continue
238 continue
239 except KeyError:
239 except KeyError:
240 self.modules_mtimes[modname] = pymtime
240 self.modules_mtimes[modname] = pymtime
241 continue
241 continue
242 else:
242 else:
243 if self.failed.get(py_filename, None) == pymtime:
243 if self.failed.get(py_filename, None) == pymtime:
244 continue
244 continue
245
245
246 self.modules_mtimes[modname] = pymtime
246 self.modules_mtimes[modname] = pymtime
247
247
248 # If we've reached this point, we should try to reload the module
248 # If we've reached this point, we should try to reload the module
249 if do_reload:
249 if do_reload:
250 try:
250 try:
251 superreload(m, reload, self.old_objects)
251 superreload(m, reload, self.old_objects)
252 if py_filename in self.failed:
252 if py_filename in self.failed:
253 del self.failed[py_filename]
253 del self.failed[py_filename]
254 except:
254 except:
255 print("[autoreload of %s failed: %s]" % (
255 print("[autoreload of %s failed: %s]" % (
256 modname, traceback.format_exc(1)), file=sys.stderr)
256 modname, traceback.format_exc(1)), file=sys.stderr)
257 self.failed[py_filename] = pymtime
257 self.failed[py_filename] = pymtime
258
258
259 #------------------------------------------------------------------------------
259 #------------------------------------------------------------------------------
260 # superreload
260 # superreload
261 #------------------------------------------------------------------------------
261 #------------------------------------------------------------------------------
262
262
263 if PY3:
263
264 func_attrs = ['__code__', '__defaults__', '__doc__',
264 func_attrs = ['__code__', '__defaults__', '__doc__',
265 '__closure__', '__globals__', '__dict__']
265 '__closure__', '__globals__', '__dict__']
266 else:
267 func_attrs = ['func_code', 'func_defaults', 'func_doc',
268 'func_closure', 'func_globals', 'func_dict']
269
266
270
267
271 def update_function(old, new):
268 def update_function(old, new):
272 """Upgrade the code object of a function"""
269 """Upgrade the code object of a function"""
273 for name in func_attrs:
270 for name in func_attrs:
274 try:
271 try:
275 setattr(old, name, getattr(new, name))
272 setattr(old, name, getattr(new, name))
276 except (AttributeError, TypeError):
273 except (AttributeError, TypeError):
277 pass
274 pass
278
275
279
276
280 def update_class(old, new):
277 def update_class(old, new):
281 """Replace stuff in the __dict__ of a class, and upgrade
278 """Replace stuff in the __dict__ of a class, and upgrade
282 method code objects"""
279 method code objects"""
283 for key in list(old.__dict__.keys()):
280 for key in list(old.__dict__.keys()):
284 old_obj = getattr(old, key)
281 old_obj = getattr(old, key)
285
282
286 try:
283 try:
287 new_obj = getattr(new, key)
284 new_obj = getattr(new, key)
288 except AttributeError:
285 except AttributeError:
289 # obsolete attribute: remove it
286 # obsolete attribute: remove it
290 try:
287 try:
291 delattr(old, key)
288 delattr(old, key)
292 except (AttributeError, TypeError):
289 except (AttributeError, TypeError):
293 pass
290 pass
294 continue
291 continue
295
292
296 if update_generic(old_obj, new_obj): continue
293 if update_generic(old_obj, new_obj): continue
297
294
298 try:
295 try:
299 setattr(old, key, getattr(new, key))
296 setattr(old, key, getattr(new, key))
300 except (AttributeError, TypeError):
297 except (AttributeError, TypeError):
301 pass # skip non-writable attributes
298 pass # skip non-writable attributes
302
299
303
300
304 def update_property(old, new):
301 def update_property(old, new):
305 """Replace get/set/del functions of a property"""
302 """Replace get/set/del functions of a property"""
306 update_generic(old.fdel, new.fdel)
303 update_generic(old.fdel, new.fdel)
307 update_generic(old.fget, new.fget)
304 update_generic(old.fget, new.fget)
308 update_generic(old.fset, new.fset)
305 update_generic(old.fset, new.fset)
309
306
310
307
311 def isinstance2(a, b, typ):
308 def isinstance2(a, b, typ):
312 return isinstance(a, typ) and isinstance(b, typ)
309 return isinstance(a, typ) and isinstance(b, typ)
313
310
314
311
315 UPDATE_RULES = [
312 UPDATE_RULES = [
316 (lambda a, b: isinstance2(a, b, type),
313 (lambda a, b: isinstance2(a, b, type),
317 update_class),
314 update_class),
318 (lambda a, b: isinstance2(a, b, types.FunctionType),
315 (lambda a, b: isinstance2(a, b, types.FunctionType),
319 update_function),
316 update_function),
320 (lambda a, b: isinstance2(a, b, property),
317 (lambda a, b: isinstance2(a, b, property),
321 update_property),
318 update_property),
322 ]
319 ]
323
320 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
324
321 lambda a, b: update_function(a.__func__, b.__func__)),
325 if PY3:
322 ])
326 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
327 lambda a, b: update_function(a.__func__, b.__func__)),
328 ])
329 else:
330 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
331 update_class),
332 (lambda a, b: isinstance2(a, b, types.MethodType),
333 lambda a, b: update_function(a.__func__, b.__func__)),
334 ])
335
323
336
324
337 def update_generic(a, b):
325 def update_generic(a, b):
338 for type_check, update in UPDATE_RULES:
326 for type_check, update in UPDATE_RULES:
339 if type_check(a, b):
327 if type_check(a, b):
340 update(a, b)
328 update(a, b)
341 return True
329 return True
342 return False
330 return False
343
331
344
332
345 class StrongRef(object):
333 class StrongRef(object):
346 def __init__(self, obj):
334 def __init__(self, obj):
347 self.obj = obj
335 self.obj = obj
348 def __call__(self):
336 def __call__(self):
349 return self.obj
337 return self.obj
350
338
351
339
352 def superreload(module, reload=reload, old_objects={}):
340 def superreload(module, reload=reload, old_objects={}):
353 """Enhanced version of the builtin reload function.
341 """Enhanced version of the builtin reload function.
354
342
355 superreload remembers objects previously in the module, and
343 superreload remembers objects previously in the module, and
356
344
357 - upgrades the class dictionary of every old class in the module
345 - upgrades the class dictionary of every old class in the module
358 - upgrades the code object of every old function and method
346 - upgrades the code object of every old function and method
359 - clears the module's namespace before reloading
347 - clears the module's namespace before reloading
360
348
361 """
349 """
362
350
363 # collect old objects in the module
351 # collect old objects in the module
364 for name, obj in list(module.__dict__.items()):
352 for name, obj in list(module.__dict__.items()):
365 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
353 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
366 continue
354 continue
367 key = (module.__name__, name)
355 key = (module.__name__, name)
368 try:
356 try:
369 old_objects.setdefault(key, []).append(weakref.ref(obj))
357 old_objects.setdefault(key, []).append(weakref.ref(obj))
370 except TypeError:
358 except TypeError:
371 # weakref doesn't work for all types;
359 # weakref doesn't work for all types;
372 # create strong references for 'important' cases
360 # create strong references for 'important' cases
373 if not PY3 and isinstance(obj, types.ClassType):
361 if not PY3 and isinstance(obj, types.ClassType):
374 old_objects.setdefault(key, []).append(StrongRef(obj))
362 old_objects.setdefault(key, []).append(StrongRef(obj))
375
363
376 # reload module
364 # reload module
377 try:
365 try:
378 # clear namespace first from old cruft
366 # clear namespace first from old cruft
379 old_dict = module.__dict__.copy()
367 old_dict = module.__dict__.copy()
380 old_name = module.__name__
368 old_name = module.__name__
381 module.__dict__.clear()
369 module.__dict__.clear()
382 module.__dict__['__name__'] = old_name
370 module.__dict__['__name__'] = old_name
383 module.__dict__['__loader__'] = old_dict['__loader__']
371 module.__dict__['__loader__'] = old_dict['__loader__']
384 except (TypeError, AttributeError, KeyError):
372 except (TypeError, AttributeError, KeyError):
385 pass
373 pass
386
374
387 try:
375 try:
388 module = reload(module)
376 module = reload(module)
389 except:
377 except:
390 # restore module dictionary on failed reload
378 # restore module dictionary on failed reload
391 module.__dict__.update(old_dict)
379 module.__dict__.update(old_dict)
392 raise
380 raise
393
381
394 # iterate over all objects and update functions & classes
382 # iterate over all objects and update functions & classes
395 for name, new_obj in list(module.__dict__.items()):
383 for name, new_obj in list(module.__dict__.items()):
396 key = (module.__name__, name)
384 key = (module.__name__, name)
397 if key not in old_objects: continue
385 if key not in old_objects: continue
398
386
399 new_refs = []
387 new_refs = []
400 for old_ref in old_objects[key]:
388 for old_ref in old_objects[key]:
401 old_obj = old_ref()
389 old_obj = old_ref()
402 if old_obj is None: continue
390 if old_obj is None: continue
403 new_refs.append(old_ref)
391 new_refs.append(old_ref)
404 update_generic(old_obj, new_obj)
392 update_generic(old_obj, new_obj)
405
393
406 if new_refs:
394 if new_refs:
407 old_objects[key] = new_refs
395 old_objects[key] = new_refs
408 else:
396 else:
409 del old_objects[key]
397 del old_objects[key]
410
398
411 return module
399 return module
412
400
413 #------------------------------------------------------------------------------
401 #------------------------------------------------------------------------------
414 # IPython connectivity
402 # IPython connectivity
415 #------------------------------------------------------------------------------
403 #------------------------------------------------------------------------------
416
404
417 from IPython.core.magic import Magics, magics_class, line_magic
405 from IPython.core.magic import Magics, magics_class, line_magic
418
406
419 @magics_class
407 @magics_class
420 class AutoreloadMagics(Magics):
408 class AutoreloadMagics(Magics):
421 def __init__(self, *a, **kw):
409 def __init__(self, *a, **kw):
422 super(AutoreloadMagics, self).__init__(*a, **kw)
410 super(AutoreloadMagics, self).__init__(*a, **kw)
423 self._reloader = ModuleReloader()
411 self._reloader = ModuleReloader()
424 self._reloader.check_all = False
412 self._reloader.check_all = False
425 self.loaded_modules = set(sys.modules)
413 self.loaded_modules = set(sys.modules)
426
414
427 @line_magic
415 @line_magic
428 def autoreload(self, parameter_s=''):
416 def autoreload(self, parameter_s=''):
429 r"""%autoreload => Reload modules automatically
417 r"""%autoreload => Reload modules automatically
430
418
431 %autoreload
419 %autoreload
432 Reload all modules (except those excluded by %aimport) automatically
420 Reload all modules (except those excluded by %aimport) automatically
433 now.
421 now.
434
422
435 %autoreload 0
423 %autoreload 0
436 Disable automatic reloading.
424 Disable automatic reloading.
437
425
438 %autoreload 1
426 %autoreload 1
439 Reload all modules imported with %aimport every time before executing
427 Reload all modules imported with %aimport every time before executing
440 the Python code typed.
428 the Python code typed.
441
429
442 %autoreload 2
430 %autoreload 2
443 Reload all modules (except those excluded by %aimport) every time
431 Reload all modules (except those excluded by %aimport) every time
444 before executing the Python code typed.
432 before executing the Python code typed.
445
433
446 Reloading Python modules in a reliable way is in general
434 Reloading Python modules in a reliable way is in general
447 difficult, and unexpected things may occur. %autoreload tries to
435 difficult, and unexpected things may occur. %autoreload tries to
448 work around common pitfalls by replacing function code objects and
436 work around common pitfalls by replacing function code objects and
449 parts of classes previously in the module with new versions. This
437 parts of classes previously in the module with new versions. This
450 makes the following things to work:
438 makes the following things to work:
451
439
452 - Functions and classes imported via 'from xxx import foo' are upgraded
440 - Functions and classes imported via 'from xxx import foo' are upgraded
453 to new versions when 'xxx' is reloaded.
441 to new versions when 'xxx' is reloaded.
454
442
455 - Methods and properties of classes are upgraded on reload, so that
443 - Methods and properties of classes are upgraded on reload, so that
456 calling 'c.foo()' on an object 'c' created before the reload causes
444 calling 'c.foo()' on an object 'c' created before the reload causes
457 the new code for 'foo' to be executed.
445 the new code for 'foo' to be executed.
458
446
459 Some of the known remaining caveats are:
447 Some of the known remaining caveats are:
460
448
461 - Replacing code objects does not always succeed: changing a @property
449 - Replacing code objects does not always succeed: changing a @property
462 in a class to an ordinary method or a method to a member variable
450 in a class to an ordinary method or a method to a member variable
463 can cause problems (but in old objects only).
451 can cause problems (but in old objects only).
464
452
465 - Functions that are removed (eg. via monkey-patching) from a module
453 - Functions that are removed (eg. via monkey-patching) from a module
466 before it is reloaded are not upgraded.
454 before it is reloaded are not upgraded.
467
455
468 - C extension modules cannot be reloaded, and so cannot be
456 - C extension modules cannot be reloaded, and so cannot be
469 autoreloaded.
457 autoreloaded.
470
458
471 """
459 """
472 if parameter_s == '':
460 if parameter_s == '':
473 self._reloader.check(True)
461 self._reloader.check(True)
474 elif parameter_s == '0':
462 elif parameter_s == '0':
475 self._reloader.enabled = False
463 self._reloader.enabled = False
476 elif parameter_s == '1':
464 elif parameter_s == '1':
477 self._reloader.check_all = False
465 self._reloader.check_all = False
478 self._reloader.enabled = True
466 self._reloader.enabled = True
479 elif parameter_s == '2':
467 elif parameter_s == '2':
480 self._reloader.check_all = True
468 self._reloader.check_all = True
481 self._reloader.enabled = True
469 self._reloader.enabled = True
482
470
483 @line_magic
471 @line_magic
484 def aimport(self, parameter_s='', stream=None):
472 def aimport(self, parameter_s='', stream=None):
485 """%aimport => Import modules for automatic reloading.
473 """%aimport => Import modules for automatic reloading.
486
474
487 %aimport
475 %aimport
488 List modules to automatically import and not to import.
476 List modules to automatically import and not to import.
489
477
490 %aimport foo
478 %aimport foo
491 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
479 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
492
480
493 %aimport foo, bar
481 %aimport foo, bar
494 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
482 Import modules 'foo', 'bar' and mark them to be autoreloaded for %autoreload 1
495
483
496 %aimport -foo
484 %aimport -foo
497 Mark module 'foo' to not be autoreloaded for %autoreload 1
485 Mark module 'foo' to not be autoreloaded for %autoreload 1
498 """
486 """
499 modname = parameter_s
487 modname = parameter_s
500 if not modname:
488 if not modname:
501 to_reload = sorted(self._reloader.modules.keys())
489 to_reload = sorted(self._reloader.modules.keys())
502 to_skip = sorted(self._reloader.skip_modules.keys())
490 to_skip = sorted(self._reloader.skip_modules.keys())
503 if stream is None:
491 if stream is None:
504 stream = sys.stdout
492 stream = sys.stdout
505 if self._reloader.check_all:
493 if self._reloader.check_all:
506 stream.write("Modules to reload:\nall-except-skipped\n")
494 stream.write("Modules to reload:\nall-except-skipped\n")
507 else:
495 else:
508 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
496 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
509 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
497 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
510 elif modname.startswith('-'):
498 elif modname.startswith('-'):
511 modname = modname[1:]
499 modname = modname[1:]
512 self._reloader.mark_module_skipped(modname)
500 self._reloader.mark_module_skipped(modname)
513 else:
501 else:
514 for _module in ([_.strip() for _ in modname.split(',')]):
502 for _module in ([_.strip() for _ in modname.split(',')]):
515 top_module, top_name = self._reloader.aimport_module(_module)
503 top_module, top_name = self._reloader.aimport_module(_module)
516
504
517 # Inject module to user namespace
505 # Inject module to user namespace
518 self.shell.push({top_name: top_module})
506 self.shell.push({top_name: top_module})
519
507
520 def pre_run_cell(self):
508 def pre_run_cell(self):
521 if self._reloader.enabled:
509 if self._reloader.enabled:
522 try:
510 try:
523 self._reloader.check()
511 self._reloader.check()
524 except:
512 except:
525 pass
513 pass
526
514
527 def post_execute_hook(self):
515 def post_execute_hook(self):
528 """Cache the modification times of any modules imported in this execution
516 """Cache the modification times of any modules imported in this execution
529 """
517 """
530 newly_loaded_modules = set(sys.modules) - self.loaded_modules
518 newly_loaded_modules = set(sys.modules) - self.loaded_modules
531 for modname in newly_loaded_modules:
519 for modname in newly_loaded_modules:
532 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
520 _, pymtime = self._reloader.filename_and_mtime(sys.modules[modname])
533 if pymtime is not None:
521 if pymtime is not None:
534 self._reloader.modules_mtimes[modname] = pymtime
522 self._reloader.modules_mtimes[modname] = pymtime
535
523
536 self.loaded_modules.update(newly_loaded_modules)
524 self.loaded_modules.update(newly_loaded_modules)
537
525
538
526
539 def load_ipython_extension(ip):
527 def load_ipython_extension(ip):
540 """Load the extension in IPython."""
528 """Load the extension in IPython."""
541 auto_reload = AutoreloadMagics(ip)
529 auto_reload = AutoreloadMagics(ip)
542 ip.register_magics(auto_reload)
530 ip.register_magics(auto_reload)
543 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
531 ip.events.register('pre_run_cell', auto_reload.pre_run_cell)
544 ip.events.register('post_execute', auto_reload.post_execute_hook)
532 ip.events.register('post_execute', auto_reload.post_execute_hook)
@@ -1,321 +1,316 b''
1 """Tests for autoreload extension.
1 """Tests for autoreload extension.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 IPython Development Team.
4 # Copyright (c) 2012 IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import shutil
18 import shutil
19 import random
19 import random
20 import time
20 import time
21 from io import StringIO
21
22
22 import nose.tools as nt
23 import nose.tools as nt
23 import IPython.testing.tools as tt
24 import IPython.testing.tools as tt
24
25
25 from IPython.extensions.autoreload import AutoreloadMagics
26 from IPython.extensions.autoreload import AutoreloadMagics
26 from IPython.core.events import EventManager, pre_run_cell
27 from IPython.core.events import EventManager, pre_run_cell
27 from IPython.utils.py3compat import PY3
28
29 if PY3:
30 from io import StringIO
31 else:
32 from StringIO import StringIO
33
28
34 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
35 # Test fixture
30 # Test fixture
36 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
37
32
38 noop = lambda *a, **kw: None
33 noop = lambda *a, **kw: None
39
34
40 class FakeShell(object):
35 class FakeShell(object):
41
36
42 def __init__(self):
37 def __init__(self):
43 self.ns = {}
38 self.ns = {}
44 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
39 self.events = EventManager(self, {'pre_run_cell', pre_run_cell})
45 self.auto_magics = AutoreloadMagics(shell=self)
40 self.auto_magics = AutoreloadMagics(shell=self)
46 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
41 self.events.register('pre_run_cell', self.auto_magics.pre_run_cell)
47
42
48 register_magics = set_hook = noop
43 register_magics = set_hook = noop
49
44
50 def run_code(self, code):
45 def run_code(self, code):
51 self.events.trigger('pre_run_cell')
46 self.events.trigger('pre_run_cell')
52 exec(code, self.ns)
47 exec(code, self.ns)
53 self.auto_magics.post_execute_hook()
48 self.auto_magics.post_execute_hook()
54
49
55 def push(self, items):
50 def push(self, items):
56 self.ns.update(items)
51 self.ns.update(items)
57
52
58 def magic_autoreload(self, parameter):
53 def magic_autoreload(self, parameter):
59 self.auto_magics.autoreload(parameter)
54 self.auto_magics.autoreload(parameter)
60
55
61 def magic_aimport(self, parameter, stream=None):
56 def magic_aimport(self, parameter, stream=None):
62 self.auto_magics.aimport(parameter, stream=stream)
57 self.auto_magics.aimport(parameter, stream=stream)
63 self.auto_magics.post_execute_hook()
58 self.auto_magics.post_execute_hook()
64
59
65
60
66 class Fixture(object):
61 class Fixture(object):
67 """Fixture for creating test module files"""
62 """Fixture for creating test module files"""
68
63
69 test_dir = None
64 test_dir = None
70 old_sys_path = None
65 old_sys_path = None
71 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
66 filename_chars = "abcdefghijklmopqrstuvwxyz0123456789"
72
67
73 def setUp(self):
68 def setUp(self):
74 self.test_dir = tempfile.mkdtemp()
69 self.test_dir = tempfile.mkdtemp()
75 self.old_sys_path = list(sys.path)
70 self.old_sys_path = list(sys.path)
76 sys.path.insert(0, self.test_dir)
71 sys.path.insert(0, self.test_dir)
77 self.shell = FakeShell()
72 self.shell = FakeShell()
78
73
79 def tearDown(self):
74 def tearDown(self):
80 shutil.rmtree(self.test_dir)
75 shutil.rmtree(self.test_dir)
81 sys.path = self.old_sys_path
76 sys.path = self.old_sys_path
82
77
83 self.test_dir = None
78 self.test_dir = None
84 self.old_sys_path = None
79 self.old_sys_path = None
85 self.shell = None
80 self.shell = None
86
81
87 def get_module(self):
82 def get_module(self):
88 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
83 module_name = "tmpmod_" + "".join(random.sample(self.filename_chars,20))
89 if module_name in sys.modules:
84 if module_name in sys.modules:
90 del sys.modules[module_name]
85 del sys.modules[module_name]
91 file_name = os.path.join(self.test_dir, module_name + ".py")
86 file_name = os.path.join(self.test_dir, module_name + ".py")
92 return module_name, file_name
87 return module_name, file_name
93
88
94 def write_file(self, filename, content):
89 def write_file(self, filename, content):
95 """
90 """
96 Write a file, and force a timestamp difference of at least one second
91 Write a file, and force a timestamp difference of at least one second
97
92
98 Notes
93 Notes
99 -----
94 -----
100 Python's .pyc files record the timestamp of their compilation
95 Python's .pyc files record the timestamp of their compilation
101 with a time resolution of one second.
96 with a time resolution of one second.
102
97
103 Therefore, we need to force a timestamp difference between .py
98 Therefore, we need to force a timestamp difference between .py
104 and .pyc, without having the .py file be timestamped in the
99 and .pyc, without having the .py file be timestamped in the
105 future, and without changing the timestamp of the .pyc file
100 future, and without changing the timestamp of the .pyc file
106 (because that is stored in the file). The only reliable way
101 (because that is stored in the file). The only reliable way
107 to achieve this seems to be to sleep.
102 to achieve this seems to be to sleep.
108 """
103 """
109
104
110 # Sleep one second + eps
105 # Sleep one second + eps
111 time.sleep(1.05)
106 time.sleep(1.05)
112
107
113 # Write
108 # Write
114 f = open(filename, 'w')
109 f = open(filename, 'w')
115 try:
110 try:
116 f.write(content)
111 f.write(content)
117 finally:
112 finally:
118 f.close()
113 f.close()
119
114
120 def new_module(self, code):
115 def new_module(self, code):
121 mod_name, mod_fn = self.get_module()
116 mod_name, mod_fn = self.get_module()
122 f = open(mod_fn, 'w')
117 f = open(mod_fn, 'w')
123 try:
118 try:
124 f.write(code)
119 f.write(code)
125 finally:
120 finally:
126 f.close()
121 f.close()
127 return mod_name, mod_fn
122 return mod_name, mod_fn
128
123
129 #-----------------------------------------------------------------------------
124 #-----------------------------------------------------------------------------
130 # Test automatic reloading
125 # Test automatic reloading
131 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
132
127
133 class TestAutoreload(Fixture):
128 class TestAutoreload(Fixture):
134 def _check_smoketest(self, use_aimport=True):
129 def _check_smoketest(self, use_aimport=True):
135 """
130 """
136 Functional test for the automatic reloader using either
131 Functional test for the automatic reloader using either
137 '%autoreload 1' or '%autoreload 2'
132 '%autoreload 1' or '%autoreload 2'
138 """
133 """
139
134
140 mod_name, mod_fn = self.new_module("""
135 mod_name, mod_fn = self.new_module("""
141 x = 9
136 x = 9
142
137
143 z = 123 # this item will be deleted
138 z = 123 # this item will be deleted
144
139
145 def foo(y):
140 def foo(y):
146 return y + 3
141 return y + 3
147
142
148 class Baz(object):
143 class Baz(object):
149 def __init__(self, x):
144 def __init__(self, x):
150 self.x = x
145 self.x = x
151 def bar(self, y):
146 def bar(self, y):
152 return self.x + y
147 return self.x + y
153 @property
148 @property
154 def quux(self):
149 def quux(self):
155 return 42
150 return 42
156 def zzz(self):
151 def zzz(self):
157 '''This method will be deleted below'''
152 '''This method will be deleted below'''
158 return 99
153 return 99
159
154
160 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
155 class Bar: # old-style class: weakref doesn't work for it on Python < 2.7
161 def foo(self):
156 def foo(self):
162 return 1
157 return 1
163 """)
158 """)
164
159
165 #
160 #
166 # Import module, and mark for reloading
161 # Import module, and mark for reloading
167 #
162 #
168 if use_aimport:
163 if use_aimport:
169 self.shell.magic_autoreload("1")
164 self.shell.magic_autoreload("1")
170 self.shell.magic_aimport(mod_name)
165 self.shell.magic_aimport(mod_name)
171 stream = StringIO()
166 stream = StringIO()
172 self.shell.magic_aimport("", stream=stream)
167 self.shell.magic_aimport("", stream=stream)
173 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
168 nt.assert_in(("Modules to reload:\n%s" % mod_name), stream.getvalue())
174
169
175 with nt.assert_raises(ImportError):
170 with nt.assert_raises(ImportError):
176 self.shell.magic_aimport("tmpmod_as318989e89ds")
171 self.shell.magic_aimport("tmpmod_as318989e89ds")
177 else:
172 else:
178 self.shell.magic_autoreload("2")
173 self.shell.magic_autoreload("2")
179 self.shell.run_code("import %s" % mod_name)
174 self.shell.run_code("import %s" % mod_name)
180 stream = StringIO()
175 stream = StringIO()
181 self.shell.magic_aimport("", stream=stream)
176 self.shell.magic_aimport("", stream=stream)
182 nt.assert_true("Modules to reload:\nall-except-skipped" in
177 nt.assert_true("Modules to reload:\nall-except-skipped" in
183 stream.getvalue())
178 stream.getvalue())
184 nt.assert_in(mod_name, self.shell.ns)
179 nt.assert_in(mod_name, self.shell.ns)
185
180
186 mod = sys.modules[mod_name]
181 mod = sys.modules[mod_name]
187
182
188 #
183 #
189 # Test module contents
184 # Test module contents
190 #
185 #
191 old_foo = mod.foo
186 old_foo = mod.foo
192 old_obj = mod.Baz(9)
187 old_obj = mod.Baz(9)
193 old_obj2 = mod.Bar()
188 old_obj2 = mod.Bar()
194
189
195 def check_module_contents():
190 def check_module_contents():
196 nt.assert_equal(mod.x, 9)
191 nt.assert_equal(mod.x, 9)
197 nt.assert_equal(mod.z, 123)
192 nt.assert_equal(mod.z, 123)
198
193
199 nt.assert_equal(old_foo(0), 3)
194 nt.assert_equal(old_foo(0), 3)
200 nt.assert_equal(mod.foo(0), 3)
195 nt.assert_equal(mod.foo(0), 3)
201
196
202 obj = mod.Baz(9)
197 obj = mod.Baz(9)
203 nt.assert_equal(old_obj.bar(1), 10)
198 nt.assert_equal(old_obj.bar(1), 10)
204 nt.assert_equal(obj.bar(1), 10)
199 nt.assert_equal(obj.bar(1), 10)
205 nt.assert_equal(obj.quux, 42)
200 nt.assert_equal(obj.quux, 42)
206 nt.assert_equal(obj.zzz(), 99)
201 nt.assert_equal(obj.zzz(), 99)
207
202
208 obj2 = mod.Bar()
203 obj2 = mod.Bar()
209 nt.assert_equal(old_obj2.foo(), 1)
204 nt.assert_equal(old_obj2.foo(), 1)
210 nt.assert_equal(obj2.foo(), 1)
205 nt.assert_equal(obj2.foo(), 1)
211
206
212 check_module_contents()
207 check_module_contents()
213
208
214 #
209 #
215 # Simulate a failed reload: no reload should occur and exactly
210 # Simulate a failed reload: no reload should occur and exactly
216 # one error message should be printed
211 # one error message should be printed
217 #
212 #
218 self.write_file(mod_fn, """
213 self.write_file(mod_fn, """
219 a syntax error
214 a syntax error
220 """)
215 """)
221
216
222 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
217 with tt.AssertPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
223 self.shell.run_code("pass") # trigger reload
218 self.shell.run_code("pass") # trigger reload
224 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
219 with tt.AssertNotPrints(('[autoreload of %s failed:' % mod_name), channel='stderr'):
225 self.shell.run_code("pass") # trigger another reload
220 self.shell.run_code("pass") # trigger another reload
226 check_module_contents()
221 check_module_contents()
227
222
228 #
223 #
229 # Rewrite module (this time reload should succeed)
224 # Rewrite module (this time reload should succeed)
230 #
225 #
231 self.write_file(mod_fn, """
226 self.write_file(mod_fn, """
232 x = 10
227 x = 10
233
228
234 def foo(y):
229 def foo(y):
235 return y + 4
230 return y + 4
236
231
237 class Baz(object):
232 class Baz(object):
238 def __init__(self, x):
233 def __init__(self, x):
239 self.x = x
234 self.x = x
240 def bar(self, y):
235 def bar(self, y):
241 return self.x + y + 1
236 return self.x + y + 1
242 @property
237 @property
243 def quux(self):
238 def quux(self):
244 return 43
239 return 43
245
240
246 class Bar: # old-style class
241 class Bar: # old-style class
247 def foo(self):
242 def foo(self):
248 return 2
243 return 2
249 """)
244 """)
250
245
251 def check_module_contents():
246 def check_module_contents():
252 nt.assert_equal(mod.x, 10)
247 nt.assert_equal(mod.x, 10)
253 nt.assert_false(hasattr(mod, 'z'))
248 nt.assert_false(hasattr(mod, 'z'))
254
249
255 nt.assert_equal(old_foo(0), 4) # superreload magic!
250 nt.assert_equal(old_foo(0), 4) # superreload magic!
256 nt.assert_equal(mod.foo(0), 4)
251 nt.assert_equal(mod.foo(0), 4)
257
252
258 obj = mod.Baz(9)
253 obj = mod.Baz(9)
259 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
254 nt.assert_equal(old_obj.bar(1), 11) # superreload magic!
260 nt.assert_equal(obj.bar(1), 11)
255 nt.assert_equal(obj.bar(1), 11)
261
256
262 nt.assert_equal(old_obj.quux, 43)
257 nt.assert_equal(old_obj.quux, 43)
263 nt.assert_equal(obj.quux, 43)
258 nt.assert_equal(obj.quux, 43)
264
259
265 nt.assert_false(hasattr(old_obj, 'zzz'))
260 nt.assert_false(hasattr(old_obj, 'zzz'))
266 nt.assert_false(hasattr(obj, 'zzz'))
261 nt.assert_false(hasattr(obj, 'zzz'))
267
262
268 obj2 = mod.Bar()
263 obj2 = mod.Bar()
269 nt.assert_equal(old_obj2.foo(), 2)
264 nt.assert_equal(old_obj2.foo(), 2)
270 nt.assert_equal(obj2.foo(), 2)
265 nt.assert_equal(obj2.foo(), 2)
271
266
272 self.shell.run_code("pass") # trigger reload
267 self.shell.run_code("pass") # trigger reload
273 check_module_contents()
268 check_module_contents()
274
269
275 #
270 #
276 # Another failure case: deleted file (shouldn't reload)
271 # Another failure case: deleted file (shouldn't reload)
277 #
272 #
278 os.unlink(mod_fn)
273 os.unlink(mod_fn)
279
274
280 self.shell.run_code("pass") # trigger reload
275 self.shell.run_code("pass") # trigger reload
281 check_module_contents()
276 check_module_contents()
282
277
283 #
278 #
284 # Disable autoreload and rewrite module: no reload should occur
279 # Disable autoreload and rewrite module: no reload should occur
285 #
280 #
286 if use_aimport:
281 if use_aimport:
287 self.shell.magic_aimport("-" + mod_name)
282 self.shell.magic_aimport("-" + mod_name)
288 stream = StringIO()
283 stream = StringIO()
289 self.shell.magic_aimport("", stream=stream)
284 self.shell.magic_aimport("", stream=stream)
290 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
285 nt.assert_true(("Modules to skip:\n%s" % mod_name) in
291 stream.getvalue())
286 stream.getvalue())
292
287
293 # This should succeed, although no such module exists
288 # This should succeed, although no such module exists
294 self.shell.magic_aimport("-tmpmod_as318989e89ds")
289 self.shell.magic_aimport("-tmpmod_as318989e89ds")
295 else:
290 else:
296 self.shell.magic_autoreload("0")
291 self.shell.magic_autoreload("0")
297
292
298 self.write_file(mod_fn, """
293 self.write_file(mod_fn, """
299 x = -99
294 x = -99
300 """)
295 """)
301
296
302 self.shell.run_code("pass") # trigger reload
297 self.shell.run_code("pass") # trigger reload
303 self.shell.run_code("pass")
298 self.shell.run_code("pass")
304 check_module_contents()
299 check_module_contents()
305
300
306 #
301 #
307 # Re-enable autoreload: reload should now occur
302 # Re-enable autoreload: reload should now occur
308 #
303 #
309 if use_aimport:
304 if use_aimport:
310 self.shell.magic_aimport(mod_name)
305 self.shell.magic_aimport(mod_name)
311 else:
306 else:
312 self.shell.magic_autoreload("")
307 self.shell.magic_autoreload("")
313
308
314 self.shell.run_code("pass") # trigger reload
309 self.shell.run_code("pass") # trigger reload
315 nt.assert_equal(mod.x, -99)
310 nt.assert_equal(mod.x, -99)
316
311
317 def test_smoketest_aimport(self):
312 def test_smoketest_aimport(self):
318 self._check_smoketest(use_aimport=True)
313 self._check_smoketest(use_aimport=True)
319
314
320 def test_smoketest_autoreload(self):
315 def test_smoketest_autoreload(self):
321 self._check_smoketest(use_aimport=False)
316 self._check_smoketest(use_aimport=False)
@@ -1,201 +1,201 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tools for handling LaTeX."""
2 """Tools for handling LaTeX."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 from io import BytesIO, open
7 from io import BytesIO, open
8 import os
8 import os
9 import tempfile
9 import tempfile
10 import shutil
10 import shutil
11 import subprocess
11 import subprocess
12 from base64 import encodebytes
12 from base64 import encodebytes
13
13
14 from IPython.utils.process import find_cmd, FindCmdError
14 from IPython.utils.process import find_cmd, FindCmdError
15 from traitlets.config import get_config
15 from traitlets.config import get_config
16 from traitlets.config.configurable import SingletonConfigurable
16 from traitlets.config.configurable import SingletonConfigurable
17 from traitlets import List, Bool, Unicode
17 from traitlets import List, Bool, Unicode
18 from IPython.utils.py3compat import cast_unicode, cast_unicode_py2 as u, PY3
18 from IPython.utils.py3compat import cast_unicode, cast_unicode_py2 as u
19
19
20
20
21 class LaTeXTool(SingletonConfigurable):
21 class LaTeXTool(SingletonConfigurable):
22 """An object to store configuration of the LaTeX tool."""
22 """An object to store configuration of the LaTeX tool."""
23 def _config_default(self):
23 def _config_default(self):
24 return get_config()
24 return get_config()
25
25
26 backends = List(
26 backends = List(
27 Unicode(), ["matplotlib", "dvipng"],
27 Unicode(), ["matplotlib", "dvipng"],
28 help="Preferred backend to draw LaTeX math equations. "
28 help="Preferred backend to draw LaTeX math equations. "
29 "Backends in the list are checked one by one and the first "
29 "Backends in the list are checked one by one and the first "
30 "usable one is used. Note that `matplotlib` backend "
30 "usable one is used. Note that `matplotlib` backend "
31 "is usable only for inline style equations. To draw "
31 "is usable only for inline style equations. To draw "
32 "display style equations, `dvipng` backend must be specified. ",
32 "display style equations, `dvipng` backend must be specified. ",
33 # It is a List instead of Enum, to make configuration more
33 # It is a List instead of Enum, to make configuration more
34 # flexible. For example, to use matplotlib mainly but dvipng
34 # flexible. For example, to use matplotlib mainly but dvipng
35 # for display style, the default ["matplotlib", "dvipng"] can
35 # for display style, the default ["matplotlib", "dvipng"] can
36 # be used. To NOT use dvipng so that other repr such as
36 # be used. To NOT use dvipng so that other repr such as
37 # unicode pretty printing is used, you can use ["matplotlib"].
37 # unicode pretty printing is used, you can use ["matplotlib"].
38 ).tag(config=True)
38 ).tag(config=True)
39
39
40 use_breqn = Bool(
40 use_breqn = Bool(
41 True,
41 True,
42 help="Use breqn.sty to automatically break long equations. "
42 help="Use breqn.sty to automatically break long equations. "
43 "This configuration takes effect only for dvipng backend.",
43 "This configuration takes effect only for dvipng backend.",
44 ).tag(config=True)
44 ).tag(config=True)
45
45
46 packages = List(
46 packages = List(
47 ['amsmath', 'amsthm', 'amssymb', 'bm'],
47 ['amsmath', 'amsthm', 'amssymb', 'bm'],
48 help="A list of packages to use for dvipng backend. "
48 help="A list of packages to use for dvipng backend. "
49 "'breqn' will be automatically appended when use_breqn=True.",
49 "'breqn' will be automatically appended when use_breqn=True.",
50 ).tag(config=True)
50 ).tag(config=True)
51
51
52 preamble = Unicode(
52 preamble = Unicode(
53 help="Additional preamble to use when generating LaTeX source "
53 help="Additional preamble to use when generating LaTeX source "
54 "for dvipng backend.",
54 "for dvipng backend.",
55 ).tag(config=True)
55 ).tag(config=True)
56
56
57
57
58 def latex_to_png(s, encode=False, backend=None, wrap=False):
58 def latex_to_png(s, encode=False, backend=None, wrap=False):
59 """Render a LaTeX string to PNG.
59 """Render a LaTeX string to PNG.
60
60
61 Parameters
61 Parameters
62 ----------
62 ----------
63 s : str
63 s : str
64 The raw string containing valid inline LaTeX.
64 The raw string containing valid inline LaTeX.
65 encode : bool, optional
65 encode : bool, optional
66 Should the PNG data base64 encoded to make it JSON'able.
66 Should the PNG data base64 encoded to make it JSON'able.
67 backend : {matplotlib, dvipng}
67 backend : {matplotlib, dvipng}
68 Backend for producing PNG data.
68 Backend for producing PNG data.
69 wrap : bool
69 wrap : bool
70 If true, Automatically wrap `s` as a LaTeX equation.
70 If true, Automatically wrap `s` as a LaTeX equation.
71
71
72 None is returned when the backend cannot be used.
72 None is returned when the backend cannot be used.
73
73
74 """
74 """
75 s = cast_unicode(s)
75 s = cast_unicode(s)
76 allowed_backends = LaTeXTool.instance().backends
76 allowed_backends = LaTeXTool.instance().backends
77 if backend is None:
77 if backend is None:
78 backend = allowed_backends[0]
78 backend = allowed_backends[0]
79 if backend not in allowed_backends:
79 if backend not in allowed_backends:
80 return None
80 return None
81 if backend == 'matplotlib':
81 if backend == 'matplotlib':
82 f = latex_to_png_mpl
82 f = latex_to_png_mpl
83 elif backend == 'dvipng':
83 elif backend == 'dvipng':
84 f = latex_to_png_dvipng
84 f = latex_to_png_dvipng
85 else:
85 else:
86 raise ValueError('No such backend {0}'.format(backend))
86 raise ValueError('No such backend {0}'.format(backend))
87 bin_data = f(s, wrap)
87 bin_data = f(s, wrap)
88 if encode and bin_data:
88 if encode and bin_data:
89 bin_data = encodebytes(bin_data)
89 bin_data = encodebytes(bin_data)
90 return bin_data
90 return bin_data
91
91
92
92
93 def latex_to_png_mpl(s, wrap):
93 def latex_to_png_mpl(s, wrap):
94 try:
94 try:
95 from matplotlib import mathtext
95 from matplotlib import mathtext
96 from pyparsing import ParseFatalException
96 from pyparsing import ParseFatalException
97 except ImportError:
97 except ImportError:
98 return None
98 return None
99
99
100 # mpl mathtext doesn't support display math, force inline
100 # mpl mathtext doesn't support display math, force inline
101 s = s.replace('$$', '$')
101 s = s.replace('$$', '$')
102 if wrap:
102 if wrap:
103 s = u'${0}$'.format(s)
103 s = u'${0}$'.format(s)
104
104
105 try:
105 try:
106 mt = mathtext.MathTextParser('bitmap')
106 mt = mathtext.MathTextParser('bitmap')
107 f = BytesIO()
107 f = BytesIO()
108 mt.to_png(f, s, fontsize=12)
108 mt.to_png(f, s, fontsize=12)
109 return f.getvalue()
109 return f.getvalue()
110 except (ValueError, RuntimeError, ParseFatalException):
110 except (ValueError, RuntimeError, ParseFatalException):
111 return None
111 return None
112
112
113
113
114 def latex_to_png_dvipng(s, wrap):
114 def latex_to_png_dvipng(s, wrap):
115 try:
115 try:
116 find_cmd('latex')
116 find_cmd('latex')
117 find_cmd('dvipng')
117 find_cmd('dvipng')
118 except FindCmdError:
118 except FindCmdError:
119 return None
119 return None
120 try:
120 try:
121 workdir = tempfile.mkdtemp()
121 workdir = tempfile.mkdtemp()
122 tmpfile = os.path.join(workdir, "tmp.tex")
122 tmpfile = os.path.join(workdir, "tmp.tex")
123 dvifile = os.path.join(workdir, "tmp.dvi")
123 dvifile = os.path.join(workdir, "tmp.dvi")
124 outfile = os.path.join(workdir, "tmp.png")
124 outfile = os.path.join(workdir, "tmp.png")
125
125
126 with open(tmpfile, "w", encoding='utf8') as f:
126 with open(tmpfile, "w", encoding='utf8') as f:
127 f.writelines(genelatex(s, wrap))
127 f.writelines(genelatex(s, wrap))
128
128
129 with open(os.devnull, 'wb') as devnull:
129 with open(os.devnull, 'wb') as devnull:
130 subprocess.check_call(
130 subprocess.check_call(
131 ["latex", "-halt-on-error", "-interaction", "batchmode", tmpfile],
131 ["latex", "-halt-on-error", "-interaction", "batchmode", tmpfile],
132 cwd=workdir, stdout=devnull, stderr=devnull)
132 cwd=workdir, stdout=devnull, stderr=devnull)
133
133
134 subprocess.check_call(
134 subprocess.check_call(
135 ["dvipng", "-T", "tight", "-x", "1500", "-z", "9",
135 ["dvipng", "-T", "tight", "-x", "1500", "-z", "9",
136 "-bg", "transparent", "-o", outfile, dvifile], cwd=workdir,
136 "-bg", "transparent", "-o", outfile, dvifile], cwd=workdir,
137 stdout=devnull, stderr=devnull)
137 stdout=devnull, stderr=devnull)
138
138
139 with open(outfile, "rb") as f:
139 with open(outfile, "rb") as f:
140 return f.read()
140 return f.read()
141 except subprocess.CalledProcessError:
141 except subprocess.CalledProcessError:
142 return None
142 return None
143 finally:
143 finally:
144 shutil.rmtree(workdir)
144 shutil.rmtree(workdir)
145
145
146
146
147 def kpsewhich(filename):
147 def kpsewhich(filename):
148 """Invoke kpsewhich command with an argument `filename`."""
148 """Invoke kpsewhich command with an argument `filename`."""
149 try:
149 try:
150 find_cmd("kpsewhich")
150 find_cmd("kpsewhich")
151 proc = subprocess.Popen(
151 proc = subprocess.Popen(
152 ["kpsewhich", filename],
152 ["kpsewhich", filename],
153 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
153 stdout=subprocess.PIPE, stderr=subprocess.PIPE)
154 (stdout, stderr) = proc.communicate()
154 (stdout, stderr) = proc.communicate()
155 return stdout.strip().decode('utf8', 'replace')
155 return stdout.strip().decode('utf8', 'replace')
156 except FindCmdError:
156 except FindCmdError:
157 pass
157 pass
158
158
159
159
160 def genelatex(body, wrap):
160 def genelatex(body, wrap):
161 """Generate LaTeX document for dvipng backend."""
161 """Generate LaTeX document for dvipng backend."""
162 lt = LaTeXTool.instance()
162 lt = LaTeXTool.instance()
163 breqn = wrap and lt.use_breqn and kpsewhich("breqn.sty")
163 breqn = wrap and lt.use_breqn and kpsewhich("breqn.sty")
164 yield u(r'\documentclass{article}')
164 yield u(r'\documentclass{article}')
165 packages = lt.packages
165 packages = lt.packages
166 if breqn:
166 if breqn:
167 packages = packages + ['breqn']
167 packages = packages + ['breqn']
168 for pack in packages:
168 for pack in packages:
169 yield u(r'\usepackage{{{0}}}'.format(pack))
169 yield u(r'\usepackage{{{0}}}'.format(pack))
170 yield u(r'\pagestyle{empty}')
170 yield u(r'\pagestyle{empty}')
171 if lt.preamble:
171 if lt.preamble:
172 yield lt.preamble
172 yield lt.preamble
173 yield u(r'\begin{document}')
173 yield u(r'\begin{document}')
174 if breqn:
174 if breqn:
175 yield u(r'\begin{dmath*}')
175 yield u(r'\begin{dmath*}')
176 yield body
176 yield body
177 yield u(r'\end{dmath*}')
177 yield u(r'\end{dmath*}')
178 elif wrap:
178 elif wrap:
179 yield u'$${0}$$'.format(body)
179 yield u'$${0}$$'.format(body)
180 else:
180 else:
181 yield body
181 yield body
182 yield u'\end{document}'
182 yield u'\end{document}'
183
183
184
184
185 _data_uri_template_png = u"""<img src="data:image/png;base64,%s" alt=%s />"""
185 _data_uri_template_png = u"""<img src="data:image/png;base64,%s" alt=%s />"""
186
186
187 def latex_to_html(s, alt='image'):
187 def latex_to_html(s, alt='image'):
188 """Render LaTeX to HTML with embedded PNG data using data URIs.
188 """Render LaTeX to HTML with embedded PNG data using data URIs.
189
189
190 Parameters
190 Parameters
191 ----------
191 ----------
192 s : str
192 s : str
193 The raw string containing valid inline LateX.
193 The raw string containing valid inline LateX.
194 alt : str
194 alt : str
195 The alt text to use for the HTML.
195 The alt text to use for the HTML.
196 """
196 """
197 base64_data = latex_to_png(s, encode=True).decode('ascii')
197 base64_data = latex_to_png(s, encode=True).decode('ascii')
198 if base64_data:
198 if base64_data:
199 return _data_uri_template_png % (base64_data, alt)
199 return _data_uri_template_png % (base64_data, alt)
200
200
201
201
@@ -1,866 +1,858 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Python advanced pretty printer. This pretty printer is intended to
3 Python advanced pretty printer. This pretty printer is intended to
4 replace the old `pprint` python module which does not allow developers
4 replace the old `pprint` python module which does not allow developers
5 to provide their own pretty print callbacks.
5 to provide their own pretty print callbacks.
6
6
7 This module is based on ruby's `prettyprint.rb` library by `Tanaka Akira`.
7 This module is based on ruby's `prettyprint.rb` library by `Tanaka Akira`.
8
8
9
9
10 Example Usage
10 Example Usage
11 -------------
11 -------------
12
12
13 To directly print the representation of an object use `pprint`::
13 To directly print the representation of an object use `pprint`::
14
14
15 from pretty import pprint
15 from pretty import pprint
16 pprint(complex_object)
16 pprint(complex_object)
17
17
18 To get a string of the output use `pretty`::
18 To get a string of the output use `pretty`::
19
19
20 from pretty import pretty
20 from pretty import pretty
21 string = pretty(complex_object)
21 string = pretty(complex_object)
22
22
23
23
24 Extending
24 Extending
25 ---------
25 ---------
26
26
27 The pretty library allows developers to add pretty printing rules for their
27 The pretty library allows developers to add pretty printing rules for their
28 own objects. This process is straightforward. All you have to do is to
28 own objects. This process is straightforward. All you have to do is to
29 add a `_repr_pretty_` method to your object and call the methods on the
29 add a `_repr_pretty_` method to your object and call the methods on the
30 pretty printer passed::
30 pretty printer passed::
31
31
32 class MyObject(object):
32 class MyObject(object):
33
33
34 def _repr_pretty_(self, p, cycle):
34 def _repr_pretty_(self, p, cycle):
35 ...
35 ...
36
36
37 Here is an example implementation of a `_repr_pretty_` method for a list
37 Here is an example implementation of a `_repr_pretty_` method for a list
38 subclass::
38 subclass::
39
39
40 class MyList(list):
40 class MyList(list):
41
41
42 def _repr_pretty_(self, p, cycle):
42 def _repr_pretty_(self, p, cycle):
43 if cycle:
43 if cycle:
44 p.text('MyList(...)')
44 p.text('MyList(...)')
45 else:
45 else:
46 with p.group(8, 'MyList([', '])'):
46 with p.group(8, 'MyList([', '])'):
47 for idx, item in enumerate(self):
47 for idx, item in enumerate(self):
48 if idx:
48 if idx:
49 p.text(',')
49 p.text(',')
50 p.breakable()
50 p.breakable()
51 p.pretty(item)
51 p.pretty(item)
52
52
53 The `cycle` parameter is `True` if pretty detected a cycle. You *have* to
53 The `cycle` parameter is `True` if pretty detected a cycle. You *have* to
54 react to that or the result is an infinite loop. `p.text()` just adds
54 react to that or the result is an infinite loop. `p.text()` just adds
55 non breaking text to the output, `p.breakable()` either adds a whitespace
55 non breaking text to the output, `p.breakable()` either adds a whitespace
56 or breaks here. If you pass it an argument it's used instead of the
56 or breaks here. If you pass it an argument it's used instead of the
57 default space. `p.pretty` prettyprints another object using the pretty print
57 default space. `p.pretty` prettyprints another object using the pretty print
58 method.
58 method.
59
59
60 The first parameter to the `group` function specifies the extra indentation
60 The first parameter to the `group` function specifies the extra indentation
61 of the next line. In this example the next item will either be on the same
61 of the next line. In this example the next item will either be on the same
62 line (if the items are short enough) or aligned with the right edge of the
62 line (if the items are short enough) or aligned with the right edge of the
63 opening bracket of `MyList`.
63 opening bracket of `MyList`.
64
64
65 If you just want to indent something you can use the group function
65 If you just want to indent something you can use the group function
66 without open / close parameters. You can also use this code::
66 without open / close parameters. You can also use this code::
67
67
68 with p.indent(2):
68 with p.indent(2):
69 ...
69 ...
70
70
71 Inheritance diagram:
71 Inheritance diagram:
72
72
73 .. inheritance-diagram:: IPython.lib.pretty
73 .. inheritance-diagram:: IPython.lib.pretty
74 :parts: 3
74 :parts: 3
75
75
76 :copyright: 2007 by Armin Ronacher.
76 :copyright: 2007 by Armin Ronacher.
77 Portions (c) 2009 by Robert Kern.
77 Portions (c) 2009 by Robert Kern.
78 :license: BSD License.
78 :license: BSD License.
79 """
79 """
80 from contextlib import contextmanager
80 from contextlib import contextmanager
81 import sys
81 import sys
82 import types
82 import types
83 import re
83 import re
84 import datetime
84 import datetime
85 from collections import deque
85 from collections import deque
86
86
87 from IPython.utils.py3compat import PY3, PYPY, cast_unicode
87 from IPython.utils.py3compat import PYPY, cast_unicode
88 from IPython.utils.encoding import get_stream_enc
88 from IPython.utils.encoding import get_stream_enc
89
89
90 from io import StringIO
90 from io import StringIO
91
91
92
92
93 __all__ = ['pretty', 'pprint', 'PrettyPrinter', 'RepresentationPrinter',
93 __all__ = ['pretty', 'pprint', 'PrettyPrinter', 'RepresentationPrinter',
94 'for_type', 'for_type_by_name']
94 'for_type', 'for_type_by_name']
95
95
96
96
97 MAX_SEQ_LENGTH = 1000
97 MAX_SEQ_LENGTH = 1000
98 _re_pattern_type = type(re.compile(''))
98 _re_pattern_type = type(re.compile(''))
99
99
100 def _safe_getattr(obj, attr, default=None):
100 def _safe_getattr(obj, attr, default=None):
101 """Safe version of getattr.
101 """Safe version of getattr.
102
102
103 Same as getattr, but will return ``default`` on any Exception,
103 Same as getattr, but will return ``default`` on any Exception,
104 rather than raising.
104 rather than raising.
105 """
105 """
106 try:
106 try:
107 return getattr(obj, attr, default)
107 return getattr(obj, attr, default)
108 except Exception:
108 except Exception:
109 return default
109 return default
110
110
111 if PY3:
111 CUnicodeIO = StringIO
112 CUnicodeIO = StringIO
113 else:
114 class CUnicodeIO(StringIO):
115 """StringIO that casts str to unicode on Python 2"""
116 def write(self, text):
117 return super(CUnicodeIO, self).write(
118 cast_unicode(text, encoding=get_stream_enc(sys.stdout)))
119
120
112
121 def pretty(obj, verbose=False, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
113 def pretty(obj, verbose=False, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
122 """
114 """
123 Pretty print the object's representation.
115 Pretty print the object's representation.
124 """
116 """
125 stream = CUnicodeIO()
117 stream = CUnicodeIO()
126 printer = RepresentationPrinter(stream, verbose, max_width, newline, max_seq_length=max_seq_length)
118 printer = RepresentationPrinter(stream, verbose, max_width, newline, max_seq_length=max_seq_length)
127 printer.pretty(obj)
119 printer.pretty(obj)
128 printer.flush()
120 printer.flush()
129 return stream.getvalue()
121 return stream.getvalue()
130
122
131
123
132 def pprint(obj, verbose=False, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
124 def pprint(obj, verbose=False, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
133 """
125 """
134 Like `pretty` but print to stdout.
126 Like `pretty` but print to stdout.
135 """
127 """
136 printer = RepresentationPrinter(sys.stdout, verbose, max_width, newline, max_seq_length=max_seq_length)
128 printer = RepresentationPrinter(sys.stdout, verbose, max_width, newline, max_seq_length=max_seq_length)
137 printer.pretty(obj)
129 printer.pretty(obj)
138 printer.flush()
130 printer.flush()
139 sys.stdout.write(newline)
131 sys.stdout.write(newline)
140 sys.stdout.flush()
132 sys.stdout.flush()
141
133
142 class _PrettyPrinterBase(object):
134 class _PrettyPrinterBase(object):
143
135
144 @contextmanager
136 @contextmanager
145 def indent(self, indent):
137 def indent(self, indent):
146 """with statement support for indenting/dedenting."""
138 """with statement support for indenting/dedenting."""
147 self.indentation += indent
139 self.indentation += indent
148 try:
140 try:
149 yield
141 yield
150 finally:
142 finally:
151 self.indentation -= indent
143 self.indentation -= indent
152
144
153 @contextmanager
145 @contextmanager
154 def group(self, indent=0, open='', close=''):
146 def group(self, indent=0, open='', close=''):
155 """like begin_group / end_group but for the with statement."""
147 """like begin_group / end_group but for the with statement."""
156 self.begin_group(indent, open)
148 self.begin_group(indent, open)
157 try:
149 try:
158 yield
150 yield
159 finally:
151 finally:
160 self.end_group(indent, close)
152 self.end_group(indent, close)
161
153
162 class PrettyPrinter(_PrettyPrinterBase):
154 class PrettyPrinter(_PrettyPrinterBase):
163 """
155 """
164 Baseclass for the `RepresentationPrinter` prettyprinter that is used to
156 Baseclass for the `RepresentationPrinter` prettyprinter that is used to
165 generate pretty reprs of objects. Contrary to the `RepresentationPrinter`
157 generate pretty reprs of objects. Contrary to the `RepresentationPrinter`
166 this printer knows nothing about the default pprinters or the `_repr_pretty_`
158 this printer knows nothing about the default pprinters or the `_repr_pretty_`
167 callback method.
159 callback method.
168 """
160 """
169
161
170 def __init__(self, output, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
162 def __init__(self, output, max_width=79, newline='\n', max_seq_length=MAX_SEQ_LENGTH):
171 self.output = output
163 self.output = output
172 self.max_width = max_width
164 self.max_width = max_width
173 self.newline = newline
165 self.newline = newline
174 self.max_seq_length = max_seq_length
166 self.max_seq_length = max_seq_length
175 self.output_width = 0
167 self.output_width = 0
176 self.buffer_width = 0
168 self.buffer_width = 0
177 self.buffer = deque()
169 self.buffer = deque()
178
170
179 root_group = Group(0)
171 root_group = Group(0)
180 self.group_stack = [root_group]
172 self.group_stack = [root_group]
181 self.group_queue = GroupQueue(root_group)
173 self.group_queue = GroupQueue(root_group)
182 self.indentation = 0
174 self.indentation = 0
183
175
184 def _break_outer_groups(self):
176 def _break_outer_groups(self):
185 while self.max_width < self.output_width + self.buffer_width:
177 while self.max_width < self.output_width + self.buffer_width:
186 group = self.group_queue.deq()
178 group = self.group_queue.deq()
187 if not group:
179 if not group:
188 return
180 return
189 while group.breakables:
181 while group.breakables:
190 x = self.buffer.popleft()
182 x = self.buffer.popleft()
191 self.output_width = x.output(self.output, self.output_width)
183 self.output_width = x.output(self.output, self.output_width)
192 self.buffer_width -= x.width
184 self.buffer_width -= x.width
193 while self.buffer and isinstance(self.buffer[0], Text):
185 while self.buffer and isinstance(self.buffer[0], Text):
194 x = self.buffer.popleft()
186 x = self.buffer.popleft()
195 self.output_width = x.output(self.output, self.output_width)
187 self.output_width = x.output(self.output, self.output_width)
196 self.buffer_width -= x.width
188 self.buffer_width -= x.width
197
189
198 def text(self, obj):
190 def text(self, obj):
199 """Add literal text to the output."""
191 """Add literal text to the output."""
200 width = len(obj)
192 width = len(obj)
201 if self.buffer:
193 if self.buffer:
202 text = self.buffer[-1]
194 text = self.buffer[-1]
203 if not isinstance(text, Text):
195 if not isinstance(text, Text):
204 text = Text()
196 text = Text()
205 self.buffer.append(text)
197 self.buffer.append(text)
206 text.add(obj, width)
198 text.add(obj, width)
207 self.buffer_width += width
199 self.buffer_width += width
208 self._break_outer_groups()
200 self._break_outer_groups()
209 else:
201 else:
210 self.output.write(obj)
202 self.output.write(obj)
211 self.output_width += width
203 self.output_width += width
212
204
213 def breakable(self, sep=' '):
205 def breakable(self, sep=' '):
214 """
206 """
215 Add a breakable separator to the output. This does not mean that it
207 Add a breakable separator to the output. This does not mean that it
216 will automatically break here. If no breaking on this position takes
208 will automatically break here. If no breaking on this position takes
217 place the `sep` is inserted which default to one space.
209 place the `sep` is inserted which default to one space.
218 """
210 """
219 width = len(sep)
211 width = len(sep)
220 group = self.group_stack[-1]
212 group = self.group_stack[-1]
221 if group.want_break:
213 if group.want_break:
222 self.flush()
214 self.flush()
223 self.output.write(self.newline)
215 self.output.write(self.newline)
224 self.output.write(' ' * self.indentation)
216 self.output.write(' ' * self.indentation)
225 self.output_width = self.indentation
217 self.output_width = self.indentation
226 self.buffer_width = 0
218 self.buffer_width = 0
227 else:
219 else:
228 self.buffer.append(Breakable(sep, width, self))
220 self.buffer.append(Breakable(sep, width, self))
229 self.buffer_width += width
221 self.buffer_width += width
230 self._break_outer_groups()
222 self._break_outer_groups()
231
223
232 def break_(self):
224 def break_(self):
233 """
225 """
234 Explicitly insert a newline into the output, maintaining correct indentation.
226 Explicitly insert a newline into the output, maintaining correct indentation.
235 """
227 """
236 self.flush()
228 self.flush()
237 self.output.write(self.newline)
229 self.output.write(self.newline)
238 self.output.write(' ' * self.indentation)
230 self.output.write(' ' * self.indentation)
239 self.output_width = self.indentation
231 self.output_width = self.indentation
240 self.buffer_width = 0
232 self.buffer_width = 0
241
233
242
234
243 def begin_group(self, indent=0, open=''):
235 def begin_group(self, indent=0, open=''):
244 """
236 """
245 Begin a group. If you want support for python < 2.5 which doesn't has
237 Begin a group. If you want support for python < 2.5 which doesn't has
246 the with statement this is the preferred way:
238 the with statement this is the preferred way:
247
239
248 p.begin_group(1, '{')
240 p.begin_group(1, '{')
249 ...
241 ...
250 p.end_group(1, '}')
242 p.end_group(1, '}')
251
243
252 The python 2.5 expression would be this:
244 The python 2.5 expression would be this:
253
245
254 with p.group(1, '{', '}'):
246 with p.group(1, '{', '}'):
255 ...
247 ...
256
248
257 The first parameter specifies the indentation for the next line (usually
249 The first parameter specifies the indentation for the next line (usually
258 the width of the opening text), the second the opening text. All
250 the width of the opening text), the second the opening text. All
259 parameters are optional.
251 parameters are optional.
260 """
252 """
261 if open:
253 if open:
262 self.text(open)
254 self.text(open)
263 group = Group(self.group_stack[-1].depth + 1)
255 group = Group(self.group_stack[-1].depth + 1)
264 self.group_stack.append(group)
256 self.group_stack.append(group)
265 self.group_queue.enq(group)
257 self.group_queue.enq(group)
266 self.indentation += indent
258 self.indentation += indent
267
259
268 def _enumerate(self, seq):
260 def _enumerate(self, seq):
269 """like enumerate, but with an upper limit on the number of items"""
261 """like enumerate, but with an upper limit on the number of items"""
270 for idx, x in enumerate(seq):
262 for idx, x in enumerate(seq):
271 if self.max_seq_length and idx >= self.max_seq_length:
263 if self.max_seq_length and idx >= self.max_seq_length:
272 self.text(',')
264 self.text(',')
273 self.breakable()
265 self.breakable()
274 self.text('...')
266 self.text('...')
275 return
267 return
276 yield idx, x
268 yield idx, x
277
269
278 def end_group(self, dedent=0, close=''):
270 def end_group(self, dedent=0, close=''):
279 """End a group. See `begin_group` for more details."""
271 """End a group. See `begin_group` for more details."""
280 self.indentation -= dedent
272 self.indentation -= dedent
281 group = self.group_stack.pop()
273 group = self.group_stack.pop()
282 if not group.breakables:
274 if not group.breakables:
283 self.group_queue.remove(group)
275 self.group_queue.remove(group)
284 if close:
276 if close:
285 self.text(close)
277 self.text(close)
286
278
287 def flush(self):
279 def flush(self):
288 """Flush data that is left in the buffer."""
280 """Flush data that is left in the buffer."""
289 for data in self.buffer:
281 for data in self.buffer:
290 self.output_width += data.output(self.output, self.output_width)
282 self.output_width += data.output(self.output, self.output_width)
291 self.buffer.clear()
283 self.buffer.clear()
292 self.buffer_width = 0
284 self.buffer_width = 0
293
285
294
286
295 def _get_mro(obj_class):
287 def _get_mro(obj_class):
296 """ Get a reasonable method resolution order of a class and its superclasses
288 """ Get a reasonable method resolution order of a class and its superclasses
297 for both old-style and new-style classes.
289 for both old-style and new-style classes.
298 """
290 """
299 if not hasattr(obj_class, '__mro__'):
291 if not hasattr(obj_class, '__mro__'):
300 # Old-style class. Mix in object to make a fake new-style class.
292 # Old-style class. Mix in object to make a fake new-style class.
301 try:
293 try:
302 obj_class = type(obj_class.__name__, (obj_class, object), {})
294 obj_class = type(obj_class.__name__, (obj_class, object), {})
303 except TypeError:
295 except TypeError:
304 # Old-style extension type that does not descend from object.
296 # Old-style extension type that does not descend from object.
305 # FIXME: try to construct a more thorough MRO.
297 # FIXME: try to construct a more thorough MRO.
306 mro = [obj_class]
298 mro = [obj_class]
307 else:
299 else:
308 mro = obj_class.__mro__[1:-1]
300 mro = obj_class.__mro__[1:-1]
309 else:
301 else:
310 mro = obj_class.__mro__
302 mro = obj_class.__mro__
311 return mro
303 return mro
312
304
313
305
314 class RepresentationPrinter(PrettyPrinter):
306 class RepresentationPrinter(PrettyPrinter):
315 """
307 """
316 Special pretty printer that has a `pretty` method that calls the pretty
308 Special pretty printer that has a `pretty` method that calls the pretty
317 printer for a python object.
309 printer for a python object.
318
310
319 This class stores processing data on `self` so you must *never* use
311 This class stores processing data on `self` so you must *never* use
320 this class in a threaded environment. Always lock it or reinstanciate
312 this class in a threaded environment. Always lock it or reinstanciate
321 it.
313 it.
322
314
323 Instances also have a verbose flag callbacks can access to control their
315 Instances also have a verbose flag callbacks can access to control their
324 output. For example the default instance repr prints all attributes and
316 output. For example the default instance repr prints all attributes and
325 methods that are not prefixed by an underscore if the printer is in
317 methods that are not prefixed by an underscore if the printer is in
326 verbose mode.
318 verbose mode.
327 """
319 """
328
320
329 def __init__(self, output, verbose=False, max_width=79, newline='\n',
321 def __init__(self, output, verbose=False, max_width=79, newline='\n',
330 singleton_pprinters=None, type_pprinters=None, deferred_pprinters=None,
322 singleton_pprinters=None, type_pprinters=None, deferred_pprinters=None,
331 max_seq_length=MAX_SEQ_LENGTH):
323 max_seq_length=MAX_SEQ_LENGTH):
332
324
333 PrettyPrinter.__init__(self, output, max_width, newline, max_seq_length=max_seq_length)
325 PrettyPrinter.__init__(self, output, max_width, newline, max_seq_length=max_seq_length)
334 self.verbose = verbose
326 self.verbose = verbose
335 self.stack = []
327 self.stack = []
336 if singleton_pprinters is None:
328 if singleton_pprinters is None:
337 singleton_pprinters = _singleton_pprinters.copy()
329 singleton_pprinters = _singleton_pprinters.copy()
338 self.singleton_pprinters = singleton_pprinters
330 self.singleton_pprinters = singleton_pprinters
339 if type_pprinters is None:
331 if type_pprinters is None:
340 type_pprinters = _type_pprinters.copy()
332 type_pprinters = _type_pprinters.copy()
341 self.type_pprinters = type_pprinters
333 self.type_pprinters = type_pprinters
342 if deferred_pprinters is None:
334 if deferred_pprinters is None:
343 deferred_pprinters = _deferred_type_pprinters.copy()
335 deferred_pprinters = _deferred_type_pprinters.copy()
344 self.deferred_pprinters = deferred_pprinters
336 self.deferred_pprinters = deferred_pprinters
345
337
346 def pretty(self, obj):
338 def pretty(self, obj):
347 """Pretty print the given object."""
339 """Pretty print the given object."""
348 obj_id = id(obj)
340 obj_id = id(obj)
349 cycle = obj_id in self.stack
341 cycle = obj_id in self.stack
350 self.stack.append(obj_id)
342 self.stack.append(obj_id)
351 self.begin_group()
343 self.begin_group()
352 try:
344 try:
353 obj_class = _safe_getattr(obj, '__class__', None) or type(obj)
345 obj_class = _safe_getattr(obj, '__class__', None) or type(obj)
354 # First try to find registered singleton printers for the type.
346 # First try to find registered singleton printers for the type.
355 try:
347 try:
356 printer = self.singleton_pprinters[obj_id]
348 printer = self.singleton_pprinters[obj_id]
357 except (TypeError, KeyError):
349 except (TypeError, KeyError):
358 pass
350 pass
359 else:
351 else:
360 return printer(obj, self, cycle)
352 return printer(obj, self, cycle)
361 # Next walk the mro and check for either:
353 # Next walk the mro and check for either:
362 # 1) a registered printer
354 # 1) a registered printer
363 # 2) a _repr_pretty_ method
355 # 2) a _repr_pretty_ method
364 for cls in _get_mro(obj_class):
356 for cls in _get_mro(obj_class):
365 if cls in self.type_pprinters:
357 if cls in self.type_pprinters:
366 # printer registered in self.type_pprinters
358 # printer registered in self.type_pprinters
367 return self.type_pprinters[cls](obj, self, cycle)
359 return self.type_pprinters[cls](obj, self, cycle)
368 else:
360 else:
369 # deferred printer
361 # deferred printer
370 printer = self._in_deferred_types(cls)
362 printer = self._in_deferred_types(cls)
371 if printer is not None:
363 if printer is not None:
372 return printer(obj, self, cycle)
364 return printer(obj, self, cycle)
373 else:
365 else:
374 # Finally look for special method names.
366 # Finally look for special method names.
375 # Some objects automatically create any requested
367 # Some objects automatically create any requested
376 # attribute. Try to ignore most of them by checking for
368 # attribute. Try to ignore most of them by checking for
377 # callability.
369 # callability.
378 if '_repr_pretty_' in cls.__dict__:
370 if '_repr_pretty_' in cls.__dict__:
379 meth = cls._repr_pretty_
371 meth = cls._repr_pretty_
380 if callable(meth):
372 if callable(meth):
381 return meth(obj, self, cycle)
373 return meth(obj, self, cycle)
382 return _default_pprint(obj, self, cycle)
374 return _default_pprint(obj, self, cycle)
383 finally:
375 finally:
384 self.end_group()
376 self.end_group()
385 self.stack.pop()
377 self.stack.pop()
386
378
387 def _in_deferred_types(self, cls):
379 def _in_deferred_types(self, cls):
388 """
380 """
389 Check if the given class is specified in the deferred type registry.
381 Check if the given class is specified in the deferred type registry.
390
382
391 Returns the printer from the registry if it exists, and None if the
383 Returns the printer from the registry if it exists, and None if the
392 class is not in the registry. Successful matches will be moved to the
384 class is not in the registry. Successful matches will be moved to the
393 regular type registry for future use.
385 regular type registry for future use.
394 """
386 """
395 mod = _safe_getattr(cls, '__module__', None)
387 mod = _safe_getattr(cls, '__module__', None)
396 name = _safe_getattr(cls, '__name__', None)
388 name = _safe_getattr(cls, '__name__', None)
397 key = (mod, name)
389 key = (mod, name)
398 printer = None
390 printer = None
399 if key in self.deferred_pprinters:
391 if key in self.deferred_pprinters:
400 # Move the printer over to the regular registry.
392 # Move the printer over to the regular registry.
401 printer = self.deferred_pprinters.pop(key)
393 printer = self.deferred_pprinters.pop(key)
402 self.type_pprinters[cls] = printer
394 self.type_pprinters[cls] = printer
403 return printer
395 return printer
404
396
405
397
406 class Printable(object):
398 class Printable(object):
407
399
408 def output(self, stream, output_width):
400 def output(self, stream, output_width):
409 return output_width
401 return output_width
410
402
411
403
412 class Text(Printable):
404 class Text(Printable):
413
405
414 def __init__(self):
406 def __init__(self):
415 self.objs = []
407 self.objs = []
416 self.width = 0
408 self.width = 0
417
409
418 def output(self, stream, output_width):
410 def output(self, stream, output_width):
419 for obj in self.objs:
411 for obj in self.objs:
420 stream.write(obj)
412 stream.write(obj)
421 return output_width + self.width
413 return output_width + self.width
422
414
423 def add(self, obj, width):
415 def add(self, obj, width):
424 self.objs.append(obj)
416 self.objs.append(obj)
425 self.width += width
417 self.width += width
426
418
427
419
428 class Breakable(Printable):
420 class Breakable(Printable):
429
421
430 def __init__(self, seq, width, pretty):
422 def __init__(self, seq, width, pretty):
431 self.obj = seq
423 self.obj = seq
432 self.width = width
424 self.width = width
433 self.pretty = pretty
425 self.pretty = pretty
434 self.indentation = pretty.indentation
426 self.indentation = pretty.indentation
435 self.group = pretty.group_stack[-1]
427 self.group = pretty.group_stack[-1]
436 self.group.breakables.append(self)
428 self.group.breakables.append(self)
437
429
438 def output(self, stream, output_width):
430 def output(self, stream, output_width):
439 self.group.breakables.popleft()
431 self.group.breakables.popleft()
440 if self.group.want_break:
432 if self.group.want_break:
441 stream.write(self.pretty.newline)
433 stream.write(self.pretty.newline)
442 stream.write(' ' * self.indentation)
434 stream.write(' ' * self.indentation)
443 return self.indentation
435 return self.indentation
444 if not self.group.breakables:
436 if not self.group.breakables:
445 self.pretty.group_queue.remove(self.group)
437 self.pretty.group_queue.remove(self.group)
446 stream.write(self.obj)
438 stream.write(self.obj)
447 return output_width + self.width
439 return output_width + self.width
448
440
449
441
450 class Group(Printable):
442 class Group(Printable):
451
443
452 def __init__(self, depth):
444 def __init__(self, depth):
453 self.depth = depth
445 self.depth = depth
454 self.breakables = deque()
446 self.breakables = deque()
455 self.want_break = False
447 self.want_break = False
456
448
457
449
458 class GroupQueue(object):
450 class GroupQueue(object):
459
451
460 def __init__(self, *groups):
452 def __init__(self, *groups):
461 self.queue = []
453 self.queue = []
462 for group in groups:
454 for group in groups:
463 self.enq(group)
455 self.enq(group)
464
456
465 def enq(self, group):
457 def enq(self, group):
466 depth = group.depth
458 depth = group.depth
467 while depth > len(self.queue) - 1:
459 while depth > len(self.queue) - 1:
468 self.queue.append([])
460 self.queue.append([])
469 self.queue[depth].append(group)
461 self.queue[depth].append(group)
470
462
471 def deq(self):
463 def deq(self):
472 for stack in self.queue:
464 for stack in self.queue:
473 for idx, group in enumerate(reversed(stack)):
465 for idx, group in enumerate(reversed(stack)):
474 if group.breakables:
466 if group.breakables:
475 del stack[idx]
467 del stack[idx]
476 group.want_break = True
468 group.want_break = True
477 return group
469 return group
478 for group in stack:
470 for group in stack:
479 group.want_break = True
471 group.want_break = True
480 del stack[:]
472 del stack[:]
481
473
482 def remove(self, group):
474 def remove(self, group):
483 try:
475 try:
484 self.queue[group.depth].remove(group)
476 self.queue[group.depth].remove(group)
485 except ValueError:
477 except ValueError:
486 pass
478 pass
487
479
488 try:
480 try:
489 _baseclass_reprs = (object.__repr__, types.InstanceType.__repr__)
481 _baseclass_reprs = (object.__repr__, types.InstanceType.__repr__)
490 except AttributeError: # Python 3
482 except AttributeError: # Python 3
491 _baseclass_reprs = (object.__repr__,)
483 _baseclass_reprs = (object.__repr__,)
492
484
493
485
494 def _default_pprint(obj, p, cycle):
486 def _default_pprint(obj, p, cycle):
495 """
487 """
496 The default print function. Used if an object does not provide one and
488 The default print function. Used if an object does not provide one and
497 it's none of the builtin objects.
489 it's none of the builtin objects.
498 """
490 """
499 klass = _safe_getattr(obj, '__class__', None) or type(obj)
491 klass = _safe_getattr(obj, '__class__', None) or type(obj)
500 if _safe_getattr(klass, '__repr__', None) not in _baseclass_reprs:
492 if _safe_getattr(klass, '__repr__', None) not in _baseclass_reprs:
501 # A user-provided repr. Find newlines and replace them with p.break_()
493 # A user-provided repr. Find newlines and replace them with p.break_()
502 _repr_pprint(obj, p, cycle)
494 _repr_pprint(obj, p, cycle)
503 return
495 return
504 p.begin_group(1, '<')
496 p.begin_group(1, '<')
505 p.pretty(klass)
497 p.pretty(klass)
506 p.text(' at 0x%x' % id(obj))
498 p.text(' at 0x%x' % id(obj))
507 if cycle:
499 if cycle:
508 p.text(' ...')
500 p.text(' ...')
509 elif p.verbose:
501 elif p.verbose:
510 first = True
502 first = True
511 for key in dir(obj):
503 for key in dir(obj):
512 if not key.startswith('_'):
504 if not key.startswith('_'):
513 try:
505 try:
514 value = getattr(obj, key)
506 value = getattr(obj, key)
515 except AttributeError:
507 except AttributeError:
516 continue
508 continue
517 if isinstance(value, types.MethodType):
509 if isinstance(value, types.MethodType):
518 continue
510 continue
519 if not first:
511 if not first:
520 p.text(',')
512 p.text(',')
521 p.breakable()
513 p.breakable()
522 p.text(key)
514 p.text(key)
523 p.text('=')
515 p.text('=')
524 step = len(key) + 1
516 step = len(key) + 1
525 p.indentation += step
517 p.indentation += step
526 p.pretty(value)
518 p.pretty(value)
527 p.indentation -= step
519 p.indentation -= step
528 first = False
520 first = False
529 p.end_group(1, '>')
521 p.end_group(1, '>')
530
522
531
523
532 def _seq_pprinter_factory(start, end, basetype):
524 def _seq_pprinter_factory(start, end, basetype):
533 """
525 """
534 Factory that returns a pprint function useful for sequences. Used by
526 Factory that returns a pprint function useful for sequences. Used by
535 the default pprint for tuples, dicts, and lists.
527 the default pprint for tuples, dicts, and lists.
536 """
528 """
537 def inner(obj, p, cycle):
529 def inner(obj, p, cycle):
538 typ = type(obj)
530 typ = type(obj)
539 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
531 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
540 # If the subclass provides its own repr, use it instead.
532 # If the subclass provides its own repr, use it instead.
541 return p.text(typ.__repr__(obj))
533 return p.text(typ.__repr__(obj))
542
534
543 if cycle:
535 if cycle:
544 return p.text(start + '...' + end)
536 return p.text(start + '...' + end)
545 step = len(start)
537 step = len(start)
546 p.begin_group(step, start)
538 p.begin_group(step, start)
547 for idx, x in p._enumerate(obj):
539 for idx, x in p._enumerate(obj):
548 if idx:
540 if idx:
549 p.text(',')
541 p.text(',')
550 p.breakable()
542 p.breakable()
551 p.pretty(x)
543 p.pretty(x)
552 if len(obj) == 1 and type(obj) is tuple:
544 if len(obj) == 1 and type(obj) is tuple:
553 # Special case for 1-item tuples.
545 # Special case for 1-item tuples.
554 p.text(',')
546 p.text(',')
555 p.end_group(step, end)
547 p.end_group(step, end)
556 return inner
548 return inner
557
549
558
550
559 def _set_pprinter_factory(start, end, basetype):
551 def _set_pprinter_factory(start, end, basetype):
560 """
552 """
561 Factory that returns a pprint function useful for sets and frozensets.
553 Factory that returns a pprint function useful for sets and frozensets.
562 """
554 """
563 def inner(obj, p, cycle):
555 def inner(obj, p, cycle):
564 typ = type(obj)
556 typ = type(obj)
565 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
557 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
566 # If the subclass provides its own repr, use it instead.
558 # If the subclass provides its own repr, use it instead.
567 return p.text(typ.__repr__(obj))
559 return p.text(typ.__repr__(obj))
568
560
569 if cycle:
561 if cycle:
570 return p.text(start + '...' + end)
562 return p.text(start + '...' + end)
571 if len(obj) == 0:
563 if len(obj) == 0:
572 # Special case.
564 # Special case.
573 p.text(basetype.__name__ + '()')
565 p.text(basetype.__name__ + '()')
574 else:
566 else:
575 step = len(start)
567 step = len(start)
576 p.begin_group(step, start)
568 p.begin_group(step, start)
577 # Like dictionary keys, we will try to sort the items if there aren't too many
569 # Like dictionary keys, we will try to sort the items if there aren't too many
578 items = obj
570 items = obj
579 if not (p.max_seq_length and len(obj) >= p.max_seq_length):
571 if not (p.max_seq_length and len(obj) >= p.max_seq_length):
580 try:
572 try:
581 items = sorted(obj)
573 items = sorted(obj)
582 except Exception:
574 except Exception:
583 # Sometimes the items don't sort.
575 # Sometimes the items don't sort.
584 pass
576 pass
585 for idx, x in p._enumerate(items):
577 for idx, x in p._enumerate(items):
586 if idx:
578 if idx:
587 p.text(',')
579 p.text(',')
588 p.breakable()
580 p.breakable()
589 p.pretty(x)
581 p.pretty(x)
590 p.end_group(step, end)
582 p.end_group(step, end)
591 return inner
583 return inner
592
584
593
585
594 def _dict_pprinter_factory(start, end, basetype=None):
586 def _dict_pprinter_factory(start, end, basetype=None):
595 """
587 """
596 Factory that returns a pprint function used by the default pprint of
588 Factory that returns a pprint function used by the default pprint of
597 dicts and dict proxies.
589 dicts and dict proxies.
598 """
590 """
599 def inner(obj, p, cycle):
591 def inner(obj, p, cycle):
600 typ = type(obj)
592 typ = type(obj)
601 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
593 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
602 # If the subclass provides its own repr, use it instead.
594 # If the subclass provides its own repr, use it instead.
603 return p.text(typ.__repr__(obj))
595 return p.text(typ.__repr__(obj))
604
596
605 if cycle:
597 if cycle:
606 return p.text('{...}')
598 return p.text('{...}')
607 step = len(start)
599 step = len(start)
608 p.begin_group(step, start)
600 p.begin_group(step, start)
609 keys = obj.keys()
601 keys = obj.keys()
610 # if dict isn't large enough to be truncated, sort keys before displaying
602 # if dict isn't large enough to be truncated, sort keys before displaying
611 if not (p.max_seq_length and len(obj) >= p.max_seq_length):
603 if not (p.max_seq_length and len(obj) >= p.max_seq_length):
612 try:
604 try:
613 keys = sorted(keys)
605 keys = sorted(keys)
614 except Exception:
606 except Exception:
615 # Sometimes the keys don't sort.
607 # Sometimes the keys don't sort.
616 pass
608 pass
617 for idx, key in p._enumerate(keys):
609 for idx, key in p._enumerate(keys):
618 if idx:
610 if idx:
619 p.text(',')
611 p.text(',')
620 p.breakable()
612 p.breakable()
621 p.pretty(key)
613 p.pretty(key)
622 p.text(': ')
614 p.text(': ')
623 p.pretty(obj[key])
615 p.pretty(obj[key])
624 p.end_group(step, end)
616 p.end_group(step, end)
625 return inner
617 return inner
626
618
627
619
628 def _super_pprint(obj, p, cycle):
620 def _super_pprint(obj, p, cycle):
629 """The pprint for the super type."""
621 """The pprint for the super type."""
630 p.begin_group(8, '<super: ')
622 p.begin_group(8, '<super: ')
631 p.pretty(obj.__thisclass__)
623 p.pretty(obj.__thisclass__)
632 p.text(',')
624 p.text(',')
633 p.breakable()
625 p.breakable()
634 if PYPY: # In PyPy, super() objects don't have __self__ attributes
626 if PYPY: # In PyPy, super() objects don't have __self__ attributes
635 dself = obj.__repr__.__self__
627 dself = obj.__repr__.__self__
636 p.pretty(None if dself is obj else dself)
628 p.pretty(None if dself is obj else dself)
637 else:
629 else:
638 p.pretty(obj.__self__)
630 p.pretty(obj.__self__)
639 p.end_group(8, '>')
631 p.end_group(8, '>')
640
632
641
633
642 def _re_pattern_pprint(obj, p, cycle):
634 def _re_pattern_pprint(obj, p, cycle):
643 """The pprint function for regular expression patterns."""
635 """The pprint function for regular expression patterns."""
644 p.text('re.compile(')
636 p.text('re.compile(')
645 pattern = repr(obj.pattern)
637 pattern = repr(obj.pattern)
646 if pattern[:1] in 'uU':
638 if pattern[:1] in 'uU':
647 pattern = pattern[1:]
639 pattern = pattern[1:]
648 prefix = 'ur'
640 prefix = 'ur'
649 else:
641 else:
650 prefix = 'r'
642 prefix = 'r'
651 pattern = prefix + pattern.replace('\\\\', '\\')
643 pattern = prefix + pattern.replace('\\\\', '\\')
652 p.text(pattern)
644 p.text(pattern)
653 if obj.flags:
645 if obj.flags:
654 p.text(',')
646 p.text(',')
655 p.breakable()
647 p.breakable()
656 done_one = False
648 done_one = False
657 for flag in ('TEMPLATE', 'IGNORECASE', 'LOCALE', 'MULTILINE', 'DOTALL',
649 for flag in ('TEMPLATE', 'IGNORECASE', 'LOCALE', 'MULTILINE', 'DOTALL',
658 'UNICODE', 'VERBOSE', 'DEBUG'):
650 'UNICODE', 'VERBOSE', 'DEBUG'):
659 if obj.flags & getattr(re, flag):
651 if obj.flags & getattr(re, flag):
660 if done_one:
652 if done_one:
661 p.text('|')
653 p.text('|')
662 p.text('re.' + flag)
654 p.text('re.' + flag)
663 done_one = True
655 done_one = True
664 p.text(')')
656 p.text(')')
665
657
666
658
667 def _type_pprint(obj, p, cycle):
659 def _type_pprint(obj, p, cycle):
668 """The pprint for classes and types."""
660 """The pprint for classes and types."""
669 # Heap allocated types might not have the module attribute,
661 # Heap allocated types might not have the module attribute,
670 # and others may set it to None.
662 # and others may set it to None.
671
663
672 # Checks for a __repr__ override in the metaclass. Can't compare the
664 # Checks for a __repr__ override in the metaclass. Can't compare the
673 # type(obj).__repr__ directly because in PyPy the representation function
665 # type(obj).__repr__ directly because in PyPy the representation function
674 # inherited from type isn't the same type.__repr__
666 # inherited from type isn't the same type.__repr__
675 if [m for m in _get_mro(type(obj)) if "__repr__" in vars(m)][:1] != [type]:
667 if [m for m in _get_mro(type(obj)) if "__repr__" in vars(m)][:1] != [type]:
676 _repr_pprint(obj, p, cycle)
668 _repr_pprint(obj, p, cycle)
677 return
669 return
678
670
679 mod = _safe_getattr(obj, '__module__', None)
671 mod = _safe_getattr(obj, '__module__', None)
680 try:
672 try:
681 name = obj.__qualname__
673 name = obj.__qualname__
682 if not isinstance(name, str):
674 if not isinstance(name, str):
683 # This can happen if the type implements __qualname__ as a property
675 # This can happen if the type implements __qualname__ as a property
684 # or other descriptor in Python 2.
676 # or other descriptor in Python 2.
685 raise Exception("Try __name__")
677 raise Exception("Try __name__")
686 except Exception:
678 except Exception:
687 name = obj.__name__
679 name = obj.__name__
688 if not isinstance(name, str):
680 if not isinstance(name, str):
689 name = '<unknown type>'
681 name = '<unknown type>'
690
682
691 if mod in (None, '__builtin__', 'builtins', 'exceptions'):
683 if mod in (None, '__builtin__', 'builtins', 'exceptions'):
692 p.text(name)
684 p.text(name)
693 else:
685 else:
694 p.text(mod + '.' + name)
686 p.text(mod + '.' + name)
695
687
696
688
697 def _repr_pprint(obj, p, cycle):
689 def _repr_pprint(obj, p, cycle):
698 """A pprint that just redirects to the normal repr function."""
690 """A pprint that just redirects to the normal repr function."""
699 # Find newlines and replace them with p.break_()
691 # Find newlines and replace them with p.break_()
700 output = repr(obj)
692 output = repr(obj)
701 for idx,output_line in enumerate(output.splitlines()):
693 for idx,output_line in enumerate(output.splitlines()):
702 if idx:
694 if idx:
703 p.break_()
695 p.break_()
704 p.text(output_line)
696 p.text(output_line)
705
697
706
698
707 def _function_pprint(obj, p, cycle):
699 def _function_pprint(obj, p, cycle):
708 """Base pprint for all functions and builtin functions."""
700 """Base pprint for all functions and builtin functions."""
709 name = _safe_getattr(obj, '__qualname__', obj.__name__)
701 name = _safe_getattr(obj, '__qualname__', obj.__name__)
710 mod = obj.__module__
702 mod = obj.__module__
711 if mod and mod not in ('__builtin__', 'builtins', 'exceptions'):
703 if mod and mod not in ('__builtin__', 'builtins', 'exceptions'):
712 name = mod + '.' + name
704 name = mod + '.' + name
713 p.text('<function %s>' % name)
705 p.text('<function %s>' % name)
714
706
715
707
716 def _exception_pprint(obj, p, cycle):
708 def _exception_pprint(obj, p, cycle):
717 """Base pprint for all exceptions."""
709 """Base pprint for all exceptions."""
718 name = getattr(obj.__class__, '__qualname__', obj.__class__.__name__)
710 name = getattr(obj.__class__, '__qualname__', obj.__class__.__name__)
719 if obj.__class__.__module__ not in ('exceptions', 'builtins'):
711 if obj.__class__.__module__ not in ('exceptions', 'builtins'):
720 name = '%s.%s' % (obj.__class__.__module__, name)
712 name = '%s.%s' % (obj.__class__.__module__, name)
721 step = len(name) + 1
713 step = len(name) + 1
722 p.begin_group(step, name + '(')
714 p.begin_group(step, name + '(')
723 for idx, arg in enumerate(getattr(obj, 'args', ())):
715 for idx, arg in enumerate(getattr(obj, 'args', ())):
724 if idx:
716 if idx:
725 p.text(',')
717 p.text(',')
726 p.breakable()
718 p.breakable()
727 p.pretty(arg)
719 p.pretty(arg)
728 p.end_group(step, ')')
720 p.end_group(step, ')')
729
721
730
722
731 #: the exception base
723 #: the exception base
732 try:
724 try:
733 _exception_base = BaseException
725 _exception_base = BaseException
734 except NameError:
726 except NameError:
735 _exception_base = Exception
727 _exception_base = Exception
736
728
737
729
738 #: printers for builtin types
730 #: printers for builtin types
739 _type_pprinters = {
731 _type_pprinters = {
740 int: _repr_pprint,
732 int: _repr_pprint,
741 float: _repr_pprint,
733 float: _repr_pprint,
742 str: _repr_pprint,
734 str: _repr_pprint,
743 tuple: _seq_pprinter_factory('(', ')', tuple),
735 tuple: _seq_pprinter_factory('(', ')', tuple),
744 list: _seq_pprinter_factory('[', ']', list),
736 list: _seq_pprinter_factory('[', ']', list),
745 dict: _dict_pprinter_factory('{', '}', dict),
737 dict: _dict_pprinter_factory('{', '}', dict),
746
738
747 set: _set_pprinter_factory('{', '}', set),
739 set: _set_pprinter_factory('{', '}', set),
748 frozenset: _set_pprinter_factory('frozenset({', '})', frozenset),
740 frozenset: _set_pprinter_factory('frozenset({', '})', frozenset),
749 super: _super_pprint,
741 super: _super_pprint,
750 _re_pattern_type: _re_pattern_pprint,
742 _re_pattern_type: _re_pattern_pprint,
751 type: _type_pprint,
743 type: _type_pprint,
752 types.FunctionType: _function_pprint,
744 types.FunctionType: _function_pprint,
753 types.BuiltinFunctionType: _function_pprint,
745 types.BuiltinFunctionType: _function_pprint,
754 types.MethodType: _repr_pprint,
746 types.MethodType: _repr_pprint,
755
747
756 datetime.datetime: _repr_pprint,
748 datetime.datetime: _repr_pprint,
757 datetime.timedelta: _repr_pprint,
749 datetime.timedelta: _repr_pprint,
758 _exception_base: _exception_pprint
750 _exception_base: _exception_pprint
759 }
751 }
760
752
761 try:
753 try:
762 # In PyPy, types.DictProxyType is dict, setting the dictproxy printer
754 # In PyPy, types.DictProxyType is dict, setting the dictproxy printer
763 # using dict.setdefault avoids overwritting the dict printer
755 # using dict.setdefault avoids overwritting the dict printer
764 _type_pprinters.setdefault(types.DictProxyType,
756 _type_pprinters.setdefault(types.DictProxyType,
765 _dict_pprinter_factory('dict_proxy({', '})'))
757 _dict_pprinter_factory('dict_proxy({', '})'))
766 _type_pprinters[types.ClassType] = _type_pprint
758 _type_pprinters[types.ClassType] = _type_pprint
767 _type_pprinters[types.SliceType] = _repr_pprint
759 _type_pprinters[types.SliceType] = _repr_pprint
768 except AttributeError: # Python 3
760 except AttributeError: # Python 3
769 _type_pprinters[types.MappingProxyType] = \
761 _type_pprinters[types.MappingProxyType] = \
770 _dict_pprinter_factory('mappingproxy({', '})')
762 _dict_pprinter_factory('mappingproxy({', '})')
771 _type_pprinters[slice] = _repr_pprint
763 _type_pprinters[slice] = _repr_pprint
772
764
773 try:
765 try:
774 _type_pprinters[long] = _repr_pprint
766 _type_pprinters[long] = _repr_pprint
775 _type_pprinters[unicode] = _repr_pprint
767 _type_pprinters[unicode] = _repr_pprint
776 except NameError:
768 except NameError:
777 _type_pprinters[range] = _repr_pprint
769 _type_pprinters[range] = _repr_pprint
778 _type_pprinters[bytes] = _repr_pprint
770 _type_pprinters[bytes] = _repr_pprint
779
771
780 #: printers for types specified by name
772 #: printers for types specified by name
781 _deferred_type_pprinters = {
773 _deferred_type_pprinters = {
782 }
774 }
783
775
784 def for_type(typ, func):
776 def for_type(typ, func):
785 """
777 """
786 Add a pretty printer for a given type.
778 Add a pretty printer for a given type.
787 """
779 """
788 oldfunc = _type_pprinters.get(typ, None)
780 oldfunc = _type_pprinters.get(typ, None)
789 if func is not None:
781 if func is not None:
790 # To support easy restoration of old pprinters, we need to ignore Nones.
782 # To support easy restoration of old pprinters, we need to ignore Nones.
791 _type_pprinters[typ] = func
783 _type_pprinters[typ] = func
792 return oldfunc
784 return oldfunc
793
785
794 def for_type_by_name(type_module, type_name, func):
786 def for_type_by_name(type_module, type_name, func):
795 """
787 """
796 Add a pretty printer for a type specified by the module and name of a type
788 Add a pretty printer for a type specified by the module and name of a type
797 rather than the type object itself.
789 rather than the type object itself.
798 """
790 """
799 key = (type_module, type_name)
791 key = (type_module, type_name)
800 oldfunc = _deferred_type_pprinters.get(key, None)
792 oldfunc = _deferred_type_pprinters.get(key, None)
801 if func is not None:
793 if func is not None:
802 # To support easy restoration of old pprinters, we need to ignore Nones.
794 # To support easy restoration of old pprinters, we need to ignore Nones.
803 _deferred_type_pprinters[key] = func
795 _deferred_type_pprinters[key] = func
804 return oldfunc
796 return oldfunc
805
797
806
798
807 #: printers for the default singletons
799 #: printers for the default singletons
808 _singleton_pprinters = dict.fromkeys(map(id, [None, True, False, Ellipsis,
800 _singleton_pprinters = dict.fromkeys(map(id, [None, True, False, Ellipsis,
809 NotImplemented]), _repr_pprint)
801 NotImplemented]), _repr_pprint)
810
802
811
803
812 def _defaultdict_pprint(obj, p, cycle):
804 def _defaultdict_pprint(obj, p, cycle):
813 name = obj.__class__.__name__
805 name = obj.__class__.__name__
814 with p.group(len(name) + 1, name + '(', ')'):
806 with p.group(len(name) + 1, name + '(', ')'):
815 if cycle:
807 if cycle:
816 p.text('...')
808 p.text('...')
817 else:
809 else:
818 p.pretty(obj.default_factory)
810 p.pretty(obj.default_factory)
819 p.text(',')
811 p.text(',')
820 p.breakable()
812 p.breakable()
821 p.pretty(dict(obj))
813 p.pretty(dict(obj))
822
814
823 def _ordereddict_pprint(obj, p, cycle):
815 def _ordereddict_pprint(obj, p, cycle):
824 name = obj.__class__.__name__
816 name = obj.__class__.__name__
825 with p.group(len(name) + 1, name + '(', ')'):
817 with p.group(len(name) + 1, name + '(', ')'):
826 if cycle:
818 if cycle:
827 p.text('...')
819 p.text('...')
828 elif len(obj):
820 elif len(obj):
829 p.pretty(list(obj.items()))
821 p.pretty(list(obj.items()))
830
822
831 def _deque_pprint(obj, p, cycle):
823 def _deque_pprint(obj, p, cycle):
832 name = obj.__class__.__name__
824 name = obj.__class__.__name__
833 with p.group(len(name) + 1, name + '(', ')'):
825 with p.group(len(name) + 1, name + '(', ')'):
834 if cycle:
826 if cycle:
835 p.text('...')
827 p.text('...')
836 else:
828 else:
837 p.pretty(list(obj))
829 p.pretty(list(obj))
838
830
839
831
840 def _counter_pprint(obj, p, cycle):
832 def _counter_pprint(obj, p, cycle):
841 name = obj.__class__.__name__
833 name = obj.__class__.__name__
842 with p.group(len(name) + 1, name + '(', ')'):
834 with p.group(len(name) + 1, name + '(', ')'):
843 if cycle:
835 if cycle:
844 p.text('...')
836 p.text('...')
845 elif len(obj):
837 elif len(obj):
846 p.pretty(dict(obj))
838 p.pretty(dict(obj))
847
839
848 for_type_by_name('collections', 'defaultdict', _defaultdict_pprint)
840 for_type_by_name('collections', 'defaultdict', _defaultdict_pprint)
849 for_type_by_name('collections', 'OrderedDict', _ordereddict_pprint)
841 for_type_by_name('collections', 'OrderedDict', _ordereddict_pprint)
850 for_type_by_name('collections', 'deque', _deque_pprint)
842 for_type_by_name('collections', 'deque', _deque_pprint)
851 for_type_by_name('collections', 'Counter', _counter_pprint)
843 for_type_by_name('collections', 'Counter', _counter_pprint)
852
844
853 if __name__ == '__main__':
845 if __name__ == '__main__':
854 from random import randrange
846 from random import randrange
855 class Foo(object):
847 class Foo(object):
856 def __init__(self):
848 def __init__(self):
857 self.foo = 1
849 self.foo = 1
858 self.bar = re.compile(r'\s+')
850 self.bar = re.compile(r'\s+')
859 self.blub = dict.fromkeys(range(30), randrange(1, 40))
851 self.blub = dict.fromkeys(range(30), randrange(1, 40))
860 self.hehe = 23424.234234
852 self.hehe = 23424.234234
861 self.list = ["blub", "blah", self]
853 self.list = ["blub", "blah", self]
862
854
863 def get_foo(self):
855 def get_foo(self):
864 print("foo")
856 print("foo")
865
857
866 pprint(Foo(), verbose=True)
858 pprint(Foo(), verbose=True)
@@ -1,1183 +1,1177 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Sphinx directive to support embedded IPython code.
3 Sphinx directive to support embedded IPython code.
4
4
5 This directive allows pasting of entire interactive IPython sessions, prompts
5 This directive allows pasting of entire interactive IPython sessions, prompts
6 and all, and their code will actually get re-executed at doc build time, with
6 and all, and their code will actually get re-executed at doc build time, with
7 all prompts renumbered sequentially. It also allows you to input code as a pure
7 all prompts renumbered sequentially. It also allows you to input code as a pure
8 python input by giving the argument python to the directive. The output looks
8 python input by giving the argument python to the directive. The output looks
9 like an interactive ipython section.
9 like an interactive ipython section.
10
10
11 To enable this directive, simply list it in your Sphinx ``conf.py`` file
11 To enable this directive, simply list it in your Sphinx ``conf.py`` file
12 (making sure the directory where you placed it is visible to sphinx, as is
12 (making sure the directory where you placed it is visible to sphinx, as is
13 needed for all Sphinx directives). For example, to enable syntax highlighting
13 needed for all Sphinx directives). For example, to enable syntax highlighting
14 and the IPython directive::
14 and the IPython directive::
15
15
16 extensions = ['IPython.sphinxext.ipython_console_highlighting',
16 extensions = ['IPython.sphinxext.ipython_console_highlighting',
17 'IPython.sphinxext.ipython_directive']
17 'IPython.sphinxext.ipython_directive']
18
18
19 The IPython directive outputs code-blocks with the language 'ipython'. So
19 The IPython directive outputs code-blocks with the language 'ipython'. So
20 if you do not have the syntax highlighting extension enabled as well, then
20 if you do not have the syntax highlighting extension enabled as well, then
21 all rendered code-blocks will be uncolored. By default this directive assumes
21 all rendered code-blocks will be uncolored. By default this directive assumes
22 that your prompts are unchanged IPython ones, but this can be customized.
22 that your prompts are unchanged IPython ones, but this can be customized.
23 The configurable options that can be placed in conf.py are:
23 The configurable options that can be placed in conf.py are:
24
24
25 ipython_savefig_dir:
25 ipython_savefig_dir:
26 The directory in which to save the figures. This is relative to the
26 The directory in which to save the figures. This is relative to the
27 Sphinx source directory. The default is `html_static_path`.
27 Sphinx source directory. The default is `html_static_path`.
28 ipython_rgxin:
28 ipython_rgxin:
29 The compiled regular expression to denote the start of IPython input
29 The compiled regular expression to denote the start of IPython input
30 lines. The default is re.compile('In \[(\d+)\]:\s?(.*)\s*'). You
30 lines. The default is re.compile('In \[(\d+)\]:\s?(.*)\s*'). You
31 shouldn't need to change this.
31 shouldn't need to change this.
32 ipython_rgxout:
32 ipython_rgxout:
33 The compiled regular expression to denote the start of IPython output
33 The compiled regular expression to denote the start of IPython output
34 lines. The default is re.compile('Out\[(\d+)\]:\s?(.*)\s*'). You
34 lines. The default is re.compile('Out\[(\d+)\]:\s?(.*)\s*'). You
35 shouldn't need to change this.
35 shouldn't need to change this.
36 ipython_promptin:
36 ipython_promptin:
37 The string to represent the IPython input prompt in the generated ReST.
37 The string to represent the IPython input prompt in the generated ReST.
38 The default is 'In [%d]:'. This expects that the line numbers are used
38 The default is 'In [%d]:'. This expects that the line numbers are used
39 in the prompt.
39 in the prompt.
40 ipython_promptout:
40 ipython_promptout:
41 The string to represent the IPython prompt in the generated ReST. The
41 The string to represent the IPython prompt in the generated ReST. The
42 default is 'Out [%d]:'. This expects that the line numbers are used
42 default is 'Out [%d]:'. This expects that the line numbers are used
43 in the prompt.
43 in the prompt.
44 ipython_mplbackend:
44 ipython_mplbackend:
45 The string which specifies if the embedded Sphinx shell should import
45 The string which specifies if the embedded Sphinx shell should import
46 Matplotlib and set the backend. The value specifies a backend that is
46 Matplotlib and set the backend. The value specifies a backend that is
47 passed to `matplotlib.use()` before any lines in `ipython_execlines` are
47 passed to `matplotlib.use()` before any lines in `ipython_execlines` are
48 executed. If not specified in conf.py, then the default value of 'agg' is
48 executed. If not specified in conf.py, then the default value of 'agg' is
49 used. To use the IPython directive without matplotlib as a dependency, set
49 used. To use the IPython directive without matplotlib as a dependency, set
50 the value to `None`. It may end up that matplotlib is still imported
50 the value to `None`. It may end up that matplotlib is still imported
51 if the user specifies so in `ipython_execlines` or makes use of the
51 if the user specifies so in `ipython_execlines` or makes use of the
52 @savefig pseudo decorator.
52 @savefig pseudo decorator.
53 ipython_execlines:
53 ipython_execlines:
54 A list of strings to be exec'd in the embedded Sphinx shell. Typical
54 A list of strings to be exec'd in the embedded Sphinx shell. Typical
55 usage is to make certain packages always available. Set this to an empty
55 usage is to make certain packages always available. Set this to an empty
56 list if you wish to have no imports always available. If specified in
56 list if you wish to have no imports always available. If specified in
57 conf.py as `None`, then it has the effect of making no imports available.
57 conf.py as `None`, then it has the effect of making no imports available.
58 If omitted from conf.py altogether, then the default value of
58 If omitted from conf.py altogether, then the default value of
59 ['import numpy as np', 'import matplotlib.pyplot as plt'] is used.
59 ['import numpy as np', 'import matplotlib.pyplot as plt'] is used.
60 ipython_holdcount
60 ipython_holdcount
61 When the @suppress pseudo-decorator is used, the execution count can be
61 When the @suppress pseudo-decorator is used, the execution count can be
62 incremented or not. The default behavior is to hold the execution count,
62 incremented or not. The default behavior is to hold the execution count,
63 corresponding to a value of `True`. Set this to `False` to increment
63 corresponding to a value of `True`. Set this to `False` to increment
64 the execution count after each suppressed command.
64 the execution count after each suppressed command.
65
65
66 As an example, to use the IPython directive when `matplotlib` is not available,
66 As an example, to use the IPython directive when `matplotlib` is not available,
67 one sets the backend to `None`::
67 one sets the backend to `None`::
68
68
69 ipython_mplbackend = None
69 ipython_mplbackend = None
70
70
71 An example usage of the directive is:
71 An example usage of the directive is:
72
72
73 .. code-block:: rst
73 .. code-block:: rst
74
74
75 .. ipython::
75 .. ipython::
76
76
77 In [1]: x = 1
77 In [1]: x = 1
78
78
79 In [2]: y = x**2
79 In [2]: y = x**2
80
80
81 In [3]: print(y)
81 In [3]: print(y)
82
82
83 See http://matplotlib.org/sampledoc/ipython_directive.html for additional
83 See http://matplotlib.org/sampledoc/ipython_directive.html for additional
84 documentation.
84 documentation.
85
85
86 Pseudo-Decorators
86 Pseudo-Decorators
87 =================
87 =================
88
88
89 Note: Only one decorator is supported per input. If more than one decorator
89 Note: Only one decorator is supported per input. If more than one decorator
90 is specified, then only the last one is used.
90 is specified, then only the last one is used.
91
91
92 In addition to the Pseudo-Decorators/options described at the above link,
92 In addition to the Pseudo-Decorators/options described at the above link,
93 several enhancements have been made. The directive will emit a message to the
93 several enhancements have been made. The directive will emit a message to the
94 console at build-time if code-execution resulted in an exception or warning.
94 console at build-time if code-execution resulted in an exception or warning.
95 You can suppress these on a per-block basis by specifying the :okexcept:
95 You can suppress these on a per-block basis by specifying the :okexcept:
96 or :okwarning: options:
96 or :okwarning: options:
97
97
98 .. code-block:: rst
98 .. code-block:: rst
99
99
100 .. ipython::
100 .. ipython::
101 :okexcept:
101 :okexcept:
102 :okwarning:
102 :okwarning:
103
103
104 In [1]: 1/0
104 In [1]: 1/0
105 In [2]: # raise warning.
105 In [2]: # raise warning.
106
106
107 ToDo
107 ToDo
108 ----
108 ----
109
109
110 - Turn the ad-hoc test() function into a real test suite.
110 - Turn the ad-hoc test() function into a real test suite.
111 - Break up ipython-specific functionality from matplotlib stuff into better
111 - Break up ipython-specific functionality from matplotlib stuff into better
112 separated code.
112 separated code.
113
113
114 Authors
114 Authors
115 -------
115 -------
116
116
117 - John D Hunter: orignal author.
117 - John D Hunter: orignal author.
118 - Fernando Perez: refactoring, documentation, cleanups, port to 0.11.
118 - Fernando Perez: refactoring, documentation, cleanups, port to 0.11.
119 - VΓ‘clavΕ milauer <eudoxos-AT-arcig.cz>: Prompt generalizations.
119 - VΓ‘clavΕ milauer <eudoxos-AT-arcig.cz>: Prompt generalizations.
120 - Skipper Seabold, refactoring, cleanups, pure python addition
120 - Skipper Seabold, refactoring, cleanups, pure python addition
121 """
121 """
122
122
123 #-----------------------------------------------------------------------------
123 #-----------------------------------------------------------------------------
124 # Imports
124 # Imports
125 #-----------------------------------------------------------------------------
125 #-----------------------------------------------------------------------------
126
126
127 # Stdlib
127 # Stdlib
128 import atexit
128 import atexit
129 import os
129 import os
130 import re
130 import re
131 import sys
131 import sys
132 import tempfile
132 import tempfile
133 import ast
133 import ast
134 import warnings
134 import warnings
135 import shutil
135 import shutil
136
136 from io import StringIO
137
137
138 # Third-party
138 # Third-party
139 from docutils.parsers.rst import directives
139 from docutils.parsers.rst import directives
140 from sphinx.util.compat import Directive
140 from sphinx.util.compat import Directive
141
141
142 # Our own
142 # Our own
143 from traitlets.config import Config
143 from traitlets.config import Config
144 from IPython import InteractiveShell
144 from IPython import InteractiveShell
145 from IPython.core.profiledir import ProfileDir
145 from IPython.core.profiledir import ProfileDir
146 from IPython.utils import io
146 from IPython.utils import io
147 from IPython.utils.py3compat import PY3
148
149 if PY3:
150 from io import StringIO
151 else:
152 from StringIO import StringIO
153
147
154 #-----------------------------------------------------------------------------
148 #-----------------------------------------------------------------------------
155 # Globals
149 # Globals
156 #-----------------------------------------------------------------------------
150 #-----------------------------------------------------------------------------
157 # for tokenizing blocks
151 # for tokenizing blocks
158 COMMENT, INPUT, OUTPUT = range(3)
152 COMMENT, INPUT, OUTPUT = range(3)
159
153
160 #-----------------------------------------------------------------------------
154 #-----------------------------------------------------------------------------
161 # Functions and class declarations
155 # Functions and class declarations
162 #-----------------------------------------------------------------------------
156 #-----------------------------------------------------------------------------
163
157
164 def block_parser(part, rgxin, rgxout, fmtin, fmtout):
158 def block_parser(part, rgxin, rgxout, fmtin, fmtout):
165 """
159 """
166 part is a string of ipython text, comprised of at most one
160 part is a string of ipython text, comprised of at most one
167 input, one output, comments, and blank lines. The block parser
161 input, one output, comments, and blank lines. The block parser
168 parses the text into a list of::
162 parses the text into a list of::
169
163
170 blocks = [ (TOKEN0, data0), (TOKEN1, data1), ...]
164 blocks = [ (TOKEN0, data0), (TOKEN1, data1), ...]
171
165
172 where TOKEN is one of [COMMENT | INPUT | OUTPUT ] and
166 where TOKEN is one of [COMMENT | INPUT | OUTPUT ] and
173 data is, depending on the type of token::
167 data is, depending on the type of token::
174
168
175 COMMENT : the comment string
169 COMMENT : the comment string
176
170
177 INPUT: the (DECORATOR, INPUT_LINE, REST) where
171 INPUT: the (DECORATOR, INPUT_LINE, REST) where
178 DECORATOR: the input decorator (or None)
172 DECORATOR: the input decorator (or None)
179 INPUT_LINE: the input as string (possibly multi-line)
173 INPUT_LINE: the input as string (possibly multi-line)
180 REST : any stdout generated by the input line (not OUTPUT)
174 REST : any stdout generated by the input line (not OUTPUT)
181
175
182 OUTPUT: the output string, possibly multi-line
176 OUTPUT: the output string, possibly multi-line
183
177
184 """
178 """
185 block = []
179 block = []
186 lines = part.split('\n')
180 lines = part.split('\n')
187 N = len(lines)
181 N = len(lines)
188 i = 0
182 i = 0
189 decorator = None
183 decorator = None
190 while 1:
184 while 1:
191
185
192 if i==N:
186 if i==N:
193 # nothing left to parse -- the last line
187 # nothing left to parse -- the last line
194 break
188 break
195
189
196 line = lines[i]
190 line = lines[i]
197 i += 1
191 i += 1
198 line_stripped = line.strip()
192 line_stripped = line.strip()
199 if line_stripped.startswith('#'):
193 if line_stripped.startswith('#'):
200 block.append((COMMENT, line))
194 block.append((COMMENT, line))
201 continue
195 continue
202
196
203 if line_stripped.startswith('@'):
197 if line_stripped.startswith('@'):
204 # Here is where we assume there is, at most, one decorator.
198 # Here is where we assume there is, at most, one decorator.
205 # Might need to rethink this.
199 # Might need to rethink this.
206 decorator = line_stripped
200 decorator = line_stripped
207 continue
201 continue
208
202
209 # does this look like an input line?
203 # does this look like an input line?
210 matchin = rgxin.match(line)
204 matchin = rgxin.match(line)
211 if matchin:
205 if matchin:
212 lineno, inputline = int(matchin.group(1)), matchin.group(2)
206 lineno, inputline = int(matchin.group(1)), matchin.group(2)
213
207
214 # the ....: continuation string
208 # the ....: continuation string
215 continuation = ' %s:'%''.join(['.']*(len(str(lineno))+2))
209 continuation = ' %s:'%''.join(['.']*(len(str(lineno))+2))
216 Nc = len(continuation)
210 Nc = len(continuation)
217 # input lines can continue on for more than one line, if
211 # input lines can continue on for more than one line, if
218 # we have a '\' line continuation char or a function call
212 # we have a '\' line continuation char or a function call
219 # echo line 'print'. The input line can only be
213 # echo line 'print'. The input line can only be
220 # terminated by the end of the block or an output line, so
214 # terminated by the end of the block or an output line, so
221 # we parse out the rest of the input line if it is
215 # we parse out the rest of the input line if it is
222 # multiline as well as any echo text
216 # multiline as well as any echo text
223
217
224 rest = []
218 rest = []
225 while i<N:
219 while i<N:
226
220
227 # look ahead; if the next line is blank, or a comment, or
221 # look ahead; if the next line is blank, or a comment, or
228 # an output line, we're done
222 # an output line, we're done
229
223
230 nextline = lines[i]
224 nextline = lines[i]
231 matchout = rgxout.match(nextline)
225 matchout = rgxout.match(nextline)
232 #print "nextline=%s, continuation=%s, starts=%s"%(nextline, continuation, nextline.startswith(continuation))
226 #print "nextline=%s, continuation=%s, starts=%s"%(nextline, continuation, nextline.startswith(continuation))
233 if matchout or nextline.startswith('#'):
227 if matchout or nextline.startswith('#'):
234 break
228 break
235 elif nextline.startswith(continuation):
229 elif nextline.startswith(continuation):
236 # The default ipython_rgx* treat the space following the colon as optional.
230 # The default ipython_rgx* treat the space following the colon as optional.
237 # However, If the space is there we must consume it or code
231 # However, If the space is there we must consume it or code
238 # employing the cython_magic extension will fail to execute.
232 # employing the cython_magic extension will fail to execute.
239 #
233 #
240 # This works with the default ipython_rgx* patterns,
234 # This works with the default ipython_rgx* patterns,
241 # If you modify them, YMMV.
235 # If you modify them, YMMV.
242 nextline = nextline[Nc:]
236 nextline = nextline[Nc:]
243 if nextline and nextline[0] == ' ':
237 if nextline and nextline[0] == ' ':
244 nextline = nextline[1:]
238 nextline = nextline[1:]
245
239
246 inputline += '\n' + nextline
240 inputline += '\n' + nextline
247 else:
241 else:
248 rest.append(nextline)
242 rest.append(nextline)
249 i+= 1
243 i+= 1
250
244
251 block.append((INPUT, (decorator, inputline, '\n'.join(rest))))
245 block.append((INPUT, (decorator, inputline, '\n'.join(rest))))
252 continue
246 continue
253
247
254 # if it looks like an output line grab all the text to the end
248 # if it looks like an output line grab all the text to the end
255 # of the block
249 # of the block
256 matchout = rgxout.match(line)
250 matchout = rgxout.match(line)
257 if matchout:
251 if matchout:
258 lineno, output = int(matchout.group(1)), matchout.group(2)
252 lineno, output = int(matchout.group(1)), matchout.group(2)
259 if i<N-1:
253 if i<N-1:
260 output = '\n'.join([output] + lines[i:])
254 output = '\n'.join([output] + lines[i:])
261
255
262 block.append((OUTPUT, output))
256 block.append((OUTPUT, output))
263 break
257 break
264
258
265 return block
259 return block
266
260
267
261
268 class EmbeddedSphinxShell(object):
262 class EmbeddedSphinxShell(object):
269 """An embedded IPython instance to run inside Sphinx"""
263 """An embedded IPython instance to run inside Sphinx"""
270
264
271 def __init__(self, exec_lines=None):
265 def __init__(self, exec_lines=None):
272
266
273 self.cout = StringIO()
267 self.cout = StringIO()
274
268
275 if exec_lines is None:
269 if exec_lines is None:
276 exec_lines = []
270 exec_lines = []
277
271
278 # Create config object for IPython
272 # Create config object for IPython
279 config = Config()
273 config = Config()
280 config.HistoryManager.hist_file = ':memory:'
274 config.HistoryManager.hist_file = ':memory:'
281 config.InteractiveShell.autocall = False
275 config.InteractiveShell.autocall = False
282 config.InteractiveShell.autoindent = False
276 config.InteractiveShell.autoindent = False
283 config.InteractiveShell.colors = 'NoColor'
277 config.InteractiveShell.colors = 'NoColor'
284
278
285 # create a profile so instance history isn't saved
279 # create a profile so instance history isn't saved
286 tmp_profile_dir = tempfile.mkdtemp(prefix='profile_')
280 tmp_profile_dir = tempfile.mkdtemp(prefix='profile_')
287 profname = 'auto_profile_sphinx_build'
281 profname = 'auto_profile_sphinx_build'
288 pdir = os.path.join(tmp_profile_dir,profname)
282 pdir = os.path.join(tmp_profile_dir,profname)
289 profile = ProfileDir.create_profile_dir(pdir)
283 profile = ProfileDir.create_profile_dir(pdir)
290
284
291 # Create and initialize global ipython, but don't start its mainloop.
285 # Create and initialize global ipython, but don't start its mainloop.
292 # This will persist across different EmbededSphinxShell instances.
286 # This will persist across different EmbededSphinxShell instances.
293 IP = InteractiveShell.instance(config=config, profile_dir=profile)
287 IP = InteractiveShell.instance(config=config, profile_dir=profile)
294 atexit.register(self.cleanup)
288 atexit.register(self.cleanup)
295
289
296 sys.stdout = self.cout
290 sys.stdout = self.cout
297 sys.stderr = self.cout
291 sys.stderr = self.cout
298
292
299 # For debugging, so we can see normal output, use this:
293 # For debugging, so we can see normal output, use this:
300 #from IPython.utils.io import Tee
294 #from IPython.utils.io import Tee
301 #sys.stdout = Tee(self.cout, channel='stdout') # dbg
295 #sys.stdout = Tee(self.cout, channel='stdout') # dbg
302 #sys.stderr = Tee(self.cout, channel='stderr') # dbg
296 #sys.stderr = Tee(self.cout, channel='stderr') # dbg
303
297
304 # Store a few parts of IPython we'll need.
298 # Store a few parts of IPython we'll need.
305 self.IP = IP
299 self.IP = IP
306 self.user_ns = self.IP.user_ns
300 self.user_ns = self.IP.user_ns
307 self.user_global_ns = self.IP.user_global_ns
301 self.user_global_ns = self.IP.user_global_ns
308
302
309 self.input = ''
303 self.input = ''
310 self.output = ''
304 self.output = ''
311 self.tmp_profile_dir = tmp_profile_dir
305 self.tmp_profile_dir = tmp_profile_dir
312
306
313 self.is_verbatim = False
307 self.is_verbatim = False
314 self.is_doctest = False
308 self.is_doctest = False
315 self.is_suppress = False
309 self.is_suppress = False
316
310
317 # Optionally, provide more detailed information to shell.
311 # Optionally, provide more detailed information to shell.
318 # this is assigned by the SetUp method of IPythonDirective
312 # this is assigned by the SetUp method of IPythonDirective
319 # to point at itself.
313 # to point at itself.
320 #
314 #
321 # So, you can access handy things at self.directive.state
315 # So, you can access handy things at self.directive.state
322 self.directive = None
316 self.directive = None
323
317
324 # on the first call to the savefig decorator, we'll import
318 # on the first call to the savefig decorator, we'll import
325 # pyplot as plt so we can make a call to the plt.gcf().savefig
319 # pyplot as plt so we can make a call to the plt.gcf().savefig
326 self._pyplot_imported = False
320 self._pyplot_imported = False
327
321
328 # Prepopulate the namespace.
322 # Prepopulate the namespace.
329 for line in exec_lines:
323 for line in exec_lines:
330 self.process_input_line(line, store_history=False)
324 self.process_input_line(line, store_history=False)
331
325
332 def cleanup(self):
326 def cleanup(self):
333 shutil.rmtree(self.tmp_profile_dir, ignore_errors=True)
327 shutil.rmtree(self.tmp_profile_dir, ignore_errors=True)
334
328
335 def clear_cout(self):
329 def clear_cout(self):
336 self.cout.seek(0)
330 self.cout.seek(0)
337 self.cout.truncate(0)
331 self.cout.truncate(0)
338
332
339 def process_input_line(self, line, store_history=True):
333 def process_input_line(self, line, store_history=True):
340 """process the input, capturing stdout"""
334 """process the input, capturing stdout"""
341
335
342 stdout = sys.stdout
336 stdout = sys.stdout
343 splitter = self.IP.input_splitter
337 splitter = self.IP.input_splitter
344 try:
338 try:
345 sys.stdout = self.cout
339 sys.stdout = self.cout
346 splitter.push(line)
340 splitter.push(line)
347 more = splitter.push_accepts_more()
341 more = splitter.push_accepts_more()
348 if not more:
342 if not more:
349 source_raw = splitter.raw_reset()
343 source_raw = splitter.raw_reset()
350 self.IP.run_cell(source_raw, store_history=store_history)
344 self.IP.run_cell(source_raw, store_history=store_history)
351 finally:
345 finally:
352 sys.stdout = stdout
346 sys.stdout = stdout
353
347
354 def process_image(self, decorator):
348 def process_image(self, decorator):
355 """
349 """
356 # build out an image directive like
350 # build out an image directive like
357 # .. image:: somefile.png
351 # .. image:: somefile.png
358 # :width 4in
352 # :width 4in
359 #
353 #
360 # from an input like
354 # from an input like
361 # savefig somefile.png width=4in
355 # savefig somefile.png width=4in
362 """
356 """
363 savefig_dir = self.savefig_dir
357 savefig_dir = self.savefig_dir
364 source_dir = self.source_dir
358 source_dir = self.source_dir
365 saveargs = decorator.split(' ')
359 saveargs = decorator.split(' ')
366 filename = saveargs[1]
360 filename = saveargs[1]
367 # insert relative path to image file in source
361 # insert relative path to image file in source
368 outfile = os.path.relpath(os.path.join(savefig_dir,filename),
362 outfile = os.path.relpath(os.path.join(savefig_dir,filename),
369 source_dir)
363 source_dir)
370
364
371 imagerows = ['.. image:: %s'%outfile]
365 imagerows = ['.. image:: %s'%outfile]
372
366
373 for kwarg in saveargs[2:]:
367 for kwarg in saveargs[2:]:
374 arg, val = kwarg.split('=')
368 arg, val = kwarg.split('=')
375 arg = arg.strip()
369 arg = arg.strip()
376 val = val.strip()
370 val = val.strip()
377 imagerows.append(' :%s: %s'%(arg, val))
371 imagerows.append(' :%s: %s'%(arg, val))
378
372
379 image_file = os.path.basename(outfile) # only return file name
373 image_file = os.path.basename(outfile) # only return file name
380 image_directive = '\n'.join(imagerows)
374 image_directive = '\n'.join(imagerows)
381 return image_file, image_directive
375 return image_file, image_directive
382
376
383 # Callbacks for each type of token
377 # Callbacks for each type of token
384 def process_input(self, data, input_prompt, lineno):
378 def process_input(self, data, input_prompt, lineno):
385 """
379 """
386 Process data block for INPUT token.
380 Process data block for INPUT token.
387
381
388 """
382 """
389 decorator, input, rest = data
383 decorator, input, rest = data
390 image_file = None
384 image_file = None
391 image_directive = None
385 image_directive = None
392
386
393 is_verbatim = decorator=='@verbatim' or self.is_verbatim
387 is_verbatim = decorator=='@verbatim' or self.is_verbatim
394 is_doctest = (decorator is not None and \
388 is_doctest = (decorator is not None and \
395 decorator.startswith('@doctest')) or self.is_doctest
389 decorator.startswith('@doctest')) or self.is_doctest
396 is_suppress = decorator=='@suppress' or self.is_suppress
390 is_suppress = decorator=='@suppress' or self.is_suppress
397 is_okexcept = decorator=='@okexcept' or self.is_okexcept
391 is_okexcept = decorator=='@okexcept' or self.is_okexcept
398 is_okwarning = decorator=='@okwarning' or self.is_okwarning
392 is_okwarning = decorator=='@okwarning' or self.is_okwarning
399 is_savefig = decorator is not None and \
393 is_savefig = decorator is not None and \
400 decorator.startswith('@savefig')
394 decorator.startswith('@savefig')
401
395
402 input_lines = input.split('\n')
396 input_lines = input.split('\n')
403 if len(input_lines) > 1:
397 if len(input_lines) > 1:
404 if input_lines[-1] != "":
398 if input_lines[-1] != "":
405 input_lines.append('') # make sure there's a blank line
399 input_lines.append('') # make sure there's a blank line
406 # so splitter buffer gets reset
400 # so splitter buffer gets reset
407
401
408 continuation = ' %s:'%''.join(['.']*(len(str(lineno))+2))
402 continuation = ' %s:'%''.join(['.']*(len(str(lineno))+2))
409
403
410 if is_savefig:
404 if is_savefig:
411 image_file, image_directive = self.process_image(decorator)
405 image_file, image_directive = self.process_image(decorator)
412
406
413 ret = []
407 ret = []
414 is_semicolon = False
408 is_semicolon = False
415
409
416 # Hold the execution count, if requested to do so.
410 # Hold the execution count, if requested to do so.
417 if is_suppress and self.hold_count:
411 if is_suppress and self.hold_count:
418 store_history = False
412 store_history = False
419 else:
413 else:
420 store_history = True
414 store_history = True
421
415
422 # Note: catch_warnings is not thread safe
416 # Note: catch_warnings is not thread safe
423 with warnings.catch_warnings(record=True) as ws:
417 with warnings.catch_warnings(record=True) as ws:
424 for i, line in enumerate(input_lines):
418 for i, line in enumerate(input_lines):
425 if line.endswith(';'):
419 if line.endswith(';'):
426 is_semicolon = True
420 is_semicolon = True
427
421
428 if i == 0:
422 if i == 0:
429 # process the first input line
423 # process the first input line
430 if is_verbatim:
424 if is_verbatim:
431 self.process_input_line('')
425 self.process_input_line('')
432 self.IP.execution_count += 1 # increment it anyway
426 self.IP.execution_count += 1 # increment it anyway
433 else:
427 else:
434 # only submit the line in non-verbatim mode
428 # only submit the line in non-verbatim mode
435 self.process_input_line(line, store_history=store_history)
429 self.process_input_line(line, store_history=store_history)
436 formatted_line = '%s %s'%(input_prompt, line)
430 formatted_line = '%s %s'%(input_prompt, line)
437 else:
431 else:
438 # process a continuation line
432 # process a continuation line
439 if not is_verbatim:
433 if not is_verbatim:
440 self.process_input_line(line, store_history=store_history)
434 self.process_input_line(line, store_history=store_history)
441
435
442 formatted_line = '%s %s'%(continuation, line)
436 formatted_line = '%s %s'%(continuation, line)
443
437
444 if not is_suppress:
438 if not is_suppress:
445 ret.append(formatted_line)
439 ret.append(formatted_line)
446
440
447 if not is_suppress and len(rest.strip()) and is_verbatim:
441 if not is_suppress and len(rest.strip()) and is_verbatim:
448 # The "rest" is the standard output of the input. This needs to be
442 # The "rest" is the standard output of the input. This needs to be
449 # added when in verbatim mode. If there is no "rest", then we don't
443 # added when in verbatim mode. If there is no "rest", then we don't
450 # add it, as the new line will be added by the processed output.
444 # add it, as the new line will be added by the processed output.
451 ret.append(rest)
445 ret.append(rest)
452
446
453 # Fetch the processed output. (This is not the submitted output.)
447 # Fetch the processed output. (This is not the submitted output.)
454 self.cout.seek(0)
448 self.cout.seek(0)
455 processed_output = self.cout.read()
449 processed_output = self.cout.read()
456 if not is_suppress and not is_semicolon:
450 if not is_suppress and not is_semicolon:
457 #
451 #
458 # In IPythonDirective.run, the elements of `ret` are eventually
452 # In IPythonDirective.run, the elements of `ret` are eventually
459 # combined such that '' entries correspond to newlines. So if
453 # combined such that '' entries correspond to newlines. So if
460 # `processed_output` is equal to '', then the adding it to `ret`
454 # `processed_output` is equal to '', then the adding it to `ret`
461 # ensures that there is a blank line between consecutive inputs
455 # ensures that there is a blank line between consecutive inputs
462 # that have no outputs, as in:
456 # that have no outputs, as in:
463 #
457 #
464 # In [1]: x = 4
458 # In [1]: x = 4
465 #
459 #
466 # In [2]: x = 5
460 # In [2]: x = 5
467 #
461 #
468 # When there is processed output, it has a '\n' at the tail end. So
462 # When there is processed output, it has a '\n' at the tail end. So
469 # adding the output to `ret` will provide the necessary spacing
463 # adding the output to `ret` will provide the necessary spacing
470 # between consecutive input/output blocks, as in:
464 # between consecutive input/output blocks, as in:
471 #
465 #
472 # In [1]: x
466 # In [1]: x
473 # Out[1]: 5
467 # Out[1]: 5
474 #
468 #
475 # In [2]: x
469 # In [2]: x
476 # Out[2]: 5
470 # Out[2]: 5
477 #
471 #
478 # When there is stdout from the input, it also has a '\n' at the
472 # When there is stdout from the input, it also has a '\n' at the
479 # tail end, and so this ensures proper spacing as well. E.g.:
473 # tail end, and so this ensures proper spacing as well. E.g.:
480 #
474 #
481 # In [1]: print x
475 # In [1]: print x
482 # 5
476 # 5
483 #
477 #
484 # In [2]: x = 5
478 # In [2]: x = 5
485 #
479 #
486 # When in verbatim mode, `processed_output` is empty (because
480 # When in verbatim mode, `processed_output` is empty (because
487 # nothing was passed to IP. Sometimes the submitted code block has
481 # nothing was passed to IP. Sometimes the submitted code block has
488 # an Out[] portion and sometimes it does not. When it does not, we
482 # an Out[] portion and sometimes it does not. When it does not, we
489 # need to ensure proper spacing, so we have to add '' to `ret`.
483 # need to ensure proper spacing, so we have to add '' to `ret`.
490 # However, if there is an Out[] in the submitted code, then we do
484 # However, if there is an Out[] in the submitted code, then we do
491 # not want to add a newline as `process_output` has stuff to add.
485 # not want to add a newline as `process_output` has stuff to add.
492 # The difficulty is that `process_input` doesn't know if
486 # The difficulty is that `process_input` doesn't know if
493 # `process_output` will be called---so it doesn't know if there is
487 # `process_output` will be called---so it doesn't know if there is
494 # Out[] in the code block. The requires that we include a hack in
488 # Out[] in the code block. The requires that we include a hack in
495 # `process_block`. See the comments there.
489 # `process_block`. See the comments there.
496 #
490 #
497 ret.append(processed_output)
491 ret.append(processed_output)
498 elif is_semicolon:
492 elif is_semicolon:
499 # Make sure there is a newline after the semicolon.
493 # Make sure there is a newline after the semicolon.
500 ret.append('')
494 ret.append('')
501
495
502 # context information
496 # context information
503 filename = "Unknown"
497 filename = "Unknown"
504 lineno = 0
498 lineno = 0
505 if self.directive.state:
499 if self.directive.state:
506 filename = self.directive.state.document.current_source
500 filename = self.directive.state.document.current_source
507 lineno = self.directive.state.document.current_line
501 lineno = self.directive.state.document.current_line
508
502
509 # output any exceptions raised during execution to stdout
503 # output any exceptions raised during execution to stdout
510 # unless :okexcept: has been specified.
504 # unless :okexcept: has been specified.
511 if not is_okexcept and "Traceback" in processed_output:
505 if not is_okexcept and "Traceback" in processed_output:
512 s = "\nException in %s at block ending on line %s\n" % (filename, lineno)
506 s = "\nException in %s at block ending on line %s\n" % (filename, lineno)
513 s += "Specify :okexcept: as an option in the ipython:: block to suppress this message\n"
507 s += "Specify :okexcept: as an option in the ipython:: block to suppress this message\n"
514 sys.stdout.write('\n\n>>>' + ('-' * 73))
508 sys.stdout.write('\n\n>>>' + ('-' * 73))
515 sys.stdout.write(s)
509 sys.stdout.write(s)
516 sys.stdout.write(processed_output)
510 sys.stdout.write(processed_output)
517 sys.stdout.write('<<<' + ('-' * 73) + '\n\n')
511 sys.stdout.write('<<<' + ('-' * 73) + '\n\n')
518
512
519 # output any warning raised during execution to stdout
513 # output any warning raised during execution to stdout
520 # unless :okwarning: has been specified.
514 # unless :okwarning: has been specified.
521 if not is_okwarning:
515 if not is_okwarning:
522 for w in ws:
516 for w in ws:
523 s = "\nWarning in %s at block ending on line %s\n" % (filename, lineno)
517 s = "\nWarning in %s at block ending on line %s\n" % (filename, lineno)
524 s += "Specify :okwarning: as an option in the ipython:: block to suppress this message\n"
518 s += "Specify :okwarning: as an option in the ipython:: block to suppress this message\n"
525 sys.stdout.write('\n\n>>>' + ('-' * 73))
519 sys.stdout.write('\n\n>>>' + ('-' * 73))
526 sys.stdout.write(s)
520 sys.stdout.write(s)
527 sys.stdout.write(('-' * 76) + '\n')
521 sys.stdout.write(('-' * 76) + '\n')
528 s=warnings.formatwarning(w.message, w.category,
522 s=warnings.formatwarning(w.message, w.category,
529 w.filename, w.lineno, w.line)
523 w.filename, w.lineno, w.line)
530 sys.stdout.write(s)
524 sys.stdout.write(s)
531 sys.stdout.write('<<<' + ('-' * 73) + '\n')
525 sys.stdout.write('<<<' + ('-' * 73) + '\n')
532
526
533 self.cout.truncate(0)
527 self.cout.truncate(0)
534
528
535 return (ret, input_lines, processed_output,
529 return (ret, input_lines, processed_output,
536 is_doctest, decorator, image_file, image_directive)
530 is_doctest, decorator, image_file, image_directive)
537
531
538
532
539 def process_output(self, data, output_prompt, input_lines, output,
533 def process_output(self, data, output_prompt, input_lines, output,
540 is_doctest, decorator, image_file):
534 is_doctest, decorator, image_file):
541 """
535 """
542 Process data block for OUTPUT token.
536 Process data block for OUTPUT token.
543
537
544 """
538 """
545 # Recall: `data` is the submitted output, and `output` is the processed
539 # Recall: `data` is the submitted output, and `output` is the processed
546 # output from `input_lines`.
540 # output from `input_lines`.
547
541
548 TAB = ' ' * 4
542 TAB = ' ' * 4
549
543
550 if is_doctest and output is not None:
544 if is_doctest and output is not None:
551
545
552 found = output # This is the processed output
546 found = output # This is the processed output
553 found = found.strip()
547 found = found.strip()
554 submitted = data.strip()
548 submitted = data.strip()
555
549
556 if self.directive is None:
550 if self.directive is None:
557 source = 'Unavailable'
551 source = 'Unavailable'
558 content = 'Unavailable'
552 content = 'Unavailable'
559 else:
553 else:
560 source = self.directive.state.document.current_source
554 source = self.directive.state.document.current_source
561 content = self.directive.content
555 content = self.directive.content
562 # Add tabs and join into a single string.
556 # Add tabs and join into a single string.
563 content = '\n'.join([TAB + line for line in content])
557 content = '\n'.join([TAB + line for line in content])
564
558
565 # Make sure the output contains the output prompt.
559 # Make sure the output contains the output prompt.
566 ind = found.find(output_prompt)
560 ind = found.find(output_prompt)
567 if ind < 0:
561 if ind < 0:
568 e = ('output does not contain output prompt\n\n'
562 e = ('output does not contain output prompt\n\n'
569 'Document source: {0}\n\n'
563 'Document source: {0}\n\n'
570 'Raw content: \n{1}\n\n'
564 'Raw content: \n{1}\n\n'
571 'Input line(s):\n{TAB}{2}\n\n'
565 'Input line(s):\n{TAB}{2}\n\n'
572 'Output line(s):\n{TAB}{3}\n\n')
566 'Output line(s):\n{TAB}{3}\n\n')
573 e = e.format(source, content, '\n'.join(input_lines),
567 e = e.format(source, content, '\n'.join(input_lines),
574 repr(found), TAB=TAB)
568 repr(found), TAB=TAB)
575 raise RuntimeError(e)
569 raise RuntimeError(e)
576 found = found[len(output_prompt):].strip()
570 found = found[len(output_prompt):].strip()
577
571
578 # Handle the actual doctest comparison.
572 # Handle the actual doctest comparison.
579 if decorator.strip() == '@doctest':
573 if decorator.strip() == '@doctest':
580 # Standard doctest
574 # Standard doctest
581 if found != submitted:
575 if found != submitted:
582 e = ('doctest failure\n\n'
576 e = ('doctest failure\n\n'
583 'Document source: {0}\n\n'
577 'Document source: {0}\n\n'
584 'Raw content: \n{1}\n\n'
578 'Raw content: \n{1}\n\n'
585 'On input line(s):\n{TAB}{2}\n\n'
579 'On input line(s):\n{TAB}{2}\n\n'
586 'we found output:\n{TAB}{3}\n\n'
580 'we found output:\n{TAB}{3}\n\n'
587 'instead of the expected:\n{TAB}{4}\n\n')
581 'instead of the expected:\n{TAB}{4}\n\n')
588 e = e.format(source, content, '\n'.join(input_lines),
582 e = e.format(source, content, '\n'.join(input_lines),
589 repr(found), repr(submitted), TAB=TAB)
583 repr(found), repr(submitted), TAB=TAB)
590 raise RuntimeError(e)
584 raise RuntimeError(e)
591 else:
585 else:
592 self.custom_doctest(decorator, input_lines, found, submitted)
586 self.custom_doctest(decorator, input_lines, found, submitted)
593
587
594 # When in verbatim mode, this holds additional submitted output
588 # When in verbatim mode, this holds additional submitted output
595 # to be written in the final Sphinx output.
589 # to be written in the final Sphinx output.
596 # https://github.com/ipython/ipython/issues/5776
590 # https://github.com/ipython/ipython/issues/5776
597 out_data = []
591 out_data = []
598
592
599 is_verbatim = decorator=='@verbatim' or self.is_verbatim
593 is_verbatim = decorator=='@verbatim' or self.is_verbatim
600 if is_verbatim and data.strip():
594 if is_verbatim and data.strip():
601 # Note that `ret` in `process_block` has '' as its last element if
595 # Note that `ret` in `process_block` has '' as its last element if
602 # the code block was in verbatim mode. So if there is no submitted
596 # the code block was in verbatim mode. So if there is no submitted
603 # output, then we will have proper spacing only if we do not add
597 # output, then we will have proper spacing only if we do not add
604 # an additional '' to `out_data`. This is why we condition on
598 # an additional '' to `out_data`. This is why we condition on
605 # `and data.strip()`.
599 # `and data.strip()`.
606
600
607 # The submitted output has no output prompt. If we want the
601 # The submitted output has no output prompt. If we want the
608 # prompt and the code to appear, we need to join them now
602 # prompt and the code to appear, we need to join them now
609 # instead of adding them separately---as this would create an
603 # instead of adding them separately---as this would create an
610 # undesired newline. How we do this ultimately depends on the
604 # undesired newline. How we do this ultimately depends on the
611 # format of the output regex. I'll do what works for the default
605 # format of the output regex. I'll do what works for the default
612 # prompt for now, and we might have to adjust if it doesn't work
606 # prompt for now, and we might have to adjust if it doesn't work
613 # in other cases. Finally, the submitted output does not have
607 # in other cases. Finally, the submitted output does not have
614 # a trailing newline, so we must add it manually.
608 # a trailing newline, so we must add it manually.
615 out_data.append("{0} {1}\n".format(output_prompt, data))
609 out_data.append("{0} {1}\n".format(output_prompt, data))
616
610
617 return out_data
611 return out_data
618
612
619 def process_comment(self, data):
613 def process_comment(self, data):
620 """Process data fPblock for COMMENT token."""
614 """Process data fPblock for COMMENT token."""
621 if not self.is_suppress:
615 if not self.is_suppress:
622 return [data]
616 return [data]
623
617
624 def save_image(self, image_file):
618 def save_image(self, image_file):
625 """
619 """
626 Saves the image file to disk.
620 Saves the image file to disk.
627 """
621 """
628 self.ensure_pyplot()
622 self.ensure_pyplot()
629 command = 'plt.gcf().savefig("%s")'%image_file
623 command = 'plt.gcf().savefig("%s")'%image_file
630 #print 'SAVEFIG', command # dbg
624 #print 'SAVEFIG', command # dbg
631 self.process_input_line('bookmark ipy_thisdir', store_history=False)
625 self.process_input_line('bookmark ipy_thisdir', store_history=False)
632 self.process_input_line('cd -b ipy_savedir', store_history=False)
626 self.process_input_line('cd -b ipy_savedir', store_history=False)
633 self.process_input_line(command, store_history=False)
627 self.process_input_line(command, store_history=False)
634 self.process_input_line('cd -b ipy_thisdir', store_history=False)
628 self.process_input_line('cd -b ipy_thisdir', store_history=False)
635 self.process_input_line('bookmark -d ipy_thisdir', store_history=False)
629 self.process_input_line('bookmark -d ipy_thisdir', store_history=False)
636 self.clear_cout()
630 self.clear_cout()
637
631
638 def process_block(self, block):
632 def process_block(self, block):
639 """
633 """
640 process block from the block_parser and return a list of processed lines
634 process block from the block_parser and return a list of processed lines
641 """
635 """
642 ret = []
636 ret = []
643 output = None
637 output = None
644 input_lines = None
638 input_lines = None
645 lineno = self.IP.execution_count
639 lineno = self.IP.execution_count
646
640
647 input_prompt = self.promptin % lineno
641 input_prompt = self.promptin % lineno
648 output_prompt = self.promptout % lineno
642 output_prompt = self.promptout % lineno
649 image_file = None
643 image_file = None
650 image_directive = None
644 image_directive = None
651
645
652 found_input = False
646 found_input = False
653 for token, data in block:
647 for token, data in block:
654 if token == COMMENT:
648 if token == COMMENT:
655 out_data = self.process_comment(data)
649 out_data = self.process_comment(data)
656 elif token == INPUT:
650 elif token == INPUT:
657 found_input = True
651 found_input = True
658 (out_data, input_lines, output, is_doctest,
652 (out_data, input_lines, output, is_doctest,
659 decorator, image_file, image_directive) = \
653 decorator, image_file, image_directive) = \
660 self.process_input(data, input_prompt, lineno)
654 self.process_input(data, input_prompt, lineno)
661 elif token == OUTPUT:
655 elif token == OUTPUT:
662 if not found_input:
656 if not found_input:
663
657
664 TAB = ' ' * 4
658 TAB = ' ' * 4
665 linenumber = 0
659 linenumber = 0
666 source = 'Unavailable'
660 source = 'Unavailable'
667 content = 'Unavailable'
661 content = 'Unavailable'
668 if self.directive:
662 if self.directive:
669 linenumber = self.directive.state.document.current_line
663 linenumber = self.directive.state.document.current_line
670 source = self.directive.state.document.current_source
664 source = self.directive.state.document.current_source
671 content = self.directive.content
665 content = self.directive.content
672 # Add tabs and join into a single string.
666 # Add tabs and join into a single string.
673 content = '\n'.join([TAB + line for line in content])
667 content = '\n'.join([TAB + line for line in content])
674
668
675 e = ('\n\nInvalid block: Block contains an output prompt '
669 e = ('\n\nInvalid block: Block contains an output prompt '
676 'without an input prompt.\n\n'
670 'without an input prompt.\n\n'
677 'Document source: {0}\n\n'
671 'Document source: {0}\n\n'
678 'Content begins at line {1}: \n\n{2}\n\n'
672 'Content begins at line {1}: \n\n{2}\n\n'
679 'Problematic block within content: \n\n{TAB}{3}\n\n')
673 'Problematic block within content: \n\n{TAB}{3}\n\n')
680 e = e.format(source, linenumber, content, block, TAB=TAB)
674 e = e.format(source, linenumber, content, block, TAB=TAB)
681
675
682 # Write, rather than include in exception, since Sphinx
676 # Write, rather than include in exception, since Sphinx
683 # will truncate tracebacks.
677 # will truncate tracebacks.
684 sys.stdout.write(e)
678 sys.stdout.write(e)
685 raise RuntimeError('An invalid block was detected.')
679 raise RuntimeError('An invalid block was detected.')
686
680
687 out_data = \
681 out_data = \
688 self.process_output(data, output_prompt, input_lines,
682 self.process_output(data, output_prompt, input_lines,
689 output, is_doctest, decorator,
683 output, is_doctest, decorator,
690 image_file)
684 image_file)
691 if out_data:
685 if out_data:
692 # Then there was user submitted output in verbatim mode.
686 # Then there was user submitted output in verbatim mode.
693 # We need to remove the last element of `ret` that was
687 # We need to remove the last element of `ret` that was
694 # added in `process_input`, as it is '' and would introduce
688 # added in `process_input`, as it is '' and would introduce
695 # an undesirable newline.
689 # an undesirable newline.
696 assert(ret[-1] == '')
690 assert(ret[-1] == '')
697 del ret[-1]
691 del ret[-1]
698
692
699 if out_data:
693 if out_data:
700 ret.extend(out_data)
694 ret.extend(out_data)
701
695
702 # save the image files
696 # save the image files
703 if image_file is not None:
697 if image_file is not None:
704 self.save_image(image_file)
698 self.save_image(image_file)
705
699
706 return ret, image_directive
700 return ret, image_directive
707
701
708 def ensure_pyplot(self):
702 def ensure_pyplot(self):
709 """
703 """
710 Ensures that pyplot has been imported into the embedded IPython shell.
704 Ensures that pyplot has been imported into the embedded IPython shell.
711
705
712 Also, makes sure to set the backend appropriately if not set already.
706 Also, makes sure to set the backend appropriately if not set already.
713
707
714 """
708 """
715 # We are here if the @figure pseudo decorator was used. Thus, it's
709 # We are here if the @figure pseudo decorator was used. Thus, it's
716 # possible that we could be here even if python_mplbackend were set to
710 # possible that we could be here even if python_mplbackend were set to
717 # `None`. That's also strange and perhaps worthy of raising an
711 # `None`. That's also strange and perhaps worthy of raising an
718 # exception, but for now, we just set the backend to 'agg'.
712 # exception, but for now, we just set the backend to 'agg'.
719
713
720 if not self._pyplot_imported:
714 if not self._pyplot_imported:
721 if 'matplotlib.backends' not in sys.modules:
715 if 'matplotlib.backends' not in sys.modules:
722 # Then ipython_matplotlib was set to None but there was a
716 # Then ipython_matplotlib was set to None but there was a
723 # call to the @figure decorator (and ipython_execlines did
717 # call to the @figure decorator (and ipython_execlines did
724 # not set a backend).
718 # not set a backend).
725 #raise Exception("No backend was set, but @figure was used!")
719 #raise Exception("No backend was set, but @figure was used!")
726 import matplotlib
720 import matplotlib
727 matplotlib.use('agg')
721 matplotlib.use('agg')
728
722
729 # Always import pyplot into embedded shell.
723 # Always import pyplot into embedded shell.
730 self.process_input_line('import matplotlib.pyplot as plt',
724 self.process_input_line('import matplotlib.pyplot as plt',
731 store_history=False)
725 store_history=False)
732 self._pyplot_imported = True
726 self._pyplot_imported = True
733
727
734 def process_pure_python(self, content):
728 def process_pure_python(self, content):
735 """
729 """
736 content is a list of strings. it is unedited directive content
730 content is a list of strings. it is unedited directive content
737
731
738 This runs it line by line in the InteractiveShell, prepends
732 This runs it line by line in the InteractiveShell, prepends
739 prompts as needed capturing stderr and stdout, then returns
733 prompts as needed capturing stderr and stdout, then returns
740 the content as a list as if it were ipython code
734 the content as a list as if it were ipython code
741 """
735 """
742 output = []
736 output = []
743 savefig = False # keep up with this to clear figure
737 savefig = False # keep up with this to clear figure
744 multiline = False # to handle line continuation
738 multiline = False # to handle line continuation
745 multiline_start = None
739 multiline_start = None
746 fmtin = self.promptin
740 fmtin = self.promptin
747
741
748 ct = 0
742 ct = 0
749
743
750 for lineno, line in enumerate(content):
744 for lineno, line in enumerate(content):
751
745
752 line_stripped = line.strip()
746 line_stripped = line.strip()
753 if not len(line):
747 if not len(line):
754 output.append(line)
748 output.append(line)
755 continue
749 continue
756
750
757 # handle decorators
751 # handle decorators
758 if line_stripped.startswith('@'):
752 if line_stripped.startswith('@'):
759 output.extend([line])
753 output.extend([line])
760 if 'savefig' in line:
754 if 'savefig' in line:
761 savefig = True # and need to clear figure
755 savefig = True # and need to clear figure
762 continue
756 continue
763
757
764 # handle comments
758 # handle comments
765 if line_stripped.startswith('#'):
759 if line_stripped.startswith('#'):
766 output.extend([line])
760 output.extend([line])
767 continue
761 continue
768
762
769 # deal with lines checking for multiline
763 # deal with lines checking for multiline
770 continuation = u' %s:'% ''.join(['.']*(len(str(ct))+2))
764 continuation = u' %s:'% ''.join(['.']*(len(str(ct))+2))
771 if not multiline:
765 if not multiline:
772 modified = u"%s %s" % (fmtin % ct, line_stripped)
766 modified = u"%s %s" % (fmtin % ct, line_stripped)
773 output.append(modified)
767 output.append(modified)
774 ct += 1
768 ct += 1
775 try:
769 try:
776 ast.parse(line_stripped)
770 ast.parse(line_stripped)
777 output.append(u'')
771 output.append(u'')
778 except Exception: # on a multiline
772 except Exception: # on a multiline
779 multiline = True
773 multiline = True
780 multiline_start = lineno
774 multiline_start = lineno
781 else: # still on a multiline
775 else: # still on a multiline
782 modified = u'%s %s' % (continuation, line)
776 modified = u'%s %s' % (continuation, line)
783 output.append(modified)
777 output.append(modified)
784
778
785 # if the next line is indented, it should be part of multiline
779 # if the next line is indented, it should be part of multiline
786 if len(content) > lineno + 1:
780 if len(content) > lineno + 1:
787 nextline = content[lineno + 1]
781 nextline = content[lineno + 1]
788 if len(nextline) - len(nextline.lstrip()) > 3:
782 if len(nextline) - len(nextline.lstrip()) > 3:
789 continue
783 continue
790 try:
784 try:
791 mod = ast.parse(
785 mod = ast.parse(
792 '\n'.join(content[multiline_start:lineno+1]))
786 '\n'.join(content[multiline_start:lineno+1]))
793 if isinstance(mod.body[0], ast.FunctionDef):
787 if isinstance(mod.body[0], ast.FunctionDef):
794 # check to see if we have the whole function
788 # check to see if we have the whole function
795 for element in mod.body[0].body:
789 for element in mod.body[0].body:
796 if isinstance(element, ast.Return):
790 if isinstance(element, ast.Return):
797 multiline = False
791 multiline = False
798 else:
792 else:
799 output.append(u'')
793 output.append(u'')
800 multiline = False
794 multiline = False
801 except Exception:
795 except Exception:
802 pass
796 pass
803
797
804 if savefig: # clear figure if plotted
798 if savefig: # clear figure if plotted
805 self.ensure_pyplot()
799 self.ensure_pyplot()
806 self.process_input_line('plt.clf()', store_history=False)
800 self.process_input_line('plt.clf()', store_history=False)
807 self.clear_cout()
801 self.clear_cout()
808 savefig = False
802 savefig = False
809
803
810 return output
804 return output
811
805
812 def custom_doctest(self, decorator, input_lines, found, submitted):
806 def custom_doctest(self, decorator, input_lines, found, submitted):
813 """
807 """
814 Perform a specialized doctest.
808 Perform a specialized doctest.
815
809
816 """
810 """
817 from .custom_doctests import doctests
811 from .custom_doctests import doctests
818
812
819 args = decorator.split()
813 args = decorator.split()
820 doctest_type = args[1]
814 doctest_type = args[1]
821 if doctest_type in doctests:
815 if doctest_type in doctests:
822 doctests[doctest_type](self, args, input_lines, found, submitted)
816 doctests[doctest_type](self, args, input_lines, found, submitted)
823 else:
817 else:
824 e = "Invalid option to @doctest: {0}".format(doctest_type)
818 e = "Invalid option to @doctest: {0}".format(doctest_type)
825 raise Exception(e)
819 raise Exception(e)
826
820
827
821
828 class IPythonDirective(Directive):
822 class IPythonDirective(Directive):
829
823
830 has_content = True
824 has_content = True
831 required_arguments = 0
825 required_arguments = 0
832 optional_arguments = 4 # python, suppress, verbatim, doctest
826 optional_arguments = 4 # python, suppress, verbatim, doctest
833 final_argumuent_whitespace = True
827 final_argumuent_whitespace = True
834 option_spec = { 'python': directives.unchanged,
828 option_spec = { 'python': directives.unchanged,
835 'suppress' : directives.flag,
829 'suppress' : directives.flag,
836 'verbatim' : directives.flag,
830 'verbatim' : directives.flag,
837 'doctest' : directives.flag,
831 'doctest' : directives.flag,
838 'okexcept': directives.flag,
832 'okexcept': directives.flag,
839 'okwarning': directives.flag
833 'okwarning': directives.flag
840 }
834 }
841
835
842 shell = None
836 shell = None
843
837
844 seen_docs = set()
838 seen_docs = set()
845
839
846 def get_config_options(self):
840 def get_config_options(self):
847 # contains sphinx configuration variables
841 # contains sphinx configuration variables
848 config = self.state.document.settings.env.config
842 config = self.state.document.settings.env.config
849
843
850 # get config variables to set figure output directory
844 # get config variables to set figure output directory
851 outdir = self.state.document.settings.env.app.outdir
845 outdir = self.state.document.settings.env.app.outdir
852 savefig_dir = config.ipython_savefig_dir
846 savefig_dir = config.ipython_savefig_dir
853 source_dir = os.path.dirname(self.state.document.current_source)
847 source_dir = os.path.dirname(self.state.document.current_source)
854 if savefig_dir is None:
848 if savefig_dir is None:
855 savefig_dir = config.html_static_path or '_static'
849 savefig_dir = config.html_static_path or '_static'
856 if isinstance(savefig_dir, list):
850 if isinstance(savefig_dir, list):
857 savefig_dir = os.path.join(*savefig_dir)
851 savefig_dir = os.path.join(*savefig_dir)
858 savefig_dir = os.path.join(outdir, savefig_dir)
852 savefig_dir = os.path.join(outdir, savefig_dir)
859
853
860 # get regex and prompt stuff
854 # get regex and prompt stuff
861 rgxin = config.ipython_rgxin
855 rgxin = config.ipython_rgxin
862 rgxout = config.ipython_rgxout
856 rgxout = config.ipython_rgxout
863 promptin = config.ipython_promptin
857 promptin = config.ipython_promptin
864 promptout = config.ipython_promptout
858 promptout = config.ipython_promptout
865 mplbackend = config.ipython_mplbackend
859 mplbackend = config.ipython_mplbackend
866 exec_lines = config.ipython_execlines
860 exec_lines = config.ipython_execlines
867 hold_count = config.ipython_holdcount
861 hold_count = config.ipython_holdcount
868
862
869 return (savefig_dir, source_dir, rgxin, rgxout,
863 return (savefig_dir, source_dir, rgxin, rgxout,
870 promptin, promptout, mplbackend, exec_lines, hold_count)
864 promptin, promptout, mplbackend, exec_lines, hold_count)
871
865
872 def setup(self):
866 def setup(self):
873 # Get configuration values.
867 # Get configuration values.
874 (savefig_dir, source_dir, rgxin, rgxout, promptin, promptout,
868 (savefig_dir, source_dir, rgxin, rgxout, promptin, promptout,
875 mplbackend, exec_lines, hold_count) = self.get_config_options()
869 mplbackend, exec_lines, hold_count) = self.get_config_options()
876
870
877 if self.shell is None:
871 if self.shell is None:
878 # We will be here many times. However, when the
872 # We will be here many times. However, when the
879 # EmbeddedSphinxShell is created, its interactive shell member
873 # EmbeddedSphinxShell is created, its interactive shell member
880 # is the same for each instance.
874 # is the same for each instance.
881
875
882 if mplbackend and 'matplotlib.backends' not in sys.modules:
876 if mplbackend and 'matplotlib.backends' not in sys.modules:
883 import matplotlib
877 import matplotlib
884 matplotlib.use(mplbackend)
878 matplotlib.use(mplbackend)
885
879
886 # Must be called after (potentially) importing matplotlib and
880 # Must be called after (potentially) importing matplotlib and
887 # setting its backend since exec_lines might import pylab.
881 # setting its backend since exec_lines might import pylab.
888 self.shell = EmbeddedSphinxShell(exec_lines)
882 self.shell = EmbeddedSphinxShell(exec_lines)
889
883
890 # Store IPython directive to enable better error messages
884 # Store IPython directive to enable better error messages
891 self.shell.directive = self
885 self.shell.directive = self
892
886
893 # reset the execution count if we haven't processed this doc
887 # reset the execution count if we haven't processed this doc
894 #NOTE: this may be borked if there are multiple seen_doc tmp files
888 #NOTE: this may be borked if there are multiple seen_doc tmp files
895 #check time stamp?
889 #check time stamp?
896 if not self.state.document.current_source in self.seen_docs:
890 if not self.state.document.current_source in self.seen_docs:
897 self.shell.IP.history_manager.reset()
891 self.shell.IP.history_manager.reset()
898 self.shell.IP.execution_count = 1
892 self.shell.IP.execution_count = 1
899 self.seen_docs.add(self.state.document.current_source)
893 self.seen_docs.add(self.state.document.current_source)
900
894
901 # and attach to shell so we don't have to pass them around
895 # and attach to shell so we don't have to pass them around
902 self.shell.rgxin = rgxin
896 self.shell.rgxin = rgxin
903 self.shell.rgxout = rgxout
897 self.shell.rgxout = rgxout
904 self.shell.promptin = promptin
898 self.shell.promptin = promptin
905 self.shell.promptout = promptout
899 self.shell.promptout = promptout
906 self.shell.savefig_dir = savefig_dir
900 self.shell.savefig_dir = savefig_dir
907 self.shell.source_dir = source_dir
901 self.shell.source_dir = source_dir
908 self.shell.hold_count = hold_count
902 self.shell.hold_count = hold_count
909
903
910 # setup bookmark for saving figures directory
904 # setup bookmark for saving figures directory
911 self.shell.process_input_line('bookmark ipy_savedir %s'%savefig_dir,
905 self.shell.process_input_line('bookmark ipy_savedir %s'%savefig_dir,
912 store_history=False)
906 store_history=False)
913 self.shell.clear_cout()
907 self.shell.clear_cout()
914
908
915 return rgxin, rgxout, promptin, promptout
909 return rgxin, rgxout, promptin, promptout
916
910
917 def teardown(self):
911 def teardown(self):
918 # delete last bookmark
912 # delete last bookmark
919 self.shell.process_input_line('bookmark -d ipy_savedir',
913 self.shell.process_input_line('bookmark -d ipy_savedir',
920 store_history=False)
914 store_history=False)
921 self.shell.clear_cout()
915 self.shell.clear_cout()
922
916
923 def run(self):
917 def run(self):
924 debug = False
918 debug = False
925
919
926 #TODO, any reason block_parser can't be a method of embeddable shell
920 #TODO, any reason block_parser can't be a method of embeddable shell
927 # then we wouldn't have to carry these around
921 # then we wouldn't have to carry these around
928 rgxin, rgxout, promptin, promptout = self.setup()
922 rgxin, rgxout, promptin, promptout = self.setup()
929
923
930 options = self.options
924 options = self.options
931 self.shell.is_suppress = 'suppress' in options
925 self.shell.is_suppress = 'suppress' in options
932 self.shell.is_doctest = 'doctest' in options
926 self.shell.is_doctest = 'doctest' in options
933 self.shell.is_verbatim = 'verbatim' in options
927 self.shell.is_verbatim = 'verbatim' in options
934 self.shell.is_okexcept = 'okexcept' in options
928 self.shell.is_okexcept = 'okexcept' in options
935 self.shell.is_okwarning = 'okwarning' in options
929 self.shell.is_okwarning = 'okwarning' in options
936
930
937 # handle pure python code
931 # handle pure python code
938 if 'python' in self.arguments:
932 if 'python' in self.arguments:
939 content = self.content
933 content = self.content
940 self.content = self.shell.process_pure_python(content)
934 self.content = self.shell.process_pure_python(content)
941
935
942 # parts consists of all text within the ipython-block.
936 # parts consists of all text within the ipython-block.
943 # Each part is an input/output block.
937 # Each part is an input/output block.
944 parts = '\n'.join(self.content).split('\n\n')
938 parts = '\n'.join(self.content).split('\n\n')
945
939
946 lines = ['.. code-block:: ipython', '']
940 lines = ['.. code-block:: ipython', '']
947 figures = []
941 figures = []
948
942
949 for part in parts:
943 for part in parts:
950 block = block_parser(part, rgxin, rgxout, promptin, promptout)
944 block = block_parser(part, rgxin, rgxout, promptin, promptout)
951 if len(block):
945 if len(block):
952 rows, figure = self.shell.process_block(block)
946 rows, figure = self.shell.process_block(block)
953 for row in rows:
947 for row in rows:
954 lines.extend([' {0}'.format(line)
948 lines.extend([' {0}'.format(line)
955 for line in row.split('\n')])
949 for line in row.split('\n')])
956
950
957 if figure is not None:
951 if figure is not None:
958 figures.append(figure)
952 figures.append(figure)
959
953
960 for figure in figures:
954 for figure in figures:
961 lines.append('')
955 lines.append('')
962 lines.extend(figure.split('\n'))
956 lines.extend(figure.split('\n'))
963 lines.append('')
957 lines.append('')
964
958
965 if len(lines) > 2:
959 if len(lines) > 2:
966 if debug:
960 if debug:
967 print('\n'.join(lines))
961 print('\n'.join(lines))
968 else:
962 else:
969 # This has to do with input, not output. But if we comment
963 # This has to do with input, not output. But if we comment
970 # these lines out, then no IPython code will appear in the
964 # these lines out, then no IPython code will appear in the
971 # final output.
965 # final output.
972 self.state_machine.insert_input(
966 self.state_machine.insert_input(
973 lines, self.state_machine.input_lines.source(0))
967 lines, self.state_machine.input_lines.source(0))
974
968
975 # cleanup
969 # cleanup
976 self.teardown()
970 self.teardown()
977
971
978 return []
972 return []
979
973
980 # Enable as a proper Sphinx directive
974 # Enable as a proper Sphinx directive
981 def setup(app):
975 def setup(app):
982 setup.app = app
976 setup.app = app
983
977
984 app.add_directive('ipython', IPythonDirective)
978 app.add_directive('ipython', IPythonDirective)
985 app.add_config_value('ipython_savefig_dir', None, 'env')
979 app.add_config_value('ipython_savefig_dir', None, 'env')
986 app.add_config_value('ipython_rgxin',
980 app.add_config_value('ipython_rgxin',
987 re.compile('In \[(\d+)\]:\s?(.*)\s*'), 'env')
981 re.compile('In \[(\d+)\]:\s?(.*)\s*'), 'env')
988 app.add_config_value('ipython_rgxout',
982 app.add_config_value('ipython_rgxout',
989 re.compile('Out\[(\d+)\]:\s?(.*)\s*'), 'env')
983 re.compile('Out\[(\d+)\]:\s?(.*)\s*'), 'env')
990 app.add_config_value('ipython_promptin', 'In [%d]:', 'env')
984 app.add_config_value('ipython_promptin', 'In [%d]:', 'env')
991 app.add_config_value('ipython_promptout', 'Out[%d]:', 'env')
985 app.add_config_value('ipython_promptout', 'Out[%d]:', 'env')
992
986
993 # We could just let matplotlib pick whatever is specified as the default
987 # We could just let matplotlib pick whatever is specified as the default
994 # backend in the matplotlibrc file, but this would cause issues if the
988 # backend in the matplotlibrc file, but this would cause issues if the
995 # backend didn't work in headless environments. For this reason, 'agg'
989 # backend didn't work in headless environments. For this reason, 'agg'
996 # is a good default backend choice.
990 # is a good default backend choice.
997 app.add_config_value('ipython_mplbackend', 'agg', 'env')
991 app.add_config_value('ipython_mplbackend', 'agg', 'env')
998
992
999 # If the user sets this config value to `None`, then EmbeddedSphinxShell's
993 # If the user sets this config value to `None`, then EmbeddedSphinxShell's
1000 # __init__ method will treat it as [].
994 # __init__ method will treat it as [].
1001 execlines = ['import numpy as np', 'import matplotlib.pyplot as plt']
995 execlines = ['import numpy as np', 'import matplotlib.pyplot as plt']
1002 app.add_config_value('ipython_execlines', execlines, 'env')
996 app.add_config_value('ipython_execlines', execlines, 'env')
1003
997
1004 app.add_config_value('ipython_holdcount', True, 'env')
998 app.add_config_value('ipython_holdcount', True, 'env')
1005
999
1006 metadata = {'parallel_read_safe': True, 'parallel_write_safe': True}
1000 metadata = {'parallel_read_safe': True, 'parallel_write_safe': True}
1007 return metadata
1001 return metadata
1008
1002
1009 # Simple smoke test, needs to be converted to a proper automatic test.
1003 # Simple smoke test, needs to be converted to a proper automatic test.
1010 def test():
1004 def test():
1011
1005
1012 examples = [
1006 examples = [
1013 r"""
1007 r"""
1014 In [9]: pwd
1008 In [9]: pwd
1015 Out[9]: '/home/jdhunter/py4science/book'
1009 Out[9]: '/home/jdhunter/py4science/book'
1016
1010
1017 In [10]: cd bookdata/
1011 In [10]: cd bookdata/
1018 /home/jdhunter/py4science/book/bookdata
1012 /home/jdhunter/py4science/book/bookdata
1019
1013
1020 In [2]: from pylab import *
1014 In [2]: from pylab import *
1021
1015
1022 In [2]: ion()
1016 In [2]: ion()
1023
1017
1024 In [3]: im = imread('stinkbug.png')
1018 In [3]: im = imread('stinkbug.png')
1025
1019
1026 @savefig mystinkbug.png width=4in
1020 @savefig mystinkbug.png width=4in
1027 In [4]: imshow(im)
1021 In [4]: imshow(im)
1028 Out[4]: <matplotlib.image.AxesImage object at 0x39ea850>
1022 Out[4]: <matplotlib.image.AxesImage object at 0x39ea850>
1029
1023
1030 """,
1024 """,
1031 r"""
1025 r"""
1032
1026
1033 In [1]: x = 'hello world'
1027 In [1]: x = 'hello world'
1034
1028
1035 # string methods can be
1029 # string methods can be
1036 # used to alter the string
1030 # used to alter the string
1037 @doctest
1031 @doctest
1038 In [2]: x.upper()
1032 In [2]: x.upper()
1039 Out[2]: 'HELLO WORLD'
1033 Out[2]: 'HELLO WORLD'
1040
1034
1041 @verbatim
1035 @verbatim
1042 In [3]: x.st<TAB>
1036 In [3]: x.st<TAB>
1043 x.startswith x.strip
1037 x.startswith x.strip
1044 """,
1038 """,
1045 r"""
1039 r"""
1046
1040
1047 In [130]: url = 'http://ichart.finance.yahoo.com/table.csv?s=CROX\
1041 In [130]: url = 'http://ichart.finance.yahoo.com/table.csv?s=CROX\
1048 .....: &d=9&e=22&f=2009&g=d&a=1&br=8&c=2006&ignore=.csv'
1042 .....: &d=9&e=22&f=2009&g=d&a=1&br=8&c=2006&ignore=.csv'
1049
1043
1050 In [131]: print url.split('&')
1044 In [131]: print url.split('&')
1051 ['http://ichart.finance.yahoo.com/table.csv?s=CROX', 'd=9', 'e=22', 'f=2009', 'g=d', 'a=1', 'b=8', 'c=2006', 'ignore=.csv']
1045 ['http://ichart.finance.yahoo.com/table.csv?s=CROX', 'd=9', 'e=22', 'f=2009', 'g=d', 'a=1', 'b=8', 'c=2006', 'ignore=.csv']
1052
1046
1053 In [60]: import urllib
1047 In [60]: import urllib
1054
1048
1055 """,
1049 """,
1056 r"""\
1050 r"""\
1057
1051
1058 In [133]: import numpy.random
1052 In [133]: import numpy.random
1059
1053
1060 @suppress
1054 @suppress
1061 In [134]: numpy.random.seed(2358)
1055 In [134]: numpy.random.seed(2358)
1062
1056
1063 @doctest
1057 @doctest
1064 In [135]: numpy.random.rand(10,2)
1058 In [135]: numpy.random.rand(10,2)
1065 Out[135]:
1059 Out[135]:
1066 array([[ 0.64524308, 0.59943846],
1060 array([[ 0.64524308, 0.59943846],
1067 [ 0.47102322, 0.8715456 ],
1061 [ 0.47102322, 0.8715456 ],
1068 [ 0.29370834, 0.74776844],
1062 [ 0.29370834, 0.74776844],
1069 [ 0.99539577, 0.1313423 ],
1063 [ 0.99539577, 0.1313423 ],
1070 [ 0.16250302, 0.21103583],
1064 [ 0.16250302, 0.21103583],
1071 [ 0.81626524, 0.1312433 ],
1065 [ 0.81626524, 0.1312433 ],
1072 [ 0.67338089, 0.72302393],
1066 [ 0.67338089, 0.72302393],
1073 [ 0.7566368 , 0.07033696],
1067 [ 0.7566368 , 0.07033696],
1074 [ 0.22591016, 0.77731835],
1068 [ 0.22591016, 0.77731835],
1075 [ 0.0072729 , 0.34273127]])
1069 [ 0.0072729 , 0.34273127]])
1076
1070
1077 """,
1071 """,
1078
1072
1079 r"""
1073 r"""
1080 In [106]: print x
1074 In [106]: print x
1081 jdh
1075 jdh
1082
1076
1083 In [109]: for i in range(10):
1077 In [109]: for i in range(10):
1084 .....: print i
1078 .....: print i
1085 .....:
1079 .....:
1086 .....:
1080 .....:
1087 0
1081 0
1088 1
1082 1
1089 2
1083 2
1090 3
1084 3
1091 4
1085 4
1092 5
1086 5
1093 6
1087 6
1094 7
1088 7
1095 8
1089 8
1096 9
1090 9
1097 """,
1091 """,
1098
1092
1099 r"""
1093 r"""
1100
1094
1101 In [144]: from pylab import *
1095 In [144]: from pylab import *
1102
1096
1103 In [145]: ion()
1097 In [145]: ion()
1104
1098
1105 # use a semicolon to suppress the output
1099 # use a semicolon to suppress the output
1106 @savefig test_hist.png width=4in
1100 @savefig test_hist.png width=4in
1107 In [151]: hist(np.random.randn(10000), 100);
1101 In [151]: hist(np.random.randn(10000), 100);
1108
1102
1109
1103
1110 @savefig test_plot.png width=4in
1104 @savefig test_plot.png width=4in
1111 In [151]: plot(np.random.randn(10000), 'o');
1105 In [151]: plot(np.random.randn(10000), 'o');
1112 """,
1106 """,
1113
1107
1114 r"""
1108 r"""
1115 # use a semicolon to suppress the output
1109 # use a semicolon to suppress the output
1116 In [151]: plt.clf()
1110 In [151]: plt.clf()
1117
1111
1118 @savefig plot_simple.png width=4in
1112 @savefig plot_simple.png width=4in
1119 In [151]: plot([1,2,3])
1113 In [151]: plot([1,2,3])
1120
1114
1121 @savefig hist_simple.png width=4in
1115 @savefig hist_simple.png width=4in
1122 In [151]: hist(np.random.randn(10000), 100);
1116 In [151]: hist(np.random.randn(10000), 100);
1123
1117
1124 """,
1118 """,
1125 r"""
1119 r"""
1126 # update the current fig
1120 # update the current fig
1127 In [151]: ylabel('number')
1121 In [151]: ylabel('number')
1128
1122
1129 In [152]: title('normal distribution')
1123 In [152]: title('normal distribution')
1130
1124
1131
1125
1132 @savefig hist_with_text.png
1126 @savefig hist_with_text.png
1133 In [153]: grid(True)
1127 In [153]: grid(True)
1134
1128
1135 @doctest float
1129 @doctest float
1136 In [154]: 0.1 + 0.2
1130 In [154]: 0.1 + 0.2
1137 Out[154]: 0.3
1131 Out[154]: 0.3
1138
1132
1139 @doctest float
1133 @doctest float
1140 In [155]: np.arange(16).reshape(4,4)
1134 In [155]: np.arange(16).reshape(4,4)
1141 Out[155]:
1135 Out[155]:
1142 array([[ 0, 1, 2, 3],
1136 array([[ 0, 1, 2, 3],
1143 [ 4, 5, 6, 7],
1137 [ 4, 5, 6, 7],
1144 [ 8, 9, 10, 11],
1138 [ 8, 9, 10, 11],
1145 [12, 13, 14, 15]])
1139 [12, 13, 14, 15]])
1146
1140
1147 In [1]: x = np.arange(16, dtype=float).reshape(4,4)
1141 In [1]: x = np.arange(16, dtype=float).reshape(4,4)
1148
1142
1149 In [2]: x[0,0] = np.inf
1143 In [2]: x[0,0] = np.inf
1150
1144
1151 In [3]: x[0,1] = np.nan
1145 In [3]: x[0,1] = np.nan
1152
1146
1153 @doctest float
1147 @doctest float
1154 In [4]: x
1148 In [4]: x
1155 Out[4]:
1149 Out[4]:
1156 array([[ inf, nan, 2., 3.],
1150 array([[ inf, nan, 2., 3.],
1157 [ 4., 5., 6., 7.],
1151 [ 4., 5., 6., 7.],
1158 [ 8., 9., 10., 11.],
1152 [ 8., 9., 10., 11.],
1159 [ 12., 13., 14., 15.]])
1153 [ 12., 13., 14., 15.]])
1160
1154
1161
1155
1162 """,
1156 """,
1163 ]
1157 ]
1164 # skip local-file depending first example:
1158 # skip local-file depending first example:
1165 examples = examples[1:]
1159 examples = examples[1:]
1166
1160
1167 #ipython_directive.DEBUG = True # dbg
1161 #ipython_directive.DEBUG = True # dbg
1168 #options = dict(suppress=True) # dbg
1162 #options = dict(suppress=True) # dbg
1169 options = dict()
1163 options = dict()
1170 for example in examples:
1164 for example in examples:
1171 content = example.split('\n')
1165 content = example.split('\n')
1172 IPythonDirective('debug', arguments=None, options=options,
1166 IPythonDirective('debug', arguments=None, options=options,
1173 content=content, lineno=0,
1167 content=content, lineno=0,
1174 content_offset=None, block_text=None,
1168 content_offset=None, block_text=None,
1175 state=None, state_machine=None,
1169 state=None, state_machine=None,
1176 )
1170 )
1177
1171
1178 # Run test suite as a script
1172 # Run test suite as a script
1179 if __name__=='__main__':
1173 if __name__=='__main__':
1180 if not os.path.isdir('_static'):
1174 if not os.path.isdir('_static'):
1181 os.mkdir('_static')
1175 os.mkdir('_static')
1182 test()
1176 test()
1183 print('All OK? Check figures in _static/')
1177 print('All OK? Check figures in _static/')
@@ -1,500 +1,478 b''
1 """IPython terminal interface using prompt_toolkit"""
1 """IPython terminal interface using prompt_toolkit"""
2
2
3 import os
3 import os
4 import sys
4 import sys
5 import warnings
5 import warnings
6 from warnings import warn
6 from warnings import warn
7
7
8 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
8 from IPython.core.interactiveshell import InteractiveShell, InteractiveShellABC
9 from IPython.utils import io
9 from IPython.utils import io
10 from IPython.utils.py3compat import PY3, cast_unicode_py2, input
10 from IPython.utils.py3compat import cast_unicode_py2, input
11 from IPython.utils.terminal import toggle_set_term_title, set_term_title
11 from IPython.utils.terminal import toggle_set_term_title, set_term_title
12 from IPython.utils.process import abbrev_cwd
12 from IPython.utils.process import abbrev_cwd
13 from traitlets import Bool, Unicode, Dict, Integer, observe, Instance, Type, default, Enum, Union
13 from traitlets import Bool, Unicode, Dict, Integer, observe, Instance, Type, default, Enum, Union
14
14
15 from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
15 from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
16 from prompt_toolkit.filters import (HasFocus, Condition, IsDone)
16 from prompt_toolkit.filters import (HasFocus, Condition, IsDone)
17 from prompt_toolkit.history import InMemoryHistory
17 from prompt_toolkit.history import InMemoryHistory
18 from prompt_toolkit.shortcuts import create_prompt_application, create_eventloop, create_prompt_layout, create_output
18 from prompt_toolkit.shortcuts import create_prompt_application, create_eventloop, create_prompt_layout, create_output
19 from prompt_toolkit.interface import CommandLineInterface
19 from prompt_toolkit.interface import CommandLineInterface
20 from prompt_toolkit.key_binding.manager import KeyBindingManager
20 from prompt_toolkit.key_binding.manager import KeyBindingManager
21 from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
21 from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
22 from prompt_toolkit.styles import PygmentsStyle, DynamicStyle
22 from prompt_toolkit.styles import PygmentsStyle, DynamicStyle
23
23
24 from pygments.styles import get_style_by_name, get_all_styles
24 from pygments.styles import get_style_by_name, get_all_styles
25 from pygments.style import Style
25 from pygments.style import Style
26 from pygments.token import Token
26 from pygments.token import Token
27
27
28 from .debugger import TerminalPdb, Pdb
28 from .debugger import TerminalPdb, Pdb
29 from .magics import TerminalMagics
29 from .magics import TerminalMagics
30 from .pt_inputhooks import get_inputhook_name_and_func
30 from .pt_inputhooks import get_inputhook_name_and_func
31 from .prompts import Prompts, ClassicPrompts, RichPromptDisplayHook
31 from .prompts import Prompts, ClassicPrompts, RichPromptDisplayHook
32 from .ptutils import IPythonPTCompleter, IPythonPTLexer
32 from .ptutils import IPythonPTCompleter, IPythonPTLexer
33 from .shortcuts import register_ipython_shortcuts
33 from .shortcuts import register_ipython_shortcuts
34
34
35 DISPLAY_BANNER_DEPRECATED = object()
35 DISPLAY_BANNER_DEPRECATED = object()
36
36
37
37
38 from pygments.style import Style
38 from pygments.style import Style
39
39
40 class _NoStyle(Style): pass
40 class _NoStyle(Style): pass
41
41
42
42
43
43
44 _style_overrides_light_bg = {
44 _style_overrides_light_bg = {
45 Token.Prompt: '#0000ff',
45 Token.Prompt: '#0000ff',
46 Token.PromptNum: '#0000ee bold',
46 Token.PromptNum: '#0000ee bold',
47 Token.OutPrompt: '#cc0000',
47 Token.OutPrompt: '#cc0000',
48 Token.OutPromptNum: '#bb0000 bold',
48 Token.OutPromptNum: '#bb0000 bold',
49 }
49 }
50
50
51 _style_overrides_linux = {
51 _style_overrides_linux = {
52 Token.Prompt: '#00cc00',
52 Token.Prompt: '#00cc00',
53 Token.PromptNum: '#00bb00 bold',
53 Token.PromptNum: '#00bb00 bold',
54 Token.OutPrompt: '#cc0000',
54 Token.OutPrompt: '#cc0000',
55 Token.OutPromptNum: '#bb0000 bold',
55 Token.OutPromptNum: '#bb0000 bold',
56 }
56 }
57
57
58
58
59
59
60 def get_default_editor():
60 def get_default_editor():
61 try:
61 try:
62 ed = os.environ['EDITOR']
62 return os.environ['EDITOR']
63 if not PY3:
64 ed = ed.decode()
65 return ed
66 except KeyError:
63 except KeyError:
67 pass
64 pass
68 except UnicodeError:
65 except UnicodeError:
69 warn("$EDITOR environment variable is not pure ASCII. Using platform "
66 warn("$EDITOR environment variable is not pure ASCII. Using platform "
70 "default editor.")
67 "default editor.")
71
68
72 if os.name == 'posix':
69 if os.name == 'posix':
73 return 'vi' # the only one guaranteed to be there!
70 return 'vi' # the only one guaranteed to be there!
74 else:
71 else:
75 return 'notepad' # same in Windows!
72 return 'notepad' # same in Windows!
76
73
77 # conservatively check for tty
74 # conservatively check for tty
78 # overridden streams can result in things like:
75 # overridden streams can result in things like:
79 # - sys.stdin = None
76 # - sys.stdin = None
80 # - no isatty method
77 # - no isatty method
81 for _name in ('stdin', 'stdout', 'stderr'):
78 for _name in ('stdin', 'stdout', 'stderr'):
82 _stream = getattr(sys, _name)
79 _stream = getattr(sys, _name)
83 if not _stream or not hasattr(_stream, 'isatty') or not _stream.isatty():
80 if not _stream or not hasattr(_stream, 'isatty') or not _stream.isatty():
84 _is_tty = False
81 _is_tty = False
85 break
82 break
86 else:
83 else:
87 _is_tty = True
84 _is_tty = True
88
85
89
86
90 _use_simple_prompt = ('IPY_TEST_SIMPLE_PROMPT' in os.environ) or (not _is_tty)
87 _use_simple_prompt = ('IPY_TEST_SIMPLE_PROMPT' in os.environ) or (not _is_tty)
91
88
92 class TerminalInteractiveShell(InteractiveShell):
89 class TerminalInteractiveShell(InteractiveShell):
93 space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
90 space_for_menu = Integer(6, help='Number of line at the bottom of the screen '
94 'to reserve for the completion menu'
91 'to reserve for the completion menu'
95 ).tag(config=True)
92 ).tag(config=True)
96
93
97 def _space_for_menu_changed(self, old, new):
94 def _space_for_menu_changed(self, old, new):
98 self._update_layout()
95 self._update_layout()
99
96
100 pt_cli = None
97 pt_cli = None
101 debugger_history = None
98 debugger_history = None
102 _pt_app = None
99 _pt_app = None
103
100
104 simple_prompt = Bool(_use_simple_prompt,
101 simple_prompt = Bool(_use_simple_prompt,
105 help="""Use `raw_input` for the REPL, without completion, multiline input, and prompt colors.
102 help="""Use `raw_input` for the REPL, without completion, multiline input, and prompt colors.
106
103
107 Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
104 Useful when controlling IPython as a subprocess, and piping STDIN/OUT/ERR. Known usage are:
108 IPython own testing machinery, and emacs inferior-shell integration through elpy.
105 IPython own testing machinery, and emacs inferior-shell integration through elpy.
109
106
110 This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
107 This mode default to `True` if the `IPY_TEST_SIMPLE_PROMPT`
111 environment variable is set, or the current terminal is not a tty.
108 environment variable is set, or the current terminal is not a tty.
112
109
113 """
110 """
114 ).tag(config=True)
111 ).tag(config=True)
115
112
116 @property
113 @property
117 def debugger_cls(self):
114 def debugger_cls(self):
118 return Pdb if self.simple_prompt else TerminalPdb
115 return Pdb if self.simple_prompt else TerminalPdb
119
116
120 confirm_exit = Bool(True,
117 confirm_exit = Bool(True,
121 help="""
118 help="""
122 Set to confirm when you try to exit IPython with an EOF (Control-D
119 Set to confirm when you try to exit IPython with an EOF (Control-D
123 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
120 in Unix, Control-Z/Enter in Windows). By typing 'exit' or 'quit',
124 you can force a direct exit without any confirmation.""",
121 you can force a direct exit without any confirmation.""",
125 ).tag(config=True)
122 ).tag(config=True)
126
123
127 editing_mode = Unicode('emacs',
124 editing_mode = Unicode('emacs',
128 help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
125 help="Shortcut style to use at the prompt. 'vi' or 'emacs'.",
129 ).tag(config=True)
126 ).tag(config=True)
130
127
131 mouse_support = Bool(False,
128 mouse_support = Bool(False,
132 help="Enable mouse support in the prompt"
129 help="Enable mouse support in the prompt"
133 ).tag(config=True)
130 ).tag(config=True)
134
131
135 highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
132 highlighting_style = Union([Unicode('legacy'), Type(klass=Style)],
136 help="""The name or class of a Pygments style to use for syntax
133 help="""The name or class of a Pygments style to use for syntax
137 highlighting: \n %s""" % ', '.join(get_all_styles())
134 highlighting: \n %s""" % ', '.join(get_all_styles())
138 ).tag(config=True)
135 ).tag(config=True)
139
136
140
137
141 @observe('highlighting_style')
138 @observe('highlighting_style')
142 @observe('colors')
139 @observe('colors')
143 def _highlighting_style_changed(self, change):
140 def _highlighting_style_changed(self, change):
144 self.refresh_style()
141 self.refresh_style()
145
142
146 def refresh_style(self):
143 def refresh_style(self):
147 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
144 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
148
145
149
146
150 highlighting_style_overrides = Dict(
147 highlighting_style_overrides = Dict(
151 help="Override highlighting format for specific tokens"
148 help="Override highlighting format for specific tokens"
152 ).tag(config=True)
149 ).tag(config=True)
153
150
154 true_color = Bool(False,
151 true_color = Bool(False,
155 help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
152 help=("Use 24bit colors instead of 256 colors in prompt highlighting. "
156 "If your terminal supports true color, the following command "
153 "If your terminal supports true color, the following command "
157 "should print 'TRUECOLOR' in orange: "
154 "should print 'TRUECOLOR' in orange: "
158 "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
155 "printf \"\\x1b[38;2;255;100;0mTRUECOLOR\\x1b[0m\\n\"")
159 ).tag(config=True)
156 ).tag(config=True)
160
157
161 editor = Unicode(get_default_editor(),
158 editor = Unicode(get_default_editor(),
162 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
159 help="Set the editor used by IPython (default to $EDITOR/vi/notepad)."
163 ).tag(config=True)
160 ).tag(config=True)
164
161
165 prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)
162 prompts_class = Type(Prompts, help='Class used to generate Prompt token for prompt_toolkit').tag(config=True)
166
163
167 prompts = Instance(Prompts)
164 prompts = Instance(Prompts)
168
165
169 @default('prompts')
166 @default('prompts')
170 def _prompts_default(self):
167 def _prompts_default(self):
171 return self.prompts_class(self)
168 return self.prompts_class(self)
172
169
173 @observe('prompts')
170 @observe('prompts')
174 def _(self, change):
171 def _(self, change):
175 self._update_layout()
172 self._update_layout()
176
173
177 @default('displayhook_class')
174 @default('displayhook_class')
178 def _displayhook_class_default(self):
175 def _displayhook_class_default(self):
179 return RichPromptDisplayHook
176 return RichPromptDisplayHook
180
177
181 term_title = Bool(True,
178 term_title = Bool(True,
182 help="Automatically set the terminal title"
179 help="Automatically set the terminal title"
183 ).tag(config=True)
180 ).tag(config=True)
184
181
185 display_completions = Enum(('column', 'multicolumn','readlinelike'),
182 display_completions = Enum(('column', 'multicolumn','readlinelike'),
186 help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
183 help= ( "Options for displaying tab completions, 'column', 'multicolumn', and "
187 "'readlinelike'. These options are for `prompt_toolkit`, see "
184 "'readlinelike'. These options are for `prompt_toolkit`, see "
188 "`prompt_toolkit` documentation for more information."
185 "`prompt_toolkit` documentation for more information."
189 ),
186 ),
190 default_value='multicolumn').tag(config=True)
187 default_value='multicolumn').tag(config=True)
191
188
192 highlight_matching_brackets = Bool(True,
189 highlight_matching_brackets = Bool(True,
193 help="Highlight matching brackets .",
190 help="Highlight matching brackets .",
194 ).tag(config=True)
191 ).tag(config=True)
195
192
196 @observe('term_title')
193 @observe('term_title')
197 def init_term_title(self, change=None):
194 def init_term_title(self, change=None):
198 # Enable or disable the terminal title.
195 # Enable or disable the terminal title.
199 if self.term_title:
196 if self.term_title:
200 toggle_set_term_title(True)
197 toggle_set_term_title(True)
201 set_term_title('IPython: ' + abbrev_cwd())
198 set_term_title('IPython: ' + abbrev_cwd())
202 else:
199 else:
203 toggle_set_term_title(False)
200 toggle_set_term_title(False)
204
201
205 def init_display_formatter(self):
202 def init_display_formatter(self):
206 super(TerminalInteractiveShell, self).init_display_formatter()
203 super(TerminalInteractiveShell, self).init_display_formatter()
207 # terminal only supports plain text
204 # terminal only supports plain text
208 self.display_formatter.active_types = ['text/plain']
205 self.display_formatter.active_types = ['text/plain']
209
206
210 def init_prompt_toolkit_cli(self):
207 def init_prompt_toolkit_cli(self):
211 if self.simple_prompt:
208 if self.simple_prompt:
212 # Fall back to plain non-interactive output for tests.
209 # Fall back to plain non-interactive output for tests.
213 # This is very limited, and only accepts a single line.
210 # This is very limited, and only accepts a single line.
214 def prompt():
211 def prompt():
215 return cast_unicode_py2(input('In [%d]: ' % self.execution_count))
212 return cast_unicode_py2(input('In [%d]: ' % self.execution_count))
216 self.prompt_for_code = prompt
213 self.prompt_for_code = prompt
217 return
214 return
218
215
219 # Set up keyboard shortcuts
216 # Set up keyboard shortcuts
220 kbmanager = KeyBindingManager.for_prompt()
217 kbmanager = KeyBindingManager.for_prompt()
221 register_ipython_shortcuts(kbmanager.registry, self)
218 register_ipython_shortcuts(kbmanager.registry, self)
222
219
223 # Pre-populate history from IPython's history database
220 # Pre-populate history from IPython's history database
224 history = InMemoryHistory()
221 history = InMemoryHistory()
225 last_cell = u""
222 last_cell = u""
226 for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
223 for __, ___, cell in self.history_manager.get_tail(self.history_load_length,
227 include_latest=True):
224 include_latest=True):
228 # Ignore blank lines and consecutive duplicates
225 # Ignore blank lines and consecutive duplicates
229 cell = cell.rstrip()
226 cell = cell.rstrip()
230 if cell and (cell != last_cell):
227 if cell and (cell != last_cell):
231 history.append(cell)
228 history.append(cell)
232 last_cell = cell
229 last_cell = cell
233
230
234 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
231 self._style = self._make_style_from_name_or_cls(self.highlighting_style)
235 style = DynamicStyle(lambda: self._style)
232 style = DynamicStyle(lambda: self._style)
236
233
237 editing_mode = getattr(EditingMode, self.editing_mode.upper())
234 editing_mode = getattr(EditingMode, self.editing_mode.upper())
238
235
239 self._pt_app = create_prompt_application(
236 self._pt_app = create_prompt_application(
240 editing_mode=editing_mode,
237 editing_mode=editing_mode,
241 key_bindings_registry=kbmanager.registry,
238 key_bindings_registry=kbmanager.registry,
242 history=history,
239 history=history,
243 completer=IPythonPTCompleter(shell=self),
240 completer=IPythonPTCompleter(shell=self),
244 enable_history_search=True,
241 enable_history_search=True,
245 style=style,
242 style=style,
246 mouse_support=self.mouse_support,
243 mouse_support=self.mouse_support,
247 **self._layout_options()
244 **self._layout_options()
248 )
245 )
249 self._eventloop = create_eventloop(self.inputhook)
246 self._eventloop = create_eventloop(self.inputhook)
250 self.pt_cli = CommandLineInterface(
247 self.pt_cli = CommandLineInterface(
251 self._pt_app, eventloop=self._eventloop,
248 self._pt_app, eventloop=self._eventloop,
252 output=create_output(true_color=self.true_color))
249 output=create_output(true_color=self.true_color))
253
250
254 def _make_style_from_name_or_cls(self, name_or_cls):
251 def _make_style_from_name_or_cls(self, name_or_cls):
255 """
252 """
256 Small wrapper that make an IPython compatible style from a style name
253 Small wrapper that make an IPython compatible style from a style name
257
254
258 We need that to add style for prompt ... etc.
255 We need that to add style for prompt ... etc.
259 """
256 """
260 style_overrides = {}
257 style_overrides = {}
261 if name_or_cls == 'legacy':
258 if name_or_cls == 'legacy':
262 legacy = self.colors.lower()
259 legacy = self.colors.lower()
263 if legacy == 'linux':
260 if legacy == 'linux':
264 style_cls = get_style_by_name('monokai')
261 style_cls = get_style_by_name('monokai')
265 style_overrides = _style_overrides_linux
262 style_overrides = _style_overrides_linux
266 elif legacy == 'lightbg':
263 elif legacy == 'lightbg':
267 style_overrides = _style_overrides_light_bg
264 style_overrides = _style_overrides_light_bg
268 style_cls = get_style_by_name('pastie')
265 style_cls = get_style_by_name('pastie')
269 elif legacy == 'neutral':
266 elif legacy == 'neutral':
270 # The default theme needs to be visible on both a dark background
267 # The default theme needs to be visible on both a dark background
271 # and a light background, because we can't tell what the terminal
268 # and a light background, because we can't tell what the terminal
272 # looks like. These tweaks to the default theme help with that.
269 # looks like. These tweaks to the default theme help with that.
273 style_cls = get_style_by_name('default')
270 style_cls = get_style_by_name('default')
274 style_overrides.update({
271 style_overrides.update({
275 Token.Number: '#007700',
272 Token.Number: '#007700',
276 Token.Operator: 'noinherit',
273 Token.Operator: 'noinherit',
277 Token.String: '#BB6622',
274 Token.String: '#BB6622',
278 Token.Name.Function: '#2080D0',
275 Token.Name.Function: '#2080D0',
279 Token.Name.Class: 'bold #2080D0',
276 Token.Name.Class: 'bold #2080D0',
280 Token.Name.Namespace: 'bold #2080D0',
277 Token.Name.Namespace: 'bold #2080D0',
281 Token.Prompt: '#009900',
278 Token.Prompt: '#009900',
282 Token.PromptNum: '#00ff00 bold',
279 Token.PromptNum: '#00ff00 bold',
283 Token.OutPrompt: '#990000',
280 Token.OutPrompt: '#990000',
284 Token.OutPromptNum: '#ff0000 bold',
281 Token.OutPromptNum: '#ff0000 bold',
285 })
282 })
286 elif legacy =='nocolor':
283 elif legacy =='nocolor':
287 style_cls=_NoStyle
284 style_cls=_NoStyle
288 style_overrides = {}
285 style_overrides = {}
289 else :
286 else :
290 raise ValueError('Got unknown colors: ', legacy)
287 raise ValueError('Got unknown colors: ', legacy)
291 else :
288 else :
292 if isinstance(name_or_cls, str):
289 if isinstance(name_or_cls, str):
293 style_cls = get_style_by_name(name_or_cls)
290 style_cls = get_style_by_name(name_or_cls)
294 else:
291 else:
295 style_cls = name_or_cls
292 style_cls = name_or_cls
296 style_overrides = {
293 style_overrides = {
297 Token.Prompt: '#009900',
294 Token.Prompt: '#009900',
298 Token.PromptNum: '#00ff00 bold',
295 Token.PromptNum: '#00ff00 bold',
299 Token.OutPrompt: '#990000',
296 Token.OutPrompt: '#990000',
300 Token.OutPromptNum: '#ff0000 bold',
297 Token.OutPromptNum: '#ff0000 bold',
301 }
298 }
302 style_overrides.update(self.highlighting_style_overrides)
299 style_overrides.update(self.highlighting_style_overrides)
303 style = PygmentsStyle.from_defaults(pygments_style_cls=style_cls,
300 style = PygmentsStyle.from_defaults(pygments_style_cls=style_cls,
304 style_dict=style_overrides)
301 style_dict=style_overrides)
305
302
306 return style
303 return style
307
304
308 def _layout_options(self):
305 def _layout_options(self):
309 """
306 """
310 Return the current layout option for the current Terminal InteractiveShell
307 Return the current layout option for the current Terminal InteractiveShell
311 """
308 """
312 return {
309 return {
313 'lexer':IPythonPTLexer(),
310 'lexer':IPythonPTLexer(),
314 'reserve_space_for_menu':self.space_for_menu,
311 'reserve_space_for_menu':self.space_for_menu,
315 'get_prompt_tokens':self.prompts.in_prompt_tokens,
312 'get_prompt_tokens':self.prompts.in_prompt_tokens,
316 'get_continuation_tokens':self.prompts.continuation_prompt_tokens,
313 'get_continuation_tokens':self.prompts.continuation_prompt_tokens,
317 'multiline':True,
314 'multiline':True,
318 'display_completions_in_columns': (self.display_completions == 'multicolumn'),
315 'display_completions_in_columns': (self.display_completions == 'multicolumn'),
319
316
320 # Highlight matching brackets, but only when this setting is
317 # Highlight matching brackets, but only when this setting is
321 # enabled, and only when the DEFAULT_BUFFER has the focus.
318 # enabled, and only when the DEFAULT_BUFFER has the focus.
322 'extra_input_processors': [ConditionalProcessor(
319 'extra_input_processors': [ConditionalProcessor(
323 processor=HighlightMatchingBracketProcessor(chars='[](){}'),
320 processor=HighlightMatchingBracketProcessor(chars='[](){}'),
324 filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
321 filter=HasFocus(DEFAULT_BUFFER) & ~IsDone() &
325 Condition(lambda cli: self.highlight_matching_brackets))],
322 Condition(lambda cli: self.highlight_matching_brackets))],
326 }
323 }
327
324
328 def _update_layout(self):
325 def _update_layout(self):
329 """
326 """
330 Ask for a re computation of the application layout, if for example ,
327 Ask for a re computation of the application layout, if for example ,
331 some configuration options have changed.
328 some configuration options have changed.
332 """
329 """
333 if self._pt_app:
330 if self._pt_app:
334 self._pt_app.layout = create_prompt_layout(**self._layout_options())
331 self._pt_app.layout = create_prompt_layout(**self._layout_options())
335
332
336 def prompt_for_code(self):
333 def prompt_for_code(self):
337 document = self.pt_cli.run(
334 document = self.pt_cli.run(
338 pre_run=self.pre_prompt, reset_current_buffer=True)
335 pre_run=self.pre_prompt, reset_current_buffer=True)
339 return document.text
336 return document.text
340
337
341 def enable_win_unicode_console(self):
338 def enable_win_unicode_console(self):
342 if sys.version_info >= (3, 6):
339 if sys.version_info >= (3, 6):
343 # Since PEP 528, Python uses the unicode APIs for the Windows
340 # Since PEP 528, Python uses the unicode APIs for the Windows
344 # console by default, so WUC shouldn't be needed.
341 # console by default, so WUC shouldn't be needed.
345 return
342 return
346
343
347 import win_unicode_console
344 import win_unicode_console
348
345 win_unicode_console.enable()
349 if PY3:
350 win_unicode_console.enable()
351 else:
352 # https://github.com/ipython/ipython/issues/9768
353 from win_unicode_console.streams import (TextStreamWrapper,
354 stdout_text_transcoded, stderr_text_transcoded)
355
356 class LenientStrStreamWrapper(TextStreamWrapper):
357 def write(self, s):
358 if isinstance(s, bytes):
359 s = s.decode(self.encoding, 'replace')
360
361 self.base.write(s)
362
363 stdout_text_str = LenientStrStreamWrapper(stdout_text_transcoded)
364 stderr_text_str = LenientStrStreamWrapper(stderr_text_transcoded)
365
366 win_unicode_console.enable(stdout=stdout_text_str,
367 stderr=stderr_text_str)
368
346
369 def init_io(self):
347 def init_io(self):
370 if sys.platform not in {'win32', 'cli'}:
348 if sys.platform not in {'win32', 'cli'}:
371 return
349 return
372
350
373 self.enable_win_unicode_console()
351 self.enable_win_unicode_console()
374
352
375 import colorama
353 import colorama
376 colorama.init()
354 colorama.init()
377
355
378 # For some reason we make these wrappers around stdout/stderr.
356 # For some reason we make these wrappers around stdout/stderr.
379 # For now, we need to reset them so all output gets coloured.
357 # For now, we need to reset them so all output gets coloured.
380 # https://github.com/ipython/ipython/issues/8669
358 # https://github.com/ipython/ipython/issues/8669
381 # io.std* are deprecated, but don't show our own deprecation warnings
359 # io.std* are deprecated, but don't show our own deprecation warnings
382 # during initialization of the deprecated API.
360 # during initialization of the deprecated API.
383 with warnings.catch_warnings():
361 with warnings.catch_warnings():
384 warnings.simplefilter('ignore', DeprecationWarning)
362 warnings.simplefilter('ignore', DeprecationWarning)
385 io.stdout = io.IOStream(sys.stdout)
363 io.stdout = io.IOStream(sys.stdout)
386 io.stderr = io.IOStream(sys.stderr)
364 io.stderr = io.IOStream(sys.stderr)
387
365
388 def init_magics(self):
366 def init_magics(self):
389 super(TerminalInteractiveShell, self).init_magics()
367 super(TerminalInteractiveShell, self).init_magics()
390 self.register_magics(TerminalMagics)
368 self.register_magics(TerminalMagics)
391
369
392 def init_alias(self):
370 def init_alias(self):
393 # The parent class defines aliases that can be safely used with any
371 # The parent class defines aliases that can be safely used with any
394 # frontend.
372 # frontend.
395 super(TerminalInteractiveShell, self).init_alias()
373 super(TerminalInteractiveShell, self).init_alias()
396
374
397 # Now define aliases that only make sense on the terminal, because they
375 # Now define aliases that only make sense on the terminal, because they
398 # need direct access to the console in a way that we can't emulate in
376 # need direct access to the console in a way that we can't emulate in
399 # GUI or web frontend
377 # GUI or web frontend
400 if os.name == 'posix':
378 if os.name == 'posix':
401 for cmd in ['clear', 'more', 'less', 'man']:
379 for cmd in ['clear', 'more', 'less', 'man']:
402 self.alias_manager.soft_define_alias(cmd, cmd)
380 self.alias_manager.soft_define_alias(cmd, cmd)
403
381
404
382
405 def __init__(self, *args, **kwargs):
383 def __init__(self, *args, **kwargs):
406 super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
384 super(TerminalInteractiveShell, self).__init__(*args, **kwargs)
407 self.init_prompt_toolkit_cli()
385 self.init_prompt_toolkit_cli()
408 self.init_term_title()
386 self.init_term_title()
409 self.keep_running = True
387 self.keep_running = True
410
388
411 self.debugger_history = InMemoryHistory()
389 self.debugger_history = InMemoryHistory()
412
390
413 def ask_exit(self):
391 def ask_exit(self):
414 self.keep_running = False
392 self.keep_running = False
415
393
416 rl_next_input = None
394 rl_next_input = None
417
395
418 def pre_prompt(self):
396 def pre_prompt(self):
419 if self.rl_next_input:
397 if self.rl_next_input:
420 self.pt_cli.application.buffer.text = cast_unicode_py2(self.rl_next_input)
398 self.pt_cli.application.buffer.text = cast_unicode_py2(self.rl_next_input)
421 self.rl_next_input = None
399 self.rl_next_input = None
422
400
423 def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):
401 def interact(self, display_banner=DISPLAY_BANNER_DEPRECATED):
424
402
425 if display_banner is not DISPLAY_BANNER_DEPRECATED:
403 if display_banner is not DISPLAY_BANNER_DEPRECATED:
426 warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
404 warn('interact `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
427
405
428 self.keep_running = True
406 self.keep_running = True
429 while self.keep_running:
407 while self.keep_running:
430 print(self.separate_in, end='')
408 print(self.separate_in, end='')
431
409
432 try:
410 try:
433 code = self.prompt_for_code()
411 code = self.prompt_for_code()
434 except EOFError:
412 except EOFError:
435 if (not self.confirm_exit) \
413 if (not self.confirm_exit) \
436 or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
414 or self.ask_yes_no('Do you really want to exit ([y]/n)?','y','n'):
437 self.ask_exit()
415 self.ask_exit()
438
416
439 else:
417 else:
440 if code:
418 if code:
441 self.run_cell(code, store_history=True)
419 self.run_cell(code, store_history=True)
442
420
443 def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
421 def mainloop(self, display_banner=DISPLAY_BANNER_DEPRECATED):
444 # An extra layer of protection in case someone mashing Ctrl-C breaks
422 # An extra layer of protection in case someone mashing Ctrl-C breaks
445 # out of our internal code.
423 # out of our internal code.
446 if display_banner is not DISPLAY_BANNER_DEPRECATED:
424 if display_banner is not DISPLAY_BANNER_DEPRECATED:
447 warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
425 warn('mainloop `display_banner` argument is deprecated since IPython 5.0. Call `show_banner()` if needed.', DeprecationWarning, stacklevel=2)
448 while True:
426 while True:
449 try:
427 try:
450 self.interact()
428 self.interact()
451 break
429 break
452 except KeyboardInterrupt:
430 except KeyboardInterrupt:
453 print("\nKeyboardInterrupt escaped interact()\n")
431 print("\nKeyboardInterrupt escaped interact()\n")
454
432
455 _inputhook = None
433 _inputhook = None
456 def inputhook(self, context):
434 def inputhook(self, context):
457 if self._inputhook is not None:
435 if self._inputhook is not None:
458 self._inputhook(context)
436 self._inputhook(context)
459
437
460 active_eventloop = None
438 active_eventloop = None
461 def enable_gui(self, gui=None):
439 def enable_gui(self, gui=None):
462 if gui:
440 if gui:
463 self.active_eventloop, self._inputhook =\
441 self.active_eventloop, self._inputhook =\
464 get_inputhook_name_and_func(gui)
442 get_inputhook_name_and_func(gui)
465 else:
443 else:
466 self.active_eventloop = self._inputhook = None
444 self.active_eventloop = self._inputhook = None
467
445
468 # Run !system commands directly, not through pipes, so terminal programs
446 # Run !system commands directly, not through pipes, so terminal programs
469 # work correctly.
447 # work correctly.
470 system = InteractiveShell.system_raw
448 system = InteractiveShell.system_raw
471
449
472 def auto_rewrite_input(self, cmd):
450 def auto_rewrite_input(self, cmd):
473 """Overridden from the parent class to use fancy rewriting prompt"""
451 """Overridden from the parent class to use fancy rewriting prompt"""
474 if not self.show_rewritten_input:
452 if not self.show_rewritten_input:
475 return
453 return
476
454
477 tokens = self.prompts.rewrite_prompt_tokens()
455 tokens = self.prompts.rewrite_prompt_tokens()
478 if self.pt_cli:
456 if self.pt_cli:
479 self.pt_cli.print_tokens(tokens)
457 self.pt_cli.print_tokens(tokens)
480 print(cmd)
458 print(cmd)
481 else:
459 else:
482 prompt = ''.join(s for t, s in tokens)
460 prompt = ''.join(s for t, s in tokens)
483 print(prompt, cmd, sep='')
461 print(prompt, cmd, sep='')
484
462
485 _prompts_before = None
463 _prompts_before = None
486 def switch_doctest_mode(self, mode):
464 def switch_doctest_mode(self, mode):
487 """Switch prompts to classic for %doctest_mode"""
465 """Switch prompts to classic for %doctest_mode"""
488 if mode:
466 if mode:
489 self._prompts_before = self.prompts
467 self._prompts_before = self.prompts
490 self.prompts = ClassicPrompts(self)
468 self.prompts = ClassicPrompts(self)
491 elif self._prompts_before:
469 elif self._prompts_before:
492 self.prompts = self._prompts_before
470 self.prompts = self._prompts_before
493 self._prompts_before = None
471 self._prompts_before = None
494 self._update_layout()
472 self._update_layout()
495
473
496
474
497 InteractiveShellABC.register(TerminalInteractiveShell)
475 InteractiveShellABC.register(TerminalInteractiveShell)
498
476
499 if __name__ == '__main__':
477 if __name__ == '__main__':
500 TerminalInteractiveShell.instance().interact()
478 TerminalInteractiveShell.instance().interact()
@@ -1,95 +1,90 b''
1 # Code borrowed from ptpython
1 # Code borrowed from ptpython
2 # https://github.com/jonathanslenders/ptpython/blob/86b71a89626114b18898a0af463978bdb32eeb70/ptpython/eventloop.py
2 # https://github.com/jonathanslenders/ptpython/blob/86b71a89626114b18898a0af463978bdb32eeb70/ptpython/eventloop.py
3
3
4 # Copyright (c) 2015, Jonathan Slenders
4 # Copyright (c) 2015, Jonathan Slenders
5 # All rights reserved.
5 # All rights reserved.
6 #
6 #
7 # Redistribution and use in source and binary forms, with or without modification,
7 # Redistribution and use in source and binary forms, with or without modification,
8 # are permitted provided that the following conditions are met:
8 # are permitted provided that the following conditions are met:
9 #
9 #
10 # * Redistributions of source code must retain the above copyright notice, this
10 # * Redistributions of source code must retain the above copyright notice, this
11 # list of conditions and the following disclaimer.
11 # list of conditions and the following disclaimer.
12 #
12 #
13 # * Redistributions in binary form must reproduce the above copyright notice, this
13 # * Redistributions in binary form must reproduce the above copyright notice, this
14 # list of conditions and the following disclaimer in the documentation and/or
14 # list of conditions and the following disclaimer in the documentation and/or
15 # other materials provided with the distribution.
15 # other materials provided with the distribution.
16 #
16 #
17 # * Neither the name of the {organization} nor the names of its
17 # * Neither the name of the {organization} nor the names of its
18 # contributors may be used to endorse or promote products derived from
18 # contributors may be used to endorse or promote products derived from
19 # this software without specific prior written permission.
19 # this software without specific prior written permission.
20 #
20 #
21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
21 # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
22 # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
24 # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
25 # ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
26 # (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
27 # LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
28 # ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29 # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31
31
32 """
32 """
33 Wrapper around the eventloop that gives some time to the Tkinter GUI to process
33 Wrapper around the eventloop that gives some time to the Tkinter GUI to process
34 events when it's loaded and while we are waiting for input at the REPL. This
34 events when it's loaded and while we are waiting for input at the REPL. This
35 way we don't block the UI of for instance ``turtle`` and other Tk libraries.
35 way we don't block the UI of for instance ``turtle`` and other Tk libraries.
36
36
37 (Normally Tkinter registeres it's callbacks in ``PyOS_InputHook`` to integrate
37 (Normally Tkinter registeres it's callbacks in ``PyOS_InputHook`` to integrate
38 in readline. ``prompt-toolkit`` doesn't understand that input hook, but this
38 in readline. ``prompt-toolkit`` doesn't understand that input hook, but this
39 will fix it for Tk.)
39 will fix it for Tk.)
40 """
40 """
41 import time
41 import time
42
42
43 from IPython.utils.py3compat import PY3
44
45 import _tkinter
43 import _tkinter
46 if PY3:
44 import tkinter
47 import tkinter
48 else:
49 import Tkinter as tkinter # Python 2
50
45
51 def inputhook(inputhook_context):
46 def inputhook(inputhook_context):
52 """
47 """
53 Inputhook for Tk.
48 Inputhook for Tk.
54 Run the Tk eventloop until prompt-toolkit needs to process the next input.
49 Run the Tk eventloop until prompt-toolkit needs to process the next input.
55 """
50 """
56 # Get the current TK application.
51 # Get the current TK application.
57 root = tkinter._default_root
52 root = tkinter._default_root
58
53
59 def wait_using_filehandler():
54 def wait_using_filehandler():
60 """
55 """
61 Run the TK eventloop until the file handler that we got from the
56 Run the TK eventloop until the file handler that we got from the
62 inputhook becomes readable.
57 inputhook becomes readable.
63 """
58 """
64 # Add a handler that sets the stop flag when `prompt-toolkit` has input
59 # Add a handler that sets the stop flag when `prompt-toolkit` has input
65 # to process.
60 # to process.
66 stop = [False]
61 stop = [False]
67 def done(*a):
62 def done(*a):
68 stop[0] = True
63 stop[0] = True
69
64
70 root.createfilehandler(inputhook_context.fileno(), _tkinter.READABLE, done)
65 root.createfilehandler(inputhook_context.fileno(), _tkinter.READABLE, done)
71
66
72 # Run the TK event loop as long as we don't receive input.
67 # Run the TK event loop as long as we don't receive input.
73 while root.dooneevent(_tkinter.ALL_EVENTS):
68 while root.dooneevent(_tkinter.ALL_EVENTS):
74 if stop[0]:
69 if stop[0]:
75 break
70 break
76
71
77 root.deletefilehandler(inputhook_context.fileno())
72 root.deletefilehandler(inputhook_context.fileno())
78
73
79 def wait_using_polling():
74 def wait_using_polling():
80 """
75 """
81 Windows TK doesn't support 'createfilehandler'.
76 Windows TK doesn't support 'createfilehandler'.
82 So, run the TK eventloop and poll until input is ready.
77 So, run the TK eventloop and poll until input is ready.
83 """
78 """
84 while not inputhook_context.input_is_ready():
79 while not inputhook_context.input_is_ready():
85 while root.dooneevent(_tkinter.ALL_EVENTS | _tkinter.DONT_WAIT):
80 while root.dooneevent(_tkinter.ALL_EVENTS | _tkinter.DONT_WAIT):
86 pass
81 pass
87 # Sleep to make the CPU idle, but not too long, so that the UI
82 # Sleep to make the CPU idle, but not too long, so that the UI
88 # stays responsive.
83 # stays responsive.
89 time.sleep(.01)
84 time.sleep(.01)
90
85
91 if root is not None:
86 if root is not None:
92 if hasattr(root, 'createfilehandler'):
87 if hasattr(root, 'createfilehandler'):
93 wait_using_filehandler()
88 wait_using_filehandler()
94 else:
89 else:
95 wait_using_polling()
90 wait_using_polling()
@@ -1,378 +1,378 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Decorators for labeling test objects.
2 """Decorators for labeling test objects.
3
3
4 Decorators that merely return a modified version of the original function
4 Decorators that merely return a modified version of the original function
5 object are straightforward. Decorators that return a new function object need
5 object are straightforward. Decorators that return a new function object need
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
6 to use nose.tools.make_decorator(original_function)(decorator) in returning the
7 decorator, in order to preserve metadata such as function name, setup and
7 decorator, in order to preserve metadata such as function name, setup and
8 teardown functions and so on - see nose.tools for more information.
8 teardown functions and so on - see nose.tools for more information.
9
9
10 This module provides a set of useful decorators meant to be ready to use in
10 This module provides a set of useful decorators meant to be ready to use in
11 your own tests. See the bottom of the file for the ready-made ones, and if you
11 your own tests. See the bottom of the file for the ready-made ones, and if you
12 find yourself writing a new one that may be of generic use, add it here.
12 find yourself writing a new one that may be of generic use, add it here.
13
13
14 Included decorators:
14 Included decorators:
15
15
16
16
17 Lightweight testing that remains unittest-compatible.
17 Lightweight testing that remains unittest-compatible.
18
18
19 - An @as_unittest decorator can be used to tag any normal parameter-less
19 - An @as_unittest decorator can be used to tag any normal parameter-less
20 function as a unittest TestCase. Then, both nose and normal unittest will
20 function as a unittest TestCase. Then, both nose and normal unittest will
21 recognize it as such. This will make it easier to migrate away from Nose if
21 recognize it as such. This will make it easier to migrate away from Nose if
22 we ever need/want to while maintaining very lightweight tests.
22 we ever need/want to while maintaining very lightweight tests.
23
23
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
24 NOTE: This file contains IPython-specific decorators. Using the machinery in
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
25 IPython.external.decorators, we import either numpy.testing.decorators if numpy is
26 available, OR use equivalent code in IPython.external._decorators, which
26 available, OR use equivalent code in IPython.external._decorators, which
27 we've copied verbatim from numpy.
27 we've copied verbatim from numpy.
28
28
29 """
29 """
30
30
31 # Copyright (c) IPython Development Team.
31 # Copyright (c) IPython Development Team.
32 # Distributed under the terms of the Modified BSD License.
32 # Distributed under the terms of the Modified BSD License.
33
33
34 import sys
34 import sys
35 import os
35 import os
36 import tempfile
36 import tempfile
37 import unittest
37 import unittest
38 import warnings
38 import warnings
39 from importlib import import_module
39 from importlib import import_module
40
40
41 from decorator import decorator
41 from decorator import decorator
42
42
43 # Expose the unittest-driven decorators
43 # Expose the unittest-driven decorators
44 from .ipunittest import ipdoctest, ipdocstring
44 from .ipunittest import ipdoctest, ipdocstring
45
45
46 # Grab the numpy-specific decorators which we keep in a file that we
46 # Grab the numpy-specific decorators which we keep in a file that we
47 # occasionally update from upstream: decorators.py is a copy of
47 # occasionally update from upstream: decorators.py is a copy of
48 # numpy.testing.decorators, we expose all of it here.
48 # numpy.testing.decorators, we expose all of it here.
49 from IPython.external.decorators import *
49 from IPython.external.decorators import *
50
50
51 # For onlyif_cmd_exists decorator
51 # For onlyif_cmd_exists decorator
52 from IPython.utils.py3compat import which, PY2, PY3, PYPY
52 from IPython.utils.py3compat import which
53
53
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55 # Classes and functions
55 # Classes and functions
56 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
57
57
58 # Simple example of the basic idea
58 # Simple example of the basic idea
59 def as_unittest(func):
59 def as_unittest(func):
60 """Decorator to make a simple function into a normal test via unittest."""
60 """Decorator to make a simple function into a normal test via unittest."""
61 class Tester(unittest.TestCase):
61 class Tester(unittest.TestCase):
62 def test(self):
62 def test(self):
63 func()
63 func()
64
64
65 Tester.__name__ = func.__name__
65 Tester.__name__ = func.__name__
66
66
67 return Tester
67 return Tester
68
68
69 # Utility functions
69 # Utility functions
70
70
71 def apply_wrapper(wrapper, func):
71 def apply_wrapper(wrapper, func):
72 """Apply a wrapper to a function for decoration.
72 """Apply a wrapper to a function for decoration.
73
73
74 This mixes Michele Simionato's decorator tool with nose's make_decorator,
74 This mixes Michele Simionato's decorator tool with nose's make_decorator,
75 to apply a wrapper in a decorator so that all nose attributes, as well as
75 to apply a wrapper in a decorator so that all nose attributes, as well as
76 function signature and other properties, survive the decoration cleanly.
76 function signature and other properties, survive the decoration cleanly.
77 This will ensure that wrapped functions can still be well introspected via
77 This will ensure that wrapped functions can still be well introspected via
78 IPython, for example.
78 IPython, for example.
79 """
79 """
80 warnings.warn("The function `apply_wrapper` is deprecated since IPython 4.0",
80 warnings.warn("The function `apply_wrapper` is deprecated since IPython 4.0",
81 DeprecationWarning, stacklevel=2)
81 DeprecationWarning, stacklevel=2)
82 import nose.tools
82 import nose.tools
83
83
84 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
84 return decorator(wrapper,nose.tools.make_decorator(func)(wrapper))
85
85
86
86
87 def make_label_dec(label, ds=None):
87 def make_label_dec(label, ds=None):
88 """Factory function to create a decorator that applies one or more labels.
88 """Factory function to create a decorator that applies one or more labels.
89
89
90 Parameters
90 Parameters
91 ----------
91 ----------
92 label : string or sequence
92 label : string or sequence
93 One or more labels that will be applied by the decorator to the functions
93 One or more labels that will be applied by the decorator to the functions
94 it decorates. Labels are attributes of the decorated function with their
94 it decorates. Labels are attributes of the decorated function with their
95 value set to True.
95 value set to True.
96
96
97 ds : string
97 ds : string
98 An optional docstring for the resulting decorator. If not given, a
98 An optional docstring for the resulting decorator. If not given, a
99 default docstring is auto-generated.
99 default docstring is auto-generated.
100
100
101 Returns
101 Returns
102 -------
102 -------
103 A decorator.
103 A decorator.
104
104
105 Examples
105 Examples
106 --------
106 --------
107
107
108 A simple labeling decorator:
108 A simple labeling decorator:
109
109
110 >>> slow = make_label_dec('slow')
110 >>> slow = make_label_dec('slow')
111 >>> slow.__doc__
111 >>> slow.__doc__
112 "Labels a test as 'slow'."
112 "Labels a test as 'slow'."
113
113
114 And one that uses multiple labels and a custom docstring:
114 And one that uses multiple labels and a custom docstring:
115
115
116 >>> rare = make_label_dec(['slow','hard'],
116 >>> rare = make_label_dec(['slow','hard'],
117 ... "Mix labels 'slow' and 'hard' for rare tests.")
117 ... "Mix labels 'slow' and 'hard' for rare tests.")
118 >>> rare.__doc__
118 >>> rare.__doc__
119 "Mix labels 'slow' and 'hard' for rare tests."
119 "Mix labels 'slow' and 'hard' for rare tests."
120
120
121 Now, let's test using this one:
121 Now, let's test using this one:
122 >>> @rare
122 >>> @rare
123 ... def f(): pass
123 ... def f(): pass
124 ...
124 ...
125 >>>
125 >>>
126 >>> f.slow
126 >>> f.slow
127 True
127 True
128 >>> f.hard
128 >>> f.hard
129 True
129 True
130 """
130 """
131
131
132 warnings.warn("The function `make_label_dec` is deprecated since IPython 4.0",
132 warnings.warn("The function `make_label_dec` is deprecated since IPython 4.0",
133 DeprecationWarning, stacklevel=2)
133 DeprecationWarning, stacklevel=2)
134 if isinstance(label, str):
134 if isinstance(label, str):
135 labels = [label]
135 labels = [label]
136 else:
136 else:
137 labels = label
137 labels = label
138
138
139 # Validate that the given label(s) are OK for use in setattr() by doing a
139 # Validate that the given label(s) are OK for use in setattr() by doing a
140 # dry run on a dummy function.
140 # dry run on a dummy function.
141 tmp = lambda : None
141 tmp = lambda : None
142 for label in labels:
142 for label in labels:
143 setattr(tmp,label,True)
143 setattr(tmp,label,True)
144
144
145 # This is the actual decorator we'll return
145 # This is the actual decorator we'll return
146 def decor(f):
146 def decor(f):
147 for label in labels:
147 for label in labels:
148 setattr(f,label,True)
148 setattr(f,label,True)
149 return f
149 return f
150
150
151 # Apply the user's docstring, or autogenerate a basic one
151 # Apply the user's docstring, or autogenerate a basic one
152 if ds is None:
152 if ds is None:
153 ds = "Labels a test as %r." % label
153 ds = "Labels a test as %r." % label
154 decor.__doc__ = ds
154 decor.__doc__ = ds
155
155
156 return decor
156 return decor
157
157
158
158
159 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
159 # Inspired by numpy's skipif, but uses the full apply_wrapper utility to
160 # preserve function metadata better and allows the skip condition to be a
160 # preserve function metadata better and allows the skip condition to be a
161 # callable.
161 # callable.
162 def skipif(skip_condition, msg=None):
162 def skipif(skip_condition, msg=None):
163 ''' Make function raise SkipTest exception if skip_condition is true
163 ''' Make function raise SkipTest exception if skip_condition is true
164
164
165 Parameters
165 Parameters
166 ----------
166 ----------
167
167
168 skip_condition : bool or callable
168 skip_condition : bool or callable
169 Flag to determine whether to skip test. If the condition is a
169 Flag to determine whether to skip test. If the condition is a
170 callable, it is used at runtime to dynamically make the decision. This
170 callable, it is used at runtime to dynamically make the decision. This
171 is useful for tests that may require costly imports, to delay the cost
171 is useful for tests that may require costly imports, to delay the cost
172 until the test suite is actually executed.
172 until the test suite is actually executed.
173 msg : string
173 msg : string
174 Message to give on raising a SkipTest exception.
174 Message to give on raising a SkipTest exception.
175
175
176 Returns
176 Returns
177 -------
177 -------
178 decorator : function
178 decorator : function
179 Decorator, which, when applied to a function, causes SkipTest
179 Decorator, which, when applied to a function, causes SkipTest
180 to be raised when the skip_condition was True, and the function
180 to be raised when the skip_condition was True, and the function
181 to be called normally otherwise.
181 to be called normally otherwise.
182
182
183 Notes
183 Notes
184 -----
184 -----
185 You will see from the code that we had to further decorate the
185 You will see from the code that we had to further decorate the
186 decorator with the nose.tools.make_decorator function in order to
186 decorator with the nose.tools.make_decorator function in order to
187 transmit function name, and various other metadata.
187 transmit function name, and various other metadata.
188 '''
188 '''
189
189
190 def skip_decorator(f):
190 def skip_decorator(f):
191 # Local import to avoid a hard nose dependency and only incur the
191 # Local import to avoid a hard nose dependency and only incur the
192 # import time overhead at actual test-time.
192 # import time overhead at actual test-time.
193 import nose
193 import nose
194
194
195 # Allow for both boolean or callable skip conditions.
195 # Allow for both boolean or callable skip conditions.
196 if callable(skip_condition):
196 if callable(skip_condition):
197 skip_val = skip_condition
197 skip_val = skip_condition
198 else:
198 else:
199 skip_val = lambda : skip_condition
199 skip_val = lambda : skip_condition
200
200
201 def get_msg(func,msg=None):
201 def get_msg(func,msg=None):
202 """Skip message with information about function being skipped."""
202 """Skip message with information about function being skipped."""
203 if msg is None: out = 'Test skipped due to test condition.'
203 if msg is None: out = 'Test skipped due to test condition.'
204 else: out = msg
204 else: out = msg
205 return "Skipping test: %s. %s" % (func.__name__,out)
205 return "Skipping test: %s. %s" % (func.__name__,out)
206
206
207 # We need to define *two* skippers because Python doesn't allow both
207 # We need to define *two* skippers because Python doesn't allow both
208 # return with value and yield inside the same function.
208 # return with value and yield inside the same function.
209 def skipper_func(*args, **kwargs):
209 def skipper_func(*args, **kwargs):
210 """Skipper for normal test functions."""
210 """Skipper for normal test functions."""
211 if skip_val():
211 if skip_val():
212 raise nose.SkipTest(get_msg(f,msg))
212 raise nose.SkipTest(get_msg(f,msg))
213 else:
213 else:
214 return f(*args, **kwargs)
214 return f(*args, **kwargs)
215
215
216 def skipper_gen(*args, **kwargs):
216 def skipper_gen(*args, **kwargs):
217 """Skipper for test generators."""
217 """Skipper for test generators."""
218 if skip_val():
218 if skip_val():
219 raise nose.SkipTest(get_msg(f,msg))
219 raise nose.SkipTest(get_msg(f,msg))
220 else:
220 else:
221 for x in f(*args, **kwargs):
221 for x in f(*args, **kwargs):
222 yield x
222 yield x
223
223
224 # Choose the right skipper to use when building the actual generator.
224 # Choose the right skipper to use when building the actual generator.
225 if nose.util.isgenerator(f):
225 if nose.util.isgenerator(f):
226 skipper = skipper_gen
226 skipper = skipper_gen
227 else:
227 else:
228 skipper = skipper_func
228 skipper = skipper_func
229
229
230 return nose.tools.make_decorator(f)(skipper)
230 return nose.tools.make_decorator(f)(skipper)
231
231
232 return skip_decorator
232 return skip_decorator
233
233
234 # A version with the condition set to true, common case just to attach a message
234 # A version with the condition set to true, common case just to attach a message
235 # to a skip decorator
235 # to a skip decorator
236 def skip(msg=None):
236 def skip(msg=None):
237 """Decorator factory - mark a test function for skipping from test suite.
237 """Decorator factory - mark a test function for skipping from test suite.
238
238
239 Parameters
239 Parameters
240 ----------
240 ----------
241 msg : string
241 msg : string
242 Optional message to be added.
242 Optional message to be added.
243
243
244 Returns
244 Returns
245 -------
245 -------
246 decorator : function
246 decorator : function
247 Decorator, which, when applied to a function, causes SkipTest
247 Decorator, which, when applied to a function, causes SkipTest
248 to be raised, with the optional message added.
248 to be raised, with the optional message added.
249 """
249 """
250
250
251 return skipif(True,msg)
251 return skipif(True,msg)
252
252
253
253
254 def onlyif(condition, msg):
254 def onlyif(condition, msg):
255 """The reverse from skipif, see skipif for details."""
255 """The reverse from skipif, see skipif for details."""
256
256
257 if callable(condition):
257 if callable(condition):
258 skip_condition = lambda : not condition()
258 skip_condition = lambda : not condition()
259 else:
259 else:
260 skip_condition = lambda : not condition
260 skip_condition = lambda : not condition
261
261
262 return skipif(skip_condition, msg)
262 return skipif(skip_condition, msg)
263
263
264 #-----------------------------------------------------------------------------
264 #-----------------------------------------------------------------------------
265 # Utility functions for decorators
265 # Utility functions for decorators
266 def module_not_available(module):
266 def module_not_available(module):
267 """Can module be imported? Returns true if module does NOT import.
267 """Can module be imported? Returns true if module does NOT import.
268
268
269 This is used to make a decorator to skip tests that require module to be
269 This is used to make a decorator to skip tests that require module to be
270 available, but delay the 'import numpy' to test execution time.
270 available, but delay the 'import numpy' to test execution time.
271 """
271 """
272 try:
272 try:
273 mod = import_module(module)
273 mod = import_module(module)
274 mod_not_avail = False
274 mod_not_avail = False
275 except ImportError:
275 except ImportError:
276 mod_not_avail = True
276 mod_not_avail = True
277
277
278 return mod_not_avail
278 return mod_not_avail
279
279
280
280
281 def decorated_dummy(dec, name):
281 def decorated_dummy(dec, name):
282 """Return a dummy function decorated with dec, with the given name.
282 """Return a dummy function decorated with dec, with the given name.
283
283
284 Examples
284 Examples
285 --------
285 --------
286 import IPython.testing.decorators as dec
286 import IPython.testing.decorators as dec
287 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
287 setup = dec.decorated_dummy(dec.skip_if_no_x11, __name__)
288 """
288 """
289 warnings.warn("The function `decorated_dummy` is deprecated since IPython 4.0",
289 warnings.warn("The function `decorated_dummy` is deprecated since IPython 4.0",
290 DeprecationWarning, stacklevel=2)
290 DeprecationWarning, stacklevel=2)
291 dummy = lambda: None
291 dummy = lambda: None
292 dummy.__name__ = name
292 dummy.__name__ = name
293 return dec(dummy)
293 return dec(dummy)
294
294
295 #-----------------------------------------------------------------------------
295 #-----------------------------------------------------------------------------
296 # Decorators for public use
296 # Decorators for public use
297
297
298 # Decorators to skip certain tests on specific platforms.
298 # Decorators to skip certain tests on specific platforms.
299 skip_win32 = skipif(sys.platform == 'win32',
299 skip_win32 = skipif(sys.platform == 'win32',
300 "This test does not run under Windows")
300 "This test does not run under Windows")
301 skip_linux = skipif(sys.platform.startswith('linux'),
301 skip_linux = skipif(sys.platform.startswith('linux'),
302 "This test does not run under Linux")
302 "This test does not run under Linux")
303 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
303 skip_osx = skipif(sys.platform == 'darwin',"This test does not run under OS X")
304
304
305
305
306 # Decorators to skip tests if not on specific platforms.
306 # Decorators to skip tests if not on specific platforms.
307 skip_if_not_win32 = skipif(sys.platform != 'win32',
307 skip_if_not_win32 = skipif(sys.platform != 'win32',
308 "This test only runs under Windows")
308 "This test only runs under Windows")
309 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
309 skip_if_not_linux = skipif(not sys.platform.startswith('linux'),
310 "This test only runs under Linux")
310 "This test only runs under Linux")
311 skip_if_not_osx = skipif(sys.platform != 'darwin',
311 skip_if_not_osx = skipif(sys.platform != 'darwin',
312 "This test only runs under OSX")
312 "This test only runs under OSX")
313
313
314
314
315 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
315 _x11_skip_cond = (sys.platform not in ('darwin', 'win32') and
316 os.environ.get('DISPLAY', '') == '')
316 os.environ.get('DISPLAY', '') == '')
317 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
317 _x11_skip_msg = "Skipped under *nix when X11/XOrg not available"
318
318
319 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
319 skip_if_no_x11 = skipif(_x11_skip_cond, _x11_skip_msg)
320
320
321 # not a decorator itself, returns a dummy function to be used as setup
321 # not a decorator itself, returns a dummy function to be used as setup
322 def skip_file_no_x11(name):
322 def skip_file_no_x11(name):
323 warnings.warn("The function `skip_file_no_x11` is deprecated since IPython 4.0",
323 warnings.warn("The function `skip_file_no_x11` is deprecated since IPython 4.0",
324 DeprecationWarning, stacklevel=2)
324 DeprecationWarning, stacklevel=2)
325 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
325 return decorated_dummy(skip_if_no_x11, name) if _x11_skip_cond else None
326
326
327 # Other skip decorators
327 # Other skip decorators
328
328
329 # generic skip without module
329 # generic skip without module
330 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
330 skip_without = lambda mod: skipif(module_not_available(mod), "This test requires %s" % mod)
331
331
332 skipif_not_numpy = skip_without('numpy')
332 skipif_not_numpy = skip_without('numpy')
333
333
334 skipif_not_matplotlib = skip_without('matplotlib')
334 skipif_not_matplotlib = skip_without('matplotlib')
335
335
336 skipif_not_sympy = skip_without('sympy')
336 skipif_not_sympy = skip_without('sympy')
337
337
338 skip_known_failure = knownfailureif(True,'This test is known to fail')
338 skip_known_failure = knownfailureif(True,'This test is known to fail')
339
339
340 # A null 'decorator', useful to make more readable code that needs to pick
340 # A null 'decorator', useful to make more readable code that needs to pick
341 # between different decorators based on OS or other conditions
341 # between different decorators based on OS or other conditions
342 null_deco = lambda f: f
342 null_deco = lambda f: f
343
343
344 # Some tests only run where we can use unicode paths. Note that we can't just
344 # Some tests only run where we can use unicode paths. Note that we can't just
345 # check os.path.supports_unicode_filenames, which is always False on Linux.
345 # check os.path.supports_unicode_filenames, which is always False on Linux.
346 try:
346 try:
347 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
347 f = tempfile.NamedTemporaryFile(prefix=u"tmp€")
348 except UnicodeEncodeError:
348 except UnicodeEncodeError:
349 unicode_paths = False
349 unicode_paths = False
350 else:
350 else:
351 unicode_paths = True
351 unicode_paths = True
352 f.close()
352 f.close()
353
353
354 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
354 onlyif_unicode_paths = onlyif(unicode_paths, ("This test is only applicable "
355 "where we can use unicode in filenames."))
355 "where we can use unicode in filenames."))
356
356
357
357
358 def onlyif_cmds_exist(*commands):
358 def onlyif_cmds_exist(*commands):
359 """
359 """
360 Decorator to skip test when at least one of `commands` is not found.
360 Decorator to skip test when at least one of `commands` is not found.
361 """
361 """
362 for cmd in commands:
362 for cmd in commands:
363 if not which(cmd):
363 if not which(cmd):
364 return skip("This test runs only if command '{0}' "
364 return skip("This test runs only if command '{0}' "
365 "is installed".format(cmd))
365 "is installed".format(cmd))
366 return null_deco
366 return null_deco
367
367
368 def onlyif_any_cmd_exists(*commands):
368 def onlyif_any_cmd_exists(*commands):
369 """
369 """
370 Decorator to skip test unless at least one of `commands` is found.
370 Decorator to skip test unless at least one of `commands` is found.
371 """
371 """
372 warnings.warn("The function `onlyif_any_cmd_exists` is deprecated since IPython 4.0",
372 warnings.warn("The function `onlyif_any_cmd_exists` is deprecated since IPython 4.0",
373 DeprecationWarning, stacklevel=2)
373 DeprecationWarning, stacklevel=2)
374 for cmd in commands:
374 for cmd in commands:
375 if which(cmd):
375 if which(cmd):
376 return null_deco
376 return null_deco
377 return skip("This test runs only if one of the commands {0} "
377 return skip("This test runs only if one of the commands {0} "
378 "is installed".format(commands))
378 "is installed".format(commands))
@@ -1,770 +1,766 b''
1 """Nose Plugin that supports IPython doctests.
1 """Nose Plugin that supports IPython doctests.
2
2
3 Limitations:
3 Limitations:
4
4
5 - When generating examples for use as doctests, make sure that you have
5 - When generating examples for use as doctests, make sure that you have
6 pretty-printing OFF. This can be done either by setting the
6 pretty-printing OFF. This can be done either by setting the
7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
7 ``PlainTextFormatter.pprint`` option in your configuration file to False, or
8 by interactively disabling it with %Pprint. This is required so that IPython
8 by interactively disabling it with %Pprint. This is required so that IPython
9 output matches that of normal Python, which is used by doctest for internal
9 output matches that of normal Python, which is used by doctest for internal
10 execution.
10 execution.
11
11
12 - Do not rely on specific prompt numbers for results (such as using
12 - Do not rely on specific prompt numbers for results (such as using
13 '_34==True', for example). For IPython tests run via an external process the
13 '_34==True', for example). For IPython tests run via an external process the
14 prompt numbers may be different, and IPython tests run as normal python code
14 prompt numbers may be different, and IPython tests run as normal python code
15 won't even have these special _NN variables set at all.
15 won't even have these special _NN variables set at all.
16 """
16 """
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Module imports
19 # Module imports
20
20
21 # From the standard library
21 # From the standard library
22 import doctest
22 import doctest
23 import inspect
23 import inspect
24 import logging
24 import logging
25 import os
25 import os
26 import re
26 import re
27 import sys
27 import sys
28 from importlib import import_module
28 from importlib import import_module
29 from io import StringIO
29
30
30 from testpath import modified_env
31 from testpath import modified_env
31
32
32 from inspect import getmodule
33 from inspect import getmodule
33
34
34 # We are overriding the default doctest runner, so we need to import a few
35 # We are overriding the default doctest runner, so we need to import a few
35 # things from doctest directly
36 # things from doctest directly
36 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
37 from doctest import (REPORTING_FLAGS, REPORT_ONLY_FIRST_FAILURE,
37 _unittest_reportflags, DocTestRunner,
38 _unittest_reportflags, DocTestRunner,
38 _extract_future_flags, pdb, _OutputRedirectingPdb,
39 _extract_future_flags, pdb, _OutputRedirectingPdb,
39 _exception_traceback,
40 _exception_traceback,
40 linecache)
41 linecache)
41
42
42 # Third-party modules
43 # Third-party modules
43
44
44 from nose.plugins import doctests, Plugin
45 from nose.plugins import doctests, Plugin
45 from nose.util import anyp, tolist
46 from nose.util import anyp, tolist
46
47
47 # Our own imports
48 # Our own imports
48 from IPython.utils.py3compat import builtin_mod, PY3
49 from IPython.utils.py3compat import builtin_mod
49
50 if PY3:
51 from io import StringIO
52 else:
53 from StringIO import StringIO
54
50
55 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
56 # Module globals and other constants
52 # Module globals and other constants
57 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
58
54
59 log = logging.getLogger(__name__)
55 log = logging.getLogger(__name__)
60
56
61
57
62 #-----------------------------------------------------------------------------
58 #-----------------------------------------------------------------------------
63 # Classes and functions
59 # Classes and functions
64 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
65
61
66 def is_extension_module(filename):
62 def is_extension_module(filename):
67 """Return whether the given filename is an extension module.
63 """Return whether the given filename is an extension module.
68
64
69 This simply checks that the extension is either .so or .pyd.
65 This simply checks that the extension is either .so or .pyd.
70 """
66 """
71 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
67 return os.path.splitext(filename)[1].lower() in ('.so','.pyd')
72
68
73
69
74 class DocTestSkip(object):
70 class DocTestSkip(object):
75 """Object wrapper for doctests to be skipped."""
71 """Object wrapper for doctests to be skipped."""
76
72
77 ds_skip = """Doctest to skip.
73 ds_skip = """Doctest to skip.
78 >>> 1 #doctest: +SKIP
74 >>> 1 #doctest: +SKIP
79 """
75 """
80
76
81 def __init__(self,obj):
77 def __init__(self,obj):
82 self.obj = obj
78 self.obj = obj
83
79
84 def __getattribute__(self,key):
80 def __getattribute__(self,key):
85 if key == '__doc__':
81 if key == '__doc__':
86 return DocTestSkip.ds_skip
82 return DocTestSkip.ds_skip
87 else:
83 else:
88 return getattr(object.__getattribute__(self,'obj'),key)
84 return getattr(object.__getattribute__(self,'obj'),key)
89
85
90 # Modified version of the one in the stdlib, that fixes a python bug (doctests
86 # Modified version of the one in the stdlib, that fixes a python bug (doctests
91 # not found in extension modules, http://bugs.python.org/issue3158)
87 # not found in extension modules, http://bugs.python.org/issue3158)
92 class DocTestFinder(doctest.DocTestFinder):
88 class DocTestFinder(doctest.DocTestFinder):
93
89
94 def _from_module(self, module, object):
90 def _from_module(self, module, object):
95 """
91 """
96 Return true if the given object is defined in the given
92 Return true if the given object is defined in the given
97 module.
93 module.
98 """
94 """
99 if module is None:
95 if module is None:
100 return True
96 return True
101 elif inspect.isfunction(object):
97 elif inspect.isfunction(object):
102 return module.__dict__ is object.__globals__
98 return module.__dict__ is object.__globals__
103 elif inspect.isbuiltin(object):
99 elif inspect.isbuiltin(object):
104 return module.__name__ == object.__module__
100 return module.__name__ == object.__module__
105 elif inspect.isclass(object):
101 elif inspect.isclass(object):
106 return module.__name__ == object.__module__
102 return module.__name__ == object.__module__
107 elif inspect.ismethod(object):
103 elif inspect.ismethod(object):
108 # This one may be a bug in cython that fails to correctly set the
104 # This one may be a bug in cython that fails to correctly set the
109 # __module__ attribute of methods, but since the same error is easy
105 # __module__ attribute of methods, but since the same error is easy
110 # to make by extension code writers, having this safety in place
106 # to make by extension code writers, having this safety in place
111 # isn't such a bad idea
107 # isn't such a bad idea
112 return module.__name__ == object.__self__.__class__.__module__
108 return module.__name__ == object.__self__.__class__.__module__
113 elif inspect.getmodule(object) is not None:
109 elif inspect.getmodule(object) is not None:
114 return module is inspect.getmodule(object)
110 return module is inspect.getmodule(object)
115 elif hasattr(object, '__module__'):
111 elif hasattr(object, '__module__'):
116 return module.__name__ == object.__module__
112 return module.__name__ == object.__module__
117 elif isinstance(object, property):
113 elif isinstance(object, property):
118 return True # [XX] no way not be sure.
114 return True # [XX] no way not be sure.
119 elif inspect.ismethoddescriptor(object):
115 elif inspect.ismethoddescriptor(object):
120 # Unbound PyQt signals reach this point in Python 3.4b3, and we want
116 # Unbound PyQt signals reach this point in Python 3.4b3, and we want
121 # to avoid throwing an error. See also http://bugs.python.org/issue3158
117 # to avoid throwing an error. See also http://bugs.python.org/issue3158
122 return False
118 return False
123 else:
119 else:
124 raise ValueError("object must be a class or function, got %r" % object)
120 raise ValueError("object must be a class or function, got %r" % object)
125
121
126 def _find(self, tests, obj, name, module, source_lines, globs, seen):
122 def _find(self, tests, obj, name, module, source_lines, globs, seen):
127 """
123 """
128 Find tests for the given object and any contained objects, and
124 Find tests for the given object and any contained objects, and
129 add them to `tests`.
125 add them to `tests`.
130 """
126 """
131 print('_find for:', obj, name, module) # dbg
127 print('_find for:', obj, name, module) # dbg
132 if hasattr(obj,"skip_doctest"):
128 if hasattr(obj,"skip_doctest"):
133 #print 'SKIPPING DOCTEST FOR:',obj # dbg
129 #print 'SKIPPING DOCTEST FOR:',obj # dbg
134 obj = DocTestSkip(obj)
130 obj = DocTestSkip(obj)
135
131
136 doctest.DocTestFinder._find(self,tests, obj, name, module,
132 doctest.DocTestFinder._find(self,tests, obj, name, module,
137 source_lines, globs, seen)
133 source_lines, globs, seen)
138
134
139 # Below we re-run pieces of the above method with manual modifications,
135 # Below we re-run pieces of the above method with manual modifications,
140 # because the original code is buggy and fails to correctly identify
136 # because the original code is buggy and fails to correctly identify
141 # doctests in extension modules.
137 # doctests in extension modules.
142
138
143 # Local shorthands
139 # Local shorthands
144 from inspect import isroutine, isclass
140 from inspect import isroutine, isclass
145
141
146 # Look for tests in a module's contained objects.
142 # Look for tests in a module's contained objects.
147 if inspect.ismodule(obj) and self._recurse:
143 if inspect.ismodule(obj) and self._recurse:
148 for valname, val in obj.__dict__.items():
144 for valname, val in obj.__dict__.items():
149 valname1 = '%s.%s' % (name, valname)
145 valname1 = '%s.%s' % (name, valname)
150 if ( (isroutine(val) or isclass(val))
146 if ( (isroutine(val) or isclass(val))
151 and self._from_module(module, val) ):
147 and self._from_module(module, val) ):
152
148
153 self._find(tests, val, valname1, module, source_lines,
149 self._find(tests, val, valname1, module, source_lines,
154 globs, seen)
150 globs, seen)
155
151
156 # Look for tests in a class's contained objects.
152 # Look for tests in a class's contained objects.
157 if inspect.isclass(obj) and self._recurse:
153 if inspect.isclass(obj) and self._recurse:
158 #print 'RECURSE into class:',obj # dbg
154 #print 'RECURSE into class:',obj # dbg
159 for valname, val in obj.__dict__.items():
155 for valname, val in obj.__dict__.items():
160 # Special handling for staticmethod/classmethod.
156 # Special handling for staticmethod/classmethod.
161 if isinstance(val, staticmethod):
157 if isinstance(val, staticmethod):
162 val = getattr(obj, valname)
158 val = getattr(obj, valname)
163 if isinstance(val, classmethod):
159 if isinstance(val, classmethod):
164 val = getattr(obj, valname).__func__
160 val = getattr(obj, valname).__func__
165
161
166 # Recurse to methods, properties, and nested classes.
162 # Recurse to methods, properties, and nested classes.
167 if ((inspect.isfunction(val) or inspect.isclass(val) or
163 if ((inspect.isfunction(val) or inspect.isclass(val) or
168 inspect.ismethod(val) or
164 inspect.ismethod(val) or
169 isinstance(val, property)) and
165 isinstance(val, property)) and
170 self._from_module(module, val)):
166 self._from_module(module, val)):
171 valname = '%s.%s' % (name, valname)
167 valname = '%s.%s' % (name, valname)
172 self._find(tests, val, valname, module, source_lines,
168 self._find(tests, val, valname, module, source_lines,
173 globs, seen)
169 globs, seen)
174
170
175
171
176 class IPDoctestOutputChecker(doctest.OutputChecker):
172 class IPDoctestOutputChecker(doctest.OutputChecker):
177 """Second-chance checker with support for random tests.
173 """Second-chance checker with support for random tests.
178
174
179 If the default comparison doesn't pass, this checker looks in the expected
175 If the default comparison doesn't pass, this checker looks in the expected
180 output string for flags that tell us to ignore the output.
176 output string for flags that tell us to ignore the output.
181 """
177 """
182
178
183 random_re = re.compile(r'#\s*random\s+')
179 random_re = re.compile(r'#\s*random\s+')
184
180
185 def check_output(self, want, got, optionflags):
181 def check_output(self, want, got, optionflags):
186 """Check output, accepting special markers embedded in the output.
182 """Check output, accepting special markers embedded in the output.
187
183
188 If the output didn't pass the default validation but the special string
184 If the output didn't pass the default validation but the special string
189 '#random' is included, we accept it."""
185 '#random' is included, we accept it."""
190
186
191 # Let the original tester verify first, in case people have valid tests
187 # Let the original tester verify first, in case people have valid tests
192 # that happen to have a comment saying '#random' embedded in.
188 # that happen to have a comment saying '#random' embedded in.
193 ret = doctest.OutputChecker.check_output(self, want, got,
189 ret = doctest.OutputChecker.check_output(self, want, got,
194 optionflags)
190 optionflags)
195 if not ret and self.random_re.search(want):
191 if not ret and self.random_re.search(want):
196 #print >> sys.stderr, 'RANDOM OK:',want # dbg
192 #print >> sys.stderr, 'RANDOM OK:',want # dbg
197 return True
193 return True
198
194
199 return ret
195 return ret
200
196
201
197
202 class DocTestCase(doctests.DocTestCase):
198 class DocTestCase(doctests.DocTestCase):
203 """Proxy for DocTestCase: provides an address() method that
199 """Proxy for DocTestCase: provides an address() method that
204 returns the correct address for the doctest case. Otherwise
200 returns the correct address for the doctest case. Otherwise
205 acts as a proxy to the test case. To provide hints for address(),
201 acts as a proxy to the test case. To provide hints for address(),
206 an obj may also be passed -- this will be used as the test object
202 an obj may also be passed -- this will be used as the test object
207 for purposes of determining the test address, if it is provided.
203 for purposes of determining the test address, if it is provided.
208 """
204 """
209
205
210 # Note: this method was taken from numpy's nosetester module.
206 # Note: this method was taken from numpy's nosetester module.
211
207
212 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
208 # Subclass nose.plugins.doctests.DocTestCase to work around a bug in
213 # its constructor that blocks non-default arguments from being passed
209 # its constructor that blocks non-default arguments from being passed
214 # down into doctest.DocTestCase
210 # down into doctest.DocTestCase
215
211
216 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
212 def __init__(self, test, optionflags=0, setUp=None, tearDown=None,
217 checker=None, obj=None, result_var='_'):
213 checker=None, obj=None, result_var='_'):
218 self._result_var = result_var
214 self._result_var = result_var
219 doctests.DocTestCase.__init__(self, test,
215 doctests.DocTestCase.__init__(self, test,
220 optionflags=optionflags,
216 optionflags=optionflags,
221 setUp=setUp, tearDown=tearDown,
217 setUp=setUp, tearDown=tearDown,
222 checker=checker)
218 checker=checker)
223 # Now we must actually copy the original constructor from the stdlib
219 # Now we must actually copy the original constructor from the stdlib
224 # doctest class, because we can't call it directly and a bug in nose
220 # doctest class, because we can't call it directly and a bug in nose
225 # means it never gets passed the right arguments.
221 # means it never gets passed the right arguments.
226
222
227 self._dt_optionflags = optionflags
223 self._dt_optionflags = optionflags
228 self._dt_checker = checker
224 self._dt_checker = checker
229 self._dt_test = test
225 self._dt_test = test
230 self._dt_test_globs_ori = test.globs
226 self._dt_test_globs_ori = test.globs
231 self._dt_setUp = setUp
227 self._dt_setUp = setUp
232 self._dt_tearDown = tearDown
228 self._dt_tearDown = tearDown
233
229
234 # XXX - store this runner once in the object!
230 # XXX - store this runner once in the object!
235 runner = IPDocTestRunner(optionflags=optionflags,
231 runner = IPDocTestRunner(optionflags=optionflags,
236 checker=checker, verbose=False)
232 checker=checker, verbose=False)
237 self._dt_runner = runner
233 self._dt_runner = runner
238
234
239
235
240 # Each doctest should remember the directory it was loaded from, so
236 # Each doctest should remember the directory it was loaded from, so
241 # things like %run work without too many contortions
237 # things like %run work without too many contortions
242 self._ori_dir = os.path.dirname(test.filename)
238 self._ori_dir = os.path.dirname(test.filename)
243
239
244 # Modified runTest from the default stdlib
240 # Modified runTest from the default stdlib
245 def runTest(self):
241 def runTest(self):
246 test = self._dt_test
242 test = self._dt_test
247 runner = self._dt_runner
243 runner = self._dt_runner
248
244
249 old = sys.stdout
245 old = sys.stdout
250 new = StringIO()
246 new = StringIO()
251 optionflags = self._dt_optionflags
247 optionflags = self._dt_optionflags
252
248
253 if not (optionflags & REPORTING_FLAGS):
249 if not (optionflags & REPORTING_FLAGS):
254 # The option flags don't include any reporting flags,
250 # The option flags don't include any reporting flags,
255 # so add the default reporting flags
251 # so add the default reporting flags
256 optionflags |= _unittest_reportflags
252 optionflags |= _unittest_reportflags
257
253
258 try:
254 try:
259 # Save our current directory and switch out to the one where the
255 # Save our current directory and switch out to the one where the
260 # test was originally created, in case another doctest did a
256 # test was originally created, in case another doctest did a
261 # directory change. We'll restore this in the finally clause.
257 # directory change. We'll restore this in the finally clause.
262 curdir = os.getcwd()
258 curdir = os.getcwd()
263 #print 'runTest in dir:', self._ori_dir # dbg
259 #print 'runTest in dir:', self._ori_dir # dbg
264 os.chdir(self._ori_dir)
260 os.chdir(self._ori_dir)
265
261
266 runner.DIVIDER = "-"*70
262 runner.DIVIDER = "-"*70
267 failures, tries = runner.run(test,out=new.write,
263 failures, tries = runner.run(test,out=new.write,
268 clear_globs=False)
264 clear_globs=False)
269 finally:
265 finally:
270 sys.stdout = old
266 sys.stdout = old
271 os.chdir(curdir)
267 os.chdir(curdir)
272
268
273 if failures:
269 if failures:
274 raise self.failureException(self.format_failure(new.getvalue()))
270 raise self.failureException(self.format_failure(new.getvalue()))
275
271
276 def setUp(self):
272 def setUp(self):
277 """Modified test setup that syncs with ipython namespace"""
273 """Modified test setup that syncs with ipython namespace"""
278 #print "setUp test", self._dt_test.examples # dbg
274 #print "setUp test", self._dt_test.examples # dbg
279 if isinstance(self._dt_test.examples[0], IPExample):
275 if isinstance(self._dt_test.examples[0], IPExample):
280 # for IPython examples *only*, we swap the globals with the ipython
276 # for IPython examples *only*, we swap the globals with the ipython
281 # namespace, after updating it with the globals (which doctest
277 # namespace, after updating it with the globals (which doctest
282 # fills with the necessary info from the module being tested).
278 # fills with the necessary info from the module being tested).
283 self.user_ns_orig = {}
279 self.user_ns_orig = {}
284 self.user_ns_orig.update(_ip.user_ns)
280 self.user_ns_orig.update(_ip.user_ns)
285 _ip.user_ns.update(self._dt_test.globs)
281 _ip.user_ns.update(self._dt_test.globs)
286 # We must remove the _ key in the namespace, so that Python's
282 # We must remove the _ key in the namespace, so that Python's
287 # doctest code sets it naturally
283 # doctest code sets it naturally
288 _ip.user_ns.pop('_', None)
284 _ip.user_ns.pop('_', None)
289 _ip.user_ns['__builtins__'] = builtin_mod
285 _ip.user_ns['__builtins__'] = builtin_mod
290 self._dt_test.globs = _ip.user_ns
286 self._dt_test.globs = _ip.user_ns
291
287
292 super(DocTestCase, self).setUp()
288 super(DocTestCase, self).setUp()
293
289
294 def tearDown(self):
290 def tearDown(self):
295
291
296 # Undo the test.globs reassignment we made, so that the parent class
292 # Undo the test.globs reassignment we made, so that the parent class
297 # teardown doesn't destroy the ipython namespace
293 # teardown doesn't destroy the ipython namespace
298 if isinstance(self._dt_test.examples[0], IPExample):
294 if isinstance(self._dt_test.examples[0], IPExample):
299 self._dt_test.globs = self._dt_test_globs_ori
295 self._dt_test.globs = self._dt_test_globs_ori
300 _ip.user_ns.clear()
296 _ip.user_ns.clear()
301 _ip.user_ns.update(self.user_ns_orig)
297 _ip.user_ns.update(self.user_ns_orig)
302
298
303 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
299 # XXX - fperez: I am not sure if this is truly a bug in nose 0.11, but
304 # it does look like one to me: its tearDown method tries to run
300 # it does look like one to me: its tearDown method tries to run
305 #
301 #
306 # delattr(builtin_mod, self._result_var)
302 # delattr(builtin_mod, self._result_var)
307 #
303 #
308 # without checking that the attribute really is there; it implicitly
304 # without checking that the attribute really is there; it implicitly
309 # assumes it should have been set via displayhook. But if the
305 # assumes it should have been set via displayhook. But if the
310 # displayhook was never called, this doesn't necessarily happen. I
306 # displayhook was never called, this doesn't necessarily happen. I
311 # haven't been able to find a little self-contained example outside of
307 # haven't been able to find a little self-contained example outside of
312 # ipython that would show the problem so I can report it to the nose
308 # ipython that would show the problem so I can report it to the nose
313 # team, but it does happen a lot in our code.
309 # team, but it does happen a lot in our code.
314 #
310 #
315 # So here, we just protect as narrowly as possible by trapping an
311 # So here, we just protect as narrowly as possible by trapping an
316 # attribute error whose message would be the name of self._result_var,
312 # attribute error whose message would be the name of self._result_var,
317 # and letting any other error propagate.
313 # and letting any other error propagate.
318 try:
314 try:
319 super(DocTestCase, self).tearDown()
315 super(DocTestCase, self).tearDown()
320 except AttributeError as exc:
316 except AttributeError as exc:
321 if exc.args[0] != self._result_var:
317 if exc.args[0] != self._result_var:
322 raise
318 raise
323
319
324
320
325 # A simple subclassing of the original with a different class name, so we can
321 # A simple subclassing of the original with a different class name, so we can
326 # distinguish and treat differently IPython examples from pure python ones.
322 # distinguish and treat differently IPython examples from pure python ones.
327 class IPExample(doctest.Example): pass
323 class IPExample(doctest.Example): pass
328
324
329
325
330 class IPExternalExample(doctest.Example):
326 class IPExternalExample(doctest.Example):
331 """Doctest examples to be run in an external process."""
327 """Doctest examples to be run in an external process."""
332
328
333 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
329 def __init__(self, source, want, exc_msg=None, lineno=0, indent=0,
334 options=None):
330 options=None):
335 # Parent constructor
331 # Parent constructor
336 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
332 doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options)
337
333
338 # An EXTRA newline is needed to prevent pexpect hangs
334 # An EXTRA newline is needed to prevent pexpect hangs
339 self.source += '\n'
335 self.source += '\n'
340
336
341
337
342 class IPDocTestParser(doctest.DocTestParser):
338 class IPDocTestParser(doctest.DocTestParser):
343 """
339 """
344 A class used to parse strings containing doctest examples.
340 A class used to parse strings containing doctest examples.
345
341
346 Note: This is a version modified to properly recognize IPython input and
342 Note: This is a version modified to properly recognize IPython input and
347 convert any IPython examples into valid Python ones.
343 convert any IPython examples into valid Python ones.
348 """
344 """
349 # This regular expression is used to find doctest examples in a
345 # This regular expression is used to find doctest examples in a
350 # string. It defines three groups: `source` is the source code
346 # string. It defines three groups: `source` is the source code
351 # (including leading indentation and prompts); `indent` is the
347 # (including leading indentation and prompts); `indent` is the
352 # indentation of the first (PS1) line of the source code; and
348 # indentation of the first (PS1) line of the source code; and
353 # `want` is the expected output (including leading indentation).
349 # `want` is the expected output (including leading indentation).
354
350
355 # Classic Python prompts or default IPython ones
351 # Classic Python prompts or default IPython ones
356 _PS1_PY = r'>>>'
352 _PS1_PY = r'>>>'
357 _PS2_PY = r'\.\.\.'
353 _PS2_PY = r'\.\.\.'
358
354
359 _PS1_IP = r'In\ \[\d+\]:'
355 _PS1_IP = r'In\ \[\d+\]:'
360 _PS2_IP = r'\ \ \ \.\.\.+:'
356 _PS2_IP = r'\ \ \ \.\.\.+:'
361
357
362 _RE_TPL = r'''
358 _RE_TPL = r'''
363 # Source consists of a PS1 line followed by zero or more PS2 lines.
359 # Source consists of a PS1 line followed by zero or more PS2 lines.
364 (?P<source>
360 (?P<source>
365 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
361 (?:^(?P<indent> [ ]*) (?P<ps1> %s) .*) # PS1 line
366 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
362 (?:\n [ ]* (?P<ps2> %s) .*)*) # PS2 lines
367 \n? # a newline
363 \n? # a newline
368 # Want consists of any non-blank lines that do not start with PS1.
364 # Want consists of any non-blank lines that do not start with PS1.
369 (?P<want> (?:(?![ ]*$) # Not a blank line
365 (?P<want> (?:(?![ ]*$) # Not a blank line
370 (?![ ]*%s) # Not a line starting with PS1
366 (?![ ]*%s) # Not a line starting with PS1
371 (?![ ]*%s) # Not a line starting with PS2
367 (?![ ]*%s) # Not a line starting with PS2
372 .*$\n? # But any other line
368 .*$\n? # But any other line
373 )*)
369 )*)
374 '''
370 '''
375
371
376 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
372 _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY),
377 re.MULTILINE | re.VERBOSE)
373 re.MULTILINE | re.VERBOSE)
378
374
379 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
375 _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP),
380 re.MULTILINE | re.VERBOSE)
376 re.MULTILINE | re.VERBOSE)
381
377
382 # Mark a test as being fully random. In this case, we simply append the
378 # Mark a test as being fully random. In this case, we simply append the
383 # random marker ('#random') to each individual example's output. This way
379 # random marker ('#random') to each individual example's output. This way
384 # we don't need to modify any other code.
380 # we don't need to modify any other code.
385 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
381 _RANDOM_TEST = re.compile(r'#\s*all-random\s+')
386
382
387 # Mark tests to be executed in an external process - currently unsupported.
383 # Mark tests to be executed in an external process - currently unsupported.
388 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
384 _EXTERNAL_IP = re.compile(r'#\s*ipdoctest:\s*EXTERNAL')
389
385
390 def ip2py(self,source):
386 def ip2py(self,source):
391 """Convert input IPython source into valid Python."""
387 """Convert input IPython source into valid Python."""
392 block = _ip.input_transformer_manager.transform_cell(source)
388 block = _ip.input_transformer_manager.transform_cell(source)
393 if len(block.splitlines()) == 1:
389 if len(block.splitlines()) == 1:
394 return _ip.prefilter(block)
390 return _ip.prefilter(block)
395 else:
391 else:
396 return block
392 return block
397
393
398 def parse(self, string, name='<string>'):
394 def parse(self, string, name='<string>'):
399 """
395 """
400 Divide the given string into examples and intervening text,
396 Divide the given string into examples and intervening text,
401 and return them as a list of alternating Examples and strings.
397 and return them as a list of alternating Examples and strings.
402 Line numbers for the Examples are 0-based. The optional
398 Line numbers for the Examples are 0-based. The optional
403 argument `name` is a name identifying this string, and is only
399 argument `name` is a name identifying this string, and is only
404 used for error messages.
400 used for error messages.
405 """
401 """
406
402
407 #print 'Parse string:\n',string # dbg
403 #print 'Parse string:\n',string # dbg
408
404
409 string = string.expandtabs()
405 string = string.expandtabs()
410 # If all lines begin with the same indentation, then strip it.
406 # If all lines begin with the same indentation, then strip it.
411 min_indent = self._min_indent(string)
407 min_indent = self._min_indent(string)
412 if min_indent > 0:
408 if min_indent > 0:
413 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
409 string = '\n'.join([l[min_indent:] for l in string.split('\n')])
414
410
415 output = []
411 output = []
416 charno, lineno = 0, 0
412 charno, lineno = 0, 0
417
413
418 # We make 'all random' tests by adding the '# random' mark to every
414 # We make 'all random' tests by adding the '# random' mark to every
419 # block of output in the test.
415 # block of output in the test.
420 if self._RANDOM_TEST.search(string):
416 if self._RANDOM_TEST.search(string):
421 random_marker = '\n# random'
417 random_marker = '\n# random'
422 else:
418 else:
423 random_marker = ''
419 random_marker = ''
424
420
425 # Whether to convert the input from ipython to python syntax
421 # Whether to convert the input from ipython to python syntax
426 ip2py = False
422 ip2py = False
427 # Find all doctest examples in the string. First, try them as Python
423 # Find all doctest examples in the string. First, try them as Python
428 # examples, then as IPython ones
424 # examples, then as IPython ones
429 terms = list(self._EXAMPLE_RE_PY.finditer(string))
425 terms = list(self._EXAMPLE_RE_PY.finditer(string))
430 if terms:
426 if terms:
431 # Normal Python example
427 # Normal Python example
432 #print '-'*70 # dbg
428 #print '-'*70 # dbg
433 #print 'PyExample, Source:\n',string # dbg
429 #print 'PyExample, Source:\n',string # dbg
434 #print '-'*70 # dbg
430 #print '-'*70 # dbg
435 Example = doctest.Example
431 Example = doctest.Example
436 else:
432 else:
437 # It's an ipython example. Note that IPExamples are run
433 # It's an ipython example. Note that IPExamples are run
438 # in-process, so their syntax must be turned into valid python.
434 # in-process, so their syntax must be turned into valid python.
439 # IPExternalExamples are run out-of-process (via pexpect) so they
435 # IPExternalExamples are run out-of-process (via pexpect) so they
440 # don't need any filtering (a real ipython will be executing them).
436 # don't need any filtering (a real ipython will be executing them).
441 terms = list(self._EXAMPLE_RE_IP.finditer(string))
437 terms = list(self._EXAMPLE_RE_IP.finditer(string))
442 if self._EXTERNAL_IP.search(string):
438 if self._EXTERNAL_IP.search(string):
443 #print '-'*70 # dbg
439 #print '-'*70 # dbg
444 #print 'IPExternalExample, Source:\n',string # dbg
440 #print 'IPExternalExample, Source:\n',string # dbg
445 #print '-'*70 # dbg
441 #print '-'*70 # dbg
446 Example = IPExternalExample
442 Example = IPExternalExample
447 else:
443 else:
448 #print '-'*70 # dbg
444 #print '-'*70 # dbg
449 #print 'IPExample, Source:\n',string # dbg
445 #print 'IPExample, Source:\n',string # dbg
450 #print '-'*70 # dbg
446 #print '-'*70 # dbg
451 Example = IPExample
447 Example = IPExample
452 ip2py = True
448 ip2py = True
453
449
454 for m in terms:
450 for m in terms:
455 # Add the pre-example text to `output`.
451 # Add the pre-example text to `output`.
456 output.append(string[charno:m.start()])
452 output.append(string[charno:m.start()])
457 # Update lineno (lines before this example)
453 # Update lineno (lines before this example)
458 lineno += string.count('\n', charno, m.start())
454 lineno += string.count('\n', charno, m.start())
459 # Extract info from the regexp match.
455 # Extract info from the regexp match.
460 (source, options, want, exc_msg) = \
456 (source, options, want, exc_msg) = \
461 self._parse_example(m, name, lineno,ip2py)
457 self._parse_example(m, name, lineno,ip2py)
462
458
463 # Append the random-output marker (it defaults to empty in most
459 # Append the random-output marker (it defaults to empty in most
464 # cases, it's only non-empty for 'all-random' tests):
460 # cases, it's only non-empty for 'all-random' tests):
465 want += random_marker
461 want += random_marker
466
462
467 if Example is IPExternalExample:
463 if Example is IPExternalExample:
468 options[doctest.NORMALIZE_WHITESPACE] = True
464 options[doctest.NORMALIZE_WHITESPACE] = True
469 want += '\n'
465 want += '\n'
470
466
471 # Create an Example, and add it to the list.
467 # Create an Example, and add it to the list.
472 if not self._IS_BLANK_OR_COMMENT(source):
468 if not self._IS_BLANK_OR_COMMENT(source):
473 output.append(Example(source, want, exc_msg,
469 output.append(Example(source, want, exc_msg,
474 lineno=lineno,
470 lineno=lineno,
475 indent=min_indent+len(m.group('indent')),
471 indent=min_indent+len(m.group('indent')),
476 options=options))
472 options=options))
477 # Update lineno (lines inside this example)
473 # Update lineno (lines inside this example)
478 lineno += string.count('\n', m.start(), m.end())
474 lineno += string.count('\n', m.start(), m.end())
479 # Update charno.
475 # Update charno.
480 charno = m.end()
476 charno = m.end()
481 # Add any remaining post-example text to `output`.
477 # Add any remaining post-example text to `output`.
482 output.append(string[charno:])
478 output.append(string[charno:])
483 return output
479 return output
484
480
485 def _parse_example(self, m, name, lineno,ip2py=False):
481 def _parse_example(self, m, name, lineno,ip2py=False):
486 """
482 """
487 Given a regular expression match from `_EXAMPLE_RE` (`m`),
483 Given a regular expression match from `_EXAMPLE_RE` (`m`),
488 return a pair `(source, want)`, where `source` is the matched
484 return a pair `(source, want)`, where `source` is the matched
489 example's source code (with prompts and indentation stripped);
485 example's source code (with prompts and indentation stripped);
490 and `want` is the example's expected output (with indentation
486 and `want` is the example's expected output (with indentation
491 stripped).
487 stripped).
492
488
493 `name` is the string's name, and `lineno` is the line number
489 `name` is the string's name, and `lineno` is the line number
494 where the example starts; both are used for error messages.
490 where the example starts; both are used for error messages.
495
491
496 Optional:
492 Optional:
497 `ip2py`: if true, filter the input via IPython to convert the syntax
493 `ip2py`: if true, filter the input via IPython to convert the syntax
498 into valid python.
494 into valid python.
499 """
495 """
500
496
501 # Get the example's indentation level.
497 # Get the example's indentation level.
502 indent = len(m.group('indent'))
498 indent = len(m.group('indent'))
503
499
504 # Divide source into lines; check that they're properly
500 # Divide source into lines; check that they're properly
505 # indented; and then strip their indentation & prompts.
501 # indented; and then strip their indentation & prompts.
506 source_lines = m.group('source').split('\n')
502 source_lines = m.group('source').split('\n')
507
503
508 # We're using variable-length input prompts
504 # We're using variable-length input prompts
509 ps1 = m.group('ps1')
505 ps1 = m.group('ps1')
510 ps2 = m.group('ps2')
506 ps2 = m.group('ps2')
511 ps1_len = len(ps1)
507 ps1_len = len(ps1)
512
508
513 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
509 self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len)
514 if ps2:
510 if ps2:
515 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
511 self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno)
516
512
517 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
513 source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines])
518
514
519 if ip2py:
515 if ip2py:
520 # Convert source input from IPython into valid Python syntax
516 # Convert source input from IPython into valid Python syntax
521 source = self.ip2py(source)
517 source = self.ip2py(source)
522
518
523 # Divide want into lines; check that it's properly indented; and
519 # Divide want into lines; check that it's properly indented; and
524 # then strip the indentation. Spaces before the last newline should
520 # then strip the indentation. Spaces before the last newline should
525 # be preserved, so plain rstrip() isn't good enough.
521 # be preserved, so plain rstrip() isn't good enough.
526 want = m.group('want')
522 want = m.group('want')
527 want_lines = want.split('\n')
523 want_lines = want.split('\n')
528 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
524 if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]):
529 del want_lines[-1] # forget final newline & spaces after it
525 del want_lines[-1] # forget final newline & spaces after it
530 self._check_prefix(want_lines, ' '*indent, name,
526 self._check_prefix(want_lines, ' '*indent, name,
531 lineno + len(source_lines))
527 lineno + len(source_lines))
532
528
533 # Remove ipython output prompt that might be present in the first line
529 # Remove ipython output prompt that might be present in the first line
534 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
530 want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0])
535
531
536 want = '\n'.join([wl[indent:] for wl in want_lines])
532 want = '\n'.join([wl[indent:] for wl in want_lines])
537
533
538 # If `want` contains a traceback message, then extract it.
534 # If `want` contains a traceback message, then extract it.
539 m = self._EXCEPTION_RE.match(want)
535 m = self._EXCEPTION_RE.match(want)
540 if m:
536 if m:
541 exc_msg = m.group('msg')
537 exc_msg = m.group('msg')
542 else:
538 else:
543 exc_msg = None
539 exc_msg = None
544
540
545 # Extract options from the source.
541 # Extract options from the source.
546 options = self._find_options(source, name, lineno)
542 options = self._find_options(source, name, lineno)
547
543
548 return source, options, want, exc_msg
544 return source, options, want, exc_msg
549
545
550 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
546 def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len):
551 """
547 """
552 Given the lines of a source string (including prompts and
548 Given the lines of a source string (including prompts and
553 leading indentation), check to make sure that every prompt is
549 leading indentation), check to make sure that every prompt is
554 followed by a space character. If any line is not followed by
550 followed by a space character. If any line is not followed by
555 a space character, then raise ValueError.
551 a space character, then raise ValueError.
556
552
557 Note: IPython-modified version which takes the input prompt length as a
553 Note: IPython-modified version which takes the input prompt length as a
558 parameter, so that prompts of variable length can be dealt with.
554 parameter, so that prompts of variable length can be dealt with.
559 """
555 """
560 space_idx = indent+ps1_len
556 space_idx = indent+ps1_len
561 min_len = space_idx+1
557 min_len = space_idx+1
562 for i, line in enumerate(lines):
558 for i, line in enumerate(lines):
563 if len(line) >= min_len and line[space_idx] != ' ':
559 if len(line) >= min_len and line[space_idx] != ' ':
564 raise ValueError('line %r of the docstring for %s '
560 raise ValueError('line %r of the docstring for %s '
565 'lacks blank after %s: %r' %
561 'lacks blank after %s: %r' %
566 (lineno+i+1, name,
562 (lineno+i+1, name,
567 line[indent:space_idx], line))
563 line[indent:space_idx], line))
568
564
569
565
570 SKIP = doctest.register_optionflag('SKIP')
566 SKIP = doctest.register_optionflag('SKIP')
571
567
572
568
573 class IPDocTestRunner(doctest.DocTestRunner,object):
569 class IPDocTestRunner(doctest.DocTestRunner,object):
574 """Test runner that synchronizes the IPython namespace with test globals.
570 """Test runner that synchronizes the IPython namespace with test globals.
575 """
571 """
576
572
577 def run(self, test, compileflags=None, out=None, clear_globs=True):
573 def run(self, test, compileflags=None, out=None, clear_globs=True):
578
574
579 # Hack: ipython needs access to the execution context of the example,
575 # Hack: ipython needs access to the execution context of the example,
580 # so that it can propagate user variables loaded by %run into
576 # so that it can propagate user variables loaded by %run into
581 # test.globs. We put them here into our modified %run as a function
577 # test.globs. We put them here into our modified %run as a function
582 # attribute. Our new %run will then only make the namespace update
578 # attribute. Our new %run will then only make the namespace update
583 # when called (rather than unconconditionally updating test.globs here
579 # when called (rather than unconconditionally updating test.globs here
584 # for all examples, most of which won't be calling %run anyway).
580 # for all examples, most of which won't be calling %run anyway).
585 #_ip._ipdoctest_test_globs = test.globs
581 #_ip._ipdoctest_test_globs = test.globs
586 #_ip._ipdoctest_test_filename = test.filename
582 #_ip._ipdoctest_test_filename = test.filename
587
583
588 test.globs.update(_ip.user_ns)
584 test.globs.update(_ip.user_ns)
589
585
590 # Override terminal size to standardise traceback format
586 # Override terminal size to standardise traceback format
591 with modified_env({'COLUMNS': '80', 'LINES': '24'}):
587 with modified_env({'COLUMNS': '80', 'LINES': '24'}):
592 return super(IPDocTestRunner,self).run(test,
588 return super(IPDocTestRunner,self).run(test,
593 compileflags,out,clear_globs)
589 compileflags,out,clear_globs)
594
590
595
591
596 class DocFileCase(doctest.DocFileCase):
592 class DocFileCase(doctest.DocFileCase):
597 """Overrides to provide filename
593 """Overrides to provide filename
598 """
594 """
599 def address(self):
595 def address(self):
600 return (self._dt_test.filename, None, None)
596 return (self._dt_test.filename, None, None)
601
597
602
598
603 class ExtensionDoctest(doctests.Doctest):
599 class ExtensionDoctest(doctests.Doctest):
604 """Nose Plugin that supports doctests in extension modules.
600 """Nose Plugin that supports doctests in extension modules.
605 """
601 """
606 name = 'extdoctest' # call nosetests with --with-extdoctest
602 name = 'extdoctest' # call nosetests with --with-extdoctest
607 enabled = True
603 enabled = True
608
604
609 def options(self, parser, env=os.environ):
605 def options(self, parser, env=os.environ):
610 Plugin.options(self, parser, env)
606 Plugin.options(self, parser, env)
611 parser.add_option('--doctest-tests', action='store_true',
607 parser.add_option('--doctest-tests', action='store_true',
612 dest='doctest_tests',
608 dest='doctest_tests',
613 default=env.get('NOSE_DOCTEST_TESTS',True),
609 default=env.get('NOSE_DOCTEST_TESTS',True),
614 help="Also look for doctests in test modules. "
610 help="Also look for doctests in test modules. "
615 "Note that classes, methods and functions should "
611 "Note that classes, methods and functions should "
616 "have either doctests or non-doctest tests, "
612 "have either doctests or non-doctest tests, "
617 "not both. [NOSE_DOCTEST_TESTS]")
613 "not both. [NOSE_DOCTEST_TESTS]")
618 parser.add_option('--doctest-extension', action="append",
614 parser.add_option('--doctest-extension', action="append",
619 dest="doctestExtension",
615 dest="doctestExtension",
620 help="Also look for doctests in files with "
616 help="Also look for doctests in files with "
621 "this extension [NOSE_DOCTEST_EXTENSION]")
617 "this extension [NOSE_DOCTEST_EXTENSION]")
622 # Set the default as a list, if given in env; otherwise
618 # Set the default as a list, if given in env; otherwise
623 # an additional value set on the command line will cause
619 # an additional value set on the command line will cause
624 # an error.
620 # an error.
625 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
621 env_setting = env.get('NOSE_DOCTEST_EXTENSION')
626 if env_setting is not None:
622 if env_setting is not None:
627 parser.set_defaults(doctestExtension=tolist(env_setting))
623 parser.set_defaults(doctestExtension=tolist(env_setting))
628
624
629
625
630 def configure(self, options, config):
626 def configure(self, options, config):
631 Plugin.configure(self, options, config)
627 Plugin.configure(self, options, config)
632 # Pull standard doctest plugin out of config; we will do doctesting
628 # Pull standard doctest plugin out of config; we will do doctesting
633 config.plugins.plugins = [p for p in config.plugins.plugins
629 config.plugins.plugins = [p for p in config.plugins.plugins
634 if p.name != 'doctest']
630 if p.name != 'doctest']
635 self.doctest_tests = options.doctest_tests
631 self.doctest_tests = options.doctest_tests
636 self.extension = tolist(options.doctestExtension)
632 self.extension = tolist(options.doctestExtension)
637
633
638 self.parser = doctest.DocTestParser()
634 self.parser = doctest.DocTestParser()
639 self.finder = DocTestFinder()
635 self.finder = DocTestFinder()
640 self.checker = IPDoctestOutputChecker()
636 self.checker = IPDoctestOutputChecker()
641 self.globs = None
637 self.globs = None
642 self.extraglobs = None
638 self.extraglobs = None
643
639
644
640
645 def loadTestsFromExtensionModule(self,filename):
641 def loadTestsFromExtensionModule(self,filename):
646 bpath,mod = os.path.split(filename)
642 bpath,mod = os.path.split(filename)
647 modname = os.path.splitext(mod)[0]
643 modname = os.path.splitext(mod)[0]
648 try:
644 try:
649 sys.path.append(bpath)
645 sys.path.append(bpath)
650 module = import_module(modname)
646 module = import_module(modname)
651 tests = list(self.loadTestsFromModule(module))
647 tests = list(self.loadTestsFromModule(module))
652 finally:
648 finally:
653 sys.path.pop()
649 sys.path.pop()
654 return tests
650 return tests
655
651
656 # NOTE: the method below is almost a copy of the original one in nose, with
652 # NOTE: the method below is almost a copy of the original one in nose, with
657 # a few modifications to control output checking.
653 # a few modifications to control output checking.
658
654
659 def loadTestsFromModule(self, module):
655 def loadTestsFromModule(self, module):
660 #print '*** ipdoctest - lTM',module # dbg
656 #print '*** ipdoctest - lTM',module # dbg
661
657
662 if not self.matches(module.__name__):
658 if not self.matches(module.__name__):
663 log.debug("Doctest doesn't want module %s", module)
659 log.debug("Doctest doesn't want module %s", module)
664 return
660 return
665
661
666 tests = self.finder.find(module,globs=self.globs,
662 tests = self.finder.find(module,globs=self.globs,
667 extraglobs=self.extraglobs)
663 extraglobs=self.extraglobs)
668 if not tests:
664 if not tests:
669 return
665 return
670
666
671 # always use whitespace and ellipsis options
667 # always use whitespace and ellipsis options
672 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
668 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
673
669
674 tests.sort()
670 tests.sort()
675 module_file = module.__file__
671 module_file = module.__file__
676 if module_file[-4:] in ('.pyc', '.pyo'):
672 if module_file[-4:] in ('.pyc', '.pyo'):
677 module_file = module_file[:-1]
673 module_file = module_file[:-1]
678 for test in tests:
674 for test in tests:
679 if not test.examples:
675 if not test.examples:
680 continue
676 continue
681 if not test.filename:
677 if not test.filename:
682 test.filename = module_file
678 test.filename = module_file
683
679
684 yield DocTestCase(test,
680 yield DocTestCase(test,
685 optionflags=optionflags,
681 optionflags=optionflags,
686 checker=self.checker)
682 checker=self.checker)
687
683
688
684
689 def loadTestsFromFile(self, filename):
685 def loadTestsFromFile(self, filename):
690 #print "ipdoctest - from file", filename # dbg
686 #print "ipdoctest - from file", filename # dbg
691 if is_extension_module(filename):
687 if is_extension_module(filename):
692 for t in self.loadTestsFromExtensionModule(filename):
688 for t in self.loadTestsFromExtensionModule(filename):
693 yield t
689 yield t
694 else:
690 else:
695 if self.extension and anyp(filename.endswith, self.extension):
691 if self.extension and anyp(filename.endswith, self.extension):
696 name = os.path.basename(filename)
692 name = os.path.basename(filename)
697 dh = open(filename)
693 dh = open(filename)
698 try:
694 try:
699 doc = dh.read()
695 doc = dh.read()
700 finally:
696 finally:
701 dh.close()
697 dh.close()
702 test = self.parser.get_doctest(
698 test = self.parser.get_doctest(
703 doc, globs={'__file__': filename}, name=name,
699 doc, globs={'__file__': filename}, name=name,
704 filename=filename, lineno=0)
700 filename=filename, lineno=0)
705 if test.examples:
701 if test.examples:
706 #print 'FileCase:',test.examples # dbg
702 #print 'FileCase:',test.examples # dbg
707 yield DocFileCase(test)
703 yield DocFileCase(test)
708 else:
704 else:
709 yield False # no tests to load
705 yield False # no tests to load
710
706
711
707
712 class IPythonDoctest(ExtensionDoctest):
708 class IPythonDoctest(ExtensionDoctest):
713 """Nose Plugin that supports doctests in extension modules.
709 """Nose Plugin that supports doctests in extension modules.
714 """
710 """
715 name = 'ipdoctest' # call nosetests with --with-ipdoctest
711 name = 'ipdoctest' # call nosetests with --with-ipdoctest
716 enabled = True
712 enabled = True
717
713
718 def makeTest(self, obj, parent):
714 def makeTest(self, obj, parent):
719 """Look for doctests in the given object, which will be a
715 """Look for doctests in the given object, which will be a
720 function, method or class.
716 function, method or class.
721 """
717 """
722 #print 'Plugin analyzing:', obj, parent # dbg
718 #print 'Plugin analyzing:', obj, parent # dbg
723 # always use whitespace and ellipsis options
719 # always use whitespace and ellipsis options
724 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
720 optionflags = doctest.NORMALIZE_WHITESPACE | doctest.ELLIPSIS
725
721
726 doctests = self.finder.find(obj, module=getmodule(parent))
722 doctests = self.finder.find(obj, module=getmodule(parent))
727 if doctests:
723 if doctests:
728 for test in doctests:
724 for test in doctests:
729 if len(test.examples) == 0:
725 if len(test.examples) == 0:
730 continue
726 continue
731
727
732 yield DocTestCase(test, obj=obj,
728 yield DocTestCase(test, obj=obj,
733 optionflags=optionflags,
729 optionflags=optionflags,
734 checker=self.checker)
730 checker=self.checker)
735
731
736 def options(self, parser, env=os.environ):
732 def options(self, parser, env=os.environ):
737 #print "Options for nose plugin:", self.name # dbg
733 #print "Options for nose plugin:", self.name # dbg
738 Plugin.options(self, parser, env)
734 Plugin.options(self, parser, env)
739 parser.add_option('--ipdoctest-tests', action='store_true',
735 parser.add_option('--ipdoctest-tests', action='store_true',
740 dest='ipdoctest_tests',
736 dest='ipdoctest_tests',
741 default=env.get('NOSE_IPDOCTEST_TESTS',True),
737 default=env.get('NOSE_IPDOCTEST_TESTS',True),
742 help="Also look for doctests in test modules. "
738 help="Also look for doctests in test modules. "
743 "Note that classes, methods and functions should "
739 "Note that classes, methods and functions should "
744 "have either doctests or non-doctest tests, "
740 "have either doctests or non-doctest tests, "
745 "not both. [NOSE_IPDOCTEST_TESTS]")
741 "not both. [NOSE_IPDOCTEST_TESTS]")
746 parser.add_option('--ipdoctest-extension', action="append",
742 parser.add_option('--ipdoctest-extension', action="append",
747 dest="ipdoctest_extension",
743 dest="ipdoctest_extension",
748 help="Also look for doctests in files with "
744 help="Also look for doctests in files with "
749 "this extension [NOSE_IPDOCTEST_EXTENSION]")
745 "this extension [NOSE_IPDOCTEST_EXTENSION]")
750 # Set the default as a list, if given in env; otherwise
746 # Set the default as a list, if given in env; otherwise
751 # an additional value set on the command line will cause
747 # an additional value set on the command line will cause
752 # an error.
748 # an error.
753 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
749 env_setting = env.get('NOSE_IPDOCTEST_EXTENSION')
754 if env_setting is not None:
750 if env_setting is not None:
755 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
751 parser.set_defaults(ipdoctest_extension=tolist(env_setting))
756
752
757 def configure(self, options, config):
753 def configure(self, options, config):
758 #print "Configuring nose plugin:", self.name # dbg
754 #print "Configuring nose plugin:", self.name # dbg
759 Plugin.configure(self, options, config)
755 Plugin.configure(self, options, config)
760 # Pull standard doctest plugin out of config; we will do doctesting
756 # Pull standard doctest plugin out of config; we will do doctesting
761 config.plugins.plugins = [p for p in config.plugins.plugins
757 config.plugins.plugins = [p for p in config.plugins.plugins
762 if p.name != 'doctest']
758 if p.name != 'doctest']
763 self.doctest_tests = options.ipdoctest_tests
759 self.doctest_tests = options.ipdoctest_tests
764 self.extension = tolist(options.ipdoctest_extension)
760 self.extension = tolist(options.ipdoctest_extension)
765
761
766 self.parser = IPDocTestParser()
762 self.parser = IPDocTestParser()
767 self.finder = DocTestFinder(parser=self.parser)
763 self.finder = DocTestFinder(parser=self.parser)
768 self.checker = IPDoctestOutputChecker()
764 self.checker = IPDoctestOutputChecker()
769 self.globs = None
765 self.globs = None
770 self.extraglobs = None
766 self.extraglobs = None
@@ -1,476 +1,468 b''
1 """Generic testing tools.
1 """Generic testing tools.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 """
6 """
7
7
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import os
12 import os
13 import re
13 import re
14 import sys
14 import sys
15 import tempfile
15 import tempfile
16
16
17 from contextlib import contextmanager
17 from contextlib import contextmanager
18 from io import StringIO
18 from io import StringIO
19 from subprocess import Popen, PIPE
19 from subprocess import Popen, PIPE
20 from unittest.mock import patch
20 from unittest.mock import patch
21
21
22 try:
22 try:
23 # These tools are used by parts of the runtime, so we make the nose
23 # These tools are used by parts of the runtime, so we make the nose
24 # dependency optional at this point. Nose is a hard dependency to run the
24 # dependency optional at this point. Nose is a hard dependency to run the
25 # test suite, but NOT to use ipython itself.
25 # test suite, but NOT to use ipython itself.
26 import nose.tools as nt
26 import nose.tools as nt
27 has_nose = True
27 has_nose = True
28 except ImportError:
28 except ImportError:
29 has_nose = False
29 has_nose = False
30
30
31 from traitlets.config.loader import Config
31 from traitlets.config.loader import Config
32 from IPython.utils.process import get_output_error_code
32 from IPython.utils.process import get_output_error_code
33 from IPython.utils.text import list_strings
33 from IPython.utils.text import list_strings
34 from IPython.utils.io import temp_pyfile, Tee
34 from IPython.utils.io import temp_pyfile, Tee
35 from IPython.utils import py3compat
35 from IPython.utils import py3compat
36 from IPython.utils.encoding import DEFAULT_ENCODING
36 from IPython.utils.encoding import DEFAULT_ENCODING
37
37
38 from . import decorators as dec
38 from . import decorators as dec
39 from . import skipdoctest
39 from . import skipdoctest
40
40
41
41
42 # The docstring for full_path doctests differently on win32 (different path
42 # The docstring for full_path doctests differently on win32 (different path
43 # separator) so just skip the doctest there. The example remains informative.
43 # separator) so just skip the doctest there. The example remains informative.
44 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
44 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45
45
46 @doctest_deco
46 @doctest_deco
47 def full_path(startPath,files):
47 def full_path(startPath,files):
48 """Make full paths for all the listed files, based on startPath.
48 """Make full paths for all the listed files, based on startPath.
49
49
50 Only the base part of startPath is kept, since this routine is typically
50 Only the base part of startPath is kept, since this routine is typically
51 used with a script's ``__file__`` variable as startPath. The base of startPath
51 used with a script's ``__file__`` variable as startPath. The base of startPath
52 is then prepended to all the listed files, forming the output list.
52 is then prepended to all the listed files, forming the output list.
53
53
54 Parameters
54 Parameters
55 ----------
55 ----------
56 startPath : string
56 startPath : string
57 Initial path to use as the base for the results. This path is split
57 Initial path to use as the base for the results. This path is split
58 using os.path.split() and only its first component is kept.
58 using os.path.split() and only its first component is kept.
59
59
60 files : string or list
60 files : string or list
61 One or more files.
61 One or more files.
62
62
63 Examples
63 Examples
64 --------
64 --------
65
65
66 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
66 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 ['/foo/a.txt', '/foo/b.txt']
67 ['/foo/a.txt', '/foo/b.txt']
68
68
69 >>> full_path('/foo',['a.txt','b.txt'])
69 >>> full_path('/foo',['a.txt','b.txt'])
70 ['/a.txt', '/b.txt']
70 ['/a.txt', '/b.txt']
71
71
72 If a single file is given, the output is still a list::
72 If a single file is given, the output is still a list::
73
73
74 >>> full_path('/foo','a.txt')
74 >>> full_path('/foo','a.txt')
75 ['/a.txt']
75 ['/a.txt']
76 """
76 """
77
77
78 files = list_strings(files)
78 files = list_strings(files)
79 base = os.path.split(startPath)[0]
79 base = os.path.split(startPath)[0]
80 return [ os.path.join(base,f) for f in files ]
80 return [ os.path.join(base,f) for f in files ]
81
81
82
82
83 def parse_test_output(txt):
83 def parse_test_output(txt):
84 """Parse the output of a test run and return errors, failures.
84 """Parse the output of a test run and return errors, failures.
85
85
86 Parameters
86 Parameters
87 ----------
87 ----------
88 txt : str
88 txt : str
89 Text output of a test run, assumed to contain a line of one of the
89 Text output of a test run, assumed to contain a line of one of the
90 following forms::
90 following forms::
91
91
92 'FAILED (errors=1)'
92 'FAILED (errors=1)'
93 'FAILED (failures=1)'
93 'FAILED (failures=1)'
94 'FAILED (errors=1, failures=1)'
94 'FAILED (errors=1, failures=1)'
95
95
96 Returns
96 Returns
97 -------
97 -------
98 nerr, nfail
98 nerr, nfail
99 number of errors and failures.
99 number of errors and failures.
100 """
100 """
101
101
102 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
102 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 if err_m:
103 if err_m:
104 nerr = int(err_m.group(1))
104 nerr = int(err_m.group(1))
105 nfail = 0
105 nfail = 0
106 return nerr, nfail
106 return nerr, nfail
107
107
108 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
108 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 if fail_m:
109 if fail_m:
110 nerr = 0
110 nerr = 0
111 nfail = int(fail_m.group(1))
111 nfail = int(fail_m.group(1))
112 return nerr, nfail
112 return nerr, nfail
113
113
114 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
114 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 re.MULTILINE)
115 re.MULTILINE)
116 if both_m:
116 if both_m:
117 nerr = int(both_m.group(1))
117 nerr = int(both_m.group(1))
118 nfail = int(both_m.group(2))
118 nfail = int(both_m.group(2))
119 return nerr, nfail
119 return nerr, nfail
120
120
121 # If the input didn't match any of these forms, assume no error/failures
121 # If the input didn't match any of these forms, assume no error/failures
122 return 0, 0
122 return 0, 0
123
123
124
124
125 # So nose doesn't think this is a test
125 # So nose doesn't think this is a test
126 parse_test_output.__test__ = False
126 parse_test_output.__test__ = False
127
127
128
128
129 def default_argv():
129 def default_argv():
130 """Return a valid default argv for creating testing instances of ipython"""
130 """Return a valid default argv for creating testing instances of ipython"""
131
131
132 return ['--quick', # so no config file is loaded
132 return ['--quick', # so no config file is loaded
133 # Other defaults to minimize side effects on stdout
133 # Other defaults to minimize side effects on stdout
134 '--colors=NoColor', '--no-term-title','--no-banner',
134 '--colors=NoColor', '--no-term-title','--no-banner',
135 '--autocall=0']
135 '--autocall=0']
136
136
137
137
138 def default_config():
138 def default_config():
139 """Return a config object with good defaults for testing."""
139 """Return a config object with good defaults for testing."""
140 config = Config()
140 config = Config()
141 config.TerminalInteractiveShell.colors = 'NoColor'
141 config.TerminalInteractiveShell.colors = 'NoColor'
142 config.TerminalTerminalInteractiveShell.term_title = False,
142 config.TerminalTerminalInteractiveShell.term_title = False,
143 config.TerminalInteractiveShell.autocall = 0
143 config.TerminalInteractiveShell.autocall = 0
144 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
144 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 config.HistoryManager.hist_file = f.name
145 config.HistoryManager.hist_file = f.name
146 f.close()
146 f.close()
147 config.HistoryManager.db_cache_size = 10000
147 config.HistoryManager.db_cache_size = 10000
148 return config
148 return config
149
149
150
150
151 def get_ipython_cmd(as_string=False):
151 def get_ipython_cmd(as_string=False):
152 """
152 """
153 Return appropriate IPython command line name. By default, this will return
153 Return appropriate IPython command line name. By default, this will return
154 a list that can be used with subprocess.Popen, for example, but passing
154 a list that can be used with subprocess.Popen, for example, but passing
155 `as_string=True` allows for returning the IPython command as a string.
155 `as_string=True` allows for returning the IPython command as a string.
156
156
157 Parameters
157 Parameters
158 ----------
158 ----------
159 as_string: bool
159 as_string: bool
160 Flag to allow to return the command as a string.
160 Flag to allow to return the command as a string.
161 """
161 """
162 ipython_cmd = [sys.executable, "-m", "IPython"]
162 ipython_cmd = [sys.executable, "-m", "IPython"]
163
163
164 if as_string:
164 if as_string:
165 ipython_cmd = " ".join(ipython_cmd)
165 ipython_cmd = " ".join(ipython_cmd)
166
166
167 return ipython_cmd
167 return ipython_cmd
168
168
169 def ipexec(fname, options=None, commands=()):
169 def ipexec(fname, options=None, commands=()):
170 """Utility to call 'ipython filename'.
170 """Utility to call 'ipython filename'.
171
171
172 Starts IPython with a minimal and safe configuration to make startup as fast
172 Starts IPython with a minimal and safe configuration to make startup as fast
173 as possible.
173 as possible.
174
174
175 Note that this starts IPython in a subprocess!
175 Note that this starts IPython in a subprocess!
176
176
177 Parameters
177 Parameters
178 ----------
178 ----------
179 fname : str
179 fname : str
180 Name of file to be executed (should have .py or .ipy extension).
180 Name of file to be executed (should have .py or .ipy extension).
181
181
182 options : optional, list
182 options : optional, list
183 Extra command-line flags to be passed to IPython.
183 Extra command-line flags to be passed to IPython.
184
184
185 commands : optional, list
185 commands : optional, list
186 Commands to send in on stdin
186 Commands to send in on stdin
187
187
188 Returns
188 Returns
189 -------
189 -------
190 (stdout, stderr) of ipython subprocess.
190 (stdout, stderr) of ipython subprocess.
191 """
191 """
192 if options is None: options = []
192 if options is None: options = []
193
193
194 cmdargs = default_argv() + options
194 cmdargs = default_argv() + options
195
195
196 test_dir = os.path.dirname(__file__)
196 test_dir = os.path.dirname(__file__)
197
197
198 ipython_cmd = get_ipython_cmd()
198 ipython_cmd = get_ipython_cmd()
199 # Absolute path for filename
199 # Absolute path for filename
200 full_fname = os.path.join(test_dir, fname)
200 full_fname = os.path.join(test_dir, fname)
201 full_cmd = ipython_cmd + cmdargs + [full_fname]
201 full_cmd = ipython_cmd + cmdargs + [full_fname]
202 env = os.environ.copy()
202 env = os.environ.copy()
203 # FIXME: ignore all warnings in ipexec while we have shims
203 # FIXME: ignore all warnings in ipexec while we have shims
204 # should we keep suppressing warnings here, even after removing shims?
204 # should we keep suppressing warnings here, even after removing shims?
205 env['PYTHONWARNINGS'] = 'ignore'
205 env['PYTHONWARNINGS'] = 'ignore'
206 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
206 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 for k, v in env.items():
207 for k, v in env.items():
208 # Debug a bizarre failure we've seen on Windows:
208 # Debug a bizarre failure we've seen on Windows:
209 # TypeError: environment can only contain strings
209 # TypeError: environment can only contain strings
210 if not isinstance(v, str):
210 if not isinstance(v, str):
211 print(k, v)
211 print(k, v)
212 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
212 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 out, err = p.communicate(input=py3compat.str_to_bytes('\n'.join(commands)) or None)
213 out, err = p.communicate(input=py3compat.str_to_bytes('\n'.join(commands)) or None)
214 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
214 out, err = py3compat.bytes_to_str(out), py3compat.bytes_to_str(err)
215 # `import readline` causes 'ESC[?1034h' to be output sometimes,
215 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 # so strip that out before doing comparisons
216 # so strip that out before doing comparisons
217 if out:
217 if out:
218 out = re.sub(r'\x1b\[[^h]+h', '', out)
218 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 return out, err
219 return out, err
220
220
221
221
222 def ipexec_validate(fname, expected_out, expected_err='',
222 def ipexec_validate(fname, expected_out, expected_err='',
223 options=None, commands=()):
223 options=None, commands=()):
224 """Utility to call 'ipython filename' and validate output/error.
224 """Utility to call 'ipython filename' and validate output/error.
225
225
226 This function raises an AssertionError if the validation fails.
226 This function raises an AssertionError if the validation fails.
227
227
228 Note that this starts IPython in a subprocess!
228 Note that this starts IPython in a subprocess!
229
229
230 Parameters
230 Parameters
231 ----------
231 ----------
232 fname : str
232 fname : str
233 Name of the file to be executed (should have .py or .ipy extension).
233 Name of the file to be executed (should have .py or .ipy extension).
234
234
235 expected_out : str
235 expected_out : str
236 Expected stdout of the process.
236 Expected stdout of the process.
237
237
238 expected_err : optional, str
238 expected_err : optional, str
239 Expected stderr of the process.
239 Expected stderr of the process.
240
240
241 options : optional, list
241 options : optional, list
242 Extra command-line flags to be passed to IPython.
242 Extra command-line flags to be passed to IPython.
243
243
244 Returns
244 Returns
245 -------
245 -------
246 None
246 None
247 """
247 """
248
248
249 import nose.tools as nt
249 import nose.tools as nt
250
250
251 out, err = ipexec(fname, options, commands)
251 out, err = ipexec(fname, options, commands)
252 #print 'OUT', out # dbg
252 #print 'OUT', out # dbg
253 #print 'ERR', err # dbg
253 #print 'ERR', err # dbg
254 # If there are any errors, we must check those befor stdout, as they may be
254 # If there are any errors, we must check those befor stdout, as they may be
255 # more informative than simply having an empty stdout.
255 # more informative than simply having an empty stdout.
256 if err:
256 if err:
257 if expected_err:
257 if expected_err:
258 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
258 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 else:
259 else:
260 raise ValueError('Running file %r produced error: %r' %
260 raise ValueError('Running file %r produced error: %r' %
261 (fname, err))
261 (fname, err))
262 # If no errors or output on stderr was expected, match stdout
262 # If no errors or output on stderr was expected, match stdout
263 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
263 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264
264
265
265
266 class TempFileMixin(object):
266 class TempFileMixin(object):
267 """Utility class to create temporary Python/IPython files.
267 """Utility class to create temporary Python/IPython files.
268
268
269 Meant as a mixin class for test cases."""
269 Meant as a mixin class for test cases."""
270
270
271 def mktmp(self, src, ext='.py'):
271 def mktmp(self, src, ext='.py'):
272 """Make a valid python temp file."""
272 """Make a valid python temp file."""
273 fname, f = temp_pyfile(src, ext)
273 fname, f = temp_pyfile(src, ext)
274 self.tmpfile = f
274 self.tmpfile = f
275 self.fname = fname
275 self.fname = fname
276
276
277 def tearDown(self):
277 def tearDown(self):
278 if hasattr(self, 'tmpfile'):
278 if hasattr(self, 'tmpfile'):
279 # If the tmpfile wasn't made because of skipped tests, like in
279 # If the tmpfile wasn't made because of skipped tests, like in
280 # win32, there's nothing to cleanup.
280 # win32, there's nothing to cleanup.
281 self.tmpfile.close()
281 self.tmpfile.close()
282 try:
282 try:
283 os.unlink(self.fname)
283 os.unlink(self.fname)
284 except:
284 except:
285 # On Windows, even though we close the file, we still can't
285 # On Windows, even though we close the file, we still can't
286 # delete it. I have no clue why
286 # delete it. I have no clue why
287 pass
287 pass
288
288
289 def __enter__(self):
289 def __enter__(self):
290 return self
290 return self
291
291
292 def __exit__(self, exc_type, exc_value, traceback):
292 def __exit__(self, exc_type, exc_value, traceback):
293 self.tearDown()
293 self.tearDown()
294
294
295
295
296 pair_fail_msg = ("Testing {0}\n\n"
296 pair_fail_msg = ("Testing {0}\n\n"
297 "In:\n"
297 "In:\n"
298 " {1!r}\n"
298 " {1!r}\n"
299 "Expected:\n"
299 "Expected:\n"
300 " {2!r}\n"
300 " {2!r}\n"
301 "Got:\n"
301 "Got:\n"
302 " {3!r}\n")
302 " {3!r}\n")
303 def check_pairs(func, pairs):
303 def check_pairs(func, pairs):
304 """Utility function for the common case of checking a function with a
304 """Utility function for the common case of checking a function with a
305 sequence of input/output pairs.
305 sequence of input/output pairs.
306
306
307 Parameters
307 Parameters
308 ----------
308 ----------
309 func : callable
309 func : callable
310 The function to be tested. Should accept a single argument.
310 The function to be tested. Should accept a single argument.
311 pairs : iterable
311 pairs : iterable
312 A list of (input, expected_output) tuples.
312 A list of (input, expected_output) tuples.
313
313
314 Returns
314 Returns
315 -------
315 -------
316 None. Raises an AssertionError if any output does not match the expected
316 None. Raises an AssertionError if any output does not match the expected
317 value.
317 value.
318 """
318 """
319 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
319 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
320 for inp, expected in pairs:
320 for inp, expected in pairs:
321 out = func(inp)
321 out = func(inp)
322 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
322 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
323
323
324
324
325 if py3compat.PY3:
325 MyStringIO = StringIO
326 MyStringIO = StringIO
327 else:
328 # In Python 2, stdout/stderr can have either bytes or unicode written to them,
329 # so we need a class that can handle both.
330 class MyStringIO(StringIO):
331 def write(self, s):
332 s = py3compat.cast_unicode(s, encoding=DEFAULT_ENCODING)
333 super(MyStringIO, self).write(s)
334
326
335 _re_type = type(re.compile(r''))
327 _re_type = type(re.compile(r''))
336
328
337 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
329 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
338 -------
330 -------
339 {2!s}
331 {2!s}
340 -------
332 -------
341 """
333 """
342
334
343 class AssertPrints(object):
335 class AssertPrints(object):
344 """Context manager for testing that code prints certain text.
336 """Context manager for testing that code prints certain text.
345
337
346 Examples
338 Examples
347 --------
339 --------
348 >>> with AssertPrints("abc", suppress=False):
340 >>> with AssertPrints("abc", suppress=False):
349 ... print("abcd")
341 ... print("abcd")
350 ... print("def")
342 ... print("def")
351 ...
343 ...
352 abcd
344 abcd
353 def
345 def
354 """
346 """
355 def __init__(self, s, channel='stdout', suppress=True):
347 def __init__(self, s, channel='stdout', suppress=True):
356 self.s = s
348 self.s = s
357 if isinstance(self.s, (str, _re_type)):
349 if isinstance(self.s, (str, _re_type)):
358 self.s = [self.s]
350 self.s = [self.s]
359 self.channel = channel
351 self.channel = channel
360 self.suppress = suppress
352 self.suppress = suppress
361
353
362 def __enter__(self):
354 def __enter__(self):
363 self.orig_stream = getattr(sys, self.channel)
355 self.orig_stream = getattr(sys, self.channel)
364 self.buffer = MyStringIO()
356 self.buffer = MyStringIO()
365 self.tee = Tee(self.buffer, channel=self.channel)
357 self.tee = Tee(self.buffer, channel=self.channel)
366 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
358 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
367
359
368 def __exit__(self, etype, value, traceback):
360 def __exit__(self, etype, value, traceback):
369 try:
361 try:
370 if value is not None:
362 if value is not None:
371 # If an error was raised, don't check anything else
363 # If an error was raised, don't check anything else
372 return False
364 return False
373 self.tee.flush()
365 self.tee.flush()
374 setattr(sys, self.channel, self.orig_stream)
366 setattr(sys, self.channel, self.orig_stream)
375 printed = self.buffer.getvalue()
367 printed = self.buffer.getvalue()
376 for s in self.s:
368 for s in self.s:
377 if isinstance(s, _re_type):
369 if isinstance(s, _re_type):
378 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
370 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
379 else:
371 else:
380 assert s in printed, notprinted_msg.format(s, self.channel, printed)
372 assert s in printed, notprinted_msg.format(s, self.channel, printed)
381 return False
373 return False
382 finally:
374 finally:
383 self.tee.close()
375 self.tee.close()
384
376
385 printed_msg = """Found {0!r} in printed output (on {1}):
377 printed_msg = """Found {0!r} in printed output (on {1}):
386 -------
378 -------
387 {2!s}
379 {2!s}
388 -------
380 -------
389 """
381 """
390
382
391 class AssertNotPrints(AssertPrints):
383 class AssertNotPrints(AssertPrints):
392 """Context manager for checking that certain output *isn't* produced.
384 """Context manager for checking that certain output *isn't* produced.
393
385
394 Counterpart of AssertPrints"""
386 Counterpart of AssertPrints"""
395 def __exit__(self, etype, value, traceback):
387 def __exit__(self, etype, value, traceback):
396 try:
388 try:
397 if value is not None:
389 if value is not None:
398 # If an error was raised, don't check anything else
390 # If an error was raised, don't check anything else
399 self.tee.close()
391 self.tee.close()
400 return False
392 return False
401 self.tee.flush()
393 self.tee.flush()
402 setattr(sys, self.channel, self.orig_stream)
394 setattr(sys, self.channel, self.orig_stream)
403 printed = self.buffer.getvalue()
395 printed = self.buffer.getvalue()
404 for s in self.s:
396 for s in self.s:
405 if isinstance(s, _re_type):
397 if isinstance(s, _re_type):
406 assert not s.search(printed),printed_msg.format(
398 assert not s.search(printed),printed_msg.format(
407 s.pattern, self.channel, printed)
399 s.pattern, self.channel, printed)
408 else:
400 else:
409 assert s not in printed, printed_msg.format(
401 assert s not in printed, printed_msg.format(
410 s, self.channel, printed)
402 s, self.channel, printed)
411 return False
403 return False
412 finally:
404 finally:
413 self.tee.close()
405 self.tee.close()
414
406
415 @contextmanager
407 @contextmanager
416 def mute_warn():
408 def mute_warn():
417 from IPython.utils import warn
409 from IPython.utils import warn
418 save_warn = warn.warn
410 save_warn = warn.warn
419 warn.warn = lambda *a, **kw: None
411 warn.warn = lambda *a, **kw: None
420 try:
412 try:
421 yield
413 yield
422 finally:
414 finally:
423 warn.warn = save_warn
415 warn.warn = save_warn
424
416
425 @contextmanager
417 @contextmanager
426 def make_tempfile(name):
418 def make_tempfile(name):
427 """ Create an empty, named, temporary file for the duration of the context.
419 """ Create an empty, named, temporary file for the duration of the context.
428 """
420 """
429 f = open(name, 'w')
421 f = open(name, 'w')
430 f.close()
422 f.close()
431 try:
423 try:
432 yield
424 yield
433 finally:
425 finally:
434 os.unlink(name)
426 os.unlink(name)
435
427
436 def fake_input(inputs):
428 def fake_input(inputs):
437 """Temporarily replace the input() function to return the given values
429 """Temporarily replace the input() function to return the given values
438
430
439 Use as a context manager:
431 Use as a context manager:
440
432
441 with fake_input(['result1', 'result2']):
433 with fake_input(['result1', 'result2']):
442 ...
434 ...
443
435
444 Values are returned in order. If input() is called again after the last value
436 Values are returned in order. If input() is called again after the last value
445 was used, EOFError is raised.
437 was used, EOFError is raised.
446 """
438 """
447 it = iter(inputs)
439 it = iter(inputs)
448 def mock_input(prompt=''):
440 def mock_input(prompt=''):
449 try:
441 try:
450 return next(it)
442 return next(it)
451 except StopIteration:
443 except StopIteration:
452 raise EOFError('No more inputs given')
444 raise EOFError('No more inputs given')
453
445
454 return patch('builtins.input', mock_input)
446 return patch('builtins.input', mock_input)
455
447
456 def help_output_test(subcommand=''):
448 def help_output_test(subcommand=''):
457 """test that `ipython [subcommand] -h` works"""
449 """test that `ipython [subcommand] -h` works"""
458 cmd = get_ipython_cmd() + [subcommand, '-h']
450 cmd = get_ipython_cmd() + [subcommand, '-h']
459 out, err, rc = get_output_error_code(cmd)
451 out, err, rc = get_output_error_code(cmd)
460 nt.assert_equal(rc, 0, err)
452 nt.assert_equal(rc, 0, err)
461 nt.assert_not_in("Traceback", err)
453 nt.assert_not_in("Traceback", err)
462 nt.assert_in("Options", out)
454 nt.assert_in("Options", out)
463 nt.assert_in("--help-all", out)
455 nt.assert_in("--help-all", out)
464 return out, err
456 return out, err
465
457
466
458
467 def help_all_output_test(subcommand=''):
459 def help_all_output_test(subcommand=''):
468 """test that `ipython [subcommand] --help-all` works"""
460 """test that `ipython [subcommand] --help-all` works"""
469 cmd = get_ipython_cmd() + [subcommand, '--help-all']
461 cmd = get_ipython_cmd() + [subcommand, '--help-all']
470 out, err, rc = get_output_error_code(cmd)
462 out, err, rc = get_output_error_code(cmd)
471 nt.assert_equal(rc, 0, err)
463 nt.assert_equal(rc, 0, err)
472 nt.assert_not_in("Traceback", err)
464 nt.assert_not_in("Traceback", err)
473 nt.assert_in("Options", out)
465 nt.assert_in("Options", out)
474 nt.assert_in("Class parameters", out)
466 nt.assert_in("Class parameters", out)
475 return out, err
467 return out, err
476
468
@@ -1,489 +1,482 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Tests for IPython.utils.path.py"""
2 """Tests for IPython.utils.path.py"""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import errno
7 import errno
8 import os
8 import os
9 import shutil
9 import shutil
10 import sys
10 import sys
11 import tempfile
11 import tempfile
12 import warnings
12 import warnings
13 from contextlib import contextmanager
13 from contextlib import contextmanager
14 from unittest.mock import patch
14 from unittest.mock import patch
15 from os.path import join, abspath, split
15 from os.path import join, abspath, split
16 from imp import reload
16
17
17 from nose import SkipTest, with_setup
18 from nose import SkipTest, with_setup
18 import nose.tools as nt
19 import nose.tools as nt
19
20
20 import IPython
21 import IPython
21 from IPython import paths
22 from IPython import paths
22 from IPython.testing import decorators as dec
23 from IPython.testing import decorators as dec
23 from IPython.testing.decorators import (skip_if_not_win32, skip_win32,
24 from IPython.testing.decorators import (skip_if_not_win32, skip_win32,
24 onlyif_unicode_paths,)
25 onlyif_unicode_paths,)
25 from IPython.testing.tools import make_tempfile, AssertPrints
26 from IPython.testing.tools import make_tempfile, AssertPrints
26 from IPython.utils import path
27 from IPython.utils import path
27 from IPython.utils import py3compat
28 from IPython.utils import py3compat
28 from IPython.utils.tempdir import TemporaryDirectory
29 from IPython.utils.tempdir import TemporaryDirectory
29
30
30 # Platform-dependent imports
31 # Platform-dependent imports
31 try:
32 try:
32 import winreg as wreg # Py 3
33 import winreg as wreg
33 except ImportError:
34 except ImportError:
35 #Fake _winreg module on non-windows platforms
36 import types
37 wr_name = "winreg"
38 sys.modules[wr_name] = types.ModuleType(wr_name)
34 try:
39 try:
35 import _winreg as wreg # Py 2
40 import winreg as wreg
36 except ImportError:
41 except ImportError:
37 #Fake _winreg module on none windows platforms
42 import _winreg as wreg
38 import types
39 wr_name = "winreg" if py3compat.PY3 else "_winreg"
40 sys.modules[wr_name] = types.ModuleType(wr_name)
41 try:
42 import winreg as wreg
43 except ImportError:
44 import _winreg as wreg
45 #Add entries that needs to be stubbed by the testing code
43 #Add entries that needs to be stubbed by the testing code
46 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
44 (wreg.OpenKey, wreg.QueryValueEx,) = (None, None)
47
45
48 try:
49 reload
50 except NameError: # Python 3
51 from imp import reload
52
53 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
54 # Globals
47 # Globals
55 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
56 env = os.environ
49 env = os.environ
57 TMP_TEST_DIR = tempfile.mkdtemp()
50 TMP_TEST_DIR = tempfile.mkdtemp()
58 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
51 HOME_TEST_DIR = join(TMP_TEST_DIR, "home_test_dir")
59 #
52 #
60 # Setup/teardown functions/decorators
53 # Setup/teardown functions/decorators
61 #
54 #
62
55
63 def setup():
56 def setup():
64 """Setup testenvironment for the module:
57 """Setup testenvironment for the module:
65
58
66 - Adds dummy home dir tree
59 - Adds dummy home dir tree
67 """
60 """
68 # Do not mask exceptions here. In particular, catching WindowsError is a
61 # Do not mask exceptions here. In particular, catching WindowsError is a
69 # problem because that exception is only defined on Windows...
62 # problem because that exception is only defined on Windows...
70 os.makedirs(os.path.join(HOME_TEST_DIR, 'ipython'))
63 os.makedirs(os.path.join(HOME_TEST_DIR, 'ipython'))
71
64
72
65
73 def teardown():
66 def teardown():
74 """Teardown testenvironment for the module:
67 """Teardown testenvironment for the module:
75
68
76 - Remove dummy home dir tree
69 - Remove dummy home dir tree
77 """
70 """
78 # Note: we remove the parent test dir, which is the root of all test
71 # Note: we remove the parent test dir, which is the root of all test
79 # subdirs we may have created. Use shutil instead of os.removedirs, so
72 # subdirs we may have created. Use shutil instead of os.removedirs, so
80 # that non-empty directories are all recursively removed.
73 # that non-empty directories are all recursively removed.
81 shutil.rmtree(TMP_TEST_DIR)
74 shutil.rmtree(TMP_TEST_DIR)
82
75
83
76
84 def setup_environment():
77 def setup_environment():
85 """Setup testenvironment for some functions that are tested
78 """Setup testenvironment for some functions that are tested
86 in this module. In particular this functions stores attributes
79 in this module. In particular this functions stores attributes
87 and other things that we need to stub in some test functions.
80 and other things that we need to stub in some test functions.
88 This needs to be done on a function level and not module level because
81 This needs to be done on a function level and not module level because
89 each testfunction needs a pristine environment.
82 each testfunction needs a pristine environment.
90 """
83 """
91 global oldstuff, platformstuff
84 global oldstuff, platformstuff
92 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
85 oldstuff = (env.copy(), os.name, sys.platform, path.get_home_dir, IPython.__file__, os.getcwd())
93
86
94 def teardown_environment():
87 def teardown_environment():
95 """Restore things that were remembered by the setup_environment function
88 """Restore things that were remembered by the setup_environment function
96 """
89 """
97 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
90 (oldenv, os.name, sys.platform, path.get_home_dir, IPython.__file__, old_wd) = oldstuff
98 os.chdir(old_wd)
91 os.chdir(old_wd)
99 reload(path)
92 reload(path)
100
93
101 for key in list(env):
94 for key in list(env):
102 if key not in oldenv:
95 if key not in oldenv:
103 del env[key]
96 del env[key]
104 env.update(oldenv)
97 env.update(oldenv)
105 if hasattr(sys, 'frozen'):
98 if hasattr(sys, 'frozen'):
106 del sys.frozen
99 del sys.frozen
107
100
108 # Build decorator that uses the setup_environment/setup_environment
101 # Build decorator that uses the setup_environment/setup_environment
109 with_environment = with_setup(setup_environment, teardown_environment)
102 with_environment = with_setup(setup_environment, teardown_environment)
110
103
111 @skip_if_not_win32
104 @skip_if_not_win32
112 @with_environment
105 @with_environment
113 def test_get_home_dir_1():
106 def test_get_home_dir_1():
114 """Testcase for py2exe logic, un-compressed lib
107 """Testcase for py2exe logic, un-compressed lib
115 """
108 """
116 unfrozen = path.get_home_dir()
109 unfrozen = path.get_home_dir()
117 sys.frozen = True
110 sys.frozen = True
118
111
119 #fake filename for IPython.__init__
112 #fake filename for IPython.__init__
120 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
113 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Lib/IPython/__init__.py"))
121
114
122 home_dir = path.get_home_dir()
115 home_dir = path.get_home_dir()
123 nt.assert_equal(home_dir, unfrozen)
116 nt.assert_equal(home_dir, unfrozen)
124
117
125
118
126 @skip_if_not_win32
119 @skip_if_not_win32
127 @with_environment
120 @with_environment
128 def test_get_home_dir_2():
121 def test_get_home_dir_2():
129 """Testcase for py2exe logic, compressed lib
122 """Testcase for py2exe logic, compressed lib
130 """
123 """
131 unfrozen = path.get_home_dir()
124 unfrozen = path.get_home_dir()
132 sys.frozen = True
125 sys.frozen = True
133 #fake filename for IPython.__init__
126 #fake filename for IPython.__init__
134 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
127 IPython.__file__ = abspath(join(HOME_TEST_DIR, "Library.zip/IPython/__init__.py")).lower()
135
128
136 home_dir = path.get_home_dir(True)
129 home_dir = path.get_home_dir(True)
137 nt.assert_equal(home_dir, unfrozen)
130 nt.assert_equal(home_dir, unfrozen)
138
131
139
132
140 @with_environment
133 @with_environment
141 def test_get_home_dir_3():
134 def test_get_home_dir_3():
142 """get_home_dir() uses $HOME if set"""
135 """get_home_dir() uses $HOME if set"""
143 env["HOME"] = HOME_TEST_DIR
136 env["HOME"] = HOME_TEST_DIR
144 home_dir = path.get_home_dir(True)
137 home_dir = path.get_home_dir(True)
145 # get_home_dir expands symlinks
138 # get_home_dir expands symlinks
146 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
139 nt.assert_equal(home_dir, os.path.realpath(env["HOME"]))
147
140
148
141
149 @with_environment
142 @with_environment
150 def test_get_home_dir_4():
143 def test_get_home_dir_4():
151 """get_home_dir() still works if $HOME is not set"""
144 """get_home_dir() still works if $HOME is not set"""
152
145
153 if 'HOME' in env: del env['HOME']
146 if 'HOME' in env: del env['HOME']
154 # this should still succeed, but we don't care what the answer is
147 # this should still succeed, but we don't care what the answer is
155 home = path.get_home_dir(False)
148 home = path.get_home_dir(False)
156
149
157 @with_environment
150 @with_environment
158 def test_get_home_dir_5():
151 def test_get_home_dir_5():
159 """raise HomeDirError if $HOME is specified, but not a writable dir"""
152 """raise HomeDirError if $HOME is specified, but not a writable dir"""
160 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
153 env['HOME'] = abspath(HOME_TEST_DIR+'garbage')
161 # set os.name = posix, to prevent My Documents fallback on Windows
154 # set os.name = posix, to prevent My Documents fallback on Windows
162 os.name = 'posix'
155 os.name = 'posix'
163 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
156 nt.assert_raises(path.HomeDirError, path.get_home_dir, True)
164
157
165 # Should we stub wreg fully so we can run the test on all platforms?
158 # Should we stub wreg fully so we can run the test on all platforms?
166 @skip_if_not_win32
159 @skip_if_not_win32
167 @with_environment
160 @with_environment
168 def test_get_home_dir_8():
161 def test_get_home_dir_8():
169 """Using registry hack for 'My Documents', os=='nt'
162 """Using registry hack for 'My Documents', os=='nt'
170
163
171 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
164 HOMESHARE, HOMEDRIVE, HOMEPATH, USERPROFILE and others are missing.
172 """
165 """
173 os.name = 'nt'
166 os.name = 'nt'
174 # Remove from stub environment all keys that may be set
167 # Remove from stub environment all keys that may be set
175 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
168 for key in ['HOME', 'HOMESHARE', 'HOMEDRIVE', 'HOMEPATH', 'USERPROFILE']:
176 env.pop(key, None)
169 env.pop(key, None)
177
170
178 class key:
171 class key:
179 def Close(self):
172 def Close(self):
180 pass
173 pass
181
174
182 with patch.object(wreg, 'OpenKey', return_value=key()), \
175 with patch.object(wreg, 'OpenKey', return_value=key()), \
183 patch.object(wreg, 'QueryValueEx', return_value=[abspath(HOME_TEST_DIR)]):
176 patch.object(wreg, 'QueryValueEx', return_value=[abspath(HOME_TEST_DIR)]):
184 home_dir = path.get_home_dir()
177 home_dir = path.get_home_dir()
185 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
178 nt.assert_equal(home_dir, abspath(HOME_TEST_DIR))
186
179
187 @with_environment
180 @with_environment
188 def test_get_xdg_dir_0():
181 def test_get_xdg_dir_0():
189 """test_get_xdg_dir_0, check xdg_dir"""
182 """test_get_xdg_dir_0, check xdg_dir"""
190 reload(path)
183 reload(path)
191 path._writable_dir = lambda path: True
184 path._writable_dir = lambda path: True
192 path.get_home_dir = lambda : 'somewhere'
185 path.get_home_dir = lambda : 'somewhere'
193 os.name = "posix"
186 os.name = "posix"
194 sys.platform = "linux2"
187 sys.platform = "linux2"
195 env.pop('IPYTHON_DIR', None)
188 env.pop('IPYTHON_DIR', None)
196 env.pop('IPYTHONDIR', None)
189 env.pop('IPYTHONDIR', None)
197 env.pop('XDG_CONFIG_HOME', None)
190 env.pop('XDG_CONFIG_HOME', None)
198
191
199 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
192 nt.assert_equal(path.get_xdg_dir(), os.path.join('somewhere', '.config'))
200
193
201
194
202 @with_environment
195 @with_environment
203 def test_get_xdg_dir_1():
196 def test_get_xdg_dir_1():
204 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
197 """test_get_xdg_dir_1, check nonexistant xdg_dir"""
205 reload(path)
198 reload(path)
206 path.get_home_dir = lambda : HOME_TEST_DIR
199 path.get_home_dir = lambda : HOME_TEST_DIR
207 os.name = "posix"
200 os.name = "posix"
208 sys.platform = "linux2"
201 sys.platform = "linux2"
209 env.pop('IPYTHON_DIR', None)
202 env.pop('IPYTHON_DIR', None)
210 env.pop('IPYTHONDIR', None)
203 env.pop('IPYTHONDIR', None)
211 env.pop('XDG_CONFIG_HOME', None)
204 env.pop('XDG_CONFIG_HOME', None)
212 nt.assert_equal(path.get_xdg_dir(), None)
205 nt.assert_equal(path.get_xdg_dir(), None)
213
206
214 @with_environment
207 @with_environment
215 def test_get_xdg_dir_2():
208 def test_get_xdg_dir_2():
216 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
209 """test_get_xdg_dir_2, check xdg_dir default to ~/.config"""
217 reload(path)
210 reload(path)
218 path.get_home_dir = lambda : HOME_TEST_DIR
211 path.get_home_dir = lambda : HOME_TEST_DIR
219 os.name = "posix"
212 os.name = "posix"
220 sys.platform = "linux2"
213 sys.platform = "linux2"
221 env.pop('IPYTHON_DIR', None)
214 env.pop('IPYTHON_DIR', None)
222 env.pop('IPYTHONDIR', None)
215 env.pop('IPYTHONDIR', None)
223 env.pop('XDG_CONFIG_HOME', None)
216 env.pop('XDG_CONFIG_HOME', None)
224 cfgdir=os.path.join(path.get_home_dir(), '.config')
217 cfgdir=os.path.join(path.get_home_dir(), '.config')
225 if not os.path.exists(cfgdir):
218 if not os.path.exists(cfgdir):
226 os.makedirs(cfgdir)
219 os.makedirs(cfgdir)
227
220
228 nt.assert_equal(path.get_xdg_dir(), cfgdir)
221 nt.assert_equal(path.get_xdg_dir(), cfgdir)
229
222
230 @with_environment
223 @with_environment
231 def test_get_xdg_dir_3():
224 def test_get_xdg_dir_3():
232 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
225 """test_get_xdg_dir_3, check xdg_dir not used on OS X"""
233 reload(path)
226 reload(path)
234 path.get_home_dir = lambda : HOME_TEST_DIR
227 path.get_home_dir = lambda : HOME_TEST_DIR
235 os.name = "posix"
228 os.name = "posix"
236 sys.platform = "darwin"
229 sys.platform = "darwin"
237 env.pop('IPYTHON_DIR', None)
230 env.pop('IPYTHON_DIR', None)
238 env.pop('IPYTHONDIR', None)
231 env.pop('IPYTHONDIR', None)
239 env.pop('XDG_CONFIG_HOME', None)
232 env.pop('XDG_CONFIG_HOME', None)
240 cfgdir=os.path.join(path.get_home_dir(), '.config')
233 cfgdir=os.path.join(path.get_home_dir(), '.config')
241 if not os.path.exists(cfgdir):
234 if not os.path.exists(cfgdir):
242 os.makedirs(cfgdir)
235 os.makedirs(cfgdir)
243
236
244 nt.assert_equal(path.get_xdg_dir(), None)
237 nt.assert_equal(path.get_xdg_dir(), None)
245
238
246 def test_filefind():
239 def test_filefind():
247 """Various tests for filefind"""
240 """Various tests for filefind"""
248 f = tempfile.NamedTemporaryFile()
241 f = tempfile.NamedTemporaryFile()
249 # print 'fname:',f.name
242 # print 'fname:',f.name
250 alt_dirs = paths.get_ipython_dir()
243 alt_dirs = paths.get_ipython_dir()
251 t = path.filefind(f.name, alt_dirs)
244 t = path.filefind(f.name, alt_dirs)
252 # print 'found:',t
245 # print 'found:',t
253
246
254
247
255 @dec.skip_if_not_win32
248 @dec.skip_if_not_win32
256 def test_get_long_path_name_win32():
249 def test_get_long_path_name_win32():
257 with TemporaryDirectory() as tmpdir:
250 with TemporaryDirectory() as tmpdir:
258
251
259 # Make a long path. Expands the path of tmpdir prematurely as it may already have a long
252 # Make a long path. Expands the path of tmpdir prematurely as it may already have a long
260 # path component, so ensure we include the long form of it
253 # path component, so ensure we include the long form of it
261 long_path = os.path.join(path.get_long_path_name(tmpdir), u'this is my long path name')
254 long_path = os.path.join(path.get_long_path_name(tmpdir), u'this is my long path name')
262 os.makedirs(long_path)
255 os.makedirs(long_path)
263
256
264 # Test to see if the short path evaluates correctly.
257 # Test to see if the short path evaluates correctly.
265 short_path = os.path.join(tmpdir, u'THISIS~1')
258 short_path = os.path.join(tmpdir, u'THISIS~1')
266 evaluated_path = path.get_long_path_name(short_path)
259 evaluated_path = path.get_long_path_name(short_path)
267 nt.assert_equal(evaluated_path.lower(), long_path.lower())
260 nt.assert_equal(evaluated_path.lower(), long_path.lower())
268
261
269
262
270 @dec.skip_win32
263 @dec.skip_win32
271 def test_get_long_path_name():
264 def test_get_long_path_name():
272 p = path.get_long_path_name('/usr/local')
265 p = path.get_long_path_name('/usr/local')
273 nt.assert_equal(p,'/usr/local')
266 nt.assert_equal(p,'/usr/local')
274
267
275 @dec.skip_win32 # can't create not-user-writable dir on win
268 @dec.skip_win32 # can't create not-user-writable dir on win
276 @with_environment
269 @with_environment
277 def test_not_writable_ipdir():
270 def test_not_writable_ipdir():
278 tmpdir = tempfile.mkdtemp()
271 tmpdir = tempfile.mkdtemp()
279 os.name = "posix"
272 os.name = "posix"
280 env.pop('IPYTHON_DIR', None)
273 env.pop('IPYTHON_DIR', None)
281 env.pop('IPYTHONDIR', None)
274 env.pop('IPYTHONDIR', None)
282 env.pop('XDG_CONFIG_HOME', None)
275 env.pop('XDG_CONFIG_HOME', None)
283 env['HOME'] = tmpdir
276 env['HOME'] = tmpdir
284 ipdir = os.path.join(tmpdir, '.ipython')
277 ipdir = os.path.join(tmpdir, '.ipython')
285 os.mkdir(ipdir, 0o555)
278 os.mkdir(ipdir, 0o555)
286 try:
279 try:
287 open(os.path.join(ipdir, "_foo_"), 'w').close()
280 open(os.path.join(ipdir, "_foo_"), 'w').close()
288 except IOError:
281 except IOError:
289 pass
282 pass
290 else:
283 else:
291 # I can still write to an unwritable dir,
284 # I can still write to an unwritable dir,
292 # assume I'm root and skip the test
285 # assume I'm root and skip the test
293 raise SkipTest("I can't create directories that I can't write to")
286 raise SkipTest("I can't create directories that I can't write to")
294 with AssertPrints('is not a writable location', channel='stderr'):
287 with AssertPrints('is not a writable location', channel='stderr'):
295 ipdir = paths.get_ipython_dir()
288 ipdir = paths.get_ipython_dir()
296 env.pop('IPYTHON_DIR', None)
289 env.pop('IPYTHON_DIR', None)
297
290
298 @with_environment
291 @with_environment
299 def test_get_py_filename():
292 def test_get_py_filename():
300 os.chdir(TMP_TEST_DIR)
293 os.chdir(TMP_TEST_DIR)
301 with make_tempfile('foo.py'):
294 with make_tempfile('foo.py'):
302 nt.assert_equal(path.get_py_filename('foo.py'), 'foo.py')
295 nt.assert_equal(path.get_py_filename('foo.py'), 'foo.py')
303 nt.assert_equal(path.get_py_filename('foo'), 'foo.py')
296 nt.assert_equal(path.get_py_filename('foo'), 'foo.py')
304 with make_tempfile('foo'):
297 with make_tempfile('foo'):
305 nt.assert_equal(path.get_py_filename('foo'), 'foo')
298 nt.assert_equal(path.get_py_filename('foo'), 'foo')
306 nt.assert_raises(IOError, path.get_py_filename, 'foo.py')
299 nt.assert_raises(IOError, path.get_py_filename, 'foo.py')
307 nt.assert_raises(IOError, path.get_py_filename, 'foo')
300 nt.assert_raises(IOError, path.get_py_filename, 'foo')
308 nt.assert_raises(IOError, path.get_py_filename, 'foo.py')
301 nt.assert_raises(IOError, path.get_py_filename, 'foo.py')
309 true_fn = 'foo with spaces.py'
302 true_fn = 'foo with spaces.py'
310 with make_tempfile(true_fn):
303 with make_tempfile(true_fn):
311 nt.assert_equal(path.get_py_filename('foo with spaces'), true_fn)
304 nt.assert_equal(path.get_py_filename('foo with spaces'), true_fn)
312 nt.assert_equal(path.get_py_filename('foo with spaces.py'), true_fn)
305 nt.assert_equal(path.get_py_filename('foo with spaces.py'), true_fn)
313 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"')
306 nt.assert_raises(IOError, path.get_py_filename, '"foo with spaces.py"')
314 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'")
307 nt.assert_raises(IOError, path.get_py_filename, "'foo with spaces.py'")
315
308
316 @onlyif_unicode_paths
309 @onlyif_unicode_paths
317 def test_unicode_in_filename():
310 def test_unicode_in_filename():
318 """When a file doesn't exist, the exception raised should be safe to call
311 """When a file doesn't exist, the exception raised should be safe to call
319 str() on - i.e. in Python 2 it must only have ASCII characters.
312 str() on - i.e. in Python 2 it must only have ASCII characters.
320
313
321 https://github.com/ipython/ipython/issues/875
314 https://github.com/ipython/ipython/issues/875
322 """
315 """
323 try:
316 try:
324 # these calls should not throw unicode encode exceptions
317 # these calls should not throw unicode encode exceptions
325 path.get_py_filename(u'fooéè.py', force_win32=False)
318 path.get_py_filename(u'fooéè.py', force_win32=False)
326 except IOError as ex:
319 except IOError as ex:
327 str(ex)
320 str(ex)
328
321
329
322
330 class TestShellGlob(object):
323 class TestShellGlob(object):
331
324
332 @classmethod
325 @classmethod
333 def setUpClass(cls):
326 def setUpClass(cls):
334 cls.filenames_start_with_a = ['a0', 'a1', 'a2']
327 cls.filenames_start_with_a = ['a0', 'a1', 'a2']
335 cls.filenames_end_with_b = ['0b', '1b', '2b']
328 cls.filenames_end_with_b = ['0b', '1b', '2b']
336 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
329 cls.filenames = cls.filenames_start_with_a + cls.filenames_end_with_b
337 cls.tempdir = TemporaryDirectory()
330 cls.tempdir = TemporaryDirectory()
338 td = cls.tempdir.name
331 td = cls.tempdir.name
339
332
340 with cls.in_tempdir():
333 with cls.in_tempdir():
341 # Create empty files
334 # Create empty files
342 for fname in cls.filenames:
335 for fname in cls.filenames:
343 open(os.path.join(td, fname), 'w').close()
336 open(os.path.join(td, fname), 'w').close()
344
337
345 @classmethod
338 @classmethod
346 def tearDownClass(cls):
339 def tearDownClass(cls):
347 cls.tempdir.cleanup()
340 cls.tempdir.cleanup()
348
341
349 @classmethod
342 @classmethod
350 @contextmanager
343 @contextmanager
351 def in_tempdir(cls):
344 def in_tempdir(cls):
352 save = os.getcwd()
345 save = os.getcwd()
353 try:
346 try:
354 os.chdir(cls.tempdir.name)
347 os.chdir(cls.tempdir.name)
355 yield
348 yield
356 finally:
349 finally:
357 os.chdir(save)
350 os.chdir(save)
358
351
359 def check_match(self, patterns, matches):
352 def check_match(self, patterns, matches):
360 with self.in_tempdir():
353 with self.in_tempdir():
361 # glob returns unordered list. that's why sorted is required.
354 # glob returns unordered list. that's why sorted is required.
362 nt.assert_equal(sorted(path.shellglob(patterns)),
355 nt.assert_equal(sorted(path.shellglob(patterns)),
363 sorted(matches))
356 sorted(matches))
364
357
365 def common_cases(self):
358 def common_cases(self):
366 return [
359 return [
367 (['*'], self.filenames),
360 (['*'], self.filenames),
368 (['a*'], self.filenames_start_with_a),
361 (['a*'], self.filenames_start_with_a),
369 (['*c'], ['*c']),
362 (['*c'], ['*c']),
370 (['*', 'a*', '*b', '*c'], self.filenames
363 (['*', 'a*', '*b', '*c'], self.filenames
371 + self.filenames_start_with_a
364 + self.filenames_start_with_a
372 + self.filenames_end_with_b
365 + self.filenames_end_with_b
373 + ['*c']),
366 + ['*c']),
374 (['a[012]'], self.filenames_start_with_a),
367 (['a[012]'], self.filenames_start_with_a),
375 ]
368 ]
376
369
377 @skip_win32
370 @skip_win32
378 def test_match_posix(self):
371 def test_match_posix(self):
379 for (patterns, matches) in self.common_cases() + [
372 for (patterns, matches) in self.common_cases() + [
380 ([r'\*'], ['*']),
373 ([r'\*'], ['*']),
381 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
374 ([r'a\*', 'a*'], ['a*'] + self.filenames_start_with_a),
382 ([r'a\[012]'], ['a[012]']),
375 ([r'a\[012]'], ['a[012]']),
383 ]:
376 ]:
384 yield (self.check_match, patterns, matches)
377 yield (self.check_match, patterns, matches)
385
378
386 @skip_if_not_win32
379 @skip_if_not_win32
387 def test_match_windows(self):
380 def test_match_windows(self):
388 for (patterns, matches) in self.common_cases() + [
381 for (patterns, matches) in self.common_cases() + [
389 # In windows, backslash is interpreted as path
382 # In windows, backslash is interpreted as path
390 # separator. Therefore, you can't escape glob
383 # separator. Therefore, you can't escape glob
391 # using it.
384 # using it.
392 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
385 ([r'a\*', 'a*'], [r'a\*'] + self.filenames_start_with_a),
393 ([r'a\[012]'], [r'a\[012]']),
386 ([r'a\[012]'], [r'a\[012]']),
394 ]:
387 ]:
395 yield (self.check_match, patterns, matches)
388 yield (self.check_match, patterns, matches)
396
389
397
390
398 def test_unescape_glob():
391 def test_unescape_glob():
399 nt.assert_equal(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
392 nt.assert_equal(path.unescape_glob(r'\*\[\!\]\?'), '*[!]?')
400 nt.assert_equal(path.unescape_glob(r'\\*'), r'\*')
393 nt.assert_equal(path.unescape_glob(r'\\*'), r'\*')
401 nt.assert_equal(path.unescape_glob(r'\\\*'), r'\*')
394 nt.assert_equal(path.unescape_glob(r'\\\*'), r'\*')
402 nt.assert_equal(path.unescape_glob(r'\\a'), r'\a')
395 nt.assert_equal(path.unescape_glob(r'\\a'), r'\a')
403 nt.assert_equal(path.unescape_glob(r'\a'), r'\a')
396 nt.assert_equal(path.unescape_glob(r'\a'), r'\a')
404
397
405
398
406 def test_ensure_dir_exists():
399 def test_ensure_dir_exists():
407 with TemporaryDirectory() as td:
400 with TemporaryDirectory() as td:
408 d = os.path.join(td, u'βˆ‚ir')
401 d = os.path.join(td, u'βˆ‚ir')
409 path.ensure_dir_exists(d) # create it
402 path.ensure_dir_exists(d) # create it
410 assert os.path.isdir(d)
403 assert os.path.isdir(d)
411 path.ensure_dir_exists(d) # no-op
404 path.ensure_dir_exists(d) # no-op
412 f = os.path.join(td, u'Ζ’ile')
405 f = os.path.join(td, u'Ζ’ile')
413 open(f, 'w').close() # touch
406 open(f, 'w').close() # touch
414 with nt.assert_raises(IOError):
407 with nt.assert_raises(IOError):
415 path.ensure_dir_exists(f)
408 path.ensure_dir_exists(f)
416
409
417 class TestLinkOrCopy(object):
410 class TestLinkOrCopy(object):
418 def setUp(self):
411 def setUp(self):
419 self.tempdir = TemporaryDirectory()
412 self.tempdir = TemporaryDirectory()
420 self.src = self.dst("src")
413 self.src = self.dst("src")
421 with open(self.src, "w") as f:
414 with open(self.src, "w") as f:
422 f.write("Hello, world!")
415 f.write("Hello, world!")
423
416
424 def tearDown(self):
417 def tearDown(self):
425 self.tempdir.cleanup()
418 self.tempdir.cleanup()
426
419
427 def dst(self, *args):
420 def dst(self, *args):
428 return os.path.join(self.tempdir.name, *args)
421 return os.path.join(self.tempdir.name, *args)
429
422
430 def assert_inode_not_equal(self, a, b):
423 def assert_inode_not_equal(self, a, b):
431 nt.assert_not_equal(os.stat(a).st_ino, os.stat(b).st_ino,
424 nt.assert_not_equal(os.stat(a).st_ino, os.stat(b).st_ino,
432 "%r and %r do reference the same indoes" %(a, b))
425 "%r and %r do reference the same indoes" %(a, b))
433
426
434 def assert_inode_equal(self, a, b):
427 def assert_inode_equal(self, a, b):
435 nt.assert_equal(os.stat(a).st_ino, os.stat(b).st_ino,
428 nt.assert_equal(os.stat(a).st_ino, os.stat(b).st_ino,
436 "%r and %r do not reference the same indoes" %(a, b))
429 "%r and %r do not reference the same indoes" %(a, b))
437
430
438 def assert_content_equal(self, a, b):
431 def assert_content_equal(self, a, b):
439 with open(a) as a_f:
432 with open(a) as a_f:
440 with open(b) as b_f:
433 with open(b) as b_f:
441 nt.assert_equal(a_f.read(), b_f.read())
434 nt.assert_equal(a_f.read(), b_f.read())
442
435
443 @skip_win32
436 @skip_win32
444 def test_link_successful(self):
437 def test_link_successful(self):
445 dst = self.dst("target")
438 dst = self.dst("target")
446 path.link_or_copy(self.src, dst)
439 path.link_or_copy(self.src, dst)
447 self.assert_inode_equal(self.src, dst)
440 self.assert_inode_equal(self.src, dst)
448
441
449 @skip_win32
442 @skip_win32
450 def test_link_into_dir(self):
443 def test_link_into_dir(self):
451 dst = self.dst("some_dir")
444 dst = self.dst("some_dir")
452 os.mkdir(dst)
445 os.mkdir(dst)
453 path.link_or_copy(self.src, dst)
446 path.link_or_copy(self.src, dst)
454 expected_dst = self.dst("some_dir", os.path.basename(self.src))
447 expected_dst = self.dst("some_dir", os.path.basename(self.src))
455 self.assert_inode_equal(self.src, expected_dst)
448 self.assert_inode_equal(self.src, expected_dst)
456
449
457 @skip_win32
450 @skip_win32
458 def test_target_exists(self):
451 def test_target_exists(self):
459 dst = self.dst("target")
452 dst = self.dst("target")
460 open(dst, "w").close()
453 open(dst, "w").close()
461 path.link_or_copy(self.src, dst)
454 path.link_or_copy(self.src, dst)
462 self.assert_inode_equal(self.src, dst)
455 self.assert_inode_equal(self.src, dst)
463
456
464 @skip_win32
457 @skip_win32
465 def test_no_link(self):
458 def test_no_link(self):
466 real_link = os.link
459 real_link = os.link
467 try:
460 try:
468 del os.link
461 del os.link
469 dst = self.dst("target")
462 dst = self.dst("target")
470 path.link_or_copy(self.src, dst)
463 path.link_or_copy(self.src, dst)
471 self.assert_content_equal(self.src, dst)
464 self.assert_content_equal(self.src, dst)
472 self.assert_inode_not_equal(self.src, dst)
465 self.assert_inode_not_equal(self.src, dst)
473 finally:
466 finally:
474 os.link = real_link
467 os.link = real_link
475
468
476 @skip_if_not_win32
469 @skip_if_not_win32
477 def test_windows(self):
470 def test_windows(self):
478 dst = self.dst("target")
471 dst = self.dst("target")
479 path.link_or_copy(self.src, dst)
472 path.link_or_copy(self.src, dst)
480 self.assert_content_equal(self.src, dst)
473 self.assert_content_equal(self.src, dst)
481
474
482 def test_link_twice(self):
475 def test_link_twice(self):
483 # Linking the same file twice shouldn't leave duplicates around.
476 # Linking the same file twice shouldn't leave duplicates around.
484 # See https://github.com/ipython/ipython/issues/6450
477 # See https://github.com/ipython/ipython/issues/6450
485 dst = self.dst('target')
478 dst = self.dst('target')
486 path.link_or_copy(self.src, dst)
479 path.link_or_copy(self.src, dst)
487 path.link_or_copy(self.src, dst)
480 path.link_or_copy(self.src, dst)
488 self.assert_inode_equal(self.src, dst)
481 self.assert_inode_equal(self.src, dst)
489 nt.assert_equal(sorted(os.listdir(self.tempdir.name)), ['src', 'target'])
482 nt.assert_equal(sorted(os.listdir(self.tempdir.name)), ['src', 'target'])
General Comments 0
You need to be logged in to leave comments. Login now