##// END OF EJS Templates
Merge pull request #5098 from minrk/parallel-debug...
Thomas Kluyver -
r15313:9d27bffe merge
parent child Browse files
Show More
@@ -1,317 +1,317
1 1 """A Task logger that presents our DB interface,
2 2 but exists entirely in memory and implemented with dicts.
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7
8 8
9 9 TaskRecords are dicts of the form::
10 10
11 11 {
12 12 'msg_id' : str(uuid),
13 13 'client_uuid' : str(uuid),
14 14 'engine_uuid' : str(uuid) or None,
15 15 'header' : dict(header),
16 16 'content': dict(content),
17 17 'buffers': list(buffers),
18 'submitted': datetime,
18 'submitted': datetime or None,
19 19 'started': datetime or None,
20 20 'completed': datetime or None,
21 'resubmitted': datetime or None,
21 'received': datetime or None,
22 'resubmitted': str(uuid) or None,
22 23 'result_header' : dict(header) or None,
23 24 'result_content' : dict(content) or None,
24 25 'result_buffers' : list(buffers) or None,
25 26 }
26 27
27 28 With this info, many of the special categories of tasks can be defined by query,
28 29 e.g.:
29 30
30 31 * pending: completed is None
31 32 * client's outstanding: client_uuid = uuid && completed is None
32 33 * MIA: arrived is None (and completed is None)
33 34
34 EngineRecords are dicts of the form::
35
36 {
37 'eid' : int(id),
38 'uuid': str(uuid)
39 }
40
41 This may be extended, but is currently.
42
43 We support a subset of mongodb operators::
35 DictDB supports a subset of mongodb operators::
44 36
45 37 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
46 38 """
47 39 #-----------------------------------------------------------------------------
48 40 # Copyright (C) 2010-2011 The IPython Development Team
49 41 #
50 42 # Distributed under the terms of the BSD License. The full license is in
51 43 # the file COPYING, distributed as part of this software.
52 44 #-----------------------------------------------------------------------------
53 45
54 46 from copy import deepcopy as copy
55 47 from datetime import datetime
56 48
57 49 from IPython.config.configurable import LoggingConfigurable
58 50
59 51 from IPython.utils.py3compat import iteritems, itervalues
60 52 from IPython.utils.traitlets import Dict, Unicode, Integer, Float
61 53
62 54 filters = {
63 55 '$lt' : lambda a,b: a < b,
64 56 '$gt' : lambda a,b: b > a,
65 57 '$eq' : lambda a,b: a == b,
66 58 '$ne' : lambda a,b: a != b,
67 59 '$lte': lambda a,b: a <= b,
68 60 '$gte': lambda a,b: a >= b,
69 61 '$in' : lambda a,b: a in b,
70 62 '$nin': lambda a,b: a not in b,
71 63 '$all': lambda a,b: all([ a in bb for bb in b ]),
72 64 '$mod': lambda a,b: a%b[0] == b[1],
73 65 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
74 66 }
75 67
76 68
77 69 class CompositeFilter(object):
78 70 """Composite filter for matching multiple properties."""
79 71
80 72 def __init__(self, dikt):
81 73 self.tests = []
82 74 self.values = []
83 75 for key, value in iteritems(dikt):
84 76 self.tests.append(filters[key])
85 77 self.values.append(value)
86 78
87 79 def __call__(self, value):
88 80 for test,check in zip(self.tests, self.values):
89 81 if not test(value, check):
90 82 return False
91 83 return True
92 84
93 85 class BaseDB(LoggingConfigurable):
94 86 """Empty Parent class so traitlets work on DB."""
95 87 # base configurable traits:
96 88 session = Unicode("")
97 89
98 90 class DictDB(BaseDB):
99 91 """Basic in-memory dict-based object for saving Task Records.
100 92
101 93 This is the first object to present the DB interface
102 94 for logging tasks out of memory.
103 95
104 96 The interface is based on MongoDB, so adding a MongoDB
105 97 backend should be straightforward.
106 98 """
107 99
108 100 _records = Dict()
109 101 _culled_ids = set() # set of ids which have been culled
110 102 _buffer_bytes = Integer(0) # running total of the bytes in the DB
111 103
112 104 size_limit = Integer(1024**3, config=True,
113 105 help="""The maximum total size (in bytes) of the buffers stored in the db
114 106
115 107 When the db exceeds this size, the oldest records will be culled until
116 108 the total size is under size_limit * (1-cull_fraction).
117 109 default: 1 GB
118 110 """
119 111 )
120 112 record_limit = Integer(1024, config=True,
121 113 help="""The maximum number of records in the db
122 114
123 115 When the history exceeds this size, the first record_limit * cull_fraction
124 116 records will be culled.
125 117 """
126 118 )
127 119 cull_fraction = Float(0.1, config=True,
128 120 help="""The fraction by which the db should culled when one of the limits is exceeded
129 121
130 122 In general, the db size will spend most of its time with a size in the range:
131 123
132 124 [limit * (1-cull_fraction), limit]
133 125
134 126 for each of size_limit and record_limit.
135 127 """
136 128 )
137 129
138 130 def _match_one(self, rec, tests):
139 131 """Check if a specific record matches tests."""
140 132 for key,test in iteritems(tests):
141 133 if not test(rec.get(key, None)):
142 134 return False
143 135 return True
144 136
145 137 def _match(self, check):
146 138 """Find all the matches for a check dict."""
147 139 matches = []
148 140 tests = {}
149 141 for k,v in iteritems(check):
150 142 if isinstance(v, dict):
151 143 tests[k] = CompositeFilter(v)
152 144 else:
153 145 tests[k] = lambda o: o==v
154 146
155 147 for rec in itervalues(self._records):
156 148 if self._match_one(rec, tests):
157 149 matches.append(copy(rec))
158 150 return matches
159 151
160 152 def _extract_subdict(self, rec, keys):
161 153 """extract subdict of keys"""
162 154 d = {}
163 155 d['msg_id'] = rec['msg_id']
164 156 for key in keys:
165 157 d[key] = rec[key]
166 158 return copy(d)
167 159
168 160 # methods for monitoring size / culling history
169 161
170 162 def _add_bytes(self, rec):
171 163 for key in ('buffers', 'result_buffers'):
172 164 for buf in rec.get(key) or []:
173 165 self._buffer_bytes += len(buf)
174 166
175 167 self._maybe_cull()
176 168
177 169 def _drop_bytes(self, rec):
178 170 for key in ('buffers', 'result_buffers'):
179 171 for buf in rec.get(key) or []:
180 172 self._buffer_bytes -= len(buf)
181 173
182 174 def _cull_oldest(self, n=1):
183 175 """cull the oldest N records"""
184 176 for msg_id in self.get_history()[:n]:
185 177 self.log.debug("Culling record: %r", msg_id)
186 178 self._culled_ids.add(msg_id)
187 179 self.drop_record(msg_id)
188 180
189 181 def _maybe_cull(self):
190 182 # cull by count:
191 183 if len(self._records) > self.record_limit:
192 184 to_cull = int(self.cull_fraction * self.record_limit)
193 185 self.log.info("%i records exceeds limit of %i, culling oldest %i",
194 186 len(self._records), self.record_limit, to_cull
195 187 )
196 188 self._cull_oldest(to_cull)
197 189
198 190 # cull by size:
199 191 if self._buffer_bytes > self.size_limit:
200 192 limit = self.size_limit * (1 - self.cull_fraction)
201 193
202 194 before = self._buffer_bytes
203 195 before_count = len(self._records)
204 196 culled = 0
205 197 while self._buffer_bytes > limit:
206 198 self._cull_oldest(1)
207 199 culled += 1
208 200
209 201 self.log.info("%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.",
210 202 before_count, before, self.size_limit, culled
211 203 )
212 204
205 def _check_dates(self, rec):
206 for key in ('submitted', 'started', 'completed'):
207 value = rec.get(key, None)
208 if value is not None and not isinstance(value, datetime):
209 raise ValueError("%s must be None or datetime, not %r" % (key, value))
210
213 211 # public API methods:
214 212
215 213 def add_record(self, msg_id, rec):
216 214 """Add a new Task Record, by msg_id."""
217 215 if msg_id in self._records:
218 216 raise KeyError("Already have msg_id %r"%(msg_id))
217 self._check_dates(rec)
219 218 self._records[msg_id] = rec
220 219 self._add_bytes(rec)
221 220 self._maybe_cull()
222 221
223 222 def get_record(self, msg_id):
224 223 """Get a specific Task Record, by msg_id."""
225 224 if msg_id in self._culled_ids:
226 225 raise KeyError("Record %r has been culled for size" % msg_id)
227 226 if not msg_id in self._records:
228 227 raise KeyError("No such msg_id %r"%(msg_id))
229 228 return copy(self._records[msg_id])
230 229
231 230 def update_record(self, msg_id, rec):
232 231 """Update the data in an existing record."""
233 232 if msg_id in self._culled_ids:
234 233 raise KeyError("Record %r has been culled for size" % msg_id)
234 self._check_dates(rec)
235 235 _rec = self._records[msg_id]
236 236 self._drop_bytes(_rec)
237 237 _rec.update(rec)
238 238 self._add_bytes(_rec)
239 239
240 240 def drop_matching_records(self, check):
241 241 """Remove a record from the DB."""
242 242 matches = self._match(check)
243 243 for rec in matches:
244 244 self._drop_bytes(rec)
245 245 del self._records[rec['msg_id']]
246 246
247 247 def drop_record(self, msg_id):
248 248 """Remove a record from the DB."""
249 249 rec = self._records[msg_id]
250 250 self._drop_bytes(rec)
251 251 del self._records[msg_id]
252 252
253 253 def find_records(self, check, keys=None):
254 254 """Find records matching a query dict, optionally extracting subset of keys.
255 255
256 256 Returns dict keyed by msg_id of matching records.
257 257
258 258 Parameters
259 259 ----------
260 260
261 261 check: dict
262 262 mongodb-style query argument
263 263 keys: list of strs [optional]
264 264 if specified, the subset of keys to extract. msg_id will *always* be
265 265 included.
266 266 """
267 267 matches = self._match(check)
268 268 if keys:
269 269 return [ self._extract_subdict(rec, keys) for rec in matches ]
270 270 else:
271 271 return matches
272 272
273 273 def get_history(self):
274 274 """get all msg_ids, ordered by time submitted."""
275 275 msg_ids = self._records.keys()
276 276 # Remove any that do not have a submitted timestamp.
277 277 # This is extremely unlikely to happen,
278 278 # but it seems to come up in some tests on VMs.
279 279 msg_ids = [ m for m in msg_ids if self._records[m]['submitted'] is not None ]
280 280 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
281 281
282 282
283 283 NODATA = KeyError("NoDB backend doesn't store any data. "
284 284 "Start the Controller with a DB backend to enable resubmission / result persistence."
285 285 )
286 286
287 287
288 288 class NoDB(BaseDB):
289 289 """A blackhole db backend that actually stores no information.
290 290
291 291 Provides the full DB interface, but raises KeyErrors on any
292 292 method that tries to access the records. This can be used to
293 293 minimize the memory footprint of the Hub when its record-keeping
294 294 functionality is not required.
295 295 """
296 296
297 297 def add_record(self, msg_id, record):
298 298 pass
299 299
300 300 def get_record(self, msg_id):
301 301 raise NODATA
302 302
303 303 def update_record(self, msg_id, record):
304 304 pass
305 305
306 306 def drop_matching_records(self, check):
307 307 pass
308 308
309 309 def drop_record(self, msg_id):
310 310 pass
311 311
312 312 def find_records(self, check, keys=None):
313 313 raise NODATA
314 314
315 315 def get_history(self):
316 316 raise NODATA
317 317
@@ -1,1426 +1,1436
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import json
22 22 import os
23 23 import sys
24 24 import time
25 25 from datetime import datetime
26 26
27 27 import zmq
28 28 from zmq.eventloop import ioloop
29 29 from zmq.eventloop.zmqstream import ZMQStream
30 30
31 31 # internal:
32 32 from IPython.utils.importstring import import_item
33 33 from IPython.utils.jsonutil import extract_dates
34 34 from IPython.utils.localinterfaces import localhost
35 35 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
36 36 from IPython.utils.traitlets import (
37 37 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
38 38 )
39 39
40 40 from IPython.parallel import error, util
41 41 from IPython.parallel.factory import RegistrationFactory
42 42
43 43 from IPython.kernel.zmq.session import SessionFactory
44 44
45 45 from .heartmonitor import HeartMonitor
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Code
49 49 #-----------------------------------------------------------------------------
50 50
51 51 def _passer(*args, **kwargs):
52 52 return
53 53
54 54 def _printer(*args, **kwargs):
55 55 print (args)
56 56 print (kwargs)
57 57
58 58 def empty_record():
59 59 """Return an empty dict with all record keys."""
60 60 return {
61 61 'msg_id' : None,
62 62 'header' : None,
63 63 'metadata' : None,
64 64 'content': None,
65 65 'buffers': None,
66 66 'submitted': None,
67 67 'client_uuid' : None,
68 68 'engine_uuid' : None,
69 69 'started': None,
70 70 'completed': None,
71 71 'resubmitted': None,
72 72 'received': None,
73 73 'result_header' : None,
74 74 'result_metadata' : None,
75 75 'result_content' : None,
76 76 'result_buffers' : None,
77 77 'queue' : None,
78 78 'pyin' : None,
79 79 'pyout': None,
80 80 'pyerr': None,
81 81 'stdout': '',
82 82 'stderr': '',
83 83 }
84 84
85 85 def init_record(msg):
86 86 """Initialize a TaskRecord based on a request."""
87 87 header = msg['header']
88 88 return {
89 89 'msg_id' : header['msg_id'],
90 90 'header' : header,
91 91 'content': msg['content'],
92 92 'metadata': msg['metadata'],
93 93 'buffers': msg['buffers'],
94 94 'submitted': header['date'],
95 95 'client_uuid' : None,
96 96 'engine_uuid' : None,
97 97 'started': None,
98 98 'completed': None,
99 99 'resubmitted': None,
100 100 'received': None,
101 101 'result_header' : None,
102 102 'result_metadata': None,
103 103 'result_content' : None,
104 104 'result_buffers' : None,
105 105 'queue' : None,
106 106 'pyin' : None,
107 107 'pyout': None,
108 108 'pyerr': None,
109 109 'stdout': '',
110 110 'stderr': '',
111 111 }
112 112
113 113
114 114 class EngineConnector(HasTraits):
115 115 """A simple object for accessing the various zmq connections of an object.
116 116 Attributes are:
117 117 id (int): engine ID
118 118 uuid (unicode): engine UUID
119 119 pending: set of msg_ids
120 120 stallback: DelayedCallback for stalled registration
121 121 """
122 122
123 123 id = Integer(0)
124 124 uuid = Unicode()
125 125 pending = Set()
126 126 stallback = Instance(ioloop.DelayedCallback)
127 127
128 128
129 129 _db_shortcuts = {
130 130 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
131 131 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
132 132 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
133 133 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
134 134 }
135 135
136 136 class HubFactory(RegistrationFactory):
137 137 """The Configurable for setting up a Hub."""
138 138
139 139 # port-pairs for monitoredqueues:
140 140 hb = Tuple(Integer,Integer,config=True,
141 141 help="""PUB/ROUTER Port pair for Engine heartbeats""")
142 142 def _hb_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 mux = Tuple(Integer,Integer,config=True,
146 146 help="""Client/Engine Port pair for MUX queue""")
147 147
148 148 def _mux_default(self):
149 149 return tuple(util.select_random_ports(2))
150 150
151 151 task = Tuple(Integer,Integer,config=True,
152 152 help="""Client/Engine Port pair for Task queue""")
153 153 def _task_default(self):
154 154 return tuple(util.select_random_ports(2))
155 155
156 156 control = Tuple(Integer,Integer,config=True,
157 157 help="""Client/Engine Port pair for Control queue""")
158 158
159 159 def _control_default(self):
160 160 return tuple(util.select_random_ports(2))
161 161
162 162 iopub = Tuple(Integer,Integer,config=True,
163 163 help="""Client/Engine Port pair for IOPub relay""")
164 164
165 165 def _iopub_default(self):
166 166 return tuple(util.select_random_ports(2))
167 167
168 168 # single ports:
169 169 mon_port = Integer(config=True,
170 170 help="""Monitor (SUB) port for queue traffic""")
171 171
172 172 def _mon_port_default(self):
173 173 return util.select_random_ports(1)[0]
174 174
175 175 notifier_port = Integer(config=True,
176 176 help="""PUB port for sending engine status notifications""")
177 177
178 178 def _notifier_port_default(self):
179 179 return util.select_random_ports(1)[0]
180 180
181 181 engine_ip = Unicode(config=True,
182 182 help="IP on which to listen for engine connections. [default: loopback]")
183 183 def _engine_ip_default(self):
184 184 return localhost()
185 185 engine_transport = Unicode('tcp', config=True,
186 186 help="0MQ transport for engine connections. [default: tcp]")
187 187
188 188 client_ip = Unicode(config=True,
189 189 help="IP on which to listen for client connections. [default: loopback]")
190 190 client_transport = Unicode('tcp', config=True,
191 191 help="0MQ transport for client connections. [default : tcp]")
192 192
193 193 monitor_ip = Unicode(config=True,
194 194 help="IP on which to listen for monitor messages. [default: loopback]")
195 195 monitor_transport = Unicode('tcp', config=True,
196 196 help="0MQ transport for monitor messages. [default : tcp]")
197 197
198 198 _client_ip_default = _monitor_ip_default = _engine_ip_default
199 199
200 200
201 201 monitor_url = Unicode('')
202 202
203 203 db_class = DottedObjectName('NoDB',
204 204 config=True, help="""The class to use for the DB backend
205 205
206 206 Options include:
207 207
208 208 SQLiteDB: SQLite
209 209 MongoDB : use MongoDB
210 210 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
211 211 NoDB : disable database altogether (default)
212 212
213 213 """)
214 214
215 215 # not configurable
216 216 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
217 217 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
218 218
219 219 def _ip_changed(self, name, old, new):
220 220 self.engine_ip = new
221 221 self.client_ip = new
222 222 self.monitor_ip = new
223 223 self._update_monitor_url()
224 224
225 225 def _update_monitor_url(self):
226 226 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
227 227
228 228 def _transport_changed(self, name, old, new):
229 229 self.engine_transport = new
230 230 self.client_transport = new
231 231 self.monitor_transport = new
232 232 self._update_monitor_url()
233 233
234 234 def __init__(self, **kwargs):
235 235 super(HubFactory, self).__init__(**kwargs)
236 236 self._update_monitor_url()
237 237
238 238
239 239 def construct(self):
240 240 self.init_hub()
241 241
242 242 def start(self):
243 243 self.heartmonitor.start()
244 244 self.log.info("Heartmonitor started")
245 245
246 246 def client_url(self, channel):
247 247 """return full zmq url for a named client channel"""
248 248 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
249 249
250 250 def engine_url(self, channel):
251 251 """return full zmq url for a named engine channel"""
252 252 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
253 253
254 254 def init_hub(self):
255 255 """construct Hub object"""
256 256
257 257 ctx = self.context
258 258 loop = self.loop
259 259 if 'TaskScheduler.scheme_name' in self.config:
260 260 scheme = self.config.TaskScheduler.scheme_name
261 261 else:
262 262 from .scheduler import TaskScheduler
263 263 scheme = TaskScheduler.scheme_name.get_default_value()
264 264
265 265 # build connection dicts
266 266 engine = self.engine_info = {
267 267 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
268 268 'registration' : self.regport,
269 269 'control' : self.control[1],
270 270 'mux' : self.mux[1],
271 271 'hb_ping' : self.hb[0],
272 272 'hb_pong' : self.hb[1],
273 273 'task' : self.task[1],
274 274 'iopub' : self.iopub[1],
275 275 }
276 276
277 277 client = self.client_info = {
278 278 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
279 279 'registration' : self.regport,
280 280 'control' : self.control[0],
281 281 'mux' : self.mux[0],
282 282 'task' : self.task[0],
283 283 'task_scheme' : scheme,
284 284 'iopub' : self.iopub[0],
285 285 'notification' : self.notifier_port,
286 286 }
287 287
288 288 self.log.debug("Hub engine addrs: %s", self.engine_info)
289 289 self.log.debug("Hub client addrs: %s", self.client_info)
290 290
291 291 # Registrar socket
292 292 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
293 293 util.set_hwm(q, 0)
294 294 q.bind(self.client_url('registration'))
295 295 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
296 296 if self.client_ip != self.engine_ip:
297 297 q.bind(self.engine_url('registration'))
298 298 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
299 299
300 300 ### Engine connections ###
301 301
302 302 # heartbeat
303 303 hpub = ctx.socket(zmq.PUB)
304 304 hpub.bind(self.engine_url('hb_ping'))
305 305 hrep = ctx.socket(zmq.ROUTER)
306 306 util.set_hwm(hrep, 0)
307 307 hrep.bind(self.engine_url('hb_pong'))
308 308 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
309 309 pingstream=ZMQStream(hpub,loop),
310 310 pongstream=ZMQStream(hrep,loop)
311 311 )
312 312
313 313 ### Client connections ###
314 314
315 315 # Notifier socket
316 316 n = ZMQStream(ctx.socket(zmq.PUB), loop)
317 317 n.bind(self.client_url('notification'))
318 318
319 319 ### build and launch the queues ###
320 320
321 321 # monitor socket
322 322 sub = ctx.socket(zmq.SUB)
323 323 sub.setsockopt(zmq.SUBSCRIBE, b"")
324 324 sub.bind(self.monitor_url)
325 325 sub.bind('inproc://monitor')
326 326 sub = ZMQStream(sub, loop)
327 327
328 328 # connect the db
329 329 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
330 330 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
331 331 self.db = import_item(str(db_class))(session=self.session.session,
332 332 parent=self, log=self.log)
333 333 time.sleep(.25)
334 334
335 335 # resubmit stream
336 336 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
337 337 url = util.disambiguate_url(self.client_url('task'))
338 338 r.connect(url)
339 339
340 340 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
341 341 query=q, notifier=n, resubmit=r, db=self.db,
342 342 engine_info=self.engine_info, client_info=self.client_info,
343 343 log=self.log)
344 344
345 345
346 346 class Hub(SessionFactory):
347 347 """The IPython Controller Hub with 0MQ connections
348 348
349 349 Parameters
350 350 ==========
351 351 loop: zmq IOLoop instance
352 352 session: Session object
353 353 <removed> context: zmq context for creating new connections (?)
354 354 queue: ZMQStream for monitoring the command queue (SUB)
355 355 query: ZMQStream for engine registration and client queries requests (ROUTER)
356 356 heartbeat: HeartMonitor object checking the pulse of the engines
357 357 notifier: ZMQStream for broadcasting engine registration changes (PUB)
358 358 db: connection to db for out of memory logging of commands
359 359 NotImplemented
360 360 engine_info: dict of zmq connection information for engines to connect
361 361 to the queues.
362 362 client_info: dict of zmq connection information for engines to connect
363 363 to the queues.
364 364 """
365 365
366 366 engine_state_file = Unicode()
367 367
368 368 # internal data structures:
369 369 ids=Set() # engine IDs
370 370 keytable=Dict()
371 371 by_ident=Dict()
372 372 engines=Dict()
373 373 clients=Dict()
374 374 hearts=Dict()
375 375 pending=Set()
376 376 queues=Dict() # pending msg_ids keyed by engine_id
377 377 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
378 378 completed=Dict() # completed msg_ids keyed by engine_id
379 379 all_completed=Set() # completed msg_ids keyed by engine_id
380 380 dead_engines=Set() # completed msg_ids keyed by engine_id
381 381 unassigned=Set() # set of task msg_ds not yet assigned a destination
382 382 incoming_registrations=Dict()
383 383 registration_timeout=Integer()
384 384 _idcounter=Integer(0)
385 385
386 386 # objects from constructor:
387 387 query=Instance(ZMQStream)
388 388 monitor=Instance(ZMQStream)
389 389 notifier=Instance(ZMQStream)
390 390 resubmit=Instance(ZMQStream)
391 391 heartmonitor=Instance(HeartMonitor)
392 392 db=Instance(object)
393 393 client_info=Dict()
394 394 engine_info=Dict()
395 395
396 396
397 397 def __init__(self, **kwargs):
398 398 """
399 399 # universal:
400 400 loop: IOLoop for creating future connections
401 401 session: streamsession for sending serialized data
402 402 # engine:
403 403 queue: ZMQStream for monitoring queue messages
404 404 query: ZMQStream for engine+client registration and client requests
405 405 heartbeat: HeartMonitor object for tracking engines
406 406 # extra:
407 407 db: ZMQStream for db connection (NotImplemented)
408 408 engine_info: zmq address/protocol dict for engine connections
409 409 client_info: zmq address/protocol dict for client connections
410 410 """
411 411
412 412 super(Hub, self).__init__(**kwargs)
413 413 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
414 414
415 415 # register our callbacks
416 416 self.query.on_recv(self.dispatch_query)
417 417 self.monitor.on_recv(self.dispatch_monitor_traffic)
418 418
419 419 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
420 420 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
421 421
422 422 self.monitor_handlers = {b'in' : self.save_queue_request,
423 423 b'out': self.save_queue_result,
424 424 b'intask': self.save_task_request,
425 425 b'outtask': self.save_task_result,
426 426 b'tracktask': self.save_task_destination,
427 427 b'incontrol': _passer,
428 428 b'outcontrol': _passer,
429 429 b'iopub': self.save_iopub_message,
430 430 }
431 431
432 432 self.query_handlers = {'queue_request': self.queue_status,
433 433 'result_request': self.get_results,
434 434 'history_request': self.get_history,
435 435 'db_request': self.db_query,
436 436 'purge_request': self.purge_results,
437 437 'load_request': self.check_load,
438 438 'resubmit_request': self.resubmit_task,
439 439 'shutdown_request': self.shutdown_request,
440 440 'registration_request' : self.register_engine,
441 441 'unregistration_request' : self.unregister_engine,
442 442 'connection_request': self.connection_request,
443 443 }
444 444
445 445 # ignore resubmit replies
446 446 self.resubmit.on_recv(lambda msg: None, copy=False)
447 447
448 448 self.log.info("hub::created hub")
449 449
450 450 @property
451 451 def _next_id(self):
452 452 """gemerate a new ID.
453 453
454 454 No longer reuse old ids, just count from 0."""
455 455 newid = self._idcounter
456 456 self._idcounter += 1
457 457 return newid
458 458 # newid = 0
459 459 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
460 460 # # print newid, self.ids, self.incoming_registrations
461 461 # while newid in self.ids or newid in incoming:
462 462 # newid += 1
463 463 # return newid
464 464
465 465 #-----------------------------------------------------------------------------
466 466 # message validation
467 467 #-----------------------------------------------------------------------------
468 468
469 469 def _validate_targets(self, targets):
470 470 """turn any valid targets argument into a list of integer ids"""
471 471 if targets is None:
472 472 # default to all
473 473 return self.ids
474 474
475 475 if isinstance(targets, (int,str,unicode_type)):
476 476 # only one target specified
477 477 targets = [targets]
478 478 _targets = []
479 479 for t in targets:
480 480 # map raw identities to ids
481 481 if isinstance(t, (str,unicode_type)):
482 482 t = self.by_ident.get(cast_bytes(t), t)
483 483 _targets.append(t)
484 484 targets = _targets
485 485 bad_targets = [ t for t in targets if t not in self.ids ]
486 486 if bad_targets:
487 487 raise IndexError("No Such Engine: %r" % bad_targets)
488 488 if not targets:
489 489 raise IndexError("No Engines Registered")
490 490 return targets
491 491
492 492 #-----------------------------------------------------------------------------
493 493 # dispatch methods (1 per stream)
494 494 #-----------------------------------------------------------------------------
495 495
496 496
497 497 @util.log_errors
498 498 def dispatch_monitor_traffic(self, msg):
499 499 """all ME and Task queue messages come through here, as well as
500 500 IOPub traffic."""
501 501 self.log.debug("monitor traffic: %r", msg[0])
502 502 switch = msg[0]
503 503 try:
504 504 idents, msg = self.session.feed_identities(msg[1:])
505 505 except ValueError:
506 506 idents=[]
507 507 if not idents:
508 508 self.log.error("Monitor message without topic: %r", msg)
509 509 return
510 510 handler = self.monitor_handlers.get(switch, None)
511 511 if handler is not None:
512 512 handler(idents, msg)
513 513 else:
514 514 self.log.error("Unrecognized monitor topic: %r", switch)
515 515
516 516
517 517 @util.log_errors
518 518 def dispatch_query(self, msg):
519 519 """Route registration requests and queries from clients."""
520 520 try:
521 521 idents, msg = self.session.feed_identities(msg)
522 522 except ValueError:
523 523 idents = []
524 524 if not idents:
525 525 self.log.error("Bad Query Message: %r", msg)
526 526 return
527 527 client_id = idents[0]
528 528 try:
529 529 msg = self.session.unserialize(msg, content=True)
530 530 except Exception:
531 531 content = error.wrap_exception()
532 532 self.log.error("Bad Query Message: %r", msg, exc_info=True)
533 533 self.session.send(self.query, "hub_error", ident=client_id,
534 534 content=content)
535 535 return
536 536 # print client_id, header, parent, content
537 537 #switch on message type:
538 538 msg_type = msg['header']['msg_type']
539 539 self.log.info("client::client %r requested %r", client_id, msg_type)
540 540 handler = self.query_handlers.get(msg_type, None)
541 541 try:
542 542 assert handler is not None, "Bad Message Type: %r" % msg_type
543 543 except:
544 544 content = error.wrap_exception()
545 545 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
546 546 self.session.send(self.query, "hub_error", ident=client_id,
547 547 content=content)
548 548 return
549 549
550 550 else:
551 551 handler(idents, msg)
552 552
553 553 def dispatch_db(self, msg):
554 554 """"""
555 555 raise NotImplementedError
556 556
557 557 #---------------------------------------------------------------------------
558 558 # handler methods (1 per event)
559 559 #---------------------------------------------------------------------------
560 560
561 561 #----------------------- Heartbeat --------------------------------------
562 562
563 563 def handle_new_heart(self, heart):
564 564 """handler to attach to heartbeater.
565 565 Called when a new heart starts to beat.
566 566 Triggers completion of registration."""
567 567 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
568 568 if heart not in self.incoming_registrations:
569 569 self.log.info("heartbeat::ignoring new heart: %r", heart)
570 570 else:
571 571 self.finish_registration(heart)
572 572
573 573
574 574 def handle_heart_failure(self, heart):
575 575 """handler to attach to heartbeater.
576 576 called when a previously registered heart fails to respond to beat request.
577 577 triggers unregistration"""
578 578 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
579 579 eid = self.hearts.get(heart, None)
580 580 uuid = self.engines[eid].uuid
581 581 if eid is None or self.keytable[eid] in self.dead_engines:
582 582 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
583 583 else:
584 584 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
585 585
586 586 #----------------------- MUX Queue Traffic ------------------------------
587 587
588 588 def save_queue_request(self, idents, msg):
589 589 if len(idents) < 2:
590 590 self.log.error("invalid identity prefix: %r", idents)
591 591 return
592 592 queue_id, client_id = idents[:2]
593 593 try:
594 594 msg = self.session.unserialize(msg)
595 595 except Exception:
596 596 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
597 597 return
598 598
599 599 eid = self.by_ident.get(queue_id, None)
600 600 if eid is None:
601 601 self.log.error("queue::target %r not registered", queue_id)
602 602 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
603 603 return
604 604 record = init_record(msg)
605 605 msg_id = record['msg_id']
606 606 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
607 607 # Unicode in records
608 608 record['engine_uuid'] = queue_id.decode('ascii')
609 609 record['client_uuid'] = msg['header']['session']
610 610 record['queue'] = 'mux'
611 611
612 612 try:
613 613 # it's posible iopub arrived first:
614 614 existing = self.db.get_record(msg_id)
615 615 for key,evalue in iteritems(existing):
616 616 rvalue = record.get(key, None)
617 617 if evalue and rvalue and evalue != rvalue:
618 618 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
619 619 elif evalue and not rvalue:
620 620 record[key] = evalue
621 621 try:
622 622 self.db.update_record(msg_id, record)
623 623 except Exception:
624 624 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
625 625 except KeyError:
626 626 try:
627 627 self.db.add_record(msg_id, record)
628 628 except Exception:
629 629 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
630 630
631 631
632 632 self.pending.add(msg_id)
633 633 self.queues[eid].append(msg_id)
634 634
635 635 def save_queue_result(self, idents, msg):
636 636 if len(idents) < 2:
637 637 self.log.error("invalid identity prefix: %r", idents)
638 638 return
639 639
640 640 client_id, queue_id = idents[:2]
641 641 try:
642 642 msg = self.session.unserialize(msg)
643 643 except Exception:
644 644 self.log.error("queue::engine %r sent invalid message to %r: %r",
645 645 queue_id, client_id, msg, exc_info=True)
646 646 return
647 647
648 648 eid = self.by_ident.get(queue_id, None)
649 649 if eid is None:
650 650 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
651 651 return
652 652
653 653 parent = msg['parent_header']
654 654 if not parent:
655 655 return
656 656 msg_id = parent['msg_id']
657 657 if msg_id in self.pending:
658 658 self.pending.remove(msg_id)
659 659 self.all_completed.add(msg_id)
660 660 self.queues[eid].remove(msg_id)
661 661 self.completed[eid].append(msg_id)
662 662 self.log.info("queue::request %r completed on %s", msg_id, eid)
663 663 elif msg_id not in self.all_completed:
664 664 # it could be a result from a dead engine that died before delivering the
665 665 # result
666 666 self.log.warn("queue:: unknown msg finished %r", msg_id)
667 667 return
668 668 # update record anyway, because the unregistration could have been premature
669 669 rheader = msg['header']
670 670 md = msg['metadata']
671 671 completed = rheader['date']
672 started = md.get('started', None)
672 started = extract_dates(md.get('started', None))
673 673 result = {
674 674 'result_header' : rheader,
675 675 'result_metadata': md,
676 676 'result_content': msg['content'],
677 677 'received': datetime.now(),
678 678 'started' : started,
679 679 'completed' : completed
680 680 }
681 681
682 682 result['result_buffers'] = msg['buffers']
683 683 try:
684 684 self.db.update_record(msg_id, result)
685 685 except Exception:
686 686 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
687 687
688 688
689 689 #--------------------- Task Queue Traffic ------------------------------
690 690
691 691 def save_task_request(self, idents, msg):
692 692 """Save the submission of a task."""
693 693 client_id = idents[0]
694 694
695 695 try:
696 696 msg = self.session.unserialize(msg)
697 697 except Exception:
698 698 self.log.error("task::client %r sent invalid task message: %r",
699 699 client_id, msg, exc_info=True)
700 700 return
701 701 record = init_record(msg)
702 702
703 703 record['client_uuid'] = msg['header']['session']
704 704 record['queue'] = 'task'
705 705 header = msg['header']
706 706 msg_id = header['msg_id']
707 707 self.pending.add(msg_id)
708 708 self.unassigned.add(msg_id)
709 709 try:
710 710 # it's posible iopub arrived first:
711 711 existing = self.db.get_record(msg_id)
712 712 if existing['resubmitted']:
713 713 for key in ('submitted', 'client_uuid', 'buffers'):
714 714 # don't clobber these keys on resubmit
715 715 # submitted and client_uuid should be different
716 716 # and buffers might be big, and shouldn't have changed
717 717 record.pop(key)
718 718 # still check content,header which should not change
719 719 # but are not expensive to compare as buffers
720 720
721 721 for key,evalue in iteritems(existing):
722 722 if key.endswith('buffers'):
723 723 # don't compare buffers
724 724 continue
725 725 rvalue = record.get(key, None)
726 726 if evalue and rvalue and evalue != rvalue:
727 727 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
728 728 elif evalue and not rvalue:
729 729 record[key] = evalue
730 730 try:
731 731 self.db.update_record(msg_id, record)
732 732 except Exception:
733 733 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
734 734 except KeyError:
735 735 try:
736 736 self.db.add_record(msg_id, record)
737 737 except Exception:
738 738 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
739 739 except Exception:
740 740 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
741 741
742 742 def save_task_result(self, idents, msg):
743 743 """save the result of a completed task."""
744 744 client_id = idents[0]
745 745 try:
746 746 msg = self.session.unserialize(msg)
747 747 except Exception:
748 748 self.log.error("task::invalid task result message send to %r: %r",
749 749 client_id, msg, exc_info=True)
750 750 return
751 751
752 752 parent = msg['parent_header']
753 753 if not parent:
754 754 # print msg
755 755 self.log.warn("Task %r had no parent!", msg)
756 756 return
757 757 msg_id = parent['msg_id']
758 758 if msg_id in self.unassigned:
759 759 self.unassigned.remove(msg_id)
760 760
761 761 header = msg['header']
762 762 md = msg['metadata']
763 763 engine_uuid = md.get('engine', u'')
764 764 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
765 765
766 766 status = md.get('status', None)
767 767
768 768 if msg_id in self.pending:
769 769 self.log.info("task::task %r finished on %s", msg_id, eid)
770 770 self.pending.remove(msg_id)
771 771 self.all_completed.add(msg_id)
772 772 if eid is not None:
773 773 if status != 'aborted':
774 774 self.completed[eid].append(msg_id)
775 775 if msg_id in self.tasks[eid]:
776 776 self.tasks[eid].remove(msg_id)
777 777 completed = header['date']
778 started = md.get('started', None)
778 started = extract_dates(md.get('started', None))
779 779 result = {
780 780 'result_header' : header,
781 781 'result_metadata': msg['metadata'],
782 782 'result_content': msg['content'],
783 783 'started' : started,
784 784 'completed' : completed,
785 785 'received' : datetime.now(),
786 786 'engine_uuid': engine_uuid,
787 787 }
788 788
789 789 result['result_buffers'] = msg['buffers']
790 790 try:
791 791 self.db.update_record(msg_id, result)
792 792 except Exception:
793 793 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
794 794
795 795 else:
796 796 self.log.debug("task::unknown task %r finished", msg_id)
797 797
798 798 def save_task_destination(self, idents, msg):
799 799 try:
800 800 msg = self.session.unserialize(msg, content=True)
801 801 except Exception:
802 802 self.log.error("task::invalid task tracking message", exc_info=True)
803 803 return
804 804 content = msg['content']
805 805 # print (content)
806 806 msg_id = content['msg_id']
807 807 engine_uuid = content['engine_id']
808 808 eid = self.by_ident[cast_bytes(engine_uuid)]
809 809
810 810 self.log.info("task::task %r arrived on %r", msg_id, eid)
811 811 if msg_id in self.unassigned:
812 812 self.unassigned.remove(msg_id)
813 813 # else:
814 814 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
815 815
816 816 self.tasks[eid].append(msg_id)
817 817 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
818 818 try:
819 819 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
820 820 except Exception:
821 821 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
822 822
823 823
824 824 def mia_task_request(self, idents, msg):
825 825 raise NotImplementedError
826 826 client_id = idents[0]
827 827 # content = dict(mia=self.mia,status='ok')
828 828 # self.session.send('mia_reply', content=content, idents=client_id)
829 829
830 830
831 831 #--------------------- IOPub Traffic ------------------------------
832 832
833 833 def save_iopub_message(self, topics, msg):
834 834 """save an iopub message into the db"""
835 835 # print (topics)
836 836 try:
837 837 msg = self.session.unserialize(msg, content=True)
838 838 except Exception:
839 839 self.log.error("iopub::invalid IOPub message", exc_info=True)
840 840 return
841 841
842 842 parent = msg['parent_header']
843 843 if not parent:
844 844 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
845 845 return
846 846 msg_id = parent['msg_id']
847 847 msg_type = msg['header']['msg_type']
848 848 content = msg['content']
849 849
850 850 # ensure msg_id is in db
851 851 try:
852 852 rec = self.db.get_record(msg_id)
853 853 except KeyError:
854 854 rec = empty_record()
855 855 rec['msg_id'] = msg_id
856 856 self.db.add_record(msg_id, rec)
857 857 # stream
858 858 d = {}
859 859 if msg_type == 'stream':
860 860 name = content['name']
861 861 s = rec[name] or ''
862 862 d[name] = s + content['data']
863 863
864 864 elif msg_type == 'pyerr':
865 865 d['pyerr'] = content
866 866 elif msg_type == 'pyin':
867 867 d['pyin'] = content['code']
868 868 elif msg_type in ('display_data', 'pyout'):
869 869 d[msg_type] = content
870 870 elif msg_type == 'status':
871 871 pass
872 872 elif msg_type == 'data_pub':
873 873 self.log.info("ignored data_pub message for %s" % msg_id)
874 874 else:
875 875 self.log.warn("unhandled iopub msg_type: %r", msg_type)
876 876
877 877 if not d:
878 878 return
879 879
880 880 try:
881 881 self.db.update_record(msg_id, d)
882 882 except Exception:
883 883 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
884 884
885 885
886 886
887 887 #-------------------------------------------------------------------------
888 888 # Registration requests
889 889 #-------------------------------------------------------------------------
890 890
891 891 def connection_request(self, client_id, msg):
892 892 """Reply with connection addresses for clients."""
893 893 self.log.info("client::client %r connected", client_id)
894 894 content = dict(status='ok')
895 895 jsonable = {}
896 896 for k,v in iteritems(self.keytable):
897 897 if v not in self.dead_engines:
898 898 jsonable[str(k)] = v
899 899 content['engines'] = jsonable
900 900 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
901 901
902 902 def register_engine(self, reg, msg):
903 903 """Register a new engine."""
904 904 content = msg['content']
905 905 try:
906 906 uuid = content['uuid']
907 907 except KeyError:
908 908 self.log.error("registration::queue not specified", exc_info=True)
909 909 return
910 910
911 911 eid = self._next_id
912 912
913 913 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
914 914
915 915 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
916 916 # check if requesting available IDs:
917 917 if cast_bytes(uuid) in self.by_ident:
918 918 try:
919 919 raise KeyError("uuid %r in use" % uuid)
920 920 except:
921 921 content = error.wrap_exception()
922 922 self.log.error("uuid %r in use", uuid, exc_info=True)
923 923 else:
924 924 for h, ec in iteritems(self.incoming_registrations):
925 925 if uuid == h:
926 926 try:
927 927 raise KeyError("heart_id %r in use" % uuid)
928 928 except:
929 929 self.log.error("heart_id %r in use", uuid, exc_info=True)
930 930 content = error.wrap_exception()
931 931 break
932 932 elif uuid == ec.uuid:
933 933 try:
934 934 raise KeyError("uuid %r in use" % uuid)
935 935 except:
936 936 self.log.error("uuid %r in use", uuid, exc_info=True)
937 937 content = error.wrap_exception()
938 938 break
939 939
940 940 msg = self.session.send(self.query, "registration_reply",
941 941 content=content,
942 942 ident=reg)
943 943
944 944 heart = cast_bytes(uuid)
945 945
946 946 if content['status'] == 'ok':
947 947 if heart in self.heartmonitor.hearts:
948 948 # already beating
949 949 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
950 950 self.finish_registration(heart)
951 951 else:
952 952 purge = lambda : self._purge_stalled_registration(heart)
953 953 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
954 954 dc.start()
955 955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
956 956 else:
957 957 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
958 958
959 959 return eid
960 960
961 961 def unregister_engine(self, ident, msg):
962 962 """Unregister an engine that explicitly requested to leave."""
963 963 try:
964 964 eid = msg['content']['id']
965 965 except:
966 966 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
967 967 return
968 968 self.log.info("registration::unregister_engine(%r)", eid)
969 969 # print (eid)
970 970 uuid = self.keytable[eid]
971 971 content=dict(id=eid, uuid=uuid)
972 972 self.dead_engines.add(uuid)
973 973 # self.ids.remove(eid)
974 974 # uuid = self.keytable.pop(eid)
975 975 #
976 976 # ec = self.engines.pop(eid)
977 977 # self.hearts.pop(ec.heartbeat)
978 978 # self.by_ident.pop(ec.queue)
979 979 # self.completed.pop(eid)
980 980 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
981 981 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
982 982 dc.start()
983 983 ############## TODO: HANDLE IT ################
984 984
985 985 self._save_engine_state()
986 986
987 987 if self.notifier:
988 988 self.session.send(self.notifier, "unregistration_notification", content=content)
989 989
990 990 def _handle_stranded_msgs(self, eid, uuid):
991 991 """Handle messages known to be on an engine when the engine unregisters.
992 992
993 993 It is possible that this will fire prematurely - that is, an engine will
994 994 go down after completing a result, and the client will be notified
995 995 that the result failed and later receive the actual result.
996 996 """
997 997
998 998 outstanding = self.queues[eid]
999 999
1000 1000 for msg_id in outstanding:
1001 1001 self.pending.remove(msg_id)
1002 1002 self.all_completed.add(msg_id)
1003 1003 try:
1004 1004 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1005 1005 except:
1006 1006 content = error.wrap_exception()
1007 1007 # build a fake header:
1008 1008 header = {}
1009 1009 header['engine'] = uuid
1010 1010 header['date'] = datetime.now()
1011 1011 rec = dict(result_content=content, result_header=header, result_buffers=[])
1012 1012 rec['completed'] = header['date']
1013 1013 rec['engine_uuid'] = uuid
1014 1014 try:
1015 1015 self.db.update_record(msg_id, rec)
1016 1016 except Exception:
1017 1017 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1018 1018
1019 1019
1020 1020 def finish_registration(self, heart):
1021 1021 """Second half of engine registration, called after our HeartMonitor
1022 1022 has received a beat from the Engine's Heart."""
1023 1023 try:
1024 1024 ec = self.incoming_registrations.pop(heart)
1025 1025 except KeyError:
1026 1026 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1027 1027 return
1028 1028 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1029 1029 if ec.stallback is not None:
1030 1030 ec.stallback.stop()
1031 1031 eid = ec.id
1032 1032 self.ids.add(eid)
1033 1033 self.keytable[eid] = ec.uuid
1034 1034 self.engines[eid] = ec
1035 1035 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1036 1036 self.queues[eid] = list()
1037 1037 self.tasks[eid] = list()
1038 1038 self.completed[eid] = list()
1039 1039 self.hearts[heart] = eid
1040 1040 content = dict(id=eid, uuid=self.engines[eid].uuid)
1041 1041 if self.notifier:
1042 1042 self.session.send(self.notifier, "registration_notification", content=content)
1043 1043 self.log.info("engine::Engine Connected: %i", eid)
1044 1044
1045 1045 self._save_engine_state()
1046 1046
1047 1047 def _purge_stalled_registration(self, heart):
1048 1048 if heart in self.incoming_registrations:
1049 1049 ec = self.incoming_registrations.pop(heart)
1050 1050 self.log.info("registration::purging stalled registration: %i", ec.id)
1051 1051 else:
1052 1052 pass
1053 1053
1054 1054 #-------------------------------------------------------------------------
1055 1055 # Engine State
1056 1056 #-------------------------------------------------------------------------
1057 1057
1058 1058
1059 1059 def _cleanup_engine_state_file(self):
1060 1060 """cleanup engine state mapping"""
1061 1061
1062 1062 if os.path.exists(self.engine_state_file):
1063 1063 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1064 1064 try:
1065 1065 os.remove(self.engine_state_file)
1066 1066 except IOError:
1067 1067 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1068 1068
1069 1069
1070 1070 def _save_engine_state(self):
1071 1071 """save engine mapping to JSON file"""
1072 1072 if not self.engine_state_file:
1073 1073 return
1074 1074 self.log.debug("save engine state to %s" % self.engine_state_file)
1075 1075 state = {}
1076 1076 engines = {}
1077 1077 for eid, ec in iteritems(self.engines):
1078 1078 if ec.uuid not in self.dead_engines:
1079 1079 engines[eid] = ec.uuid
1080 1080
1081 1081 state['engines'] = engines
1082 1082
1083 1083 state['next_id'] = self._idcounter
1084 1084
1085 1085 with open(self.engine_state_file, 'w') as f:
1086 1086 json.dump(state, f)
1087 1087
1088 1088
1089 1089 def _load_engine_state(self):
1090 1090 """load engine mapping from JSON file"""
1091 1091 if not os.path.exists(self.engine_state_file):
1092 1092 return
1093 1093
1094 1094 self.log.info("loading engine state from %s" % self.engine_state_file)
1095 1095
1096 1096 with open(self.engine_state_file) as f:
1097 1097 state = json.load(f)
1098 1098
1099 1099 save_notifier = self.notifier
1100 1100 self.notifier = None
1101 1101 for eid, uuid in iteritems(state['engines']):
1102 1102 heart = uuid.encode('ascii')
1103 1103 # start with this heart as current and beating:
1104 1104 self.heartmonitor.responses.add(heart)
1105 1105 self.heartmonitor.hearts.add(heart)
1106 1106
1107 1107 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1108 1108 self.finish_registration(heart)
1109 1109
1110 1110 self.notifier = save_notifier
1111 1111
1112 1112 self._idcounter = state['next_id']
1113 1113
1114 1114 #-------------------------------------------------------------------------
1115 1115 # Client Requests
1116 1116 #-------------------------------------------------------------------------
1117 1117
1118 1118 def shutdown_request(self, client_id, msg):
1119 1119 """handle shutdown request."""
1120 1120 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1121 1121 # also notify other clients of shutdown
1122 1122 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1123 1123 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1124 1124 dc.start()
1125 1125
1126 1126 def _shutdown(self):
1127 1127 self.log.info("hub::hub shutting down.")
1128 1128 time.sleep(0.1)
1129 1129 sys.exit(0)
1130 1130
1131 1131
1132 1132 def check_load(self, client_id, msg):
1133 1133 content = msg['content']
1134 1134 try:
1135 1135 targets = content['targets']
1136 1136 targets = self._validate_targets(targets)
1137 1137 except:
1138 1138 content = error.wrap_exception()
1139 1139 self.session.send(self.query, "hub_error",
1140 1140 content=content, ident=client_id)
1141 1141 return
1142 1142
1143 1143 content = dict(status='ok')
1144 1144 # loads = {}
1145 1145 for t in targets:
1146 1146 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1147 1147 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1148 1148
1149 1149
1150 1150 def queue_status(self, client_id, msg):
1151 1151 """Return the Queue status of one or more targets.
1152 1152
1153 1153 If verbose, return the msg_ids, else return len of each type.
1154 1154
1155 1155 Keys:
1156 1156
1157 1157 * queue (pending MUX jobs)
1158 1158 * tasks (pending Task jobs)
1159 1159 * completed (finished jobs from both queues)
1160 1160 """
1161 1161 content = msg['content']
1162 1162 targets = content['targets']
1163 1163 try:
1164 1164 targets = self._validate_targets(targets)
1165 1165 except:
1166 1166 content = error.wrap_exception()
1167 1167 self.session.send(self.query, "hub_error",
1168 1168 content=content, ident=client_id)
1169 1169 return
1170 1170 verbose = content.get('verbose', False)
1171 1171 content = dict(status='ok')
1172 1172 for t in targets:
1173 1173 queue = self.queues[t]
1174 1174 completed = self.completed[t]
1175 1175 tasks = self.tasks[t]
1176 1176 if not verbose:
1177 1177 queue = len(queue)
1178 1178 completed = len(completed)
1179 1179 tasks = len(tasks)
1180 1180 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1181 1181 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1182 1182 # print (content)
1183 1183 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1184 1184
1185 1185 def purge_results(self, client_id, msg):
1186 1186 """Purge results from memory. This method is more valuable before we move
1187 1187 to a DB based message storage mechanism."""
1188 1188 content = msg['content']
1189 1189 self.log.info("Dropping records with %s", content)
1190 1190 msg_ids = content.get('msg_ids', [])
1191 1191 reply = dict(status='ok')
1192 1192 if msg_ids == 'all':
1193 1193 try:
1194 1194 self.db.drop_matching_records(dict(completed={'$ne':None}))
1195 1195 except Exception:
1196 1196 reply = error.wrap_exception()
1197 self.log.exception("Error dropping records")
1197 1198 else:
1198 1199 pending = [m for m in msg_ids if (m in self.pending)]
1199 1200 if pending:
1200 1201 try:
1201 1202 raise IndexError("msg pending: %r" % pending[0])
1202 1203 except:
1203 1204 reply = error.wrap_exception()
1205 self.log.exception("Error dropping records")
1204 1206 else:
1205 1207 try:
1206 1208 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1207 1209 except Exception:
1208 1210 reply = error.wrap_exception()
1211 self.log.exception("Error dropping records")
1209 1212
1210 1213 if reply['status'] == 'ok':
1211 1214 eids = content.get('engine_ids', [])
1212 1215 for eid in eids:
1213 1216 if eid not in self.engines:
1214 1217 try:
1215 1218 raise IndexError("No such engine: %i" % eid)
1216 1219 except:
1217 1220 reply = error.wrap_exception()
1221 self.log.exception("Error dropping records")
1218 1222 break
1219 1223 uid = self.engines[eid].uuid
1220 1224 try:
1221 1225 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1222 1226 except Exception:
1223 1227 reply = error.wrap_exception()
1228 self.log.exception("Error dropping records")
1224 1229 break
1225 1230
1226 1231 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1227 1232
1228 1233 def resubmit_task(self, client_id, msg):
1229 1234 """Resubmit one or more tasks."""
1230 1235 def finish(reply):
1231 1236 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1232 1237
1233 1238 content = msg['content']
1234 1239 msg_ids = content['msg_ids']
1235 1240 reply = dict(status='ok')
1236 1241 try:
1237 1242 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1238 1243 'header', 'content', 'buffers'])
1239 1244 except Exception:
1240 1245 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1241 1246 return finish(error.wrap_exception())
1242 1247
1243 1248 # validate msg_ids
1244 1249 found_ids = [ rec['msg_id'] for rec in records ]
1245 1250 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1246 1251 if len(records) > len(msg_ids):
1247 1252 try:
1248 1253 raise RuntimeError("DB appears to be in an inconsistent state."
1249 1254 "More matching records were found than should exist")
1250 1255 except Exception:
1256 self.log.exception("Failed to resubmit task")
1251 1257 return finish(error.wrap_exception())
1252 1258 elif len(records) < len(msg_ids):
1253 1259 missing = [ m for m in msg_ids if m not in found_ids ]
1254 1260 try:
1255 1261 raise KeyError("No such msg(s): %r" % missing)
1256 1262 except KeyError:
1263 self.log.exception("Failed to resubmit task")
1257 1264 return finish(error.wrap_exception())
1258 1265 elif pending_ids:
1259 1266 pass
1260 1267 # no need to raise on resubmit of pending task, now that we
1261 1268 # resubmit under new ID, but do we want to raise anyway?
1262 1269 # msg_id = invalid_ids[0]
1263 1270 # try:
1264 1271 # raise ValueError("Task(s) %r appears to be inflight" % )
1265 1272 # except Exception:
1266 1273 # return finish(error.wrap_exception())
1267 1274
1268 1275 # mapping of original IDs to resubmitted IDs
1269 1276 resubmitted = {}
1270 1277
1271 1278 # send the messages
1272 1279 for rec in records:
1273 1280 header = rec['header']
1274 1281 msg = self.session.msg(header['msg_type'], parent=header)
1275 1282 msg_id = msg['msg_id']
1276 1283 msg['content'] = rec['content']
1277 1284
1278 1285 # use the old header, but update msg_id and timestamp
1279 1286 fresh = msg['header']
1280 1287 header['msg_id'] = fresh['msg_id']
1281 1288 header['date'] = fresh['date']
1282 1289 msg['header'] = header
1283 1290
1284 1291 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1285 1292
1286 1293 resubmitted[rec['msg_id']] = msg_id
1287 1294 self.pending.add(msg_id)
1288 1295 msg['buffers'] = rec['buffers']
1289 1296 try:
1290 1297 self.db.add_record(msg_id, init_record(msg))
1291 1298 except Exception:
1292 1299 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1293 1300 return finish(error.wrap_exception())
1294 1301
1295 1302 finish(dict(status='ok', resubmitted=resubmitted))
1296 1303
1297 1304 # store the new IDs in the Task DB
1298 1305 for msg_id, resubmit_id in iteritems(resubmitted):
1299 1306 try:
1300 1307 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1301 1308 except Exception:
1302 1309 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1303 1310
1304 1311
1305 1312 def _extract_record(self, rec):
1306 1313 """decompose a TaskRecord dict into subsection of reply for get_result"""
1307 1314 io_dict = {}
1308 1315 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1309 1316 io_dict[key] = rec[key]
1310 1317 content = {
1311 1318 'header': rec['header'],
1312 1319 'metadata': rec['metadata'],
1313 1320 'result_metadata': rec['result_metadata'],
1314 1321 'result_header' : rec['result_header'],
1315 1322 'result_content': rec['result_content'],
1316 1323 'received' : rec['received'],
1317 1324 'io' : io_dict,
1318 1325 }
1319 1326 if rec['result_buffers']:
1320 1327 buffers = list(map(bytes, rec['result_buffers']))
1321 1328 else:
1322 1329 buffers = []
1323 1330
1324 1331 return content, buffers
1325 1332
1326 1333 def get_results(self, client_id, msg):
1327 1334 """Get the result of 1 or more messages."""
1328 1335 content = msg['content']
1329 1336 msg_ids = sorted(set(content['msg_ids']))
1330 1337 statusonly = content.get('status_only', False)
1331 1338 pending = []
1332 1339 completed = []
1333 1340 content = dict(status='ok')
1334 1341 content['pending'] = pending
1335 1342 content['completed'] = completed
1336 1343 buffers = []
1337 1344 if not statusonly:
1338 1345 try:
1339 1346 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1340 1347 # turn match list into dict, for faster lookup
1341 1348 records = {}
1342 1349 for rec in matches:
1343 1350 records[rec['msg_id']] = rec
1344 1351 except Exception:
1345 1352 content = error.wrap_exception()
1353 self.log.exception("Failed to get results")
1346 1354 self.session.send(self.query, "result_reply", content=content,
1347 1355 parent=msg, ident=client_id)
1348 1356 return
1349 1357 else:
1350 1358 records = {}
1351 1359 for msg_id in msg_ids:
1352 1360 if msg_id in self.pending:
1353 1361 pending.append(msg_id)
1354 1362 elif msg_id in self.all_completed:
1355 1363 completed.append(msg_id)
1356 1364 if not statusonly:
1357 1365 c,bufs = self._extract_record(records[msg_id])
1358 1366 content[msg_id] = c
1359 1367 buffers.extend(bufs)
1360 1368 elif msg_id in records:
1361 1369 if rec['completed']:
1362 1370 completed.append(msg_id)
1363 1371 c,bufs = self._extract_record(records[msg_id])
1364 1372 content[msg_id] = c
1365 1373 buffers.extend(bufs)
1366 1374 else:
1367 1375 pending.append(msg_id)
1368 1376 else:
1369 1377 try:
1370 1378 raise KeyError('No such message: '+msg_id)
1371 1379 except:
1372 1380 content = error.wrap_exception()
1373 1381 break
1374 1382 self.session.send(self.query, "result_reply", content=content,
1375 1383 parent=msg, ident=client_id,
1376 1384 buffers=buffers)
1377 1385
1378 1386 def get_history(self, client_id, msg):
1379 1387 """Get a list of all msg_ids in our DB records"""
1380 1388 try:
1381 1389 msg_ids = self.db.get_history()
1382 1390 except Exception as e:
1383 1391 content = error.wrap_exception()
1392 self.log.exception("Failed to get history")
1384 1393 else:
1385 1394 content = dict(status='ok', history=msg_ids)
1386 1395
1387 1396 self.session.send(self.query, "history_reply", content=content,
1388 1397 parent=msg, ident=client_id)
1389 1398
1390 1399 def db_query(self, client_id, msg):
1391 1400 """Perform a raw query on the task record database."""
1392 1401 content = msg['content']
1393 1402 query = extract_dates(content.get('query', {}))
1394 1403 keys = content.get('keys', None)
1395 1404 buffers = []
1396 1405 empty = list()
1397 1406 try:
1398 1407 records = self.db.find_records(query, keys)
1399 1408 except Exception as e:
1400 1409 content = error.wrap_exception()
1410 self.log.exception("DB query failed")
1401 1411 else:
1402 1412 # extract buffers from reply content:
1403 1413 if keys is not None:
1404 1414 buffer_lens = [] if 'buffers' in keys else None
1405 1415 result_buffer_lens = [] if 'result_buffers' in keys else None
1406 1416 else:
1407 1417 buffer_lens = None
1408 1418 result_buffer_lens = None
1409 1419
1410 1420 for rec in records:
1411 1421 # buffers may be None, so double check
1412 1422 b = rec.pop('buffers', empty) or empty
1413 1423 if buffer_lens is not None:
1414 1424 buffer_lens.append(len(b))
1415 1425 buffers.extend(b)
1416 1426 rb = rec.pop('result_buffers', empty) or empty
1417 1427 if result_buffer_lens is not None:
1418 1428 result_buffer_lens.append(len(rb))
1419 1429 buffers.extend(rb)
1420 1430 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1421 1431 result_buffer_lens=result_buffer_lens)
1422 1432 # self.log.debug (content)
1423 1433 self.session.send(self.query, "db_reply", content=content,
1424 1434 parent=msg, ident=client_id,
1425 1435 buffers=buffers)
1426 1436
@@ -1,130 +1,145
1 1 """toplevel setup/teardown for parallel tests."""
2 2 from __future__ import print_function
3 3
4 4 #-------------------------------------------------------------------------------
5 5 # Copyright (C) 2011 The IPython Development Team
6 6 #
7 7 # Distributed under the terms of the BSD License. The full license is in
8 8 # the file COPYING, distributed as part of this software.
9 9 #-------------------------------------------------------------------------------
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Imports
13 13 #-------------------------------------------------------------------------------
14 14
15 15 import os
16 16 import tempfile
17 17 import time
18 18 from subprocess import Popen, PIPE, STDOUT
19 19
20 20 import nose
21 21
22 22 from IPython.utils.path import get_ipython_dir
23 from IPython.parallel import Client
23 from IPython.parallel import Client, error
24 24 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
25 25 ipengine_cmd_argv,
26 26 ipcontroller_cmd_argv,
27 27 SIGKILL,
28 28 ProcessStateError,
29 29 )
30 30
31 31 # globals
32 32 launchers = []
33 33 blackhole = open(os.devnull, 'w')
34 34
35 35 # Launcher class
36 36 class TestProcessLauncher(LocalProcessLauncher):
37 37 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
38 38 def start(self):
39 39 if self.state == 'before':
40 40 # Store stdout & stderr to show with failing tests.
41 41 # This is defined in IPython.testing.iptest
42 42 self.process = Popen(self.args,
43 43 stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT,
44 44 env=os.environ,
45 45 cwd=self.work_dir
46 46 )
47 47 self.notify_start(self.process.pid)
48 48 self.poll = self.process.poll
49 49 else:
50 50 s = 'The process was already started and has state: %r' % self.state
51 51 raise ProcessStateError(s)
52 52
53 53 # nose setup/teardown
54 54
55 55 def setup():
56
57 # show tracebacks for RemoteErrors
58 class RemoteErrorWithTB(error.RemoteError):
59 def __str__(self):
60 s = super(RemoteErrorWithTB, self).__str__()
61 return '\n'.join([s, self.traceback or ''])
62
63 error.RemoteError = RemoteErrorWithTB
64
56 65 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
57 66 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
58 67 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
59 68 for json in (engine_json, client_json):
60 69 if os.path.exists(json):
61 70 os.remove(json)
62 71
63 72 cp = TestProcessLauncher()
64 73 cp.cmd_and_args = ipcontroller_cmd_argv + \
65 74 ['--profile=iptest', '--log-level=20', '--ping=250', '--dictdb']
66 75 cp.start()
67 76 launchers.append(cp)
68 77 tic = time.time()
69 78 while not os.path.exists(engine_json) or not os.path.exists(client_json):
70 79 if cp.poll() is not None:
71 80 raise RuntimeError("The test controller exited with status %s" % cp.poll())
72 81 elif time.time()-tic > 15:
73 82 raise RuntimeError("Timeout waiting for the test controller to start.")
74 83 time.sleep(0.1)
75 84 add_engines(1)
76 85
77 86 def add_engines(n=1, profile='iptest', total=False):
78 87 """add a number of engines to a given profile.
79 88
80 89 If total is True, then already running engines are counted, and only
81 90 the additional engines necessary (if any) are started.
82 91 """
83 92 rc = Client(profile=profile)
84 93 base = len(rc)
85 94
86 95 if total:
87 96 n = max(n - base, 0)
88 97
89 98 eps = []
90 99 for i in range(n):
91 100 ep = TestProcessLauncher()
92 101 ep.cmd_and_args = ipengine_cmd_argv + [
93 102 '--profile=%s' % profile,
94 103 '--log-level=50',
95 104 '--InteractiveShell.colors=nocolor'
96 105 ]
97 106 ep.start()
98 107 launchers.append(ep)
99 108 eps.append(ep)
100 109 tic = time.time()
101 110 while len(rc) < base+n:
102 111 if any([ ep.poll() is not None for ep in eps ]):
103 112 raise RuntimeError("A test engine failed to start.")
104 113 elif time.time()-tic > 15:
105 114 raise RuntimeError("Timeout waiting for engines to connect.")
106 115 time.sleep(.1)
107 116 rc.spin()
108 117 rc.close()
109 118 return eps
110 119
111 120 def teardown():
121 try:
112 122 time.sleep(1)
123 except KeyboardInterrupt:
124 return
113 125 while launchers:
114 126 p = launchers.pop()
115 127 if p.poll() is None:
116 128 try:
117 129 p.stop()
118 130 except Exception as e:
119 131 print(e)
120 132 pass
121 133 if p.poll() is None:
134 try:
122 135 time.sleep(.25)
136 except KeyboardInterrupt:
137 return
123 138 if p.poll() is None:
124 139 try:
125 140 print('cleaning up test process...')
126 141 p.signal(SIGKILL)
127 142 except:
128 143 print("couldn't shutdown process: ", p)
129 144 blackhole.close()
130 145
@@ -1,314 +1,314
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 from __future__ import division
20 20
21 21 import logging
22 22 import os
23 23 import tempfile
24 24 import time
25 25
26 26 from datetime import datetime, timedelta
27 27 from unittest import TestCase
28 28
29 29 from IPython.parallel import error
30 30 from IPython.parallel.controller.dictdb import DictDB
31 31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 32 from IPython.parallel.controller.hub import init_record, empty_record
33 33
34 34 from IPython.testing import decorators as dec
35 35 from IPython.kernel.zmq.session import Session
36 36
37 37
38 38 #-------------------------------------------------------------------------------
39 39 # TestCases
40 40 #-------------------------------------------------------------------------------
41 41
42 42
43 43 def setup():
44 44 global temp_db
45 45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46 46
47 47
48 48 class TaskDBTest:
49 49 def setUp(self):
50 50 self.session = Session()
51 51 self.db = self.create_db()
52 52 self.load_records(16)
53 53
54 54 def create_db(self):
55 55 raise NotImplementedError
56 56
57 57 def load_records(self, n=1, buffer_size=100):
58 58 """load n records for testing"""
59 59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 60 time.sleep(0.1)
61 61 msg_ids = []
62 62 for i in range(n):
63 63 msg = self.session.msg('apply_request', content=dict(a=5))
64 64 msg['buffers'] = [os.urandom(buffer_size)]
65 65 rec = init_record(msg)
66 66 msg_id = msg['header']['msg_id']
67 67 msg_ids.append(msg_id)
68 68 self.db.add_record(msg_id, rec)
69 69 return msg_ids
70 70
71 71 def test_add_record(self):
72 72 before = self.db.get_history()
73 73 self.load_records(5)
74 74 after = self.db.get_history()
75 75 self.assertEqual(len(after), len(before)+5)
76 76 self.assertEqual(after[:-5],before)
77 77
78 78 def test_drop_record(self):
79 79 msg_id = self.load_records()[-1]
80 80 rec = self.db.get_record(msg_id)
81 81 self.db.drop_record(msg_id)
82 82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83 83
84 84 def _round_to_millisecond(self, dt):
85 85 """necessary because mongodb rounds microseconds"""
86 86 micro = dt.microsecond
87 87 extra = int(str(micro)[-3:])
88 88 return dt - timedelta(microseconds=extra)
89 89
90 90 def test_update_record(self):
91 91 now = self._round_to_millisecond(datetime.now())
92 92 #
93 93 msg_id = self.db.get_history()[-1]
94 94 rec1 = self.db.get_record(msg_id)
95 95 data = {'stdout': 'hello there', 'completed' : now}
96 96 self.db.update_record(msg_id, data)
97 97 rec2 = self.db.get_record(msg_id)
98 98 self.assertEqual(rec2['stdout'], 'hello there')
99 99 self.assertEqual(rec2['completed'], now)
100 100 rec1.update(data)
101 101 self.assertEqual(rec1, rec2)
102 102
103 103 # def test_update_record_bad(self):
104 104 # """test updating nonexistant records"""
105 105 # msg_id = str(uuid.uuid4())
106 106 # data = {'stdout': 'hello there'}
107 107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108 108
109 109 def test_find_records_dt(self):
110 110 """test finding records by date"""
111 111 hist = self.db.get_history()
112 112 middle = self.db.get_record(hist[len(hist)//2])
113 113 tic = middle['submitted']
114 114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 116 self.assertEqual(len(before)+len(after),len(hist))
117 117 for b in before:
118 118 self.assertTrue(b['submitted'] < tic)
119 119 for a in after:
120 120 self.assertTrue(a['submitted'] >= tic)
121 121 same = self.db.find_records({'submitted' : tic})
122 122 for s in same:
123 123 self.assertTrue(s['submitted'] == tic)
124 124
125 125 def test_find_records_keys(self):
126 126 """test extracting subset of record keys"""
127 127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 128 for rec in found:
129 129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130 130
131 131 def test_find_records_msg_id(self):
132 132 """ensure msg_id is always in found records"""
133 133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 134 for rec in found:
135 135 self.assertTrue('msg_id' in rec.keys())
136 136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 137 for rec in found:
138 138 self.assertTrue('msg_id' in rec.keys())
139 139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 140 for rec in found:
141 141 self.assertTrue('msg_id' in rec.keys())
142 142
143 143 def test_find_records_in(self):
144 144 """test finding records with '$in','$nin' operators"""
145 145 hist = self.db.get_history()
146 146 even = hist[::2]
147 147 odd = hist[1::2]
148 148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 149 found = [ r['msg_id'] for r in recs ]
150 150 self.assertEqual(set(even), set(found))
151 151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 152 found = [ r['msg_id'] for r in recs ]
153 153 self.assertEqual(set(odd), set(found))
154 154
155 155 def test_get_history(self):
156 156 msg_ids = self.db.get_history()
157 157 latest = datetime(1984,1,1)
158 158 for msg_id in msg_ids:
159 159 rec = self.db.get_record(msg_id)
160 160 newt = rec['submitted']
161 161 self.assertTrue(newt >= latest)
162 162 latest = newt
163 163 msg_id = self.load_records(1)[-1]
164 164 self.assertEqual(self.db.get_history()[-1],msg_id)
165 165
166 166 def test_datetime(self):
167 167 """get/set timestamps with datetime objects"""
168 168 msg_id = self.db.get_history()[-1]
169 169 rec = self.db.get_record(msg_id)
170 170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 172 rec = self.db.get_record(msg_id)
173 173 self.assertTrue(isinstance(rec['completed'], datetime))
174 174
175 175 def test_drop_matching(self):
176 176 msg_ids = self.load_records(10)
177 177 query = {'msg_id' : {'$in':msg_ids}}
178 178 self.db.drop_matching_records(query)
179 179 recs = self.db.find_records(query)
180 180 self.assertEqual(len(recs), 0)
181 181
182 182 def test_null(self):
183 183 """test None comparison queries"""
184 184 msg_ids = self.load_records(10)
185 185
186 186 query = {'msg_id' : None}
187 187 recs = self.db.find_records(query)
188 188 self.assertEqual(len(recs), 0)
189 189
190 190 query = {'msg_id' : {'$ne' : None}}
191 191 recs = self.db.find_records(query)
192 192 self.assertTrue(len(recs) >= 10)
193 193
194 194 def test_pop_safe_get(self):
195 195 """editing query results shouldn't affect record [get]"""
196 196 msg_id = self.db.get_history()[-1]
197 197 rec = self.db.get_record(msg_id)
198 198 rec.pop('buffers')
199 199 rec['garbage'] = 'hello'
200 200 rec['header']['msg_id'] = 'fubar'
201 201 rec2 = self.db.get_record(msg_id)
202 202 self.assertTrue('buffers' in rec2)
203 203 self.assertFalse('garbage' in rec2)
204 204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205 205
206 206 def test_pop_safe_find(self):
207 207 """editing query results shouldn't affect record [find]"""
208 208 msg_id = self.db.get_history()[-1]
209 209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 210 rec.pop('buffers')
211 211 rec['garbage'] = 'hello'
212 212 rec['header']['msg_id'] = 'fubar'
213 213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 214 self.assertTrue('buffers' in rec2)
215 215 self.assertFalse('garbage' in rec2)
216 216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217 217
218 218 def test_pop_safe_find_keys(self):
219 219 """editing query results shouldn't affect record [find+keys]"""
220 220 msg_id = self.db.get_history()[-1]
221 221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 222 rec.pop('buffers')
223 223 rec['garbage'] = 'hello'
224 224 rec['header']['msg_id'] = 'fubar'
225 225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 226 self.assertTrue('buffers' in rec2)
227 227 self.assertFalse('garbage' in rec2)
228 228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229 229
230 230
231 231 class TestDictBackend(TaskDBTest, TestCase):
232 232
233 233 def create_db(self):
234 234 return DictDB()
235 235
236 236 def test_cull_count(self):
237 237 self.db = self.create_db() # skip the load-records init from setUp
238 238 self.db.record_limit = 20
239 239 self.db.cull_fraction = 0.2
240 240 self.load_records(20)
241 241 self.assertEqual(len(self.db.get_history()), 20)
242 242 self.load_records(1)
243 243 # 0.2 * 20 = 4, 21 - 4 = 17
244 244 self.assertEqual(len(self.db.get_history()), 17)
245 245 self.load_records(3)
246 246 self.assertEqual(len(self.db.get_history()), 20)
247 247 self.load_records(1)
248 248 self.assertEqual(len(self.db.get_history()), 17)
249 249
250 for i in range(100):
250 for i in range(25):
251 251 self.load_records(1)
252 252 self.assertTrue(len(self.db.get_history()) >= 17)
253 253 self.assertTrue(len(self.db.get_history()) <= 20)
254 254
255 255 def test_cull_size(self):
256 256 self.db = self.create_db() # skip the load-records init from setUp
257 257 self.db.size_limit = 1000
258 258 self.db.cull_fraction = 0.2
259 259 self.load_records(100, buffer_size=10)
260 260 self.assertEqual(len(self.db.get_history()), 100)
261 261 self.load_records(1, buffer_size=0)
262 262 self.assertEqual(len(self.db.get_history()), 101)
263 263 self.load_records(1, buffer_size=1)
264 264 # 0.2 * 100 = 20, 101 - 20 = 81
265 265 self.assertEqual(len(self.db.get_history()), 81)
266 266
267 267 def test_cull_size_drop(self):
268 268 """dropping records updates tracked buffer size"""
269 269 self.db = self.create_db() # skip the load-records init from setUp
270 270 self.db.size_limit = 1000
271 271 self.db.cull_fraction = 0.2
272 272 self.load_records(100, buffer_size=10)
273 273 self.assertEqual(len(self.db.get_history()), 100)
274 274 self.db.drop_record(self.db.get_history()[-1])
275 275 self.assertEqual(len(self.db.get_history()), 99)
276 276 self.load_records(1, buffer_size=5)
277 277 self.assertEqual(len(self.db.get_history()), 100)
278 278 self.load_records(1, buffer_size=5)
279 279 self.assertEqual(len(self.db.get_history()), 101)
280 280 self.load_records(1, buffer_size=1)
281 281 self.assertEqual(len(self.db.get_history()), 81)
282 282
283 283 def test_cull_size_update(self):
284 284 """updating records updates tracked buffer size"""
285 285 self.db = self.create_db() # skip the load-records init from setUp
286 286 self.db.size_limit = 1000
287 287 self.db.cull_fraction = 0.2
288 288 self.load_records(100, buffer_size=10)
289 289 self.assertEqual(len(self.db.get_history()), 100)
290 290 msg_id = self.db.get_history()[-1]
291 291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
292 292 self.assertEqual(len(self.db.get_history()), 100)
293 293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
294 294 self.assertEqual(len(self.db.get_history()), 79)
295 295
296 296 class TestSQLiteBackend(TaskDBTest, TestCase):
297 297
298 298 @dec.skip_without('sqlite3')
299 299 def create_db(self):
300 300 location, fname = os.path.split(temp_db)
301 301 log = logging.getLogger('test')
302 302 log.setLevel(logging.CRITICAL)
303 303 return SQLiteDB(location=location, fname=fname, log=log)
304 304
305 305 def tearDown(self):
306 306 self.db._db.close()
307 307
308 308
309 309 def teardown():
310 310 """cleanup task db file after all tests have run"""
311 311 try:
312 312 os.remove(temp_db)
313 313 except:
314 314 pass
@@ -1,256 +1,259
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010-2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import math
15 15 import re
16 16 import types
17 17 from datetime import datetime
18 18
19 19 try:
20 20 # base64.encodestring is deprecated in Python 3.x
21 21 from base64 import encodebytes
22 22 except ImportError:
23 23 # Python 2.x
24 24 from base64 import encodestring as encodebytes
25 25
26 26 from IPython.utils import py3compat
27 27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
28 28 from IPython.utils.encoding import DEFAULT_ENCODING
29 29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Globals and constants
33 33 #-----------------------------------------------------------------------------
34 34
35 35 # timestamp formats
36 36 ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6})Z?([\+\-]\d{2}:?\d{2})?$")
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?Z?([\+\-]\d{2}:?\d{2})?$")
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Classes and functions
41 41 #-----------------------------------------------------------------------------
42 42
43 43 def rekey(dikt):
44 44 """Rekey a dict that has been forced to use str keys where there should be
45 45 ints by json."""
46 46 for k in dikt:
47 47 if isinstance(k, string_types):
48 48 ik=fk=None
49 49 try:
50 50 ik = int(k)
51 51 except ValueError:
52 52 try:
53 53 fk = float(k)
54 54 except ValueError:
55 55 continue
56 56 if ik is not None:
57 57 nk = ik
58 58 else:
59 59 nk = fk
60 60 if nk in dikt:
61 61 raise KeyError("already have key %r"%nk)
62 62 dikt[nk] = dikt.pop(k)
63 63 return dikt
64 64
65 65 def parse_date(s):
66 66 """parse an ISO8601 date string
67 67
68 68 If it is None or not a valid ISO8601 timestamp,
69 69 it will be returned unmodified.
70 70 Otherwise, it will return a datetime object.
71 71 """
72 72 if s is None:
73 73 return s
74 74 m = ISO8601_PAT.match(s)
75 75 if m:
76 76 # FIXME: add actual timezone support
77 77 # this just drops the timezone info
78 notz = m.groups()[0]
78 notz, ms, tz = m.groups()
79 if not ms:
80 ms = '.0'
81 notz = notz + ms
79 82 return datetime.strptime(notz, ISO8601)
80 83 return s
81 84
82 85 def extract_dates(obj):
83 86 """extract ISO8601 dates from unpacked JSON"""
84 87 if isinstance(obj, dict):
85 88 new_obj = {} # don't clobber
86 89 for k,v in iteritems(obj):
87 90 new_obj[k] = extract_dates(v)
88 91 obj = new_obj
89 92 elif isinstance(obj, (list, tuple)):
90 93 obj = [ extract_dates(o) for o in obj ]
91 94 elif isinstance(obj, string_types):
92 95 obj = parse_date(obj)
93 96 return obj
94 97
95 98 def squash_dates(obj):
96 99 """squash datetime objects into ISO8601 strings"""
97 100 if isinstance(obj, dict):
98 101 obj = dict(obj) # don't clobber
99 102 for k,v in iteritems(obj):
100 103 obj[k] = squash_dates(v)
101 104 elif isinstance(obj, (list, tuple)):
102 105 obj = [ squash_dates(o) for o in obj ]
103 106 elif isinstance(obj, datetime):
104 107 obj = obj.isoformat()
105 108 return obj
106 109
107 110 def date_default(obj):
108 111 """default function for packing datetime objects in JSON."""
109 112 if isinstance(obj, datetime):
110 113 return obj.isoformat()
111 114 else:
112 115 raise TypeError("%r is not JSON serializable"%obj)
113 116
114 117
115 118 # constants for identifying png/jpeg data
116 119 PNG = b'\x89PNG\r\n\x1a\n'
117 120 # front of PNG base64-encoded
118 121 PNG64 = b'iVBORw0KG'
119 122 JPEG = b'\xff\xd8'
120 123 # front of JPEG base64-encoded
121 124 JPEG64 = b'/9'
122 125 # front of PDF base64-encoded
123 126 PDF64 = b'JVBER'
124 127
125 128 def encode_images(format_dict):
126 129 """b64-encodes images in a displaypub format dict
127 130
128 131 Perhaps this should be handled in json_clean itself?
129 132
130 133 Parameters
131 134 ----------
132 135
133 136 format_dict : dict
134 137 A dictionary of display data keyed by mime-type
135 138
136 139 Returns
137 140 -------
138 141
139 142 format_dict : dict
140 143 A copy of the same dictionary,
141 144 but binary image data ('image/png', 'image/jpeg' or 'application/pdf')
142 145 is base64-encoded.
143 146
144 147 """
145 148 encoded = format_dict.copy()
146 149
147 150 pngdata = format_dict.get('image/png')
148 151 if isinstance(pngdata, bytes):
149 152 # make sure we don't double-encode
150 153 if not pngdata.startswith(PNG64):
151 154 pngdata = encodebytes(pngdata)
152 155 encoded['image/png'] = pngdata.decode('ascii')
153 156
154 157 jpegdata = format_dict.get('image/jpeg')
155 158 if isinstance(jpegdata, bytes):
156 159 # make sure we don't double-encode
157 160 if not jpegdata.startswith(JPEG64):
158 161 jpegdata = encodebytes(jpegdata)
159 162 encoded['image/jpeg'] = jpegdata.decode('ascii')
160 163
161 164 pdfdata = format_dict.get('application/pdf')
162 165 if isinstance(pdfdata, bytes):
163 166 # make sure we don't double-encode
164 167 if not pdfdata.startswith(PDF64):
165 168 pdfdata = encodebytes(pdfdata)
166 169 encoded['application/pdf'] = pdfdata.decode('ascii')
167 170
168 171 return encoded
169 172
170 173
171 174 def json_clean(obj):
172 175 """Clean an object to ensure it's safe to encode in JSON.
173 176
174 177 Atomic, immutable objects are returned unmodified. Sets and tuples are
175 178 converted to lists, lists are copied and dicts are also copied.
176 179
177 180 Note: dicts whose keys could cause collisions upon encoding (such as a dict
178 181 with both the number 1 and the string '1' as keys) will cause a ValueError
179 182 to be raised.
180 183
181 184 Parameters
182 185 ----------
183 186 obj : any python object
184 187
185 188 Returns
186 189 -------
187 190 out : object
188 191
189 192 A version of the input which will not cause an encoding error when
190 193 encoded as JSON. Note that this function does not *encode* its inputs,
191 194 it simply sanitizes it so that there will be no encoding errors later.
192 195
193 196 Examples
194 197 --------
195 198 >>> json_clean(4)
196 199 4
197 200 >>> json_clean(list(range(10)))
198 201 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
199 202 >>> sorted(json_clean(dict(x=1, y=2)).items())
200 203 [('x', 1), ('y', 2)]
201 204 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
202 205 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
203 206 >>> json_clean(True)
204 207 True
205 208 """
206 209 # types that are 'atomic' and ok in json as-is.
207 210 atomic_ok = (unicode_type, type(None))
208 211
209 212 # containers that we need to convert into lists
210 213 container_to_list = (tuple, set, types.GeneratorType)
211 214
212 215 if isinstance(obj, float):
213 216 # cast out-of-range floats to their reprs
214 217 if math.isnan(obj) or math.isinf(obj):
215 218 return repr(obj)
216 219 return float(obj)
217 220
218 221 if isinstance(obj, int):
219 222 # cast int to int, in case subclasses override __str__ (e.g. boost enum, #4598)
220 223 if isinstance(obj, bool):
221 224 # bools are ints, but we don't want to cast them to 0,1
222 225 return obj
223 226 return int(obj)
224 227
225 228 if isinstance(obj, atomic_ok):
226 229 return obj
227 230
228 231 if isinstance(obj, bytes):
229 232 return obj.decode(DEFAULT_ENCODING, 'replace')
230 233
231 234 if isinstance(obj, container_to_list) or (
232 235 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
233 236 obj = list(obj)
234 237
235 238 if isinstance(obj, list):
236 239 return [json_clean(x) for x in obj]
237 240
238 241 if isinstance(obj, dict):
239 242 # First, validate that the dict won't lose data in conversion due to
240 243 # key collisions after stringification. This can happen with keys like
241 244 # True and 'true' or 1 and '1', which collide in JSON.
242 245 nkeys = len(obj)
243 246 nkeys_collapsed = len(set(map(str, obj)))
244 247 if nkeys != nkeys_collapsed:
245 248 raise ValueError('dict can not be safely converted to JSON: '
246 249 'key collision would lead to dropped values')
247 250 # If all OK, proceed by making the new dict that will be json-safe
248 251 out = {}
249 252 for k,v in iteritems(obj):
250 253 out[str(k)] = json_clean(v)
251 254 return out
252 255
253 256 # If we get here, we don't know how to handle the object, so we just get
254 257 # its repr and return that. This will catch lambdas, open sockets, class
255 258 # objects, and any other complicated contraption that json can't encode
256 259 return repr(obj)
@@ -1,149 +1,151
1 1 """Test suite for our JSON utilities.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010-2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import datetime
15 15 import json
16 16 from base64 import decodestring
17 17
18 18 # third party
19 19 import nose.tools as nt
20 20
21 21 # our own
22 22 from IPython.utils import jsonutil, tz
23 23 from ..jsonutil import json_clean, encode_images
24 24 from ..py3compat import unicode_to_str, str_to_bytes, iteritems
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # Test functions
28 28 #-----------------------------------------------------------------------------
29 29 class Int(int):
30 30 def __str__(self):
31 31 return 'Int(%i)' % self
32 32
33 33 def test():
34 34 # list of input/expected output. Use None for the expected output if it
35 35 # can be the same as the input.
36 36 pairs = [(1, None), # start with scalars
37 37 (1.0, None),
38 38 ('a', None),
39 39 (True, None),
40 40 (False, None),
41 41 (None, None),
42 42 # complex numbers for now just go to strings, as otherwise they
43 43 # are unserializable
44 44 (1j, '1j'),
45 45 # Containers
46 46 ([1, 2], None),
47 47 ((1, 2), [1, 2]),
48 48 (set([1, 2]), [1, 2]),
49 49 (dict(x=1), None),
50 50 ({'x': 1, 'y':[1,2,3], '1':'int'}, None),
51 51 # More exotic objects
52 52 ((x for x in range(3)), [0, 1, 2]),
53 53 (iter([1, 2]), [1, 2]),
54 54 (Int(5), 5),
55 55 ]
56 56
57 57 for val, jval in pairs:
58 58 if jval is None:
59 59 jval = val
60 60 out = json_clean(val)
61 61 # validate our cleanup
62 62 nt.assert_equal(out, jval)
63 63 # and ensure that what we return, indeed encodes cleanly
64 64 json.loads(json.dumps(out))
65 65
66 66
67 67
68 68 def test_encode_images():
69 69 # invalid data, but the header and footer are from real files
70 70 pngdata = b'\x89PNG\r\n\x1a\nblahblahnotactuallyvalidIEND\xaeB`\x82'
71 71 jpegdata = b'\xff\xd8\xff\xe0\x00\x10JFIFblahblahjpeg(\xa0\x0f\xff\xd9'
72 72 pdfdata = b'%PDF-1.\ntrailer<</Root<</Pages<</Kids[<</MediaBox[0 0 3 3]>>]>>>>>>'
73 73
74 74 fmt = {
75 75 'image/png' : pngdata,
76 76 'image/jpeg' : jpegdata,
77 77 'application/pdf' : pdfdata
78 78 }
79 79 encoded = encode_images(fmt)
80 80 for key, value in iteritems(fmt):
81 81 # encoded has unicode, want bytes
82 82 decoded = decodestring(encoded[key].encode('ascii'))
83 83 nt.assert_equal(decoded, value)
84 84 encoded2 = encode_images(encoded)
85 85 nt.assert_equal(encoded, encoded2)
86 86
87 87 b64_str = {}
88 88 for key, encoded in iteritems(encoded):
89 89 b64_str[key] = unicode_to_str(encoded)
90 90 encoded3 = encode_images(b64_str)
91 91 nt.assert_equal(encoded3, b64_str)
92 92 for key, value in iteritems(fmt):
93 93 # encoded3 has str, want bytes
94 94 decoded = decodestring(str_to_bytes(encoded3[key]))
95 95 nt.assert_equal(decoded, value)
96 96
97 97 def test_lambda():
98 98 jc = json_clean(lambda : 1)
99 assert isinstance(jc, str)
100 assert '<lambda>' in jc
99 nt.assert_is_instance(jc, str)
100 nt.assert_in('<lambda>', jc)
101 101 json.dumps(jc)
102 102
103 103 def test_extract_dates():
104 104 timestamps = [
105 105 '2013-07-03T16:34:52.249482',
106 106 '2013-07-03T16:34:52.249482Z',
107 107 '2013-07-03T16:34:52.249482Z-0800',
108 108 '2013-07-03T16:34:52.249482Z+0800',
109 109 '2013-07-03T16:34:52.249482Z+08:00',
110 110 '2013-07-03T16:34:52.249482Z-08:00',
111 111 '2013-07-03T16:34:52.249482-0800',
112 112 '2013-07-03T16:34:52.249482+0800',
113 113 '2013-07-03T16:34:52.249482+08:00',
114 114 '2013-07-03T16:34:52.249482-08:00',
115 115 ]
116 116 extracted = jsonutil.extract_dates(timestamps)
117 117 ref = extracted[0]
118 118 for dt in extracted:
119 119 nt.assert_true(isinstance(dt, datetime.datetime))
120 120 nt.assert_equal(dt, ref)
121 121
122 122 def test_parse_ms_precision():
123 base = '2013-07-03T16:34:52.'
123 base = '2013-07-03T16:34:52'
124 124 digits = '1234567890'
125 125
126 parsed = jsonutil.parse_date(base)
127 nt.assert_is_instance(parsed, datetime.datetime)
126 128 for i in range(len(digits)):
127 ts = base + digits[:i]
129 ts = base + '.' + digits[:i]
128 130 parsed = jsonutil.parse_date(ts)
129 131 if i >= 1 and i <= 6:
130 assert isinstance(parsed, datetime.datetime)
132 nt.assert_is_instance(parsed, datetime.datetime)
131 133 else:
132 assert isinstance(parsed, str)
134 nt.assert_is_instance(parsed, str)
133 135
134 136 def test_date_default():
135 137 data = dict(today=datetime.datetime.now(), utcnow=tz.utcnow())
136 138 jsondata = json.dumps(data, default=jsonutil.date_default)
137 139 nt.assert_in("+00", jsondata)
138 140 nt.assert_equal(jsondata.count("+00"), 1)
139 141 extracted = jsonutil.extract_dates(json.loads(jsondata))
140 142 for dt in extracted.values():
141 nt.assert_true(isinstance(dt, datetime.datetime))
143 nt.assert_is_instance(dt, datetime.datetime)
142 144
143 145 def test_exception():
144 146 bad_dicts = [{1:'number', '1':'string'},
145 147 {True:'bool', 'True':'string'},
146 148 ]
147 149 for d in bad_dicts:
148 150 nt.assert_raises(ValueError, json_clean, d)
149 151
General Comments 0
You need to be logged in to leave comments. Login now