##// END OF EJS Templates
various db backend fixes...
MinRK -
Show More
@@ -0,0 +1,37 b''
1 """Tests for mongodb backend"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 from nose import SkipTest
15
16 from pymongo import Connection
17 from IPython.parallel.controller.mongodb import MongoDB
18
19 from . import test_db
20
21 try:
22 c = Connection()
23 except Exception:
24 c=None
25
26 class TestMongoBackend(test_db.TestDictBackend):
27 """MongoDB backend tests"""
28
29 def create_db(self):
30 try:
31 return MongoDB(database='iptestdb', _connection=c)
32 except Exception:
33 raise SkipTest("Couldn't connect to mongodb")
34
35 def teardown(self):
36 if c is not None:
37 c.drop_database('iptestdb')
@@ -1,180 +1,180 b''
1 1 """A Task logger that presents our DB interface,
2 2 but exists entirely in memory and implemented with dicts.
3 3
4 4 TaskRecords are dicts of the form:
5 5 {
6 6 'msg_id' : str(uuid),
7 7 'client_uuid' : str(uuid),
8 8 'engine_uuid' : str(uuid) or None,
9 9 'header' : dict(header),
10 10 'content': dict(content),
11 11 'buffers': list(buffers),
12 12 'submitted': datetime,
13 13 'started': datetime or None,
14 14 'completed': datetime or None,
15 15 'resubmitted': datetime or None,
16 16 'result_header' : dict(header) or None,
17 17 'result_content' : dict(content) or None,
18 18 'result_buffers' : list(buffers) or None,
19 19 }
20 20 With this info, many of the special categories of tasks can be defined by query:
21 21
22 22 pending: completed is None
23 23 client's outstanding: client_uuid = uuid && completed is None
24 24 MIA: arrived is None (and completed is None)
25 25 etc.
26 26
27 27 EngineRecords are dicts of the form:
28 28 {
29 29 'eid' : int(id),
30 30 'uuid': str(uuid)
31 31 }
32 32 This may be extended, but is currently.
33 33
34 34 We support a subset of mongodb operators:
35 35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 36 """
37 37 #-----------------------------------------------------------------------------
38 38 # Copyright (C) 2010 The IPython Development Team
39 39 #
40 40 # Distributed under the terms of the BSD License. The full license is in
41 41 # the file COPYING, distributed as part of this software.
42 42 #-----------------------------------------------------------------------------
43 43
44 44
45 45 from datetime import datetime
46 46
47 47 from IPython.config.configurable import Configurable
48 48
49 49 from IPython.utils.traitlets import Dict, CUnicode
50 50
51 51 filters = {
52 52 '$lt' : lambda a,b: a < b,
53 53 '$gt' : lambda a,b: b > a,
54 54 '$eq' : lambda a,b: a == b,
55 55 '$ne' : lambda a,b: a != b,
56 56 '$lte': lambda a,b: a <= b,
57 57 '$gte': lambda a,b: a >= b,
58 58 '$in' : lambda a,b: a in b,
59 59 '$nin': lambda a,b: a not in b,
60 60 '$all': lambda a,b: all([ a in bb for bb in b ]),
61 61 '$mod': lambda a,b: a%b[0] == b[1],
62 62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
63 63 }
64 64
65 65
66 66 class CompositeFilter(object):
67 67 """Composite filter for matching multiple properties."""
68 68
69 69 def __init__(self, dikt):
70 70 self.tests = []
71 71 self.values = []
72 72 for key, value in dikt.iteritems():
73 73 self.tests.append(filters[key])
74 74 self.values.append(value)
75 75
76 76 def __call__(self, value):
77 77 for test,check in zip(self.tests, self.values):
78 78 if not test(value, check):
79 79 return False
80 80 return True
81 81
82 82 class BaseDB(Configurable):
83 83 """Empty Parent class so traitlets work on DB."""
84 84 # base configurable traits:
85 85 session = CUnicode("")
86 86
87 87 class DictDB(BaseDB):
88 88 """Basic in-memory dict-based object for saving Task Records.
89 89
90 90 This is the first object to present the DB interface
91 91 for logging tasks out of memory.
92 92
93 93 The interface is based on MongoDB, so adding a MongoDB
94 94 backend should be straightforward.
95 95 """
96 96
97 97 _records = Dict()
98 98
99 99 def _match_one(self, rec, tests):
100 100 """Check if a specific record matches tests."""
101 101 for key,test in tests.iteritems():
102 102 if not test(rec.get(key, None)):
103 103 return False
104 104 return True
105 105
106 106 def _match(self, check):
107 107 """Find all the matches for a check dict."""
108 108 matches = []
109 109 tests = {}
110 110 for k,v in check.iteritems():
111 111 if isinstance(v, dict):
112 112 tests[k] = CompositeFilter(v)
113 113 else:
114 114 tests[k] = lambda o: o==v
115 115
116 116 for rec in self._records.itervalues():
117 117 if self._match_one(rec, tests):
118 118 matches.append(rec)
119 119 return matches
120 120
121 121 def _extract_subdict(self, rec, keys):
122 122 """extract subdict of keys"""
123 123 d = {}
124 124 d['msg_id'] = rec['msg_id']
125 125 for key in keys:
126 126 d[key] = rec[key]
127 127 return d
128 128
129 129 def add_record(self, msg_id, rec):
130 130 """Add a new Task Record, by msg_id."""
131 131 if self._records.has_key(msg_id):
132 132 raise KeyError("Already have msg_id %r"%(msg_id))
133 133 self._records[msg_id] = rec
134 134
135 135 def get_record(self, msg_id):
136 136 """Get a specific Task Record, by msg_id."""
137 137 if not self._records.has_key(msg_id):
138 138 raise KeyError("No such msg_id %r"%(msg_id))
139 139 return self._records[msg_id]
140 140
141 141 def update_record(self, msg_id, rec):
142 142 """Update the data in an existing record."""
143 143 self._records[msg_id].update(rec)
144 144
145 145 def drop_matching_records(self, check):
146 146 """Remove a record from the DB."""
147 147 matches = self._match(check)
148 148 for m in matches:
149 del self._records[m]
149 del self._records[m['msg_id']]
150 150
151 151 def drop_record(self, msg_id):
152 152 """Remove a record from the DB."""
153 153 del self._records[msg_id]
154 154
155 155
156 156 def find_records(self, check, keys=None):
157 157 """Find records matching a query dict, optionally extracting subset of keys.
158 158
159 159 Returns dict keyed by msg_id of matching records.
160 160
161 161 Parameters
162 162 ----------
163 163
164 164 check: dict
165 165 mongodb-style query argument
166 166 keys: list of strs [optional]
167 167 if specified, the subset of keys to extract. msg_id will *always* be
168 168 included.
169 169 """
170 170 matches = self._match(check)
171 171 if keys:
172 172 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 173 else:
174 174 return matches
175 175
176 176
177 177 def get_history(self):
178 178 """get all msg_ids, ordered by time submitted."""
179 179 msg_ids = self._records.keys()
180 180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -1,1284 +1,1282 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29 29
30 30 from IPython.parallel import error, util
31 31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 32
33 33 from .heartmonitor import HeartMonitor
34 34
35 35 #-----------------------------------------------------------------------------
36 36 # Code
37 37 #-----------------------------------------------------------------------------
38 38
39 39 def _passer(*args, **kwargs):
40 40 return
41 41
42 42 def _printer(*args, **kwargs):
43 43 print (args)
44 44 print (kwargs)
45 45
46 46 def empty_record():
47 47 """Return an empty dict with all record keys."""
48 48 return {
49 49 'msg_id' : None,
50 50 'header' : None,
51 51 'content': None,
52 52 'buffers': None,
53 53 'submitted': None,
54 54 'client_uuid' : None,
55 55 'engine_uuid' : None,
56 56 'started': None,
57 57 'completed': None,
58 58 'resubmitted': None,
59 59 'result_header' : None,
60 60 'result_content' : None,
61 61 'result_buffers' : None,
62 62 'queue' : None,
63 63 'pyin' : None,
64 64 'pyout': None,
65 65 'pyerr': None,
66 66 'stdout': '',
67 67 'stderr': '',
68 68 }
69 69
70 70 def init_record(msg):
71 71 """Initialize a TaskRecord based on a request."""
72 72 header = msg['header']
73 73 return {
74 74 'msg_id' : header['msg_id'],
75 75 'header' : header,
76 76 'content': msg['content'],
77 77 'buffers': msg['buffers'],
78 78 'submitted': datetime.strptime(header['date'], util.ISO8601),
79 79 'client_uuid' : None,
80 80 'engine_uuid' : None,
81 81 'started': None,
82 82 'completed': None,
83 83 'resubmitted': None,
84 84 'result_header' : None,
85 85 'result_content' : None,
86 86 'result_buffers' : None,
87 87 'queue' : None,
88 88 'pyin' : None,
89 89 'pyout': None,
90 90 'pyerr': None,
91 91 'stdout': '',
92 92 'stderr': '',
93 93 }
94 94
95 95
96 96 class EngineConnector(HasTraits):
97 97 """A simple object for accessing the various zmq connections of an object.
98 98 Attributes are:
99 99 id (int): engine ID
100 100 uuid (str): uuid (unused?)
101 101 queue (str): identity of queue's XREQ socket
102 102 registration (str): identity of registration XREQ socket
103 103 heartbeat (str): identity of heartbeat XREQ socket
104 104 """
105 105 id=Int(0)
106 106 queue=Str()
107 107 control=Str()
108 108 registration=Str()
109 109 heartbeat=Str()
110 110 pending=Set()
111 111
112 112 class HubFactory(RegistrationFactory):
113 113 """The Configurable for setting up a Hub."""
114 114
115 115 # name of a scheduler scheme
116 116 scheme = Str('leastload', config=True)
117 117
118 118 # port-pairs for monitoredqueues:
119 119 hb = Instance(list, config=True)
120 120 def _hb_default(self):
121 121 return util.select_random_ports(2)
122 122
123 123 mux = Instance(list, config=True)
124 124 def _mux_default(self):
125 125 return util.select_random_ports(2)
126 126
127 127 task = Instance(list, config=True)
128 128 def _task_default(self):
129 129 return util.select_random_ports(2)
130 130
131 131 control = Instance(list, config=True)
132 132 def _control_default(self):
133 133 return util.select_random_ports(2)
134 134
135 135 iopub = Instance(list, config=True)
136 136 def _iopub_default(self):
137 137 return util.select_random_ports(2)
138 138
139 139 # single ports:
140 140 mon_port = Instance(int, config=True)
141 141 def _mon_port_default(self):
142 142 return util.select_random_ports(1)[0]
143 143
144 144 notifier_port = Instance(int, config=True)
145 145 def _notifier_port_default(self):
146 146 return util.select_random_ports(1)[0]
147 147
148 148 ping = Int(1000, config=True) # ping frequency
149 149
150 150 engine_ip = CStr('127.0.0.1', config=True)
151 151 engine_transport = CStr('tcp', config=True)
152 152
153 153 client_ip = CStr('127.0.0.1', config=True)
154 154 client_transport = CStr('tcp', config=True)
155 155
156 156 monitor_ip = CStr('127.0.0.1', config=True)
157 157 monitor_transport = CStr('tcp', config=True)
158 158
159 159 monitor_url = CStr('')
160 160
161 161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
162 162
163 163 # not configurable
164 164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
165 165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
166 166 subconstructors = List()
167 167 _constructed = Bool(False)
168 168
169 169 def _ip_changed(self, name, old, new):
170 170 self.engine_ip = new
171 171 self.client_ip = new
172 172 self.monitor_ip = new
173 173 self._update_monitor_url()
174 174
175 175 def _update_monitor_url(self):
176 176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
177 177
178 178 def _transport_changed(self, name, old, new):
179 179 self.engine_transport = new
180 180 self.client_transport = new
181 181 self.monitor_transport = new
182 182 self._update_monitor_url()
183 183
184 184 def __init__(self, **kwargs):
185 185 super(HubFactory, self).__init__(**kwargs)
186 186 self._update_monitor_url()
187 187 # self.on_trait_change(self._sync_ips, 'ip')
188 188 # self.on_trait_change(self._sync_transports, 'transport')
189 189 self.subconstructors.append(self.construct_hub)
190 190
191 191
192 192 def construct(self):
193 193 assert not self._constructed, "already constructed!"
194 194
195 195 for subc in self.subconstructors:
196 196 subc()
197 197
198 198 self._constructed = True
199 199
200 200
201 201 def start(self):
202 202 assert self._constructed, "must be constructed by self.construct() first!"
203 203 self.heartmonitor.start()
204 204 self.log.info("Heartmonitor started")
205 205
206 206 def construct_hub(self):
207 207 """construct"""
208 208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
209 209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
210 210
211 211 ctx = self.context
212 212 loop = self.loop
213 213
214 214 # Registrar socket
215 215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
216 216 q.bind(client_iface % self.regport)
217 217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
218 218 if self.client_ip != self.engine_ip:
219 219 q.bind(engine_iface % self.regport)
220 220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
221 221
222 222 ### Engine connections ###
223 223
224 224 # heartbeat
225 225 hpub = ctx.socket(zmq.PUB)
226 226 hpub.bind(engine_iface % self.hb[0])
227 227 hrep = ctx.socket(zmq.XREP)
228 228 hrep.bind(engine_iface % self.hb[1])
229 229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
230 230 period=self.ping, logname=self.log.name)
231 231
232 232 ### Client connections ###
233 233 # Notifier socket
234 234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
235 235 n.bind(client_iface%self.notifier_port)
236 236
237 237 ### build and launch the queues ###
238 238
239 239 # monitor socket
240 240 sub = ctx.socket(zmq.SUB)
241 241 sub.setsockopt(zmq.SUBSCRIBE, "")
242 242 sub.bind(self.monitor_url)
243 243 sub.bind('inproc://monitor')
244 244 sub = ZMQStream(sub, loop)
245 245
246 246 # connect the db
247 247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
248 248 # cdir = self.config.Global.cluster_dir
249 249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
250 250 time.sleep(.25)
251 251
252 252 # build connection dicts
253 253 self.engine_info = {
254 254 'control' : engine_iface%self.control[1],
255 255 'mux': engine_iface%self.mux[1],
256 256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
257 257 'task' : engine_iface%self.task[1],
258 258 'iopub' : engine_iface%self.iopub[1],
259 259 # 'monitor' : engine_iface%self.mon_port,
260 260 }
261 261
262 262 self.client_info = {
263 263 'control' : client_iface%self.control[0],
264 264 'mux': client_iface%self.mux[0],
265 265 'task' : (self.scheme, client_iface%self.task[0]),
266 266 'iopub' : client_iface%self.iopub[0],
267 267 'notification': client_iface%self.notifier_port
268 268 }
269 269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 270 self.log.debug("Hub client addrs: %s"%self.client_info)
271 271
272 272 # resubmit stream
273 273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 274 url = util.disambiguate_url(self.client_info['task'][-1])
275 275 r.setsockopt(zmq.IDENTITY, self.session.session)
276 276 r.connect(url)
277 277
278 278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
279 279 query=q, notifier=n, resubmit=r, db=self.db,
280 280 engine_info=self.engine_info, client_info=self.client_info,
281 281 logname=self.log.name)
282 282
283 283
284 284 class Hub(LoggingFactory):
285 285 """The IPython Controller Hub with 0MQ connections
286 286
287 287 Parameters
288 288 ==========
289 289 loop: zmq IOLoop instance
290 290 session: StreamSession object
291 291 <removed> context: zmq context for creating new connections (?)
292 292 queue: ZMQStream for monitoring the command queue (SUB)
293 293 query: ZMQStream for engine registration and client queries requests (XREP)
294 294 heartbeat: HeartMonitor object checking the pulse of the engines
295 295 notifier: ZMQStream for broadcasting engine registration changes (PUB)
296 296 db: connection to db for out of memory logging of commands
297 297 NotImplemented
298 298 engine_info: dict of zmq connection information for engines to connect
299 299 to the queues.
300 300 client_info: dict of zmq connection information for engines to connect
301 301 to the queues.
302 302 """
303 303 # internal data structures:
304 304 ids=Set() # engine IDs
305 305 keytable=Dict()
306 306 by_ident=Dict()
307 307 engines=Dict()
308 308 clients=Dict()
309 309 hearts=Dict()
310 310 pending=Set()
311 311 queues=Dict() # pending msg_ids keyed by engine_id
312 312 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
313 313 completed=Dict() # completed msg_ids keyed by engine_id
314 314 all_completed=Set() # completed msg_ids keyed by engine_id
315 315 dead_engines=Set() # completed msg_ids keyed by engine_id
316 316 unassigned=Set() # set of task msg_ds not yet assigned a destination
317 317 incoming_registrations=Dict()
318 318 registration_timeout=Int()
319 319 _idcounter=Int(0)
320 320
321 321 # objects from constructor:
322 322 loop=Instance(ioloop.IOLoop)
323 323 query=Instance(ZMQStream)
324 324 monitor=Instance(ZMQStream)
325 325 notifier=Instance(ZMQStream)
326 326 resubmit=Instance(ZMQStream)
327 327 heartmonitor=Instance(HeartMonitor)
328 328 db=Instance(object)
329 329 client_info=Dict()
330 330 engine_info=Dict()
331 331
332 332
333 333 def __init__(self, **kwargs):
334 334 """
335 335 # universal:
336 336 loop: IOLoop for creating future connections
337 337 session: streamsession for sending serialized data
338 338 # engine:
339 339 queue: ZMQStream for monitoring queue messages
340 340 query: ZMQStream for engine+client registration and client requests
341 341 heartbeat: HeartMonitor object for tracking engines
342 342 # extra:
343 343 db: ZMQStream for db connection (NotImplemented)
344 344 engine_info: zmq address/protocol dict for engine connections
345 345 client_info: zmq address/protocol dict for client connections
346 346 """
347 347
348 348 super(Hub, self).__init__(**kwargs)
349 349 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
350 350
351 351 # validate connection dicts:
352 352 for k,v in self.client_info.iteritems():
353 353 if k == 'task':
354 354 util.validate_url_container(v[1])
355 355 else:
356 356 util.validate_url_container(v)
357 357 # util.validate_url_container(self.client_info)
358 358 util.validate_url_container(self.engine_info)
359 359
360 360 # register our callbacks
361 361 self.query.on_recv(self.dispatch_query)
362 362 self.monitor.on_recv(self.dispatch_monitor_traffic)
363 363
364 364 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
365 365 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
366 366
367 367 self.monitor_handlers = { 'in' : self.save_queue_request,
368 368 'out': self.save_queue_result,
369 369 'intask': self.save_task_request,
370 370 'outtask': self.save_task_result,
371 371 'tracktask': self.save_task_destination,
372 372 'incontrol': _passer,
373 373 'outcontrol': _passer,
374 374 'iopub': self.save_iopub_message,
375 375 }
376 376
377 377 self.query_handlers = {'queue_request': self.queue_status,
378 378 'result_request': self.get_results,
379 379 'history_request': self.get_history,
380 380 'db_request': self.db_query,
381 381 'purge_request': self.purge_results,
382 382 'load_request': self.check_load,
383 383 'resubmit_request': self.resubmit_task,
384 384 'shutdown_request': self.shutdown_request,
385 385 'registration_request' : self.register_engine,
386 386 'unregistration_request' : self.unregister_engine,
387 387 'connection_request': self.connection_request,
388 388 }
389 389
390 390 # ignore resubmit replies
391 391 self.resubmit.on_recv(lambda msg: None, copy=False)
392 392
393 393 self.log.info("hub::created hub")
394 394
395 395 @property
396 396 def _next_id(self):
397 397 """gemerate a new ID.
398 398
399 399 No longer reuse old ids, just count from 0."""
400 400 newid = self._idcounter
401 401 self._idcounter += 1
402 402 return newid
403 403 # newid = 0
404 404 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
405 405 # # print newid, self.ids, self.incoming_registrations
406 406 # while newid in self.ids or newid in incoming:
407 407 # newid += 1
408 408 # return newid
409 409
410 410 #-----------------------------------------------------------------------------
411 411 # message validation
412 412 #-----------------------------------------------------------------------------
413 413
414 414 def _validate_targets(self, targets):
415 415 """turn any valid targets argument into a list of integer ids"""
416 416 if targets is None:
417 417 # default to all
418 418 targets = self.ids
419 419
420 420 if isinstance(targets, (int,str,unicode)):
421 421 # only one target specified
422 422 targets = [targets]
423 423 _targets = []
424 424 for t in targets:
425 425 # map raw identities to ids
426 426 if isinstance(t, (str,unicode)):
427 427 t = self.by_ident.get(t, t)
428 428 _targets.append(t)
429 429 targets = _targets
430 430 bad_targets = [ t for t in targets if t not in self.ids ]
431 431 if bad_targets:
432 432 raise IndexError("No Such Engine: %r"%bad_targets)
433 433 if not targets:
434 434 raise IndexError("No Engines Registered")
435 435 return targets
436 436
437 437 #-----------------------------------------------------------------------------
438 438 # dispatch methods (1 per stream)
439 439 #-----------------------------------------------------------------------------
440 440
441 441 # def dispatch_registration_request(self, msg):
442 442 # """"""
443 443 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
444 444 # idents,msg = self.session.feed_identities(msg)
445 445 # if not idents:
446 446 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
447 447 # return
448 448 # try:
449 449 # msg = self.session.unpack_message(msg,content=True)
450 450 # except:
451 451 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
452 452 # return
453 453 #
454 454 # msg_type = msg['msg_type']
455 455 # content = msg['content']
456 456 #
457 457 # handler = self.query_handlers.get(msg_type, None)
458 458 # if handler is None:
459 459 # self.log.error("registration::got bad registration message: %s"%msg)
460 460 # else:
461 461 # handler(idents, msg)
462 462
463 463 def dispatch_monitor_traffic(self, msg):
464 464 """all ME and Task queue messages come through here, as well as
465 465 IOPub traffic."""
466 466 self.log.debug("monitor traffic: %r"%msg[:2])
467 467 switch = msg[0]
468 468 idents, msg = self.session.feed_identities(msg[1:])
469 469 if not idents:
470 470 self.log.error("Bad Monitor Message: %r"%msg)
471 471 return
472 472 handler = self.monitor_handlers.get(switch, None)
473 473 if handler is not None:
474 474 handler(idents, msg)
475 475 else:
476 476 self.log.error("Invalid monitor topic: %r"%switch)
477 477
478 478
479 479 def dispatch_query(self, msg):
480 480 """Route registration requests and queries from clients."""
481 481 idents, msg = self.session.feed_identities(msg)
482 482 if not idents:
483 483 self.log.error("Bad Query Message: %r"%msg)
484 484 return
485 485 client_id = idents[0]
486 486 try:
487 487 msg = self.session.unpack_message(msg, content=True)
488 488 except:
489 489 content = error.wrap_exception()
490 490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 491 self.session.send(self.query, "hub_error", ident=client_id,
492 492 content=content)
493 493 return
494 494
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 497 msg_type = msg['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %s"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 553 msg = self.session.unpack_message(msg, content=False)
554 554 except:
555 555 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
562 562 return
563 563
564 564 header = msg['header']
565 565 msg_id = header['msg_id']
566 566 record = init_record(msg)
567 567 record['engine_uuid'] = queue_id
568 568 record['client_uuid'] = client_id
569 569 record['queue'] = 'mux'
570 570
571 571 try:
572 572 # it's posible iopub arrived first:
573 573 existing = self.db.get_record(msg_id)
574 574 for key,evalue in existing.iteritems():
575 575 rvalue = record.get(key, None)
576 576 if evalue and rvalue and evalue != rvalue:
577 577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
578 578 elif evalue and not rvalue:
579 579 record[key] = evalue
580 580 self.db.update_record(msg_id, record)
581 581 except KeyError:
582 582 self.db.add_record(msg_id, record)
583 583
584 584 self.pending.add(msg_id)
585 585 self.queues[eid].append(msg_id)
586 586
587 587 def save_queue_result(self, idents, msg):
588 588 if len(idents) < 2:
589 589 self.log.error("invalid identity prefix: %s"%idents)
590 590 return
591 591
592 592 client_id, queue_id = idents[:2]
593 593 try:
594 594 msg = self.session.unpack_message(msg, content=False)
595 595 except:
596 596 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
597 597 queue_id,client_id, msg), exc_info=True)
598 598 return
599 599
600 600 eid = self.by_ident.get(queue_id, None)
601 601 if eid is None:
602 602 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
603 603 # self.log.debug("queue:: %s"%msg[2:])
604 604 return
605 605
606 606 parent = msg['parent_header']
607 607 if not parent:
608 608 return
609 609 msg_id = parent['msg_id']
610 610 if msg_id in self.pending:
611 611 self.pending.remove(msg_id)
612 612 self.all_completed.add(msg_id)
613 613 self.queues[eid].remove(msg_id)
614 614 self.completed[eid].append(msg_id)
615 615 elif msg_id not in self.all_completed:
616 616 # it could be a result from a dead engine that died before delivering the
617 617 # result
618 618 self.log.warn("queue:: unknown msg finished %s"%msg_id)
619 619 return
620 620 # update record anyway, because the unregistration could have been premature
621 621 rheader = msg['header']
622 622 completed = datetime.strptime(rheader['date'], util.ISO8601)
623 623 started = rheader.get('started', None)
624 624 if started is not None:
625 625 started = datetime.strptime(started, util.ISO8601)
626 626 result = {
627 627 'result_header' : rheader,
628 628 'result_content': msg['content'],
629 629 'started' : started,
630 630 'completed' : completed
631 631 }
632 632
633 633 result['result_buffers'] = msg['buffers']
634 634 try:
635 635 self.db.update_record(msg_id, result)
636 636 except Exception:
637 637 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
638 638
639 639
640 640 #--------------------- Task Queue Traffic ------------------------------
641 641
642 642 def save_task_request(self, idents, msg):
643 643 """Save the submission of a task."""
644 644 client_id = idents[0]
645 645
646 646 try:
647 647 msg = self.session.unpack_message(msg, content=False)
648 648 except:
649 649 self.log.error("task::client %r sent invalid task message: %s"%(
650 650 client_id, msg), exc_info=True)
651 651 return
652 652 record = init_record(msg)
653 653
654 654 record['client_uuid'] = client_id
655 655 record['queue'] = 'task'
656 656 header = msg['header']
657 657 msg_id = header['msg_id']
658 658 self.pending.add(msg_id)
659 659 self.unassigned.add(msg_id)
660 660 try:
661 661 # it's posible iopub arrived first:
662 662 existing = self.db.get_record(msg_id)
663 663 if existing['resubmitted']:
664 664 for key in ('submitted', 'client_uuid', 'buffers'):
665 665 # don't clobber these keys on resubmit
666 666 # submitted and client_uuid should be different
667 667 # and buffers might be big, and shouldn't have changed
668 668 record.pop(key)
669 669 # still check content,header which should not change
670 670 # but are not expensive to compare as buffers
671 671
672 672 for key,evalue in existing.iteritems():
673 673 if key.endswith('buffers'):
674 674 # don't compare buffers
675 675 continue
676 676 rvalue = record.get(key, None)
677 677 if evalue and rvalue and evalue != rvalue:
678 678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
679 679 elif evalue and not rvalue:
680 680 record[key] = evalue
681 681 self.db.update_record(msg_id, record)
682 682 except KeyError:
683 683 self.db.add_record(msg_id, record)
684 684 except Exception:
685 685 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
686 686
687 687 def save_task_result(self, idents, msg):
688 688 """save the result of a completed task."""
689 689 client_id = idents[0]
690 690 try:
691 691 msg = self.session.unpack_message(msg, content=False)
692 692 except:
693 693 self.log.error("task::invalid task result message send to %r: %s"%(
694 694 client_id, msg), exc_info=True)
695 695 raise
696 696 return
697 697
698 698 parent = msg['parent_header']
699 699 if not parent:
700 700 # print msg
701 701 self.log.warn("Task %r had no parent!"%msg)
702 702 return
703 703 msg_id = parent['msg_id']
704 704 if msg_id in self.unassigned:
705 705 self.unassigned.remove(msg_id)
706 706
707 707 header = msg['header']
708 708 engine_uuid = header.get('engine', None)
709 709 eid = self.by_ident.get(engine_uuid, None)
710 710
711 711 if msg_id in self.pending:
712 712 self.pending.remove(msg_id)
713 713 self.all_completed.add(msg_id)
714 714 if eid is not None:
715 715 self.completed[eid].append(msg_id)
716 716 if msg_id in self.tasks[eid]:
717 717 self.tasks[eid].remove(msg_id)
718 718 completed = datetime.strptime(header['date'], util.ISO8601)
719 719 started = header.get('started', None)
720 720 if started is not None:
721 721 started = datetime.strptime(started, util.ISO8601)
722 722 result = {
723 723 'result_header' : header,
724 724 'result_content': msg['content'],
725 725 'started' : started,
726 726 'completed' : completed,
727 727 'engine_uuid': engine_uuid
728 728 }
729 729
730 730 result['result_buffers'] = msg['buffers']
731 731 try:
732 732 self.db.update_record(msg_id, result)
733 733 except Exception:
734 734 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
735 735
736 736 else:
737 737 self.log.debug("task::unknown task %s finished"%msg_id)
738 738
739 739 def save_task_destination(self, idents, msg):
740 740 try:
741 741 msg = self.session.unpack_message(msg, content=True)
742 742 except:
743 743 self.log.error("task::invalid task tracking message", exc_info=True)
744 744 return
745 745 content = msg['content']
746 746 # print (content)
747 747 msg_id = content['msg_id']
748 748 engine_uuid = content['engine_id']
749 749 eid = self.by_ident[engine_uuid]
750 750
751 751 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
752 752 if msg_id in self.unassigned:
753 753 self.unassigned.remove(msg_id)
754 754 # else:
755 755 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
756 756
757 757 self.tasks[eid].append(msg_id)
758 758 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
759 759 try:
760 760 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
761 761 except Exception:
762 762 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
763 763
764 764
765 765 def mia_task_request(self, idents, msg):
766 766 raise NotImplementedError
767 767 client_id = idents[0]
768 768 # content = dict(mia=self.mia,status='ok')
769 769 # self.session.send('mia_reply', content=content, idents=client_id)
770 770
771 771
772 772 #--------------------- IOPub Traffic ------------------------------
773 773
774 774 def save_iopub_message(self, topics, msg):
775 775 """save an iopub message into the db"""
776 776 # print (topics)
777 777 try:
778 778 msg = self.session.unpack_message(msg, content=True)
779 779 except:
780 780 self.log.error("iopub::invalid IOPub message", exc_info=True)
781 781 return
782 782
783 783 parent = msg['parent_header']
784 784 if not parent:
785 785 self.log.error("iopub::invalid IOPub message: %s"%msg)
786 786 return
787 787 msg_id = parent['msg_id']
788 788 msg_type = msg['msg_type']
789 789 content = msg['content']
790 790
791 791 # ensure msg_id is in db
792 792 try:
793 793 rec = self.db.get_record(msg_id)
794 794 except KeyError:
795 795 rec = empty_record()
796 796 rec['msg_id'] = msg_id
797 797 self.db.add_record(msg_id, rec)
798 798 # stream
799 799 d = {}
800 800 if msg_type == 'stream':
801 801 name = content['name']
802 802 s = rec[name] or ''
803 803 d[name] = s + content['data']
804 804
805 805 elif msg_type == 'pyerr':
806 806 d['pyerr'] = content
807 807 elif msg_type == 'pyin':
808 808 d['pyin'] = content['code']
809 809 else:
810 810 d[msg_type] = content.get('data', '')
811 811
812 812 try:
813 813 self.db.update_record(msg_id, d)
814 814 except Exception:
815 815 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
816 816
817 817
818 818
819 819 #-------------------------------------------------------------------------
820 820 # Registration requests
821 821 #-------------------------------------------------------------------------
822 822
823 823 def connection_request(self, client_id, msg):
824 824 """Reply with connection addresses for clients."""
825 825 self.log.info("client::client %s connected"%client_id)
826 826 content = dict(status='ok')
827 827 content.update(self.client_info)
828 828 jsonable = {}
829 829 for k,v in self.keytable.iteritems():
830 830 if v not in self.dead_engines:
831 831 jsonable[str(k)] = v
832 832 content['engines'] = jsonable
833 833 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
834 834
835 835 def register_engine(self, reg, msg):
836 836 """Register a new engine."""
837 837 content = msg['content']
838 838 try:
839 839 queue = content['queue']
840 840 except KeyError:
841 841 self.log.error("registration::queue not specified", exc_info=True)
842 842 return
843 843 heart = content.get('heartbeat', None)
844 844 """register a new engine, and create the socket(s) necessary"""
845 845 eid = self._next_id
846 846 # print (eid, queue, reg, heart)
847 847
848 848 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
849 849
850 850 content = dict(id=eid,status='ok')
851 851 content.update(self.engine_info)
852 852 # check if requesting available IDs:
853 853 if queue in self.by_ident:
854 854 try:
855 855 raise KeyError("queue_id %r in use"%queue)
856 856 except:
857 857 content = error.wrap_exception()
858 858 self.log.error("queue_id %r in use"%queue, exc_info=True)
859 859 elif heart in self.hearts: # need to check unique hearts?
860 860 try:
861 861 raise KeyError("heart_id %r in use"%heart)
862 862 except:
863 863 self.log.error("heart_id %r in use"%heart, exc_info=True)
864 864 content = error.wrap_exception()
865 865 else:
866 866 for h, pack in self.incoming_registrations.iteritems():
867 867 if heart == h:
868 868 try:
869 869 raise KeyError("heart_id %r in use"%heart)
870 870 except:
871 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 872 content = error.wrap_exception()
873 873 break
874 874 elif queue == pack[1]:
875 875 try:
876 876 raise KeyError("queue_id %r in use"%queue)
877 877 except:
878 878 self.log.error("queue_id %r in use"%queue, exc_info=True)
879 879 content = error.wrap_exception()
880 880 break
881 881
882 882 msg = self.session.send(self.query, "registration_reply",
883 883 content=content,
884 884 ident=reg)
885 885
886 886 if content['status'] == 'ok':
887 887 if heart in self.heartmonitor.hearts:
888 888 # already beating
889 889 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
890 890 self.finish_registration(heart)
891 891 else:
892 892 purge = lambda : self._purge_stalled_registration(heart)
893 893 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
894 894 dc.start()
895 895 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
896 896 else:
897 897 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
898 898 return eid
899 899
900 900 def unregister_engine(self, ident, msg):
901 901 """Unregister an engine that explicitly requested to leave."""
902 902 try:
903 903 eid = msg['content']['id']
904 904 except:
905 905 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
906 906 return
907 907 self.log.info("registration::unregister_engine(%s)"%eid)
908 908 # print (eid)
909 909 uuid = self.keytable[eid]
910 910 content=dict(id=eid, queue=uuid)
911 911 self.dead_engines.add(uuid)
912 912 # self.ids.remove(eid)
913 913 # uuid = self.keytable.pop(eid)
914 914 #
915 915 # ec = self.engines.pop(eid)
916 916 # self.hearts.pop(ec.heartbeat)
917 917 # self.by_ident.pop(ec.queue)
918 918 # self.completed.pop(eid)
919 919 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
920 920 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
921 921 dc.start()
922 922 ############## TODO: HANDLE IT ################
923 923
924 924 if self.notifier:
925 925 self.session.send(self.notifier, "unregistration_notification", content=content)
926 926
927 927 def _handle_stranded_msgs(self, eid, uuid):
928 928 """Handle messages known to be on an engine when the engine unregisters.
929 929
930 930 It is possible that this will fire prematurely - that is, an engine will
931 931 go down after completing a result, and the client will be notified
932 932 that the result failed and later receive the actual result.
933 933 """
934 934
935 935 outstanding = self.queues[eid]
936 936
937 937 for msg_id in outstanding:
938 938 self.pending.remove(msg_id)
939 939 self.all_completed.add(msg_id)
940 940 try:
941 941 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
942 942 except:
943 943 content = error.wrap_exception()
944 944 # build a fake header:
945 945 header = {}
946 946 header['engine'] = uuid
947 947 header['date'] = datetime.now()
948 948 rec = dict(result_content=content, result_header=header, result_buffers=[])
949 949 rec['completed'] = header['date']
950 950 rec['engine_uuid'] = uuid
951 951 try:
952 952 self.db.update_record(msg_id, rec)
953 953 except Exception:
954 954 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
955 955
956 956
957 957 def finish_registration(self, heart):
958 958 """Second half of engine registration, called after our HeartMonitor
959 959 has received a beat from the Engine's Heart."""
960 960 try:
961 961 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
962 962 except KeyError:
963 963 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
964 964 return
965 965 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
966 966 if purge is not None:
967 967 purge.stop()
968 968 control = queue
969 969 self.ids.add(eid)
970 970 self.keytable[eid] = queue
971 971 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
972 972 control=control, heartbeat=heart)
973 973 self.by_ident[queue] = eid
974 974 self.queues[eid] = list()
975 975 self.tasks[eid] = list()
976 976 self.completed[eid] = list()
977 977 self.hearts[heart] = eid
978 978 content = dict(id=eid, queue=self.engines[eid].queue)
979 979 if self.notifier:
980 980 self.session.send(self.notifier, "registration_notification", content=content)
981 981 self.log.info("engine::Engine Connected: %i"%eid)
982 982
983 983 def _purge_stalled_registration(self, heart):
984 984 if heart in self.incoming_registrations:
985 985 eid = self.incoming_registrations.pop(heart)[0]
986 986 self.log.info("registration::purging stalled registration: %i"%eid)
987 987 else:
988 988 pass
989 989
990 990 #-------------------------------------------------------------------------
991 991 # Client Requests
992 992 #-------------------------------------------------------------------------
993 993
994 994 def shutdown_request(self, client_id, msg):
995 995 """handle shutdown request."""
996 996 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
997 997 # also notify other clients of shutdown
998 998 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
999 999 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1000 1000 dc.start()
1001 1001
1002 1002 def _shutdown(self):
1003 1003 self.log.info("hub::hub shutting down.")
1004 1004 time.sleep(0.1)
1005 1005 sys.exit(0)
1006 1006
1007 1007
1008 1008 def check_load(self, client_id, msg):
1009 1009 content = msg['content']
1010 1010 try:
1011 1011 targets = content['targets']
1012 1012 targets = self._validate_targets(targets)
1013 1013 except:
1014 1014 content = error.wrap_exception()
1015 1015 self.session.send(self.query, "hub_error",
1016 1016 content=content, ident=client_id)
1017 1017 return
1018 1018
1019 1019 content = dict(status='ok')
1020 1020 # loads = {}
1021 1021 for t in targets:
1022 1022 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1023 1023 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1024 1024
1025 1025
1026 1026 def queue_status(self, client_id, msg):
1027 1027 """Return the Queue status of one or more targets.
1028 1028 if verbose: return the msg_ids
1029 1029 else: return len of each type.
1030 1030 keys: queue (pending MUX jobs)
1031 1031 tasks (pending Task jobs)
1032 1032 completed (finished jobs from both queues)"""
1033 1033 content = msg['content']
1034 1034 targets = content['targets']
1035 1035 try:
1036 1036 targets = self._validate_targets(targets)
1037 1037 except:
1038 1038 content = error.wrap_exception()
1039 1039 self.session.send(self.query, "hub_error",
1040 1040 content=content, ident=client_id)
1041 1041 return
1042 1042 verbose = content.get('verbose', False)
1043 1043 content = dict(status='ok')
1044 1044 for t in targets:
1045 1045 queue = self.queues[t]
1046 1046 completed = self.completed[t]
1047 1047 tasks = self.tasks[t]
1048 1048 if not verbose:
1049 1049 queue = len(queue)
1050 1050 completed = len(completed)
1051 1051 tasks = len(tasks)
1052 1052 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1053 1053 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1054 1054
1055 1055 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1056 1056
1057 1057 def purge_results(self, client_id, msg):
1058 1058 """Purge results from memory. This method is more valuable before we move
1059 1059 to a DB based message storage mechanism."""
1060 1060 content = msg['content']
1061 1061 msg_ids = content.get('msg_ids', [])
1062 1062 reply = dict(status='ok')
1063 1063 if msg_ids == 'all':
1064 1064 try:
1065 1065 self.db.drop_matching_records(dict(completed={'$ne':None}))
1066 1066 except Exception:
1067 1067 reply = error.wrap_exception()
1068 1068 else:
1069 for msg_id in msg_ids:
1070 if msg_id in self.all_completed:
1071 self.db.drop_record(msg_id)
1072 else:
1073 if msg_id in self.pending:
1074 try:
1075 raise IndexError("msg pending: %r"%msg_id)
1076 except:
1077 reply = error.wrap_exception()
1078 else:
1069 pending = filter(lambda m: m in self.pending, msg_ids)
1070 if pending:
1071 try:
1072 raise IndexError("msg pending: %r"%pending[0])
1073 except:
1074 reply = error.wrap_exception()
1075 else:
1076 try:
1077 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1078 except Exception:
1079 reply = error.wrap_exception()
1080
1081 if reply['status'] == 'ok':
1082 eids = content.get('engine_ids', [])
1083 for eid in eids:
1084 if eid not in self.engines:
1079 1085 try:
1080 raise IndexError("No such msg: %r"%msg_id)
1086 raise IndexError("No such engine: %i"%eid)
1081 1087 except:
1082 1088 reply = error.wrap_exception()
1083 break
1084 eids = content.get('engine_ids', [])
1085 for eid in eids:
1086 if eid not in self.engines:
1089 break
1090 msg_ids = self.completed.pop(eid)
1091 uid = self.engines[eid].queue
1087 1092 try:
1088 raise IndexError("No such engine: %i"%eid)
1089 except:
1093 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1094 except Exception:
1090 1095 reply = error.wrap_exception()
1091 break
1092 msg_ids = self.completed.pop(eid)
1093 uid = self.engines[eid].queue
1094 try:
1095 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1096 except Exception:
1097 reply = error.wrap_exception()
1098 break
1096 break
1099 1097
1100 1098 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1101 1099
1102 1100 def resubmit_task(self, client_id, msg):
1103 1101 """Resubmit one or more tasks."""
1104 1102 def finish(reply):
1105 1103 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1106 1104
1107 1105 content = msg['content']
1108 1106 msg_ids = content['msg_ids']
1109 1107 reply = dict(status='ok')
1110 1108 try:
1111 1109 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1112 1110 'header', 'content', 'buffers'])
1113 1111 except Exception:
1114 1112 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1115 1113 return finish(error.wrap_exception())
1116 1114
1117 1115 # validate msg_ids
1118 1116 found_ids = [ rec['msg_id'] for rec in records ]
1119 1117 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1120 1118 if len(records) > len(msg_ids):
1121 1119 try:
1122 1120 raise RuntimeError("DB appears to be in an inconsistent state."
1123 1121 "More matching records were found than should exist")
1124 1122 except Exception:
1125 1123 return finish(error.wrap_exception())
1126 1124 elif len(records) < len(msg_ids):
1127 1125 missing = [ m for m in msg_ids if m not in found_ids ]
1128 1126 try:
1129 1127 raise KeyError("No such msg(s): %s"%missing)
1130 1128 except KeyError:
1131 1129 return finish(error.wrap_exception())
1132 1130 elif invalid_ids:
1133 1131 msg_id = invalid_ids[0]
1134 1132 try:
1135 1133 raise ValueError("Task %r appears to be inflight"%(msg_id))
1136 1134 except Exception:
1137 1135 return finish(error.wrap_exception())
1138 1136
1139 1137 # clear the existing records
1140 1138 rec = empty_record()
1141 1139 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 1140 rec['resubmitted'] = datetime.now()
1143 1141 rec['queue'] = 'task'
1144 1142 rec['client_uuid'] = client_id[0]
1145 1143 try:
1146 1144 for msg_id in msg_ids:
1147 1145 self.all_completed.discard(msg_id)
1148 1146 self.db.update_record(msg_id, rec)
1149 1147 except Exception:
1150 1148 self.log.error('db::db error upating record', exc_info=True)
1151 1149 reply = error.wrap_exception()
1152 1150 else:
1153 1151 # send the messages
1154 1152 for rec in records:
1155 1153 header = rec['header']
1156 1154 msg = self.session.msg(header['msg_type'])
1157 1155 msg['content'] = rec['content']
1158 1156 msg['header'] = header
1159 1157 msg['msg_id'] = rec['msg_id']
1160 1158 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1161 1159
1162 1160 finish(dict(status='ok'))
1163 1161
1164 1162
1165 1163 def _extract_record(self, rec):
1166 1164 """decompose a TaskRecord dict into subsection of reply for get_result"""
1167 1165 io_dict = {}
1168 1166 for key in 'pyin pyout pyerr stdout stderr'.split():
1169 1167 io_dict[key] = rec[key]
1170 1168 content = { 'result_content': rec['result_content'],
1171 1169 'header': rec['header'],
1172 1170 'result_header' : rec['result_header'],
1173 1171 'io' : io_dict,
1174 1172 }
1175 1173 if rec['result_buffers']:
1176 1174 buffers = map(str, rec['result_buffers'])
1177 1175 else:
1178 1176 buffers = []
1179 1177
1180 1178 return content, buffers
1181 1179
1182 1180 def get_results(self, client_id, msg):
1183 1181 """Get the result of 1 or more messages."""
1184 1182 content = msg['content']
1185 1183 msg_ids = sorted(set(content['msg_ids']))
1186 1184 statusonly = content.get('status_only', False)
1187 1185 pending = []
1188 1186 completed = []
1189 1187 content = dict(status='ok')
1190 1188 content['pending'] = pending
1191 1189 content['completed'] = completed
1192 1190 buffers = []
1193 1191 if not statusonly:
1194 1192 try:
1195 1193 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1196 1194 # turn match list into dict, for faster lookup
1197 1195 records = {}
1198 1196 for rec in matches:
1199 1197 records[rec['msg_id']] = rec
1200 1198 except Exception:
1201 1199 content = error.wrap_exception()
1202 1200 self.session.send(self.query, "result_reply", content=content,
1203 1201 parent=msg, ident=client_id)
1204 1202 return
1205 1203 else:
1206 1204 records = {}
1207 1205 for msg_id in msg_ids:
1208 1206 if msg_id in self.pending:
1209 1207 pending.append(msg_id)
1210 1208 elif msg_id in self.all_completed:
1211 1209 completed.append(msg_id)
1212 1210 if not statusonly:
1213 1211 c,bufs = self._extract_record(records[msg_id])
1214 1212 content[msg_id] = c
1215 1213 buffers.extend(bufs)
1216 1214 elif msg_id in records:
1217 1215 if rec['completed']:
1218 1216 completed.append(msg_id)
1219 1217 c,bufs = self._extract_record(records[msg_id])
1220 1218 content[msg_id] = c
1221 1219 buffers.extend(bufs)
1222 1220 else:
1223 1221 pending.append(msg_id)
1224 1222 else:
1225 1223 try:
1226 1224 raise KeyError('No such message: '+msg_id)
1227 1225 except:
1228 1226 content = error.wrap_exception()
1229 1227 break
1230 1228 self.session.send(self.query, "result_reply", content=content,
1231 1229 parent=msg, ident=client_id,
1232 1230 buffers=buffers)
1233 1231
1234 1232 def get_history(self, client_id, msg):
1235 1233 """Get a list of all msg_ids in our DB records"""
1236 1234 try:
1237 1235 msg_ids = self.db.get_history()
1238 1236 except Exception as e:
1239 1237 content = error.wrap_exception()
1240 1238 else:
1241 1239 content = dict(status='ok', history=msg_ids)
1242 1240
1243 1241 self.session.send(self.query, "history_reply", content=content,
1244 1242 parent=msg, ident=client_id)
1245 1243
1246 1244 def db_query(self, client_id, msg):
1247 1245 """Perform a raw query on the task record database."""
1248 1246 content = msg['content']
1249 1247 query = content.get('query', {})
1250 1248 keys = content.get('keys', None)
1251 1249 query = util.extract_dates(query)
1252 1250 buffers = []
1253 1251 empty = list()
1254 1252
1255 1253 try:
1256 1254 records = self.db.find_records(query, keys)
1257 1255 except Exception as e:
1258 1256 content = error.wrap_exception()
1259 1257 else:
1260 1258 # extract buffers from reply content:
1261 1259 if keys is not None:
1262 1260 buffer_lens = [] if 'buffers' in keys else None
1263 1261 result_buffer_lens = [] if 'result_buffers' in keys else None
1264 1262 else:
1265 1263 buffer_lens = []
1266 1264 result_buffer_lens = []
1267 1265
1268 1266 for rec in records:
1269 1267 # buffers may be None, so double check
1270 1268 if buffer_lens is not None:
1271 1269 b = rec.pop('buffers', empty) or empty
1272 1270 buffer_lens.append(len(b))
1273 1271 buffers.extend(b)
1274 1272 if result_buffer_lens is not None:
1275 1273 rb = rec.pop('result_buffers', empty) or empty
1276 1274 result_buffer_lens.append(len(rb))
1277 1275 buffers.extend(rb)
1278 1276 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1279 1277 result_buffer_lens=result_buffer_lens)
1280 1278
1281 1279 self.session.send(self.query, "db_reply", content=content,
1282 1280 parent=msg, ident=client_id,
1283 1281 buffers=buffers)
1284 1282
@@ -1,96 +1,101 b''
1 1 """A TaskRecord backend using mongodb"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 from datetime import datetime
10
11 9 from pymongo import Connection
12 10 from pymongo.binary import Binary
13 11
14 from IPython.utils.traitlets import Dict, List, CUnicode
12 from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance
15 13
16 14 from .dictdb import BaseDB
17 15
18 16 #-----------------------------------------------------------------------------
19 17 # MongoDB class
20 18 #-----------------------------------------------------------------------------
21 19
22 20 class MongoDB(BaseDB):
23 21 """MongoDB TaskRecord backend."""
24 22
25 23 connection_args = List(config=True) # args passed to pymongo.Connection
26 24 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
27 25 database = CUnicode(config=True) # name of the mongodb database
28 _table = Dict()
26
27 _connection = Instance(Connection) # pymongo connection
29 28
30 29 def __init__(self, **kwargs):
31 30 super(MongoDB, self).__init__(**kwargs)
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
31 if self._connection is None:
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
33 33 if not self.database:
34 34 self.database = self.session
35 35 self._db = self._connection[self.database]
36 36 self._records = self._db['task_records']
37 self._records.ensure_index('msg_id', unique=True)
38 self._records.ensure_index('submitted') # for sorting history
39 # for rec in self._records.find
37 40
38 41 def _binary_buffers(self, rec):
39 42 for key in ('buffers', 'result_buffers'):
40 43 if rec.get(key, None):
41 44 rec[key] = map(Binary, rec[key])
42 45 return rec
43 46
44 47 def add_record(self, msg_id, rec):
45 48 """Add a new Task Record, by msg_id."""
46 49 # print rec
47 50 rec = self._binary_buffers(rec)
48 obj_id = self._records.insert(rec)
49 self._table[msg_id] = obj_id
51 self._records.insert(rec)
50 52
51 53 def get_record(self, msg_id):
52 54 """Get a specific Task Record, by msg_id."""
53 return self._records.find_one(self._table[msg_id])
55 r = self._records.find_one({'msg_id': msg_id})
56 if not r:
57 # r will be '' if nothing is found
58 raise KeyError(msg_id)
59 return r
54 60
55 61 def update_record(self, msg_id, rec):
56 62 """Update the data in an existing record."""
57 63 rec = self._binary_buffers(rec)
58 obj_id = self._table[msg_id]
59 self._records.update({'_id':obj_id}, {'$set': rec})
64
65 self._records.update({'msg_id':msg_id}, {'$set': rec})
60 66
61 67 def drop_matching_records(self, check):
62 68 """Remove a record from the DB."""
63 69 self._records.remove(check)
64 70
65 71 def drop_record(self, msg_id):
66 72 """Remove a record from the DB."""
67 obj_id = self._table.pop(msg_id)
68 self._records.remove(obj_id)
73 self._records.remove({'msg_id':msg_id})
69 74
70 75 def find_records(self, check, keys=None):
71 76 """Find records matching a query dict, optionally extracting subset of keys.
72 77
73 78 Returns list of matching records.
74 79
75 80 Parameters
76 81 ----------
77 82
78 83 check: dict
79 84 mongodb-style query argument
80 85 keys: list of strs [optional]
81 86 if specified, the subset of keys to extract. msg_id will *always* be
82 87 included.
83 88 """
84 89 if keys and 'msg_id' not in keys:
85 90 keys.append('msg_id')
86 91 matches = list(self._records.find(check,keys))
87 92 for rec in matches:
88 93 rec.pop('_id')
89 94 return matches
90 95
91 96 def get_history(self):
92 97 """get all msg_ids, ordered by time submitted."""
93 98 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
94 99 return [ rec['msg_id'] for rec in cursor ]
95 100
96 101
@@ -1,312 +1,326 b''
1 1 """A TaskRecord backend using sqlite3"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 import json
10 10 import os
11 11 import cPickle as pickle
12 12 from datetime import datetime
13 13
14 14 import sqlite3
15 15
16 16 from zmq.eventloop import ioloop
17 17
18 18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
19 19 from .dictdb import BaseDB
20 20 from IPython.parallel.util import ISO8601
21 21
22 22 #-----------------------------------------------------------------------------
23 23 # SQLite operators, adapters, and converters
24 24 #-----------------------------------------------------------------------------
25 25
26 26 operators = {
27 27 '$lt' : "<",
28 28 '$gt' : ">",
29 29 # null is handled weird with ==,!=
30 '$eq' : "IS",
31 '$ne' : "IS NOT",
30 '$eq' : "=",
31 '$ne' : "!=",
32 32 '$lte': "<=",
33 33 '$gte': ">=",
34 '$in' : ('IS', ' OR '),
35 '$nin': ('IS NOT', ' AND '),
34 '$in' : ('=', ' OR '),
35 '$nin': ('!=', ' AND '),
36 36 # '$all': None,
37 37 # '$mod': None,
38 38 # '$exists' : None
39 39 }
40 null_operators = {
41 '=' : "IS NULL",
42 '!=' : "IS NOT NULL",
43 }
40 44
41 45 def _adapt_datetime(dt):
42 46 return dt.strftime(ISO8601)
43 47
44 48 def _convert_datetime(ds):
45 49 if ds is None:
46 50 return ds
47 51 else:
48 52 return datetime.strptime(ds, ISO8601)
49 53
50 54 def _adapt_dict(d):
51 55 return json.dumps(d)
52 56
53 57 def _convert_dict(ds):
54 58 if ds is None:
55 59 return ds
56 60 else:
57 61 return json.loads(ds)
58 62
59 63 def _adapt_bufs(bufs):
60 64 # this is *horrible*
61 65 # copy buffers into single list and pickle it:
62 66 if bufs and isinstance(bufs[0], (bytes, buffer)):
63 67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
64 68 elif bufs:
65 69 return bufs
66 70 else:
67 71 return None
68 72
69 73 def _convert_bufs(bs):
70 74 if bs is None:
71 75 return []
72 76 else:
73 77 return pickle.loads(bytes(bs))
74 78
75 79 #-----------------------------------------------------------------------------
76 80 # SQLiteDB class
77 81 #-----------------------------------------------------------------------------
78 82
79 83 class SQLiteDB(BaseDB):
80 84 """SQLite3 TaskRecord backend."""
81 85
82 86 filename = CUnicode('tasks.db', config=True)
83 87 location = CUnicode('', config=True)
84 88 table = CUnicode("", config=True)
85 89
86 90 _db = Instance('sqlite3.Connection')
87 91 _keys = List(['msg_id' ,
88 92 'header' ,
89 93 'content',
90 94 'buffers',
91 95 'submitted',
92 96 'client_uuid' ,
93 97 'engine_uuid' ,
94 98 'started',
95 99 'completed',
96 100 'resubmitted',
97 101 'result_header' ,
98 102 'result_content' ,
99 103 'result_buffers' ,
100 104 'queue' ,
101 105 'pyin' ,
102 106 'pyout',
103 107 'pyerr',
104 108 'stdout',
105 109 'stderr',
106 110 ])
107 111
108 112 def __init__(self, **kwargs):
109 113 super(SQLiteDB, self).__init__(**kwargs)
110 114 if not self.table:
111 115 # use session, and prefix _, since starting with # is illegal
112 116 self.table = '_'+self.session.replace('-','_')
113 117 if not self.location:
114 118 if hasattr(self.config.Global, 'cluster_dir'):
115 119 self.location = self.config.Global.cluster_dir
116 120 else:
117 121 self.location = '.'
118 122 self._init_db()
119 123
120 124 # register db commit as 2s periodic callback
121 125 # to prevent clogging pipes
122 126 # assumes we are being run in a zmq ioloop app
123 127 loop = ioloop.IOLoop.instance()
124 128 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 129 pc.start()
126 130
127 131 def _defaults(self, keys=None):
128 132 """create an empty record"""
129 133 d = {}
130 134 keys = self._keys if keys is None else keys
131 135 for key in keys:
132 136 d[key] = None
133 137 return d
134 138
135 139 def _init_db(self):
136 140 """Connect to the database and get new session number."""
137 141 # register adapters
138 142 sqlite3.register_adapter(datetime, _adapt_datetime)
139 143 sqlite3.register_converter('datetime', _convert_datetime)
140 144 sqlite3.register_adapter(dict, _adapt_dict)
141 145 sqlite3.register_converter('dict', _convert_dict)
142 146 sqlite3.register_adapter(list, _adapt_bufs)
143 147 sqlite3.register_converter('bufs', _convert_bufs)
144 148 # connect to the db
145 149 dbfile = os.path.join(self.location, self.filename)
146 150 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
147 151 # isolation_level = None)#,
148 152 cached_statements=64)
149 153 # print dir(self._db)
150 154
151 155 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
152 156 (msg_id text PRIMARY KEY,
153 157 header dict text,
154 158 content dict text,
155 159 buffers bufs blob,
156 160 submitted datetime text,
157 161 client_uuid text,
158 162 engine_uuid text,
159 163 started datetime text,
160 164 completed datetime text,
161 165 resubmitted datetime text,
162 166 result_header dict text,
163 167 result_content dict text,
164 168 result_buffers bufs blob,
165 169 queue text,
166 170 pyin text,
167 171 pyout text,
168 172 pyerr text,
169 173 stdout text,
170 174 stderr text)
171 175 """%self.table)
172 176 self._db.commit()
173 177
174 178 def _dict_to_list(self, d):
175 179 """turn a mongodb-style record dict into a list."""
176 180
177 181 return [ d[key] for key in self._keys ]
178 182
179 183 def _list_to_dict(self, line, keys=None):
180 184 """Inverse of dict_to_list"""
181 185 keys = self._keys if keys is None else keys
182 186 d = self._defaults(keys)
183 187 for key,value in zip(keys, line):
184 188 d[key] = value
185 189
186 190 return d
187 191
188 192 def _render_expression(self, check):
189 193 """Turn a mongodb-style search dict into an SQL query."""
190 194 expressions = []
191 195 args = []
192 196
193 197 skeys = set(check.keys())
194 198 skeys.difference_update(set(self._keys))
195 199 skeys.difference_update(set(['buffers', 'result_buffers']))
196 200 if skeys:
197 201 raise KeyError("Illegal testing key(s): %s"%skeys)
198 202
199 203 for name,sub_check in check.iteritems():
200 204 if isinstance(sub_check, dict):
201 205 for test,value in sub_check.iteritems():
202 206 try:
203 207 op = operators[test]
204 208 except KeyError:
205 209 raise KeyError("Unsupported operator: %r"%test)
206 210 if isinstance(op, tuple):
207 211 op, join = op
208 expr = "%s %s ?"%(name, op)
209 if isinstance(value, (tuple,list)):
210 expr = '( %s )'%( join.join([expr]*len(value)) )
211 args.extend(value)
212
213 if value is None and op in null_operators:
214 expr = "%s %s"%null_operators[op]
212 215 else:
213 args.append(value)
216 expr = "%s %s ?"%(name, op)
217 if isinstance(value, (tuple,list)):
218 if op in null_operators and any([v is None for v in value]):
219 # equality tests don't work with NULL
220 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
221 expr = '( %s )'%( join.join([expr]*len(value)) )
222 args.extend(value)
223 else:
224 args.append(value)
214 225 expressions.append(expr)
215 226 else:
216 227 # it's an equality check
217 expressions.append("%s IS ?"%name)
218 args.append(sub_check)
228 if sub_check is None:
229 expressions.append("%s IS NULL")
230 else:
231 expressions.append("%s = ?"%name)
232 args.append(sub_check)
219 233
220 234 expr = " AND ".join(expressions)
221 235 return expr, args
222 236
223 237 def add_record(self, msg_id, rec):
224 238 """Add a new Task Record, by msg_id."""
225 239 d = self._defaults()
226 240 d.update(rec)
227 241 d['msg_id'] = msg_id
228 242 line = self._dict_to_list(d)
229 243 tups = '(%s)'%(','.join(['?']*len(line)))
230 244 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
231 245 # self._db.commit()
232 246
233 247 def get_record(self, msg_id):
234 248 """Get a specific Task Record, by msg_id."""
235 249 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
236 250 line = cursor.fetchone()
237 251 if line is None:
238 252 raise KeyError("No such msg: %r"%msg_id)
239 253 return self._list_to_dict(line)
240 254
241 255 def update_record(self, msg_id, rec):
242 256 """Update the data in an existing record."""
243 257 query = "UPDATE %s SET "%self.table
244 258 sets = []
245 259 keys = sorted(rec.keys())
246 260 values = []
247 261 for key in keys:
248 262 sets.append('%s = ?'%key)
249 263 values.append(rec[key])
250 264 query += ', '.join(sets)
251 265 query += ' WHERE msg_id == ?'
252 266 values.append(msg_id)
253 267 self._db.execute(query, values)
254 268 # self._db.commit()
255 269
256 270 def drop_record(self, msg_id):
257 271 """Remove a record from the DB."""
258 272 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
259 273 # self._db.commit()
260 274
261 275 def drop_matching_records(self, check):
262 276 """Remove a record from the DB."""
263 277 expr,args = self._render_expression(check)
264 278 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
265 279 self._db.execute(query,args)
266 280 # self._db.commit()
267 281
268 282 def find_records(self, check, keys=None):
269 283 """Find records matching a query dict, optionally extracting subset of keys.
270 284
271 285 Returns list of matching records.
272 286
273 287 Parameters
274 288 ----------
275 289
276 290 check: dict
277 291 mongodb-style query argument
278 292 keys: list of strs [optional]
279 293 if specified, the subset of keys to extract. msg_id will *always* be
280 294 included.
281 295 """
282 296 if keys:
283 297 bad_keys = [ key for key in keys if key not in self._keys ]
284 298 if bad_keys:
285 299 raise KeyError("Bad record key(s): %s"%bad_keys)
286 300
287 301 if keys:
288 302 # ensure msg_id is present and first:
289 303 if 'msg_id' in keys:
290 304 keys.remove('msg_id')
291 305 keys.insert(0, 'msg_id')
292 306 req = ', '.join(keys)
293 307 else:
294 308 req = '*'
295 309 expr,args = self._render_expression(check)
296 310 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
297 311 cursor = self._db.execute(query, args)
298 312 matches = cursor.fetchall()
299 313 records = []
300 314 for line in matches:
301 315 rec = self._list_to_dict(line, keys)
302 316 records.append(rec)
303 317 return records
304 318
305 319 def get_history(self):
306 320 """get all msg_ids, ordered by time submitted."""
307 321 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
308 322 cursor = self._db.execute(query)
309 323 # will be a list of length 1 tuples
310 324 return [ tup[0] for tup in cursor.fetchall()]
311 325
312 326 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,237 +1,244 b''
1 1 """Tests for parallel client.py"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import time
15 15 from datetime import datetime
16 16 from tempfile import mktemp
17 17
18 18 import zmq
19 19
20 20 from IPython.parallel.client import client as clientmod
21 21 from IPython.parallel import error
22 22 from IPython.parallel import AsyncResult, AsyncHubResult
23 23 from IPython.parallel import LoadBalancedView, DirectView
24 24
25 25 from clienttest import ClusterTestCase, segfault, wait, add_engines
26 26
27 27 def setup():
28 28 add_engines(4)
29 29
30 30 class TestClient(ClusterTestCase):
31 31
32 32 def test_ids(self):
33 33 n = len(self.client.ids)
34 34 self.add_engines(3)
35 35 self.assertEquals(len(self.client.ids), n+3)
36 36
37 37 def test_view_indexing(self):
38 38 """test index access for views"""
39 39 self.add_engines(2)
40 40 targets = self.client._build_targets('all')[-1]
41 41 v = self.client[:]
42 42 self.assertEquals(v.targets, targets)
43 43 t = self.client.ids[2]
44 44 v = self.client[t]
45 45 self.assert_(isinstance(v, DirectView))
46 46 self.assertEquals(v.targets, t)
47 47 t = self.client.ids[2:4]
48 48 v = self.client[t]
49 49 self.assert_(isinstance(v, DirectView))
50 50 self.assertEquals(v.targets, t)
51 51 v = self.client[::2]
52 52 self.assert_(isinstance(v, DirectView))
53 53 self.assertEquals(v.targets, targets[::2])
54 54 v = self.client[1::3]
55 55 self.assert_(isinstance(v, DirectView))
56 56 self.assertEquals(v.targets, targets[1::3])
57 57 v = self.client[:-3]
58 58 self.assert_(isinstance(v, DirectView))
59 59 self.assertEquals(v.targets, targets[:-3])
60 60 v = self.client[-1]
61 61 self.assert_(isinstance(v, DirectView))
62 62 self.assertEquals(v.targets, targets[-1])
63 63 self.assertRaises(TypeError, lambda : self.client[None])
64 64
65 65 def test_lbview_targets(self):
66 66 """test load_balanced_view targets"""
67 67 v = self.client.load_balanced_view()
68 68 self.assertEquals(v.targets, None)
69 69 v = self.client.load_balanced_view(-1)
70 70 self.assertEquals(v.targets, [self.client.ids[-1]])
71 71 v = self.client.load_balanced_view('all')
72 72 self.assertEquals(v.targets, self.client.ids)
73 73
74 74 def test_targets(self):
75 75 """test various valid targets arguments"""
76 76 build = self.client._build_targets
77 77 ids = self.client.ids
78 78 idents,targets = build(None)
79 79 self.assertEquals(ids, targets)
80 80
81 81 def test_clear(self):
82 82 """test clear behavior"""
83 83 # self.add_engines(2)
84 84 v = self.client[:]
85 85 v.block=True
86 86 v.push(dict(a=5))
87 87 v.pull('a')
88 88 id0 = self.client.ids[-1]
89 89 self.client.clear(targets=id0, block=True)
90 90 a = self.client[:-1].get('a')
91 91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 92 self.client.clear(block=True)
93 93 for i in self.client.ids:
94 94 # print i
95 95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96 96
97 97 def test_get_result(self):
98 98 """test getting results from the Hub."""
99 99 c = clientmod.Client(profile='iptest')
100 100 # self.add_engines(1)
101 101 t = c.ids[-1]
102 102 ar = c[t].apply_async(wait, 1)
103 103 # give the monitor time to notice the message
104 104 time.sleep(.25)
105 105 ahr = self.client.get_result(ar.msg_ids)
106 106 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 107 self.assertEquals(ahr.get(), ar.get())
108 108 ar2 = self.client.get_result(ar.msg_ids)
109 109 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 110 c.close()
111 111
112 112 def test_ids_list(self):
113 113 """test client.ids"""
114 114 # self.add_engines(2)
115 115 ids = self.client.ids
116 116 self.assertEquals(ids, self.client._ids)
117 117 self.assertFalse(ids is self.client._ids)
118 118 ids.remove(ids[-1])
119 119 self.assertNotEquals(ids, self.client._ids)
120 120
121 121 def test_queue_status(self):
122 122 # self.addEngine(4)
123 123 ids = self.client.ids
124 124 id0 = ids[0]
125 125 qs = self.client.queue_status(targets=id0)
126 126 self.assertTrue(isinstance(qs, dict))
127 127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 128 allqs = self.client.queue_status()
129 129 self.assertTrue(isinstance(allqs, dict))
130 130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 131 unassigned = allqs.pop('unassigned')
132 132 for eid,qs in allqs.items():
133 133 self.assertTrue(isinstance(qs, dict))
134 134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135 135
136 136 def test_shutdown(self):
137 137 # self.addEngine(4)
138 138 ids = self.client.ids
139 139 id0 = ids[0]
140 140 self.client.shutdown(id0, block=True)
141 141 while id0 in self.client.ids:
142 142 time.sleep(0.1)
143 143 self.client.spin()
144 144
145 145 self.assertRaises(IndexError, lambda : self.client[id0])
146 146
147 147 def test_result_status(self):
148 148 pass
149 149 # to be written
150 150
151 151 def test_db_query_dt(self):
152 152 """test db query by date"""
153 153 hist = self.client.hub_history()
154 154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 155 tic = middle['submitted']
156 156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 158 self.assertEquals(len(before)+len(after),len(hist))
159 159 for b in before:
160 160 self.assertTrue(b['submitted'] < tic)
161 161 for a in after:
162 162 self.assertTrue(a['submitted'] >= tic)
163 163 same = self.client.db_query({'submitted' : tic})
164 164 for s in same:
165 165 self.assertTrue(s['submitted'] == tic)
166 166
167 167 def test_db_query_keys(self):
168 168 """test extracting subset of record keys"""
169 169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 170 for rec in found:
171 171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172 172
173 173 def test_db_query_msg_id(self):
174 174 """ensure msg_id is always in db queries"""
175 175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 176 for rec in found:
177 177 self.assertTrue('msg_id' in rec.keys())
178 178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 179 for rec in found:
180 180 self.assertTrue('msg_id' in rec.keys())
181 181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 182 for rec in found:
183 183 self.assertTrue('msg_id' in rec.keys())
184 184
185 185 def test_db_query_in(self):
186 186 """test db query with '$in','$nin' operators"""
187 187 hist = self.client.hub_history()
188 188 even = hist[::2]
189 189 odd = hist[1::2]
190 190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 191 found = [ r['msg_id'] for r in recs ]
192 192 self.assertEquals(set(even), set(found))
193 193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 194 found = [ r['msg_id'] for r in recs ]
195 195 self.assertEquals(set(odd), set(found))
196 196
197 197 def test_hub_history(self):
198 198 hist = self.client.hub_history()
199 199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 200 recdict = {}
201 201 for rec in recs:
202 202 recdict[rec['msg_id']] = rec
203 203
204 204 latest = datetime(1984,1,1)
205 205 for msg_id in hist:
206 206 rec = recdict[msg_id]
207 207 newt = rec['submitted']
208 208 self.assertTrue(newt >= latest)
209 209 latest = newt
210 210 ar = self.client[-1].apply_async(lambda : 1)
211 211 ar.get()
212 212 time.sleep(0.25)
213 213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214 214
215 215 def test_resubmit(self):
216 216 def f():
217 217 import random
218 218 return random.random()
219 219 v = self.client.load_balanced_view()
220 220 ar = v.apply_async(f)
221 221 r1 = ar.get(1)
222 222 ahr = self.client.resubmit(ar.msg_ids)
223 223 r2 = ahr.get(1)
224 224 self.assertFalse(r1 == r2)
225 225
226 226 def test_resubmit_inflight(self):
227 227 """ensure ValueError on resubmit of inflight task"""
228 228 v = self.client.load_balanced_view()
229 229 ar = v.apply_async(time.sleep,1)
230 230 # give the message a chance to arrive
231 231 time.sleep(0.2)
232 232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 233 ar.get(2)
234 234
235 235 def test_resubmit_badkey(self):
236 236 """ensure KeyError on resubmit of nonexistant task"""
237 237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
238
239 def test_purge_results(self):
240 hist = self.client.hub_history()
241 self.client.purge_results(hist)
242 newhist = self.client.hub_history()
243 self.assertTrue(len(newhist) == 0)
244
@@ -1,182 +1,170 b''
1 1 """Tests for db backends"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14
15 15 import tempfile
16 16 import time
17 17
18 import uuid
19
20 18 from datetime import datetime, timedelta
21 from random import choice, randint
22 19 from unittest import TestCase
23 20
24 21 from nose import SkipTest
25 22
26 23 from IPython.parallel import error, streamsession as ss
27 24 from IPython.parallel.controller.dictdb import DictDB
28 25 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 26 from IPython.parallel.controller.hub import init_record, empty_record
30 27
31 28 #-------------------------------------------------------------------------------
32 29 # TestCases
33 30 #-------------------------------------------------------------------------------
34 31
35 32 class TestDictBackend(TestCase):
36 33 def setUp(self):
37 34 self.session = ss.StreamSession()
38 35 self.db = self.create_db()
39 36 self.load_records(16)
40 37
41 38 def create_db(self):
42 39 return DictDB()
43 40
44 41 def load_records(self, n=1):
45 42 """load n records for testing"""
46 43 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 44 time.sleep(0.1)
48 45 msg_ids = []
49 46 for i in range(n):
50 47 msg = self.session.msg('apply_request', content=dict(a=5))
51 48 msg['buffers'] = []
52 49 rec = init_record(msg)
53 50 msg_ids.append(msg['msg_id'])
54 51 self.db.add_record(msg['msg_id'], rec)
55 52 return msg_ids
56 53
57 54 def test_add_record(self):
58 55 before = self.db.get_history()
59 56 self.load_records(5)
60 57 after = self.db.get_history()
61 58 self.assertEquals(len(after), len(before)+5)
62 59 self.assertEquals(after[:-5],before)
63 60
64 61 def test_drop_record(self):
65 62 msg_id = self.load_records()[-1]
66 63 rec = self.db.get_record(msg_id)
67 64 self.db.drop_record(msg_id)
68 65 self.assertRaises(KeyError,self.db.get_record, msg_id)
69 66
70 67 def _round_to_millisecond(self, dt):
71 68 """necessary because mongodb rounds microseconds"""
72 69 micro = dt.microsecond
73 70 extra = int(str(micro)[-3:])
74 71 return dt - timedelta(microseconds=extra)
75 72
76 73 def test_update_record(self):
77 74 now = self._round_to_millisecond(datetime.now())
78 75 #
79 76 msg_id = self.db.get_history()[-1]
80 77 rec1 = self.db.get_record(msg_id)
81 78 data = {'stdout': 'hello there', 'completed' : now}
82 79 self.db.update_record(msg_id, data)
83 80 rec2 = self.db.get_record(msg_id)
84 81 self.assertEquals(rec2['stdout'], 'hello there')
85 82 self.assertEquals(rec2['completed'], now)
86 83 rec1.update(data)
87 84 self.assertEquals(rec1, rec2)
88 85
89 86 # def test_update_record_bad(self):
90 87 # """test updating nonexistant records"""
91 88 # msg_id = str(uuid.uuid4())
92 89 # data = {'stdout': 'hello there'}
93 90 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94 91
95 92 def test_find_records_dt(self):
96 93 """test finding records by date"""
97 94 hist = self.db.get_history()
98 95 middle = self.db.get_record(hist[len(hist)/2])
99 96 tic = middle['submitted']
100 97 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 98 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 99 self.assertEquals(len(before)+len(after),len(hist))
103 100 for b in before:
104 101 self.assertTrue(b['submitted'] < tic)
105 102 for a in after:
106 103 self.assertTrue(a['submitted'] >= tic)
107 104 same = self.db.find_records({'submitted' : tic})
108 105 for s in same:
109 106 self.assertTrue(s['submitted'] == tic)
110 107
111 108 def test_find_records_keys(self):
112 109 """test extracting subset of record keys"""
113 110 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 111 for rec in found:
115 112 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116 113
117 114 def test_find_records_msg_id(self):
118 115 """ensure msg_id is always in found records"""
119 116 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 117 for rec in found:
121 118 self.assertTrue('msg_id' in rec.keys())
122 119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 120 for rec in found:
124 121 self.assertTrue('msg_id' in rec.keys())
125 122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 123 for rec in found:
127 124 self.assertTrue('msg_id' in rec.keys())
128 125
129 126 def test_find_records_in(self):
130 127 """test finding records with '$in','$nin' operators"""
131 128 hist = self.db.get_history()
132 129 even = hist[::2]
133 130 odd = hist[1::2]
134 131 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 132 found = [ r['msg_id'] for r in recs ]
136 133 self.assertEquals(set(even), set(found))
137 134 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 135 found = [ r['msg_id'] for r in recs ]
139 136 self.assertEquals(set(odd), set(found))
140 137
141 138 def test_get_history(self):
142 139 msg_ids = self.db.get_history()
143 140 latest = datetime(1984,1,1)
144 141 for msg_id in msg_ids:
145 142 rec = self.db.get_record(msg_id)
146 143 newt = rec['submitted']
147 144 self.assertTrue(newt >= latest)
148 145 latest = newt
149 146 msg_id = self.load_records(1)[-1]
150 147 self.assertEquals(self.db.get_history()[-1],msg_id)
151 148
152 149 def test_datetime(self):
153 150 """get/set timestamps with datetime objects"""
154 151 msg_id = self.db.get_history()[-1]
155 152 rec = self.db.get_record(msg_id)
156 153 self.assertTrue(isinstance(rec['submitted'], datetime))
157 154 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 155 rec = self.db.get_record(msg_id)
159 156 self.assertTrue(isinstance(rec['completed'], datetime))
157
158 def test_drop_matching(self):
159 msg_ids = self.load_records(10)
160 query = {'msg_id' : {'$in':msg_ids}}
161 self.db.drop_matching_records(query)
162 recs = self.db.find_records(query)
163 self.assertTrue(len(recs)==0)
160 164
161 165 class TestSQLiteBackend(TestDictBackend):
162 166 def create_db(self):
163 167 return SQLiteDB(location=tempfile.gettempdir())
164 168
165 169 def tearDown(self):
166 170 self.db._db.close()
167
168 # optional MongoDB test
169 try:
170 from IPython.parallel.controller.mongodb import MongoDB
171 except ImportError:
172 pass
173 else:
174 class TestMongoBackend(TestDictBackend):
175 def create_db(self):
176 try:
177 return MongoDB(database='iptestdb')
178 except Exception:
179 raise SkipTest("Couldn't connect to mongodb instance")
180
181 def tearDown(self):
182 self.db._connection.drop_database('iptestdb')
@@ -1,440 +1,441 b''
1 1 # -*- coding: utf-8 -*-
2 2 """IPython Test Suite Runner.
3 3
4 4 This module provides a main entry point to a user script to test IPython
5 5 itself from the command line. There are two ways of running this script:
6 6
7 7 1. With the syntax `iptest all`. This runs our entire test suite by
8 8 calling this script (with different arguments) recursively. This
9 9 causes modules and package to be tested in different processes, using nose
10 10 or trial where appropriate.
11 11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
12 12 the script simply calls nose, but with special command line flags and
13 13 plugins loaded.
14 14
15 15 """
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2009 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 # Stdlib
29 29 import os
30 30 import os.path as path
31 31 import signal
32 32 import sys
33 33 import subprocess
34 34 import tempfile
35 35 import time
36 36 import warnings
37 37
38 38 # Note: monkeypatch!
39 39 # We need to monkeypatch a small problem in nose itself first, before importing
40 40 # it for actual use. This should get into nose upstream, but its release cycle
41 41 # is slow and we need it for our parametric tests to work correctly.
42 42 from IPython.testing import nosepatch
43 43 # Now, proceed to import nose itself
44 44 import nose.plugins.builtin
45 45 from nose.core import TestProgram
46 46
47 47 # Our own imports
48 48 from IPython.utils.path import get_ipython_module_path
49 49 from IPython.utils.process import find_cmd, pycmd2argv
50 50 from IPython.utils.sysinfo import sys_info
51 51
52 52 from IPython.testing import globalipapp
53 53 from IPython.testing.plugin.ipdoctest import IPythonDoctest
54 54 from IPython.external.decorators import KnownFailure
55 55
56 56 pjoin = path.join
57 57
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # Globals
61 61 #-----------------------------------------------------------------------------
62 62
63 63
64 64 #-----------------------------------------------------------------------------
65 65 # Warnings control
66 66 #-----------------------------------------------------------------------------
67 67
68 68 # Twisted generates annoying warnings with Python 2.6, as will do other code
69 69 # that imports 'sets' as of today
70 70 warnings.filterwarnings('ignore', 'the sets module is deprecated',
71 71 DeprecationWarning )
72 72
73 73 # This one also comes from Twisted
74 74 warnings.filterwarnings('ignore', 'the sha module is deprecated',
75 75 DeprecationWarning)
76 76
77 77 # Wx on Fedora11 spits these out
78 78 warnings.filterwarnings('ignore', 'wxPython/wxWidgets release number mismatch',
79 79 UserWarning)
80 80
81 81 #-----------------------------------------------------------------------------
82 82 # Logic for skipping doctests
83 83 #-----------------------------------------------------------------------------
84 84
85 85 def test_for(mod, min_version=None):
86 86 """Test to see if mod is importable."""
87 87 try:
88 88 __import__(mod)
89 89 except (ImportError, RuntimeError):
90 90 # GTK reports Runtime error if it can't be initialized even if it's
91 91 # importable.
92 92 return False
93 93 else:
94 94 if min_version:
95 95 return sys.modules[mod].__version__ >= min_version
96 96 else:
97 97 return True
98 98
99 99 # Global dict where we can store information on what we have and what we don't
100 100 # have available at test run time
101 101 have = {}
102 102
103 103 have['curses'] = test_for('_curses')
104 104 have['matplotlib'] = test_for('matplotlib')
105 105 have['pexpect'] = test_for('pexpect')
106 106 have['pymongo'] = test_for('pymongo')
107 107 have['wx'] = test_for('wx')
108 108 have['wx.aui'] = test_for('wx.aui')
109 109 if os.name == 'nt':
110 110 have['zmq'] = test_for('zmq', '2.1.7')
111 111 else:
112 112 have['zmq'] = test_for('zmq', '2.1.4')
113 113 have['qt'] = test_for('IPython.external.qt')
114 114
115 115 #-----------------------------------------------------------------------------
116 116 # Functions and classes
117 117 #-----------------------------------------------------------------------------
118 118
119 119 def report():
120 120 """Return a string with a summary report of test-related variables."""
121 121
122 122 out = [ sys_info(), '\n']
123 123
124 124 avail = []
125 125 not_avail = []
126 126
127 127 for k, is_avail in have.items():
128 128 if is_avail:
129 129 avail.append(k)
130 130 else:
131 131 not_avail.append(k)
132 132
133 133 if avail:
134 134 out.append('\nTools and libraries available at test time:\n')
135 135 avail.sort()
136 136 out.append(' ' + ' '.join(avail)+'\n')
137 137
138 138 if not_avail:
139 139 out.append('\nTools and libraries NOT available at test time:\n')
140 140 not_avail.sort()
141 141 out.append(' ' + ' '.join(not_avail)+'\n')
142 142
143 143 return ''.join(out)
144 144
145 145
146 146 def make_exclude():
147 147 """Make patterns of modules and packages to exclude from testing.
148 148
149 149 For the IPythonDoctest plugin, we need to exclude certain patterns that
150 150 cause testing problems. We should strive to minimize the number of
151 151 skipped modules, since this means untested code.
152 152
153 153 These modules and packages will NOT get scanned by nose at all for tests.
154 154 """
155 155 # Simple utility to make IPython paths more readably, we need a lot of
156 156 # these below
157 157 ipjoin = lambda *paths: pjoin('IPython', *paths)
158 158
159 159 exclusions = [ipjoin('external'),
160 160 pjoin('IPython_doctest_plugin'),
161 161 ipjoin('quarantine'),
162 162 ipjoin('deathrow'),
163 163 ipjoin('testing', 'attic'),
164 164 # This guy is probably attic material
165 165 ipjoin('testing', 'mkdoctests'),
166 166 # Testing inputhook will need a lot of thought, to figure out
167 167 # how to have tests that don't lock up with the gui event
168 168 # loops in the picture
169 169 ipjoin('lib', 'inputhook'),
170 170 # Config files aren't really importable stand-alone
171 171 ipjoin('config', 'default'),
172 172 ipjoin('config', 'profile'),
173 173 ]
174 174
175 175 if not have['wx']:
176 176 exclusions.append(ipjoin('lib', 'inputhookwx'))
177 177
178 178 # We do this unconditionally, so that the test suite doesn't import
179 179 # gtk, changing the default encoding and masking some unicode bugs.
180 180 exclusions.append(ipjoin('lib', 'inputhookgtk'))
181 181
182 182 # These have to be skipped on win32 because the use echo, rm, cd, etc.
183 183 # See ticket https://bugs.launchpad.net/bugs/366982
184 184 if sys.platform == 'win32':
185 185 exclusions.append(ipjoin('testing', 'plugin', 'test_exampleip'))
186 186 exclusions.append(ipjoin('testing', 'plugin', 'dtexample'))
187 187
188 188 if not have['pexpect']:
189 189 exclusions.extend([ipjoin('scripts', 'irunner'),
190 190 ipjoin('lib', 'irunner'),
191 191 ipjoin('lib', 'tests', 'test_irunner')])
192 192
193 193 if not have['zmq']:
194 194 exclusions.append(ipjoin('zmq'))
195 195 exclusions.append(ipjoin('frontend', 'qt'))
196 196 exclusions.append(ipjoin('parallel'))
197 197 elif not have['qt']:
198 198 exclusions.append(ipjoin('frontend', 'qt'))
199 199
200 200 if not have['pymongo']:
201 201 exclusions.append(ipjoin('parallel', 'controller', 'mongodb'))
202 exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb'))
202 203
203 204 if not have['matplotlib']:
204 205 exclusions.extend([ipjoin('lib', 'pylabtools'),
205 206 ipjoin('lib', 'tests', 'test_pylabtools')])
206 207
207 208 # This is needed for the reg-exp to match on win32 in the ipdoctest plugin.
208 209 if sys.platform == 'win32':
209 210 exclusions = [s.replace('\\','\\\\') for s in exclusions]
210 211
211 212 return exclusions
212 213
213 214
214 215 class IPTester(object):
215 216 """Call that calls iptest or trial in a subprocess.
216 217 """
217 218 #: string, name of test runner that will be called
218 219 runner = None
219 220 #: list, parameters for test runner
220 221 params = None
221 222 #: list, arguments of system call to be made to call test runner
222 223 call_args = None
223 224 #: list, process ids of subprocesses we start (for cleanup)
224 225 pids = None
225 226
226 227 def __init__(self, runner='iptest', params=None):
227 228 """Create new test runner."""
228 229 p = os.path
229 230 if runner == 'iptest':
230 231 iptest_app = get_ipython_module_path('IPython.testing.iptest')
231 232 self.runner = pycmd2argv(iptest_app) + sys.argv[1:]
232 233 else:
233 234 raise Exception('Not a valid test runner: %s' % repr(runner))
234 235 if params is None:
235 236 params = []
236 237 if isinstance(params, str):
237 238 params = [params]
238 239 self.params = params
239 240
240 241 # Assemble call
241 242 self.call_args = self.runner+self.params
242 243
243 244 # Store pids of anything we start to clean up on deletion, if possible
244 245 # (on posix only, since win32 has no os.kill)
245 246 self.pids = []
246 247
247 248 if sys.platform == 'win32':
248 249 def _run_cmd(self):
249 250 # On Windows, use os.system instead of subprocess.call, because I
250 251 # was having problems with subprocess and I just don't know enough
251 252 # about win32 to debug this reliably. Os.system may be the 'old
252 253 # fashioned' way to do it, but it works just fine. If someone
253 254 # later can clean this up that's fine, as long as the tests run
254 255 # reliably in win32.
255 256 # What types of problems are you having. They may be related to
256 257 # running Python in unboffered mode. BG.
257 258 return os.system(' '.join(self.call_args))
258 259 else:
259 260 def _run_cmd(self):
260 261 # print >> sys.stderr, '*** CMD:', ' '.join(self.call_args) # dbg
261 262 subp = subprocess.Popen(self.call_args)
262 263 self.pids.append(subp.pid)
263 264 # If this fails, the pid will be left in self.pids and cleaned up
264 265 # later, but if the wait call succeeds, then we can clear the
265 266 # stored pid.
266 267 retcode = subp.wait()
267 268 self.pids.pop()
268 269 return retcode
269 270
270 271 def run(self):
271 272 """Run the stored commands"""
272 273 try:
273 274 return self._run_cmd()
274 275 except:
275 276 import traceback
276 277 traceback.print_exc()
277 278 return 1 # signal failure
278 279
279 280 def __del__(self):
280 281 """Cleanup on exit by killing any leftover processes."""
281 282
282 283 if not hasattr(os, 'kill'):
283 284 return
284 285
285 286 for pid in self.pids:
286 287 try:
287 288 print 'Cleaning stale PID:', pid
288 289 os.kill(pid, signal.SIGKILL)
289 290 except OSError:
290 291 # This is just a best effort, if we fail or the process was
291 292 # really gone, ignore it.
292 293 pass
293 294
294 295
295 296 def make_runners():
296 297 """Define the top-level packages that need to be tested.
297 298 """
298 299
299 300 # Packages to be tested via nose, that only depend on the stdlib
300 301 nose_pkg_names = ['config', 'core', 'extensions', 'frontend', 'lib',
301 302 'scripts', 'testing', 'utils' ]
302 303
303 304 if have['zmq']:
304 305 nose_pkg_names.append('parallel')
305 306
306 307 # For debugging this code, only load quick stuff
307 308 #nose_pkg_names = ['core', 'extensions'] # dbg
308 309
309 310 # Make fully qualified package names prepending 'IPython.' to our name lists
310 311 nose_packages = ['IPython.%s' % m for m in nose_pkg_names ]
311 312
312 313 # Make runners
313 314 runners = [ (v, IPTester('iptest', params=v)) for v in nose_packages ]
314 315
315 316 return runners
316 317
317 318
318 319 def run_iptest():
319 320 """Run the IPython test suite using nose.
320 321
321 322 This function is called when this script is **not** called with the form
322 323 `iptest all`. It simply calls nose with appropriate command line flags
323 324 and accepts all of the standard nose arguments.
324 325 """
325 326
326 327 warnings.filterwarnings('ignore',
327 328 'This will be removed soon. Use IPython.testing.util instead')
328 329
329 330 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
330 331
331 332 # Loading ipdoctest causes problems with Twisted, but
332 333 # our test suite runner now separates things and runs
333 334 # all Twisted tests with trial.
334 335 '--with-ipdoctest',
335 336 '--ipdoctest-tests','--ipdoctest-extension=txt',
336 337
337 338 # We add --exe because of setuptools' imbecility (it
338 339 # blindly does chmod +x on ALL files). Nose does the
339 340 # right thing and it tries to avoid executables,
340 341 # setuptools unfortunately forces our hand here. This
341 342 # has been discussed on the distutils list and the
342 343 # setuptools devs refuse to fix this problem!
343 344 '--exe',
344 345 ]
345 346
346 347 if nose.__version__ >= '0.11':
347 348 # I don't fully understand why we need this one, but depending on what
348 349 # directory the test suite is run from, if we don't give it, 0 tests
349 350 # get run. Specifically, if the test suite is run from the source dir
350 351 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
351 352 # even if the same call done in this directory works fine). It appears
352 353 # that if the requested package is in the current dir, nose bails early
353 354 # by default. Since it's otherwise harmless, leave it in by default
354 355 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
355 356 argv.append('--traverse-namespace')
356 357
357 358 # Construct list of plugins, omitting the existing doctest plugin, which
358 359 # ours replaces (and extends).
359 360 plugins = [IPythonDoctest(make_exclude()), KnownFailure()]
360 361 for p in nose.plugins.builtin.plugins:
361 362 plug = p()
362 363 if plug.name == 'doctest':
363 364 continue
364 365 plugins.append(plug)
365 366
366 367 # We need a global ipython running in this process
367 368 globalipapp.start_ipython()
368 369 # Now nose can run
369 370 TestProgram(argv=argv, plugins=plugins)
370 371
371 372
372 373 def run_iptestall():
373 374 """Run the entire IPython test suite by calling nose and trial.
374 375
375 376 This function constructs :class:`IPTester` instances for all IPython
376 377 modules and package and then runs each of them. This causes the modules
377 378 and packages of IPython to be tested each in their own subprocess using
378 379 nose or twisted.trial appropriately.
379 380 """
380 381
381 382 runners = make_runners()
382 383
383 384 # Run the test runners in a temporary dir so we can nuke it when finished
384 385 # to clean up any junk files left over by accident. This also makes it
385 386 # robust against being run in non-writeable directories by mistake, as the
386 387 # temp dir will always be user-writeable.
387 388 curdir = os.getcwd()
388 389 testdir = tempfile.gettempdir()
389 390 os.chdir(testdir)
390 391
391 392 # Run all test runners, tracking execution time
392 393 failed = []
393 394 t_start = time.time()
394 395 try:
395 396 for (name, runner) in runners:
396 397 print '*'*70
397 398 print 'IPython test group:',name
398 399 res = runner.run()
399 400 if res:
400 401 failed.append( (name, runner) )
401 402 finally:
402 403 os.chdir(curdir)
403 404 t_end = time.time()
404 405 t_tests = t_end - t_start
405 406 nrunners = len(runners)
406 407 nfail = len(failed)
407 408 # summarize results
408 409 print
409 410 print '*'*70
410 411 print 'Test suite completed for system with the following information:'
411 412 print report()
412 413 print 'Ran %s test groups in %.3fs' % (nrunners, t_tests)
413 414 print
414 415 print 'Status:'
415 416 if not failed:
416 417 print 'OK'
417 418 else:
418 419 # If anything went wrong, point out what command to rerun manually to
419 420 # see the actual errors and individual summary
420 421 print 'ERROR - %s out of %s test groups failed.' % (nfail, nrunners)
421 422 for name, failed_runner in failed:
422 423 print '-'*40
423 424 print 'Runner failed:',name
424 425 print 'You may wish to rerun this one individually, with:'
425 426 print ' '.join(failed_runner.call_args)
426 427 print
427 428
428 429
429 430 def main():
430 431 for arg in sys.argv[1:]:
431 432 if arg.startswith('IPython'):
432 433 # This is in-process
433 434 run_iptest()
434 435 else:
435 436 # This starts subprocesses
436 437 run_iptestall()
437 438
438 439
439 440 if __name__ == '__main__':
440 441 main()
General Comments 0
You need to be logged in to leave comments. Login now