##// END OF EJS Templates
fix(encryptor): use a failsafe mechanism of detecting old algo for encryption to NOT crash the app when switching to fernet
super-admin -
r5363:7bfb02ec default
parent child Browse files
Show More
@@ -1,49 +1,49 b''
1 from rhodecode.lib.str_utils import safe_bytes
1 from rhodecode.lib.str_utils import safe_bytes
2 from rhodecode.lib.encrypt import encrypt_data, validate_and_decrypt_data
2 from rhodecode.lib.encrypt import encrypt_data, validate_and_decrypt_data
3 from rhodecode.lib.encrypt2 import Encryptor
3 from rhodecode.lib.encrypt2 import Encryptor
4
4
5 ALLOWED_ALGOS = ['aes', 'fernet']
5 ALLOWED_ALGOS = ['aes', 'fernet']
6
6
7
7
8 def get_default_algo():
8 def get_default_algo():
9 import rhodecode
9 import rhodecode
10 return rhodecode.CONFIG.get('rhodecode.encrypted_values.algorithm') or 'aes'
10 return rhodecode.CONFIG.get('rhodecode.encrypted_values.algorithm') or 'aes'
11
11
12
12
13 def encrypt_value(value: bytes, enc_key: bytes, algo: str = ''):
13 def encrypt_value(value: bytes, enc_key: bytes, algo: str = ''):
14 if not algo:
14 if not algo:
15 # not explicit algo, just use what's set by config
15 # not explicit algo, just use what's set by config
16 algo = get_default_algo()
16 algo = get_default_algo()
17
17
18 if algo not in ALLOWED_ALGOS:
18 if algo not in ALLOWED_ALGOS:
19 ValueError(f'Bad encryption algorithm, should be {ALLOWED_ALGOS}, got: {algo}')
19 ValueError(f'Bad encryption algorithm, should be {ALLOWED_ALGOS}, got: {algo}')
20
20
21 enc_key = safe_bytes(enc_key)
21 enc_key = safe_bytes(enc_key)
22 value = safe_bytes(value)
22 value = safe_bytes(value)
23
23
24 if algo == 'aes':
24 if algo == 'aes':
25 return encrypt_data(value, enc_key=enc_key)
25 return encrypt_data(value, enc_key=enc_key)
26 if algo == 'fernet':
26 if algo == 'fernet':
27 return Encryptor(enc_key).encrypt(value)
27 return Encryptor(enc_key).encrypt(value)
28
28
29 return value
29 return value
30
30
31
31
32 def decrypt_value(value: bytes, enc_key: bytes, algo: str = '', strict_mode: bool = False):
32 def decrypt_value(value: bytes, enc_key: bytes, algo: str = '', strict_mode: bool = False):
33 enc_key = safe_bytes(enc_key)
34 value = safe_bytes(value)
33
35
34 if not algo:
36 if not algo:
35 # not explicit algo, just use what's set by config
37 # not explicit algo, just use what's set by config
36 algo = get_default_algo()
38 algo = Encryptor.detect_enc_algo(value) or get_default_algo()
37 if algo not in ALLOWED_ALGOS:
39 if algo not in ALLOWED_ALGOS:
38 ValueError(f'Bad encryption algorithm, should be {ALLOWED_ALGOS}, got: {algo}')
40 ValueError(f'Bad encryption algorithm, should be {ALLOWED_ALGOS}, got: {algo}')
39
41
40 enc_key = safe_bytes(enc_key)
41 value = safe_bytes(value)
42 safe = not strict_mode
42 safe = not strict_mode
43
43
44 if algo == 'aes':
44 if algo == 'aes':
45 return validate_and_decrypt_data(value, enc_key, safe=safe)
45 return validate_and_decrypt_data(value, enc_key, safe=safe)
46 if algo == 'fernet':
46 if algo == 'fernet':
47 return Encryptor(enc_key).decrypt(value, safe=safe)
47 return Encryptor(enc_key).decrypt(value, safe=safe)
48
48
49 return value
49 return value
@@ -1,84 +1,97 b''
1 import os
1 import os
2 import base64
2 import base64
3 from cryptography.fernet import Fernet, InvalidToken
3 from cryptography.fernet import Fernet, InvalidToken
4 from cryptography.hazmat.backends import default_backend
4 from cryptography.hazmat.backends import default_backend
5 from cryptography.hazmat.primitives import hashes
5 from cryptography.hazmat.primitives import hashes
6 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
6 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
7
7
8 from rhodecode.lib.str_utils import safe_str
8 from rhodecode.lib.str_utils import safe_str
9 from rhodecode.lib.exceptions import signature_verification_error
9 from rhodecode.lib.exceptions import signature_verification_error
10
10
11
11
12 class InvalidDecryptedValue(str):
12 class InvalidDecryptedValue(str):
13
13
14 def __new__(cls, content):
14 def __new__(cls, content):
15 """
15 """
16 This will generate something like this::
16 This will generate something like this::
17 <InvalidDecryptedValue(QkWusFgLJXR6m42v...)>
17 <InvalidDecryptedValue(QkWusFgLJXR6m42v...)>
18 And represent a safe indicator that encryption key is broken
18 And represent a safe indicator that encryption key is broken
19 """
19 """
20 content = f'<{cls.__name__}({content[:16]}...)>'
20 content = f'<{cls.__name__}({content[:16]}...)>'
21 return str.__new__(cls, content)
21 return str.__new__(cls, content)
22
22
23
23
24 class Encryptor(object):
24 class Encryptor(object):
25 key_format = b'enc2$salt:{1}$data:{2}'
25 key_format = b'enc2$salt:{1}$data:{2}'
26
26 pref_len = 5 # salt:, data:
27 pref_len = 5 # salt:, data:
27
28
29 @classmethod
30 def detect_enc_algo(cls, enc_data: bytes):
31 parts = enc_data.split(b'$', 3)
32 if len(parts) != 3:
33 raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got {parts}')
34
35 if b'enc$aes_hmac$' in enc_data:
36 return 'aes'
37 elif b'enc2$salt' in enc_data:
38 return 'fernet'
39 return None
40
28 def __init__(self, enc_key: bytes):
41 def __init__(self, enc_key: bytes):
29 self.enc_key = enc_key
42 self.enc_key = enc_key
30
43
31 def b64_encode(self, data):
44 def b64_encode(self, data):
32 return base64.urlsafe_b64encode(data)
45 return base64.urlsafe_b64encode(data)
33
46
34 def b64_decode(self, data):
47 def b64_decode(self, data):
35 return base64.urlsafe_b64decode(data)
48 return base64.urlsafe_b64decode(data)
36
49
37 def get_encryptor(self, salt):
50 def get_encryptor(self, salt):
38 """
51 """
39 Uses Fernet as encryptor with HMAC signature
52 Uses Fernet as encryptor with HMAC signature
40 :param salt: random salt used for encrypting the data
53 :param salt: random salt used for encrypting the data
41 """
54 """
42 kdf = PBKDF2HMAC(
55 kdf = PBKDF2HMAC(
43 algorithm=hashes.SHA512(),
56 algorithm=hashes.SHA512(),
44 length=32,
57 length=32,
45 salt=salt,
58 salt=salt,
46 iterations=100000,
59 iterations=100000,
47 backend=default_backend()
60 backend=default_backend()
48 )
61 )
49 key = self.b64_encode(kdf.derive(self.enc_key))
62 key = self.b64_encode(kdf.derive(self.enc_key))
50 return Fernet(key)
63 return Fernet(key)
51
64
52 def _get_parts(self, enc_data):
65 def _get_parts(self, enc_data):
53 parts = enc_data.split(b'$', 3)
66 parts = enc_data.split(b'$', 3)
54 if len(parts) != 3:
67 if len(parts) != 3:
55 raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got {parts}')
68 raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got {parts}')
56 prefix, salt, enc_data = parts
69 prefix, salt, enc_data = parts
57
70
58 try:
71 try:
59 salt = self.b64_decode(salt[self.pref_len:])
72 salt = self.b64_decode(salt[self.pref_len:])
60 except TypeError:
73 except TypeError:
61 # bad base64
74 # bad base64
62 raise ValueError('Encrypted Data salt invalid format, expected base64 format')
75 raise ValueError('Encrypted Data salt invalid format, expected base64 format')
63
76
64 enc_data = enc_data[self.pref_len:]
77 enc_data = enc_data[self.pref_len:]
65 return prefix, salt, enc_data
78 return prefix, salt, enc_data
66
79
67 def encrypt(self, data) -> bytes:
80 def encrypt(self, data) -> bytes:
68 salt = os.urandom(64)
81 salt = os.urandom(64)
69 encryptor = self.get_encryptor(salt)
82 encryptor = self.get_encryptor(salt)
70 enc_data = encryptor.encrypt(data)
83 enc_data = encryptor.encrypt(data)
71 return self.key_format.replace(b'{1}', self.b64_encode(salt)).replace(b'{2}', enc_data)
84 return self.key_format.replace(b'{1}', self.b64_encode(salt)).replace(b'{2}', enc_data)
72
85
73 def decrypt(self, data, safe=True) -> bytes | InvalidDecryptedValue:
86 def decrypt(self, data, safe=True) -> bytes | InvalidDecryptedValue:
74 parts = self._get_parts(data)
87 parts = self._get_parts(data)
75 salt = parts[1]
88 salt = parts[1]
76 enc_data = parts[2]
89 enc_data = parts[2]
77 encryptor = self.get_encryptor(salt)
90 encryptor = self.get_encryptor(salt)
78 try:
91 try:
79 return encryptor.decrypt(enc_data)
92 return encryptor.decrypt(enc_data)
80 except (InvalidToken,):
93 except (InvalidToken,):
81 decrypt_fail = InvalidDecryptedValue(safe_str(data))
94 decrypt_fail = InvalidDecryptedValue(safe_str(data))
82 if safe:
95 if safe:
83 return decrypt_fail
96 return decrypt_fail
84 raise signature_verification_error(decrypt_fail)
97 raise signature_verification_error(decrypt_fail)
General Comments 0
You need to be logged in to leave comments. Login now