##// END OF EJS Templates
fix(encryption): don't be strict on enc format when no enc headers are missing....
super-admin -
r5381:0a39631e default
parent child Browse files
Show More
@@ -1,154 +1,155 b''
1 1 # Copyright (C) 2014-2023 RhodeCode GmbH
2 2 #
3 3 # This program is free software: you can redistribute it and/or modify
4 4 # it under the terms of the GNU Affero General Public License, version 3
5 5 # (only), as published by the Free Software Foundation.
6 6 #
7 7 # This program is distributed in the hope that it will be useful,
8 8 # but WITHOUT ANY WARRANTY; without even the implied warranty of
9 9 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
10 10 # GNU General Public License for more details.
11 11 #
12 12 # You should have received a copy of the GNU Affero General Public License
13 13 # along with this program. If not, see <http://www.gnu.org/licenses/>.
14 14 #
15 15 # This program is dual-licensed. If you wish to learn more about the
16 16 # RhodeCode Enterprise Edition, including its added features, Support services,
17 17 # and proprietary license terms, please see https://rhodecode.com/licenses/
18 18
19 19
20 20 """
21 21 Generic encryption library for RhodeCode
22 22 """
23 23
24 24 import base64
25 25 import logging
26 26
27 27 from Crypto.Cipher import AES
28 28 from Crypto import Random
29 29 from Crypto.Hash import HMAC, SHA256
30 30
31 31 from rhodecode.lib.str_utils import safe_bytes, safe_str
32 32 from rhodecode.lib.exceptions import signature_verification_error
33 33
34 34
35 35 class InvalidDecryptedValue(str):
36 36
37 37 def __new__(cls, content):
38 38 """
39 39 This will generate something like this::
40 40 <InvalidDecryptedValue(QkWusFgLJXR6m42v...)>
41 41 And represent a safe indicator that encryption key is broken
42 42 """
43 43 content = f'<{cls.__name__}({content[:16]}...)>'
44 44 return str.__new__(cls, content)
45 45
46
46 47 KEY_FORMAT = b'enc$aes_hmac${1}'
47 48
48 49
49 50 class AESCipher(object):
50 51
51 52 def __init__(self, key: bytes, hmac=False, strict_verification=True):
52 53
53 54 if not key:
54 55 raise ValueError('passed key variable is empty')
55 56 self.strict_verification = strict_verification
56 57 self.block_size = 32
57 58 self.hmac_size = 32
58 59 self.hmac = hmac
59 60
60 61 self.key = SHA256.new(safe_bytes(key)).digest()
61 62 self.hmac_key = SHA256.new(self.key).digest()
62 63
63 64 def verify_hmac_signature(self, raw_data):
64 65 org_hmac_signature = raw_data[-self.hmac_size:]
65 66 data_without_sig = raw_data[:-self.hmac_size]
66 67 recomputed_hmac = HMAC.new(
67 68 self.hmac_key, data_without_sig, digestmod=SHA256).digest()
68 69 return org_hmac_signature == recomputed_hmac
69 70
70 71 def encrypt(self, raw: bytes):
71 72 raw = self._pad(raw)
72 73 iv = Random.new().read(AES.block_size)
73 74 cipher = AES.new(self.key, AES.MODE_CBC, iv)
74 75 enc_value = cipher.encrypt(raw)
75 76
76 77 hmac_signature = b''
77 78 if self.hmac:
78 79 # compute hmac+sha256 on iv + enc text, we use
79 80 # encrypt then mac method to create the signature
80 81 hmac_signature = HMAC.new(
81 82 self.hmac_key, iv + enc_value, digestmod=SHA256).digest()
82 83
83 84 return base64.b64encode(iv + enc_value + hmac_signature)
84 85
85 86 def decrypt(self, enc, safe=True) -> bytes | InvalidDecryptedValue:
86 87 enc_org = enc
87 88 try:
88 89 enc = base64.b64decode(enc)
89 90 except Exception:
90 91 logging.exception('Failed Base64 decode')
91 92 raise signature_verification_error('Failed Base64 decode')
92 93
93 94 if self.hmac and len(enc) > self.hmac_size:
94 95 if self.verify_hmac_signature(enc):
95 96 # cut off the HMAC verification digest
96 97 enc = enc[:-self.hmac_size]
97 98 else:
98 99
99 100 decrypt_fail = InvalidDecryptedValue(safe_str(enc_org))
100 101 if safe:
101 102 return decrypt_fail
102 103 raise signature_verification_error(decrypt_fail)
103 104
104 105 iv = enc[:AES.block_size]
105 106 cipher = AES.new(self.key, AES.MODE_CBC, iv)
106 107 return self._unpad(cipher.decrypt(enc[AES.block_size:]))
107 108
108 109 def _pad(self, s):
109 110 block_pad = (self.block_size - len(s) % self.block_size)
110 111 return s + block_pad * safe_bytes(chr(block_pad))
111 112
112 113 @staticmethod
113 114 def _unpad(s):
114 115 return s[:-ord(s[len(s)-1:])]
115 116
116 117
117 118 def validate_and_decrypt_data(enc_data, enc_key, enc_strict_mode=False, safe=True):
118 119 enc_data = safe_str(enc_data)
119 120
120 121 if '$' not in enc_data:
121 122 # probably not encrypted values
122 123 return enc_data
123 124
124 125 parts = enc_data.split('$', 3)
125 126 if len(parts) != 3:
126 127 raise ValueError(f'Encrypted Data has invalid format, expected {KEY_FORMAT}, got {parts}, org value: {enc_data}')
127 128
128 129 enc_type = parts[1]
129 130 enc_data_part = parts[2]
130 131
131 132 if parts[0] != 'enc':
132 133 # parts ok but without our header?
133 134 return enc_data
134 135
135 136 # at that stage we know it's our encryption
136 137 if enc_type == 'aes':
137 138 decrypted_data = AESCipher(enc_key).decrypt(enc_data_part, safe=safe)
138 139 elif enc_type == 'aes_hmac':
139 140 decrypted_data = AESCipher(
140 141 enc_key, hmac=True,
141 142 strict_verification=enc_strict_mode).decrypt(enc_data_part, safe=safe)
142 143
143 144 else:
144 145 raise ValueError(
145 146 f'Encryption type part is wrong, must be `aes` '
146 147 f'or `aes_hmac`, got `{enc_type}` instead')
147 148
148 149 return decrypted_data
149 150
150 151
151 152 def encrypt_data(data, enc_key: bytes):
152 153 enc_key = safe_bytes(enc_key)
153 154 enc_value = AESCipher(enc_key, hmac=True).encrypt(safe_bytes(data))
154 155 return KEY_FORMAT.replace(b'{1}', enc_value)
@@ -1,97 +1,101 b''
1 1 import os
2 2 import base64
3 3 from cryptography.fernet import Fernet, InvalidToken
4 4 from cryptography.hazmat.backends import default_backend
5 5 from cryptography.hazmat.primitives import hashes
6 6 from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
7 7
8 8 from rhodecode.lib.str_utils import safe_str
9 9 from rhodecode.lib.exceptions import signature_verification_error
10 10
11 11
12 12 class InvalidDecryptedValue(str):
13 13
14 14 def __new__(cls, content):
15 15 """
16 16 This will generate something like this::
17 17 <InvalidDecryptedValue(QkWusFgLJXR6m42v...)>
18 18 And represent a safe indicator that encryption key is broken
19 19 """
20 20 content = f'<{cls.__name__}({content[:16]}...)>'
21 21 return str.__new__(cls, content)
22 22
23 23
24 24 class Encryptor(object):
25 25 key_format = b'enc2$salt:{1}$data:{2}'
26 26
27 27 pref_len = 5 # salt:, data:
28 28
29 29 @classmethod
30 30 def detect_enc_algo(cls, enc_data: bytes):
31 31 parts = enc_data.split(b'$', 3)
32 if len(parts) != 3:
33 raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got {parts}')
34 32
35 33 if b'enc$aes_hmac$' in enc_data:
34 # we expect this data is encrypted, so validate the header
35 if len(parts) != 3:
36 raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got `{parts}`')
36 37 return 'aes'
37 38 elif b'enc2$salt' in enc_data:
39 # we expect this data is encrypted, so validate the header
40 if len(parts) != 3:
41 raise ValueError(f'Encrypted Data has invalid format, expected {cls.key_format}, got `{parts}`')
38 42 return 'fernet'
39 43 return None
40 44
41 45 def __init__(self, enc_key: bytes):
42 46 self.enc_key = enc_key
43 47
44 48 def b64_encode(self, data):
45 49 return base64.urlsafe_b64encode(data)
46 50
47 51 def b64_decode(self, data):
48 52 return base64.urlsafe_b64decode(data)
49 53
50 54 def get_encryptor(self, salt):
51 55 """
52 56 Uses Fernet as encryptor with HMAC signature
53 57 :param salt: random salt used for encrypting the data
54 58 """
55 59 kdf = PBKDF2HMAC(
56 60 algorithm=hashes.SHA512(),
57 61 length=32,
58 62 salt=salt,
59 63 iterations=100000,
60 64 backend=default_backend()
61 65 )
62 66 key = self.b64_encode(kdf.derive(self.enc_key))
63 67 return Fernet(key)
64 68
65 69 def _get_parts(self, enc_data):
66 70 parts = enc_data.split(b'$', 3)
67 71 if len(parts) != 3:
68 raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got {parts}')
72 raise ValueError(f'Encrypted Data has invalid format, expected {self.key_format}, got `{parts}`')
69 73 prefix, salt, enc_data = parts
70 74
71 75 try:
72 76 salt = self.b64_decode(salt[self.pref_len:])
73 77 except TypeError:
74 78 # bad base64
75 79 raise ValueError('Encrypted Data salt invalid format, expected base64 format')
76 80
77 81 enc_data = enc_data[self.pref_len:]
78 82 return prefix, salt, enc_data
79 83
80 84 def encrypt(self, data) -> bytes:
81 85 salt = os.urandom(64)
82 86 encryptor = self.get_encryptor(salt)
83 87 enc_data = encryptor.encrypt(data)
84 88 return self.key_format.replace(b'{1}', self.b64_encode(salt)).replace(b'{2}', enc_data)
85 89
86 90 def decrypt(self, data, safe=True) -> bytes | InvalidDecryptedValue:
87 91 parts = self._get_parts(data)
88 92 salt = parts[1]
89 93 enc_data = parts[2]
90 94 encryptor = self.get_encryptor(salt)
91 95 try:
92 96 return encryptor.decrypt(enc_data)
93 97 except (InvalidToken,):
94 98 decrypt_fail = InvalidDecryptedValue(safe_str(data))
95 99 if safe:
96 100 return decrypt_fail
97 101 raise signature_verification_error(decrypt_fail)
General Comments 0
You need to be logged in to leave comments. Login now