From 89d672ee48b0e0998159e85619a986837573ee40 2024-12-13 10:31:04 From: M Bussonnier Date: 2024-12-13 10:31:04 Subject: [PATCH] Add more type in history.py (#14616) This is to help to write a history handler when we do not have sqlite3 for pyodide. --- diff --git a/IPython/core/history.py b/IPython/core/history.py index f59ca11..505d9aa 100644 --- a/IPython/core/history.py +++ b/IPython/core/history.py @@ -1,5 +1,7 @@ """ History related magics and functionality """ +from __future__ import annotations + # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. @@ -7,7 +9,8 @@ import atexit import datetime import re -import sqlite3 + + import threading from pathlib import Path @@ -29,31 +32,55 @@ from traitlets.config.configurable import LoggingConfigurable from IPython.paths import locate_profile from IPython.utils.decorators import undoc +from typing import Iterable, Tuple, Optional, TYPE_CHECKING +import typing + +if TYPE_CHECKING: + from IPython.core.interactiveshell import InteractiveShell + from IPython.config.Configuration import Configuration + +try: + from sqlite3 import DatabaseError, OperationalError + import sqlite3 + + sqlite3_found = True +except ModuleNotFoundError: + sqlite3_found = False + + class DatabaseError(Exception): # type: ignore [no-redef] + pass + + class OperationalError(Exception): # type: ignore [no-redef] + pass + + +InOrInOut = typing.Union[str, Tuple[str, Optional[str]]] #----------------------------------------------------------------------------- # Classes and functions #----------------------------------------------------------------------------- @undoc -class DummyDB(object): +class DummyDB: """Dummy DB that will act as a black hole for history. Only used in the absence of sqlite""" - def execute(*args, **kwargs): + + def execute(*args: typing.Any, **kwargs: typing.Any) -> typing.List: return [] - def commit(self, *args, **kwargs): + def commit(self, *args, **kwargs): # type: ignore [no-untyped-def] pass - def __enter__(self, *args, **kwargs): + def __enter__(self, *args, **kwargs): # type: ignore [no-untyped-def] pass - def __exit__(self, *args, **kwargs): + def __exit__(self, *args, **kwargs): # type: ignore [no-untyped-def] pass @decorator -def only_when_enabled(f, self, *a, **kw): +def only_when_enabled(f, self, *a, **kw): # type: ignore [no-untyped-def] """Decorator: return an empty list in the absence of sqlite.""" if not self.enabled: return [] @@ -66,7 +93,7 @@ def only_when_enabled(f, self, *a, **kw): _SAVE_DB_SIZE = 16384 @decorator -def catch_corrupt_db(f, self, *a, **kw): +def catch_corrupt_db(f, self, *a, **kw): # type: ignore [no-untyped-def] """A decorator which wraps HistoryAccessor method calls to catch errors from a corrupt SQLite database, move the old database out of the way, and create a new one. @@ -76,7 +103,7 @@ def catch_corrupt_db(f, self, *a, **kw): """ try: return f(self, *a, **kw) - except (sqlite3.DatabaseError, sqlite3.OperationalError) as e: + except (DatabaseError, OperationalError) as e: self._corrupt_db_counter += 1 self.log.error("Failed to open SQLite history %s (%s).", self.hist_file, e) if self.hist_file != ':memory:': @@ -114,17 +141,39 @@ def catch_corrupt_db(f, self, *a, **kw): class HistoryAccessorBase(LoggingConfigurable): """An abstract class for History Accessors """ - def get_tail(self, n=10, raw=True, output=False, include_latest=False): + def get_tail( + self, + n: int = 10, + raw: bool = True, + output: bool = False, + include_latest: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: raise NotImplementedError - def search(self, pattern="*", raw=True, search_raw=True, - output=False, n=None, unique=False): + def search( + self, + pattern: str = "*", + raw: bool = True, + search_raw: bool = True, + output: bool = False, + n: Optional[int] = None, + unique: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: raise NotImplementedError - def get_range(self, session, start=1, stop=None, raw=True,output=False): + def get_range( + self, + session: int, + start: int = 1, + stop: Optional[int] = None, + raw: bool = True, + output: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: raise NotImplementedError - def get_range_by_str(self, rangestr, raw=True, output=False): + def get_range_by_str( + self, rangestr: str, raw: bool = True, output: bool = False + ) -> Iterable[Tuple[int, int, InOrInOut]]: raise NotImplementedError @@ -160,7 +209,8 @@ class HistoryAccessor(HistoryAccessorBase): """, ).tag(config=True) - enabled = Bool(True, + enabled = Bool( + sqlite3_found, help="""enable the SQLite history set enabled=False to disable the SQLite history, @@ -179,13 +229,14 @@ class HistoryAccessor(HistoryAccessorBase): ).tag(config=True) @default("connection_options") - def _default_connection_options(self): + def _default_connection_options(self) -> typing.Dict[str, bool]: return dict(check_same_thread=False) # The SQLite database db = Any() @observe('db') - def _db_changed(self, change): + @only_when_enabled + def _db_changed(self, change): # type: ignore [no-untyped-def] """validate the db, since it can be an Instance of two different types""" new = change['new'] connection_types = (DummyDB, sqlite3.Connection) @@ -194,7 +245,9 @@ class HistoryAccessor(HistoryAccessorBase): (self.__class__.__name__, new) raise TraitError(msg) - def __init__(self, profile="default", hist_file="", **traits): + def __init__( + self, profile: str = "default", hist_file: str = "", **traits: typing.Any + ) -> None: """Create a new history accessor. Parameters @@ -222,7 +275,7 @@ class HistoryAccessor(HistoryAccessorBase): self.init_db() - def _get_hist_file_name(self, profile='default'): + def _get_hist_file_name(self, profile: str = "default") -> Path: """Find the history file for the given profile name. This is overridden by the HistoryManager subclass, to use the shell's @@ -236,7 +289,7 @@ class HistoryAccessor(HistoryAccessorBase): return Path(locate_profile(profile)) / "history.sqlite" @catch_corrupt_db - def init_db(self): + def init_db(self) -> None: """Connect to the database, and create tables if necessary.""" if not self.enabled: self.db = DummyDB() @@ -245,7 +298,7 @@ class HistoryAccessor(HistoryAccessorBase): # use detect_types so that timestamps return datetime objects kwargs = dict(detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES) kwargs.update(self.connection_options) - self.db = sqlite3.connect(str(self.hist_file), **kwargs) + self.db = sqlite3.connect(str(self.hist_file), **kwargs) # type: ignore [call-overload] with self.db: self.db.execute( """CREATE TABLE IF NOT EXISTS sessions (session integer @@ -267,7 +320,7 @@ class HistoryAccessor(HistoryAccessorBase): # success! reset corrupt db count self._corrupt_db_counter = 0 - def writeout_cache(self): + def writeout_cache(self) -> None: """Overridden by HistoryManager to dump the cache before certain database lookups.""" pass @@ -275,7 +328,14 @@ class HistoryAccessor(HistoryAccessorBase): ## ------------------------------- ## Methods for retrieving history: ## ------------------------------- - def _run_sql(self, sql, params, raw=True, output=False, latest=False): + def _run_sql( + self, + sql: str, + params: typing.Tuple, + raw: bool = True, + output: bool = False, + latest: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Prepares and runs an SQL query for the history database. Parameters @@ -310,7 +370,9 @@ class HistoryAccessor(HistoryAccessorBase): @only_when_enabled @catch_corrupt_db - def get_session_info(self, session): + def get_session_info( + self, session: int + ) -> Tuple[int, datetime.datetime, Optional[datetime.datetime], Optional[int], str]: """Get info about a session. Parameters @@ -335,7 +397,7 @@ class HistoryAccessor(HistoryAccessorBase): return self.db.execute(query, (session,)).fetchone() @catch_corrupt_db - def get_last_session_id(self): + def get_last_session_id(self) -> Optional[int]: """Get the last session ID currently in the database. Within IPython, this should be the same as the value stored in @@ -343,9 +405,16 @@ class HistoryAccessor(HistoryAccessorBase): """ for record in self.get_tail(n=1, include_latest=True): return record[0] + return None @catch_corrupt_db - def get_tail(self, n=10, raw=True, output=False, include_latest=False): + def get_tail( + self, + n: int = 10, + raw: bool = True, + output: bool = False, + include_latest: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Get the last n lines from the history database. Parameters @@ -374,8 +443,15 @@ class HistoryAccessor(HistoryAccessorBase): return reversed(list(cur)) @catch_corrupt_db - def search(self, pattern="*", raw=True, search_raw=True, - output=False, n=None, unique=False): + def search( + self, + pattern: str = "*", + raw: bool = True, + search_raw: bool = True, + output: bool = False, + n: Optional[int] = None, + unique: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Search the database using unix glob-style matching (wildcards * and ?). @@ -402,7 +478,7 @@ class HistoryAccessor(HistoryAccessorBase): tosearch = "history." + tosearch self.writeout_cache() sqlform = "WHERE %s GLOB ?" % tosearch - params = (pattern,) + params: typing.Tuple[typing.Any, ...] = (pattern,) if unique: sqlform += ' GROUP BY {0}'.format(tosearch) if n is not None: @@ -416,7 +492,14 @@ class HistoryAccessor(HistoryAccessorBase): return cur @catch_corrupt_db - def get_range(self, session, start=1, stop=None, raw=True,output=False): + def get_range( + self, + session: int, + start: int = 1, + stop: Optional[int] = None, + raw: bool = True, + output: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Retrieve input by session. Parameters @@ -443,6 +526,7 @@ class HistoryAccessor(HistoryAccessorBase): (session, line, input) if output is False, or (session, line, (input, output)) if output is True. """ + params: typing.Tuple[typing.Any, ...] if stop: lineclause = "line >= ? AND line < ?" params = (session, start, stop) @@ -453,7 +537,9 @@ class HistoryAccessor(HistoryAccessorBase): return self._run_sql("WHERE session==? AND %s" % lineclause, params, raw=raw, output=output) - def get_range_by_str(self, rangestr, raw=True, output=False): + def get_range_by_str( + self, rangestr: str, raw: bool = True, output: bool = False + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Get lines of history from a string of ranges, as used by magic commands %hist, %save, %macro, etc. @@ -483,8 +569,9 @@ class HistoryManager(HistoryAccessor): # Public interface # An instance of the IPython shell we are attached to - shell = Instance('IPython.core.interactiveshell.InteractiveShellABC', - allow_none=True) + shell = Instance( + "IPython.core.interactiveshell.InteractiveShellABC", allow_none=False + ) # Lists to hold processed and raw history. These start with a blank entry # so that we can index them starting from 1 input_hist_parsed = List([""]) @@ -493,7 +580,7 @@ class HistoryManager(HistoryAccessor): dir_hist: List = List() @default("dir_hist") - def _dir_hist_default(self): + def _dir_hist_default(self) -> typing.List[Path]: try: return [Path.cwd()] except OSError: @@ -503,10 +590,10 @@ class HistoryManager(HistoryAccessor): # execution count. output_hist = Dict() # The text/plain repr of outputs. - output_hist_reprs = Dict() + output_hist_reprs: typing.Dict[int, str] = Dict() # type: ignore [assignment] # The number of the current session in the history database - session_number = Integer() + session_number: int = Integer() # type: ignore [assignment] db_log_output = Bool(False, help="Should the history database include output? (default: no)" @@ -516,13 +603,13 @@ class HistoryManager(HistoryAccessor): "Values of 1 or less effectively disable caching." ).tag(config=True) # The input and output caches - db_input_cache: List = List() - db_output_cache: List = List() + db_input_cache: List[Tuple[int, str, str]] = List() + db_output_cache: List[Tuple[int, str]] = List() # History saving in separate thread save_thread = Instance('IPython.core.history.HistorySavingThread', allow_none=True) - save_flag = Instance(threading.Event, allow_none=True) + save_flag = Instance(threading.Event, allow_none=False) # Private interface # Variables used to store the three last inputs from the user. On each new @@ -538,18 +625,21 @@ class HistoryManager(HistoryAccessor): # an exit call). _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$") - def __init__(self, shell=None, config=None, **traits): - """Create a new history manager associated with a shell instance. - """ - super(HistoryManager, self).__init__(shell=shell, config=config, - **traits) + def __init__( + self, + shell: InteractiveShell, + config: Optional[Configuration] = None, + **traits: typing.Any, + ): + """Create a new history manager associated with a shell instance.""" + super().__init__(shell=shell, config=config, **traits) self.save_flag = threading.Event() self.db_input_cache_lock = threading.Lock() self.db_output_cache_lock = threading.Lock() try: self.new_session() - except sqlite3.OperationalError: + except OperationalError: self.log.error("Failed to create history session in %s. History will not be saved.", self.hist_file, exc_info=True) self.hist_file = ':memory:' @@ -565,7 +655,7 @@ class HistoryManager(HistoryAccessor): ) self.hist_file = ":memory:" - def _get_hist_file_name(self, profile=None): + def _get_hist_file_name(self, profile: Optional[str] = None) -> Path: """Get default history file name based on the Shell's profile. The profile parameter is ignored, but must exist for compatibility with @@ -574,7 +664,7 @@ class HistoryManager(HistoryAccessor): return Path(profile_dir) / "history.sqlite" @only_when_enabled - def new_session(self, conn=None): + def new_session(self, conn: Optional[sqlite3.Connection] = None) -> None: """Get a new session number.""" if conn is None: conn = self.db @@ -585,9 +675,10 @@ class HistoryManager(HistoryAccessor): NULL, '') """, (datetime.datetime.now().isoformat(" "),), ) + assert isinstance(cur.lastrowid, int) self.session_number = cur.lastrowid - def end_session(self): + def end_session(self) -> None: """Close the database session, filling in the end time and line count.""" self.writeout_cache() with self.db: @@ -602,13 +693,13 @@ class HistoryManager(HistoryAccessor): ) self.session_number = 0 - def name_session(self, name): + def name_session(self, name: str) -> None: """Give the current session a name in the history database.""" with self.db: self.db.execute("UPDATE sessions SET remark=? WHERE session==?", (name, self.session_number)) - def reset(self, new_session=True): + def reset(self, new_session: bool = True) -> None: """Clear the session history, releasing all object references, and optionally open a new session.""" self.output_hist.clear() @@ -625,7 +716,9 @@ class HistoryManager(HistoryAccessor): # ------------------------------ # Methods for retrieving history # ------------------------------ - def get_session_info(self, session=0): + def get_session_info( + self, session: int = 0 + ) -> Tuple[int, datetime.datetime, Optional[datetime.datetime], Optional[int], str]: """Get info about a session. Parameters @@ -653,7 +746,13 @@ class HistoryManager(HistoryAccessor): return super(HistoryManager, self).get_session_info(session=session) @catch_corrupt_db - def get_tail(self, n=10, raw=True, output=False, include_latest=False): + def get_tail( + self, + n: int = 10, + raw: bool = True, + output: bool = False, + include_latest: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Get the last n lines from the history database. Most recent entry last. @@ -697,7 +796,7 @@ class HistoryManager(HistoryAccessor): ) ) - everything = this_cur + other_cur + everything: typing.List[Tuple[int, int, InOrInOut]] = this_cur + other_cur everything = everything[:n] @@ -705,7 +804,13 @@ class HistoryManager(HistoryAccessor): return list(everything)[:0:-1] return list(everything)[::-1] - def _get_range_session(self, start=1, stop=None, raw=True, output=False): + def _get_range_session( + self, + start: int = 1, + stop: Optional[int] = None, + raw: bool = True, + output: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Get input and output history from the current session. Called by get_range, and takes similar parameters.""" input_hist = self.input_hist_raw if raw else self.input_hist_parsed @@ -717,7 +822,7 @@ class HistoryManager(HistoryAccessor): stop = n elif stop < 0: stop += n - + line: InOrInOut for i in range(start, stop): if output: line = (input_hist[i], self.output_hist_reprs.get(i)) @@ -725,7 +830,14 @@ class HistoryManager(HistoryAccessor): line = input_hist[i] yield (0, i, line) - def get_range(self, session=0, start=1, stop=None, raw=True,output=False): + def get_range( + self, + session: int = 0, + start: int = 1, + stop: Optional[int] = None, + raw: bool = True, + output: bool = False, + ) -> Iterable[Tuple[int, int, InOrInOut]]: """Retrieve input by session. Parameters @@ -763,7 +875,9 @@ class HistoryManager(HistoryAccessor): ## ---------------------------- ## Methods for storing history: ## ---------------------------- - def store_inputs(self, line_num, source, source_raw=None): + def store_inputs( + self, line_num: int, source: str, source_raw: Optional[str] = None + ) -> None: """Store source and raw input in history and create input cache variables ``_i*``. @@ -811,7 +925,7 @@ class HistoryManager(HistoryAccessor): if self.shell is not None: self.shell.push(to_main, interactive=False) - def store_output(self, line_num): + def store_output(self, line_num: int) -> None: """If database output logging is enabled, this saves all the outputs from the indicated prompt number to the database. It's called by run_cell after code has been executed. @@ -823,6 +937,7 @@ class HistoryManager(HistoryAccessor): """ if (not self.db_log_output) or (line_num not in self.output_hist_reprs): return + lnum: int = line_num output = self.output_hist_reprs[line_num] with self.db_output_cache_lock: @@ -830,20 +945,20 @@ class HistoryManager(HistoryAccessor): if self.db_cache_size <= 1: self.save_flag.set() - def _writeout_input_cache(self, conn): + def _writeout_input_cache(self, conn: sqlite3.Connection) -> None: with conn: for line in self.db_input_cache: conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)", (self.session_number,)+line) - def _writeout_output_cache(self, conn): + def _writeout_output_cache(self, conn: sqlite3.Connection) -> None: with conn: for line in self.db_output_cache: conn.execute("INSERT INTO output_history VALUES (?, ?, ?)", (self.session_number,)+line) @only_when_enabled - def writeout_cache(self, conn=None): + def writeout_cache(self, conn: Optional[sqlite3.Connection] = None) -> None: """Write any entries in the cache to the database.""" if conn is None: conn = self.db @@ -885,13 +1000,15 @@ class HistorySavingThread(threading.Thread): daemon = True stop_now = False enabled = True - def __init__(self, history_manager): + history_manager: HistoryManager + + def __init__(self, history_manager: HistoryManager) -> None: super(HistorySavingThread, self).__init__(name="IPythonHistorySavingThread") self.history_manager = history_manager self.enabled = history_manager.enabled @only_when_enabled - def run(self): + def run(self) -> None: atexit.register(self.stop) # We need a separate db connection per thread: try: @@ -912,7 +1029,7 @@ class HistorySavingThread(threading.Thread): finally: atexit.unregister(self.stop) - def stop(self): + def stop(self) -> None: """This can be called from the main thread to safely stop this thread. Note that it does not attempt to write out remaining history before @@ -933,7 +1050,7 @@ range_re = re.compile(r""" $""", re.VERBOSE) -def extract_hist_ranges(ranges_str): +def extract_hist_ranges(ranges_str: str) -> Iterable[Tuple[int, int, Optional[int]]]: """Turn a string of history ranges into 3-tuples of (session, start, stop). Empty string results in a `[(0, 1, None)]`, i.e. "everything from current @@ -965,6 +1082,7 @@ def extract_hist_ranges(ranges_str): end = None # provide the entire session hist if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3] + assert end is not None end += 1 startsess = rmatch.group("startsess") or "0" endsess = rmatch.group("endsess") or startsess @@ -982,7 +1100,7 @@ def extract_hist_ranges(ranges_str): yield (endsess, 1, end) -def _format_lineno(session, line): +def _format_lineno(session: int, line: int) -> str: """Helper function to format line numbers properly.""" if session == 0: return str(line) diff --git a/pyproject.toml b/pyproject.toml index 3f8aaf3..2c11b60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -199,7 +199,6 @@ module = [ "IPython.core.formatters", "IPython.core.getipython", "IPython.core.guarded_eval", - "IPython.core.history", "IPython.core.historyapp", "IPython.core.hooks", "IPython.core.inputtransformer",