diff --git a/rust/hg-cpython/src/ref_sharing.rs b/rust/hg-cpython/src/ref_sharing.rs --- a/rust/hg-cpython/src/ref_sharing.rs +++ b/rust/hg-cpython/src/ref_sharing.rs @@ -25,6 +25,7 @@ use crate::exceptions::AlreadyBorrowed; use cpython::{PyClone, PyObject, PyResult, Python}; use std::cell::{Cell, Ref, RefCell, RefMut}; +use std::ops::{Deref, DerefMut}; /// Manages the shared state between Python and Rust #[derive(Debug, Default)] @@ -333,17 +334,29 @@ impl PyLeaked { } } - /// Returns an immutable reference to the inner value. - pub fn get_ref<'a>(&'a self, _py: Python<'a>) -> &'a T { - self.data.as_ref().unwrap() + /// Immutably borrows the wrapped value. + pub fn try_borrow<'a>( + &'a self, + py: Python<'a>, + ) -> PyResult> { + Ok(PyLeakedRef { + _py: py, + data: self.data.as_ref().unwrap(), + }) } - /// Returns a mutable reference to the inner value. + /// Mutably borrows the wrapped value. /// /// Typically `T` is an iterator. If `T` is an immutable reference, /// `get_mut()` is useless since the inner value can't be mutated. - pub fn get_mut<'a>(&'a mut self, _py: Python<'a>) -> &'a mut T { - self.data.as_mut().unwrap() + pub fn try_borrow_mut<'a>( + &'a mut self, + py: Python<'a>, + ) -> PyResult> { + Ok(PyLeakedRefMut { + _py: py, + data: self.data.as_mut().unwrap(), + }) } /// Converts the inner value by the given function. @@ -389,6 +402,40 @@ impl Drop for PyLeaked { } } +/// Immutably borrowed reference to a leaked value. +pub struct PyLeakedRef<'a, T> { + _py: Python<'a>, + data: &'a T, +} + +impl Deref for PyLeakedRef<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.data + } +} + +/// Mutably borrowed reference to a leaked value. +pub struct PyLeakedRefMut<'a, T> { + _py: Python<'a>, + data: &'a mut T, +} + +impl Deref for PyLeakedRefMut<'_, T> { + type Target = T; + + fn deref(&self) -> &T { + self.data + } +} + +impl DerefMut for PyLeakedRefMut<'_, T> { + fn deref_mut(&mut self) -> &mut T { + self.data + } +} + /// Defines a `py_class!` that acts as a Python iterator over a Rust iterator. /// /// TODO: this is a bit awkward to use, and a better (more complicated) @@ -457,7 +504,8 @@ macro_rules! py_shared_iterator { def __next__(&self) -> PyResult<$success_type> { let mut inner_opt = self.inner(py).borrow_mut(); if let Some(leaked) = inner_opt.as_mut() { - match leaked.get_mut(py).next() { + let mut iter = leaked.try_borrow_mut(py)?; + match iter.next() { None => { // replace Some(inner) by None, drop $leaked inner_opt.take(); @@ -512,6 +560,28 @@ mod test { } #[test] + fn test_leaked_borrow() { + let (gil, owner) = prepare_env(); + let py = gil.python(); + let leaked = owner.string_shared(py).leak_immutable().unwrap(); + let leaked_ref = leaked.try_borrow(py).unwrap(); + assert_eq!(*leaked_ref, "new"); + } + + #[test] + fn test_leaked_borrow_mut() { + let (gil, owner) = prepare_env(); + let py = gil.python(); + let leaked = owner.string_shared(py).leak_immutable().unwrap(); + let mut leaked_iter = unsafe { leaked.map(py, |s| s.chars()) }; + let mut leaked_ref = leaked_iter.try_borrow_mut(py).unwrap(); + assert_eq!(leaked_ref.next(), Some('n')); + assert_eq!(leaked_ref.next(), Some('e')); + assert_eq!(leaked_ref.next(), Some('w')); + assert_eq!(leaked_ref.next(), None); + } + + #[test] fn test_borrow_mut_while_leaked() { let (gil, owner) = prepare_env(); let py = gil.python();