##// END OF EJS Templates
deps: bumped zope.interface==7.1.1
deps: bumped zope.interface==7.1.1

File last commit:

r5431:2e534b16 default
r5595:e7e5db4e default
Show More
__init__.py
402 lines | 13.3 KiB | text/x-python | PythonLexer
import threading
import weakref
from base64 import b64encode
from logging import getLogger
from os import urandom
from typing import Union
from redis import StrictRedis
__version__ = '4.0.0'
loggers = {
k: getLogger("rhodecode." + ".".join((__name__, k)))
for k in [
"acquire",
"refresh.thread.start",
"refresh.thread.stop",
"refresh.thread.exit",
"refresh.start",
"refresh.shutdown",
"refresh.exit",
"release",
]
}
text_type = str
binary_type = bytes
# Check if the id match. If not, return an error code.
UNLOCK_SCRIPT = b"""
if redis.call("get", KEYS[1]) ~= ARGV[1] then
return 1
else
redis.call("del", KEYS[2])
redis.call("lpush", KEYS[2], 1)
redis.call("pexpire", KEYS[2], ARGV[2])
redis.call("del", KEYS[1])
return 0
end
"""
# Covers both cases when key doesn't exist and doesn't equal to lock's id
EXTEND_SCRIPT = b"""
if redis.call("get", KEYS[1]) ~= ARGV[1] then
return 1
elseif redis.call("ttl", KEYS[1]) < 0 then
return 2
else
redis.call("expire", KEYS[1], ARGV[2])
return 0
end
"""
RESET_SCRIPT = b"""
redis.call('del', KEYS[2])
redis.call('lpush', KEYS[2], 1)
redis.call('pexpire', KEYS[2], ARGV[2])
return redis.call('del', KEYS[1])
"""
RESET_ALL_SCRIPT = b"""
local locks = redis.call('keys', 'lock:*')
local signal
for _, lock in pairs(locks) do
signal = 'lock-signal:' .. string.sub(lock, 6)
redis.call('del', signal)
redis.call('lpush', signal, 1)
redis.call('expire', signal, 1)
redis.call('del', lock)
end
return #locks
"""
class AlreadyAcquired(RuntimeError):
pass
class NotAcquired(RuntimeError):
pass
class AlreadyStarted(RuntimeError):
pass
class TimeoutNotUsable(RuntimeError):
pass
class InvalidTimeout(RuntimeError):
pass
class TimeoutTooLarge(RuntimeError):
pass
class NotExpirable(RuntimeError):
pass
class Lock:
"""
A Lock context manager implemented via redis SETNX/BLPOP.
"""
unlock_script = None
extend_script = None
reset_script = None
reset_all_script = None
blocking = None
_lock_renewal_interval: float
_lock_renewal_thread: Union[threading.Thread, None]
def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000, blocking=True):
"""
:param redis_client:
An instance of :class:`~StrictRedis`.
:param name:
The name (redis key) the lock should have.
:param expire:
The lock expiry time in seconds. If left at the default (None)
the lock will not expire.
:param id:
The ID (redis value) the lock should have. A random value is
generated when left at the default.
Note that if you specify this then the lock is marked as "held". Acquires
won't be possible.
:param auto_renewal:
If set to ``True``, Lock will automatically renew the lock so that it
doesn't expire for as long as the lock is held (acquire() called
or running in a context manager).
Implementation note: Renewal will happen using a daemon thread with
an interval of ``expire*2/3``. If wishing to use a different renewal
time, subclass Lock, call ``super().__init__()`` then set
``self._lock_renewal_interval`` to your desired interval.
:param strict:
If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``.
:param signal_expire:
Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``.
:param blocking:
Boolean value specifying whether lock should be blocking or not.
Used in `__enter__` method.
"""
if strict and not isinstance(redis_client, StrictRedis):
raise ValueError("redis_client must be instance of StrictRedis. "
"Use strict=False if you know what you're doing.")
if auto_renewal and expire is None:
raise ValueError("Expire may not be None when auto_renewal is set")
self._client = redis_client
if expire:
expire = int(expire)
if expire < 0:
raise ValueError("A negative expire is not acceptable.")
else:
expire = None
self._expire = expire
self._signal_expire = signal_expire
if id is None:
self._id = b64encode(urandom(18)).decode('ascii')
elif isinstance(id, binary_type):
try:
self._id = id.decode('ascii')
except UnicodeDecodeError:
self._id = b64encode(id).decode('ascii')
elif isinstance(id, text_type):
self._id = id
else:
raise TypeError(f"Incorrect type for `id`. Must be bytes/str not {type(id)}.")
self._name = 'lock:' + name
self._signal = 'lock-signal:' + name
self._lock_renewal_interval = (float(expire) * 2 / 3
if auto_renewal
else None)
self._lock_renewal_thread = None
self.blocking = blocking
self.register_scripts(redis_client)
@classmethod
def register_scripts(cls, redis_client):
global reset_all_script
if reset_all_script is None:
cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT)
cls.extend_script = redis_client.register_script(EXTEND_SCRIPT)
cls.reset_script = redis_client.register_script(RESET_SCRIPT)
cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT)
@property
def _held(self):
return self.id == self.get_owner_id()
def reset(self):
"""
Forcibly deletes the lock. Use this with care.
"""
self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire))
@property
def id(self):
return self._id
def get_owner_id(self):
owner_id = self._client.get(self._name)
if isinstance(owner_id, binary_type):
owner_id = owner_id.decode('ascii', 'replace')
return owner_id
def acquire(self, blocking=True, timeout=None):
"""
:param blocking:
Boolean value specifying whether lock should be blocking or not.
:param timeout:
An integer value specifying the maximum number of seconds to block.
"""
logger = loggers["acquire"]
logger.debug("Getting blocking: %s acquire on %r ...", blocking, self._name)
if self._held:
owner_id = self.get_owner_id()
raise AlreadyAcquired("Already acquired from this Lock instance. Lock id: {}".format(owner_id))
if not blocking and timeout is not None:
raise TimeoutNotUsable("Timeout cannot be used if blocking=False")
if timeout:
timeout = int(timeout)
if timeout < 0:
raise InvalidTimeout(f"Timeout ({timeout}) cannot be less than or equal to 0")
if self._expire and not self._lock_renewal_interval and timeout > self._expire:
raise TimeoutTooLarge(f"Timeout ({timeout}) cannot be greater than expire ({self._expire})")
busy = True
blpop_timeout = timeout or self._expire or 0
timed_out = False
while busy:
busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire)
if busy:
if timed_out:
return False
elif blocking:
timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout
else:
logger.warning("Failed to acquire Lock(%r).", self._name)
return False
logger.debug("Acquired Lock(%r).", self._name)
if self._lock_renewal_interval is not None:
self._start_lock_renewer()
return True
def extend(self, expire=None):
"""
Extends expiration time of the lock.
:param expire:
New expiration time. If ``None`` - `expire` provided during
lock initialization will be taken.
"""
if expire:
expire = int(expire)
if expire < 0:
raise ValueError("A negative expire is not acceptable.")
elif self._expire is not None:
expire = self._expire
else:
raise TypeError(
"To extend a lock 'expire' must be provided as an "
"argument to extend() method or at initialization time."
)
error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire))
if error == 1:
raise NotAcquired(f"Lock {self._name} is not acquired or it already expired.")
elif error == 2:
raise NotExpirable(f"Lock {self._name} has no assigned expiration time")
elif error:
raise RuntimeError(f"Unsupported error code {error} from EXTEND script")
@staticmethod
def _lock_renewer(name, lockref, interval, stop):
"""
Renew the lock key in redis every `interval` seconds for as long
as `self._lock_renewal_thread.should_exit` is False.
"""
while not stop.wait(timeout=interval):
loggers["refresh.thread.start"].debug("Refreshing Lock(%r).", name)
lock: "Lock" = lockref()
if lock is None:
loggers["refresh.thread.stop"].debug(
"Stopping loop because Lock(%r) was garbage collected.", name
)
break
lock.extend(expire=lock._expire)
del lock
loggers["refresh.thread.exit"].debug("Exiting renewal thread for Lock(%r).", name)
def _start_lock_renewer(self):
"""
Starts the lock refresher thread.
"""
if self._lock_renewal_thread is not None:
raise AlreadyStarted("Lock refresh thread already started")
loggers["refresh.start"].debug(
"Starting renewal thread for Lock(%r). Refresh interval: %s seconds.",
self._name, self._lock_renewal_interval
)
self._lock_renewal_stop = threading.Event()
self._lock_renewal_thread = threading.Thread(
group=None,
target=self._lock_renewer,
kwargs={
'name': self._name,
'lockref': weakref.ref(self),
'interval': self._lock_renewal_interval,
'stop': self._lock_renewal_stop,
},
)
self._lock_renewal_thread.daemon = True
self._lock_renewal_thread.start()
def _stop_lock_renewer(self):
"""
Stop the lock renewer.
This signals the renewal thread and waits for its exit.
"""
if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive():
return
loggers["refresh.shutdown"].debug("Signaling renewal thread for Lock(%r) to exit.", self._name)
self._lock_renewal_stop.set()
self._lock_renewal_thread.join()
self._lock_renewal_thread = None
loggers["refresh.exit"].debug("Renewal thread for Lock(%r) exited.", self._name)
def __enter__(self):
acquired = self.acquire(blocking=self.blocking)
if not acquired:
if self.blocking:
raise AssertionError(f"Lock({self._name}) wasn't acquired, but blocking=True was used!")
raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.")
return self
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
self.release()
def release(self):
"""Releases the lock, that was acquired with the same object.
.. note::
If you want to release a lock that you acquired in a different place you have two choices:
* Use ``Lock("name", id=id_from_other_place).release()``
* Use ``Lock("name").reset()``
"""
if self._lock_renewal_thread is not None:
self._stop_lock_renewer()
loggers["release"].debug("Releasing Lock(%r).", self._name)
error = self.unlock_script(client=self._client, keys=(self._name, self._signal), args=(self._id, self._signal_expire))
if error == 1:
raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.")
elif error:
raise RuntimeError(f"Unsupported error code {error} from EXTEND script.")
def locked(self):
"""
Return true if the lock is acquired.
Checks that lock with same name already exists. This method returns true, even if
lock have another id.
"""
return self._client.exists(self._name) == 1
reset_all_script = None
def reset_all(redis_client):
"""
Forcibly deletes all locks if its remains (like a crash reason). Use this with care.
:param redis_client:
An instance of :class:`~StrictRedis`.
"""
Lock.register_scripts(redis_client)
reset_all_script(client=redis_client) # noqa