import threading import weakref from base64 import b64encode from logging import getLogger from os import urandom from typing import Union from redis import StrictRedis __version__ = '4.0.0' loggers = { k: getLogger("rhodecode." + ".".join((__name__, k))) for k in [ "acquire", "refresh.thread.start", "refresh.thread.stop", "refresh.thread.exit", "refresh.start", "refresh.shutdown", "refresh.exit", "release", ] } text_type = str binary_type = bytes # Check if the id match. If not, return an error code. UNLOCK_SCRIPT = b""" if redis.call("get", KEYS[1]) ~= ARGV[1] then return 1 else redis.call("del", KEYS[2]) redis.call("lpush", KEYS[2], 1) redis.call("pexpire", KEYS[2], ARGV[2]) redis.call("del", KEYS[1]) return 0 end """ # Covers both cases when key doesn't exist and doesn't equal to lock's id EXTEND_SCRIPT = b""" if redis.call("get", KEYS[1]) ~= ARGV[1] then return 1 elseif redis.call("ttl", KEYS[1]) < 0 then return 2 else redis.call("expire", KEYS[1], ARGV[2]) return 0 end """ RESET_SCRIPT = b""" redis.call('del', KEYS[2]) redis.call('lpush', KEYS[2], 1) redis.call('pexpire', KEYS[2], ARGV[2]) return redis.call('del', KEYS[1]) """ RESET_ALL_SCRIPT = b""" local locks = redis.call('keys', 'lock:*') local signal for _, lock in pairs(locks) do signal = 'lock-signal:' .. string.sub(lock, 6) redis.call('del', signal) redis.call('lpush', signal, 1) redis.call('expire', signal, 1) redis.call('del', lock) end return #locks """ class AlreadyAcquired(RuntimeError): pass class NotAcquired(RuntimeError): pass class AlreadyStarted(RuntimeError): pass class TimeoutNotUsable(RuntimeError): pass class InvalidTimeout(RuntimeError): pass class TimeoutTooLarge(RuntimeError): pass class NotExpirable(RuntimeError): pass class Lock: """ A Lock context manager implemented via redis SETNX/BLPOP. """ unlock_script = None extend_script = None reset_script = None reset_all_script = None blocking = None _lock_renewal_interval: float _lock_renewal_thread: Union[threading.Thread, None] def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False, strict=True, signal_expire=1000, blocking=True): """ :param redis_client: An instance of :class:`~StrictRedis`. :param name: The name (redis key) the lock should have. :param expire: The lock expiry time in seconds. If left at the default (None) the lock will not expire. :param id: The ID (redis value) the lock should have. A random value is generated when left at the default. Note that if you specify this then the lock is marked as "held". Acquires won't be possible. :param auto_renewal: If set to ``True``, Lock will automatically renew the lock so that it doesn't expire for as long as the lock is held (acquire() called or running in a context manager). Implementation note: Renewal will happen using a daemon thread with an interval of ``expire*2/3``. If wishing to use a different renewal time, subclass Lock, call ``super().__init__()`` then set ``self._lock_renewal_interval`` to your desired interval. :param strict: If set ``True`` then the ``redis_client`` needs to be an instance of ``redis.StrictRedis``. :param signal_expire: Advanced option to override signal list expiration in milliseconds. Increase it for very slow clients. Default: ``1000``. :param blocking: Boolean value specifying whether lock should be blocking or not. Used in `__enter__` method. """ if strict and not isinstance(redis_client, StrictRedis): raise ValueError("redis_client must be instance of StrictRedis. " "Use strict=False if you know what you're doing.") if auto_renewal and expire is None: raise ValueError("Expire may not be None when auto_renewal is set") self._client = redis_client if expire: expire = int(expire) if expire < 0: raise ValueError("A negative expire is not acceptable.") else: expire = None self._expire = expire self._signal_expire = signal_expire if id is None: self._id = b64encode(urandom(18)).decode('ascii') elif isinstance(id, binary_type): try: self._id = id.decode('ascii') except UnicodeDecodeError: self._id = b64encode(id).decode('ascii') elif isinstance(id, text_type): self._id = id else: raise TypeError(f"Incorrect type for `id`. Must be bytes/str not {type(id)}.") self._name = 'lock:' + name self._signal = 'lock-signal:' + name self._lock_renewal_interval = (float(expire) * 2 / 3 if auto_renewal else None) self._lock_renewal_thread = None self.blocking = blocking self.register_scripts(redis_client) @classmethod def register_scripts(cls, redis_client): global reset_all_script if reset_all_script is None: cls.unlock_script = redis_client.register_script(UNLOCK_SCRIPT) cls.extend_script = redis_client.register_script(EXTEND_SCRIPT) cls.reset_script = redis_client.register_script(RESET_SCRIPT) cls.reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT) reset_all_script = redis_client.register_script(RESET_ALL_SCRIPT) @property def _held(self): return self.id == self.get_owner_id() def reset(self): """ Forcibly deletes the lock. Use this with care. """ self.reset_script(client=self._client, keys=(self._name, self._signal), args=(self.id, self._signal_expire)) @property def id(self): return self._id def get_owner_id(self): owner_id = self._client.get(self._name) if isinstance(owner_id, binary_type): owner_id = owner_id.decode('ascii', 'replace') return owner_id def acquire(self, blocking=True, timeout=None): """ :param blocking: Boolean value specifying whether lock should be blocking or not. :param timeout: An integer value specifying the maximum number of seconds to block. """ logger = loggers["acquire"] logger.debug("Getting blocking: %s acquire on %r ...", blocking, self._name) if self._held: owner_id = self.get_owner_id() raise AlreadyAcquired("Already acquired from this Lock instance. Lock id: {}".format(owner_id)) if not blocking and timeout is not None: raise TimeoutNotUsable("Timeout cannot be used if blocking=False") if timeout: timeout = int(timeout) if timeout < 0: raise InvalidTimeout(f"Timeout ({timeout}) cannot be less than or equal to 0") if self._expire and not self._lock_renewal_interval and timeout > self._expire: raise TimeoutTooLarge(f"Timeout ({timeout}) cannot be greater than expire ({self._expire})") busy = True blpop_timeout = timeout or self._expire or 0 timed_out = False while busy: busy = not self._client.set(self._name, self._id, nx=True, ex=self._expire) if busy: if timed_out: return False elif blocking: timed_out = not self._client.blpop(self._signal, blpop_timeout) and timeout else: logger.warning("Failed to acquire Lock(%r).", self._name) return False logger.debug("Acquired Lock(%r).", self._name) if self._lock_renewal_interval is not None: self._start_lock_renewer() return True def extend(self, expire=None): """ Extends expiration time of the lock. :param expire: New expiration time. If ``None`` - `expire` provided during lock initialization will be taken. """ if expire: expire = int(expire) if expire < 0: raise ValueError("A negative expire is not acceptable.") elif self._expire is not None: expire = self._expire else: raise TypeError( "To extend a lock 'expire' must be provided as an " "argument to extend() method or at initialization time." ) error = self.extend_script(client=self._client, keys=(self._name, self._signal), args=(self._id, expire)) if error == 1: raise NotAcquired(f"Lock {self._name} is not acquired or it already expired.") elif error == 2: raise NotExpirable(f"Lock {self._name} has no assigned expiration time") elif error: raise RuntimeError(f"Unsupported error code {error} from EXTEND script") @staticmethod def _lock_renewer(name, lockref, interval, stop): """ Renew the lock key in redis every `interval` seconds for as long as `self._lock_renewal_thread.should_exit` is False. """ while not stop.wait(timeout=interval): loggers["refresh.thread.start"].debug("Refreshing Lock(%r).", name) lock: "Lock" = lockref() if lock is None: loggers["refresh.thread.stop"].debug( "Stopping loop because Lock(%r) was garbage collected.", name ) break lock.extend(expire=lock._expire) del lock loggers["refresh.thread.exit"].debug("Exiting renewal thread for Lock(%r).", name) def _start_lock_renewer(self): """ Starts the lock refresher thread. """ if self._lock_renewal_thread is not None: raise AlreadyStarted("Lock refresh thread already started") loggers["refresh.start"].debug( "Starting renewal thread for Lock(%r). Refresh interval: %s seconds.", self._name, self._lock_renewal_interval ) self._lock_renewal_stop = threading.Event() self._lock_renewal_thread = threading.Thread( group=None, target=self._lock_renewer, kwargs={ 'name': self._name, 'lockref': weakref.ref(self), 'interval': self._lock_renewal_interval, 'stop': self._lock_renewal_stop, }, ) self._lock_renewal_thread.daemon = True self._lock_renewal_thread.start() def _stop_lock_renewer(self): """ Stop the lock renewer. This signals the renewal thread and waits for its exit. """ if self._lock_renewal_thread is None or not self._lock_renewal_thread.is_alive(): return loggers["refresh.shutdown"].debug("Signaling renewal thread for Lock(%r) to exit.", self._name) self._lock_renewal_stop.set() self._lock_renewal_thread.join() self._lock_renewal_thread = None loggers["refresh.exit"].debug("Renewal thread for Lock(%r) exited.", self._name) def __enter__(self): acquired = self.acquire(blocking=self.blocking) if not acquired: if self.blocking: raise AssertionError(f"Lock({self._name}) wasn't acquired, but blocking=True was used!") raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.") return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): self.release() def release(self): """Releases the lock, that was acquired with the same object. .. note:: If you want to release a lock that you acquired in a different place you have two choices: * Use ``Lock("name", id=id_from_other_place).release()`` * Use ``Lock("name").reset()`` """ if self._lock_renewal_thread is not None: self._stop_lock_renewer() loggers["release"].debug("Releasing Lock(%r).", self._name) error = self.unlock_script(client=self._client, keys=(self._name, self._signal), args=(self._id, self._signal_expire)) if error == 1: raise NotAcquired(f"Lock({self._name}) is not acquired or it already expired.") elif error: raise RuntimeError(f"Unsupported error code {error} from EXTEND script.") def locked(self): """ Return true if the lock is acquired. Checks that lock with same name already exists. This method returns true, even if lock have another id. """ return self._client.exists(self._name) == 1 reset_all_script = None def reset_all(redis_client): """ Forcibly deletes all locks if its remains (like a crash reason). Use this with care. :param redis_client: An instance of :class:`~StrictRedis`. """ Lock.register_scripts(redis_client) reset_all_script(client=redis_client) # noqa