##// END OF EJS Templates
Merge pull request #5098 from minrk/parallel-debug...
Thomas Kluyver -
r15313:9d27bffe merge
parent child Browse files
Show More
@@ -1,317 +1,317 b''
1 """A Task logger that presents our DB interface,
1 """A Task logger that presents our DB interface,
2 but exists entirely in memory and implemented with dicts.
2 but exists entirely in memory and implemented with dicts.
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7
7
8
8
9 TaskRecords are dicts of the form::
9 TaskRecords are dicts of the form::
10
10
11 {
11 {
12 'msg_id' : str(uuid),
12 'msg_id' : str(uuid),
13 'client_uuid' : str(uuid),
13 'client_uuid' : str(uuid),
14 'engine_uuid' : str(uuid) or None,
14 'engine_uuid' : str(uuid) or None,
15 'header' : dict(header),
15 'header' : dict(header),
16 'content': dict(content),
16 'content': dict(content),
17 'buffers': list(buffers),
17 'buffers': list(buffers),
18 'submitted': datetime,
18 'submitted': datetime or None,
19 'started': datetime or None,
19 'started': datetime or None,
20 'completed': datetime or None,
20 'completed': datetime or None,
21 'resubmitted': datetime or None,
21 'received': datetime or None,
22 'resubmitted': str(uuid) or None,
22 'result_header' : dict(header) or None,
23 'result_header' : dict(header) or None,
23 'result_content' : dict(content) or None,
24 'result_content' : dict(content) or None,
24 'result_buffers' : list(buffers) or None,
25 'result_buffers' : list(buffers) or None,
25 }
26 }
26
27
27 With this info, many of the special categories of tasks can be defined by query,
28 With this info, many of the special categories of tasks can be defined by query,
28 e.g.:
29 e.g.:
29
30
30 * pending: completed is None
31 * pending: completed is None
31 * client's outstanding: client_uuid = uuid && completed is None
32 * client's outstanding: client_uuid = uuid && completed is None
32 * MIA: arrived is None (and completed is None)
33 * MIA: arrived is None (and completed is None)
33
34
34 EngineRecords are dicts of the form::
35 DictDB supports a subset of mongodb operators::
35
36 {
37 'eid' : int(id),
38 'uuid': str(uuid)
39 }
40
41 This may be extended, but is currently.
42
43 We support a subset of mongodb operators::
44
36
45 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
37 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
46 """
38 """
47 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
48 # Copyright (C) 2010-2011 The IPython Development Team
40 # Copyright (C) 2010-2011 The IPython Development Team
49 #
41 #
50 # Distributed under the terms of the BSD License. The full license is in
42 # Distributed under the terms of the BSD License. The full license is in
51 # the file COPYING, distributed as part of this software.
43 # the file COPYING, distributed as part of this software.
52 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
53
45
54 from copy import deepcopy as copy
46 from copy import deepcopy as copy
55 from datetime import datetime
47 from datetime import datetime
56
48
57 from IPython.config.configurable import LoggingConfigurable
49 from IPython.config.configurable import LoggingConfigurable
58
50
59 from IPython.utils.py3compat import iteritems, itervalues
51 from IPython.utils.py3compat import iteritems, itervalues
60 from IPython.utils.traitlets import Dict, Unicode, Integer, Float
52 from IPython.utils.traitlets import Dict, Unicode, Integer, Float
61
53
62 filters = {
54 filters = {
63 '$lt' : lambda a,b: a < b,
55 '$lt' : lambda a,b: a < b,
64 '$gt' : lambda a,b: b > a,
56 '$gt' : lambda a,b: b > a,
65 '$eq' : lambda a,b: a == b,
57 '$eq' : lambda a,b: a == b,
66 '$ne' : lambda a,b: a != b,
58 '$ne' : lambda a,b: a != b,
67 '$lte': lambda a,b: a <= b,
59 '$lte': lambda a,b: a <= b,
68 '$gte': lambda a,b: a >= b,
60 '$gte': lambda a,b: a >= b,
69 '$in' : lambda a,b: a in b,
61 '$in' : lambda a,b: a in b,
70 '$nin': lambda a,b: a not in b,
62 '$nin': lambda a,b: a not in b,
71 '$all': lambda a,b: all([ a in bb for bb in b ]),
63 '$all': lambda a,b: all([ a in bb for bb in b ]),
72 '$mod': lambda a,b: a%b[0] == b[1],
64 '$mod': lambda a,b: a%b[0] == b[1],
73 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
65 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
74 }
66 }
75
67
76
68
77 class CompositeFilter(object):
69 class CompositeFilter(object):
78 """Composite filter for matching multiple properties."""
70 """Composite filter for matching multiple properties."""
79
71
80 def __init__(self, dikt):
72 def __init__(self, dikt):
81 self.tests = []
73 self.tests = []
82 self.values = []
74 self.values = []
83 for key, value in iteritems(dikt):
75 for key, value in iteritems(dikt):
84 self.tests.append(filters[key])
76 self.tests.append(filters[key])
85 self.values.append(value)
77 self.values.append(value)
86
78
87 def __call__(self, value):
79 def __call__(self, value):
88 for test,check in zip(self.tests, self.values):
80 for test,check in zip(self.tests, self.values):
89 if not test(value, check):
81 if not test(value, check):
90 return False
82 return False
91 return True
83 return True
92
84
93 class BaseDB(LoggingConfigurable):
85 class BaseDB(LoggingConfigurable):
94 """Empty Parent class so traitlets work on DB."""
86 """Empty Parent class so traitlets work on DB."""
95 # base configurable traits:
87 # base configurable traits:
96 session = Unicode("")
88 session = Unicode("")
97
89
98 class DictDB(BaseDB):
90 class DictDB(BaseDB):
99 """Basic in-memory dict-based object for saving Task Records.
91 """Basic in-memory dict-based object for saving Task Records.
100
92
101 This is the first object to present the DB interface
93 This is the first object to present the DB interface
102 for logging tasks out of memory.
94 for logging tasks out of memory.
103
95
104 The interface is based on MongoDB, so adding a MongoDB
96 The interface is based on MongoDB, so adding a MongoDB
105 backend should be straightforward.
97 backend should be straightforward.
106 """
98 """
107
99
108 _records = Dict()
100 _records = Dict()
109 _culled_ids = set() # set of ids which have been culled
101 _culled_ids = set() # set of ids which have been culled
110 _buffer_bytes = Integer(0) # running total of the bytes in the DB
102 _buffer_bytes = Integer(0) # running total of the bytes in the DB
111
103
112 size_limit = Integer(1024**3, config=True,
104 size_limit = Integer(1024**3, config=True,
113 help="""The maximum total size (in bytes) of the buffers stored in the db
105 help="""The maximum total size (in bytes) of the buffers stored in the db
114
106
115 When the db exceeds this size, the oldest records will be culled until
107 When the db exceeds this size, the oldest records will be culled until
116 the total size is under size_limit * (1-cull_fraction).
108 the total size is under size_limit * (1-cull_fraction).
117 default: 1 GB
109 default: 1 GB
118 """
110 """
119 )
111 )
120 record_limit = Integer(1024, config=True,
112 record_limit = Integer(1024, config=True,
121 help="""The maximum number of records in the db
113 help="""The maximum number of records in the db
122
114
123 When the history exceeds this size, the first record_limit * cull_fraction
115 When the history exceeds this size, the first record_limit * cull_fraction
124 records will be culled.
116 records will be culled.
125 """
117 """
126 )
118 )
127 cull_fraction = Float(0.1, config=True,
119 cull_fraction = Float(0.1, config=True,
128 help="""The fraction by which the db should culled when one of the limits is exceeded
120 help="""The fraction by which the db should culled when one of the limits is exceeded
129
121
130 In general, the db size will spend most of its time with a size in the range:
122 In general, the db size will spend most of its time with a size in the range:
131
123
132 [limit * (1-cull_fraction), limit]
124 [limit * (1-cull_fraction), limit]
133
125
134 for each of size_limit and record_limit.
126 for each of size_limit and record_limit.
135 """
127 """
136 )
128 )
137
129
138 def _match_one(self, rec, tests):
130 def _match_one(self, rec, tests):
139 """Check if a specific record matches tests."""
131 """Check if a specific record matches tests."""
140 for key,test in iteritems(tests):
132 for key,test in iteritems(tests):
141 if not test(rec.get(key, None)):
133 if not test(rec.get(key, None)):
142 return False
134 return False
143 return True
135 return True
144
136
145 def _match(self, check):
137 def _match(self, check):
146 """Find all the matches for a check dict."""
138 """Find all the matches for a check dict."""
147 matches = []
139 matches = []
148 tests = {}
140 tests = {}
149 for k,v in iteritems(check):
141 for k,v in iteritems(check):
150 if isinstance(v, dict):
142 if isinstance(v, dict):
151 tests[k] = CompositeFilter(v)
143 tests[k] = CompositeFilter(v)
152 else:
144 else:
153 tests[k] = lambda o: o==v
145 tests[k] = lambda o: o==v
154
146
155 for rec in itervalues(self._records):
147 for rec in itervalues(self._records):
156 if self._match_one(rec, tests):
148 if self._match_one(rec, tests):
157 matches.append(copy(rec))
149 matches.append(copy(rec))
158 return matches
150 return matches
159
151
160 def _extract_subdict(self, rec, keys):
152 def _extract_subdict(self, rec, keys):
161 """extract subdict of keys"""
153 """extract subdict of keys"""
162 d = {}
154 d = {}
163 d['msg_id'] = rec['msg_id']
155 d['msg_id'] = rec['msg_id']
164 for key in keys:
156 for key in keys:
165 d[key] = rec[key]
157 d[key] = rec[key]
166 return copy(d)
158 return copy(d)
167
159
168 # methods for monitoring size / culling history
160 # methods for monitoring size / culling history
169
161
170 def _add_bytes(self, rec):
162 def _add_bytes(self, rec):
171 for key in ('buffers', 'result_buffers'):
163 for key in ('buffers', 'result_buffers'):
172 for buf in rec.get(key) or []:
164 for buf in rec.get(key) or []:
173 self._buffer_bytes += len(buf)
165 self._buffer_bytes += len(buf)
174
166
175 self._maybe_cull()
167 self._maybe_cull()
176
168
177 def _drop_bytes(self, rec):
169 def _drop_bytes(self, rec):
178 for key in ('buffers', 'result_buffers'):
170 for key in ('buffers', 'result_buffers'):
179 for buf in rec.get(key) or []:
171 for buf in rec.get(key) or []:
180 self._buffer_bytes -= len(buf)
172 self._buffer_bytes -= len(buf)
181
173
182 def _cull_oldest(self, n=1):
174 def _cull_oldest(self, n=1):
183 """cull the oldest N records"""
175 """cull the oldest N records"""
184 for msg_id in self.get_history()[:n]:
176 for msg_id in self.get_history()[:n]:
185 self.log.debug("Culling record: %r", msg_id)
177 self.log.debug("Culling record: %r", msg_id)
186 self._culled_ids.add(msg_id)
178 self._culled_ids.add(msg_id)
187 self.drop_record(msg_id)
179 self.drop_record(msg_id)
188
180
189 def _maybe_cull(self):
181 def _maybe_cull(self):
190 # cull by count:
182 # cull by count:
191 if len(self._records) > self.record_limit:
183 if len(self._records) > self.record_limit:
192 to_cull = int(self.cull_fraction * self.record_limit)
184 to_cull = int(self.cull_fraction * self.record_limit)
193 self.log.info("%i records exceeds limit of %i, culling oldest %i",
185 self.log.info("%i records exceeds limit of %i, culling oldest %i",
194 len(self._records), self.record_limit, to_cull
186 len(self._records), self.record_limit, to_cull
195 )
187 )
196 self._cull_oldest(to_cull)
188 self._cull_oldest(to_cull)
197
189
198 # cull by size:
190 # cull by size:
199 if self._buffer_bytes > self.size_limit:
191 if self._buffer_bytes > self.size_limit:
200 limit = self.size_limit * (1 - self.cull_fraction)
192 limit = self.size_limit * (1 - self.cull_fraction)
201
193
202 before = self._buffer_bytes
194 before = self._buffer_bytes
203 before_count = len(self._records)
195 before_count = len(self._records)
204 culled = 0
196 culled = 0
205 while self._buffer_bytes > limit:
197 while self._buffer_bytes > limit:
206 self._cull_oldest(1)
198 self._cull_oldest(1)
207 culled += 1
199 culled += 1
208
200
209 self.log.info("%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.",
201 self.log.info("%i records with total buffer size %i exceeds limit: %i. Culled oldest %i records.",
210 before_count, before, self.size_limit, culled
202 before_count, before, self.size_limit, culled
211 )
203 )
212
204
205 def _check_dates(self, rec):
206 for key in ('submitted', 'started', 'completed'):
207 value = rec.get(key, None)
208 if value is not None and not isinstance(value, datetime):
209 raise ValueError("%s must be None or datetime, not %r" % (key, value))
210
213 # public API methods:
211 # public API methods:
214
212
215 def add_record(self, msg_id, rec):
213 def add_record(self, msg_id, rec):
216 """Add a new Task Record, by msg_id."""
214 """Add a new Task Record, by msg_id."""
217 if msg_id in self._records:
215 if msg_id in self._records:
218 raise KeyError("Already have msg_id %r"%(msg_id))
216 raise KeyError("Already have msg_id %r"%(msg_id))
217 self._check_dates(rec)
219 self._records[msg_id] = rec
218 self._records[msg_id] = rec
220 self._add_bytes(rec)
219 self._add_bytes(rec)
221 self._maybe_cull()
220 self._maybe_cull()
222
221
223 def get_record(self, msg_id):
222 def get_record(self, msg_id):
224 """Get a specific Task Record, by msg_id."""
223 """Get a specific Task Record, by msg_id."""
225 if msg_id in self._culled_ids:
224 if msg_id in self._culled_ids:
226 raise KeyError("Record %r has been culled for size" % msg_id)
225 raise KeyError("Record %r has been culled for size" % msg_id)
227 if not msg_id in self._records:
226 if not msg_id in self._records:
228 raise KeyError("No such msg_id %r"%(msg_id))
227 raise KeyError("No such msg_id %r"%(msg_id))
229 return copy(self._records[msg_id])
228 return copy(self._records[msg_id])
230
229
231 def update_record(self, msg_id, rec):
230 def update_record(self, msg_id, rec):
232 """Update the data in an existing record."""
231 """Update the data in an existing record."""
233 if msg_id in self._culled_ids:
232 if msg_id in self._culled_ids:
234 raise KeyError("Record %r has been culled for size" % msg_id)
233 raise KeyError("Record %r has been culled for size" % msg_id)
234 self._check_dates(rec)
235 _rec = self._records[msg_id]
235 _rec = self._records[msg_id]
236 self._drop_bytes(_rec)
236 self._drop_bytes(_rec)
237 _rec.update(rec)
237 _rec.update(rec)
238 self._add_bytes(_rec)
238 self._add_bytes(_rec)
239
239
240 def drop_matching_records(self, check):
240 def drop_matching_records(self, check):
241 """Remove a record from the DB."""
241 """Remove a record from the DB."""
242 matches = self._match(check)
242 matches = self._match(check)
243 for rec in matches:
243 for rec in matches:
244 self._drop_bytes(rec)
244 self._drop_bytes(rec)
245 del self._records[rec['msg_id']]
245 del self._records[rec['msg_id']]
246
246
247 def drop_record(self, msg_id):
247 def drop_record(self, msg_id):
248 """Remove a record from the DB."""
248 """Remove a record from the DB."""
249 rec = self._records[msg_id]
249 rec = self._records[msg_id]
250 self._drop_bytes(rec)
250 self._drop_bytes(rec)
251 del self._records[msg_id]
251 del self._records[msg_id]
252
252
253 def find_records(self, check, keys=None):
253 def find_records(self, check, keys=None):
254 """Find records matching a query dict, optionally extracting subset of keys.
254 """Find records matching a query dict, optionally extracting subset of keys.
255
255
256 Returns dict keyed by msg_id of matching records.
256 Returns dict keyed by msg_id of matching records.
257
257
258 Parameters
258 Parameters
259 ----------
259 ----------
260
260
261 check: dict
261 check: dict
262 mongodb-style query argument
262 mongodb-style query argument
263 keys: list of strs [optional]
263 keys: list of strs [optional]
264 if specified, the subset of keys to extract. msg_id will *always* be
264 if specified, the subset of keys to extract. msg_id will *always* be
265 included.
265 included.
266 """
266 """
267 matches = self._match(check)
267 matches = self._match(check)
268 if keys:
268 if keys:
269 return [ self._extract_subdict(rec, keys) for rec in matches ]
269 return [ self._extract_subdict(rec, keys) for rec in matches ]
270 else:
270 else:
271 return matches
271 return matches
272
272
273 def get_history(self):
273 def get_history(self):
274 """get all msg_ids, ordered by time submitted."""
274 """get all msg_ids, ordered by time submitted."""
275 msg_ids = self._records.keys()
275 msg_ids = self._records.keys()
276 # Remove any that do not have a submitted timestamp.
276 # Remove any that do not have a submitted timestamp.
277 # This is extremely unlikely to happen,
277 # This is extremely unlikely to happen,
278 # but it seems to come up in some tests on VMs.
278 # but it seems to come up in some tests on VMs.
279 msg_ids = [ m for m in msg_ids if self._records[m]['submitted'] is not None ]
279 msg_ids = [ m for m in msg_ids if self._records[m]['submitted'] is not None ]
280 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
280 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
281
281
282
282
283 NODATA = KeyError("NoDB backend doesn't store any data. "
283 NODATA = KeyError("NoDB backend doesn't store any data. "
284 "Start the Controller with a DB backend to enable resubmission / result persistence."
284 "Start the Controller with a DB backend to enable resubmission / result persistence."
285 )
285 )
286
286
287
287
288 class NoDB(BaseDB):
288 class NoDB(BaseDB):
289 """A blackhole db backend that actually stores no information.
289 """A blackhole db backend that actually stores no information.
290
290
291 Provides the full DB interface, but raises KeyErrors on any
291 Provides the full DB interface, but raises KeyErrors on any
292 method that tries to access the records. This can be used to
292 method that tries to access the records. This can be used to
293 minimize the memory footprint of the Hub when its record-keeping
293 minimize the memory footprint of the Hub when its record-keeping
294 functionality is not required.
294 functionality is not required.
295 """
295 """
296
296
297 def add_record(self, msg_id, record):
297 def add_record(self, msg_id, record):
298 pass
298 pass
299
299
300 def get_record(self, msg_id):
300 def get_record(self, msg_id):
301 raise NODATA
301 raise NODATA
302
302
303 def update_record(self, msg_id, record):
303 def update_record(self, msg_id, record):
304 pass
304 pass
305
305
306 def drop_matching_records(self, check):
306 def drop_matching_records(self, check):
307 pass
307 pass
308
308
309 def drop_record(self, msg_id):
309 def drop_record(self, msg_id):
310 pass
310 pass
311
311
312 def find_records(self, check, keys=None):
312 def find_records(self, check, keys=None):
313 raise NODATA
313 raise NODATA
314
314
315 def get_history(self):
315 def get_history(self):
316 raise NODATA
316 raise NODATA
317
317
@@ -1,1426 +1,1436 b''
1 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
2 This is the master object that handles connections from engines and clients,
2 This is the master object that handles connections from engines and clients,
3 and monitors traffic through the various queues.
3 and monitors traffic through the various queues.
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 """
8 """
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 from __future__ import print_function
19 from __future__ import print_function
20
20
21 import json
21 import json
22 import os
22 import os
23 import sys
23 import sys
24 import time
24 import time
25 from datetime import datetime
25 from datetime import datetime
26
26
27 import zmq
27 import zmq
28 from zmq.eventloop import ioloop
28 from zmq.eventloop import ioloop
29 from zmq.eventloop.zmqstream import ZMQStream
29 from zmq.eventloop.zmqstream import ZMQStream
30
30
31 # internal:
31 # internal:
32 from IPython.utils.importstring import import_item
32 from IPython.utils.importstring import import_item
33 from IPython.utils.jsonutil import extract_dates
33 from IPython.utils.jsonutil import extract_dates
34 from IPython.utils.localinterfaces import localhost
34 from IPython.utils.localinterfaces import localhost
35 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
35 from IPython.utils.py3compat import cast_bytes, unicode_type, iteritems
36 from IPython.utils.traitlets import (
36 from IPython.utils.traitlets import (
37 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
38 )
38 )
39
39
40 from IPython.parallel import error, util
40 from IPython.parallel import error, util
41 from IPython.parallel.factory import RegistrationFactory
41 from IPython.parallel.factory import RegistrationFactory
42
42
43 from IPython.kernel.zmq.session import SessionFactory
43 from IPython.kernel.zmq.session import SessionFactory
44
44
45 from .heartmonitor import HeartMonitor
45 from .heartmonitor import HeartMonitor
46
46
47 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
48 # Code
48 # Code
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50
50
51 def _passer(*args, **kwargs):
51 def _passer(*args, **kwargs):
52 return
52 return
53
53
54 def _printer(*args, **kwargs):
54 def _printer(*args, **kwargs):
55 print (args)
55 print (args)
56 print (kwargs)
56 print (kwargs)
57
57
58 def empty_record():
58 def empty_record():
59 """Return an empty dict with all record keys."""
59 """Return an empty dict with all record keys."""
60 return {
60 return {
61 'msg_id' : None,
61 'msg_id' : None,
62 'header' : None,
62 'header' : None,
63 'metadata' : None,
63 'metadata' : None,
64 'content': None,
64 'content': None,
65 'buffers': None,
65 'buffers': None,
66 'submitted': None,
66 'submitted': None,
67 'client_uuid' : None,
67 'client_uuid' : None,
68 'engine_uuid' : None,
68 'engine_uuid' : None,
69 'started': None,
69 'started': None,
70 'completed': None,
70 'completed': None,
71 'resubmitted': None,
71 'resubmitted': None,
72 'received': None,
72 'received': None,
73 'result_header' : None,
73 'result_header' : None,
74 'result_metadata' : None,
74 'result_metadata' : None,
75 'result_content' : None,
75 'result_content' : None,
76 'result_buffers' : None,
76 'result_buffers' : None,
77 'queue' : None,
77 'queue' : None,
78 'pyin' : None,
78 'pyin' : None,
79 'pyout': None,
79 'pyout': None,
80 'pyerr': None,
80 'pyerr': None,
81 'stdout': '',
81 'stdout': '',
82 'stderr': '',
82 'stderr': '',
83 }
83 }
84
84
85 def init_record(msg):
85 def init_record(msg):
86 """Initialize a TaskRecord based on a request."""
86 """Initialize a TaskRecord based on a request."""
87 header = msg['header']
87 header = msg['header']
88 return {
88 return {
89 'msg_id' : header['msg_id'],
89 'msg_id' : header['msg_id'],
90 'header' : header,
90 'header' : header,
91 'content': msg['content'],
91 'content': msg['content'],
92 'metadata': msg['metadata'],
92 'metadata': msg['metadata'],
93 'buffers': msg['buffers'],
93 'buffers': msg['buffers'],
94 'submitted': header['date'],
94 'submitted': header['date'],
95 'client_uuid' : None,
95 'client_uuid' : None,
96 'engine_uuid' : None,
96 'engine_uuid' : None,
97 'started': None,
97 'started': None,
98 'completed': None,
98 'completed': None,
99 'resubmitted': None,
99 'resubmitted': None,
100 'received': None,
100 'received': None,
101 'result_header' : None,
101 'result_header' : None,
102 'result_metadata': None,
102 'result_metadata': None,
103 'result_content' : None,
103 'result_content' : None,
104 'result_buffers' : None,
104 'result_buffers' : None,
105 'queue' : None,
105 'queue' : None,
106 'pyin' : None,
106 'pyin' : None,
107 'pyout': None,
107 'pyout': None,
108 'pyerr': None,
108 'pyerr': None,
109 'stdout': '',
109 'stdout': '',
110 'stderr': '',
110 'stderr': '',
111 }
111 }
112
112
113
113
114 class EngineConnector(HasTraits):
114 class EngineConnector(HasTraits):
115 """A simple object for accessing the various zmq connections of an object.
115 """A simple object for accessing the various zmq connections of an object.
116 Attributes are:
116 Attributes are:
117 id (int): engine ID
117 id (int): engine ID
118 uuid (unicode): engine UUID
118 uuid (unicode): engine UUID
119 pending: set of msg_ids
119 pending: set of msg_ids
120 stallback: DelayedCallback for stalled registration
120 stallback: DelayedCallback for stalled registration
121 """
121 """
122
122
123 id = Integer(0)
123 id = Integer(0)
124 uuid = Unicode()
124 uuid = Unicode()
125 pending = Set()
125 pending = Set()
126 stallback = Instance(ioloop.DelayedCallback)
126 stallback = Instance(ioloop.DelayedCallback)
127
127
128
128
129 _db_shortcuts = {
129 _db_shortcuts = {
130 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
131 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
132 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
133 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
134 }
134 }
135
135
136 class HubFactory(RegistrationFactory):
136 class HubFactory(RegistrationFactory):
137 """The Configurable for setting up a Hub."""
137 """The Configurable for setting up a Hub."""
138
138
139 # port-pairs for monitoredqueues:
139 # port-pairs for monitoredqueues:
140 hb = Tuple(Integer,Integer,config=True,
140 hb = Tuple(Integer,Integer,config=True,
141 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 help="""PUB/ROUTER Port pair for Engine heartbeats""")
142 def _hb_default(self):
142 def _hb_default(self):
143 return tuple(util.select_random_ports(2))
143 return tuple(util.select_random_ports(2))
144
144
145 mux = Tuple(Integer,Integer,config=True,
145 mux = Tuple(Integer,Integer,config=True,
146 help="""Client/Engine Port pair for MUX queue""")
146 help="""Client/Engine Port pair for MUX queue""")
147
147
148 def _mux_default(self):
148 def _mux_default(self):
149 return tuple(util.select_random_ports(2))
149 return tuple(util.select_random_ports(2))
150
150
151 task = Tuple(Integer,Integer,config=True,
151 task = Tuple(Integer,Integer,config=True,
152 help="""Client/Engine Port pair for Task queue""")
152 help="""Client/Engine Port pair for Task queue""")
153 def _task_default(self):
153 def _task_default(self):
154 return tuple(util.select_random_ports(2))
154 return tuple(util.select_random_ports(2))
155
155
156 control = Tuple(Integer,Integer,config=True,
156 control = Tuple(Integer,Integer,config=True,
157 help="""Client/Engine Port pair for Control queue""")
157 help="""Client/Engine Port pair for Control queue""")
158
158
159 def _control_default(self):
159 def _control_default(self):
160 return tuple(util.select_random_ports(2))
160 return tuple(util.select_random_ports(2))
161
161
162 iopub = Tuple(Integer,Integer,config=True,
162 iopub = Tuple(Integer,Integer,config=True,
163 help="""Client/Engine Port pair for IOPub relay""")
163 help="""Client/Engine Port pair for IOPub relay""")
164
164
165 def _iopub_default(self):
165 def _iopub_default(self):
166 return tuple(util.select_random_ports(2))
166 return tuple(util.select_random_ports(2))
167
167
168 # single ports:
168 # single ports:
169 mon_port = Integer(config=True,
169 mon_port = Integer(config=True,
170 help="""Monitor (SUB) port for queue traffic""")
170 help="""Monitor (SUB) port for queue traffic""")
171
171
172 def _mon_port_default(self):
172 def _mon_port_default(self):
173 return util.select_random_ports(1)[0]
173 return util.select_random_ports(1)[0]
174
174
175 notifier_port = Integer(config=True,
175 notifier_port = Integer(config=True,
176 help="""PUB port for sending engine status notifications""")
176 help="""PUB port for sending engine status notifications""")
177
177
178 def _notifier_port_default(self):
178 def _notifier_port_default(self):
179 return util.select_random_ports(1)[0]
179 return util.select_random_ports(1)[0]
180
180
181 engine_ip = Unicode(config=True,
181 engine_ip = Unicode(config=True,
182 help="IP on which to listen for engine connections. [default: loopback]")
182 help="IP on which to listen for engine connections. [default: loopback]")
183 def _engine_ip_default(self):
183 def _engine_ip_default(self):
184 return localhost()
184 return localhost()
185 engine_transport = Unicode('tcp', config=True,
185 engine_transport = Unicode('tcp', config=True,
186 help="0MQ transport for engine connections. [default: tcp]")
186 help="0MQ transport for engine connections. [default: tcp]")
187
187
188 client_ip = Unicode(config=True,
188 client_ip = Unicode(config=True,
189 help="IP on which to listen for client connections. [default: loopback]")
189 help="IP on which to listen for client connections. [default: loopback]")
190 client_transport = Unicode('tcp', config=True,
190 client_transport = Unicode('tcp', config=True,
191 help="0MQ transport for client connections. [default : tcp]")
191 help="0MQ transport for client connections. [default : tcp]")
192
192
193 monitor_ip = Unicode(config=True,
193 monitor_ip = Unicode(config=True,
194 help="IP on which to listen for monitor messages. [default: loopback]")
194 help="IP on which to listen for monitor messages. [default: loopback]")
195 monitor_transport = Unicode('tcp', config=True,
195 monitor_transport = Unicode('tcp', config=True,
196 help="0MQ transport for monitor messages. [default : tcp]")
196 help="0MQ transport for monitor messages. [default : tcp]")
197
197
198 _client_ip_default = _monitor_ip_default = _engine_ip_default
198 _client_ip_default = _monitor_ip_default = _engine_ip_default
199
199
200
200
201 monitor_url = Unicode('')
201 monitor_url = Unicode('')
202
202
203 db_class = DottedObjectName('NoDB',
203 db_class = DottedObjectName('NoDB',
204 config=True, help="""The class to use for the DB backend
204 config=True, help="""The class to use for the DB backend
205
205
206 Options include:
206 Options include:
207
207
208 SQLiteDB: SQLite
208 SQLiteDB: SQLite
209 MongoDB : use MongoDB
209 MongoDB : use MongoDB
210 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
211 NoDB : disable database altogether (default)
211 NoDB : disable database altogether (default)
212
212
213 """)
213 """)
214
214
215 # not configurable
215 # not configurable
216 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
217 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
218
218
219 def _ip_changed(self, name, old, new):
219 def _ip_changed(self, name, old, new):
220 self.engine_ip = new
220 self.engine_ip = new
221 self.client_ip = new
221 self.client_ip = new
222 self.monitor_ip = new
222 self.monitor_ip = new
223 self._update_monitor_url()
223 self._update_monitor_url()
224
224
225 def _update_monitor_url(self):
225 def _update_monitor_url(self):
226 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
227
227
228 def _transport_changed(self, name, old, new):
228 def _transport_changed(self, name, old, new):
229 self.engine_transport = new
229 self.engine_transport = new
230 self.client_transport = new
230 self.client_transport = new
231 self.monitor_transport = new
231 self.monitor_transport = new
232 self._update_monitor_url()
232 self._update_monitor_url()
233
233
234 def __init__(self, **kwargs):
234 def __init__(self, **kwargs):
235 super(HubFactory, self).__init__(**kwargs)
235 super(HubFactory, self).__init__(**kwargs)
236 self._update_monitor_url()
236 self._update_monitor_url()
237
237
238
238
239 def construct(self):
239 def construct(self):
240 self.init_hub()
240 self.init_hub()
241
241
242 def start(self):
242 def start(self):
243 self.heartmonitor.start()
243 self.heartmonitor.start()
244 self.log.info("Heartmonitor started")
244 self.log.info("Heartmonitor started")
245
245
246 def client_url(self, channel):
246 def client_url(self, channel):
247 """return full zmq url for a named client channel"""
247 """return full zmq url for a named client channel"""
248 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
249
249
250 def engine_url(self, channel):
250 def engine_url(self, channel):
251 """return full zmq url for a named engine channel"""
251 """return full zmq url for a named engine channel"""
252 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
253
253
254 def init_hub(self):
254 def init_hub(self):
255 """construct Hub object"""
255 """construct Hub object"""
256
256
257 ctx = self.context
257 ctx = self.context
258 loop = self.loop
258 loop = self.loop
259 if 'TaskScheduler.scheme_name' in self.config:
259 if 'TaskScheduler.scheme_name' in self.config:
260 scheme = self.config.TaskScheduler.scheme_name
260 scheme = self.config.TaskScheduler.scheme_name
261 else:
261 else:
262 from .scheduler import TaskScheduler
262 from .scheduler import TaskScheduler
263 scheme = TaskScheduler.scheme_name.get_default_value()
263 scheme = TaskScheduler.scheme_name.get_default_value()
264
264
265 # build connection dicts
265 # build connection dicts
266 engine = self.engine_info = {
266 engine = self.engine_info = {
267 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
267 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
268 'registration' : self.regport,
268 'registration' : self.regport,
269 'control' : self.control[1],
269 'control' : self.control[1],
270 'mux' : self.mux[1],
270 'mux' : self.mux[1],
271 'hb_ping' : self.hb[0],
271 'hb_ping' : self.hb[0],
272 'hb_pong' : self.hb[1],
272 'hb_pong' : self.hb[1],
273 'task' : self.task[1],
273 'task' : self.task[1],
274 'iopub' : self.iopub[1],
274 'iopub' : self.iopub[1],
275 }
275 }
276
276
277 client = self.client_info = {
277 client = self.client_info = {
278 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
278 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
279 'registration' : self.regport,
279 'registration' : self.regport,
280 'control' : self.control[0],
280 'control' : self.control[0],
281 'mux' : self.mux[0],
281 'mux' : self.mux[0],
282 'task' : self.task[0],
282 'task' : self.task[0],
283 'task_scheme' : scheme,
283 'task_scheme' : scheme,
284 'iopub' : self.iopub[0],
284 'iopub' : self.iopub[0],
285 'notification' : self.notifier_port,
285 'notification' : self.notifier_port,
286 }
286 }
287
287
288 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub engine addrs: %s", self.engine_info)
289 self.log.debug("Hub client addrs: %s", self.client_info)
289 self.log.debug("Hub client addrs: %s", self.client_info)
290
290
291 # Registrar socket
291 # Registrar socket
292 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
292 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
293 util.set_hwm(q, 0)
293 util.set_hwm(q, 0)
294 q.bind(self.client_url('registration'))
294 q.bind(self.client_url('registration'))
295 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
295 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
296 if self.client_ip != self.engine_ip:
296 if self.client_ip != self.engine_ip:
297 q.bind(self.engine_url('registration'))
297 q.bind(self.engine_url('registration'))
298 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
298 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
299
299
300 ### Engine connections ###
300 ### Engine connections ###
301
301
302 # heartbeat
302 # heartbeat
303 hpub = ctx.socket(zmq.PUB)
303 hpub = ctx.socket(zmq.PUB)
304 hpub.bind(self.engine_url('hb_ping'))
304 hpub.bind(self.engine_url('hb_ping'))
305 hrep = ctx.socket(zmq.ROUTER)
305 hrep = ctx.socket(zmq.ROUTER)
306 util.set_hwm(hrep, 0)
306 util.set_hwm(hrep, 0)
307 hrep.bind(self.engine_url('hb_pong'))
307 hrep.bind(self.engine_url('hb_pong'))
308 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
308 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
309 pingstream=ZMQStream(hpub,loop),
309 pingstream=ZMQStream(hpub,loop),
310 pongstream=ZMQStream(hrep,loop)
310 pongstream=ZMQStream(hrep,loop)
311 )
311 )
312
312
313 ### Client connections ###
313 ### Client connections ###
314
314
315 # Notifier socket
315 # Notifier socket
316 n = ZMQStream(ctx.socket(zmq.PUB), loop)
316 n = ZMQStream(ctx.socket(zmq.PUB), loop)
317 n.bind(self.client_url('notification'))
317 n.bind(self.client_url('notification'))
318
318
319 ### build and launch the queues ###
319 ### build and launch the queues ###
320
320
321 # monitor socket
321 # monitor socket
322 sub = ctx.socket(zmq.SUB)
322 sub = ctx.socket(zmq.SUB)
323 sub.setsockopt(zmq.SUBSCRIBE, b"")
323 sub.setsockopt(zmq.SUBSCRIBE, b"")
324 sub.bind(self.monitor_url)
324 sub.bind(self.monitor_url)
325 sub.bind('inproc://monitor')
325 sub.bind('inproc://monitor')
326 sub = ZMQStream(sub, loop)
326 sub = ZMQStream(sub, loop)
327
327
328 # connect the db
328 # connect the db
329 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
329 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
330 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
330 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
331 self.db = import_item(str(db_class))(session=self.session.session,
331 self.db = import_item(str(db_class))(session=self.session.session,
332 parent=self, log=self.log)
332 parent=self, log=self.log)
333 time.sleep(.25)
333 time.sleep(.25)
334
334
335 # resubmit stream
335 # resubmit stream
336 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
336 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
337 url = util.disambiguate_url(self.client_url('task'))
337 url = util.disambiguate_url(self.client_url('task'))
338 r.connect(url)
338 r.connect(url)
339
339
340 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
340 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
341 query=q, notifier=n, resubmit=r, db=self.db,
341 query=q, notifier=n, resubmit=r, db=self.db,
342 engine_info=self.engine_info, client_info=self.client_info,
342 engine_info=self.engine_info, client_info=self.client_info,
343 log=self.log)
343 log=self.log)
344
344
345
345
346 class Hub(SessionFactory):
346 class Hub(SessionFactory):
347 """The IPython Controller Hub with 0MQ connections
347 """The IPython Controller Hub with 0MQ connections
348
348
349 Parameters
349 Parameters
350 ==========
350 ==========
351 loop: zmq IOLoop instance
351 loop: zmq IOLoop instance
352 session: Session object
352 session: Session object
353 <removed> context: zmq context for creating new connections (?)
353 <removed> context: zmq context for creating new connections (?)
354 queue: ZMQStream for monitoring the command queue (SUB)
354 queue: ZMQStream for monitoring the command queue (SUB)
355 query: ZMQStream for engine registration and client queries requests (ROUTER)
355 query: ZMQStream for engine registration and client queries requests (ROUTER)
356 heartbeat: HeartMonitor object checking the pulse of the engines
356 heartbeat: HeartMonitor object checking the pulse of the engines
357 notifier: ZMQStream for broadcasting engine registration changes (PUB)
357 notifier: ZMQStream for broadcasting engine registration changes (PUB)
358 db: connection to db for out of memory logging of commands
358 db: connection to db for out of memory logging of commands
359 NotImplemented
359 NotImplemented
360 engine_info: dict of zmq connection information for engines to connect
360 engine_info: dict of zmq connection information for engines to connect
361 to the queues.
361 to the queues.
362 client_info: dict of zmq connection information for engines to connect
362 client_info: dict of zmq connection information for engines to connect
363 to the queues.
363 to the queues.
364 """
364 """
365
365
366 engine_state_file = Unicode()
366 engine_state_file = Unicode()
367
367
368 # internal data structures:
368 # internal data structures:
369 ids=Set() # engine IDs
369 ids=Set() # engine IDs
370 keytable=Dict()
370 keytable=Dict()
371 by_ident=Dict()
371 by_ident=Dict()
372 engines=Dict()
372 engines=Dict()
373 clients=Dict()
373 clients=Dict()
374 hearts=Dict()
374 hearts=Dict()
375 pending=Set()
375 pending=Set()
376 queues=Dict() # pending msg_ids keyed by engine_id
376 queues=Dict() # pending msg_ids keyed by engine_id
377 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
377 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
378 completed=Dict() # completed msg_ids keyed by engine_id
378 completed=Dict() # completed msg_ids keyed by engine_id
379 all_completed=Set() # completed msg_ids keyed by engine_id
379 all_completed=Set() # completed msg_ids keyed by engine_id
380 dead_engines=Set() # completed msg_ids keyed by engine_id
380 dead_engines=Set() # completed msg_ids keyed by engine_id
381 unassigned=Set() # set of task msg_ds not yet assigned a destination
381 unassigned=Set() # set of task msg_ds not yet assigned a destination
382 incoming_registrations=Dict()
382 incoming_registrations=Dict()
383 registration_timeout=Integer()
383 registration_timeout=Integer()
384 _idcounter=Integer(0)
384 _idcounter=Integer(0)
385
385
386 # objects from constructor:
386 # objects from constructor:
387 query=Instance(ZMQStream)
387 query=Instance(ZMQStream)
388 monitor=Instance(ZMQStream)
388 monitor=Instance(ZMQStream)
389 notifier=Instance(ZMQStream)
389 notifier=Instance(ZMQStream)
390 resubmit=Instance(ZMQStream)
390 resubmit=Instance(ZMQStream)
391 heartmonitor=Instance(HeartMonitor)
391 heartmonitor=Instance(HeartMonitor)
392 db=Instance(object)
392 db=Instance(object)
393 client_info=Dict()
393 client_info=Dict()
394 engine_info=Dict()
394 engine_info=Dict()
395
395
396
396
397 def __init__(self, **kwargs):
397 def __init__(self, **kwargs):
398 """
398 """
399 # universal:
399 # universal:
400 loop: IOLoop for creating future connections
400 loop: IOLoop for creating future connections
401 session: streamsession for sending serialized data
401 session: streamsession for sending serialized data
402 # engine:
402 # engine:
403 queue: ZMQStream for monitoring queue messages
403 queue: ZMQStream for monitoring queue messages
404 query: ZMQStream for engine+client registration and client requests
404 query: ZMQStream for engine+client registration and client requests
405 heartbeat: HeartMonitor object for tracking engines
405 heartbeat: HeartMonitor object for tracking engines
406 # extra:
406 # extra:
407 db: ZMQStream for db connection (NotImplemented)
407 db: ZMQStream for db connection (NotImplemented)
408 engine_info: zmq address/protocol dict for engine connections
408 engine_info: zmq address/protocol dict for engine connections
409 client_info: zmq address/protocol dict for client connections
409 client_info: zmq address/protocol dict for client connections
410 """
410 """
411
411
412 super(Hub, self).__init__(**kwargs)
412 super(Hub, self).__init__(**kwargs)
413 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
413 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
414
414
415 # register our callbacks
415 # register our callbacks
416 self.query.on_recv(self.dispatch_query)
416 self.query.on_recv(self.dispatch_query)
417 self.monitor.on_recv(self.dispatch_monitor_traffic)
417 self.monitor.on_recv(self.dispatch_monitor_traffic)
418
418
419 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
419 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
420 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
420 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
421
421
422 self.monitor_handlers = {b'in' : self.save_queue_request,
422 self.monitor_handlers = {b'in' : self.save_queue_request,
423 b'out': self.save_queue_result,
423 b'out': self.save_queue_result,
424 b'intask': self.save_task_request,
424 b'intask': self.save_task_request,
425 b'outtask': self.save_task_result,
425 b'outtask': self.save_task_result,
426 b'tracktask': self.save_task_destination,
426 b'tracktask': self.save_task_destination,
427 b'incontrol': _passer,
427 b'incontrol': _passer,
428 b'outcontrol': _passer,
428 b'outcontrol': _passer,
429 b'iopub': self.save_iopub_message,
429 b'iopub': self.save_iopub_message,
430 }
430 }
431
431
432 self.query_handlers = {'queue_request': self.queue_status,
432 self.query_handlers = {'queue_request': self.queue_status,
433 'result_request': self.get_results,
433 'result_request': self.get_results,
434 'history_request': self.get_history,
434 'history_request': self.get_history,
435 'db_request': self.db_query,
435 'db_request': self.db_query,
436 'purge_request': self.purge_results,
436 'purge_request': self.purge_results,
437 'load_request': self.check_load,
437 'load_request': self.check_load,
438 'resubmit_request': self.resubmit_task,
438 'resubmit_request': self.resubmit_task,
439 'shutdown_request': self.shutdown_request,
439 'shutdown_request': self.shutdown_request,
440 'registration_request' : self.register_engine,
440 'registration_request' : self.register_engine,
441 'unregistration_request' : self.unregister_engine,
441 'unregistration_request' : self.unregister_engine,
442 'connection_request': self.connection_request,
442 'connection_request': self.connection_request,
443 }
443 }
444
444
445 # ignore resubmit replies
445 # ignore resubmit replies
446 self.resubmit.on_recv(lambda msg: None, copy=False)
446 self.resubmit.on_recv(lambda msg: None, copy=False)
447
447
448 self.log.info("hub::created hub")
448 self.log.info("hub::created hub")
449
449
450 @property
450 @property
451 def _next_id(self):
451 def _next_id(self):
452 """gemerate a new ID.
452 """gemerate a new ID.
453
453
454 No longer reuse old ids, just count from 0."""
454 No longer reuse old ids, just count from 0."""
455 newid = self._idcounter
455 newid = self._idcounter
456 self._idcounter += 1
456 self._idcounter += 1
457 return newid
457 return newid
458 # newid = 0
458 # newid = 0
459 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
459 # incoming = [id[0] for id in itervalues(self.incoming_registrations)]
460 # # print newid, self.ids, self.incoming_registrations
460 # # print newid, self.ids, self.incoming_registrations
461 # while newid in self.ids or newid in incoming:
461 # while newid in self.ids or newid in incoming:
462 # newid += 1
462 # newid += 1
463 # return newid
463 # return newid
464
464
465 #-----------------------------------------------------------------------------
465 #-----------------------------------------------------------------------------
466 # message validation
466 # message validation
467 #-----------------------------------------------------------------------------
467 #-----------------------------------------------------------------------------
468
468
469 def _validate_targets(self, targets):
469 def _validate_targets(self, targets):
470 """turn any valid targets argument into a list of integer ids"""
470 """turn any valid targets argument into a list of integer ids"""
471 if targets is None:
471 if targets is None:
472 # default to all
472 # default to all
473 return self.ids
473 return self.ids
474
474
475 if isinstance(targets, (int,str,unicode_type)):
475 if isinstance(targets, (int,str,unicode_type)):
476 # only one target specified
476 # only one target specified
477 targets = [targets]
477 targets = [targets]
478 _targets = []
478 _targets = []
479 for t in targets:
479 for t in targets:
480 # map raw identities to ids
480 # map raw identities to ids
481 if isinstance(t, (str,unicode_type)):
481 if isinstance(t, (str,unicode_type)):
482 t = self.by_ident.get(cast_bytes(t), t)
482 t = self.by_ident.get(cast_bytes(t), t)
483 _targets.append(t)
483 _targets.append(t)
484 targets = _targets
484 targets = _targets
485 bad_targets = [ t for t in targets if t not in self.ids ]
485 bad_targets = [ t for t in targets if t not in self.ids ]
486 if bad_targets:
486 if bad_targets:
487 raise IndexError("No Such Engine: %r" % bad_targets)
487 raise IndexError("No Such Engine: %r" % bad_targets)
488 if not targets:
488 if not targets:
489 raise IndexError("No Engines Registered")
489 raise IndexError("No Engines Registered")
490 return targets
490 return targets
491
491
492 #-----------------------------------------------------------------------------
492 #-----------------------------------------------------------------------------
493 # dispatch methods (1 per stream)
493 # dispatch methods (1 per stream)
494 #-----------------------------------------------------------------------------
494 #-----------------------------------------------------------------------------
495
495
496
496
497 @util.log_errors
497 @util.log_errors
498 def dispatch_monitor_traffic(self, msg):
498 def dispatch_monitor_traffic(self, msg):
499 """all ME and Task queue messages come through here, as well as
499 """all ME and Task queue messages come through here, as well as
500 IOPub traffic."""
500 IOPub traffic."""
501 self.log.debug("monitor traffic: %r", msg[0])
501 self.log.debug("monitor traffic: %r", msg[0])
502 switch = msg[0]
502 switch = msg[0]
503 try:
503 try:
504 idents, msg = self.session.feed_identities(msg[1:])
504 idents, msg = self.session.feed_identities(msg[1:])
505 except ValueError:
505 except ValueError:
506 idents=[]
506 idents=[]
507 if not idents:
507 if not idents:
508 self.log.error("Monitor message without topic: %r", msg)
508 self.log.error("Monitor message without topic: %r", msg)
509 return
509 return
510 handler = self.monitor_handlers.get(switch, None)
510 handler = self.monitor_handlers.get(switch, None)
511 if handler is not None:
511 if handler is not None:
512 handler(idents, msg)
512 handler(idents, msg)
513 else:
513 else:
514 self.log.error("Unrecognized monitor topic: %r", switch)
514 self.log.error("Unrecognized monitor topic: %r", switch)
515
515
516
516
517 @util.log_errors
517 @util.log_errors
518 def dispatch_query(self, msg):
518 def dispatch_query(self, msg):
519 """Route registration requests and queries from clients."""
519 """Route registration requests and queries from clients."""
520 try:
520 try:
521 idents, msg = self.session.feed_identities(msg)
521 idents, msg = self.session.feed_identities(msg)
522 except ValueError:
522 except ValueError:
523 idents = []
523 idents = []
524 if not idents:
524 if not idents:
525 self.log.error("Bad Query Message: %r", msg)
525 self.log.error("Bad Query Message: %r", msg)
526 return
526 return
527 client_id = idents[0]
527 client_id = idents[0]
528 try:
528 try:
529 msg = self.session.unserialize(msg, content=True)
529 msg = self.session.unserialize(msg, content=True)
530 except Exception:
530 except Exception:
531 content = error.wrap_exception()
531 content = error.wrap_exception()
532 self.log.error("Bad Query Message: %r", msg, exc_info=True)
532 self.log.error("Bad Query Message: %r", msg, exc_info=True)
533 self.session.send(self.query, "hub_error", ident=client_id,
533 self.session.send(self.query, "hub_error", ident=client_id,
534 content=content)
534 content=content)
535 return
535 return
536 # print client_id, header, parent, content
536 # print client_id, header, parent, content
537 #switch on message type:
537 #switch on message type:
538 msg_type = msg['header']['msg_type']
538 msg_type = msg['header']['msg_type']
539 self.log.info("client::client %r requested %r", client_id, msg_type)
539 self.log.info("client::client %r requested %r", client_id, msg_type)
540 handler = self.query_handlers.get(msg_type, None)
540 handler = self.query_handlers.get(msg_type, None)
541 try:
541 try:
542 assert handler is not None, "Bad Message Type: %r" % msg_type
542 assert handler is not None, "Bad Message Type: %r" % msg_type
543 except:
543 except:
544 content = error.wrap_exception()
544 content = error.wrap_exception()
545 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
545 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
546 self.session.send(self.query, "hub_error", ident=client_id,
546 self.session.send(self.query, "hub_error", ident=client_id,
547 content=content)
547 content=content)
548 return
548 return
549
549
550 else:
550 else:
551 handler(idents, msg)
551 handler(idents, msg)
552
552
553 def dispatch_db(self, msg):
553 def dispatch_db(self, msg):
554 """"""
554 """"""
555 raise NotImplementedError
555 raise NotImplementedError
556
556
557 #---------------------------------------------------------------------------
557 #---------------------------------------------------------------------------
558 # handler methods (1 per event)
558 # handler methods (1 per event)
559 #---------------------------------------------------------------------------
559 #---------------------------------------------------------------------------
560
560
561 #----------------------- Heartbeat --------------------------------------
561 #----------------------- Heartbeat --------------------------------------
562
562
563 def handle_new_heart(self, heart):
563 def handle_new_heart(self, heart):
564 """handler to attach to heartbeater.
564 """handler to attach to heartbeater.
565 Called when a new heart starts to beat.
565 Called when a new heart starts to beat.
566 Triggers completion of registration."""
566 Triggers completion of registration."""
567 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
567 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
568 if heart not in self.incoming_registrations:
568 if heart not in self.incoming_registrations:
569 self.log.info("heartbeat::ignoring new heart: %r", heart)
569 self.log.info("heartbeat::ignoring new heart: %r", heart)
570 else:
570 else:
571 self.finish_registration(heart)
571 self.finish_registration(heart)
572
572
573
573
574 def handle_heart_failure(self, heart):
574 def handle_heart_failure(self, heart):
575 """handler to attach to heartbeater.
575 """handler to attach to heartbeater.
576 called when a previously registered heart fails to respond to beat request.
576 called when a previously registered heart fails to respond to beat request.
577 triggers unregistration"""
577 triggers unregistration"""
578 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
578 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
579 eid = self.hearts.get(heart, None)
579 eid = self.hearts.get(heart, None)
580 uuid = self.engines[eid].uuid
580 uuid = self.engines[eid].uuid
581 if eid is None or self.keytable[eid] in self.dead_engines:
581 if eid is None or self.keytable[eid] in self.dead_engines:
582 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
582 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
583 else:
583 else:
584 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
584 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
585
585
586 #----------------------- MUX Queue Traffic ------------------------------
586 #----------------------- MUX Queue Traffic ------------------------------
587
587
588 def save_queue_request(self, idents, msg):
588 def save_queue_request(self, idents, msg):
589 if len(idents) < 2:
589 if len(idents) < 2:
590 self.log.error("invalid identity prefix: %r", idents)
590 self.log.error("invalid identity prefix: %r", idents)
591 return
591 return
592 queue_id, client_id = idents[:2]
592 queue_id, client_id = idents[:2]
593 try:
593 try:
594 msg = self.session.unserialize(msg)
594 msg = self.session.unserialize(msg)
595 except Exception:
595 except Exception:
596 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
596 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
597 return
597 return
598
598
599 eid = self.by_ident.get(queue_id, None)
599 eid = self.by_ident.get(queue_id, None)
600 if eid is None:
600 if eid is None:
601 self.log.error("queue::target %r not registered", queue_id)
601 self.log.error("queue::target %r not registered", queue_id)
602 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
602 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
603 return
603 return
604 record = init_record(msg)
604 record = init_record(msg)
605 msg_id = record['msg_id']
605 msg_id = record['msg_id']
606 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
606 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
607 # Unicode in records
607 # Unicode in records
608 record['engine_uuid'] = queue_id.decode('ascii')
608 record['engine_uuid'] = queue_id.decode('ascii')
609 record['client_uuid'] = msg['header']['session']
609 record['client_uuid'] = msg['header']['session']
610 record['queue'] = 'mux'
610 record['queue'] = 'mux'
611
611
612 try:
612 try:
613 # it's posible iopub arrived first:
613 # it's posible iopub arrived first:
614 existing = self.db.get_record(msg_id)
614 existing = self.db.get_record(msg_id)
615 for key,evalue in iteritems(existing):
615 for key,evalue in iteritems(existing):
616 rvalue = record.get(key, None)
616 rvalue = record.get(key, None)
617 if evalue and rvalue and evalue != rvalue:
617 if evalue and rvalue and evalue != rvalue:
618 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
618 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
619 elif evalue and not rvalue:
619 elif evalue and not rvalue:
620 record[key] = evalue
620 record[key] = evalue
621 try:
621 try:
622 self.db.update_record(msg_id, record)
622 self.db.update_record(msg_id, record)
623 except Exception:
623 except Exception:
624 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
624 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
625 except KeyError:
625 except KeyError:
626 try:
626 try:
627 self.db.add_record(msg_id, record)
627 self.db.add_record(msg_id, record)
628 except Exception:
628 except Exception:
629 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
629 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
630
630
631
631
632 self.pending.add(msg_id)
632 self.pending.add(msg_id)
633 self.queues[eid].append(msg_id)
633 self.queues[eid].append(msg_id)
634
634
635 def save_queue_result(self, idents, msg):
635 def save_queue_result(self, idents, msg):
636 if len(idents) < 2:
636 if len(idents) < 2:
637 self.log.error("invalid identity prefix: %r", idents)
637 self.log.error("invalid identity prefix: %r", idents)
638 return
638 return
639
639
640 client_id, queue_id = idents[:2]
640 client_id, queue_id = idents[:2]
641 try:
641 try:
642 msg = self.session.unserialize(msg)
642 msg = self.session.unserialize(msg)
643 except Exception:
643 except Exception:
644 self.log.error("queue::engine %r sent invalid message to %r: %r",
644 self.log.error("queue::engine %r sent invalid message to %r: %r",
645 queue_id, client_id, msg, exc_info=True)
645 queue_id, client_id, msg, exc_info=True)
646 return
646 return
647
647
648 eid = self.by_ident.get(queue_id, None)
648 eid = self.by_ident.get(queue_id, None)
649 if eid is None:
649 if eid is None:
650 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
650 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
651 return
651 return
652
652
653 parent = msg['parent_header']
653 parent = msg['parent_header']
654 if not parent:
654 if not parent:
655 return
655 return
656 msg_id = parent['msg_id']
656 msg_id = parent['msg_id']
657 if msg_id in self.pending:
657 if msg_id in self.pending:
658 self.pending.remove(msg_id)
658 self.pending.remove(msg_id)
659 self.all_completed.add(msg_id)
659 self.all_completed.add(msg_id)
660 self.queues[eid].remove(msg_id)
660 self.queues[eid].remove(msg_id)
661 self.completed[eid].append(msg_id)
661 self.completed[eid].append(msg_id)
662 self.log.info("queue::request %r completed on %s", msg_id, eid)
662 self.log.info("queue::request %r completed on %s", msg_id, eid)
663 elif msg_id not in self.all_completed:
663 elif msg_id not in self.all_completed:
664 # it could be a result from a dead engine that died before delivering the
664 # it could be a result from a dead engine that died before delivering the
665 # result
665 # result
666 self.log.warn("queue:: unknown msg finished %r", msg_id)
666 self.log.warn("queue:: unknown msg finished %r", msg_id)
667 return
667 return
668 # update record anyway, because the unregistration could have been premature
668 # update record anyway, because the unregistration could have been premature
669 rheader = msg['header']
669 rheader = msg['header']
670 md = msg['metadata']
670 md = msg['metadata']
671 completed = rheader['date']
671 completed = rheader['date']
672 started = md.get('started', None)
672 started = extract_dates(md.get('started', None))
673 result = {
673 result = {
674 'result_header' : rheader,
674 'result_header' : rheader,
675 'result_metadata': md,
675 'result_metadata': md,
676 'result_content': msg['content'],
676 'result_content': msg['content'],
677 'received': datetime.now(),
677 'received': datetime.now(),
678 'started' : started,
678 'started' : started,
679 'completed' : completed
679 'completed' : completed
680 }
680 }
681
681
682 result['result_buffers'] = msg['buffers']
682 result['result_buffers'] = msg['buffers']
683 try:
683 try:
684 self.db.update_record(msg_id, result)
684 self.db.update_record(msg_id, result)
685 except Exception:
685 except Exception:
686 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
686 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
687
687
688
688
689 #--------------------- Task Queue Traffic ------------------------------
689 #--------------------- Task Queue Traffic ------------------------------
690
690
691 def save_task_request(self, idents, msg):
691 def save_task_request(self, idents, msg):
692 """Save the submission of a task."""
692 """Save the submission of a task."""
693 client_id = idents[0]
693 client_id = idents[0]
694
694
695 try:
695 try:
696 msg = self.session.unserialize(msg)
696 msg = self.session.unserialize(msg)
697 except Exception:
697 except Exception:
698 self.log.error("task::client %r sent invalid task message: %r",
698 self.log.error("task::client %r sent invalid task message: %r",
699 client_id, msg, exc_info=True)
699 client_id, msg, exc_info=True)
700 return
700 return
701 record = init_record(msg)
701 record = init_record(msg)
702
702
703 record['client_uuid'] = msg['header']['session']
703 record['client_uuid'] = msg['header']['session']
704 record['queue'] = 'task'
704 record['queue'] = 'task'
705 header = msg['header']
705 header = msg['header']
706 msg_id = header['msg_id']
706 msg_id = header['msg_id']
707 self.pending.add(msg_id)
707 self.pending.add(msg_id)
708 self.unassigned.add(msg_id)
708 self.unassigned.add(msg_id)
709 try:
709 try:
710 # it's posible iopub arrived first:
710 # it's posible iopub arrived first:
711 existing = self.db.get_record(msg_id)
711 existing = self.db.get_record(msg_id)
712 if existing['resubmitted']:
712 if existing['resubmitted']:
713 for key in ('submitted', 'client_uuid', 'buffers'):
713 for key in ('submitted', 'client_uuid', 'buffers'):
714 # don't clobber these keys on resubmit
714 # don't clobber these keys on resubmit
715 # submitted and client_uuid should be different
715 # submitted and client_uuid should be different
716 # and buffers might be big, and shouldn't have changed
716 # and buffers might be big, and shouldn't have changed
717 record.pop(key)
717 record.pop(key)
718 # still check content,header which should not change
718 # still check content,header which should not change
719 # but are not expensive to compare as buffers
719 # but are not expensive to compare as buffers
720
720
721 for key,evalue in iteritems(existing):
721 for key,evalue in iteritems(existing):
722 if key.endswith('buffers'):
722 if key.endswith('buffers'):
723 # don't compare buffers
723 # don't compare buffers
724 continue
724 continue
725 rvalue = record.get(key, None)
725 rvalue = record.get(key, None)
726 if evalue and rvalue and evalue != rvalue:
726 if evalue and rvalue and evalue != rvalue:
727 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
727 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
728 elif evalue and not rvalue:
728 elif evalue and not rvalue:
729 record[key] = evalue
729 record[key] = evalue
730 try:
730 try:
731 self.db.update_record(msg_id, record)
731 self.db.update_record(msg_id, record)
732 except Exception:
732 except Exception:
733 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
733 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
734 except KeyError:
734 except KeyError:
735 try:
735 try:
736 self.db.add_record(msg_id, record)
736 self.db.add_record(msg_id, record)
737 except Exception:
737 except Exception:
738 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
738 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
739 except Exception:
739 except Exception:
740 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
740 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
741
741
742 def save_task_result(self, idents, msg):
742 def save_task_result(self, idents, msg):
743 """save the result of a completed task."""
743 """save the result of a completed task."""
744 client_id = idents[0]
744 client_id = idents[0]
745 try:
745 try:
746 msg = self.session.unserialize(msg)
746 msg = self.session.unserialize(msg)
747 except Exception:
747 except Exception:
748 self.log.error("task::invalid task result message send to %r: %r",
748 self.log.error("task::invalid task result message send to %r: %r",
749 client_id, msg, exc_info=True)
749 client_id, msg, exc_info=True)
750 return
750 return
751
751
752 parent = msg['parent_header']
752 parent = msg['parent_header']
753 if not parent:
753 if not parent:
754 # print msg
754 # print msg
755 self.log.warn("Task %r had no parent!", msg)
755 self.log.warn("Task %r had no parent!", msg)
756 return
756 return
757 msg_id = parent['msg_id']
757 msg_id = parent['msg_id']
758 if msg_id in self.unassigned:
758 if msg_id in self.unassigned:
759 self.unassigned.remove(msg_id)
759 self.unassigned.remove(msg_id)
760
760
761 header = msg['header']
761 header = msg['header']
762 md = msg['metadata']
762 md = msg['metadata']
763 engine_uuid = md.get('engine', u'')
763 engine_uuid = md.get('engine', u'')
764 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
764 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
765
765
766 status = md.get('status', None)
766 status = md.get('status', None)
767
767
768 if msg_id in self.pending:
768 if msg_id in self.pending:
769 self.log.info("task::task %r finished on %s", msg_id, eid)
769 self.log.info("task::task %r finished on %s", msg_id, eid)
770 self.pending.remove(msg_id)
770 self.pending.remove(msg_id)
771 self.all_completed.add(msg_id)
771 self.all_completed.add(msg_id)
772 if eid is not None:
772 if eid is not None:
773 if status != 'aborted':
773 if status != 'aborted':
774 self.completed[eid].append(msg_id)
774 self.completed[eid].append(msg_id)
775 if msg_id in self.tasks[eid]:
775 if msg_id in self.tasks[eid]:
776 self.tasks[eid].remove(msg_id)
776 self.tasks[eid].remove(msg_id)
777 completed = header['date']
777 completed = header['date']
778 started = md.get('started', None)
778 started = extract_dates(md.get('started', None))
779 result = {
779 result = {
780 'result_header' : header,
780 'result_header' : header,
781 'result_metadata': msg['metadata'],
781 'result_metadata': msg['metadata'],
782 'result_content': msg['content'],
782 'result_content': msg['content'],
783 'started' : started,
783 'started' : started,
784 'completed' : completed,
784 'completed' : completed,
785 'received' : datetime.now(),
785 'received' : datetime.now(),
786 'engine_uuid': engine_uuid,
786 'engine_uuid': engine_uuid,
787 }
787 }
788
788
789 result['result_buffers'] = msg['buffers']
789 result['result_buffers'] = msg['buffers']
790 try:
790 try:
791 self.db.update_record(msg_id, result)
791 self.db.update_record(msg_id, result)
792 except Exception:
792 except Exception:
793 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
793 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
794
794
795 else:
795 else:
796 self.log.debug("task::unknown task %r finished", msg_id)
796 self.log.debug("task::unknown task %r finished", msg_id)
797
797
798 def save_task_destination(self, idents, msg):
798 def save_task_destination(self, idents, msg):
799 try:
799 try:
800 msg = self.session.unserialize(msg, content=True)
800 msg = self.session.unserialize(msg, content=True)
801 except Exception:
801 except Exception:
802 self.log.error("task::invalid task tracking message", exc_info=True)
802 self.log.error("task::invalid task tracking message", exc_info=True)
803 return
803 return
804 content = msg['content']
804 content = msg['content']
805 # print (content)
805 # print (content)
806 msg_id = content['msg_id']
806 msg_id = content['msg_id']
807 engine_uuid = content['engine_id']
807 engine_uuid = content['engine_id']
808 eid = self.by_ident[cast_bytes(engine_uuid)]
808 eid = self.by_ident[cast_bytes(engine_uuid)]
809
809
810 self.log.info("task::task %r arrived on %r", msg_id, eid)
810 self.log.info("task::task %r arrived on %r", msg_id, eid)
811 if msg_id in self.unassigned:
811 if msg_id in self.unassigned:
812 self.unassigned.remove(msg_id)
812 self.unassigned.remove(msg_id)
813 # else:
813 # else:
814 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
814 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
815
815
816 self.tasks[eid].append(msg_id)
816 self.tasks[eid].append(msg_id)
817 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
817 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
818 try:
818 try:
819 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
819 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
820 except Exception:
820 except Exception:
821 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
821 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
822
822
823
823
824 def mia_task_request(self, idents, msg):
824 def mia_task_request(self, idents, msg):
825 raise NotImplementedError
825 raise NotImplementedError
826 client_id = idents[0]
826 client_id = idents[0]
827 # content = dict(mia=self.mia,status='ok')
827 # content = dict(mia=self.mia,status='ok')
828 # self.session.send('mia_reply', content=content, idents=client_id)
828 # self.session.send('mia_reply', content=content, idents=client_id)
829
829
830
830
831 #--------------------- IOPub Traffic ------------------------------
831 #--------------------- IOPub Traffic ------------------------------
832
832
833 def save_iopub_message(self, topics, msg):
833 def save_iopub_message(self, topics, msg):
834 """save an iopub message into the db"""
834 """save an iopub message into the db"""
835 # print (topics)
835 # print (topics)
836 try:
836 try:
837 msg = self.session.unserialize(msg, content=True)
837 msg = self.session.unserialize(msg, content=True)
838 except Exception:
838 except Exception:
839 self.log.error("iopub::invalid IOPub message", exc_info=True)
839 self.log.error("iopub::invalid IOPub message", exc_info=True)
840 return
840 return
841
841
842 parent = msg['parent_header']
842 parent = msg['parent_header']
843 if not parent:
843 if not parent:
844 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
844 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
845 return
845 return
846 msg_id = parent['msg_id']
846 msg_id = parent['msg_id']
847 msg_type = msg['header']['msg_type']
847 msg_type = msg['header']['msg_type']
848 content = msg['content']
848 content = msg['content']
849
849
850 # ensure msg_id is in db
850 # ensure msg_id is in db
851 try:
851 try:
852 rec = self.db.get_record(msg_id)
852 rec = self.db.get_record(msg_id)
853 except KeyError:
853 except KeyError:
854 rec = empty_record()
854 rec = empty_record()
855 rec['msg_id'] = msg_id
855 rec['msg_id'] = msg_id
856 self.db.add_record(msg_id, rec)
856 self.db.add_record(msg_id, rec)
857 # stream
857 # stream
858 d = {}
858 d = {}
859 if msg_type == 'stream':
859 if msg_type == 'stream':
860 name = content['name']
860 name = content['name']
861 s = rec[name] or ''
861 s = rec[name] or ''
862 d[name] = s + content['data']
862 d[name] = s + content['data']
863
863
864 elif msg_type == 'pyerr':
864 elif msg_type == 'pyerr':
865 d['pyerr'] = content
865 d['pyerr'] = content
866 elif msg_type == 'pyin':
866 elif msg_type == 'pyin':
867 d['pyin'] = content['code']
867 d['pyin'] = content['code']
868 elif msg_type in ('display_data', 'pyout'):
868 elif msg_type in ('display_data', 'pyout'):
869 d[msg_type] = content
869 d[msg_type] = content
870 elif msg_type == 'status':
870 elif msg_type == 'status':
871 pass
871 pass
872 elif msg_type == 'data_pub':
872 elif msg_type == 'data_pub':
873 self.log.info("ignored data_pub message for %s" % msg_id)
873 self.log.info("ignored data_pub message for %s" % msg_id)
874 else:
874 else:
875 self.log.warn("unhandled iopub msg_type: %r", msg_type)
875 self.log.warn("unhandled iopub msg_type: %r", msg_type)
876
876
877 if not d:
877 if not d:
878 return
878 return
879
879
880 try:
880 try:
881 self.db.update_record(msg_id, d)
881 self.db.update_record(msg_id, d)
882 except Exception:
882 except Exception:
883 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
883 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
884
884
885
885
886
886
887 #-------------------------------------------------------------------------
887 #-------------------------------------------------------------------------
888 # Registration requests
888 # Registration requests
889 #-------------------------------------------------------------------------
889 #-------------------------------------------------------------------------
890
890
891 def connection_request(self, client_id, msg):
891 def connection_request(self, client_id, msg):
892 """Reply with connection addresses for clients."""
892 """Reply with connection addresses for clients."""
893 self.log.info("client::client %r connected", client_id)
893 self.log.info("client::client %r connected", client_id)
894 content = dict(status='ok')
894 content = dict(status='ok')
895 jsonable = {}
895 jsonable = {}
896 for k,v in iteritems(self.keytable):
896 for k,v in iteritems(self.keytable):
897 if v not in self.dead_engines:
897 if v not in self.dead_engines:
898 jsonable[str(k)] = v
898 jsonable[str(k)] = v
899 content['engines'] = jsonable
899 content['engines'] = jsonable
900 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
900 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
901
901
902 def register_engine(self, reg, msg):
902 def register_engine(self, reg, msg):
903 """Register a new engine."""
903 """Register a new engine."""
904 content = msg['content']
904 content = msg['content']
905 try:
905 try:
906 uuid = content['uuid']
906 uuid = content['uuid']
907 except KeyError:
907 except KeyError:
908 self.log.error("registration::queue not specified", exc_info=True)
908 self.log.error("registration::queue not specified", exc_info=True)
909 return
909 return
910
910
911 eid = self._next_id
911 eid = self._next_id
912
912
913 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
913 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
914
914
915 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
915 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
916 # check if requesting available IDs:
916 # check if requesting available IDs:
917 if cast_bytes(uuid) in self.by_ident:
917 if cast_bytes(uuid) in self.by_ident:
918 try:
918 try:
919 raise KeyError("uuid %r in use" % uuid)
919 raise KeyError("uuid %r in use" % uuid)
920 except:
920 except:
921 content = error.wrap_exception()
921 content = error.wrap_exception()
922 self.log.error("uuid %r in use", uuid, exc_info=True)
922 self.log.error("uuid %r in use", uuid, exc_info=True)
923 else:
923 else:
924 for h, ec in iteritems(self.incoming_registrations):
924 for h, ec in iteritems(self.incoming_registrations):
925 if uuid == h:
925 if uuid == h:
926 try:
926 try:
927 raise KeyError("heart_id %r in use" % uuid)
927 raise KeyError("heart_id %r in use" % uuid)
928 except:
928 except:
929 self.log.error("heart_id %r in use", uuid, exc_info=True)
929 self.log.error("heart_id %r in use", uuid, exc_info=True)
930 content = error.wrap_exception()
930 content = error.wrap_exception()
931 break
931 break
932 elif uuid == ec.uuid:
932 elif uuid == ec.uuid:
933 try:
933 try:
934 raise KeyError("uuid %r in use" % uuid)
934 raise KeyError("uuid %r in use" % uuid)
935 except:
935 except:
936 self.log.error("uuid %r in use", uuid, exc_info=True)
936 self.log.error("uuid %r in use", uuid, exc_info=True)
937 content = error.wrap_exception()
937 content = error.wrap_exception()
938 break
938 break
939
939
940 msg = self.session.send(self.query, "registration_reply",
940 msg = self.session.send(self.query, "registration_reply",
941 content=content,
941 content=content,
942 ident=reg)
942 ident=reg)
943
943
944 heart = cast_bytes(uuid)
944 heart = cast_bytes(uuid)
945
945
946 if content['status'] == 'ok':
946 if content['status'] == 'ok':
947 if heart in self.heartmonitor.hearts:
947 if heart in self.heartmonitor.hearts:
948 # already beating
948 # already beating
949 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
949 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
950 self.finish_registration(heart)
950 self.finish_registration(heart)
951 else:
951 else:
952 purge = lambda : self._purge_stalled_registration(heart)
952 purge = lambda : self._purge_stalled_registration(heart)
953 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
953 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
954 dc.start()
954 dc.start()
955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
955 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
956 else:
956 else:
957 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
957 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
958
958
959 return eid
959 return eid
960
960
961 def unregister_engine(self, ident, msg):
961 def unregister_engine(self, ident, msg):
962 """Unregister an engine that explicitly requested to leave."""
962 """Unregister an engine that explicitly requested to leave."""
963 try:
963 try:
964 eid = msg['content']['id']
964 eid = msg['content']['id']
965 except:
965 except:
966 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
966 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
967 return
967 return
968 self.log.info("registration::unregister_engine(%r)", eid)
968 self.log.info("registration::unregister_engine(%r)", eid)
969 # print (eid)
969 # print (eid)
970 uuid = self.keytable[eid]
970 uuid = self.keytable[eid]
971 content=dict(id=eid, uuid=uuid)
971 content=dict(id=eid, uuid=uuid)
972 self.dead_engines.add(uuid)
972 self.dead_engines.add(uuid)
973 # self.ids.remove(eid)
973 # self.ids.remove(eid)
974 # uuid = self.keytable.pop(eid)
974 # uuid = self.keytable.pop(eid)
975 #
975 #
976 # ec = self.engines.pop(eid)
976 # ec = self.engines.pop(eid)
977 # self.hearts.pop(ec.heartbeat)
977 # self.hearts.pop(ec.heartbeat)
978 # self.by_ident.pop(ec.queue)
978 # self.by_ident.pop(ec.queue)
979 # self.completed.pop(eid)
979 # self.completed.pop(eid)
980 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
980 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
981 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
981 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
982 dc.start()
982 dc.start()
983 ############## TODO: HANDLE IT ################
983 ############## TODO: HANDLE IT ################
984
984
985 self._save_engine_state()
985 self._save_engine_state()
986
986
987 if self.notifier:
987 if self.notifier:
988 self.session.send(self.notifier, "unregistration_notification", content=content)
988 self.session.send(self.notifier, "unregistration_notification", content=content)
989
989
990 def _handle_stranded_msgs(self, eid, uuid):
990 def _handle_stranded_msgs(self, eid, uuid):
991 """Handle messages known to be on an engine when the engine unregisters.
991 """Handle messages known to be on an engine when the engine unregisters.
992
992
993 It is possible that this will fire prematurely - that is, an engine will
993 It is possible that this will fire prematurely - that is, an engine will
994 go down after completing a result, and the client will be notified
994 go down after completing a result, and the client will be notified
995 that the result failed and later receive the actual result.
995 that the result failed and later receive the actual result.
996 """
996 """
997
997
998 outstanding = self.queues[eid]
998 outstanding = self.queues[eid]
999
999
1000 for msg_id in outstanding:
1000 for msg_id in outstanding:
1001 self.pending.remove(msg_id)
1001 self.pending.remove(msg_id)
1002 self.all_completed.add(msg_id)
1002 self.all_completed.add(msg_id)
1003 try:
1003 try:
1004 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1004 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1005 except:
1005 except:
1006 content = error.wrap_exception()
1006 content = error.wrap_exception()
1007 # build a fake header:
1007 # build a fake header:
1008 header = {}
1008 header = {}
1009 header['engine'] = uuid
1009 header['engine'] = uuid
1010 header['date'] = datetime.now()
1010 header['date'] = datetime.now()
1011 rec = dict(result_content=content, result_header=header, result_buffers=[])
1011 rec = dict(result_content=content, result_header=header, result_buffers=[])
1012 rec['completed'] = header['date']
1012 rec['completed'] = header['date']
1013 rec['engine_uuid'] = uuid
1013 rec['engine_uuid'] = uuid
1014 try:
1014 try:
1015 self.db.update_record(msg_id, rec)
1015 self.db.update_record(msg_id, rec)
1016 except Exception:
1016 except Exception:
1017 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1017 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1018
1018
1019
1019
1020 def finish_registration(self, heart):
1020 def finish_registration(self, heart):
1021 """Second half of engine registration, called after our HeartMonitor
1021 """Second half of engine registration, called after our HeartMonitor
1022 has received a beat from the Engine's Heart."""
1022 has received a beat from the Engine's Heart."""
1023 try:
1023 try:
1024 ec = self.incoming_registrations.pop(heart)
1024 ec = self.incoming_registrations.pop(heart)
1025 except KeyError:
1025 except KeyError:
1026 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1026 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1027 return
1027 return
1028 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1028 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1029 if ec.stallback is not None:
1029 if ec.stallback is not None:
1030 ec.stallback.stop()
1030 ec.stallback.stop()
1031 eid = ec.id
1031 eid = ec.id
1032 self.ids.add(eid)
1032 self.ids.add(eid)
1033 self.keytable[eid] = ec.uuid
1033 self.keytable[eid] = ec.uuid
1034 self.engines[eid] = ec
1034 self.engines[eid] = ec
1035 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1035 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1036 self.queues[eid] = list()
1036 self.queues[eid] = list()
1037 self.tasks[eid] = list()
1037 self.tasks[eid] = list()
1038 self.completed[eid] = list()
1038 self.completed[eid] = list()
1039 self.hearts[heart] = eid
1039 self.hearts[heart] = eid
1040 content = dict(id=eid, uuid=self.engines[eid].uuid)
1040 content = dict(id=eid, uuid=self.engines[eid].uuid)
1041 if self.notifier:
1041 if self.notifier:
1042 self.session.send(self.notifier, "registration_notification", content=content)
1042 self.session.send(self.notifier, "registration_notification", content=content)
1043 self.log.info("engine::Engine Connected: %i", eid)
1043 self.log.info("engine::Engine Connected: %i", eid)
1044
1044
1045 self._save_engine_state()
1045 self._save_engine_state()
1046
1046
1047 def _purge_stalled_registration(self, heart):
1047 def _purge_stalled_registration(self, heart):
1048 if heart in self.incoming_registrations:
1048 if heart in self.incoming_registrations:
1049 ec = self.incoming_registrations.pop(heart)
1049 ec = self.incoming_registrations.pop(heart)
1050 self.log.info("registration::purging stalled registration: %i", ec.id)
1050 self.log.info("registration::purging stalled registration: %i", ec.id)
1051 else:
1051 else:
1052 pass
1052 pass
1053
1053
1054 #-------------------------------------------------------------------------
1054 #-------------------------------------------------------------------------
1055 # Engine State
1055 # Engine State
1056 #-------------------------------------------------------------------------
1056 #-------------------------------------------------------------------------
1057
1057
1058
1058
1059 def _cleanup_engine_state_file(self):
1059 def _cleanup_engine_state_file(self):
1060 """cleanup engine state mapping"""
1060 """cleanup engine state mapping"""
1061
1061
1062 if os.path.exists(self.engine_state_file):
1062 if os.path.exists(self.engine_state_file):
1063 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1063 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1064 try:
1064 try:
1065 os.remove(self.engine_state_file)
1065 os.remove(self.engine_state_file)
1066 except IOError:
1066 except IOError:
1067 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1067 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1068
1068
1069
1069
1070 def _save_engine_state(self):
1070 def _save_engine_state(self):
1071 """save engine mapping to JSON file"""
1071 """save engine mapping to JSON file"""
1072 if not self.engine_state_file:
1072 if not self.engine_state_file:
1073 return
1073 return
1074 self.log.debug("save engine state to %s" % self.engine_state_file)
1074 self.log.debug("save engine state to %s" % self.engine_state_file)
1075 state = {}
1075 state = {}
1076 engines = {}
1076 engines = {}
1077 for eid, ec in iteritems(self.engines):
1077 for eid, ec in iteritems(self.engines):
1078 if ec.uuid not in self.dead_engines:
1078 if ec.uuid not in self.dead_engines:
1079 engines[eid] = ec.uuid
1079 engines[eid] = ec.uuid
1080
1080
1081 state['engines'] = engines
1081 state['engines'] = engines
1082
1082
1083 state['next_id'] = self._idcounter
1083 state['next_id'] = self._idcounter
1084
1084
1085 with open(self.engine_state_file, 'w') as f:
1085 with open(self.engine_state_file, 'w') as f:
1086 json.dump(state, f)
1086 json.dump(state, f)
1087
1087
1088
1088
1089 def _load_engine_state(self):
1089 def _load_engine_state(self):
1090 """load engine mapping from JSON file"""
1090 """load engine mapping from JSON file"""
1091 if not os.path.exists(self.engine_state_file):
1091 if not os.path.exists(self.engine_state_file):
1092 return
1092 return
1093
1093
1094 self.log.info("loading engine state from %s" % self.engine_state_file)
1094 self.log.info("loading engine state from %s" % self.engine_state_file)
1095
1095
1096 with open(self.engine_state_file) as f:
1096 with open(self.engine_state_file) as f:
1097 state = json.load(f)
1097 state = json.load(f)
1098
1098
1099 save_notifier = self.notifier
1099 save_notifier = self.notifier
1100 self.notifier = None
1100 self.notifier = None
1101 for eid, uuid in iteritems(state['engines']):
1101 for eid, uuid in iteritems(state['engines']):
1102 heart = uuid.encode('ascii')
1102 heart = uuid.encode('ascii')
1103 # start with this heart as current and beating:
1103 # start with this heart as current and beating:
1104 self.heartmonitor.responses.add(heart)
1104 self.heartmonitor.responses.add(heart)
1105 self.heartmonitor.hearts.add(heart)
1105 self.heartmonitor.hearts.add(heart)
1106
1106
1107 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1107 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1108 self.finish_registration(heart)
1108 self.finish_registration(heart)
1109
1109
1110 self.notifier = save_notifier
1110 self.notifier = save_notifier
1111
1111
1112 self._idcounter = state['next_id']
1112 self._idcounter = state['next_id']
1113
1113
1114 #-------------------------------------------------------------------------
1114 #-------------------------------------------------------------------------
1115 # Client Requests
1115 # Client Requests
1116 #-------------------------------------------------------------------------
1116 #-------------------------------------------------------------------------
1117
1117
1118 def shutdown_request(self, client_id, msg):
1118 def shutdown_request(self, client_id, msg):
1119 """handle shutdown request."""
1119 """handle shutdown request."""
1120 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1120 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1121 # also notify other clients of shutdown
1121 # also notify other clients of shutdown
1122 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1122 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1123 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1123 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1124 dc.start()
1124 dc.start()
1125
1125
1126 def _shutdown(self):
1126 def _shutdown(self):
1127 self.log.info("hub::hub shutting down.")
1127 self.log.info("hub::hub shutting down.")
1128 time.sleep(0.1)
1128 time.sleep(0.1)
1129 sys.exit(0)
1129 sys.exit(0)
1130
1130
1131
1131
1132 def check_load(self, client_id, msg):
1132 def check_load(self, client_id, msg):
1133 content = msg['content']
1133 content = msg['content']
1134 try:
1134 try:
1135 targets = content['targets']
1135 targets = content['targets']
1136 targets = self._validate_targets(targets)
1136 targets = self._validate_targets(targets)
1137 except:
1137 except:
1138 content = error.wrap_exception()
1138 content = error.wrap_exception()
1139 self.session.send(self.query, "hub_error",
1139 self.session.send(self.query, "hub_error",
1140 content=content, ident=client_id)
1140 content=content, ident=client_id)
1141 return
1141 return
1142
1142
1143 content = dict(status='ok')
1143 content = dict(status='ok')
1144 # loads = {}
1144 # loads = {}
1145 for t in targets:
1145 for t in targets:
1146 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1146 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1147 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1147 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1148
1148
1149
1149
1150 def queue_status(self, client_id, msg):
1150 def queue_status(self, client_id, msg):
1151 """Return the Queue status of one or more targets.
1151 """Return the Queue status of one or more targets.
1152
1152
1153 If verbose, return the msg_ids, else return len of each type.
1153 If verbose, return the msg_ids, else return len of each type.
1154
1154
1155 Keys:
1155 Keys:
1156
1156
1157 * queue (pending MUX jobs)
1157 * queue (pending MUX jobs)
1158 * tasks (pending Task jobs)
1158 * tasks (pending Task jobs)
1159 * completed (finished jobs from both queues)
1159 * completed (finished jobs from both queues)
1160 """
1160 """
1161 content = msg['content']
1161 content = msg['content']
1162 targets = content['targets']
1162 targets = content['targets']
1163 try:
1163 try:
1164 targets = self._validate_targets(targets)
1164 targets = self._validate_targets(targets)
1165 except:
1165 except:
1166 content = error.wrap_exception()
1166 content = error.wrap_exception()
1167 self.session.send(self.query, "hub_error",
1167 self.session.send(self.query, "hub_error",
1168 content=content, ident=client_id)
1168 content=content, ident=client_id)
1169 return
1169 return
1170 verbose = content.get('verbose', False)
1170 verbose = content.get('verbose', False)
1171 content = dict(status='ok')
1171 content = dict(status='ok')
1172 for t in targets:
1172 for t in targets:
1173 queue = self.queues[t]
1173 queue = self.queues[t]
1174 completed = self.completed[t]
1174 completed = self.completed[t]
1175 tasks = self.tasks[t]
1175 tasks = self.tasks[t]
1176 if not verbose:
1176 if not verbose:
1177 queue = len(queue)
1177 queue = len(queue)
1178 completed = len(completed)
1178 completed = len(completed)
1179 tasks = len(tasks)
1179 tasks = len(tasks)
1180 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1180 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1181 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1181 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1182 # print (content)
1182 # print (content)
1183 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1183 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1184
1184
1185 def purge_results(self, client_id, msg):
1185 def purge_results(self, client_id, msg):
1186 """Purge results from memory. This method is more valuable before we move
1186 """Purge results from memory. This method is more valuable before we move
1187 to a DB based message storage mechanism."""
1187 to a DB based message storage mechanism."""
1188 content = msg['content']
1188 content = msg['content']
1189 self.log.info("Dropping records with %s", content)
1189 self.log.info("Dropping records with %s", content)
1190 msg_ids = content.get('msg_ids', [])
1190 msg_ids = content.get('msg_ids', [])
1191 reply = dict(status='ok')
1191 reply = dict(status='ok')
1192 if msg_ids == 'all':
1192 if msg_ids == 'all':
1193 try:
1193 try:
1194 self.db.drop_matching_records(dict(completed={'$ne':None}))
1194 self.db.drop_matching_records(dict(completed={'$ne':None}))
1195 except Exception:
1195 except Exception:
1196 reply = error.wrap_exception()
1196 reply = error.wrap_exception()
1197 self.log.exception("Error dropping records")
1197 else:
1198 else:
1198 pending = [m for m in msg_ids if (m in self.pending)]
1199 pending = [m for m in msg_ids if (m in self.pending)]
1199 if pending:
1200 if pending:
1200 try:
1201 try:
1201 raise IndexError("msg pending: %r" % pending[0])
1202 raise IndexError("msg pending: %r" % pending[0])
1202 except:
1203 except:
1203 reply = error.wrap_exception()
1204 reply = error.wrap_exception()
1205 self.log.exception("Error dropping records")
1204 else:
1206 else:
1205 try:
1207 try:
1206 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1208 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1207 except Exception:
1209 except Exception:
1208 reply = error.wrap_exception()
1210 reply = error.wrap_exception()
1211 self.log.exception("Error dropping records")
1209
1212
1210 if reply['status'] == 'ok':
1213 if reply['status'] == 'ok':
1211 eids = content.get('engine_ids', [])
1214 eids = content.get('engine_ids', [])
1212 for eid in eids:
1215 for eid in eids:
1213 if eid not in self.engines:
1216 if eid not in self.engines:
1214 try:
1217 try:
1215 raise IndexError("No such engine: %i" % eid)
1218 raise IndexError("No such engine: %i" % eid)
1216 except:
1219 except:
1217 reply = error.wrap_exception()
1220 reply = error.wrap_exception()
1221 self.log.exception("Error dropping records")
1218 break
1222 break
1219 uid = self.engines[eid].uuid
1223 uid = self.engines[eid].uuid
1220 try:
1224 try:
1221 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1225 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1222 except Exception:
1226 except Exception:
1223 reply = error.wrap_exception()
1227 reply = error.wrap_exception()
1228 self.log.exception("Error dropping records")
1224 break
1229 break
1225
1230
1226 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1231 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1227
1232
1228 def resubmit_task(self, client_id, msg):
1233 def resubmit_task(self, client_id, msg):
1229 """Resubmit one or more tasks."""
1234 """Resubmit one or more tasks."""
1230 def finish(reply):
1235 def finish(reply):
1231 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1236 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1232
1237
1233 content = msg['content']
1238 content = msg['content']
1234 msg_ids = content['msg_ids']
1239 msg_ids = content['msg_ids']
1235 reply = dict(status='ok')
1240 reply = dict(status='ok')
1236 try:
1241 try:
1237 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1242 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1238 'header', 'content', 'buffers'])
1243 'header', 'content', 'buffers'])
1239 except Exception:
1244 except Exception:
1240 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1245 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1241 return finish(error.wrap_exception())
1246 return finish(error.wrap_exception())
1242
1247
1243 # validate msg_ids
1248 # validate msg_ids
1244 found_ids = [ rec['msg_id'] for rec in records ]
1249 found_ids = [ rec['msg_id'] for rec in records ]
1245 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1250 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1246 if len(records) > len(msg_ids):
1251 if len(records) > len(msg_ids):
1247 try:
1252 try:
1248 raise RuntimeError("DB appears to be in an inconsistent state."
1253 raise RuntimeError("DB appears to be in an inconsistent state."
1249 "More matching records were found than should exist")
1254 "More matching records were found than should exist")
1250 except Exception:
1255 except Exception:
1256 self.log.exception("Failed to resubmit task")
1251 return finish(error.wrap_exception())
1257 return finish(error.wrap_exception())
1252 elif len(records) < len(msg_ids):
1258 elif len(records) < len(msg_ids):
1253 missing = [ m for m in msg_ids if m not in found_ids ]
1259 missing = [ m for m in msg_ids if m not in found_ids ]
1254 try:
1260 try:
1255 raise KeyError("No such msg(s): %r" % missing)
1261 raise KeyError("No such msg(s): %r" % missing)
1256 except KeyError:
1262 except KeyError:
1263 self.log.exception("Failed to resubmit task")
1257 return finish(error.wrap_exception())
1264 return finish(error.wrap_exception())
1258 elif pending_ids:
1265 elif pending_ids:
1259 pass
1266 pass
1260 # no need to raise on resubmit of pending task, now that we
1267 # no need to raise on resubmit of pending task, now that we
1261 # resubmit under new ID, but do we want to raise anyway?
1268 # resubmit under new ID, but do we want to raise anyway?
1262 # msg_id = invalid_ids[0]
1269 # msg_id = invalid_ids[0]
1263 # try:
1270 # try:
1264 # raise ValueError("Task(s) %r appears to be inflight" % )
1271 # raise ValueError("Task(s) %r appears to be inflight" % )
1265 # except Exception:
1272 # except Exception:
1266 # return finish(error.wrap_exception())
1273 # return finish(error.wrap_exception())
1267
1274
1268 # mapping of original IDs to resubmitted IDs
1275 # mapping of original IDs to resubmitted IDs
1269 resubmitted = {}
1276 resubmitted = {}
1270
1277
1271 # send the messages
1278 # send the messages
1272 for rec in records:
1279 for rec in records:
1273 header = rec['header']
1280 header = rec['header']
1274 msg = self.session.msg(header['msg_type'], parent=header)
1281 msg = self.session.msg(header['msg_type'], parent=header)
1275 msg_id = msg['msg_id']
1282 msg_id = msg['msg_id']
1276 msg['content'] = rec['content']
1283 msg['content'] = rec['content']
1277
1284
1278 # use the old header, but update msg_id and timestamp
1285 # use the old header, but update msg_id and timestamp
1279 fresh = msg['header']
1286 fresh = msg['header']
1280 header['msg_id'] = fresh['msg_id']
1287 header['msg_id'] = fresh['msg_id']
1281 header['date'] = fresh['date']
1288 header['date'] = fresh['date']
1282 msg['header'] = header
1289 msg['header'] = header
1283
1290
1284 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1291 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1285
1292
1286 resubmitted[rec['msg_id']] = msg_id
1293 resubmitted[rec['msg_id']] = msg_id
1287 self.pending.add(msg_id)
1294 self.pending.add(msg_id)
1288 msg['buffers'] = rec['buffers']
1295 msg['buffers'] = rec['buffers']
1289 try:
1296 try:
1290 self.db.add_record(msg_id, init_record(msg))
1297 self.db.add_record(msg_id, init_record(msg))
1291 except Exception:
1298 except Exception:
1292 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1299 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1293 return finish(error.wrap_exception())
1300 return finish(error.wrap_exception())
1294
1301
1295 finish(dict(status='ok', resubmitted=resubmitted))
1302 finish(dict(status='ok', resubmitted=resubmitted))
1296
1303
1297 # store the new IDs in the Task DB
1304 # store the new IDs in the Task DB
1298 for msg_id, resubmit_id in iteritems(resubmitted):
1305 for msg_id, resubmit_id in iteritems(resubmitted):
1299 try:
1306 try:
1300 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1307 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1301 except Exception:
1308 except Exception:
1302 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1309 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1303
1310
1304
1311
1305 def _extract_record(self, rec):
1312 def _extract_record(self, rec):
1306 """decompose a TaskRecord dict into subsection of reply for get_result"""
1313 """decompose a TaskRecord dict into subsection of reply for get_result"""
1307 io_dict = {}
1314 io_dict = {}
1308 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1315 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1309 io_dict[key] = rec[key]
1316 io_dict[key] = rec[key]
1310 content = {
1317 content = {
1311 'header': rec['header'],
1318 'header': rec['header'],
1312 'metadata': rec['metadata'],
1319 'metadata': rec['metadata'],
1313 'result_metadata': rec['result_metadata'],
1320 'result_metadata': rec['result_metadata'],
1314 'result_header' : rec['result_header'],
1321 'result_header' : rec['result_header'],
1315 'result_content': rec['result_content'],
1322 'result_content': rec['result_content'],
1316 'received' : rec['received'],
1323 'received' : rec['received'],
1317 'io' : io_dict,
1324 'io' : io_dict,
1318 }
1325 }
1319 if rec['result_buffers']:
1326 if rec['result_buffers']:
1320 buffers = list(map(bytes, rec['result_buffers']))
1327 buffers = list(map(bytes, rec['result_buffers']))
1321 else:
1328 else:
1322 buffers = []
1329 buffers = []
1323
1330
1324 return content, buffers
1331 return content, buffers
1325
1332
1326 def get_results(self, client_id, msg):
1333 def get_results(self, client_id, msg):
1327 """Get the result of 1 or more messages."""
1334 """Get the result of 1 or more messages."""
1328 content = msg['content']
1335 content = msg['content']
1329 msg_ids = sorted(set(content['msg_ids']))
1336 msg_ids = sorted(set(content['msg_ids']))
1330 statusonly = content.get('status_only', False)
1337 statusonly = content.get('status_only', False)
1331 pending = []
1338 pending = []
1332 completed = []
1339 completed = []
1333 content = dict(status='ok')
1340 content = dict(status='ok')
1334 content['pending'] = pending
1341 content['pending'] = pending
1335 content['completed'] = completed
1342 content['completed'] = completed
1336 buffers = []
1343 buffers = []
1337 if not statusonly:
1344 if not statusonly:
1338 try:
1345 try:
1339 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1346 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1340 # turn match list into dict, for faster lookup
1347 # turn match list into dict, for faster lookup
1341 records = {}
1348 records = {}
1342 for rec in matches:
1349 for rec in matches:
1343 records[rec['msg_id']] = rec
1350 records[rec['msg_id']] = rec
1344 except Exception:
1351 except Exception:
1345 content = error.wrap_exception()
1352 content = error.wrap_exception()
1353 self.log.exception("Failed to get results")
1346 self.session.send(self.query, "result_reply", content=content,
1354 self.session.send(self.query, "result_reply", content=content,
1347 parent=msg, ident=client_id)
1355 parent=msg, ident=client_id)
1348 return
1356 return
1349 else:
1357 else:
1350 records = {}
1358 records = {}
1351 for msg_id in msg_ids:
1359 for msg_id in msg_ids:
1352 if msg_id in self.pending:
1360 if msg_id in self.pending:
1353 pending.append(msg_id)
1361 pending.append(msg_id)
1354 elif msg_id in self.all_completed:
1362 elif msg_id in self.all_completed:
1355 completed.append(msg_id)
1363 completed.append(msg_id)
1356 if not statusonly:
1364 if not statusonly:
1357 c,bufs = self._extract_record(records[msg_id])
1365 c,bufs = self._extract_record(records[msg_id])
1358 content[msg_id] = c
1366 content[msg_id] = c
1359 buffers.extend(bufs)
1367 buffers.extend(bufs)
1360 elif msg_id in records:
1368 elif msg_id in records:
1361 if rec['completed']:
1369 if rec['completed']:
1362 completed.append(msg_id)
1370 completed.append(msg_id)
1363 c,bufs = self._extract_record(records[msg_id])
1371 c,bufs = self._extract_record(records[msg_id])
1364 content[msg_id] = c
1372 content[msg_id] = c
1365 buffers.extend(bufs)
1373 buffers.extend(bufs)
1366 else:
1374 else:
1367 pending.append(msg_id)
1375 pending.append(msg_id)
1368 else:
1376 else:
1369 try:
1377 try:
1370 raise KeyError('No such message: '+msg_id)
1378 raise KeyError('No such message: '+msg_id)
1371 except:
1379 except:
1372 content = error.wrap_exception()
1380 content = error.wrap_exception()
1373 break
1381 break
1374 self.session.send(self.query, "result_reply", content=content,
1382 self.session.send(self.query, "result_reply", content=content,
1375 parent=msg, ident=client_id,
1383 parent=msg, ident=client_id,
1376 buffers=buffers)
1384 buffers=buffers)
1377
1385
1378 def get_history(self, client_id, msg):
1386 def get_history(self, client_id, msg):
1379 """Get a list of all msg_ids in our DB records"""
1387 """Get a list of all msg_ids in our DB records"""
1380 try:
1388 try:
1381 msg_ids = self.db.get_history()
1389 msg_ids = self.db.get_history()
1382 except Exception as e:
1390 except Exception as e:
1383 content = error.wrap_exception()
1391 content = error.wrap_exception()
1392 self.log.exception("Failed to get history")
1384 else:
1393 else:
1385 content = dict(status='ok', history=msg_ids)
1394 content = dict(status='ok', history=msg_ids)
1386
1395
1387 self.session.send(self.query, "history_reply", content=content,
1396 self.session.send(self.query, "history_reply", content=content,
1388 parent=msg, ident=client_id)
1397 parent=msg, ident=client_id)
1389
1398
1390 def db_query(self, client_id, msg):
1399 def db_query(self, client_id, msg):
1391 """Perform a raw query on the task record database."""
1400 """Perform a raw query on the task record database."""
1392 content = msg['content']
1401 content = msg['content']
1393 query = extract_dates(content.get('query', {}))
1402 query = extract_dates(content.get('query', {}))
1394 keys = content.get('keys', None)
1403 keys = content.get('keys', None)
1395 buffers = []
1404 buffers = []
1396 empty = list()
1405 empty = list()
1397 try:
1406 try:
1398 records = self.db.find_records(query, keys)
1407 records = self.db.find_records(query, keys)
1399 except Exception as e:
1408 except Exception as e:
1400 content = error.wrap_exception()
1409 content = error.wrap_exception()
1410 self.log.exception("DB query failed")
1401 else:
1411 else:
1402 # extract buffers from reply content:
1412 # extract buffers from reply content:
1403 if keys is not None:
1413 if keys is not None:
1404 buffer_lens = [] if 'buffers' in keys else None
1414 buffer_lens = [] if 'buffers' in keys else None
1405 result_buffer_lens = [] if 'result_buffers' in keys else None
1415 result_buffer_lens = [] if 'result_buffers' in keys else None
1406 else:
1416 else:
1407 buffer_lens = None
1417 buffer_lens = None
1408 result_buffer_lens = None
1418 result_buffer_lens = None
1409
1419
1410 for rec in records:
1420 for rec in records:
1411 # buffers may be None, so double check
1421 # buffers may be None, so double check
1412 b = rec.pop('buffers', empty) or empty
1422 b = rec.pop('buffers', empty) or empty
1413 if buffer_lens is not None:
1423 if buffer_lens is not None:
1414 buffer_lens.append(len(b))
1424 buffer_lens.append(len(b))
1415 buffers.extend(b)
1425 buffers.extend(b)
1416 rb = rec.pop('result_buffers', empty) or empty
1426 rb = rec.pop('result_buffers', empty) or empty
1417 if result_buffer_lens is not None:
1427 if result_buffer_lens is not None:
1418 result_buffer_lens.append(len(rb))
1428 result_buffer_lens.append(len(rb))
1419 buffers.extend(rb)
1429 buffers.extend(rb)
1420 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1430 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1421 result_buffer_lens=result_buffer_lens)
1431 result_buffer_lens=result_buffer_lens)
1422 # self.log.debug (content)
1432 # self.log.debug (content)
1423 self.session.send(self.query, "db_reply", content=content,
1433 self.session.send(self.query, "db_reply", content=content,
1424 parent=msg, ident=client_id,
1434 parent=msg, ident=client_id,
1425 buffers=buffers)
1435 buffers=buffers)
1426
1436
@@ -1,130 +1,145 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2 from __future__ import print_function
2 from __future__ import print_function
3
3
4 #-------------------------------------------------------------------------------
4 #-------------------------------------------------------------------------------
5 # Copyright (C) 2011 The IPython Development Team
5 # Copyright (C) 2011 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-------------------------------------------------------------------------------
9 #-------------------------------------------------------------------------------
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 import os
15 import os
16 import tempfile
16 import tempfile
17 import time
17 import time
18 from subprocess import Popen, PIPE, STDOUT
18 from subprocess import Popen, PIPE, STDOUT
19
19
20 import nose
20 import nose
21
21
22 from IPython.utils.path import get_ipython_dir
22 from IPython.utils.path import get_ipython_dir
23 from IPython.parallel import Client
23 from IPython.parallel import Client, error
24 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
24 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
25 ipengine_cmd_argv,
25 ipengine_cmd_argv,
26 ipcontroller_cmd_argv,
26 ipcontroller_cmd_argv,
27 SIGKILL,
27 SIGKILL,
28 ProcessStateError,
28 ProcessStateError,
29 )
29 )
30
30
31 # globals
31 # globals
32 launchers = []
32 launchers = []
33 blackhole = open(os.devnull, 'w')
33 blackhole = open(os.devnull, 'w')
34
34
35 # Launcher class
35 # Launcher class
36 class TestProcessLauncher(LocalProcessLauncher):
36 class TestProcessLauncher(LocalProcessLauncher):
37 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
37 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
38 def start(self):
38 def start(self):
39 if self.state == 'before':
39 if self.state == 'before':
40 # Store stdout & stderr to show with failing tests.
40 # Store stdout & stderr to show with failing tests.
41 # This is defined in IPython.testing.iptest
41 # This is defined in IPython.testing.iptest
42 self.process = Popen(self.args,
42 self.process = Popen(self.args,
43 stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT,
43 stdout=nose.iptest_stdstreams_fileno(), stderr=STDOUT,
44 env=os.environ,
44 env=os.environ,
45 cwd=self.work_dir
45 cwd=self.work_dir
46 )
46 )
47 self.notify_start(self.process.pid)
47 self.notify_start(self.process.pid)
48 self.poll = self.process.poll
48 self.poll = self.process.poll
49 else:
49 else:
50 s = 'The process was already started and has state: %r' % self.state
50 s = 'The process was already started and has state: %r' % self.state
51 raise ProcessStateError(s)
51 raise ProcessStateError(s)
52
52
53 # nose setup/teardown
53 # nose setup/teardown
54
54
55 def setup():
55 def setup():
56
57 # show tracebacks for RemoteErrors
58 class RemoteErrorWithTB(error.RemoteError):
59 def __str__(self):
60 s = super(RemoteErrorWithTB, self).__str__()
61 return '\n'.join([s, self.traceback or ''])
62
63 error.RemoteError = RemoteErrorWithTB
64
56 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
65 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
57 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
66 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
58 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
67 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
59 for json in (engine_json, client_json):
68 for json in (engine_json, client_json):
60 if os.path.exists(json):
69 if os.path.exists(json):
61 os.remove(json)
70 os.remove(json)
62
71
63 cp = TestProcessLauncher()
72 cp = TestProcessLauncher()
64 cp.cmd_and_args = ipcontroller_cmd_argv + \
73 cp.cmd_and_args = ipcontroller_cmd_argv + \
65 ['--profile=iptest', '--log-level=20', '--ping=250', '--dictdb']
74 ['--profile=iptest', '--log-level=20', '--ping=250', '--dictdb']
66 cp.start()
75 cp.start()
67 launchers.append(cp)
76 launchers.append(cp)
68 tic = time.time()
77 tic = time.time()
69 while not os.path.exists(engine_json) or not os.path.exists(client_json):
78 while not os.path.exists(engine_json) or not os.path.exists(client_json):
70 if cp.poll() is not None:
79 if cp.poll() is not None:
71 raise RuntimeError("The test controller exited with status %s" % cp.poll())
80 raise RuntimeError("The test controller exited with status %s" % cp.poll())
72 elif time.time()-tic > 15:
81 elif time.time()-tic > 15:
73 raise RuntimeError("Timeout waiting for the test controller to start.")
82 raise RuntimeError("Timeout waiting for the test controller to start.")
74 time.sleep(0.1)
83 time.sleep(0.1)
75 add_engines(1)
84 add_engines(1)
76
85
77 def add_engines(n=1, profile='iptest', total=False):
86 def add_engines(n=1, profile='iptest', total=False):
78 """add a number of engines to a given profile.
87 """add a number of engines to a given profile.
79
88
80 If total is True, then already running engines are counted, and only
89 If total is True, then already running engines are counted, and only
81 the additional engines necessary (if any) are started.
90 the additional engines necessary (if any) are started.
82 """
91 """
83 rc = Client(profile=profile)
92 rc = Client(profile=profile)
84 base = len(rc)
93 base = len(rc)
85
94
86 if total:
95 if total:
87 n = max(n - base, 0)
96 n = max(n - base, 0)
88
97
89 eps = []
98 eps = []
90 for i in range(n):
99 for i in range(n):
91 ep = TestProcessLauncher()
100 ep = TestProcessLauncher()
92 ep.cmd_and_args = ipengine_cmd_argv + [
101 ep.cmd_and_args = ipengine_cmd_argv + [
93 '--profile=%s' % profile,
102 '--profile=%s' % profile,
94 '--log-level=50',
103 '--log-level=50',
95 '--InteractiveShell.colors=nocolor'
104 '--InteractiveShell.colors=nocolor'
96 ]
105 ]
97 ep.start()
106 ep.start()
98 launchers.append(ep)
107 launchers.append(ep)
99 eps.append(ep)
108 eps.append(ep)
100 tic = time.time()
109 tic = time.time()
101 while len(rc) < base+n:
110 while len(rc) < base+n:
102 if any([ ep.poll() is not None for ep in eps ]):
111 if any([ ep.poll() is not None for ep in eps ]):
103 raise RuntimeError("A test engine failed to start.")
112 raise RuntimeError("A test engine failed to start.")
104 elif time.time()-tic > 15:
113 elif time.time()-tic > 15:
105 raise RuntimeError("Timeout waiting for engines to connect.")
114 raise RuntimeError("Timeout waiting for engines to connect.")
106 time.sleep(.1)
115 time.sleep(.1)
107 rc.spin()
116 rc.spin()
108 rc.close()
117 rc.close()
109 return eps
118 return eps
110
119
111 def teardown():
120 def teardown():
112 time.sleep(1)
121 try:
122 time.sleep(1)
123 except KeyboardInterrupt:
124 return
113 while launchers:
125 while launchers:
114 p = launchers.pop()
126 p = launchers.pop()
115 if p.poll() is None:
127 if p.poll() is None:
116 try:
128 try:
117 p.stop()
129 p.stop()
118 except Exception as e:
130 except Exception as e:
119 print(e)
131 print(e)
120 pass
132 pass
121 if p.poll() is None:
133 if p.poll() is None:
122 time.sleep(.25)
134 try:
135 time.sleep(.25)
136 except KeyboardInterrupt:
137 return
123 if p.poll() is None:
138 if p.poll() is None:
124 try:
139 try:
125 print('cleaning up test process...')
140 print('cleaning up test process...')
126 p.signal(SIGKILL)
141 p.signal(SIGKILL)
127 except:
142 except:
128 print("couldn't shutdown process: ", p)
143 print("couldn't shutdown process: ", p)
129 blackhole.close()
144 blackhole.close()
130
145
@@ -1,314 +1,314 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import logging
22 import os
22 import os
23 import tempfile
23 import tempfile
24 import time
24 import time
25
25
26 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
27 from unittest import TestCase
27 from unittest import TestCase
28
28
29 from IPython.parallel import error
29 from IPython.parallel import error
30 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.dictdb import DictDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
32 from IPython.parallel.controller.hub import init_record, empty_record
32 from IPython.parallel.controller.hub import init_record, empty_record
33
33
34 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
35 from IPython.kernel.zmq.session import Session
35 from IPython.kernel.zmq.session import Session
36
36
37
37
38 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
39 # TestCases
39 # TestCases
40 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
41
41
42
42
43 def setup():
43 def setup():
44 global temp_db
44 global temp_db
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
46
46
47
47
48 class TaskDBTest:
48 class TaskDBTest:
49 def setUp(self):
49 def setUp(self):
50 self.session = Session()
50 self.session = Session()
51 self.db = self.create_db()
51 self.db = self.create_db()
52 self.load_records(16)
52 self.load_records(16)
53
53
54 def create_db(self):
54 def create_db(self):
55 raise NotImplementedError
55 raise NotImplementedError
56
56
57 def load_records(self, n=1, buffer_size=100):
57 def load_records(self, n=1, buffer_size=100):
58 """load n records for testing"""
58 """load n records for testing"""
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
60 time.sleep(0.1)
60 time.sleep(0.1)
61 msg_ids = []
61 msg_ids = []
62 for i in range(n):
62 for i in range(n):
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
64 msg['buffers'] = [os.urandom(buffer_size)]
64 msg['buffers'] = [os.urandom(buffer_size)]
65 rec = init_record(msg)
65 rec = init_record(msg)
66 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
67 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
68 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
69 return msg_ids
69 return msg_ids
70
70
71 def test_add_record(self):
71 def test_add_record(self):
72 before = self.db.get_history()
72 before = self.db.get_history()
73 self.load_records(5)
73 self.load_records(5)
74 after = self.db.get_history()
74 after = self.db.get_history()
75 self.assertEqual(len(after), len(before)+5)
75 self.assertEqual(len(after), len(before)+5)
76 self.assertEqual(after[:-5],before)
76 self.assertEqual(after[:-5],before)
77
77
78 def test_drop_record(self):
78 def test_drop_record(self):
79 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
80 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
83
83
84 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
85 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
86 micro = dt.microsecond
86 micro = dt.microsecond
87 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
88 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
89
89
90 def test_update_record(self):
90 def test_update_record(self):
91 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
92 #
92 #
93 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
94 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
95 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
96 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
97 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
98 self.assertEqual(rec2['stdout'], 'hello there')
98 self.assertEqual(rec2['stdout'], 'hello there')
99 self.assertEqual(rec2['completed'], now)
99 self.assertEqual(rec2['completed'], now)
100 rec1.update(data)
100 rec1.update(data)
101 self.assertEqual(rec1, rec2)
101 self.assertEqual(rec1, rec2)
102
102
103 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
104 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
105 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
106 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
108
108
109 def test_find_records_dt(self):
109 def test_find_records_dt(self):
110 """test finding records by date"""
110 """test finding records by date"""
111 hist = self.db.get_history()
111 hist = self.db.get_history()
112 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
113 tic = middle['submitted']
113 tic = middle['submitted']
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
116 self.assertEqual(len(before)+len(after),len(hist))
116 self.assertEqual(len(before)+len(after),len(hist))
117 for b in before:
117 for b in before:
118 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
119 for a in after:
119 for a in after:
120 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
121 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
122 for s in same:
122 for s in same:
123 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
124
124
125 def test_find_records_keys(self):
125 def test_find_records_keys(self):
126 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
128 for rec in found:
128 for rec in found:
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEqual(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
130
130
131 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
132 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
134 for rec in found:
134 for rec in found:
135 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
137 for rec in found:
137 for rec in found:
138 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
140 for rec in found:
140 for rec in found:
141 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
142
142
143 def test_find_records_in(self):
143 def test_find_records_in(self):
144 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
145 hist = self.db.get_history()
145 hist = self.db.get_history()
146 even = hist[::2]
146 even = hist[::2]
147 odd = hist[1::2]
147 odd = hist[1::2]
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
149 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
150 self.assertEqual(set(even), set(found))
150 self.assertEqual(set(even), set(found))
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
152 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
153 self.assertEqual(set(odd), set(found))
153 self.assertEqual(set(odd), set(found))
154
154
155 def test_get_history(self):
155 def test_get_history(self):
156 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
157 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
158 for msg_id in msg_ids:
158 for msg_id in msg_ids:
159 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
160 newt = rec['submitted']
160 newt = rec['submitted']
161 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
162 latest = newt
162 latest = newt
163 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
164 self.assertEqual(self.db.get_history()[-1],msg_id)
164 self.assertEqual(self.db.get_history()[-1],msg_id)
165
165
166 def test_datetime(self):
166 def test_datetime(self):
167 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
168 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
169 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
172 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
173 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
174
174
175 def test_drop_matching(self):
175 def test_drop_matching(self):
176 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
177 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
178 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
179 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
180 self.assertEqual(len(recs), 0)
180 self.assertEqual(len(recs), 0)
181
181
182 def test_null(self):
182 def test_null(self):
183 """test None comparison queries"""
183 """test None comparison queries"""
184 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
185
185
186 query = {'msg_id' : None}
186 query = {'msg_id' : None}
187 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
188 self.assertEqual(len(recs), 0)
188 self.assertEqual(len(recs), 0)
189
189
190 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
191 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
192 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
193
193
194 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
195 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
196 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
197 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
198 rec.pop('buffers')
198 rec.pop('buffers')
199 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
200 rec['header']['msg_id'] = 'fubar'
200 rec['header']['msg_id'] = 'fubar'
201 rec2 = self.db.get_record(msg_id)
201 rec2 = self.db.get_record(msg_id)
202 self.assertTrue('buffers' in rec2)
202 self.assertTrue('buffers' in rec2)
203 self.assertFalse('garbage' in rec2)
203 self.assertFalse('garbage' in rec2)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
204 self.assertEqual(rec2['header']['msg_id'], msg_id)
205
205
206 def test_pop_safe_find(self):
206 def test_pop_safe_find(self):
207 """editing query results shouldn't affect record [find]"""
207 """editing query results shouldn't affect record [find]"""
208 msg_id = self.db.get_history()[-1]
208 msg_id = self.db.get_history()[-1]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
209 rec = self.db.find_records({'msg_id' : msg_id})[0]
210 rec.pop('buffers')
210 rec.pop('buffers')
211 rec['garbage'] = 'hello'
211 rec['garbage'] = 'hello'
212 rec['header']['msg_id'] = 'fubar'
212 rec['header']['msg_id'] = 'fubar'
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
213 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
214 self.assertTrue('buffers' in rec2)
214 self.assertTrue('buffers' in rec2)
215 self.assertFalse('garbage' in rec2)
215 self.assertFalse('garbage' in rec2)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
216 self.assertEqual(rec2['header']['msg_id'], msg_id)
217
217
218 def test_pop_safe_find_keys(self):
218 def test_pop_safe_find_keys(self):
219 """editing query results shouldn't affect record [find+keys]"""
219 """editing query results shouldn't affect record [find+keys]"""
220 msg_id = self.db.get_history()[-1]
220 msg_id = self.db.get_history()[-1]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
221 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers', 'header'])[0]
222 rec.pop('buffers')
222 rec.pop('buffers')
223 rec['garbage'] = 'hello'
223 rec['garbage'] = 'hello'
224 rec['header']['msg_id'] = 'fubar'
224 rec['header']['msg_id'] = 'fubar'
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
225 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
226 self.assertTrue('buffers' in rec2)
226 self.assertTrue('buffers' in rec2)
227 self.assertFalse('garbage' in rec2)
227 self.assertFalse('garbage' in rec2)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
228 self.assertEqual(rec2['header']['msg_id'], msg_id)
229
229
230
230
231 class TestDictBackend(TaskDBTest, TestCase):
231 class TestDictBackend(TaskDBTest, TestCase):
232
232
233 def create_db(self):
233 def create_db(self):
234 return DictDB()
234 return DictDB()
235
235
236 def test_cull_count(self):
236 def test_cull_count(self):
237 self.db = self.create_db() # skip the load-records init from setUp
237 self.db = self.create_db() # skip the load-records init from setUp
238 self.db.record_limit = 20
238 self.db.record_limit = 20
239 self.db.cull_fraction = 0.2
239 self.db.cull_fraction = 0.2
240 self.load_records(20)
240 self.load_records(20)
241 self.assertEqual(len(self.db.get_history()), 20)
241 self.assertEqual(len(self.db.get_history()), 20)
242 self.load_records(1)
242 self.load_records(1)
243 # 0.2 * 20 = 4, 21 - 4 = 17
243 # 0.2 * 20 = 4, 21 - 4 = 17
244 self.assertEqual(len(self.db.get_history()), 17)
244 self.assertEqual(len(self.db.get_history()), 17)
245 self.load_records(3)
245 self.load_records(3)
246 self.assertEqual(len(self.db.get_history()), 20)
246 self.assertEqual(len(self.db.get_history()), 20)
247 self.load_records(1)
247 self.load_records(1)
248 self.assertEqual(len(self.db.get_history()), 17)
248 self.assertEqual(len(self.db.get_history()), 17)
249
249
250 for i in range(100):
250 for i in range(25):
251 self.load_records(1)
251 self.load_records(1)
252 self.assertTrue(len(self.db.get_history()) >= 17)
252 self.assertTrue(len(self.db.get_history()) >= 17)
253 self.assertTrue(len(self.db.get_history()) <= 20)
253 self.assertTrue(len(self.db.get_history()) <= 20)
254
254
255 def test_cull_size(self):
255 def test_cull_size(self):
256 self.db = self.create_db() # skip the load-records init from setUp
256 self.db = self.create_db() # skip the load-records init from setUp
257 self.db.size_limit = 1000
257 self.db.size_limit = 1000
258 self.db.cull_fraction = 0.2
258 self.db.cull_fraction = 0.2
259 self.load_records(100, buffer_size=10)
259 self.load_records(100, buffer_size=10)
260 self.assertEqual(len(self.db.get_history()), 100)
260 self.assertEqual(len(self.db.get_history()), 100)
261 self.load_records(1, buffer_size=0)
261 self.load_records(1, buffer_size=0)
262 self.assertEqual(len(self.db.get_history()), 101)
262 self.assertEqual(len(self.db.get_history()), 101)
263 self.load_records(1, buffer_size=1)
263 self.load_records(1, buffer_size=1)
264 # 0.2 * 100 = 20, 101 - 20 = 81
264 # 0.2 * 100 = 20, 101 - 20 = 81
265 self.assertEqual(len(self.db.get_history()), 81)
265 self.assertEqual(len(self.db.get_history()), 81)
266
266
267 def test_cull_size_drop(self):
267 def test_cull_size_drop(self):
268 """dropping records updates tracked buffer size"""
268 """dropping records updates tracked buffer size"""
269 self.db = self.create_db() # skip the load-records init from setUp
269 self.db = self.create_db() # skip the load-records init from setUp
270 self.db.size_limit = 1000
270 self.db.size_limit = 1000
271 self.db.cull_fraction = 0.2
271 self.db.cull_fraction = 0.2
272 self.load_records(100, buffer_size=10)
272 self.load_records(100, buffer_size=10)
273 self.assertEqual(len(self.db.get_history()), 100)
273 self.assertEqual(len(self.db.get_history()), 100)
274 self.db.drop_record(self.db.get_history()[-1])
274 self.db.drop_record(self.db.get_history()[-1])
275 self.assertEqual(len(self.db.get_history()), 99)
275 self.assertEqual(len(self.db.get_history()), 99)
276 self.load_records(1, buffer_size=5)
276 self.load_records(1, buffer_size=5)
277 self.assertEqual(len(self.db.get_history()), 100)
277 self.assertEqual(len(self.db.get_history()), 100)
278 self.load_records(1, buffer_size=5)
278 self.load_records(1, buffer_size=5)
279 self.assertEqual(len(self.db.get_history()), 101)
279 self.assertEqual(len(self.db.get_history()), 101)
280 self.load_records(1, buffer_size=1)
280 self.load_records(1, buffer_size=1)
281 self.assertEqual(len(self.db.get_history()), 81)
281 self.assertEqual(len(self.db.get_history()), 81)
282
282
283 def test_cull_size_update(self):
283 def test_cull_size_update(self):
284 """updating records updates tracked buffer size"""
284 """updating records updates tracked buffer size"""
285 self.db = self.create_db() # skip the load-records init from setUp
285 self.db = self.create_db() # skip the load-records init from setUp
286 self.db.size_limit = 1000
286 self.db.size_limit = 1000
287 self.db.cull_fraction = 0.2
287 self.db.cull_fraction = 0.2
288 self.load_records(100, buffer_size=10)
288 self.load_records(100, buffer_size=10)
289 self.assertEqual(len(self.db.get_history()), 100)
289 self.assertEqual(len(self.db.get_history()), 100)
290 msg_id = self.db.get_history()[-1]
290 msg_id = self.db.get_history()[-1]
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
291 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(10)], buffers=[]))
292 self.assertEqual(len(self.db.get_history()), 100)
292 self.assertEqual(len(self.db.get_history()), 100)
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
293 self.db.update_record(msg_id, dict(result_buffers = [os.urandom(11)], buffers=[]))
294 self.assertEqual(len(self.db.get_history()), 79)
294 self.assertEqual(len(self.db.get_history()), 79)
295
295
296 class TestSQLiteBackend(TaskDBTest, TestCase):
296 class TestSQLiteBackend(TaskDBTest, TestCase):
297
297
298 @dec.skip_without('sqlite3')
298 @dec.skip_without('sqlite3')
299 def create_db(self):
299 def create_db(self):
300 location, fname = os.path.split(temp_db)
300 location, fname = os.path.split(temp_db)
301 log = logging.getLogger('test')
301 log = logging.getLogger('test')
302 log.setLevel(logging.CRITICAL)
302 log.setLevel(logging.CRITICAL)
303 return SQLiteDB(location=location, fname=fname, log=log)
303 return SQLiteDB(location=location, fname=fname, log=log)
304
304
305 def tearDown(self):
305 def tearDown(self):
306 self.db._db.close()
306 self.db._db.close()
307
307
308
308
309 def teardown():
309 def teardown():
310 """cleanup task db file after all tests have run"""
310 """cleanup task db file after all tests have run"""
311 try:
311 try:
312 os.remove(temp_db)
312 os.remove(temp_db)
313 except:
313 except:
314 pass
314 pass
@@ -1,256 +1,259 b''
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import math
14 import math
15 import re
15 import re
16 import types
16 import types
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 # base64.encodestring is deprecated in Python 3.x
20 # base64.encodestring is deprecated in Python 3.x
21 from base64 import encodebytes
21 from base64 import encodebytes
22 except ImportError:
22 except ImportError:
23 # Python 2.x
23 # Python 2.x
24 from base64 import encodestring as encodebytes
24 from base64 import encodestring as encodebytes
25
25
26 from IPython.utils import py3compat
26 from IPython.utils import py3compat
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
27 from IPython.utils.py3compat import string_types, unicode_type, iteritems
28 from IPython.utils.encoding import DEFAULT_ENCODING
28 from IPython.utils.encoding import DEFAULT_ENCODING
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Globals and constants
32 # Globals and constants
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34
34
35 # timestamp formats
35 # timestamp formats
36 ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
36 ISO8601 = "%Y-%m-%dT%H:%M:%S.%f"
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6})Z?([\+\-]\d{2}:?\d{2})?$")
37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2})(\.\d{1,6})?Z?([\+\-]\d{2}:?\d{2})?$")
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Classes and functions
40 # Classes and functions
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 def rekey(dikt):
43 def rekey(dikt):
44 """Rekey a dict that has been forced to use str keys where there should be
44 """Rekey a dict that has been forced to use str keys where there should be
45 ints by json."""
45 ints by json."""
46 for k in dikt:
46 for k in dikt:
47 if isinstance(k, string_types):
47 if isinstance(k, string_types):
48 ik=fk=None
48 ik=fk=None
49 try:
49 try:
50 ik = int(k)
50 ik = int(k)
51 except ValueError:
51 except ValueError:
52 try:
52 try:
53 fk = float(k)
53 fk = float(k)
54 except ValueError:
54 except ValueError:
55 continue
55 continue
56 if ik is not None:
56 if ik is not None:
57 nk = ik
57 nk = ik
58 else:
58 else:
59 nk = fk
59 nk = fk
60 if nk in dikt:
60 if nk in dikt:
61 raise KeyError("already have key %r"%nk)
61 raise KeyError("already have key %r"%nk)
62 dikt[nk] = dikt.pop(k)
62 dikt[nk] = dikt.pop(k)
63 return dikt
63 return dikt
64
64
65 def parse_date(s):
65 def parse_date(s):
66 """parse an ISO8601 date string
66 """parse an ISO8601 date string
67
67
68 If it is None or not a valid ISO8601 timestamp,
68 If it is None or not a valid ISO8601 timestamp,
69 it will be returned unmodified.
69 it will be returned unmodified.
70 Otherwise, it will return a datetime object.
70 Otherwise, it will return a datetime object.
71 """
71 """
72 if s is None:
72 if s is None:
73 return s
73 return s
74 m = ISO8601_PAT.match(s)
74 m = ISO8601_PAT.match(s)
75 if m:
75 if m:
76 # FIXME: add actual timezone support
76 # FIXME: add actual timezone support
77 # this just drops the timezone info
77 # this just drops the timezone info
78 notz = m.groups()[0]
78 notz, ms, tz = m.groups()
79 if not ms:
80 ms = '.0'
81 notz = notz + ms
79 return datetime.strptime(notz, ISO8601)
82 return datetime.strptime(notz, ISO8601)
80 return s
83 return s
81
84
82 def extract_dates(obj):
85 def extract_dates(obj):
83 """extract ISO8601 dates from unpacked JSON"""
86 """extract ISO8601 dates from unpacked JSON"""
84 if isinstance(obj, dict):
87 if isinstance(obj, dict):
85 new_obj = {} # don't clobber
88 new_obj = {} # don't clobber
86 for k,v in iteritems(obj):
89 for k,v in iteritems(obj):
87 new_obj[k] = extract_dates(v)
90 new_obj[k] = extract_dates(v)
88 obj = new_obj
91 obj = new_obj
89 elif isinstance(obj, (list, tuple)):
92 elif isinstance(obj, (list, tuple)):
90 obj = [ extract_dates(o) for o in obj ]
93 obj = [ extract_dates(o) for o in obj ]
91 elif isinstance(obj, string_types):
94 elif isinstance(obj, string_types):
92 obj = parse_date(obj)
95 obj = parse_date(obj)
93 return obj
96 return obj
94
97
95 def squash_dates(obj):
98 def squash_dates(obj):
96 """squash datetime objects into ISO8601 strings"""
99 """squash datetime objects into ISO8601 strings"""
97 if isinstance(obj, dict):
100 if isinstance(obj, dict):
98 obj = dict(obj) # don't clobber
101 obj = dict(obj) # don't clobber
99 for k,v in iteritems(obj):
102 for k,v in iteritems(obj):
100 obj[k] = squash_dates(v)
103 obj[k] = squash_dates(v)
101 elif isinstance(obj, (list, tuple)):
104 elif isinstance(obj, (list, tuple)):
102 obj = [ squash_dates(o) for o in obj ]
105 obj = [ squash_dates(o) for o in obj ]
103 elif isinstance(obj, datetime):
106 elif isinstance(obj, datetime):
104 obj = obj.isoformat()
107 obj = obj.isoformat()
105 return obj
108 return obj
106
109
107 def date_default(obj):
110 def date_default(obj):
108 """default function for packing datetime objects in JSON."""
111 """default function for packing datetime objects in JSON."""
109 if isinstance(obj, datetime):
112 if isinstance(obj, datetime):
110 return obj.isoformat()
113 return obj.isoformat()
111 else:
114 else:
112 raise TypeError("%r is not JSON serializable"%obj)
115 raise TypeError("%r is not JSON serializable"%obj)
113
116
114
117
115 # constants for identifying png/jpeg data
118 # constants for identifying png/jpeg data
116 PNG = b'\x89PNG\r\n\x1a\n'
119 PNG = b'\x89PNG\r\n\x1a\n'
117 # front of PNG base64-encoded
120 # front of PNG base64-encoded
118 PNG64 = b'iVBORw0KG'
121 PNG64 = b'iVBORw0KG'
119 JPEG = b'\xff\xd8'
122 JPEG = b'\xff\xd8'
120 # front of JPEG base64-encoded
123 # front of JPEG base64-encoded
121 JPEG64 = b'/9'
124 JPEG64 = b'/9'
122 # front of PDF base64-encoded
125 # front of PDF base64-encoded
123 PDF64 = b'JVBER'
126 PDF64 = b'JVBER'
124
127
125 def encode_images(format_dict):
128 def encode_images(format_dict):
126 """b64-encodes images in a displaypub format dict
129 """b64-encodes images in a displaypub format dict
127
130
128 Perhaps this should be handled in json_clean itself?
131 Perhaps this should be handled in json_clean itself?
129
132
130 Parameters
133 Parameters
131 ----------
134 ----------
132
135
133 format_dict : dict
136 format_dict : dict
134 A dictionary of display data keyed by mime-type
137 A dictionary of display data keyed by mime-type
135
138
136 Returns
139 Returns
137 -------
140 -------
138
141
139 format_dict : dict
142 format_dict : dict
140 A copy of the same dictionary,
143 A copy of the same dictionary,
141 but binary image data ('image/png', 'image/jpeg' or 'application/pdf')
144 but binary image data ('image/png', 'image/jpeg' or 'application/pdf')
142 is base64-encoded.
145 is base64-encoded.
143
146
144 """
147 """
145 encoded = format_dict.copy()
148 encoded = format_dict.copy()
146
149
147 pngdata = format_dict.get('image/png')
150 pngdata = format_dict.get('image/png')
148 if isinstance(pngdata, bytes):
151 if isinstance(pngdata, bytes):
149 # make sure we don't double-encode
152 # make sure we don't double-encode
150 if not pngdata.startswith(PNG64):
153 if not pngdata.startswith(PNG64):
151 pngdata = encodebytes(pngdata)
154 pngdata = encodebytes(pngdata)
152 encoded['image/png'] = pngdata.decode('ascii')
155 encoded['image/png'] = pngdata.decode('ascii')
153
156
154 jpegdata = format_dict.get('image/jpeg')
157 jpegdata = format_dict.get('image/jpeg')
155 if isinstance(jpegdata, bytes):
158 if isinstance(jpegdata, bytes):
156 # make sure we don't double-encode
159 # make sure we don't double-encode
157 if not jpegdata.startswith(JPEG64):
160 if not jpegdata.startswith(JPEG64):
158 jpegdata = encodebytes(jpegdata)
161 jpegdata = encodebytes(jpegdata)
159 encoded['image/jpeg'] = jpegdata.decode('ascii')
162 encoded['image/jpeg'] = jpegdata.decode('ascii')
160
163
161 pdfdata = format_dict.get('application/pdf')
164 pdfdata = format_dict.get('application/pdf')
162 if isinstance(pdfdata, bytes):
165 if isinstance(pdfdata, bytes):
163 # make sure we don't double-encode
166 # make sure we don't double-encode
164 if not pdfdata.startswith(PDF64):
167 if not pdfdata.startswith(PDF64):
165 pdfdata = encodebytes(pdfdata)
168 pdfdata = encodebytes(pdfdata)
166 encoded['application/pdf'] = pdfdata.decode('ascii')
169 encoded['application/pdf'] = pdfdata.decode('ascii')
167
170
168 return encoded
171 return encoded
169
172
170
173
171 def json_clean(obj):
174 def json_clean(obj):
172 """Clean an object to ensure it's safe to encode in JSON.
175 """Clean an object to ensure it's safe to encode in JSON.
173
176
174 Atomic, immutable objects are returned unmodified. Sets and tuples are
177 Atomic, immutable objects are returned unmodified. Sets and tuples are
175 converted to lists, lists are copied and dicts are also copied.
178 converted to lists, lists are copied and dicts are also copied.
176
179
177 Note: dicts whose keys could cause collisions upon encoding (such as a dict
180 Note: dicts whose keys could cause collisions upon encoding (such as a dict
178 with both the number 1 and the string '1' as keys) will cause a ValueError
181 with both the number 1 and the string '1' as keys) will cause a ValueError
179 to be raised.
182 to be raised.
180
183
181 Parameters
184 Parameters
182 ----------
185 ----------
183 obj : any python object
186 obj : any python object
184
187
185 Returns
188 Returns
186 -------
189 -------
187 out : object
190 out : object
188
191
189 A version of the input which will not cause an encoding error when
192 A version of the input which will not cause an encoding error when
190 encoded as JSON. Note that this function does not *encode* its inputs,
193 encoded as JSON. Note that this function does not *encode* its inputs,
191 it simply sanitizes it so that there will be no encoding errors later.
194 it simply sanitizes it so that there will be no encoding errors later.
192
195
193 Examples
196 Examples
194 --------
197 --------
195 >>> json_clean(4)
198 >>> json_clean(4)
196 4
199 4
197 >>> json_clean(list(range(10)))
200 >>> json_clean(list(range(10)))
198 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
201 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
199 >>> sorted(json_clean(dict(x=1, y=2)).items())
202 >>> sorted(json_clean(dict(x=1, y=2)).items())
200 [('x', 1), ('y', 2)]
203 [('x', 1), ('y', 2)]
201 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
204 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
202 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
205 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
203 >>> json_clean(True)
206 >>> json_clean(True)
204 True
207 True
205 """
208 """
206 # types that are 'atomic' and ok in json as-is.
209 # types that are 'atomic' and ok in json as-is.
207 atomic_ok = (unicode_type, type(None))
210 atomic_ok = (unicode_type, type(None))
208
211
209 # containers that we need to convert into lists
212 # containers that we need to convert into lists
210 container_to_list = (tuple, set, types.GeneratorType)
213 container_to_list = (tuple, set, types.GeneratorType)
211
214
212 if isinstance(obj, float):
215 if isinstance(obj, float):
213 # cast out-of-range floats to their reprs
216 # cast out-of-range floats to their reprs
214 if math.isnan(obj) or math.isinf(obj):
217 if math.isnan(obj) or math.isinf(obj):
215 return repr(obj)
218 return repr(obj)
216 return float(obj)
219 return float(obj)
217
220
218 if isinstance(obj, int):
221 if isinstance(obj, int):
219 # cast int to int, in case subclasses override __str__ (e.g. boost enum, #4598)
222 # cast int to int, in case subclasses override __str__ (e.g. boost enum, #4598)
220 if isinstance(obj, bool):
223 if isinstance(obj, bool):
221 # bools are ints, but we don't want to cast them to 0,1
224 # bools are ints, but we don't want to cast them to 0,1
222 return obj
225 return obj
223 return int(obj)
226 return int(obj)
224
227
225 if isinstance(obj, atomic_ok):
228 if isinstance(obj, atomic_ok):
226 return obj
229 return obj
227
230
228 if isinstance(obj, bytes):
231 if isinstance(obj, bytes):
229 return obj.decode(DEFAULT_ENCODING, 'replace')
232 return obj.decode(DEFAULT_ENCODING, 'replace')
230
233
231 if isinstance(obj, container_to_list) or (
234 if isinstance(obj, container_to_list) or (
232 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
235 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
233 obj = list(obj)
236 obj = list(obj)
234
237
235 if isinstance(obj, list):
238 if isinstance(obj, list):
236 return [json_clean(x) for x in obj]
239 return [json_clean(x) for x in obj]
237
240
238 if isinstance(obj, dict):
241 if isinstance(obj, dict):
239 # First, validate that the dict won't lose data in conversion due to
242 # First, validate that the dict won't lose data in conversion due to
240 # key collisions after stringification. This can happen with keys like
243 # key collisions after stringification. This can happen with keys like
241 # True and 'true' or 1 and '1', which collide in JSON.
244 # True and 'true' or 1 and '1', which collide in JSON.
242 nkeys = len(obj)
245 nkeys = len(obj)
243 nkeys_collapsed = len(set(map(str, obj)))
246 nkeys_collapsed = len(set(map(str, obj)))
244 if nkeys != nkeys_collapsed:
247 if nkeys != nkeys_collapsed:
245 raise ValueError('dict can not be safely converted to JSON: '
248 raise ValueError('dict can not be safely converted to JSON: '
246 'key collision would lead to dropped values')
249 'key collision would lead to dropped values')
247 # If all OK, proceed by making the new dict that will be json-safe
250 # If all OK, proceed by making the new dict that will be json-safe
248 out = {}
251 out = {}
249 for k,v in iteritems(obj):
252 for k,v in iteritems(obj):
250 out[str(k)] = json_clean(v)
253 out[str(k)] = json_clean(v)
251 return out
254 return out
252
255
253 # If we get here, we don't know how to handle the object, so we just get
256 # If we get here, we don't know how to handle the object, so we just get
254 # its repr and return that. This will catch lambdas, open sockets, class
257 # its repr and return that. This will catch lambdas, open sockets, class
255 # objects, and any other complicated contraption that json can't encode
258 # objects, and any other complicated contraption that json can't encode
256 return repr(obj)
259 return repr(obj)
@@ -1,149 +1,151 b''
1 """Test suite for our JSON utilities.
1 """Test suite for our JSON utilities.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010-2011 The IPython Development Team
4 # Copyright (C) 2010-2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import datetime
14 import datetime
15 import json
15 import json
16 from base64 import decodestring
16 from base64 import decodestring
17
17
18 # third party
18 # third party
19 import nose.tools as nt
19 import nose.tools as nt
20
20
21 # our own
21 # our own
22 from IPython.utils import jsonutil, tz
22 from IPython.utils import jsonutil, tz
23 from ..jsonutil import json_clean, encode_images
23 from ..jsonutil import json_clean, encode_images
24 from ..py3compat import unicode_to_str, str_to_bytes, iteritems
24 from ..py3compat import unicode_to_str, str_to_bytes, iteritems
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Test functions
27 # Test functions
28 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
29 class Int(int):
29 class Int(int):
30 def __str__(self):
30 def __str__(self):
31 return 'Int(%i)' % self
31 return 'Int(%i)' % self
32
32
33 def test():
33 def test():
34 # list of input/expected output. Use None for the expected output if it
34 # list of input/expected output. Use None for the expected output if it
35 # can be the same as the input.
35 # can be the same as the input.
36 pairs = [(1, None), # start with scalars
36 pairs = [(1, None), # start with scalars
37 (1.0, None),
37 (1.0, None),
38 ('a', None),
38 ('a', None),
39 (True, None),
39 (True, None),
40 (False, None),
40 (False, None),
41 (None, None),
41 (None, None),
42 # complex numbers for now just go to strings, as otherwise they
42 # complex numbers for now just go to strings, as otherwise they
43 # are unserializable
43 # are unserializable
44 (1j, '1j'),
44 (1j, '1j'),
45 # Containers
45 # Containers
46 ([1, 2], None),
46 ([1, 2], None),
47 ((1, 2), [1, 2]),
47 ((1, 2), [1, 2]),
48 (set([1, 2]), [1, 2]),
48 (set([1, 2]), [1, 2]),
49 (dict(x=1), None),
49 (dict(x=1), None),
50 ({'x': 1, 'y':[1,2,3], '1':'int'}, None),
50 ({'x': 1, 'y':[1,2,3], '1':'int'}, None),
51 # More exotic objects
51 # More exotic objects
52 ((x for x in range(3)), [0, 1, 2]),
52 ((x for x in range(3)), [0, 1, 2]),
53 (iter([1, 2]), [1, 2]),
53 (iter([1, 2]), [1, 2]),
54 (Int(5), 5),
54 (Int(5), 5),
55 ]
55 ]
56
56
57 for val, jval in pairs:
57 for val, jval in pairs:
58 if jval is None:
58 if jval is None:
59 jval = val
59 jval = val
60 out = json_clean(val)
60 out = json_clean(val)
61 # validate our cleanup
61 # validate our cleanup
62 nt.assert_equal(out, jval)
62 nt.assert_equal(out, jval)
63 # and ensure that what we return, indeed encodes cleanly
63 # and ensure that what we return, indeed encodes cleanly
64 json.loads(json.dumps(out))
64 json.loads(json.dumps(out))
65
65
66
66
67
67
68 def test_encode_images():
68 def test_encode_images():
69 # invalid data, but the header and footer are from real files
69 # invalid data, but the header and footer are from real files
70 pngdata = b'\x89PNG\r\n\x1a\nblahblahnotactuallyvalidIEND\xaeB`\x82'
70 pngdata = b'\x89PNG\r\n\x1a\nblahblahnotactuallyvalidIEND\xaeB`\x82'
71 jpegdata = b'\xff\xd8\xff\xe0\x00\x10JFIFblahblahjpeg(\xa0\x0f\xff\xd9'
71 jpegdata = b'\xff\xd8\xff\xe0\x00\x10JFIFblahblahjpeg(\xa0\x0f\xff\xd9'
72 pdfdata = b'%PDF-1.\ntrailer<</Root<</Pages<</Kids[<</MediaBox[0 0 3 3]>>]>>>>>>'
72 pdfdata = b'%PDF-1.\ntrailer<</Root<</Pages<</Kids[<</MediaBox[0 0 3 3]>>]>>>>>>'
73
73
74 fmt = {
74 fmt = {
75 'image/png' : pngdata,
75 'image/png' : pngdata,
76 'image/jpeg' : jpegdata,
76 'image/jpeg' : jpegdata,
77 'application/pdf' : pdfdata
77 'application/pdf' : pdfdata
78 }
78 }
79 encoded = encode_images(fmt)
79 encoded = encode_images(fmt)
80 for key, value in iteritems(fmt):
80 for key, value in iteritems(fmt):
81 # encoded has unicode, want bytes
81 # encoded has unicode, want bytes
82 decoded = decodestring(encoded[key].encode('ascii'))
82 decoded = decodestring(encoded[key].encode('ascii'))
83 nt.assert_equal(decoded, value)
83 nt.assert_equal(decoded, value)
84 encoded2 = encode_images(encoded)
84 encoded2 = encode_images(encoded)
85 nt.assert_equal(encoded, encoded2)
85 nt.assert_equal(encoded, encoded2)
86
86
87 b64_str = {}
87 b64_str = {}
88 for key, encoded in iteritems(encoded):
88 for key, encoded in iteritems(encoded):
89 b64_str[key] = unicode_to_str(encoded)
89 b64_str[key] = unicode_to_str(encoded)
90 encoded3 = encode_images(b64_str)
90 encoded3 = encode_images(b64_str)
91 nt.assert_equal(encoded3, b64_str)
91 nt.assert_equal(encoded3, b64_str)
92 for key, value in iteritems(fmt):
92 for key, value in iteritems(fmt):
93 # encoded3 has str, want bytes
93 # encoded3 has str, want bytes
94 decoded = decodestring(str_to_bytes(encoded3[key]))
94 decoded = decodestring(str_to_bytes(encoded3[key]))
95 nt.assert_equal(decoded, value)
95 nt.assert_equal(decoded, value)
96
96
97 def test_lambda():
97 def test_lambda():
98 jc = json_clean(lambda : 1)
98 jc = json_clean(lambda : 1)
99 assert isinstance(jc, str)
99 nt.assert_is_instance(jc, str)
100 assert '<lambda>' in jc
100 nt.assert_in('<lambda>', jc)
101 json.dumps(jc)
101 json.dumps(jc)
102
102
103 def test_extract_dates():
103 def test_extract_dates():
104 timestamps = [
104 timestamps = [
105 '2013-07-03T16:34:52.249482',
105 '2013-07-03T16:34:52.249482',
106 '2013-07-03T16:34:52.249482Z',
106 '2013-07-03T16:34:52.249482Z',
107 '2013-07-03T16:34:52.249482Z-0800',
107 '2013-07-03T16:34:52.249482Z-0800',
108 '2013-07-03T16:34:52.249482Z+0800',
108 '2013-07-03T16:34:52.249482Z+0800',
109 '2013-07-03T16:34:52.249482Z+08:00',
109 '2013-07-03T16:34:52.249482Z+08:00',
110 '2013-07-03T16:34:52.249482Z-08:00',
110 '2013-07-03T16:34:52.249482Z-08:00',
111 '2013-07-03T16:34:52.249482-0800',
111 '2013-07-03T16:34:52.249482-0800',
112 '2013-07-03T16:34:52.249482+0800',
112 '2013-07-03T16:34:52.249482+0800',
113 '2013-07-03T16:34:52.249482+08:00',
113 '2013-07-03T16:34:52.249482+08:00',
114 '2013-07-03T16:34:52.249482-08:00',
114 '2013-07-03T16:34:52.249482-08:00',
115 ]
115 ]
116 extracted = jsonutil.extract_dates(timestamps)
116 extracted = jsonutil.extract_dates(timestamps)
117 ref = extracted[0]
117 ref = extracted[0]
118 for dt in extracted:
118 for dt in extracted:
119 nt.assert_true(isinstance(dt, datetime.datetime))
119 nt.assert_true(isinstance(dt, datetime.datetime))
120 nt.assert_equal(dt, ref)
120 nt.assert_equal(dt, ref)
121
121
122 def test_parse_ms_precision():
122 def test_parse_ms_precision():
123 base = '2013-07-03T16:34:52.'
123 base = '2013-07-03T16:34:52'
124 digits = '1234567890'
124 digits = '1234567890'
125
125
126 parsed = jsonutil.parse_date(base)
127 nt.assert_is_instance(parsed, datetime.datetime)
126 for i in range(len(digits)):
128 for i in range(len(digits)):
127 ts = base + digits[:i]
129 ts = base + '.' + digits[:i]
128 parsed = jsonutil.parse_date(ts)
130 parsed = jsonutil.parse_date(ts)
129 if i >= 1 and i <= 6:
131 if i >= 1 and i <= 6:
130 assert isinstance(parsed, datetime.datetime)
132 nt.assert_is_instance(parsed, datetime.datetime)
131 else:
133 else:
132 assert isinstance(parsed, str)
134 nt.assert_is_instance(parsed, str)
133
135
134 def test_date_default():
136 def test_date_default():
135 data = dict(today=datetime.datetime.now(), utcnow=tz.utcnow())
137 data = dict(today=datetime.datetime.now(), utcnow=tz.utcnow())
136 jsondata = json.dumps(data, default=jsonutil.date_default)
138 jsondata = json.dumps(data, default=jsonutil.date_default)
137 nt.assert_in("+00", jsondata)
139 nt.assert_in("+00", jsondata)
138 nt.assert_equal(jsondata.count("+00"), 1)
140 nt.assert_equal(jsondata.count("+00"), 1)
139 extracted = jsonutil.extract_dates(json.loads(jsondata))
141 extracted = jsonutil.extract_dates(json.loads(jsondata))
140 for dt in extracted.values():
142 for dt in extracted.values():
141 nt.assert_true(isinstance(dt, datetime.datetime))
143 nt.assert_is_instance(dt, datetime.datetime)
142
144
143 def test_exception():
145 def test_exception():
144 bad_dicts = [{1:'number', '1':'string'},
146 bad_dicts = [{1:'number', '1':'string'},
145 {True:'bool', 'True':'string'},
147 {True:'bool', 'True':'string'},
146 ]
148 ]
147 for d in bad_dicts:
149 for d in bad_dicts:
148 nt.assert_raises(ValueError, json_clean, d)
150 nt.assert_raises(ValueError, json_clean, d)
149
151
General Comments 0
You need to be logged in to leave comments. Login now