From 3df473c58741eadf468cfb761ac9b7d9db576be4 2022-08-30 09:44:39 From: Matthias Bussonnier Date: 2022-08-30 09:44:39 Subject: [PATCH] Merge pull request #13719 from balval/gh-13666 Fix HistoryAccessor.get_tail bug (#13666) --- diff --git a/IPython/core/history.py b/IPython/core/history.py index 9b0b2cb..1a89060 100644 --- a/IPython/core/history.py +++ b/IPython/core/history.py @@ -202,7 +202,6 @@ class HistoryAccessor(HistoryAccessorBase): config : :class:`~traitlets.config.loader.Config` Config object. hist_file can also be set through this. """ - # We need a pointer back to the shell for various tasks. super(HistoryAccessor, self).__init__(**traits) # defer setting hist_file from kwarg until after init, # otherwise the default kwarg value would clobber any value @@ -344,11 +343,6 @@ class HistoryAccessor(HistoryAccessorBase): def get_tail(self, n=10, raw=True, output=False, include_latest=False): """Get the last n lines from the history database. - Most recent entry last. - - Completion will be reordered so that that the last ones are when - possible from current session. - Parameters ---------- n : int @@ -367,31 +361,12 @@ class HistoryAccessor(HistoryAccessorBase): self.writeout_cache() if not include_latest: n += 1 - # cursor/line/entry - this_cur = list( - self._run_sql( - "WHERE session == ? ORDER BY line DESC LIMIT ? ", - (self.session_number, n), - raw=raw, - output=output, - ) - ) - other_cur = list( - self._run_sql( - "WHERE session != ? ORDER BY session DESC, line DESC LIMIT ?", - (self.session_number, n), - raw=raw, - output=output, - ) + cur = self._run_sql( + "ORDER BY session DESC, line DESC LIMIT ?", (n,), raw=raw, output=output ) - - everything = this_cur + other_cur - - everything = everything[:n] - if not include_latest: - return list(everything)[:0:-1] - return list(everything)[::-1] + return reversed(list(cur)[1:]) + return reversed(list(cur)) @catch_corrupt_db def search(self, pattern="*", raw=True, search_raw=True, @@ -560,7 +535,6 @@ class HistoryManager(HistoryAccessor): def __init__(self, shell=None, config=None, **traits): """Create a new history manager associated with a shell instance. """ - # We need a pointer back to the shell for various tasks. super(HistoryManager, self).__init__(shell=shell, config=config, **traits) self.save_flag = threading.Event() @@ -656,6 +630,59 @@ class HistoryManager(HistoryAccessor): return super(HistoryManager, self).get_session_info(session=session) + @catch_corrupt_db + def get_tail(self, n=10, raw=True, output=False, include_latest=False): + """Get the last n lines from the history database. + + Most recent entry last. + + Completion will be reordered so that that the last ones are when + possible from current session. + + Parameters + ---------- + n : int + The number of lines to get + raw, output : bool + See :meth:`get_range` + include_latest : bool + If False (default), n+1 lines are fetched, and the latest one + is discarded. This is intended to be used where the function + is called by a user command, which it should not return. + + Returns + ------- + Tuples as :meth:`get_range` + """ + self.writeout_cache() + if not include_latest: + n += 1 + # cursor/line/entry + this_cur = list( + self._run_sql( + "WHERE session == ? ORDER BY line DESC LIMIT ? ", + (self.session_number, n), + raw=raw, + output=output, + ) + ) + other_cur = list( + self._run_sql( + "WHERE session != ? ORDER BY session DESC, line DESC LIMIT ?", + (self.session_number, n), + raw=raw, + output=output, + ) + ) + + everything = this_cur + other_cur + + everything = everything[:n] + + if not include_latest: + return list(everything)[:0:-1] + return list(everything)[::-1] + def _get_range_session(self, start=1, stop=None, raw=True, output=False): """Get input and output history from the current session. Called by get_range, and takes similar parameters.""" diff --git a/IPython/core/tests/test_history.py b/IPython/core/tests/test_history.py index 73d50c8..a9ebafd 100644 --- a/IPython/core/tests/test_history.py +++ b/IPython/core/tests/test_history.py @@ -17,7 +17,7 @@ from tempfile import TemporaryDirectory # our own packages from traitlets.config.loader import Config -from IPython.core.history import HistoryManager, extract_hist_ranges +from IPython.core.history import HistoryAccessor, HistoryManager, extract_hist_ranges def test_proper_default_encoding(): @@ -227,3 +227,81 @@ def test_histmanager_disabled(): # hist_file should not be created assert hist_file.exists() is False + + +def test_get_tail_session_awareness(): + """Test .get_tail() is: + - session specific in HistoryManager + - session agnostic in HistoryAccessor + same for .get_last_session_id() + """ + ip = get_ipython() + with TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) + hist_file = tmp_path / "history.sqlite" + get_source = lambda x: x[2] + hm1 = None + hm2 = None + ha = None + try: + # hm1 creates a new session and adds history entries, + # ha catches up + hm1 = HistoryManager(shell=ip, hist_file=hist_file) + hm1_last_sid = hm1.get_last_session_id + ha = HistoryAccessor(hist_file=hist_file) + ha_last_sid = ha.get_last_session_id + + hist1 = ["a=1", "b=1", "c=1"] + for i, h in enumerate(hist1 + [""], start=1): + hm1.store_inputs(i, h) + assert list(map(get_source, hm1.get_tail())) == hist1 + assert list(map(get_source, ha.get_tail())) == hist1 + sid1 = hm1_last_sid() + assert sid1 is not None + assert ha_last_sid() == sid1 + + # hm2 creates a new session and adds entries, + # ha catches up + hm2 = HistoryManager(shell=ip, hist_file=hist_file) + hm2_last_sid = hm2.get_last_session_id + + hist2 = ["a=2", "b=2", "c=2"] + for i, h in enumerate(hist2 + [""], start=1): + hm2.store_inputs(i, h) + tail = hm2.get_tail(n=3) + assert list(map(get_source, tail)) == hist2 + tail = ha.get_tail(n=3) + assert list(map(get_source, tail)) == hist2 + sid2 = hm2_last_sid() + assert sid2 is not None + assert ha_last_sid() == sid2 + assert sid2 != sid1 + + # but hm1 still maintains its point of reference + # and adding more entries to it doesn't change others + # immediate perspective + assert hm1_last_sid() == sid1 + tail = hm1.get_tail(n=3) + assert list(map(get_source, tail)) == hist1 + + hist3 = ["a=3", "b=3", "c=3"] + for i, h in enumerate(hist3 + [""], start=5): + hm1.store_inputs(i, h) + tail = hm1.get_tail(n=7) + assert list(map(get_source, tail)) == hist1 + [""] + hist3 + tail = hm2.get_tail(n=3) + assert list(map(get_source, tail)) == hist2 + tail = ha.get_tail(n=3) + assert list(map(get_source, tail)) == hist2 + assert hm1_last_sid() == sid1 + assert hm2_last_sid() == sid2 + assert ha_last_sid() == sid2 + finally: + if hm1: + hm1.save_thread.stop() + hm1.db.close() + if hm2: + hm2.save_thread.stop() + hm2.db.close() + if ha: + ha.db.close()