diff --git a/IPython/kernel/zmq/session.py b/IPython/kernel/zmq/session.py index d4b8b13..a85988e 100644 --- a/IPython/kernel/zmq/session.py +++ b/IPython/kernel/zmq/session.py @@ -28,6 +28,7 @@ import hmac import logging import os import pprint +import random import uuid from datetime import datetime @@ -310,8 +311,16 @@ class Session(Configurable): self.auth = hmac.HMAC(new) else: self.auth = None + auth = Instance(hmac.HMAC) + digest_history = Set() + digest_history_size = Integer(2**16, config=True, + help="""The maximum number of digests to remember. + + The digest history will be culled when it exceeds this value. + """ + ) keyfile = Unicode('', config=True, help="""path to file containing execution key.""") @@ -699,6 +708,30 @@ class Session(Configurable): idents, msg_list = msg_list[:idx], msg_list[idx+1:] return [m.bytes for m in idents], msg_list + def _add_digest(self, signature): + """add a digest to history to protect against replay attacks""" + if self.digest_history_size == 0: + # no history, never add digests + return + + self.digest_history.add(signature) + if len(self.digest_history) > self.digest_history_size: + # threshold reached, cull 10% + self._cull_digest_history() + + def _cull_digest_history(self): + """cull the digest history + + Removes a randomly selected 10% of the digest history + """ + current = len(self.digest_history) + n_to_cull = max(int(current // 10), current - self.digest_history_size) + if n_to_cull >= current: + self.digest_history = set() + return + to_cull = random.sample(self.digest_history, n_to_cull) + self.digest_history.difference_update(to_cull) + def unserialize(self, msg_list, content=True, copy=True): """Unserialize a msg_list to a nested message dict. @@ -734,8 +767,8 @@ class Session(Configurable): if not signature: raise ValueError("Unsigned Message") if signature in self.digest_history: - raise ValueError("Duplicate Signature: %r"%signature) - self.digest_history.add(signature) + raise ValueError("Duplicate Signature: %r" % signature) + self._add_digest(signature) check = self.sign(msg_list[1:5]) if not signature == check: raise ValueError("Invalid Signature: %r" % signature) diff --git a/IPython/kernel/zmq/tests/test_session.py b/IPython/kernel/zmq/tests/test_session.py index d039462..df7b77b 100644 --- a/IPython/kernel/zmq/tests/test_session.py +++ b/IPython/kernel/zmq/tests/test_session.py @@ -204,4 +204,22 @@ class TestSession(SessionTestCase): self.assertEqual(session.bsession, session.session.encode('ascii')) self.assertEqual(b'stuff', session.bsession) + def test_zero_digest_history(self): + session = ss.Session(digest_history_size=0) + for i in range(11): + session._add_digest(uuid.uuid4().bytes) + self.assertEqual(len(session.digest_history), 0) + + def test_cull_digest_history(self): + session = ss.Session(digest_history_size=100) + for i in range(100): + session._add_digest(uuid.uuid4().bytes) + self.assertTrue(len(session.digest_history) == 100) + session._add_digest(uuid.uuid4().bytes) + self.assertTrue(len(session.digest_history) == 91) + for i in range(9): + session._add_digest(uuid.uuid4().bytes) + self.assertTrue(len(session.digest_history) == 100) + session._add_digest(uuid.uuid4().bytes) + self.assertTrue(len(session.digest_history) == 91)