Show More
@@ -28,6 +28,7 b' import hmac' | |||||
28 | import logging |
|
28 | import logging | |
29 | import os |
|
29 | import os | |
30 | import pprint |
|
30 | import pprint | |
|
31 | import random | |||
31 | import uuid |
|
32 | import uuid | |
32 | from datetime import datetime |
|
33 | from datetime import datetime | |
33 |
|
34 | |||
@@ -310,8 +311,16 b' class Session(Configurable):' | |||||
310 | self.auth = hmac.HMAC(new) |
|
311 | self.auth = hmac.HMAC(new) | |
311 | else: |
|
312 | else: | |
312 | self.auth = None |
|
313 | self.auth = None | |
|
314 | ||||
313 | auth = Instance(hmac.HMAC) |
|
315 | auth = Instance(hmac.HMAC) | |
|
316 | ||||
314 | digest_history = Set() |
|
317 | digest_history = Set() | |
|
318 | digest_history_size = Integer(2**16, config=True, | |||
|
319 | help="""The maximum number of digests to remember. | |||
|
320 | ||||
|
321 | The digest history will be culled when it exceeds this value. | |||
|
322 | """ | |||
|
323 | ) | |||
315 |
|
324 | |||
316 | keyfile = Unicode('', config=True, |
|
325 | keyfile = Unicode('', config=True, | |
317 | help="""path to file containing execution key.""") |
|
326 | help="""path to file containing execution key.""") | |
@@ -699,6 +708,30 b' class Session(Configurable):' | |||||
699 | idents, msg_list = msg_list[:idx], msg_list[idx+1:] |
|
708 | idents, msg_list = msg_list[:idx], msg_list[idx+1:] | |
700 | return [m.bytes for m in idents], msg_list |
|
709 | return [m.bytes for m in idents], msg_list | |
701 |
|
710 | |||
|
711 | def _add_digest(self, signature): | |||
|
712 | """add a digest to history to protect against replay attacks""" | |||
|
713 | if self.digest_history_size == 0: | |||
|
714 | # no history, never add digests | |||
|
715 | return | |||
|
716 | ||||
|
717 | self.digest_history.add(signature) | |||
|
718 | if len(self.digest_history) > self.digest_history_size: | |||
|
719 | # threshold reached, cull 10% | |||
|
720 | self._cull_digest_history() | |||
|
721 | ||||
|
722 | def _cull_digest_history(self): | |||
|
723 | """cull the digest history | |||
|
724 | ||||
|
725 | Removes a randomly selected 10% of the digest history | |||
|
726 | """ | |||
|
727 | current = len(self.digest_history) | |||
|
728 | n_to_cull = max(int(current // 10), current - self.digest_history_size) | |||
|
729 | if n_to_cull >= current: | |||
|
730 | self.digest_history = set() | |||
|
731 | return | |||
|
732 | to_cull = random.sample(self.digest_history, n_to_cull) | |||
|
733 | self.digest_history.difference_update(to_cull) | |||
|
734 | ||||
702 | def unserialize(self, msg_list, content=True, copy=True): |
|
735 | def unserialize(self, msg_list, content=True, copy=True): | |
703 | """Unserialize a msg_list to a nested message dict. |
|
736 | """Unserialize a msg_list to a nested message dict. | |
704 |
|
737 | |||
@@ -735,7 +768,7 b' class Session(Configurable):' | |||||
735 | raise ValueError("Unsigned Message") |
|
768 | raise ValueError("Unsigned Message") | |
736 | if signature in self.digest_history: |
|
769 | if signature in self.digest_history: | |
737 | raise ValueError("Duplicate Signature: %r"%signature) |
|
770 | raise ValueError("Duplicate Signature: %r" % signature) | |
738 |
self.digest |
|
771 | self._add_digest(signature) | |
739 | check = self.sign(msg_list[1:5]) |
|
772 | check = self.sign(msg_list[1:5]) | |
740 | if not signature == check: |
|
773 | if not signature == check: | |
741 | raise ValueError("Invalid Signature: %r" % signature) |
|
774 | raise ValueError("Invalid Signature: %r" % signature) |
@@ -204,4 +204,22 b' class TestSession(SessionTestCase):' | |||||
204 | self.assertEqual(session.bsession, session.session.encode('ascii')) |
|
204 | self.assertEqual(session.bsession, session.session.encode('ascii')) | |
205 | self.assertEqual(b'stuff', session.bsession) |
|
205 | self.assertEqual(b'stuff', session.bsession) | |
206 |
|
206 | |||
|
207 | def test_zero_digest_history(self): | |||
|
208 | session = ss.Session(digest_history_size=0) | |||
|
209 | for i in range(11): | |||
|
210 | session._add_digest(uuid.uuid4().bytes) | |||
|
211 | self.assertEqual(len(session.digest_history), 0) | |||
|
212 | ||||
|
213 | def test_cull_digest_history(self): | |||
|
214 | session = ss.Session(digest_history_size=100) | |||
|
215 | for i in range(100): | |||
|
216 | session._add_digest(uuid.uuid4().bytes) | |||
|
217 | self.assertTrue(len(session.digest_history) == 100) | |||
|
218 | session._add_digest(uuid.uuid4().bytes) | |||
|
219 | self.assertTrue(len(session.digest_history) == 91) | |||
|
220 | for i in range(9): | |||
|
221 | session._add_digest(uuid.uuid4().bytes) | |||
|
222 | self.assertTrue(len(session.digest_history) == 100) | |||
|
223 | session._add_digest(uuid.uuid4().bytes) | |||
|
224 | self.assertTrue(len(session.digest_history) == 91) | |||
207 |
|
225 |
General Comments 0
You need to be logged in to leave comments.
Login now