|
|
import os
|
|
|
import base64
|
|
|
from cryptography.fernet import Fernet, InvalidToken
|
|
|
from cryptography.hazmat.backends import default_backend
|
|
|
from cryptography.hazmat.primitives import hashes
|
|
|
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
|
|
|
|
|
from rhodecode.lib.str_utils import safe_str
|
|
|
from rhodecode.lib.exceptions import signature_verification_error
|
|
|
|
|
|
|
|
|
class InvalidDecryptedValue(str):
|
|
|
|
|
|
def __new__(cls, content):
|
|
|
"""
|
|
|
This will generate something like this::
|
|
|
<InvalidDecryptedValue(QkWusFgLJXR6m42v...)>
|
|
|
And represent a safe indicator that encryption key is broken
|
|
|
"""
|
|
|
content = f'<{cls.__name__}({content[:16]}...)>'
|
|
|
return str.__new__(cls, content)
|
|
|
|
|
|
|
|
|
class Encryptor(object):
|
|
|
key_format = b'enc2$salt:{1}$data:{2}'
|
|
|
|
|
|
pref_len = 5 # salt:, data:
|
|
|
|
|
|
@classmethod
|
|
|
def detect_enc_algo(cls, enc_data: bytes):
|
|
|
parts = enc_data.split(b'$', 3)
|
|
|
|
|
|
if b'enc$aes_hmac$' in enc_data:
|
|
|
# we expect this data is encrypted, so validate the header
|
|
|
if len(parts) != 3:
|
|
|
raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got `{parts}`')
|
|
|
return 'aes'
|
|
|
elif b'enc2$salt' in enc_data:
|
|
|
# we expect this data is encrypted, so validate the header
|
|
|
if len(parts) != 3:
|
|
|
raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got `{parts}`')
|
|
|
return 'fernet'
|
|
|
return None
|
|
|
|
|
|
def __init__(self, enc_key: bytes):
|
|
|
self.enc_key = enc_key
|
|
|
|
|
|
def b64_encode(self, data):
|
|
|
return base64.urlsafe_b64encode(data)
|
|
|
|
|
|
def b64_decode(self, data):
|
|
|
return base64.urlsafe_b64decode(data)
|
|
|
|
|
|
def get_encryptor(self, salt):
|
|
|
"""
|
|
|
Uses Fernet as encryptor with HMAC signature
|
|
|
:param salt: random salt used for encrypting the data
|
|
|
"""
|
|
|
kdf = PBKDF2HMAC(
|
|
|
algorithm=hashes.SHA512(),
|
|
|
length=32,
|
|
|
salt=salt,
|
|
|
iterations=100000,
|
|
|
backend=default_backend()
|
|
|
)
|
|
|
key = self.b64_encode(kdf.derive(self.enc_key))
|
|
|
return Fernet(key)
|
|
|
|
|
|
def _get_parts(self, enc_data):
|
|
|
parts = enc_data.split(b'$', 3)
|
|
|
if len(parts) != 3:
|
|
|
raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got `{parts}`')
|
|
|
prefix, salt, enc_data = parts
|
|
|
|
|
|
try:
|
|
|
salt = self.b64_decode(salt[self.pref_len:])
|
|
|
except TypeError:
|
|
|
# bad base64
|
|
|
raise ValueError('Encrypted Data salt invalid format, expected base64 format')
|
|
|
|
|
|
enc_data = enc_data[self.pref_len:]
|
|
|
return prefix, salt, enc_data
|
|
|
|
|
|
def encrypt(self, data) -> bytes:
|
|
|
salt = os.urandom(64)
|
|
|
encryptor = self.get_encryptor(salt)
|
|
|
enc_data = encryptor.encrypt(data)
|
|
|
return self.key_format.replace(b'{1}', self.b64_encode(salt)).replace(b'{2}', enc_data)
|
|
|
|
|
|
def decrypt(self, data, safe=True) -> bytes | InvalidDecryptedValue:
|
|
|
parts = self._get_parts(data)
|
|
|
salt = parts[1]
|
|
|
enc_data = parts[2]
|
|
|
encryptor = self.get_encryptor(salt)
|
|
|
try:
|
|
|
return encryptor.decrypt(enc_data)
|
|
|
except (InvalidToken,):
|
|
|
decrypt_fail = InvalidDecryptedValue(safe_str(data))
|
|
|
if safe:
|
|
|
return decrypt_fail
|
|
|
raise signature_verification_error(decrypt_fail)
|
|
|
|