import os import base64 from cryptography.fernet import Fernet, InvalidToken from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from rhodecode.lib.str_utils import safe_str from rhodecode.lib.exceptions import signature_verification_error class InvalidDecryptedValue(str): def __new__(cls, content): """ This will generate something like this:: And represent a safe indicator that encryption key is broken """ content = f'<{cls.__name__}({content[:16]}...)>' return str.__new__(cls, content) class Encryptor(object): key_format = b'enc2$salt:{1}$data:{2}' pref_len = 5 # salt:, data: def __init__(self, enc_key: bytes): self.enc_key = enc_key def b64_encode(self, data): return base64.urlsafe_b64encode(data) def b64_decode(self, data): return base64.urlsafe_b64decode(data) def get_encryptor(self, salt): """ Uses Fernet as encryptor with HMAC signature :param salt: random salt used for encrypting the data """ kdf = PBKDF2HMAC( algorithm=hashes.SHA512(), length=32, salt=salt, iterations=100000, backend=default_backend() ) key = self.b64_encode(kdf.derive(self.enc_key)) return Fernet(key) def _get_parts(self, enc_data): parts = enc_data.split(b'$', 3) if len(parts) != 3: raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got {parts}') prefix, salt, enc_data = parts try: salt = self.b64_decode(salt[self.pref_len:]) except TypeError: # bad base64 raise ValueError('Encrypted Data salt invalid format, expected base64 format') enc_data = enc_data[self.pref_len:] return prefix, salt, enc_data def encrypt(self, data) -> bytes: salt = os.urandom(64) encryptor = self.get_encryptor(salt) enc_data = encryptor.encrypt(data) return self.key_format.replace(b'{1}', self.b64_encode(salt)).replace(b'{2}', enc_data) def decrypt(self, data, safe=True) -> bytes | InvalidDecryptedValue: parts = self._get_parts(data) salt = parts[1] enc_data = parts[2] encryptor = self.get_encryptor(salt) try: return encryptor.decrypt(enc_data) except (InvalidToken,): decrypt_fail = InvalidDecryptedValue(safe_str(data)) if safe: return decrypt_fail raise signature_verification_error(decrypt_fail)