##// END OF EJS Templates
various db backend fixes...
MinRK -
Show More
@@ -0,0 +1,37 b''
1 """Tests for mongodb backend"""
2
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
5 #
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
9
10 #-------------------------------------------------------------------------------
11 # Imports
12 #-------------------------------------------------------------------------------
13
14 from nose import SkipTest
15
16 from pymongo import Connection
17 from IPython.parallel.controller.mongodb import MongoDB
18
19 from . import test_db
20
21 try:
22 c = Connection()
23 except Exception:
24 c=None
25
26 class TestMongoBackend(test_db.TestDictBackend):
27 """MongoDB backend tests"""
28
29 def create_db(self):
30 try:
31 return MongoDB(database='iptestdb', _connection=c)
32 except Exception:
33 raise SkipTest("Couldn't connect to mongodb")
34
35 def teardown(self):
36 if c is not None:
37 c.drop_database('iptestdb')
@@ -1,180 +1,180 b''
1 """A Task logger that presents our DB interface,
1 """A Task logger that presents our DB interface,
2 but exists entirely in memory and implemented with dicts.
2 but exists entirely in memory and implemented with dicts.
3
3
4 TaskRecords are dicts of the form:
4 TaskRecords are dicts of the form:
5 {
5 {
6 'msg_id' : str(uuid),
6 'msg_id' : str(uuid),
7 'client_uuid' : str(uuid),
7 'client_uuid' : str(uuid),
8 'engine_uuid' : str(uuid) or None,
8 'engine_uuid' : str(uuid) or None,
9 'header' : dict(header),
9 'header' : dict(header),
10 'content': dict(content),
10 'content': dict(content),
11 'buffers': list(buffers),
11 'buffers': list(buffers),
12 'submitted': datetime,
12 'submitted': datetime,
13 'started': datetime or None,
13 'started': datetime or None,
14 'completed': datetime or None,
14 'completed': datetime or None,
15 'resubmitted': datetime or None,
15 'resubmitted': datetime or None,
16 'result_header' : dict(header) or None,
16 'result_header' : dict(header) or None,
17 'result_content' : dict(content) or None,
17 'result_content' : dict(content) or None,
18 'result_buffers' : list(buffers) or None,
18 'result_buffers' : list(buffers) or None,
19 }
19 }
20 With this info, many of the special categories of tasks can be defined by query:
20 With this info, many of the special categories of tasks can be defined by query:
21
21
22 pending: completed is None
22 pending: completed is None
23 client's outstanding: client_uuid = uuid && completed is None
23 client's outstanding: client_uuid = uuid && completed is None
24 MIA: arrived is None (and completed is None)
24 MIA: arrived is None (and completed is None)
25 etc.
25 etc.
26
26
27 EngineRecords are dicts of the form:
27 EngineRecords are dicts of the form:
28 {
28 {
29 'eid' : int(id),
29 'eid' : int(id),
30 'uuid': str(uuid)
30 'uuid': str(uuid)
31 }
31 }
32 This may be extended, but is currently.
32 This may be extended, but is currently.
33
33
34 We support a subset of mongodb operators:
34 We support a subset of mongodb operators:
35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 """
36 """
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38 # Copyright (C) 2010 The IPython Development Team
38 # Copyright (C) 2010 The IPython Development Team
39 #
39 #
40 # Distributed under the terms of the BSD License. The full license is in
40 # Distributed under the terms of the BSD License. The full license is in
41 # the file COPYING, distributed as part of this software.
41 # the file COPYING, distributed as part of this software.
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43
43
44
44
45 from datetime import datetime
45 from datetime import datetime
46
46
47 from IPython.config.configurable import Configurable
47 from IPython.config.configurable import Configurable
48
48
49 from IPython.utils.traitlets import Dict, CUnicode
49 from IPython.utils.traitlets import Dict, CUnicode
50
50
51 filters = {
51 filters = {
52 '$lt' : lambda a,b: a < b,
52 '$lt' : lambda a,b: a < b,
53 '$gt' : lambda a,b: b > a,
53 '$gt' : lambda a,b: b > a,
54 '$eq' : lambda a,b: a == b,
54 '$eq' : lambda a,b: a == b,
55 '$ne' : lambda a,b: a != b,
55 '$ne' : lambda a,b: a != b,
56 '$lte': lambda a,b: a <= b,
56 '$lte': lambda a,b: a <= b,
57 '$gte': lambda a,b: a >= b,
57 '$gte': lambda a,b: a >= b,
58 '$in' : lambda a,b: a in b,
58 '$in' : lambda a,b: a in b,
59 '$nin': lambda a,b: a not in b,
59 '$nin': lambda a,b: a not in b,
60 '$all': lambda a,b: all([ a in bb for bb in b ]),
60 '$all': lambda a,b: all([ a in bb for bb in b ]),
61 '$mod': lambda a,b: a%b[0] == b[1],
61 '$mod': lambda a,b: a%b[0] == b[1],
62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
62 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
63 }
63 }
64
64
65
65
66 class CompositeFilter(object):
66 class CompositeFilter(object):
67 """Composite filter for matching multiple properties."""
67 """Composite filter for matching multiple properties."""
68
68
69 def __init__(self, dikt):
69 def __init__(self, dikt):
70 self.tests = []
70 self.tests = []
71 self.values = []
71 self.values = []
72 for key, value in dikt.iteritems():
72 for key, value in dikt.iteritems():
73 self.tests.append(filters[key])
73 self.tests.append(filters[key])
74 self.values.append(value)
74 self.values.append(value)
75
75
76 def __call__(self, value):
76 def __call__(self, value):
77 for test,check in zip(self.tests, self.values):
77 for test,check in zip(self.tests, self.values):
78 if not test(value, check):
78 if not test(value, check):
79 return False
79 return False
80 return True
80 return True
81
81
82 class BaseDB(Configurable):
82 class BaseDB(Configurable):
83 """Empty Parent class so traitlets work on DB."""
83 """Empty Parent class so traitlets work on DB."""
84 # base configurable traits:
84 # base configurable traits:
85 session = CUnicode("")
85 session = CUnicode("")
86
86
87 class DictDB(BaseDB):
87 class DictDB(BaseDB):
88 """Basic in-memory dict-based object for saving Task Records.
88 """Basic in-memory dict-based object for saving Task Records.
89
89
90 This is the first object to present the DB interface
90 This is the first object to present the DB interface
91 for logging tasks out of memory.
91 for logging tasks out of memory.
92
92
93 The interface is based on MongoDB, so adding a MongoDB
93 The interface is based on MongoDB, so adding a MongoDB
94 backend should be straightforward.
94 backend should be straightforward.
95 """
95 """
96
96
97 _records = Dict()
97 _records = Dict()
98
98
99 def _match_one(self, rec, tests):
99 def _match_one(self, rec, tests):
100 """Check if a specific record matches tests."""
100 """Check if a specific record matches tests."""
101 for key,test in tests.iteritems():
101 for key,test in tests.iteritems():
102 if not test(rec.get(key, None)):
102 if not test(rec.get(key, None)):
103 return False
103 return False
104 return True
104 return True
105
105
106 def _match(self, check):
106 def _match(self, check):
107 """Find all the matches for a check dict."""
107 """Find all the matches for a check dict."""
108 matches = []
108 matches = []
109 tests = {}
109 tests = {}
110 for k,v in check.iteritems():
110 for k,v in check.iteritems():
111 if isinstance(v, dict):
111 if isinstance(v, dict):
112 tests[k] = CompositeFilter(v)
112 tests[k] = CompositeFilter(v)
113 else:
113 else:
114 tests[k] = lambda o: o==v
114 tests[k] = lambda o: o==v
115
115
116 for rec in self._records.itervalues():
116 for rec in self._records.itervalues():
117 if self._match_one(rec, tests):
117 if self._match_one(rec, tests):
118 matches.append(rec)
118 matches.append(rec)
119 return matches
119 return matches
120
120
121 def _extract_subdict(self, rec, keys):
121 def _extract_subdict(self, rec, keys):
122 """extract subdict of keys"""
122 """extract subdict of keys"""
123 d = {}
123 d = {}
124 d['msg_id'] = rec['msg_id']
124 d['msg_id'] = rec['msg_id']
125 for key in keys:
125 for key in keys:
126 d[key] = rec[key]
126 d[key] = rec[key]
127 return d
127 return d
128
128
129 def add_record(self, msg_id, rec):
129 def add_record(self, msg_id, rec):
130 """Add a new Task Record, by msg_id."""
130 """Add a new Task Record, by msg_id."""
131 if self._records.has_key(msg_id):
131 if self._records.has_key(msg_id):
132 raise KeyError("Already have msg_id %r"%(msg_id))
132 raise KeyError("Already have msg_id %r"%(msg_id))
133 self._records[msg_id] = rec
133 self._records[msg_id] = rec
134
134
135 def get_record(self, msg_id):
135 def get_record(self, msg_id):
136 """Get a specific Task Record, by msg_id."""
136 """Get a specific Task Record, by msg_id."""
137 if not self._records.has_key(msg_id):
137 if not self._records.has_key(msg_id):
138 raise KeyError("No such msg_id %r"%(msg_id))
138 raise KeyError("No such msg_id %r"%(msg_id))
139 return self._records[msg_id]
139 return self._records[msg_id]
140
140
141 def update_record(self, msg_id, rec):
141 def update_record(self, msg_id, rec):
142 """Update the data in an existing record."""
142 """Update the data in an existing record."""
143 self._records[msg_id].update(rec)
143 self._records[msg_id].update(rec)
144
144
145 def drop_matching_records(self, check):
145 def drop_matching_records(self, check):
146 """Remove a record from the DB."""
146 """Remove a record from the DB."""
147 matches = self._match(check)
147 matches = self._match(check)
148 for m in matches:
148 for m in matches:
149 del self._records[m]
149 del self._records[m['msg_id']]
150
150
151 def drop_record(self, msg_id):
151 def drop_record(self, msg_id):
152 """Remove a record from the DB."""
152 """Remove a record from the DB."""
153 del self._records[msg_id]
153 del self._records[msg_id]
154
154
155
155
156 def find_records(self, check, keys=None):
156 def find_records(self, check, keys=None):
157 """Find records matching a query dict, optionally extracting subset of keys.
157 """Find records matching a query dict, optionally extracting subset of keys.
158
158
159 Returns dict keyed by msg_id of matching records.
159 Returns dict keyed by msg_id of matching records.
160
160
161 Parameters
161 Parameters
162 ----------
162 ----------
163
163
164 check: dict
164 check: dict
165 mongodb-style query argument
165 mongodb-style query argument
166 keys: list of strs [optional]
166 keys: list of strs [optional]
167 if specified, the subset of keys to extract. msg_id will *always* be
167 if specified, the subset of keys to extract. msg_id will *always* be
168 included.
168 included.
169 """
169 """
170 matches = self._match(check)
170 matches = self._match(check)
171 if keys:
171 if keys:
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
172 return [ self._extract_subdict(rec, keys) for rec in matches ]
173 else:
173 else:
174 return matches
174 return matches
175
175
176
176
177 def get_history(self):
177 def get_history(self):
178 """get all msg_ids, ordered by time submitted."""
178 """get all msg_ids, ordered by time submitted."""
179 msg_ids = self._records.keys()
179 msg_ids = self._records.keys()
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
180 return sorted(msg_ids, key=lambda m: self._records[m]['submitted'])
@@ -1,1284 +1,1282 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5 """
5 """
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010 The IPython Development Team
7 # Copyright (C) 2010 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from datetime import datetime
20 from datetime import datetime
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24 from zmq.eventloop.zmqstream import ZMQStream
24 from zmq.eventloop.zmqstream import ZMQStream
25
25
26 # internal:
26 # internal:
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29
29
30 from IPython.parallel import error, util
30 from IPython.parallel import error, util
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32
32
33 from .heartmonitor import HeartMonitor
33 from .heartmonitor import HeartMonitor
34
34
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36 # Code
36 # Code
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38
38
39 def _passer(*args, **kwargs):
39 def _passer(*args, **kwargs):
40 return
40 return
41
41
42 def _printer(*args, **kwargs):
42 def _printer(*args, **kwargs):
43 print (args)
43 print (args)
44 print (kwargs)
44 print (kwargs)
45
45
46 def empty_record():
46 def empty_record():
47 """Return an empty dict with all record keys."""
47 """Return an empty dict with all record keys."""
48 return {
48 return {
49 'msg_id' : None,
49 'msg_id' : None,
50 'header' : None,
50 'header' : None,
51 'content': None,
51 'content': None,
52 'buffers': None,
52 'buffers': None,
53 'submitted': None,
53 'submitted': None,
54 'client_uuid' : None,
54 'client_uuid' : None,
55 'engine_uuid' : None,
55 'engine_uuid' : None,
56 'started': None,
56 'started': None,
57 'completed': None,
57 'completed': None,
58 'resubmitted': None,
58 'resubmitted': None,
59 'result_header' : None,
59 'result_header' : None,
60 'result_content' : None,
60 'result_content' : None,
61 'result_buffers' : None,
61 'result_buffers' : None,
62 'queue' : None,
62 'queue' : None,
63 'pyin' : None,
63 'pyin' : None,
64 'pyout': None,
64 'pyout': None,
65 'pyerr': None,
65 'pyerr': None,
66 'stdout': '',
66 'stdout': '',
67 'stderr': '',
67 'stderr': '',
68 }
68 }
69
69
70 def init_record(msg):
70 def init_record(msg):
71 """Initialize a TaskRecord based on a request."""
71 """Initialize a TaskRecord based on a request."""
72 header = msg['header']
72 header = msg['header']
73 return {
73 return {
74 'msg_id' : header['msg_id'],
74 'msg_id' : header['msg_id'],
75 'header' : header,
75 'header' : header,
76 'content': msg['content'],
76 'content': msg['content'],
77 'buffers': msg['buffers'],
77 'buffers': msg['buffers'],
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
78 'submitted': datetime.strptime(header['date'], util.ISO8601),
79 'client_uuid' : None,
79 'client_uuid' : None,
80 'engine_uuid' : None,
80 'engine_uuid' : None,
81 'started': None,
81 'started': None,
82 'completed': None,
82 'completed': None,
83 'resubmitted': None,
83 'resubmitted': None,
84 'result_header' : None,
84 'result_header' : None,
85 'result_content' : None,
85 'result_content' : None,
86 'result_buffers' : None,
86 'result_buffers' : None,
87 'queue' : None,
87 'queue' : None,
88 'pyin' : None,
88 'pyin' : None,
89 'pyout': None,
89 'pyout': None,
90 'pyerr': None,
90 'pyerr': None,
91 'stdout': '',
91 'stdout': '',
92 'stderr': '',
92 'stderr': '',
93 }
93 }
94
94
95
95
96 class EngineConnector(HasTraits):
96 class EngineConnector(HasTraits):
97 """A simple object for accessing the various zmq connections of an object.
97 """A simple object for accessing the various zmq connections of an object.
98 Attributes are:
98 Attributes are:
99 id (int): engine ID
99 id (int): engine ID
100 uuid (str): uuid (unused?)
100 uuid (str): uuid (unused?)
101 queue (str): identity of queue's XREQ socket
101 queue (str): identity of queue's XREQ socket
102 registration (str): identity of registration XREQ socket
102 registration (str): identity of registration XREQ socket
103 heartbeat (str): identity of heartbeat XREQ socket
103 heartbeat (str): identity of heartbeat XREQ socket
104 """
104 """
105 id=Int(0)
105 id=Int(0)
106 queue=Str()
106 queue=Str()
107 control=Str()
107 control=Str()
108 registration=Str()
108 registration=Str()
109 heartbeat=Str()
109 heartbeat=Str()
110 pending=Set()
110 pending=Set()
111
111
112 class HubFactory(RegistrationFactory):
112 class HubFactory(RegistrationFactory):
113 """The Configurable for setting up a Hub."""
113 """The Configurable for setting up a Hub."""
114
114
115 # name of a scheduler scheme
115 # name of a scheduler scheme
116 scheme = Str('leastload', config=True)
116 scheme = Str('leastload', config=True)
117
117
118 # port-pairs for monitoredqueues:
118 # port-pairs for monitoredqueues:
119 hb = Instance(list, config=True)
119 hb = Instance(list, config=True)
120 def _hb_default(self):
120 def _hb_default(self):
121 return util.select_random_ports(2)
121 return util.select_random_ports(2)
122
122
123 mux = Instance(list, config=True)
123 mux = Instance(list, config=True)
124 def _mux_default(self):
124 def _mux_default(self):
125 return util.select_random_ports(2)
125 return util.select_random_ports(2)
126
126
127 task = Instance(list, config=True)
127 task = Instance(list, config=True)
128 def _task_default(self):
128 def _task_default(self):
129 return util.select_random_ports(2)
129 return util.select_random_ports(2)
130
130
131 control = Instance(list, config=True)
131 control = Instance(list, config=True)
132 def _control_default(self):
132 def _control_default(self):
133 return util.select_random_ports(2)
133 return util.select_random_ports(2)
134
134
135 iopub = Instance(list, config=True)
135 iopub = Instance(list, config=True)
136 def _iopub_default(self):
136 def _iopub_default(self):
137 return util.select_random_ports(2)
137 return util.select_random_ports(2)
138
138
139 # single ports:
139 # single ports:
140 mon_port = Instance(int, config=True)
140 mon_port = Instance(int, config=True)
141 def _mon_port_default(self):
141 def _mon_port_default(self):
142 return util.select_random_ports(1)[0]
142 return util.select_random_ports(1)[0]
143
143
144 notifier_port = Instance(int, config=True)
144 notifier_port = Instance(int, config=True)
145 def _notifier_port_default(self):
145 def _notifier_port_default(self):
146 return util.select_random_ports(1)[0]
146 return util.select_random_ports(1)[0]
147
147
148 ping = Int(1000, config=True) # ping frequency
148 ping = Int(1000, config=True) # ping frequency
149
149
150 engine_ip = CStr('127.0.0.1', config=True)
150 engine_ip = CStr('127.0.0.1', config=True)
151 engine_transport = CStr('tcp', config=True)
151 engine_transport = CStr('tcp', config=True)
152
152
153 client_ip = CStr('127.0.0.1', config=True)
153 client_ip = CStr('127.0.0.1', config=True)
154 client_transport = CStr('tcp', config=True)
154 client_transport = CStr('tcp', config=True)
155
155
156 monitor_ip = CStr('127.0.0.1', config=True)
156 monitor_ip = CStr('127.0.0.1', config=True)
157 monitor_transport = CStr('tcp', config=True)
157 monitor_transport = CStr('tcp', config=True)
158
158
159 monitor_url = CStr('')
159 monitor_url = CStr('')
160
160
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
161 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
162
162
163 # not configurable
163 # not configurable
164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
164 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
165 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
166 subconstructors = List()
166 subconstructors = List()
167 _constructed = Bool(False)
167 _constructed = Bool(False)
168
168
169 def _ip_changed(self, name, old, new):
169 def _ip_changed(self, name, old, new):
170 self.engine_ip = new
170 self.engine_ip = new
171 self.client_ip = new
171 self.client_ip = new
172 self.monitor_ip = new
172 self.monitor_ip = new
173 self._update_monitor_url()
173 self._update_monitor_url()
174
174
175 def _update_monitor_url(self):
175 def _update_monitor_url(self):
176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
176 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
177
177
178 def _transport_changed(self, name, old, new):
178 def _transport_changed(self, name, old, new):
179 self.engine_transport = new
179 self.engine_transport = new
180 self.client_transport = new
180 self.client_transport = new
181 self.monitor_transport = new
181 self.monitor_transport = new
182 self._update_monitor_url()
182 self._update_monitor_url()
183
183
184 def __init__(self, **kwargs):
184 def __init__(self, **kwargs):
185 super(HubFactory, self).__init__(**kwargs)
185 super(HubFactory, self).__init__(**kwargs)
186 self._update_monitor_url()
186 self._update_monitor_url()
187 # self.on_trait_change(self._sync_ips, 'ip')
187 # self.on_trait_change(self._sync_ips, 'ip')
188 # self.on_trait_change(self._sync_transports, 'transport')
188 # self.on_trait_change(self._sync_transports, 'transport')
189 self.subconstructors.append(self.construct_hub)
189 self.subconstructors.append(self.construct_hub)
190
190
191
191
192 def construct(self):
192 def construct(self):
193 assert not self._constructed, "already constructed!"
193 assert not self._constructed, "already constructed!"
194
194
195 for subc in self.subconstructors:
195 for subc in self.subconstructors:
196 subc()
196 subc()
197
197
198 self._constructed = True
198 self._constructed = True
199
199
200
200
201 def start(self):
201 def start(self):
202 assert self._constructed, "must be constructed by self.construct() first!"
202 assert self._constructed, "must be constructed by self.construct() first!"
203 self.heartmonitor.start()
203 self.heartmonitor.start()
204 self.log.info("Heartmonitor started")
204 self.log.info("Heartmonitor started")
205
205
206 def construct_hub(self):
206 def construct_hub(self):
207 """construct"""
207 """construct"""
208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
208 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
209 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
210
210
211 ctx = self.context
211 ctx = self.context
212 loop = self.loop
212 loop = self.loop
213
213
214 # Registrar socket
214 # Registrar socket
215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
215 q = ZMQStream(ctx.socket(zmq.XREP), loop)
216 q.bind(client_iface % self.regport)
216 q.bind(client_iface % self.regport)
217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
217 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
218 if self.client_ip != self.engine_ip:
218 if self.client_ip != self.engine_ip:
219 q.bind(engine_iface % self.regport)
219 q.bind(engine_iface % self.regport)
220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
220 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
221
221
222 ### Engine connections ###
222 ### Engine connections ###
223
223
224 # heartbeat
224 # heartbeat
225 hpub = ctx.socket(zmq.PUB)
225 hpub = ctx.socket(zmq.PUB)
226 hpub.bind(engine_iface % self.hb[0])
226 hpub.bind(engine_iface % self.hb[0])
227 hrep = ctx.socket(zmq.XREP)
227 hrep = ctx.socket(zmq.XREP)
228 hrep.bind(engine_iface % self.hb[1])
228 hrep.bind(engine_iface % self.hb[1])
229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
229 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
230 period=self.ping, logname=self.log.name)
230 period=self.ping, logname=self.log.name)
231
231
232 ### Client connections ###
232 ### Client connections ###
233 # Notifier socket
233 # Notifier socket
234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
234 n = ZMQStream(ctx.socket(zmq.PUB), loop)
235 n.bind(client_iface%self.notifier_port)
235 n.bind(client_iface%self.notifier_port)
236
236
237 ### build and launch the queues ###
237 ### build and launch the queues ###
238
238
239 # monitor socket
239 # monitor socket
240 sub = ctx.socket(zmq.SUB)
240 sub = ctx.socket(zmq.SUB)
241 sub.setsockopt(zmq.SUBSCRIBE, "")
241 sub.setsockopt(zmq.SUBSCRIBE, "")
242 sub.bind(self.monitor_url)
242 sub.bind(self.monitor_url)
243 sub.bind('inproc://monitor')
243 sub.bind('inproc://monitor')
244 sub = ZMQStream(sub, loop)
244 sub = ZMQStream(sub, loop)
245
245
246 # connect the db
246 # connect the db
247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
247 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
248 # cdir = self.config.Global.cluster_dir
248 # cdir = self.config.Global.cluster_dir
249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
249 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
250 time.sleep(.25)
250 time.sleep(.25)
251
251
252 # build connection dicts
252 # build connection dicts
253 self.engine_info = {
253 self.engine_info = {
254 'control' : engine_iface%self.control[1],
254 'control' : engine_iface%self.control[1],
255 'mux': engine_iface%self.mux[1],
255 'mux': engine_iface%self.mux[1],
256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
256 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
257 'task' : engine_iface%self.task[1],
257 'task' : engine_iface%self.task[1],
258 'iopub' : engine_iface%self.iopub[1],
258 'iopub' : engine_iface%self.iopub[1],
259 # 'monitor' : engine_iface%self.mon_port,
259 # 'monitor' : engine_iface%self.mon_port,
260 }
260 }
261
261
262 self.client_info = {
262 self.client_info = {
263 'control' : client_iface%self.control[0],
263 'control' : client_iface%self.control[0],
264 'mux': client_iface%self.mux[0],
264 'mux': client_iface%self.mux[0],
265 'task' : (self.scheme, client_iface%self.task[0]),
265 'task' : (self.scheme, client_iface%self.task[0]),
266 'iopub' : client_iface%self.iopub[0],
266 'iopub' : client_iface%self.iopub[0],
267 'notification': client_iface%self.notifier_port
267 'notification': client_iface%self.notifier_port
268 }
268 }
269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
269 self.log.debug("Hub engine addrs: %s"%self.engine_info)
270 self.log.debug("Hub client addrs: %s"%self.client_info)
270 self.log.debug("Hub client addrs: %s"%self.client_info)
271
271
272 # resubmit stream
272 # resubmit stream
273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
273 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
274 url = util.disambiguate_url(self.client_info['task'][-1])
274 url = util.disambiguate_url(self.client_info['task'][-1])
275 r.setsockopt(zmq.IDENTITY, self.session.session)
275 r.setsockopt(zmq.IDENTITY, self.session.session)
276 r.connect(url)
276 r.connect(url)
277
277
278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
278 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
279 query=q, notifier=n, resubmit=r, db=self.db,
279 query=q, notifier=n, resubmit=r, db=self.db,
280 engine_info=self.engine_info, client_info=self.client_info,
280 engine_info=self.engine_info, client_info=self.client_info,
281 logname=self.log.name)
281 logname=self.log.name)
282
282
283
283
284 class Hub(LoggingFactory):
284 class Hub(LoggingFactory):
285 """The IPython Controller Hub with 0MQ connections
285 """The IPython Controller Hub with 0MQ connections
286
286
287 Parameters
287 Parameters
288 ==========
288 ==========
289 loop: zmq IOLoop instance
289 loop: zmq IOLoop instance
290 session: StreamSession object
290 session: StreamSession object
291 <removed> context: zmq context for creating new connections (?)
291 <removed> context: zmq context for creating new connections (?)
292 queue: ZMQStream for monitoring the command queue (SUB)
292 queue: ZMQStream for monitoring the command queue (SUB)
293 query: ZMQStream for engine registration and client queries requests (XREP)
293 query: ZMQStream for engine registration and client queries requests (XREP)
294 heartbeat: HeartMonitor object checking the pulse of the engines
294 heartbeat: HeartMonitor object checking the pulse of the engines
295 notifier: ZMQStream for broadcasting engine registration changes (PUB)
295 notifier: ZMQStream for broadcasting engine registration changes (PUB)
296 db: connection to db for out of memory logging of commands
296 db: connection to db for out of memory logging of commands
297 NotImplemented
297 NotImplemented
298 engine_info: dict of zmq connection information for engines to connect
298 engine_info: dict of zmq connection information for engines to connect
299 to the queues.
299 to the queues.
300 client_info: dict of zmq connection information for engines to connect
300 client_info: dict of zmq connection information for engines to connect
301 to the queues.
301 to the queues.
302 """
302 """
303 # internal data structures:
303 # internal data structures:
304 ids=Set() # engine IDs
304 ids=Set() # engine IDs
305 keytable=Dict()
305 keytable=Dict()
306 by_ident=Dict()
306 by_ident=Dict()
307 engines=Dict()
307 engines=Dict()
308 clients=Dict()
308 clients=Dict()
309 hearts=Dict()
309 hearts=Dict()
310 pending=Set()
310 pending=Set()
311 queues=Dict() # pending msg_ids keyed by engine_id
311 queues=Dict() # pending msg_ids keyed by engine_id
312 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
312 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
313 completed=Dict() # completed msg_ids keyed by engine_id
313 completed=Dict() # completed msg_ids keyed by engine_id
314 all_completed=Set() # completed msg_ids keyed by engine_id
314 all_completed=Set() # completed msg_ids keyed by engine_id
315 dead_engines=Set() # completed msg_ids keyed by engine_id
315 dead_engines=Set() # completed msg_ids keyed by engine_id
316 unassigned=Set() # set of task msg_ds not yet assigned a destination
316 unassigned=Set() # set of task msg_ds not yet assigned a destination
317 incoming_registrations=Dict()
317 incoming_registrations=Dict()
318 registration_timeout=Int()
318 registration_timeout=Int()
319 _idcounter=Int(0)
319 _idcounter=Int(0)
320
320
321 # objects from constructor:
321 # objects from constructor:
322 loop=Instance(ioloop.IOLoop)
322 loop=Instance(ioloop.IOLoop)
323 query=Instance(ZMQStream)
323 query=Instance(ZMQStream)
324 monitor=Instance(ZMQStream)
324 monitor=Instance(ZMQStream)
325 notifier=Instance(ZMQStream)
325 notifier=Instance(ZMQStream)
326 resubmit=Instance(ZMQStream)
326 resubmit=Instance(ZMQStream)
327 heartmonitor=Instance(HeartMonitor)
327 heartmonitor=Instance(HeartMonitor)
328 db=Instance(object)
328 db=Instance(object)
329 client_info=Dict()
329 client_info=Dict()
330 engine_info=Dict()
330 engine_info=Dict()
331
331
332
332
333 def __init__(self, **kwargs):
333 def __init__(self, **kwargs):
334 """
334 """
335 # universal:
335 # universal:
336 loop: IOLoop for creating future connections
336 loop: IOLoop for creating future connections
337 session: streamsession for sending serialized data
337 session: streamsession for sending serialized data
338 # engine:
338 # engine:
339 queue: ZMQStream for monitoring queue messages
339 queue: ZMQStream for monitoring queue messages
340 query: ZMQStream for engine+client registration and client requests
340 query: ZMQStream for engine+client registration and client requests
341 heartbeat: HeartMonitor object for tracking engines
341 heartbeat: HeartMonitor object for tracking engines
342 # extra:
342 # extra:
343 db: ZMQStream for db connection (NotImplemented)
343 db: ZMQStream for db connection (NotImplemented)
344 engine_info: zmq address/protocol dict for engine connections
344 engine_info: zmq address/protocol dict for engine connections
345 client_info: zmq address/protocol dict for client connections
345 client_info: zmq address/protocol dict for client connections
346 """
346 """
347
347
348 super(Hub, self).__init__(**kwargs)
348 super(Hub, self).__init__(**kwargs)
349 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
349 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
350
350
351 # validate connection dicts:
351 # validate connection dicts:
352 for k,v in self.client_info.iteritems():
352 for k,v in self.client_info.iteritems():
353 if k == 'task':
353 if k == 'task':
354 util.validate_url_container(v[1])
354 util.validate_url_container(v[1])
355 else:
355 else:
356 util.validate_url_container(v)
356 util.validate_url_container(v)
357 # util.validate_url_container(self.client_info)
357 # util.validate_url_container(self.client_info)
358 util.validate_url_container(self.engine_info)
358 util.validate_url_container(self.engine_info)
359
359
360 # register our callbacks
360 # register our callbacks
361 self.query.on_recv(self.dispatch_query)
361 self.query.on_recv(self.dispatch_query)
362 self.monitor.on_recv(self.dispatch_monitor_traffic)
362 self.monitor.on_recv(self.dispatch_monitor_traffic)
363
363
364 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
364 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
365 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
365 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
366
366
367 self.monitor_handlers = { 'in' : self.save_queue_request,
367 self.monitor_handlers = { 'in' : self.save_queue_request,
368 'out': self.save_queue_result,
368 'out': self.save_queue_result,
369 'intask': self.save_task_request,
369 'intask': self.save_task_request,
370 'outtask': self.save_task_result,
370 'outtask': self.save_task_result,
371 'tracktask': self.save_task_destination,
371 'tracktask': self.save_task_destination,
372 'incontrol': _passer,
372 'incontrol': _passer,
373 'outcontrol': _passer,
373 'outcontrol': _passer,
374 'iopub': self.save_iopub_message,
374 'iopub': self.save_iopub_message,
375 }
375 }
376
376
377 self.query_handlers = {'queue_request': self.queue_status,
377 self.query_handlers = {'queue_request': self.queue_status,
378 'result_request': self.get_results,
378 'result_request': self.get_results,
379 'history_request': self.get_history,
379 'history_request': self.get_history,
380 'db_request': self.db_query,
380 'db_request': self.db_query,
381 'purge_request': self.purge_results,
381 'purge_request': self.purge_results,
382 'load_request': self.check_load,
382 'load_request': self.check_load,
383 'resubmit_request': self.resubmit_task,
383 'resubmit_request': self.resubmit_task,
384 'shutdown_request': self.shutdown_request,
384 'shutdown_request': self.shutdown_request,
385 'registration_request' : self.register_engine,
385 'registration_request' : self.register_engine,
386 'unregistration_request' : self.unregister_engine,
386 'unregistration_request' : self.unregister_engine,
387 'connection_request': self.connection_request,
387 'connection_request': self.connection_request,
388 }
388 }
389
389
390 # ignore resubmit replies
390 # ignore resubmit replies
391 self.resubmit.on_recv(lambda msg: None, copy=False)
391 self.resubmit.on_recv(lambda msg: None, copy=False)
392
392
393 self.log.info("hub::created hub")
393 self.log.info("hub::created hub")
394
394
395 @property
395 @property
396 def _next_id(self):
396 def _next_id(self):
397 """gemerate a new ID.
397 """gemerate a new ID.
398
398
399 No longer reuse old ids, just count from 0."""
399 No longer reuse old ids, just count from 0."""
400 newid = self._idcounter
400 newid = self._idcounter
401 self._idcounter += 1
401 self._idcounter += 1
402 return newid
402 return newid
403 # newid = 0
403 # newid = 0
404 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
404 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
405 # # print newid, self.ids, self.incoming_registrations
405 # # print newid, self.ids, self.incoming_registrations
406 # while newid in self.ids or newid in incoming:
406 # while newid in self.ids or newid in incoming:
407 # newid += 1
407 # newid += 1
408 # return newid
408 # return newid
409
409
410 #-----------------------------------------------------------------------------
410 #-----------------------------------------------------------------------------
411 # message validation
411 # message validation
412 #-----------------------------------------------------------------------------
412 #-----------------------------------------------------------------------------
413
413
414 def _validate_targets(self, targets):
414 def _validate_targets(self, targets):
415 """turn any valid targets argument into a list of integer ids"""
415 """turn any valid targets argument into a list of integer ids"""
416 if targets is None:
416 if targets is None:
417 # default to all
417 # default to all
418 targets = self.ids
418 targets = self.ids
419
419
420 if isinstance(targets, (int,str,unicode)):
420 if isinstance(targets, (int,str,unicode)):
421 # only one target specified
421 # only one target specified
422 targets = [targets]
422 targets = [targets]
423 _targets = []
423 _targets = []
424 for t in targets:
424 for t in targets:
425 # map raw identities to ids
425 # map raw identities to ids
426 if isinstance(t, (str,unicode)):
426 if isinstance(t, (str,unicode)):
427 t = self.by_ident.get(t, t)
427 t = self.by_ident.get(t, t)
428 _targets.append(t)
428 _targets.append(t)
429 targets = _targets
429 targets = _targets
430 bad_targets = [ t for t in targets if t not in self.ids ]
430 bad_targets = [ t for t in targets if t not in self.ids ]
431 if bad_targets:
431 if bad_targets:
432 raise IndexError("No Such Engine: %r"%bad_targets)
432 raise IndexError("No Such Engine: %r"%bad_targets)
433 if not targets:
433 if not targets:
434 raise IndexError("No Engines Registered")
434 raise IndexError("No Engines Registered")
435 return targets
435 return targets
436
436
437 #-----------------------------------------------------------------------------
437 #-----------------------------------------------------------------------------
438 # dispatch methods (1 per stream)
438 # dispatch methods (1 per stream)
439 #-----------------------------------------------------------------------------
439 #-----------------------------------------------------------------------------
440
440
441 # def dispatch_registration_request(self, msg):
441 # def dispatch_registration_request(self, msg):
442 # """"""
442 # """"""
443 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
443 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
444 # idents,msg = self.session.feed_identities(msg)
444 # idents,msg = self.session.feed_identities(msg)
445 # if not idents:
445 # if not idents:
446 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
446 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
447 # return
447 # return
448 # try:
448 # try:
449 # msg = self.session.unpack_message(msg,content=True)
449 # msg = self.session.unpack_message(msg,content=True)
450 # except:
450 # except:
451 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
451 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
452 # return
452 # return
453 #
453 #
454 # msg_type = msg['msg_type']
454 # msg_type = msg['msg_type']
455 # content = msg['content']
455 # content = msg['content']
456 #
456 #
457 # handler = self.query_handlers.get(msg_type, None)
457 # handler = self.query_handlers.get(msg_type, None)
458 # if handler is None:
458 # if handler is None:
459 # self.log.error("registration::got bad registration message: %s"%msg)
459 # self.log.error("registration::got bad registration message: %s"%msg)
460 # else:
460 # else:
461 # handler(idents, msg)
461 # handler(idents, msg)
462
462
463 def dispatch_monitor_traffic(self, msg):
463 def dispatch_monitor_traffic(self, msg):
464 """all ME and Task queue messages come through here, as well as
464 """all ME and Task queue messages come through here, as well as
465 IOPub traffic."""
465 IOPub traffic."""
466 self.log.debug("monitor traffic: %r"%msg[:2])
466 self.log.debug("monitor traffic: %r"%msg[:2])
467 switch = msg[0]
467 switch = msg[0]
468 idents, msg = self.session.feed_identities(msg[1:])
468 idents, msg = self.session.feed_identities(msg[1:])
469 if not idents:
469 if not idents:
470 self.log.error("Bad Monitor Message: %r"%msg)
470 self.log.error("Bad Monitor Message: %r"%msg)
471 return
471 return
472 handler = self.monitor_handlers.get(switch, None)
472 handler = self.monitor_handlers.get(switch, None)
473 if handler is not None:
473 if handler is not None:
474 handler(idents, msg)
474 handler(idents, msg)
475 else:
475 else:
476 self.log.error("Invalid monitor topic: %r"%switch)
476 self.log.error("Invalid monitor topic: %r"%switch)
477
477
478
478
479 def dispatch_query(self, msg):
479 def dispatch_query(self, msg):
480 """Route registration requests and queries from clients."""
480 """Route registration requests and queries from clients."""
481 idents, msg = self.session.feed_identities(msg)
481 idents, msg = self.session.feed_identities(msg)
482 if not idents:
482 if not idents:
483 self.log.error("Bad Query Message: %r"%msg)
483 self.log.error("Bad Query Message: %r"%msg)
484 return
484 return
485 client_id = idents[0]
485 client_id = idents[0]
486 try:
486 try:
487 msg = self.session.unpack_message(msg, content=True)
487 msg = self.session.unpack_message(msg, content=True)
488 except:
488 except:
489 content = error.wrap_exception()
489 content = error.wrap_exception()
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 self.session.send(self.query, "hub_error", ident=client_id,
491 self.session.send(self.query, "hub_error", ident=client_id,
492 content=content)
492 content=content)
493 return
493 return
494
494
495 # print client_id, header, parent, content
495 # print client_id, header, parent, content
496 #switch on message type:
496 #switch on message type:
497 msg_type = msg['msg_type']
497 msg_type = msg['msg_type']
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 handler = self.query_handlers.get(msg_type, None)
499 handler = self.query_handlers.get(msg_type, None)
500 try:
500 try:
501 assert handler is not None, "Bad Message Type: %r"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 except:
502 except:
503 content = error.wrap_exception()
503 content = error.wrap_exception()
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 self.session.send(self.query, "hub_error", ident=client_id,
505 self.session.send(self.query, "hub_error", ident=client_id,
506 content=content)
506 content=content)
507 return
507 return
508
508
509 else:
509 else:
510 handler(idents, msg)
510 handler(idents, msg)
511
511
512 def dispatch_db(self, msg):
512 def dispatch_db(self, msg):
513 """"""
513 """"""
514 raise NotImplementedError
514 raise NotImplementedError
515
515
516 #---------------------------------------------------------------------------
516 #---------------------------------------------------------------------------
517 # handler methods (1 per event)
517 # handler methods (1 per event)
518 #---------------------------------------------------------------------------
518 #---------------------------------------------------------------------------
519
519
520 #----------------------- Heartbeat --------------------------------------
520 #----------------------- Heartbeat --------------------------------------
521
521
522 def handle_new_heart(self, heart):
522 def handle_new_heart(self, heart):
523 """handler to attach to heartbeater.
523 """handler to attach to heartbeater.
524 Called when a new heart starts to beat.
524 Called when a new heart starts to beat.
525 Triggers completion of registration."""
525 Triggers completion of registration."""
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 if heart not in self.incoming_registrations:
527 if heart not in self.incoming_registrations:
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 else:
529 else:
530 self.finish_registration(heart)
530 self.finish_registration(heart)
531
531
532
532
533 def handle_heart_failure(self, heart):
533 def handle_heart_failure(self, heart):
534 """handler to attach to heartbeater.
534 """handler to attach to heartbeater.
535 called when a previously registered heart fails to respond to beat request.
535 called when a previously registered heart fails to respond to beat request.
536 triggers unregistration"""
536 triggers unregistration"""
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 eid = self.hearts.get(heart, None)
538 eid = self.hearts.get(heart, None)
539 queue = self.engines[eid].queue
539 queue = self.engines[eid].queue
540 if eid is None:
540 if eid is None:
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 else:
542 else:
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544
544
545 #----------------------- MUX Queue Traffic ------------------------------
545 #----------------------- MUX Queue Traffic ------------------------------
546
546
547 def save_queue_request(self, idents, msg):
547 def save_queue_request(self, idents, msg):
548 if len(idents) < 2:
548 if len(idents) < 2:
549 self.log.error("invalid identity prefix: %s"%idents)
549 self.log.error("invalid identity prefix: %s"%idents)
550 return
550 return
551 queue_id, client_id = idents[:2]
551 queue_id, client_id = idents[:2]
552 try:
552 try:
553 msg = self.session.unpack_message(msg, content=False)
553 msg = self.session.unpack_message(msg, content=False)
554 except:
554 except:
555 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
555 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
556 return
556 return
557
557
558 eid = self.by_ident.get(queue_id, None)
558 eid = self.by_ident.get(queue_id, None)
559 if eid is None:
559 if eid is None:
560 self.log.error("queue::target %r not registered"%queue_id)
560 self.log.error("queue::target %r not registered"%queue_id)
561 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
561 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
562 return
562 return
563
563
564 header = msg['header']
564 header = msg['header']
565 msg_id = header['msg_id']
565 msg_id = header['msg_id']
566 record = init_record(msg)
566 record = init_record(msg)
567 record['engine_uuid'] = queue_id
567 record['engine_uuid'] = queue_id
568 record['client_uuid'] = client_id
568 record['client_uuid'] = client_id
569 record['queue'] = 'mux'
569 record['queue'] = 'mux'
570
570
571 try:
571 try:
572 # it's posible iopub arrived first:
572 # it's posible iopub arrived first:
573 existing = self.db.get_record(msg_id)
573 existing = self.db.get_record(msg_id)
574 for key,evalue in existing.iteritems():
574 for key,evalue in existing.iteritems():
575 rvalue = record.get(key, None)
575 rvalue = record.get(key, None)
576 if evalue and rvalue and evalue != rvalue:
576 if evalue and rvalue and evalue != rvalue:
577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
578 elif evalue and not rvalue:
578 elif evalue and not rvalue:
579 record[key] = evalue
579 record[key] = evalue
580 self.db.update_record(msg_id, record)
580 self.db.update_record(msg_id, record)
581 except KeyError:
581 except KeyError:
582 self.db.add_record(msg_id, record)
582 self.db.add_record(msg_id, record)
583
583
584 self.pending.add(msg_id)
584 self.pending.add(msg_id)
585 self.queues[eid].append(msg_id)
585 self.queues[eid].append(msg_id)
586
586
587 def save_queue_result(self, idents, msg):
587 def save_queue_result(self, idents, msg):
588 if len(idents) < 2:
588 if len(idents) < 2:
589 self.log.error("invalid identity prefix: %s"%idents)
589 self.log.error("invalid identity prefix: %s"%idents)
590 return
590 return
591
591
592 client_id, queue_id = idents[:2]
592 client_id, queue_id = idents[:2]
593 try:
593 try:
594 msg = self.session.unpack_message(msg, content=False)
594 msg = self.session.unpack_message(msg, content=False)
595 except:
595 except:
596 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
596 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
597 queue_id,client_id, msg), exc_info=True)
597 queue_id,client_id, msg), exc_info=True)
598 return
598 return
599
599
600 eid = self.by_ident.get(queue_id, None)
600 eid = self.by_ident.get(queue_id, None)
601 if eid is None:
601 if eid is None:
602 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
602 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
603 # self.log.debug("queue:: %s"%msg[2:])
603 # self.log.debug("queue:: %s"%msg[2:])
604 return
604 return
605
605
606 parent = msg['parent_header']
606 parent = msg['parent_header']
607 if not parent:
607 if not parent:
608 return
608 return
609 msg_id = parent['msg_id']
609 msg_id = parent['msg_id']
610 if msg_id in self.pending:
610 if msg_id in self.pending:
611 self.pending.remove(msg_id)
611 self.pending.remove(msg_id)
612 self.all_completed.add(msg_id)
612 self.all_completed.add(msg_id)
613 self.queues[eid].remove(msg_id)
613 self.queues[eid].remove(msg_id)
614 self.completed[eid].append(msg_id)
614 self.completed[eid].append(msg_id)
615 elif msg_id not in self.all_completed:
615 elif msg_id not in self.all_completed:
616 # it could be a result from a dead engine that died before delivering the
616 # it could be a result from a dead engine that died before delivering the
617 # result
617 # result
618 self.log.warn("queue:: unknown msg finished %s"%msg_id)
618 self.log.warn("queue:: unknown msg finished %s"%msg_id)
619 return
619 return
620 # update record anyway, because the unregistration could have been premature
620 # update record anyway, because the unregistration could have been premature
621 rheader = msg['header']
621 rheader = msg['header']
622 completed = datetime.strptime(rheader['date'], util.ISO8601)
622 completed = datetime.strptime(rheader['date'], util.ISO8601)
623 started = rheader.get('started', None)
623 started = rheader.get('started', None)
624 if started is not None:
624 if started is not None:
625 started = datetime.strptime(started, util.ISO8601)
625 started = datetime.strptime(started, util.ISO8601)
626 result = {
626 result = {
627 'result_header' : rheader,
627 'result_header' : rheader,
628 'result_content': msg['content'],
628 'result_content': msg['content'],
629 'started' : started,
629 'started' : started,
630 'completed' : completed
630 'completed' : completed
631 }
631 }
632
632
633 result['result_buffers'] = msg['buffers']
633 result['result_buffers'] = msg['buffers']
634 try:
634 try:
635 self.db.update_record(msg_id, result)
635 self.db.update_record(msg_id, result)
636 except Exception:
636 except Exception:
637 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
637 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
638
638
639
639
640 #--------------------- Task Queue Traffic ------------------------------
640 #--------------------- Task Queue Traffic ------------------------------
641
641
642 def save_task_request(self, idents, msg):
642 def save_task_request(self, idents, msg):
643 """Save the submission of a task."""
643 """Save the submission of a task."""
644 client_id = idents[0]
644 client_id = idents[0]
645
645
646 try:
646 try:
647 msg = self.session.unpack_message(msg, content=False)
647 msg = self.session.unpack_message(msg, content=False)
648 except:
648 except:
649 self.log.error("task::client %r sent invalid task message: %s"%(
649 self.log.error("task::client %r sent invalid task message: %s"%(
650 client_id, msg), exc_info=True)
650 client_id, msg), exc_info=True)
651 return
651 return
652 record = init_record(msg)
652 record = init_record(msg)
653
653
654 record['client_uuid'] = client_id
654 record['client_uuid'] = client_id
655 record['queue'] = 'task'
655 record['queue'] = 'task'
656 header = msg['header']
656 header = msg['header']
657 msg_id = header['msg_id']
657 msg_id = header['msg_id']
658 self.pending.add(msg_id)
658 self.pending.add(msg_id)
659 self.unassigned.add(msg_id)
659 self.unassigned.add(msg_id)
660 try:
660 try:
661 # it's posible iopub arrived first:
661 # it's posible iopub arrived first:
662 existing = self.db.get_record(msg_id)
662 existing = self.db.get_record(msg_id)
663 if existing['resubmitted']:
663 if existing['resubmitted']:
664 for key in ('submitted', 'client_uuid', 'buffers'):
664 for key in ('submitted', 'client_uuid', 'buffers'):
665 # don't clobber these keys on resubmit
665 # don't clobber these keys on resubmit
666 # submitted and client_uuid should be different
666 # submitted and client_uuid should be different
667 # and buffers might be big, and shouldn't have changed
667 # and buffers might be big, and shouldn't have changed
668 record.pop(key)
668 record.pop(key)
669 # still check content,header which should not change
669 # still check content,header which should not change
670 # but are not expensive to compare as buffers
670 # but are not expensive to compare as buffers
671
671
672 for key,evalue in existing.iteritems():
672 for key,evalue in existing.iteritems():
673 if key.endswith('buffers'):
673 if key.endswith('buffers'):
674 # don't compare buffers
674 # don't compare buffers
675 continue
675 continue
676 rvalue = record.get(key, None)
676 rvalue = record.get(key, None)
677 if evalue and rvalue and evalue != rvalue:
677 if evalue and rvalue and evalue != rvalue:
678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
678 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
679 elif evalue and not rvalue:
679 elif evalue and not rvalue:
680 record[key] = evalue
680 record[key] = evalue
681 self.db.update_record(msg_id, record)
681 self.db.update_record(msg_id, record)
682 except KeyError:
682 except KeyError:
683 self.db.add_record(msg_id, record)
683 self.db.add_record(msg_id, record)
684 except Exception:
684 except Exception:
685 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
685 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
686
686
687 def save_task_result(self, idents, msg):
687 def save_task_result(self, idents, msg):
688 """save the result of a completed task."""
688 """save the result of a completed task."""
689 client_id = idents[0]
689 client_id = idents[0]
690 try:
690 try:
691 msg = self.session.unpack_message(msg, content=False)
691 msg = self.session.unpack_message(msg, content=False)
692 except:
692 except:
693 self.log.error("task::invalid task result message send to %r: %s"%(
693 self.log.error("task::invalid task result message send to %r: %s"%(
694 client_id, msg), exc_info=True)
694 client_id, msg), exc_info=True)
695 raise
695 raise
696 return
696 return
697
697
698 parent = msg['parent_header']
698 parent = msg['parent_header']
699 if not parent:
699 if not parent:
700 # print msg
700 # print msg
701 self.log.warn("Task %r had no parent!"%msg)
701 self.log.warn("Task %r had no parent!"%msg)
702 return
702 return
703 msg_id = parent['msg_id']
703 msg_id = parent['msg_id']
704 if msg_id in self.unassigned:
704 if msg_id in self.unassigned:
705 self.unassigned.remove(msg_id)
705 self.unassigned.remove(msg_id)
706
706
707 header = msg['header']
707 header = msg['header']
708 engine_uuid = header.get('engine', None)
708 engine_uuid = header.get('engine', None)
709 eid = self.by_ident.get(engine_uuid, None)
709 eid = self.by_ident.get(engine_uuid, None)
710
710
711 if msg_id in self.pending:
711 if msg_id in self.pending:
712 self.pending.remove(msg_id)
712 self.pending.remove(msg_id)
713 self.all_completed.add(msg_id)
713 self.all_completed.add(msg_id)
714 if eid is not None:
714 if eid is not None:
715 self.completed[eid].append(msg_id)
715 self.completed[eid].append(msg_id)
716 if msg_id in self.tasks[eid]:
716 if msg_id in self.tasks[eid]:
717 self.tasks[eid].remove(msg_id)
717 self.tasks[eid].remove(msg_id)
718 completed = datetime.strptime(header['date'], util.ISO8601)
718 completed = datetime.strptime(header['date'], util.ISO8601)
719 started = header.get('started', None)
719 started = header.get('started', None)
720 if started is not None:
720 if started is not None:
721 started = datetime.strptime(started, util.ISO8601)
721 started = datetime.strptime(started, util.ISO8601)
722 result = {
722 result = {
723 'result_header' : header,
723 'result_header' : header,
724 'result_content': msg['content'],
724 'result_content': msg['content'],
725 'started' : started,
725 'started' : started,
726 'completed' : completed,
726 'completed' : completed,
727 'engine_uuid': engine_uuid
727 'engine_uuid': engine_uuid
728 }
728 }
729
729
730 result['result_buffers'] = msg['buffers']
730 result['result_buffers'] = msg['buffers']
731 try:
731 try:
732 self.db.update_record(msg_id, result)
732 self.db.update_record(msg_id, result)
733 except Exception:
733 except Exception:
734 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
734 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
735
735
736 else:
736 else:
737 self.log.debug("task::unknown task %s finished"%msg_id)
737 self.log.debug("task::unknown task %s finished"%msg_id)
738
738
739 def save_task_destination(self, idents, msg):
739 def save_task_destination(self, idents, msg):
740 try:
740 try:
741 msg = self.session.unpack_message(msg, content=True)
741 msg = self.session.unpack_message(msg, content=True)
742 except:
742 except:
743 self.log.error("task::invalid task tracking message", exc_info=True)
743 self.log.error("task::invalid task tracking message", exc_info=True)
744 return
744 return
745 content = msg['content']
745 content = msg['content']
746 # print (content)
746 # print (content)
747 msg_id = content['msg_id']
747 msg_id = content['msg_id']
748 engine_uuid = content['engine_id']
748 engine_uuid = content['engine_id']
749 eid = self.by_ident[engine_uuid]
749 eid = self.by_ident[engine_uuid]
750
750
751 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
751 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
752 if msg_id in self.unassigned:
752 if msg_id in self.unassigned:
753 self.unassigned.remove(msg_id)
753 self.unassigned.remove(msg_id)
754 # else:
754 # else:
755 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
755 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
756
756
757 self.tasks[eid].append(msg_id)
757 self.tasks[eid].append(msg_id)
758 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
758 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
759 try:
759 try:
760 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
760 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
761 except Exception:
761 except Exception:
762 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
762 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
763
763
764
764
765 def mia_task_request(self, idents, msg):
765 def mia_task_request(self, idents, msg):
766 raise NotImplementedError
766 raise NotImplementedError
767 client_id = idents[0]
767 client_id = idents[0]
768 # content = dict(mia=self.mia,status='ok')
768 # content = dict(mia=self.mia,status='ok')
769 # self.session.send('mia_reply', content=content, idents=client_id)
769 # self.session.send('mia_reply', content=content, idents=client_id)
770
770
771
771
772 #--------------------- IOPub Traffic ------------------------------
772 #--------------------- IOPub Traffic ------------------------------
773
773
774 def save_iopub_message(self, topics, msg):
774 def save_iopub_message(self, topics, msg):
775 """save an iopub message into the db"""
775 """save an iopub message into the db"""
776 # print (topics)
776 # print (topics)
777 try:
777 try:
778 msg = self.session.unpack_message(msg, content=True)
778 msg = self.session.unpack_message(msg, content=True)
779 except:
779 except:
780 self.log.error("iopub::invalid IOPub message", exc_info=True)
780 self.log.error("iopub::invalid IOPub message", exc_info=True)
781 return
781 return
782
782
783 parent = msg['parent_header']
783 parent = msg['parent_header']
784 if not parent:
784 if not parent:
785 self.log.error("iopub::invalid IOPub message: %s"%msg)
785 self.log.error("iopub::invalid IOPub message: %s"%msg)
786 return
786 return
787 msg_id = parent['msg_id']
787 msg_id = parent['msg_id']
788 msg_type = msg['msg_type']
788 msg_type = msg['msg_type']
789 content = msg['content']
789 content = msg['content']
790
790
791 # ensure msg_id is in db
791 # ensure msg_id is in db
792 try:
792 try:
793 rec = self.db.get_record(msg_id)
793 rec = self.db.get_record(msg_id)
794 except KeyError:
794 except KeyError:
795 rec = empty_record()
795 rec = empty_record()
796 rec['msg_id'] = msg_id
796 rec['msg_id'] = msg_id
797 self.db.add_record(msg_id, rec)
797 self.db.add_record(msg_id, rec)
798 # stream
798 # stream
799 d = {}
799 d = {}
800 if msg_type == 'stream':
800 if msg_type == 'stream':
801 name = content['name']
801 name = content['name']
802 s = rec[name] or ''
802 s = rec[name] or ''
803 d[name] = s + content['data']
803 d[name] = s + content['data']
804
804
805 elif msg_type == 'pyerr':
805 elif msg_type == 'pyerr':
806 d['pyerr'] = content
806 d['pyerr'] = content
807 elif msg_type == 'pyin':
807 elif msg_type == 'pyin':
808 d['pyin'] = content['code']
808 d['pyin'] = content['code']
809 else:
809 else:
810 d[msg_type] = content.get('data', '')
810 d[msg_type] = content.get('data', '')
811
811
812 try:
812 try:
813 self.db.update_record(msg_id, d)
813 self.db.update_record(msg_id, d)
814 except Exception:
814 except Exception:
815 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
815 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
816
816
817
817
818
818
819 #-------------------------------------------------------------------------
819 #-------------------------------------------------------------------------
820 # Registration requests
820 # Registration requests
821 #-------------------------------------------------------------------------
821 #-------------------------------------------------------------------------
822
822
823 def connection_request(self, client_id, msg):
823 def connection_request(self, client_id, msg):
824 """Reply with connection addresses for clients."""
824 """Reply with connection addresses for clients."""
825 self.log.info("client::client %s connected"%client_id)
825 self.log.info("client::client %s connected"%client_id)
826 content = dict(status='ok')
826 content = dict(status='ok')
827 content.update(self.client_info)
827 content.update(self.client_info)
828 jsonable = {}
828 jsonable = {}
829 for k,v in self.keytable.iteritems():
829 for k,v in self.keytable.iteritems():
830 if v not in self.dead_engines:
830 if v not in self.dead_engines:
831 jsonable[str(k)] = v
831 jsonable[str(k)] = v
832 content['engines'] = jsonable
832 content['engines'] = jsonable
833 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
833 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
834
834
835 def register_engine(self, reg, msg):
835 def register_engine(self, reg, msg):
836 """Register a new engine."""
836 """Register a new engine."""
837 content = msg['content']
837 content = msg['content']
838 try:
838 try:
839 queue = content['queue']
839 queue = content['queue']
840 except KeyError:
840 except KeyError:
841 self.log.error("registration::queue not specified", exc_info=True)
841 self.log.error("registration::queue not specified", exc_info=True)
842 return
842 return
843 heart = content.get('heartbeat', None)
843 heart = content.get('heartbeat', None)
844 """register a new engine, and create the socket(s) necessary"""
844 """register a new engine, and create the socket(s) necessary"""
845 eid = self._next_id
845 eid = self._next_id
846 # print (eid, queue, reg, heart)
846 # print (eid, queue, reg, heart)
847
847
848 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
848 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
849
849
850 content = dict(id=eid,status='ok')
850 content = dict(id=eid,status='ok')
851 content.update(self.engine_info)
851 content.update(self.engine_info)
852 # check if requesting available IDs:
852 # check if requesting available IDs:
853 if queue in self.by_ident:
853 if queue in self.by_ident:
854 try:
854 try:
855 raise KeyError("queue_id %r in use"%queue)
855 raise KeyError("queue_id %r in use"%queue)
856 except:
856 except:
857 content = error.wrap_exception()
857 content = error.wrap_exception()
858 self.log.error("queue_id %r in use"%queue, exc_info=True)
858 self.log.error("queue_id %r in use"%queue, exc_info=True)
859 elif heart in self.hearts: # need to check unique hearts?
859 elif heart in self.hearts: # need to check unique hearts?
860 try:
860 try:
861 raise KeyError("heart_id %r in use"%heart)
861 raise KeyError("heart_id %r in use"%heart)
862 except:
862 except:
863 self.log.error("heart_id %r in use"%heart, exc_info=True)
863 self.log.error("heart_id %r in use"%heart, exc_info=True)
864 content = error.wrap_exception()
864 content = error.wrap_exception()
865 else:
865 else:
866 for h, pack in self.incoming_registrations.iteritems():
866 for h, pack in self.incoming_registrations.iteritems():
867 if heart == h:
867 if heart == h:
868 try:
868 try:
869 raise KeyError("heart_id %r in use"%heart)
869 raise KeyError("heart_id %r in use"%heart)
870 except:
870 except:
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 content = error.wrap_exception()
872 content = error.wrap_exception()
873 break
873 break
874 elif queue == pack[1]:
874 elif queue == pack[1]:
875 try:
875 try:
876 raise KeyError("queue_id %r in use"%queue)
876 raise KeyError("queue_id %r in use"%queue)
877 except:
877 except:
878 self.log.error("queue_id %r in use"%queue, exc_info=True)
878 self.log.error("queue_id %r in use"%queue, exc_info=True)
879 content = error.wrap_exception()
879 content = error.wrap_exception()
880 break
880 break
881
881
882 msg = self.session.send(self.query, "registration_reply",
882 msg = self.session.send(self.query, "registration_reply",
883 content=content,
883 content=content,
884 ident=reg)
884 ident=reg)
885
885
886 if content['status'] == 'ok':
886 if content['status'] == 'ok':
887 if heart in self.heartmonitor.hearts:
887 if heart in self.heartmonitor.hearts:
888 # already beating
888 # already beating
889 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
889 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
890 self.finish_registration(heart)
890 self.finish_registration(heart)
891 else:
891 else:
892 purge = lambda : self._purge_stalled_registration(heart)
892 purge = lambda : self._purge_stalled_registration(heart)
893 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
893 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
894 dc.start()
894 dc.start()
895 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
895 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
896 else:
896 else:
897 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
897 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
898 return eid
898 return eid
899
899
900 def unregister_engine(self, ident, msg):
900 def unregister_engine(self, ident, msg):
901 """Unregister an engine that explicitly requested to leave."""
901 """Unregister an engine that explicitly requested to leave."""
902 try:
902 try:
903 eid = msg['content']['id']
903 eid = msg['content']['id']
904 except:
904 except:
905 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
905 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
906 return
906 return
907 self.log.info("registration::unregister_engine(%s)"%eid)
907 self.log.info("registration::unregister_engine(%s)"%eid)
908 # print (eid)
908 # print (eid)
909 uuid = self.keytable[eid]
909 uuid = self.keytable[eid]
910 content=dict(id=eid, queue=uuid)
910 content=dict(id=eid, queue=uuid)
911 self.dead_engines.add(uuid)
911 self.dead_engines.add(uuid)
912 # self.ids.remove(eid)
912 # self.ids.remove(eid)
913 # uuid = self.keytable.pop(eid)
913 # uuid = self.keytable.pop(eid)
914 #
914 #
915 # ec = self.engines.pop(eid)
915 # ec = self.engines.pop(eid)
916 # self.hearts.pop(ec.heartbeat)
916 # self.hearts.pop(ec.heartbeat)
917 # self.by_ident.pop(ec.queue)
917 # self.by_ident.pop(ec.queue)
918 # self.completed.pop(eid)
918 # self.completed.pop(eid)
919 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
919 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
920 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
920 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
921 dc.start()
921 dc.start()
922 ############## TODO: HANDLE IT ################
922 ############## TODO: HANDLE IT ################
923
923
924 if self.notifier:
924 if self.notifier:
925 self.session.send(self.notifier, "unregistration_notification", content=content)
925 self.session.send(self.notifier, "unregistration_notification", content=content)
926
926
927 def _handle_stranded_msgs(self, eid, uuid):
927 def _handle_stranded_msgs(self, eid, uuid):
928 """Handle messages known to be on an engine when the engine unregisters.
928 """Handle messages known to be on an engine when the engine unregisters.
929
929
930 It is possible that this will fire prematurely - that is, an engine will
930 It is possible that this will fire prematurely - that is, an engine will
931 go down after completing a result, and the client will be notified
931 go down after completing a result, and the client will be notified
932 that the result failed and later receive the actual result.
932 that the result failed and later receive the actual result.
933 """
933 """
934
934
935 outstanding = self.queues[eid]
935 outstanding = self.queues[eid]
936
936
937 for msg_id in outstanding:
937 for msg_id in outstanding:
938 self.pending.remove(msg_id)
938 self.pending.remove(msg_id)
939 self.all_completed.add(msg_id)
939 self.all_completed.add(msg_id)
940 try:
940 try:
941 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
941 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
942 except:
942 except:
943 content = error.wrap_exception()
943 content = error.wrap_exception()
944 # build a fake header:
944 # build a fake header:
945 header = {}
945 header = {}
946 header['engine'] = uuid
946 header['engine'] = uuid
947 header['date'] = datetime.now()
947 header['date'] = datetime.now()
948 rec = dict(result_content=content, result_header=header, result_buffers=[])
948 rec = dict(result_content=content, result_header=header, result_buffers=[])
949 rec['completed'] = header['date']
949 rec['completed'] = header['date']
950 rec['engine_uuid'] = uuid
950 rec['engine_uuid'] = uuid
951 try:
951 try:
952 self.db.update_record(msg_id, rec)
952 self.db.update_record(msg_id, rec)
953 except Exception:
953 except Exception:
954 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
954 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
955
955
956
956
957 def finish_registration(self, heart):
957 def finish_registration(self, heart):
958 """Second half of engine registration, called after our HeartMonitor
958 """Second half of engine registration, called after our HeartMonitor
959 has received a beat from the Engine's Heart."""
959 has received a beat from the Engine's Heart."""
960 try:
960 try:
961 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
961 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
962 except KeyError:
962 except KeyError:
963 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
963 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
964 return
964 return
965 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
965 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
966 if purge is not None:
966 if purge is not None:
967 purge.stop()
967 purge.stop()
968 control = queue
968 control = queue
969 self.ids.add(eid)
969 self.ids.add(eid)
970 self.keytable[eid] = queue
970 self.keytable[eid] = queue
971 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
971 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
972 control=control, heartbeat=heart)
972 control=control, heartbeat=heart)
973 self.by_ident[queue] = eid
973 self.by_ident[queue] = eid
974 self.queues[eid] = list()
974 self.queues[eid] = list()
975 self.tasks[eid] = list()
975 self.tasks[eid] = list()
976 self.completed[eid] = list()
976 self.completed[eid] = list()
977 self.hearts[heart] = eid
977 self.hearts[heart] = eid
978 content = dict(id=eid, queue=self.engines[eid].queue)
978 content = dict(id=eid, queue=self.engines[eid].queue)
979 if self.notifier:
979 if self.notifier:
980 self.session.send(self.notifier, "registration_notification", content=content)
980 self.session.send(self.notifier, "registration_notification", content=content)
981 self.log.info("engine::Engine Connected: %i"%eid)
981 self.log.info("engine::Engine Connected: %i"%eid)
982
982
983 def _purge_stalled_registration(self, heart):
983 def _purge_stalled_registration(self, heart):
984 if heart in self.incoming_registrations:
984 if heart in self.incoming_registrations:
985 eid = self.incoming_registrations.pop(heart)[0]
985 eid = self.incoming_registrations.pop(heart)[0]
986 self.log.info("registration::purging stalled registration: %i"%eid)
986 self.log.info("registration::purging stalled registration: %i"%eid)
987 else:
987 else:
988 pass
988 pass
989
989
990 #-------------------------------------------------------------------------
990 #-------------------------------------------------------------------------
991 # Client Requests
991 # Client Requests
992 #-------------------------------------------------------------------------
992 #-------------------------------------------------------------------------
993
993
994 def shutdown_request(self, client_id, msg):
994 def shutdown_request(self, client_id, msg):
995 """handle shutdown request."""
995 """handle shutdown request."""
996 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
996 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
997 # also notify other clients of shutdown
997 # also notify other clients of shutdown
998 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
998 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
999 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
999 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1000 dc.start()
1000 dc.start()
1001
1001
1002 def _shutdown(self):
1002 def _shutdown(self):
1003 self.log.info("hub::hub shutting down.")
1003 self.log.info("hub::hub shutting down.")
1004 time.sleep(0.1)
1004 time.sleep(0.1)
1005 sys.exit(0)
1005 sys.exit(0)
1006
1006
1007
1007
1008 def check_load(self, client_id, msg):
1008 def check_load(self, client_id, msg):
1009 content = msg['content']
1009 content = msg['content']
1010 try:
1010 try:
1011 targets = content['targets']
1011 targets = content['targets']
1012 targets = self._validate_targets(targets)
1012 targets = self._validate_targets(targets)
1013 except:
1013 except:
1014 content = error.wrap_exception()
1014 content = error.wrap_exception()
1015 self.session.send(self.query, "hub_error",
1015 self.session.send(self.query, "hub_error",
1016 content=content, ident=client_id)
1016 content=content, ident=client_id)
1017 return
1017 return
1018
1018
1019 content = dict(status='ok')
1019 content = dict(status='ok')
1020 # loads = {}
1020 # loads = {}
1021 for t in targets:
1021 for t in targets:
1022 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1022 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1023 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1023 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1024
1024
1025
1025
1026 def queue_status(self, client_id, msg):
1026 def queue_status(self, client_id, msg):
1027 """Return the Queue status of one or more targets.
1027 """Return the Queue status of one or more targets.
1028 if verbose: return the msg_ids
1028 if verbose: return the msg_ids
1029 else: return len of each type.
1029 else: return len of each type.
1030 keys: queue (pending MUX jobs)
1030 keys: queue (pending MUX jobs)
1031 tasks (pending Task jobs)
1031 tasks (pending Task jobs)
1032 completed (finished jobs from both queues)"""
1032 completed (finished jobs from both queues)"""
1033 content = msg['content']
1033 content = msg['content']
1034 targets = content['targets']
1034 targets = content['targets']
1035 try:
1035 try:
1036 targets = self._validate_targets(targets)
1036 targets = self._validate_targets(targets)
1037 except:
1037 except:
1038 content = error.wrap_exception()
1038 content = error.wrap_exception()
1039 self.session.send(self.query, "hub_error",
1039 self.session.send(self.query, "hub_error",
1040 content=content, ident=client_id)
1040 content=content, ident=client_id)
1041 return
1041 return
1042 verbose = content.get('verbose', False)
1042 verbose = content.get('verbose', False)
1043 content = dict(status='ok')
1043 content = dict(status='ok')
1044 for t in targets:
1044 for t in targets:
1045 queue = self.queues[t]
1045 queue = self.queues[t]
1046 completed = self.completed[t]
1046 completed = self.completed[t]
1047 tasks = self.tasks[t]
1047 tasks = self.tasks[t]
1048 if not verbose:
1048 if not verbose:
1049 queue = len(queue)
1049 queue = len(queue)
1050 completed = len(completed)
1050 completed = len(completed)
1051 tasks = len(tasks)
1051 tasks = len(tasks)
1052 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1052 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1053 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1053 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1054
1054
1055 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1055 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1056
1056
1057 def purge_results(self, client_id, msg):
1057 def purge_results(self, client_id, msg):
1058 """Purge results from memory. This method is more valuable before we move
1058 """Purge results from memory. This method is more valuable before we move
1059 to a DB based message storage mechanism."""
1059 to a DB based message storage mechanism."""
1060 content = msg['content']
1060 content = msg['content']
1061 msg_ids = content.get('msg_ids', [])
1061 msg_ids = content.get('msg_ids', [])
1062 reply = dict(status='ok')
1062 reply = dict(status='ok')
1063 if msg_ids == 'all':
1063 if msg_ids == 'all':
1064 try:
1064 try:
1065 self.db.drop_matching_records(dict(completed={'$ne':None}))
1065 self.db.drop_matching_records(dict(completed={'$ne':None}))
1066 except Exception:
1066 except Exception:
1067 reply = error.wrap_exception()
1067 reply = error.wrap_exception()
1068 else:
1068 else:
1069 for msg_id in msg_ids:
1069 pending = filter(lambda m: m in self.pending, msg_ids)
1070 if msg_id in self.all_completed:
1070 if pending:
1071 self.db.drop_record(msg_id)
1071 try:
1072 else:
1072 raise IndexError("msg pending: %r"%pending[0])
1073 if msg_id in self.pending:
1073 except:
1074 try:
1074 reply = error.wrap_exception()
1075 raise IndexError("msg pending: %r"%msg_id)
1075 else:
1076 except:
1076 try:
1077 reply = error.wrap_exception()
1077 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1078 else:
1078 except Exception:
1079 reply = error.wrap_exception()
1080
1081 if reply['status'] == 'ok':
1082 eids = content.get('engine_ids', [])
1083 for eid in eids:
1084 if eid not in self.engines:
1079 try:
1085 try:
1080 raise IndexError("No such msg: %r"%msg_id)
1086 raise IndexError("No such engine: %i"%eid)
1081 except:
1087 except:
1082 reply = error.wrap_exception()
1088 reply = error.wrap_exception()
1083 break
1089 break
1084 eids = content.get('engine_ids', [])
1090 msg_ids = self.completed.pop(eid)
1085 for eid in eids:
1091 uid = self.engines[eid].queue
1086 if eid not in self.engines:
1087 try:
1092 try:
1088 raise IndexError("No such engine: %i"%eid)
1093 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1089 except:
1094 except Exception:
1090 reply = error.wrap_exception()
1095 reply = error.wrap_exception()
1091 break
1096 break
1092 msg_ids = self.completed.pop(eid)
1093 uid = self.engines[eid].queue
1094 try:
1095 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1096 except Exception:
1097 reply = error.wrap_exception()
1098 break
1099
1097
1100 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1098 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1101
1099
1102 def resubmit_task(self, client_id, msg):
1100 def resubmit_task(self, client_id, msg):
1103 """Resubmit one or more tasks."""
1101 """Resubmit one or more tasks."""
1104 def finish(reply):
1102 def finish(reply):
1105 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1103 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1106
1104
1107 content = msg['content']
1105 content = msg['content']
1108 msg_ids = content['msg_ids']
1106 msg_ids = content['msg_ids']
1109 reply = dict(status='ok')
1107 reply = dict(status='ok')
1110 try:
1108 try:
1111 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1109 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1112 'header', 'content', 'buffers'])
1110 'header', 'content', 'buffers'])
1113 except Exception:
1111 except Exception:
1114 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1112 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1115 return finish(error.wrap_exception())
1113 return finish(error.wrap_exception())
1116
1114
1117 # validate msg_ids
1115 # validate msg_ids
1118 found_ids = [ rec['msg_id'] for rec in records ]
1116 found_ids = [ rec['msg_id'] for rec in records ]
1119 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1117 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1120 if len(records) > len(msg_ids):
1118 if len(records) > len(msg_ids):
1121 try:
1119 try:
1122 raise RuntimeError("DB appears to be in an inconsistent state."
1120 raise RuntimeError("DB appears to be in an inconsistent state."
1123 "More matching records were found than should exist")
1121 "More matching records were found than should exist")
1124 except Exception:
1122 except Exception:
1125 return finish(error.wrap_exception())
1123 return finish(error.wrap_exception())
1126 elif len(records) < len(msg_ids):
1124 elif len(records) < len(msg_ids):
1127 missing = [ m for m in msg_ids if m not in found_ids ]
1125 missing = [ m for m in msg_ids if m not in found_ids ]
1128 try:
1126 try:
1129 raise KeyError("No such msg(s): %s"%missing)
1127 raise KeyError("No such msg(s): %s"%missing)
1130 except KeyError:
1128 except KeyError:
1131 return finish(error.wrap_exception())
1129 return finish(error.wrap_exception())
1132 elif invalid_ids:
1130 elif invalid_ids:
1133 msg_id = invalid_ids[0]
1131 msg_id = invalid_ids[0]
1134 try:
1132 try:
1135 raise ValueError("Task %r appears to be inflight"%(msg_id))
1133 raise ValueError("Task %r appears to be inflight"%(msg_id))
1136 except Exception:
1134 except Exception:
1137 return finish(error.wrap_exception())
1135 return finish(error.wrap_exception())
1138
1136
1139 # clear the existing records
1137 # clear the existing records
1140 rec = empty_record()
1138 rec = empty_record()
1141 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1139 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1142 rec['resubmitted'] = datetime.now()
1140 rec['resubmitted'] = datetime.now()
1143 rec['queue'] = 'task'
1141 rec['queue'] = 'task'
1144 rec['client_uuid'] = client_id[0]
1142 rec['client_uuid'] = client_id[0]
1145 try:
1143 try:
1146 for msg_id in msg_ids:
1144 for msg_id in msg_ids:
1147 self.all_completed.discard(msg_id)
1145 self.all_completed.discard(msg_id)
1148 self.db.update_record(msg_id, rec)
1146 self.db.update_record(msg_id, rec)
1149 except Exception:
1147 except Exception:
1150 self.log.error('db::db error upating record', exc_info=True)
1148 self.log.error('db::db error upating record', exc_info=True)
1151 reply = error.wrap_exception()
1149 reply = error.wrap_exception()
1152 else:
1150 else:
1153 # send the messages
1151 # send the messages
1154 for rec in records:
1152 for rec in records:
1155 header = rec['header']
1153 header = rec['header']
1156 msg = self.session.msg(header['msg_type'])
1154 msg = self.session.msg(header['msg_type'])
1157 msg['content'] = rec['content']
1155 msg['content'] = rec['content']
1158 msg['header'] = header
1156 msg['header'] = header
1159 msg['msg_id'] = rec['msg_id']
1157 msg['msg_id'] = rec['msg_id']
1160 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1158 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1161
1159
1162 finish(dict(status='ok'))
1160 finish(dict(status='ok'))
1163
1161
1164
1162
1165 def _extract_record(self, rec):
1163 def _extract_record(self, rec):
1166 """decompose a TaskRecord dict into subsection of reply for get_result"""
1164 """decompose a TaskRecord dict into subsection of reply for get_result"""
1167 io_dict = {}
1165 io_dict = {}
1168 for key in 'pyin pyout pyerr stdout stderr'.split():
1166 for key in 'pyin pyout pyerr stdout stderr'.split():
1169 io_dict[key] = rec[key]
1167 io_dict[key] = rec[key]
1170 content = { 'result_content': rec['result_content'],
1168 content = { 'result_content': rec['result_content'],
1171 'header': rec['header'],
1169 'header': rec['header'],
1172 'result_header' : rec['result_header'],
1170 'result_header' : rec['result_header'],
1173 'io' : io_dict,
1171 'io' : io_dict,
1174 }
1172 }
1175 if rec['result_buffers']:
1173 if rec['result_buffers']:
1176 buffers = map(str, rec['result_buffers'])
1174 buffers = map(str, rec['result_buffers'])
1177 else:
1175 else:
1178 buffers = []
1176 buffers = []
1179
1177
1180 return content, buffers
1178 return content, buffers
1181
1179
1182 def get_results(self, client_id, msg):
1180 def get_results(self, client_id, msg):
1183 """Get the result of 1 or more messages."""
1181 """Get the result of 1 or more messages."""
1184 content = msg['content']
1182 content = msg['content']
1185 msg_ids = sorted(set(content['msg_ids']))
1183 msg_ids = sorted(set(content['msg_ids']))
1186 statusonly = content.get('status_only', False)
1184 statusonly = content.get('status_only', False)
1187 pending = []
1185 pending = []
1188 completed = []
1186 completed = []
1189 content = dict(status='ok')
1187 content = dict(status='ok')
1190 content['pending'] = pending
1188 content['pending'] = pending
1191 content['completed'] = completed
1189 content['completed'] = completed
1192 buffers = []
1190 buffers = []
1193 if not statusonly:
1191 if not statusonly:
1194 try:
1192 try:
1195 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1193 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1196 # turn match list into dict, for faster lookup
1194 # turn match list into dict, for faster lookup
1197 records = {}
1195 records = {}
1198 for rec in matches:
1196 for rec in matches:
1199 records[rec['msg_id']] = rec
1197 records[rec['msg_id']] = rec
1200 except Exception:
1198 except Exception:
1201 content = error.wrap_exception()
1199 content = error.wrap_exception()
1202 self.session.send(self.query, "result_reply", content=content,
1200 self.session.send(self.query, "result_reply", content=content,
1203 parent=msg, ident=client_id)
1201 parent=msg, ident=client_id)
1204 return
1202 return
1205 else:
1203 else:
1206 records = {}
1204 records = {}
1207 for msg_id in msg_ids:
1205 for msg_id in msg_ids:
1208 if msg_id in self.pending:
1206 if msg_id in self.pending:
1209 pending.append(msg_id)
1207 pending.append(msg_id)
1210 elif msg_id in self.all_completed:
1208 elif msg_id in self.all_completed:
1211 completed.append(msg_id)
1209 completed.append(msg_id)
1212 if not statusonly:
1210 if not statusonly:
1213 c,bufs = self._extract_record(records[msg_id])
1211 c,bufs = self._extract_record(records[msg_id])
1214 content[msg_id] = c
1212 content[msg_id] = c
1215 buffers.extend(bufs)
1213 buffers.extend(bufs)
1216 elif msg_id in records:
1214 elif msg_id in records:
1217 if rec['completed']:
1215 if rec['completed']:
1218 completed.append(msg_id)
1216 completed.append(msg_id)
1219 c,bufs = self._extract_record(records[msg_id])
1217 c,bufs = self._extract_record(records[msg_id])
1220 content[msg_id] = c
1218 content[msg_id] = c
1221 buffers.extend(bufs)
1219 buffers.extend(bufs)
1222 else:
1220 else:
1223 pending.append(msg_id)
1221 pending.append(msg_id)
1224 else:
1222 else:
1225 try:
1223 try:
1226 raise KeyError('No such message: '+msg_id)
1224 raise KeyError('No such message: '+msg_id)
1227 except:
1225 except:
1228 content = error.wrap_exception()
1226 content = error.wrap_exception()
1229 break
1227 break
1230 self.session.send(self.query, "result_reply", content=content,
1228 self.session.send(self.query, "result_reply", content=content,
1231 parent=msg, ident=client_id,
1229 parent=msg, ident=client_id,
1232 buffers=buffers)
1230 buffers=buffers)
1233
1231
1234 def get_history(self, client_id, msg):
1232 def get_history(self, client_id, msg):
1235 """Get a list of all msg_ids in our DB records"""
1233 """Get a list of all msg_ids in our DB records"""
1236 try:
1234 try:
1237 msg_ids = self.db.get_history()
1235 msg_ids = self.db.get_history()
1238 except Exception as e:
1236 except Exception as e:
1239 content = error.wrap_exception()
1237 content = error.wrap_exception()
1240 else:
1238 else:
1241 content = dict(status='ok', history=msg_ids)
1239 content = dict(status='ok', history=msg_ids)
1242
1240
1243 self.session.send(self.query, "history_reply", content=content,
1241 self.session.send(self.query, "history_reply", content=content,
1244 parent=msg, ident=client_id)
1242 parent=msg, ident=client_id)
1245
1243
1246 def db_query(self, client_id, msg):
1244 def db_query(self, client_id, msg):
1247 """Perform a raw query on the task record database."""
1245 """Perform a raw query on the task record database."""
1248 content = msg['content']
1246 content = msg['content']
1249 query = content.get('query', {})
1247 query = content.get('query', {})
1250 keys = content.get('keys', None)
1248 keys = content.get('keys', None)
1251 query = util.extract_dates(query)
1249 query = util.extract_dates(query)
1252 buffers = []
1250 buffers = []
1253 empty = list()
1251 empty = list()
1254
1252
1255 try:
1253 try:
1256 records = self.db.find_records(query, keys)
1254 records = self.db.find_records(query, keys)
1257 except Exception as e:
1255 except Exception as e:
1258 content = error.wrap_exception()
1256 content = error.wrap_exception()
1259 else:
1257 else:
1260 # extract buffers from reply content:
1258 # extract buffers from reply content:
1261 if keys is not None:
1259 if keys is not None:
1262 buffer_lens = [] if 'buffers' in keys else None
1260 buffer_lens = [] if 'buffers' in keys else None
1263 result_buffer_lens = [] if 'result_buffers' in keys else None
1261 result_buffer_lens = [] if 'result_buffers' in keys else None
1264 else:
1262 else:
1265 buffer_lens = []
1263 buffer_lens = []
1266 result_buffer_lens = []
1264 result_buffer_lens = []
1267
1265
1268 for rec in records:
1266 for rec in records:
1269 # buffers may be None, so double check
1267 # buffers may be None, so double check
1270 if buffer_lens is not None:
1268 if buffer_lens is not None:
1271 b = rec.pop('buffers', empty) or empty
1269 b = rec.pop('buffers', empty) or empty
1272 buffer_lens.append(len(b))
1270 buffer_lens.append(len(b))
1273 buffers.extend(b)
1271 buffers.extend(b)
1274 if result_buffer_lens is not None:
1272 if result_buffer_lens is not None:
1275 rb = rec.pop('result_buffers', empty) or empty
1273 rb = rec.pop('result_buffers', empty) or empty
1276 result_buffer_lens.append(len(rb))
1274 result_buffer_lens.append(len(rb))
1277 buffers.extend(rb)
1275 buffers.extend(rb)
1278 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1276 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1279 result_buffer_lens=result_buffer_lens)
1277 result_buffer_lens=result_buffer_lens)
1280
1278
1281 self.session.send(self.query, "db_reply", content=content,
1279 self.session.send(self.query, "db_reply", content=content,
1282 parent=msg, ident=client_id,
1280 parent=msg, ident=client_id,
1283 buffers=buffers)
1281 buffers=buffers)
1284
1282
@@ -1,96 +1,101 b''
1 """A TaskRecord backend using mongodb"""
1 """A TaskRecord backend using mongodb"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 from datetime import datetime
10
11 from pymongo import Connection
9 from pymongo import Connection
12 from pymongo.binary import Binary
10 from pymongo.binary import Binary
13
11
14 from IPython.utils.traitlets import Dict, List, CUnicode
12 from IPython.utils.traitlets import Dict, List, CUnicode, CStr, Instance
15
13
16 from .dictdb import BaseDB
14 from .dictdb import BaseDB
17
15
18 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
19 # MongoDB class
17 # MongoDB class
20 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
21
19
22 class MongoDB(BaseDB):
20 class MongoDB(BaseDB):
23 """MongoDB TaskRecord backend."""
21 """MongoDB TaskRecord backend."""
24
22
25 connection_args = List(config=True) # args passed to pymongo.Connection
23 connection_args = List(config=True) # args passed to pymongo.Connection
26 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
24 connection_kwargs = Dict(config=True) # kwargs passed to pymongo.Connection
27 database = CUnicode(config=True) # name of the mongodb database
25 database = CUnicode(config=True) # name of the mongodb database
28 _table = Dict()
26
27 _connection = Instance(Connection) # pymongo connection
29
28
30 def __init__(self, **kwargs):
29 def __init__(self, **kwargs):
31 super(MongoDB, self).__init__(**kwargs)
30 super(MongoDB, self).__init__(**kwargs)
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
31 if self._connection is None:
32 self._connection = Connection(*self.connection_args, **self.connection_kwargs)
33 if not self.database:
33 if not self.database:
34 self.database = self.session
34 self.database = self.session
35 self._db = self._connection[self.database]
35 self._db = self._connection[self.database]
36 self._records = self._db['task_records']
36 self._records = self._db['task_records']
37 self._records.ensure_index('msg_id', unique=True)
38 self._records.ensure_index('submitted') # for sorting history
39 # for rec in self._records.find
37
40
38 def _binary_buffers(self, rec):
41 def _binary_buffers(self, rec):
39 for key in ('buffers', 'result_buffers'):
42 for key in ('buffers', 'result_buffers'):
40 if rec.get(key, None):
43 if rec.get(key, None):
41 rec[key] = map(Binary, rec[key])
44 rec[key] = map(Binary, rec[key])
42 return rec
45 return rec
43
46
44 def add_record(self, msg_id, rec):
47 def add_record(self, msg_id, rec):
45 """Add a new Task Record, by msg_id."""
48 """Add a new Task Record, by msg_id."""
46 # print rec
49 # print rec
47 rec = self._binary_buffers(rec)
50 rec = self._binary_buffers(rec)
48 obj_id = self._records.insert(rec)
51 self._records.insert(rec)
49 self._table[msg_id] = obj_id
50
52
51 def get_record(self, msg_id):
53 def get_record(self, msg_id):
52 """Get a specific Task Record, by msg_id."""
54 """Get a specific Task Record, by msg_id."""
53 return self._records.find_one(self._table[msg_id])
55 r = self._records.find_one({'msg_id': msg_id})
56 if not r:
57 # r will be '' if nothing is found
58 raise KeyError(msg_id)
59 return r
54
60
55 def update_record(self, msg_id, rec):
61 def update_record(self, msg_id, rec):
56 """Update the data in an existing record."""
62 """Update the data in an existing record."""
57 rec = self._binary_buffers(rec)
63 rec = self._binary_buffers(rec)
58 obj_id = self._table[msg_id]
64
59 self._records.update({'_id':obj_id}, {'$set': rec})
65 self._records.update({'msg_id':msg_id}, {'$set': rec})
60
66
61 def drop_matching_records(self, check):
67 def drop_matching_records(self, check):
62 """Remove a record from the DB."""
68 """Remove a record from the DB."""
63 self._records.remove(check)
69 self._records.remove(check)
64
70
65 def drop_record(self, msg_id):
71 def drop_record(self, msg_id):
66 """Remove a record from the DB."""
72 """Remove a record from the DB."""
67 obj_id = self._table.pop(msg_id)
73 self._records.remove({'msg_id':msg_id})
68 self._records.remove(obj_id)
69
74
70 def find_records(self, check, keys=None):
75 def find_records(self, check, keys=None):
71 """Find records matching a query dict, optionally extracting subset of keys.
76 """Find records matching a query dict, optionally extracting subset of keys.
72
77
73 Returns list of matching records.
78 Returns list of matching records.
74
79
75 Parameters
80 Parameters
76 ----------
81 ----------
77
82
78 check: dict
83 check: dict
79 mongodb-style query argument
84 mongodb-style query argument
80 keys: list of strs [optional]
85 keys: list of strs [optional]
81 if specified, the subset of keys to extract. msg_id will *always* be
86 if specified, the subset of keys to extract. msg_id will *always* be
82 included.
87 included.
83 """
88 """
84 if keys and 'msg_id' not in keys:
89 if keys and 'msg_id' not in keys:
85 keys.append('msg_id')
90 keys.append('msg_id')
86 matches = list(self._records.find(check,keys))
91 matches = list(self._records.find(check,keys))
87 for rec in matches:
92 for rec in matches:
88 rec.pop('_id')
93 rec.pop('_id')
89 return matches
94 return matches
90
95
91 def get_history(self):
96 def get_history(self):
92 """get all msg_ids, ordered by time submitted."""
97 """get all msg_ids, ordered by time submitted."""
93 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
98 cursor = self._records.find({},{'msg_id':1}).sort('submitted')
94 return [ rec['msg_id'] for rec in cursor ]
99 return [ rec['msg_id'] for rec in cursor ]
95
100
96
101
@@ -1,312 +1,326 b''
1 """A TaskRecord backend using sqlite3"""
1 """A TaskRecord backend using sqlite3"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 import json
9 import json
10 import os
10 import os
11 import cPickle as pickle
11 import cPickle as pickle
12 from datetime import datetime
12 from datetime import datetime
13
13
14 import sqlite3
14 import sqlite3
15
15
16 from zmq.eventloop import ioloop
16 from zmq.eventloop import ioloop
17
17
18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
18 from IPython.utils.traitlets import CUnicode, CStr, Instance, List
19 from .dictdb import BaseDB
19 from .dictdb import BaseDB
20 from IPython.parallel.util import ISO8601
20 from IPython.parallel.util import ISO8601
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # SQLite operators, adapters, and converters
23 # SQLite operators, adapters, and converters
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 operators = {
26 operators = {
27 '$lt' : "<",
27 '$lt' : "<",
28 '$gt' : ">",
28 '$gt' : ">",
29 # null is handled weird with ==,!=
29 # null is handled weird with ==,!=
30 '$eq' : "IS",
30 '$eq' : "=",
31 '$ne' : "IS NOT",
31 '$ne' : "!=",
32 '$lte': "<=",
32 '$lte': "<=",
33 '$gte': ">=",
33 '$gte': ">=",
34 '$in' : ('IS', ' OR '),
34 '$in' : ('=', ' OR '),
35 '$nin': ('IS NOT', ' AND '),
35 '$nin': ('!=', ' AND '),
36 # '$all': None,
36 # '$all': None,
37 # '$mod': None,
37 # '$mod': None,
38 # '$exists' : None
38 # '$exists' : None
39 }
39 }
40 null_operators = {
41 '=' : "IS NULL",
42 '!=' : "IS NOT NULL",
43 }
40
44
41 def _adapt_datetime(dt):
45 def _adapt_datetime(dt):
42 return dt.strftime(ISO8601)
46 return dt.strftime(ISO8601)
43
47
44 def _convert_datetime(ds):
48 def _convert_datetime(ds):
45 if ds is None:
49 if ds is None:
46 return ds
50 return ds
47 else:
51 else:
48 return datetime.strptime(ds, ISO8601)
52 return datetime.strptime(ds, ISO8601)
49
53
50 def _adapt_dict(d):
54 def _adapt_dict(d):
51 return json.dumps(d)
55 return json.dumps(d)
52
56
53 def _convert_dict(ds):
57 def _convert_dict(ds):
54 if ds is None:
58 if ds is None:
55 return ds
59 return ds
56 else:
60 else:
57 return json.loads(ds)
61 return json.loads(ds)
58
62
59 def _adapt_bufs(bufs):
63 def _adapt_bufs(bufs):
60 # this is *horrible*
64 # this is *horrible*
61 # copy buffers into single list and pickle it:
65 # copy buffers into single list and pickle it:
62 if bufs and isinstance(bufs[0], (bytes, buffer)):
66 if bufs and isinstance(bufs[0], (bytes, buffer)):
63 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
64 elif bufs:
68 elif bufs:
65 return bufs
69 return bufs
66 else:
70 else:
67 return None
71 return None
68
72
69 def _convert_bufs(bs):
73 def _convert_bufs(bs):
70 if bs is None:
74 if bs is None:
71 return []
75 return []
72 else:
76 else:
73 return pickle.loads(bytes(bs))
77 return pickle.loads(bytes(bs))
74
78
75 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
76 # SQLiteDB class
80 # SQLiteDB class
77 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
78
82
79 class SQLiteDB(BaseDB):
83 class SQLiteDB(BaseDB):
80 """SQLite3 TaskRecord backend."""
84 """SQLite3 TaskRecord backend."""
81
85
82 filename = CUnicode('tasks.db', config=True)
86 filename = CUnicode('tasks.db', config=True)
83 location = CUnicode('', config=True)
87 location = CUnicode('', config=True)
84 table = CUnicode("", config=True)
88 table = CUnicode("", config=True)
85
89
86 _db = Instance('sqlite3.Connection')
90 _db = Instance('sqlite3.Connection')
87 _keys = List(['msg_id' ,
91 _keys = List(['msg_id' ,
88 'header' ,
92 'header' ,
89 'content',
93 'content',
90 'buffers',
94 'buffers',
91 'submitted',
95 'submitted',
92 'client_uuid' ,
96 'client_uuid' ,
93 'engine_uuid' ,
97 'engine_uuid' ,
94 'started',
98 'started',
95 'completed',
99 'completed',
96 'resubmitted',
100 'resubmitted',
97 'result_header' ,
101 'result_header' ,
98 'result_content' ,
102 'result_content' ,
99 'result_buffers' ,
103 'result_buffers' ,
100 'queue' ,
104 'queue' ,
101 'pyin' ,
105 'pyin' ,
102 'pyout',
106 'pyout',
103 'pyerr',
107 'pyerr',
104 'stdout',
108 'stdout',
105 'stderr',
109 'stderr',
106 ])
110 ])
107
111
108 def __init__(self, **kwargs):
112 def __init__(self, **kwargs):
109 super(SQLiteDB, self).__init__(**kwargs)
113 super(SQLiteDB, self).__init__(**kwargs)
110 if not self.table:
114 if not self.table:
111 # use session, and prefix _, since starting with # is illegal
115 # use session, and prefix _, since starting with # is illegal
112 self.table = '_'+self.session.replace('-','_')
116 self.table = '_'+self.session.replace('-','_')
113 if not self.location:
117 if not self.location:
114 if hasattr(self.config.Global, 'cluster_dir'):
118 if hasattr(self.config.Global, 'cluster_dir'):
115 self.location = self.config.Global.cluster_dir
119 self.location = self.config.Global.cluster_dir
116 else:
120 else:
117 self.location = '.'
121 self.location = '.'
118 self._init_db()
122 self._init_db()
119
123
120 # register db commit as 2s periodic callback
124 # register db commit as 2s periodic callback
121 # to prevent clogging pipes
125 # to prevent clogging pipes
122 # assumes we are being run in a zmq ioloop app
126 # assumes we are being run in a zmq ioloop app
123 loop = ioloop.IOLoop.instance()
127 loop = ioloop.IOLoop.instance()
124 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
128 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
125 pc.start()
129 pc.start()
126
130
127 def _defaults(self, keys=None):
131 def _defaults(self, keys=None):
128 """create an empty record"""
132 """create an empty record"""
129 d = {}
133 d = {}
130 keys = self._keys if keys is None else keys
134 keys = self._keys if keys is None else keys
131 for key in keys:
135 for key in keys:
132 d[key] = None
136 d[key] = None
133 return d
137 return d
134
138
135 def _init_db(self):
139 def _init_db(self):
136 """Connect to the database and get new session number."""
140 """Connect to the database and get new session number."""
137 # register adapters
141 # register adapters
138 sqlite3.register_adapter(datetime, _adapt_datetime)
142 sqlite3.register_adapter(datetime, _adapt_datetime)
139 sqlite3.register_converter('datetime', _convert_datetime)
143 sqlite3.register_converter('datetime', _convert_datetime)
140 sqlite3.register_adapter(dict, _adapt_dict)
144 sqlite3.register_adapter(dict, _adapt_dict)
141 sqlite3.register_converter('dict', _convert_dict)
145 sqlite3.register_converter('dict', _convert_dict)
142 sqlite3.register_adapter(list, _adapt_bufs)
146 sqlite3.register_adapter(list, _adapt_bufs)
143 sqlite3.register_converter('bufs', _convert_bufs)
147 sqlite3.register_converter('bufs', _convert_bufs)
144 # connect to the db
148 # connect to the db
145 dbfile = os.path.join(self.location, self.filename)
149 dbfile = os.path.join(self.location, self.filename)
146 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
150 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
147 # isolation_level = None)#,
151 # isolation_level = None)#,
148 cached_statements=64)
152 cached_statements=64)
149 # print dir(self._db)
153 # print dir(self._db)
150
154
151 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
155 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
152 (msg_id text PRIMARY KEY,
156 (msg_id text PRIMARY KEY,
153 header dict text,
157 header dict text,
154 content dict text,
158 content dict text,
155 buffers bufs blob,
159 buffers bufs blob,
156 submitted datetime text,
160 submitted datetime text,
157 client_uuid text,
161 client_uuid text,
158 engine_uuid text,
162 engine_uuid text,
159 started datetime text,
163 started datetime text,
160 completed datetime text,
164 completed datetime text,
161 resubmitted datetime text,
165 resubmitted datetime text,
162 result_header dict text,
166 result_header dict text,
163 result_content dict text,
167 result_content dict text,
164 result_buffers bufs blob,
168 result_buffers bufs blob,
165 queue text,
169 queue text,
166 pyin text,
170 pyin text,
167 pyout text,
171 pyout text,
168 pyerr text,
172 pyerr text,
169 stdout text,
173 stdout text,
170 stderr text)
174 stderr text)
171 """%self.table)
175 """%self.table)
172 self._db.commit()
176 self._db.commit()
173
177
174 def _dict_to_list(self, d):
178 def _dict_to_list(self, d):
175 """turn a mongodb-style record dict into a list."""
179 """turn a mongodb-style record dict into a list."""
176
180
177 return [ d[key] for key in self._keys ]
181 return [ d[key] for key in self._keys ]
178
182
179 def _list_to_dict(self, line, keys=None):
183 def _list_to_dict(self, line, keys=None):
180 """Inverse of dict_to_list"""
184 """Inverse of dict_to_list"""
181 keys = self._keys if keys is None else keys
185 keys = self._keys if keys is None else keys
182 d = self._defaults(keys)
186 d = self._defaults(keys)
183 for key,value in zip(keys, line):
187 for key,value in zip(keys, line):
184 d[key] = value
188 d[key] = value
185
189
186 return d
190 return d
187
191
188 def _render_expression(self, check):
192 def _render_expression(self, check):
189 """Turn a mongodb-style search dict into an SQL query."""
193 """Turn a mongodb-style search dict into an SQL query."""
190 expressions = []
194 expressions = []
191 args = []
195 args = []
192
196
193 skeys = set(check.keys())
197 skeys = set(check.keys())
194 skeys.difference_update(set(self._keys))
198 skeys.difference_update(set(self._keys))
195 skeys.difference_update(set(['buffers', 'result_buffers']))
199 skeys.difference_update(set(['buffers', 'result_buffers']))
196 if skeys:
200 if skeys:
197 raise KeyError("Illegal testing key(s): %s"%skeys)
201 raise KeyError("Illegal testing key(s): %s"%skeys)
198
202
199 for name,sub_check in check.iteritems():
203 for name,sub_check in check.iteritems():
200 if isinstance(sub_check, dict):
204 if isinstance(sub_check, dict):
201 for test,value in sub_check.iteritems():
205 for test,value in sub_check.iteritems():
202 try:
206 try:
203 op = operators[test]
207 op = operators[test]
204 except KeyError:
208 except KeyError:
205 raise KeyError("Unsupported operator: %r"%test)
209 raise KeyError("Unsupported operator: %r"%test)
206 if isinstance(op, tuple):
210 if isinstance(op, tuple):
207 op, join = op
211 op, join = op
208 expr = "%s %s ?"%(name, op)
212
209 if isinstance(value, (tuple,list)):
213 if value is None and op in null_operators:
210 expr = '( %s )'%( join.join([expr]*len(value)) )
214 expr = "%s %s"%null_operators[op]
211 args.extend(value)
212 else:
215 else:
213 args.append(value)
216 expr = "%s %s ?"%(name, op)
217 if isinstance(value, (tuple,list)):
218 if op in null_operators and any([v is None for v in value]):
219 # equality tests don't work with NULL
220 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
221 expr = '( %s )'%( join.join([expr]*len(value)) )
222 args.extend(value)
223 else:
224 args.append(value)
214 expressions.append(expr)
225 expressions.append(expr)
215 else:
226 else:
216 # it's an equality check
227 # it's an equality check
217 expressions.append("%s IS ?"%name)
228 if sub_check is None:
218 args.append(sub_check)
229 expressions.append("%s IS NULL")
230 else:
231 expressions.append("%s = ?"%name)
232 args.append(sub_check)
219
233
220 expr = " AND ".join(expressions)
234 expr = " AND ".join(expressions)
221 return expr, args
235 return expr, args
222
236
223 def add_record(self, msg_id, rec):
237 def add_record(self, msg_id, rec):
224 """Add a new Task Record, by msg_id."""
238 """Add a new Task Record, by msg_id."""
225 d = self._defaults()
239 d = self._defaults()
226 d.update(rec)
240 d.update(rec)
227 d['msg_id'] = msg_id
241 d['msg_id'] = msg_id
228 line = self._dict_to_list(d)
242 line = self._dict_to_list(d)
229 tups = '(%s)'%(','.join(['?']*len(line)))
243 tups = '(%s)'%(','.join(['?']*len(line)))
230 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
244 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
231 # self._db.commit()
245 # self._db.commit()
232
246
233 def get_record(self, msg_id):
247 def get_record(self, msg_id):
234 """Get a specific Task Record, by msg_id."""
248 """Get a specific Task Record, by msg_id."""
235 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
249 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
236 line = cursor.fetchone()
250 line = cursor.fetchone()
237 if line is None:
251 if line is None:
238 raise KeyError("No such msg: %r"%msg_id)
252 raise KeyError("No such msg: %r"%msg_id)
239 return self._list_to_dict(line)
253 return self._list_to_dict(line)
240
254
241 def update_record(self, msg_id, rec):
255 def update_record(self, msg_id, rec):
242 """Update the data in an existing record."""
256 """Update the data in an existing record."""
243 query = "UPDATE %s SET "%self.table
257 query = "UPDATE %s SET "%self.table
244 sets = []
258 sets = []
245 keys = sorted(rec.keys())
259 keys = sorted(rec.keys())
246 values = []
260 values = []
247 for key in keys:
261 for key in keys:
248 sets.append('%s = ?'%key)
262 sets.append('%s = ?'%key)
249 values.append(rec[key])
263 values.append(rec[key])
250 query += ', '.join(sets)
264 query += ', '.join(sets)
251 query += ' WHERE msg_id == ?'
265 query += ' WHERE msg_id == ?'
252 values.append(msg_id)
266 values.append(msg_id)
253 self._db.execute(query, values)
267 self._db.execute(query, values)
254 # self._db.commit()
268 # self._db.commit()
255
269
256 def drop_record(self, msg_id):
270 def drop_record(self, msg_id):
257 """Remove a record from the DB."""
271 """Remove a record from the DB."""
258 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
272 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
259 # self._db.commit()
273 # self._db.commit()
260
274
261 def drop_matching_records(self, check):
275 def drop_matching_records(self, check):
262 """Remove a record from the DB."""
276 """Remove a record from the DB."""
263 expr,args = self._render_expression(check)
277 expr,args = self._render_expression(check)
264 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
278 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
265 self._db.execute(query,args)
279 self._db.execute(query,args)
266 # self._db.commit()
280 # self._db.commit()
267
281
268 def find_records(self, check, keys=None):
282 def find_records(self, check, keys=None):
269 """Find records matching a query dict, optionally extracting subset of keys.
283 """Find records matching a query dict, optionally extracting subset of keys.
270
284
271 Returns list of matching records.
285 Returns list of matching records.
272
286
273 Parameters
287 Parameters
274 ----------
288 ----------
275
289
276 check: dict
290 check: dict
277 mongodb-style query argument
291 mongodb-style query argument
278 keys: list of strs [optional]
292 keys: list of strs [optional]
279 if specified, the subset of keys to extract. msg_id will *always* be
293 if specified, the subset of keys to extract. msg_id will *always* be
280 included.
294 included.
281 """
295 """
282 if keys:
296 if keys:
283 bad_keys = [ key for key in keys if key not in self._keys ]
297 bad_keys = [ key for key in keys if key not in self._keys ]
284 if bad_keys:
298 if bad_keys:
285 raise KeyError("Bad record key(s): %s"%bad_keys)
299 raise KeyError("Bad record key(s): %s"%bad_keys)
286
300
287 if keys:
301 if keys:
288 # ensure msg_id is present and first:
302 # ensure msg_id is present and first:
289 if 'msg_id' in keys:
303 if 'msg_id' in keys:
290 keys.remove('msg_id')
304 keys.remove('msg_id')
291 keys.insert(0, 'msg_id')
305 keys.insert(0, 'msg_id')
292 req = ', '.join(keys)
306 req = ', '.join(keys)
293 else:
307 else:
294 req = '*'
308 req = '*'
295 expr,args = self._render_expression(check)
309 expr,args = self._render_expression(check)
296 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
310 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
297 cursor = self._db.execute(query, args)
311 cursor = self._db.execute(query, args)
298 matches = cursor.fetchall()
312 matches = cursor.fetchall()
299 records = []
313 records = []
300 for line in matches:
314 for line in matches:
301 rec = self._list_to_dict(line, keys)
315 rec = self._list_to_dict(line, keys)
302 records.append(rec)
316 records.append(rec)
303 return records
317 return records
304
318
305 def get_history(self):
319 def get_history(self):
306 """get all msg_ids, ordered by time submitted."""
320 """get all msg_ids, ordered by time submitted."""
307 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
321 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
308 cursor = self._db.execute(query)
322 cursor = self._db.execute(query)
309 # will be a list of length 1 tuples
323 # will be a list of length 1 tuples
310 return [ tup[0] for tup in cursor.fetchall()]
324 return [ tup[0] for tup in cursor.fetchall()]
311
325
312 __all__ = ['SQLiteDB'] No newline at end of file
326 __all__ = ['SQLiteDB']
@@ -1,237 +1,244 b''
1 """Tests for parallel client.py"""
1 """Tests for parallel client.py"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import time
14 import time
15 from datetime import datetime
15 from datetime import datetime
16 from tempfile import mktemp
16 from tempfile import mktemp
17
17
18 import zmq
18 import zmq
19
19
20 from IPython.parallel.client import client as clientmod
20 from IPython.parallel.client import client as clientmod
21 from IPython.parallel import error
21 from IPython.parallel import error
22 from IPython.parallel import AsyncResult, AsyncHubResult
22 from IPython.parallel import AsyncResult, AsyncHubResult
23 from IPython.parallel import LoadBalancedView, DirectView
23 from IPython.parallel import LoadBalancedView, DirectView
24
24
25 from clienttest import ClusterTestCase, segfault, wait, add_engines
25 from clienttest import ClusterTestCase, segfault, wait, add_engines
26
26
27 def setup():
27 def setup():
28 add_engines(4)
28 add_engines(4)
29
29
30 class TestClient(ClusterTestCase):
30 class TestClient(ClusterTestCase):
31
31
32 def test_ids(self):
32 def test_ids(self):
33 n = len(self.client.ids)
33 n = len(self.client.ids)
34 self.add_engines(3)
34 self.add_engines(3)
35 self.assertEquals(len(self.client.ids), n+3)
35 self.assertEquals(len(self.client.ids), n+3)
36
36
37 def test_view_indexing(self):
37 def test_view_indexing(self):
38 """test index access for views"""
38 """test index access for views"""
39 self.add_engines(2)
39 self.add_engines(2)
40 targets = self.client._build_targets('all')[-1]
40 targets = self.client._build_targets('all')[-1]
41 v = self.client[:]
41 v = self.client[:]
42 self.assertEquals(v.targets, targets)
42 self.assertEquals(v.targets, targets)
43 t = self.client.ids[2]
43 t = self.client.ids[2]
44 v = self.client[t]
44 v = self.client[t]
45 self.assert_(isinstance(v, DirectView))
45 self.assert_(isinstance(v, DirectView))
46 self.assertEquals(v.targets, t)
46 self.assertEquals(v.targets, t)
47 t = self.client.ids[2:4]
47 t = self.client.ids[2:4]
48 v = self.client[t]
48 v = self.client[t]
49 self.assert_(isinstance(v, DirectView))
49 self.assert_(isinstance(v, DirectView))
50 self.assertEquals(v.targets, t)
50 self.assertEquals(v.targets, t)
51 v = self.client[::2]
51 v = self.client[::2]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, targets[::2])
53 self.assertEquals(v.targets, targets[::2])
54 v = self.client[1::3]
54 v = self.client[1::3]
55 self.assert_(isinstance(v, DirectView))
55 self.assert_(isinstance(v, DirectView))
56 self.assertEquals(v.targets, targets[1::3])
56 self.assertEquals(v.targets, targets[1::3])
57 v = self.client[:-3]
57 v = self.client[:-3]
58 self.assert_(isinstance(v, DirectView))
58 self.assert_(isinstance(v, DirectView))
59 self.assertEquals(v.targets, targets[:-3])
59 self.assertEquals(v.targets, targets[:-3])
60 v = self.client[-1]
60 v = self.client[-1]
61 self.assert_(isinstance(v, DirectView))
61 self.assert_(isinstance(v, DirectView))
62 self.assertEquals(v.targets, targets[-1])
62 self.assertEquals(v.targets, targets[-1])
63 self.assertRaises(TypeError, lambda : self.client[None])
63 self.assertRaises(TypeError, lambda : self.client[None])
64
64
65 def test_lbview_targets(self):
65 def test_lbview_targets(self):
66 """test load_balanced_view targets"""
66 """test load_balanced_view targets"""
67 v = self.client.load_balanced_view()
67 v = self.client.load_balanced_view()
68 self.assertEquals(v.targets, None)
68 self.assertEquals(v.targets, None)
69 v = self.client.load_balanced_view(-1)
69 v = self.client.load_balanced_view(-1)
70 self.assertEquals(v.targets, [self.client.ids[-1]])
70 self.assertEquals(v.targets, [self.client.ids[-1]])
71 v = self.client.load_balanced_view('all')
71 v = self.client.load_balanced_view('all')
72 self.assertEquals(v.targets, self.client.ids)
72 self.assertEquals(v.targets, self.client.ids)
73
73
74 def test_targets(self):
74 def test_targets(self):
75 """test various valid targets arguments"""
75 """test various valid targets arguments"""
76 build = self.client._build_targets
76 build = self.client._build_targets
77 ids = self.client.ids
77 ids = self.client.ids
78 idents,targets = build(None)
78 idents,targets = build(None)
79 self.assertEquals(ids, targets)
79 self.assertEquals(ids, targets)
80
80
81 def test_clear(self):
81 def test_clear(self):
82 """test clear behavior"""
82 """test clear behavior"""
83 # self.add_engines(2)
83 # self.add_engines(2)
84 v = self.client[:]
84 v = self.client[:]
85 v.block=True
85 v.block=True
86 v.push(dict(a=5))
86 v.push(dict(a=5))
87 v.pull('a')
87 v.pull('a')
88 id0 = self.client.ids[-1]
88 id0 = self.client.ids[-1]
89 self.client.clear(targets=id0, block=True)
89 self.client.clear(targets=id0, block=True)
90 a = self.client[:-1].get('a')
90 a = self.client[:-1].get('a')
91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
91 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
92 self.client.clear(block=True)
92 self.client.clear(block=True)
93 for i in self.client.ids:
93 for i in self.client.ids:
94 # print i
94 # print i
95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
95 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
96
96
97 def test_get_result(self):
97 def test_get_result(self):
98 """test getting results from the Hub."""
98 """test getting results from the Hub."""
99 c = clientmod.Client(profile='iptest')
99 c = clientmod.Client(profile='iptest')
100 # self.add_engines(1)
100 # self.add_engines(1)
101 t = c.ids[-1]
101 t = c.ids[-1]
102 ar = c[t].apply_async(wait, 1)
102 ar = c[t].apply_async(wait, 1)
103 # give the monitor time to notice the message
103 # give the monitor time to notice the message
104 time.sleep(.25)
104 time.sleep(.25)
105 ahr = self.client.get_result(ar.msg_ids)
105 ahr = self.client.get_result(ar.msg_ids)
106 self.assertTrue(isinstance(ahr, AsyncHubResult))
106 self.assertTrue(isinstance(ahr, AsyncHubResult))
107 self.assertEquals(ahr.get(), ar.get())
107 self.assertEquals(ahr.get(), ar.get())
108 ar2 = self.client.get_result(ar.msg_ids)
108 ar2 = self.client.get_result(ar.msg_ids)
109 self.assertFalse(isinstance(ar2, AsyncHubResult))
109 self.assertFalse(isinstance(ar2, AsyncHubResult))
110 c.close()
110 c.close()
111
111
112 def test_ids_list(self):
112 def test_ids_list(self):
113 """test client.ids"""
113 """test client.ids"""
114 # self.add_engines(2)
114 # self.add_engines(2)
115 ids = self.client.ids
115 ids = self.client.ids
116 self.assertEquals(ids, self.client._ids)
116 self.assertEquals(ids, self.client._ids)
117 self.assertFalse(ids is self.client._ids)
117 self.assertFalse(ids is self.client._ids)
118 ids.remove(ids[-1])
118 ids.remove(ids[-1])
119 self.assertNotEquals(ids, self.client._ids)
119 self.assertNotEquals(ids, self.client._ids)
120
120
121 def test_queue_status(self):
121 def test_queue_status(self):
122 # self.addEngine(4)
122 # self.addEngine(4)
123 ids = self.client.ids
123 ids = self.client.ids
124 id0 = ids[0]
124 id0 = ids[0]
125 qs = self.client.queue_status(targets=id0)
125 qs = self.client.queue_status(targets=id0)
126 self.assertTrue(isinstance(qs, dict))
126 self.assertTrue(isinstance(qs, dict))
127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
127 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
128 allqs = self.client.queue_status()
128 allqs = self.client.queue_status()
129 self.assertTrue(isinstance(allqs, dict))
129 self.assertTrue(isinstance(allqs, dict))
130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
130 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
131 unassigned = allqs.pop('unassigned')
131 unassigned = allqs.pop('unassigned')
132 for eid,qs in allqs.items():
132 for eid,qs in allqs.items():
133 self.assertTrue(isinstance(qs, dict))
133 self.assertTrue(isinstance(qs, dict))
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
135
135
136 def test_shutdown(self):
136 def test_shutdown(self):
137 # self.addEngine(4)
137 # self.addEngine(4)
138 ids = self.client.ids
138 ids = self.client.ids
139 id0 = ids[0]
139 id0 = ids[0]
140 self.client.shutdown(id0, block=True)
140 self.client.shutdown(id0, block=True)
141 while id0 in self.client.ids:
141 while id0 in self.client.ids:
142 time.sleep(0.1)
142 time.sleep(0.1)
143 self.client.spin()
143 self.client.spin()
144
144
145 self.assertRaises(IndexError, lambda : self.client[id0])
145 self.assertRaises(IndexError, lambda : self.client[id0])
146
146
147 def test_result_status(self):
147 def test_result_status(self):
148 pass
148 pass
149 # to be written
149 # to be written
150
150
151 def test_db_query_dt(self):
151 def test_db_query_dt(self):
152 """test db query by date"""
152 """test db query by date"""
153 hist = self.client.hub_history()
153 hist = self.client.hub_history()
154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
154 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
155 tic = middle['submitted']
155 tic = middle['submitted']
156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
156 before = self.client.db_query({'submitted' : {'$lt' : tic}})
157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
157 after = self.client.db_query({'submitted' : {'$gte' : tic}})
158 self.assertEquals(len(before)+len(after),len(hist))
158 self.assertEquals(len(before)+len(after),len(hist))
159 for b in before:
159 for b in before:
160 self.assertTrue(b['submitted'] < tic)
160 self.assertTrue(b['submitted'] < tic)
161 for a in after:
161 for a in after:
162 self.assertTrue(a['submitted'] >= tic)
162 self.assertTrue(a['submitted'] >= tic)
163 same = self.client.db_query({'submitted' : tic})
163 same = self.client.db_query({'submitted' : tic})
164 for s in same:
164 for s in same:
165 self.assertTrue(s['submitted'] == tic)
165 self.assertTrue(s['submitted'] == tic)
166
166
167 def test_db_query_keys(self):
167 def test_db_query_keys(self):
168 """test extracting subset of record keys"""
168 """test extracting subset of record keys"""
169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
169 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
170 for rec in found:
170 for rec in found:
171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
171 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
172
172
173 def test_db_query_msg_id(self):
173 def test_db_query_msg_id(self):
174 """ensure msg_id is always in db queries"""
174 """ensure msg_id is always in db queries"""
175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
175 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
176 for rec in found:
176 for rec in found:
177 self.assertTrue('msg_id' in rec.keys())
177 self.assertTrue('msg_id' in rec.keys())
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
179 for rec in found:
179 for rec in found:
180 self.assertTrue('msg_id' in rec.keys())
180 self.assertTrue('msg_id' in rec.keys())
181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
181 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
182 for rec in found:
182 for rec in found:
183 self.assertTrue('msg_id' in rec.keys())
183 self.assertTrue('msg_id' in rec.keys())
184
184
185 def test_db_query_in(self):
185 def test_db_query_in(self):
186 """test db query with '$in','$nin' operators"""
186 """test db query with '$in','$nin' operators"""
187 hist = self.client.hub_history()
187 hist = self.client.hub_history()
188 even = hist[::2]
188 even = hist[::2]
189 odd = hist[1::2]
189 odd = hist[1::2]
190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
190 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
191 found = [ r['msg_id'] for r in recs ]
191 found = [ r['msg_id'] for r in recs ]
192 self.assertEquals(set(even), set(found))
192 self.assertEquals(set(even), set(found))
193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
193 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
194 found = [ r['msg_id'] for r in recs ]
194 found = [ r['msg_id'] for r in recs ]
195 self.assertEquals(set(odd), set(found))
195 self.assertEquals(set(odd), set(found))
196
196
197 def test_hub_history(self):
197 def test_hub_history(self):
198 hist = self.client.hub_history()
198 hist = self.client.hub_history()
199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
199 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
200 recdict = {}
200 recdict = {}
201 for rec in recs:
201 for rec in recs:
202 recdict[rec['msg_id']] = rec
202 recdict[rec['msg_id']] = rec
203
203
204 latest = datetime(1984,1,1)
204 latest = datetime(1984,1,1)
205 for msg_id in hist:
205 for msg_id in hist:
206 rec = recdict[msg_id]
206 rec = recdict[msg_id]
207 newt = rec['submitted']
207 newt = rec['submitted']
208 self.assertTrue(newt >= latest)
208 self.assertTrue(newt >= latest)
209 latest = newt
209 latest = newt
210 ar = self.client[-1].apply_async(lambda : 1)
210 ar = self.client[-1].apply_async(lambda : 1)
211 ar.get()
211 ar.get()
212 time.sleep(0.25)
212 time.sleep(0.25)
213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
213 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
214
214
215 def test_resubmit(self):
215 def test_resubmit(self):
216 def f():
216 def f():
217 import random
217 import random
218 return random.random()
218 return random.random()
219 v = self.client.load_balanced_view()
219 v = self.client.load_balanced_view()
220 ar = v.apply_async(f)
220 ar = v.apply_async(f)
221 r1 = ar.get(1)
221 r1 = ar.get(1)
222 ahr = self.client.resubmit(ar.msg_ids)
222 ahr = self.client.resubmit(ar.msg_ids)
223 r2 = ahr.get(1)
223 r2 = ahr.get(1)
224 self.assertFalse(r1 == r2)
224 self.assertFalse(r1 == r2)
225
225
226 def test_resubmit_inflight(self):
226 def test_resubmit_inflight(self):
227 """ensure ValueError on resubmit of inflight task"""
227 """ensure ValueError on resubmit of inflight task"""
228 v = self.client.load_balanced_view()
228 v = self.client.load_balanced_view()
229 ar = v.apply_async(time.sleep,1)
229 ar = v.apply_async(time.sleep,1)
230 # give the message a chance to arrive
230 # give the message a chance to arrive
231 time.sleep(0.2)
231 time.sleep(0.2)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
232 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
233 ar.get(2)
233 ar.get(2)
234
234
235 def test_resubmit_badkey(self):
235 def test_resubmit_badkey(self):
236 """ensure KeyError on resubmit of nonexistant task"""
236 """ensure KeyError on resubmit of nonexistant task"""
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
237 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
238
239 def test_purge_results(self):
240 hist = self.client.hub_history()
241 self.client.purge_results(hist)
242 newhist = self.client.hub_history()
243 self.assertTrue(len(newhist) == 0)
244
@@ -1,182 +1,170 b''
1 """Tests for db backends"""
1 """Tests for db backends"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14
14
15 import tempfile
15 import tempfile
16 import time
16 import time
17
17
18 import uuid
19
20 from datetime import datetime, timedelta
18 from datetime import datetime, timedelta
21 from random import choice, randint
22 from unittest import TestCase
19 from unittest import TestCase
23
20
24 from nose import SkipTest
21 from nose import SkipTest
25
22
26 from IPython.parallel import error, streamsession as ss
23 from IPython.parallel import error, streamsession as ss
27 from IPython.parallel.controller.dictdb import DictDB
24 from IPython.parallel.controller.dictdb import DictDB
28 from IPython.parallel.controller.sqlitedb import SQLiteDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
29 from IPython.parallel.controller.hub import init_record, empty_record
26 from IPython.parallel.controller.hub import init_record, empty_record
30
27
31 #-------------------------------------------------------------------------------
28 #-------------------------------------------------------------------------------
32 # TestCases
29 # TestCases
33 #-------------------------------------------------------------------------------
30 #-------------------------------------------------------------------------------
34
31
35 class TestDictBackend(TestCase):
32 class TestDictBackend(TestCase):
36 def setUp(self):
33 def setUp(self):
37 self.session = ss.StreamSession()
34 self.session = ss.StreamSession()
38 self.db = self.create_db()
35 self.db = self.create_db()
39 self.load_records(16)
36 self.load_records(16)
40
37
41 def create_db(self):
38 def create_db(self):
42 return DictDB()
39 return DictDB()
43
40
44 def load_records(self, n=1):
41 def load_records(self, n=1):
45 """load n records for testing"""
42 """load n records for testing"""
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
43 #sleep 1/10 s, to ensure timestamp is different to previous calls
47 time.sleep(0.1)
44 time.sleep(0.1)
48 msg_ids = []
45 msg_ids = []
49 for i in range(n):
46 for i in range(n):
50 msg = self.session.msg('apply_request', content=dict(a=5))
47 msg = self.session.msg('apply_request', content=dict(a=5))
51 msg['buffers'] = []
48 msg['buffers'] = []
52 rec = init_record(msg)
49 rec = init_record(msg)
53 msg_ids.append(msg['msg_id'])
50 msg_ids.append(msg['msg_id'])
54 self.db.add_record(msg['msg_id'], rec)
51 self.db.add_record(msg['msg_id'], rec)
55 return msg_ids
52 return msg_ids
56
53
57 def test_add_record(self):
54 def test_add_record(self):
58 before = self.db.get_history()
55 before = self.db.get_history()
59 self.load_records(5)
56 self.load_records(5)
60 after = self.db.get_history()
57 after = self.db.get_history()
61 self.assertEquals(len(after), len(before)+5)
58 self.assertEquals(len(after), len(before)+5)
62 self.assertEquals(after[:-5],before)
59 self.assertEquals(after[:-5],before)
63
60
64 def test_drop_record(self):
61 def test_drop_record(self):
65 msg_id = self.load_records()[-1]
62 msg_id = self.load_records()[-1]
66 rec = self.db.get_record(msg_id)
63 rec = self.db.get_record(msg_id)
67 self.db.drop_record(msg_id)
64 self.db.drop_record(msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
65 self.assertRaises(KeyError,self.db.get_record, msg_id)
69
66
70 def _round_to_millisecond(self, dt):
67 def _round_to_millisecond(self, dt):
71 """necessary because mongodb rounds microseconds"""
68 """necessary because mongodb rounds microseconds"""
72 micro = dt.microsecond
69 micro = dt.microsecond
73 extra = int(str(micro)[-3:])
70 extra = int(str(micro)[-3:])
74 return dt - timedelta(microseconds=extra)
71 return dt - timedelta(microseconds=extra)
75
72
76 def test_update_record(self):
73 def test_update_record(self):
77 now = self._round_to_millisecond(datetime.now())
74 now = self._round_to_millisecond(datetime.now())
78 #
75 #
79 msg_id = self.db.get_history()[-1]
76 msg_id = self.db.get_history()[-1]
80 rec1 = self.db.get_record(msg_id)
77 rec1 = self.db.get_record(msg_id)
81 data = {'stdout': 'hello there', 'completed' : now}
78 data = {'stdout': 'hello there', 'completed' : now}
82 self.db.update_record(msg_id, data)
79 self.db.update_record(msg_id, data)
83 rec2 = self.db.get_record(msg_id)
80 rec2 = self.db.get_record(msg_id)
84 self.assertEquals(rec2['stdout'], 'hello there')
81 self.assertEquals(rec2['stdout'], 'hello there')
85 self.assertEquals(rec2['completed'], now)
82 self.assertEquals(rec2['completed'], now)
86 rec1.update(data)
83 rec1.update(data)
87 self.assertEquals(rec1, rec2)
84 self.assertEquals(rec1, rec2)
88
85
89 # def test_update_record_bad(self):
86 # def test_update_record_bad(self):
90 # """test updating nonexistant records"""
87 # """test updating nonexistant records"""
91 # msg_id = str(uuid.uuid4())
88 # msg_id = str(uuid.uuid4())
92 # data = {'stdout': 'hello there'}
89 # data = {'stdout': 'hello there'}
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
90 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
94
91
95 def test_find_records_dt(self):
92 def test_find_records_dt(self):
96 """test finding records by date"""
93 """test finding records by date"""
97 hist = self.db.get_history()
94 hist = self.db.get_history()
98 middle = self.db.get_record(hist[len(hist)/2])
95 middle = self.db.get_record(hist[len(hist)/2])
99 tic = middle['submitted']
96 tic = middle['submitted']
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
97 before = self.db.find_records({'submitted' : {'$lt' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
98 after = self.db.find_records({'submitted' : {'$gte' : tic}})
102 self.assertEquals(len(before)+len(after),len(hist))
99 self.assertEquals(len(before)+len(after),len(hist))
103 for b in before:
100 for b in before:
104 self.assertTrue(b['submitted'] < tic)
101 self.assertTrue(b['submitted'] < tic)
105 for a in after:
102 for a in after:
106 self.assertTrue(a['submitted'] >= tic)
103 self.assertTrue(a['submitted'] >= tic)
107 same = self.db.find_records({'submitted' : tic})
104 same = self.db.find_records({'submitted' : tic})
108 for s in same:
105 for s in same:
109 self.assertTrue(s['submitted'] == tic)
106 self.assertTrue(s['submitted'] == tic)
110
107
111 def test_find_records_keys(self):
108 def test_find_records_keys(self):
112 """test extracting subset of record keys"""
109 """test extracting subset of record keys"""
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
110 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
114 for rec in found:
111 for rec in found:
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
112 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
116
113
117 def test_find_records_msg_id(self):
114 def test_find_records_msg_id(self):
118 """ensure msg_id is always in found records"""
115 """ensure msg_id is always in found records"""
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
116 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
120 for rec in found:
117 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
118 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
123 for rec in found:
120 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
121 self.assertTrue('msg_id' in rec.keys())
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
126 for rec in found:
123 for rec in found:
127 self.assertTrue('msg_id' in rec.keys())
124 self.assertTrue('msg_id' in rec.keys())
128
125
129 def test_find_records_in(self):
126 def test_find_records_in(self):
130 """test finding records with '$in','$nin' operators"""
127 """test finding records with '$in','$nin' operators"""
131 hist = self.db.get_history()
128 hist = self.db.get_history()
132 even = hist[::2]
129 even = hist[::2]
133 odd = hist[1::2]
130 odd = hist[1::2]
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
131 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
135 found = [ r['msg_id'] for r in recs ]
132 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(even), set(found))
133 self.assertEquals(set(even), set(found))
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
134 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
138 found = [ r['msg_id'] for r in recs ]
135 found = [ r['msg_id'] for r in recs ]
139 self.assertEquals(set(odd), set(found))
136 self.assertEquals(set(odd), set(found))
140
137
141 def test_get_history(self):
138 def test_get_history(self):
142 msg_ids = self.db.get_history()
139 msg_ids = self.db.get_history()
143 latest = datetime(1984,1,1)
140 latest = datetime(1984,1,1)
144 for msg_id in msg_ids:
141 for msg_id in msg_ids:
145 rec = self.db.get_record(msg_id)
142 rec = self.db.get_record(msg_id)
146 newt = rec['submitted']
143 newt = rec['submitted']
147 self.assertTrue(newt >= latest)
144 self.assertTrue(newt >= latest)
148 latest = newt
145 latest = newt
149 msg_id = self.load_records(1)[-1]
146 msg_id = self.load_records(1)[-1]
150 self.assertEquals(self.db.get_history()[-1],msg_id)
147 self.assertEquals(self.db.get_history()[-1],msg_id)
151
148
152 def test_datetime(self):
149 def test_datetime(self):
153 """get/set timestamps with datetime objects"""
150 """get/set timestamps with datetime objects"""
154 msg_id = self.db.get_history()[-1]
151 msg_id = self.db.get_history()[-1]
155 rec = self.db.get_record(msg_id)
152 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['submitted'], datetime))
153 self.assertTrue(isinstance(rec['submitted'], datetime))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
154 self.db.update_record(msg_id, dict(completed=datetime.now()))
158 rec = self.db.get_record(msg_id)
155 rec = self.db.get_record(msg_id)
159 self.assertTrue(isinstance(rec['completed'], datetime))
156 self.assertTrue(isinstance(rec['completed'], datetime))
157
158 def test_drop_matching(self):
159 msg_ids = self.load_records(10)
160 query = {'msg_id' : {'$in':msg_ids}}
161 self.db.drop_matching_records(query)
162 recs = self.db.find_records(query)
163 self.assertTrue(len(recs)==0)
160
164
161 class TestSQLiteBackend(TestDictBackend):
165 class TestSQLiteBackend(TestDictBackend):
162 def create_db(self):
166 def create_db(self):
163 return SQLiteDB(location=tempfile.gettempdir())
167 return SQLiteDB(location=tempfile.gettempdir())
164
168
165 def tearDown(self):
169 def tearDown(self):
166 self.db._db.close()
170 self.db._db.close()
167
168 # optional MongoDB test
169 try:
170 from IPython.parallel.controller.mongodb import MongoDB
171 except ImportError:
172 pass
173 else:
174 class TestMongoBackend(TestDictBackend):
175 def create_db(self):
176 try:
177 return MongoDB(database='iptestdb')
178 except Exception:
179 raise SkipTest("Couldn't connect to mongodb instance")
180
181 def tearDown(self):
182 self.db._connection.drop_database('iptestdb')
@@ -1,440 +1,441 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """IPython Test Suite Runner.
2 """IPython Test Suite Runner.
3
3
4 This module provides a main entry point to a user script to test IPython
4 This module provides a main entry point to a user script to test IPython
5 itself from the command line. There are two ways of running this script:
5 itself from the command line. There are two ways of running this script:
6
6
7 1. With the syntax `iptest all`. This runs our entire test suite by
7 1. With the syntax `iptest all`. This runs our entire test suite by
8 calling this script (with different arguments) recursively. This
8 calling this script (with different arguments) recursively. This
9 causes modules and package to be tested in different processes, using nose
9 causes modules and package to be tested in different processes, using nose
10 or trial where appropriate.
10 or trial where appropriate.
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
11 2. With the regular nose syntax, like `iptest -vvs IPython`. In this form
12 the script simply calls nose, but with special command line flags and
12 the script simply calls nose, but with special command line flags and
13 plugins loaded.
13 plugins loaded.
14
14
15 """
15 """
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Copyright (C) 2009 The IPython Development Team
18 # Copyright (C) 2009 The IPython Development Team
19 #
19 #
20 # Distributed under the terms of the BSD License. The full license is in
20 # Distributed under the terms of the BSD License. The full license is in
21 # the file COPYING, distributed as part of this software.
21 # the file COPYING, distributed as part of this software.
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25 # Imports
25 # Imports
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27
27
28 # Stdlib
28 # Stdlib
29 import os
29 import os
30 import os.path as path
30 import os.path as path
31 import signal
31 import signal
32 import sys
32 import sys
33 import subprocess
33 import subprocess
34 import tempfile
34 import tempfile
35 import time
35 import time
36 import warnings
36 import warnings
37
37
38 # Note: monkeypatch!
38 # Note: monkeypatch!
39 # We need to monkeypatch a small problem in nose itself first, before importing
39 # We need to monkeypatch a small problem in nose itself first, before importing
40 # it for actual use. This should get into nose upstream, but its release cycle
40 # it for actual use. This should get into nose upstream, but its release cycle
41 # is slow and we need it for our parametric tests to work correctly.
41 # is slow and we need it for our parametric tests to work correctly.
42 from IPython.testing import nosepatch
42 from IPython.testing import nosepatch
43 # Now, proceed to import nose itself
43 # Now, proceed to import nose itself
44 import nose.plugins.builtin
44 import nose.plugins.builtin
45 from nose.core import TestProgram
45 from nose.core import TestProgram
46
46
47 # Our own imports
47 # Our own imports
48 from IPython.utils.path import get_ipython_module_path
48 from IPython.utils.path import get_ipython_module_path
49 from IPython.utils.process import find_cmd, pycmd2argv
49 from IPython.utils.process import find_cmd, pycmd2argv
50 from IPython.utils.sysinfo import sys_info
50 from IPython.utils.sysinfo import sys_info
51
51
52 from IPython.testing import globalipapp
52 from IPython.testing import globalipapp
53 from IPython.testing.plugin.ipdoctest import IPythonDoctest
53 from IPython.testing.plugin.ipdoctest import IPythonDoctest
54 from IPython.external.decorators import KnownFailure
54 from IPython.external.decorators import KnownFailure
55
55
56 pjoin = path.join
56 pjoin = path.join
57
57
58
58
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60 # Globals
60 # Globals
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62
62
63
63
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65 # Warnings control
65 # Warnings control
66 #-----------------------------------------------------------------------------
66 #-----------------------------------------------------------------------------
67
67
68 # Twisted generates annoying warnings with Python 2.6, as will do other code
68 # Twisted generates annoying warnings with Python 2.6, as will do other code
69 # that imports 'sets' as of today
69 # that imports 'sets' as of today
70 warnings.filterwarnings('ignore', 'the sets module is deprecated',
70 warnings.filterwarnings('ignore', 'the sets module is deprecated',
71 DeprecationWarning )
71 DeprecationWarning )
72
72
73 # This one also comes from Twisted
73 # This one also comes from Twisted
74 warnings.filterwarnings('ignore', 'the sha module is deprecated',
74 warnings.filterwarnings('ignore', 'the sha module is deprecated',
75 DeprecationWarning)
75 DeprecationWarning)
76
76
77 # Wx on Fedora11 spits these out
77 # Wx on Fedora11 spits these out
78 warnings.filterwarnings('ignore', 'wxPython/wxWidgets release number mismatch',
78 warnings.filterwarnings('ignore', 'wxPython/wxWidgets release number mismatch',
79 UserWarning)
79 UserWarning)
80
80
81 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
82 # Logic for skipping doctests
82 # Logic for skipping doctests
83 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
84
84
85 def test_for(mod, min_version=None):
85 def test_for(mod, min_version=None):
86 """Test to see if mod is importable."""
86 """Test to see if mod is importable."""
87 try:
87 try:
88 __import__(mod)
88 __import__(mod)
89 except (ImportError, RuntimeError):
89 except (ImportError, RuntimeError):
90 # GTK reports Runtime error if it can't be initialized even if it's
90 # GTK reports Runtime error if it can't be initialized even if it's
91 # importable.
91 # importable.
92 return False
92 return False
93 else:
93 else:
94 if min_version:
94 if min_version:
95 return sys.modules[mod].__version__ >= min_version
95 return sys.modules[mod].__version__ >= min_version
96 else:
96 else:
97 return True
97 return True
98
98
99 # Global dict where we can store information on what we have and what we don't
99 # Global dict where we can store information on what we have and what we don't
100 # have available at test run time
100 # have available at test run time
101 have = {}
101 have = {}
102
102
103 have['curses'] = test_for('_curses')
103 have['curses'] = test_for('_curses')
104 have['matplotlib'] = test_for('matplotlib')
104 have['matplotlib'] = test_for('matplotlib')
105 have['pexpect'] = test_for('pexpect')
105 have['pexpect'] = test_for('pexpect')
106 have['pymongo'] = test_for('pymongo')
106 have['pymongo'] = test_for('pymongo')
107 have['wx'] = test_for('wx')
107 have['wx'] = test_for('wx')
108 have['wx.aui'] = test_for('wx.aui')
108 have['wx.aui'] = test_for('wx.aui')
109 if os.name == 'nt':
109 if os.name == 'nt':
110 have['zmq'] = test_for('zmq', '2.1.7')
110 have['zmq'] = test_for('zmq', '2.1.7')
111 else:
111 else:
112 have['zmq'] = test_for('zmq', '2.1.4')
112 have['zmq'] = test_for('zmq', '2.1.4')
113 have['qt'] = test_for('IPython.external.qt')
113 have['qt'] = test_for('IPython.external.qt')
114
114
115 #-----------------------------------------------------------------------------
115 #-----------------------------------------------------------------------------
116 # Functions and classes
116 # Functions and classes
117 #-----------------------------------------------------------------------------
117 #-----------------------------------------------------------------------------
118
118
119 def report():
119 def report():
120 """Return a string with a summary report of test-related variables."""
120 """Return a string with a summary report of test-related variables."""
121
121
122 out = [ sys_info(), '\n']
122 out = [ sys_info(), '\n']
123
123
124 avail = []
124 avail = []
125 not_avail = []
125 not_avail = []
126
126
127 for k, is_avail in have.items():
127 for k, is_avail in have.items():
128 if is_avail:
128 if is_avail:
129 avail.append(k)
129 avail.append(k)
130 else:
130 else:
131 not_avail.append(k)
131 not_avail.append(k)
132
132
133 if avail:
133 if avail:
134 out.append('\nTools and libraries available at test time:\n')
134 out.append('\nTools and libraries available at test time:\n')
135 avail.sort()
135 avail.sort()
136 out.append(' ' + ' '.join(avail)+'\n')
136 out.append(' ' + ' '.join(avail)+'\n')
137
137
138 if not_avail:
138 if not_avail:
139 out.append('\nTools and libraries NOT available at test time:\n')
139 out.append('\nTools and libraries NOT available at test time:\n')
140 not_avail.sort()
140 not_avail.sort()
141 out.append(' ' + ' '.join(not_avail)+'\n')
141 out.append(' ' + ' '.join(not_avail)+'\n')
142
142
143 return ''.join(out)
143 return ''.join(out)
144
144
145
145
146 def make_exclude():
146 def make_exclude():
147 """Make patterns of modules and packages to exclude from testing.
147 """Make patterns of modules and packages to exclude from testing.
148
148
149 For the IPythonDoctest plugin, we need to exclude certain patterns that
149 For the IPythonDoctest plugin, we need to exclude certain patterns that
150 cause testing problems. We should strive to minimize the number of
150 cause testing problems. We should strive to minimize the number of
151 skipped modules, since this means untested code.
151 skipped modules, since this means untested code.
152
152
153 These modules and packages will NOT get scanned by nose at all for tests.
153 These modules and packages will NOT get scanned by nose at all for tests.
154 """
154 """
155 # Simple utility to make IPython paths more readably, we need a lot of
155 # Simple utility to make IPython paths more readably, we need a lot of
156 # these below
156 # these below
157 ipjoin = lambda *paths: pjoin('IPython', *paths)
157 ipjoin = lambda *paths: pjoin('IPython', *paths)
158
158
159 exclusions = [ipjoin('external'),
159 exclusions = [ipjoin('external'),
160 pjoin('IPython_doctest_plugin'),
160 pjoin('IPython_doctest_plugin'),
161 ipjoin('quarantine'),
161 ipjoin('quarantine'),
162 ipjoin('deathrow'),
162 ipjoin('deathrow'),
163 ipjoin('testing', 'attic'),
163 ipjoin('testing', 'attic'),
164 # This guy is probably attic material
164 # This guy is probably attic material
165 ipjoin('testing', 'mkdoctests'),
165 ipjoin('testing', 'mkdoctests'),
166 # Testing inputhook will need a lot of thought, to figure out
166 # Testing inputhook will need a lot of thought, to figure out
167 # how to have tests that don't lock up with the gui event
167 # how to have tests that don't lock up with the gui event
168 # loops in the picture
168 # loops in the picture
169 ipjoin('lib', 'inputhook'),
169 ipjoin('lib', 'inputhook'),
170 # Config files aren't really importable stand-alone
170 # Config files aren't really importable stand-alone
171 ipjoin('config', 'default'),
171 ipjoin('config', 'default'),
172 ipjoin('config', 'profile'),
172 ipjoin('config', 'profile'),
173 ]
173 ]
174
174
175 if not have['wx']:
175 if not have['wx']:
176 exclusions.append(ipjoin('lib', 'inputhookwx'))
176 exclusions.append(ipjoin('lib', 'inputhookwx'))
177
177
178 # We do this unconditionally, so that the test suite doesn't import
178 # We do this unconditionally, so that the test suite doesn't import
179 # gtk, changing the default encoding and masking some unicode bugs.
179 # gtk, changing the default encoding and masking some unicode bugs.
180 exclusions.append(ipjoin('lib', 'inputhookgtk'))
180 exclusions.append(ipjoin('lib', 'inputhookgtk'))
181
181
182 # These have to be skipped on win32 because the use echo, rm, cd, etc.
182 # These have to be skipped on win32 because the use echo, rm, cd, etc.
183 # See ticket https://bugs.launchpad.net/bugs/366982
183 # See ticket https://bugs.launchpad.net/bugs/366982
184 if sys.platform == 'win32':
184 if sys.platform == 'win32':
185 exclusions.append(ipjoin('testing', 'plugin', 'test_exampleip'))
185 exclusions.append(ipjoin('testing', 'plugin', 'test_exampleip'))
186 exclusions.append(ipjoin('testing', 'plugin', 'dtexample'))
186 exclusions.append(ipjoin('testing', 'plugin', 'dtexample'))
187
187
188 if not have['pexpect']:
188 if not have['pexpect']:
189 exclusions.extend([ipjoin('scripts', 'irunner'),
189 exclusions.extend([ipjoin('scripts', 'irunner'),
190 ipjoin('lib', 'irunner'),
190 ipjoin('lib', 'irunner'),
191 ipjoin('lib', 'tests', 'test_irunner')])
191 ipjoin('lib', 'tests', 'test_irunner')])
192
192
193 if not have['zmq']:
193 if not have['zmq']:
194 exclusions.append(ipjoin('zmq'))
194 exclusions.append(ipjoin('zmq'))
195 exclusions.append(ipjoin('frontend', 'qt'))
195 exclusions.append(ipjoin('frontend', 'qt'))
196 exclusions.append(ipjoin('parallel'))
196 exclusions.append(ipjoin('parallel'))
197 elif not have['qt']:
197 elif not have['qt']:
198 exclusions.append(ipjoin('frontend', 'qt'))
198 exclusions.append(ipjoin('frontend', 'qt'))
199
199
200 if not have['pymongo']:
200 if not have['pymongo']:
201 exclusions.append(ipjoin('parallel', 'controller', 'mongodb'))
201 exclusions.append(ipjoin('parallel', 'controller', 'mongodb'))
202 exclusions.append(ipjoin('parallel', 'tests', 'test_mongodb'))
202
203
203 if not have['matplotlib']:
204 if not have['matplotlib']:
204 exclusions.extend([ipjoin('lib', 'pylabtools'),
205 exclusions.extend([ipjoin('lib', 'pylabtools'),
205 ipjoin('lib', 'tests', 'test_pylabtools')])
206 ipjoin('lib', 'tests', 'test_pylabtools')])
206
207
207 # This is needed for the reg-exp to match on win32 in the ipdoctest plugin.
208 # This is needed for the reg-exp to match on win32 in the ipdoctest plugin.
208 if sys.platform == 'win32':
209 if sys.platform == 'win32':
209 exclusions = [s.replace('\\','\\\\') for s in exclusions]
210 exclusions = [s.replace('\\','\\\\') for s in exclusions]
210
211
211 return exclusions
212 return exclusions
212
213
213
214
214 class IPTester(object):
215 class IPTester(object):
215 """Call that calls iptest or trial in a subprocess.
216 """Call that calls iptest or trial in a subprocess.
216 """
217 """
217 #: string, name of test runner that will be called
218 #: string, name of test runner that will be called
218 runner = None
219 runner = None
219 #: list, parameters for test runner
220 #: list, parameters for test runner
220 params = None
221 params = None
221 #: list, arguments of system call to be made to call test runner
222 #: list, arguments of system call to be made to call test runner
222 call_args = None
223 call_args = None
223 #: list, process ids of subprocesses we start (for cleanup)
224 #: list, process ids of subprocesses we start (for cleanup)
224 pids = None
225 pids = None
225
226
226 def __init__(self, runner='iptest', params=None):
227 def __init__(self, runner='iptest', params=None):
227 """Create new test runner."""
228 """Create new test runner."""
228 p = os.path
229 p = os.path
229 if runner == 'iptest':
230 if runner == 'iptest':
230 iptest_app = get_ipython_module_path('IPython.testing.iptest')
231 iptest_app = get_ipython_module_path('IPython.testing.iptest')
231 self.runner = pycmd2argv(iptest_app) + sys.argv[1:]
232 self.runner = pycmd2argv(iptest_app) + sys.argv[1:]
232 else:
233 else:
233 raise Exception('Not a valid test runner: %s' % repr(runner))
234 raise Exception('Not a valid test runner: %s' % repr(runner))
234 if params is None:
235 if params is None:
235 params = []
236 params = []
236 if isinstance(params, str):
237 if isinstance(params, str):
237 params = [params]
238 params = [params]
238 self.params = params
239 self.params = params
239
240
240 # Assemble call
241 # Assemble call
241 self.call_args = self.runner+self.params
242 self.call_args = self.runner+self.params
242
243
243 # Store pids of anything we start to clean up on deletion, if possible
244 # Store pids of anything we start to clean up on deletion, if possible
244 # (on posix only, since win32 has no os.kill)
245 # (on posix only, since win32 has no os.kill)
245 self.pids = []
246 self.pids = []
246
247
247 if sys.platform == 'win32':
248 if sys.platform == 'win32':
248 def _run_cmd(self):
249 def _run_cmd(self):
249 # On Windows, use os.system instead of subprocess.call, because I
250 # On Windows, use os.system instead of subprocess.call, because I
250 # was having problems with subprocess and I just don't know enough
251 # was having problems with subprocess and I just don't know enough
251 # about win32 to debug this reliably. Os.system may be the 'old
252 # about win32 to debug this reliably. Os.system may be the 'old
252 # fashioned' way to do it, but it works just fine. If someone
253 # fashioned' way to do it, but it works just fine. If someone
253 # later can clean this up that's fine, as long as the tests run
254 # later can clean this up that's fine, as long as the tests run
254 # reliably in win32.
255 # reliably in win32.
255 # What types of problems are you having. They may be related to
256 # What types of problems are you having. They may be related to
256 # running Python in unboffered mode. BG.
257 # running Python in unboffered mode. BG.
257 return os.system(' '.join(self.call_args))
258 return os.system(' '.join(self.call_args))
258 else:
259 else:
259 def _run_cmd(self):
260 def _run_cmd(self):
260 # print >> sys.stderr, '*** CMD:', ' '.join(self.call_args) # dbg
261 # print >> sys.stderr, '*** CMD:', ' '.join(self.call_args) # dbg
261 subp = subprocess.Popen(self.call_args)
262 subp = subprocess.Popen(self.call_args)
262 self.pids.append(subp.pid)
263 self.pids.append(subp.pid)
263 # If this fails, the pid will be left in self.pids and cleaned up
264 # If this fails, the pid will be left in self.pids and cleaned up
264 # later, but if the wait call succeeds, then we can clear the
265 # later, but if the wait call succeeds, then we can clear the
265 # stored pid.
266 # stored pid.
266 retcode = subp.wait()
267 retcode = subp.wait()
267 self.pids.pop()
268 self.pids.pop()
268 return retcode
269 return retcode
269
270
270 def run(self):
271 def run(self):
271 """Run the stored commands"""
272 """Run the stored commands"""
272 try:
273 try:
273 return self._run_cmd()
274 return self._run_cmd()
274 except:
275 except:
275 import traceback
276 import traceback
276 traceback.print_exc()
277 traceback.print_exc()
277 return 1 # signal failure
278 return 1 # signal failure
278
279
279 def __del__(self):
280 def __del__(self):
280 """Cleanup on exit by killing any leftover processes."""
281 """Cleanup on exit by killing any leftover processes."""
281
282
282 if not hasattr(os, 'kill'):
283 if not hasattr(os, 'kill'):
283 return
284 return
284
285
285 for pid in self.pids:
286 for pid in self.pids:
286 try:
287 try:
287 print 'Cleaning stale PID:', pid
288 print 'Cleaning stale PID:', pid
288 os.kill(pid, signal.SIGKILL)
289 os.kill(pid, signal.SIGKILL)
289 except OSError:
290 except OSError:
290 # This is just a best effort, if we fail or the process was
291 # This is just a best effort, if we fail or the process was
291 # really gone, ignore it.
292 # really gone, ignore it.
292 pass
293 pass
293
294
294
295
295 def make_runners():
296 def make_runners():
296 """Define the top-level packages that need to be tested.
297 """Define the top-level packages that need to be tested.
297 """
298 """
298
299
299 # Packages to be tested via nose, that only depend on the stdlib
300 # Packages to be tested via nose, that only depend on the stdlib
300 nose_pkg_names = ['config', 'core', 'extensions', 'frontend', 'lib',
301 nose_pkg_names = ['config', 'core', 'extensions', 'frontend', 'lib',
301 'scripts', 'testing', 'utils' ]
302 'scripts', 'testing', 'utils' ]
302
303
303 if have['zmq']:
304 if have['zmq']:
304 nose_pkg_names.append('parallel')
305 nose_pkg_names.append('parallel')
305
306
306 # For debugging this code, only load quick stuff
307 # For debugging this code, only load quick stuff
307 #nose_pkg_names = ['core', 'extensions'] # dbg
308 #nose_pkg_names = ['core', 'extensions'] # dbg
308
309
309 # Make fully qualified package names prepending 'IPython.' to our name lists
310 # Make fully qualified package names prepending 'IPython.' to our name lists
310 nose_packages = ['IPython.%s' % m for m in nose_pkg_names ]
311 nose_packages = ['IPython.%s' % m for m in nose_pkg_names ]
311
312
312 # Make runners
313 # Make runners
313 runners = [ (v, IPTester('iptest', params=v)) for v in nose_packages ]
314 runners = [ (v, IPTester('iptest', params=v)) for v in nose_packages ]
314
315
315 return runners
316 return runners
316
317
317
318
318 def run_iptest():
319 def run_iptest():
319 """Run the IPython test suite using nose.
320 """Run the IPython test suite using nose.
320
321
321 This function is called when this script is **not** called with the form
322 This function is called when this script is **not** called with the form
322 `iptest all`. It simply calls nose with appropriate command line flags
323 `iptest all`. It simply calls nose with appropriate command line flags
323 and accepts all of the standard nose arguments.
324 and accepts all of the standard nose arguments.
324 """
325 """
325
326
326 warnings.filterwarnings('ignore',
327 warnings.filterwarnings('ignore',
327 'This will be removed soon. Use IPython.testing.util instead')
328 'This will be removed soon. Use IPython.testing.util instead')
328
329
329 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
330 argv = sys.argv + [ '--detailed-errors', # extra info in tracebacks
330
331
331 # Loading ipdoctest causes problems with Twisted, but
332 # Loading ipdoctest causes problems with Twisted, but
332 # our test suite runner now separates things and runs
333 # our test suite runner now separates things and runs
333 # all Twisted tests with trial.
334 # all Twisted tests with trial.
334 '--with-ipdoctest',
335 '--with-ipdoctest',
335 '--ipdoctest-tests','--ipdoctest-extension=txt',
336 '--ipdoctest-tests','--ipdoctest-extension=txt',
336
337
337 # We add --exe because of setuptools' imbecility (it
338 # We add --exe because of setuptools' imbecility (it
338 # blindly does chmod +x on ALL files). Nose does the
339 # blindly does chmod +x on ALL files). Nose does the
339 # right thing and it tries to avoid executables,
340 # right thing and it tries to avoid executables,
340 # setuptools unfortunately forces our hand here. This
341 # setuptools unfortunately forces our hand here. This
341 # has been discussed on the distutils list and the
342 # has been discussed on the distutils list and the
342 # setuptools devs refuse to fix this problem!
343 # setuptools devs refuse to fix this problem!
343 '--exe',
344 '--exe',
344 ]
345 ]
345
346
346 if nose.__version__ >= '0.11':
347 if nose.__version__ >= '0.11':
347 # I don't fully understand why we need this one, but depending on what
348 # I don't fully understand why we need this one, but depending on what
348 # directory the test suite is run from, if we don't give it, 0 tests
349 # directory the test suite is run from, if we don't give it, 0 tests
349 # get run. Specifically, if the test suite is run from the source dir
350 # get run. Specifically, if the test suite is run from the source dir
350 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
351 # with an argument (like 'iptest.py IPython.core', 0 tests are run,
351 # even if the same call done in this directory works fine). It appears
352 # even if the same call done in this directory works fine). It appears
352 # that if the requested package is in the current dir, nose bails early
353 # that if the requested package is in the current dir, nose bails early
353 # by default. Since it's otherwise harmless, leave it in by default
354 # by default. Since it's otherwise harmless, leave it in by default
354 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
355 # for nose >= 0.11, though unfortunately nose 0.10 doesn't support it.
355 argv.append('--traverse-namespace')
356 argv.append('--traverse-namespace')
356
357
357 # Construct list of plugins, omitting the existing doctest plugin, which
358 # Construct list of plugins, omitting the existing doctest plugin, which
358 # ours replaces (and extends).
359 # ours replaces (and extends).
359 plugins = [IPythonDoctest(make_exclude()), KnownFailure()]
360 plugins = [IPythonDoctest(make_exclude()), KnownFailure()]
360 for p in nose.plugins.builtin.plugins:
361 for p in nose.plugins.builtin.plugins:
361 plug = p()
362 plug = p()
362 if plug.name == 'doctest':
363 if plug.name == 'doctest':
363 continue
364 continue
364 plugins.append(plug)
365 plugins.append(plug)
365
366
366 # We need a global ipython running in this process
367 # We need a global ipython running in this process
367 globalipapp.start_ipython()
368 globalipapp.start_ipython()
368 # Now nose can run
369 # Now nose can run
369 TestProgram(argv=argv, plugins=plugins)
370 TestProgram(argv=argv, plugins=plugins)
370
371
371
372
372 def run_iptestall():
373 def run_iptestall():
373 """Run the entire IPython test suite by calling nose and trial.
374 """Run the entire IPython test suite by calling nose and trial.
374
375
375 This function constructs :class:`IPTester` instances for all IPython
376 This function constructs :class:`IPTester` instances for all IPython
376 modules and package and then runs each of them. This causes the modules
377 modules and package and then runs each of them. This causes the modules
377 and packages of IPython to be tested each in their own subprocess using
378 and packages of IPython to be tested each in their own subprocess using
378 nose or twisted.trial appropriately.
379 nose or twisted.trial appropriately.
379 """
380 """
380
381
381 runners = make_runners()
382 runners = make_runners()
382
383
383 # Run the test runners in a temporary dir so we can nuke it when finished
384 # Run the test runners in a temporary dir so we can nuke it when finished
384 # to clean up any junk files left over by accident. This also makes it
385 # to clean up any junk files left over by accident. This also makes it
385 # robust against being run in non-writeable directories by mistake, as the
386 # robust against being run in non-writeable directories by mistake, as the
386 # temp dir will always be user-writeable.
387 # temp dir will always be user-writeable.
387 curdir = os.getcwd()
388 curdir = os.getcwd()
388 testdir = tempfile.gettempdir()
389 testdir = tempfile.gettempdir()
389 os.chdir(testdir)
390 os.chdir(testdir)
390
391
391 # Run all test runners, tracking execution time
392 # Run all test runners, tracking execution time
392 failed = []
393 failed = []
393 t_start = time.time()
394 t_start = time.time()
394 try:
395 try:
395 for (name, runner) in runners:
396 for (name, runner) in runners:
396 print '*'*70
397 print '*'*70
397 print 'IPython test group:',name
398 print 'IPython test group:',name
398 res = runner.run()
399 res = runner.run()
399 if res:
400 if res:
400 failed.append( (name, runner) )
401 failed.append( (name, runner) )
401 finally:
402 finally:
402 os.chdir(curdir)
403 os.chdir(curdir)
403 t_end = time.time()
404 t_end = time.time()
404 t_tests = t_end - t_start
405 t_tests = t_end - t_start
405 nrunners = len(runners)
406 nrunners = len(runners)
406 nfail = len(failed)
407 nfail = len(failed)
407 # summarize results
408 # summarize results
408 print
409 print
409 print '*'*70
410 print '*'*70
410 print 'Test suite completed for system with the following information:'
411 print 'Test suite completed for system with the following information:'
411 print report()
412 print report()
412 print 'Ran %s test groups in %.3fs' % (nrunners, t_tests)
413 print 'Ran %s test groups in %.3fs' % (nrunners, t_tests)
413 print
414 print
414 print 'Status:'
415 print 'Status:'
415 if not failed:
416 if not failed:
416 print 'OK'
417 print 'OK'
417 else:
418 else:
418 # If anything went wrong, point out what command to rerun manually to
419 # If anything went wrong, point out what command to rerun manually to
419 # see the actual errors and individual summary
420 # see the actual errors and individual summary
420 print 'ERROR - %s out of %s test groups failed.' % (nfail, nrunners)
421 print 'ERROR - %s out of %s test groups failed.' % (nfail, nrunners)
421 for name, failed_runner in failed:
422 for name, failed_runner in failed:
422 print '-'*40
423 print '-'*40
423 print 'Runner failed:',name
424 print 'Runner failed:',name
424 print 'You may wish to rerun this one individually, with:'
425 print 'You may wish to rerun this one individually, with:'
425 print ' '.join(failed_runner.call_args)
426 print ' '.join(failed_runner.call_args)
426 print
427 print
427
428
428
429
429 def main():
430 def main():
430 for arg in sys.argv[1:]:
431 for arg in sys.argv[1:]:
431 if arg.startswith('IPython'):
432 if arg.startswith('IPython'):
432 # This is in-process
433 # This is in-process
433 run_iptest()
434 run_iptest()
434 else:
435 else:
435 # This starts subprocesses
436 # This starts subprocesses
436 run_iptestall()
437 run_iptestall()
437
438
438
439
439 if __name__ == '__main__':
440 if __name__ == '__main__':
440 main()
441 main()
General Comments 0
You need to be logged in to leave comments. Login now