##// END OF EJS Templates
Update history.py to use pathlib...
Jakub Klus -
Show More
@@ -1,881 +1,885 b''
1 1 """ History related magics and functionality """
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6
7 7 import atexit
8 8 import datetime
9 import os
9 from pathlib import Path
10 10 import re
11 11 import sqlite3
12 12 import threading
13 13
14 14 from traitlets.config.configurable import LoggingConfigurable
15 15 from decorator import decorator
16 16 from IPython.utils.decorators import undoc
17 17 from IPython.paths import locate_profile
18 18 from traitlets import (
19 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
20 default, observe,
19 Any, Bool, Dict, Instance, Integer, List, Unicode, Union, TraitError,
20 default, observe
21 21 )
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Classes and functions
25 25 #-----------------------------------------------------------------------------
26 26
27 27 @undoc
28 28 class DummyDB(object):
29 29 """Dummy DB that will act as a black hole for history.
30
30
31 31 Only used in the absence of sqlite"""
32 32 def execute(*args, **kwargs):
33 33 return []
34
34
35 35 def commit(self, *args, **kwargs):
36 36 pass
37
37
38 38 def __enter__(self, *args, **kwargs):
39 39 pass
40
40
41 41 def __exit__(self, *args, **kwargs):
42 42 pass
43 43
44 44
45 45 @decorator
46 46 def only_when_enabled(f, self, *a, **kw):
47 47 """Decorator: return an empty list in the absence of sqlite."""
48 48 if not self.enabled:
49 49 return []
50 50 else:
51 51 return f(self, *a, **kw)
52 52
53 53
54 54 # use 16kB as threshold for whether a corrupt history db should be saved
55 55 # that should be at least 100 entries or so
56 56 _SAVE_DB_SIZE = 16384
57 57
58 58 @decorator
59 59 def catch_corrupt_db(f, self, *a, **kw):
60 60 """A decorator which wraps HistoryAccessor method calls to catch errors from
61 61 a corrupt SQLite database, move the old database out of the way, and create
62 62 a new one.
63 63
64 64 We avoid clobbering larger databases because this may be triggered due to filesystem issues,
65 65 not just a corrupt file.
66 66 """
67 67 try:
68 68 return f(self, *a, **kw)
69 69 except (sqlite3.DatabaseError, sqlite3.OperationalError) as e:
70 70 self._corrupt_db_counter += 1
71 71 self.log.error("Failed to open SQLite history %s (%s).", self.hist_file, e)
72 72 if self.hist_file != ':memory:':
73 73 if self._corrupt_db_counter > self._corrupt_db_limit:
74 74 self.hist_file = ':memory:'
75 75 self.log.error("Failed to load history too many times, history will not be saved.")
76 elif os.path.isfile(self.hist_file):
76 elif self.hist_file.is_file():
77 77 # move the file out of the way
78 base, ext = os.path.splitext(self.hist_file)
79 size = os.stat(self.hist_file).st_size
78 base = str(self.hist_file.parent / self.hist_file.stem)
79 ext = self.hist_file.suffix
80 size = self.hist_file.stat().st_size
80 81 if size >= _SAVE_DB_SIZE:
81 82 # if there's significant content, avoid clobbering
82 83 now = datetime.datetime.now().isoformat().replace(':', '.')
83 84 newpath = base + '-corrupt-' + now + ext
84 85 # don't clobber previous corrupt backups
85 86 for i in range(100):
86 if not os.path.isfile(newpath):
87 if not Path(newpath).exists():
87 88 break
88 89 else:
89 90 newpath = base + '-corrupt-' + now + (u'-%i' % i) + ext
90 91 else:
91 92 # not much content, possibly empty; don't worry about clobbering
92 93 # maybe we should just delete it?
93 94 newpath = base + '-corrupt' + ext
94 os.rename(self.hist_file, newpath)
95 self.hist_file.rename(newpath)
95 96 self.log.error("History file was moved to %s and a new file created.", newpath)
96 97 self.init_db()
97 98 return []
98 99 else:
99 100 # Failed with :memory:, something serious is wrong
100 101 raise
101
102
102 103 class HistoryAccessorBase(LoggingConfigurable):
103 104 """An abstract class for History Accessors """
104 105
105 106 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
106 107 raise NotImplementedError
107 108
108 109 def search(self, pattern="*", raw=True, search_raw=True,
109 110 output=False, n=None, unique=False):
110 111 raise NotImplementedError
111 112
112 113 def get_range(self, session, start=1, stop=None, raw=True,output=False):
113 114 raise NotImplementedError
114 115
115 116 def get_range_by_str(self, rangestr, raw=True, output=False):
116 117 raise NotImplementedError
117 118
118 119
119 120 class HistoryAccessor(HistoryAccessorBase):
120 121 """Access the history database without adding to it.
121
122
122 123 This is intended for use by standalone history tools. IPython shells use
123 124 HistoryManager, below, which is a subclass of this."""
124 125
125 126 # counter for init_db retries, so we don't keep trying over and over
126 127 _corrupt_db_counter = 0
127 128 # after two failures, fallback on :memory:
128 129 _corrupt_db_limit = 2
129 130
130 131 # String holding the path to the history file
131 hist_file = Unicode(
132 hist_file = Union([Instance(Path), Unicode()],
132 133 help="""Path to file to use for SQLite history database.
133
134
134 135 By default, IPython will put the history database in the IPython
135 136 profile directory. If you would rather share one history among
136 137 profiles, you can set this value in each, so that they are consistent.
137
138
138 139 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
139 140 mounts. If you see IPython hanging, try setting this to something on a
140 141 local disk, e.g::
141
142
142 143 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
143 144
144 145 you can also use the specific value `:memory:` (including the colon
145 146 at both end but not the back ticks), to avoid creating an history file.
146
147 """).tag(config=True)
148
147
148 """
149 ).tag(config=True)
150
149 151 enabled = Bool(True,
150 152 help="""enable the SQLite history
151
153
152 154 set enabled=False to disable the SQLite history,
153 155 in which case there will be no stored history, no SQLite connection,
154 156 and no background saving thread. This may be necessary in some
155 157 threaded environments where IPython is embedded.
156 158 """
157 159 ).tag(config=True)
158
160
159 161 connection_options = Dict(
160 162 help="""Options for configuring the SQLite connection
161
163
162 164 These options are passed as keyword args to sqlite3.connect
163 165 when establishing database connections.
164 166 """
165 167 ).tag(config=True)
166 168
167 169 # The SQLite database
168 170 db = Any()
169 171 @observe('db')
170 172 def _db_changed(self, change):
171 173 """validate the db, since it can be an Instance of two different types"""
172 174 new = change['new']
173 175 connection_types = (DummyDB, sqlite3.Connection)
174 176 if not isinstance(new, connection_types):
175 177 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
176 178 (self.__class__.__name__, new)
177 179 raise TraitError(msg)
178
179 def __init__(self, profile='default', hist_file=u'', **traits):
180
181 def __init__(self, profile='default', hist_file="", **traits):
180 182 """Create a new history accessor.
181
183
182 184 Parameters
183 185 ----------
184 186 profile : str
185 187 The name of the profile from which to open history.
186 188 hist_file : str
187 189 Path to an SQLite history database stored by IPython. If specified,
188 190 hist_file overrides profile.
189 191 config : :class:`~traitlets.config.loader.Config`
190 192 Config object. hist_file can also be set through this.
191 193 """
192 194 # We need a pointer back to the shell for various tasks.
193 195 super(HistoryAccessor, self).__init__(**traits)
194 196 # defer setting hist_file from kwarg until after init,
195 197 # otherwise the default kwarg value would clobber any value
196 198 # set by config
197 199 if hist_file:
198 200 self.hist_file = hist_file
199
200 if self.hist_file == u'':
201
202 try:
203 self.hist_file
204 except TraitError:
201 205 # No one has set the hist_file, yet.
202 206 self.hist_file = self._get_hist_file_name(profile)
203
207
204 208 self.init_db()
205
209
206 210 def _get_hist_file_name(self, profile='default'):
207 211 """Find the history file for the given profile name.
208
212
209 213 This is overridden by the HistoryManager subclass, to use the shell's
210 214 active profile.
211
215
212 216 Parameters
213 217 ----------
214 218 profile : str
215 219 The name of a profile which has a history file.
216 220 """
217 return os.path.join(locate_profile(profile), 'history.sqlite')
218
221 return Path(locate_profile(profile)) / 'history.sqlite'
222
219 223 @catch_corrupt_db
220 224 def init_db(self):
221 225 """Connect to the database, and create tables if necessary."""
222 226 if not self.enabled:
223 227 self.db = DummyDB()
224 228 return
225
229
226 230 # use detect_types so that timestamps return datetime objects
227 231 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
228 232 kwargs.update(self.connection_options)
229 233 self.db = sqlite3.connect(self.hist_file, **kwargs)
230 234 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
231 235 primary key autoincrement, start timestamp,
232 236 end timestamp, num_cmds integer, remark text)""")
233 237 self.db.execute("""CREATE TABLE IF NOT EXISTS history
234 238 (session integer, line integer, source text, source_raw text,
235 239 PRIMARY KEY (session, line))""")
236 240 # Output history is optional, but ensure the table's there so it can be
237 241 # enabled later.
238 242 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
239 243 (session integer, line integer, output text,
240 244 PRIMARY KEY (session, line))""")
241 245 self.db.commit()
242 246 # success! reset corrupt db count
243 247 self._corrupt_db_counter = 0
244 248
245 249 def writeout_cache(self):
246 250 """Overridden by HistoryManager to dump the cache before certain
247 251 database lookups."""
248 252 pass
249 253
250 254 ## -------------------------------
251 255 ## Methods for retrieving history:
252 256 ## -------------------------------
253 257 def _run_sql(self, sql, params, raw=True, output=False):
254 258 """Prepares and runs an SQL query for the history database.
255 259
256 260 Parameters
257 261 ----------
258 262 sql : str
259 263 Any filtering expressions to go after SELECT ... FROM ...
260 264 params : tuple
261 265 Parameters passed to the SQL query (to replace "?")
262 266 raw, output : bool
263 267 See :meth:`get_range`
264 268
265 269 Returns
266 270 -------
267 271 Tuples as :meth:`get_range`
268 272 """
269 273 toget = 'source_raw' if raw else 'source'
270 274 sqlfrom = "history"
271 275 if output:
272 276 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
273 277 toget = "history.%s, output_history.output" % toget
274 278 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
275 279 (toget, sqlfrom) + sql, params)
276 280 if output: # Regroup into 3-tuples, and parse JSON
277 281 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
278 282 return cur
279 283
280 284 @only_when_enabled
281 285 @catch_corrupt_db
282 286 def get_session_info(self, session):
283 287 """Get info about a session.
284 288
285 289 Parameters
286 290 ----------
287 291
288 292 session : int
289 293 Session number to retrieve.
290 294
291 295 Returns
292 296 -------
293
297
294 298 session_id : int
295 299 Session ID number
296 300 start : datetime
297 301 Timestamp for the start of the session.
298 302 end : datetime
299 303 Timestamp for the end of the session, or None if IPython crashed.
300 304 num_cmds : int
301 305 Number of commands run, or None if IPython crashed.
302 306 remark : unicode
303 307 A manually set description.
304 308 """
305 309 query = "SELECT * from sessions where session == ?"
306 310 return self.db.execute(query, (session,)).fetchone()
307 311
308 312 @catch_corrupt_db
309 313 def get_last_session_id(self):
310 314 """Get the last session ID currently in the database.
311
315
312 316 Within IPython, this should be the same as the value stored in
313 317 :attr:`HistoryManager.session_number`.
314 318 """
315 319 for record in self.get_tail(n=1, include_latest=True):
316 320 return record[0]
317 321
318 322 @catch_corrupt_db
319 323 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
320 324 """Get the last n lines from the history database.
321 325
322 326 Parameters
323 327 ----------
324 328 n : int
325 329 The number of lines to get
326 330 raw, output : bool
327 331 See :meth:`get_range`
328 332 include_latest : bool
329 333 If False (default), n+1 lines are fetched, and the latest one
330 334 is discarded. This is intended to be used where the function
331 335 is called by a user command, which it should not return.
332 336
333 337 Returns
334 338 -------
335 339 Tuples as :meth:`get_range`
336 340 """
337 341 self.writeout_cache()
338 342 if not include_latest:
339 343 n += 1
340 344 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
341 345 (n,), raw=raw, output=output)
342 346 if not include_latest:
343 347 return reversed(list(cur)[1:])
344 348 return reversed(list(cur))
345 349
346 350 @catch_corrupt_db
347 351 def search(self, pattern="*", raw=True, search_raw=True,
348 352 output=False, n=None, unique=False):
349 353 """Search the database using unix glob-style matching (wildcards
350 354 * and ?).
351 355
352 356 Parameters
353 357 ----------
354 358 pattern : str
355 359 The wildcarded pattern to match when searching
356 360 search_raw : bool
357 361 If True, search the raw input, otherwise, the parsed input
358 362 raw, output : bool
359 363 See :meth:`get_range`
360 364 n : None or int
361 365 If an integer is given, it defines the limit of
362 366 returned entries.
363 367 unique : bool
364 368 When it is true, return only unique entries.
365 369
366 370 Returns
367 371 -------
368 372 Tuples as :meth:`get_range`
369 373 """
370 374 tosearch = "source_raw" if search_raw else "source"
371 375 if output:
372 376 tosearch = "history." + tosearch
373 377 self.writeout_cache()
374 378 sqlform = "WHERE %s GLOB ?" % tosearch
375 379 params = (pattern,)
376 380 if unique:
377 381 sqlform += ' GROUP BY {0}'.format(tosearch)
378 382 if n is not None:
379 383 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
380 384 params += (n,)
381 385 elif unique:
382 386 sqlform += " ORDER BY session, line"
383 387 cur = self._run_sql(sqlform, params, raw=raw, output=output)
384 388 if n is not None:
385 389 return reversed(list(cur))
386 390 return cur
387
391
388 392 @catch_corrupt_db
389 393 def get_range(self, session, start=1, stop=None, raw=True,output=False):
390 394 """Retrieve input by session.
391 395
392 396 Parameters
393 397 ----------
394 398 session : int
395 399 Session number to retrieve.
396 400 start : int
397 401 First line to retrieve.
398 402 stop : int
399 403 End of line range (excluded from output itself). If None, retrieve
400 404 to the end of the session.
401 405 raw : bool
402 406 If True, return untranslated input
403 407 output : bool
404 408 If True, attempt to include output. This will be 'real' Python
405 409 objects for the current session, or text reprs from previous
406 410 sessions if db_log_output was enabled at the time. Where no output
407 411 is found, None is used.
408 412
409 413 Returns
410 414 -------
411 415 entries
412 416 An iterator over the desired lines. Each line is a 3-tuple, either
413 417 (session, line, input) if output is False, or
414 418 (session, line, (input, output)) if output is True.
415 419 """
416 420 if stop:
417 421 lineclause = "line >= ? AND line < ?"
418 422 params = (session, start, stop)
419 423 else:
420 424 lineclause = "line>=?"
421 425 params = (session, start)
422 426
423 427 return self._run_sql("WHERE session==? AND %s" % lineclause,
424 428 params, raw=raw, output=output)
425 429
426 430 def get_range_by_str(self, rangestr, raw=True, output=False):
427 431 """Get lines of history from a string of ranges, as used by magic
428 432 commands %hist, %save, %macro, etc.
429 433
430 434 Parameters
431 435 ----------
432 436 rangestr : str
433 437 A string specifying ranges, e.g. "5 ~2/1-4". See
434 438 :func:`magic_history` for full details.
435 439 raw, output : bool
436 440 As :meth:`get_range`
437 441
438 442 Returns
439 443 -------
440 444 Tuples as :meth:`get_range`
441 445 """
442 446 for sess, s, e in extract_hist_ranges(rangestr):
443 447 for line in self.get_range(sess, s, e, raw=raw, output=output):
444 448 yield line
445 449
446 450
447 451 class HistoryManager(HistoryAccessor):
448 452 """A class to organize all history-related functionality in one place.
449 453 """
450 454 # Public interface
451 455
452 456 # An instance of the IPython shell we are attached to
453 457 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
454 458 allow_none=True)
455 459 # Lists to hold processed and raw history. These start with a blank entry
456 460 # so that we can index them starting from 1
457 461 input_hist_parsed = List([""])
458 462 input_hist_raw = List([""])
459 463 # A list of directories visited during session
460 464 dir_hist = List()
461 465 @default('dir_hist')
462 466 def _dir_hist_default(self):
463 467 try:
464 return [os.getcwd()]
468 return [Path.cwd()]
465 469 except OSError:
466 470 return []
467 471
468 472 # A dict of output history, keyed with ints from the shell's
469 473 # execution count.
470 474 output_hist = Dict()
471 475 # The text/plain repr of outputs.
472 476 output_hist_reprs = Dict()
473 477
474 478 # The number of the current session in the history database
475 479 session_number = Integer()
476
480
477 481 db_log_output = Bool(False,
478 482 help="Should the history database include output? (default: no)"
479 483 ).tag(config=True)
480 484 db_cache_size = Integer(0,
481 485 help="Write to database every x commands (higher values save disk access & power).\n"
482 486 "Values of 1 or less effectively disable caching."
483 487 ).tag(config=True)
484 488 # The input and output caches
485 489 db_input_cache = List()
486 490 db_output_cache = List()
487
491
488 492 # History saving in separate thread
489 493 save_thread = Instance('IPython.core.history.HistorySavingThread',
490 494 allow_none=True)
491 495 save_flag = Instance(threading.Event, allow_none=True)
492
496
493 497 # Private interface
494 498 # Variables used to store the three last inputs from the user. On each new
495 499 # history update, we populate the user's namespace with these, shifted as
496 500 # necessary.
497 501 _i00 = Unicode(u'')
498 502 _i = Unicode(u'')
499 503 _ii = Unicode(u'')
500 504 _iii = Unicode(u'')
501 505
502 506 # A regex matching all forms of the exit command, so that we don't store
503 507 # them in the history (it's annoying to rewind the first entry and land on
504 508 # an exit call).
505 509 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
506 510
507 511 def __init__(self, shell=None, config=None, **traits):
508 512 """Create a new history manager associated with a shell instance.
509 513 """
510 514 # We need a pointer back to the shell for various tasks.
511 515 super(HistoryManager, self).__init__(shell=shell, config=config,
512 516 **traits)
513 517 self.save_flag = threading.Event()
514 518 self.db_input_cache_lock = threading.Lock()
515 519 self.db_output_cache_lock = threading.Lock()
516
520
517 521 try:
518 522 self.new_session()
519 523 except sqlite3.OperationalError:
520 524 self.log.error("Failed to create history session in %s. History will not be saved.",
521 525 self.hist_file, exc_info=True)
522 526 self.hist_file = ':memory:'
523
527
524 528 if self.enabled and self.hist_file != ':memory:':
525 529 self.save_thread = HistorySavingThread(self)
526 530 self.save_thread.start()
527 531
528 532 def _get_hist_file_name(self, profile=None):
529 533 """Get default history file name based on the Shell's profile.
530
534
531 535 The profile parameter is ignored, but must exist for compatibility with
532 536 the parent class."""
533 537 profile_dir = self.shell.profile_dir.location
534 return os.path.join(profile_dir, 'history.sqlite')
535
538 return Path(profile_dir)/'history.sqlite'
539
536 540 @only_when_enabled
537 541 def new_session(self, conn=None):
538 542 """Get a new session number."""
539 543 if conn is None:
540 544 conn = self.db
541
545
542 546 with conn:
543 547 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
544 548 NULL, "") """, (datetime.datetime.now(),))
545 549 self.session_number = cur.lastrowid
546
550
547 551 def end_session(self):
548 552 """Close the database session, filling in the end time and line count."""
549 553 self.writeout_cache()
550 554 with self.db:
551 555 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
552 556 session==?""", (datetime.datetime.now(),
553 557 len(self.input_hist_parsed)-1, self.session_number))
554 558 self.session_number = 0
555
559
556 560 def name_session(self, name):
557 561 """Give the current session a name in the history database."""
558 562 with self.db:
559 563 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
560 564 (name, self.session_number))
561
565
562 566 def reset(self, new_session=True):
563 567 """Clear the session history, releasing all object references, and
564 568 optionally open a new session."""
565 569 self.output_hist.clear()
566 570 # The directory history can't be completely empty
567 self.dir_hist[:] = [os.getcwd()]
568
571 self.dir_hist[:] = [Path.cwd()]
572
569 573 if new_session:
570 574 if self.session_number:
571 575 self.end_session()
572 576 self.input_hist_parsed[:] = [""]
573 577 self.input_hist_raw[:] = [""]
574 578 self.new_session()
575 579
576 580 # ------------------------------
577 581 # Methods for retrieving history
578 582 # ------------------------------
579 583 def get_session_info(self, session=0):
580 584 """Get info about a session.
581 585
582 586 Parameters
583 587 ----------
584 588
585 589 session : int
586 590 Session number to retrieve. The current session is 0, and negative
587 591 numbers count back from current session, so -1 is the previous session.
588 592
589 593 Returns
590 594 -------
591
595
592 596 session_id : int
593 597 Session ID number
594 598 start : datetime
595 599 Timestamp for the start of the session.
596 600 end : datetime
597 601 Timestamp for the end of the session, or None if IPython crashed.
598 602 num_cmds : int
599 603 Number of commands run, or None if IPython crashed.
600 604 remark : unicode
601 605 A manually set description.
602 606 """
603 607 if session <= 0:
604 608 session += self.session_number
605 609
606 610 return super(HistoryManager, self).get_session_info(session=session)
607 611
608 612 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
609 613 """Get input and output history from the current session. Called by
610 614 get_range, and takes similar parameters."""
611 615 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
612
616
613 617 n = len(input_hist)
614 618 if start < 0:
615 619 start += n
616 620 if not stop or (stop > n):
617 621 stop = n
618 622 elif stop < 0:
619 623 stop += n
620
624
621 625 for i in range(start, stop):
622 626 if output:
623 627 line = (input_hist[i], self.output_hist_reprs.get(i))
624 628 else:
625 629 line = input_hist[i]
626 630 yield (0, i, line)
627
631
628 632 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
629 633 """Retrieve input by session.
630
634
631 635 Parameters
632 636 ----------
633 637 session : int
634 638 Session number to retrieve. The current session is 0, and negative
635 639 numbers count back from current session, so -1 is previous session.
636 640 start : int
637 641 First line to retrieve.
638 642 stop : int
639 643 End of line range (excluded from output itself). If None, retrieve
640 644 to the end of the session.
641 645 raw : bool
642 646 If True, return untranslated input
643 647 output : bool
644 648 If True, attempt to include output. This will be 'real' Python
645 649 objects for the current session, or text reprs from previous
646 650 sessions if db_log_output was enabled at the time. Where no output
647 651 is found, None is used.
648
652
649 653 Returns
650 654 -------
651 655 entries
652 656 An iterator over the desired lines. Each line is a 3-tuple, either
653 657 (session, line, input) if output is False, or
654 658 (session, line, (input, output)) if output is True.
655 659 """
656 660 if session <= 0:
657 661 session += self.session_number
658 662 if session==self.session_number: # Current session
659 663 return self._get_range_session(start, stop, raw, output)
660 664 return super(HistoryManager, self).get_range(session, start, stop, raw,
661 665 output)
662 666
663 667 ## ----------------------------
664 668 ## Methods for storing history:
665 669 ## ----------------------------
666 670 def store_inputs(self, line_num, source, source_raw=None):
667 671 """Store source and raw input in history and create input cache
668 672 variables ``_i*``.
669 673
670 674 Parameters
671 675 ----------
672 676 line_num : int
673 677 The prompt number of this input.
674 678
675 679 source : str
676 680 Python input.
677 681
678 682 source_raw : str, optional
679 683 If given, this is the raw input without any IPython transformations
680 684 applied to it. If not given, ``source`` is used.
681 685 """
682 686 if source_raw is None:
683 687 source_raw = source
684 688 source = source.rstrip('\n')
685 689 source_raw = source_raw.rstrip('\n')
686 690
687 691 # do not store exit/quit commands
688 692 if self._exit_re.match(source_raw.strip()):
689 693 return
690 694
691 695 self.input_hist_parsed.append(source)
692 696 self.input_hist_raw.append(source_raw)
693 697
694 698 with self.db_input_cache_lock:
695 699 self.db_input_cache.append((line_num, source, source_raw))
696 700 # Trigger to flush cache and write to DB.
697 701 if len(self.db_input_cache) >= self.db_cache_size:
698 702 self.save_flag.set()
699 703
700 704 # update the auto _i variables
701 705 self._iii = self._ii
702 706 self._ii = self._i
703 707 self._i = self._i00
704 708 self._i00 = source_raw
705 709
706 710 # hackish access to user namespace to create _i1,_i2... dynamically
707 711 new_i = '_i%s' % line_num
708 712 to_main = {'_i': self._i,
709 713 '_ii': self._ii,
710 714 '_iii': self._iii,
711 715 new_i : self._i00 }
712
716
713 717 if self.shell is not None:
714 718 self.shell.push(to_main, interactive=False)
715 719
716 720 def store_output(self, line_num):
717 721 """If database output logging is enabled, this saves all the
718 722 outputs from the indicated prompt number to the database. It's
719 723 called by run_cell after code has been executed.
720 724
721 725 Parameters
722 726 ----------
723 727 line_num : int
724 728 The line number from which to save outputs
725 729 """
726 730 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
727 731 return
728 732 output = self.output_hist_reprs[line_num]
729 733
730 734 with self.db_output_cache_lock:
731 735 self.db_output_cache.append((line_num, output))
732 736 if self.db_cache_size <= 1:
733 737 self.save_flag.set()
734 738
735 739 def _writeout_input_cache(self, conn):
736 740 with conn:
737 741 for line in self.db_input_cache:
738 742 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
739 743 (self.session_number,)+line)
740 744
741 745 def _writeout_output_cache(self, conn):
742 746 with conn:
743 747 for line in self.db_output_cache:
744 748 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
745 749 (self.session_number,)+line)
746 750
747 751 @only_when_enabled
748 752 def writeout_cache(self, conn=None):
749 753 """Write any entries in the cache to the database."""
750 754 if conn is None:
751 755 conn = self.db
752 756
753 757 with self.db_input_cache_lock:
754 758 try:
755 759 self._writeout_input_cache(conn)
756 760 except sqlite3.IntegrityError:
757 761 self.new_session(conn)
758 762 print("ERROR! Session/line number was not unique in",
759 763 "database. History logging moved to new session",
760 764 self.session_number)
761 765 try:
762 766 # Try writing to the new session. If this fails, don't
763 767 # recurse
764 768 self._writeout_input_cache(conn)
765 769 except sqlite3.IntegrityError:
766 770 pass
767 771 finally:
768 772 self.db_input_cache = []
769 773
770 774 with self.db_output_cache_lock:
771 775 try:
772 776 self._writeout_output_cache(conn)
773 777 except sqlite3.IntegrityError:
774 778 print("!! Session/line number for output was not unique",
775 779 "in database. Output will not be stored.")
776 780 finally:
777 781 self.db_output_cache = []
778 782
779 783
780 784 class HistorySavingThread(threading.Thread):
781 785 """This thread takes care of writing history to the database, so that
782 786 the UI isn't held up while that happens.
783 787
784 788 It waits for the HistoryManager's save_flag to be set, then writes out
785 789 the history cache. The main thread is responsible for setting the flag when
786 790 the cache size reaches a defined threshold."""
787 791 daemon = True
788 792 stop_now = False
789 793 enabled = True
790 794 def __init__(self, history_manager):
791 795 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
792 796 self.history_manager = history_manager
793 797 self.enabled = history_manager.enabled
794 798 atexit.register(self.stop)
795 799
796 800 @only_when_enabled
797 801 def run(self):
798 802 # We need a separate db connection per thread:
799 803 try:
800 804 self.db = sqlite3.connect(self.history_manager.hist_file,
801 805 **self.history_manager.connection_options
802 806 )
803 807 while True:
804 808 self.history_manager.save_flag.wait()
805 809 if self.stop_now:
806 810 self.db.close()
807 811 return
808 812 self.history_manager.save_flag.clear()
809 813 self.history_manager.writeout_cache(self.db)
810 814 except Exception as e:
811 815 print(("The history saving thread hit an unexpected error (%s)."
812 816 "History will not be written to the database.") % repr(e))
813 817
814 818 def stop(self):
815 819 """This can be called from the main thread to safely stop this thread.
816 820
817 821 Note that it does not attempt to write out remaining history before
818 822 exiting. That should be done by calling the HistoryManager's
819 823 end_session method."""
820 824 self.stop_now = True
821 825 self.history_manager.save_flag.set()
822 826 self.join()
823 827
824 828
825 829 # To match, e.g. ~5/8-~2/3
826 830 range_re = re.compile(r"""
827 831 ((?P<startsess>~?\d+)/)?
828 832 (?P<start>\d+)?
829 833 ((?P<sep>[\-:])
830 834 ((?P<endsess>~?\d+)/)?
831 835 (?P<end>\d+))?
832 836 $""", re.VERBOSE)
833 837
834 838
835 839 def extract_hist_ranges(ranges_str):
836 840 """Turn a string of history ranges into 3-tuples of (session, start, stop).
837 841
838 842 Examples
839 843 --------
840 844 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
841 845 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
842 846 """
843 847 for range_str in ranges_str.split():
844 848 rmatch = range_re.match(range_str)
845 849 if not rmatch:
846 850 continue
847 851 start = rmatch.group("start")
848 852 if start:
849 853 start = int(start)
850 854 end = rmatch.group("end")
851 855 # If no end specified, get (a, a + 1)
852 856 end = int(end) if end else start + 1
853 857 else: # start not specified
854 858 if not rmatch.group('startsess'): # no startsess
855 859 continue
856 860 start = 1
857 861 end = None # provide the entire session hist
858 862
859 863 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
860 864 end += 1
861 865 startsess = rmatch.group("startsess") or "0"
862 866 endsess = rmatch.group("endsess") or startsess
863 867 startsess = int(startsess.replace("~","-"))
864 868 endsess = int(endsess.replace("~","-"))
865 869 assert endsess >= startsess, "start session must be earlier than end session"
866 870
867 871 if endsess == startsess:
868 872 yield (startsess, start, end)
869 873 continue
870 874 # Multiple sessions in one range:
871 875 yield (startsess, start, None)
872 876 for sess in range(startsess+1, endsess):
873 877 yield (sess, 1, None)
874 878 yield (endsess, 1, end)
875 879
876 880
877 881 def _format_lineno(session, line):
878 882 """Helper function to format line numbers properly."""
879 883 if session == 0:
880 884 return str(line)
881 885 return "%s#%s" % (session, line)
@@ -1,214 +1,215 b''
1 1 # coding: utf-8
2 2 """Tests for the IPython tab-completion machinery.
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Module imports
6 6 #-----------------------------------------------------------------------------
7 7
8 8 # stdlib
9 9 import io
10 import os
10 from pathlib import Path
11 11 import sys
12 12 import tempfile
13 13 from datetime import datetime
14 14 import sqlite3
15 15
16 16 # third party
17 17 import nose.tools as nt
18 18
19 19 # our own packages
20 20 from traitlets.config.loader import Config
21 21 from IPython.utils.tempdir import TemporaryDirectory
22 22 from IPython.core.history import HistoryManager, extract_hist_ranges
23 23 from IPython.testing.decorators import skipif
24 24
25 25 def test_proper_default_encoding():
26 26 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
27 27
28 28 @skipif(sqlite3.sqlite_version_info > (3,24,0))
29 29 def test_history():
30 30 ip = get_ipython()
31 31 with TemporaryDirectory() as tmpdir:
32 tmp_path = Path(tmpdir)
32 33 hist_manager_ori = ip.history_manager
33 hist_file = os.path.join(tmpdir, 'history.sqlite')
34 hist_file = tmp_path / 'history.sqlite'
34 35 try:
35 36 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
36 37 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
37 38 for i, h in enumerate(hist, start=1):
38 39 ip.history_manager.store_inputs(i, h)
39 40
40 41 ip.history_manager.db_log_output = True
41 42 # Doesn't match the input, but we'll just check it's stored.
42 43 ip.history_manager.output_hist_reprs[3] = "spam"
43 44 ip.history_manager.store_output(3)
44 45
45 46 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
46 47
47 48 # Detailed tests for _get_range_session
48 49 grs = ip.history_manager._get_range_session
49 50 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
50 51 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
51 52 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
52 53
53 54 # Check whether specifying a range beyond the end of the current
54 55 # session results in an error (gh-804)
55 56 ip.magic('%hist 2-500')
56 57
57 58 # Check that we can write non-ascii characters to a file
58 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
59 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
60 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
61 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
59 ip.magic("%%hist -f %s" % (tmp_path / "test1"))
60 ip.magic("%%hist -pf %s" % (tmp_path / "test2"))
61 ip.magic("%%hist -nf %s" % (tmp_path / "test3"))
62 ip.magic("%%save %s 1-10" % (tmp_path / "test4"))
62 63
63 64 # New session
64 65 ip.history_manager.reset()
65 66 newcmds = [u"z=5",
66 67 u"class X(object):\n pass",
67 68 u"k='p'",
68 69 u"z=5"]
69 70 for i, cmd in enumerate(newcmds, start=1):
70 71 ip.history_manager.store_inputs(i, cmd)
71 72 gothist = ip.history_manager.get_range(start=1, stop=4)
72 73 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
73 74 # Previous session:
74 75 gothist = ip.history_manager.get_range(-1, 1, 4)
75 76 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
76 77
77 78 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
78 79
79 80 # Check get_hist_tail
80 81 gothist = ip.history_manager.get_tail(5, output=True,
81 82 include_latest=True)
82 83 expected = [(1, 3, (hist[-1], "spam"))] \
83 84 + [(s, n, (c, None)) for (s, n, c) in newhist]
84 85 nt.assert_equal(list(gothist), expected)
85 86
86 87 gothist = ip.history_manager.get_tail(2)
87 88 expected = newhist[-3:-1]
88 89 nt.assert_equal(list(gothist), expected)
89 90
90 91 # Check get_hist_search
91 92
92 93 gothist = ip.history_manager.search("*test*")
93 94 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
94 95
95 96 gothist = ip.history_manager.search("*=*")
96 97 nt.assert_equal(list(gothist),
97 98 [(1, 1, hist[0]),
98 99 (1, 2, hist[1]),
99 100 (1, 3, hist[2]),
100 101 newhist[0],
101 102 newhist[2],
102 103 newhist[3]])
103 104
104 105 gothist = ip.history_manager.search("*=*", n=4)
105 106 nt.assert_equal(list(gothist),
106 107 [(1, 3, hist[2]),
107 108 newhist[0],
108 109 newhist[2],
109 110 newhist[3]])
110 111
111 112 gothist = ip.history_manager.search("*=*", unique=True)
112 113 nt.assert_equal(list(gothist),
113 114 [(1, 1, hist[0]),
114 115 (1, 2, hist[1]),
115 116 (1, 3, hist[2]),
116 117 newhist[2],
117 118 newhist[3]])
118 119
119 120 gothist = ip.history_manager.search("*=*", unique=True, n=3)
120 121 nt.assert_equal(list(gothist),
121 122 [(1, 3, hist[2]),
122 123 newhist[2],
123 124 newhist[3]])
124 125
125 126 gothist = ip.history_manager.search("b*", output=True)
126 127 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
127 128
128 129 # Cross testing: check that magic %save can get previous session.
129 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
130 ip.magic("save " + testfilename + " ~1/1-3")
130 testfilename = (tmp_path /"test.py").resolve()
131 ip.magic("save " + str(testfilename) + " ~1/1-3")
131 132 with io.open(testfilename, encoding='utf-8') as testfile:
132 133 nt.assert_equal(testfile.read(),
133 134 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
134 135
135 136 # Duplicate line numbers - check that it doesn't crash, and
136 137 # gets a new session
137 138 ip.history_manager.store_inputs(1, "rogue")
138 139 ip.history_manager.writeout_cache()
139 140 nt.assert_equal(ip.history_manager.session_number, 3)
140 141 finally:
141 142 # Ensure saving thread is shut down before we try to clean up the files
142 143 ip.history_manager.save_thread.stop()
143 144 # Forcibly close database rather than relying on garbage collection
144 145 ip.history_manager.db.close()
145 146 # Restore history manager
146 147 ip.history_manager = hist_manager_ori
147 148
148 149
149 150 def test_extract_hist_ranges():
150 151 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
151 152 expected = [(0, 1, 2), # 0 == current session
152 153 (2, 3, 4),
153 154 (-4, 5, 7),
154 155 (-4, 7, 10),
155 156 (-9, 2, None), # None == to end
156 157 (-8, 1, None),
157 158 (-7, 1, 6),
158 159 (-10, 1, None)]
159 160 actual = list(extract_hist_ranges(instr))
160 161 nt.assert_equal(actual, expected)
161 162
162 163 def test_magic_rerun():
163 164 """Simple test for %rerun (no args -> rerun last line)"""
164 165 ip = get_ipython()
165 166 ip.run_cell("a = 10", store_history=True)
166 167 ip.run_cell("a += 1", store_history=True)
167 168 nt.assert_equal(ip.user_ns["a"], 11)
168 169 ip.run_cell("%rerun", store_history=True)
169 170 nt.assert_equal(ip.user_ns["a"], 12)
170 171
171 172 def test_timestamp_type():
172 173 ip = get_ipython()
173 174 info = ip.history_manager.get_session_info()
174 175 nt.assert_true(isinstance(info[1], datetime))
175 176
176 177 def test_hist_file_config():
177 178 cfg = Config()
178 179 tfile = tempfile.NamedTemporaryFile(delete=False)
179 cfg.HistoryManager.hist_file = tfile.name
180 cfg.HistoryManager.hist_file = Path(tfile.name)
180 181 try:
181 182 hm = HistoryManager(shell=get_ipython(), config=cfg)
182 183 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
183 184 finally:
184 185 try:
185 os.remove(tfile.name)
186 Path(tfile.name).unlink()
186 187 except OSError:
187 188 # same catch as in testing.tools.TempFileMixin
188 189 # On Windows, even though we close the file, we still can't
189 190 # delete it. I have no clue why
190 191 pass
191 192
192 193 def test_histmanager_disabled():
193 194 """Ensure that disabling the history manager doesn't create a database."""
194 195 cfg = Config()
195 196 cfg.HistoryAccessor.enabled = False
196 197
197 198 ip = get_ipython()
198 199 with TemporaryDirectory() as tmpdir:
199 200 hist_manager_ori = ip.history_manager
200 hist_file = os.path.join(tmpdir, 'history.sqlite')
201 hist_file = Path(tmpdir) / 'history.sqlite'
201 202 cfg.HistoryManager.hist_file = hist_file
202 203 try:
203 204 ip.history_manager = HistoryManager(shell=ip, config=cfg)
204 205 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='β‚¬Γ†ΒΎΓ·ΓŸ'"]
205 206 for i, h in enumerate(hist, start=1):
206 207 ip.history_manager.store_inputs(i, h)
207 208 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
208 209 ip.history_manager.reset()
209 210 ip.history_manager.end_session()
210 211 finally:
211 212 ip.history_manager = hist_manager_ori
212 213
213 214 # hist_file should not be created
214 nt.assert_false(os.path.exists(hist_file))
215 nt.assert_false(hist_file.exists())
@@ -1,471 +1,472 b''
1 1 """Generic testing tools.
2 2
3 3 Authors
4 4 -------
5 5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 6 """
7 7
8 8
9 9 # Copyright (c) IPython Development Team.
10 10 # Distributed under the terms of the Modified BSD License.
11 11
12 12 import os
13 from pathlib import Path
13 14 import re
14 15 import sys
15 16 import tempfile
16 17 import unittest
17 18
18 19 from contextlib import contextmanager
19 20 from io import StringIO
20 21 from subprocess import Popen, PIPE
21 22 from unittest.mock import patch
22 23
23 24 try:
24 25 # These tools are used by parts of the runtime, so we make the nose
25 26 # dependency optional at this point. Nose is a hard dependency to run the
26 27 # test suite, but NOT to use ipython itself.
27 28 import nose.tools as nt
28 29 has_nose = True
29 30 except ImportError:
30 31 has_nose = False
31 32
32 33 from traitlets.config.loader import Config
33 34 from IPython.utils.process import get_output_error_code
34 35 from IPython.utils.text import list_strings
35 36 from IPython.utils.io import temp_pyfile, Tee
36 37 from IPython.utils import py3compat
37 38
38 39 from . import decorators as dec
39 40 from . import skipdoctest
40 41
41 42
42 43 # The docstring for full_path doctests differently on win32 (different path
43 44 # separator) so just skip the doctest there. The example remains informative.
44 45 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45 46
46 47 @doctest_deco
47 48 def full_path(startPath,files):
48 49 """Make full paths for all the listed files, based on startPath.
49 50
50 51 Only the base part of startPath is kept, since this routine is typically
51 52 used with a script's ``__file__`` variable as startPath. The base of startPath
52 53 is then prepended to all the listed files, forming the output list.
53 54
54 55 Parameters
55 56 ----------
56 57 startPath : string
57 58 Initial path to use as the base for the results. This path is split
58 59 using os.path.split() and only its first component is kept.
59 60
60 61 files : string or list
61 62 One or more files.
62 63
63 64 Examples
64 65 --------
65 66
66 67 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 68 ['/foo/a.txt', '/foo/b.txt']
68 69
69 70 >>> full_path('/foo',['a.txt','b.txt'])
70 71 ['/a.txt', '/b.txt']
71 72
72 73 If a single file is given, the output is still a list::
73 74
74 75 >>> full_path('/foo','a.txt')
75 76 ['/a.txt']
76 77 """
77 78
78 79 files = list_strings(files)
79 80 base = os.path.split(startPath)[0]
80 81 return [ os.path.join(base,f) for f in files ]
81 82
82 83
83 84 def parse_test_output(txt):
84 85 """Parse the output of a test run and return errors, failures.
85 86
86 87 Parameters
87 88 ----------
88 89 txt : str
89 90 Text output of a test run, assumed to contain a line of one of the
90 91 following forms::
91 92
92 93 'FAILED (errors=1)'
93 94 'FAILED (failures=1)'
94 95 'FAILED (errors=1, failures=1)'
95 96
96 97 Returns
97 98 -------
98 99 nerr, nfail
99 100 number of errors and failures.
100 101 """
101 102
102 103 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 104 if err_m:
104 105 nerr = int(err_m.group(1))
105 106 nfail = 0
106 107 return nerr, nfail
107 108
108 109 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 110 if fail_m:
110 111 nerr = 0
111 112 nfail = int(fail_m.group(1))
112 113 return nerr, nfail
113 114
114 115 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 116 re.MULTILINE)
116 117 if both_m:
117 118 nerr = int(both_m.group(1))
118 119 nfail = int(both_m.group(2))
119 120 return nerr, nfail
120 121
121 122 # If the input didn't match any of these forms, assume no error/failures
122 123 return 0, 0
123 124
124 125
125 126 # So nose doesn't think this is a test
126 127 parse_test_output.__test__ = False
127 128
128 129
129 130 def default_argv():
130 131 """Return a valid default argv for creating testing instances of ipython"""
131 132
132 133 return ['--quick', # so no config file is loaded
133 134 # Other defaults to minimize side effects on stdout
134 135 '--colors=NoColor', '--no-term-title','--no-banner',
135 136 '--autocall=0']
136 137
137 138
138 139 def default_config():
139 140 """Return a config object with good defaults for testing."""
140 141 config = Config()
141 142 config.TerminalInteractiveShell.colors = 'NoColor'
142 143 config.TerminalTerminalInteractiveShell.term_title = False,
143 144 config.TerminalInteractiveShell.autocall = 0
144 145 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 config.HistoryManager.hist_file = f.name
146 config.HistoryManager.hist_file = Path(f.name)
146 147 f.close()
147 148 config.HistoryManager.db_cache_size = 10000
148 149 return config
149 150
150 151
151 152 def get_ipython_cmd(as_string=False):
152 153 """
153 154 Return appropriate IPython command line name. By default, this will return
154 155 a list that can be used with subprocess.Popen, for example, but passing
155 156 `as_string=True` allows for returning the IPython command as a string.
156 157
157 158 Parameters
158 159 ----------
159 160 as_string: bool
160 161 Flag to allow to return the command as a string.
161 162 """
162 163 ipython_cmd = [sys.executable, "-m", "IPython"]
163 164
164 165 if as_string:
165 166 ipython_cmd = " ".join(ipython_cmd)
166 167
167 168 return ipython_cmd
168 169
169 170 def ipexec(fname, options=None, commands=()):
170 171 """Utility to call 'ipython filename'.
171 172
172 173 Starts IPython with a minimal and safe configuration to make startup as fast
173 174 as possible.
174 175
175 176 Note that this starts IPython in a subprocess!
176 177
177 178 Parameters
178 179 ----------
179 180 fname : str
180 181 Name of file to be executed (should have .py or .ipy extension).
181 182
182 183 options : optional, list
183 184 Extra command-line flags to be passed to IPython.
184 185
185 186 commands : optional, list
186 187 Commands to send in on stdin
187 188
188 189 Returns
189 190 -------
190 191 ``(stdout, stderr)`` of ipython subprocess.
191 192 """
192 193 if options is None: options = []
193 194
194 195 cmdargs = default_argv() + options
195 196
196 197 test_dir = os.path.dirname(__file__)
197 198
198 199 ipython_cmd = get_ipython_cmd()
199 200 # Absolute path for filename
200 201 full_fname = os.path.join(test_dir, fname)
201 202 full_cmd = ipython_cmd + cmdargs + ['--', full_fname]
202 203 env = os.environ.copy()
203 204 # FIXME: ignore all warnings in ipexec while we have shims
204 205 # should we keep suppressing warnings here, even after removing shims?
205 206 env['PYTHONWARNINGS'] = 'ignore'
206 207 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 208 for k, v in env.items():
208 209 # Debug a bizarre failure we've seen on Windows:
209 210 # TypeError: environment can only contain strings
210 211 if not isinstance(v, str):
211 212 print(k, v)
212 213 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 214 out, err = p.communicate(input=py3compat.encode('\n'.join(commands)) or None)
214 215 out, err = py3compat.decode(out), py3compat.decode(err)
215 216 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 217 # so strip that out before doing comparisons
217 218 if out:
218 219 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 220 return out, err
220 221
221 222
222 223 def ipexec_validate(fname, expected_out, expected_err='',
223 224 options=None, commands=()):
224 225 """Utility to call 'ipython filename' and validate output/error.
225 226
226 227 This function raises an AssertionError if the validation fails.
227 228
228 229 Note that this starts IPython in a subprocess!
229 230
230 231 Parameters
231 232 ----------
232 233 fname : str
233 234 Name of the file to be executed (should have .py or .ipy extension).
234 235
235 236 expected_out : str
236 237 Expected stdout of the process.
237 238
238 239 expected_err : optional, str
239 240 Expected stderr of the process.
240 241
241 242 options : optional, list
242 243 Extra command-line flags to be passed to IPython.
243 244
244 245 Returns
245 246 -------
246 247 None
247 248 """
248 249
249 250 import nose.tools as nt
250 251
251 252 out, err = ipexec(fname, options, commands)
252 253 #print 'OUT', out # dbg
253 254 #print 'ERR', err # dbg
254 255 # If there are any errors, we must check those before stdout, as they may be
255 256 # more informative than simply having an empty stdout.
256 257 if err:
257 258 if expected_err:
258 259 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 260 else:
260 261 raise ValueError('Running file %r produced error: %r' %
261 262 (fname, err))
262 263 # If no errors or output on stderr was expected, match stdout
263 264 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264 265
265 266
266 267 class TempFileMixin(unittest.TestCase):
267 268 """Utility class to create temporary Python/IPython files.
268 269
269 270 Meant as a mixin class for test cases."""
270 271
271 272 def mktmp(self, src, ext='.py'):
272 273 """Make a valid python temp file."""
273 274 fname = temp_pyfile(src, ext)
274 275 if not hasattr(self, 'tmps'):
275 276 self.tmps=[]
276 277 self.tmps.append(fname)
277 278 self.fname = fname
278 279
279 280 def tearDown(self):
280 281 # If the tmpfile wasn't made because of skipped tests, like in
281 282 # win32, there's nothing to cleanup.
282 283 if hasattr(self, 'tmps'):
283 284 for fname in self.tmps:
284 285 # If the tmpfile wasn't made because of skipped tests, like in
285 286 # win32, there's nothing to cleanup.
286 287 try:
287 288 os.unlink(fname)
288 289 except:
289 290 # On Windows, even though we close the file, we still can't
290 291 # delete it. I have no clue why
291 292 pass
292 293
293 294 def __enter__(self):
294 295 return self
295 296
296 297 def __exit__(self, exc_type, exc_value, traceback):
297 298 self.tearDown()
298 299
299 300
300 301 pair_fail_msg = ("Testing {0}\n\n"
301 302 "In:\n"
302 303 " {1!r}\n"
303 304 "Expected:\n"
304 305 " {2!r}\n"
305 306 "Got:\n"
306 307 " {3!r}\n")
307 308 def check_pairs(func, pairs):
308 309 """Utility function for the common case of checking a function with a
309 310 sequence of input/output pairs.
310 311
311 312 Parameters
312 313 ----------
313 314 func : callable
314 315 The function to be tested. Should accept a single argument.
315 316 pairs : iterable
316 317 A list of (input, expected_output) tuples.
317 318
318 319 Returns
319 320 -------
320 321 None. Raises an AssertionError if any output does not match the expected
321 322 value.
322 323 """
323 324 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
324 325 for inp, expected in pairs:
325 326 out = func(inp)
326 327 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
327 328
328 329
329 330 MyStringIO = StringIO
330 331
331 332 _re_type = type(re.compile(r''))
332 333
333 334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 335 -------
335 336 {2!s}
336 337 -------
337 338 """
338 339
339 340 class AssertPrints(object):
340 341 """Context manager for testing that code prints certain text.
341 342
342 343 Examples
343 344 --------
344 345 >>> with AssertPrints("abc", suppress=False):
345 346 ... print("abcd")
346 347 ... print("def")
347 348 ...
348 349 abcd
349 350 def
350 351 """
351 352 def __init__(self, s, channel='stdout', suppress=True):
352 353 self.s = s
353 354 if isinstance(self.s, (str, _re_type)):
354 355 self.s = [self.s]
355 356 self.channel = channel
356 357 self.suppress = suppress
357 358
358 359 def __enter__(self):
359 360 self.orig_stream = getattr(sys, self.channel)
360 361 self.buffer = MyStringIO()
361 362 self.tee = Tee(self.buffer, channel=self.channel)
362 363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363 364
364 365 def __exit__(self, etype, value, traceback):
365 366 try:
366 367 if value is not None:
367 368 # If an error was raised, don't check anything else
368 369 return False
369 370 self.tee.flush()
370 371 setattr(sys, self.channel, self.orig_stream)
371 372 printed = self.buffer.getvalue()
372 373 for s in self.s:
373 374 if isinstance(s, _re_type):
374 375 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
375 376 else:
376 377 assert s in printed, notprinted_msg.format(s, self.channel, printed)
377 378 return False
378 379 finally:
379 380 self.tee.close()
380 381
381 382 printed_msg = """Found {0!r} in printed output (on {1}):
382 383 -------
383 384 {2!s}
384 385 -------
385 386 """
386 387
387 388 class AssertNotPrints(AssertPrints):
388 389 """Context manager for checking that certain output *isn't* produced.
389 390
390 391 Counterpart of AssertPrints"""
391 392 def __exit__(self, etype, value, traceback):
392 393 try:
393 394 if value is not None:
394 395 # If an error was raised, don't check anything else
395 396 self.tee.close()
396 397 return False
397 398 self.tee.flush()
398 399 setattr(sys, self.channel, self.orig_stream)
399 400 printed = self.buffer.getvalue()
400 401 for s in self.s:
401 402 if isinstance(s, _re_type):
402 403 assert not s.search(printed),printed_msg.format(
403 404 s.pattern, self.channel, printed)
404 405 else:
405 406 assert s not in printed, printed_msg.format(
406 407 s, self.channel, printed)
407 408 return False
408 409 finally:
409 410 self.tee.close()
410 411
411 412 @contextmanager
412 413 def mute_warn():
413 414 from IPython.utils import warn
414 415 save_warn = warn.warn
415 416 warn.warn = lambda *a, **kw: None
416 417 try:
417 418 yield
418 419 finally:
419 420 warn.warn = save_warn
420 421
421 422 @contextmanager
422 423 def make_tempfile(name):
423 424 """ Create an empty, named, temporary file for the duration of the context.
424 425 """
425 426 open(name, 'w').close()
426 427 try:
427 428 yield
428 429 finally:
429 430 os.unlink(name)
430 431
431 432 def fake_input(inputs):
432 433 """Temporarily replace the input() function to return the given values
433 434
434 435 Use as a context manager:
435 436
436 437 with fake_input(['result1', 'result2']):
437 438 ...
438 439
439 440 Values are returned in order. If input() is called again after the last value
440 441 was used, EOFError is raised.
441 442 """
442 443 it = iter(inputs)
443 444 def mock_input(prompt=''):
444 445 try:
445 446 return next(it)
446 447 except StopIteration as e:
447 448 raise EOFError('No more inputs given') from e
448 449
449 450 return patch('builtins.input', mock_input)
450 451
451 452 def help_output_test(subcommand=''):
452 453 """test that `ipython [subcommand] -h` works"""
453 454 cmd = get_ipython_cmd() + [subcommand, '-h']
454 455 out, err, rc = get_output_error_code(cmd)
455 456 nt.assert_equal(rc, 0, err)
456 457 nt.assert_not_in("Traceback", err)
457 458 nt.assert_in("Options", out)
458 459 nt.assert_in("--help-all", out)
459 460 return out, err
460 461
461 462
462 463 def help_all_output_test(subcommand=''):
463 464 """test that `ipython [subcommand] --help-all` works"""
464 465 cmd = get_ipython_cmd() + [subcommand, '--help-all']
465 466 out, err, rc = get_output_error_code(cmd)
466 467 nt.assert_equal(rc, 0, err)
467 468 nt.assert_not_in("Traceback", err)
468 469 nt.assert_in("Options", out)
469 470 nt.assert_in("Class", out)
470 471 return out, err
471 472
General Comments 0
You need to be logged in to leave comments. Login now