diff --git a/mercurial/lock.py b/mercurial/lock.py --- a/mercurial/lock.py +++ b/mercurial/lock.py @@ -40,7 +40,7 @@ class lock(object): _host = None def __init__(self, vfs, file, timeout=-1, releasefn=None, acquirefn=None, - desc=None, parentlock=None): + desc=None, inheritchecker=None, parentlock=None): self.vfs = vfs self.f = file self.held = 0 @@ -48,6 +48,7 @@ class lock(object): self.releasefn = releasefn self.acquirefn = acquirefn self.desc = desc + self._inheritchecker = inheritchecker self.parentlock = parentlock self._parentheld = False self._inherited = False @@ -186,6 +187,8 @@ class lock(object): if self._inherited: raise error.LockInheritanceContractViolation( 'inherit cannot be called while lock is already inherited') + if self._inheritchecker is not None: + self._inheritchecker() if self.releasefn: self.releasefn() if self._parentheld: diff --git a/tests/test-lock.py b/tests/test-lock.py --- a/tests/test-lock.py +++ b/tests/test-lock.py @@ -8,6 +8,7 @@ import types import unittest from mercurial import ( + error, lock, scmutil, ) @@ -250,5 +251,21 @@ class testlock(unittest.TestCase): parentlock.release() + def testinheritcheck(self): + d = tempfile.mkdtemp(dir=os.getcwd()) + state = teststate(self, d) + def check(): + raise error.LockInheritanceContractViolation('check failed') + lock = state.makelock(inheritchecker=check) + state.assertacquirecalled(True) + + def tryinherit(): + with lock.inherit() as lockname: + pass + + self.assertRaises(error.LockInheritanceContractViolation, tryinherit) + + lock.release() + if __name__ == '__main__': silenttestrunner.main(__name__)