diff --git a/rust/hg-pyo3/src/ancestors.rs b/rust/hg-pyo3/src/ancestors.rs --- a/rust/hg-pyo3/src/ancestors.rs +++ b/rust/hg-pyo3/src/ancestors.rs @@ -10,9 +10,12 @@ //! and can be used as replacement for the the pure `ancestor` Python module. use cpython::UnsafePyLeaked; use pyo3::prelude::*; +use pyo3::types::PyTuple; +use std::collections::HashSet; use std::sync::RwLock; +use hg::MissingAncestors as CoreMissing; use vcsgraph::lazy_ancestors::{ AncestorsIterator as VCGAncestorsIterator, LazyAncestors as VCGLazyAncestors, @@ -153,6 +156,130 @@ impl LazyAncestors { } } +#[pyclass] +struct MissingAncestors { + inner: RwLock>>, + proxy_index: PyObject, +} + +#[pymethods] +impl MissingAncestors { + #[new] + fn new( + index_proxy: &Bound<'_, PyAny>, + bases: &Bound<'_, PyAny>, + ) -> PyResult { + let cloned_proxy = index_proxy.clone().unbind(); + let bases_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + let (py, leaked_idx) = proxy_index_py_leak(index_proxy)?; + + // Safety: we don't leak the "faked" reference out of + // `UnsafePyLeaked` + let inner = unsafe { + leaked_idx.map(py, |idx| CoreMissing::new(idx, bases_vec)) + }; + Ok(Self { + inner: inner.into(), + proxy_index: cloned_proxy, + }) + } + + fn hasbases(slf: PyRef<'_, Self>) -> PyResult { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner.has_bases()) + } + + fn addbases( + slf: PyRefMut<'_, Self>, + bases: &Bound<'_, PyAny>, + ) -> PyResult<()> { + let index_proxy = slf.proxy_index.bind(slf.py()); + let bases_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + inner.add_bases(bases_vec); + Ok(()) + } + + fn bases(slf: PyRef<'_, Self>) -> PyResult> { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner.get_bases().iter().map(|r| PyRevision(r.0)).collect()) + } + + fn basesheads(slf: PyRef<'_, Self>) -> PyResult> { + let leaked = slf.inner.read().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let inner = unsafe { py_leaked_borrow(&slf, &leaked) }?; + Ok(inner + .bases_heads() + .map_err(GraphError::from_hg)? + .iter() + .map(|r| PyRevision(r.0)) + .collect()) + } + + fn removeancestorsfrom( + slf: PyRef<'_, Self>, + revs: &Bound<'_, PyAny>, + ) -> PyResult<()> { + // Original comment from hg-cpython: + // this is very lame: we convert to a Rust set, update it in place + // and then convert back to Python, only to have Python remove the + // excess (thankfully, Python is happy with a list or even an + // iterator) + // Leads to improve this: + // - have the CoreMissing instead do something emit revisions to + // discard + // - define a trait for sets of revisions in the core and implement + // it for a Python set rewrapped with the GIL marker + // PyO3 additional comment: the trait approach would probably be + // simpler because we can implement it without a Py wrappper, just + // on &Bound<'py, PySet> + let index_proxy = slf.proxy_index.bind(slf.py()); + let mut revs_set: HashSet<_> = + rev_pyiter_collect_with_py_index(revs, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + + inner + .remove_ancestors_from(&mut revs_set) + .map_err(GraphError::from_hg)?; + // convert as Python tuple and discard from original `revs` + let remaining_tuple = + PyTuple::new(slf.py(), revs_set.iter().map(|r| PyRevision(r.0)))?; + revs.call_method("intersection_update", (remaining_tuple,), None)?; + Ok(()) + } + + fn missingancestors( + slf: PyRefMut<'_, Self>, + bases: &Bound<'_, PyAny>, + ) -> PyResult> { + let index_proxy = slf.proxy_index.bind(slf.py()); + let revs_vec: Vec<_> = + rev_pyiter_collect_with_py_index(bases, index_proxy)?; + + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the "faked" reference out of `UnsafePyLeaked` + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked) }?; + + let missing_vec = inner + .missing_ancestors(revs_vec) + .map_err(GraphError::from_hg)?; + Ok(missing_vec.iter().map(|r| PyRevision(r.0)).collect()) + } +} + pub fn init_module<'py>( py: Python<'py>, package: &str, @@ -160,5 +287,6 @@ pub fn init_module<'py>( let m = new_submodule(py, package, "ancestor")?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(m) } diff --git a/tests/test-rust-ancestor.py b/tests/test-rust-ancestor.py --- a/tests/test-rust-ancestor.py +++ b/tests/test-rust-ancestor.py @@ -172,12 +172,6 @@ class RustAncestorsTestMixin: idx = self.parserustindex() self.assertEqual(dagop.headrevs(idx, [1, 2, 3]), {3}) - -class RustCPythonAncestorsTest( - revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin -): - rustext_pkg = rustext - def testmissingancestors(self): MissingAncestors = self.ancestors_mod().MissingAncestors @@ -200,6 +194,12 @@ class RustCPythonAncestorsTest( self.assertEqual(revs, {2, 3}) +class RustCPythonAncestorsTest( + revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin +): + rustext_pkg = rustext + + class PyO3AncestorsTest( revlogtesting.RustRevlogBasedTestBase, RustAncestorsTestMixin ):