##// END OF EJS Templates
rev msg_id to avoid signature collisions...
Min RK -
Show More
@@ -1,318 +1,318 b''
1 1 """test building messages with Session"""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import hmac
7 7 import os
8 8 import uuid
9 9 from datetime import datetime
10 10
11 11 import zmq
12 12
13 13 from zmq.tests import BaseZMQTestCase
14 14 from zmq.eventloop.zmqstream import ZMQStream
15 15
16 16 from IPython.kernel.zmq import session as ss
17 17
18 18 from IPython.testing.decorators import skipif, module_not_available
19 19 from IPython.utils.py3compat import string_types
20 20 from IPython.utils import jsonutil
21 21
22 22 def _bad_packer(obj):
23 23 raise TypeError("I don't work")
24 24
25 25 def _bad_unpacker(bytes):
26 26 raise TypeError("I don't work either")
27 27
28 28 class SessionTestCase(BaseZMQTestCase):
29 29
30 30 def setUp(self):
31 31 BaseZMQTestCase.setUp(self)
32 32 self.session = ss.Session()
33 33
34 34
35 35 class TestSession(SessionTestCase):
36 36
37 37 def test_msg(self):
38 38 """message format"""
39 39 msg = self.session.msg('execute')
40 40 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
41 41 s = set(msg.keys())
42 42 self.assertEqual(s, thekeys)
43 43 self.assertTrue(isinstance(msg['content'],dict))
44 44 self.assertTrue(isinstance(msg['metadata'],dict))
45 45 self.assertTrue(isinstance(msg['header'],dict))
46 46 self.assertTrue(isinstance(msg['parent_header'],dict))
47 47 self.assertTrue(isinstance(msg['msg_id'],str))
48 48 self.assertTrue(isinstance(msg['msg_type'],str))
49 49 self.assertEqual(msg['header']['msg_type'], 'execute')
50 50 self.assertEqual(msg['msg_type'], 'execute')
51 51
52 52 def test_serialize(self):
53 53 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
54 54 msg_list = self.session.serialize(msg, ident=b'foo')
55 55 ident, msg_list = self.session.feed_identities(msg_list)
56 56 new_msg = self.session.deserialize(msg_list)
57 57 self.assertEqual(ident[0], b'foo')
58 58 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
59 59 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
60 60 self.assertEqual(new_msg['header'],msg['header'])
61 61 self.assertEqual(new_msg['content'],msg['content'])
62 62 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
63 63 self.assertEqual(new_msg['metadata'],msg['metadata'])
64 64 # ensure floats don't come out as Decimal:
65 65 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
66 66
67 67 def test_default_secure(self):
68 68 self.assertIsInstance(self.session.key, bytes)
69 69 self.assertIsInstance(self.session.auth, hmac.HMAC)
70 70
71 71 def test_send(self):
72 72 ctx = zmq.Context.instance()
73 73 A = ctx.socket(zmq.PAIR)
74 74 B = ctx.socket(zmq.PAIR)
75 75 A.bind("inproc://test")
76 76 B.connect("inproc://test")
77 77
78 78 msg = self.session.msg('execute', content=dict(a=10))
79 79 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
80 80
81 81 ident, msg_list = self.session.feed_identities(B.recv_multipart())
82 82 new_msg = self.session.deserialize(msg_list)
83 83 self.assertEqual(ident[0], b'foo')
84 84 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
85 85 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
86 86 self.assertEqual(new_msg['header'],msg['header'])
87 87 self.assertEqual(new_msg['content'],msg['content'])
88 88 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
89 89 self.assertEqual(new_msg['metadata'],msg['metadata'])
90 90 self.assertEqual(new_msg['buffers'],[b'bar'])
91 91
92 92 content = msg['content']
93 93 header = msg['header']
94 header['date'] = datetime.now()
94 header['msg_id'] = self.session.msg_id
95 95 parent = msg['parent_header']
96 96 metadata = msg['metadata']
97 97 msg_type = header['msg_type']
98 98 self.session.send(A, None, content=content, parent=parent,
99 99 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
100 100 ident, msg_list = self.session.feed_identities(B.recv_multipart())
101 101 new_msg = self.session.deserialize(msg_list)
102 102 self.assertEqual(ident[0], b'foo')
103 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
103 self.assertEqual(new_msg['msg_id'],header['msg_id'])
104 104 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
105 105 self.assertEqual(new_msg['header'],msg['header'])
106 106 self.assertEqual(new_msg['content'],msg['content'])
107 107 self.assertEqual(new_msg['metadata'],msg['metadata'])
108 108 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
109 109 self.assertEqual(new_msg['buffers'],[b'bar'])
110 110
111 header['date'] = datetime.now()
111 header['msg_id'] = self.session.msg_id
112 112
113 113 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
114 114 ident, new_msg = self.session.recv(B)
115 115 self.assertEqual(ident[0], b'foo')
116 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
116 self.assertEqual(new_msg['msg_id'],header['msg_id'])
117 117 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
118 118 self.assertEqual(new_msg['header'],msg['header'])
119 119 self.assertEqual(new_msg['content'],msg['content'])
120 120 self.assertEqual(new_msg['metadata'],msg['metadata'])
121 121 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
122 122 self.assertEqual(new_msg['buffers'],[b'bar'])
123 123
124 124 A.close()
125 125 B.close()
126 126 ctx.term()
127 127
128 128 def test_args(self):
129 129 """initialization arguments for Session"""
130 130 s = self.session
131 131 self.assertTrue(s.pack is ss.default_packer)
132 132 self.assertTrue(s.unpack is ss.default_unpacker)
133 133 self.assertEqual(s.username, os.environ.get('USER', u'username'))
134 134
135 135 s = ss.Session()
136 136 self.assertEqual(s.username, os.environ.get('USER', u'username'))
137 137
138 138 self.assertRaises(TypeError, ss.Session, pack='hi')
139 139 self.assertRaises(TypeError, ss.Session, unpack='hi')
140 140 u = str(uuid.uuid4())
141 141 s = ss.Session(username=u'carrot', session=u)
142 142 self.assertEqual(s.session, u)
143 143 self.assertEqual(s.username, u'carrot')
144 144
145 145 def test_tracking(self):
146 146 """test tracking messages"""
147 147 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
148 148 s = self.session
149 149 s.copy_threshold = 1
150 150 stream = ZMQStream(a)
151 151 msg = s.send(a, 'hello', track=False)
152 152 self.assertTrue(msg['tracker'] is ss.DONE)
153 153 msg = s.send(a, 'hello', track=True)
154 154 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
155 155 M = zmq.Message(b'hi there', track=True)
156 156 msg = s.send(a, 'hello', buffers=[M], track=True)
157 157 t = msg['tracker']
158 158 self.assertTrue(isinstance(t, zmq.MessageTracker))
159 159 self.assertRaises(zmq.NotDone, t.wait, .1)
160 160 del M
161 161 t.wait(1) # this will raise
162 162
163 163
164 164 def test_unique_msg_ids(self):
165 165 """test that messages receive unique ids"""
166 166 ids = set()
167 167 for i in range(2**12):
168 168 h = self.session.msg_header('test')
169 169 msg_id = h['msg_id']
170 170 self.assertTrue(msg_id not in ids)
171 171 ids.add(msg_id)
172 172
173 173 def test_feed_identities(self):
174 174 """scrub the front for zmq IDENTITIES"""
175 175 theids = "engine client other".split()
176 176 content = dict(code='whoda',stuff=object())
177 177 themsg = self.session.msg('execute',content=content)
178 178 pmsg = theids
179 179
180 180 def test_session_id(self):
181 181 session = ss.Session()
182 182 # get bs before us
183 183 bs = session.bsession
184 184 us = session.session
185 185 self.assertEqual(us.encode('ascii'), bs)
186 186 session = ss.Session()
187 187 # get us before bs
188 188 us = session.session
189 189 bs = session.bsession
190 190 self.assertEqual(us.encode('ascii'), bs)
191 191 # change propagates:
192 192 session.session = 'something else'
193 193 bs = session.bsession
194 194 us = session.session
195 195 self.assertEqual(us.encode('ascii'), bs)
196 196 session = ss.Session(session='stuff')
197 197 # get us before bs
198 198 self.assertEqual(session.bsession, session.session.encode('ascii'))
199 199 self.assertEqual(b'stuff', session.bsession)
200 200
201 201 def test_zero_digest_history(self):
202 202 session = ss.Session(digest_history_size=0)
203 203 for i in range(11):
204 204 session._add_digest(uuid.uuid4().bytes)
205 205 self.assertEqual(len(session.digest_history), 0)
206 206
207 207 def test_cull_digest_history(self):
208 208 session = ss.Session(digest_history_size=100)
209 209 for i in range(100):
210 210 session._add_digest(uuid.uuid4().bytes)
211 211 self.assertTrue(len(session.digest_history) == 100)
212 212 session._add_digest(uuid.uuid4().bytes)
213 213 self.assertTrue(len(session.digest_history) == 91)
214 214 for i in range(9):
215 215 session._add_digest(uuid.uuid4().bytes)
216 216 self.assertTrue(len(session.digest_history) == 100)
217 217 session._add_digest(uuid.uuid4().bytes)
218 218 self.assertTrue(len(session.digest_history) == 91)
219 219
220 220 def test_bad_pack(self):
221 221 try:
222 222 session = ss.Session(pack=_bad_packer)
223 223 except ValueError as e:
224 224 self.assertIn("could not serialize", str(e))
225 225 self.assertIn("don't work", str(e))
226 226 else:
227 227 self.fail("Should have raised ValueError")
228 228
229 229 def test_bad_unpack(self):
230 230 try:
231 231 session = ss.Session(unpack=_bad_unpacker)
232 232 except ValueError as e:
233 233 self.assertIn("could not handle output", str(e))
234 234 self.assertIn("don't work either", str(e))
235 235 else:
236 236 self.fail("Should have raised ValueError")
237 237
238 238 def test_bad_packer(self):
239 239 try:
240 240 session = ss.Session(packer=__name__ + '._bad_packer')
241 241 except ValueError as e:
242 242 self.assertIn("could not serialize", str(e))
243 243 self.assertIn("don't work", str(e))
244 244 else:
245 245 self.fail("Should have raised ValueError")
246 246
247 247 def test_bad_unpacker(self):
248 248 try:
249 249 session = ss.Session(unpacker=__name__ + '._bad_unpacker')
250 250 except ValueError as e:
251 251 self.assertIn("could not handle output", str(e))
252 252 self.assertIn("don't work either", str(e))
253 253 else:
254 254 self.fail("Should have raised ValueError")
255 255
256 256 def test_bad_roundtrip(self):
257 257 with self.assertRaises(ValueError):
258 258 session = ss.Session(unpack=lambda b: 5)
259 259
260 260 def _datetime_test(self, session):
261 261 content = dict(t=datetime.now())
262 262 metadata = dict(t=datetime.now())
263 263 p = session.msg('msg')
264 264 msg = session.msg('msg', content=content, metadata=metadata, parent=p['header'])
265 265 smsg = session.serialize(msg)
266 266 msg2 = session.deserialize(session.feed_identities(smsg)[1])
267 267 assert isinstance(msg2['header']['date'], datetime)
268 268 self.assertEqual(msg['header'], msg2['header'])
269 269 self.assertEqual(msg['parent_header'], msg2['parent_header'])
270 270 self.assertEqual(msg['parent_header'], msg2['parent_header'])
271 271 assert isinstance(msg['content']['t'], datetime)
272 272 assert isinstance(msg['metadata']['t'], datetime)
273 273 assert isinstance(msg2['content']['t'], string_types)
274 274 assert isinstance(msg2['metadata']['t'], string_types)
275 275 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
276 276 self.assertEqual(msg['content'], jsonutil.extract_dates(msg2['content']))
277 277
278 278 def test_datetimes(self):
279 279 self._datetime_test(self.session)
280 280
281 281 def test_datetimes_pickle(self):
282 282 session = ss.Session(packer='pickle')
283 283 self._datetime_test(session)
284 284
285 285 @skipif(module_not_available('msgpack'))
286 286 def test_datetimes_msgpack(self):
287 287 import msgpack
288 288
289 289 session = ss.Session(
290 290 pack=msgpack.packb,
291 291 unpack=lambda buf: msgpack.unpackb(buf, encoding='utf8'),
292 292 )
293 293 self._datetime_test(session)
294 294
295 295 def test_send_raw(self):
296 296 ctx = zmq.Context.instance()
297 297 A = ctx.socket(zmq.PAIR)
298 298 B = ctx.socket(zmq.PAIR)
299 299 A.bind("inproc://test")
300 300 B.connect("inproc://test")
301 301
302 302 msg = self.session.msg('execute', content=dict(a=10))
303 303 msg_list = [self.session.pack(msg[part]) for part in
304 304 ['header', 'parent_header', 'metadata', 'content']]
305 305 self.session.send_raw(A, msg_list, ident=b'foo')
306 306
307 307 ident, new_msg_list = self.session.feed_identities(B.recv_multipart())
308 308 new_msg = self.session.deserialize(new_msg_list)
309 309 self.assertEqual(ident[0], b'foo')
310 310 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
311 311 self.assertEqual(new_msg['header'],msg['header'])
312 312 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
313 313 self.assertEqual(new_msg['content'],msg['content'])
314 314 self.assertEqual(new_msg['metadata'],msg['metadata'])
315 315
316 316 A.close()
317 317 B.close()
318 318 ctx.term()
General Comments 0
You need to be logged in to leave comments. Login now