##// END OF EJS Templates
Update history.py to use pathlib...
Jakub Klus -
Show More
@@ -1,881 +1,885 b''
1 """ History related magics and functionality """
1 """ History related magics and functionality """
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6
6
7 import atexit
7 import atexit
8 import datetime
8 import datetime
9 import os
9 from pathlib import Path
10 import re
10 import re
11 import sqlite3
11 import sqlite3
12 import threading
12 import threading
13
13
14 from traitlets.config.configurable import LoggingConfigurable
14 from traitlets.config.configurable import LoggingConfigurable
15 from decorator import decorator
15 from decorator import decorator
16 from IPython.utils.decorators import undoc
16 from IPython.utils.decorators import undoc
17 from IPython.paths import locate_profile
17 from IPython.paths import locate_profile
18 from traitlets import (
18 from traitlets import (
19 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
19 Any, Bool, Dict, Instance, Integer, List, Unicode, Union, TraitError,
20 default, observe,
20 default, observe
21 )
21 )
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Classes and functions
24 # Classes and functions
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 @undoc
27 @undoc
28 class DummyDB(object):
28 class DummyDB(object):
29 """Dummy DB that will act as a black hole for history.
29 """Dummy DB that will act as a black hole for history.
30
30
31 Only used in the absence of sqlite"""
31 Only used in the absence of sqlite"""
32 def execute(*args, **kwargs):
32 def execute(*args, **kwargs):
33 return []
33 return []
34
34
35 def commit(self, *args, **kwargs):
35 def commit(self, *args, **kwargs):
36 pass
36 pass
37
37
38 def __enter__(self, *args, **kwargs):
38 def __enter__(self, *args, **kwargs):
39 pass
39 pass
40
40
41 def __exit__(self, *args, **kwargs):
41 def __exit__(self, *args, **kwargs):
42 pass
42 pass
43
43
44
44
45 @decorator
45 @decorator
46 def only_when_enabled(f, self, *a, **kw):
46 def only_when_enabled(f, self, *a, **kw):
47 """Decorator: return an empty list in the absence of sqlite."""
47 """Decorator: return an empty list in the absence of sqlite."""
48 if not self.enabled:
48 if not self.enabled:
49 return []
49 return []
50 else:
50 else:
51 return f(self, *a, **kw)
51 return f(self, *a, **kw)
52
52
53
53
54 # use 16kB as threshold for whether a corrupt history db should be saved
54 # use 16kB as threshold for whether a corrupt history db should be saved
55 # that should be at least 100 entries or so
55 # that should be at least 100 entries or so
56 _SAVE_DB_SIZE = 16384
56 _SAVE_DB_SIZE = 16384
57
57
58 @decorator
58 @decorator
59 def catch_corrupt_db(f, self, *a, **kw):
59 def catch_corrupt_db(f, self, *a, **kw):
60 """A decorator which wraps HistoryAccessor method calls to catch errors from
60 """A decorator which wraps HistoryAccessor method calls to catch errors from
61 a corrupt SQLite database, move the old database out of the way, and create
61 a corrupt SQLite database, move the old database out of the way, and create
62 a new one.
62 a new one.
63
63
64 We avoid clobbering larger databases because this may be triggered due to filesystem issues,
64 We avoid clobbering larger databases because this may be triggered due to filesystem issues,
65 not just a corrupt file.
65 not just a corrupt file.
66 """
66 """
67 try:
67 try:
68 return f(self, *a, **kw)
68 return f(self, *a, **kw)
69 except (sqlite3.DatabaseError, sqlite3.OperationalError) as e:
69 except (sqlite3.DatabaseError, sqlite3.OperationalError) as e:
70 self._corrupt_db_counter += 1
70 self._corrupt_db_counter += 1
71 self.log.error("Failed to open SQLite history %s (%s).", self.hist_file, e)
71 self.log.error("Failed to open SQLite history %s (%s).", self.hist_file, e)
72 if self.hist_file != ':memory:':
72 if self.hist_file != ':memory:':
73 if self._corrupt_db_counter > self._corrupt_db_limit:
73 if self._corrupt_db_counter > self._corrupt_db_limit:
74 self.hist_file = ':memory:'
74 self.hist_file = ':memory:'
75 self.log.error("Failed to load history too many times, history will not be saved.")
75 self.log.error("Failed to load history too many times, history will not be saved.")
76 elif os.path.isfile(self.hist_file):
76 elif self.hist_file.is_file():
77 # move the file out of the way
77 # move the file out of the way
78 base, ext = os.path.splitext(self.hist_file)
78 base = str(self.hist_file.parent / self.hist_file.stem)
79 size = os.stat(self.hist_file).st_size
79 ext = self.hist_file.suffix
80 size = self.hist_file.stat().st_size
80 if size >= _SAVE_DB_SIZE:
81 if size >= _SAVE_DB_SIZE:
81 # if there's significant content, avoid clobbering
82 # if there's significant content, avoid clobbering
82 now = datetime.datetime.now().isoformat().replace(':', '.')
83 now = datetime.datetime.now().isoformat().replace(':', '.')
83 newpath = base + '-corrupt-' + now + ext
84 newpath = base + '-corrupt-' + now + ext
84 # don't clobber previous corrupt backups
85 # don't clobber previous corrupt backups
85 for i in range(100):
86 for i in range(100):
86 if not os.path.isfile(newpath):
87 if not Path(newpath).exists():
87 break
88 break
88 else:
89 else:
89 newpath = base + '-corrupt-' + now + (u'-%i' % i) + ext
90 newpath = base + '-corrupt-' + now + (u'-%i' % i) + ext
90 else:
91 else:
91 # not much content, possibly empty; don't worry about clobbering
92 # not much content, possibly empty; don't worry about clobbering
92 # maybe we should just delete it?
93 # maybe we should just delete it?
93 newpath = base + '-corrupt' + ext
94 newpath = base + '-corrupt' + ext
94 os.rename(self.hist_file, newpath)
95 self.hist_file.rename(newpath)
95 self.log.error("History file was moved to %s and a new file created.", newpath)
96 self.log.error("History file was moved to %s and a new file created.", newpath)
96 self.init_db()
97 self.init_db()
97 return []
98 return []
98 else:
99 else:
99 # Failed with :memory:, something serious is wrong
100 # Failed with :memory:, something serious is wrong
100 raise
101 raise
101
102
102 class HistoryAccessorBase(LoggingConfigurable):
103 class HistoryAccessorBase(LoggingConfigurable):
103 """An abstract class for History Accessors """
104 """An abstract class for History Accessors """
104
105
105 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
106 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
106 raise NotImplementedError
107 raise NotImplementedError
107
108
108 def search(self, pattern="*", raw=True, search_raw=True,
109 def search(self, pattern="*", raw=True, search_raw=True,
109 output=False, n=None, unique=False):
110 output=False, n=None, unique=False):
110 raise NotImplementedError
111 raise NotImplementedError
111
112
112 def get_range(self, session, start=1, stop=None, raw=True,output=False):
113 def get_range(self, session, start=1, stop=None, raw=True,output=False):
113 raise NotImplementedError
114 raise NotImplementedError
114
115
115 def get_range_by_str(self, rangestr, raw=True, output=False):
116 def get_range_by_str(self, rangestr, raw=True, output=False):
116 raise NotImplementedError
117 raise NotImplementedError
117
118
118
119
119 class HistoryAccessor(HistoryAccessorBase):
120 class HistoryAccessor(HistoryAccessorBase):
120 """Access the history database without adding to it.
121 """Access the history database without adding to it.
121
122
122 This is intended for use by standalone history tools. IPython shells use
123 This is intended for use by standalone history tools. IPython shells use
123 HistoryManager, below, which is a subclass of this."""
124 HistoryManager, below, which is a subclass of this."""
124
125
125 # counter for init_db retries, so we don't keep trying over and over
126 # counter for init_db retries, so we don't keep trying over and over
126 _corrupt_db_counter = 0
127 _corrupt_db_counter = 0
127 # after two failures, fallback on :memory:
128 # after two failures, fallback on :memory:
128 _corrupt_db_limit = 2
129 _corrupt_db_limit = 2
129
130
130 # String holding the path to the history file
131 # String holding the path to the history file
131 hist_file = Unicode(
132 hist_file = Union([Instance(Path), Unicode()],
132 help="""Path to file to use for SQLite history database.
133 help="""Path to file to use for SQLite history database.
133
134
134 By default, IPython will put the history database in the IPython
135 By default, IPython will put the history database in the IPython
135 profile directory. If you would rather share one history among
136 profile directory. If you would rather share one history among
136 profiles, you can set this value in each, so that they are consistent.
137 profiles, you can set this value in each, so that they are consistent.
137
138
138 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
139 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
139 mounts. If you see IPython hanging, try setting this to something on a
140 mounts. If you see IPython hanging, try setting this to something on a
140 local disk, e.g::
141 local disk, e.g::
141
142
142 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
143 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
143
144
144 you can also use the specific value `:memory:` (including the colon
145 you can also use the specific value `:memory:` (including the colon
145 at both end but not the back ticks), to avoid creating an history file.
146 at both end but not the back ticks), to avoid creating an history file.
146
147
147 """).tag(config=True)
148 """
148
149 ).tag(config=True)
150
149 enabled = Bool(True,
151 enabled = Bool(True,
150 help="""enable the SQLite history
152 help="""enable the SQLite history
151
153
152 set enabled=False to disable the SQLite history,
154 set enabled=False to disable the SQLite history,
153 in which case there will be no stored history, no SQLite connection,
155 in which case there will be no stored history, no SQLite connection,
154 and no background saving thread. This may be necessary in some
156 and no background saving thread. This may be necessary in some
155 threaded environments where IPython is embedded.
157 threaded environments where IPython is embedded.
156 """
158 """
157 ).tag(config=True)
159 ).tag(config=True)
158
160
159 connection_options = Dict(
161 connection_options = Dict(
160 help="""Options for configuring the SQLite connection
162 help="""Options for configuring the SQLite connection
161
163
162 These options are passed as keyword args to sqlite3.connect
164 These options are passed as keyword args to sqlite3.connect
163 when establishing database connections.
165 when establishing database connections.
164 """
166 """
165 ).tag(config=True)
167 ).tag(config=True)
166
168
167 # The SQLite database
169 # The SQLite database
168 db = Any()
170 db = Any()
169 @observe('db')
171 @observe('db')
170 def _db_changed(self, change):
172 def _db_changed(self, change):
171 """validate the db, since it can be an Instance of two different types"""
173 """validate the db, since it can be an Instance of two different types"""
172 new = change['new']
174 new = change['new']
173 connection_types = (DummyDB, sqlite3.Connection)
175 connection_types = (DummyDB, sqlite3.Connection)
174 if not isinstance(new, connection_types):
176 if not isinstance(new, connection_types):
175 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
177 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
176 (self.__class__.__name__, new)
178 (self.__class__.__name__, new)
177 raise TraitError(msg)
179 raise TraitError(msg)
178
180
179 def __init__(self, profile='default', hist_file=u'', **traits):
181 def __init__(self, profile='default', hist_file="", **traits):
180 """Create a new history accessor.
182 """Create a new history accessor.
181
183
182 Parameters
184 Parameters
183 ----------
185 ----------
184 profile : str
186 profile : str
185 The name of the profile from which to open history.
187 The name of the profile from which to open history.
186 hist_file : str
188 hist_file : str
187 Path to an SQLite history database stored by IPython. If specified,
189 Path to an SQLite history database stored by IPython. If specified,
188 hist_file overrides profile.
190 hist_file overrides profile.
189 config : :class:`~traitlets.config.loader.Config`
191 config : :class:`~traitlets.config.loader.Config`
190 Config object. hist_file can also be set through this.
192 Config object. hist_file can also be set through this.
191 """
193 """
192 # We need a pointer back to the shell for various tasks.
194 # We need a pointer back to the shell for various tasks.
193 super(HistoryAccessor, self).__init__(**traits)
195 super(HistoryAccessor, self).__init__(**traits)
194 # defer setting hist_file from kwarg until after init,
196 # defer setting hist_file from kwarg until after init,
195 # otherwise the default kwarg value would clobber any value
197 # otherwise the default kwarg value would clobber any value
196 # set by config
198 # set by config
197 if hist_file:
199 if hist_file:
198 self.hist_file = hist_file
200 self.hist_file = hist_file
199
201
200 if self.hist_file == u'':
202 try:
203 self.hist_file
204 except TraitError:
201 # No one has set the hist_file, yet.
205 # No one has set the hist_file, yet.
202 self.hist_file = self._get_hist_file_name(profile)
206 self.hist_file = self._get_hist_file_name(profile)
203
207
204 self.init_db()
208 self.init_db()
205
209
206 def _get_hist_file_name(self, profile='default'):
210 def _get_hist_file_name(self, profile='default'):
207 """Find the history file for the given profile name.
211 """Find the history file for the given profile name.
208
212
209 This is overridden by the HistoryManager subclass, to use the shell's
213 This is overridden by the HistoryManager subclass, to use the shell's
210 active profile.
214 active profile.
211
215
212 Parameters
216 Parameters
213 ----------
217 ----------
214 profile : str
218 profile : str
215 The name of a profile which has a history file.
219 The name of a profile which has a history file.
216 """
220 """
217 return os.path.join(locate_profile(profile), 'history.sqlite')
221 return Path(locate_profile(profile)) / 'history.sqlite'
218
222
219 @catch_corrupt_db
223 @catch_corrupt_db
220 def init_db(self):
224 def init_db(self):
221 """Connect to the database, and create tables if necessary."""
225 """Connect to the database, and create tables if necessary."""
222 if not self.enabled:
226 if not self.enabled:
223 self.db = DummyDB()
227 self.db = DummyDB()
224 return
228 return
225
229
226 # use detect_types so that timestamps return datetime objects
230 # use detect_types so that timestamps return datetime objects
227 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
231 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
228 kwargs.update(self.connection_options)
232 kwargs.update(self.connection_options)
229 self.db = sqlite3.connect(self.hist_file, **kwargs)
233 self.db = sqlite3.connect(self.hist_file, **kwargs)
230 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
234 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
231 primary key autoincrement, start timestamp,
235 primary key autoincrement, start timestamp,
232 end timestamp, num_cmds integer, remark text)""")
236 end timestamp, num_cmds integer, remark text)""")
233 self.db.execute("""CREATE TABLE IF NOT EXISTS history
237 self.db.execute("""CREATE TABLE IF NOT EXISTS history
234 (session integer, line integer, source text, source_raw text,
238 (session integer, line integer, source text, source_raw text,
235 PRIMARY KEY (session, line))""")
239 PRIMARY KEY (session, line))""")
236 # Output history is optional, but ensure the table's there so it can be
240 # Output history is optional, but ensure the table's there so it can be
237 # enabled later.
241 # enabled later.
238 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
242 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
239 (session integer, line integer, output text,
243 (session integer, line integer, output text,
240 PRIMARY KEY (session, line))""")
244 PRIMARY KEY (session, line))""")
241 self.db.commit()
245 self.db.commit()
242 # success! reset corrupt db count
246 # success! reset corrupt db count
243 self._corrupt_db_counter = 0
247 self._corrupt_db_counter = 0
244
248
245 def writeout_cache(self):
249 def writeout_cache(self):
246 """Overridden by HistoryManager to dump the cache before certain
250 """Overridden by HistoryManager to dump the cache before certain
247 database lookups."""
251 database lookups."""
248 pass
252 pass
249
253
250 ## -------------------------------
254 ## -------------------------------
251 ## Methods for retrieving history:
255 ## Methods for retrieving history:
252 ## -------------------------------
256 ## -------------------------------
253 def _run_sql(self, sql, params, raw=True, output=False):
257 def _run_sql(self, sql, params, raw=True, output=False):
254 """Prepares and runs an SQL query for the history database.
258 """Prepares and runs an SQL query for the history database.
255
259
256 Parameters
260 Parameters
257 ----------
261 ----------
258 sql : str
262 sql : str
259 Any filtering expressions to go after SELECT ... FROM ...
263 Any filtering expressions to go after SELECT ... FROM ...
260 params : tuple
264 params : tuple
261 Parameters passed to the SQL query (to replace "?")
265 Parameters passed to the SQL query (to replace "?")
262 raw, output : bool
266 raw, output : bool
263 See :meth:`get_range`
267 See :meth:`get_range`
264
268
265 Returns
269 Returns
266 -------
270 -------
267 Tuples as :meth:`get_range`
271 Tuples as :meth:`get_range`
268 """
272 """
269 toget = 'source_raw' if raw else 'source'
273 toget = 'source_raw' if raw else 'source'
270 sqlfrom = "history"
274 sqlfrom = "history"
271 if output:
275 if output:
272 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
276 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
273 toget = "history.%s, output_history.output" % toget
277 toget = "history.%s, output_history.output" % toget
274 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
278 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
275 (toget, sqlfrom) + sql, params)
279 (toget, sqlfrom) + sql, params)
276 if output: # Regroup into 3-tuples, and parse JSON
280 if output: # Regroup into 3-tuples, and parse JSON
277 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
281 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
278 return cur
282 return cur
279
283
280 @only_when_enabled
284 @only_when_enabled
281 @catch_corrupt_db
285 @catch_corrupt_db
282 def get_session_info(self, session):
286 def get_session_info(self, session):
283 """Get info about a session.
287 """Get info about a session.
284
288
285 Parameters
289 Parameters
286 ----------
290 ----------
287
291
288 session : int
292 session : int
289 Session number to retrieve.
293 Session number to retrieve.
290
294
291 Returns
295 Returns
292 -------
296 -------
293
297
294 session_id : int
298 session_id : int
295 Session ID number
299 Session ID number
296 start : datetime
300 start : datetime
297 Timestamp for the start of the session.
301 Timestamp for the start of the session.
298 end : datetime
302 end : datetime
299 Timestamp for the end of the session, or None if IPython crashed.
303 Timestamp for the end of the session, or None if IPython crashed.
300 num_cmds : int
304 num_cmds : int
301 Number of commands run, or None if IPython crashed.
305 Number of commands run, or None if IPython crashed.
302 remark : unicode
306 remark : unicode
303 A manually set description.
307 A manually set description.
304 """
308 """
305 query = "SELECT * from sessions where session == ?"
309 query = "SELECT * from sessions where session == ?"
306 return self.db.execute(query, (session,)).fetchone()
310 return self.db.execute(query, (session,)).fetchone()
307
311
308 @catch_corrupt_db
312 @catch_corrupt_db
309 def get_last_session_id(self):
313 def get_last_session_id(self):
310 """Get the last session ID currently in the database.
314 """Get the last session ID currently in the database.
311
315
312 Within IPython, this should be the same as the value stored in
316 Within IPython, this should be the same as the value stored in
313 :attr:`HistoryManager.session_number`.
317 :attr:`HistoryManager.session_number`.
314 """
318 """
315 for record in self.get_tail(n=1, include_latest=True):
319 for record in self.get_tail(n=1, include_latest=True):
316 return record[0]
320 return record[0]
317
321
318 @catch_corrupt_db
322 @catch_corrupt_db
319 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
323 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
320 """Get the last n lines from the history database.
324 """Get the last n lines from the history database.
321
325
322 Parameters
326 Parameters
323 ----------
327 ----------
324 n : int
328 n : int
325 The number of lines to get
329 The number of lines to get
326 raw, output : bool
330 raw, output : bool
327 See :meth:`get_range`
331 See :meth:`get_range`
328 include_latest : bool
332 include_latest : bool
329 If False (default), n+1 lines are fetched, and the latest one
333 If False (default), n+1 lines are fetched, and the latest one
330 is discarded. This is intended to be used where the function
334 is discarded. This is intended to be used where the function
331 is called by a user command, which it should not return.
335 is called by a user command, which it should not return.
332
336
333 Returns
337 Returns
334 -------
338 -------
335 Tuples as :meth:`get_range`
339 Tuples as :meth:`get_range`
336 """
340 """
337 self.writeout_cache()
341 self.writeout_cache()
338 if not include_latest:
342 if not include_latest:
339 n += 1
343 n += 1
340 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
344 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
341 (n,), raw=raw, output=output)
345 (n,), raw=raw, output=output)
342 if not include_latest:
346 if not include_latest:
343 return reversed(list(cur)[1:])
347 return reversed(list(cur)[1:])
344 return reversed(list(cur))
348 return reversed(list(cur))
345
349
346 @catch_corrupt_db
350 @catch_corrupt_db
347 def search(self, pattern="*", raw=True, search_raw=True,
351 def search(self, pattern="*", raw=True, search_raw=True,
348 output=False, n=None, unique=False):
352 output=False, n=None, unique=False):
349 """Search the database using unix glob-style matching (wildcards
353 """Search the database using unix glob-style matching (wildcards
350 * and ?).
354 * and ?).
351
355
352 Parameters
356 Parameters
353 ----------
357 ----------
354 pattern : str
358 pattern : str
355 The wildcarded pattern to match when searching
359 The wildcarded pattern to match when searching
356 search_raw : bool
360 search_raw : bool
357 If True, search the raw input, otherwise, the parsed input
361 If True, search the raw input, otherwise, the parsed input
358 raw, output : bool
362 raw, output : bool
359 See :meth:`get_range`
363 See :meth:`get_range`
360 n : None or int
364 n : None or int
361 If an integer is given, it defines the limit of
365 If an integer is given, it defines the limit of
362 returned entries.
366 returned entries.
363 unique : bool
367 unique : bool
364 When it is true, return only unique entries.
368 When it is true, return only unique entries.
365
369
366 Returns
370 Returns
367 -------
371 -------
368 Tuples as :meth:`get_range`
372 Tuples as :meth:`get_range`
369 """
373 """
370 tosearch = "source_raw" if search_raw else "source"
374 tosearch = "source_raw" if search_raw else "source"
371 if output:
375 if output:
372 tosearch = "history." + tosearch
376 tosearch = "history." + tosearch
373 self.writeout_cache()
377 self.writeout_cache()
374 sqlform = "WHERE %s GLOB ?" % tosearch
378 sqlform = "WHERE %s GLOB ?" % tosearch
375 params = (pattern,)
379 params = (pattern,)
376 if unique:
380 if unique:
377 sqlform += ' GROUP BY {0}'.format(tosearch)
381 sqlform += ' GROUP BY {0}'.format(tosearch)
378 if n is not None:
382 if n is not None:
379 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
383 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
380 params += (n,)
384 params += (n,)
381 elif unique:
385 elif unique:
382 sqlform += " ORDER BY session, line"
386 sqlform += " ORDER BY session, line"
383 cur = self._run_sql(sqlform, params, raw=raw, output=output)
387 cur = self._run_sql(sqlform, params, raw=raw, output=output)
384 if n is not None:
388 if n is not None:
385 return reversed(list(cur))
389 return reversed(list(cur))
386 return cur
390 return cur
387
391
388 @catch_corrupt_db
392 @catch_corrupt_db
389 def get_range(self, session, start=1, stop=None, raw=True,output=False):
393 def get_range(self, session, start=1, stop=None, raw=True,output=False):
390 """Retrieve input by session.
394 """Retrieve input by session.
391
395
392 Parameters
396 Parameters
393 ----------
397 ----------
394 session : int
398 session : int
395 Session number to retrieve.
399 Session number to retrieve.
396 start : int
400 start : int
397 First line to retrieve.
401 First line to retrieve.
398 stop : int
402 stop : int
399 End of line range (excluded from output itself). If None, retrieve
403 End of line range (excluded from output itself). If None, retrieve
400 to the end of the session.
404 to the end of the session.
401 raw : bool
405 raw : bool
402 If True, return untranslated input
406 If True, return untranslated input
403 output : bool
407 output : bool
404 If True, attempt to include output. This will be 'real' Python
408 If True, attempt to include output. This will be 'real' Python
405 objects for the current session, or text reprs from previous
409 objects for the current session, or text reprs from previous
406 sessions if db_log_output was enabled at the time. Where no output
410 sessions if db_log_output was enabled at the time. Where no output
407 is found, None is used.
411 is found, None is used.
408
412
409 Returns
413 Returns
410 -------
414 -------
411 entries
415 entries
412 An iterator over the desired lines. Each line is a 3-tuple, either
416 An iterator over the desired lines. Each line is a 3-tuple, either
413 (session, line, input) if output is False, or
417 (session, line, input) if output is False, or
414 (session, line, (input, output)) if output is True.
418 (session, line, (input, output)) if output is True.
415 """
419 """
416 if stop:
420 if stop:
417 lineclause = "line >= ? AND line < ?"
421 lineclause = "line >= ? AND line < ?"
418 params = (session, start, stop)
422 params = (session, start, stop)
419 else:
423 else:
420 lineclause = "line>=?"
424 lineclause = "line>=?"
421 params = (session, start)
425 params = (session, start)
422
426
423 return self._run_sql("WHERE session==? AND %s" % lineclause,
427 return self._run_sql("WHERE session==? AND %s" % lineclause,
424 params, raw=raw, output=output)
428 params, raw=raw, output=output)
425
429
426 def get_range_by_str(self, rangestr, raw=True, output=False):
430 def get_range_by_str(self, rangestr, raw=True, output=False):
427 """Get lines of history from a string of ranges, as used by magic
431 """Get lines of history from a string of ranges, as used by magic
428 commands %hist, %save, %macro, etc.
432 commands %hist, %save, %macro, etc.
429
433
430 Parameters
434 Parameters
431 ----------
435 ----------
432 rangestr : str
436 rangestr : str
433 A string specifying ranges, e.g. "5 ~2/1-4". See
437 A string specifying ranges, e.g. "5 ~2/1-4". See
434 :func:`magic_history` for full details.
438 :func:`magic_history` for full details.
435 raw, output : bool
439 raw, output : bool
436 As :meth:`get_range`
440 As :meth:`get_range`
437
441
438 Returns
442 Returns
439 -------
443 -------
440 Tuples as :meth:`get_range`
444 Tuples as :meth:`get_range`
441 """
445 """
442 for sess, s, e in extract_hist_ranges(rangestr):
446 for sess, s, e in extract_hist_ranges(rangestr):
443 for line in self.get_range(sess, s, e, raw=raw, output=output):
447 for line in self.get_range(sess, s, e, raw=raw, output=output):
444 yield line
448 yield line
445
449
446
450
447 class HistoryManager(HistoryAccessor):
451 class HistoryManager(HistoryAccessor):
448 """A class to organize all history-related functionality in one place.
452 """A class to organize all history-related functionality in one place.
449 """
453 """
450 # Public interface
454 # Public interface
451
455
452 # An instance of the IPython shell we are attached to
456 # An instance of the IPython shell we are attached to
453 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
457 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
454 allow_none=True)
458 allow_none=True)
455 # Lists to hold processed and raw history. These start with a blank entry
459 # Lists to hold processed and raw history. These start with a blank entry
456 # so that we can index them starting from 1
460 # so that we can index them starting from 1
457 input_hist_parsed = List([""])
461 input_hist_parsed = List([""])
458 input_hist_raw = List([""])
462 input_hist_raw = List([""])
459 # A list of directories visited during session
463 # A list of directories visited during session
460 dir_hist = List()
464 dir_hist = List()
461 @default('dir_hist')
465 @default('dir_hist')
462 def _dir_hist_default(self):
466 def _dir_hist_default(self):
463 try:
467 try:
464 return [os.getcwd()]
468 return [Path.cwd()]
465 except OSError:
469 except OSError:
466 return []
470 return []
467
471
468 # A dict of output history, keyed with ints from the shell's
472 # A dict of output history, keyed with ints from the shell's
469 # execution count.
473 # execution count.
470 output_hist = Dict()
474 output_hist = Dict()
471 # The text/plain repr of outputs.
475 # The text/plain repr of outputs.
472 output_hist_reprs = Dict()
476 output_hist_reprs = Dict()
473
477
474 # The number of the current session in the history database
478 # The number of the current session in the history database
475 session_number = Integer()
479 session_number = Integer()
476
480
477 db_log_output = Bool(False,
481 db_log_output = Bool(False,
478 help="Should the history database include output? (default: no)"
482 help="Should the history database include output? (default: no)"
479 ).tag(config=True)
483 ).tag(config=True)
480 db_cache_size = Integer(0,
484 db_cache_size = Integer(0,
481 help="Write to database every x commands (higher values save disk access & power).\n"
485 help="Write to database every x commands (higher values save disk access & power).\n"
482 "Values of 1 or less effectively disable caching."
486 "Values of 1 or less effectively disable caching."
483 ).tag(config=True)
487 ).tag(config=True)
484 # The input and output caches
488 # The input and output caches
485 db_input_cache = List()
489 db_input_cache = List()
486 db_output_cache = List()
490 db_output_cache = List()
487
491
488 # History saving in separate thread
492 # History saving in separate thread
489 save_thread = Instance('IPython.core.history.HistorySavingThread',
493 save_thread = Instance('IPython.core.history.HistorySavingThread',
490 allow_none=True)
494 allow_none=True)
491 save_flag = Instance(threading.Event, allow_none=True)
495 save_flag = Instance(threading.Event, allow_none=True)
492
496
493 # Private interface
497 # Private interface
494 # Variables used to store the three last inputs from the user. On each new
498 # Variables used to store the three last inputs from the user. On each new
495 # history update, we populate the user's namespace with these, shifted as
499 # history update, we populate the user's namespace with these, shifted as
496 # necessary.
500 # necessary.
497 _i00 = Unicode(u'')
501 _i00 = Unicode(u'')
498 _i = Unicode(u'')
502 _i = Unicode(u'')
499 _ii = Unicode(u'')
503 _ii = Unicode(u'')
500 _iii = Unicode(u'')
504 _iii = Unicode(u'')
501
505
502 # A regex matching all forms of the exit command, so that we don't store
506 # A regex matching all forms of the exit command, so that we don't store
503 # them in the history (it's annoying to rewind the first entry and land on
507 # them in the history (it's annoying to rewind the first entry and land on
504 # an exit call).
508 # an exit call).
505 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
509 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
506
510
507 def __init__(self, shell=None, config=None, **traits):
511 def __init__(self, shell=None, config=None, **traits):
508 """Create a new history manager associated with a shell instance.
512 """Create a new history manager associated with a shell instance.
509 """
513 """
510 # We need a pointer back to the shell for various tasks.
514 # We need a pointer back to the shell for various tasks.
511 super(HistoryManager, self).__init__(shell=shell, config=config,
515 super(HistoryManager, self).__init__(shell=shell, config=config,
512 **traits)
516 **traits)
513 self.save_flag = threading.Event()
517 self.save_flag = threading.Event()
514 self.db_input_cache_lock = threading.Lock()
518 self.db_input_cache_lock = threading.Lock()
515 self.db_output_cache_lock = threading.Lock()
519 self.db_output_cache_lock = threading.Lock()
516
520
517 try:
521 try:
518 self.new_session()
522 self.new_session()
519 except sqlite3.OperationalError:
523 except sqlite3.OperationalError:
520 self.log.error("Failed to create history session in %s. History will not be saved.",
524 self.log.error("Failed to create history session in %s. History will not be saved.",
521 self.hist_file, exc_info=True)
525 self.hist_file, exc_info=True)
522 self.hist_file = ':memory:'
526 self.hist_file = ':memory:'
523
527
524 if self.enabled and self.hist_file != ':memory:':
528 if self.enabled and self.hist_file != ':memory:':
525 self.save_thread = HistorySavingThread(self)
529 self.save_thread = HistorySavingThread(self)
526 self.save_thread.start()
530 self.save_thread.start()
527
531
528 def _get_hist_file_name(self, profile=None):
532 def _get_hist_file_name(self, profile=None):
529 """Get default history file name based on the Shell's profile.
533 """Get default history file name based on the Shell's profile.
530
534
531 The profile parameter is ignored, but must exist for compatibility with
535 The profile parameter is ignored, but must exist for compatibility with
532 the parent class."""
536 the parent class."""
533 profile_dir = self.shell.profile_dir.location
537 profile_dir = self.shell.profile_dir.location
534 return os.path.join(profile_dir, 'history.sqlite')
538 return Path(profile_dir)/'history.sqlite'
535
539
536 @only_when_enabled
540 @only_when_enabled
537 def new_session(self, conn=None):
541 def new_session(self, conn=None):
538 """Get a new session number."""
542 """Get a new session number."""
539 if conn is None:
543 if conn is None:
540 conn = self.db
544 conn = self.db
541
545
542 with conn:
546 with conn:
543 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
547 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
544 NULL, "") """, (datetime.datetime.now(),))
548 NULL, "") """, (datetime.datetime.now(),))
545 self.session_number = cur.lastrowid
549 self.session_number = cur.lastrowid
546
550
547 def end_session(self):
551 def end_session(self):
548 """Close the database session, filling in the end time and line count."""
552 """Close the database session, filling in the end time and line count."""
549 self.writeout_cache()
553 self.writeout_cache()
550 with self.db:
554 with self.db:
551 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
555 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
552 session==?""", (datetime.datetime.now(),
556 session==?""", (datetime.datetime.now(),
553 len(self.input_hist_parsed)-1, self.session_number))
557 len(self.input_hist_parsed)-1, self.session_number))
554 self.session_number = 0
558 self.session_number = 0
555
559
556 def name_session(self, name):
560 def name_session(self, name):
557 """Give the current session a name in the history database."""
561 """Give the current session a name in the history database."""
558 with self.db:
562 with self.db:
559 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
563 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
560 (name, self.session_number))
564 (name, self.session_number))
561
565
562 def reset(self, new_session=True):
566 def reset(self, new_session=True):
563 """Clear the session history, releasing all object references, and
567 """Clear the session history, releasing all object references, and
564 optionally open a new session."""
568 optionally open a new session."""
565 self.output_hist.clear()
569 self.output_hist.clear()
566 # The directory history can't be completely empty
570 # The directory history can't be completely empty
567 self.dir_hist[:] = [os.getcwd()]
571 self.dir_hist[:] = [Path.cwd()]
568
572
569 if new_session:
573 if new_session:
570 if self.session_number:
574 if self.session_number:
571 self.end_session()
575 self.end_session()
572 self.input_hist_parsed[:] = [""]
576 self.input_hist_parsed[:] = [""]
573 self.input_hist_raw[:] = [""]
577 self.input_hist_raw[:] = [""]
574 self.new_session()
578 self.new_session()
575
579
576 # ------------------------------
580 # ------------------------------
577 # Methods for retrieving history
581 # Methods for retrieving history
578 # ------------------------------
582 # ------------------------------
579 def get_session_info(self, session=0):
583 def get_session_info(self, session=0):
580 """Get info about a session.
584 """Get info about a session.
581
585
582 Parameters
586 Parameters
583 ----------
587 ----------
584
588
585 session : int
589 session : int
586 Session number to retrieve. The current session is 0, and negative
590 Session number to retrieve. The current session is 0, and negative
587 numbers count back from current session, so -1 is the previous session.
591 numbers count back from current session, so -1 is the previous session.
588
592
589 Returns
593 Returns
590 -------
594 -------
591
595
592 session_id : int
596 session_id : int
593 Session ID number
597 Session ID number
594 start : datetime
598 start : datetime
595 Timestamp for the start of the session.
599 Timestamp for the start of the session.
596 end : datetime
600 end : datetime
597 Timestamp for the end of the session, or None if IPython crashed.
601 Timestamp for the end of the session, or None if IPython crashed.
598 num_cmds : int
602 num_cmds : int
599 Number of commands run, or None if IPython crashed.
603 Number of commands run, or None if IPython crashed.
600 remark : unicode
604 remark : unicode
601 A manually set description.
605 A manually set description.
602 """
606 """
603 if session <= 0:
607 if session <= 0:
604 session += self.session_number
608 session += self.session_number
605
609
606 return super(HistoryManager, self).get_session_info(session=session)
610 return super(HistoryManager, self).get_session_info(session=session)
607
611
608 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
612 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
609 """Get input and output history from the current session. Called by
613 """Get input and output history from the current session. Called by
610 get_range, and takes similar parameters."""
614 get_range, and takes similar parameters."""
611 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
615 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
612
616
613 n = len(input_hist)
617 n = len(input_hist)
614 if start < 0:
618 if start < 0:
615 start += n
619 start += n
616 if not stop or (stop > n):
620 if not stop or (stop > n):
617 stop = n
621 stop = n
618 elif stop < 0:
622 elif stop < 0:
619 stop += n
623 stop += n
620
624
621 for i in range(start, stop):
625 for i in range(start, stop):
622 if output:
626 if output:
623 line = (input_hist[i], self.output_hist_reprs.get(i))
627 line = (input_hist[i], self.output_hist_reprs.get(i))
624 else:
628 else:
625 line = input_hist[i]
629 line = input_hist[i]
626 yield (0, i, line)
630 yield (0, i, line)
627
631
628 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
632 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
629 """Retrieve input by session.
633 """Retrieve input by session.
630
634
631 Parameters
635 Parameters
632 ----------
636 ----------
633 session : int
637 session : int
634 Session number to retrieve. The current session is 0, and negative
638 Session number to retrieve. The current session is 0, and negative
635 numbers count back from current session, so -1 is previous session.
639 numbers count back from current session, so -1 is previous session.
636 start : int
640 start : int
637 First line to retrieve.
641 First line to retrieve.
638 stop : int
642 stop : int
639 End of line range (excluded from output itself). If None, retrieve
643 End of line range (excluded from output itself). If None, retrieve
640 to the end of the session.
644 to the end of the session.
641 raw : bool
645 raw : bool
642 If True, return untranslated input
646 If True, return untranslated input
643 output : bool
647 output : bool
644 If True, attempt to include output. This will be 'real' Python
648 If True, attempt to include output. This will be 'real' Python
645 objects for the current session, or text reprs from previous
649 objects for the current session, or text reprs from previous
646 sessions if db_log_output was enabled at the time. Where no output
650 sessions if db_log_output was enabled at the time. Where no output
647 is found, None is used.
651 is found, None is used.
648
652
649 Returns
653 Returns
650 -------
654 -------
651 entries
655 entries
652 An iterator over the desired lines. Each line is a 3-tuple, either
656 An iterator over the desired lines. Each line is a 3-tuple, either
653 (session, line, input) if output is False, or
657 (session, line, input) if output is False, or
654 (session, line, (input, output)) if output is True.
658 (session, line, (input, output)) if output is True.
655 """
659 """
656 if session <= 0:
660 if session <= 0:
657 session += self.session_number
661 session += self.session_number
658 if session==self.session_number: # Current session
662 if session==self.session_number: # Current session
659 return self._get_range_session(start, stop, raw, output)
663 return self._get_range_session(start, stop, raw, output)
660 return super(HistoryManager, self).get_range(session, start, stop, raw,
664 return super(HistoryManager, self).get_range(session, start, stop, raw,
661 output)
665 output)
662
666
663 ## ----------------------------
667 ## ----------------------------
664 ## Methods for storing history:
668 ## Methods for storing history:
665 ## ----------------------------
669 ## ----------------------------
666 def store_inputs(self, line_num, source, source_raw=None):
670 def store_inputs(self, line_num, source, source_raw=None):
667 """Store source and raw input in history and create input cache
671 """Store source and raw input in history and create input cache
668 variables ``_i*``.
672 variables ``_i*``.
669
673
670 Parameters
674 Parameters
671 ----------
675 ----------
672 line_num : int
676 line_num : int
673 The prompt number of this input.
677 The prompt number of this input.
674
678
675 source : str
679 source : str
676 Python input.
680 Python input.
677
681
678 source_raw : str, optional
682 source_raw : str, optional
679 If given, this is the raw input without any IPython transformations
683 If given, this is the raw input without any IPython transformations
680 applied to it. If not given, ``source`` is used.
684 applied to it. If not given, ``source`` is used.
681 """
685 """
682 if source_raw is None:
686 if source_raw is None:
683 source_raw = source
687 source_raw = source
684 source = source.rstrip('\n')
688 source = source.rstrip('\n')
685 source_raw = source_raw.rstrip('\n')
689 source_raw = source_raw.rstrip('\n')
686
690
687 # do not store exit/quit commands
691 # do not store exit/quit commands
688 if self._exit_re.match(source_raw.strip()):
692 if self._exit_re.match(source_raw.strip()):
689 return
693 return
690
694
691 self.input_hist_parsed.append(source)
695 self.input_hist_parsed.append(source)
692 self.input_hist_raw.append(source_raw)
696 self.input_hist_raw.append(source_raw)
693
697
694 with self.db_input_cache_lock:
698 with self.db_input_cache_lock:
695 self.db_input_cache.append((line_num, source, source_raw))
699 self.db_input_cache.append((line_num, source, source_raw))
696 # Trigger to flush cache and write to DB.
700 # Trigger to flush cache and write to DB.
697 if len(self.db_input_cache) >= self.db_cache_size:
701 if len(self.db_input_cache) >= self.db_cache_size:
698 self.save_flag.set()
702 self.save_flag.set()
699
703
700 # update the auto _i variables
704 # update the auto _i variables
701 self._iii = self._ii
705 self._iii = self._ii
702 self._ii = self._i
706 self._ii = self._i
703 self._i = self._i00
707 self._i = self._i00
704 self._i00 = source_raw
708 self._i00 = source_raw
705
709
706 # hackish access to user namespace to create _i1,_i2... dynamically
710 # hackish access to user namespace to create _i1,_i2... dynamically
707 new_i = '_i%s' % line_num
711 new_i = '_i%s' % line_num
708 to_main = {'_i': self._i,
712 to_main = {'_i': self._i,
709 '_ii': self._ii,
713 '_ii': self._ii,
710 '_iii': self._iii,
714 '_iii': self._iii,
711 new_i : self._i00 }
715 new_i : self._i00 }
712
716
713 if self.shell is not None:
717 if self.shell is not None:
714 self.shell.push(to_main, interactive=False)
718 self.shell.push(to_main, interactive=False)
715
719
716 def store_output(self, line_num):
720 def store_output(self, line_num):
717 """If database output logging is enabled, this saves all the
721 """If database output logging is enabled, this saves all the
718 outputs from the indicated prompt number to the database. It's
722 outputs from the indicated prompt number to the database. It's
719 called by run_cell after code has been executed.
723 called by run_cell after code has been executed.
720
724
721 Parameters
725 Parameters
722 ----------
726 ----------
723 line_num : int
727 line_num : int
724 The line number from which to save outputs
728 The line number from which to save outputs
725 """
729 """
726 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
730 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
727 return
731 return
728 output = self.output_hist_reprs[line_num]
732 output = self.output_hist_reprs[line_num]
729
733
730 with self.db_output_cache_lock:
734 with self.db_output_cache_lock:
731 self.db_output_cache.append((line_num, output))
735 self.db_output_cache.append((line_num, output))
732 if self.db_cache_size <= 1:
736 if self.db_cache_size <= 1:
733 self.save_flag.set()
737 self.save_flag.set()
734
738
735 def _writeout_input_cache(self, conn):
739 def _writeout_input_cache(self, conn):
736 with conn:
740 with conn:
737 for line in self.db_input_cache:
741 for line in self.db_input_cache:
738 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
742 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
739 (self.session_number,)+line)
743 (self.session_number,)+line)
740
744
741 def _writeout_output_cache(self, conn):
745 def _writeout_output_cache(self, conn):
742 with conn:
746 with conn:
743 for line in self.db_output_cache:
747 for line in self.db_output_cache:
744 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
748 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
745 (self.session_number,)+line)
749 (self.session_number,)+line)
746
750
747 @only_when_enabled
751 @only_when_enabled
748 def writeout_cache(self, conn=None):
752 def writeout_cache(self, conn=None):
749 """Write any entries in the cache to the database."""
753 """Write any entries in the cache to the database."""
750 if conn is None:
754 if conn is None:
751 conn = self.db
755 conn = self.db
752
756
753 with self.db_input_cache_lock:
757 with self.db_input_cache_lock:
754 try:
758 try:
755 self._writeout_input_cache(conn)
759 self._writeout_input_cache(conn)
756 except sqlite3.IntegrityError:
760 except sqlite3.IntegrityError:
757 self.new_session(conn)
761 self.new_session(conn)
758 print("ERROR! Session/line number was not unique in",
762 print("ERROR! Session/line number was not unique in",
759 "database. History logging moved to new session",
763 "database. History logging moved to new session",
760 self.session_number)
764 self.session_number)
761 try:
765 try:
762 # Try writing to the new session. If this fails, don't
766 # Try writing to the new session. If this fails, don't
763 # recurse
767 # recurse
764 self._writeout_input_cache(conn)
768 self._writeout_input_cache(conn)
765 except sqlite3.IntegrityError:
769 except sqlite3.IntegrityError:
766 pass
770 pass
767 finally:
771 finally:
768 self.db_input_cache = []
772 self.db_input_cache = []
769
773
770 with self.db_output_cache_lock:
774 with self.db_output_cache_lock:
771 try:
775 try:
772 self._writeout_output_cache(conn)
776 self._writeout_output_cache(conn)
773 except sqlite3.IntegrityError:
777 except sqlite3.IntegrityError:
774 print("!! Session/line number for output was not unique",
778 print("!! Session/line number for output was not unique",
775 "in database. Output will not be stored.")
779 "in database. Output will not be stored.")
776 finally:
780 finally:
777 self.db_output_cache = []
781 self.db_output_cache = []
778
782
779
783
780 class HistorySavingThread(threading.Thread):
784 class HistorySavingThread(threading.Thread):
781 """This thread takes care of writing history to the database, so that
785 """This thread takes care of writing history to the database, so that
782 the UI isn't held up while that happens.
786 the UI isn't held up while that happens.
783
787
784 It waits for the HistoryManager's save_flag to be set, then writes out
788 It waits for the HistoryManager's save_flag to be set, then writes out
785 the history cache. The main thread is responsible for setting the flag when
789 the history cache. The main thread is responsible for setting the flag when
786 the cache size reaches a defined threshold."""
790 the cache size reaches a defined threshold."""
787 daemon = True
791 daemon = True
788 stop_now = False
792 stop_now = False
789 enabled = True
793 enabled = True
790 def __init__(self, history_manager):
794 def __init__(self, history_manager):
791 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
795 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
792 self.history_manager = history_manager
796 self.history_manager = history_manager
793 self.enabled = history_manager.enabled
797 self.enabled = history_manager.enabled
794 atexit.register(self.stop)
798 atexit.register(self.stop)
795
799
796 @only_when_enabled
800 @only_when_enabled
797 def run(self):
801 def run(self):
798 # We need a separate db connection per thread:
802 # We need a separate db connection per thread:
799 try:
803 try:
800 self.db = sqlite3.connect(self.history_manager.hist_file,
804 self.db = sqlite3.connect(self.history_manager.hist_file,
801 **self.history_manager.connection_options
805 **self.history_manager.connection_options
802 )
806 )
803 while True:
807 while True:
804 self.history_manager.save_flag.wait()
808 self.history_manager.save_flag.wait()
805 if self.stop_now:
809 if self.stop_now:
806 self.db.close()
810 self.db.close()
807 return
811 return
808 self.history_manager.save_flag.clear()
812 self.history_manager.save_flag.clear()
809 self.history_manager.writeout_cache(self.db)
813 self.history_manager.writeout_cache(self.db)
810 except Exception as e:
814 except Exception as e:
811 print(("The history saving thread hit an unexpected error (%s)."
815 print(("The history saving thread hit an unexpected error (%s)."
812 "History will not be written to the database.") % repr(e))
816 "History will not be written to the database.") % repr(e))
813
817
814 def stop(self):
818 def stop(self):
815 """This can be called from the main thread to safely stop this thread.
819 """This can be called from the main thread to safely stop this thread.
816
820
817 Note that it does not attempt to write out remaining history before
821 Note that it does not attempt to write out remaining history before
818 exiting. That should be done by calling the HistoryManager's
822 exiting. That should be done by calling the HistoryManager's
819 end_session method."""
823 end_session method."""
820 self.stop_now = True
824 self.stop_now = True
821 self.history_manager.save_flag.set()
825 self.history_manager.save_flag.set()
822 self.join()
826 self.join()
823
827
824
828
825 # To match, e.g. ~5/8-~2/3
829 # To match, e.g. ~5/8-~2/3
826 range_re = re.compile(r"""
830 range_re = re.compile(r"""
827 ((?P<startsess>~?\d+)/)?
831 ((?P<startsess>~?\d+)/)?
828 (?P<start>\d+)?
832 (?P<start>\d+)?
829 ((?P<sep>[\-:])
833 ((?P<sep>[\-:])
830 ((?P<endsess>~?\d+)/)?
834 ((?P<endsess>~?\d+)/)?
831 (?P<end>\d+))?
835 (?P<end>\d+))?
832 $""", re.VERBOSE)
836 $""", re.VERBOSE)
833
837
834
838
835 def extract_hist_ranges(ranges_str):
839 def extract_hist_ranges(ranges_str):
836 """Turn a string of history ranges into 3-tuples of (session, start, stop).
840 """Turn a string of history ranges into 3-tuples of (session, start, stop).
837
841
838 Examples
842 Examples
839 --------
843 --------
840 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
844 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
841 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
845 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
842 """
846 """
843 for range_str in ranges_str.split():
847 for range_str in ranges_str.split():
844 rmatch = range_re.match(range_str)
848 rmatch = range_re.match(range_str)
845 if not rmatch:
849 if not rmatch:
846 continue
850 continue
847 start = rmatch.group("start")
851 start = rmatch.group("start")
848 if start:
852 if start:
849 start = int(start)
853 start = int(start)
850 end = rmatch.group("end")
854 end = rmatch.group("end")
851 # If no end specified, get (a, a + 1)
855 # If no end specified, get (a, a + 1)
852 end = int(end) if end else start + 1
856 end = int(end) if end else start + 1
853 else: # start not specified
857 else: # start not specified
854 if not rmatch.group('startsess'): # no startsess
858 if not rmatch.group('startsess'): # no startsess
855 continue
859 continue
856 start = 1
860 start = 1
857 end = None # provide the entire session hist
861 end = None # provide the entire session hist
858
862
859 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
863 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
860 end += 1
864 end += 1
861 startsess = rmatch.group("startsess") or "0"
865 startsess = rmatch.group("startsess") or "0"
862 endsess = rmatch.group("endsess") or startsess
866 endsess = rmatch.group("endsess") or startsess
863 startsess = int(startsess.replace("~","-"))
867 startsess = int(startsess.replace("~","-"))
864 endsess = int(endsess.replace("~","-"))
868 endsess = int(endsess.replace("~","-"))
865 assert endsess >= startsess, "start session must be earlier than end session"
869 assert endsess >= startsess, "start session must be earlier than end session"
866
870
867 if endsess == startsess:
871 if endsess == startsess:
868 yield (startsess, start, end)
872 yield (startsess, start, end)
869 continue
873 continue
870 # Multiple sessions in one range:
874 # Multiple sessions in one range:
871 yield (startsess, start, None)
875 yield (startsess, start, None)
872 for sess in range(startsess+1, endsess):
876 for sess in range(startsess+1, endsess):
873 yield (sess, 1, None)
877 yield (sess, 1, None)
874 yield (endsess, 1, end)
878 yield (endsess, 1, end)
875
879
876
880
877 def _format_lineno(session, line):
881 def _format_lineno(session, line):
878 """Helper function to format line numbers properly."""
882 """Helper function to format line numbers properly."""
879 if session == 0:
883 if session == 0:
880 return str(line)
884 return str(line)
881 return "%s#%s" % (session, line)
885 return "%s#%s" % (session, line)
@@ -1,214 +1,215 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Tests for the IPython tab-completion machinery.
2 """Tests for the IPython tab-completion machinery.
3 """
3 """
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Module imports
5 # Module imports
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7
7
8 # stdlib
8 # stdlib
9 import io
9 import io
10 import os
10 from pathlib import Path
11 import sys
11 import sys
12 import tempfile
12 import tempfile
13 from datetime import datetime
13 from datetime import datetime
14 import sqlite3
14 import sqlite3
15
15
16 # third party
16 # third party
17 import nose.tools as nt
17 import nose.tools as nt
18
18
19 # our own packages
19 # our own packages
20 from traitlets.config.loader import Config
20 from traitlets.config.loader import Config
21 from IPython.utils.tempdir import TemporaryDirectory
21 from IPython.utils.tempdir import TemporaryDirectory
22 from IPython.core.history import HistoryManager, extract_hist_ranges
22 from IPython.core.history import HistoryManager, extract_hist_ranges
23 from IPython.testing.decorators import skipif
23 from IPython.testing.decorators import skipif
24
24
25 def test_proper_default_encoding():
25 def test_proper_default_encoding():
26 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
26 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
27
27
28 @skipif(sqlite3.sqlite_version_info > (3,24,0))
28 @skipif(sqlite3.sqlite_version_info > (3,24,0))
29 def test_history():
29 def test_history():
30 ip = get_ipython()
30 ip = get_ipython()
31 with TemporaryDirectory() as tmpdir:
31 with TemporaryDirectory() as tmpdir:
32 tmp_path = Path(tmpdir)
32 hist_manager_ori = ip.history_manager
33 hist_manager_ori = ip.history_manager
33 hist_file = os.path.join(tmpdir, 'history.sqlite')
34 hist_file = tmp_path / 'history.sqlite'
34 try:
35 try:
35 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
36 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
36 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
37 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
37 for i, h in enumerate(hist, start=1):
38 for i, h in enumerate(hist, start=1):
38 ip.history_manager.store_inputs(i, h)
39 ip.history_manager.store_inputs(i, h)
39
40
40 ip.history_manager.db_log_output = True
41 ip.history_manager.db_log_output = True
41 # Doesn't match the input, but we'll just check it's stored.
42 # Doesn't match the input, but we'll just check it's stored.
42 ip.history_manager.output_hist_reprs[3] = "spam"
43 ip.history_manager.output_hist_reprs[3] = "spam"
43 ip.history_manager.store_output(3)
44 ip.history_manager.store_output(3)
44
45
45 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
46 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
46
47
47 # Detailed tests for _get_range_session
48 # Detailed tests for _get_range_session
48 grs = ip.history_manager._get_range_session
49 grs = ip.history_manager._get_range_session
49 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
50 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
50 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
51 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
51 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
52 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
52
53
53 # Check whether specifying a range beyond the end of the current
54 # Check whether specifying a range beyond the end of the current
54 # session results in an error (gh-804)
55 # session results in an error (gh-804)
55 ip.magic('%hist 2-500')
56 ip.magic('%hist 2-500')
56
57
57 # Check that we can write non-ascii characters to a file
58 # Check that we can write non-ascii characters to a file
58 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
59 ip.magic("%%hist -f %s" % (tmp_path / "test1"))
59 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
60 ip.magic("%%hist -pf %s" % (tmp_path / "test2"))
60 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
61 ip.magic("%%hist -nf %s" % (tmp_path / "test3"))
61 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
62 ip.magic("%%save %s 1-10" % (tmp_path / "test4"))
62
63
63 # New session
64 # New session
64 ip.history_manager.reset()
65 ip.history_manager.reset()
65 newcmds = [u"z=5",
66 newcmds = [u"z=5",
66 u"class X(object):\n pass",
67 u"class X(object):\n pass",
67 u"k='p'",
68 u"k='p'",
68 u"z=5"]
69 u"z=5"]
69 for i, cmd in enumerate(newcmds, start=1):
70 for i, cmd in enumerate(newcmds, start=1):
70 ip.history_manager.store_inputs(i, cmd)
71 ip.history_manager.store_inputs(i, cmd)
71 gothist = ip.history_manager.get_range(start=1, stop=4)
72 gothist = ip.history_manager.get_range(start=1, stop=4)
72 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
73 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
73 # Previous session:
74 # Previous session:
74 gothist = ip.history_manager.get_range(-1, 1, 4)
75 gothist = ip.history_manager.get_range(-1, 1, 4)
75 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
76 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
76
77
77 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
78 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
78
79
79 # Check get_hist_tail
80 # Check get_hist_tail
80 gothist = ip.history_manager.get_tail(5, output=True,
81 gothist = ip.history_manager.get_tail(5, output=True,
81 include_latest=True)
82 include_latest=True)
82 expected = [(1, 3, (hist[-1], "spam"))] \
83 expected = [(1, 3, (hist[-1], "spam"))] \
83 + [(s, n, (c, None)) for (s, n, c) in newhist]
84 + [(s, n, (c, None)) for (s, n, c) in newhist]
84 nt.assert_equal(list(gothist), expected)
85 nt.assert_equal(list(gothist), expected)
85
86
86 gothist = ip.history_manager.get_tail(2)
87 gothist = ip.history_manager.get_tail(2)
87 expected = newhist[-3:-1]
88 expected = newhist[-3:-1]
88 nt.assert_equal(list(gothist), expected)
89 nt.assert_equal(list(gothist), expected)
89
90
90 # Check get_hist_search
91 # Check get_hist_search
91
92
92 gothist = ip.history_manager.search("*test*")
93 gothist = ip.history_manager.search("*test*")
93 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
94 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
94
95
95 gothist = ip.history_manager.search("*=*")
96 gothist = ip.history_manager.search("*=*")
96 nt.assert_equal(list(gothist),
97 nt.assert_equal(list(gothist),
97 [(1, 1, hist[0]),
98 [(1, 1, hist[0]),
98 (1, 2, hist[1]),
99 (1, 2, hist[1]),
99 (1, 3, hist[2]),
100 (1, 3, hist[2]),
100 newhist[0],
101 newhist[0],
101 newhist[2],
102 newhist[2],
102 newhist[3]])
103 newhist[3]])
103
104
104 gothist = ip.history_manager.search("*=*", n=4)
105 gothist = ip.history_manager.search("*=*", n=4)
105 nt.assert_equal(list(gothist),
106 nt.assert_equal(list(gothist),
106 [(1, 3, hist[2]),
107 [(1, 3, hist[2]),
107 newhist[0],
108 newhist[0],
108 newhist[2],
109 newhist[2],
109 newhist[3]])
110 newhist[3]])
110
111
111 gothist = ip.history_manager.search("*=*", unique=True)
112 gothist = ip.history_manager.search("*=*", unique=True)
112 nt.assert_equal(list(gothist),
113 nt.assert_equal(list(gothist),
113 [(1, 1, hist[0]),
114 [(1, 1, hist[0]),
114 (1, 2, hist[1]),
115 (1, 2, hist[1]),
115 (1, 3, hist[2]),
116 (1, 3, hist[2]),
116 newhist[2],
117 newhist[2],
117 newhist[3]])
118 newhist[3]])
118
119
119 gothist = ip.history_manager.search("*=*", unique=True, n=3)
120 gothist = ip.history_manager.search("*=*", unique=True, n=3)
120 nt.assert_equal(list(gothist),
121 nt.assert_equal(list(gothist),
121 [(1, 3, hist[2]),
122 [(1, 3, hist[2]),
122 newhist[2],
123 newhist[2],
123 newhist[3]])
124 newhist[3]])
124
125
125 gothist = ip.history_manager.search("b*", output=True)
126 gothist = ip.history_manager.search("b*", output=True)
126 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
127 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
127
128
128 # Cross testing: check that magic %save can get previous session.
129 # Cross testing: check that magic %save can get previous session.
129 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
130 testfilename = (tmp_path /"test.py").resolve()
130 ip.magic("save " + testfilename + " ~1/1-3")
131 ip.magic("save " + str(testfilename) + " ~1/1-3")
131 with io.open(testfilename, encoding='utf-8') as testfile:
132 with io.open(testfilename, encoding='utf-8') as testfile:
132 nt.assert_equal(testfile.read(),
133 nt.assert_equal(testfile.read(),
133 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
134 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
134
135
135 # Duplicate line numbers - check that it doesn't crash, and
136 # Duplicate line numbers - check that it doesn't crash, and
136 # gets a new session
137 # gets a new session
137 ip.history_manager.store_inputs(1, "rogue")
138 ip.history_manager.store_inputs(1, "rogue")
138 ip.history_manager.writeout_cache()
139 ip.history_manager.writeout_cache()
139 nt.assert_equal(ip.history_manager.session_number, 3)
140 nt.assert_equal(ip.history_manager.session_number, 3)
140 finally:
141 finally:
141 # Ensure saving thread is shut down before we try to clean up the files
142 # Ensure saving thread is shut down before we try to clean up the files
142 ip.history_manager.save_thread.stop()
143 ip.history_manager.save_thread.stop()
143 # Forcibly close database rather than relying on garbage collection
144 # Forcibly close database rather than relying on garbage collection
144 ip.history_manager.db.close()
145 ip.history_manager.db.close()
145 # Restore history manager
146 # Restore history manager
146 ip.history_manager = hist_manager_ori
147 ip.history_manager = hist_manager_ori
147
148
148
149
149 def test_extract_hist_ranges():
150 def test_extract_hist_ranges():
150 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
151 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
151 expected = [(0, 1, 2), # 0 == current session
152 expected = [(0, 1, 2), # 0 == current session
152 (2, 3, 4),
153 (2, 3, 4),
153 (-4, 5, 7),
154 (-4, 5, 7),
154 (-4, 7, 10),
155 (-4, 7, 10),
155 (-9, 2, None), # None == to end
156 (-9, 2, None), # None == to end
156 (-8, 1, None),
157 (-8, 1, None),
157 (-7, 1, 6),
158 (-7, 1, 6),
158 (-10, 1, None)]
159 (-10, 1, None)]
159 actual = list(extract_hist_ranges(instr))
160 actual = list(extract_hist_ranges(instr))
160 nt.assert_equal(actual, expected)
161 nt.assert_equal(actual, expected)
161
162
162 def test_magic_rerun():
163 def test_magic_rerun():
163 """Simple test for %rerun (no args -> rerun last line)"""
164 """Simple test for %rerun (no args -> rerun last line)"""
164 ip = get_ipython()
165 ip = get_ipython()
165 ip.run_cell("a = 10", store_history=True)
166 ip.run_cell("a = 10", store_history=True)
166 ip.run_cell("a += 1", store_history=True)
167 ip.run_cell("a += 1", store_history=True)
167 nt.assert_equal(ip.user_ns["a"], 11)
168 nt.assert_equal(ip.user_ns["a"], 11)
168 ip.run_cell("%rerun", store_history=True)
169 ip.run_cell("%rerun", store_history=True)
169 nt.assert_equal(ip.user_ns["a"], 12)
170 nt.assert_equal(ip.user_ns["a"], 12)
170
171
171 def test_timestamp_type():
172 def test_timestamp_type():
172 ip = get_ipython()
173 ip = get_ipython()
173 info = ip.history_manager.get_session_info()
174 info = ip.history_manager.get_session_info()
174 nt.assert_true(isinstance(info[1], datetime))
175 nt.assert_true(isinstance(info[1], datetime))
175
176
176 def test_hist_file_config():
177 def test_hist_file_config():
177 cfg = Config()
178 cfg = Config()
178 tfile = tempfile.NamedTemporaryFile(delete=False)
179 tfile = tempfile.NamedTemporaryFile(delete=False)
179 cfg.HistoryManager.hist_file = tfile.name
180 cfg.HistoryManager.hist_file = Path(tfile.name)
180 try:
181 try:
181 hm = HistoryManager(shell=get_ipython(), config=cfg)
182 hm = HistoryManager(shell=get_ipython(), config=cfg)
182 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
183 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
183 finally:
184 finally:
184 try:
185 try:
185 os.remove(tfile.name)
186 Path(tfile.name).unlink()
186 except OSError:
187 except OSError:
187 # same catch as in testing.tools.TempFileMixin
188 # same catch as in testing.tools.TempFileMixin
188 # On Windows, even though we close the file, we still can't
189 # On Windows, even though we close the file, we still can't
189 # delete it. I have no clue why
190 # delete it. I have no clue why
190 pass
191 pass
191
192
192 def test_histmanager_disabled():
193 def test_histmanager_disabled():
193 """Ensure that disabling the history manager doesn't create a database."""
194 """Ensure that disabling the history manager doesn't create a database."""
194 cfg = Config()
195 cfg = Config()
195 cfg.HistoryAccessor.enabled = False
196 cfg.HistoryAccessor.enabled = False
196
197
197 ip = get_ipython()
198 ip = get_ipython()
198 with TemporaryDirectory() as tmpdir:
199 with TemporaryDirectory() as tmpdir:
199 hist_manager_ori = ip.history_manager
200 hist_manager_ori = ip.history_manager
200 hist_file = os.path.join(tmpdir, 'history.sqlite')
201 hist_file = Path(tmpdir) / 'history.sqlite'
201 cfg.HistoryManager.hist_file = hist_file
202 cfg.HistoryManager.hist_file = hist_file
202 try:
203 try:
203 ip.history_manager = HistoryManager(shell=ip, config=cfg)
204 ip.history_manager = HistoryManager(shell=ip, config=cfg)
204 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
205 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
205 for i, h in enumerate(hist, start=1):
206 for i, h in enumerate(hist, start=1):
206 ip.history_manager.store_inputs(i, h)
207 ip.history_manager.store_inputs(i, h)
207 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
208 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
208 ip.history_manager.reset()
209 ip.history_manager.reset()
209 ip.history_manager.end_session()
210 ip.history_manager.end_session()
210 finally:
211 finally:
211 ip.history_manager = hist_manager_ori
212 ip.history_manager = hist_manager_ori
212
213
213 # hist_file should not be created
214 # hist_file should not be created
214 nt.assert_false(os.path.exists(hist_file))
215 nt.assert_false(hist_file.exists())
@@ -1,471 +1,472 b''
1 """Generic testing tools.
1 """Generic testing tools.
2
2
3 Authors
3 Authors
4 -------
4 -------
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 """
6 """
7
7
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 import os
12 import os
13 from pathlib import Path
13 import re
14 import re
14 import sys
15 import sys
15 import tempfile
16 import tempfile
16 import unittest
17 import unittest
17
18
18 from contextlib import contextmanager
19 from contextlib import contextmanager
19 from io import StringIO
20 from io import StringIO
20 from subprocess import Popen, PIPE
21 from subprocess import Popen, PIPE
21 from unittest.mock import patch
22 from unittest.mock import patch
22
23
23 try:
24 try:
24 # These tools are used by parts of the runtime, so we make the nose
25 # These tools are used by parts of the runtime, so we make the nose
25 # dependency optional at this point. Nose is a hard dependency to run the
26 # dependency optional at this point. Nose is a hard dependency to run the
26 # test suite, but NOT to use ipython itself.
27 # test suite, but NOT to use ipython itself.
27 import nose.tools as nt
28 import nose.tools as nt
28 has_nose = True
29 has_nose = True
29 except ImportError:
30 except ImportError:
30 has_nose = False
31 has_nose = False
31
32
32 from traitlets.config.loader import Config
33 from traitlets.config.loader import Config
33 from IPython.utils.process import get_output_error_code
34 from IPython.utils.process import get_output_error_code
34 from IPython.utils.text import list_strings
35 from IPython.utils.text import list_strings
35 from IPython.utils.io import temp_pyfile, Tee
36 from IPython.utils.io import temp_pyfile, Tee
36 from IPython.utils import py3compat
37 from IPython.utils import py3compat
37
38
38 from . import decorators as dec
39 from . import decorators as dec
39 from . import skipdoctest
40 from . import skipdoctest
40
41
41
42
42 # The docstring for full_path doctests differently on win32 (different path
43 # The docstring for full_path doctests differently on win32 (different path
43 # separator) so just skip the doctest there. The example remains informative.
44 # separator) so just skip the doctest there. The example remains informative.
44 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45
46
46 @doctest_deco
47 @doctest_deco
47 def full_path(startPath,files):
48 def full_path(startPath,files):
48 """Make full paths for all the listed files, based on startPath.
49 """Make full paths for all the listed files, based on startPath.
49
50
50 Only the base part of startPath is kept, since this routine is typically
51 Only the base part of startPath is kept, since this routine is typically
51 used with a script's ``__file__`` variable as startPath. The base of startPath
52 used with a script's ``__file__`` variable as startPath. The base of startPath
52 is then prepended to all the listed files, forming the output list.
53 is then prepended to all the listed files, forming the output list.
53
54
54 Parameters
55 Parameters
55 ----------
56 ----------
56 startPath : string
57 startPath : string
57 Initial path to use as the base for the results. This path is split
58 Initial path to use as the base for the results. This path is split
58 using os.path.split() and only its first component is kept.
59 using os.path.split() and only its first component is kept.
59
60
60 files : string or list
61 files : string or list
61 One or more files.
62 One or more files.
62
63
63 Examples
64 Examples
64 --------
65 --------
65
66
66 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 ['/foo/a.txt', '/foo/b.txt']
68 ['/foo/a.txt', '/foo/b.txt']
68
69
69 >>> full_path('/foo',['a.txt','b.txt'])
70 >>> full_path('/foo',['a.txt','b.txt'])
70 ['/a.txt', '/b.txt']
71 ['/a.txt', '/b.txt']
71
72
72 If a single file is given, the output is still a list::
73 If a single file is given, the output is still a list::
73
74
74 >>> full_path('/foo','a.txt')
75 >>> full_path('/foo','a.txt')
75 ['/a.txt']
76 ['/a.txt']
76 """
77 """
77
78
78 files = list_strings(files)
79 files = list_strings(files)
79 base = os.path.split(startPath)[0]
80 base = os.path.split(startPath)[0]
80 return [ os.path.join(base,f) for f in files ]
81 return [ os.path.join(base,f) for f in files ]
81
82
82
83
83 def parse_test_output(txt):
84 def parse_test_output(txt):
84 """Parse the output of a test run and return errors, failures.
85 """Parse the output of a test run and return errors, failures.
85
86
86 Parameters
87 Parameters
87 ----------
88 ----------
88 txt : str
89 txt : str
89 Text output of a test run, assumed to contain a line of one of the
90 Text output of a test run, assumed to contain a line of one of the
90 following forms::
91 following forms::
91
92
92 'FAILED (errors=1)'
93 'FAILED (errors=1)'
93 'FAILED (failures=1)'
94 'FAILED (failures=1)'
94 'FAILED (errors=1, failures=1)'
95 'FAILED (errors=1, failures=1)'
95
96
96 Returns
97 Returns
97 -------
98 -------
98 nerr, nfail
99 nerr, nfail
99 number of errors and failures.
100 number of errors and failures.
100 """
101 """
101
102
102 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 if err_m:
104 if err_m:
104 nerr = int(err_m.group(1))
105 nerr = int(err_m.group(1))
105 nfail = 0
106 nfail = 0
106 return nerr, nfail
107 return nerr, nfail
107
108
108 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 if fail_m:
110 if fail_m:
110 nerr = 0
111 nerr = 0
111 nfail = int(fail_m.group(1))
112 nfail = int(fail_m.group(1))
112 return nerr, nfail
113 return nerr, nfail
113
114
114 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 re.MULTILINE)
116 re.MULTILINE)
116 if both_m:
117 if both_m:
117 nerr = int(both_m.group(1))
118 nerr = int(both_m.group(1))
118 nfail = int(both_m.group(2))
119 nfail = int(both_m.group(2))
119 return nerr, nfail
120 return nerr, nfail
120
121
121 # If the input didn't match any of these forms, assume no error/failures
122 # If the input didn't match any of these forms, assume no error/failures
122 return 0, 0
123 return 0, 0
123
124
124
125
125 # So nose doesn't think this is a test
126 # So nose doesn't think this is a test
126 parse_test_output.__test__ = False
127 parse_test_output.__test__ = False
127
128
128
129
129 def default_argv():
130 def default_argv():
130 """Return a valid default argv for creating testing instances of ipython"""
131 """Return a valid default argv for creating testing instances of ipython"""
131
132
132 return ['--quick', # so no config file is loaded
133 return ['--quick', # so no config file is loaded
133 # Other defaults to minimize side effects on stdout
134 # Other defaults to minimize side effects on stdout
134 '--colors=NoColor', '--no-term-title','--no-banner',
135 '--colors=NoColor', '--no-term-title','--no-banner',
135 '--autocall=0']
136 '--autocall=0']
136
137
137
138
138 def default_config():
139 def default_config():
139 """Return a config object with good defaults for testing."""
140 """Return a config object with good defaults for testing."""
140 config = Config()
141 config = Config()
141 config.TerminalInteractiveShell.colors = 'NoColor'
142 config.TerminalInteractiveShell.colors = 'NoColor'
142 config.TerminalTerminalInteractiveShell.term_title = False,
143 config.TerminalTerminalInteractiveShell.term_title = False,
143 config.TerminalInteractiveShell.autocall = 0
144 config.TerminalInteractiveShell.autocall = 0
144 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 config.HistoryManager.hist_file = f.name
146 config.HistoryManager.hist_file = Path(f.name)
146 f.close()
147 f.close()
147 config.HistoryManager.db_cache_size = 10000
148 config.HistoryManager.db_cache_size = 10000
148 return config
149 return config
149
150
150
151
151 def get_ipython_cmd(as_string=False):
152 def get_ipython_cmd(as_string=False):
152 """
153 """
153 Return appropriate IPython command line name. By default, this will return
154 Return appropriate IPython command line name. By default, this will return
154 a list that can be used with subprocess.Popen, for example, but passing
155 a list that can be used with subprocess.Popen, for example, but passing
155 `as_string=True` allows for returning the IPython command as a string.
156 `as_string=True` allows for returning the IPython command as a string.
156
157
157 Parameters
158 Parameters
158 ----------
159 ----------
159 as_string: bool
160 as_string: bool
160 Flag to allow to return the command as a string.
161 Flag to allow to return the command as a string.
161 """
162 """
162 ipython_cmd = [sys.executable, "-m", "IPython"]
163 ipython_cmd = [sys.executable, "-m", "IPython"]
163
164
164 if as_string:
165 if as_string:
165 ipython_cmd = " ".join(ipython_cmd)
166 ipython_cmd = " ".join(ipython_cmd)
166
167
167 return ipython_cmd
168 return ipython_cmd
168
169
169 def ipexec(fname, options=None, commands=()):
170 def ipexec(fname, options=None, commands=()):
170 """Utility to call 'ipython filename'.
171 """Utility to call 'ipython filename'.
171
172
172 Starts IPython with a minimal and safe configuration to make startup as fast
173 Starts IPython with a minimal and safe configuration to make startup as fast
173 as possible.
174 as possible.
174
175
175 Note that this starts IPython in a subprocess!
176 Note that this starts IPython in a subprocess!
176
177
177 Parameters
178 Parameters
178 ----------
179 ----------
179 fname : str
180 fname : str
180 Name of file to be executed (should have .py or .ipy extension).
181 Name of file to be executed (should have .py or .ipy extension).
181
182
182 options : optional, list
183 options : optional, list
183 Extra command-line flags to be passed to IPython.
184 Extra command-line flags to be passed to IPython.
184
185
185 commands : optional, list
186 commands : optional, list
186 Commands to send in on stdin
187 Commands to send in on stdin
187
188
188 Returns
189 Returns
189 -------
190 -------
190 ``(stdout, stderr)`` of ipython subprocess.
191 ``(stdout, stderr)`` of ipython subprocess.
191 """
192 """
192 if options is None: options = []
193 if options is None: options = []
193
194
194 cmdargs = default_argv() + options
195 cmdargs = default_argv() + options
195
196
196 test_dir = os.path.dirname(__file__)
197 test_dir = os.path.dirname(__file__)
197
198
198 ipython_cmd = get_ipython_cmd()
199 ipython_cmd = get_ipython_cmd()
199 # Absolute path for filename
200 # Absolute path for filename
200 full_fname = os.path.join(test_dir, fname)
201 full_fname = os.path.join(test_dir, fname)
201 full_cmd = ipython_cmd + cmdargs + ['--', full_fname]
202 full_cmd = ipython_cmd + cmdargs + ['--', full_fname]
202 env = os.environ.copy()
203 env = os.environ.copy()
203 # FIXME: ignore all warnings in ipexec while we have shims
204 # FIXME: ignore all warnings in ipexec while we have shims
204 # should we keep suppressing warnings here, even after removing shims?
205 # should we keep suppressing warnings here, even after removing shims?
205 env['PYTHONWARNINGS'] = 'ignore'
206 env['PYTHONWARNINGS'] = 'ignore'
206 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 for k, v in env.items():
208 for k, v in env.items():
208 # Debug a bizarre failure we've seen on Windows:
209 # Debug a bizarre failure we've seen on Windows:
209 # TypeError: environment can only contain strings
210 # TypeError: environment can only contain strings
210 if not isinstance(v, str):
211 if not isinstance(v, str):
211 print(k, v)
212 print(k, v)
212 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 out, err = p.communicate(input=py3compat.encode('\n'.join(commands)) or None)
214 out, err = p.communicate(input=py3compat.encode('\n'.join(commands)) or None)
214 out, err = py3compat.decode(out), py3compat.decode(err)
215 out, err = py3compat.decode(out), py3compat.decode(err)
215 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 # so strip that out before doing comparisons
217 # so strip that out before doing comparisons
217 if out:
218 if out:
218 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 return out, err
220 return out, err
220
221
221
222
222 def ipexec_validate(fname, expected_out, expected_err='',
223 def ipexec_validate(fname, expected_out, expected_err='',
223 options=None, commands=()):
224 options=None, commands=()):
224 """Utility to call 'ipython filename' and validate output/error.
225 """Utility to call 'ipython filename' and validate output/error.
225
226
226 This function raises an AssertionError if the validation fails.
227 This function raises an AssertionError if the validation fails.
227
228
228 Note that this starts IPython in a subprocess!
229 Note that this starts IPython in a subprocess!
229
230
230 Parameters
231 Parameters
231 ----------
232 ----------
232 fname : str
233 fname : str
233 Name of the file to be executed (should have .py or .ipy extension).
234 Name of the file to be executed (should have .py or .ipy extension).
234
235
235 expected_out : str
236 expected_out : str
236 Expected stdout of the process.
237 Expected stdout of the process.
237
238
238 expected_err : optional, str
239 expected_err : optional, str
239 Expected stderr of the process.
240 Expected stderr of the process.
240
241
241 options : optional, list
242 options : optional, list
242 Extra command-line flags to be passed to IPython.
243 Extra command-line flags to be passed to IPython.
243
244
244 Returns
245 Returns
245 -------
246 -------
246 None
247 None
247 """
248 """
248
249
249 import nose.tools as nt
250 import nose.tools as nt
250
251
251 out, err = ipexec(fname, options, commands)
252 out, err = ipexec(fname, options, commands)
252 #print 'OUT', out # dbg
253 #print 'OUT', out # dbg
253 #print 'ERR', err # dbg
254 #print 'ERR', err # dbg
254 # If there are any errors, we must check those before stdout, as they may be
255 # If there are any errors, we must check those before stdout, as they may be
255 # more informative than simply having an empty stdout.
256 # more informative than simply having an empty stdout.
256 if err:
257 if err:
257 if expected_err:
258 if expected_err:
258 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 else:
260 else:
260 raise ValueError('Running file %r produced error: %r' %
261 raise ValueError('Running file %r produced error: %r' %
261 (fname, err))
262 (fname, err))
262 # If no errors or output on stderr was expected, match stdout
263 # If no errors or output on stderr was expected, match stdout
263 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264
265
265
266
266 class TempFileMixin(unittest.TestCase):
267 class TempFileMixin(unittest.TestCase):
267 """Utility class to create temporary Python/IPython files.
268 """Utility class to create temporary Python/IPython files.
268
269
269 Meant as a mixin class for test cases."""
270 Meant as a mixin class for test cases."""
270
271
271 def mktmp(self, src, ext='.py'):
272 def mktmp(self, src, ext='.py'):
272 """Make a valid python temp file."""
273 """Make a valid python temp file."""
273 fname = temp_pyfile(src, ext)
274 fname = temp_pyfile(src, ext)
274 if not hasattr(self, 'tmps'):
275 if not hasattr(self, 'tmps'):
275 self.tmps=[]
276 self.tmps=[]
276 self.tmps.append(fname)
277 self.tmps.append(fname)
277 self.fname = fname
278 self.fname = fname
278
279
279 def tearDown(self):
280 def tearDown(self):
280 # If the tmpfile wasn't made because of skipped tests, like in
281 # If the tmpfile wasn't made because of skipped tests, like in
281 # win32, there's nothing to cleanup.
282 # win32, there's nothing to cleanup.
282 if hasattr(self, 'tmps'):
283 if hasattr(self, 'tmps'):
283 for fname in self.tmps:
284 for fname in self.tmps:
284 # If the tmpfile wasn't made because of skipped tests, like in
285 # If the tmpfile wasn't made because of skipped tests, like in
285 # win32, there's nothing to cleanup.
286 # win32, there's nothing to cleanup.
286 try:
287 try:
287 os.unlink(fname)
288 os.unlink(fname)
288 except:
289 except:
289 # On Windows, even though we close the file, we still can't
290 # On Windows, even though we close the file, we still can't
290 # delete it. I have no clue why
291 # delete it. I have no clue why
291 pass
292 pass
292
293
293 def __enter__(self):
294 def __enter__(self):
294 return self
295 return self
295
296
296 def __exit__(self, exc_type, exc_value, traceback):
297 def __exit__(self, exc_type, exc_value, traceback):
297 self.tearDown()
298 self.tearDown()
298
299
299
300
300 pair_fail_msg = ("Testing {0}\n\n"
301 pair_fail_msg = ("Testing {0}\n\n"
301 "In:\n"
302 "In:\n"
302 " {1!r}\n"
303 " {1!r}\n"
303 "Expected:\n"
304 "Expected:\n"
304 " {2!r}\n"
305 " {2!r}\n"
305 "Got:\n"
306 "Got:\n"
306 " {3!r}\n")
307 " {3!r}\n")
307 def check_pairs(func, pairs):
308 def check_pairs(func, pairs):
308 """Utility function for the common case of checking a function with a
309 """Utility function for the common case of checking a function with a
309 sequence of input/output pairs.
310 sequence of input/output pairs.
310
311
311 Parameters
312 Parameters
312 ----------
313 ----------
313 func : callable
314 func : callable
314 The function to be tested. Should accept a single argument.
315 The function to be tested. Should accept a single argument.
315 pairs : iterable
316 pairs : iterable
316 A list of (input, expected_output) tuples.
317 A list of (input, expected_output) tuples.
317
318
318 Returns
319 Returns
319 -------
320 -------
320 None. Raises an AssertionError if any output does not match the expected
321 None. Raises an AssertionError if any output does not match the expected
321 value.
322 value.
322 """
323 """
323 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
324 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
324 for inp, expected in pairs:
325 for inp, expected in pairs:
325 out = func(inp)
326 out = func(inp)
326 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
327 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
327
328
328
329
329 MyStringIO = StringIO
330 MyStringIO = StringIO
330
331
331 _re_type = type(re.compile(r''))
332 _re_type = type(re.compile(r''))
332
333
333 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 -------
335 -------
335 {2!s}
336 {2!s}
336 -------
337 -------
337 """
338 """
338
339
339 class AssertPrints(object):
340 class AssertPrints(object):
340 """Context manager for testing that code prints certain text.
341 """Context manager for testing that code prints certain text.
341
342
342 Examples
343 Examples
343 --------
344 --------
344 >>> with AssertPrints("abc", suppress=False):
345 >>> with AssertPrints("abc", suppress=False):
345 ... print("abcd")
346 ... print("abcd")
346 ... print("def")
347 ... print("def")
347 ...
348 ...
348 abcd
349 abcd
349 def
350 def
350 """
351 """
351 def __init__(self, s, channel='stdout', suppress=True):
352 def __init__(self, s, channel='stdout', suppress=True):
352 self.s = s
353 self.s = s
353 if isinstance(self.s, (str, _re_type)):
354 if isinstance(self.s, (str, _re_type)):
354 self.s = [self.s]
355 self.s = [self.s]
355 self.channel = channel
356 self.channel = channel
356 self.suppress = suppress
357 self.suppress = suppress
357
358
358 def __enter__(self):
359 def __enter__(self):
359 self.orig_stream = getattr(sys, self.channel)
360 self.orig_stream = getattr(sys, self.channel)
360 self.buffer = MyStringIO()
361 self.buffer = MyStringIO()
361 self.tee = Tee(self.buffer, channel=self.channel)
362 self.tee = Tee(self.buffer, channel=self.channel)
362 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363
364
364 def __exit__(self, etype, value, traceback):
365 def __exit__(self, etype, value, traceback):
365 try:
366 try:
366 if value is not None:
367 if value is not None:
367 # If an error was raised, don't check anything else
368 # If an error was raised, don't check anything else
368 return False
369 return False
369 self.tee.flush()
370 self.tee.flush()
370 setattr(sys, self.channel, self.orig_stream)
371 setattr(sys, self.channel, self.orig_stream)
371 printed = self.buffer.getvalue()
372 printed = self.buffer.getvalue()
372 for s in self.s:
373 for s in self.s:
373 if isinstance(s, _re_type):
374 if isinstance(s, _re_type):
374 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
375 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
375 else:
376 else:
376 assert s in printed, notprinted_msg.format(s, self.channel, printed)
377 assert s in printed, notprinted_msg.format(s, self.channel, printed)
377 return False
378 return False
378 finally:
379 finally:
379 self.tee.close()
380 self.tee.close()
380
381
381 printed_msg = """Found {0!r} in printed output (on {1}):
382 printed_msg = """Found {0!r} in printed output (on {1}):
382 -------
383 -------
383 {2!s}
384 {2!s}
384 -------
385 -------
385 """
386 """
386
387
387 class AssertNotPrints(AssertPrints):
388 class AssertNotPrints(AssertPrints):
388 """Context manager for checking that certain output *isn't* produced.
389 """Context manager for checking that certain output *isn't* produced.
389
390
390 Counterpart of AssertPrints"""
391 Counterpart of AssertPrints"""
391 def __exit__(self, etype, value, traceback):
392 def __exit__(self, etype, value, traceback):
392 try:
393 try:
393 if value is not None:
394 if value is not None:
394 # If an error was raised, don't check anything else
395 # If an error was raised, don't check anything else
395 self.tee.close()
396 self.tee.close()
396 return False
397 return False
397 self.tee.flush()
398 self.tee.flush()
398 setattr(sys, self.channel, self.orig_stream)
399 setattr(sys, self.channel, self.orig_stream)
399 printed = self.buffer.getvalue()
400 printed = self.buffer.getvalue()
400 for s in self.s:
401 for s in self.s:
401 if isinstance(s, _re_type):
402 if isinstance(s, _re_type):
402 assert not s.search(printed),printed_msg.format(
403 assert not s.search(printed),printed_msg.format(
403 s.pattern, self.channel, printed)
404 s.pattern, self.channel, printed)
404 else:
405 else:
405 assert s not in printed, printed_msg.format(
406 assert s not in printed, printed_msg.format(
406 s, self.channel, printed)
407 s, self.channel, printed)
407 return False
408 return False
408 finally:
409 finally:
409 self.tee.close()
410 self.tee.close()
410
411
411 @contextmanager
412 @contextmanager
412 def mute_warn():
413 def mute_warn():
413 from IPython.utils import warn
414 from IPython.utils import warn
414 save_warn = warn.warn
415 save_warn = warn.warn
415 warn.warn = lambda *a, **kw: None
416 warn.warn = lambda *a, **kw: None
416 try:
417 try:
417 yield
418 yield
418 finally:
419 finally:
419 warn.warn = save_warn
420 warn.warn = save_warn
420
421
421 @contextmanager
422 @contextmanager
422 def make_tempfile(name):
423 def make_tempfile(name):
423 """ Create an empty, named, temporary file for the duration of the context.
424 """ Create an empty, named, temporary file for the duration of the context.
424 """
425 """
425 open(name, 'w').close()
426 open(name, 'w').close()
426 try:
427 try:
427 yield
428 yield
428 finally:
429 finally:
429 os.unlink(name)
430 os.unlink(name)
430
431
431 def fake_input(inputs):
432 def fake_input(inputs):
432 """Temporarily replace the input() function to return the given values
433 """Temporarily replace the input() function to return the given values
433
434
434 Use as a context manager:
435 Use as a context manager:
435
436
436 with fake_input(['result1', 'result2']):
437 with fake_input(['result1', 'result2']):
437 ...
438 ...
438
439
439 Values are returned in order. If input() is called again after the last value
440 Values are returned in order. If input() is called again after the last value
440 was used, EOFError is raised.
441 was used, EOFError is raised.
441 """
442 """
442 it = iter(inputs)
443 it = iter(inputs)
443 def mock_input(prompt=''):
444 def mock_input(prompt=''):
444 try:
445 try:
445 return next(it)
446 return next(it)
446 except StopIteration as e:
447 except StopIteration as e:
447 raise EOFError('No more inputs given') from e
448 raise EOFError('No more inputs given') from e
448
449
449 return patch('builtins.input', mock_input)
450 return patch('builtins.input', mock_input)
450
451
451 def help_output_test(subcommand=''):
452 def help_output_test(subcommand=''):
452 """test that `ipython [subcommand] -h` works"""
453 """test that `ipython [subcommand] -h` works"""
453 cmd = get_ipython_cmd() + [subcommand, '-h']
454 cmd = get_ipython_cmd() + [subcommand, '-h']
454 out, err, rc = get_output_error_code(cmd)
455 out, err, rc = get_output_error_code(cmd)
455 nt.assert_equal(rc, 0, err)
456 nt.assert_equal(rc, 0, err)
456 nt.assert_not_in("Traceback", err)
457 nt.assert_not_in("Traceback", err)
457 nt.assert_in("Options", out)
458 nt.assert_in("Options", out)
458 nt.assert_in("--help-all", out)
459 nt.assert_in("--help-all", out)
459 return out, err
460 return out, err
460
461
461
462
462 def help_all_output_test(subcommand=''):
463 def help_all_output_test(subcommand=''):
463 """test that `ipython [subcommand] --help-all` works"""
464 """test that `ipython [subcommand] --help-all` works"""
464 cmd = get_ipython_cmd() + [subcommand, '--help-all']
465 cmd = get_ipython_cmd() + [subcommand, '--help-all']
465 out, err, rc = get_output_error_code(cmd)
466 out, err, rc = get_output_error_code(cmd)
466 nt.assert_equal(rc, 0, err)
467 nt.assert_equal(rc, 0, err)
467 nt.assert_not_in("Traceback", err)
468 nt.assert_not_in("Traceback", err)
468 nt.assert_in("Options", out)
469 nt.assert_in("Options", out)
469 nt.assert_in("Class", out)
470 nt.assert_in("Class", out)
470 return out, err
471 return out, err
471
472
General Comments 0
You need to be logged in to leave comments. Login now