diff --git a/rust/hg-pyo3/src/ancestors.rs b/rust/hg-pyo3/src/ancestors.rs new file mode 100644 --- /dev/null +++ b/rust/hg-pyo3/src/ancestors.rs @@ -0,0 +1,87 @@ +// ancestors.rs +// +// Copyright 2024 Georges Racinet +// +// This software may be used and distributed according to the terms of the +// GNU General Public License version 2 or any later version. + +//! Bindings for the `hg::ancestors` module provided by the +//! `hg-core` crate. From Python, this will be seen as `pyo3_rustext.ancestor` +//! and can be used as replacement for the the pure `ancestor` Python module. +use cpython::UnsafePyLeaked; +use pyo3::prelude::*; + +use std::sync::RwLock; + +use vcsgraph::lazy_ancestors::AncestorsIterator as VCGAncestorsIterator; + +use crate::convert_cpython::{ + proxy_index_extract, proxy_index_py_leak, py_leaked_borrow_mut, + py_leaked_or_map_err, +}; +use crate::exceptions::{map_lock_error, GraphError}; +use crate::revision::{rev_pyiter_collect, PyRevision}; +use crate::util::new_submodule; +use rusthg::revlog::PySharedIndex; + +#[pyclass] +struct AncestorsIterator { + inner: RwLock>>, +} + +#[pymethods] +impl AncestorsIterator { + #[new] + fn new( + index_proxy: &Bound<'_, PyAny>, + initrevs: &Bound<'_, PyAny>, + stoprev: PyRevision, + inclusive: bool, + ) -> PyResult { + // Safety: we don't leak the "faked" reference out of + // `UnsafePyLeaked` + let initvec: Vec<_> = { + let borrowed_idx = unsafe { proxy_index_extract(index_proxy)? }; + rev_pyiter_collect(initrevs, borrowed_idx)? + }; + let (py, leaked_idx) = proxy_index_py_leak(index_proxy)?; + let res_ait = unsafe { + leaked_idx.map(py, |idx| { + VCGAncestorsIterator::new( + idx, + initvec.into_iter().map(|r| r.0), + stoprev.0, + inclusive, + ) + }) + }; + let ait = + py_leaked_or_map_err(py, res_ait, GraphError::from_vcsgraph)?; + let inner = ait.into(); + Ok(Self { inner }) + } + + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(slf: PyRefMut<'_, Self>) -> PyResult> { + let mut leaked = slf.inner.write().map_err(map_lock_error)?; + // Safety: we don't leak the inner 'static ref out of UnsafePyLeaked + let mut inner = unsafe { py_leaked_borrow_mut(&slf, &mut leaked)? }; + match inner.next() { + Some(Err(e)) => Err(GraphError::from_vcsgraph(e)), + None => Ok(None), + Some(Ok(r)) => Ok(Some(PyRevision(r))), + } + } +} + +pub fn init_module<'py>( + py: Python<'py>, + package: &str, +) -> PyResult> { + let m = new_submodule(py, package, "ancestor")?; + m.add_class::()?; + Ok(m) +} diff --git a/rust/hg-pyo3/src/convert_cpython.rs b/rust/hg-pyo3/src/convert_cpython.rs --- a/rust/hg-pyo3/src/convert_cpython.rs +++ b/rust/hg-pyo3/src/convert_cpython.rs @@ -263,7 +263,6 @@ pub(crate) unsafe fn py_leaked_borrow_mu /// This would spare users of the `cpython` crate the additional `unsafe` deref /// to inspect the error and return it outside `UnsafePyLeaked`, and the /// subsequent unwrapping that this function performs. -#[allow(dead_code)] pub(crate) fn py_leaked_or_map_err( py: cpython::Python, leaked: cpython::UnsafePyLeaked>, diff --git a/rust/hg-pyo3/src/exceptions.rs b/rust/hg-pyo3/src/exceptions.rs --- a/rust/hg-pyo3/src/exceptions.rs +++ b/rust/hg-pyo3/src/exceptions.rs @@ -1,4 +1,4 @@ -use pyo3::exceptions::PyValueError; +use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::import_exception; use pyo3::{create_exception, PyErr}; @@ -32,3 +32,7 @@ impl GraphError { } } } + +pub fn map_lock_error(e: std::sync::PoisonError) -> PyErr { + PyRuntimeError::new_err(format!("In Rust PyO3 bindings: {e}")) +} diff --git a/rust/hg-pyo3/src/lib.rs b/rust/hg-pyo3/src/lib.rs --- a/rust/hg-pyo3/src/lib.rs +++ b/rust/hg-pyo3/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::prelude::*; +mod ancestors; mod convert_cpython; mod dagops; mod exceptions; @@ -17,6 +18,7 @@ fn pyo3_rustext(py: Python<'_>, m: &Boun let name: String = m.getattr("__name__")?.extract()?; let dotted_name = format!("mercurial.{}", name); + m.add_submodule(&ancestors::init_module(py, &dotted_name)?)?; m.add_submodule(&dagops::init_module(py, &dotted_name)?)?; m.add("GraphError", py.get_type::())?; Ok(())