##// END OF EJS Templates
remove MockSocket in test_session...
MinRK -
Show More
@@ -1,220 +1,207 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 20
21 21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 27 self.session = ss.Session()
28 28
29 29
30 class MockSocket(zmq.Socket):
31 data = None
32
33 def __init__(self, *args, **kwargs):
34 super(MockSocket,self).__init__(*args,**kwargs)
35 self.data = []
36
37 def send_multipart(self, msgparts, *args, **kwargs):
38 self.data.extend(msgparts)
39
40 def send(self, part, *args, **kwargs):
41 self.data.append(part)
42
43 def recv_multipart(self, *args, **kwargs):
44 return self.data
45
46 30 class TestSession(SessionTestCase):
47 31
48 32 def test_msg(self):
49 33 """message format"""
50 34 msg = self.session.msg('execute')
51 35 thekeys = set('header parent_header metadata content msg_type msg_id'.split())
52 36 s = set(msg.keys())
53 37 self.assertEqual(s, thekeys)
54 38 self.assertTrue(isinstance(msg['content'],dict))
55 39 self.assertTrue(isinstance(msg['metadata'],dict))
56 40 self.assertTrue(isinstance(msg['header'],dict))
57 41 self.assertTrue(isinstance(msg['parent_header'],dict))
58 42 self.assertTrue(isinstance(msg['msg_id'],str))
59 43 self.assertTrue(isinstance(msg['msg_type'],str))
60 44 self.assertEqual(msg['header']['msg_type'], 'execute')
61 45 self.assertEqual(msg['msg_type'], 'execute')
62 46
63 47 def test_serialize(self):
64 48 msg = self.session.msg('execute', content=dict(a=10, b=1.1))
65 49 msg_list = self.session.serialize(msg, ident=b'foo')
66 50 ident, msg_list = self.session.feed_identities(msg_list)
67 51 new_msg = self.session.unserialize(msg_list)
68 52 self.assertEqual(ident[0], b'foo')
69 53 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
70 54 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
71 55 self.assertEqual(new_msg['header'],msg['header'])
72 56 self.assertEqual(new_msg['content'],msg['content'])
73 57 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
74 58 self.assertEqual(new_msg['metadata'],msg['metadata'])
75 59 # ensure floats don't come out as Decimal:
76 60 self.assertEqual(type(new_msg['content']['b']),type(new_msg['content']['b']))
77 61
78 62 def test_send(self):
79 socket = MockSocket(zmq.Context.instance(),zmq.PAIR)
63 ctx = zmq.Context.instance()
64 A = ctx.socket(zmq.PAIR)
65 B = ctx.socket(zmq.PAIR)
66 A.bind("inproc://test")
67 B.connect("inproc://test")
80 68
81 69 msg = self.session.msg('execute', content=dict(a=10))
82 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
83 ident, msg_list = self.session.feed_identities(socket.data)
70 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
71
72 ident, msg_list = self.session.feed_identities(B.recv_multipart())
84 73 new_msg = self.session.unserialize(msg_list)
85 74 self.assertEqual(ident[0], b'foo')
86 75 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
87 76 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
88 77 self.assertEqual(new_msg['header'],msg['header'])
89 78 self.assertEqual(new_msg['content'],msg['content'])
90 79 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
91 80 self.assertEqual(new_msg['metadata'],msg['metadata'])
92 81 self.assertEqual(new_msg['buffers'],[b'bar'])
93 82
94 socket.data = []
95
96 83 content = msg['content']
97 84 header = msg['header']
98 85 parent = msg['parent_header']
99 86 metadata = msg['metadata']
100 87 msg_type = header['msg_type']
101 self.session.send(socket, None, content=content, parent=parent,
88 self.session.send(A, None, content=content, parent=parent,
102 89 header=header, metadata=metadata, ident=b'foo', buffers=[b'bar'])
103 ident, msg_list = self.session.feed_identities(socket.data)
90 ident, msg_list = self.session.feed_identities(B.recv_multipart())
104 91 new_msg = self.session.unserialize(msg_list)
105 92 self.assertEqual(ident[0], b'foo')
106 93 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
107 94 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
108 95 self.assertEqual(new_msg['header'],msg['header'])
109 96 self.assertEqual(new_msg['content'],msg['content'])
110 97 self.assertEqual(new_msg['metadata'],msg['metadata'])
111 98 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
112 99 self.assertEqual(new_msg['buffers'],[b'bar'])
113 100
114 socket.data = []
115
116 self.session.send(socket, msg, ident=b'foo', buffers=[b'bar'])
117 ident, new_msg = self.session.recv(socket)
101 self.session.send(A, msg, ident=b'foo', buffers=[b'bar'])
102 ident, new_msg = self.session.recv(B)
118 103 self.assertEqual(ident[0], b'foo')
119 104 self.assertEqual(new_msg['msg_id'],msg['msg_id'])
120 105 self.assertEqual(new_msg['msg_type'],msg['msg_type'])
121 106 self.assertEqual(new_msg['header'],msg['header'])
122 107 self.assertEqual(new_msg['content'],msg['content'])
123 108 self.assertEqual(new_msg['metadata'],msg['metadata'])
124 109 self.assertEqual(new_msg['parent_header'],msg['parent_header'])
125 110 self.assertEqual(new_msg['buffers'],[b'bar'])
126 111
127 socket.close()
112 A.close()
113 B.close()
114 ctx.term()
128 115
129 116 def test_args(self):
130 117 """initialization arguments for Session"""
131 118 s = self.session
132 119 self.assertTrue(s.pack is ss.default_packer)
133 120 self.assertTrue(s.unpack is ss.default_unpacker)
134 121 self.assertEqual(s.username, os.environ.get('USER', u'username'))
135 122
136 123 s = ss.Session()
137 124 self.assertEqual(s.username, os.environ.get('USER', u'username'))
138 125
139 126 self.assertRaises(TypeError, ss.Session, pack='hi')
140 127 self.assertRaises(TypeError, ss.Session, unpack='hi')
141 128 u = str(uuid.uuid4())
142 129 s = ss.Session(username=u'carrot', session=u)
143 130 self.assertEqual(s.session, u)
144 131 self.assertEqual(s.username, u'carrot')
145 132
146 133 def test_tracking(self):
147 134 """test tracking messages"""
148 135 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
149 136 s = self.session
150 137 s.copy_threshold = 1
151 138 stream = ZMQStream(a)
152 139 msg = s.send(a, 'hello', track=False)
153 140 self.assertTrue(msg['tracker'] is ss.DONE)
154 141 msg = s.send(a, 'hello', track=True)
155 142 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
156 143 M = zmq.Message(b'hi there', track=True)
157 144 msg = s.send(a, 'hello', buffers=[M], track=True)
158 145 t = msg['tracker']
159 146 self.assertTrue(isinstance(t, zmq.MessageTracker))
160 147 self.assertRaises(zmq.NotDone, t.wait, .1)
161 148 del M
162 149 t.wait(1) # this will raise
163 150
164 151
165 152 # def test_rekey(self):
166 153 # """rekeying dict around json str keys"""
167 154 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
168 155 # self.assertRaises(KeyError, ss.rekey, d)
169 156 #
170 157 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
171 158 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
172 159 # rd = ss.rekey(d)
173 160 # self.assertEqual(d2,rd)
174 161 #
175 162 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
176 163 # d2 = {1.5:d['1.5'],1:d['1']}
177 164 # rd = ss.rekey(d)
178 165 # self.assertEqual(d2,rd)
179 166 #
180 167 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
181 168 # self.assertRaises(KeyError, ss.rekey, d)
182 169 #
183 170 def test_unique_msg_ids(self):
184 171 """test that messages receive unique ids"""
185 172 ids = set()
186 173 for i in range(2**12):
187 174 h = self.session.msg_header('test')
188 175 msg_id = h['msg_id']
189 176 self.assertTrue(msg_id not in ids)
190 177 ids.add(msg_id)
191 178
192 179 def test_feed_identities(self):
193 180 """scrub the front for zmq IDENTITIES"""
194 181 theids = "engine client other".split()
195 182 content = dict(code='whoda',stuff=object())
196 183 themsg = self.session.msg('execute',content=content)
197 184 pmsg = theids
198 185
199 186 def test_session_id(self):
200 187 session = ss.Session()
201 188 # get bs before us
202 189 bs = session.bsession
203 190 us = session.session
204 191 self.assertEqual(us.encode('ascii'), bs)
205 192 session = ss.Session()
206 193 # get us before bs
207 194 us = session.session
208 195 bs = session.bsession
209 196 self.assertEqual(us.encode('ascii'), bs)
210 197 # change propagates:
211 198 session.session = 'something else'
212 199 bs = session.bsession
213 200 us = session.session
214 201 self.assertEqual(us.encode('ascii'), bs)
215 202 session = ss.Session(session='stuff')
216 203 # get us before bs
217 204 self.assertEqual(session.bsession, session.session.encode('ascii'))
218 205 self.assertEqual(b'stuff', session.bsession)
219 206
220 207
General Comments 0
You need to be logged in to leave comments. Login now