##// END OF EJS Templates
Merge pull request #12644 from deep-jkl/fix-pathlib-in-tests
Matthias Bussonnier -
r26187:4967ec2a merge
parent child Browse files
Show More
@@ -1,881 +1,897
1 1 """ History related magics and functionality """
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6
7 7 import atexit
8 8 import datetime
9 import os
9 from pathlib import Path
10 10 import re
11 11 import sqlite3
12 12 import threading
13 13
14 14 from traitlets.config.configurable import LoggingConfigurable
15 15 from decorator import decorator
16 16 from IPython.utils.decorators import undoc
17 17 from IPython.paths import locate_profile
18 18 from traitlets import (
19 Any, Bool, Dict, Instance, Integer, List, Unicode, TraitError,
20 default, observe,
19 Any,
20 Bool,
21 Dict,
22 Instance,
23 Integer,
24 List,
25 Unicode,
26 Union,
27 TraitError,
28 default,
29 observe,
21 30 )
22 31
23 32 #-----------------------------------------------------------------------------
24 33 # Classes and functions
25 34 #-----------------------------------------------------------------------------
26 35
27 36 @undoc
28 37 class DummyDB(object):
29 38 """Dummy DB that will act as a black hole for history.
30
39
31 40 Only used in the absence of sqlite"""
32 41 def execute(*args, **kwargs):
33 42 return []
34
43
35 44 def commit(self, *args, **kwargs):
36 45 pass
37
46
38 47 def __enter__(self, *args, **kwargs):
39 48 pass
40
49
41 50 def __exit__(self, *args, **kwargs):
42 51 pass
43 52
44 53
45 54 @decorator
46 55 def only_when_enabled(f, self, *a, **kw):
47 56 """Decorator: return an empty list in the absence of sqlite."""
48 57 if not self.enabled:
49 58 return []
50 59 else:
51 60 return f(self, *a, **kw)
52 61
53 62
54 63 # use 16kB as threshold for whether a corrupt history db should be saved
55 64 # that should be at least 100 entries or so
56 65 _SAVE_DB_SIZE = 16384
57 66
58 67 @decorator
59 68 def catch_corrupt_db(f, self, *a, **kw):
60 69 """A decorator which wraps HistoryAccessor method calls to catch errors from
61 70 a corrupt SQLite database, move the old database out of the way, and create
62 71 a new one.
63 72
64 73 We avoid clobbering larger databases because this may be triggered due to filesystem issues,
65 74 not just a corrupt file.
66 75 """
67 76 try:
68 77 return f(self, *a, **kw)
69 78 except (sqlite3.DatabaseError, sqlite3.OperationalError) as e:
70 79 self._corrupt_db_counter += 1
71 80 self.log.error("Failed to open SQLite history %s (%s).", self.hist_file, e)
72 81 if self.hist_file != ':memory:':
73 82 if self._corrupt_db_counter > self._corrupt_db_limit:
74 83 self.hist_file = ':memory:'
75 84 self.log.error("Failed to load history too many times, history will not be saved.")
76 elif os.path.isfile(self.hist_file):
85 elif self.hist_file.is_file():
77 86 # move the file out of the way
78 base, ext = os.path.splitext(self.hist_file)
79 size = os.stat(self.hist_file).st_size
87 base = str(self.hist_file.parent / self.hist_file.stem)
88 ext = self.hist_file.suffix
89 size = self.hist_file.stat().st_size
80 90 if size >= _SAVE_DB_SIZE:
81 91 # if there's significant content, avoid clobbering
82 92 now = datetime.datetime.now().isoformat().replace(':', '.')
83 93 newpath = base + '-corrupt-' + now + ext
84 94 # don't clobber previous corrupt backups
85 95 for i in range(100):
86 if not os.path.isfile(newpath):
96 if not Path(newpath).exists():
87 97 break
88 98 else:
89 99 newpath = base + '-corrupt-' + now + (u'-%i' % i) + ext
90 100 else:
91 101 # not much content, possibly empty; don't worry about clobbering
92 102 # maybe we should just delete it?
93 103 newpath = base + '-corrupt' + ext
94 os.rename(self.hist_file, newpath)
104 self.hist_file.rename(newpath)
95 105 self.log.error("History file was moved to %s and a new file created.", newpath)
96 106 self.init_db()
97 107 return []
98 108 else:
99 109 # Failed with :memory:, something serious is wrong
100 110 raise
101
111
112
102 113 class HistoryAccessorBase(LoggingConfigurable):
103 114 """An abstract class for History Accessors """
104 115
105 116 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
106 117 raise NotImplementedError
107 118
108 119 def search(self, pattern="*", raw=True, search_raw=True,
109 120 output=False, n=None, unique=False):
110 121 raise NotImplementedError
111 122
112 123 def get_range(self, session, start=1, stop=None, raw=True,output=False):
113 124 raise NotImplementedError
114 125
115 126 def get_range_by_str(self, rangestr, raw=True, output=False):
116 127 raise NotImplementedError
117 128
118 129
119 130 class HistoryAccessor(HistoryAccessorBase):
120 131 """Access the history database without adding to it.
121
132
122 133 This is intended for use by standalone history tools. IPython shells use
123 134 HistoryManager, below, which is a subclass of this."""
124 135
125 136 # counter for init_db retries, so we don't keep trying over and over
126 137 _corrupt_db_counter = 0
127 138 # after two failures, fallback on :memory:
128 139 _corrupt_db_limit = 2
129 140
130 141 # String holding the path to the history file
131 hist_file = Unicode(
142 hist_file = Union(
143 [Instance(Path), Unicode()],
132 144 help="""Path to file to use for SQLite history database.
133
145
134 146 By default, IPython will put the history database in the IPython
135 147 profile directory. If you would rather share one history among
136 148 profiles, you can set this value in each, so that they are consistent.
137
149
138 150 Due to an issue with fcntl, SQLite is known to misbehave on some NFS
139 151 mounts. If you see IPython hanging, try setting this to something on a
140 152 local disk, e.g::
141
153
142 154 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
143 155
144 156 you can also use the specific value `:memory:` (including the colon
145 157 at both end but not the back ticks), to avoid creating an history file.
146
147 """).tag(config=True)
148
158
159 """,
160 ).tag(config=True)
161
149 162 enabled = Bool(True,
150 163 help="""enable the SQLite history
151
164
152 165 set enabled=False to disable the SQLite history,
153 166 in which case there will be no stored history, no SQLite connection,
154 167 and no background saving thread. This may be necessary in some
155 168 threaded environments where IPython is embedded.
156 169 """
157 170 ).tag(config=True)
158
171
159 172 connection_options = Dict(
160 173 help="""Options for configuring the SQLite connection
161
174
162 175 These options are passed as keyword args to sqlite3.connect
163 176 when establishing database connections.
164 177 """
165 178 ).tag(config=True)
166 179
167 180 # The SQLite database
168 181 db = Any()
169 182 @observe('db')
170 183 def _db_changed(self, change):
171 184 """validate the db, since it can be an Instance of two different types"""
172 185 new = change['new']
173 186 connection_types = (DummyDB, sqlite3.Connection)
174 187 if not isinstance(new, connection_types):
175 188 msg = "%s.db must be sqlite3 Connection or DummyDB, not %r" % \
176 189 (self.__class__.__name__, new)
177 190 raise TraitError(msg)
178
179 def __init__(self, profile='default', hist_file=u'', **traits):
191
192 def __init__(self, profile="default", hist_file="", **traits):
180 193 """Create a new history accessor.
181
194
182 195 Parameters
183 196 ----------
184 197 profile : str
185 198 The name of the profile from which to open history.
186 199 hist_file : str
187 200 Path to an SQLite history database stored by IPython. If specified,
188 201 hist_file overrides profile.
189 202 config : :class:`~traitlets.config.loader.Config`
190 203 Config object. hist_file can also be set through this.
191 204 """
192 205 # We need a pointer back to the shell for various tasks.
193 206 super(HistoryAccessor, self).__init__(**traits)
194 207 # defer setting hist_file from kwarg until after init,
195 208 # otherwise the default kwarg value would clobber any value
196 209 # set by config
197 210 if hist_file:
198 211 self.hist_file = hist_file
199
200 if self.hist_file == u'':
212
213 try:
214 self.hist_file
215 except TraitError:
201 216 # No one has set the hist_file, yet.
202 217 self.hist_file = self._get_hist_file_name(profile)
203
218
204 219 self.init_db()
205
220
206 221 def _get_hist_file_name(self, profile='default'):
207 222 """Find the history file for the given profile name.
208
223
209 224 This is overridden by the HistoryManager subclass, to use the shell's
210 225 active profile.
211
226
212 227 Parameters
213 228 ----------
214 229 profile : str
215 230 The name of a profile which has a history file.
216 231 """
217 return os.path.join(locate_profile(profile), 'history.sqlite')
218
232 return Path(locate_profile(profile)) / "history.sqlite"
233
219 234 @catch_corrupt_db
220 235 def init_db(self):
221 236 """Connect to the database, and create tables if necessary."""
222 237 if not self.enabled:
223 238 self.db = DummyDB()
224 239 return
225
240
226 241 # use detect_types so that timestamps return datetime objects
227 242 kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
228 243 kwargs.update(self.connection_options)
229 self.db = sqlite3.connect(self.hist_file, **kwargs)
244 self.db = sqlite3.connect(str(self.hist_file), **kwargs)
230 245 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
231 246 primary key autoincrement, start timestamp,
232 247 end timestamp, num_cmds integer, remark text)""")
233 248 self.db.execute("""CREATE TABLE IF NOT EXISTS history
234 249 (session integer, line integer, source text, source_raw text,
235 250 PRIMARY KEY (session, line))""")
236 251 # Output history is optional, but ensure the table's there so it can be
237 252 # enabled later.
238 253 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
239 254 (session integer, line integer, output text,
240 255 PRIMARY KEY (session, line))""")
241 256 self.db.commit()
242 257 # success! reset corrupt db count
243 258 self._corrupt_db_counter = 0
244 259
245 260 def writeout_cache(self):
246 261 """Overridden by HistoryManager to dump the cache before certain
247 262 database lookups."""
248 263 pass
249 264
250 265 ## -------------------------------
251 266 ## Methods for retrieving history:
252 267 ## -------------------------------
253 268 def _run_sql(self, sql, params, raw=True, output=False):
254 269 """Prepares and runs an SQL query for the history database.
255 270
256 271 Parameters
257 272 ----------
258 273 sql : str
259 274 Any filtering expressions to go after SELECT ... FROM ...
260 275 params : tuple
261 276 Parameters passed to the SQL query (to replace "?")
262 277 raw, output : bool
263 278 See :meth:`get_range`
264 279
265 280 Returns
266 281 -------
267 282 Tuples as :meth:`get_range`
268 283 """
269 284 toget = 'source_raw' if raw else 'source'
270 285 sqlfrom = "history"
271 286 if output:
272 287 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
273 288 toget = "history.%s, output_history.output" % toget
274 289 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
275 290 (toget, sqlfrom) + sql, params)
276 291 if output: # Regroup into 3-tuples, and parse JSON
277 292 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
278 293 return cur
279 294
280 295 @only_when_enabled
281 296 @catch_corrupt_db
282 297 def get_session_info(self, session):
283 298 """Get info about a session.
284 299
285 300 Parameters
286 301 ----------
287 302
288 303 session : int
289 304 Session number to retrieve.
290 305
291 306 Returns
292 307 -------
293
308
294 309 session_id : int
295 310 Session ID number
296 311 start : datetime
297 312 Timestamp for the start of the session.
298 313 end : datetime
299 314 Timestamp for the end of the session, or None if IPython crashed.
300 315 num_cmds : int
301 316 Number of commands run, or None if IPython crashed.
302 317 remark : unicode
303 318 A manually set description.
304 319 """
305 320 query = "SELECT * from sessions where session == ?"
306 321 return self.db.execute(query, (session,)).fetchone()
307 322
308 323 @catch_corrupt_db
309 324 def get_last_session_id(self):
310 325 """Get the last session ID currently in the database.
311
326
312 327 Within IPython, this should be the same as the value stored in
313 328 :attr:`HistoryManager.session_number`.
314 329 """
315 330 for record in self.get_tail(n=1, include_latest=True):
316 331 return record[0]
317 332
318 333 @catch_corrupt_db
319 334 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
320 335 """Get the last n lines from the history database.
321 336
322 337 Parameters
323 338 ----------
324 339 n : int
325 340 The number of lines to get
326 341 raw, output : bool
327 342 See :meth:`get_range`
328 343 include_latest : bool
329 344 If False (default), n+1 lines are fetched, and the latest one
330 345 is discarded. This is intended to be used where the function
331 346 is called by a user command, which it should not return.
332 347
333 348 Returns
334 349 -------
335 350 Tuples as :meth:`get_range`
336 351 """
337 352 self.writeout_cache()
338 353 if not include_latest:
339 354 n += 1
340 355 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
341 356 (n,), raw=raw, output=output)
342 357 if not include_latest:
343 358 return reversed(list(cur)[1:])
344 359 return reversed(list(cur))
345 360
346 361 @catch_corrupt_db
347 362 def search(self, pattern="*", raw=True, search_raw=True,
348 363 output=False, n=None, unique=False):
349 364 """Search the database using unix glob-style matching (wildcards
350 365 * and ?).
351 366
352 367 Parameters
353 368 ----------
354 369 pattern : str
355 370 The wildcarded pattern to match when searching
356 371 search_raw : bool
357 372 If True, search the raw input, otherwise, the parsed input
358 373 raw, output : bool
359 374 See :meth:`get_range`
360 375 n : None or int
361 376 If an integer is given, it defines the limit of
362 377 returned entries.
363 378 unique : bool
364 379 When it is true, return only unique entries.
365 380
366 381 Returns
367 382 -------
368 383 Tuples as :meth:`get_range`
369 384 """
370 385 tosearch = "source_raw" if search_raw else "source"
371 386 if output:
372 387 tosearch = "history." + tosearch
373 388 self.writeout_cache()
374 389 sqlform = "WHERE %s GLOB ?" % tosearch
375 390 params = (pattern,)
376 391 if unique:
377 392 sqlform += ' GROUP BY {0}'.format(tosearch)
378 393 if n is not None:
379 394 sqlform += " ORDER BY session DESC, line DESC LIMIT ?"
380 395 params += (n,)
381 396 elif unique:
382 397 sqlform += " ORDER BY session, line"
383 398 cur = self._run_sql(sqlform, params, raw=raw, output=output)
384 399 if n is not None:
385 400 return reversed(list(cur))
386 401 return cur
387
402
388 403 @catch_corrupt_db
389 404 def get_range(self, session, start=1, stop=None, raw=True,output=False):
390 405 """Retrieve input by session.
391 406
392 407 Parameters
393 408 ----------
394 409 session : int
395 410 Session number to retrieve.
396 411 start : int
397 412 First line to retrieve.
398 413 stop : int
399 414 End of line range (excluded from output itself). If None, retrieve
400 415 to the end of the session.
401 416 raw : bool
402 417 If True, return untranslated input
403 418 output : bool
404 419 If True, attempt to include output. This will be 'real' Python
405 420 objects for the current session, or text reprs from previous
406 421 sessions if db_log_output was enabled at the time. Where no output
407 422 is found, None is used.
408 423
409 424 Returns
410 425 -------
411 426 entries
412 427 An iterator over the desired lines. Each line is a 3-tuple, either
413 428 (session, line, input) if output is False, or
414 429 (session, line, (input, output)) if output is True.
415 430 """
416 431 if stop:
417 432 lineclause = "line >= ? AND line < ?"
418 433 params = (session, start, stop)
419 434 else:
420 435 lineclause = "line>=?"
421 436 params = (session, start)
422 437
423 438 return self._run_sql("WHERE session==? AND %s" % lineclause,
424 439 params, raw=raw, output=output)
425 440
426 441 def get_range_by_str(self, rangestr, raw=True, output=False):
427 442 """Get lines of history from a string of ranges, as used by magic
428 443 commands %hist, %save, %macro, etc.
429 444
430 445 Parameters
431 446 ----------
432 447 rangestr : str
433 448 A string specifying ranges, e.g. "5 ~2/1-4". See
434 449 :func:`magic_history` for full details.
435 450 raw, output : bool
436 451 As :meth:`get_range`
437 452
438 453 Returns
439 454 -------
440 455 Tuples as :meth:`get_range`
441 456 """
442 457 for sess, s, e in extract_hist_ranges(rangestr):
443 458 for line in self.get_range(sess, s, e, raw=raw, output=output):
444 459 yield line
445 460
446 461
447 462 class HistoryManager(HistoryAccessor):
448 463 """A class to organize all history-related functionality in one place.
449 464 """
450 465 # Public interface
451 466
452 467 # An instance of the IPython shell we are attached to
453 468 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC',
454 469 allow_none=True)
455 470 # Lists to hold processed and raw history. These start with a blank entry
456 471 # so that we can index them starting from 1
457 472 input_hist_parsed = List([""])
458 473 input_hist_raw = List([""])
459 474 # A list of directories visited during session
460 475 dir_hist = List()
461 476 @default('dir_hist')
462 477 def _dir_hist_default(self):
463 478 try:
464 return [os.getcwd()]
479 return [Path.cwd()]
465 480 except OSError:
466 481 return []
467 482
468 483 # A dict of output history, keyed with ints from the shell's
469 484 # execution count.
470 485 output_hist = Dict()
471 486 # The text/plain repr of outputs.
472 487 output_hist_reprs = Dict()
473 488
474 489 # The number of the current session in the history database
475 490 session_number = Integer()
476
491
477 492 db_log_output = Bool(False,
478 493 help="Should the history database include output? (default: no)"
479 494 ).tag(config=True)
480 495 db_cache_size = Integer(0,
481 496 help="Write to database every x commands (higher values save disk access & power).\n"
482 497 "Values of 1 or less effectively disable caching."
483 498 ).tag(config=True)
484 499 # The input and output caches
485 500 db_input_cache = List()
486 501 db_output_cache = List()
487
502
488 503 # History saving in separate thread
489 504 save_thread = Instance('IPython.core.history.HistorySavingThread',
490 505 allow_none=True)
491 506 save_flag = Instance(threading.Event, allow_none=True)
492
507
493 508 # Private interface
494 509 # Variables used to store the three last inputs from the user. On each new
495 510 # history update, we populate the user's namespace with these, shifted as
496 511 # necessary.
497 512 _i00 = Unicode(u'')
498 513 _i = Unicode(u'')
499 514 _ii = Unicode(u'')
500 515 _iii = Unicode(u'')
501 516
502 517 # A regex matching all forms of the exit command, so that we don't store
503 518 # them in the history (it's annoying to rewind the first entry and land on
504 519 # an exit call).
505 520 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
506 521
507 522 def __init__(self, shell=None, config=None, **traits):
508 523 """Create a new history manager associated with a shell instance.
509 524 """
510 525 # We need a pointer back to the shell for various tasks.
511 526 super(HistoryManager, self).__init__(shell=shell, config=config,
512 527 **traits)
513 528 self.save_flag = threading.Event()
514 529 self.db_input_cache_lock = threading.Lock()
515 530 self.db_output_cache_lock = threading.Lock()
516
531
517 532 try:
518 533 self.new_session()
519 534 except sqlite3.OperationalError:
520 535 self.log.error("Failed to create history session in %s. History will not be saved.",
521 536 self.hist_file, exc_info=True)
522 537 self.hist_file = ':memory:'
523
538
524 539 if self.enabled and self.hist_file != ':memory:':
525 540 self.save_thread = HistorySavingThread(self)
526 541 self.save_thread.start()
527 542
528 543 def _get_hist_file_name(self, profile=None):
529 544 """Get default history file name based on the Shell's profile.
530
545
531 546 The profile parameter is ignored, but must exist for compatibility with
532 547 the parent class."""
533 548 profile_dir = self.shell.profile_dir.location
534 return os.path.join(profile_dir, 'history.sqlite')
535
549 return Path(profile_dir) / "history.sqlite"
550
536 551 @only_when_enabled
537 552 def new_session(self, conn=None):
538 553 """Get a new session number."""
539 554 if conn is None:
540 555 conn = self.db
541
556
542 557 with conn:
543 558 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
544 559 NULL, "") """, (datetime.datetime.now(),))
545 560 self.session_number = cur.lastrowid
546
561
547 562 def end_session(self):
548 563 """Close the database session, filling in the end time and line count."""
549 564 self.writeout_cache()
550 565 with self.db:
551 566 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
552 567 session==?""", (datetime.datetime.now(),
553 568 len(self.input_hist_parsed)-1, self.session_number))
554 569 self.session_number = 0
555
570
556 571 def name_session(self, name):
557 572 """Give the current session a name in the history database."""
558 573 with self.db:
559 574 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
560 575 (name, self.session_number))
561
576
562 577 def reset(self, new_session=True):
563 578 """Clear the session history, releasing all object references, and
564 579 optionally open a new session."""
565 580 self.output_hist.clear()
566 581 # The directory history can't be completely empty
567 self.dir_hist[:] = [os.getcwd()]
568
582 self.dir_hist[:] = [Path.cwd()]
583
569 584 if new_session:
570 585 if self.session_number:
571 586 self.end_session()
572 587 self.input_hist_parsed[:] = [""]
573 588 self.input_hist_raw[:] = [""]
574 589 self.new_session()
575 590
576 591 # ------------------------------
577 592 # Methods for retrieving history
578 593 # ------------------------------
579 594 def get_session_info(self, session=0):
580 595 """Get info about a session.
581 596
582 597 Parameters
583 598 ----------
584 599
585 600 session : int
586 601 Session number to retrieve. The current session is 0, and negative
587 602 numbers count back from current session, so -1 is the previous session.
588 603
589 604 Returns
590 605 -------
591
606
592 607 session_id : int
593 608 Session ID number
594 609 start : datetime
595 610 Timestamp for the start of the session.
596 611 end : datetime
597 612 Timestamp for the end of the session, or None if IPython crashed.
598 613 num_cmds : int
599 614 Number of commands run, or None if IPython crashed.
600 615 remark : unicode
601 616 A manually set description.
602 617 """
603 618 if session <= 0:
604 619 session += self.session_number
605 620
606 621 return super(HistoryManager, self).get_session_info(session=session)
607 622
608 623 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
609 624 """Get input and output history from the current session. Called by
610 625 get_range, and takes similar parameters."""
611 626 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
612
627
613 628 n = len(input_hist)
614 629 if start < 0:
615 630 start += n
616 631 if not stop or (stop > n):
617 632 stop = n
618 633 elif stop < 0:
619 634 stop += n
620
635
621 636 for i in range(start, stop):
622 637 if output:
623 638 line = (input_hist[i], self.output_hist_reprs.get(i))
624 639 else:
625 640 line = input_hist[i]
626 641 yield (0, i, line)
627
642
628 643 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
629 644 """Retrieve input by session.
630
645
631 646 Parameters
632 647 ----------
633 648 session : int
634 649 Session number to retrieve. The current session is 0, and negative
635 650 numbers count back from current session, so -1 is previous session.
636 651 start : int
637 652 First line to retrieve.
638 653 stop : int
639 654 End of line range (excluded from output itself). If None, retrieve
640 655 to the end of the session.
641 656 raw : bool
642 657 If True, return untranslated input
643 658 output : bool
644 659 If True, attempt to include output. This will be 'real' Python
645 660 objects for the current session, or text reprs from previous
646 661 sessions if db_log_output was enabled at the time. Where no output
647 662 is found, None is used.
648
663
649 664 Returns
650 665 -------
651 666 entries
652 667 An iterator over the desired lines. Each line is a 3-tuple, either
653 668 (session, line, input) if output is False, or
654 669 (session, line, (input, output)) if output is True.
655 670 """
656 671 if session <= 0:
657 672 session += self.session_number
658 673 if session==self.session_number: # Current session
659 674 return self._get_range_session(start, stop, raw, output)
660 675 return super(HistoryManager, self).get_range(session, start, stop, raw,
661 676 output)
662 677
663 678 ## ----------------------------
664 679 ## Methods for storing history:
665 680 ## ----------------------------
666 681 def store_inputs(self, line_num, source, source_raw=None):
667 682 """Store source and raw input in history and create input cache
668 683 variables ``_i*``.
669 684
670 685 Parameters
671 686 ----------
672 687 line_num : int
673 688 The prompt number of this input.
674 689
675 690 source : str
676 691 Python input.
677 692
678 693 source_raw : str, optional
679 694 If given, this is the raw input without any IPython transformations
680 695 applied to it. If not given, ``source`` is used.
681 696 """
682 697 if source_raw is None:
683 698 source_raw = source
684 699 source = source.rstrip('\n')
685 700 source_raw = source_raw.rstrip('\n')
686 701
687 702 # do not store exit/quit commands
688 703 if self._exit_re.match(source_raw.strip()):
689 704 return
690 705
691 706 self.input_hist_parsed.append(source)
692 707 self.input_hist_raw.append(source_raw)
693 708
694 709 with self.db_input_cache_lock:
695 710 self.db_input_cache.append((line_num, source, source_raw))
696 711 # Trigger to flush cache and write to DB.
697 712 if len(self.db_input_cache) >= self.db_cache_size:
698 713 self.save_flag.set()
699 714
700 715 # update the auto _i variables
701 716 self._iii = self._ii
702 717 self._ii = self._i
703 718 self._i = self._i00
704 719 self._i00 = source_raw
705 720
706 721 # hackish access to user namespace to create _i1,_i2... dynamically
707 722 new_i = '_i%s' % line_num
708 723 to_main = {'_i': self._i,
709 724 '_ii': self._ii,
710 725 '_iii': self._iii,
711 726 new_i : self._i00 }
712
727
713 728 if self.shell is not None:
714 729 self.shell.push(to_main, interactive=False)
715 730
716 731 def store_output(self, line_num):
717 732 """If database output logging is enabled, this saves all the
718 733 outputs from the indicated prompt number to the database. It's
719 734 called by run_cell after code has been executed.
720 735
721 736 Parameters
722 737 ----------
723 738 line_num : int
724 739 The line number from which to save outputs
725 740 """
726 741 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
727 742 return
728 743 output = self.output_hist_reprs[line_num]
729 744
730 745 with self.db_output_cache_lock:
731 746 self.db_output_cache.append((line_num, output))
732 747 if self.db_cache_size <= 1:
733 748 self.save_flag.set()
734 749
735 750 def _writeout_input_cache(self, conn):
736 751 with conn:
737 752 for line in self.db_input_cache:
738 753 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
739 754 (self.session_number,)+line)
740 755
741 756 def _writeout_output_cache(self, conn):
742 757 with conn:
743 758 for line in self.db_output_cache:
744 759 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
745 760 (self.session_number,)+line)
746 761
747 762 @only_when_enabled
748 763 def writeout_cache(self, conn=None):
749 764 """Write any entries in the cache to the database."""
750 765 if conn is None:
751 766 conn = self.db
752 767
753 768 with self.db_input_cache_lock:
754 769 try:
755 770 self._writeout_input_cache(conn)
756 771 except sqlite3.IntegrityError:
757 772 self.new_session(conn)
758 773 print("ERROR! Session/line number was not unique in",
759 774 "database. History logging moved to new session",
760 775 self.session_number)
761 776 try:
762 777 # Try writing to the new session. If this fails, don't
763 778 # recurse
764 779 self._writeout_input_cache(conn)
765 780 except sqlite3.IntegrityError:
766 781 pass
767 782 finally:
768 783 self.db_input_cache = []
769 784
770 785 with self.db_output_cache_lock:
771 786 try:
772 787 self._writeout_output_cache(conn)
773 788 except sqlite3.IntegrityError:
774 789 print("!! Session/line number for output was not unique",
775 790 "in database. Output will not be stored.")
776 791 finally:
777 792 self.db_output_cache = []
778 793
779 794
780 795 class HistorySavingThread(threading.Thread):
781 796 """This thread takes care of writing history to the database, so that
782 797 the UI isn't held up while that happens.
783 798
784 799 It waits for the HistoryManager's save_flag to be set, then writes out
785 800 the history cache. The main thread is responsible for setting the flag when
786 801 the cache size reaches a defined threshold."""
787 802 daemon = True
788 803 stop_now = False
789 804 enabled = True
790 805 def __init__(self, history_manager):
791 806 super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread")
792 807 self.history_manager = history_manager
793 808 self.enabled = history_manager.enabled
794 809 atexit.register(self.stop)
795 810
796 811 @only_when_enabled
797 812 def run(self):
798 813 # We need a separate db connection per thread:
799 814 try:
800 self.db = sqlite3.connect(self.history_manager.hist_file,
801 **self.history_manager.connection_options
815 self.db = sqlite3.connect(
816 str(self.history_manager.hist_file),
817 **self.history_manager.connection_options
802 818 )
803 819 while True:
804 820 self.history_manager.save_flag.wait()
805 821 if self.stop_now:
806 822 self.db.close()
807 823 return
808 824 self.history_manager.save_flag.clear()
809 825 self.history_manager.writeout_cache(self.db)
810 826 except Exception as e:
811 827 print(("The history saving thread hit an unexpected error (%s)."
812 828 "History will not be written to the database.") % repr(e))
813 829
814 830 def stop(self):
815 831 """This can be called from the main thread to safely stop this thread.
816 832
817 833 Note that it does not attempt to write out remaining history before
818 834 exiting. That should be done by calling the HistoryManager's
819 835 end_session method."""
820 836 self.stop_now = True
821 837 self.history_manager.save_flag.set()
822 838 self.join()
823 839
824 840
825 841 # To match, e.g. ~5/8-~2/3
826 842 range_re = re.compile(r"""
827 843 ((?P<startsess>~?\d+)/)?
828 844 (?P<start>\d+)?
829 845 ((?P<sep>[\-:])
830 846 ((?P<endsess>~?\d+)/)?
831 847 (?P<end>\d+))?
832 848 $""", re.VERBOSE)
833 849
834 850
835 851 def extract_hist_ranges(ranges_str):
836 852 """Turn a string of history ranges into 3-tuples of (session, start, stop).
837 853
838 854 Examples
839 855 --------
840 856 >>> list(extract_hist_ranges("~8/5-~7/4 2"))
841 857 [(-8, 5, None), (-7, 1, 5), (0, 2, 3)]
842 858 """
843 859 for range_str in ranges_str.split():
844 860 rmatch = range_re.match(range_str)
845 861 if not rmatch:
846 862 continue
847 863 start = rmatch.group("start")
848 864 if start:
849 865 start = int(start)
850 866 end = rmatch.group("end")
851 867 # If no end specified, get (a, a + 1)
852 868 end = int(end) if end else start + 1
853 869 else: # start not specified
854 870 if not rmatch.group('startsess'): # no startsess
855 871 continue
856 872 start = 1
857 873 end = None # provide the entire session hist
858 874
859 875 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
860 876 end += 1
861 877 startsess = rmatch.group("startsess") or "0"
862 878 endsess = rmatch.group("endsess") or startsess
863 879 startsess = int(startsess.replace("~","-"))
864 880 endsess = int(endsess.replace("~","-"))
865 881 assert endsess >= startsess, "start session must be earlier than end session"
866 882
867 883 if endsess == startsess:
868 884 yield (startsess, start, end)
869 885 continue
870 886 # Multiple sessions in one range:
871 887 yield (startsess, start, None)
872 888 for sess in range(startsess+1, endsess):
873 889 yield (sess, 1, None)
874 890 yield (endsess, 1, end)
875 891
876 892
877 893 def _format_lineno(session, line):
878 894 """Helper function to format line numbers properly."""
879 895 if session == 0:
880 896 return str(line)
881 897 return "%s#%s" % (session, line)
@@ -1,214 +1,215
1 1 # coding: utf-8
2 2 """Tests for the IPython tab-completion machinery.
3 3 """
4 4 #-----------------------------------------------------------------------------
5 5 # Module imports
6 6 #-----------------------------------------------------------------------------
7 7
8 8 # stdlib
9 9 import io
10 import os
10 from pathlib import Path
11 11 import sys
12 12 import tempfile
13 13 from datetime import datetime
14 14 import sqlite3
15 15
16 16 # third party
17 17 import nose.tools as nt
18 18
19 19 # our own packages
20 20 from traitlets.config.loader import Config
21 21 from IPython.utils.tempdir import TemporaryDirectory
22 22 from IPython.core.history import HistoryManager, extract_hist_ranges
23 23 from IPython.testing.decorators import skipif
24 24
25 25 def test_proper_default_encoding():
26 26 nt.assert_equal(sys.getdefaultencoding(), "utf-8")
27 27
28 28 @skipif(sqlite3.sqlite_version_info > (3,24,0))
29 29 def test_history():
30 30 ip = get_ipython()
31 31 with TemporaryDirectory() as tmpdir:
32 tmp_path = Path(tmpdir)
32 33 hist_manager_ori = ip.history_manager
33 hist_file = os.path.join(tmpdir, 'history.sqlite')
34 hist_file = tmp_path / "history.sqlite"
34 35 try:
35 36 ip.history_manager = HistoryManager(shell=ip, hist_file=hist_file)
36 37 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
37 38 for i, h in enumerate(hist, start=1):
38 39 ip.history_manager.store_inputs(i, h)
39 40
40 41 ip.history_manager.db_log_output = True
41 42 # Doesn't match the input, but we'll just check it's stored.
42 43 ip.history_manager.output_hist_reprs[3] = "spam"
43 44 ip.history_manager.store_output(3)
44 45
45 46 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
46 47
47 48 # Detailed tests for _get_range_session
48 49 grs = ip.history_manager._get_range_session
49 50 nt.assert_equal(list(grs(start=2,stop=-1)), list(zip([0], [2], hist[1:-1])))
50 51 nt.assert_equal(list(grs(start=-2)), list(zip([0,0], [2,3], hist[-2:])))
51 52 nt.assert_equal(list(grs(output=True)), list(zip([0,0,0], [1,2,3], zip(hist, [None,None,'spam']))))
52 53
53 54 # Check whether specifying a range beyond the end of the current
54 55 # session results in an error (gh-804)
55 56 ip.magic('%hist 2-500')
56 57
57 58 # Check that we can write non-ascii characters to a file
58 ip.magic("%%hist -f %s" % os.path.join(tmpdir, "test1"))
59 ip.magic("%%hist -pf %s" % os.path.join(tmpdir, "test2"))
60 ip.magic("%%hist -nf %s" % os.path.join(tmpdir, "test3"))
61 ip.magic("%%save %s 1-10" % os.path.join(tmpdir, "test4"))
59 ip.magic("%%hist -f %s" % (tmp_path / "test1"))
60 ip.magic("%%hist -pf %s" % (tmp_path / "test2"))
61 ip.magic("%%hist -nf %s" % (tmp_path / "test3"))
62 ip.magic("%%save %s 1-10" % (tmp_path / "test4"))
62 63
63 64 # New session
64 65 ip.history_manager.reset()
65 66 newcmds = [u"z=5",
66 67 u"class X(object):\n pass",
67 68 u"k='p'",
68 69 u"z=5"]
69 70 for i, cmd in enumerate(newcmds, start=1):
70 71 ip.history_manager.store_inputs(i, cmd)
71 72 gothist = ip.history_manager.get_range(start=1, stop=4)
72 73 nt.assert_equal(list(gothist), list(zip([0,0,0],[1,2,3], newcmds)))
73 74 # Previous session:
74 75 gothist = ip.history_manager.get_range(-1, 1, 4)
75 76 nt.assert_equal(list(gothist), list(zip([1,1,1],[1,2,3], hist)))
76 77
77 78 newhist = [(2, i, c) for (i, c) in enumerate(newcmds, 1)]
78 79
79 80 # Check get_hist_tail
80 81 gothist = ip.history_manager.get_tail(5, output=True,
81 82 include_latest=True)
82 83 expected = [(1, 3, (hist[-1], "spam"))] \
83 84 + [(s, n, (c, None)) for (s, n, c) in newhist]
84 85 nt.assert_equal(list(gothist), expected)
85 86
86 87 gothist = ip.history_manager.get_tail(2)
87 88 expected = newhist[-3:-1]
88 89 nt.assert_equal(list(gothist), expected)
89 90
90 91 # Check get_hist_search
91 92
92 93 gothist = ip.history_manager.search("*test*")
93 94 nt.assert_equal(list(gothist), [(1,2,hist[1])] )
94 95
95 96 gothist = ip.history_manager.search("*=*")
96 97 nt.assert_equal(list(gothist),
97 98 [(1, 1, hist[0]),
98 99 (1, 2, hist[1]),
99 100 (1, 3, hist[2]),
100 101 newhist[0],
101 102 newhist[2],
102 103 newhist[3]])
103 104
104 105 gothist = ip.history_manager.search("*=*", n=4)
105 106 nt.assert_equal(list(gothist),
106 107 [(1, 3, hist[2]),
107 108 newhist[0],
108 109 newhist[2],
109 110 newhist[3]])
110 111
111 112 gothist = ip.history_manager.search("*=*", unique=True)
112 113 nt.assert_equal(list(gothist),
113 114 [(1, 1, hist[0]),
114 115 (1, 2, hist[1]),
115 116 (1, 3, hist[2]),
116 117 newhist[2],
117 118 newhist[3]])
118 119
119 120 gothist = ip.history_manager.search("*=*", unique=True, n=3)
120 121 nt.assert_equal(list(gothist),
121 122 [(1, 3, hist[2]),
122 123 newhist[2],
123 124 newhist[3]])
124 125
125 126 gothist = ip.history_manager.search("b*", output=True)
126 127 nt.assert_equal(list(gothist), [(1,3,(hist[2],"spam"))] )
127 128
128 129 # Cross testing: check that magic %save can get previous session.
129 testfilename = os.path.realpath(os.path.join(tmpdir, "test.py"))
130 ip.magic("save " + testfilename + " ~1/1-3")
130 testfilename = (tmp_path / "test.py").resolve()
131 ip.magic("save " + str(testfilename) + " ~1/1-3")
131 132 with io.open(testfilename, encoding='utf-8') as testfile:
132 133 nt.assert_equal(testfile.read(),
133 134 u"# coding: utf-8\n" + u"\n".join(hist)+u"\n")
134 135
135 136 # Duplicate line numbers - check that it doesn't crash, and
136 137 # gets a new session
137 138 ip.history_manager.store_inputs(1, "rogue")
138 139 ip.history_manager.writeout_cache()
139 140 nt.assert_equal(ip.history_manager.session_number, 3)
140 141 finally:
141 142 # Ensure saving thread is shut down before we try to clean up the files
142 143 ip.history_manager.save_thread.stop()
143 144 # Forcibly close database rather than relying on garbage collection
144 145 ip.history_manager.db.close()
145 146 # Restore history manager
146 147 ip.history_manager = hist_manager_ori
147 148
148 149
149 150 def test_extract_hist_ranges():
150 151 instr = "1 2/3 ~4/5-6 ~4/7-~4/9 ~9/2-~7/5 ~10/"
151 152 expected = [(0, 1, 2), # 0 == current session
152 153 (2, 3, 4),
153 154 (-4, 5, 7),
154 155 (-4, 7, 10),
155 156 (-9, 2, None), # None == to end
156 157 (-8, 1, None),
157 158 (-7, 1, 6),
158 159 (-10, 1, None)]
159 160 actual = list(extract_hist_ranges(instr))
160 161 nt.assert_equal(actual, expected)
161 162
162 163 def test_magic_rerun():
163 164 """Simple test for %rerun (no args -> rerun last line)"""
164 165 ip = get_ipython()
165 166 ip.run_cell("a = 10", store_history=True)
166 167 ip.run_cell("a += 1", store_history=True)
167 168 nt.assert_equal(ip.user_ns["a"], 11)
168 169 ip.run_cell("%rerun", store_history=True)
169 170 nt.assert_equal(ip.user_ns["a"], 12)
170 171
171 172 def test_timestamp_type():
172 173 ip = get_ipython()
173 174 info = ip.history_manager.get_session_info()
174 175 nt.assert_true(isinstance(info[1], datetime))
175 176
176 177 def test_hist_file_config():
177 178 cfg = Config()
178 179 tfile = tempfile.NamedTemporaryFile(delete=False)
179 cfg.HistoryManager.hist_file = tfile.name
180 cfg.HistoryManager.hist_file = Path(tfile.name)
180 181 try:
181 182 hm = HistoryManager(shell=get_ipython(), config=cfg)
182 183 nt.assert_equal(hm.hist_file, cfg.HistoryManager.hist_file)
183 184 finally:
184 185 try:
185 os.remove(tfile.name)
186 Path(tfile.name).unlink()
186 187 except OSError:
187 188 # same catch as in testing.tools.TempFileMixin
188 189 # On Windows, even though we close the file, we still can't
189 190 # delete it. I have no clue why
190 191 pass
191 192
192 193 def test_histmanager_disabled():
193 194 """Ensure that disabling the history manager doesn't create a database."""
194 195 cfg = Config()
195 196 cfg.HistoryAccessor.enabled = False
196 197
197 198 ip = get_ipython()
198 199 with TemporaryDirectory() as tmpdir:
199 200 hist_manager_ori = ip.history_manager
200 hist_file = os.path.join(tmpdir, 'history.sqlite')
201 hist_file = Path(tmpdir) / "history.sqlite"
201 202 cfg.HistoryManager.hist_file = hist_file
202 203 try:
203 204 ip.history_manager = HistoryManager(shell=ip, config=cfg)
204 205 hist = [u'a=1', u'def f():\n test = 1\n return test', u"b='€Æ¾÷ß'"]
205 206 for i, h in enumerate(hist, start=1):
206 207 ip.history_manager.store_inputs(i, h)
207 208 nt.assert_equal(ip.history_manager.input_hist_raw, [''] + hist)
208 209 ip.history_manager.reset()
209 210 ip.history_manager.end_session()
210 211 finally:
211 212 ip.history_manager = hist_manager_ori
212 213
213 214 # hist_file should not be created
214 nt.assert_false(os.path.exists(hist_file))
215 nt.assert_false(hist_file.exists())
@@ -1,471 +1,472
1 1 """Generic testing tools.
2 2
3 3 Authors
4 4 -------
5 5 - Fernando Perez <Fernando.Perez@berkeley.edu>
6 6 """
7 7
8 8
9 9 # Copyright (c) IPython Development Team.
10 10 # Distributed under the terms of the Modified BSD License.
11 11
12 12 import os
13 from pathlib import Path
13 14 import re
14 15 import sys
15 16 import tempfile
16 17 import unittest
17 18
18 19 from contextlib import contextmanager
19 20 from io import StringIO
20 21 from subprocess import Popen, PIPE
21 22 from unittest.mock import patch
22 23
23 24 try:
24 25 # These tools are used by parts of the runtime, so we make the nose
25 26 # dependency optional at this point. Nose is a hard dependency to run the
26 27 # test suite, but NOT to use ipython itself.
27 28 import nose.tools as nt
28 29 has_nose = True
29 30 except ImportError:
30 31 has_nose = False
31 32
32 33 from traitlets.config.loader import Config
33 34 from IPython.utils.process import get_output_error_code
34 35 from IPython.utils.text import list_strings
35 36 from IPython.utils.io import temp_pyfile, Tee
36 37 from IPython.utils import py3compat
37 38
38 39 from . import decorators as dec
39 40 from . import skipdoctest
40 41
41 42
42 43 # The docstring for full_path doctests differently on win32 (different path
43 44 # separator) so just skip the doctest there. The example remains informative.
44 45 doctest_deco = skipdoctest.skip_doctest if sys.platform == 'win32' else dec.null_deco
45 46
46 47 @doctest_deco
47 48 def full_path(startPath,files):
48 49 """Make full paths for all the listed files, based on startPath.
49 50
50 51 Only the base part of startPath is kept, since this routine is typically
51 52 used with a script's ``__file__`` variable as startPath. The base of startPath
52 53 is then prepended to all the listed files, forming the output list.
53 54
54 55 Parameters
55 56 ----------
56 57 startPath : string
57 58 Initial path to use as the base for the results. This path is split
58 59 using os.path.split() and only its first component is kept.
59 60
60 61 files : string or list
61 62 One or more files.
62 63
63 64 Examples
64 65 --------
65 66
66 67 >>> full_path('/foo/bar.py',['a.txt','b.txt'])
67 68 ['/foo/a.txt', '/foo/b.txt']
68 69
69 70 >>> full_path('/foo',['a.txt','b.txt'])
70 71 ['/a.txt', '/b.txt']
71 72
72 73 If a single file is given, the output is still a list::
73 74
74 75 >>> full_path('/foo','a.txt')
75 76 ['/a.txt']
76 77 """
77 78
78 79 files = list_strings(files)
79 80 base = os.path.split(startPath)[0]
80 81 return [ os.path.join(base,f) for f in files ]
81 82
82 83
83 84 def parse_test_output(txt):
84 85 """Parse the output of a test run and return errors, failures.
85 86
86 87 Parameters
87 88 ----------
88 89 txt : str
89 90 Text output of a test run, assumed to contain a line of one of the
90 91 following forms::
91 92
92 93 'FAILED (errors=1)'
93 94 'FAILED (failures=1)'
94 95 'FAILED (errors=1, failures=1)'
95 96
96 97 Returns
97 98 -------
98 99 nerr, nfail
99 100 number of errors and failures.
100 101 """
101 102
102 103 err_m = re.search(r'^FAILED \(errors=(\d+)\)', txt, re.MULTILINE)
103 104 if err_m:
104 105 nerr = int(err_m.group(1))
105 106 nfail = 0
106 107 return nerr, nfail
107 108
108 109 fail_m = re.search(r'^FAILED \(failures=(\d+)\)', txt, re.MULTILINE)
109 110 if fail_m:
110 111 nerr = 0
111 112 nfail = int(fail_m.group(1))
112 113 return nerr, nfail
113 114
114 115 both_m = re.search(r'^FAILED \(errors=(\d+), failures=(\d+)\)', txt,
115 116 re.MULTILINE)
116 117 if both_m:
117 118 nerr = int(both_m.group(1))
118 119 nfail = int(both_m.group(2))
119 120 return nerr, nfail
120 121
121 122 # If the input didn't match any of these forms, assume no error/failures
122 123 return 0, 0
123 124
124 125
125 126 # So nose doesn't think this is a test
126 127 parse_test_output.__test__ = False
127 128
128 129
129 130 def default_argv():
130 131 """Return a valid default argv for creating testing instances of ipython"""
131 132
132 133 return ['--quick', # so no config file is loaded
133 134 # Other defaults to minimize side effects on stdout
134 135 '--colors=NoColor', '--no-term-title','--no-banner',
135 136 '--autocall=0']
136 137
137 138
138 139 def default_config():
139 140 """Return a config object with good defaults for testing."""
140 141 config = Config()
141 142 config.TerminalInteractiveShell.colors = 'NoColor'
142 143 config.TerminalTerminalInteractiveShell.term_title = False,
143 144 config.TerminalInteractiveShell.autocall = 0
144 145 f = tempfile.NamedTemporaryFile(suffix=u'test_hist.sqlite', delete=False)
145 config.HistoryManager.hist_file = f.name
146 config.HistoryManager.hist_file = Path(f.name)
146 147 f.close()
147 148 config.HistoryManager.db_cache_size = 10000
148 149 return config
149 150
150 151
151 152 def get_ipython_cmd(as_string=False):
152 153 """
153 154 Return appropriate IPython command line name. By default, this will return
154 155 a list that can be used with subprocess.Popen, for example, but passing
155 156 `as_string=True` allows for returning the IPython command as a string.
156 157
157 158 Parameters
158 159 ----------
159 160 as_string: bool
160 161 Flag to allow to return the command as a string.
161 162 """
162 163 ipython_cmd = [sys.executable, "-m", "IPython"]
163 164
164 165 if as_string:
165 166 ipython_cmd = " ".join(ipython_cmd)
166 167
167 168 return ipython_cmd
168 169
169 170 def ipexec(fname, options=None, commands=()):
170 171 """Utility to call 'ipython filename'.
171 172
172 173 Starts IPython with a minimal and safe configuration to make startup as fast
173 174 as possible.
174 175
175 176 Note that this starts IPython in a subprocess!
176 177
177 178 Parameters
178 179 ----------
179 180 fname : str
180 181 Name of file to be executed (should have .py or .ipy extension).
181 182
182 183 options : optional, list
183 184 Extra command-line flags to be passed to IPython.
184 185
185 186 commands : optional, list
186 187 Commands to send in on stdin
187 188
188 189 Returns
189 190 -------
190 191 ``(stdout, stderr)`` of ipython subprocess.
191 192 """
192 193 if options is None: options = []
193 194
194 195 cmdargs = default_argv() + options
195 196
196 197 test_dir = os.path.dirname(__file__)
197 198
198 199 ipython_cmd = get_ipython_cmd()
199 200 # Absolute path for filename
200 201 full_fname = os.path.join(test_dir, fname)
201 202 full_cmd = ipython_cmd + cmdargs + ['--', full_fname]
202 203 env = os.environ.copy()
203 204 # FIXME: ignore all warnings in ipexec while we have shims
204 205 # should we keep suppressing warnings here, even after removing shims?
205 206 env['PYTHONWARNINGS'] = 'ignore'
206 207 # env.pop('PYTHONWARNINGS', None) # Avoid extraneous warnings appearing on stderr
207 208 for k, v in env.items():
208 209 # Debug a bizarre failure we've seen on Windows:
209 210 # TypeError: environment can only contain strings
210 211 if not isinstance(v, str):
211 212 print(k, v)
212 213 p = Popen(full_cmd, stdout=PIPE, stderr=PIPE, stdin=PIPE, env=env)
213 214 out, err = p.communicate(input=py3compat.encode('\n'.join(commands)) or None)
214 215 out, err = py3compat.decode(out), py3compat.decode(err)
215 216 # `import readline` causes 'ESC[?1034h' to be output sometimes,
216 217 # so strip that out before doing comparisons
217 218 if out:
218 219 out = re.sub(r'\x1b\[[^h]+h', '', out)
219 220 return out, err
220 221
221 222
222 223 def ipexec_validate(fname, expected_out, expected_err='',
223 224 options=None, commands=()):
224 225 """Utility to call 'ipython filename' and validate output/error.
225 226
226 227 This function raises an AssertionError if the validation fails.
227 228
228 229 Note that this starts IPython in a subprocess!
229 230
230 231 Parameters
231 232 ----------
232 233 fname : str
233 234 Name of the file to be executed (should have .py or .ipy extension).
234 235
235 236 expected_out : str
236 237 Expected stdout of the process.
237 238
238 239 expected_err : optional, str
239 240 Expected stderr of the process.
240 241
241 242 options : optional, list
242 243 Extra command-line flags to be passed to IPython.
243 244
244 245 Returns
245 246 -------
246 247 None
247 248 """
248 249
249 250 import nose.tools as nt
250 251
251 252 out, err = ipexec(fname, options, commands)
252 253 #print 'OUT', out # dbg
253 254 #print 'ERR', err # dbg
254 255 # If there are any errors, we must check those before stdout, as they may be
255 256 # more informative than simply having an empty stdout.
256 257 if err:
257 258 if expected_err:
258 259 nt.assert_equal("\n".join(err.strip().splitlines()), "\n".join(expected_err.strip().splitlines()))
259 260 else:
260 261 raise ValueError('Running file %r produced error: %r' %
261 262 (fname, err))
262 263 # If no errors or output on stderr was expected, match stdout
263 264 nt.assert_equal("\n".join(out.strip().splitlines()), "\n".join(expected_out.strip().splitlines()))
264 265
265 266
266 267 class TempFileMixin(unittest.TestCase):
267 268 """Utility class to create temporary Python/IPython files.
268 269
269 270 Meant as a mixin class for test cases."""
270 271
271 272 def mktmp(self, src, ext='.py'):
272 273 """Make a valid python temp file."""
273 274 fname = temp_pyfile(src, ext)
274 275 if not hasattr(self, 'tmps'):
275 276 self.tmps=[]
276 277 self.tmps.append(fname)
277 278 self.fname = fname
278 279
279 280 def tearDown(self):
280 281 # If the tmpfile wasn't made because of skipped tests, like in
281 282 # win32, there's nothing to cleanup.
282 283 if hasattr(self, 'tmps'):
283 284 for fname in self.tmps:
284 285 # If the tmpfile wasn't made because of skipped tests, like in
285 286 # win32, there's nothing to cleanup.
286 287 try:
287 288 os.unlink(fname)
288 289 except:
289 290 # On Windows, even though we close the file, we still can't
290 291 # delete it. I have no clue why
291 292 pass
292 293
293 294 def __enter__(self):
294 295 return self
295 296
296 297 def __exit__(self, exc_type, exc_value, traceback):
297 298 self.tearDown()
298 299
299 300
300 301 pair_fail_msg = ("Testing {0}\n\n"
301 302 "In:\n"
302 303 " {1!r}\n"
303 304 "Expected:\n"
304 305 " {2!r}\n"
305 306 "Got:\n"
306 307 " {3!r}\n")
307 308 def check_pairs(func, pairs):
308 309 """Utility function for the common case of checking a function with a
309 310 sequence of input/output pairs.
310 311
311 312 Parameters
312 313 ----------
313 314 func : callable
314 315 The function to be tested. Should accept a single argument.
315 316 pairs : iterable
316 317 A list of (input, expected_output) tuples.
317 318
318 319 Returns
319 320 -------
320 321 None. Raises an AssertionError if any output does not match the expected
321 322 value.
322 323 """
323 324 name = getattr(func, "func_name", getattr(func, "__name__", "<unknown>"))
324 325 for inp, expected in pairs:
325 326 out = func(inp)
326 327 assert out == expected, pair_fail_msg.format(name, inp, expected, out)
327 328
328 329
329 330 MyStringIO = StringIO
330 331
331 332 _re_type = type(re.compile(r''))
332 333
333 334 notprinted_msg = """Did not find {0!r} in printed output (on {1}):
334 335 -------
335 336 {2!s}
336 337 -------
337 338 """
338 339
339 340 class AssertPrints(object):
340 341 """Context manager for testing that code prints certain text.
341 342
342 343 Examples
343 344 --------
344 345 >>> with AssertPrints("abc", suppress=False):
345 346 ... print("abcd")
346 347 ... print("def")
347 348 ...
348 349 abcd
349 350 def
350 351 """
351 352 def __init__(self, s, channel='stdout', suppress=True):
352 353 self.s = s
353 354 if isinstance(self.s, (str, _re_type)):
354 355 self.s = [self.s]
355 356 self.channel = channel
356 357 self.suppress = suppress
357 358
358 359 def __enter__(self):
359 360 self.orig_stream = getattr(sys, self.channel)
360 361 self.buffer = MyStringIO()
361 362 self.tee = Tee(self.buffer, channel=self.channel)
362 363 setattr(sys, self.channel, self.buffer if self.suppress else self.tee)
363 364
364 365 def __exit__(self, etype, value, traceback):
365 366 try:
366 367 if value is not None:
367 368 # If an error was raised, don't check anything else
368 369 return False
369 370 self.tee.flush()
370 371 setattr(sys, self.channel, self.orig_stream)
371 372 printed = self.buffer.getvalue()
372 373 for s in self.s:
373 374 if isinstance(s, _re_type):
374 375 assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed)
375 376 else:
376 377 assert s in printed, notprinted_msg.format(s, self.channel, printed)
377 378 return False
378 379 finally:
379 380 self.tee.close()
380 381
381 382 printed_msg = """Found {0!r} in printed output (on {1}):
382 383 -------
383 384 {2!s}
384 385 -------
385 386 """
386 387
387 388 class AssertNotPrints(AssertPrints):
388 389 """Context manager for checking that certain output *isn't* produced.
389 390
390 391 Counterpart of AssertPrints"""
391 392 def __exit__(self, etype, value, traceback):
392 393 try:
393 394 if value is not None:
394 395 # If an error was raised, don't check anything else
395 396 self.tee.close()
396 397 return False
397 398 self.tee.flush()
398 399 setattr(sys, self.channel, self.orig_stream)
399 400 printed = self.buffer.getvalue()
400 401 for s in self.s:
401 402 if isinstance(s, _re_type):
402 403 assert not s.search(printed),printed_msg.format(
403 404 s.pattern, self.channel, printed)
404 405 else:
405 406 assert s not in printed, printed_msg.format(
406 407 s, self.channel, printed)
407 408 return False
408 409 finally:
409 410 self.tee.close()
410 411
411 412 @contextmanager
412 413 def mute_warn():
413 414 from IPython.utils import warn
414 415 save_warn = warn.warn
415 416 warn.warn = lambda *a, **kw: None
416 417 try:
417 418 yield
418 419 finally:
419 420 warn.warn = save_warn
420 421
421 422 @contextmanager
422 423 def make_tempfile(name):
423 424 """ Create an empty, named, temporary file for the duration of the context.
424 425 """
425 426 open(name, 'w').close()
426 427 try:
427 428 yield
428 429 finally:
429 430 os.unlink(name)
430 431
431 432 def fake_input(inputs):
432 433 """Temporarily replace the input() function to return the given values
433 434
434 435 Use as a context manager:
435 436
436 437 with fake_input(['result1', 'result2']):
437 438 ...
438 439
439 440 Values are returned in order. If input() is called again after the last value
440 441 was used, EOFError is raised.
441 442 """
442 443 it = iter(inputs)
443 444 def mock_input(prompt=''):
444 445 try:
445 446 return next(it)
446 447 except StopIteration as e:
447 448 raise EOFError('No more inputs given') from e
448 449
449 450 return patch('builtins.input', mock_input)
450 451
451 452 def help_output_test(subcommand=''):
452 453 """test that `ipython [subcommand] -h` works"""
453 454 cmd = get_ipython_cmd() + [subcommand, '-h']
454 455 out, err, rc = get_output_error_code(cmd)
455 456 nt.assert_equal(rc, 0, err)
456 457 nt.assert_not_in("Traceback", err)
457 458 nt.assert_in("Options", out)
458 459 nt.assert_in("--help-all", out)
459 460 return out, err
460 461
461 462
462 463 def help_all_output_test(subcommand=''):
463 464 """test that `ipython [subcommand] --help-all` works"""
464 465 cmd = get_ipython_cmd() + [subcommand, '--help-all']
465 466 out, err, rc = get_output_error_code(cmd)
466 467 nt.assert_equal(rc, 0, err)
467 468 nt.assert_not_in("Traceback", err)
468 469 nt.assert_in("Options", out)
469 470 nt.assert_in("Class", out)
470 471 return out, err
471 472
General Comments 0
You need to be logged in to leave comments. Login now