##// END OF EJS Templates
Tweaks to improve automated conversion to Python 3 code.
Thomas Kluyver -
Show More
@@ -1,428 +1,428 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 from __future__ import with_statement
24 from __future__ import with_statement
25
25
26 import os
26 import os
27 import socket
27 import socket
28 import stat
28 import stat
29 import sys
29 import sys
30 import uuid
30 import uuid
31
31
32 from multiprocessing import Process
32 from multiprocessing import Process
33
33
34 import zmq
34 import zmq
35 from zmq.devices import ProcessMonitoredQueue
35 from zmq.devices import ProcessMonitoredQueue
36 from zmq.log.handlers import PUBHandler
36 from zmq.log.handlers import PUBHandler
37 from zmq.utils import jsonapi as json
37 from zmq.utils import jsonapi as json
38
38
39 from IPython.config.application import boolean_flag
39 from IPython.config.application import boolean_flag
40 from IPython.core.profiledir import ProfileDir
40 from IPython.core.profiledir import ProfileDir
41
41
42 from IPython.parallel.apps.baseapp import (
42 from IPython.parallel.apps.baseapp import (
43 BaseParallelApplication,
43 BaseParallelApplication,
44 base_flags
44 base_flags
45 )
45 )
46 from IPython.utils.importstring import import_item
46 from IPython.utils.importstring import import_item
47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
48
48
49 # from IPython.parallel.controller.controller import ControllerFactory
49 # from IPython.parallel.controller.controller import ControllerFactory
50 from IPython.zmq.session import Session
50 from IPython.zmq.session import Session
51 from IPython.parallel.controller.heartmonitor import HeartMonitor
51 from IPython.parallel.controller.heartmonitor import HeartMonitor
52 from IPython.parallel.controller.hub import HubFactory
52 from IPython.parallel.controller.hub import HubFactory
53 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
53 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
54 from IPython.parallel.controller.sqlitedb import SQLiteDB
54 from IPython.parallel.controller.sqlitedb import SQLiteDB
55
55
56 from IPython.parallel.util import signal_children, split_url
56 from IPython.parallel.util import signal_children, split_url
57
57
58 # conditional import of MongoDB backend class
58 # conditional import of MongoDB backend class
59
59
60 try:
60 try:
61 from IPython.parallel.controller.mongodb import MongoDB
61 from IPython.parallel.controller.mongodb import MongoDB
62 except ImportError:
62 except ImportError:
63 maybe_mongo = []
63 maybe_mongo = []
64 else:
64 else:
65 maybe_mongo = [MongoDB]
65 maybe_mongo = [MongoDB]
66
66
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Module level variables
69 # Module level variables
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72
72
73 #: The default config file name for this application
73 #: The default config file name for this application
74 default_config_file_name = u'ipcontroller_config.py'
74 default_config_file_name = u'ipcontroller_config.py'
75
75
76
76
77 _description = """Start the IPython controller for parallel computing.
77 _description = """Start the IPython controller for parallel computing.
78
78
79 The IPython controller provides a gateway between the IPython engines and
79 The IPython controller provides a gateway between the IPython engines and
80 clients. The controller needs to be started before the engines and can be
80 clients. The controller needs to be started before the engines and can be
81 configured using command line options or using a cluster directory. Cluster
81 configured using command line options or using a cluster directory. Cluster
82 directories contain config, log and security files and are usually located in
82 directories contain config, log and security files and are usually located in
83 your ipython directory and named as "profile_name". See the `profile`
83 your ipython directory and named as "profile_name". See the `profile`
84 and `profile_dir` options for details.
84 and `profile_dir` options for details.
85 """
85 """
86
86
87
87
88
88
89
89
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91 # The main application
91 # The main application
92 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
93 flags = {}
93 flags = {}
94 flags.update(base_flags)
94 flags.update(base_flags)
95 flags.update({
95 flags.update({
96 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
96 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
97 'Use threads instead of processes for the schedulers'),
97 'Use threads instead of processes for the schedulers'),
98 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
98 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
99 'use the SQLiteDB backend'),
99 'use the SQLiteDB backend'),
100 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
100 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
101 'use the MongoDB backend'),
101 'use the MongoDB backend'),
102 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
102 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
103 'use the in-memory DictDB backend'),
103 'use the in-memory DictDB backend'),
104 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
104 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
105 'reuse existing json connection files')
105 'reuse existing json connection files')
106 })
106 })
107
107
108 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
108 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
109 "Use HMAC digests for authentication of messages.",
109 "Use HMAC digests for authentication of messages.",
110 "Don't authenticate messages."
110 "Don't authenticate messages."
111 ))
111 ))
112
112
113 class IPControllerApp(BaseParallelApplication):
113 class IPControllerApp(BaseParallelApplication):
114
114
115 name = u'ipcontroller'
115 name = u'ipcontroller'
116 description = _description
116 description = _description
117 config_file_name = Unicode(default_config_file_name)
117 config_file_name = Unicode(default_config_file_name)
118 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
118 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
119
119
120 # change default to True
120 # change default to True
121 auto_create = Bool(True, config=True,
121 auto_create = Bool(True, config=True,
122 help="""Whether to create profile dir if it doesn't exist.""")
122 help="""Whether to create profile dir if it doesn't exist.""")
123
123
124 reuse_files = Bool(False, config=True,
124 reuse_files = Bool(False, config=True,
125 help='Whether to reuse existing json connection files.'
125 help='Whether to reuse existing json connection files.'
126 )
126 )
127 secure = Bool(True, config=True,
127 secure = Bool(True, config=True,
128 help='Whether to use HMAC digests for extra message authentication.'
128 help='Whether to use HMAC digests for extra message authentication.'
129 )
129 )
130 ssh_server = Unicode(u'', config=True,
130 ssh_server = Unicode(u'', config=True,
131 help="""ssh url for clients to use when connecting to the Controller
131 help="""ssh url for clients to use when connecting to the Controller
132 processes. It should be of the form: [user@]server[:port]. The
132 processes. It should be of the form: [user@]server[:port]. The
133 Controller's listening addresses must be accessible from the ssh server""",
133 Controller's listening addresses must be accessible from the ssh server""",
134 )
134 )
135 location = Unicode(u'', config=True,
135 location = Unicode(u'', config=True,
136 help="""The external IP or domain name of the Controller, used for disambiguating
136 help="""The external IP or domain name of the Controller, used for disambiguating
137 engine and client connections.""",
137 engine and client connections.""",
138 )
138 )
139 import_statements = List([], config=True,
139 import_statements = List([], config=True,
140 help="import statements to be run at startup. Necessary in some environments"
140 help="import statements to be run at startup. Necessary in some environments"
141 )
141 )
142
142
143 use_threads = Bool(False, config=True,
143 use_threads = Bool(False, config=True,
144 help='Use threads instead of processes for the schedulers',
144 help='Use threads instead of processes for the schedulers',
145 )
145 )
146
146
147 # internal
147 # internal
148 children = List()
148 children = List()
149 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
149 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
150
150
151 def _use_threads_changed(self, name, old, new):
151 def _use_threads_changed(self, name, old, new):
152 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
152 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
153
153
154 aliases = Dict(dict(
154 aliases = Dict(dict(
155 log_level = 'IPControllerApp.log_level',
155 log_level = 'IPControllerApp.log_level',
156 log_url = 'IPControllerApp.log_url',
156 log_url = 'IPControllerApp.log_url',
157 reuse_files = 'IPControllerApp.reuse_files',
157 reuse_files = 'IPControllerApp.reuse_files',
158 secure = 'IPControllerApp.secure',
158 secure = 'IPControllerApp.secure',
159 ssh = 'IPControllerApp.ssh_server',
159 ssh = 'IPControllerApp.ssh_server',
160 use_threads = 'IPControllerApp.use_threads',
160 use_threads = 'IPControllerApp.use_threads',
161 import_statements = 'IPControllerApp.import_statements',
161 import_statements = 'IPControllerApp.import_statements',
162 location = 'IPControllerApp.location',
162 location = 'IPControllerApp.location',
163
163
164 ident = 'Session.session',
164 ident = 'Session.session',
165 user = 'Session.username',
165 user = 'Session.username',
166 exec_key = 'Session.keyfile',
166 exec_key = 'Session.keyfile',
167
167
168 url = 'HubFactory.url',
168 url = 'HubFactory.url',
169 ip = 'HubFactory.ip',
169 ip = 'HubFactory.ip',
170 transport = 'HubFactory.transport',
170 transport = 'HubFactory.transport',
171 port = 'HubFactory.regport',
171 port = 'HubFactory.regport',
172
172
173 ping = 'HeartMonitor.period',
173 ping = 'HeartMonitor.period',
174
174
175 scheme = 'TaskScheduler.scheme_name',
175 scheme = 'TaskScheduler.scheme_name',
176 hwm = 'TaskScheduler.hwm',
176 hwm = 'TaskScheduler.hwm',
177
177
178
178
179 profile = "BaseIPythonApplication.profile",
179 profile = "BaseIPythonApplication.profile",
180 profile_dir = 'ProfileDir.location',
180 profile_dir = 'ProfileDir.location',
181
181
182 ))
182 ))
183 flags = Dict(flags)
183 flags = Dict(flags)
184
184
185
185
186 def save_connection_dict(self, fname, cdict):
186 def save_connection_dict(self, fname, cdict):
187 """save a connection dict to json file."""
187 """save a connection dict to json file."""
188 c = self.config
188 c = self.config
189 url = cdict['url']
189 url = cdict['url']
190 location = cdict['location']
190 location = cdict['location']
191 if not location:
191 if not location:
192 try:
192 try:
193 proto,ip,port = split_url(url)
193 proto,ip,port = split_url(url)
194 except AssertionError:
194 except AssertionError:
195 pass
195 pass
196 else:
196 else:
197 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
197 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
198 cdict['location'] = location
198 cdict['location'] = location
199 fname = os.path.join(self.profile_dir.security_dir, fname)
199 fname = os.path.join(self.profile_dir.security_dir, fname)
200 with open(fname, 'w') as f:
200 with open(fname, 'wb') as f:
201 f.write(json.dumps(cdict, indent=2))
201 f.write(json.dumps(cdict, indent=2))
202 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
202 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
203
203
204 def load_config_from_json(self):
204 def load_config_from_json(self):
205 """load config from existing json connector files."""
205 """load config from existing json connector files."""
206 c = self.config
206 c = self.config
207 # load from engine config
207 # load from engine config
208 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
208 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
209 cfg = json.loads(f.read())
209 cfg = json.loads(f.read())
210 key = c.Session.key = cfg['exec_key']
210 key = c.Session.key = cfg['exec_key']
211 xport,addr = cfg['url'].split('://')
211 xport,addr = cfg['url'].split('://')
212 c.HubFactory.engine_transport = xport
212 c.HubFactory.engine_transport = xport
213 ip,ports = addr.split(':')
213 ip,ports = addr.split(':')
214 c.HubFactory.engine_ip = ip
214 c.HubFactory.engine_ip = ip
215 c.HubFactory.regport = int(ports)
215 c.HubFactory.regport = int(ports)
216 self.location = cfg['location']
216 self.location = cfg['location']
217
217
218 # load client config
218 # load client config
219 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
219 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
220 cfg = json.loads(f.read())
220 cfg = json.loads(f.read())
221 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
221 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
222 xport,addr = cfg['url'].split('://')
222 xport,addr = cfg['url'].split('://')
223 c.HubFactory.client_transport = xport
223 c.HubFactory.client_transport = xport
224 ip,ports = addr.split(':')
224 ip,ports = addr.split(':')
225 c.HubFactory.client_ip = ip
225 c.HubFactory.client_ip = ip
226 self.ssh_server = cfg['ssh']
226 self.ssh_server = cfg['ssh']
227 assert int(ports) == c.HubFactory.regport, "regport mismatch"
227 assert int(ports) == c.HubFactory.regport, "regport mismatch"
228
228
229 def init_hub(self):
229 def init_hub(self):
230 c = self.config
230 c = self.config
231
231
232 self.do_import_statements()
232 self.do_import_statements()
233 reusing = self.reuse_files
233 reusing = self.reuse_files
234 if reusing:
234 if reusing:
235 try:
235 try:
236 self.load_config_from_json()
236 self.load_config_from_json()
237 except (AssertionError,IOError):
237 except (AssertionError,IOError):
238 reusing=False
238 reusing=False
239 # check again, because reusing may have failed:
239 # check again, because reusing may have failed:
240 if reusing:
240 if reusing:
241 pass
241 pass
242 elif self.secure:
242 elif self.secure:
243 key = str(uuid.uuid4())
243 key = str(uuid.uuid4())
244 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
244 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
245 # with open(keyfile, 'w') as f:
245 # with open(keyfile, 'w') as f:
246 # f.write(key)
246 # f.write(key)
247 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
247 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
248 c.Session.key = key
248 c.Session.key = key
249 else:
249 else:
250 key = c.Session.key = ''
250 key = c.Session.key = ''
251
251
252 try:
252 try:
253 self.factory = HubFactory(config=c, log=self.log)
253 self.factory = HubFactory(config=c, log=self.log)
254 # self.start_logging()
254 # self.start_logging()
255 self.factory.init_hub()
255 self.factory.init_hub()
256 except:
256 except:
257 self.log.error("Couldn't construct the Controller", exc_info=True)
257 self.log.error("Couldn't construct the Controller", exc_info=True)
258 self.exit(1)
258 self.exit(1)
259
259
260 if not reusing:
260 if not reusing:
261 # save to new json config files
261 # save to new json config files
262 f = self.factory
262 f = self.factory
263 cdict = {'exec_key' : key,
263 cdict = {'exec_key' : key,
264 'ssh' : self.ssh_server,
264 'ssh' : self.ssh_server,
265 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
265 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
266 'location' : self.location
266 'location' : self.location
267 }
267 }
268 self.save_connection_dict('ipcontroller-client.json', cdict)
268 self.save_connection_dict('ipcontroller-client.json', cdict)
269 edict = cdict
269 edict = cdict
270 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
270 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
271 self.save_connection_dict('ipcontroller-engine.json', edict)
271 self.save_connection_dict('ipcontroller-engine.json', edict)
272
272
273 #
273 #
274 def init_schedulers(self):
274 def init_schedulers(self):
275 children = self.children
275 children = self.children
276 mq = import_item(str(self.mq_class))
276 mq = import_item(str(self.mq_class))
277
277
278 hub = self.factory
278 hub = self.factory
279 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
279 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
280 # IOPub relay (in a Process)
280 # IOPub relay (in a Process)
281 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
281 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
282 q.bind_in(hub.client_info['iopub'])
282 q.bind_in(hub.client_info['iopub'])
283 q.bind_out(hub.engine_info['iopub'])
283 q.bind_out(hub.engine_info['iopub'])
284 q.setsockopt_out(zmq.SUBSCRIBE, '')
284 q.setsockopt_out(zmq.SUBSCRIBE, '')
285 q.connect_mon(hub.monitor_url)
285 q.connect_mon(hub.monitor_url)
286 q.daemon=True
286 q.daemon=True
287 children.append(q)
287 children.append(q)
288
288
289 # Multiplexer Queue (in a Process)
289 # Multiplexer Queue (in a Process)
290 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
290 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
291 q.bind_in(hub.client_info['mux'])
291 q.bind_in(hub.client_info['mux'])
292 q.setsockopt_in(zmq.IDENTITY, 'mux')
292 q.setsockopt_in(zmq.IDENTITY, 'mux')
293 q.bind_out(hub.engine_info['mux'])
293 q.bind_out(hub.engine_info['mux'])
294 q.connect_mon(hub.monitor_url)
294 q.connect_mon(hub.monitor_url)
295 q.daemon=True
295 q.daemon=True
296 children.append(q)
296 children.append(q)
297
297
298 # Control Queue (in a Process)
298 # Control Queue (in a Process)
299 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
299 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
300 q.bind_in(hub.client_info['control'])
300 q.bind_in(hub.client_info['control'])
301 q.setsockopt_in(zmq.IDENTITY, 'control')
301 q.setsockopt_in(zmq.IDENTITY, 'control')
302 q.bind_out(hub.engine_info['control'])
302 q.bind_out(hub.engine_info['control'])
303 q.connect_mon(hub.monitor_url)
303 q.connect_mon(hub.monitor_url)
304 q.daemon=True
304 q.daemon=True
305 children.append(q)
305 children.append(q)
306 try:
306 try:
307 scheme = self.config.TaskScheduler.scheme_name
307 scheme = self.config.TaskScheduler.scheme_name
308 except AttributeError:
308 except AttributeError:
309 scheme = TaskScheduler.scheme_name.get_default_value()
309 scheme = TaskScheduler.scheme_name.get_default_value()
310 # Task Queue (in a Process)
310 # Task Queue (in a Process)
311 if scheme == 'pure':
311 if scheme == 'pure':
312 self.log.warn("task::using pure XREQ Task scheduler")
312 self.log.warn("task::using pure XREQ Task scheduler")
313 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
313 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
314 # q.setsockopt_out(zmq.HWM, hub.hwm)
314 # q.setsockopt_out(zmq.HWM, hub.hwm)
315 q.bind_in(hub.client_info['task'][1])
315 q.bind_in(hub.client_info['task'][1])
316 q.setsockopt_in(zmq.IDENTITY, 'task')
316 q.setsockopt_in(zmq.IDENTITY, 'task')
317 q.bind_out(hub.engine_info['task'])
317 q.bind_out(hub.engine_info['task'])
318 q.connect_mon(hub.monitor_url)
318 q.connect_mon(hub.monitor_url)
319 q.daemon=True
319 q.daemon=True
320 children.append(q)
320 children.append(q)
321 elif scheme == 'none':
321 elif scheme == 'none':
322 self.log.warn("task::using no Task scheduler")
322 self.log.warn("task::using no Task scheduler")
323
323
324 else:
324 else:
325 self.log.info("task::using Python %s Task scheduler"%scheme)
325 self.log.info("task::using Python %s Task scheduler"%scheme)
326 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
326 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
327 hub.monitor_url, hub.client_info['notification'])
327 hub.monitor_url, hub.client_info['notification'])
328 kwargs = dict(logname='scheduler', loglevel=self.log_level,
328 kwargs = dict(logname='scheduler', loglevel=self.log_level,
329 log_url = self.log_url, config=dict(self.config))
329 log_url = self.log_url, config=dict(self.config))
330 if 'Process' in self.mq_class:
330 if 'Process' in self.mq_class:
331 # run the Python scheduler in a Process
331 # run the Python scheduler in a Process
332 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
332 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
333 q.daemon=True
333 q.daemon=True
334 children.append(q)
334 children.append(q)
335 else:
335 else:
336 # single-threaded Controller
336 # single-threaded Controller
337 kwargs['in_thread'] = True
337 kwargs['in_thread'] = True
338 launch_scheduler(*sargs, **kwargs)
338 launch_scheduler(*sargs, **kwargs)
339
339
340
340
341 def save_urls(self):
341 def save_urls(self):
342 """save the registration urls to files."""
342 """save the registration urls to files."""
343 c = self.config
343 c = self.config
344
344
345 sec_dir = self.profile_dir.security_dir
345 sec_dir = self.profile_dir.security_dir
346 cf = self.factory
346 cf = self.factory
347
347
348 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
348 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
349 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
349 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
350
350
351 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
351 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
352 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
352 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
353
353
354
354
355 def do_import_statements(self):
355 def do_import_statements(self):
356 statements = self.import_statements
356 statements = self.import_statements
357 for s in statements:
357 for s in statements:
358 try:
358 try:
359 self.log.msg("Executing statement: '%s'" % s)
359 self.log.msg("Executing statement: '%s'" % s)
360 exec s in globals(), locals()
360 exec s in globals(), locals()
361 except:
361 except:
362 self.log.msg("Error running statement: %s" % s)
362 self.log.msg("Error running statement: %s" % s)
363
363
364 def forward_logging(self):
364 def forward_logging(self):
365 if self.log_url:
365 if self.log_url:
366 self.log.info("Forwarding logging to %s"%self.log_url)
366 self.log.info("Forwarding logging to %s"%self.log_url)
367 context = zmq.Context.instance()
367 context = zmq.Context.instance()
368 lsock = context.socket(zmq.PUB)
368 lsock = context.socket(zmq.PUB)
369 lsock.connect(self.log_url)
369 lsock.connect(self.log_url)
370 handler = PUBHandler(lsock)
370 handler = PUBHandler(lsock)
371 self.log.removeHandler(self._log_handler)
371 self.log.removeHandler(self._log_handler)
372 handler.root_topic = 'controller'
372 handler.root_topic = 'controller'
373 handler.setLevel(self.log_level)
373 handler.setLevel(self.log_level)
374 self.log.addHandler(handler)
374 self.log.addHandler(handler)
375 self._log_handler = handler
375 self._log_handler = handler
376 # #
376 # #
377
377
378 def initialize(self, argv=None):
378 def initialize(self, argv=None):
379 super(IPControllerApp, self).initialize(argv)
379 super(IPControllerApp, self).initialize(argv)
380 self.forward_logging()
380 self.forward_logging()
381 self.init_hub()
381 self.init_hub()
382 self.init_schedulers()
382 self.init_schedulers()
383
383
384 def start(self):
384 def start(self):
385 # Start the subprocesses:
385 # Start the subprocesses:
386 self.factory.start()
386 self.factory.start()
387 child_procs = []
387 child_procs = []
388 for child in self.children:
388 for child in self.children:
389 child.start()
389 child.start()
390 if isinstance(child, ProcessMonitoredQueue):
390 if isinstance(child, ProcessMonitoredQueue):
391 child_procs.append(child.launcher)
391 child_procs.append(child.launcher)
392 elif isinstance(child, Process):
392 elif isinstance(child, Process):
393 child_procs.append(child)
393 child_procs.append(child)
394 if child_procs:
394 if child_procs:
395 signal_children(child_procs)
395 signal_children(child_procs)
396
396
397 self.write_pid_file(overwrite=True)
397 self.write_pid_file(overwrite=True)
398
398
399 try:
399 try:
400 self.factory.loop.start()
400 self.factory.loop.start()
401 except KeyboardInterrupt:
401 except KeyboardInterrupt:
402 self.log.critical("Interrupted, Exiting...\n")
402 self.log.critical("Interrupted, Exiting...\n")
403
403
404
404
405
405
406 def launch_new_instance():
406 def launch_new_instance():
407 """Create and run the IPython controller"""
407 """Create and run the IPython controller"""
408 if sys.platform == 'win32':
408 if sys.platform == 'win32':
409 # make sure we don't get called from a multiprocessing subprocess
409 # make sure we don't get called from a multiprocessing subprocess
410 # this can result in infinite Controllers being started on Windows
410 # this can result in infinite Controllers being started on Windows
411 # which doesn't have a proper fork, so multiprocessing is wonky
411 # which doesn't have a proper fork, so multiprocessing is wonky
412
412
413 # this only comes up when IPython has been installed using vanilla
413 # this only comes up when IPython has been installed using vanilla
414 # setuptools, and *not* distribute.
414 # setuptools, and *not* distribute.
415 import multiprocessing
415 import multiprocessing
416 p = multiprocessing.current_process()
416 p = multiprocessing.current_process()
417 # the main process has name 'MainProcess'
417 # the main process has name 'MainProcess'
418 # subprocesses will have names like 'Process-1'
418 # subprocesses will have names like 'Process-1'
419 if p.name != 'MainProcess':
419 if p.name != 'MainProcess':
420 # we are a subprocess, don't start another Controller!
420 # we are a subprocess, don't start another Controller!
421 return
421 return
422 app = IPControllerApp.instance()
422 app = IPControllerApp.instance()
423 app.initialize()
423 app.initialize()
424 app.start()
424 app.start()
425
425
426
426
427 if __name__ == '__main__':
427 if __name__ == '__main__':
428 launch_new_instance()
428 launch_new_instance()
@@ -1,1288 +1,1288 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010 The IPython Development Team
11 # Copyright (C) 2010 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 from __future__ import print_function
20 from __future__ import print_function
21
21
22 import sys
22 import sys
23 import time
23 import time
24 from datetime import datetime
24 from datetime import datetime
25
25
26 import zmq
26 import zmq
27 from zmq.eventloop import ioloop
27 from zmq.eventloop import ioloop
28 from zmq.eventloop.zmqstream import ZMQStream
28 from zmq.eventloop.zmqstream import ZMQStream
29
29
30 # internal:
30 # internal:
31 from IPython.utils.importstring import import_item
31 from IPython.utils.importstring import import_item
32 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 )
34 )
35
35
36 from IPython.parallel import error, util
36 from IPython.parallel import error, util
37 from IPython.parallel.factory import RegistrationFactory
37 from IPython.parallel.factory import RegistrationFactory
38
38
39 from IPython.zmq.session import SessionFactory
39 from IPython.zmq.session import SessionFactory
40
40
41 from .heartmonitor import HeartMonitor
41 from .heartmonitor import HeartMonitor
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Code
44 # Code
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 def _passer(*args, **kwargs):
47 def _passer(*args, **kwargs):
48 return
48 return
49
49
50 def _printer(*args, **kwargs):
50 def _printer(*args, **kwargs):
51 print (args)
51 print (args)
52 print (kwargs)
52 print (kwargs)
53
53
54 def empty_record():
54 def empty_record():
55 """Return an empty dict with all record keys."""
55 """Return an empty dict with all record keys."""
56 return {
56 return {
57 'msg_id' : None,
57 'msg_id' : None,
58 'header' : None,
58 'header' : None,
59 'content': None,
59 'content': None,
60 'buffers': None,
60 'buffers': None,
61 'submitted': None,
61 'submitted': None,
62 'client_uuid' : None,
62 'client_uuid' : None,
63 'engine_uuid' : None,
63 'engine_uuid' : None,
64 'started': None,
64 'started': None,
65 'completed': None,
65 'completed': None,
66 'resubmitted': None,
66 'resubmitted': None,
67 'result_header' : None,
67 'result_header' : None,
68 'result_content' : None,
68 'result_content' : None,
69 'result_buffers' : None,
69 'result_buffers' : None,
70 'queue' : None,
70 'queue' : None,
71 'pyin' : None,
71 'pyin' : None,
72 'pyout': None,
72 'pyout': None,
73 'pyerr': None,
73 'pyerr': None,
74 'stdout': '',
74 'stdout': '',
75 'stderr': '',
75 'stderr': '',
76 }
76 }
77
77
78 def init_record(msg):
78 def init_record(msg):
79 """Initialize a TaskRecord based on a request."""
79 """Initialize a TaskRecord based on a request."""
80 header = msg['header']
80 header = msg['header']
81 return {
81 return {
82 'msg_id' : header['msg_id'],
82 'msg_id' : header['msg_id'],
83 'header' : header,
83 'header' : header,
84 'content': msg['content'],
84 'content': msg['content'],
85 'buffers': msg['buffers'],
85 'buffers': msg['buffers'],
86 'submitted': header['date'],
86 'submitted': header['date'],
87 'client_uuid' : None,
87 'client_uuid' : None,
88 'engine_uuid' : None,
88 'engine_uuid' : None,
89 'started': None,
89 'started': None,
90 'completed': None,
90 'completed': None,
91 'resubmitted': None,
91 'resubmitted': None,
92 'result_header' : None,
92 'result_header' : None,
93 'result_content' : None,
93 'result_content' : None,
94 'result_buffers' : None,
94 'result_buffers' : None,
95 'queue' : None,
95 'queue' : None,
96 'pyin' : None,
96 'pyin' : None,
97 'pyout': None,
97 'pyout': None,
98 'pyerr': None,
98 'pyerr': None,
99 'stdout': '',
99 'stdout': '',
100 'stderr': '',
100 'stderr': '',
101 }
101 }
102
102
103
103
104 class EngineConnector(HasTraits):
104 class EngineConnector(HasTraits):
105 """A simple object for accessing the various zmq connections of an object.
105 """A simple object for accessing the various zmq connections of an object.
106 Attributes are:
106 Attributes are:
107 id (int): engine ID
107 id (int): engine ID
108 uuid (str): uuid (unused?)
108 uuid (str): uuid (unused?)
109 queue (str): identity of queue's XREQ socket
109 queue (str): identity of queue's XREQ socket
110 registration (str): identity of registration XREQ socket
110 registration (str): identity of registration XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
112 """
112 """
113 id=Int(0)
113 id=Int(0)
114 queue=CBytes()
114 queue=CBytes()
115 control=CBytes()
115 control=CBytes()
116 registration=CBytes()
116 registration=CBytes()
117 heartbeat=CBytes()
117 heartbeat=CBytes()
118 pending=Set()
118 pending=Set()
119
119
120 class HubFactory(RegistrationFactory):
120 class HubFactory(RegistrationFactory):
121 """The Configurable for setting up a Hub."""
121 """The Configurable for setting up a Hub."""
122
122
123 # port-pairs for monitoredqueues:
123 # port-pairs for monitoredqueues:
124 hb = Tuple(Int,Int,config=True,
124 hb = Tuple(Int,Int,config=True,
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 def _hb_default(self):
126 def _hb_default(self):
127 return tuple(util.select_random_ports(2))
127 return tuple(util.select_random_ports(2))
128
128
129 mux = Tuple(Int,Int,config=True,
129 mux = Tuple(Int,Int,config=True,
130 help="""Engine/Client Port pair for MUX queue""")
130 help="""Engine/Client Port pair for MUX queue""")
131
131
132 def _mux_default(self):
132 def _mux_default(self):
133 return tuple(util.select_random_ports(2))
133 return tuple(util.select_random_ports(2))
134
134
135 task = Tuple(Int,Int,config=True,
135 task = Tuple(Int,Int,config=True,
136 help="""Engine/Client Port pair for Task queue""")
136 help="""Engine/Client Port pair for Task queue""")
137 def _task_default(self):
137 def _task_default(self):
138 return tuple(util.select_random_ports(2))
138 return tuple(util.select_random_ports(2))
139
139
140 control = Tuple(Int,Int,config=True,
140 control = Tuple(Int,Int,config=True,
141 help="""Engine/Client Port pair for Control queue""")
141 help="""Engine/Client Port pair for Control queue""")
142
142
143 def _control_default(self):
143 def _control_default(self):
144 return tuple(util.select_random_ports(2))
144 return tuple(util.select_random_ports(2))
145
145
146 iopub = Tuple(Int,Int,config=True,
146 iopub = Tuple(Int,Int,config=True,
147 help="""Engine/Client Port pair for IOPub relay""")
147 help="""Engine/Client Port pair for IOPub relay""")
148
148
149 def _iopub_default(self):
149 def _iopub_default(self):
150 return tuple(util.select_random_ports(2))
150 return tuple(util.select_random_ports(2))
151
151
152 # single ports:
152 # single ports:
153 mon_port = Int(config=True,
153 mon_port = Int(config=True,
154 help="""Monitor (SUB) port for queue traffic""")
154 help="""Monitor (SUB) port for queue traffic""")
155
155
156 def _mon_port_default(self):
156 def _mon_port_default(self):
157 return util.select_random_ports(1)[0]
157 return util.select_random_ports(1)[0]
158
158
159 notifier_port = Int(config=True,
159 notifier_port = Int(config=True,
160 help="""PUB port for sending engine status notifications""")
160 help="""PUB port for sending engine status notifications""")
161
161
162 def _notifier_port_default(self):
162 def _notifier_port_default(self):
163 return util.select_random_ports(1)[0]
163 return util.select_random_ports(1)[0]
164
164
165 engine_ip = Unicode('127.0.0.1', config=True,
165 engine_ip = Unicode('127.0.0.1', config=True,
166 help="IP on which to listen for engine connections. [default: loopback]")
166 help="IP on which to listen for engine connections. [default: loopback]")
167 engine_transport = Unicode('tcp', config=True,
167 engine_transport = Unicode('tcp', config=True,
168 help="0MQ transport for engine connections. [default: tcp]")
168 help="0MQ transport for engine connections. [default: tcp]")
169
169
170 client_ip = Unicode('127.0.0.1', config=True,
170 client_ip = Unicode('127.0.0.1', config=True,
171 help="IP on which to listen for client connections. [default: loopback]")
171 help="IP on which to listen for client connections. [default: loopback]")
172 client_transport = Unicode('tcp', config=True,
172 client_transport = Unicode('tcp', config=True,
173 help="0MQ transport for client connections. [default : tcp]")
173 help="0MQ transport for client connections. [default : tcp]")
174
174
175 monitor_ip = Unicode('127.0.0.1', config=True,
175 monitor_ip = Unicode('127.0.0.1', config=True,
176 help="IP on which to listen for monitor messages. [default: loopback]")
176 help="IP on which to listen for monitor messages. [default: loopback]")
177 monitor_transport = Unicode('tcp', config=True,
177 monitor_transport = Unicode('tcp', config=True,
178 help="0MQ transport for monitor messages. [default : tcp]")
178 help="0MQ transport for monitor messages. [default : tcp]")
179
179
180 monitor_url = Unicode('')
180 monitor_url = Unicode('')
181
181
182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 config=True, help="""The class to use for the DB backend""")
183 config=True, help="""The class to use for the DB backend""")
184
184
185 # not configurable
185 # not configurable
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188
188
189 def _ip_changed(self, name, old, new):
189 def _ip_changed(self, name, old, new):
190 self.engine_ip = new
190 self.engine_ip = new
191 self.client_ip = new
191 self.client_ip = new
192 self.monitor_ip = new
192 self.monitor_ip = new
193 self._update_monitor_url()
193 self._update_monitor_url()
194
194
195 def _update_monitor_url(self):
195 def _update_monitor_url(self):
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197
197
198 def _transport_changed(self, name, old, new):
198 def _transport_changed(self, name, old, new):
199 self.engine_transport = new
199 self.engine_transport = new
200 self.client_transport = new
200 self.client_transport = new
201 self.monitor_transport = new
201 self.monitor_transport = new
202 self._update_monitor_url()
202 self._update_monitor_url()
203
203
204 def __init__(self, **kwargs):
204 def __init__(self, **kwargs):
205 super(HubFactory, self).__init__(**kwargs)
205 super(HubFactory, self).__init__(**kwargs)
206 self._update_monitor_url()
206 self._update_monitor_url()
207
207
208
208
209 def construct(self):
209 def construct(self):
210 self.init_hub()
210 self.init_hub()
211
211
212 def start(self):
212 def start(self):
213 self.heartmonitor.start()
213 self.heartmonitor.start()
214 self.log.info("Heartmonitor started")
214 self.log.info("Heartmonitor started")
215
215
216 def init_hub(self):
216 def init_hub(self):
217 """construct"""
217 """construct"""
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220
220
221 ctx = self.context
221 ctx = self.context
222 loop = self.loop
222 loop = self.loop
223
223
224 # Registrar socket
224 # Registrar socket
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 q.bind(client_iface % self.regport)
226 q.bind(client_iface % self.regport)
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 if self.client_ip != self.engine_ip:
228 if self.client_ip != self.engine_ip:
229 q.bind(engine_iface % self.regport)
229 q.bind(engine_iface % self.regport)
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231
231
232 ### Engine connections ###
232 ### Engine connections ###
233
233
234 # heartbeat
234 # heartbeat
235 hpub = ctx.socket(zmq.PUB)
235 hpub = ctx.socket(zmq.PUB)
236 hpub.bind(engine_iface % self.hb[0])
236 hpub.bind(engine_iface % self.hb[0])
237 hrep = ctx.socket(zmq.XREP)
237 hrep = ctx.socket(zmq.XREP)
238 hrep.bind(engine_iface % self.hb[1])
238 hrep.bind(engine_iface % self.hb[1])
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 pingstream=ZMQStream(hpub,loop),
240 pingstream=ZMQStream(hpub,loop),
241 pongstream=ZMQStream(hrep,loop)
241 pongstream=ZMQStream(hrep,loop)
242 )
242 )
243
243
244 ### Client connections ###
244 ### Client connections ###
245 # Notifier socket
245 # Notifier socket
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 n.bind(client_iface%self.notifier_port)
247 n.bind(client_iface%self.notifier_port)
248
248
249 ### build and launch the queues ###
249 ### build and launch the queues ###
250
250
251 # monitor socket
251 # monitor socket
252 sub = ctx.socket(zmq.SUB)
252 sub = ctx.socket(zmq.SUB)
253 sub.setsockopt(zmq.SUBSCRIBE, "")
253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 sub.bind(self.monitor_url)
254 sub.bind(self.monitor_url)
255 sub.bind('inproc://monitor')
255 sub.bind('inproc://monitor')
256 sub = ZMQStream(sub, loop)
256 sub = ZMQStream(sub, loop)
257
257
258 # connect the db
258 # connect the db
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 # cdir = self.config.Global.cluster_dir
260 # cdir = self.config.Global.cluster_dir
261 self.db = import_item(str(self.db_class))(session=self.session.session,
261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 config=self.config, log=self.log)
262 config=self.config, log=self.log)
263 time.sleep(.25)
263 time.sleep(.25)
264 try:
264 try:
265 scheme = self.config.TaskScheduler.scheme_name
265 scheme = self.config.TaskScheduler.scheme_name
266 except AttributeError:
266 except AttributeError:
267 from .scheduler import TaskScheduler
267 from .scheduler import TaskScheduler
268 scheme = TaskScheduler.scheme_name.get_default_value()
268 scheme = TaskScheduler.scheme_name.get_default_value()
269 # build connection dicts
269 # build connection dicts
270 self.engine_info = {
270 self.engine_info = {
271 'control' : engine_iface%self.control[1],
271 'control' : engine_iface%self.control[1],
272 'mux': engine_iface%self.mux[1],
272 'mux': engine_iface%self.mux[1],
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 'task' : engine_iface%self.task[1],
274 'task' : engine_iface%self.task[1],
275 'iopub' : engine_iface%self.iopub[1],
275 'iopub' : engine_iface%self.iopub[1],
276 # 'monitor' : engine_iface%self.mon_port,
276 # 'monitor' : engine_iface%self.mon_port,
277 }
277 }
278
278
279 self.client_info = {
279 self.client_info = {
280 'control' : client_iface%self.control[0],
280 'control' : client_iface%self.control[0],
281 'mux': client_iface%self.mux[0],
281 'mux': client_iface%self.mux[0],
282 'task' : (scheme, client_iface%self.task[0]),
282 'task' : (scheme, client_iface%self.task[0]),
283 'iopub' : client_iface%self.iopub[0],
283 'iopub' : client_iface%self.iopub[0],
284 'notification': client_iface%self.notifier_port
284 'notification': client_iface%self.notifier_port
285 }
285 }
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
288
288
289 # resubmit stream
289 # resubmit stream
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 url = util.disambiguate_url(self.client_info['task'][-1])
291 url = util.disambiguate_url(self.client_info['task'][-1])
292 r.setsockopt(zmq.IDENTITY, self.session.session)
292 r.setsockopt(zmq.IDENTITY, self.session.session)
293 r.connect(url)
293 r.connect(url)
294
294
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 query=q, notifier=n, resubmit=r, db=self.db,
296 query=q, notifier=n, resubmit=r, db=self.db,
297 engine_info=self.engine_info, client_info=self.client_info,
297 engine_info=self.engine_info, client_info=self.client_info,
298 log=self.log)
298 log=self.log)
299
299
300
300
301 class Hub(SessionFactory):
301 class Hub(SessionFactory):
302 """The IPython Controller Hub with 0MQ connections
302 """The IPython Controller Hub with 0MQ connections
303
303
304 Parameters
304 Parameters
305 ==========
305 ==========
306 loop: zmq IOLoop instance
306 loop: zmq IOLoop instance
307 session: Session object
307 session: Session object
308 <removed> context: zmq context for creating new connections (?)
308 <removed> context: zmq context for creating new connections (?)
309 queue: ZMQStream for monitoring the command queue (SUB)
309 queue: ZMQStream for monitoring the command queue (SUB)
310 query: ZMQStream for engine registration and client queries requests (XREP)
310 query: ZMQStream for engine registration and client queries requests (XREP)
311 heartbeat: HeartMonitor object checking the pulse of the engines
311 heartbeat: HeartMonitor object checking the pulse of the engines
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 db: connection to db for out of memory logging of commands
313 db: connection to db for out of memory logging of commands
314 NotImplemented
314 NotImplemented
315 engine_info: dict of zmq connection information for engines to connect
315 engine_info: dict of zmq connection information for engines to connect
316 to the queues.
316 to the queues.
317 client_info: dict of zmq connection information for engines to connect
317 client_info: dict of zmq connection information for engines to connect
318 to the queues.
318 to the queues.
319 """
319 """
320 # internal data structures:
320 # internal data structures:
321 ids=Set() # engine IDs
321 ids=Set() # engine IDs
322 keytable=Dict()
322 keytable=Dict()
323 by_ident=Dict()
323 by_ident=Dict()
324 engines=Dict()
324 engines=Dict()
325 clients=Dict()
325 clients=Dict()
326 hearts=Dict()
326 hearts=Dict()
327 pending=Set()
327 pending=Set()
328 queues=Dict() # pending msg_ids keyed by engine_id
328 queues=Dict() # pending msg_ids keyed by engine_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 completed=Dict() # completed msg_ids keyed by engine_id
330 completed=Dict() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 incoming_registrations=Dict()
334 incoming_registrations=Dict()
335 registration_timeout=Int()
335 registration_timeout=Int()
336 _idcounter=Int(0)
336 _idcounter=Int(0)
337
337
338 # objects from constructor:
338 # objects from constructor:
339 query=Instance(ZMQStream)
339 query=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
343 heartmonitor=Instance(HeartMonitor)
343 heartmonitor=Instance(HeartMonitor)
344 db=Instance(object)
344 db=Instance(object)
345 client_info=Dict()
345 client_info=Dict()
346 engine_info=Dict()
346 engine_info=Dict()
347
347
348
348
349 def __init__(self, **kwargs):
349 def __init__(self, **kwargs):
350 """
350 """
351 # universal:
351 # universal:
352 loop: IOLoop for creating future connections
352 loop: IOLoop for creating future connections
353 session: streamsession for sending serialized data
353 session: streamsession for sending serialized data
354 # engine:
354 # engine:
355 queue: ZMQStream for monitoring queue messages
355 queue: ZMQStream for monitoring queue messages
356 query: ZMQStream for engine+client registration and client requests
356 query: ZMQStream for engine+client registration and client requests
357 heartbeat: HeartMonitor object for tracking engines
357 heartbeat: HeartMonitor object for tracking engines
358 # extra:
358 # extra:
359 db: ZMQStream for db connection (NotImplemented)
359 db: ZMQStream for db connection (NotImplemented)
360 engine_info: zmq address/protocol dict for engine connections
360 engine_info: zmq address/protocol dict for engine connections
361 client_info: zmq address/protocol dict for client connections
361 client_info: zmq address/protocol dict for client connections
362 """
362 """
363
363
364 super(Hub, self).__init__(**kwargs)
364 super(Hub, self).__init__(**kwargs)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366
366
367 # validate connection dicts:
367 # validate connection dicts:
368 for k,v in self.client_info.iteritems():
368 for k,v in self.client_info.iteritems():
369 if k == 'task':
369 if k == 'task':
370 util.validate_url_container(v[1])
370 util.validate_url_container(v[1])
371 else:
371 else:
372 util.validate_url_container(v)
372 util.validate_url_container(v)
373 # util.validate_url_container(self.client_info)
373 # util.validate_url_container(self.client_info)
374 util.validate_url_container(self.engine_info)
374 util.validate_url_container(self.engine_info)
375
375
376 # register our callbacks
376 # register our callbacks
377 self.query.on_recv(self.dispatch_query)
377 self.query.on_recv(self.dispatch_query)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379
379
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382
382
383 self.monitor_handlers = { 'in' : self.save_queue_request,
383 self.monitor_handlers = { 'in' : self.save_queue_request,
384 'out': self.save_queue_result,
384 'out': self.save_queue_result,
385 'intask': self.save_task_request,
385 'intask': self.save_task_request,
386 'outtask': self.save_task_result,
386 'outtask': self.save_task_result,
387 'tracktask': self.save_task_destination,
387 'tracktask': self.save_task_destination,
388 'incontrol': _passer,
388 'incontrol': _passer,
389 'outcontrol': _passer,
389 'outcontrol': _passer,
390 'iopub': self.save_iopub_message,
390 'iopub': self.save_iopub_message,
391 }
391 }
392
392
393 self.query_handlers = {'queue_request': self.queue_status,
393 self.query_handlers = {'queue_request': self.queue_status,
394 'result_request': self.get_results,
394 'result_request': self.get_results,
395 'history_request': self.get_history,
395 'history_request': self.get_history,
396 'db_request': self.db_query,
396 'db_request': self.db_query,
397 'purge_request': self.purge_results,
397 'purge_request': self.purge_results,
398 'load_request': self.check_load,
398 'load_request': self.check_load,
399 'resubmit_request': self.resubmit_task,
399 'resubmit_request': self.resubmit_task,
400 'shutdown_request': self.shutdown_request,
400 'shutdown_request': self.shutdown_request,
401 'registration_request' : self.register_engine,
401 'registration_request' : self.register_engine,
402 'unregistration_request' : self.unregister_engine,
402 'unregistration_request' : self.unregister_engine,
403 'connection_request': self.connection_request,
403 'connection_request': self.connection_request,
404 }
404 }
405
405
406 # ignore resubmit replies
406 # ignore resubmit replies
407 self.resubmit.on_recv(lambda msg: None, copy=False)
407 self.resubmit.on_recv(lambda msg: None, copy=False)
408
408
409 self.log.info("hub::created hub")
409 self.log.info("hub::created hub")
410
410
411 @property
411 @property
412 def _next_id(self):
412 def _next_id(self):
413 """gemerate a new ID.
413 """gemerate a new ID.
414
414
415 No longer reuse old ids, just count from 0."""
415 No longer reuse old ids, just count from 0."""
416 newid = self._idcounter
416 newid = self._idcounter
417 self._idcounter += 1
417 self._idcounter += 1
418 return newid
418 return newid
419 # newid = 0
419 # newid = 0
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 # # print newid, self.ids, self.incoming_registrations
421 # # print newid, self.ids, self.incoming_registrations
422 # while newid in self.ids or newid in incoming:
422 # while newid in self.ids or newid in incoming:
423 # newid += 1
423 # newid += 1
424 # return newid
424 # return newid
425
425
426 #-----------------------------------------------------------------------------
426 #-----------------------------------------------------------------------------
427 # message validation
427 # message validation
428 #-----------------------------------------------------------------------------
428 #-----------------------------------------------------------------------------
429
429
430 def _validate_targets(self, targets):
430 def _validate_targets(self, targets):
431 """turn any valid targets argument into a list of integer ids"""
431 """turn any valid targets argument into a list of integer ids"""
432 if targets is None:
432 if targets is None:
433 # default to all
433 # default to all
434 targets = self.ids
434 targets = self.ids
435
435
436 if isinstance(targets, (int,str,unicode)):
436 if isinstance(targets, (int,str,unicode)):
437 # only one target specified
437 # only one target specified
438 targets = [targets]
438 targets = [targets]
439 _targets = []
439 _targets = []
440 for t in targets:
440 for t in targets:
441 # map raw identities to ids
441 # map raw identities to ids
442 if isinstance(t, (str,unicode)):
442 if isinstance(t, (str,unicode)):
443 t = self.by_ident.get(t, t)
443 t = self.by_ident.get(t, t)
444 _targets.append(t)
444 _targets.append(t)
445 targets = _targets
445 targets = _targets
446 bad_targets = [ t for t in targets if t not in self.ids ]
446 bad_targets = [ t for t in targets if t not in self.ids ]
447 if bad_targets:
447 if bad_targets:
448 raise IndexError("No Such Engine: %r"%bad_targets)
448 raise IndexError("No Such Engine: %r"%bad_targets)
449 if not targets:
449 if not targets:
450 raise IndexError("No Engines Registered")
450 raise IndexError("No Engines Registered")
451 return targets
451 return targets
452
452
453 #-----------------------------------------------------------------------------
453 #-----------------------------------------------------------------------------
454 # dispatch methods (1 per stream)
454 # dispatch methods (1 per stream)
455 #-----------------------------------------------------------------------------
455 #-----------------------------------------------------------------------------
456
456
457
457
458 def dispatch_monitor_traffic(self, msg):
458 def dispatch_monitor_traffic(self, msg):
459 """all ME and Task queue messages come through here, as well as
459 """all ME and Task queue messages come through here, as well as
460 IOPub traffic."""
460 IOPub traffic."""
461 self.log.debug("monitor traffic: %r"%msg[:2])
461 self.log.debug("monitor traffic: %r"%msg[:2])
462 switch = msg[0]
462 switch = msg[0]
463 try:
463 try:
464 idents, msg = self.session.feed_identities(msg[1:])
464 idents, msg = self.session.feed_identities(msg[1:])
465 except ValueError:
465 except ValueError:
466 idents=[]
466 idents=[]
467 if not idents:
467 if not idents:
468 self.log.error("Bad Monitor Message: %r"%msg)
468 self.log.error("Bad Monitor Message: %r"%msg)
469 return
469 return
470 handler = self.monitor_handlers.get(switch, None)
470 handler = self.monitor_handlers.get(switch, None)
471 if handler is not None:
471 if handler is not None:
472 handler(idents, msg)
472 handler(idents, msg)
473 else:
473 else:
474 self.log.error("Invalid monitor topic: %r"%switch)
474 self.log.error("Invalid monitor topic: %r"%switch)
475
475
476
476
477 def dispatch_query(self, msg):
477 def dispatch_query(self, msg):
478 """Route registration requests and queries from clients."""
478 """Route registration requests and queries from clients."""
479 try:
479 try:
480 idents, msg = self.session.feed_identities(msg)
480 idents, msg = self.session.feed_identities(msg)
481 except ValueError:
481 except ValueError:
482 idents = []
482 idents = []
483 if not idents:
483 if not idents:
484 self.log.error("Bad Query Message: %r"%msg)
484 self.log.error("Bad Query Message: %r"%msg)
485 return
485 return
486 client_id = idents[0]
486 client_id = idents[0]
487 try:
487 try:
488 msg = self.session.unpack_message(msg, content=True)
488 msg = self.session.unpack_message(msg, content=True)
489 except Exception:
489 except Exception:
490 content = error.wrap_exception()
490 content = error.wrap_exception()
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 self.session.send(self.query, "hub_error", ident=client_id,
492 self.session.send(self.query, "hub_error", ident=client_id,
493 content=content)
493 content=content)
494 return
494 return
495 # print client_id, header, parent, content
495 # print client_id, header, parent, content
496 #switch on message type:
496 #switch on message type:
497 msg_type = msg['msg_type']
497 msg_type = msg['msg_type']
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 handler = self.query_handlers.get(msg_type, None)
499 handler = self.query_handlers.get(msg_type, None)
500 try:
500 try:
501 assert handler is not None, "Bad Message Type: %r"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 except:
502 except:
503 content = error.wrap_exception()
503 content = error.wrap_exception()
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 self.session.send(self.query, "hub_error", ident=client_id,
505 self.session.send(self.query, "hub_error", ident=client_id,
506 content=content)
506 content=content)
507 return
507 return
508
508
509 else:
509 else:
510 handler(idents, msg)
510 handler(idents, msg)
511
511
512 def dispatch_db(self, msg):
512 def dispatch_db(self, msg):
513 """"""
513 """"""
514 raise NotImplementedError
514 raise NotImplementedError
515
515
516 #---------------------------------------------------------------------------
516 #---------------------------------------------------------------------------
517 # handler methods (1 per event)
517 # handler methods (1 per event)
518 #---------------------------------------------------------------------------
518 #---------------------------------------------------------------------------
519
519
520 #----------------------- Heartbeat --------------------------------------
520 #----------------------- Heartbeat --------------------------------------
521
521
522 def handle_new_heart(self, heart):
522 def handle_new_heart(self, heart):
523 """handler to attach to heartbeater.
523 """handler to attach to heartbeater.
524 Called when a new heart starts to beat.
524 Called when a new heart starts to beat.
525 Triggers completion of registration."""
525 Triggers completion of registration."""
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 if heart not in self.incoming_registrations:
527 if heart not in self.incoming_registrations:
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 else:
529 else:
530 self.finish_registration(heart)
530 self.finish_registration(heart)
531
531
532
532
533 def handle_heart_failure(self, heart):
533 def handle_heart_failure(self, heart):
534 """handler to attach to heartbeater.
534 """handler to attach to heartbeater.
535 called when a previously registered heart fails to respond to beat request.
535 called when a previously registered heart fails to respond to beat request.
536 triggers unregistration"""
536 triggers unregistration"""
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 eid = self.hearts.get(heart, None)
538 eid = self.hearts.get(heart, None)
539 queue = self.engines[eid].queue
539 queue = self.engines[eid].queue
540 if eid is None:
540 if eid is None:
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 else:
542 else:
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544
544
545 #----------------------- MUX Queue Traffic ------------------------------
545 #----------------------- MUX Queue Traffic ------------------------------
546
546
547 def save_queue_request(self, idents, msg):
547 def save_queue_request(self, idents, msg):
548 if len(idents) < 2:
548 if len(idents) < 2:
549 self.log.error("invalid identity prefix: %r"%idents)
549 self.log.error("invalid identity prefix: %r"%idents)
550 return
550 return
551 queue_id, client_id = idents[:2]
551 queue_id, client_id = idents[:2]
552 try:
552 try:
553 msg = self.session.unpack_message(msg)
553 msg = self.session.unpack_message(msg)
554 except Exception:
554 except Exception:
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 return
556 return
557
557
558 eid = self.by_ident.get(queue_id, None)
558 eid = self.by_ident.get(queue_id, None)
559 if eid is None:
559 if eid is None:
560 self.log.error("queue::target %r not registered"%queue_id)
560 self.log.error("queue::target %r not registered"%queue_id)
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 return
562 return
563 record = init_record(msg)
563 record = init_record(msg)
564 msg_id = record['msg_id']
564 msg_id = record['msg_id']
565 record['engine_uuid'] = queue_id
565 record['engine_uuid'] = queue_id
566 record['client_uuid'] = client_id
566 record['client_uuid'] = client_id
567 record['queue'] = 'mux'
567 record['queue'] = 'mux'
568
568
569 try:
569 try:
570 # it's posible iopub arrived first:
570 # it's posible iopub arrived first:
571 existing = self.db.get_record(msg_id)
571 existing = self.db.get_record(msg_id)
572 for key,evalue in existing.iteritems():
572 for key,evalue in existing.iteritems():
573 rvalue = record.get(key, None)
573 rvalue = record.get(key, None)
574 if evalue and rvalue and evalue != rvalue:
574 if evalue and rvalue and evalue != rvalue:
575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 elif evalue and not rvalue:
576 elif evalue and not rvalue:
577 record[key] = evalue
577 record[key] = evalue
578 try:
578 try:
579 self.db.update_record(msg_id, record)
579 self.db.update_record(msg_id, record)
580 except Exception:
580 except Exception:
581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 except KeyError:
582 except KeyError:
583 try:
583 try:
584 self.db.add_record(msg_id, record)
584 self.db.add_record(msg_id, record)
585 except Exception:
585 except Exception:
586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587
587
588
588
589 self.pending.add(msg_id)
589 self.pending.add(msg_id)
590 self.queues[eid].append(msg_id)
590 self.queues[eid].append(msg_id)
591
591
592 def save_queue_result(self, idents, msg):
592 def save_queue_result(self, idents, msg):
593 if len(idents) < 2:
593 if len(idents) < 2:
594 self.log.error("invalid identity prefix: %r"%idents)
594 self.log.error("invalid identity prefix: %r"%idents)
595 return
595 return
596
596
597 client_id, queue_id = idents[:2]
597 client_id, queue_id = idents[:2]
598 try:
598 try:
599 msg = self.session.unpack_message(msg)
599 msg = self.session.unpack_message(msg)
600 except Exception:
600 except Exception:
601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 queue_id,client_id, msg), exc_info=True)
602 queue_id,client_id, msg), exc_info=True)
603 return
603 return
604
604
605 eid = self.by_ident.get(queue_id, None)
605 eid = self.by_ident.get(queue_id, None)
606 if eid is None:
606 if eid is None:
607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 return
608 return
609
609
610 parent = msg['parent_header']
610 parent = msg['parent_header']
611 if not parent:
611 if not parent:
612 return
612 return
613 msg_id = parent['msg_id']
613 msg_id = parent['msg_id']
614 if msg_id in self.pending:
614 if msg_id in self.pending:
615 self.pending.remove(msg_id)
615 self.pending.remove(msg_id)
616 self.all_completed.add(msg_id)
616 self.all_completed.add(msg_id)
617 self.queues[eid].remove(msg_id)
617 self.queues[eid].remove(msg_id)
618 self.completed[eid].append(msg_id)
618 self.completed[eid].append(msg_id)
619 elif msg_id not in self.all_completed:
619 elif msg_id not in self.all_completed:
620 # it could be a result from a dead engine that died before delivering the
620 # it could be a result from a dead engine that died before delivering the
621 # result
621 # result
622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 return
623 return
624 # update record anyway, because the unregistration could have been premature
624 # update record anyway, because the unregistration could have been premature
625 rheader = msg['header']
625 rheader = msg['header']
626 completed = rheader['date']
626 completed = rheader['date']
627 started = rheader.get('started', None)
627 started = rheader.get('started', None)
628 result = {
628 result = {
629 'result_header' : rheader,
629 'result_header' : rheader,
630 'result_content': msg['content'],
630 'result_content': msg['content'],
631 'started' : started,
631 'started' : started,
632 'completed' : completed
632 'completed' : completed
633 }
633 }
634
634
635 result['result_buffers'] = msg['buffers']
635 result['result_buffers'] = msg['buffers']
636 try:
636 try:
637 self.db.update_record(msg_id, result)
637 self.db.update_record(msg_id, result)
638 except Exception:
638 except Exception:
639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640
640
641
641
642 #--------------------- Task Queue Traffic ------------------------------
642 #--------------------- Task Queue Traffic ------------------------------
643
643
644 def save_task_request(self, idents, msg):
644 def save_task_request(self, idents, msg):
645 """Save the submission of a task."""
645 """Save the submission of a task."""
646 client_id = idents[0]
646 client_id = idents[0]
647
647
648 try:
648 try:
649 msg = self.session.unpack_message(msg)
649 msg = self.session.unpack_message(msg)
650 except Exception:
650 except Exception:
651 self.log.error("task::client %r sent invalid task message: %r"%(
651 self.log.error("task::client %r sent invalid task message: %r"%(
652 client_id, msg), exc_info=True)
652 client_id, msg), exc_info=True)
653 return
653 return
654 record = init_record(msg)
654 record = init_record(msg)
655
655
656 record['client_uuid'] = client_id
656 record['client_uuid'] = client_id
657 record['queue'] = 'task'
657 record['queue'] = 'task'
658 header = msg['header']
658 header = msg['header']
659 msg_id = header['msg_id']
659 msg_id = header['msg_id']
660 self.pending.add(msg_id)
660 self.pending.add(msg_id)
661 self.unassigned.add(msg_id)
661 self.unassigned.add(msg_id)
662 try:
662 try:
663 # it's posible iopub arrived first:
663 # it's posible iopub arrived first:
664 existing = self.db.get_record(msg_id)
664 existing = self.db.get_record(msg_id)
665 if existing['resubmitted']:
665 if existing['resubmitted']:
666 for key in ('submitted', 'client_uuid', 'buffers'):
666 for key in ('submitted', 'client_uuid', 'buffers'):
667 # don't clobber these keys on resubmit
667 # don't clobber these keys on resubmit
668 # submitted and client_uuid should be different
668 # submitted and client_uuid should be different
669 # and buffers might be big, and shouldn't have changed
669 # and buffers might be big, and shouldn't have changed
670 record.pop(key)
670 record.pop(key)
671 # still check content,header which should not change
671 # still check content,header which should not change
672 # but are not expensive to compare as buffers
672 # but are not expensive to compare as buffers
673
673
674 for key,evalue in existing.iteritems():
674 for key,evalue in existing.iteritems():
675 if key.endswith('buffers'):
675 if key.endswith('buffers'):
676 # don't compare buffers
676 # don't compare buffers
677 continue
677 continue
678 rvalue = record.get(key, None)
678 rvalue = record.get(key, None)
679 if evalue and rvalue and evalue != rvalue:
679 if evalue and rvalue and evalue != rvalue:
680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 elif evalue and not rvalue:
681 elif evalue and not rvalue:
682 record[key] = evalue
682 record[key] = evalue
683 try:
683 try:
684 self.db.update_record(msg_id, record)
684 self.db.update_record(msg_id, record)
685 except Exception:
685 except Exception:
686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 except KeyError:
687 except KeyError:
688 try:
688 try:
689 self.db.add_record(msg_id, record)
689 self.db.add_record(msg_id, record)
690 except Exception:
690 except Exception:
691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 except Exception:
692 except Exception:
693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694
694
695 def save_task_result(self, idents, msg):
695 def save_task_result(self, idents, msg):
696 """save the result of a completed task."""
696 """save the result of a completed task."""
697 client_id = idents[0]
697 client_id = idents[0]
698 try:
698 try:
699 msg = self.session.unpack_message(msg)
699 msg = self.session.unpack_message(msg)
700 except Exception:
700 except Exception:
701 self.log.error("task::invalid task result message send to %r: %r"%(
701 self.log.error("task::invalid task result message send to %r: %r"%(
702 client_id, msg), exc_info=True)
702 client_id, msg), exc_info=True)
703 return
703 return
704
704
705 parent = msg['parent_header']
705 parent = msg['parent_header']
706 if not parent:
706 if not parent:
707 # print msg
707 # print msg
708 self.log.warn("Task %r had no parent!"%msg)
708 self.log.warn("Task %r had no parent!"%msg)
709 return
709 return
710 msg_id = parent['msg_id']
710 msg_id = parent['msg_id']
711 if msg_id in self.unassigned:
711 if msg_id in self.unassigned:
712 self.unassigned.remove(msg_id)
712 self.unassigned.remove(msg_id)
713
713
714 header = msg['header']
714 header = msg['header']
715 engine_uuid = header.get('engine', None)
715 engine_uuid = header.get('engine', None)
716 eid = self.by_ident.get(engine_uuid, None)
716 eid = self.by_ident.get(engine_uuid, None)
717
717
718 if msg_id in self.pending:
718 if msg_id in self.pending:
719 self.pending.remove(msg_id)
719 self.pending.remove(msg_id)
720 self.all_completed.add(msg_id)
720 self.all_completed.add(msg_id)
721 if eid is not None:
721 if eid is not None:
722 self.completed[eid].append(msg_id)
722 self.completed[eid].append(msg_id)
723 if msg_id in self.tasks[eid]:
723 if msg_id in self.tasks[eid]:
724 self.tasks[eid].remove(msg_id)
724 self.tasks[eid].remove(msg_id)
725 completed = header['date']
725 completed = header['date']
726 started = header.get('started', None)
726 started = header.get('started', None)
727 result = {
727 result = {
728 'result_header' : header,
728 'result_header' : header,
729 'result_content': msg['content'],
729 'result_content': msg['content'],
730 'started' : started,
730 'started' : started,
731 'completed' : completed,
731 'completed' : completed,
732 'engine_uuid': engine_uuid
732 'engine_uuid': engine_uuid
733 }
733 }
734
734
735 result['result_buffers'] = msg['buffers']
735 result['result_buffers'] = msg['buffers']
736 try:
736 try:
737 self.db.update_record(msg_id, result)
737 self.db.update_record(msg_id, result)
738 except Exception:
738 except Exception:
739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740
740
741 else:
741 else:
742 self.log.debug("task::unknown task %r finished"%msg_id)
742 self.log.debug("task::unknown task %r finished"%msg_id)
743
743
744 def save_task_destination(self, idents, msg):
744 def save_task_destination(self, idents, msg):
745 try:
745 try:
746 msg = self.session.unpack_message(msg, content=True)
746 msg = self.session.unpack_message(msg, content=True)
747 except Exception:
747 except Exception:
748 self.log.error("task::invalid task tracking message", exc_info=True)
748 self.log.error("task::invalid task tracking message", exc_info=True)
749 return
749 return
750 content = msg['content']
750 content = msg['content']
751 # print (content)
751 # print (content)
752 msg_id = content['msg_id']
752 msg_id = content['msg_id']
753 engine_uuid = content['engine_id']
753 engine_uuid = content['engine_id']
754 eid = self.by_ident[engine_uuid]
754 eid = self.by_ident[engine_uuid]
755
755
756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 if msg_id in self.unassigned:
757 if msg_id in self.unassigned:
758 self.unassigned.remove(msg_id)
758 self.unassigned.remove(msg_id)
759 # else:
759 # else:
760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761
761
762 self.tasks[eid].append(msg_id)
762 self.tasks[eid].append(msg_id)
763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 try:
764 try:
765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 except Exception:
766 except Exception:
767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768
768
769
769
770 def mia_task_request(self, idents, msg):
770 def mia_task_request(self, idents, msg):
771 raise NotImplementedError
771 raise NotImplementedError
772 client_id = idents[0]
772 client_id = idents[0]
773 # content = dict(mia=self.mia,status='ok')
773 # content = dict(mia=self.mia,status='ok')
774 # self.session.send('mia_reply', content=content, idents=client_id)
774 # self.session.send('mia_reply', content=content, idents=client_id)
775
775
776
776
777 #--------------------- IOPub Traffic ------------------------------
777 #--------------------- IOPub Traffic ------------------------------
778
778
779 def save_iopub_message(self, topics, msg):
779 def save_iopub_message(self, topics, msg):
780 """save an iopub message into the db"""
780 """save an iopub message into the db"""
781 # print (topics)
781 # print (topics)
782 try:
782 try:
783 msg = self.session.unpack_message(msg, content=True)
783 msg = self.session.unpack_message(msg, content=True)
784 except Exception:
784 except Exception:
785 self.log.error("iopub::invalid IOPub message", exc_info=True)
785 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 return
786 return
787
787
788 parent = msg['parent_header']
788 parent = msg['parent_header']
789 if not parent:
789 if not parent:
790 self.log.error("iopub::invalid IOPub message: %r"%msg)
790 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 return
791 return
792 msg_id = parent['msg_id']
792 msg_id = parent['msg_id']
793 msg_type = msg['msg_type']
793 msg_type = msg['msg_type']
794 content = msg['content']
794 content = msg['content']
795
795
796 # ensure msg_id is in db
796 # ensure msg_id is in db
797 try:
797 try:
798 rec = self.db.get_record(msg_id)
798 rec = self.db.get_record(msg_id)
799 except KeyError:
799 except KeyError:
800 rec = empty_record()
800 rec = empty_record()
801 rec['msg_id'] = msg_id
801 rec['msg_id'] = msg_id
802 self.db.add_record(msg_id, rec)
802 self.db.add_record(msg_id, rec)
803 # stream
803 # stream
804 d = {}
804 d = {}
805 if msg_type == 'stream':
805 if msg_type == 'stream':
806 name = content['name']
806 name = content['name']
807 s = rec[name] or ''
807 s = rec[name] or ''
808 d[name] = s + content['data']
808 d[name] = s + content['data']
809
809
810 elif msg_type == 'pyerr':
810 elif msg_type == 'pyerr':
811 d['pyerr'] = content
811 d['pyerr'] = content
812 elif msg_type == 'pyin':
812 elif msg_type == 'pyin':
813 d['pyin'] = content['code']
813 d['pyin'] = content['code']
814 else:
814 else:
815 d[msg_type] = content.get('data', '')
815 d[msg_type] = content.get('data', '')
816
816
817 try:
817 try:
818 self.db.update_record(msg_id, d)
818 self.db.update_record(msg_id, d)
819 except Exception:
819 except Exception:
820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821
821
822
822
823
823
824 #-------------------------------------------------------------------------
824 #-------------------------------------------------------------------------
825 # Registration requests
825 # Registration requests
826 #-------------------------------------------------------------------------
826 #-------------------------------------------------------------------------
827
827
828 def connection_request(self, client_id, msg):
828 def connection_request(self, client_id, msg):
829 """Reply with connection addresses for clients."""
829 """Reply with connection addresses for clients."""
830 self.log.info("client::client %r connected"%client_id)
830 self.log.info("client::client %r connected"%client_id)
831 content = dict(status='ok')
831 content = dict(status='ok')
832 content.update(self.client_info)
832 content.update(self.client_info)
833 jsonable = {}
833 jsonable = {}
834 for k,v in self.keytable.iteritems():
834 for k,v in self.keytable.iteritems():
835 if v not in self.dead_engines:
835 if v not in self.dead_engines:
836 jsonable[str(k)] = v
836 jsonable[str(k)] = v
837 content['engines'] = jsonable
837 content['engines'] = jsonable
838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839
839
840 def register_engine(self, reg, msg):
840 def register_engine(self, reg, msg):
841 """Register a new engine."""
841 """Register a new engine."""
842 content = msg['content']
842 content = msg['content']
843 try:
843 try:
844 queue = content['queue']
844 queue = content['queue']
845 except KeyError:
845 except KeyError:
846 self.log.error("registration::queue not specified", exc_info=True)
846 self.log.error("registration::queue not specified", exc_info=True)
847 return
847 return
848 heart = content.get('heartbeat', None)
848 heart = content.get('heartbeat', None)
849 """register a new engine, and create the socket(s) necessary"""
849 """register a new engine, and create the socket(s) necessary"""
850 eid = self._next_id
850 eid = self._next_id
851 # print (eid, queue, reg, heart)
851 # print (eid, queue, reg, heart)
852
852
853 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
853 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
854
854
855 content = dict(id=eid,status='ok')
855 content = dict(id=eid,status='ok')
856 content.update(self.engine_info)
856 content.update(self.engine_info)
857 # check if requesting available IDs:
857 # check if requesting available IDs:
858 if queue in self.by_ident:
858 if queue in self.by_ident:
859 try:
859 try:
860 raise KeyError("queue_id %r in use"%queue)
860 raise KeyError("queue_id %r in use"%queue)
861 except:
861 except:
862 content = error.wrap_exception()
862 content = error.wrap_exception()
863 self.log.error("queue_id %r in use"%queue, exc_info=True)
863 self.log.error("queue_id %r in use"%queue, exc_info=True)
864 elif heart in self.hearts: # need to check unique hearts?
864 elif heart in self.hearts: # need to check unique hearts?
865 try:
865 try:
866 raise KeyError("heart_id %r in use"%heart)
866 raise KeyError("heart_id %r in use"%heart)
867 except:
867 except:
868 self.log.error("heart_id %r in use"%heart, exc_info=True)
868 self.log.error("heart_id %r in use"%heart, exc_info=True)
869 content = error.wrap_exception()
869 content = error.wrap_exception()
870 else:
870 else:
871 for h, pack in self.incoming_registrations.iteritems():
871 for h, pack in self.incoming_registrations.iteritems():
872 if heart == h:
872 if heart == h:
873 try:
873 try:
874 raise KeyError("heart_id %r in use"%heart)
874 raise KeyError("heart_id %r in use"%heart)
875 except:
875 except:
876 self.log.error("heart_id %r in use"%heart, exc_info=True)
876 self.log.error("heart_id %r in use"%heart, exc_info=True)
877 content = error.wrap_exception()
877 content = error.wrap_exception()
878 break
878 break
879 elif queue == pack[1]:
879 elif queue == pack[1]:
880 try:
880 try:
881 raise KeyError("queue_id %r in use"%queue)
881 raise KeyError("queue_id %r in use"%queue)
882 except:
882 except:
883 self.log.error("queue_id %r in use"%queue, exc_info=True)
883 self.log.error("queue_id %r in use"%queue, exc_info=True)
884 content = error.wrap_exception()
884 content = error.wrap_exception()
885 break
885 break
886
886
887 msg = self.session.send(self.query, "registration_reply",
887 msg = self.session.send(self.query, "registration_reply",
888 content=content,
888 content=content,
889 ident=reg)
889 ident=reg)
890
890
891 if content['status'] == 'ok':
891 if content['status'] == 'ok':
892 if heart in self.heartmonitor.hearts:
892 if heart in self.heartmonitor.hearts:
893 # already beating
893 # already beating
894 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
894 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
895 self.finish_registration(heart)
895 self.finish_registration(heart)
896 else:
896 else:
897 purge = lambda : self._purge_stalled_registration(heart)
897 purge = lambda : self._purge_stalled_registration(heart)
898 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
898 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
899 dc.start()
899 dc.start()
900 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
900 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
901 else:
901 else:
902 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
902 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
903 return eid
903 return eid
904
904
905 def unregister_engine(self, ident, msg):
905 def unregister_engine(self, ident, msg):
906 """Unregister an engine that explicitly requested to leave."""
906 """Unregister an engine that explicitly requested to leave."""
907 try:
907 try:
908 eid = msg['content']['id']
908 eid = msg['content']['id']
909 except:
909 except:
910 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
910 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
911 return
911 return
912 self.log.info("registration::unregister_engine(%r)"%eid)
912 self.log.info("registration::unregister_engine(%r)"%eid)
913 # print (eid)
913 # print (eid)
914 uuid = self.keytable[eid]
914 uuid = self.keytable[eid]
915 content=dict(id=eid, queue=uuid)
915 content=dict(id=eid, queue=uuid)
916 self.dead_engines.add(uuid)
916 self.dead_engines.add(uuid)
917 # self.ids.remove(eid)
917 # self.ids.remove(eid)
918 # uuid = self.keytable.pop(eid)
918 # uuid = self.keytable.pop(eid)
919 #
919 #
920 # ec = self.engines.pop(eid)
920 # ec = self.engines.pop(eid)
921 # self.hearts.pop(ec.heartbeat)
921 # self.hearts.pop(ec.heartbeat)
922 # self.by_ident.pop(ec.queue)
922 # self.by_ident.pop(ec.queue)
923 # self.completed.pop(eid)
923 # self.completed.pop(eid)
924 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
924 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
925 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
925 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
926 dc.start()
926 dc.start()
927 ############## TODO: HANDLE IT ################
927 ############## TODO: HANDLE IT ################
928
928
929 if self.notifier:
929 if self.notifier:
930 self.session.send(self.notifier, "unregistration_notification", content=content)
930 self.session.send(self.notifier, "unregistration_notification", content=content)
931
931
932 def _handle_stranded_msgs(self, eid, uuid):
932 def _handle_stranded_msgs(self, eid, uuid):
933 """Handle messages known to be on an engine when the engine unregisters.
933 """Handle messages known to be on an engine when the engine unregisters.
934
934
935 It is possible that this will fire prematurely - that is, an engine will
935 It is possible that this will fire prematurely - that is, an engine will
936 go down after completing a result, and the client will be notified
936 go down after completing a result, and the client will be notified
937 that the result failed and later receive the actual result.
937 that the result failed and later receive the actual result.
938 """
938 """
939
939
940 outstanding = self.queues[eid]
940 outstanding = self.queues[eid]
941
941
942 for msg_id in outstanding:
942 for msg_id in outstanding:
943 self.pending.remove(msg_id)
943 self.pending.remove(msg_id)
944 self.all_completed.add(msg_id)
944 self.all_completed.add(msg_id)
945 try:
945 try:
946 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
946 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
947 except:
947 except:
948 content = error.wrap_exception()
948 content = error.wrap_exception()
949 # build a fake header:
949 # build a fake header:
950 header = {}
950 header = {}
951 header['engine'] = uuid
951 header['engine'] = uuid
952 header['date'] = datetime.now()
952 header['date'] = datetime.now()
953 rec = dict(result_content=content, result_header=header, result_buffers=[])
953 rec = dict(result_content=content, result_header=header, result_buffers=[])
954 rec['completed'] = header['date']
954 rec['completed'] = header['date']
955 rec['engine_uuid'] = uuid
955 rec['engine_uuid'] = uuid
956 try:
956 try:
957 self.db.update_record(msg_id, rec)
957 self.db.update_record(msg_id, rec)
958 except Exception:
958 except Exception:
959 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
959 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
960
960
961
961
962 def finish_registration(self, heart):
962 def finish_registration(self, heart):
963 """Second half of engine registration, called after our HeartMonitor
963 """Second half of engine registration, called after our HeartMonitor
964 has received a beat from the Engine's Heart."""
964 has received a beat from the Engine's Heart."""
965 try:
965 try:
966 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
966 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
967 except KeyError:
967 except KeyError:
968 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
968 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
969 return
969 return
970 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
970 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
971 if purge is not None:
971 if purge is not None:
972 purge.stop()
972 purge.stop()
973 control = queue
973 control = queue
974 self.ids.add(eid)
974 self.ids.add(eid)
975 self.keytable[eid] = queue
975 self.keytable[eid] = queue
976 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
976 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
977 control=control, heartbeat=heart)
977 control=control, heartbeat=heart)
978 self.by_ident[queue] = eid
978 self.by_ident[queue] = eid
979 self.queues[eid] = list()
979 self.queues[eid] = list()
980 self.tasks[eid] = list()
980 self.tasks[eid] = list()
981 self.completed[eid] = list()
981 self.completed[eid] = list()
982 self.hearts[heart] = eid
982 self.hearts[heart] = eid
983 content = dict(id=eid, queue=self.engines[eid].queue)
983 content = dict(id=eid, queue=self.engines[eid].queue)
984 if self.notifier:
984 if self.notifier:
985 self.session.send(self.notifier, "registration_notification", content=content)
985 self.session.send(self.notifier, "registration_notification", content=content)
986 self.log.info("engine::Engine Connected: %i"%eid)
986 self.log.info("engine::Engine Connected: %i"%eid)
987
987
988 def _purge_stalled_registration(self, heart):
988 def _purge_stalled_registration(self, heart):
989 if heart in self.incoming_registrations:
989 if heart in self.incoming_registrations:
990 eid = self.incoming_registrations.pop(heart)[0]
990 eid = self.incoming_registrations.pop(heart)[0]
991 self.log.info("registration::purging stalled registration: %i"%eid)
991 self.log.info("registration::purging stalled registration: %i"%eid)
992 else:
992 else:
993 pass
993 pass
994
994
995 #-------------------------------------------------------------------------
995 #-------------------------------------------------------------------------
996 # Client Requests
996 # Client Requests
997 #-------------------------------------------------------------------------
997 #-------------------------------------------------------------------------
998
998
999 def shutdown_request(self, client_id, msg):
999 def shutdown_request(self, client_id, msg):
1000 """handle shutdown request."""
1000 """handle shutdown request."""
1001 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1001 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1002 # also notify other clients of shutdown
1002 # also notify other clients of shutdown
1003 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1003 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1004 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1004 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1005 dc.start()
1005 dc.start()
1006
1006
1007 def _shutdown(self):
1007 def _shutdown(self):
1008 self.log.info("hub::hub shutting down.")
1008 self.log.info("hub::hub shutting down.")
1009 time.sleep(0.1)
1009 time.sleep(0.1)
1010 sys.exit(0)
1010 sys.exit(0)
1011
1011
1012
1012
1013 def check_load(self, client_id, msg):
1013 def check_load(self, client_id, msg):
1014 content = msg['content']
1014 content = msg['content']
1015 try:
1015 try:
1016 targets = content['targets']
1016 targets = content['targets']
1017 targets = self._validate_targets(targets)
1017 targets = self._validate_targets(targets)
1018 except:
1018 except:
1019 content = error.wrap_exception()
1019 content = error.wrap_exception()
1020 self.session.send(self.query, "hub_error",
1020 self.session.send(self.query, "hub_error",
1021 content=content, ident=client_id)
1021 content=content, ident=client_id)
1022 return
1022 return
1023
1023
1024 content = dict(status='ok')
1024 content = dict(status='ok')
1025 # loads = {}
1025 # loads = {}
1026 for t in targets:
1026 for t in targets:
1027 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1027 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1028 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1028 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1029
1029
1030
1030
1031 def queue_status(self, client_id, msg):
1031 def queue_status(self, client_id, msg):
1032 """Return the Queue status of one or more targets.
1032 """Return the Queue status of one or more targets.
1033 if verbose: return the msg_ids
1033 if verbose: return the msg_ids
1034 else: return len of each type.
1034 else: return len of each type.
1035 keys: queue (pending MUX jobs)
1035 keys: queue (pending MUX jobs)
1036 tasks (pending Task jobs)
1036 tasks (pending Task jobs)
1037 completed (finished jobs from both queues)"""
1037 completed (finished jobs from both queues)"""
1038 content = msg['content']
1038 content = msg['content']
1039 targets = content['targets']
1039 targets = content['targets']
1040 try:
1040 try:
1041 targets = self._validate_targets(targets)
1041 targets = self._validate_targets(targets)
1042 except:
1042 except:
1043 content = error.wrap_exception()
1043 content = error.wrap_exception()
1044 self.session.send(self.query, "hub_error",
1044 self.session.send(self.query, "hub_error",
1045 content=content, ident=client_id)
1045 content=content, ident=client_id)
1046 return
1046 return
1047 verbose = content.get('verbose', False)
1047 verbose = content.get('verbose', False)
1048 content = dict(status='ok')
1048 content = dict(status='ok')
1049 for t in targets:
1049 for t in targets:
1050 queue = self.queues[t]
1050 queue = self.queues[t]
1051 completed = self.completed[t]
1051 completed = self.completed[t]
1052 tasks = self.tasks[t]
1052 tasks = self.tasks[t]
1053 if not verbose:
1053 if not verbose:
1054 queue = len(queue)
1054 queue = len(queue)
1055 completed = len(completed)
1055 completed = len(completed)
1056 tasks = len(tasks)
1056 tasks = len(tasks)
1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1058 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1058 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1059
1059
1060 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1060 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1061
1061
1062 def purge_results(self, client_id, msg):
1062 def purge_results(self, client_id, msg):
1063 """Purge results from memory. This method is more valuable before we move
1063 """Purge results from memory. This method is more valuable before we move
1064 to a DB based message storage mechanism."""
1064 to a DB based message storage mechanism."""
1065 content = msg['content']
1065 content = msg['content']
1066 msg_ids = content.get('msg_ids', [])
1066 msg_ids = content.get('msg_ids', [])
1067 reply = dict(status='ok')
1067 reply = dict(status='ok')
1068 if msg_ids == 'all':
1068 if msg_ids == 'all':
1069 try:
1069 try:
1070 self.db.drop_matching_records(dict(completed={'$ne':None}))
1070 self.db.drop_matching_records(dict(completed={'$ne':None}))
1071 except Exception:
1071 except Exception:
1072 reply = error.wrap_exception()
1072 reply = error.wrap_exception()
1073 else:
1073 else:
1074 pending = filter(lambda m: m in self.pending, msg_ids)
1074 pending = filter(lambda m: m in self.pending, msg_ids)
1075 if pending:
1075 if pending:
1076 try:
1076 try:
1077 raise IndexError("msg pending: %r"%pending[0])
1077 raise IndexError("msg pending: %r"%pending[0])
1078 except:
1078 except:
1079 reply = error.wrap_exception()
1079 reply = error.wrap_exception()
1080 else:
1080 else:
1081 try:
1081 try:
1082 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1082 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1083 except Exception:
1083 except Exception:
1084 reply = error.wrap_exception()
1084 reply = error.wrap_exception()
1085
1085
1086 if reply['status'] == 'ok':
1086 if reply['status'] == 'ok':
1087 eids = content.get('engine_ids', [])
1087 eids = content.get('engine_ids', [])
1088 for eid in eids:
1088 for eid in eids:
1089 if eid not in self.engines:
1089 if eid not in self.engines:
1090 try:
1090 try:
1091 raise IndexError("No such engine: %i"%eid)
1091 raise IndexError("No such engine: %i"%eid)
1092 except:
1092 except:
1093 reply = error.wrap_exception()
1093 reply = error.wrap_exception()
1094 break
1094 break
1095 msg_ids = self.completed.pop(eid)
1095 msg_ids = self.completed.pop(eid)
1096 uid = self.engines[eid].queue
1096 uid = self.engines[eid].queue
1097 try:
1097 try:
1098 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1098 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1099 except Exception:
1099 except Exception:
1100 reply = error.wrap_exception()
1100 reply = error.wrap_exception()
1101 break
1101 break
1102
1102
1103 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1103 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1104
1104
1105 def resubmit_task(self, client_id, msg):
1105 def resubmit_task(self, client_id, msg):
1106 """Resubmit one or more tasks."""
1106 """Resubmit one or more tasks."""
1107 def finish(reply):
1107 def finish(reply):
1108 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1108 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1109
1109
1110 content = msg['content']
1110 content = msg['content']
1111 msg_ids = content['msg_ids']
1111 msg_ids = content['msg_ids']
1112 reply = dict(status='ok')
1112 reply = dict(status='ok')
1113 try:
1113 try:
1114 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1114 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1115 'header', 'content', 'buffers'])
1115 'header', 'content', 'buffers'])
1116 except Exception:
1116 except Exception:
1117 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1117 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1118 return finish(error.wrap_exception())
1118 return finish(error.wrap_exception())
1119
1119
1120 # validate msg_ids
1120 # validate msg_ids
1121 found_ids = [ rec['msg_id'] for rec in records ]
1121 found_ids = [ rec['msg_id'] for rec in records ]
1122 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1122 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1123 if len(records) > len(msg_ids):
1123 if len(records) > len(msg_ids):
1124 try:
1124 try:
1125 raise RuntimeError("DB appears to be in an inconsistent state."
1125 raise RuntimeError("DB appears to be in an inconsistent state."
1126 "More matching records were found than should exist")
1126 "More matching records were found than should exist")
1127 except Exception:
1127 except Exception:
1128 return finish(error.wrap_exception())
1128 return finish(error.wrap_exception())
1129 elif len(records) < len(msg_ids):
1129 elif len(records) < len(msg_ids):
1130 missing = [ m for m in msg_ids if m not in found_ids ]
1130 missing = [ m for m in msg_ids if m not in found_ids ]
1131 try:
1131 try:
1132 raise KeyError("No such msg(s): %r"%missing)
1132 raise KeyError("No such msg(s): %r"%missing)
1133 except KeyError:
1133 except KeyError:
1134 return finish(error.wrap_exception())
1134 return finish(error.wrap_exception())
1135 elif invalid_ids:
1135 elif invalid_ids:
1136 msg_id = invalid_ids[0]
1136 msg_id = invalid_ids[0]
1137 try:
1137 try:
1138 raise ValueError("Task %r appears to be inflight"%(msg_id))
1138 raise ValueError("Task %r appears to be inflight"%(msg_id))
1139 except Exception:
1139 except Exception:
1140 return finish(error.wrap_exception())
1140 return finish(error.wrap_exception())
1141
1141
1142 # clear the existing records
1142 # clear the existing records
1143 now = datetime.now()
1143 now = datetime.now()
1144 rec = empty_record()
1144 rec = empty_record()
1145 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1145 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1146 rec['resubmitted'] = now
1146 rec['resubmitted'] = now
1147 rec['queue'] = 'task'
1147 rec['queue'] = 'task'
1148 rec['client_uuid'] = client_id[0]
1148 rec['client_uuid'] = client_id[0]
1149 try:
1149 try:
1150 for msg_id in msg_ids:
1150 for msg_id in msg_ids:
1151 self.all_completed.discard(msg_id)
1151 self.all_completed.discard(msg_id)
1152 self.db.update_record(msg_id, rec)
1152 self.db.update_record(msg_id, rec)
1153 except Exception:
1153 except Exception:
1154 self.log.error('db::db error upating record', exc_info=True)
1154 self.log.error('db::db error upating record', exc_info=True)
1155 reply = error.wrap_exception()
1155 reply = error.wrap_exception()
1156 else:
1156 else:
1157 # send the messages
1157 # send the messages
1158 for rec in records:
1158 for rec in records:
1159 header = rec['header']
1159 header = rec['header']
1160 # include resubmitted in header to prevent digest collision
1160 # include resubmitted in header to prevent digest collision
1161 header['resubmitted'] = now
1161 header['resubmitted'] = now
1162 msg = self.session.msg(header['msg_type'])
1162 msg = self.session.msg(header['msg_type'])
1163 msg['content'] = rec['content']
1163 msg['content'] = rec['content']
1164 msg['header'] = header
1164 msg['header'] = header
1165 msg['msg_id'] = rec['msg_id']
1165 msg['msg_id'] = rec['msg_id']
1166 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1166 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1167
1167
1168 finish(dict(status='ok'))
1168 finish(dict(status='ok'))
1169
1169
1170
1170
1171 def _extract_record(self, rec):
1171 def _extract_record(self, rec):
1172 """decompose a TaskRecord dict into subsection of reply for get_result"""
1172 """decompose a TaskRecord dict into subsection of reply for get_result"""
1173 io_dict = {}
1173 io_dict = {}
1174 for key in 'pyin pyout pyerr stdout stderr'.split():
1174 for key in 'pyin pyout pyerr stdout stderr'.split():
1175 io_dict[key] = rec[key]
1175 io_dict[key] = rec[key]
1176 content = { 'result_content': rec['result_content'],
1176 content = { 'result_content': rec['result_content'],
1177 'header': rec['header'],
1177 'header': rec['header'],
1178 'result_header' : rec['result_header'],
1178 'result_header' : rec['result_header'],
1179 'io' : io_dict,
1179 'io' : io_dict,
1180 }
1180 }
1181 if rec['result_buffers']:
1181 if rec['result_buffers']:
1182 buffers = map(str, rec['result_buffers'])
1182 buffers = map(str, rec['result_buffers'])
1183 else:
1183 else:
1184 buffers = []
1184 buffers = []
1185
1185
1186 return content, buffers
1186 return content, buffers
1187
1187
1188 def get_results(self, client_id, msg):
1188 def get_results(self, client_id, msg):
1189 """Get the result of 1 or more messages."""
1189 """Get the result of 1 or more messages."""
1190 content = msg['content']
1190 content = msg['content']
1191 msg_ids = sorted(set(content['msg_ids']))
1191 msg_ids = sorted(set(content['msg_ids']))
1192 statusonly = content.get('status_only', False)
1192 statusonly = content.get('status_only', False)
1193 pending = []
1193 pending = []
1194 completed = []
1194 completed = []
1195 content = dict(status='ok')
1195 content = dict(status='ok')
1196 content['pending'] = pending
1196 content['pending'] = pending
1197 content['completed'] = completed
1197 content['completed'] = completed
1198 buffers = []
1198 buffers = []
1199 if not statusonly:
1199 if not statusonly:
1200 try:
1200 try:
1201 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1201 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1202 # turn match list into dict, for faster lookup
1202 # turn match list into dict, for faster lookup
1203 records = {}
1203 records = {}
1204 for rec in matches:
1204 for rec in matches:
1205 records[rec['msg_id']] = rec
1205 records[rec['msg_id']] = rec
1206 except Exception:
1206 except Exception:
1207 content = error.wrap_exception()
1207 content = error.wrap_exception()
1208 self.session.send(self.query, "result_reply", content=content,
1208 self.session.send(self.query, "result_reply", content=content,
1209 parent=msg, ident=client_id)
1209 parent=msg, ident=client_id)
1210 return
1210 return
1211 else:
1211 else:
1212 records = {}
1212 records = {}
1213 for msg_id in msg_ids:
1213 for msg_id in msg_ids:
1214 if msg_id in self.pending:
1214 if msg_id in self.pending:
1215 pending.append(msg_id)
1215 pending.append(msg_id)
1216 elif msg_id in self.all_completed:
1216 elif msg_id in self.all_completed:
1217 completed.append(msg_id)
1217 completed.append(msg_id)
1218 if not statusonly:
1218 if not statusonly:
1219 c,bufs = self._extract_record(records[msg_id])
1219 c,bufs = self._extract_record(records[msg_id])
1220 content[msg_id] = c
1220 content[msg_id] = c
1221 buffers.extend(bufs)
1221 buffers.extend(bufs)
1222 elif msg_id in records:
1222 elif msg_id in records:
1223 if rec['completed']:
1223 if rec['completed']:
1224 completed.append(msg_id)
1224 completed.append(msg_id)
1225 c,bufs = self._extract_record(records[msg_id])
1225 c,bufs = self._extract_record(records[msg_id])
1226 content[msg_id] = c
1226 content[msg_id] = c
1227 buffers.extend(bufs)
1227 buffers.extend(bufs)
1228 else:
1228 else:
1229 pending.append(msg_id)
1229 pending.append(msg_id)
1230 else:
1230 else:
1231 try:
1231 try:
1232 raise KeyError('No such message: '+msg_id)
1232 raise KeyError('No such message: '+msg_id)
1233 except:
1233 except:
1234 content = error.wrap_exception()
1234 content = error.wrap_exception()
1235 break
1235 break
1236 self.session.send(self.query, "result_reply", content=content,
1236 self.session.send(self.query, "result_reply", content=content,
1237 parent=msg, ident=client_id,
1237 parent=msg, ident=client_id,
1238 buffers=buffers)
1238 buffers=buffers)
1239
1239
1240 def get_history(self, client_id, msg):
1240 def get_history(self, client_id, msg):
1241 """Get a list of all msg_ids in our DB records"""
1241 """Get a list of all msg_ids in our DB records"""
1242 try:
1242 try:
1243 msg_ids = self.db.get_history()
1243 msg_ids = self.db.get_history()
1244 except Exception as e:
1244 except Exception as e:
1245 content = error.wrap_exception()
1245 content = error.wrap_exception()
1246 else:
1246 else:
1247 content = dict(status='ok', history=msg_ids)
1247 content = dict(status='ok', history=msg_ids)
1248
1248
1249 self.session.send(self.query, "history_reply", content=content,
1249 self.session.send(self.query, "history_reply", content=content,
1250 parent=msg, ident=client_id)
1250 parent=msg, ident=client_id)
1251
1251
1252 def db_query(self, client_id, msg):
1252 def db_query(self, client_id, msg):
1253 """Perform a raw query on the task record database."""
1253 """Perform a raw query on the task record database."""
1254 content = msg['content']
1254 content = msg['content']
1255 query = content.get('query', {})
1255 query = content.get('query', {})
1256 keys = content.get('keys', None)
1256 keys = content.get('keys', None)
1257 buffers = []
1257 buffers = []
1258 empty = list()
1258 empty = list()
1259 try:
1259 try:
1260 records = self.db.find_records(query, keys)
1260 records = self.db.find_records(query, keys)
1261 except Exception as e:
1261 except Exception as e:
1262 content = error.wrap_exception()
1262 content = error.wrap_exception()
1263 else:
1263 else:
1264 # extract buffers from reply content:
1264 # extract buffers from reply content:
1265 if keys is not None:
1265 if keys is not None:
1266 buffer_lens = [] if 'buffers' in keys else None
1266 buffer_lens = [] if 'buffers' in keys else None
1267 result_buffer_lens = [] if 'result_buffers' in keys else None
1267 result_buffer_lens = [] if 'result_buffers' in keys else None
1268 else:
1268 else:
1269 buffer_lens = []
1269 buffer_lens = []
1270 result_buffer_lens = []
1270 result_buffer_lens = []
1271
1271
1272 for rec in records:
1272 for rec in records:
1273 # buffers may be None, so double check
1273 # buffers may be None, so double check
1274 if buffer_lens is not None:
1274 if buffer_lens is not None:
1275 b = rec.pop('buffers', empty) or empty
1275 b = rec.pop('buffers', empty) or empty
1276 buffer_lens.append(len(b))
1276 buffer_lens.append(len(b))
1277 buffers.extend(b)
1277 buffers.extend(b)
1278 if result_buffer_lens is not None:
1278 if result_buffer_lens is not None:
1279 rb = rec.pop('result_buffers', empty) or empty
1279 rb = rec.pop('result_buffers', empty) or empty
1280 result_buffer_lens.append(len(rb))
1280 result_buffer_lens.append(len(rb))
1281 buffers.extend(rb)
1281 buffers.extend(rb)
1282 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1282 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1283 result_buffer_lens=result_buffer_lens)
1283 result_buffer_lens=result_buffer_lens)
1284
1284
1285 self.session.send(self.query, "db_reply", content=content,
1285 self.session.send(self.query, "db_reply", content=content,
1286 parent=msg, ident=client_id,
1286 parent=msg, ident=client_id,
1287 buffers=buffers)
1287 buffers=buffers)
1288
1288
@@ -1,703 +1,703 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 from __future__ import print_function
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26
26
27 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
44
44
45 from IPython.parallel import error
45 from IPython.parallel import error
46 from IPython.parallel.factory import SessionFactory
46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger
47 from IPython.parallel.util import connect_logger, local_logger
48
48
49 from .dependency import Dependency
49 from .dependency import Dependency
50
50
51 @decorator
51 @decorator
52 def logged(f,self,*args,**kwargs):
52 def logged(f,self,*args,**kwargs):
53 # print ("#--------------------")
53 # print ("#--------------------")
54 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
54 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
55 # print ("#--")
55 # print ("#--")
56 return f(self,*args, **kwargs)
56 return f(self,*args, **kwargs)
57
57
58 #----------------------------------------------------------------------
58 #----------------------------------------------------------------------
59 # Chooser functions
59 # Chooser functions
60 #----------------------------------------------------------------------
60 #----------------------------------------------------------------------
61
61
62 def plainrandom(loads):
62 def plainrandom(loads):
63 """Plain random pick."""
63 """Plain random pick."""
64 n = len(loads)
64 n = len(loads)
65 return randint(0,n-1)
65 return randint(0,n-1)
66
66
67 def lru(loads):
67 def lru(loads):
68 """Always pick the front of the line.
68 """Always pick the front of the line.
69
69
70 The content of `loads` is ignored.
70 The content of `loads` is ignored.
71
71
72 Assumes LRU ordering of loads, with oldest first.
72 Assumes LRU ordering of loads, with oldest first.
73 """
73 """
74 return 0
74 return 0
75
75
76 def twobin(loads):
76 def twobin(loads):
77 """Pick two at random, use the LRU of the two.
77 """Pick two at random, use the LRU of the two.
78
78
79 The content of loads is ignored.
79 The content of loads is ignored.
80
80
81 Assumes LRU ordering of loads, with oldest first.
81 Assumes LRU ordering of loads, with oldest first.
82 """
82 """
83 n = len(loads)
83 n = len(loads)
84 a = randint(0,n-1)
84 a = randint(0,n-1)
85 b = randint(0,n-1)
85 b = randint(0,n-1)
86 return min(a,b)
86 return min(a,b)
87
87
88 def weighted(loads):
88 def weighted(loads):
89 """Pick two at random using inverse load as weight.
89 """Pick two at random using inverse load as weight.
90
90
91 Return the less loaded of the two.
91 Return the less loaded of the two.
92 """
92 """
93 # weight 0 a million times more than 1:
93 # weight 0 a million times more than 1:
94 weights = 1./(1e-6+numpy.array(loads))
94 weights = 1./(1e-6+numpy.array(loads))
95 sums = weights.cumsum()
95 sums = weights.cumsum()
96 t = sums[-1]
96 t = sums[-1]
97 x = random()*t
97 x = random()*t
98 y = random()*t
98 y = random()*t
99 idx = 0
99 idx = 0
100 idy = 0
100 idy = 0
101 while sums[idx] < x:
101 while sums[idx] < x:
102 idx += 1
102 idx += 1
103 while sums[idy] < y:
103 while sums[idy] < y:
104 idy += 1
104 idy += 1
105 if weights[idy] > weights[idx]:
105 if weights[idy] > weights[idx]:
106 return idy
106 return idy
107 else:
107 else:
108 return idx
108 return idx
109
109
110 def leastload(loads):
110 def leastload(loads):
111 """Always choose the lowest load.
111 """Always choose the lowest load.
112
112
113 If the lowest load occurs more than once, the first
113 If the lowest load occurs more than once, the first
114 occurance will be used. If loads has LRU ordering, this means
114 occurance will be used. If loads has LRU ordering, this means
115 the LRU of those with the lowest load is chosen.
115 the LRU of those with the lowest load is chosen.
116 """
116 """
117 return loads.index(min(loads))
117 return loads.index(min(loads))
118
118
119 #---------------------------------------------------------------------
119 #---------------------------------------------------------------------
120 # Classes
120 # Classes
121 #---------------------------------------------------------------------
121 #---------------------------------------------------------------------
122 # store empty default dependency:
122 # store empty default dependency:
123 MET = Dependency([])
123 MET = Dependency([])
124
124
125 class TaskScheduler(SessionFactory):
125 class TaskScheduler(SessionFactory):
126 """Python TaskScheduler object.
126 """Python TaskScheduler object.
127
127
128 This is the simplest object that supports msg_id based
128 This is the simplest object that supports msg_id based
129 DAG dependencies. *Only* task msg_ids are checked, not
129 DAG dependencies. *Only* task msg_ids are checked, not
130 msg_ids of jobs submitted via the MUX queue.
130 msg_ids of jobs submitted via the MUX queue.
131
131
132 """
132 """
133
133
134 hwm = Int(0, config=True, shortname='hwm',
134 hwm = Int(0, config=True, shortname='hwm',
135 help="""specify the High Water Mark (HWM) for the downstream
135 help="""specify the High Water Mark (HWM) for the downstream
136 socket in the Task scheduler. This is the maximum number
136 socket in the Task scheduler. This is the maximum number
137 of allowed outstanding tasks on each engine."""
137 of allowed outstanding tasks on each engine."""
138 )
138 )
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 'leastload', config=True, shortname='scheme', allow_none=False,
140 'leastload', config=True, shortname='scheme', allow_none=False,
141 help="""select the task scheduler scheme [default: Python LRU]
141 help="""select the task scheduler scheme [default: Python LRU]
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 )
143 )
144 def _scheme_name_changed(self, old, new):
144 def _scheme_name_changed(self, old, new):
145 self.log.debug("Using scheme %r"%new)
145 self.log.debug("Using scheme %r"%new)
146 self.scheme = globals()[new]
146 self.scheme = globals()[new]
147
147
148 # input arguments:
148 # input arguments:
149 scheme = Instance(FunctionType) # function for determining the destination
149 scheme = Instance(FunctionType) # function for determining the destination
150 def _scheme_default(self):
150 def _scheme_default(self):
151 return leastload
151 return leastload
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156
156
157 # internals:
157 # internals:
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 pending = Dict() # dict by engine_uuid of submitted tasks
162 pending = Dict() # dict by engine_uuid of submitted tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 clients = Dict() # dict by msg_id for who submitted the task
166 clients = Dict() # dict by msg_id for who submitted the task
167 targets = List() # list of target IDENTs
167 targets = List() # list of target IDENTs
168 loads = List() # list of engine loads
168 loads = List() # list of engine loads
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 all_completed = Set() # set of all completed tasks
170 all_completed = Set() # set of all completed tasks
171 all_failed = Set() # set of all failed tasks
171 all_failed = Set() # set of all failed tasks
172 all_done = Set() # set of all finished tasks=union(completed,failed)
172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 all_ids = Set() # set of all submitted task IDs
173 all_ids = Set() # set of all submitted task IDs
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176
176
177
177
178 def start(self):
178 def start(self):
179 self.engine_stream.on_recv(self.dispatch_result, copy=False)
179 self.engine_stream.on_recv(self.dispatch_result, copy=False)
180 self._notification_handlers = dict(
180 self._notification_handlers = dict(
181 registration_notification = self._register_engine,
181 registration_notification = self._register_engine,
182 unregistration_notification = self._unregister_engine
182 unregistration_notification = self._unregister_engine
183 )
183 )
184 self.notifier_stream.on_recv(self.dispatch_notification)
184 self.notifier_stream.on_recv(self.dispatch_notification)
185 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
185 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
186 self.auditor.start()
186 self.auditor.start()
187 self.log.info("Scheduler started [%s]"%self.scheme_name)
187 self.log.info("Scheduler started [%s]"%self.scheme_name)
188
188
189 def resume_receiving(self):
189 def resume_receiving(self):
190 """Resume accepting jobs."""
190 """Resume accepting jobs."""
191 self.client_stream.on_recv(self.dispatch_submission, copy=False)
191 self.client_stream.on_recv(self.dispatch_submission, copy=False)
192
192
193 def stop_receiving(self):
193 def stop_receiving(self):
194 """Stop accepting jobs while there are no engines.
194 """Stop accepting jobs while there are no engines.
195 Leave them in the ZMQ queue."""
195 Leave them in the ZMQ queue."""
196 self.client_stream.on_recv(None)
196 self.client_stream.on_recv(None)
197
197
198 #-----------------------------------------------------------------------
198 #-----------------------------------------------------------------------
199 # [Un]Registration Handling
199 # [Un]Registration Handling
200 #-----------------------------------------------------------------------
200 #-----------------------------------------------------------------------
201
201
202 def dispatch_notification(self, msg):
202 def dispatch_notification(self, msg):
203 """dispatch register/unregister events."""
203 """dispatch register/unregister events."""
204 try:
204 try:
205 idents,msg = self.session.feed_identities(msg)
205 idents,msg = self.session.feed_identities(msg)
206 except ValueError:
206 except ValueError:
207 self.log.warn("task::Invalid Message: %r"%msg)
207 self.log.warn("task::Invalid Message: %r"%msg)
208 return
208 return
209 try:
209 try:
210 msg = self.session.unpack_message(msg)
210 msg = self.session.unpack_message(msg)
211 except ValueError:
211 except ValueError:
212 self.log.warn("task::Unauthorized message from: %r"%idents)
212 self.log.warn("task::Unauthorized message from: %r"%idents)
213 return
213 return
214
214
215 msg_type = msg['msg_type']
215 msg_type = msg['msg_type']
216
216
217 handler = self._notification_handlers.get(msg_type, None)
217 handler = self._notification_handlers.get(msg_type, None)
218 if handler is None:
218 if handler is None:
219 self.log.error("Unhandled message type: %r"%msg_type)
219 self.log.error("Unhandled message type: %r"%msg_type)
220 else:
220 else:
221 try:
221 try:
222 handler(str(msg['content']['queue']))
222 handler(str(msg['content']['queue']))
223 except KeyError:
223 except KeyError:
224 self.log.error("task::Invalid notification msg: %r"%msg)
224 self.log.error("task::Invalid notification msg: %r"%msg)
225
225
226 @logged
226 @logged
227 def _register_engine(self, uid):
227 def _register_engine(self, uid):
228 """New engine with ident `uid` became available."""
228 """New engine with ident `uid` became available."""
229 # head of the line:
229 # head of the line:
230 self.targets.insert(0,uid)
230 self.targets.insert(0,uid)
231 self.loads.insert(0,0)
231 self.loads.insert(0,0)
232 # initialize sets
232 # initialize sets
233 self.completed[uid] = set()
233 self.completed[uid] = set()
234 self.failed[uid] = set()
234 self.failed[uid] = set()
235 self.pending[uid] = {}
235 self.pending[uid] = {}
236 if len(self.targets) == 1:
236 if len(self.targets) == 1:
237 self.resume_receiving()
237 self.resume_receiving()
238 # rescan the graph:
238 # rescan the graph:
239 self.update_graph(None)
239 self.update_graph(None)
240
240
241 def _unregister_engine(self, uid):
241 def _unregister_engine(self, uid):
242 """Existing engine with ident `uid` became unavailable."""
242 """Existing engine with ident `uid` became unavailable."""
243 if len(self.targets) == 1:
243 if len(self.targets) == 1:
244 # this was our only engine
244 # this was our only engine
245 self.stop_receiving()
245 self.stop_receiving()
246
246
247 # handle any potentially finished tasks:
247 # handle any potentially finished tasks:
248 self.engine_stream.flush()
248 self.engine_stream.flush()
249
249
250 # don't pop destinations, because they might be used later
250 # don't pop destinations, because they might be used later
251 # map(self.destinations.pop, self.completed.pop(uid))
251 # map(self.destinations.pop, self.completed.pop(uid))
252 # map(self.destinations.pop, self.failed.pop(uid))
252 # map(self.destinations.pop, self.failed.pop(uid))
253
253
254 # prevent this engine from receiving work
254 # prevent this engine from receiving work
255 idx = self.targets.index(uid)
255 idx = self.targets.index(uid)
256 self.targets.pop(idx)
256 self.targets.pop(idx)
257 self.loads.pop(idx)
257 self.loads.pop(idx)
258
258
259 # wait 5 seconds before cleaning up pending jobs, since the results might
259 # wait 5 seconds before cleaning up pending jobs, since the results might
260 # still be incoming
260 # still be incoming
261 if self.pending[uid]:
261 if self.pending[uid]:
262 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
262 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
263 dc.start()
263 dc.start()
264 else:
264 else:
265 self.completed.pop(uid)
265 self.completed.pop(uid)
266 self.failed.pop(uid)
266 self.failed.pop(uid)
267
267
268
268
269 @logged
269 @logged
270 def handle_stranded_tasks(self, engine):
270 def handle_stranded_tasks(self, engine):
271 """Deal with jobs resident in an engine that died."""
271 """Deal with jobs resident in an engine that died."""
272 lost = self.pending[engine]
272 lost = self.pending[engine]
273 for msg_id in lost.keys():
273 for msg_id in lost.keys():
274 if msg_id not in self.pending[engine]:
274 if msg_id not in self.pending[engine]:
275 # prevent double-handling of messages
275 # prevent double-handling of messages
276 continue
276 continue
277
277
278 raw_msg = lost[msg_id][0]
278 raw_msg = lost[msg_id][0]
279 idents,msg = self.session.feed_identities(raw_msg, copy=False)
279 idents,msg = self.session.feed_identities(raw_msg, copy=False)
280 parent = self.session.unpack(msg[1].bytes)
280 parent = self.session.unpack(msg[1].bytes)
281 idents = [engine, idents[0]]
281 idents = [engine, idents[0]]
282
282
283 # build fake error reply
283 # build fake error reply
284 try:
284 try:
285 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
285 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
286 except:
286 except:
287 content = error.wrap_exception()
287 content = error.wrap_exception()
288 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
288 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
289 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
289 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
290 # and dispatch it
290 # and dispatch it
291 self.dispatch_result(raw_reply)
291 self.dispatch_result(raw_reply)
292
292
293 # finally scrub completed/failed lists
293 # finally scrub completed/failed lists
294 self.completed.pop(engine)
294 self.completed.pop(engine)
295 self.failed.pop(engine)
295 self.failed.pop(engine)
296
296
297
297
298 #-----------------------------------------------------------------------
298 #-----------------------------------------------------------------------
299 # Job Submission
299 # Job Submission
300 #-----------------------------------------------------------------------
300 #-----------------------------------------------------------------------
301 @logged
301 @logged
302 def dispatch_submission(self, raw_msg):
302 def dispatch_submission(self, raw_msg):
303 """Dispatch job submission to appropriate handlers."""
303 """Dispatch job submission to appropriate handlers."""
304 # ensure targets up to date:
304 # ensure targets up to date:
305 self.notifier_stream.flush()
305 self.notifier_stream.flush()
306 try:
306 try:
307 idents, msg = self.session.feed_identities(raw_msg, copy=False)
307 idents, msg = self.session.feed_identities(raw_msg, copy=False)
308 msg = self.session.unpack_message(msg, content=False, copy=False)
308 msg = self.session.unpack_message(msg, content=False, copy=False)
309 except Exception:
309 except Exception:
310 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
310 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
311 return
311 return
312
312
313
313
314 # send to monitor
314 # send to monitor
315 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
315 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
316
316
317 header = msg['header']
317 header = msg['header']
318 msg_id = header['msg_id']
318 msg_id = header['msg_id']
319 self.all_ids.add(msg_id)
319 self.all_ids.add(msg_id)
320
320
321 # targets
321 # targets
322 targets = set(header.get('targets', []))
322 targets = set(header.get('targets', []))
323 retries = header.get('retries', 0)
323 retries = header.get('retries', 0)
324 self.retries[msg_id] = retries
324 self.retries[msg_id] = retries
325
325
326 # time dependencies
326 # time dependencies
327 after = Dependency(header.get('after', []))
327 after = Dependency(header.get('after', []))
328 if after.all:
328 if after.all:
329 if after.success:
329 if after.success:
330 after.difference_update(self.all_completed)
330 after.difference_update(self.all_completed)
331 if after.failure:
331 if after.failure:
332 after.difference_update(self.all_failed)
332 after.difference_update(self.all_failed)
333 if after.check(self.all_completed, self.all_failed):
333 if after.check(self.all_completed, self.all_failed):
334 # recast as empty set, if `after` already met,
334 # recast as empty set, if `after` already met,
335 # to prevent unnecessary set comparisons
335 # to prevent unnecessary set comparisons
336 after = MET
336 after = MET
337
337
338 # location dependencies
338 # location dependencies
339 follow = Dependency(header.get('follow', []))
339 follow = Dependency(header.get('follow', []))
340
340
341 # turn timeouts into datetime objects:
341 # turn timeouts into datetime objects:
342 timeout = header.get('timeout', None)
342 timeout = header.get('timeout', None)
343 if timeout:
343 if timeout:
344 timeout = datetime.now() + timedelta(0,timeout,0)
344 timeout = datetime.now() + timedelta(0,timeout,0)
345
345
346 args = [raw_msg, targets, after, follow, timeout]
346 args = [raw_msg, targets, after, follow, timeout]
347
347
348 # validate and reduce dependencies:
348 # validate and reduce dependencies:
349 for dep in after,follow:
349 for dep in after,follow:
350 # check valid:
350 # check valid:
351 if msg_id in dep or dep.difference(self.all_ids):
351 if msg_id in dep or dep.difference(self.all_ids):
352 self.depending[msg_id] = args
352 self.depending[msg_id] = args
353 return self.fail_unreachable(msg_id, error.InvalidDependency)
353 return self.fail_unreachable(msg_id, error.InvalidDependency)
354 # check if unreachable:
354 # check if unreachable:
355 if dep.unreachable(self.all_completed, self.all_failed):
355 if dep.unreachable(self.all_completed, self.all_failed):
356 self.depending[msg_id] = args
356 self.depending[msg_id] = args
357 return self.fail_unreachable(msg_id)
357 return self.fail_unreachable(msg_id)
358
358
359 if after.check(self.all_completed, self.all_failed):
359 if after.check(self.all_completed, self.all_failed):
360 # time deps already met, try to run
360 # time deps already met, try to run
361 if not self.maybe_run(msg_id, *args):
361 if not self.maybe_run(msg_id, *args):
362 # can't run yet
362 # can't run yet
363 if msg_id not in self.all_failed:
363 if msg_id not in self.all_failed:
364 # could have failed as unreachable
364 # could have failed as unreachable
365 self.save_unmet(msg_id, *args)
365 self.save_unmet(msg_id, *args)
366 else:
366 else:
367 self.save_unmet(msg_id, *args)
367 self.save_unmet(msg_id, *args)
368
368
369 # @logged
369 # @logged
370 def audit_timeouts(self):
370 def audit_timeouts(self):
371 """Audit all waiting tasks for expired timeouts."""
371 """Audit all waiting tasks for expired timeouts."""
372 now = datetime.now()
372 now = datetime.now()
373 for msg_id in self.depending.keys():
373 for msg_id in self.depending.keys():
374 # must recheck, in case one failure cascaded to another:
374 # must recheck, in case one failure cascaded to another:
375 if msg_id in self.depending:
375 if msg_id in self.depending:
376 raw,after,targets,follow,timeout = self.depending[msg_id]
376 raw,after,targets,follow,timeout = self.depending[msg_id]
377 if timeout and timeout < now:
377 if timeout and timeout < now:
378 self.fail_unreachable(msg_id, error.TaskTimeout)
378 self.fail_unreachable(msg_id, error.TaskTimeout)
379
379
380 @logged
380 @logged
381 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
381 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
382 """a task has become unreachable, send a reply with an ImpossibleDependency
382 """a task has become unreachable, send a reply with an ImpossibleDependency
383 error."""
383 error."""
384 if msg_id not in self.depending:
384 if msg_id not in self.depending:
385 self.log.error("msg %r already failed!"%msg_id)
385 self.log.error("msg %r already failed!"%msg_id)
386 return
386 return
387 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
387 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
388 for mid in follow.union(after):
388 for mid in follow.union(after):
389 if mid in self.graph:
389 if mid in self.graph:
390 self.graph[mid].remove(msg_id)
390 self.graph[mid].remove(msg_id)
391
391
392 # FIXME: unpacking a message I've already unpacked, but didn't save:
392 # FIXME: unpacking a message I've already unpacked, but didn't save:
393 idents,msg = self.session.feed_identities(raw_msg, copy=False)
393 idents,msg = self.session.feed_identities(raw_msg, copy=False)
394 header = self.session.unpack(msg[1].bytes)
394 header = self.session.unpack(msg[1].bytes)
395
395
396 try:
396 try:
397 raise why()
397 raise why()
398 except:
398 except:
399 content = error.wrap_exception()
399 content = error.wrap_exception()
400
400
401 self.all_done.add(msg_id)
401 self.all_done.add(msg_id)
402 self.all_failed.add(msg_id)
402 self.all_failed.add(msg_id)
403
403
404 msg = self.session.send(self.client_stream, 'apply_reply', content,
404 msg = self.session.send(self.client_stream, 'apply_reply', content,
405 parent=header, ident=idents)
405 parent=header, ident=idents)
406 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
406 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
407
407
408 self.update_graph(msg_id, success=False)
408 self.update_graph(msg_id, success=False)
409
409
410 @logged
410 @logged
411 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
411 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
412 """check location dependencies, and run if they are met."""
412 """check location dependencies, and run if they are met."""
413 blacklist = self.blacklist.setdefault(msg_id, set())
413 blacklist = self.blacklist.setdefault(msg_id, set())
414 if follow or targets or blacklist or self.hwm:
414 if follow or targets or blacklist or self.hwm:
415 # we need a can_run filter
415 # we need a can_run filter
416 def can_run(idx):
416 def can_run(idx):
417 # check hwm
417 # check hwm
418 if self.hwm and self.loads[idx] == self.hwm:
418 if self.hwm and self.loads[idx] == self.hwm:
419 return False
419 return False
420 target = self.targets[idx]
420 target = self.targets[idx]
421 # check blacklist
421 # check blacklist
422 if target in blacklist:
422 if target in blacklist:
423 return False
423 return False
424 # check targets
424 # check targets
425 if targets and target not in targets:
425 if targets and target not in targets:
426 return False
426 return False
427 # check follow
427 # check follow
428 return follow.check(self.completed[target], self.failed[target])
428 return follow.check(self.completed[target], self.failed[target])
429
429
430 indices = filter(can_run, range(len(self.targets)))
430 indices = filter(can_run, range(len(self.targets)))
431
431
432 if not indices:
432 if not indices:
433 # couldn't run
433 # couldn't run
434 if follow.all:
434 if follow.all:
435 # check follow for impossibility
435 # check follow for impossibility
436 dests = set()
436 dests = set()
437 relevant = set()
437 relevant = set()
438 if follow.success:
438 if follow.success:
439 relevant = self.all_completed
439 relevant = self.all_completed
440 if follow.failure:
440 if follow.failure:
441 relevant = relevant.union(self.all_failed)
441 relevant = relevant.union(self.all_failed)
442 for m in follow.intersection(relevant):
442 for m in follow.intersection(relevant):
443 dests.add(self.destinations[m])
443 dests.add(self.destinations[m])
444 if len(dests) > 1:
444 if len(dests) > 1:
445 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
445 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
446 self.fail_unreachable(msg_id)
446 self.fail_unreachable(msg_id)
447 return False
447 return False
448 if targets:
448 if targets:
449 # check blacklist+targets for impossibility
449 # check blacklist+targets for impossibility
450 targets.difference_update(blacklist)
450 targets.difference_update(blacklist)
451 if not targets or not targets.intersection(self.targets):
451 if not targets or not targets.intersection(self.targets):
452 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
452 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
453 self.fail_unreachable(msg_id)
453 self.fail_unreachable(msg_id)
454 return False
454 return False
455 return False
455 return False
456 else:
456 else:
457 indices = None
457 indices = None
458
458
459 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
459 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
460 return True
460 return True
461
461
462 @logged
462 @logged
463 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
463 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
464 """Save a message for later submission when its dependencies are met."""
464 """Save a message for later submission when its dependencies are met."""
465 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
465 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
466 # track the ids in follow or after, but not those already finished
466 # track the ids in follow or after, but not those already finished
467 for dep_id in after.union(follow).difference(self.all_done):
467 for dep_id in after.union(follow).difference(self.all_done):
468 if dep_id not in self.graph:
468 if dep_id not in self.graph:
469 self.graph[dep_id] = set()
469 self.graph[dep_id] = set()
470 self.graph[dep_id].add(msg_id)
470 self.graph[dep_id].add(msg_id)
471
471
472 @logged
472 @logged
473 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
473 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
474 """Submit a task to any of a subset of our targets."""
474 """Submit a task to any of a subset of our targets."""
475 if indices:
475 if indices:
476 loads = [self.loads[i] for i in indices]
476 loads = [self.loads[i] for i in indices]
477 else:
477 else:
478 loads = self.loads
478 loads = self.loads
479 idx = self.scheme(loads)
479 idx = self.scheme(loads)
480 if indices:
480 if indices:
481 idx = indices[idx]
481 idx = indices[idx]
482 target = self.targets[idx]
482 target = self.targets[idx]
483 # print (target, map(str, msg[:3]))
483 # print (target, map(str, msg[:3]))
484 # send job to the engine
484 # send job to the engine
485 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
485 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
486 self.engine_stream.send_multipart(raw_msg, copy=False)
486 self.engine_stream.send_multipart(raw_msg, copy=False)
487 # update load
487 # update load
488 self.add_job(idx)
488 self.add_job(idx)
489 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
489 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
490 # notify Hub
490 # notify Hub
491 content = dict(msg_id=msg_id, engine_id=target)
491 content = dict(msg_id=msg_id, engine_id=target)
492 self.session.send(self.mon_stream, 'task_destination', content=content,
492 self.session.send(self.mon_stream, 'task_destination', content=content,
493 ident=['tracktask',self.session.session])
493 ident=['tracktask',self.session.session])
494
494
495
495
496 #-----------------------------------------------------------------------
496 #-----------------------------------------------------------------------
497 # Result Handling
497 # Result Handling
498 #-----------------------------------------------------------------------
498 #-----------------------------------------------------------------------
499 @logged
499 @logged
500 def dispatch_result(self, raw_msg):
500 def dispatch_result(self, raw_msg):
501 """dispatch method for result replies"""
501 """dispatch method for result replies"""
502 try:
502 try:
503 idents,msg = self.session.feed_identities(raw_msg, copy=False)
503 idents,msg = self.session.feed_identities(raw_msg, copy=False)
504 msg = self.session.unpack_message(msg, content=False, copy=False)
504 msg = self.session.unpack_message(msg, content=False, copy=False)
505 engine = idents[0]
505 engine = idents[0]
506 try:
506 try:
507 idx = self.targets.index(engine)
507 idx = self.targets.index(engine)
508 except ValueError:
508 except ValueError:
509 pass # skip load-update for dead engines
509 pass # skip load-update for dead engines
510 else:
510 else:
511 self.finish_job(idx)
511 self.finish_job(idx)
512 except Exception:
512 except Exception:
513 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
513 self.log.error("task::Invaid result: %r"%raw_msg, exc_info=True)
514 return
514 return
515
515
516 header = msg['header']
516 header = msg['header']
517 parent = msg['parent_header']
517 parent = msg['parent_header']
518 if header.get('dependencies_met', True):
518 if header.get('dependencies_met', True):
519 success = (header['status'] == 'ok')
519 success = (header['status'] == 'ok')
520 msg_id = parent['msg_id']
520 msg_id = parent['msg_id']
521 retries = self.retries[msg_id]
521 retries = self.retries[msg_id]
522 if not success and retries > 0:
522 if not success and retries > 0:
523 # failed
523 # failed
524 self.retries[msg_id] = retries - 1
524 self.retries[msg_id] = retries - 1
525 self.handle_unmet_dependency(idents, parent)
525 self.handle_unmet_dependency(idents, parent)
526 else:
526 else:
527 del self.retries[msg_id]
527 del self.retries[msg_id]
528 # relay to client and update graph
528 # relay to client and update graph
529 self.handle_result(idents, parent, raw_msg, success)
529 self.handle_result(idents, parent, raw_msg, success)
530 # send to Hub monitor
530 # send to Hub monitor
531 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
531 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
532 else:
532 else:
533 self.handle_unmet_dependency(idents, parent)
533 self.handle_unmet_dependency(idents, parent)
534
534
535 @logged
535 @logged
536 def handle_result(self, idents, parent, raw_msg, success=True):
536 def handle_result(self, idents, parent, raw_msg, success=True):
537 """handle a real task result, either success or failure"""
537 """handle a real task result, either success or failure"""
538 # first, relay result to client
538 # first, relay result to client
539 engine = idents[0]
539 engine = idents[0]
540 client = idents[1]
540 client = idents[1]
541 # swap_ids for XREP-XREP mirror
541 # swap_ids for XREP-XREP mirror
542 raw_msg[:2] = [client,engine]
542 raw_msg[:2] = [client,engine]
543 # print (map(str, raw_msg[:4]))
543 # print (map(str, raw_msg[:4]))
544 self.client_stream.send_multipart(raw_msg, copy=False)
544 self.client_stream.send_multipart(raw_msg, copy=False)
545 # now, update our data structures
545 # now, update our data structures
546 msg_id = parent['msg_id']
546 msg_id = parent['msg_id']
547 self.blacklist.pop(msg_id, None)
547 self.blacklist.pop(msg_id, None)
548 self.pending[engine].pop(msg_id)
548 self.pending[engine].pop(msg_id)
549 if success:
549 if success:
550 self.completed[engine].add(msg_id)
550 self.completed[engine].add(msg_id)
551 self.all_completed.add(msg_id)
551 self.all_completed.add(msg_id)
552 else:
552 else:
553 self.failed[engine].add(msg_id)
553 self.failed[engine].add(msg_id)
554 self.all_failed.add(msg_id)
554 self.all_failed.add(msg_id)
555 self.all_done.add(msg_id)
555 self.all_done.add(msg_id)
556 self.destinations[msg_id] = engine
556 self.destinations[msg_id] = engine
557
557
558 self.update_graph(msg_id, success)
558 self.update_graph(msg_id, success)
559
559
560 @logged
560 @logged
561 def handle_unmet_dependency(self, idents, parent):
561 def handle_unmet_dependency(self, idents, parent):
562 """handle an unmet dependency"""
562 """handle an unmet dependency"""
563 engine = idents[0]
563 engine = idents[0]
564 msg_id = parent['msg_id']
564 msg_id = parent['msg_id']
565
565
566 if msg_id not in self.blacklist:
566 if msg_id not in self.blacklist:
567 self.blacklist[msg_id] = set()
567 self.blacklist[msg_id] = set()
568 self.blacklist[msg_id].add(engine)
568 self.blacklist[msg_id].add(engine)
569
569
570 args = self.pending[engine].pop(msg_id)
570 args = self.pending[engine].pop(msg_id)
571 raw,targets,after,follow,timeout = args
571 raw,targets,after,follow,timeout = args
572
572
573 if self.blacklist[msg_id] == targets:
573 if self.blacklist[msg_id] == targets:
574 self.depending[msg_id] = args
574 self.depending[msg_id] = args
575 self.fail_unreachable(msg_id)
575 self.fail_unreachable(msg_id)
576 elif not self.maybe_run(msg_id, *args):
576 elif not self.maybe_run(msg_id, *args):
577 # resubmit failed
577 # resubmit failed
578 if msg_id not in self.all_failed:
578 if msg_id not in self.all_failed:
579 # put it back in our dependency tree
579 # put it back in our dependency tree
580 self.save_unmet(msg_id, *args)
580 self.save_unmet(msg_id, *args)
581
581
582 if self.hwm:
582 if self.hwm:
583 try:
583 try:
584 idx = self.targets.index(engine)
584 idx = self.targets.index(engine)
585 except ValueError:
585 except ValueError:
586 pass # skip load-update for dead engines
586 pass # skip load-update for dead engines
587 else:
587 else:
588 if self.loads[idx] == self.hwm-1:
588 if self.loads[idx] == self.hwm-1:
589 self.update_graph(None)
589 self.update_graph(None)
590
590
591
591
592
592
593 @logged
593 @logged
594 def update_graph(self, dep_id=None, success=True):
594 def update_graph(self, dep_id=None, success=True):
595 """dep_id just finished. Update our dependency
595 """dep_id just finished. Update our dependency
596 graph and submit any jobs that just became runable.
596 graph and submit any jobs that just became runable.
597
597
598 Called with dep_id=None to update entire graph for hwm, but without finishing
598 Called with dep_id=None to update entire graph for hwm, but without finishing
599 a task.
599 a task.
600 """
600 """
601 # print ("\n\n***********")
601 # print ("\n\n***********")
602 # pprint (dep_id)
602 # pprint (dep_id)
603 # pprint (self.graph)
603 # pprint (self.graph)
604 # pprint (self.depending)
604 # pprint (self.depending)
605 # pprint (self.all_completed)
605 # pprint (self.all_completed)
606 # pprint (self.all_failed)
606 # pprint (self.all_failed)
607 # print ("\n\n***********\n\n")
607 # print ("\n\n***********\n\n")
608 # update any jobs that depended on the dependency
608 # update any jobs that depended on the dependency
609 jobs = self.graph.pop(dep_id, [])
609 jobs = self.graph.pop(dep_id, [])
610
610
611 # recheck *all* jobs if
611 # recheck *all* jobs if
612 # a) we have HWM and an engine just become no longer full
612 # a) we have HWM and an engine just become no longer full
613 # or b) dep_id was given as None
613 # or b) dep_id was given as None
614 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
614 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
615 jobs = self.depending.keys()
615 jobs = self.depending.keys()
616
616
617 for msg_id in jobs:
617 for msg_id in jobs:
618 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
618 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
619
619
620 if after.unreachable(self.all_completed, self.all_failed)\
620 if after.unreachable(self.all_completed, self.all_failed)\
621 or follow.unreachable(self.all_completed, self.all_failed):
621 or follow.unreachable(self.all_completed, self.all_failed):
622 self.fail_unreachable(msg_id)
622 self.fail_unreachable(msg_id)
623
623
624 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
624 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
625 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
625 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
626
626
627 self.depending.pop(msg_id)
627 self.depending.pop(msg_id)
628 for mid in follow.union(after):
628 for mid in follow.union(after):
629 if mid in self.graph:
629 if mid in self.graph:
630 self.graph[mid].remove(msg_id)
630 self.graph[mid].remove(msg_id)
631
631
632 #----------------------------------------------------------------------
632 #----------------------------------------------------------------------
633 # methods to be overridden by subclasses
633 # methods to be overridden by subclasses
634 #----------------------------------------------------------------------
634 #----------------------------------------------------------------------
635
635
636 def add_job(self, idx):
636 def add_job(self, idx):
637 """Called after self.targets[idx] just got the job with header.
637 """Called after self.targets[idx] just got the job with header.
638 Override with subclasses. The default ordering is simple LRU.
638 Override with subclasses. The default ordering is simple LRU.
639 The default loads are the number of outstanding jobs."""
639 The default loads are the number of outstanding jobs."""
640 self.loads[idx] += 1
640 self.loads[idx] += 1
641 for lis in (self.targets, self.loads):
641 for lis in (self.targets, self.loads):
642 lis.append(lis.pop(idx))
642 lis.append(lis.pop(idx))
643
643
644
644
645 def finish_job(self, idx):
645 def finish_job(self, idx):
646 """Called after self.targets[idx] just finished a job.
646 """Called after self.targets[idx] just finished a job.
647 Override with subclasses."""
647 Override with subclasses."""
648 self.loads[idx] -= 1
648 self.loads[idx] -= 1
649
649
650
650
651
651
652 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
652 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
653 logname='root', log_url=None, loglevel=logging.DEBUG,
653 logname='root', log_url=None, loglevel=logging.DEBUG,
654 identity=b'task', in_thread=False):
654 identity=b'task', in_thread=False):
655
655
656 ZMQStream = zmqstream.ZMQStream
656 ZMQStream = zmqstream.ZMQStream
657
657
658 if config:
658 if config:
659 # unwrap dict back into Config
659 # unwrap dict back into Config
660 config = Config(config)
660 config = Config(config)
661
661
662 if in_thread:
662 if in_thread:
663 # use instance() to get the same Context/Loop as our parent
663 # use instance() to get the same Context/Loop as our parent
664 ctx = zmq.Context.instance()
664 ctx = zmq.Context.instance()
665 loop = ioloop.IOLoop.instance()
665 loop = ioloop.IOLoop.instance()
666 else:
666 else:
667 # in a process, don't use instance()
667 # in a process, don't use instance()
668 # for safety with multiprocessing
668 # for safety with multiprocessing
669 ctx = zmq.Context()
669 ctx = zmq.Context()
670 loop = ioloop.IOLoop()
670 loop = ioloop.IOLoop()
671 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
671 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
672 ins.setsockopt(zmq.IDENTITY, identity)
672 ins.setsockopt(zmq.IDENTITY, identity)
673 ins.bind(in_addr)
673 ins.bind(in_addr)
674
674
675 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
675 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
676 outs.setsockopt(zmq.IDENTITY, identity)
676 outs.setsockopt(zmq.IDENTITY, identity)
677 outs.bind(out_addr)
677 outs.bind(out_addr)
678 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
678 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
679 mons.connect(mon_addr)
679 mons.connect(mon_addr)
680 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
680 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
681 nots.setsockopt(zmq.SUBSCRIBE, '')
681 nots.setsockopt(zmq.SUBSCRIBE, b'')
682 nots.connect(not_addr)
682 nots.connect(not_addr)
683
683
684 # setup logging.
684 # setup logging.
685 if in_thread:
685 if in_thread:
686 log = Application.instance().log
686 log = Application.instance().log
687 else:
687 else:
688 if log_url:
688 if log_url:
689 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
689 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
690 else:
690 else:
691 log = local_logger(logname, loglevel)
691 log = local_logger(logname, loglevel)
692
692
693 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
693 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
694 mon_stream=mons, notifier_stream=nots,
694 mon_stream=mons, notifier_stream=nots,
695 loop=loop, log=log,
695 loop=loop, log=log,
696 config=config)
696 config=config)
697 scheduler.start()
697 scheduler.start()
698 if not in_thread:
698 if not in_thread:
699 try:
699 try:
700 loop.start()
700 loop.start()
701 except KeyboardInterrupt:
701 except KeyboardInterrupt:
702 print ("interrupted, exiting...", file=sys.__stderr__)
702 print ("interrupted, exiting...", file=sys.__stderr__)
703
703
@@ -1,847 +1,847 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 Tests for IPython.utils.traitlets.
4 Tests for IPython.utils.traitlets.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * Enthought, Inc. Some of the code in this file comes from enthought.traits
9 * Enthought, Inc. Some of the code in this file comes from enthought.traits
10 and is licensed under the BSD license. Also, many of the ideas also come
10 and is licensed under the BSD license. Also, many of the ideas also come
11 from enthought.traits even though our implementation is very different.
11 from enthought.traits even though our implementation is very different.
12 """
12 """
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Copyright (C) 2008-2009 The IPython Development Team
15 # Copyright (C) 2008-2009 The IPython Development Team
16 #
16 #
17 # Distributed under the terms of the BSD License. The full license is in
17 # Distributed under the terms of the BSD License. The full license is in
18 # the file COPYING, distributed as part of this software.
18 # the file COPYING, distributed as part of this software.
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # Imports
22 # Imports
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24
24
25 import sys
25 import sys
26 from unittest import TestCase
26 from unittest import TestCase
27
27
28 from IPython.utils.traitlets import (
28 from IPython.utils.traitlets import (
29 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
29 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
30 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
30 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
31 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
31 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
32 ObjectName, DottedObjectName
32 ObjectName, DottedObjectName
33 )
33 )
34
34
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Helper classes for testing
37 # Helper classes for testing
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40
40
41 class HasTraitsStub(HasTraits):
41 class HasTraitsStub(HasTraits):
42
42
43 def _notify_trait(self, name, old, new):
43 def _notify_trait(self, name, old, new):
44 self._notify_name = name
44 self._notify_name = name
45 self._notify_old = old
45 self._notify_old = old
46 self._notify_new = new
46 self._notify_new = new
47
47
48
48
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50 # Test classes
50 # Test classes
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52
52
53
53
54 class TestTraitType(TestCase):
54 class TestTraitType(TestCase):
55
55
56 def test_get_undefined(self):
56 def test_get_undefined(self):
57 class A(HasTraits):
57 class A(HasTraits):
58 a = TraitType
58 a = TraitType
59 a = A()
59 a = A()
60 self.assertEquals(a.a, Undefined)
60 self.assertEquals(a.a, Undefined)
61
61
62 def test_set(self):
62 def test_set(self):
63 class A(HasTraitsStub):
63 class A(HasTraitsStub):
64 a = TraitType
64 a = TraitType
65
65
66 a = A()
66 a = A()
67 a.a = 10
67 a.a = 10
68 self.assertEquals(a.a, 10)
68 self.assertEquals(a.a, 10)
69 self.assertEquals(a._notify_name, 'a')
69 self.assertEquals(a._notify_name, 'a')
70 self.assertEquals(a._notify_old, Undefined)
70 self.assertEquals(a._notify_old, Undefined)
71 self.assertEquals(a._notify_new, 10)
71 self.assertEquals(a._notify_new, 10)
72
72
73 def test_validate(self):
73 def test_validate(self):
74 class MyTT(TraitType):
74 class MyTT(TraitType):
75 def validate(self, inst, value):
75 def validate(self, inst, value):
76 return -1
76 return -1
77 class A(HasTraitsStub):
77 class A(HasTraitsStub):
78 tt = MyTT
78 tt = MyTT
79
79
80 a = A()
80 a = A()
81 a.tt = 10
81 a.tt = 10
82 self.assertEquals(a.tt, -1)
82 self.assertEquals(a.tt, -1)
83
83
84 def test_default_validate(self):
84 def test_default_validate(self):
85 class MyIntTT(TraitType):
85 class MyIntTT(TraitType):
86 def validate(self, obj, value):
86 def validate(self, obj, value):
87 if isinstance(value, int):
87 if isinstance(value, int):
88 return value
88 return value
89 self.error(obj, value)
89 self.error(obj, value)
90 class A(HasTraits):
90 class A(HasTraits):
91 tt = MyIntTT(10)
91 tt = MyIntTT(10)
92 a = A()
92 a = A()
93 self.assertEquals(a.tt, 10)
93 self.assertEquals(a.tt, 10)
94
94
95 # Defaults are validated when the HasTraits is instantiated
95 # Defaults are validated when the HasTraits is instantiated
96 class B(HasTraits):
96 class B(HasTraits):
97 tt = MyIntTT('bad default')
97 tt = MyIntTT('bad default')
98 self.assertRaises(TraitError, B)
98 self.assertRaises(TraitError, B)
99
99
100 def test_is_valid_for(self):
100 def test_is_valid_for(self):
101 class MyTT(TraitType):
101 class MyTT(TraitType):
102 def is_valid_for(self, value):
102 def is_valid_for(self, value):
103 return True
103 return True
104 class A(HasTraits):
104 class A(HasTraits):
105 tt = MyTT
105 tt = MyTT
106
106
107 a = A()
107 a = A()
108 a.tt = 10
108 a.tt = 10
109 self.assertEquals(a.tt, 10)
109 self.assertEquals(a.tt, 10)
110
110
111 def test_value_for(self):
111 def test_value_for(self):
112 class MyTT(TraitType):
112 class MyTT(TraitType):
113 def value_for(self, value):
113 def value_for(self, value):
114 return 20
114 return 20
115 class A(HasTraits):
115 class A(HasTraits):
116 tt = MyTT
116 tt = MyTT
117
117
118 a = A()
118 a = A()
119 a.tt = 10
119 a.tt = 10
120 self.assertEquals(a.tt, 20)
120 self.assertEquals(a.tt, 20)
121
121
122 def test_info(self):
122 def test_info(self):
123 class A(HasTraits):
123 class A(HasTraits):
124 tt = TraitType
124 tt = TraitType
125 a = A()
125 a = A()
126 self.assertEquals(A.tt.info(), 'any value')
126 self.assertEquals(A.tt.info(), 'any value')
127
127
128 def test_error(self):
128 def test_error(self):
129 class A(HasTraits):
129 class A(HasTraits):
130 tt = TraitType
130 tt = TraitType
131 a = A()
131 a = A()
132 self.assertRaises(TraitError, A.tt.error, a, 10)
132 self.assertRaises(TraitError, A.tt.error, a, 10)
133
133
134 def test_dynamic_initializer(self):
134 def test_dynamic_initializer(self):
135 class A(HasTraits):
135 class A(HasTraits):
136 x = Int(10)
136 x = Int(10)
137 def _x_default(self):
137 def _x_default(self):
138 return 11
138 return 11
139 class B(A):
139 class B(A):
140 x = Int(20)
140 x = Int(20)
141 class C(A):
141 class C(A):
142 def _x_default(self):
142 def _x_default(self):
143 return 21
143 return 21
144
144
145 a = A()
145 a = A()
146 self.assertEquals(a._trait_values, {})
146 self.assertEquals(a._trait_values, {})
147 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
147 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
148 self.assertEquals(a.x, 11)
148 self.assertEquals(a.x, 11)
149 self.assertEquals(a._trait_values, {'x': 11})
149 self.assertEquals(a._trait_values, {'x': 11})
150 b = B()
150 b = B()
151 self.assertEquals(b._trait_values, {'x': 20})
151 self.assertEquals(b._trait_values, {'x': 20})
152 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
152 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
153 self.assertEquals(b.x, 20)
153 self.assertEquals(b.x, 20)
154 c = C()
154 c = C()
155 self.assertEquals(c._trait_values, {})
155 self.assertEquals(c._trait_values, {})
156 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
156 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
157 self.assertEquals(c.x, 21)
157 self.assertEquals(c.x, 21)
158 self.assertEquals(c._trait_values, {'x': 21})
158 self.assertEquals(c._trait_values, {'x': 21})
159 # Ensure that the base class remains unmolested when the _default
159 # Ensure that the base class remains unmolested when the _default
160 # initializer gets overridden in a subclass.
160 # initializer gets overridden in a subclass.
161 a = A()
161 a = A()
162 c = C()
162 c = C()
163 self.assertEquals(a._trait_values, {})
163 self.assertEquals(a._trait_values, {})
164 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
164 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
165 self.assertEquals(a.x, 11)
165 self.assertEquals(a.x, 11)
166 self.assertEquals(a._trait_values, {'x': 11})
166 self.assertEquals(a._trait_values, {'x': 11})
167
167
168
168
169
169
170 class TestHasTraitsMeta(TestCase):
170 class TestHasTraitsMeta(TestCase):
171
171
172 def test_metaclass(self):
172 def test_metaclass(self):
173 self.assertEquals(type(HasTraits), MetaHasTraits)
173 self.assertEquals(type(HasTraits), MetaHasTraits)
174
174
175 class A(HasTraits):
175 class A(HasTraits):
176 a = Int
176 a = Int
177
177
178 a = A()
178 a = A()
179 self.assertEquals(type(a.__class__), MetaHasTraits)
179 self.assertEquals(type(a.__class__), MetaHasTraits)
180 self.assertEquals(a.a,0)
180 self.assertEquals(a.a,0)
181 a.a = 10
181 a.a = 10
182 self.assertEquals(a.a,10)
182 self.assertEquals(a.a,10)
183
183
184 class B(HasTraits):
184 class B(HasTraits):
185 b = Int()
185 b = Int()
186
186
187 b = B()
187 b = B()
188 self.assertEquals(b.b,0)
188 self.assertEquals(b.b,0)
189 b.b = 10
189 b.b = 10
190 self.assertEquals(b.b,10)
190 self.assertEquals(b.b,10)
191
191
192 class C(HasTraits):
192 class C(HasTraits):
193 c = Int(30)
193 c = Int(30)
194
194
195 c = C()
195 c = C()
196 self.assertEquals(c.c,30)
196 self.assertEquals(c.c,30)
197 c.c = 10
197 c.c = 10
198 self.assertEquals(c.c,10)
198 self.assertEquals(c.c,10)
199
199
200 def test_this_class(self):
200 def test_this_class(self):
201 class A(HasTraits):
201 class A(HasTraits):
202 t = This()
202 t = This()
203 tt = This()
203 tt = This()
204 class B(A):
204 class B(A):
205 tt = This()
205 tt = This()
206 ttt = This()
206 ttt = This()
207 self.assertEquals(A.t.this_class, A)
207 self.assertEquals(A.t.this_class, A)
208 self.assertEquals(B.t.this_class, A)
208 self.assertEquals(B.t.this_class, A)
209 self.assertEquals(B.tt.this_class, B)
209 self.assertEquals(B.tt.this_class, B)
210 self.assertEquals(B.ttt.this_class, B)
210 self.assertEquals(B.ttt.this_class, B)
211
211
212 class TestHasTraitsNotify(TestCase):
212 class TestHasTraitsNotify(TestCase):
213
213
214 def setUp(self):
214 def setUp(self):
215 self._notify1 = []
215 self._notify1 = []
216 self._notify2 = []
216 self._notify2 = []
217
217
218 def notify1(self, name, old, new):
218 def notify1(self, name, old, new):
219 self._notify1.append((name, old, new))
219 self._notify1.append((name, old, new))
220
220
221 def notify2(self, name, old, new):
221 def notify2(self, name, old, new):
222 self._notify2.append((name, old, new))
222 self._notify2.append((name, old, new))
223
223
224 def test_notify_all(self):
224 def test_notify_all(self):
225
225
226 class A(HasTraits):
226 class A(HasTraits):
227 a = Int
227 a = Int
228 b = Float
228 b = Float
229
229
230 a = A()
230 a = A()
231 a.on_trait_change(self.notify1)
231 a.on_trait_change(self.notify1)
232 a.a = 0
232 a.a = 0
233 self.assertEquals(len(self._notify1),0)
233 self.assertEquals(len(self._notify1),0)
234 a.b = 0.0
234 a.b = 0.0
235 self.assertEquals(len(self._notify1),0)
235 self.assertEquals(len(self._notify1),0)
236 a.a = 10
236 a.a = 10
237 self.assert_(('a',0,10) in self._notify1)
237 self.assert_(('a',0,10) in self._notify1)
238 a.b = 10.0
238 a.b = 10.0
239 self.assert_(('b',0.0,10.0) in self._notify1)
239 self.assert_(('b',0.0,10.0) in self._notify1)
240 self.assertRaises(TraitError,setattr,a,'a','bad string')
240 self.assertRaises(TraitError,setattr,a,'a','bad string')
241 self.assertRaises(TraitError,setattr,a,'b','bad string')
241 self.assertRaises(TraitError,setattr,a,'b','bad string')
242 self._notify1 = []
242 self._notify1 = []
243 a.on_trait_change(self.notify1,remove=True)
243 a.on_trait_change(self.notify1,remove=True)
244 a.a = 20
244 a.a = 20
245 a.b = 20.0
245 a.b = 20.0
246 self.assertEquals(len(self._notify1),0)
246 self.assertEquals(len(self._notify1),0)
247
247
248 def test_notify_one(self):
248 def test_notify_one(self):
249
249
250 class A(HasTraits):
250 class A(HasTraits):
251 a = Int
251 a = Int
252 b = Float
252 b = Float
253
253
254 a = A()
254 a = A()
255 a.on_trait_change(self.notify1, 'a')
255 a.on_trait_change(self.notify1, 'a')
256 a.a = 0
256 a.a = 0
257 self.assertEquals(len(self._notify1),0)
257 self.assertEquals(len(self._notify1),0)
258 a.a = 10
258 a.a = 10
259 self.assert_(('a',0,10) in self._notify1)
259 self.assert_(('a',0,10) in self._notify1)
260 self.assertRaises(TraitError,setattr,a,'a','bad string')
260 self.assertRaises(TraitError,setattr,a,'a','bad string')
261
261
262 def test_subclass(self):
262 def test_subclass(self):
263
263
264 class A(HasTraits):
264 class A(HasTraits):
265 a = Int
265 a = Int
266
266
267 class B(A):
267 class B(A):
268 b = Float
268 b = Float
269
269
270 b = B()
270 b = B()
271 self.assertEquals(b.a,0)
271 self.assertEquals(b.a,0)
272 self.assertEquals(b.b,0.0)
272 self.assertEquals(b.b,0.0)
273 b.a = 100
273 b.a = 100
274 b.b = 100.0
274 b.b = 100.0
275 self.assertEquals(b.a,100)
275 self.assertEquals(b.a,100)
276 self.assertEquals(b.b,100.0)
276 self.assertEquals(b.b,100.0)
277
277
278 def test_notify_subclass(self):
278 def test_notify_subclass(self):
279
279
280 class A(HasTraits):
280 class A(HasTraits):
281 a = Int
281 a = Int
282
282
283 class B(A):
283 class B(A):
284 b = Float
284 b = Float
285
285
286 b = B()
286 b = B()
287 b.on_trait_change(self.notify1, 'a')
287 b.on_trait_change(self.notify1, 'a')
288 b.on_trait_change(self.notify2, 'b')
288 b.on_trait_change(self.notify2, 'b')
289 b.a = 0
289 b.a = 0
290 b.b = 0.0
290 b.b = 0.0
291 self.assertEquals(len(self._notify1),0)
291 self.assertEquals(len(self._notify1),0)
292 self.assertEquals(len(self._notify2),0)
292 self.assertEquals(len(self._notify2),0)
293 b.a = 10
293 b.a = 10
294 b.b = 10.0
294 b.b = 10.0
295 self.assert_(('a',0,10) in self._notify1)
295 self.assert_(('a',0,10) in self._notify1)
296 self.assert_(('b',0.0,10.0) in self._notify2)
296 self.assert_(('b',0.0,10.0) in self._notify2)
297
297
298 def test_static_notify(self):
298 def test_static_notify(self):
299
299
300 class A(HasTraits):
300 class A(HasTraits):
301 a = Int
301 a = Int
302 _notify1 = []
302 _notify1 = []
303 def _a_changed(self, name, old, new):
303 def _a_changed(self, name, old, new):
304 self._notify1.append((name, old, new))
304 self._notify1.append((name, old, new))
305
305
306 a = A()
306 a = A()
307 a.a = 0
307 a.a = 0
308 # This is broken!!!
308 # This is broken!!!
309 self.assertEquals(len(a._notify1),0)
309 self.assertEquals(len(a._notify1),0)
310 a.a = 10
310 a.a = 10
311 self.assert_(('a',0,10) in a._notify1)
311 self.assert_(('a',0,10) in a._notify1)
312
312
313 class B(A):
313 class B(A):
314 b = Float
314 b = Float
315 _notify2 = []
315 _notify2 = []
316 def _b_changed(self, name, old, new):
316 def _b_changed(self, name, old, new):
317 self._notify2.append((name, old, new))
317 self._notify2.append((name, old, new))
318
318
319 b = B()
319 b = B()
320 b.a = 10
320 b.a = 10
321 b.b = 10.0
321 b.b = 10.0
322 self.assert_(('a',0,10) in b._notify1)
322 self.assert_(('a',0,10) in b._notify1)
323 self.assert_(('b',0.0,10.0) in b._notify2)
323 self.assert_(('b',0.0,10.0) in b._notify2)
324
324
325 def test_notify_args(self):
325 def test_notify_args(self):
326
326
327 def callback0():
327 def callback0():
328 self.cb = ()
328 self.cb = ()
329 def callback1(name):
329 def callback1(name):
330 self.cb = (name,)
330 self.cb = (name,)
331 def callback2(name, new):
331 def callback2(name, new):
332 self.cb = (name, new)
332 self.cb = (name, new)
333 def callback3(name, old, new):
333 def callback3(name, old, new):
334 self.cb = (name, old, new)
334 self.cb = (name, old, new)
335
335
336 class A(HasTraits):
336 class A(HasTraits):
337 a = Int
337 a = Int
338
338
339 a = A()
339 a = A()
340 a.on_trait_change(callback0, 'a')
340 a.on_trait_change(callback0, 'a')
341 a.a = 10
341 a.a = 10
342 self.assertEquals(self.cb,())
342 self.assertEquals(self.cb,())
343 a.on_trait_change(callback0, 'a', remove=True)
343 a.on_trait_change(callback0, 'a', remove=True)
344
344
345 a.on_trait_change(callback1, 'a')
345 a.on_trait_change(callback1, 'a')
346 a.a = 100
346 a.a = 100
347 self.assertEquals(self.cb,('a',))
347 self.assertEquals(self.cb,('a',))
348 a.on_trait_change(callback1, 'a', remove=True)
348 a.on_trait_change(callback1, 'a', remove=True)
349
349
350 a.on_trait_change(callback2, 'a')
350 a.on_trait_change(callback2, 'a')
351 a.a = 1000
351 a.a = 1000
352 self.assertEquals(self.cb,('a',1000))
352 self.assertEquals(self.cb,('a',1000))
353 a.on_trait_change(callback2, 'a', remove=True)
353 a.on_trait_change(callback2, 'a', remove=True)
354
354
355 a.on_trait_change(callback3, 'a')
355 a.on_trait_change(callback3, 'a')
356 a.a = 10000
356 a.a = 10000
357 self.assertEquals(self.cb,('a',1000,10000))
357 self.assertEquals(self.cb,('a',1000,10000))
358 a.on_trait_change(callback3, 'a', remove=True)
358 a.on_trait_change(callback3, 'a', remove=True)
359
359
360 self.assertEquals(len(a._trait_notifiers['a']),0)
360 self.assertEquals(len(a._trait_notifiers['a']),0)
361
361
362
362
363 class TestHasTraits(TestCase):
363 class TestHasTraits(TestCase):
364
364
365 def test_trait_names(self):
365 def test_trait_names(self):
366 class A(HasTraits):
366 class A(HasTraits):
367 i = Int
367 i = Int
368 f = Float
368 f = Float
369 a = A()
369 a = A()
370 self.assertEquals(a.trait_names(),['i','f'])
370 self.assertEquals(a.trait_names(),['i','f'])
371 self.assertEquals(A.class_trait_names(),['i','f'])
371 self.assertEquals(A.class_trait_names(),['i','f'])
372
372
373 def test_trait_metadata(self):
373 def test_trait_metadata(self):
374 class A(HasTraits):
374 class A(HasTraits):
375 i = Int(config_key='MY_VALUE')
375 i = Int(config_key='MY_VALUE')
376 a = A()
376 a = A()
377 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
377 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
378
378
379 def test_traits(self):
379 def test_traits(self):
380 class A(HasTraits):
380 class A(HasTraits):
381 i = Int
381 i = Int
382 f = Float
382 f = Float
383 a = A()
383 a = A()
384 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
384 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
385 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
385 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
386
386
387 def test_traits_metadata(self):
387 def test_traits_metadata(self):
388 class A(HasTraits):
388 class A(HasTraits):
389 i = Int(config_key='VALUE1', other_thing='VALUE2')
389 i = Int(config_key='VALUE1', other_thing='VALUE2')
390 f = Float(config_key='VALUE3', other_thing='VALUE2')
390 f = Float(config_key='VALUE3', other_thing='VALUE2')
391 j = Int(0)
391 j = Int(0)
392 a = A()
392 a = A()
393 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
393 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
394 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
394 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
395 self.assertEquals(traits, dict(i=A.i))
395 self.assertEquals(traits, dict(i=A.i))
396
396
397 # This passes, but it shouldn't because I am replicating a bug in
397 # This passes, but it shouldn't because I am replicating a bug in
398 # traits.
398 # traits.
399 traits = a.traits(config_key=lambda v: True)
399 traits = a.traits(config_key=lambda v: True)
400 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
400 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
401
401
402 def test_init(self):
402 def test_init(self):
403 class A(HasTraits):
403 class A(HasTraits):
404 i = Int()
404 i = Int()
405 x = Float()
405 x = Float()
406 a = A(i=1, x=10.0)
406 a = A(i=1, x=10.0)
407 self.assertEquals(a.i, 1)
407 self.assertEquals(a.i, 1)
408 self.assertEquals(a.x, 10.0)
408 self.assertEquals(a.x, 10.0)
409
409
410 #-----------------------------------------------------------------------------
410 #-----------------------------------------------------------------------------
411 # Tests for specific trait types
411 # Tests for specific trait types
412 #-----------------------------------------------------------------------------
412 #-----------------------------------------------------------------------------
413
413
414
414
415 class TestType(TestCase):
415 class TestType(TestCase):
416
416
417 def test_default(self):
417 def test_default(self):
418
418
419 class B(object): pass
419 class B(object): pass
420 class A(HasTraits):
420 class A(HasTraits):
421 klass = Type
421 klass = Type
422
422
423 a = A()
423 a = A()
424 self.assertEquals(a.klass, None)
424 self.assertEquals(a.klass, None)
425
425
426 a.klass = B
426 a.klass = B
427 self.assertEquals(a.klass, B)
427 self.assertEquals(a.klass, B)
428 self.assertRaises(TraitError, setattr, a, 'klass', 10)
428 self.assertRaises(TraitError, setattr, a, 'klass', 10)
429
429
430 def test_value(self):
430 def test_value(self):
431
431
432 class B(object): pass
432 class B(object): pass
433 class C(object): pass
433 class C(object): pass
434 class A(HasTraits):
434 class A(HasTraits):
435 klass = Type(B)
435 klass = Type(B)
436
436
437 a = A()
437 a = A()
438 self.assertEquals(a.klass, B)
438 self.assertEquals(a.klass, B)
439 self.assertRaises(TraitError, setattr, a, 'klass', C)
439 self.assertRaises(TraitError, setattr, a, 'klass', C)
440 self.assertRaises(TraitError, setattr, a, 'klass', object)
440 self.assertRaises(TraitError, setattr, a, 'klass', object)
441 a.klass = B
441 a.klass = B
442
442
443 def test_allow_none(self):
443 def test_allow_none(self):
444
444
445 class B(object): pass
445 class B(object): pass
446 class C(B): pass
446 class C(B): pass
447 class A(HasTraits):
447 class A(HasTraits):
448 klass = Type(B, allow_none=False)
448 klass = Type(B, allow_none=False)
449
449
450 a = A()
450 a = A()
451 self.assertEquals(a.klass, B)
451 self.assertEquals(a.klass, B)
452 self.assertRaises(TraitError, setattr, a, 'klass', None)
452 self.assertRaises(TraitError, setattr, a, 'klass', None)
453 a.klass = C
453 a.klass = C
454 self.assertEquals(a.klass, C)
454 self.assertEquals(a.klass, C)
455
455
456 def test_validate_klass(self):
456 def test_validate_klass(self):
457
457
458 class A(HasTraits):
458 class A(HasTraits):
459 klass = Type('no strings allowed')
459 klass = Type('no strings allowed')
460
460
461 self.assertRaises(ImportError, A)
461 self.assertRaises(ImportError, A)
462
462
463 class A(HasTraits):
463 class A(HasTraits):
464 klass = Type('rub.adub.Duck')
464 klass = Type('rub.adub.Duck')
465
465
466 self.assertRaises(ImportError, A)
466 self.assertRaises(ImportError, A)
467
467
468 def test_validate_default(self):
468 def test_validate_default(self):
469
469
470 class B(object): pass
470 class B(object): pass
471 class A(HasTraits):
471 class A(HasTraits):
472 klass = Type('bad default', B)
472 klass = Type('bad default', B)
473
473
474 self.assertRaises(ImportError, A)
474 self.assertRaises(ImportError, A)
475
475
476 class C(HasTraits):
476 class C(HasTraits):
477 klass = Type(None, B, allow_none=False)
477 klass = Type(None, B, allow_none=False)
478
478
479 self.assertRaises(TraitError, C)
479 self.assertRaises(TraitError, C)
480
480
481 def test_str_klass(self):
481 def test_str_klass(self):
482
482
483 class A(HasTraits):
483 class A(HasTraits):
484 klass = Type('IPython.utils.ipstruct.Struct')
484 klass = Type('IPython.utils.ipstruct.Struct')
485
485
486 from IPython.utils.ipstruct import Struct
486 from IPython.utils.ipstruct import Struct
487 a = A()
487 a = A()
488 a.klass = Struct
488 a.klass = Struct
489 self.assertEquals(a.klass, Struct)
489 self.assertEquals(a.klass, Struct)
490
490
491 self.assertRaises(TraitError, setattr, a, 'klass', 10)
491 self.assertRaises(TraitError, setattr, a, 'klass', 10)
492
492
493 class TestInstance(TestCase):
493 class TestInstance(TestCase):
494
494
495 def test_basic(self):
495 def test_basic(self):
496 class Foo(object): pass
496 class Foo(object): pass
497 class Bar(Foo): pass
497 class Bar(Foo): pass
498 class Bah(object): pass
498 class Bah(object): pass
499
499
500 class A(HasTraits):
500 class A(HasTraits):
501 inst = Instance(Foo)
501 inst = Instance(Foo)
502
502
503 a = A()
503 a = A()
504 self.assert_(a.inst is None)
504 self.assert_(a.inst is None)
505 a.inst = Foo()
505 a.inst = Foo()
506 self.assert_(isinstance(a.inst, Foo))
506 self.assert_(isinstance(a.inst, Foo))
507 a.inst = Bar()
507 a.inst = Bar()
508 self.assert_(isinstance(a.inst, Foo))
508 self.assert_(isinstance(a.inst, Foo))
509 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
509 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
510 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
510 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
511 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
511 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
512
512
513 def test_unique_default_value(self):
513 def test_unique_default_value(self):
514 class Foo(object): pass
514 class Foo(object): pass
515 class A(HasTraits):
515 class A(HasTraits):
516 inst = Instance(Foo,(),{})
516 inst = Instance(Foo,(),{})
517
517
518 a = A()
518 a = A()
519 b = A()
519 b = A()
520 self.assert_(a.inst is not b.inst)
520 self.assert_(a.inst is not b.inst)
521
521
522 def test_args_kw(self):
522 def test_args_kw(self):
523 class Foo(object):
523 class Foo(object):
524 def __init__(self, c): self.c = c
524 def __init__(self, c): self.c = c
525 class Bar(object): pass
525 class Bar(object): pass
526 class Bah(object):
526 class Bah(object):
527 def __init__(self, c, d):
527 def __init__(self, c, d):
528 self.c = c; self.d = d
528 self.c = c; self.d = d
529
529
530 class A(HasTraits):
530 class A(HasTraits):
531 inst = Instance(Foo, (10,))
531 inst = Instance(Foo, (10,))
532 a = A()
532 a = A()
533 self.assertEquals(a.inst.c, 10)
533 self.assertEquals(a.inst.c, 10)
534
534
535 class B(HasTraits):
535 class B(HasTraits):
536 inst = Instance(Bah, args=(10,), kw=dict(d=20))
536 inst = Instance(Bah, args=(10,), kw=dict(d=20))
537 b = B()
537 b = B()
538 self.assertEquals(b.inst.c, 10)
538 self.assertEquals(b.inst.c, 10)
539 self.assertEquals(b.inst.d, 20)
539 self.assertEquals(b.inst.d, 20)
540
540
541 class C(HasTraits):
541 class C(HasTraits):
542 inst = Instance(Foo)
542 inst = Instance(Foo)
543 c = C()
543 c = C()
544 self.assert_(c.inst is None)
544 self.assert_(c.inst is None)
545
545
546 def test_bad_default(self):
546 def test_bad_default(self):
547 class Foo(object): pass
547 class Foo(object): pass
548
548
549 class A(HasTraits):
549 class A(HasTraits):
550 inst = Instance(Foo, allow_none=False)
550 inst = Instance(Foo, allow_none=False)
551
551
552 self.assertRaises(TraitError, A)
552 self.assertRaises(TraitError, A)
553
553
554 def test_instance(self):
554 def test_instance(self):
555 class Foo(object): pass
555 class Foo(object): pass
556
556
557 def inner():
557 def inner():
558 class A(HasTraits):
558 class A(HasTraits):
559 inst = Instance(Foo())
559 inst = Instance(Foo())
560
560
561 self.assertRaises(TraitError, inner)
561 self.assertRaises(TraitError, inner)
562
562
563
563
564 class TestThis(TestCase):
564 class TestThis(TestCase):
565
565
566 def test_this_class(self):
566 def test_this_class(self):
567 class Foo(HasTraits):
567 class Foo(HasTraits):
568 this = This
568 this = This
569
569
570 f = Foo()
570 f = Foo()
571 self.assertEquals(f.this, None)
571 self.assertEquals(f.this, None)
572 g = Foo()
572 g = Foo()
573 f.this = g
573 f.this = g
574 self.assertEquals(f.this, g)
574 self.assertEquals(f.this, g)
575 self.assertRaises(TraitError, setattr, f, 'this', 10)
575 self.assertRaises(TraitError, setattr, f, 'this', 10)
576
576
577 def test_this_inst(self):
577 def test_this_inst(self):
578 class Foo(HasTraits):
578 class Foo(HasTraits):
579 this = This()
579 this = This()
580
580
581 f = Foo()
581 f = Foo()
582 f.this = Foo()
582 f.this = Foo()
583 self.assert_(isinstance(f.this, Foo))
583 self.assert_(isinstance(f.this, Foo))
584
584
585 def test_subclass(self):
585 def test_subclass(self):
586 class Foo(HasTraits):
586 class Foo(HasTraits):
587 t = This()
587 t = This()
588 class Bar(Foo):
588 class Bar(Foo):
589 pass
589 pass
590 f = Foo()
590 f = Foo()
591 b = Bar()
591 b = Bar()
592 f.t = b
592 f.t = b
593 b.t = f
593 b.t = f
594 self.assertEquals(f.t, b)
594 self.assertEquals(f.t, b)
595 self.assertEquals(b.t, f)
595 self.assertEquals(b.t, f)
596
596
597 def test_subclass_override(self):
597 def test_subclass_override(self):
598 class Foo(HasTraits):
598 class Foo(HasTraits):
599 t = This()
599 t = This()
600 class Bar(Foo):
600 class Bar(Foo):
601 t = This()
601 t = This()
602 f = Foo()
602 f = Foo()
603 b = Bar()
603 b = Bar()
604 f.t = b
604 f.t = b
605 self.assertEquals(f.t, b)
605 self.assertEquals(f.t, b)
606 self.assertRaises(TraitError, setattr, b, 't', f)
606 self.assertRaises(TraitError, setattr, b, 't', f)
607
607
608 class TraitTestBase(TestCase):
608 class TraitTestBase(TestCase):
609 """A best testing class for basic trait types."""
609 """A best testing class for basic trait types."""
610
610
611 def assign(self, value):
611 def assign(self, value):
612 self.obj.value = value
612 self.obj.value = value
613
613
614 def coerce(self, value):
614 def coerce(self, value):
615 return value
615 return value
616
616
617 def test_good_values(self):
617 def test_good_values(self):
618 if hasattr(self, '_good_values'):
618 if hasattr(self, '_good_values'):
619 for value in self._good_values:
619 for value in self._good_values:
620 self.assign(value)
620 self.assign(value)
621 self.assertEquals(self.obj.value, self.coerce(value))
621 self.assertEquals(self.obj.value, self.coerce(value))
622
622
623 def test_bad_values(self):
623 def test_bad_values(self):
624 if hasattr(self, '_bad_values'):
624 if hasattr(self, '_bad_values'):
625 for value in self._bad_values:
625 for value in self._bad_values:
626 self.assertRaises(TraitError, self.assign, value)
626 self.assertRaises(TraitError, self.assign, value)
627
627
628 def test_default_value(self):
628 def test_default_value(self):
629 if hasattr(self, '_default_value'):
629 if hasattr(self, '_default_value'):
630 self.assertEquals(self._default_value, self.obj.value)
630 self.assertEquals(self._default_value, self.obj.value)
631
631
632
632
633 class AnyTrait(HasTraits):
633 class AnyTrait(HasTraits):
634
634
635 value = Any
635 value = Any
636
636
637 class AnyTraitTest(TraitTestBase):
637 class AnyTraitTest(TraitTestBase):
638
638
639 obj = AnyTrait()
639 obj = AnyTrait()
640
640
641 _default_value = None
641 _default_value = None
642 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
642 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
643 _bad_values = []
643 _bad_values = []
644
644
645
645
646 class IntTrait(HasTraits):
646 class IntTrait(HasTraits):
647
647
648 value = Int(99)
648 value = Int(99)
649
649
650 class TestInt(TraitTestBase):
650 class TestInt(TraitTestBase):
651
651
652 obj = IntTrait()
652 obj = IntTrait()
653 _default_value = 99
653 _default_value = 99
654 _good_values = [10, -10]
654 _good_values = [10, -10]
655 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
655 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
656 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
656 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
657 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
657 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
658
658
659
659
660 class LongTrait(HasTraits):
660 class LongTrait(HasTraits):
661
661
662 value = Long(99L)
662 value = Long(99L)
663
663
664 class TestLong(TraitTestBase):
664 class TestLong(TraitTestBase):
665
665
666 obj = LongTrait()
666 obj = LongTrait()
667
667
668 _default_value = 99L
668 _default_value = 99L
669 _good_values = [10, -10, 10L, -10L]
669 _good_values = [10, -10, 10L, -10L]
670 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
670 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
671 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
671 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
672 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
672 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
673 u'-10.1']
673 u'-10.1']
674
674
675
675
676 class FloatTrait(HasTraits):
676 class FloatTrait(HasTraits):
677
677
678 value = Float(99.0)
678 value = Float(99.0)
679
679
680 class TestFloat(TraitTestBase):
680 class TestFloat(TraitTestBase):
681
681
682 obj = FloatTrait()
682 obj = FloatTrait()
683
683
684 _default_value = 99.0
684 _default_value = 99.0
685 _good_values = [10, -10, 10.1, -10.1]
685 _good_values = [10, -10, 10.1, -10.1]
686 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
686 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
687 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
687 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
688 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
688 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
689
689
690
690
691 class ComplexTrait(HasTraits):
691 class ComplexTrait(HasTraits):
692
692
693 value = Complex(99.0-99.0j)
693 value = Complex(99.0-99.0j)
694
694
695 class TestComplex(TraitTestBase):
695 class TestComplex(TraitTestBase):
696
696
697 obj = ComplexTrait()
697 obj = ComplexTrait()
698
698
699 _default_value = 99.0-99.0j
699 _default_value = 99.0-99.0j
700 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
700 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
701 10.1j, 10.1+10.1j, 10.1-10.1j]
701 10.1j, 10.1+10.1j, 10.1-10.1j]
702 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
702 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
703
703
704
704
705 class BytesTrait(HasTraits):
705 class BytesTrait(HasTraits):
706
706
707 value = Bytes('string')
707 value = Bytes(b'string')
708
708
709 class TestBytes(TraitTestBase):
709 class TestBytes(TraitTestBase):
710
710
711 obj = BytesTrait()
711 obj = BytesTrait()
712
712
713 _default_value = 'string'
713 _default_value = b'string'
714 _good_values = ['10', '-10', '10L',
714 _good_values = [b'10', b'-10', b'10L',
715 '-10L', '10.1', '-10.1', 'string']
715 b'-10L', b'10.1', b'-10.1', b'string']
716 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
716 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
717 ['ten'],{'ten': 10},(10,), None, u'string']
717 ['ten'],{'ten': 10},(10,), None, u'string']
718
718
719
719
720 class UnicodeTrait(HasTraits):
720 class UnicodeTrait(HasTraits):
721
721
722 value = Unicode(u'unicode')
722 value = Unicode(u'unicode')
723
723
724 class TestUnicode(TraitTestBase):
724 class TestUnicode(TraitTestBase):
725
725
726 obj = UnicodeTrait()
726 obj = UnicodeTrait()
727
727
728 _default_value = u'unicode'
728 _default_value = u'unicode'
729 _good_values = ['10', '-10', '10L', '-10L', '10.1',
729 _good_values = ['10', '-10', '10L', '-10L', '10.1',
730 '-10.1', '', u'', 'string', u'string', u"€"]
730 '-10.1', '', u'', 'string', u'string', u"€"]
731 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
731 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
732 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
732 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
733
733
734
734
735 class ObjectNameTrait(HasTraits):
735 class ObjectNameTrait(HasTraits):
736 value = ObjectName("abc")
736 value = ObjectName("abc")
737
737
738 class TestObjectName(TraitTestBase):
738 class TestObjectName(TraitTestBase):
739 obj = ObjectNameTrait()
739 obj = ObjectNameTrait()
740
740
741 _default_value = "abc"
741 _default_value = "abc"
742 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
742 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
743 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
743 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
744 object(), object]
744 object(), object]
745 if sys.version_info[0] < 3:
745 if sys.version_info[0] < 3:
746 _bad_values.append(u"ΓΎ")
746 _bad_values.append(u"ΓΎ")
747 else:
747 else:
748 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
748 _good_values.append(u"ΓΎ") # ΓΎ=1 is valid in Python 3 (PEP 3131).
749
749
750
750
751 class DottedObjectNameTrait(HasTraits):
751 class DottedObjectNameTrait(HasTraits):
752 value = DottedObjectName("a.b")
752 value = DottedObjectName("a.b")
753
753
754 class TestDottedObjectName(TraitTestBase):
754 class TestDottedObjectName(TraitTestBase):
755 obj = DottedObjectNameTrait()
755 obj = DottedObjectNameTrait()
756
756
757 _default_value = "a.b"
757 _default_value = "a.b"
758 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
758 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
759 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
759 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
760 if sys.version_info[0] < 3:
760 if sys.version_info[0] < 3:
761 _bad_values.append(u"t.ΓΎ")
761 _bad_values.append(u"t.ΓΎ")
762 else:
762 else:
763 _good_values.append(u"t.ΓΎ")
763 _good_values.append(u"t.ΓΎ")
764
764
765
765
766 class TCPAddressTrait(HasTraits):
766 class TCPAddressTrait(HasTraits):
767
767
768 value = TCPAddress()
768 value = TCPAddress()
769
769
770 class TestTCPAddress(TraitTestBase):
770 class TestTCPAddress(TraitTestBase):
771
771
772 obj = TCPAddressTrait()
772 obj = TCPAddressTrait()
773
773
774 _default_value = ('127.0.0.1',0)
774 _default_value = ('127.0.0.1',0)
775 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
775 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
776 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
776 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
777
777
778 class ListTrait(HasTraits):
778 class ListTrait(HasTraits):
779
779
780 value = List(Int)
780 value = List(Int)
781
781
782 class TestList(TraitTestBase):
782 class TestList(TraitTestBase):
783
783
784 obj = ListTrait()
784 obj = ListTrait()
785
785
786 _default_value = []
786 _default_value = []
787 _good_values = [[], [1], range(10)]
787 _good_values = [[], [1], range(10)]
788 _bad_values = [10, [1,'a'], 'a', (1,2)]
788 _bad_values = [10, [1,'a'], 'a', (1,2)]
789
789
790 class LenListTrait(HasTraits):
790 class LenListTrait(HasTraits):
791
791
792 value = List(Int, [0], minlen=1, maxlen=2)
792 value = List(Int, [0], minlen=1, maxlen=2)
793
793
794 class TestLenList(TraitTestBase):
794 class TestLenList(TraitTestBase):
795
795
796 obj = LenListTrait()
796 obj = LenListTrait()
797
797
798 _default_value = [0]
798 _default_value = [0]
799 _good_values = [[1], range(2)]
799 _good_values = [[1], range(2)]
800 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
800 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
801
801
802 class TupleTrait(HasTraits):
802 class TupleTrait(HasTraits):
803
803
804 value = Tuple(Int)
804 value = Tuple(Int)
805
805
806 class TestTupleTrait(TraitTestBase):
806 class TestTupleTrait(TraitTestBase):
807
807
808 obj = TupleTrait()
808 obj = TupleTrait()
809
809
810 _default_value = None
810 _default_value = None
811 _good_values = [(1,), None,(0,)]
811 _good_values = [(1,), None,(0,)]
812 _bad_values = [10, (1,2), [1],('a'), ()]
812 _bad_values = [10, (1,2), [1],('a'), ()]
813
813
814 def test_invalid_args(self):
814 def test_invalid_args(self):
815 self.assertRaises(TypeError, Tuple, 5)
815 self.assertRaises(TypeError, Tuple, 5)
816 self.assertRaises(TypeError, Tuple, default_value='hello')
816 self.assertRaises(TypeError, Tuple, default_value='hello')
817 t = Tuple(Int, CBytes, default_value=(1,5))
817 t = Tuple(Int, CBytes, default_value=(1,5))
818
818
819 class LooseTupleTrait(HasTraits):
819 class LooseTupleTrait(HasTraits):
820
820
821 value = Tuple((1,2,3))
821 value = Tuple((1,2,3))
822
822
823 class TestLooseTupleTrait(TraitTestBase):
823 class TestLooseTupleTrait(TraitTestBase):
824
824
825 obj = LooseTupleTrait()
825 obj = LooseTupleTrait()
826
826
827 _default_value = (1,2,3)
827 _default_value = (1,2,3)
828 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
828 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
829 _bad_values = [10, 'hello', [1], []]
829 _bad_values = [10, 'hello', [1], []]
830
830
831 def test_invalid_args(self):
831 def test_invalid_args(self):
832 self.assertRaises(TypeError, Tuple, 5)
832 self.assertRaises(TypeError, Tuple, 5)
833 self.assertRaises(TypeError, Tuple, default_value='hello')
833 self.assertRaises(TypeError, Tuple, default_value='hello')
834 t = Tuple(Int, CBytes, default_value=(1,5))
834 t = Tuple(Int, CBytes, default_value=(1,5))
835
835
836
836
837 class MultiTupleTrait(HasTraits):
837 class MultiTupleTrait(HasTraits):
838
838
839 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
839 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
840
840
841 class TestMultiTuple(TraitTestBase):
841 class TestMultiTuple(TraitTestBase):
842
842
843 obj = MultiTupleTrait()
843 obj = MultiTupleTrait()
844
844
845 _default_value = (99,'bottles')
845 _default_value = (99,'bottles')
846 _good_values = [(1,'a'), (2,'b')]
846 _good_values = [(1,'a'), (2,'b')]
847 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
847 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
@@ -1,628 +1,628 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """Session object for building, serializing, sending, and receiving messages in
2 """Session object for building, serializing, sending, and receiving messages in
3 IPython. The Session object supports serialization, HMAC signatures, and
3 IPython. The Session object supports serialization, HMAC signatures, and
4 metadata on messages.
4 metadata on messages.
5
5
6 Also defined here are utilities for working with Sessions:
6 Also defined here are utilities for working with Sessions:
7 * A SessionFactory to be used as a base class for configurables that work with
7 * A SessionFactory to be used as a base class for configurables that work with
8 Sessions.
8 Sessions.
9 * A Message object for convenience that allows attribute-access to the msg dict.
9 * A Message object for convenience that allows attribute-access to the msg dict.
10
10
11 Authors:
11 Authors:
12
12
13 * Min RK
13 * Min RK
14 * Brian Granger
14 * Brian Granger
15 * Fernando Perez
15 * Fernando Perez
16 """
16 """
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Copyright (C) 2010-2011 The IPython Development Team
18 # Copyright (C) 2010-2011 The IPython Development Team
19 #
19 #
20 # Distributed under the terms of the BSD License. The full license is in
20 # Distributed under the terms of the BSD License. The full license is in
21 # the file COPYING, distributed as part of this software.
21 # the file COPYING, distributed as part of this software.
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25 # Imports
25 # Imports
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27
27
28 import hmac
28 import hmac
29 import logging
29 import logging
30 import os
30 import os
31 import pprint
31 import pprint
32 import uuid
32 import uuid
33 from datetime import datetime
33 from datetime import datetime
34
34
35 try:
35 try:
36 import cPickle
36 import cPickle
37 pickle = cPickle
37 pickle = cPickle
38 except:
38 except:
39 cPickle = None
39 cPickle = None
40 import pickle
40 import pickle
41
41
42 import zmq
42 import zmq
43 from zmq.utils import jsonapi
43 from zmq.utils import jsonapi
44 from zmq.eventloop.ioloop import IOLoop
44 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.zmqstream import ZMQStream
45 from zmq.eventloop.zmqstream import ZMQStream
46
46
47 from IPython.config.configurable import Configurable, LoggingConfigurable
47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.utils.importstring import import_item
48 from IPython.utils.importstring import import_item
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 DottedObjectName)
51 DottedObjectName)
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # utility functions
54 # utility functions
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 def squash_unicode(obj):
57 def squash_unicode(obj):
58 """coerce unicode back to bytestrings."""
58 """coerce unicode back to bytestrings."""
59 if isinstance(obj,dict):
59 if isinstance(obj,dict):
60 for key in obj.keys():
60 for key in obj.keys():
61 obj[key] = squash_unicode(obj[key])
61 obj[key] = squash_unicode(obj[key])
62 if isinstance(key, unicode):
62 if isinstance(key, unicode):
63 obj[squash_unicode(key)] = obj.pop(key)
63 obj[squash_unicode(key)] = obj.pop(key)
64 elif isinstance(obj, list):
64 elif isinstance(obj, list):
65 for i,v in enumerate(obj):
65 for i,v in enumerate(obj):
66 obj[i] = squash_unicode(v)
66 obj[i] = squash_unicode(v)
67 elif isinstance(obj, unicode):
67 elif isinstance(obj, unicode):
68 obj = obj.encode('utf8')
68 obj = obj.encode('utf8')
69 return obj
69 return obj
70
70
71 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
72 # globals and defaults
72 # globals and defaults
73 #-----------------------------------------------------------------------------
73 #-----------------------------------------------------------------------------
74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77
77
78 pickle_packer = lambda o: pickle.dumps(o,-1)
78 pickle_packer = lambda o: pickle.dumps(o,-1)
79 pickle_unpacker = pickle.loads
79 pickle_unpacker = pickle.loads
80
80
81 default_packer = json_packer
81 default_packer = json_packer
82 default_unpacker = json_unpacker
82 default_unpacker = json_unpacker
83
83
84
84
85 DELIM="<IDS|MSG>"
85 DELIM=b"<IDS|MSG>"
86
86
87 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
88 # Classes
88 # Classes
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90
90
91 class SessionFactory(LoggingConfigurable):
91 class SessionFactory(LoggingConfigurable):
92 """The Base class for configurables that have a Session, Context, logger,
92 """The Base class for configurables that have a Session, Context, logger,
93 and IOLoop.
93 and IOLoop.
94 """
94 """
95
95
96 logname = Unicode('')
96 logname = Unicode('')
97 def _logname_changed(self, name, old, new):
97 def _logname_changed(self, name, old, new):
98 self.log = logging.getLogger(new)
98 self.log = logging.getLogger(new)
99
99
100 # not configurable:
100 # not configurable:
101 context = Instance('zmq.Context')
101 context = Instance('zmq.Context')
102 def _context_default(self):
102 def _context_default(self):
103 return zmq.Context.instance()
103 return zmq.Context.instance()
104
104
105 session = Instance('IPython.zmq.session.Session')
105 session = Instance('IPython.zmq.session.Session')
106
106
107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 def _loop_default(self):
108 def _loop_default(self):
109 return IOLoop.instance()
109 return IOLoop.instance()
110
110
111 def __init__(self, **kwargs):
111 def __init__(self, **kwargs):
112 super(SessionFactory, self).__init__(**kwargs)
112 super(SessionFactory, self).__init__(**kwargs)
113
113
114 if self.session is None:
114 if self.session is None:
115 # construct the session
115 # construct the session
116 self.session = Session(**kwargs)
116 self.session = Session(**kwargs)
117
117
118
118
119 class Message(object):
119 class Message(object):
120 """A simple message object that maps dict keys to attributes.
120 """A simple message object that maps dict keys to attributes.
121
121
122 A Message can be created from a dict and a dict from a Message instance
122 A Message can be created from a dict and a dict from a Message instance
123 simply by calling dict(msg_obj)."""
123 simply by calling dict(msg_obj)."""
124
124
125 def __init__(self, msg_dict):
125 def __init__(self, msg_dict):
126 dct = self.__dict__
126 dct = self.__dict__
127 for k, v in dict(msg_dict).iteritems():
127 for k, v in dict(msg_dict).iteritems():
128 if isinstance(v, dict):
128 if isinstance(v, dict):
129 v = Message(v)
129 v = Message(v)
130 dct[k] = v
130 dct[k] = v
131
131
132 # Having this iterator lets dict(msg_obj) work out of the box.
132 # Having this iterator lets dict(msg_obj) work out of the box.
133 def __iter__(self):
133 def __iter__(self):
134 return iter(self.__dict__.iteritems())
134 return iter(self.__dict__.iteritems())
135
135
136 def __repr__(self):
136 def __repr__(self):
137 return repr(self.__dict__)
137 return repr(self.__dict__)
138
138
139 def __str__(self):
139 def __str__(self):
140 return pprint.pformat(self.__dict__)
140 return pprint.pformat(self.__dict__)
141
141
142 def __contains__(self, k):
142 def __contains__(self, k):
143 return k in self.__dict__
143 return k in self.__dict__
144
144
145 def __getitem__(self, k):
145 def __getitem__(self, k):
146 return self.__dict__[k]
146 return self.__dict__[k]
147
147
148
148
149 def msg_header(msg_id, msg_type, username, session):
149 def msg_header(msg_id, msg_type, username, session):
150 date = datetime.now()
150 date = datetime.now()
151 return locals()
151 return locals()
152
152
153 def extract_header(msg_or_header):
153 def extract_header(msg_or_header):
154 """Given a message or header, return the header."""
154 """Given a message or header, return the header."""
155 if not msg_or_header:
155 if not msg_or_header:
156 return {}
156 return {}
157 try:
157 try:
158 # See if msg_or_header is the entire message.
158 # See if msg_or_header is the entire message.
159 h = msg_or_header['header']
159 h = msg_or_header['header']
160 except KeyError:
160 except KeyError:
161 try:
161 try:
162 # See if msg_or_header is just the header
162 # See if msg_or_header is just the header
163 h = msg_or_header['msg_id']
163 h = msg_or_header['msg_id']
164 except KeyError:
164 except KeyError:
165 raise
165 raise
166 else:
166 else:
167 h = msg_or_header
167 h = msg_or_header
168 if not isinstance(h, dict):
168 if not isinstance(h, dict):
169 h = dict(h)
169 h = dict(h)
170 return h
170 return h
171
171
172 class Session(Configurable):
172 class Session(Configurable):
173 """Object for handling serialization and sending of messages.
173 """Object for handling serialization and sending of messages.
174
174
175 The Session object handles building messages and sending them
175 The Session object handles building messages and sending them
176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 other over the network via Session objects, and only need to work with the
177 other over the network via Session objects, and only need to work with the
178 dict-based IPython message spec. The Session will handle
178 dict-based IPython message spec. The Session will handle
179 serialization/deserialization, security, and metadata.
179 serialization/deserialization, security, and metadata.
180
180
181 Sessions support configurable serialiization via packer/unpacker traits,
181 Sessions support configurable serialiization via packer/unpacker traits,
182 and signing with HMAC digests via the key/keyfile traits.
182 and signing with HMAC digests via the key/keyfile traits.
183
183
184 Parameters
184 Parameters
185 ----------
185 ----------
186
186
187 debug : bool
187 debug : bool
188 whether to trigger extra debugging statements
188 whether to trigger extra debugging statements
189 packer/unpacker : str : 'json', 'pickle' or import_string
189 packer/unpacker : str : 'json', 'pickle' or import_string
190 importstrings for methods to serialize message parts. If just
190 importstrings for methods to serialize message parts. If just
191 'json' or 'pickle', predefined JSON and pickle packers will be used.
191 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 Otherwise, the entire importstring must be used.
192 Otherwise, the entire importstring must be used.
193
193
194 The functions must accept at least valid JSON input, and output *bytes*.
194 The functions must accept at least valid JSON input, and output *bytes*.
195
195
196 For example, to use msgpack:
196 For example, to use msgpack:
197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 pack/unpack : callables
198 pack/unpack : callables
199 You can also set the pack/unpack callables for serialization directly.
199 You can also set the pack/unpack callables for serialization directly.
200 session : bytes
200 session : bytes
201 the ID of this Session object. The default is to generate a new UUID.
201 the ID of this Session object. The default is to generate a new UUID.
202 username : unicode
202 username : unicode
203 username added to message headers. The default is to ask the OS.
203 username added to message headers. The default is to ask the OS.
204 key : bytes
204 key : bytes
205 The key used to initialize an HMAC signature. If unset, messages
205 The key used to initialize an HMAC signature. If unset, messages
206 will not be signed or checked.
206 will not be signed or checked.
207 keyfile : filepath
207 keyfile : filepath
208 The file containing a key. If this is set, `key` will be initialized
208 The file containing a key. If this is set, `key` will be initialized
209 to the contents of the file.
209 to the contents of the file.
210
210
211 """
211 """
212
212
213 debug=Bool(False, config=True, help="""Debug output in the Session""")
213 debug=Bool(False, config=True, help="""Debug output in the Session""")
214
214
215 packer = DottedObjectName('json',config=True,
215 packer = DottedObjectName('json',config=True,
216 help="""The name of the packer for serializing messages.
216 help="""The name of the packer for serializing messages.
217 Should be one of 'json', 'pickle', or an import name
217 Should be one of 'json', 'pickle', or an import name
218 for a custom callable serializer.""")
218 for a custom callable serializer.""")
219 def _packer_changed(self, name, old, new):
219 def _packer_changed(self, name, old, new):
220 if new.lower() == 'json':
220 if new.lower() == 'json':
221 self.pack = json_packer
221 self.pack = json_packer
222 self.unpack = json_unpacker
222 self.unpack = json_unpacker
223 elif new.lower() == 'pickle':
223 elif new.lower() == 'pickle':
224 self.pack = pickle_packer
224 self.pack = pickle_packer
225 self.unpack = pickle_unpacker
225 self.unpack = pickle_unpacker
226 else:
226 else:
227 self.pack = import_item(str(new))
227 self.pack = import_item(str(new))
228
228
229 unpacker = DottedObjectName('json', config=True,
229 unpacker = DottedObjectName('json', config=True,
230 help="""The name of the unpacker for unserializing messages.
230 help="""The name of the unpacker for unserializing messages.
231 Only used with custom functions for `packer`.""")
231 Only used with custom functions for `packer`.""")
232 def _unpacker_changed(self, name, old, new):
232 def _unpacker_changed(self, name, old, new):
233 if new.lower() == 'json':
233 if new.lower() == 'json':
234 self.pack = json_packer
234 self.pack = json_packer
235 self.unpack = json_unpacker
235 self.unpack = json_unpacker
236 elif new.lower() == 'pickle':
236 elif new.lower() == 'pickle':
237 self.pack = pickle_packer
237 self.pack = pickle_packer
238 self.unpack = pickle_unpacker
238 self.unpack = pickle_unpacker
239 else:
239 else:
240 self.unpack = import_item(str(new))
240 self.unpack = import_item(str(new))
241
241
242 session = CBytes(b'', config=True,
242 session = CBytes(b'', config=True,
243 help="""The UUID identifying this session.""")
243 help="""The UUID identifying this session.""")
244 def _session_default(self):
244 def _session_default(self):
245 return bytes(uuid.uuid4())
245 return bytes(uuid.uuid4())
246
246
247 username = Unicode(os.environ.get('USER','username'), config=True,
247 username = Unicode(os.environ.get('USER','username'), config=True,
248 help="""Username for the Session. Default is your system username.""")
248 help="""Username for the Session. Default is your system username.""")
249
249
250 # message signature related traits:
250 # message signature related traits:
251 key = CBytes(b'', config=True,
251 key = CBytes(b'', config=True,
252 help="""execution key, for extra authentication.""")
252 help="""execution key, for extra authentication.""")
253 def _key_changed(self, name, old, new):
253 def _key_changed(self, name, old, new):
254 if new:
254 if new:
255 self.auth = hmac.HMAC(new)
255 self.auth = hmac.HMAC(new)
256 else:
256 else:
257 self.auth = None
257 self.auth = None
258 auth = Instance(hmac.HMAC)
258 auth = Instance(hmac.HMAC)
259 digest_history = Set()
259 digest_history = Set()
260
260
261 keyfile = Unicode('', config=True,
261 keyfile = Unicode('', config=True,
262 help="""path to file containing execution key.""")
262 help="""path to file containing execution key.""")
263 def _keyfile_changed(self, name, old, new):
263 def _keyfile_changed(self, name, old, new):
264 with open(new, 'rb') as f:
264 with open(new, 'rb') as f:
265 self.key = f.read().strip()
265 self.key = f.read().strip()
266
266
267 pack = Any(default_packer) # the actual packer function
267 pack = Any(default_packer) # the actual packer function
268 def _pack_changed(self, name, old, new):
268 def _pack_changed(self, name, old, new):
269 if not callable(new):
269 if not callable(new):
270 raise TypeError("packer must be callable, not %s"%type(new))
270 raise TypeError("packer must be callable, not %s"%type(new))
271
271
272 unpack = Any(default_unpacker) # the actual packer function
272 unpack = Any(default_unpacker) # the actual packer function
273 def _unpack_changed(self, name, old, new):
273 def _unpack_changed(self, name, old, new):
274 # unpacker is not checked - it is assumed to be
274 # unpacker is not checked - it is assumed to be
275 if not callable(new):
275 if not callable(new):
276 raise TypeError("unpacker must be callable, not %s"%type(new))
276 raise TypeError("unpacker must be callable, not %s"%type(new))
277
277
278 def __init__(self, **kwargs):
278 def __init__(self, **kwargs):
279 """create a Session object
279 """create a Session object
280
280
281 Parameters
281 Parameters
282 ----------
282 ----------
283
283
284 debug : bool
284 debug : bool
285 whether to trigger extra debugging statements
285 whether to trigger extra debugging statements
286 packer/unpacker : str : 'json', 'pickle' or import_string
286 packer/unpacker : str : 'json', 'pickle' or import_string
287 importstrings for methods to serialize message parts. If just
287 importstrings for methods to serialize message parts. If just
288 'json' or 'pickle', predefined JSON and pickle packers will be used.
288 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 Otherwise, the entire importstring must be used.
289 Otherwise, the entire importstring must be used.
290
290
291 The functions must accept at least valid JSON input, and output
291 The functions must accept at least valid JSON input, and output
292 *bytes*.
292 *bytes*.
293
293
294 For example, to use msgpack:
294 For example, to use msgpack:
295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 pack/unpack : callables
296 pack/unpack : callables
297 You can also set the pack/unpack callables for serialization
297 You can also set the pack/unpack callables for serialization
298 directly.
298 directly.
299 session : bytes
299 session : bytes
300 the ID of this Session object. The default is to generate a new
300 the ID of this Session object. The default is to generate a new
301 UUID.
301 UUID.
302 username : unicode
302 username : unicode
303 username added to message headers. The default is to ask the OS.
303 username added to message headers. The default is to ask the OS.
304 key : bytes
304 key : bytes
305 The key used to initialize an HMAC signature. If unset, messages
305 The key used to initialize an HMAC signature. If unset, messages
306 will not be signed or checked.
306 will not be signed or checked.
307 keyfile : filepath
307 keyfile : filepath
308 The file containing a key. If this is set, `key` will be
308 The file containing a key. If this is set, `key` will be
309 initialized to the contents of the file.
309 initialized to the contents of the file.
310 """
310 """
311 super(Session, self).__init__(**kwargs)
311 super(Session, self).__init__(**kwargs)
312 self._check_packers()
312 self._check_packers()
313 self.none = self.pack({})
313 self.none = self.pack({})
314
314
315 @property
315 @property
316 def msg_id(self):
316 def msg_id(self):
317 """always return new uuid"""
317 """always return new uuid"""
318 return str(uuid.uuid4())
318 return str(uuid.uuid4())
319
319
320 def _check_packers(self):
320 def _check_packers(self):
321 """check packers for binary data and datetime support."""
321 """check packers for binary data and datetime support."""
322 pack = self.pack
322 pack = self.pack
323 unpack = self.unpack
323 unpack = self.unpack
324
324
325 # check simple serialization
325 # check simple serialization
326 msg = dict(a=[1,'hi'])
326 msg = dict(a=[1,'hi'])
327 try:
327 try:
328 packed = pack(msg)
328 packed = pack(msg)
329 except Exception:
329 except Exception:
330 raise ValueError("packer could not serialize a simple message")
330 raise ValueError("packer could not serialize a simple message")
331
331
332 # ensure packed message is bytes
332 # ensure packed message is bytes
333 if not isinstance(packed, bytes):
333 if not isinstance(packed, bytes):
334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335
335
336 # check that unpack is pack's inverse
336 # check that unpack is pack's inverse
337 try:
337 try:
338 unpacked = unpack(packed)
338 unpacked = unpack(packed)
339 except Exception:
339 except Exception:
340 raise ValueError("unpacker could not handle the packer's output")
340 raise ValueError("unpacker could not handle the packer's output")
341
341
342 # check datetime support
342 # check datetime support
343 msg = dict(t=datetime.now())
343 msg = dict(t=datetime.now())
344 try:
344 try:
345 unpacked = unpack(pack(msg))
345 unpacked = unpack(pack(msg))
346 except Exception:
346 except Exception:
347 self.pack = lambda o: pack(squash_dates(o))
347 self.pack = lambda o: pack(squash_dates(o))
348 self.unpack = lambda s: extract_dates(unpack(s))
348 self.unpack = lambda s: extract_dates(unpack(s))
349
349
350 def msg_header(self, msg_type):
350 def msg_header(self, msg_type):
351 return msg_header(self.msg_id, msg_type, self.username, self.session)
351 return msg_header(self.msg_id, msg_type, self.username, self.session)
352
352
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
354 msg = {}
354 msg = {}
355 msg['header'] = self.msg_header(msg_type)
355 msg['header'] = self.msg_header(msg_type)
356 msg['msg_id'] = msg['header']['msg_id']
356 msg['msg_id'] = msg['header']['msg_id']
357 msg['parent_header'] = {} if parent is None else extract_header(parent)
357 msg['parent_header'] = {} if parent is None else extract_header(parent)
358 msg['msg_type'] = msg_type
358 msg['msg_type'] = msg_type
359 msg['content'] = {} if content is None else content
359 msg['content'] = {} if content is None else content
360 sub = {} if subheader is None else subheader
360 sub = {} if subheader is None else subheader
361 msg['header'].update(sub)
361 msg['header'].update(sub)
362 return msg
362 return msg
363
363
364 def sign(self, msg):
364 def sign(self, msg):
365 """Sign a message with HMAC digest. If no auth, return b''."""
365 """Sign a message with HMAC digest. If no auth, return b''."""
366 if self.auth is None:
366 if self.auth is None:
367 return b''
367 return b''
368 h = self.auth.copy()
368 h = self.auth.copy()
369 for m in msg:
369 for m in msg:
370 h.update(m)
370 h.update(m)
371 return h.hexdigest()
371 return h.hexdigest()
372
372
373 def serialize(self, msg, ident=None):
373 def serialize(self, msg, ident=None):
374 """Serialize the message components to bytes.
374 """Serialize the message components to bytes.
375
375
376 Returns
376 Returns
377 -------
377 -------
378
378
379 list of bytes objects
379 list of bytes objects
380
380
381 """
381 """
382 content = msg.get('content', {})
382 content = msg.get('content', {})
383 if content is None:
383 if content is None:
384 content = self.none
384 content = self.none
385 elif isinstance(content, dict):
385 elif isinstance(content, dict):
386 content = self.pack(content)
386 content = self.pack(content)
387 elif isinstance(content, bytes):
387 elif isinstance(content, bytes):
388 # content is already packed, as in a relayed message
388 # content is already packed, as in a relayed message
389 pass
389 pass
390 elif isinstance(content, unicode):
390 elif isinstance(content, unicode):
391 # should be bytes, but JSON often spits out unicode
391 # should be bytes, but JSON often spits out unicode
392 content = content.encode('utf8')
392 content = content.encode('utf8')
393 else:
393 else:
394 raise TypeError("Content incorrect type: %s"%type(content))
394 raise TypeError("Content incorrect type: %s"%type(content))
395
395
396 real_message = [self.pack(msg['header']),
396 real_message = [self.pack(msg['header']),
397 self.pack(msg['parent_header']),
397 self.pack(msg['parent_header']),
398 content
398 content
399 ]
399 ]
400
400
401 to_send = []
401 to_send = []
402
402
403 if isinstance(ident, list):
403 if isinstance(ident, list):
404 # accept list of idents
404 # accept list of idents
405 to_send.extend(ident)
405 to_send.extend(ident)
406 elif ident is not None:
406 elif ident is not None:
407 to_send.append(ident)
407 to_send.append(ident)
408 to_send.append(DELIM)
408 to_send.append(DELIM)
409
409
410 signature = self.sign(real_message)
410 signature = self.sign(real_message)
411 to_send.append(signature)
411 to_send.append(signature)
412
412
413 to_send.extend(real_message)
413 to_send.extend(real_message)
414
414
415 return to_send
415 return to_send
416
416
417 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
417 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
418 buffers=None, subheader=None, track=False):
418 buffers=None, subheader=None, track=False):
419 """Build and send a message via stream or socket.
419 """Build and send a message via stream or socket.
420
420
421 Parameters
421 Parameters
422 ----------
422 ----------
423
423
424 stream : zmq.Socket or ZMQStream
424 stream : zmq.Socket or ZMQStream
425 the socket-like object used to send the data
425 the socket-like object used to send the data
426 msg_or_type : str or Message/dict
426 msg_or_type : str or Message/dict
427 Normally, msg_or_type will be a msg_type unless a message is being
427 Normally, msg_or_type will be a msg_type unless a message is being
428 sent more than once.
428 sent more than once.
429
429
430 content : dict or None
430 content : dict or None
431 the content of the message (ignored if msg_or_type is a message)
431 the content of the message (ignored if msg_or_type is a message)
432 parent : Message or dict or None
432 parent : Message or dict or None
433 the parent or parent header describing the parent of this message
433 the parent or parent header describing the parent of this message
434 ident : bytes or list of bytes
434 ident : bytes or list of bytes
435 the zmq.IDENTITY routing path
435 the zmq.IDENTITY routing path
436 subheader : dict or None
436 subheader : dict or None
437 extra header keys for this message's header
437 extra header keys for this message's header
438 buffers : list or None
438 buffers : list or None
439 the already-serialized buffers to be appended to the message
439 the already-serialized buffers to be appended to the message
440 track : bool
440 track : bool
441 whether to track. Only for use with Sockets,
441 whether to track. Only for use with Sockets,
442 because ZMQStream objects cannot track messages.
442 because ZMQStream objects cannot track messages.
443
443
444 Returns
444 Returns
445 -------
445 -------
446 msg : message dict
446 msg : message dict
447 the constructed message
447 the constructed message
448 (msg,tracker) : (message dict, MessageTracker)
448 (msg,tracker) : (message dict, MessageTracker)
449 if track=True, then a 2-tuple will be returned,
449 if track=True, then a 2-tuple will be returned,
450 the first element being the constructed
450 the first element being the constructed
451 message, and the second being the MessageTracker
451 message, and the second being the MessageTracker
452
452
453 """
453 """
454
454
455 if not isinstance(stream, (zmq.Socket, ZMQStream)):
455 if not isinstance(stream, (zmq.Socket, ZMQStream)):
456 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
456 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
457 elif track and isinstance(stream, ZMQStream):
457 elif track and isinstance(stream, ZMQStream):
458 raise TypeError("ZMQStream cannot track messages")
458 raise TypeError("ZMQStream cannot track messages")
459
459
460 if isinstance(msg_or_type, (Message, dict)):
460 if isinstance(msg_or_type, (Message, dict)):
461 # we got a Message, not a msg_type
461 # we got a Message, not a msg_type
462 # don't build a new Message
462 # don't build a new Message
463 msg = msg_or_type
463 msg = msg_or_type
464 else:
464 else:
465 msg = self.msg(msg_or_type, content, parent, subheader)
465 msg = self.msg(msg_or_type, content, parent, subheader)
466
466
467 buffers = [] if buffers is None else buffers
467 buffers = [] if buffers is None else buffers
468 to_send = self.serialize(msg, ident)
468 to_send = self.serialize(msg, ident)
469 flag = 0
469 flag = 0
470 if buffers:
470 if buffers:
471 flag = zmq.SNDMORE
471 flag = zmq.SNDMORE
472 _track = False
472 _track = False
473 else:
473 else:
474 _track=track
474 _track=track
475 if track:
475 if track:
476 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
476 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
477 else:
477 else:
478 tracker = stream.send_multipart(to_send, flag, copy=False)
478 tracker = stream.send_multipart(to_send, flag, copy=False)
479 for b in buffers[:-1]:
479 for b in buffers[:-1]:
480 stream.send(b, flag, copy=False)
480 stream.send(b, flag, copy=False)
481 if buffers:
481 if buffers:
482 if track:
482 if track:
483 tracker = stream.send(buffers[-1], copy=False, track=track)
483 tracker = stream.send(buffers[-1], copy=False, track=track)
484 else:
484 else:
485 tracker = stream.send(buffers[-1], copy=False)
485 tracker = stream.send(buffers[-1], copy=False)
486
486
487 # omsg = Message(msg)
487 # omsg = Message(msg)
488 if self.debug:
488 if self.debug:
489 pprint.pprint(msg)
489 pprint.pprint(msg)
490 pprint.pprint(to_send)
490 pprint.pprint(to_send)
491 pprint.pprint(buffers)
491 pprint.pprint(buffers)
492
492
493 msg['tracker'] = tracker
493 msg['tracker'] = tracker
494
494
495 return msg
495 return msg
496
496
497 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
497 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
498 """Send a raw message via ident path.
498 """Send a raw message via ident path.
499
499
500 Parameters
500 Parameters
501 ----------
501 ----------
502 msg : list of sendable buffers"""
502 msg : list of sendable buffers"""
503 to_send = []
503 to_send = []
504 if isinstance(ident, bytes):
504 if isinstance(ident, bytes):
505 ident = [ident]
505 ident = [ident]
506 if ident is not None:
506 if ident is not None:
507 to_send.extend(ident)
507 to_send.extend(ident)
508
508
509 to_send.append(DELIM)
509 to_send.append(DELIM)
510 to_send.append(self.sign(msg))
510 to_send.append(self.sign(msg))
511 to_send.extend(msg)
511 to_send.extend(msg)
512 stream.send_multipart(msg, flags, copy=copy)
512 stream.send_multipart(msg, flags, copy=copy)
513
513
514 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
514 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
515 """receives and unpacks a message
515 """receives and unpacks a message
516 returns [idents], msg"""
516 returns [idents], msg"""
517 if isinstance(socket, ZMQStream):
517 if isinstance(socket, ZMQStream):
518 socket = socket.socket
518 socket = socket.socket
519 try:
519 try:
520 msg = socket.recv_multipart(mode)
520 msg = socket.recv_multipart(mode)
521 except zmq.ZMQError as e:
521 except zmq.ZMQError as e:
522 if e.errno == zmq.EAGAIN:
522 if e.errno == zmq.EAGAIN:
523 # We can convert EAGAIN to None as we know in this case
523 # We can convert EAGAIN to None as we know in this case
524 # recv_multipart won't return None.
524 # recv_multipart won't return None.
525 return None,None
525 return None,None
526 else:
526 else:
527 raise
527 raise
528 # split multipart message into identity list and message dict
528 # split multipart message into identity list and message dict
529 # invalid large messages can cause very expensive string comparisons
529 # invalid large messages can cause very expensive string comparisons
530 idents, msg = self.feed_identities(msg, copy)
530 idents, msg = self.feed_identities(msg, copy)
531 try:
531 try:
532 return idents, self.unpack_message(msg, content=content, copy=copy)
532 return idents, self.unpack_message(msg, content=content, copy=copy)
533 except Exception as e:
533 except Exception as e:
534 print (idents, msg)
534 print (idents, msg)
535 # TODO: handle it
535 # TODO: handle it
536 raise e
536 raise e
537
537
538 def feed_identities(self, msg, copy=True):
538 def feed_identities(self, msg, copy=True):
539 """feed until DELIM is reached, then return the prefix as idents and
539 """feed until DELIM is reached, then return the prefix as idents and
540 remainder as msg. This is easily broken by setting an IDENT to DELIM,
540 remainder as msg. This is easily broken by setting an IDENT to DELIM,
541 but that would be silly.
541 but that would be silly.
542
542
543 Parameters
543 Parameters
544 ----------
544 ----------
545 msg : a list of Message or bytes objects
545 msg : a list of Message or bytes objects
546 the message to be split
546 the message to be split
547 copy : bool
547 copy : bool
548 flag determining whether the arguments are bytes or Messages
548 flag determining whether the arguments are bytes or Messages
549
549
550 Returns
550 Returns
551 -------
551 -------
552 (idents,msg) : two lists
552 (idents,msg) : two lists
553 idents will always be a list of bytes - the indentity prefix
553 idents will always be a list of bytes - the indentity prefix
554 msg will be a list of bytes or Messages, unchanged from input
554 msg will be a list of bytes or Messages, unchanged from input
555 msg should be unpackable via self.unpack_message at this point.
555 msg should be unpackable via self.unpack_message at this point.
556 """
556 """
557 if copy:
557 if copy:
558 idx = msg.index(DELIM)
558 idx = msg.index(DELIM)
559 return msg[:idx], msg[idx+1:]
559 return msg[:idx], msg[idx+1:]
560 else:
560 else:
561 failed = True
561 failed = True
562 for idx,m in enumerate(msg):
562 for idx,m in enumerate(msg):
563 if m.bytes == DELIM:
563 if m.bytes == DELIM:
564 failed = False
564 failed = False
565 break
565 break
566 if failed:
566 if failed:
567 raise ValueError("DELIM not in msg")
567 raise ValueError("DELIM not in msg")
568 idents, msg = msg[:idx], msg[idx+1:]
568 idents, msg = msg[:idx], msg[idx+1:]
569 return [m.bytes for m in idents], msg
569 return [m.bytes for m in idents], msg
570
570
571 def unpack_message(self, msg, content=True, copy=True):
571 def unpack_message(self, msg, content=True, copy=True):
572 """Return a message object from the format
572 """Return a message object from the format
573 sent by self.send.
573 sent by self.send.
574
574
575 Parameters:
575 Parameters:
576 -----------
576 -----------
577
577
578 content : bool (True)
578 content : bool (True)
579 whether to unpack the content dict (True),
579 whether to unpack the content dict (True),
580 or leave it serialized (False)
580 or leave it serialized (False)
581
581
582 copy : bool (True)
582 copy : bool (True)
583 whether to return the bytes (True),
583 whether to return the bytes (True),
584 or the non-copying Message object in each place (False)
584 or the non-copying Message object in each place (False)
585
585
586 """
586 """
587 minlen = 4
587 minlen = 4
588 message = {}
588 message = {}
589 if not copy:
589 if not copy:
590 for i in range(minlen):
590 for i in range(minlen):
591 msg[i] = msg[i].bytes
591 msg[i] = msg[i].bytes
592 if self.auth is not None:
592 if self.auth is not None:
593 signature = msg[0]
593 signature = msg[0]
594 if signature in self.digest_history:
594 if signature in self.digest_history:
595 raise ValueError("Duplicate Signature: %r"%signature)
595 raise ValueError("Duplicate Signature: %r"%signature)
596 self.digest_history.add(signature)
596 self.digest_history.add(signature)
597 check = self.sign(msg[1:4])
597 check = self.sign(msg[1:4])
598 if not signature == check:
598 if not signature == check:
599 raise ValueError("Invalid Signature: %r"%signature)
599 raise ValueError("Invalid Signature: %r"%signature)
600 if not len(msg) >= minlen:
600 if not len(msg) >= minlen:
601 raise TypeError("malformed message, must have at least %i elements"%minlen)
601 raise TypeError("malformed message, must have at least %i elements"%minlen)
602 message['header'] = self.unpack(msg[1])
602 message['header'] = self.unpack(msg[1])
603 message['msg_type'] = message['header']['msg_type']
603 message['msg_type'] = message['header']['msg_type']
604 message['parent_header'] = self.unpack(msg[2])
604 message['parent_header'] = self.unpack(msg[2])
605 if content:
605 if content:
606 message['content'] = self.unpack(msg[3])
606 message['content'] = self.unpack(msg[3])
607 else:
607 else:
608 message['content'] = msg[3]
608 message['content'] = msg[3]
609
609
610 message['buffers'] = msg[4:]
610 message['buffers'] = msg[4:]
611 return message
611 return message
612
612
613 def test_msg2obj():
613 def test_msg2obj():
614 am = dict(x=1)
614 am = dict(x=1)
615 ao = Message(am)
615 ao = Message(am)
616 assert ao.x == am['x']
616 assert ao.x == am['x']
617
617
618 am['y'] = dict(z=1)
618 am['y'] = dict(z=1)
619 ao = Message(am)
619 ao = Message(am)
620 assert ao.y.z == am['y']['z']
620 assert ao.y.z == am['y']['z']
621
621
622 k1, k2 = 'y', 'z'
622 k1, k2 = 'y', 'z'
623 assert ao[k1][k2] == am[k1][k2]
623 assert ao[k1][k2] == am[k1][k2]
624
624
625 am2 = dict(ao)
625 am2 = dict(ao)
626 assert am['x'] == am2['x']
626 assert am['x'] == am2['x']
627 assert am['y']['z'] == am2['y']['z']
627 assert am['y']['z'] == am2['y']['z']
628
628
General Comments 0
You need to be logged in to leave comments. Login now