##// END OF EJS Templates
cleanup per review...
MinRK -
Show More
@@ -1,422 +1,422 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 from __future__ import with_statement
24 from __future__ import with_statement
25
25
26 import os
26 import os
27 import socket
27 import socket
28 import stat
28 import stat
29 import sys
29 import sys
30 import uuid
30 import uuid
31
31
32 from multiprocessing import Process
32 from multiprocessing import Process
33
33
34 import zmq
34 import zmq
35 from zmq.devices import ProcessMonitoredQueue
35 from zmq.devices import ProcessMonitoredQueue
36 from zmq.log.handlers import PUBHandler
36 from zmq.log.handlers import PUBHandler
37 from zmq.utils import jsonapi as json
37 from zmq.utils import jsonapi as json
38
38
39 from IPython.config.application import boolean_flag
39 from IPython.config.application import boolean_flag
40 from IPython.core.profiledir import ProfileDir
40 from IPython.core.profiledir import ProfileDir
41
41
42 from IPython.parallel.apps.baseapp import (
42 from IPython.parallel.apps.baseapp import (
43 BaseParallelApplication,
43 BaseParallelApplication,
44 base_aliases,
44 base_aliases,
45 base_flags,
45 base_flags,
46 )
46 )
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
49
49
50 # from IPython.parallel.controller.controller import ControllerFactory
50 # from IPython.parallel.controller.controller import ControllerFactory
51 from IPython.zmq.session import Session
51 from IPython.zmq.session import Session
52 from IPython.parallel.controller.heartmonitor import HeartMonitor
52 from IPython.parallel.controller.heartmonitor import HeartMonitor
53 from IPython.parallel.controller.hub import HubFactory
53 from IPython.parallel.controller.hub import HubFactory
54 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
54 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
55 from IPython.parallel.controller.sqlitedb import SQLiteDB
55 from IPython.parallel.controller.sqlitedb import SQLiteDB
56
56
57 from IPython.parallel.util import signal_children, split_url, ensure_bytes
57 from IPython.parallel.util import signal_children, split_url, asbytes
58
58
59 # conditional import of MongoDB backend class
59 # conditional import of MongoDB backend class
60
60
61 try:
61 try:
62 from IPython.parallel.controller.mongodb import MongoDB
62 from IPython.parallel.controller.mongodb import MongoDB
63 except ImportError:
63 except ImportError:
64 maybe_mongo = []
64 maybe_mongo = []
65 else:
65 else:
66 maybe_mongo = [MongoDB]
66 maybe_mongo = [MongoDB]
67
67
68
68
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70 # Module level variables
70 # Module level variables
71 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
72
72
73
73
74 #: The default config file name for this application
74 #: The default config file name for this application
75 default_config_file_name = u'ipcontroller_config.py'
75 default_config_file_name = u'ipcontroller_config.py'
76
76
77
77
78 _description = """Start the IPython controller for parallel computing.
78 _description = """Start the IPython controller for parallel computing.
79
79
80 The IPython controller provides a gateway between the IPython engines and
80 The IPython controller provides a gateway between the IPython engines and
81 clients. The controller needs to be started before the engines and can be
81 clients. The controller needs to be started before the engines and can be
82 configured using command line options or using a cluster directory. Cluster
82 configured using command line options or using a cluster directory. Cluster
83 directories contain config, log and security files and are usually located in
83 directories contain config, log and security files and are usually located in
84 your ipython directory and named as "profile_name". See the `profile`
84 your ipython directory and named as "profile_name". See the `profile`
85 and `profile_dir` options for details.
85 and `profile_dir` options for details.
86 """
86 """
87
87
88
88
89
89
90
90
91 #-----------------------------------------------------------------------------
91 #-----------------------------------------------------------------------------
92 # The main application
92 # The main application
93 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
94 flags = {}
94 flags = {}
95 flags.update(base_flags)
95 flags.update(base_flags)
96 flags.update({
96 flags.update({
97 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
97 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
98 'Use threads instead of processes for the schedulers'),
98 'Use threads instead of processes for the schedulers'),
99 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
99 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
100 'use the SQLiteDB backend'),
100 'use the SQLiteDB backend'),
101 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
101 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
102 'use the MongoDB backend'),
102 'use the MongoDB backend'),
103 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
103 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
104 'use the in-memory DictDB backend'),
104 'use the in-memory DictDB backend'),
105 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
105 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
106 'reuse existing json connection files')
106 'reuse existing json connection files')
107 })
107 })
108
108
109 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
109 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
110 "Use HMAC digests for authentication of messages.",
110 "Use HMAC digests for authentication of messages.",
111 "Don't authenticate messages."
111 "Don't authenticate messages."
112 ))
112 ))
113 aliases = dict(
113 aliases = dict(
114 reuse_files = 'IPControllerApp.reuse_files',
114 reuse_files = 'IPControllerApp.reuse_files',
115 secure = 'IPControllerApp.secure',
115 secure = 'IPControllerApp.secure',
116 ssh = 'IPControllerApp.ssh_server',
116 ssh = 'IPControllerApp.ssh_server',
117 use_threads = 'IPControllerApp.use_threads',
117 use_threads = 'IPControllerApp.use_threads',
118 location = 'IPControllerApp.location',
118 location = 'IPControllerApp.location',
119
119
120 ident = 'Session.session',
120 ident = 'Session.session',
121 user = 'Session.username',
121 user = 'Session.username',
122 exec_key = 'Session.keyfile',
122 exec_key = 'Session.keyfile',
123
123
124 url = 'HubFactory.url',
124 url = 'HubFactory.url',
125 ip = 'HubFactory.ip',
125 ip = 'HubFactory.ip',
126 transport = 'HubFactory.transport',
126 transport = 'HubFactory.transport',
127 port = 'HubFactory.regport',
127 port = 'HubFactory.regport',
128
128
129 ping = 'HeartMonitor.period',
129 ping = 'HeartMonitor.period',
130
130
131 scheme = 'TaskScheduler.scheme_name',
131 scheme = 'TaskScheduler.scheme_name',
132 hwm = 'TaskScheduler.hwm',
132 hwm = 'TaskScheduler.hwm',
133 )
133 )
134 aliases.update(base_aliases)
134 aliases.update(base_aliases)
135
135
136 class IPControllerApp(BaseParallelApplication):
136 class IPControllerApp(BaseParallelApplication):
137
137
138 name = u'ipcontroller'
138 name = u'ipcontroller'
139 description = _description
139 description = _description
140 config_file_name = Unicode(default_config_file_name)
140 config_file_name = Unicode(default_config_file_name)
141 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
141 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
142
142
143 # change default to True
143 # change default to True
144 auto_create = Bool(True, config=True,
144 auto_create = Bool(True, config=True,
145 help="""Whether to create profile dir if it doesn't exist.""")
145 help="""Whether to create profile dir if it doesn't exist.""")
146
146
147 reuse_files = Bool(False, config=True,
147 reuse_files = Bool(False, config=True,
148 help='Whether to reuse existing json connection files.'
148 help='Whether to reuse existing json connection files.'
149 )
149 )
150 secure = Bool(True, config=True,
150 secure = Bool(True, config=True,
151 help='Whether to use HMAC digests for extra message authentication.'
151 help='Whether to use HMAC digests for extra message authentication.'
152 )
152 )
153 ssh_server = Unicode(u'', config=True,
153 ssh_server = Unicode(u'', config=True,
154 help="""ssh url for clients to use when connecting to the Controller
154 help="""ssh url for clients to use when connecting to the Controller
155 processes. It should be of the form: [user@]server[:port]. The
155 processes. It should be of the form: [user@]server[:port]. The
156 Controller's listening addresses must be accessible from the ssh server""",
156 Controller's listening addresses must be accessible from the ssh server""",
157 )
157 )
158 location = Unicode(u'', config=True,
158 location = Unicode(u'', config=True,
159 help="""The external IP or domain name of the Controller, used for disambiguating
159 help="""The external IP or domain name of the Controller, used for disambiguating
160 engine and client connections.""",
160 engine and client connections.""",
161 )
161 )
162 import_statements = List([], config=True,
162 import_statements = List([], config=True,
163 help="import statements to be run at startup. Necessary in some environments"
163 help="import statements to be run at startup. Necessary in some environments"
164 )
164 )
165
165
166 use_threads = Bool(False, config=True,
166 use_threads = Bool(False, config=True,
167 help='Use threads instead of processes for the schedulers',
167 help='Use threads instead of processes for the schedulers',
168 )
168 )
169
169
170 # internal
170 # internal
171 children = List()
171 children = List()
172 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
172 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
173
173
174 def _use_threads_changed(self, name, old, new):
174 def _use_threads_changed(self, name, old, new):
175 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
175 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
176
176
177 aliases = Dict(aliases)
177 aliases = Dict(aliases)
178 flags = Dict(flags)
178 flags = Dict(flags)
179
179
180
180
181 def save_connection_dict(self, fname, cdict):
181 def save_connection_dict(self, fname, cdict):
182 """save a connection dict to json file."""
182 """save a connection dict to json file."""
183 c = self.config
183 c = self.config
184 url = cdict['url']
184 url = cdict['url']
185 location = cdict['location']
185 location = cdict['location']
186 if not location:
186 if not location:
187 try:
187 try:
188 proto,ip,port = split_url(url)
188 proto,ip,port = split_url(url)
189 except AssertionError:
189 except AssertionError:
190 pass
190 pass
191 else:
191 else:
192 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
192 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
193 cdict['location'] = location
193 cdict['location'] = location
194 fname = os.path.join(self.profile_dir.security_dir, fname)
194 fname = os.path.join(self.profile_dir.security_dir, fname)
195 with open(fname, 'wb') as f:
195 with open(fname, 'wb') as f:
196 f.write(json.dumps(cdict, indent=2))
196 f.write(json.dumps(cdict, indent=2))
197 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
197 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
198
198
199 def load_config_from_json(self):
199 def load_config_from_json(self):
200 """load config from existing json connector files."""
200 """load config from existing json connector files."""
201 c = self.config
201 c = self.config
202 # load from engine config
202 # load from engine config
203 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
204 cfg = json.loads(f.read())
204 cfg = json.loads(f.read())
205 key = c.Session.key = ensure_bytes(cfg['exec_key'])
205 key = c.Session.key = asbytes(cfg['exec_key'])
206 xport,addr = cfg['url'].split('://')
206 xport,addr = cfg['url'].split('://')
207 c.HubFactory.engine_transport = xport
207 c.HubFactory.engine_transport = xport
208 ip,ports = addr.split(':')
208 ip,ports = addr.split(':')
209 c.HubFactory.engine_ip = ip
209 c.HubFactory.engine_ip = ip
210 c.HubFactory.regport = int(ports)
210 c.HubFactory.regport = int(ports)
211 self.location = cfg['location']
211 self.location = cfg['location']
212 # load client config
212 # load client config
213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 cfg = json.loads(f.read())
214 cfg = json.loads(f.read())
215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 xport,addr = cfg['url'].split('://')
216 xport,addr = cfg['url'].split('://')
217 c.HubFactory.client_transport = xport
217 c.HubFactory.client_transport = xport
218 ip,ports = addr.split(':')
218 ip,ports = addr.split(':')
219 c.HubFactory.client_ip = ip
219 c.HubFactory.client_ip = ip
220 self.ssh_server = cfg['ssh']
220 self.ssh_server = cfg['ssh']
221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222
222
223 def init_hub(self):
223 def init_hub(self):
224 c = self.config
224 c = self.config
225
225
226 self.do_import_statements()
226 self.do_import_statements()
227 reusing = self.reuse_files
227 reusing = self.reuse_files
228 if reusing:
228 if reusing:
229 try:
229 try:
230 self.load_config_from_json()
230 self.load_config_from_json()
231 except (AssertionError,IOError):
231 except (AssertionError,IOError):
232 reusing=False
232 reusing=False
233 # check again, because reusing may have failed:
233 # check again, because reusing may have failed:
234 if reusing:
234 if reusing:
235 pass
235 pass
236 elif self.secure:
236 elif self.secure:
237 key = str(uuid.uuid4())
237 key = str(uuid.uuid4())
238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 # with open(keyfile, 'w') as f:
239 # with open(keyfile, 'w') as f:
240 # f.write(key)
240 # f.write(key)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.Session.key = ensure_bytes(key)
242 c.Session.key = asbytes(key)
243 else:
243 else:
244 key = c.Session.key = b''
244 key = c.Session.key = b''
245
245
246 try:
246 try:
247 self.factory = HubFactory(config=c, log=self.log)
247 self.factory = HubFactory(config=c, log=self.log)
248 # self.start_logging()
248 # self.start_logging()
249 self.factory.init_hub()
249 self.factory.init_hub()
250 except:
250 except:
251 self.log.error("Couldn't construct the Controller", exc_info=True)
251 self.log.error("Couldn't construct the Controller", exc_info=True)
252 self.exit(1)
252 self.exit(1)
253
253
254 if not reusing:
254 if not reusing:
255 # save to new json config files
255 # save to new json config files
256 f = self.factory
256 f = self.factory
257 cdict = {'exec_key' : key,
257 cdict = {'exec_key' : key,
258 'ssh' : self.ssh_server,
258 'ssh' : self.ssh_server,
259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 'location' : self.location
260 'location' : self.location
261 }
261 }
262 self.save_connection_dict('ipcontroller-client.json', cdict)
262 self.save_connection_dict('ipcontroller-client.json', cdict)
263 edict = cdict
263 edict = cdict
264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 self.save_connection_dict('ipcontroller-engine.json', edict)
265 self.save_connection_dict('ipcontroller-engine.json', edict)
266
266
267 #
267 #
268 def init_schedulers(self):
268 def init_schedulers(self):
269 children = self.children
269 children = self.children
270 mq = import_item(str(self.mq_class))
270 mq = import_item(str(self.mq_class))
271
271
272 hub = self.factory
272 hub = self.factory
273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 # IOPub relay (in a Process)
274 # IOPub relay (in a Process)
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
276 q.bind_in(hub.client_info['iopub'])
276 q.bind_in(hub.client_info['iopub'])
277 q.bind_out(hub.engine_info['iopub'])
277 q.bind_out(hub.engine_info['iopub'])
278 q.setsockopt_out(zmq.SUBSCRIBE, b'')
278 q.setsockopt_out(zmq.SUBSCRIBE, b'')
279 q.connect_mon(hub.monitor_url)
279 q.connect_mon(hub.monitor_url)
280 q.daemon=True
280 q.daemon=True
281 children.append(q)
281 children.append(q)
282
282
283 # Multiplexer Queue (in a Process)
283 # Multiplexer Queue (in a Process)
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'in', b'out')
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'in', b'out')
285 q.bind_in(hub.client_info['mux'])
285 q.bind_in(hub.client_info['mux'])
286 q.setsockopt_in(zmq.IDENTITY, b'mux')
286 q.setsockopt_in(zmq.IDENTITY, b'mux')
287 q.bind_out(hub.engine_info['mux'])
287 q.bind_out(hub.engine_info['mux'])
288 q.connect_mon(hub.monitor_url)
288 q.connect_mon(hub.monitor_url)
289 q.daemon=True
289 q.daemon=True
290 children.append(q)
290 children.append(q)
291
291
292 # Control Queue (in a Process)
292 # Control Queue (in a Process)
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'incontrol', b'outcontrol')
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'incontrol', b'outcontrol')
294 q.bind_in(hub.client_info['control'])
294 q.bind_in(hub.client_info['control'])
295 q.setsockopt_in(zmq.IDENTITY, b'control')
295 q.setsockopt_in(zmq.IDENTITY, b'control')
296 q.bind_out(hub.engine_info['control'])
296 q.bind_out(hub.engine_info['control'])
297 q.connect_mon(hub.monitor_url)
297 q.connect_mon(hub.monitor_url)
298 q.daemon=True
298 q.daemon=True
299 children.append(q)
299 children.append(q)
300 try:
300 try:
301 scheme = self.config.TaskScheduler.scheme_name
301 scheme = self.config.TaskScheduler.scheme_name
302 except AttributeError:
302 except AttributeError:
303 scheme = TaskScheduler.scheme_name.get_default_value()
303 scheme = TaskScheduler.scheme_name.get_default_value()
304 # Task Queue (in a Process)
304 # Task Queue (in a Process)
305 if scheme == 'pure':
305 if scheme == 'pure':
306 self.log.warn("task::using pure XREQ Task scheduler")
306 self.log.warn("task::using pure XREQ Task scheduler")
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, b'intask', b'outtask')
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, b'intask', b'outtask')
308 # q.setsockopt_out(zmq.HWM, hub.hwm)
308 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 q.bind_in(hub.client_info['task'][1])
309 q.bind_in(hub.client_info['task'][1])
310 q.setsockopt_in(zmq.IDENTITY, b'task')
310 q.setsockopt_in(zmq.IDENTITY, b'task')
311 q.bind_out(hub.engine_info['task'])
311 q.bind_out(hub.engine_info['task'])
312 q.connect_mon(hub.monitor_url)
312 q.connect_mon(hub.monitor_url)
313 q.daemon=True
313 q.daemon=True
314 children.append(q)
314 children.append(q)
315 elif scheme == 'none':
315 elif scheme == 'none':
316 self.log.warn("task::using no Task scheduler")
316 self.log.warn("task::using no Task scheduler")
317
317
318 else:
318 else:
319 self.log.info("task::using Python %s Task scheduler"%scheme)
319 self.log.info("task::using Python %s Task scheduler"%scheme)
320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 hub.monitor_url, hub.client_info['notification'])
321 hub.monitor_url, hub.client_info['notification'])
322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 log_url = self.log_url, config=dict(self.config))
323 log_url = self.log_url, config=dict(self.config))
324 if 'Process' in self.mq_class:
324 if 'Process' in self.mq_class:
325 # run the Python scheduler in a Process
325 # run the Python scheduler in a Process
326 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
326 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
327 q.daemon=True
327 q.daemon=True
328 children.append(q)
328 children.append(q)
329 else:
329 else:
330 # single-threaded Controller
330 # single-threaded Controller
331 kwargs['in_thread'] = True
331 kwargs['in_thread'] = True
332 launch_scheduler(*sargs, **kwargs)
332 launch_scheduler(*sargs, **kwargs)
333
333
334
334
335 def save_urls(self):
335 def save_urls(self):
336 """save the registration urls to files."""
336 """save the registration urls to files."""
337 c = self.config
337 c = self.config
338
338
339 sec_dir = self.profile_dir.security_dir
339 sec_dir = self.profile_dir.security_dir
340 cf = self.factory
340 cf = self.factory
341
341
342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
344
344
345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
347
347
348
348
349 def do_import_statements(self):
349 def do_import_statements(self):
350 statements = self.import_statements
350 statements = self.import_statements
351 for s in statements:
351 for s in statements:
352 try:
352 try:
353 self.log.msg("Executing statement: '%s'" % s)
353 self.log.msg("Executing statement: '%s'" % s)
354 exec s in globals(), locals()
354 exec s in globals(), locals()
355 except:
355 except:
356 self.log.msg("Error running statement: %s" % s)
356 self.log.msg("Error running statement: %s" % s)
357
357
358 def forward_logging(self):
358 def forward_logging(self):
359 if self.log_url:
359 if self.log_url:
360 self.log.info("Forwarding logging to %s"%self.log_url)
360 self.log.info("Forwarding logging to %s"%self.log_url)
361 context = zmq.Context.instance()
361 context = zmq.Context.instance()
362 lsock = context.socket(zmq.PUB)
362 lsock = context.socket(zmq.PUB)
363 lsock.connect(self.log_url)
363 lsock.connect(self.log_url)
364 handler = PUBHandler(lsock)
364 handler = PUBHandler(lsock)
365 self.log.removeHandler(self._log_handler)
365 self.log.removeHandler(self._log_handler)
366 handler.root_topic = 'controller'
366 handler.root_topic = 'controller'
367 handler.setLevel(self.log_level)
367 handler.setLevel(self.log_level)
368 self.log.addHandler(handler)
368 self.log.addHandler(handler)
369 self._log_handler = handler
369 self._log_handler = handler
370 # #
370 # #
371
371
372 def initialize(self, argv=None):
372 def initialize(self, argv=None):
373 super(IPControllerApp, self).initialize(argv)
373 super(IPControllerApp, self).initialize(argv)
374 self.forward_logging()
374 self.forward_logging()
375 self.init_hub()
375 self.init_hub()
376 self.init_schedulers()
376 self.init_schedulers()
377
377
378 def start(self):
378 def start(self):
379 # Start the subprocesses:
379 # Start the subprocesses:
380 self.factory.start()
380 self.factory.start()
381 child_procs = []
381 child_procs = []
382 for child in self.children:
382 for child in self.children:
383 child.start()
383 child.start()
384 if isinstance(child, ProcessMonitoredQueue):
384 if isinstance(child, ProcessMonitoredQueue):
385 child_procs.append(child.launcher)
385 child_procs.append(child.launcher)
386 elif isinstance(child, Process):
386 elif isinstance(child, Process):
387 child_procs.append(child)
387 child_procs.append(child)
388 if child_procs:
388 if child_procs:
389 signal_children(child_procs)
389 signal_children(child_procs)
390
390
391 self.write_pid_file(overwrite=True)
391 self.write_pid_file(overwrite=True)
392
392
393 try:
393 try:
394 self.factory.loop.start()
394 self.factory.loop.start()
395 except KeyboardInterrupt:
395 except KeyboardInterrupt:
396 self.log.critical("Interrupted, Exiting...\n")
396 self.log.critical("Interrupted, Exiting...\n")
397
397
398
398
399
399
400 def launch_new_instance():
400 def launch_new_instance():
401 """Create and run the IPython controller"""
401 """Create and run the IPython controller"""
402 if sys.platform == 'win32':
402 if sys.platform == 'win32':
403 # make sure we don't get called from a multiprocessing subprocess
403 # make sure we don't get called from a multiprocessing subprocess
404 # this can result in infinite Controllers being started on Windows
404 # this can result in infinite Controllers being started on Windows
405 # which doesn't have a proper fork, so multiprocessing is wonky
405 # which doesn't have a proper fork, so multiprocessing is wonky
406
406
407 # this only comes up when IPython has been installed using vanilla
407 # this only comes up when IPython has been installed using vanilla
408 # setuptools, and *not* distribute.
408 # setuptools, and *not* distribute.
409 import multiprocessing
409 import multiprocessing
410 p = multiprocessing.current_process()
410 p = multiprocessing.current_process()
411 # the main process has name 'MainProcess'
411 # the main process has name 'MainProcess'
412 # subprocesses will have names like 'Process-1'
412 # subprocesses will have names like 'Process-1'
413 if p.name != 'MainProcess':
413 if p.name != 'MainProcess':
414 # we are a subprocess, don't start another Controller!
414 # we are a subprocess, don't start another Controller!
415 return
415 return
416 app = IPControllerApp.instance()
416 app = IPControllerApp.instance()
417 app.initialize()
417 app.initialize()
418 app.start()
418 app.start()
419
419
420
420
421 if __name__ == '__main__':
421 if __name__ == '__main__':
422 launch_new_instance()
422 launch_new_instance()
@@ -1,301 +1,301 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython engine application
4 The IPython engine application
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import json
24 import json
25 import os
25 import os
26 import sys
26 import sys
27 import time
27 import time
28
28
29 import zmq
29 import zmq
30 from zmq.eventloop import ioloop
30 from zmq.eventloop import ioloop
31
31
32 from IPython.core.profiledir import ProfileDir
32 from IPython.core.profiledir import ProfileDir
33 from IPython.parallel.apps.baseapp import (
33 from IPython.parallel.apps.baseapp import (
34 BaseParallelApplication,
34 BaseParallelApplication,
35 base_aliases,
35 base_aliases,
36 base_flags,
36 base_flags,
37 )
37 )
38 from IPython.zmq.log import EnginePUBHandler
38 from IPython.zmq.log import EnginePUBHandler
39
39
40 from IPython.config.configurable import Configurable
40 from IPython.config.configurable import Configurable
41 from IPython.zmq.session import Session
41 from IPython.zmq.session import Session
42 from IPython.parallel.engine.engine import EngineFactory
42 from IPython.parallel.engine.engine import EngineFactory
43 from IPython.parallel.engine.streamkernel import Kernel
43 from IPython.parallel.engine.streamkernel import Kernel
44 from IPython.parallel.util import disambiguate_url, ensure_bytes
44 from IPython.parallel.util import disambiguate_url, asbytes
45
45
46 from IPython.utils.importstring import import_item
46 from IPython.utils.importstring import import_item
47 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float
47 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float
48
48
49
49
50 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
51 # Module level variables
51 # Module level variables
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53
53
54 #: The default config file name for this application
54 #: The default config file name for this application
55 default_config_file_name = u'ipengine_config.py'
55 default_config_file_name = u'ipengine_config.py'
56
56
57 _description = """Start an IPython engine for parallel computing.
57 _description = """Start an IPython engine for parallel computing.
58
58
59 IPython engines run in parallel and perform computations on behalf of a client
59 IPython engines run in parallel and perform computations on behalf of a client
60 and controller. A controller needs to be started before the engines. The
60 and controller. A controller needs to be started before the engines. The
61 engine can be configured using command line options or using a cluster
61 engine can be configured using command line options or using a cluster
62 directory. Cluster directories contain config, log and security files and are
62 directory. Cluster directories contain config, log and security files and are
63 usually located in your ipython directory and named as "profile_name".
63 usually located in your ipython directory and named as "profile_name".
64 See the `profile` and `profile_dir` options for details.
64 See the `profile` and `profile_dir` options for details.
65 """
65 """
66
66
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # MPI configuration
69 # MPI configuration
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71
71
72 mpi4py_init = """from mpi4py import MPI as mpi
72 mpi4py_init = """from mpi4py import MPI as mpi
73 mpi.size = mpi.COMM_WORLD.Get_size()
73 mpi.size = mpi.COMM_WORLD.Get_size()
74 mpi.rank = mpi.COMM_WORLD.Get_rank()
74 mpi.rank = mpi.COMM_WORLD.Get_rank()
75 """
75 """
76
76
77
77
78 pytrilinos_init = """from PyTrilinos import Epetra
78 pytrilinos_init = """from PyTrilinos import Epetra
79 class SimpleStruct:
79 class SimpleStruct:
80 pass
80 pass
81 mpi = SimpleStruct()
81 mpi = SimpleStruct()
82 mpi.rank = 0
82 mpi.rank = 0
83 mpi.size = 0
83 mpi.size = 0
84 """
84 """
85
85
86 class MPI(Configurable):
86 class MPI(Configurable):
87 """Configurable for MPI initialization"""
87 """Configurable for MPI initialization"""
88 use = Unicode('', config=True,
88 use = Unicode('', config=True,
89 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
89 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
90 )
90 )
91
91
92 def _on_use_changed(self, old, new):
92 def _on_use_changed(self, old, new):
93 # load default init script if it's not set
93 # load default init script if it's not set
94 if not self.init_script:
94 if not self.init_script:
95 self.init_script = self.default_inits.get(new, '')
95 self.init_script = self.default_inits.get(new, '')
96
96
97 init_script = Unicode('', config=True,
97 init_script = Unicode('', config=True,
98 help="Initialization code for MPI")
98 help="Initialization code for MPI")
99
99
100 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
100 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
101 config=True)
101 config=True)
102
102
103
103
104 #-----------------------------------------------------------------------------
104 #-----------------------------------------------------------------------------
105 # Main application
105 # Main application
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 aliases = dict(
107 aliases = dict(
108 file = 'IPEngineApp.url_file',
108 file = 'IPEngineApp.url_file',
109 c = 'IPEngineApp.startup_command',
109 c = 'IPEngineApp.startup_command',
110 s = 'IPEngineApp.startup_script',
110 s = 'IPEngineApp.startup_script',
111
111
112 ident = 'Session.session',
112 ident = 'Session.session',
113 user = 'Session.username',
113 user = 'Session.username',
114 exec_key = 'Session.keyfile',
114 exec_key = 'Session.keyfile',
115
115
116 url = 'EngineFactory.url',
116 url = 'EngineFactory.url',
117 ip = 'EngineFactory.ip',
117 ip = 'EngineFactory.ip',
118 transport = 'EngineFactory.transport',
118 transport = 'EngineFactory.transport',
119 port = 'EngineFactory.regport',
119 port = 'EngineFactory.regport',
120 location = 'EngineFactory.location',
120 location = 'EngineFactory.location',
121
121
122 timeout = 'EngineFactory.timeout',
122 timeout = 'EngineFactory.timeout',
123
123
124 mpi = 'MPI.use',
124 mpi = 'MPI.use',
125
125
126 )
126 )
127 aliases.update(base_aliases)
127 aliases.update(base_aliases)
128
128
129 class IPEngineApp(BaseParallelApplication):
129 class IPEngineApp(BaseParallelApplication):
130
130
131 name = Unicode(u'ipengine')
131 name = Unicode(u'ipengine')
132 description = Unicode(_description)
132 description = Unicode(_description)
133 config_file_name = Unicode(default_config_file_name)
133 config_file_name = Unicode(default_config_file_name)
134 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
134 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
135
135
136 startup_script = Unicode(u'', config=True,
136 startup_script = Unicode(u'', config=True,
137 help='specify a script to be run at startup')
137 help='specify a script to be run at startup')
138 startup_command = Unicode('', config=True,
138 startup_command = Unicode('', config=True,
139 help='specify a command to be run at startup')
139 help='specify a command to be run at startup')
140
140
141 url_file = Unicode(u'', config=True,
141 url_file = Unicode(u'', config=True,
142 help="""The full location of the file containing the connection information for
142 help="""The full location of the file containing the connection information for
143 the controller. If this is not given, the file must be in the
143 the controller. If this is not given, the file must be in the
144 security directory of the cluster directory. This location is
144 security directory of the cluster directory. This location is
145 resolved using the `profile` or `profile_dir` options.""",
145 resolved using the `profile` or `profile_dir` options.""",
146 )
146 )
147 wait_for_url_file = Float(5, config=True,
147 wait_for_url_file = Float(5, config=True,
148 help="""The maximum number of seconds to wait for url_file to exist.
148 help="""The maximum number of seconds to wait for url_file to exist.
149 This is useful for batch-systems and shared-filesystems where the
149 This is useful for batch-systems and shared-filesystems where the
150 controller and engine are started at the same time and it
150 controller and engine are started at the same time and it
151 may take a moment for the controller to write the connector files.""")
151 may take a moment for the controller to write the connector files.""")
152
152
153 url_file_name = Unicode(u'ipcontroller-engine.json')
153 url_file_name = Unicode(u'ipcontroller-engine.json')
154 log_url = Unicode('', config=True,
154 log_url = Unicode('', config=True,
155 help="""The URL for the iploggerapp instance, for forwarding
155 help="""The URL for the iploggerapp instance, for forwarding
156 logging to a central location.""")
156 logging to a central location.""")
157
157
158 aliases = Dict(aliases)
158 aliases = Dict(aliases)
159
159
160 # def find_key_file(self):
160 # def find_key_file(self):
161 # """Set the key file.
161 # """Set the key file.
162 #
162 #
163 # Here we don't try to actually see if it exists for is valid as that
163 # Here we don't try to actually see if it exists for is valid as that
164 # is hadled by the connection logic.
164 # is hadled by the connection logic.
165 # """
165 # """
166 # config = self.master_config
166 # config = self.master_config
167 # # Find the actual controller key file
167 # # Find the actual controller key file
168 # if not config.Global.key_file:
168 # if not config.Global.key_file:
169 # try_this = os.path.join(
169 # try_this = os.path.join(
170 # config.Global.profile_dir,
170 # config.Global.profile_dir,
171 # config.Global.security_dir,
171 # config.Global.security_dir,
172 # config.Global.key_file_name
172 # config.Global.key_file_name
173 # )
173 # )
174 # config.Global.key_file = try_this
174 # config.Global.key_file = try_this
175
175
176 def find_url_file(self):
176 def find_url_file(self):
177 """Set the url file.
177 """Set the url file.
178
178
179 Here we don't try to actually see if it exists for is valid as that
179 Here we don't try to actually see if it exists for is valid as that
180 is hadled by the connection logic.
180 is hadled by the connection logic.
181 """
181 """
182 config = self.config
182 config = self.config
183 # Find the actual controller key file
183 # Find the actual controller key file
184 if not self.url_file:
184 if not self.url_file:
185 self.url_file = os.path.join(
185 self.url_file = os.path.join(
186 self.profile_dir.security_dir,
186 self.profile_dir.security_dir,
187 self.url_file_name
187 self.url_file_name
188 )
188 )
189 def init_engine(self):
189 def init_engine(self):
190 # This is the working dir by now.
190 # This is the working dir by now.
191 sys.path.insert(0, '')
191 sys.path.insert(0, '')
192 config = self.config
192 config = self.config
193 # print config
193 # print config
194 self.find_url_file()
194 self.find_url_file()
195
195
196 # was the url manually specified?
196 # was the url manually specified?
197 keys = set(self.config.EngineFactory.keys())
197 keys = set(self.config.EngineFactory.keys())
198 keys = keys.union(set(self.config.RegistrationFactory.keys()))
198 keys = keys.union(set(self.config.RegistrationFactory.keys()))
199
199
200 if keys.intersection(set(['ip', 'url', 'port'])):
200 if keys.intersection(set(['ip', 'url', 'port'])):
201 # Connection info was specified, don't wait for the file
201 # Connection info was specified, don't wait for the file
202 url_specified = True
202 url_specified = True
203 self.wait_for_url_file = 0
203 self.wait_for_url_file = 0
204 else:
204 else:
205 url_specified = False
205 url_specified = False
206
206
207 if self.wait_for_url_file and not os.path.exists(self.url_file):
207 if self.wait_for_url_file and not os.path.exists(self.url_file):
208 self.log.warn("url_file %r not found"%self.url_file)
208 self.log.warn("url_file %r not found"%self.url_file)
209 self.log.warn("Waiting up to %.1f seconds for it to arrive."%self.wait_for_url_file)
209 self.log.warn("Waiting up to %.1f seconds for it to arrive."%self.wait_for_url_file)
210 tic = time.time()
210 tic = time.time()
211 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
211 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
212 # wait for url_file to exist, for up to 10 seconds
212 # wait for url_file to exist, for up to 10 seconds
213 time.sleep(0.1)
213 time.sleep(0.1)
214
214
215 if os.path.exists(self.url_file):
215 if os.path.exists(self.url_file):
216 self.log.info("Loading url_file %r"%self.url_file)
216 self.log.info("Loading url_file %r"%self.url_file)
217 with open(self.url_file) as f:
217 with open(self.url_file) as f:
218 d = json.loads(f.read())
218 d = json.loads(f.read())
219 if d['exec_key']:
219 if d['exec_key']:
220 config.Session.key = ensure_bytes(d['exec_key'])
220 config.Session.key = asbytes(d['exec_key'])
221 d['url'] = disambiguate_url(d['url'], d['location'])
221 d['url'] = disambiguate_url(d['url'], d['location'])
222 config.EngineFactory.url = d['url']
222 config.EngineFactory.url = d['url']
223 config.EngineFactory.location = d['location']
223 config.EngineFactory.location = d['location']
224 elif not url_specified:
224 elif not url_specified:
225 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
225 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
226 self.exit(1)
226 self.exit(1)
227
227
228
228
229 try:
229 try:
230 exec_lines = config.Kernel.exec_lines
230 exec_lines = config.Kernel.exec_lines
231 except AttributeError:
231 except AttributeError:
232 config.Kernel.exec_lines = []
232 config.Kernel.exec_lines = []
233 exec_lines = config.Kernel.exec_lines
233 exec_lines = config.Kernel.exec_lines
234
234
235 if self.startup_script:
235 if self.startup_script:
236 enc = sys.getfilesystemencoding() or 'utf8'
236 enc = sys.getfilesystemencoding() or 'utf8'
237 cmd="execfile(%r)"%self.startup_script.encode(enc)
237 cmd="execfile(%r)"%self.startup_script.encode(enc)
238 exec_lines.append(cmd)
238 exec_lines.append(cmd)
239 if self.startup_command:
239 if self.startup_command:
240 exec_lines.append(self.startup_command)
240 exec_lines.append(self.startup_command)
241
241
242 # Create the underlying shell class and Engine
242 # Create the underlying shell class and Engine
243 # shell_class = import_item(self.master_config.Global.shell_class)
243 # shell_class = import_item(self.master_config.Global.shell_class)
244 # print self.config
244 # print self.config
245 try:
245 try:
246 self.engine = EngineFactory(config=config, log=self.log)
246 self.engine = EngineFactory(config=config, log=self.log)
247 except:
247 except:
248 self.log.error("Couldn't start the Engine", exc_info=True)
248 self.log.error("Couldn't start the Engine", exc_info=True)
249 self.exit(1)
249 self.exit(1)
250
250
251 def forward_logging(self):
251 def forward_logging(self):
252 if self.log_url:
252 if self.log_url:
253 self.log.info("Forwarding logging to %s"%self.log_url)
253 self.log.info("Forwarding logging to %s"%self.log_url)
254 context = self.engine.context
254 context = self.engine.context
255 lsock = context.socket(zmq.PUB)
255 lsock = context.socket(zmq.PUB)
256 lsock.connect(self.log_url)
256 lsock.connect(self.log_url)
257 self.log.removeHandler(self._log_handler)
257 self.log.removeHandler(self._log_handler)
258 handler = EnginePUBHandler(self.engine, lsock)
258 handler = EnginePUBHandler(self.engine, lsock)
259 handler.setLevel(self.log_level)
259 handler.setLevel(self.log_level)
260 self.log.addHandler(handler)
260 self.log.addHandler(handler)
261 self._log_handler = handler
261 self._log_handler = handler
262 #
262 #
263 def init_mpi(self):
263 def init_mpi(self):
264 global mpi
264 global mpi
265 self.mpi = MPI(config=self.config)
265 self.mpi = MPI(config=self.config)
266
266
267 mpi_import_statement = self.mpi.init_script
267 mpi_import_statement = self.mpi.init_script
268 if mpi_import_statement:
268 if mpi_import_statement:
269 try:
269 try:
270 self.log.info("Initializing MPI:")
270 self.log.info("Initializing MPI:")
271 self.log.info(mpi_import_statement)
271 self.log.info(mpi_import_statement)
272 exec mpi_import_statement in globals()
272 exec mpi_import_statement in globals()
273 except:
273 except:
274 mpi = None
274 mpi = None
275 else:
275 else:
276 mpi = None
276 mpi = None
277
277
278 def initialize(self, argv=None):
278 def initialize(self, argv=None):
279 super(IPEngineApp, self).initialize(argv)
279 super(IPEngineApp, self).initialize(argv)
280 self.init_mpi()
280 self.init_mpi()
281 self.init_engine()
281 self.init_engine()
282 self.forward_logging()
282 self.forward_logging()
283
283
284 def start(self):
284 def start(self):
285 self.engine.start()
285 self.engine.start()
286 try:
286 try:
287 self.engine.loop.start()
287 self.engine.loop.start()
288 except KeyboardInterrupt:
288 except KeyboardInterrupt:
289 self.log.critical("Engine Interrupted, shutting down...\n")
289 self.log.critical("Engine Interrupted, shutting down...\n")
290
290
291
291
292 def launch_new_instance():
292 def launch_new_instance():
293 """Create and run the IPython engine"""
293 """Create and run the IPython engine"""
294 app = IPEngineApp.instance()
294 app = IPEngineApp.instance()
295 app.initialize()
295 app.initialize()
296 app.start()
296 app.start()
297
297
298
298
299 if __name__ == '__main__':
299 if __name__ == '__main__':
300 launch_new_instance()
300 launch_new_instance()
301
301
@@ -1,1428 +1,1428 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 import time
21 import time
22 import warnings
22 import warnings
23 from datetime import datetime
23 from datetime import datetime
24 from getpass import getpass
24 from getpass import getpass
25 from pprint import pprint
25 from pprint import pprint
26
26
27 pjoin = os.path.join
27 pjoin = os.path.join
28
28
29 import zmq
29 import zmq
30 # from zmq.eventloop import ioloop, zmqstream
30 # from zmq.eventloop import ioloop, zmqstream
31
31
32 from IPython.config.configurable import MultipleInstanceError
32 from IPython.config.configurable import MultipleInstanceError
33 from IPython.core.application import BaseIPythonApplication
33 from IPython.core.application import BaseIPythonApplication
34
34
35 from IPython.utils.jsonutil import rekey
35 from IPython.utils.jsonutil import rekey
36 from IPython.utils.localinterfaces import LOCAL_IPS
36 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.path import get_ipython_dir
37 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 Dict, List, Bool, Set)
39 Dict, List, Bool, Set)
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.external.ssh import tunnel
41 from IPython.external.ssh import tunnel
42
42
43 from IPython.parallel import error
43 from IPython.parallel import error
44 from IPython.parallel import util
44 from IPython.parallel import util
45
45
46 from IPython.zmq.session import Session, Message
46 from IPython.zmq.session import Session, Message
47
47
48 from .asyncresult import AsyncResult, AsyncHubResult
48 from .asyncresult import AsyncResult, AsyncHubResult
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 from .view import DirectView, LoadBalancedView
50 from .view import DirectView, LoadBalancedView
51
51
52 if sys.version_info[0] >= 3:
52 if sys.version_info[0] >= 3:
53 # xrange is used in a coupe 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 # should be just 'range' in 3k
54 # should be just 'range' in 3k
55 xrange = range
55 xrange = range
56
56
57 #--------------------------------------------------------------------------
57 #--------------------------------------------------------------------------
58 # Decorators for Client methods
58 # Decorators for Client methods
59 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
60
60
61 @decorator
61 @decorator
62 def spin_first(f, self, *args, **kwargs):
62 def spin_first(f, self, *args, **kwargs):
63 """Call spin() to sync state prior to calling the method."""
63 """Call spin() to sync state prior to calling the method."""
64 self.spin()
64 self.spin()
65 return f(self, *args, **kwargs)
65 return f(self, *args, **kwargs)
66
66
67
67
68 #--------------------------------------------------------------------------
68 #--------------------------------------------------------------------------
69 # Classes
69 # Classes
70 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
71
71
72 class Metadata(dict):
72 class Metadata(dict):
73 """Subclass of dict for initializing metadata values.
73 """Subclass of dict for initializing metadata values.
74
74
75 Attribute access works on keys.
75 Attribute access works on keys.
76
76
77 These objects have a strict set of keys - errors will raise if you try
77 These objects have a strict set of keys - errors will raise if you try
78 to add new keys.
78 to add new keys.
79 """
79 """
80 def __init__(self, *args, **kwargs):
80 def __init__(self, *args, **kwargs):
81 dict.__init__(self)
81 dict.__init__(self)
82 md = {'msg_id' : None,
82 md = {'msg_id' : None,
83 'submitted' : None,
83 'submitted' : None,
84 'started' : None,
84 'started' : None,
85 'completed' : None,
85 'completed' : None,
86 'received' : None,
86 'received' : None,
87 'engine_uuid' : None,
87 'engine_uuid' : None,
88 'engine_id' : None,
88 'engine_id' : None,
89 'follow' : None,
89 'follow' : None,
90 'after' : None,
90 'after' : None,
91 'status' : None,
91 'status' : None,
92
92
93 'pyin' : None,
93 'pyin' : None,
94 'pyout' : None,
94 'pyout' : None,
95 'pyerr' : None,
95 'pyerr' : None,
96 'stdout' : '',
96 'stdout' : '',
97 'stderr' : '',
97 'stderr' : '',
98 }
98 }
99 self.update(md)
99 self.update(md)
100 self.update(dict(*args, **kwargs))
100 self.update(dict(*args, **kwargs))
101
101
102 def __getattr__(self, key):
102 def __getattr__(self, key):
103 """getattr aliased to getitem"""
103 """getattr aliased to getitem"""
104 if key in self.iterkeys():
104 if key in self.iterkeys():
105 return self[key]
105 return self[key]
106 else:
106 else:
107 raise AttributeError(key)
107 raise AttributeError(key)
108
108
109 def __setattr__(self, key, value):
109 def __setattr__(self, key, value):
110 """setattr aliased to setitem, with strict"""
110 """setattr aliased to setitem, with strict"""
111 if key in self.iterkeys():
111 if key in self.iterkeys():
112 self[key] = value
112 self[key] = value
113 else:
113 else:
114 raise AttributeError(key)
114 raise AttributeError(key)
115
115
116 def __setitem__(self, key, value):
116 def __setitem__(self, key, value):
117 """strict static key enforcement"""
117 """strict static key enforcement"""
118 if key in self.iterkeys():
118 if key in self.iterkeys():
119 dict.__setitem__(self, key, value)
119 dict.__setitem__(self, key, value)
120 else:
120 else:
121 raise KeyError(key)
121 raise KeyError(key)
122
122
123
123
124 class Client(HasTraits):
124 class Client(HasTraits):
125 """A semi-synchronous client to the IPython ZMQ cluster
125 """A semi-synchronous client to the IPython ZMQ cluster
126
126
127 Parameters
127 Parameters
128 ----------
128 ----------
129
129
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 Connection information for the Hub's registration. If a json connector
131 Connection information for the Hub's registration. If a json connector
132 file is given, then likely no further configuration is necessary.
132 file is given, then likely no further configuration is necessary.
133 [Default: use profile]
133 [Default: use profile]
134 profile : bytes
134 profile : bytes
135 The name of the Cluster profile to be used to find connector information.
135 The name of the Cluster profile to be used to find connector information.
136 If run from an IPython application, the default profile will be the same
136 If run from an IPython application, the default profile will be the same
137 as the running application, otherwise it will be 'default'.
137 as the running application, otherwise it will be 'default'.
138 context : zmq.Context
138 context : zmq.Context
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 debug : bool
140 debug : bool
141 flag for lots of message printing for debug purposes
141 flag for lots of message printing for debug purposes
142 timeout : int/float
142 timeout : int/float
143 time (in seconds) to wait for connection replies from the Hub
143 time (in seconds) to wait for connection replies from the Hub
144 [Default: 10]
144 [Default: 10]
145
145
146 #-------------- session related args ----------------
146 #-------------- session related args ----------------
147
147
148 config : Config object
148 config : Config object
149 If specified, this will be relayed to the Session for configuration
149 If specified, this will be relayed to the Session for configuration
150 username : str
150 username : str
151 set username for the session object
151 set username for the session object
152 packer : str (import_string) or callable
152 packer : str (import_string) or callable
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 function to serialize messages. Must support same input as
154 function to serialize messages. Must support same input as
155 JSON, and output must be bytes.
155 JSON, and output must be bytes.
156 You can pass a callable directly as `pack`
156 You can pass a callable directly as `pack`
157 unpacker : str (import_string) or callable
157 unpacker : str (import_string) or callable
158 The inverse of packer. Only necessary if packer is specified as *not* one
158 The inverse of packer. Only necessary if packer is specified as *not* one
159 of 'json' or 'pickle'.
159 of 'json' or 'pickle'.
160
160
161 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
162 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
163 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
164 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
167 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
168 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
169
169
170 sshserver : str
170 sshserver : str
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
173 the ip given in addr.
173 the ip given in addr.
174 sshkey : str; path to public ssh key file
174 sshkey : str; path to public ssh key file
175 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
176 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
177 password : str
177 password : str
178 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
179 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
180 paramiko : bool
180 paramiko : bool
181 flag for whether to use paramiko instead of shell ssh for tunneling.
181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 [default: True on win32, False else]
182 [default: True on win32, False else]
183
183
184 ------- exec authentication args -------
184 ------- exec authentication args -------
185 If even localhost is untrusted, you can have some protection against
185 If even localhost is untrusted, you can have some protection against
186 unauthorized execution by signing messages with HMAC digests.
186 unauthorized execution by signing messages with HMAC digests.
187 Messages are still sent as cleartext, so if someone can snoop your
187 Messages are still sent as cleartext, so if someone can snoop your
188 loopback traffic this will not protect your privacy, but will prevent
188 loopback traffic this will not protect your privacy, but will prevent
189 unauthorized execution.
189 unauthorized execution.
190
190
191 exec_key : str
191 exec_key : str
192 an authentication key or file containing a key
192 an authentication key or file containing a key
193 default: None
193 default: None
194
194
195
195
196 Attributes
196 Attributes
197 ----------
197 ----------
198
198
199 ids : list of int engine IDs
199 ids : list of int engine IDs
200 requesting the ids attribute always synchronizes
200 requesting the ids attribute always synchronizes
201 the registration state. To request ids without synchronization,
201 the registration state. To request ids without synchronization,
202 use semi-private _ids attributes.
202 use semi-private _ids attributes.
203
203
204 history : list of msg_ids
204 history : list of msg_ids
205 a list of msg_ids, keeping track of all the execution
205 a list of msg_ids, keeping track of all the execution
206 messages you have submitted in order.
206 messages you have submitted in order.
207
207
208 outstanding : set of msg_ids
208 outstanding : set of msg_ids
209 a set of msg_ids that have been submitted, but whose
209 a set of msg_ids that have been submitted, but whose
210 results have not yet been received.
210 results have not yet been received.
211
211
212 results : dict
212 results : dict
213 a dict of all our results, keyed by msg_id
213 a dict of all our results, keyed by msg_id
214
214
215 block : bool
215 block : bool
216 determines default behavior when block not specified
216 determines default behavior when block not specified
217 in execution methods
217 in execution methods
218
218
219 Methods
219 Methods
220 -------
220 -------
221
221
222 spin
222 spin
223 flushes incoming results and registration state changes
223 flushes incoming results and registration state changes
224 control methods spin, and requesting `ids` also ensures up to date
224 control methods spin, and requesting `ids` also ensures up to date
225
225
226 wait
226 wait
227 wait on one or more msg_ids
227 wait on one or more msg_ids
228
228
229 execution methods
229 execution methods
230 apply
230 apply
231 legacy: execute, run
231 legacy: execute, run
232
232
233 data movement
233 data movement
234 push, pull, scatter, gather
234 push, pull, scatter, gather
235
235
236 query methods
236 query methods
237 queue_status, get_result, purge, result_status
237 queue_status, get_result, purge, result_status
238
238
239 control methods
239 control methods
240 abort, shutdown
240 abort, shutdown
241
241
242 """
242 """
243
243
244
244
245 block = Bool(False)
245 block = Bool(False)
246 outstanding = Set()
246 outstanding = Set()
247 results = Instance('collections.defaultdict', (dict,))
247 results = Instance('collections.defaultdict', (dict,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
248 metadata = Instance('collections.defaultdict', (Metadata,))
249 history = List()
249 history = List()
250 debug = Bool(False)
250 debug = Bool(False)
251
251
252 profile=Unicode()
252 profile=Unicode()
253 def _profile_default(self):
253 def _profile_default(self):
254 if BaseIPythonApplication.initialized():
254 if BaseIPythonApplication.initialized():
255 # an IPython app *might* be running, try to get its profile
255 # an IPython app *might* be running, try to get its profile
256 try:
256 try:
257 return BaseIPythonApplication.instance().profile
257 return BaseIPythonApplication.instance().profile
258 except (AttributeError, MultipleInstanceError):
258 except (AttributeError, MultipleInstanceError):
259 # could be a *different* subclass of config.Application,
259 # could be a *different* subclass of config.Application,
260 # which would raise one of these two errors.
260 # which would raise one of these two errors.
261 return u'default'
261 return u'default'
262 else:
262 else:
263 return u'default'
263 return u'default'
264
264
265
265
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 _ids = List()
267 _ids = List()
268 _connected=Bool(False)
268 _connected=Bool(False)
269 _ssh=Bool(False)
269 _ssh=Bool(False)
270 _context = Instance('zmq.Context')
270 _context = Instance('zmq.Context')
271 _config = Dict()
271 _config = Dict()
272 _engines=Instance(util.ReverseDict, (), {})
272 _engines=Instance(util.ReverseDict, (), {})
273 # _hub_socket=Instance('zmq.Socket')
273 # _hub_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
274 _query_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
275 _control_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
276 _iopub_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
277 _notification_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
278 _mux_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
279 _task_socket=Instance('zmq.Socket')
280 _task_scheme=Unicode()
280 _task_scheme=Unicode()
281 _closed = False
281 _closed = False
282 _ignored_control_replies=Int(0)
282 _ignored_control_replies=Int(0)
283 _ignored_hub_replies=Int(0)
283 _ignored_hub_replies=Int(0)
284
284
285 def __new__(self, *args, **kw):
285 def __new__(self, *args, **kw):
286 # don't raise on positional args
286 # don't raise on positional args
287 return HasTraits.__new__(self, **kw)
287 return HasTraits.__new__(self, **kw)
288
288
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 context=None, debug=False, exec_key=None,
290 context=None, debug=False, exec_key=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 timeout=10, **extra_args
292 timeout=10, **extra_args
293 ):
293 ):
294 if profile:
294 if profile:
295 super(Client, self).__init__(debug=debug, profile=profile)
295 super(Client, self).__init__(debug=debug, profile=profile)
296 else:
296 else:
297 super(Client, self).__init__(debug=debug)
297 super(Client, self).__init__(debug=debug)
298 if context is None:
298 if context is None:
299 context = zmq.Context.instance()
299 context = zmq.Context.instance()
300 self._context = context
300 self._context = context
301
301
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 if self._cd is not None:
303 if self._cd is not None:
304 if url_or_file is None:
304 if url_or_file is None:
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 " Please specify at least one of url_or_file or profile."
307 " Please specify at least one of url_or_file or profile."
308
308
309 try:
309 try:
310 util.validate_url(url_or_file)
310 util.validate_url(url_or_file)
311 except AssertionError:
311 except AssertionError:
312 if not os.path.exists(url_or_file):
312 if not os.path.exists(url_or_file):
313 if self._cd:
313 if self._cd:
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 with open(url_or_file) as f:
316 with open(url_or_file) as f:
317 cfg = json.loads(f.read())
317 cfg = json.loads(f.read())
318 else:
318 else:
319 cfg = {'url':url_or_file}
319 cfg = {'url':url_or_file}
320
320
321 # sync defaults from args, json:
321 # sync defaults from args, json:
322 if sshserver:
322 if sshserver:
323 cfg['ssh'] = sshserver
323 cfg['ssh'] = sshserver
324 if exec_key:
324 if exec_key:
325 cfg['exec_key'] = exec_key
325 cfg['exec_key'] = exec_key
326 exec_key = cfg['exec_key']
326 exec_key = cfg['exec_key']
327 location = cfg.setdefault('location', None)
327 location = cfg.setdefault('location', None)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 url = cfg['url']
329 url = cfg['url']
330 proto,addr,port = util.split_url(url)
330 proto,addr,port = util.split_url(url)
331 if location is not None and addr == '127.0.0.1':
331 if location is not None and addr == '127.0.0.1':
332 # location specified, and connection is expected to be local
332 # location specified, and connection is expected to be local
333 if location not in LOCAL_IPS and not sshserver:
333 if location not in LOCAL_IPS and not sshserver:
334 # load ssh from JSON *only* if the controller is not on
334 # load ssh from JSON *only* if the controller is not on
335 # this machine
335 # this machine
336 sshserver=cfg['ssh']
336 sshserver=cfg['ssh']
337 if location not in LOCAL_IPS and not sshserver:
337 if location not in LOCAL_IPS and not sshserver:
338 # warn if no ssh specified, but SSH is probably needed
338 # warn if no ssh specified, but SSH is probably needed
339 # This is only a warning, because the most likely cause
339 # This is only a warning, because the most likely cause
340 # is a local Controller on a laptop whose IP is dynamic
340 # is a local Controller on a laptop whose IP is dynamic
341 warnings.warn("""
341 warnings.warn("""
342 Controller appears to be listening on localhost, but not on this machine.
342 Controller appears to be listening on localhost, but not on this machine.
343 If this is true, you should specify Client(...,sshserver='you@%s')
343 If this is true, you should specify Client(...,sshserver='you@%s')
344 or instruct your controller to listen on an external IP."""%location,
344 or instruct your controller to listen on an external IP."""%location,
345 RuntimeWarning)
345 RuntimeWarning)
346
346
347 self._config = cfg
347 self._config = cfg
348
348
349 self._ssh = bool(sshserver or sshkey or password)
349 self._ssh = bool(sshserver or sshkey or password)
350 if self._ssh and sshserver is None:
350 if self._ssh and sshserver is None:
351 # default to ssh via localhost
351 # default to ssh via localhost
352 sshserver = url.split('://')[1].split(':')[0]
352 sshserver = url.split('://')[1].split(':')[0]
353 if self._ssh and password is None:
353 if self._ssh and password is None:
354 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
354 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
355 password=False
355 password=False
356 else:
356 else:
357 password = getpass("SSH Password for %s: "%sshserver)
357 password = getpass("SSH Password for %s: "%sshserver)
358 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
358 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
359
359
360 # configure and construct the session
360 # configure and construct the session
361 if exec_key is not None:
361 if exec_key is not None:
362 if os.path.isfile(exec_key):
362 if os.path.isfile(exec_key):
363 extra_args['keyfile'] = exec_key
363 extra_args['keyfile'] = exec_key
364 else:
364 else:
365 exec_key = util.ensure_bytes(exec_key)
365 exec_key = util.asbytes(exec_key)
366 extra_args['key'] = exec_key
366 extra_args['key'] = exec_key
367 self.session = Session(**extra_args)
367 self.session = Session(**extra_args)
368
368
369 self._query_socket = self._context.socket(zmq.XREQ)
369 self._query_socket = self._context.socket(zmq.XREQ)
370 self._query_socket.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
370 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
371 if self._ssh:
371 if self._ssh:
372 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
372 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
373 else:
373 else:
374 self._query_socket.connect(url)
374 self._query_socket.connect(url)
375
375
376 self.session.debug = self.debug
376 self.session.debug = self.debug
377
377
378 self._notification_handlers = {'registration_notification' : self._register_engine,
378 self._notification_handlers = {'registration_notification' : self._register_engine,
379 'unregistration_notification' : self._unregister_engine,
379 'unregistration_notification' : self._unregister_engine,
380 'shutdown_notification' : lambda msg: self.close(),
380 'shutdown_notification' : lambda msg: self.close(),
381 }
381 }
382 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
382 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
383 'apply_reply' : self._handle_apply_reply}
383 'apply_reply' : self._handle_apply_reply}
384 self._connect(sshserver, ssh_kwargs, timeout)
384 self._connect(sshserver, ssh_kwargs, timeout)
385
385
386 def __del__(self):
386 def __del__(self):
387 """cleanup sockets, but _not_ context."""
387 """cleanup sockets, but _not_ context."""
388 self.close()
388 self.close()
389
389
390 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
390 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
391 if ipython_dir is None:
391 if ipython_dir is None:
392 ipython_dir = get_ipython_dir()
392 ipython_dir = get_ipython_dir()
393 if profile_dir is not None:
393 if profile_dir is not None:
394 try:
394 try:
395 self._cd = ProfileDir.find_profile_dir(profile_dir)
395 self._cd = ProfileDir.find_profile_dir(profile_dir)
396 return
396 return
397 except ProfileDirError:
397 except ProfileDirError:
398 pass
398 pass
399 elif profile is not None:
399 elif profile is not None:
400 try:
400 try:
401 self._cd = ProfileDir.find_profile_dir_by_name(
401 self._cd = ProfileDir.find_profile_dir_by_name(
402 ipython_dir, profile)
402 ipython_dir, profile)
403 return
403 return
404 except ProfileDirError:
404 except ProfileDirError:
405 pass
405 pass
406 self._cd = None
406 self._cd = None
407
407
408 def _update_engines(self, engines):
408 def _update_engines(self, engines):
409 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
409 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
410 for k,v in engines.iteritems():
410 for k,v in engines.iteritems():
411 eid = int(k)
411 eid = int(k)
412 self._engines[eid] = v
412 self._engines[eid] = v
413 self._ids.append(eid)
413 self._ids.append(eid)
414 self._ids = sorted(self._ids)
414 self._ids = sorted(self._ids)
415 if sorted(self._engines.keys()) != range(len(self._engines)) and \
415 if sorted(self._engines.keys()) != range(len(self._engines)) and \
416 self._task_scheme == 'pure' and self._task_socket:
416 self._task_scheme == 'pure' and self._task_socket:
417 self._stop_scheduling_tasks()
417 self._stop_scheduling_tasks()
418
418
419 def _stop_scheduling_tasks(self):
419 def _stop_scheduling_tasks(self):
420 """Stop scheduling tasks because an engine has been unregistered
420 """Stop scheduling tasks because an engine has been unregistered
421 from a pure ZMQ scheduler.
421 from a pure ZMQ scheduler.
422 """
422 """
423 self._task_socket.close()
423 self._task_socket.close()
424 self._task_socket = None
424 self._task_socket = None
425 msg = "An engine has been unregistered, and we are using pure " +\
425 msg = "An engine has been unregistered, and we are using pure " +\
426 "ZMQ task scheduling. Task farming will be disabled."
426 "ZMQ task scheduling. Task farming will be disabled."
427 if self.outstanding:
427 if self.outstanding:
428 msg += " If you were running tasks when this happened, " +\
428 msg += " If you were running tasks when this happened, " +\
429 "some `outstanding` msg_ids may never resolve."
429 "some `outstanding` msg_ids may never resolve."
430 warnings.warn(msg, RuntimeWarning)
430 warnings.warn(msg, RuntimeWarning)
431
431
432 def _build_targets(self, targets):
432 def _build_targets(self, targets):
433 """Turn valid target IDs or 'all' into two lists:
433 """Turn valid target IDs or 'all' into two lists:
434 (int_ids, uuids).
434 (int_ids, uuids).
435 """
435 """
436 if not self._ids:
436 if not self._ids:
437 # flush notification socket if no engines yet, just in case
437 # flush notification socket if no engines yet, just in case
438 if not self.ids:
438 if not self.ids:
439 raise error.NoEnginesRegistered("Can't build targets without any engines")
439 raise error.NoEnginesRegistered("Can't build targets without any engines")
440
440
441 if targets is None:
441 if targets is None:
442 targets = self._ids
442 targets = self._ids
443 elif isinstance(targets, str):
443 elif isinstance(targets, str):
444 if targets.lower() == 'all':
444 if targets.lower() == 'all':
445 targets = self._ids
445 targets = self._ids
446 else:
446 else:
447 raise TypeError("%r not valid str target, must be 'all'"%(targets))
447 raise TypeError("%r not valid str target, must be 'all'"%(targets))
448 elif isinstance(targets, int):
448 elif isinstance(targets, int):
449 if targets < 0:
449 if targets < 0:
450 targets = self.ids[targets]
450 targets = self.ids[targets]
451 if targets not in self._ids:
451 if targets not in self._ids:
452 raise IndexError("No such engine: %i"%targets)
452 raise IndexError("No such engine: %i"%targets)
453 targets = [targets]
453 targets = [targets]
454
454
455 if isinstance(targets, slice):
455 if isinstance(targets, slice):
456 indices = range(len(self._ids))[targets]
456 indices = range(len(self._ids))[targets]
457 ids = self.ids
457 ids = self.ids
458 targets = [ ids[i] for i in indices ]
458 targets = [ ids[i] for i in indices ]
459
459
460 if not isinstance(targets, (tuple, list, xrange)):
460 if not isinstance(targets, (tuple, list, xrange)):
461 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
461 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
462
462
463 return [util.ensure_bytes(self._engines[t]) for t in targets], list(targets)
463 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
464
464
465 def _connect(self, sshserver, ssh_kwargs, timeout):
465 def _connect(self, sshserver, ssh_kwargs, timeout):
466 """setup all our socket connections to the cluster. This is called from
466 """setup all our socket connections to the cluster. This is called from
467 __init__."""
467 __init__."""
468
468
469 # Maybe allow reconnecting?
469 # Maybe allow reconnecting?
470 if self._connected:
470 if self._connected:
471 return
471 return
472 self._connected=True
472 self._connected=True
473
473
474 def connect_socket(s, url):
474 def connect_socket(s, url):
475 url = util.disambiguate_url(url, self._config['location'])
475 url = util.disambiguate_url(url, self._config['location'])
476 if self._ssh:
476 if self._ssh:
477 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
477 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
478 else:
478 else:
479 return s.connect(url)
479 return s.connect(url)
480
480
481 self.session.send(self._query_socket, 'connection_request')
481 self.session.send(self._query_socket, 'connection_request')
482 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
482 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
483 poller = zmq.Poller()
483 poller = zmq.Poller()
484 poller.register(self._query_socket, zmq.POLLIN)
484 poller.register(self._query_socket, zmq.POLLIN)
485 # poll expects milliseconds, timeout is seconds
485 # poll expects milliseconds, timeout is seconds
486 evts = poller.poll(timeout*1000)
486 evts = poller.poll(timeout*1000)
487 if not evts:
487 if not evts:
488 raise error.TimeoutError("Hub connection request timed out")
488 raise error.TimeoutError("Hub connection request timed out")
489 idents,msg = self.session.recv(self._query_socket,mode=0)
489 idents,msg = self.session.recv(self._query_socket,mode=0)
490 if self.debug:
490 if self.debug:
491 pprint(msg)
491 pprint(msg)
492 msg = Message(msg)
492 msg = Message(msg)
493 content = msg.content
493 content = msg.content
494 self._config['registration'] = dict(content)
494 self._config['registration'] = dict(content)
495 if content.status == 'ok':
495 if content.status == 'ok':
496 ident = util.ensure_bytes(self.session.session)
496 ident = util.asbytes(self.session.session)
497 if content.mux:
497 if content.mux:
498 self._mux_socket = self._context.socket(zmq.XREQ)
498 self._mux_socket = self._context.socket(zmq.XREQ)
499 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
499 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
500 connect_socket(self._mux_socket, content.mux)
500 connect_socket(self._mux_socket, content.mux)
501 if content.task:
501 if content.task:
502 self._task_scheme, task_addr = content.task
502 self._task_scheme, task_addr = content.task
503 self._task_socket = self._context.socket(zmq.XREQ)
503 self._task_socket = self._context.socket(zmq.XREQ)
504 self._task_socket.setsockopt(zmq.IDENTITY, ident)
504 self._task_socket.setsockopt(zmq.IDENTITY, ident)
505 connect_socket(self._task_socket, task_addr)
505 connect_socket(self._task_socket, task_addr)
506 if content.notification:
506 if content.notification:
507 self._notification_socket = self._context.socket(zmq.SUB)
507 self._notification_socket = self._context.socket(zmq.SUB)
508 connect_socket(self._notification_socket, content.notification)
508 connect_socket(self._notification_socket, content.notification)
509 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
509 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
510 # if content.query:
510 # if content.query:
511 # self._query_socket = self._context.socket(zmq.XREQ)
511 # self._query_socket = self._context.socket(zmq.XREQ)
512 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
512 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
513 # connect_socket(self._query_socket, content.query)
513 # connect_socket(self._query_socket, content.query)
514 if content.control:
514 if content.control:
515 self._control_socket = self._context.socket(zmq.XREQ)
515 self._control_socket = self._context.socket(zmq.XREQ)
516 self._control_socket.setsockopt(zmq.IDENTITY, ident)
516 self._control_socket.setsockopt(zmq.IDENTITY, ident)
517 connect_socket(self._control_socket, content.control)
517 connect_socket(self._control_socket, content.control)
518 if content.iopub:
518 if content.iopub:
519 self._iopub_socket = self._context.socket(zmq.SUB)
519 self._iopub_socket = self._context.socket(zmq.SUB)
520 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
520 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
521 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
521 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
522 connect_socket(self._iopub_socket, content.iopub)
522 connect_socket(self._iopub_socket, content.iopub)
523 self._update_engines(dict(content.engines))
523 self._update_engines(dict(content.engines))
524 else:
524 else:
525 self._connected = False
525 self._connected = False
526 raise Exception("Failed to connect!")
526 raise Exception("Failed to connect!")
527
527
528 #--------------------------------------------------------------------------
528 #--------------------------------------------------------------------------
529 # handlers and callbacks for incoming messages
529 # handlers and callbacks for incoming messages
530 #--------------------------------------------------------------------------
530 #--------------------------------------------------------------------------
531
531
532 def _unwrap_exception(self, content):
532 def _unwrap_exception(self, content):
533 """unwrap exception, and remap engine_id to int."""
533 """unwrap exception, and remap engine_id to int."""
534 e = error.unwrap_exception(content)
534 e = error.unwrap_exception(content)
535 # print e.traceback
535 # print e.traceback
536 if e.engine_info:
536 if e.engine_info:
537 e_uuid = e.engine_info['engine_uuid']
537 e_uuid = e.engine_info['engine_uuid']
538 eid = self._engines[e_uuid]
538 eid = self._engines[e_uuid]
539 e.engine_info['engine_id'] = eid
539 e.engine_info['engine_id'] = eid
540 return e
540 return e
541
541
542 def _extract_metadata(self, header, parent, content):
542 def _extract_metadata(self, header, parent, content):
543 md = {'msg_id' : parent['msg_id'],
543 md = {'msg_id' : parent['msg_id'],
544 'received' : datetime.now(),
544 'received' : datetime.now(),
545 'engine_uuid' : header.get('engine', None),
545 'engine_uuid' : header.get('engine', None),
546 'follow' : parent.get('follow', []),
546 'follow' : parent.get('follow', []),
547 'after' : parent.get('after', []),
547 'after' : parent.get('after', []),
548 'status' : content['status'],
548 'status' : content['status'],
549 }
549 }
550
550
551 if md['engine_uuid'] is not None:
551 if md['engine_uuid'] is not None:
552 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
552 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
553
553
554 if 'date' in parent:
554 if 'date' in parent:
555 md['submitted'] = parent['date']
555 md['submitted'] = parent['date']
556 if 'started' in header:
556 if 'started' in header:
557 md['started'] = header['started']
557 md['started'] = header['started']
558 if 'date' in header:
558 if 'date' in header:
559 md['completed'] = header['date']
559 md['completed'] = header['date']
560 return md
560 return md
561
561
562 def _register_engine(self, msg):
562 def _register_engine(self, msg):
563 """Register a new engine, and update our connection info."""
563 """Register a new engine, and update our connection info."""
564 content = msg['content']
564 content = msg['content']
565 eid = content['id']
565 eid = content['id']
566 d = {eid : content['queue']}
566 d = {eid : content['queue']}
567 self._update_engines(d)
567 self._update_engines(d)
568
568
569 def _unregister_engine(self, msg):
569 def _unregister_engine(self, msg):
570 """Unregister an engine that has died."""
570 """Unregister an engine that has died."""
571 content = msg['content']
571 content = msg['content']
572 eid = int(content['id'])
572 eid = int(content['id'])
573 if eid in self._ids:
573 if eid in self._ids:
574 self._ids.remove(eid)
574 self._ids.remove(eid)
575 uuid = self._engines.pop(eid)
575 uuid = self._engines.pop(eid)
576
576
577 self._handle_stranded_msgs(eid, uuid)
577 self._handle_stranded_msgs(eid, uuid)
578
578
579 if self._task_socket and self._task_scheme == 'pure':
579 if self._task_socket and self._task_scheme == 'pure':
580 self._stop_scheduling_tasks()
580 self._stop_scheduling_tasks()
581
581
582 def _handle_stranded_msgs(self, eid, uuid):
582 def _handle_stranded_msgs(self, eid, uuid):
583 """Handle messages known to be on an engine when the engine unregisters.
583 """Handle messages known to be on an engine when the engine unregisters.
584
584
585 It is possible that this will fire prematurely - that is, an engine will
585 It is possible that this will fire prematurely - that is, an engine will
586 go down after completing a result, and the client will be notified
586 go down after completing a result, and the client will be notified
587 of the unregistration and later receive the successful result.
587 of the unregistration and later receive the successful result.
588 """
588 """
589
589
590 outstanding = self._outstanding_dict[uuid]
590 outstanding = self._outstanding_dict[uuid]
591
591
592 for msg_id in list(outstanding):
592 for msg_id in list(outstanding):
593 if msg_id in self.results:
593 if msg_id in self.results:
594 # we already
594 # we already
595 continue
595 continue
596 try:
596 try:
597 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
597 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
598 except:
598 except:
599 content = error.wrap_exception()
599 content = error.wrap_exception()
600 # build a fake message:
600 # build a fake message:
601 parent = {}
601 parent = {}
602 header = {}
602 header = {}
603 parent['msg_id'] = msg_id
603 parent['msg_id'] = msg_id
604 header['engine'] = uuid
604 header['engine'] = uuid
605 header['date'] = datetime.now()
605 header['date'] = datetime.now()
606 msg = dict(parent_header=parent, header=header, content=content)
606 msg = dict(parent_header=parent, header=header, content=content)
607 self._handle_apply_reply(msg)
607 self._handle_apply_reply(msg)
608
608
609 def _handle_execute_reply(self, msg):
609 def _handle_execute_reply(self, msg):
610 """Save the reply to an execute_request into our results.
610 """Save the reply to an execute_request into our results.
611
611
612 execute messages are never actually used. apply is used instead.
612 execute messages are never actually used. apply is used instead.
613 """
613 """
614
614
615 parent = msg['parent_header']
615 parent = msg['parent_header']
616 msg_id = parent['msg_id']
616 msg_id = parent['msg_id']
617 if msg_id not in self.outstanding:
617 if msg_id not in self.outstanding:
618 if msg_id in self.history:
618 if msg_id in self.history:
619 print ("got stale result: %s"%msg_id)
619 print ("got stale result: %s"%msg_id)
620 else:
620 else:
621 print ("got unknown result: %s"%msg_id)
621 print ("got unknown result: %s"%msg_id)
622 else:
622 else:
623 self.outstanding.remove(msg_id)
623 self.outstanding.remove(msg_id)
624 self.results[msg_id] = self._unwrap_exception(msg['content'])
624 self.results[msg_id] = self._unwrap_exception(msg['content'])
625
625
626 def _handle_apply_reply(self, msg):
626 def _handle_apply_reply(self, msg):
627 """Save the reply to an apply_request into our results."""
627 """Save the reply to an apply_request into our results."""
628 parent = msg['parent_header']
628 parent = msg['parent_header']
629 msg_id = parent['msg_id']
629 msg_id = parent['msg_id']
630 if msg_id not in self.outstanding:
630 if msg_id not in self.outstanding:
631 if msg_id in self.history:
631 if msg_id in self.history:
632 print ("got stale result: %s"%msg_id)
632 print ("got stale result: %s"%msg_id)
633 print self.results[msg_id]
633 print self.results[msg_id]
634 print msg
634 print msg
635 else:
635 else:
636 print ("got unknown result: %s"%msg_id)
636 print ("got unknown result: %s"%msg_id)
637 else:
637 else:
638 self.outstanding.remove(msg_id)
638 self.outstanding.remove(msg_id)
639 content = msg['content']
639 content = msg['content']
640 header = msg['header']
640 header = msg['header']
641
641
642 # construct metadata:
642 # construct metadata:
643 md = self.metadata[msg_id]
643 md = self.metadata[msg_id]
644 md.update(self._extract_metadata(header, parent, content))
644 md.update(self._extract_metadata(header, parent, content))
645 # is this redundant?
645 # is this redundant?
646 self.metadata[msg_id] = md
646 self.metadata[msg_id] = md
647
647
648 e_outstanding = self._outstanding_dict[md['engine_uuid']]
648 e_outstanding = self._outstanding_dict[md['engine_uuid']]
649 if msg_id in e_outstanding:
649 if msg_id in e_outstanding:
650 e_outstanding.remove(msg_id)
650 e_outstanding.remove(msg_id)
651
651
652 # construct result:
652 # construct result:
653 if content['status'] == 'ok':
653 if content['status'] == 'ok':
654 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
654 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
655 elif content['status'] == 'aborted':
655 elif content['status'] == 'aborted':
656 self.results[msg_id] = error.TaskAborted(msg_id)
656 self.results[msg_id] = error.TaskAborted(msg_id)
657 elif content['status'] == 'resubmitted':
657 elif content['status'] == 'resubmitted':
658 # TODO: handle resubmission
658 # TODO: handle resubmission
659 pass
659 pass
660 else:
660 else:
661 self.results[msg_id] = self._unwrap_exception(content)
661 self.results[msg_id] = self._unwrap_exception(content)
662
662
663 def _flush_notifications(self):
663 def _flush_notifications(self):
664 """Flush notifications of engine registrations waiting
664 """Flush notifications of engine registrations waiting
665 in ZMQ queue."""
665 in ZMQ queue."""
666 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
666 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
667 while msg is not None:
667 while msg is not None:
668 if self.debug:
668 if self.debug:
669 pprint(msg)
669 pprint(msg)
670 msg_type = msg['msg_type']
670 msg_type = msg['msg_type']
671 handler = self._notification_handlers.get(msg_type, None)
671 handler = self._notification_handlers.get(msg_type, None)
672 if handler is None:
672 if handler is None:
673 raise Exception("Unhandled message type: %s"%msg.msg_type)
673 raise Exception("Unhandled message type: %s"%msg.msg_type)
674 else:
674 else:
675 handler(msg)
675 handler(msg)
676 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
676 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
677
677
678 def _flush_results(self, sock):
678 def _flush_results(self, sock):
679 """Flush task or queue results waiting in ZMQ queue."""
679 """Flush task or queue results waiting in ZMQ queue."""
680 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
680 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
681 while msg is not None:
681 while msg is not None:
682 if self.debug:
682 if self.debug:
683 pprint(msg)
683 pprint(msg)
684 msg_type = msg['msg_type']
684 msg_type = msg['msg_type']
685 handler = self._queue_handlers.get(msg_type, None)
685 handler = self._queue_handlers.get(msg_type, None)
686 if handler is None:
686 if handler is None:
687 raise Exception("Unhandled message type: %s"%msg.msg_type)
687 raise Exception("Unhandled message type: %s"%msg.msg_type)
688 else:
688 else:
689 handler(msg)
689 handler(msg)
690 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
690 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
691
691
692 def _flush_control(self, sock):
692 def _flush_control(self, sock):
693 """Flush replies from the control channel waiting
693 """Flush replies from the control channel waiting
694 in the ZMQ queue.
694 in the ZMQ queue.
695
695
696 Currently: ignore them."""
696 Currently: ignore them."""
697 if self._ignored_control_replies <= 0:
697 if self._ignored_control_replies <= 0:
698 return
698 return
699 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
699 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
700 while msg is not None:
700 while msg is not None:
701 self._ignored_control_replies -= 1
701 self._ignored_control_replies -= 1
702 if self.debug:
702 if self.debug:
703 pprint(msg)
703 pprint(msg)
704 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
704 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
705
705
706 def _flush_ignored_control(self):
706 def _flush_ignored_control(self):
707 """flush ignored control replies"""
707 """flush ignored control replies"""
708 while self._ignored_control_replies > 0:
708 while self._ignored_control_replies > 0:
709 self.session.recv(self._control_socket)
709 self.session.recv(self._control_socket)
710 self._ignored_control_replies -= 1
710 self._ignored_control_replies -= 1
711
711
712 def _flush_ignored_hub_replies(self):
712 def _flush_ignored_hub_replies(self):
713 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
713 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
714 while msg is not None:
714 while msg is not None:
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716
716
717 def _flush_iopub(self, sock):
717 def _flush_iopub(self, sock):
718 """Flush replies from the iopub channel waiting
718 """Flush replies from the iopub channel waiting
719 in the ZMQ queue.
719 in the ZMQ queue.
720 """
720 """
721 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
721 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
722 while msg is not None:
722 while msg is not None:
723 if self.debug:
723 if self.debug:
724 pprint(msg)
724 pprint(msg)
725 parent = msg['parent_header']
725 parent = msg['parent_header']
726 msg_id = parent['msg_id']
726 msg_id = parent['msg_id']
727 content = msg['content']
727 content = msg['content']
728 header = msg['header']
728 header = msg['header']
729 msg_type = msg['msg_type']
729 msg_type = msg['msg_type']
730
730
731 # init metadata:
731 # init metadata:
732 md = self.metadata[msg_id]
732 md = self.metadata[msg_id]
733
733
734 if msg_type == 'stream':
734 if msg_type == 'stream':
735 name = content['name']
735 name = content['name']
736 s = md[name] or ''
736 s = md[name] or ''
737 md[name] = s + content['data']
737 md[name] = s + content['data']
738 elif msg_type == 'pyerr':
738 elif msg_type == 'pyerr':
739 md.update({'pyerr' : self._unwrap_exception(content)})
739 md.update({'pyerr' : self._unwrap_exception(content)})
740 elif msg_type == 'pyin':
740 elif msg_type == 'pyin':
741 md.update({'pyin' : content['code']})
741 md.update({'pyin' : content['code']})
742 else:
742 else:
743 md.update({msg_type : content.get('data', '')})
743 md.update({msg_type : content.get('data', '')})
744
744
745 # reduntant?
745 # reduntant?
746 self.metadata[msg_id] = md
746 self.metadata[msg_id] = md
747
747
748 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
748 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
749
749
750 #--------------------------------------------------------------------------
750 #--------------------------------------------------------------------------
751 # len, getitem
751 # len, getitem
752 #--------------------------------------------------------------------------
752 #--------------------------------------------------------------------------
753
753
754 def __len__(self):
754 def __len__(self):
755 """len(client) returns # of engines."""
755 """len(client) returns # of engines."""
756 return len(self.ids)
756 return len(self.ids)
757
757
758 def __getitem__(self, key):
758 def __getitem__(self, key):
759 """index access returns DirectView multiplexer objects
759 """index access returns DirectView multiplexer objects
760
760
761 Must be int, slice, or list/tuple/xrange of ints"""
761 Must be int, slice, or list/tuple/xrange of ints"""
762 if not isinstance(key, (int, slice, tuple, list, xrange)):
762 if not isinstance(key, (int, slice, tuple, list, xrange)):
763 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
763 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
764 else:
764 else:
765 return self.direct_view(key)
765 return self.direct_view(key)
766
766
767 #--------------------------------------------------------------------------
767 #--------------------------------------------------------------------------
768 # Begin public methods
768 # Begin public methods
769 #--------------------------------------------------------------------------
769 #--------------------------------------------------------------------------
770
770
771 @property
771 @property
772 def ids(self):
772 def ids(self):
773 """Always up-to-date ids property."""
773 """Always up-to-date ids property."""
774 self._flush_notifications()
774 self._flush_notifications()
775 # always copy:
775 # always copy:
776 return list(self._ids)
776 return list(self._ids)
777
777
778 def close(self):
778 def close(self):
779 if self._closed:
779 if self._closed:
780 return
780 return
781 snames = filter(lambda n: n.endswith('socket'), dir(self))
781 snames = filter(lambda n: n.endswith('socket'), dir(self))
782 for socket in map(lambda name: getattr(self, name), snames):
782 for socket in map(lambda name: getattr(self, name), snames):
783 if isinstance(socket, zmq.Socket) and not socket.closed:
783 if isinstance(socket, zmq.Socket) and not socket.closed:
784 socket.close()
784 socket.close()
785 self._closed = True
785 self._closed = True
786
786
787 def spin(self):
787 def spin(self):
788 """Flush any registration notifications and execution results
788 """Flush any registration notifications and execution results
789 waiting in the ZMQ queue.
789 waiting in the ZMQ queue.
790 """
790 """
791 if self._notification_socket:
791 if self._notification_socket:
792 self._flush_notifications()
792 self._flush_notifications()
793 if self._mux_socket:
793 if self._mux_socket:
794 self._flush_results(self._mux_socket)
794 self._flush_results(self._mux_socket)
795 if self._task_socket:
795 if self._task_socket:
796 self._flush_results(self._task_socket)
796 self._flush_results(self._task_socket)
797 if self._control_socket:
797 if self._control_socket:
798 self._flush_control(self._control_socket)
798 self._flush_control(self._control_socket)
799 if self._iopub_socket:
799 if self._iopub_socket:
800 self._flush_iopub(self._iopub_socket)
800 self._flush_iopub(self._iopub_socket)
801 if self._query_socket:
801 if self._query_socket:
802 self._flush_ignored_hub_replies()
802 self._flush_ignored_hub_replies()
803
803
804 def wait(self, jobs=None, timeout=-1):
804 def wait(self, jobs=None, timeout=-1):
805 """waits on one or more `jobs`, for up to `timeout` seconds.
805 """waits on one or more `jobs`, for up to `timeout` seconds.
806
806
807 Parameters
807 Parameters
808 ----------
808 ----------
809
809
810 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
810 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
811 ints are indices to self.history
811 ints are indices to self.history
812 strs are msg_ids
812 strs are msg_ids
813 default: wait on all outstanding messages
813 default: wait on all outstanding messages
814 timeout : float
814 timeout : float
815 a time in seconds, after which to give up.
815 a time in seconds, after which to give up.
816 default is -1, which means no timeout
816 default is -1, which means no timeout
817
817
818 Returns
818 Returns
819 -------
819 -------
820
820
821 True : when all msg_ids are done
821 True : when all msg_ids are done
822 False : timeout reached, some msg_ids still outstanding
822 False : timeout reached, some msg_ids still outstanding
823 """
823 """
824 tic = time.time()
824 tic = time.time()
825 if jobs is None:
825 if jobs is None:
826 theids = self.outstanding
826 theids = self.outstanding
827 else:
827 else:
828 if isinstance(jobs, (int, str, AsyncResult)):
828 if isinstance(jobs, (int, str, AsyncResult)):
829 jobs = [jobs]
829 jobs = [jobs]
830 theids = set()
830 theids = set()
831 for job in jobs:
831 for job in jobs:
832 if isinstance(job, int):
832 if isinstance(job, int):
833 # index access
833 # index access
834 job = self.history[job]
834 job = self.history[job]
835 elif isinstance(job, AsyncResult):
835 elif isinstance(job, AsyncResult):
836 map(theids.add, job.msg_ids)
836 map(theids.add, job.msg_ids)
837 continue
837 continue
838 theids.add(job)
838 theids.add(job)
839 if not theids.intersection(self.outstanding):
839 if not theids.intersection(self.outstanding):
840 return True
840 return True
841 self.spin()
841 self.spin()
842 while theids.intersection(self.outstanding):
842 while theids.intersection(self.outstanding):
843 if timeout >= 0 and ( time.time()-tic ) > timeout:
843 if timeout >= 0 and ( time.time()-tic ) > timeout:
844 break
844 break
845 time.sleep(1e-3)
845 time.sleep(1e-3)
846 self.spin()
846 self.spin()
847 return len(theids.intersection(self.outstanding)) == 0
847 return len(theids.intersection(self.outstanding)) == 0
848
848
849 #--------------------------------------------------------------------------
849 #--------------------------------------------------------------------------
850 # Control methods
850 # Control methods
851 #--------------------------------------------------------------------------
851 #--------------------------------------------------------------------------
852
852
853 @spin_first
853 @spin_first
854 def clear(self, targets=None, block=None):
854 def clear(self, targets=None, block=None):
855 """Clear the namespace in target(s)."""
855 """Clear the namespace in target(s)."""
856 block = self.block if block is None else block
856 block = self.block if block is None else block
857 targets = self._build_targets(targets)[0]
857 targets = self._build_targets(targets)[0]
858 for t in targets:
858 for t in targets:
859 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
859 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
860 error = False
860 error = False
861 if block:
861 if block:
862 self._flush_ignored_control()
862 self._flush_ignored_control()
863 for i in range(len(targets)):
863 for i in range(len(targets)):
864 idents,msg = self.session.recv(self._control_socket,0)
864 idents,msg = self.session.recv(self._control_socket,0)
865 if self.debug:
865 if self.debug:
866 pprint(msg)
866 pprint(msg)
867 if msg['content']['status'] != 'ok':
867 if msg['content']['status'] != 'ok':
868 error = self._unwrap_exception(msg['content'])
868 error = self._unwrap_exception(msg['content'])
869 else:
869 else:
870 self._ignored_control_replies += len(targets)
870 self._ignored_control_replies += len(targets)
871 if error:
871 if error:
872 raise error
872 raise error
873
873
874
874
875 @spin_first
875 @spin_first
876 def abort(self, jobs=None, targets=None, block=None):
876 def abort(self, jobs=None, targets=None, block=None):
877 """Abort specific jobs from the execution queues of target(s).
877 """Abort specific jobs from the execution queues of target(s).
878
878
879 This is a mechanism to prevent jobs that have already been submitted
879 This is a mechanism to prevent jobs that have already been submitted
880 from executing.
880 from executing.
881
881
882 Parameters
882 Parameters
883 ----------
883 ----------
884
884
885 jobs : msg_id, list of msg_ids, or AsyncResult
885 jobs : msg_id, list of msg_ids, or AsyncResult
886 The jobs to be aborted
886 The jobs to be aborted
887
887
888
888
889 """
889 """
890 block = self.block if block is None else block
890 block = self.block if block is None else block
891 targets = self._build_targets(targets)[0]
891 targets = self._build_targets(targets)[0]
892 msg_ids = []
892 msg_ids = []
893 if isinstance(jobs, (basestring,AsyncResult)):
893 if isinstance(jobs, (basestring,AsyncResult)):
894 jobs = [jobs]
894 jobs = [jobs]
895 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
895 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
896 if bad_ids:
896 if bad_ids:
897 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
897 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
898 for j in jobs:
898 for j in jobs:
899 if isinstance(j, AsyncResult):
899 if isinstance(j, AsyncResult):
900 msg_ids.extend(j.msg_ids)
900 msg_ids.extend(j.msg_ids)
901 else:
901 else:
902 msg_ids.append(j)
902 msg_ids.append(j)
903 content = dict(msg_ids=msg_ids)
903 content = dict(msg_ids=msg_ids)
904 for t in targets:
904 for t in targets:
905 self.session.send(self._control_socket, 'abort_request',
905 self.session.send(self._control_socket, 'abort_request',
906 content=content, ident=t)
906 content=content, ident=t)
907 error = False
907 error = False
908 if block:
908 if block:
909 self._flush_ignored_control()
909 self._flush_ignored_control()
910 for i in range(len(targets)):
910 for i in range(len(targets)):
911 idents,msg = self.session.recv(self._control_socket,0)
911 idents,msg = self.session.recv(self._control_socket,0)
912 if self.debug:
912 if self.debug:
913 pprint(msg)
913 pprint(msg)
914 if msg['content']['status'] != 'ok':
914 if msg['content']['status'] != 'ok':
915 error = self._unwrap_exception(msg['content'])
915 error = self._unwrap_exception(msg['content'])
916 else:
916 else:
917 self._ignored_control_replies += len(targets)
917 self._ignored_control_replies += len(targets)
918 if error:
918 if error:
919 raise error
919 raise error
920
920
921 @spin_first
921 @spin_first
922 def shutdown(self, targets=None, restart=False, hub=False, block=None):
922 def shutdown(self, targets=None, restart=False, hub=False, block=None):
923 """Terminates one or more engine processes, optionally including the hub."""
923 """Terminates one or more engine processes, optionally including the hub."""
924 block = self.block if block is None else block
924 block = self.block if block is None else block
925 if hub:
925 if hub:
926 targets = 'all'
926 targets = 'all'
927 targets = self._build_targets(targets)[0]
927 targets = self._build_targets(targets)[0]
928 for t in targets:
928 for t in targets:
929 self.session.send(self._control_socket, 'shutdown_request',
929 self.session.send(self._control_socket, 'shutdown_request',
930 content={'restart':restart},ident=t)
930 content={'restart':restart},ident=t)
931 error = False
931 error = False
932 if block or hub:
932 if block or hub:
933 self._flush_ignored_control()
933 self._flush_ignored_control()
934 for i in range(len(targets)):
934 for i in range(len(targets)):
935 idents,msg = self.session.recv(self._control_socket, 0)
935 idents,msg = self.session.recv(self._control_socket, 0)
936 if self.debug:
936 if self.debug:
937 pprint(msg)
937 pprint(msg)
938 if msg['content']['status'] != 'ok':
938 if msg['content']['status'] != 'ok':
939 error = self._unwrap_exception(msg['content'])
939 error = self._unwrap_exception(msg['content'])
940 else:
940 else:
941 self._ignored_control_replies += len(targets)
941 self._ignored_control_replies += len(targets)
942
942
943 if hub:
943 if hub:
944 time.sleep(0.25)
944 time.sleep(0.25)
945 self.session.send(self._query_socket, 'shutdown_request')
945 self.session.send(self._query_socket, 'shutdown_request')
946 idents,msg = self.session.recv(self._query_socket, 0)
946 idents,msg = self.session.recv(self._query_socket, 0)
947 if self.debug:
947 if self.debug:
948 pprint(msg)
948 pprint(msg)
949 if msg['content']['status'] != 'ok':
949 if msg['content']['status'] != 'ok':
950 error = self._unwrap_exception(msg['content'])
950 error = self._unwrap_exception(msg['content'])
951
951
952 if error:
952 if error:
953 raise error
953 raise error
954
954
955 #--------------------------------------------------------------------------
955 #--------------------------------------------------------------------------
956 # Execution related methods
956 # Execution related methods
957 #--------------------------------------------------------------------------
957 #--------------------------------------------------------------------------
958
958
959 def _maybe_raise(self, result):
959 def _maybe_raise(self, result):
960 """wrapper for maybe raising an exception if apply failed."""
960 """wrapper for maybe raising an exception if apply failed."""
961 if isinstance(result, error.RemoteError):
961 if isinstance(result, error.RemoteError):
962 raise result
962 raise result
963
963
964 return result
964 return result
965
965
966 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
966 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
967 ident=None):
967 ident=None):
968 """construct and send an apply message via a socket.
968 """construct and send an apply message via a socket.
969
969
970 This is the principal method with which all engine execution is performed by views.
970 This is the principal method with which all engine execution is performed by views.
971 """
971 """
972
972
973 assert not self._closed, "cannot use me anymore, I'm closed!"
973 assert not self._closed, "cannot use me anymore, I'm closed!"
974 # defaults:
974 # defaults:
975 args = args if args is not None else []
975 args = args if args is not None else []
976 kwargs = kwargs if kwargs is not None else {}
976 kwargs = kwargs if kwargs is not None else {}
977 subheader = subheader if subheader is not None else {}
977 subheader = subheader if subheader is not None else {}
978
978
979 # validate arguments
979 # validate arguments
980 if not callable(f):
980 if not callable(f):
981 raise TypeError("f must be callable, not %s"%type(f))
981 raise TypeError("f must be callable, not %s"%type(f))
982 if not isinstance(args, (tuple, list)):
982 if not isinstance(args, (tuple, list)):
983 raise TypeError("args must be tuple or list, not %s"%type(args))
983 raise TypeError("args must be tuple or list, not %s"%type(args))
984 if not isinstance(kwargs, dict):
984 if not isinstance(kwargs, dict):
985 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
985 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
986 if not isinstance(subheader, dict):
986 if not isinstance(subheader, dict):
987 raise TypeError("subheader must be dict, not %s"%type(subheader))
987 raise TypeError("subheader must be dict, not %s"%type(subheader))
988
988
989 bufs = util.pack_apply_message(f,args,kwargs)
989 bufs = util.pack_apply_message(f,args,kwargs)
990
990
991 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
991 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
992 subheader=subheader, track=track)
992 subheader=subheader, track=track)
993
993
994 msg_id = msg['msg_id']
994 msg_id = msg['msg_id']
995 self.outstanding.add(msg_id)
995 self.outstanding.add(msg_id)
996 if ident:
996 if ident:
997 # possibly routed to a specific engine
997 # possibly routed to a specific engine
998 if isinstance(ident, list):
998 if isinstance(ident, list):
999 ident = ident[-1]
999 ident = ident[-1]
1000 if ident in self._engines.values():
1000 if ident in self._engines.values():
1001 # save for later, in case of engine death
1001 # save for later, in case of engine death
1002 self._outstanding_dict[ident].add(msg_id)
1002 self._outstanding_dict[ident].add(msg_id)
1003 self.history.append(msg_id)
1003 self.history.append(msg_id)
1004 self.metadata[msg_id]['submitted'] = datetime.now()
1004 self.metadata[msg_id]['submitted'] = datetime.now()
1005
1005
1006 return msg
1006 return msg
1007
1007
1008 #--------------------------------------------------------------------------
1008 #--------------------------------------------------------------------------
1009 # construct a View object
1009 # construct a View object
1010 #--------------------------------------------------------------------------
1010 #--------------------------------------------------------------------------
1011
1011
1012 def load_balanced_view(self, targets=None):
1012 def load_balanced_view(self, targets=None):
1013 """construct a DirectView object.
1013 """construct a DirectView object.
1014
1014
1015 If no arguments are specified, create a LoadBalancedView
1015 If no arguments are specified, create a LoadBalancedView
1016 using all engines.
1016 using all engines.
1017
1017
1018 Parameters
1018 Parameters
1019 ----------
1019 ----------
1020
1020
1021 targets: list,slice,int,etc. [default: use all engines]
1021 targets: list,slice,int,etc. [default: use all engines]
1022 The subset of engines across which to load-balance
1022 The subset of engines across which to load-balance
1023 """
1023 """
1024 if targets is not None:
1024 if targets is not None:
1025 targets = self._build_targets(targets)[1]
1025 targets = self._build_targets(targets)[1]
1026 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1026 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1027
1027
1028 def direct_view(self, targets='all'):
1028 def direct_view(self, targets='all'):
1029 """construct a DirectView object.
1029 """construct a DirectView object.
1030
1030
1031 If no targets are specified, create a DirectView
1031 If no targets are specified, create a DirectView
1032 using all engines.
1032 using all engines.
1033
1033
1034 Parameters
1034 Parameters
1035 ----------
1035 ----------
1036
1036
1037 targets: list,slice,int,etc. [default: use all engines]
1037 targets: list,slice,int,etc. [default: use all engines]
1038 The engines to use for the View
1038 The engines to use for the View
1039 """
1039 """
1040 single = isinstance(targets, int)
1040 single = isinstance(targets, int)
1041 targets = self._build_targets(targets)[1]
1041 targets = self._build_targets(targets)[1]
1042 if single:
1042 if single:
1043 targets = targets[0]
1043 targets = targets[0]
1044 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1044 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1045
1045
1046 #--------------------------------------------------------------------------
1046 #--------------------------------------------------------------------------
1047 # Query methods
1047 # Query methods
1048 #--------------------------------------------------------------------------
1048 #--------------------------------------------------------------------------
1049
1049
1050 @spin_first
1050 @spin_first
1051 def get_result(self, indices_or_msg_ids=None, block=None):
1051 def get_result(self, indices_or_msg_ids=None, block=None):
1052 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1052 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1053
1053
1054 If the client already has the results, no request to the Hub will be made.
1054 If the client already has the results, no request to the Hub will be made.
1055
1055
1056 This is a convenient way to construct AsyncResult objects, which are wrappers
1056 This is a convenient way to construct AsyncResult objects, which are wrappers
1057 that include metadata about execution, and allow for awaiting results that
1057 that include metadata about execution, and allow for awaiting results that
1058 were not submitted by this Client.
1058 were not submitted by this Client.
1059
1059
1060 It can also be a convenient way to retrieve the metadata associated with
1060 It can also be a convenient way to retrieve the metadata associated with
1061 blocking execution, since it always retrieves
1061 blocking execution, since it always retrieves
1062
1062
1063 Examples
1063 Examples
1064 --------
1064 --------
1065 ::
1065 ::
1066
1066
1067 In [10]: r = client.apply()
1067 In [10]: r = client.apply()
1068
1068
1069 Parameters
1069 Parameters
1070 ----------
1070 ----------
1071
1071
1072 indices_or_msg_ids : integer history index, str msg_id, or list of either
1072 indices_or_msg_ids : integer history index, str msg_id, or list of either
1073 The indices or msg_ids of indices to be retrieved
1073 The indices or msg_ids of indices to be retrieved
1074
1074
1075 block : bool
1075 block : bool
1076 Whether to wait for the result to be done
1076 Whether to wait for the result to be done
1077
1077
1078 Returns
1078 Returns
1079 -------
1079 -------
1080
1080
1081 AsyncResult
1081 AsyncResult
1082 A single AsyncResult object will always be returned.
1082 A single AsyncResult object will always be returned.
1083
1083
1084 AsyncHubResult
1084 AsyncHubResult
1085 A subclass of AsyncResult that retrieves results from the Hub
1085 A subclass of AsyncResult that retrieves results from the Hub
1086
1086
1087 """
1087 """
1088 block = self.block if block is None else block
1088 block = self.block if block is None else block
1089 if indices_or_msg_ids is None:
1089 if indices_or_msg_ids is None:
1090 indices_or_msg_ids = -1
1090 indices_or_msg_ids = -1
1091
1091
1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1093 indices_or_msg_ids = [indices_or_msg_ids]
1093 indices_or_msg_ids = [indices_or_msg_ids]
1094
1094
1095 theids = []
1095 theids = []
1096 for id in indices_or_msg_ids:
1096 for id in indices_or_msg_ids:
1097 if isinstance(id, int):
1097 if isinstance(id, int):
1098 id = self.history[id]
1098 id = self.history[id]
1099 if not isinstance(id, str):
1099 if not isinstance(id, str):
1100 raise TypeError("indices must be str or int, not %r"%id)
1100 raise TypeError("indices must be str or int, not %r"%id)
1101 theids.append(id)
1101 theids.append(id)
1102
1102
1103 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1103 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1104 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1104 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1105
1105
1106 if remote_ids:
1106 if remote_ids:
1107 ar = AsyncHubResult(self, msg_ids=theids)
1107 ar = AsyncHubResult(self, msg_ids=theids)
1108 else:
1108 else:
1109 ar = AsyncResult(self, msg_ids=theids)
1109 ar = AsyncResult(self, msg_ids=theids)
1110
1110
1111 if block:
1111 if block:
1112 ar.wait()
1112 ar.wait()
1113
1113
1114 return ar
1114 return ar
1115
1115
1116 @spin_first
1116 @spin_first
1117 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1117 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1118 """Resubmit one or more tasks.
1118 """Resubmit one or more tasks.
1119
1119
1120 in-flight tasks may not be resubmitted.
1120 in-flight tasks may not be resubmitted.
1121
1121
1122 Parameters
1122 Parameters
1123 ----------
1123 ----------
1124
1124
1125 indices_or_msg_ids : integer history index, str msg_id, or list of either
1125 indices_or_msg_ids : integer history index, str msg_id, or list of either
1126 The indices or msg_ids of indices to be retrieved
1126 The indices or msg_ids of indices to be retrieved
1127
1127
1128 block : bool
1128 block : bool
1129 Whether to wait for the result to be done
1129 Whether to wait for the result to be done
1130
1130
1131 Returns
1131 Returns
1132 -------
1132 -------
1133
1133
1134 AsyncHubResult
1134 AsyncHubResult
1135 A subclass of AsyncResult that retrieves results from the Hub
1135 A subclass of AsyncResult that retrieves results from the Hub
1136
1136
1137 """
1137 """
1138 block = self.block if block is None else block
1138 block = self.block if block is None else block
1139 if indices_or_msg_ids is None:
1139 if indices_or_msg_ids is None:
1140 indices_or_msg_ids = -1
1140 indices_or_msg_ids = -1
1141
1141
1142 if not isinstance(indices_or_msg_ids, (list,tuple)):
1142 if not isinstance(indices_or_msg_ids, (list,tuple)):
1143 indices_or_msg_ids = [indices_or_msg_ids]
1143 indices_or_msg_ids = [indices_or_msg_ids]
1144
1144
1145 theids = []
1145 theids = []
1146 for id in indices_or_msg_ids:
1146 for id in indices_or_msg_ids:
1147 if isinstance(id, int):
1147 if isinstance(id, int):
1148 id = self.history[id]
1148 id = self.history[id]
1149 if not isinstance(id, str):
1149 if not isinstance(id, str):
1150 raise TypeError("indices must be str or int, not %r"%id)
1150 raise TypeError("indices must be str or int, not %r"%id)
1151 theids.append(id)
1151 theids.append(id)
1152
1152
1153 for msg_id in theids:
1153 for msg_id in theids:
1154 self.outstanding.discard(msg_id)
1154 self.outstanding.discard(msg_id)
1155 if msg_id in self.history:
1155 if msg_id in self.history:
1156 self.history.remove(msg_id)
1156 self.history.remove(msg_id)
1157 self.results.pop(msg_id, None)
1157 self.results.pop(msg_id, None)
1158 self.metadata.pop(msg_id, None)
1158 self.metadata.pop(msg_id, None)
1159 content = dict(msg_ids = theids)
1159 content = dict(msg_ids = theids)
1160
1160
1161 self.session.send(self._query_socket, 'resubmit_request', content)
1161 self.session.send(self._query_socket, 'resubmit_request', content)
1162
1162
1163 zmq.select([self._query_socket], [], [])
1163 zmq.select([self._query_socket], [], [])
1164 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1164 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1165 if self.debug:
1165 if self.debug:
1166 pprint(msg)
1166 pprint(msg)
1167 content = msg['content']
1167 content = msg['content']
1168 if content['status'] != 'ok':
1168 if content['status'] != 'ok':
1169 raise self._unwrap_exception(content)
1169 raise self._unwrap_exception(content)
1170
1170
1171 ar = AsyncHubResult(self, msg_ids=theids)
1171 ar = AsyncHubResult(self, msg_ids=theids)
1172
1172
1173 if block:
1173 if block:
1174 ar.wait()
1174 ar.wait()
1175
1175
1176 return ar
1176 return ar
1177
1177
1178 @spin_first
1178 @spin_first
1179 def result_status(self, msg_ids, status_only=True):
1179 def result_status(self, msg_ids, status_only=True):
1180 """Check on the status of the result(s) of the apply request with `msg_ids`.
1180 """Check on the status of the result(s) of the apply request with `msg_ids`.
1181
1181
1182 If status_only is False, then the actual results will be retrieved, else
1182 If status_only is False, then the actual results will be retrieved, else
1183 only the status of the results will be checked.
1183 only the status of the results will be checked.
1184
1184
1185 Parameters
1185 Parameters
1186 ----------
1186 ----------
1187
1187
1188 msg_ids : list of msg_ids
1188 msg_ids : list of msg_ids
1189 if int:
1189 if int:
1190 Passed as index to self.history for convenience.
1190 Passed as index to self.history for convenience.
1191 status_only : bool (default: True)
1191 status_only : bool (default: True)
1192 if False:
1192 if False:
1193 Retrieve the actual results of completed tasks.
1193 Retrieve the actual results of completed tasks.
1194
1194
1195 Returns
1195 Returns
1196 -------
1196 -------
1197
1197
1198 results : dict
1198 results : dict
1199 There will always be the keys 'pending' and 'completed', which will
1199 There will always be the keys 'pending' and 'completed', which will
1200 be lists of msg_ids that are incomplete or complete. If `status_only`
1200 be lists of msg_ids that are incomplete or complete. If `status_only`
1201 is False, then completed results will be keyed by their `msg_id`.
1201 is False, then completed results will be keyed by their `msg_id`.
1202 """
1202 """
1203 if not isinstance(msg_ids, (list,tuple)):
1203 if not isinstance(msg_ids, (list,tuple)):
1204 msg_ids = [msg_ids]
1204 msg_ids = [msg_ids]
1205
1205
1206 theids = []
1206 theids = []
1207 for msg_id in msg_ids:
1207 for msg_id in msg_ids:
1208 if isinstance(msg_id, int):
1208 if isinstance(msg_id, int):
1209 msg_id = self.history[msg_id]
1209 msg_id = self.history[msg_id]
1210 if not isinstance(msg_id, basestring):
1210 if not isinstance(msg_id, basestring):
1211 raise TypeError("msg_ids must be str, not %r"%msg_id)
1211 raise TypeError("msg_ids must be str, not %r"%msg_id)
1212 theids.append(msg_id)
1212 theids.append(msg_id)
1213
1213
1214 completed = []
1214 completed = []
1215 local_results = {}
1215 local_results = {}
1216
1216
1217 # comment this block out to temporarily disable local shortcut:
1217 # comment this block out to temporarily disable local shortcut:
1218 for msg_id in theids:
1218 for msg_id in theids:
1219 if msg_id in self.results:
1219 if msg_id in self.results:
1220 completed.append(msg_id)
1220 completed.append(msg_id)
1221 local_results[msg_id] = self.results[msg_id]
1221 local_results[msg_id] = self.results[msg_id]
1222 theids.remove(msg_id)
1222 theids.remove(msg_id)
1223
1223
1224 if theids: # some not locally cached
1224 if theids: # some not locally cached
1225 content = dict(msg_ids=theids, status_only=status_only)
1225 content = dict(msg_ids=theids, status_only=status_only)
1226 msg = self.session.send(self._query_socket, "result_request", content=content)
1226 msg = self.session.send(self._query_socket, "result_request", content=content)
1227 zmq.select([self._query_socket], [], [])
1227 zmq.select([self._query_socket], [], [])
1228 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1228 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1229 if self.debug:
1229 if self.debug:
1230 pprint(msg)
1230 pprint(msg)
1231 content = msg['content']
1231 content = msg['content']
1232 if content['status'] != 'ok':
1232 if content['status'] != 'ok':
1233 raise self._unwrap_exception(content)
1233 raise self._unwrap_exception(content)
1234 buffers = msg['buffers']
1234 buffers = msg['buffers']
1235 else:
1235 else:
1236 content = dict(completed=[],pending=[])
1236 content = dict(completed=[],pending=[])
1237
1237
1238 content['completed'].extend(completed)
1238 content['completed'].extend(completed)
1239
1239
1240 if status_only:
1240 if status_only:
1241 return content
1241 return content
1242
1242
1243 failures = []
1243 failures = []
1244 # load cached results into result:
1244 # load cached results into result:
1245 content.update(local_results)
1245 content.update(local_results)
1246
1246
1247 # update cache with results:
1247 # update cache with results:
1248 for msg_id in sorted(theids):
1248 for msg_id in sorted(theids):
1249 if msg_id in content['completed']:
1249 if msg_id in content['completed']:
1250 rec = content[msg_id]
1250 rec = content[msg_id]
1251 parent = rec['header']
1251 parent = rec['header']
1252 header = rec['result_header']
1252 header = rec['result_header']
1253 rcontent = rec['result_content']
1253 rcontent = rec['result_content']
1254 iodict = rec['io']
1254 iodict = rec['io']
1255 if isinstance(rcontent, str):
1255 if isinstance(rcontent, str):
1256 rcontent = self.session.unpack(rcontent)
1256 rcontent = self.session.unpack(rcontent)
1257
1257
1258 md = self.metadata[msg_id]
1258 md = self.metadata[msg_id]
1259 md.update(self._extract_metadata(header, parent, rcontent))
1259 md.update(self._extract_metadata(header, parent, rcontent))
1260 md.update(iodict)
1260 md.update(iodict)
1261
1261
1262 if rcontent['status'] == 'ok':
1262 if rcontent['status'] == 'ok':
1263 res,buffers = util.unserialize_object(buffers)
1263 res,buffers = util.unserialize_object(buffers)
1264 else:
1264 else:
1265 print rcontent
1265 print rcontent
1266 res = self._unwrap_exception(rcontent)
1266 res = self._unwrap_exception(rcontent)
1267 failures.append(res)
1267 failures.append(res)
1268
1268
1269 self.results[msg_id] = res
1269 self.results[msg_id] = res
1270 content[msg_id] = res
1270 content[msg_id] = res
1271
1271
1272 if len(theids) == 1 and failures:
1272 if len(theids) == 1 and failures:
1273 raise failures[0]
1273 raise failures[0]
1274
1274
1275 error.collect_exceptions(failures, "result_status")
1275 error.collect_exceptions(failures, "result_status")
1276 return content
1276 return content
1277
1277
1278 @spin_first
1278 @spin_first
1279 def queue_status(self, targets='all', verbose=False):
1279 def queue_status(self, targets='all', verbose=False):
1280 """Fetch the status of engine queues.
1280 """Fetch the status of engine queues.
1281
1281
1282 Parameters
1282 Parameters
1283 ----------
1283 ----------
1284
1284
1285 targets : int/str/list of ints/strs
1285 targets : int/str/list of ints/strs
1286 the engines whose states are to be queried.
1286 the engines whose states are to be queried.
1287 default : all
1287 default : all
1288 verbose : bool
1288 verbose : bool
1289 Whether to return lengths only, or lists of ids for each element
1289 Whether to return lengths only, or lists of ids for each element
1290 """
1290 """
1291 engine_ids = self._build_targets(targets)[1]
1291 engine_ids = self._build_targets(targets)[1]
1292 content = dict(targets=engine_ids, verbose=verbose)
1292 content = dict(targets=engine_ids, verbose=verbose)
1293 self.session.send(self._query_socket, "queue_request", content=content)
1293 self.session.send(self._query_socket, "queue_request", content=content)
1294 idents,msg = self.session.recv(self._query_socket, 0)
1294 idents,msg = self.session.recv(self._query_socket, 0)
1295 if self.debug:
1295 if self.debug:
1296 pprint(msg)
1296 pprint(msg)
1297 content = msg['content']
1297 content = msg['content']
1298 status = content.pop('status')
1298 status = content.pop('status')
1299 if status != 'ok':
1299 if status != 'ok':
1300 raise self._unwrap_exception(content)
1300 raise self._unwrap_exception(content)
1301 content = rekey(content)
1301 content = rekey(content)
1302 if isinstance(targets, int):
1302 if isinstance(targets, int):
1303 return content[targets]
1303 return content[targets]
1304 else:
1304 else:
1305 return content
1305 return content
1306
1306
1307 @spin_first
1307 @spin_first
1308 def purge_results(self, jobs=[], targets=[]):
1308 def purge_results(self, jobs=[], targets=[]):
1309 """Tell the Hub to forget results.
1309 """Tell the Hub to forget results.
1310
1310
1311 Individual results can be purged by msg_id, or the entire
1311 Individual results can be purged by msg_id, or the entire
1312 history of specific targets can be purged.
1312 history of specific targets can be purged.
1313
1313
1314 Use `purge_results('all')` to scrub everything from the Hub's db.
1314 Use `purge_results('all')` to scrub everything from the Hub's db.
1315
1315
1316 Parameters
1316 Parameters
1317 ----------
1317 ----------
1318
1318
1319 jobs : str or list of str or AsyncResult objects
1319 jobs : str or list of str or AsyncResult objects
1320 the msg_ids whose results should be forgotten.
1320 the msg_ids whose results should be forgotten.
1321 targets : int/str/list of ints/strs
1321 targets : int/str/list of ints/strs
1322 The targets, by int_id, whose entire history is to be purged.
1322 The targets, by int_id, whose entire history is to be purged.
1323
1323
1324 default : None
1324 default : None
1325 """
1325 """
1326 if not targets and not jobs:
1326 if not targets and not jobs:
1327 raise ValueError("Must specify at least one of `targets` and `jobs`")
1327 raise ValueError("Must specify at least one of `targets` and `jobs`")
1328 if targets:
1328 if targets:
1329 targets = self._build_targets(targets)[1]
1329 targets = self._build_targets(targets)[1]
1330
1330
1331 # construct msg_ids from jobs
1331 # construct msg_ids from jobs
1332 if jobs == 'all':
1332 if jobs == 'all':
1333 msg_ids = jobs
1333 msg_ids = jobs
1334 else:
1334 else:
1335 msg_ids = []
1335 msg_ids = []
1336 if isinstance(jobs, (basestring,AsyncResult)):
1336 if isinstance(jobs, (basestring,AsyncResult)):
1337 jobs = [jobs]
1337 jobs = [jobs]
1338 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1338 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1339 if bad_ids:
1339 if bad_ids:
1340 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1340 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1341 for j in jobs:
1341 for j in jobs:
1342 if isinstance(j, AsyncResult):
1342 if isinstance(j, AsyncResult):
1343 msg_ids.extend(j.msg_ids)
1343 msg_ids.extend(j.msg_ids)
1344 else:
1344 else:
1345 msg_ids.append(j)
1345 msg_ids.append(j)
1346
1346
1347 content = dict(engine_ids=targets, msg_ids=msg_ids)
1347 content = dict(engine_ids=targets, msg_ids=msg_ids)
1348 self.session.send(self._query_socket, "purge_request", content=content)
1348 self.session.send(self._query_socket, "purge_request", content=content)
1349 idents, msg = self.session.recv(self._query_socket, 0)
1349 idents, msg = self.session.recv(self._query_socket, 0)
1350 if self.debug:
1350 if self.debug:
1351 pprint(msg)
1351 pprint(msg)
1352 content = msg['content']
1352 content = msg['content']
1353 if content['status'] != 'ok':
1353 if content['status'] != 'ok':
1354 raise self._unwrap_exception(content)
1354 raise self._unwrap_exception(content)
1355
1355
1356 @spin_first
1356 @spin_first
1357 def hub_history(self):
1357 def hub_history(self):
1358 """Get the Hub's history
1358 """Get the Hub's history
1359
1359
1360 Just like the Client, the Hub has a history, which is a list of msg_ids.
1360 Just like the Client, the Hub has a history, which is a list of msg_ids.
1361 This will contain the history of all clients, and, depending on configuration,
1361 This will contain the history of all clients, and, depending on configuration,
1362 may contain history across multiple cluster sessions.
1362 may contain history across multiple cluster sessions.
1363
1363
1364 Any msg_id returned here is a valid argument to `get_result`.
1364 Any msg_id returned here is a valid argument to `get_result`.
1365
1365
1366 Returns
1366 Returns
1367 -------
1367 -------
1368
1368
1369 msg_ids : list of strs
1369 msg_ids : list of strs
1370 list of all msg_ids, ordered by task submission time.
1370 list of all msg_ids, ordered by task submission time.
1371 """
1371 """
1372
1372
1373 self.session.send(self._query_socket, "history_request", content={})
1373 self.session.send(self._query_socket, "history_request", content={})
1374 idents, msg = self.session.recv(self._query_socket, 0)
1374 idents, msg = self.session.recv(self._query_socket, 0)
1375
1375
1376 if self.debug:
1376 if self.debug:
1377 pprint(msg)
1377 pprint(msg)
1378 content = msg['content']
1378 content = msg['content']
1379 if content['status'] != 'ok':
1379 if content['status'] != 'ok':
1380 raise self._unwrap_exception(content)
1380 raise self._unwrap_exception(content)
1381 else:
1381 else:
1382 return content['history']
1382 return content['history']
1383
1383
1384 @spin_first
1384 @spin_first
1385 def db_query(self, query, keys=None):
1385 def db_query(self, query, keys=None):
1386 """Query the Hub's TaskRecord database
1386 """Query the Hub's TaskRecord database
1387
1387
1388 This will return a list of task record dicts that match `query`
1388 This will return a list of task record dicts that match `query`
1389
1389
1390 Parameters
1390 Parameters
1391 ----------
1391 ----------
1392
1392
1393 query : mongodb query dict
1393 query : mongodb query dict
1394 The search dict. See mongodb query docs for details.
1394 The search dict. See mongodb query docs for details.
1395 keys : list of strs [optional]
1395 keys : list of strs [optional]
1396 The subset of keys to be returned. The default is to fetch everything but buffers.
1396 The subset of keys to be returned. The default is to fetch everything but buffers.
1397 'msg_id' will *always* be included.
1397 'msg_id' will *always* be included.
1398 """
1398 """
1399 if isinstance(keys, basestring):
1399 if isinstance(keys, basestring):
1400 keys = [keys]
1400 keys = [keys]
1401 content = dict(query=query, keys=keys)
1401 content = dict(query=query, keys=keys)
1402 self.session.send(self._query_socket, "db_request", content=content)
1402 self.session.send(self._query_socket, "db_request", content=content)
1403 idents, msg = self.session.recv(self._query_socket, 0)
1403 idents, msg = self.session.recv(self._query_socket, 0)
1404 if self.debug:
1404 if self.debug:
1405 pprint(msg)
1405 pprint(msg)
1406 content = msg['content']
1406 content = msg['content']
1407 if content['status'] != 'ok':
1407 if content['status'] != 'ok':
1408 raise self._unwrap_exception(content)
1408 raise self._unwrap_exception(content)
1409
1409
1410 records = content['records']
1410 records = content['records']
1411
1411
1412 buffer_lens = content['buffer_lens']
1412 buffer_lens = content['buffer_lens']
1413 result_buffer_lens = content['result_buffer_lens']
1413 result_buffer_lens = content['result_buffer_lens']
1414 buffers = msg['buffers']
1414 buffers = msg['buffers']
1415 has_bufs = buffer_lens is not None
1415 has_bufs = buffer_lens is not None
1416 has_rbufs = result_buffer_lens is not None
1416 has_rbufs = result_buffer_lens is not None
1417 for i,rec in enumerate(records):
1417 for i,rec in enumerate(records):
1418 # relink buffers
1418 # relink buffers
1419 if has_bufs:
1419 if has_bufs:
1420 blen = buffer_lens[i]
1420 blen = buffer_lens[i]
1421 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1421 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1422 if has_rbufs:
1422 if has_rbufs:
1423 blen = result_buffer_lens[i]
1423 blen = result_buffer_lens[i]
1424 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1424 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1425
1425
1426 return records
1426 return records
1427
1427
1428 __all__ = [ 'Client' ]
1428 __all__ = [ 'Client' ]
@@ -1,173 +1,173 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 and hearts are tracked based on their XREQ identities.
4 and hearts are tracked based on their XREQ identities.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 from __future__ import print_function
17 from __future__ import print_function
18 import time
18 import time
19 import uuid
19 import uuid
20
20
21 import zmq
21 import zmq
22 from zmq.devices import ThreadDevice
22 from zmq.devices import ThreadDevice
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 from IPython.config.configurable import LoggingConfigurable
25 from IPython.config.configurable import LoggingConfigurable
26 from IPython.utils.traitlets import Set, Instance, CFloat
26 from IPython.utils.traitlets import Set, Instance, CFloat
27
27
28 from IPython.parallel.util import ensure_bytes
28 from IPython.parallel.util import asbytes
29
29
30 class Heart(object):
30 class Heart(object):
31 """A basic heart object for responding to a HeartMonitor.
31 """A basic heart object for responding to a HeartMonitor.
32 This is a simple wrapper with defaults for the most common
32 This is a simple wrapper with defaults for the most common
33 Device model for responding to heartbeats.
33 Device model for responding to heartbeats.
34
34
35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 SUB/XREQ for in/out.
36 SUB/XREQ for in/out.
37
37
38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
39 device=None
39 device=None
40 id=None
40 id=None
41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
43 self.device.daemon=True
43 self.device.daemon=True
44 self.device.connect_in(in_addr)
44 self.device.connect_in(in_addr)
45 self.device.connect_out(out_addr)
45 self.device.connect_out(out_addr)
46 if in_type == zmq.SUB:
46 if in_type == zmq.SUB:
47 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
47 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
48 if heart_id is None:
48 if heart_id is None:
49 heart_id = uuid.uuid4().bytes
49 heart_id = uuid.uuid4().bytes
50 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
50 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
51 self.id = heart_id
51 self.id = heart_id
52
52
53 def start(self):
53 def start(self):
54 return self.device.start()
54 return self.device.start()
55
55
56 class HeartMonitor(LoggingConfigurable):
56 class HeartMonitor(LoggingConfigurable):
57 """A basic HeartMonitor class
57 """A basic HeartMonitor class
58 pingstream: a PUB stream
58 pingstream: a PUB stream
59 pongstream: an XREP stream
59 pongstream: an XREP stream
60 period: the period of the heartbeat in milliseconds"""
60 period: the period of the heartbeat in milliseconds"""
61
61
62 period=CFloat(1000, config=True,
62 period=CFloat(1000, config=True,
63 help='The frequency at which the Hub pings the engines for heartbeats '
63 help='The frequency at which the Hub pings the engines for heartbeats '
64 ' (in ms) [default: 100]',
64 ' (in ms) [default: 100]',
65 )
65 )
66
66
67 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
67 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
69 loop = Instance('zmq.eventloop.ioloop.IOLoop')
69 loop = Instance('zmq.eventloop.ioloop.IOLoop')
70 def _loop_default(self):
70 def _loop_default(self):
71 return ioloop.IOLoop.instance()
71 return ioloop.IOLoop.instance()
72
72
73 # not settable:
73 # not settable:
74 hearts=Set()
74 hearts=Set()
75 responses=Set()
75 responses=Set()
76 on_probation=Set()
76 on_probation=Set()
77 last_ping=CFloat(0)
77 last_ping=CFloat(0)
78 _new_handlers = Set()
78 _new_handlers = Set()
79 _failure_handlers = Set()
79 _failure_handlers = Set()
80 lifetime = CFloat(0)
80 lifetime = CFloat(0)
81 tic = CFloat(0)
81 tic = CFloat(0)
82
82
83 def __init__(self, **kwargs):
83 def __init__(self, **kwargs):
84 super(HeartMonitor, self).__init__(**kwargs)
84 super(HeartMonitor, self).__init__(**kwargs)
85
85
86 self.pongstream.on_recv(self.handle_pong)
86 self.pongstream.on_recv(self.handle_pong)
87
87
88 def start(self):
88 def start(self):
89 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
89 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
90 self.caller.start()
90 self.caller.start()
91
91
92 def add_new_heart_handler(self, handler):
92 def add_new_heart_handler(self, handler):
93 """add a new handler for new hearts"""
93 """add a new handler for new hearts"""
94 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
94 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
95 self._new_handlers.add(handler)
95 self._new_handlers.add(handler)
96
96
97 def add_heart_failure_handler(self, handler):
97 def add_heart_failure_handler(self, handler):
98 """add a new handler for heart failure"""
98 """add a new handler for heart failure"""
99 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
99 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
100 self._failure_handlers.add(handler)
100 self._failure_handlers.add(handler)
101
101
102 def beat(self):
102 def beat(self):
103 self.pongstream.flush()
103 self.pongstream.flush()
104 self.last_ping = self.lifetime
104 self.last_ping = self.lifetime
105
105
106 toc = time.time()
106 toc = time.time()
107 self.lifetime += toc-self.tic
107 self.lifetime += toc-self.tic
108 self.tic = toc
108 self.tic = toc
109 # self.log.debug("heartbeat::%s"%self.lifetime)
109 # self.log.debug("heartbeat::%s"%self.lifetime)
110 goodhearts = self.hearts.intersection(self.responses)
110 goodhearts = self.hearts.intersection(self.responses)
111 missed_beats = self.hearts.difference(goodhearts)
111 missed_beats = self.hearts.difference(goodhearts)
112 heartfailures = self.on_probation.intersection(missed_beats)
112 heartfailures = self.on_probation.intersection(missed_beats)
113 newhearts = self.responses.difference(goodhearts)
113 newhearts = self.responses.difference(goodhearts)
114 map(self.handle_new_heart, newhearts)
114 map(self.handle_new_heart, newhearts)
115 map(self.handle_heart_failure, heartfailures)
115 map(self.handle_heart_failure, heartfailures)
116 self.on_probation = missed_beats.intersection(self.hearts)
116 self.on_probation = missed_beats.intersection(self.hearts)
117 self.responses = set()
117 self.responses = set()
118 # print self.on_probation, self.hearts
118 # print self.on_probation, self.hearts
119 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
119 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
120 self.pingstream.send(ensure_bytes(str(self.lifetime)))
120 self.pingstream.send(asbytes(str(self.lifetime)))
121
121
122 def handle_new_heart(self, heart):
122 def handle_new_heart(self, heart):
123 if self._new_handlers:
123 if self._new_handlers:
124 for handler in self._new_handlers:
124 for handler in self._new_handlers:
125 handler(heart)
125 handler(heart)
126 else:
126 else:
127 self.log.info("heartbeat::yay, got new heart %s!"%heart)
127 self.log.info("heartbeat::yay, got new heart %s!"%heart)
128 self.hearts.add(heart)
128 self.hearts.add(heart)
129
129
130 def handle_heart_failure(self, heart):
130 def handle_heart_failure(self, heart):
131 if self._failure_handlers:
131 if self._failure_handlers:
132 for handler in self._failure_handlers:
132 for handler in self._failure_handlers:
133 try:
133 try:
134 handler(heart)
134 handler(heart)
135 except Exception as e:
135 except Exception as e:
136 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
136 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
137 pass
137 pass
138 else:
138 else:
139 self.log.info("heartbeat::Heart %s failed :("%heart)
139 self.log.info("heartbeat::Heart %s failed :("%heart)
140 self.hearts.remove(heart)
140 self.hearts.remove(heart)
141
141
142
142
143 def handle_pong(self, msg):
143 def handle_pong(self, msg):
144 "a heart just beat"
144 "a heart just beat"
145 current = ensure_bytes(str(self.lifetime))
145 current = asbytes(str(self.lifetime))
146 last = ensure_bytes(str(self.last_ping))
146 last = asbytes(str(self.last_ping))
147 if msg[1] == current:
147 if msg[1] == current:
148 delta = time.time()-self.tic
148 delta = time.time()-self.tic
149 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
149 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
150 self.responses.add(msg[0])
150 self.responses.add(msg[0])
151 elif msg[1] == last:
151 elif msg[1] == last:
152 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
152 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
153 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
153 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
154 self.responses.add(msg[0])
154 self.responses.add(msg[0])
155 else:
155 else:
156 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
156 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
157 (msg[1],self.lifetime))
157 (msg[1],self.lifetime))
158
158
159
159
160 if __name__ == '__main__':
160 if __name__ == '__main__':
161 loop = ioloop.IOLoop.instance()
161 loop = ioloop.IOLoop.instance()
162 context = zmq.Context()
162 context = zmq.Context()
163 pub = context.socket(zmq.PUB)
163 pub = context.socket(zmq.PUB)
164 pub.bind('tcp://127.0.0.1:5555')
164 pub.bind('tcp://127.0.0.1:5555')
165 xrep = context.socket(zmq.XREP)
165 xrep = context.socket(zmq.XREP)
166 xrep.bind('tcp://127.0.0.1:5556')
166 xrep.bind('tcp://127.0.0.1:5556')
167
167
168 outstream = zmqstream.ZMQStream(pub, loop)
168 outstream = zmqstream.ZMQStream(pub, loop)
169 instream = zmqstream.ZMQStream(xrep, loop)
169 instream = zmqstream.ZMQStream(xrep, loop)
170
170
171 hb = HeartMonitor(loop, outstream, instream)
171 hb = HeartMonitor(loop, outstream, instream)
172
172
173 loop.start()
173 loop.start()
@@ -1,1291 +1,1291 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010 The IPython Development Team
11 # Copyright (C) 2010 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Imports
18 # Imports
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 from __future__ import print_function
20 from __future__ import print_function
21
21
22 import sys
22 import sys
23 import time
23 import time
24 from datetime import datetime
24 from datetime import datetime
25
25
26 import zmq
26 import zmq
27 from zmq.eventloop import ioloop
27 from zmq.eventloop import ioloop
28 from zmq.eventloop.zmqstream import ZMQStream
28 from zmq.eventloop.zmqstream import ZMQStream
29
29
30 # internal:
30 # internal:
31 from IPython.utils.importstring import import_item
31 from IPython.utils.importstring import import_item
32 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 )
34 )
35
35
36 from IPython.parallel import error, util
36 from IPython.parallel import error, util
37 from IPython.parallel.factory import RegistrationFactory
37 from IPython.parallel.factory import RegistrationFactory
38
38
39 from IPython.zmq.session import SessionFactory
39 from IPython.zmq.session import SessionFactory
40
40
41 from .heartmonitor import HeartMonitor
41 from .heartmonitor import HeartMonitor
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Code
44 # Code
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 def _passer(*args, **kwargs):
47 def _passer(*args, **kwargs):
48 return
48 return
49
49
50 def _printer(*args, **kwargs):
50 def _printer(*args, **kwargs):
51 print (args)
51 print (args)
52 print (kwargs)
52 print (kwargs)
53
53
54 def empty_record():
54 def empty_record():
55 """Return an empty dict with all record keys."""
55 """Return an empty dict with all record keys."""
56 return {
56 return {
57 'msg_id' : None,
57 'msg_id' : None,
58 'header' : None,
58 'header' : None,
59 'content': None,
59 'content': None,
60 'buffers': None,
60 'buffers': None,
61 'submitted': None,
61 'submitted': None,
62 'client_uuid' : None,
62 'client_uuid' : None,
63 'engine_uuid' : None,
63 'engine_uuid' : None,
64 'started': None,
64 'started': None,
65 'completed': None,
65 'completed': None,
66 'resubmitted': None,
66 'resubmitted': None,
67 'result_header' : None,
67 'result_header' : None,
68 'result_content' : None,
68 'result_content' : None,
69 'result_buffers' : None,
69 'result_buffers' : None,
70 'queue' : None,
70 'queue' : None,
71 'pyin' : None,
71 'pyin' : None,
72 'pyout': None,
72 'pyout': None,
73 'pyerr': None,
73 'pyerr': None,
74 'stdout': '',
74 'stdout': '',
75 'stderr': '',
75 'stderr': '',
76 }
76 }
77
77
78 def init_record(msg):
78 def init_record(msg):
79 """Initialize a TaskRecord based on a request."""
79 """Initialize a TaskRecord based on a request."""
80 header = msg['header']
80 header = msg['header']
81 return {
81 return {
82 'msg_id' : header['msg_id'],
82 'msg_id' : header['msg_id'],
83 'header' : header,
83 'header' : header,
84 'content': msg['content'],
84 'content': msg['content'],
85 'buffers': msg['buffers'],
85 'buffers': msg['buffers'],
86 'submitted': header['date'],
86 'submitted': header['date'],
87 'client_uuid' : None,
87 'client_uuid' : None,
88 'engine_uuid' : None,
88 'engine_uuid' : None,
89 'started': None,
89 'started': None,
90 'completed': None,
90 'completed': None,
91 'resubmitted': None,
91 'resubmitted': None,
92 'result_header' : None,
92 'result_header' : None,
93 'result_content' : None,
93 'result_content' : None,
94 'result_buffers' : None,
94 'result_buffers' : None,
95 'queue' : None,
95 'queue' : None,
96 'pyin' : None,
96 'pyin' : None,
97 'pyout': None,
97 'pyout': None,
98 'pyerr': None,
98 'pyerr': None,
99 'stdout': '',
99 'stdout': '',
100 'stderr': '',
100 'stderr': '',
101 }
101 }
102
102
103
103
104 class EngineConnector(HasTraits):
104 class EngineConnector(HasTraits):
105 """A simple object for accessing the various zmq connections of an object.
105 """A simple object for accessing the various zmq connections of an object.
106 Attributes are:
106 Attributes are:
107 id (int): engine ID
107 id (int): engine ID
108 uuid (str): uuid (unused?)
108 uuid (str): uuid (unused?)
109 queue (str): identity of queue's XREQ socket
109 queue (str): identity of queue's XREQ socket
110 registration (str): identity of registration XREQ socket
110 registration (str): identity of registration XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
112 """
112 """
113 id=Int(0)
113 id=Int(0)
114 queue=CBytes()
114 queue=CBytes()
115 control=CBytes()
115 control=CBytes()
116 registration=CBytes()
116 registration=CBytes()
117 heartbeat=CBytes()
117 heartbeat=CBytes()
118 pending=Set()
118 pending=Set()
119
119
120 class HubFactory(RegistrationFactory):
120 class HubFactory(RegistrationFactory):
121 """The Configurable for setting up a Hub."""
121 """The Configurable for setting up a Hub."""
122
122
123 # port-pairs for monitoredqueues:
123 # port-pairs for monitoredqueues:
124 hb = Tuple(Int,Int,config=True,
124 hb = Tuple(Int,Int,config=True,
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 def _hb_default(self):
126 def _hb_default(self):
127 return tuple(util.select_random_ports(2))
127 return tuple(util.select_random_ports(2))
128
128
129 mux = Tuple(Int,Int,config=True,
129 mux = Tuple(Int,Int,config=True,
130 help="""Engine/Client Port pair for MUX queue""")
130 help="""Engine/Client Port pair for MUX queue""")
131
131
132 def _mux_default(self):
132 def _mux_default(self):
133 return tuple(util.select_random_ports(2))
133 return tuple(util.select_random_ports(2))
134
134
135 task = Tuple(Int,Int,config=True,
135 task = Tuple(Int,Int,config=True,
136 help="""Engine/Client Port pair for Task queue""")
136 help="""Engine/Client Port pair for Task queue""")
137 def _task_default(self):
137 def _task_default(self):
138 return tuple(util.select_random_ports(2))
138 return tuple(util.select_random_ports(2))
139
139
140 control = Tuple(Int,Int,config=True,
140 control = Tuple(Int,Int,config=True,
141 help="""Engine/Client Port pair for Control queue""")
141 help="""Engine/Client Port pair for Control queue""")
142
142
143 def _control_default(self):
143 def _control_default(self):
144 return tuple(util.select_random_ports(2))
144 return tuple(util.select_random_ports(2))
145
145
146 iopub = Tuple(Int,Int,config=True,
146 iopub = Tuple(Int,Int,config=True,
147 help="""Engine/Client Port pair for IOPub relay""")
147 help="""Engine/Client Port pair for IOPub relay""")
148
148
149 def _iopub_default(self):
149 def _iopub_default(self):
150 return tuple(util.select_random_ports(2))
150 return tuple(util.select_random_ports(2))
151
151
152 # single ports:
152 # single ports:
153 mon_port = Int(config=True,
153 mon_port = Int(config=True,
154 help="""Monitor (SUB) port for queue traffic""")
154 help="""Monitor (SUB) port for queue traffic""")
155
155
156 def _mon_port_default(self):
156 def _mon_port_default(self):
157 return util.select_random_ports(1)[0]
157 return util.select_random_ports(1)[0]
158
158
159 notifier_port = Int(config=True,
159 notifier_port = Int(config=True,
160 help="""PUB port for sending engine status notifications""")
160 help="""PUB port for sending engine status notifications""")
161
161
162 def _notifier_port_default(self):
162 def _notifier_port_default(self):
163 return util.select_random_ports(1)[0]
163 return util.select_random_ports(1)[0]
164
164
165 engine_ip = Unicode('127.0.0.1', config=True,
165 engine_ip = Unicode('127.0.0.1', config=True,
166 help="IP on which to listen for engine connections. [default: loopback]")
166 help="IP on which to listen for engine connections. [default: loopback]")
167 engine_transport = Unicode('tcp', config=True,
167 engine_transport = Unicode('tcp', config=True,
168 help="0MQ transport for engine connections. [default: tcp]")
168 help="0MQ transport for engine connections. [default: tcp]")
169
169
170 client_ip = Unicode('127.0.0.1', config=True,
170 client_ip = Unicode('127.0.0.1', config=True,
171 help="IP on which to listen for client connections. [default: loopback]")
171 help="IP on which to listen for client connections. [default: loopback]")
172 client_transport = Unicode('tcp', config=True,
172 client_transport = Unicode('tcp', config=True,
173 help="0MQ transport for client connections. [default : tcp]")
173 help="0MQ transport for client connections. [default : tcp]")
174
174
175 monitor_ip = Unicode('127.0.0.1', config=True,
175 monitor_ip = Unicode('127.0.0.1', config=True,
176 help="IP on which to listen for monitor messages. [default: loopback]")
176 help="IP on which to listen for monitor messages. [default: loopback]")
177 monitor_transport = Unicode('tcp', config=True,
177 monitor_transport = Unicode('tcp', config=True,
178 help="0MQ transport for monitor messages. [default : tcp]")
178 help="0MQ transport for monitor messages. [default : tcp]")
179
179
180 monitor_url = Unicode('')
180 monitor_url = Unicode('')
181
181
182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 config=True, help="""The class to use for the DB backend""")
183 config=True, help="""The class to use for the DB backend""")
184
184
185 # not configurable
185 # not configurable
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188
188
189 def _ip_changed(self, name, old, new):
189 def _ip_changed(self, name, old, new):
190 self.engine_ip = new
190 self.engine_ip = new
191 self.client_ip = new
191 self.client_ip = new
192 self.monitor_ip = new
192 self.monitor_ip = new
193 self._update_monitor_url()
193 self._update_monitor_url()
194
194
195 def _update_monitor_url(self):
195 def _update_monitor_url(self):
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197
197
198 def _transport_changed(self, name, old, new):
198 def _transport_changed(self, name, old, new):
199 self.engine_transport = new
199 self.engine_transport = new
200 self.client_transport = new
200 self.client_transport = new
201 self.monitor_transport = new
201 self.monitor_transport = new
202 self._update_monitor_url()
202 self._update_monitor_url()
203
203
204 def __init__(self, **kwargs):
204 def __init__(self, **kwargs):
205 super(HubFactory, self).__init__(**kwargs)
205 super(HubFactory, self).__init__(**kwargs)
206 self._update_monitor_url()
206 self._update_monitor_url()
207
207
208
208
209 def construct(self):
209 def construct(self):
210 self.init_hub()
210 self.init_hub()
211
211
212 def start(self):
212 def start(self):
213 self.heartmonitor.start()
213 self.heartmonitor.start()
214 self.log.info("Heartmonitor started")
214 self.log.info("Heartmonitor started")
215
215
216 def init_hub(self):
216 def init_hub(self):
217 """construct"""
217 """construct"""
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220
220
221 ctx = self.context
221 ctx = self.context
222 loop = self.loop
222 loop = self.loop
223
223
224 # Registrar socket
224 # Registrar socket
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 q.bind(client_iface % self.regport)
226 q.bind(client_iface % self.regport)
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 if self.client_ip != self.engine_ip:
228 if self.client_ip != self.engine_ip:
229 q.bind(engine_iface % self.regport)
229 q.bind(engine_iface % self.regport)
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231
231
232 ### Engine connections ###
232 ### Engine connections ###
233
233
234 # heartbeat
234 # heartbeat
235 hpub = ctx.socket(zmq.PUB)
235 hpub = ctx.socket(zmq.PUB)
236 hpub.bind(engine_iface % self.hb[0])
236 hpub.bind(engine_iface % self.hb[0])
237 hrep = ctx.socket(zmq.XREP)
237 hrep = ctx.socket(zmq.XREP)
238 hrep.bind(engine_iface % self.hb[1])
238 hrep.bind(engine_iface % self.hb[1])
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 pingstream=ZMQStream(hpub,loop),
240 pingstream=ZMQStream(hpub,loop),
241 pongstream=ZMQStream(hrep,loop)
241 pongstream=ZMQStream(hrep,loop)
242 )
242 )
243
243
244 ### Client connections ###
244 ### Client connections ###
245 # Notifier socket
245 # Notifier socket
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 n.bind(client_iface%self.notifier_port)
247 n.bind(client_iface%self.notifier_port)
248
248
249 ### build and launch the queues ###
249 ### build and launch the queues ###
250
250
251 # monitor socket
251 # monitor socket
252 sub = ctx.socket(zmq.SUB)
252 sub = ctx.socket(zmq.SUB)
253 sub.setsockopt(zmq.SUBSCRIBE, b"")
253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 sub.bind(self.monitor_url)
254 sub.bind(self.monitor_url)
255 sub.bind('inproc://monitor')
255 sub.bind('inproc://monitor')
256 sub = ZMQStream(sub, loop)
256 sub = ZMQStream(sub, loop)
257
257
258 # connect the db
258 # connect the db
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 # cdir = self.config.Global.cluster_dir
260 # cdir = self.config.Global.cluster_dir
261 self.db = import_item(str(self.db_class))(session=self.session.session,
261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 config=self.config, log=self.log)
262 config=self.config, log=self.log)
263 time.sleep(.25)
263 time.sleep(.25)
264 try:
264 try:
265 scheme = self.config.TaskScheduler.scheme_name
265 scheme = self.config.TaskScheduler.scheme_name
266 except AttributeError:
266 except AttributeError:
267 from .scheduler import TaskScheduler
267 from .scheduler import TaskScheduler
268 scheme = TaskScheduler.scheme_name.get_default_value()
268 scheme = TaskScheduler.scheme_name.get_default_value()
269 # build connection dicts
269 # build connection dicts
270 self.engine_info = {
270 self.engine_info = {
271 'control' : engine_iface%self.control[1],
271 'control' : engine_iface%self.control[1],
272 'mux': engine_iface%self.mux[1],
272 'mux': engine_iface%self.mux[1],
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 'task' : engine_iface%self.task[1],
274 'task' : engine_iface%self.task[1],
275 'iopub' : engine_iface%self.iopub[1],
275 'iopub' : engine_iface%self.iopub[1],
276 # 'monitor' : engine_iface%self.mon_port,
276 # 'monitor' : engine_iface%self.mon_port,
277 }
277 }
278
278
279 self.client_info = {
279 self.client_info = {
280 'control' : client_iface%self.control[0],
280 'control' : client_iface%self.control[0],
281 'mux': client_iface%self.mux[0],
281 'mux': client_iface%self.mux[0],
282 'task' : (scheme, client_iface%self.task[0]),
282 'task' : (scheme, client_iface%self.task[0]),
283 'iopub' : client_iface%self.iopub[0],
283 'iopub' : client_iface%self.iopub[0],
284 'notification': client_iface%self.notifier_port
284 'notification': client_iface%self.notifier_port
285 }
285 }
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
288
288
289 # resubmit stream
289 # resubmit stream
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 url = util.disambiguate_url(self.client_info['task'][-1])
291 url = util.disambiguate_url(self.client_info['task'][-1])
292 r.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
292 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 r.connect(url)
293 r.connect(url)
294
294
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 query=q, notifier=n, resubmit=r, db=self.db,
296 query=q, notifier=n, resubmit=r, db=self.db,
297 engine_info=self.engine_info, client_info=self.client_info,
297 engine_info=self.engine_info, client_info=self.client_info,
298 log=self.log)
298 log=self.log)
299
299
300
300
301 class Hub(SessionFactory):
301 class Hub(SessionFactory):
302 """The IPython Controller Hub with 0MQ connections
302 """The IPython Controller Hub with 0MQ connections
303
303
304 Parameters
304 Parameters
305 ==========
305 ==========
306 loop: zmq IOLoop instance
306 loop: zmq IOLoop instance
307 session: Session object
307 session: Session object
308 <removed> context: zmq context for creating new connections (?)
308 <removed> context: zmq context for creating new connections (?)
309 queue: ZMQStream for monitoring the command queue (SUB)
309 queue: ZMQStream for monitoring the command queue (SUB)
310 query: ZMQStream for engine registration and client queries requests (XREP)
310 query: ZMQStream for engine registration and client queries requests (XREP)
311 heartbeat: HeartMonitor object checking the pulse of the engines
311 heartbeat: HeartMonitor object checking the pulse of the engines
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 db: connection to db for out of memory logging of commands
313 db: connection to db for out of memory logging of commands
314 NotImplemented
314 NotImplemented
315 engine_info: dict of zmq connection information for engines to connect
315 engine_info: dict of zmq connection information for engines to connect
316 to the queues.
316 to the queues.
317 client_info: dict of zmq connection information for engines to connect
317 client_info: dict of zmq connection information for engines to connect
318 to the queues.
318 to the queues.
319 """
319 """
320 # internal data structures:
320 # internal data structures:
321 ids=Set() # engine IDs
321 ids=Set() # engine IDs
322 keytable=Dict()
322 keytable=Dict()
323 by_ident=Dict()
323 by_ident=Dict()
324 engines=Dict()
324 engines=Dict()
325 clients=Dict()
325 clients=Dict()
326 hearts=Dict()
326 hearts=Dict()
327 pending=Set()
327 pending=Set()
328 queues=Dict() # pending msg_ids keyed by engine_id
328 queues=Dict() # pending msg_ids keyed by engine_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 completed=Dict() # completed msg_ids keyed by engine_id
330 completed=Dict() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 incoming_registrations=Dict()
334 incoming_registrations=Dict()
335 registration_timeout=Int()
335 registration_timeout=Int()
336 _idcounter=Int(0)
336 _idcounter=Int(0)
337
337
338 # objects from constructor:
338 # objects from constructor:
339 query=Instance(ZMQStream)
339 query=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
343 heartmonitor=Instance(HeartMonitor)
343 heartmonitor=Instance(HeartMonitor)
344 db=Instance(object)
344 db=Instance(object)
345 client_info=Dict()
345 client_info=Dict()
346 engine_info=Dict()
346 engine_info=Dict()
347
347
348
348
349 def __init__(self, **kwargs):
349 def __init__(self, **kwargs):
350 """
350 """
351 # universal:
351 # universal:
352 loop: IOLoop for creating future connections
352 loop: IOLoop for creating future connections
353 session: streamsession for sending serialized data
353 session: streamsession for sending serialized data
354 # engine:
354 # engine:
355 queue: ZMQStream for monitoring queue messages
355 queue: ZMQStream for monitoring queue messages
356 query: ZMQStream for engine+client registration and client requests
356 query: ZMQStream for engine+client registration and client requests
357 heartbeat: HeartMonitor object for tracking engines
357 heartbeat: HeartMonitor object for tracking engines
358 # extra:
358 # extra:
359 db: ZMQStream for db connection (NotImplemented)
359 db: ZMQStream for db connection (NotImplemented)
360 engine_info: zmq address/protocol dict for engine connections
360 engine_info: zmq address/protocol dict for engine connections
361 client_info: zmq address/protocol dict for client connections
361 client_info: zmq address/protocol dict for client connections
362 """
362 """
363
363
364 super(Hub, self).__init__(**kwargs)
364 super(Hub, self).__init__(**kwargs)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366
366
367 # validate connection dicts:
367 # validate connection dicts:
368 for k,v in self.client_info.iteritems():
368 for k,v in self.client_info.iteritems():
369 if k == 'task':
369 if k == 'task':
370 util.validate_url_container(v[1])
370 util.validate_url_container(v[1])
371 else:
371 else:
372 util.validate_url_container(v)
372 util.validate_url_container(v)
373 # util.validate_url_container(self.client_info)
373 # util.validate_url_container(self.client_info)
374 util.validate_url_container(self.engine_info)
374 util.validate_url_container(self.engine_info)
375
375
376 # register our callbacks
376 # register our callbacks
377 self.query.on_recv(self.dispatch_query)
377 self.query.on_recv(self.dispatch_query)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379
379
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382
382
383 self.monitor_handlers = {b'in' : self.save_queue_request,
383 self.monitor_handlers = {b'in' : self.save_queue_request,
384 b'out': self.save_queue_result,
384 b'out': self.save_queue_result,
385 b'intask': self.save_task_request,
385 b'intask': self.save_task_request,
386 b'outtask': self.save_task_result,
386 b'outtask': self.save_task_result,
387 b'tracktask': self.save_task_destination,
387 b'tracktask': self.save_task_destination,
388 b'incontrol': _passer,
388 b'incontrol': _passer,
389 b'outcontrol': _passer,
389 b'outcontrol': _passer,
390 b'iopub': self.save_iopub_message,
390 b'iopub': self.save_iopub_message,
391 }
391 }
392
392
393 self.query_handlers = {'queue_request': self.queue_status,
393 self.query_handlers = {'queue_request': self.queue_status,
394 'result_request': self.get_results,
394 'result_request': self.get_results,
395 'history_request': self.get_history,
395 'history_request': self.get_history,
396 'db_request': self.db_query,
396 'db_request': self.db_query,
397 'purge_request': self.purge_results,
397 'purge_request': self.purge_results,
398 'load_request': self.check_load,
398 'load_request': self.check_load,
399 'resubmit_request': self.resubmit_task,
399 'resubmit_request': self.resubmit_task,
400 'shutdown_request': self.shutdown_request,
400 'shutdown_request': self.shutdown_request,
401 'registration_request' : self.register_engine,
401 'registration_request' : self.register_engine,
402 'unregistration_request' : self.unregister_engine,
402 'unregistration_request' : self.unregister_engine,
403 'connection_request': self.connection_request,
403 'connection_request': self.connection_request,
404 }
404 }
405
405
406 # ignore resubmit replies
406 # ignore resubmit replies
407 self.resubmit.on_recv(lambda msg: None, copy=False)
407 self.resubmit.on_recv(lambda msg: None, copy=False)
408
408
409 self.log.info("hub::created hub")
409 self.log.info("hub::created hub")
410
410
411 @property
411 @property
412 def _next_id(self):
412 def _next_id(self):
413 """gemerate a new ID.
413 """gemerate a new ID.
414
414
415 No longer reuse old ids, just count from 0."""
415 No longer reuse old ids, just count from 0."""
416 newid = self._idcounter
416 newid = self._idcounter
417 self._idcounter += 1
417 self._idcounter += 1
418 return newid
418 return newid
419 # newid = 0
419 # newid = 0
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 # # print newid, self.ids, self.incoming_registrations
421 # # print newid, self.ids, self.incoming_registrations
422 # while newid in self.ids or newid in incoming:
422 # while newid in self.ids or newid in incoming:
423 # newid += 1
423 # newid += 1
424 # return newid
424 # return newid
425
425
426 #-----------------------------------------------------------------------------
426 #-----------------------------------------------------------------------------
427 # message validation
427 # message validation
428 #-----------------------------------------------------------------------------
428 #-----------------------------------------------------------------------------
429
429
430 def _validate_targets(self, targets):
430 def _validate_targets(self, targets):
431 """turn any valid targets argument into a list of integer ids"""
431 """turn any valid targets argument into a list of integer ids"""
432 if targets is None:
432 if targets is None:
433 # default to all
433 # default to all
434 targets = self.ids
434 targets = self.ids
435
435
436 if isinstance(targets, (int,str,unicode)):
436 if isinstance(targets, (int,str,unicode)):
437 # only one target specified
437 # only one target specified
438 targets = [targets]
438 targets = [targets]
439 _targets = []
439 _targets = []
440 for t in targets:
440 for t in targets:
441 # map raw identities to ids
441 # map raw identities to ids
442 if isinstance(t, (str,unicode)):
442 if isinstance(t, (str,unicode)):
443 t = self.by_ident.get(t, t)
443 t = self.by_ident.get(t, t)
444 _targets.append(t)
444 _targets.append(t)
445 targets = _targets
445 targets = _targets
446 bad_targets = [ t for t in targets if t not in self.ids ]
446 bad_targets = [ t for t in targets if t not in self.ids ]
447 if bad_targets:
447 if bad_targets:
448 raise IndexError("No Such Engine: %r"%bad_targets)
448 raise IndexError("No Such Engine: %r"%bad_targets)
449 if not targets:
449 if not targets:
450 raise IndexError("No Engines Registered")
450 raise IndexError("No Engines Registered")
451 return targets
451 return targets
452
452
453 #-----------------------------------------------------------------------------
453 #-----------------------------------------------------------------------------
454 # dispatch methods (1 per stream)
454 # dispatch methods (1 per stream)
455 #-----------------------------------------------------------------------------
455 #-----------------------------------------------------------------------------
456
456
457
457
458 def dispatch_monitor_traffic(self, msg):
458 def dispatch_monitor_traffic(self, msg):
459 """all ME and Task queue messages come through here, as well as
459 """all ME and Task queue messages come through here, as well as
460 IOPub traffic."""
460 IOPub traffic."""
461 self.log.debug("monitor traffic: %r"%msg[:2])
461 self.log.debug("monitor traffic: %r"%msg[:2])
462 switch = msg[0]
462 switch = msg[0]
463 try:
463 try:
464 idents, msg = self.session.feed_identities(msg[1:])
464 idents, msg = self.session.feed_identities(msg[1:])
465 except ValueError:
465 except ValueError:
466 idents=[]
466 idents=[]
467 if not idents:
467 if not idents:
468 self.log.error("Bad Monitor Message: %r"%msg)
468 self.log.error("Bad Monitor Message: %r"%msg)
469 return
469 return
470 handler = self.monitor_handlers.get(switch, None)
470 handler = self.monitor_handlers.get(switch, None)
471 if handler is not None:
471 if handler is not None:
472 handler(idents, msg)
472 handler(idents, msg)
473 else:
473 else:
474 self.log.error("Invalid monitor topic: %r"%switch)
474 self.log.error("Invalid monitor topic: %r"%switch)
475
475
476
476
477 def dispatch_query(self, msg):
477 def dispatch_query(self, msg):
478 """Route registration requests and queries from clients."""
478 """Route registration requests and queries from clients."""
479 try:
479 try:
480 idents, msg = self.session.feed_identities(msg)
480 idents, msg = self.session.feed_identities(msg)
481 except ValueError:
481 except ValueError:
482 idents = []
482 idents = []
483 if not idents:
483 if not idents:
484 self.log.error("Bad Query Message: %r"%msg)
484 self.log.error("Bad Query Message: %r"%msg)
485 return
485 return
486 client_id = idents[0]
486 client_id = idents[0]
487 try:
487 try:
488 msg = self.session.unpack_message(msg, content=True)
488 msg = self.session.unpack_message(msg, content=True)
489 except Exception:
489 except Exception:
490 content = error.wrap_exception()
490 content = error.wrap_exception()
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 self.session.send(self.query, "hub_error", ident=client_id,
492 self.session.send(self.query, "hub_error", ident=client_id,
493 content=content)
493 content=content)
494 return
494 return
495 # print client_id, header, parent, content
495 # print client_id, header, parent, content
496 #switch on message type:
496 #switch on message type:
497 msg_type = msg['msg_type']
497 msg_type = msg['msg_type']
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 handler = self.query_handlers.get(msg_type, None)
499 handler = self.query_handlers.get(msg_type, None)
500 try:
500 try:
501 assert handler is not None, "Bad Message Type: %r"%msg_type
501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 except:
502 except:
503 content = error.wrap_exception()
503 content = error.wrap_exception()
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 self.session.send(self.query, "hub_error", ident=client_id,
505 self.session.send(self.query, "hub_error", ident=client_id,
506 content=content)
506 content=content)
507 return
507 return
508
508
509 else:
509 else:
510 handler(idents, msg)
510 handler(idents, msg)
511
511
512 def dispatch_db(self, msg):
512 def dispatch_db(self, msg):
513 """"""
513 """"""
514 raise NotImplementedError
514 raise NotImplementedError
515
515
516 #---------------------------------------------------------------------------
516 #---------------------------------------------------------------------------
517 # handler methods (1 per event)
517 # handler methods (1 per event)
518 #---------------------------------------------------------------------------
518 #---------------------------------------------------------------------------
519
519
520 #----------------------- Heartbeat --------------------------------------
520 #----------------------- Heartbeat --------------------------------------
521
521
522 def handle_new_heart(self, heart):
522 def handle_new_heart(self, heart):
523 """handler to attach to heartbeater.
523 """handler to attach to heartbeater.
524 Called when a new heart starts to beat.
524 Called when a new heart starts to beat.
525 Triggers completion of registration."""
525 Triggers completion of registration."""
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 if heart not in self.incoming_registrations:
527 if heart not in self.incoming_registrations:
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 else:
529 else:
530 self.finish_registration(heart)
530 self.finish_registration(heart)
531
531
532
532
533 def handle_heart_failure(self, heart):
533 def handle_heart_failure(self, heart):
534 """handler to attach to heartbeater.
534 """handler to attach to heartbeater.
535 called when a previously registered heart fails to respond to beat request.
535 called when a previously registered heart fails to respond to beat request.
536 triggers unregistration"""
536 triggers unregistration"""
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 eid = self.hearts.get(heart, None)
538 eid = self.hearts.get(heart, None)
539 queue = self.engines[eid].queue
539 queue = self.engines[eid].queue
540 if eid is None:
540 if eid is None:
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 else:
542 else:
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544
544
545 #----------------------- MUX Queue Traffic ------------------------------
545 #----------------------- MUX Queue Traffic ------------------------------
546
546
547 def save_queue_request(self, idents, msg):
547 def save_queue_request(self, idents, msg):
548 if len(idents) < 2:
548 if len(idents) < 2:
549 self.log.error("invalid identity prefix: %r"%idents)
549 self.log.error("invalid identity prefix: %r"%idents)
550 return
550 return
551 queue_id, client_id = idents[:2]
551 queue_id, client_id = idents[:2]
552 try:
552 try:
553 msg = self.session.unpack_message(msg)
553 msg = self.session.unpack_message(msg)
554 except Exception:
554 except Exception:
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 return
556 return
557
557
558 eid = self.by_ident.get(queue_id, None)
558 eid = self.by_ident.get(queue_id, None)
559 if eid is None:
559 if eid is None:
560 self.log.error("queue::target %r not registered"%queue_id)
560 self.log.error("queue::target %r not registered"%queue_id)
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 return
562 return
563 record = init_record(msg)
563 record = init_record(msg)
564 msg_id = record['msg_id']
564 msg_id = record['msg_id']
565 # Unicode in records
565 # Unicode in records
566 record['engine_uuid'] = queue_id.decode('ascii')
566 record['engine_uuid'] = queue_id.decode('ascii')
567 record['client_uuid'] = client_id.decode('ascii')
567 record['client_uuid'] = client_id.decode('ascii')
568 record['queue'] = 'mux'
568 record['queue'] = 'mux'
569
569
570 try:
570 try:
571 # it's posible iopub arrived first:
571 # it's posible iopub arrived first:
572 existing = self.db.get_record(msg_id)
572 existing = self.db.get_record(msg_id)
573 for key,evalue in existing.iteritems():
573 for key,evalue in existing.iteritems():
574 rvalue = record.get(key, None)
574 rvalue = record.get(key, None)
575 if evalue and rvalue and evalue != rvalue:
575 if evalue and rvalue and evalue != rvalue:
576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 elif evalue and not rvalue:
577 elif evalue and not rvalue:
578 record[key] = evalue
578 record[key] = evalue
579 try:
579 try:
580 self.db.update_record(msg_id, record)
580 self.db.update_record(msg_id, record)
581 except Exception:
581 except Exception:
582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 except KeyError:
583 except KeyError:
584 try:
584 try:
585 self.db.add_record(msg_id, record)
585 self.db.add_record(msg_id, record)
586 except Exception:
586 except Exception:
587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588
588
589
589
590 self.pending.add(msg_id)
590 self.pending.add(msg_id)
591 self.queues[eid].append(msg_id)
591 self.queues[eid].append(msg_id)
592
592
593 def save_queue_result(self, idents, msg):
593 def save_queue_result(self, idents, msg):
594 if len(idents) < 2:
594 if len(idents) < 2:
595 self.log.error("invalid identity prefix: %r"%idents)
595 self.log.error("invalid identity prefix: %r"%idents)
596 return
596 return
597
597
598 client_id, queue_id = idents[:2]
598 client_id, queue_id = idents[:2]
599 try:
599 try:
600 msg = self.session.unpack_message(msg)
600 msg = self.session.unpack_message(msg)
601 except Exception:
601 except Exception:
602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 queue_id,client_id, msg), exc_info=True)
603 queue_id,client_id, msg), exc_info=True)
604 return
604 return
605
605
606 eid = self.by_ident.get(queue_id, None)
606 eid = self.by_ident.get(queue_id, None)
607 if eid is None:
607 if eid is None:
608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 return
609 return
610
610
611 parent = msg['parent_header']
611 parent = msg['parent_header']
612 if not parent:
612 if not parent:
613 return
613 return
614 msg_id = parent['msg_id']
614 msg_id = parent['msg_id']
615 if msg_id in self.pending:
615 if msg_id in self.pending:
616 self.pending.remove(msg_id)
616 self.pending.remove(msg_id)
617 self.all_completed.add(msg_id)
617 self.all_completed.add(msg_id)
618 self.queues[eid].remove(msg_id)
618 self.queues[eid].remove(msg_id)
619 self.completed[eid].append(msg_id)
619 self.completed[eid].append(msg_id)
620 elif msg_id not in self.all_completed:
620 elif msg_id not in self.all_completed:
621 # it could be a result from a dead engine that died before delivering the
621 # it could be a result from a dead engine that died before delivering the
622 # result
622 # result
623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 return
624 return
625 # update record anyway, because the unregistration could have been premature
625 # update record anyway, because the unregistration could have been premature
626 rheader = msg['header']
626 rheader = msg['header']
627 completed = rheader['date']
627 completed = rheader['date']
628 started = rheader.get('started', None)
628 started = rheader.get('started', None)
629 result = {
629 result = {
630 'result_header' : rheader,
630 'result_header' : rheader,
631 'result_content': msg['content'],
631 'result_content': msg['content'],
632 'started' : started,
632 'started' : started,
633 'completed' : completed
633 'completed' : completed
634 }
634 }
635
635
636 result['result_buffers'] = msg['buffers']
636 result['result_buffers'] = msg['buffers']
637 try:
637 try:
638 self.db.update_record(msg_id, result)
638 self.db.update_record(msg_id, result)
639 except Exception:
639 except Exception:
640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641
641
642
642
643 #--------------------- Task Queue Traffic ------------------------------
643 #--------------------- Task Queue Traffic ------------------------------
644
644
645 def save_task_request(self, idents, msg):
645 def save_task_request(self, idents, msg):
646 """Save the submission of a task."""
646 """Save the submission of a task."""
647 client_id = idents[0]
647 client_id = idents[0]
648
648
649 try:
649 try:
650 msg = self.session.unpack_message(msg)
650 msg = self.session.unpack_message(msg)
651 except Exception:
651 except Exception:
652 self.log.error("task::client %r sent invalid task message: %r"%(
652 self.log.error("task::client %r sent invalid task message: %r"%(
653 client_id, msg), exc_info=True)
653 client_id, msg), exc_info=True)
654 return
654 return
655 record = init_record(msg)
655 record = init_record(msg)
656
656
657 record['client_uuid'] = client_id
657 record['client_uuid'] = client_id
658 record['queue'] = 'task'
658 record['queue'] = 'task'
659 header = msg['header']
659 header = msg['header']
660 msg_id = header['msg_id']
660 msg_id = header['msg_id']
661 self.pending.add(msg_id)
661 self.pending.add(msg_id)
662 self.unassigned.add(msg_id)
662 self.unassigned.add(msg_id)
663 try:
663 try:
664 # it's posible iopub arrived first:
664 # it's posible iopub arrived first:
665 existing = self.db.get_record(msg_id)
665 existing = self.db.get_record(msg_id)
666 if existing['resubmitted']:
666 if existing['resubmitted']:
667 for key in ('submitted', 'client_uuid', 'buffers'):
667 for key in ('submitted', 'client_uuid', 'buffers'):
668 # don't clobber these keys on resubmit
668 # don't clobber these keys on resubmit
669 # submitted and client_uuid should be different
669 # submitted and client_uuid should be different
670 # and buffers might be big, and shouldn't have changed
670 # and buffers might be big, and shouldn't have changed
671 record.pop(key)
671 record.pop(key)
672 # still check content,header which should not change
672 # still check content,header which should not change
673 # but are not expensive to compare as buffers
673 # but are not expensive to compare as buffers
674
674
675 for key,evalue in existing.iteritems():
675 for key,evalue in existing.iteritems():
676 if key.endswith('buffers'):
676 if key.endswith('buffers'):
677 # don't compare buffers
677 # don't compare buffers
678 continue
678 continue
679 rvalue = record.get(key, None)
679 rvalue = record.get(key, None)
680 if evalue and rvalue and evalue != rvalue:
680 if evalue and rvalue and evalue != rvalue:
681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 elif evalue and not rvalue:
682 elif evalue and not rvalue:
683 record[key] = evalue
683 record[key] = evalue
684 try:
684 try:
685 self.db.update_record(msg_id, record)
685 self.db.update_record(msg_id, record)
686 except Exception:
686 except Exception:
687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 except KeyError:
688 except KeyError:
689 try:
689 try:
690 self.db.add_record(msg_id, record)
690 self.db.add_record(msg_id, record)
691 except Exception:
691 except Exception:
692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 except Exception:
693 except Exception:
694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695
695
696 def save_task_result(self, idents, msg):
696 def save_task_result(self, idents, msg):
697 """save the result of a completed task."""
697 """save the result of a completed task."""
698 client_id = idents[0]
698 client_id = idents[0]
699 try:
699 try:
700 msg = self.session.unpack_message(msg)
700 msg = self.session.unpack_message(msg)
701 except Exception:
701 except Exception:
702 self.log.error("task::invalid task result message send to %r: %r"%(
702 self.log.error("task::invalid task result message send to %r: %r"%(
703 client_id, msg), exc_info=True)
703 client_id, msg), exc_info=True)
704 return
704 return
705
705
706 parent = msg['parent_header']
706 parent = msg['parent_header']
707 if not parent:
707 if not parent:
708 # print msg
708 # print msg
709 self.log.warn("Task %r had no parent!"%msg)
709 self.log.warn("Task %r had no parent!"%msg)
710 return
710 return
711 msg_id = parent['msg_id']
711 msg_id = parent['msg_id']
712 if msg_id in self.unassigned:
712 if msg_id in self.unassigned:
713 self.unassigned.remove(msg_id)
713 self.unassigned.remove(msg_id)
714
714
715 header = msg['header']
715 header = msg['header']
716 engine_uuid = header.get('engine', None)
716 engine_uuid = header.get('engine', None)
717 eid = self.by_ident.get(engine_uuid, None)
717 eid = self.by_ident.get(engine_uuid, None)
718
718
719 if msg_id in self.pending:
719 if msg_id in self.pending:
720 self.pending.remove(msg_id)
720 self.pending.remove(msg_id)
721 self.all_completed.add(msg_id)
721 self.all_completed.add(msg_id)
722 if eid is not None:
722 if eid is not None:
723 self.completed[eid].append(msg_id)
723 self.completed[eid].append(msg_id)
724 if msg_id in self.tasks[eid]:
724 if msg_id in self.tasks[eid]:
725 self.tasks[eid].remove(msg_id)
725 self.tasks[eid].remove(msg_id)
726 completed = header['date']
726 completed = header['date']
727 started = header.get('started', None)
727 started = header.get('started', None)
728 result = {
728 result = {
729 'result_header' : header,
729 'result_header' : header,
730 'result_content': msg['content'],
730 'result_content': msg['content'],
731 'started' : started,
731 'started' : started,
732 'completed' : completed,
732 'completed' : completed,
733 'engine_uuid': engine_uuid
733 'engine_uuid': engine_uuid
734 }
734 }
735
735
736 result['result_buffers'] = msg['buffers']
736 result['result_buffers'] = msg['buffers']
737 try:
737 try:
738 self.db.update_record(msg_id, result)
738 self.db.update_record(msg_id, result)
739 except Exception:
739 except Exception:
740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741
741
742 else:
742 else:
743 self.log.debug("task::unknown task %r finished"%msg_id)
743 self.log.debug("task::unknown task %r finished"%msg_id)
744
744
745 def save_task_destination(self, idents, msg):
745 def save_task_destination(self, idents, msg):
746 try:
746 try:
747 msg = self.session.unpack_message(msg, content=True)
747 msg = self.session.unpack_message(msg, content=True)
748 except Exception:
748 except Exception:
749 self.log.error("task::invalid task tracking message", exc_info=True)
749 self.log.error("task::invalid task tracking message", exc_info=True)
750 return
750 return
751 content = msg['content']
751 content = msg['content']
752 # print (content)
752 # print (content)
753 msg_id = content['msg_id']
753 msg_id = content['msg_id']
754 engine_uuid = content['engine_id']
754 engine_uuid = content['engine_id']
755 eid = self.by_ident[util.ensure_bytes(engine_uuid)]
755 eid = self.by_ident[util.asbytes(engine_uuid)]
756
756
757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 if msg_id in self.unassigned:
758 if msg_id in self.unassigned:
759 self.unassigned.remove(msg_id)
759 self.unassigned.remove(msg_id)
760 # else:
760 # else:
761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762
762
763 self.tasks[eid].append(msg_id)
763 self.tasks[eid].append(msg_id)
764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 try:
765 try:
766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 except Exception:
767 except Exception:
768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769
769
770
770
771 def mia_task_request(self, idents, msg):
771 def mia_task_request(self, idents, msg):
772 raise NotImplementedError
772 raise NotImplementedError
773 client_id = idents[0]
773 client_id = idents[0]
774 # content = dict(mia=self.mia,status='ok')
774 # content = dict(mia=self.mia,status='ok')
775 # self.session.send('mia_reply', content=content, idents=client_id)
775 # self.session.send('mia_reply', content=content, idents=client_id)
776
776
777
777
778 #--------------------- IOPub Traffic ------------------------------
778 #--------------------- IOPub Traffic ------------------------------
779
779
780 def save_iopub_message(self, topics, msg):
780 def save_iopub_message(self, topics, msg):
781 """save an iopub message into the db"""
781 """save an iopub message into the db"""
782 # print (topics)
782 # print (topics)
783 try:
783 try:
784 msg = self.session.unpack_message(msg, content=True)
784 msg = self.session.unpack_message(msg, content=True)
785 except Exception:
785 except Exception:
786 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 return
787 return
788
788
789 parent = msg['parent_header']
789 parent = msg['parent_header']
790 if not parent:
790 if not parent:
791 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 return
792 return
793 msg_id = parent['msg_id']
793 msg_id = parent['msg_id']
794 msg_type = msg['msg_type']
794 msg_type = msg['msg_type']
795 content = msg['content']
795 content = msg['content']
796
796
797 # ensure msg_id is in db
797 # ensure msg_id is in db
798 try:
798 try:
799 rec = self.db.get_record(msg_id)
799 rec = self.db.get_record(msg_id)
800 except KeyError:
800 except KeyError:
801 rec = empty_record()
801 rec = empty_record()
802 rec['msg_id'] = msg_id
802 rec['msg_id'] = msg_id
803 self.db.add_record(msg_id, rec)
803 self.db.add_record(msg_id, rec)
804 # stream
804 # stream
805 d = {}
805 d = {}
806 if msg_type == 'stream':
806 if msg_type == 'stream':
807 name = content['name']
807 name = content['name']
808 s = rec[name] or ''
808 s = rec[name] or ''
809 d[name] = s + content['data']
809 d[name] = s + content['data']
810
810
811 elif msg_type == 'pyerr':
811 elif msg_type == 'pyerr':
812 d['pyerr'] = content
812 d['pyerr'] = content
813 elif msg_type == 'pyin':
813 elif msg_type == 'pyin':
814 d['pyin'] = content['code']
814 d['pyin'] = content['code']
815 else:
815 else:
816 d[msg_type] = content.get('data', '')
816 d[msg_type] = content.get('data', '')
817
817
818 try:
818 try:
819 self.db.update_record(msg_id, d)
819 self.db.update_record(msg_id, d)
820 except Exception:
820 except Exception:
821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822
822
823
823
824
824
825 #-------------------------------------------------------------------------
825 #-------------------------------------------------------------------------
826 # Registration requests
826 # Registration requests
827 #-------------------------------------------------------------------------
827 #-------------------------------------------------------------------------
828
828
829 def connection_request(self, client_id, msg):
829 def connection_request(self, client_id, msg):
830 """Reply with connection addresses for clients."""
830 """Reply with connection addresses for clients."""
831 self.log.info("client::client %r connected"%client_id)
831 self.log.info("client::client %r connected"%client_id)
832 content = dict(status='ok')
832 content = dict(status='ok')
833 content.update(self.client_info)
833 content.update(self.client_info)
834 jsonable = {}
834 jsonable = {}
835 for k,v in self.keytable.iteritems():
835 for k,v in self.keytable.iteritems():
836 if v not in self.dead_engines:
836 if v not in self.dead_engines:
837 jsonable[str(k)] = v.decode('ascii')
837 jsonable[str(k)] = v.decode('ascii')
838 content['engines'] = jsonable
838 content['engines'] = jsonable
839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840
840
841 def register_engine(self, reg, msg):
841 def register_engine(self, reg, msg):
842 """Register a new engine."""
842 """Register a new engine."""
843 content = msg['content']
843 content = msg['content']
844 try:
844 try:
845 queue = util.ensure_bytes(content['queue'])
845 queue = util.asbytes(content['queue'])
846 except KeyError:
846 except KeyError:
847 self.log.error("registration::queue not specified", exc_info=True)
847 self.log.error("registration::queue not specified", exc_info=True)
848 return
848 return
849 heart = content.get('heartbeat', None)
849 heart = content.get('heartbeat', None)
850 if heart:
850 if heart:
851 heart = util.ensure_bytes(heart)
851 heart = util.asbytes(heart)
852 """register a new engine, and create the socket(s) necessary"""
852 """register a new engine, and create the socket(s) necessary"""
853 eid = self._next_id
853 eid = self._next_id
854 # print (eid, queue, reg, heart)
854 # print (eid, queue, reg, heart)
855
855
856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857
857
858 content = dict(id=eid,status='ok')
858 content = dict(id=eid,status='ok')
859 content.update(self.engine_info)
859 content.update(self.engine_info)
860 # check if requesting available IDs:
860 # check if requesting available IDs:
861 if queue in self.by_ident:
861 if queue in self.by_ident:
862 try:
862 try:
863 raise KeyError("queue_id %r in use"%queue)
863 raise KeyError("queue_id %r in use"%queue)
864 except:
864 except:
865 content = error.wrap_exception()
865 content = error.wrap_exception()
866 self.log.error("queue_id %r in use"%queue, exc_info=True)
866 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 elif heart in self.hearts: # need to check unique hearts?
867 elif heart in self.hearts: # need to check unique hearts?
868 try:
868 try:
869 raise KeyError("heart_id %r in use"%heart)
869 raise KeyError("heart_id %r in use"%heart)
870 except:
870 except:
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 content = error.wrap_exception()
872 content = error.wrap_exception()
873 else:
873 else:
874 for h, pack in self.incoming_registrations.iteritems():
874 for h, pack in self.incoming_registrations.iteritems():
875 if heart == h:
875 if heart == h:
876 try:
876 try:
877 raise KeyError("heart_id %r in use"%heart)
877 raise KeyError("heart_id %r in use"%heart)
878 except:
878 except:
879 self.log.error("heart_id %r in use"%heart, exc_info=True)
879 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 content = error.wrap_exception()
880 content = error.wrap_exception()
881 break
881 break
882 elif queue == pack[1]:
882 elif queue == pack[1]:
883 try:
883 try:
884 raise KeyError("queue_id %r in use"%queue)
884 raise KeyError("queue_id %r in use"%queue)
885 except:
885 except:
886 self.log.error("queue_id %r in use"%queue, exc_info=True)
886 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 content = error.wrap_exception()
887 content = error.wrap_exception()
888 break
888 break
889
889
890 msg = self.session.send(self.query, "registration_reply",
890 msg = self.session.send(self.query, "registration_reply",
891 content=content,
891 content=content,
892 ident=reg)
892 ident=reg)
893
893
894 if content['status'] == 'ok':
894 if content['status'] == 'ok':
895 if heart in self.heartmonitor.hearts:
895 if heart in self.heartmonitor.hearts:
896 # already beating
896 # already beating
897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 self.finish_registration(heart)
898 self.finish_registration(heart)
899 else:
899 else:
900 purge = lambda : self._purge_stalled_registration(heart)
900 purge = lambda : self._purge_stalled_registration(heart)
901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 dc.start()
902 dc.start()
903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 else:
904 else:
905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 return eid
906 return eid
907
907
908 def unregister_engine(self, ident, msg):
908 def unregister_engine(self, ident, msg):
909 """Unregister an engine that explicitly requested to leave."""
909 """Unregister an engine that explicitly requested to leave."""
910 try:
910 try:
911 eid = msg['content']['id']
911 eid = msg['content']['id']
912 except:
912 except:
913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 return
914 return
915 self.log.info("registration::unregister_engine(%r)"%eid)
915 self.log.info("registration::unregister_engine(%r)"%eid)
916 # print (eid)
916 # print (eid)
917 uuid = self.keytable[eid]
917 uuid = self.keytable[eid]
918 content=dict(id=eid, queue=uuid.decode())
918 content=dict(id=eid, queue=uuid.decode('ascii'))
919 self.dead_engines.add(uuid)
919 self.dead_engines.add(uuid)
920 # self.ids.remove(eid)
920 # self.ids.remove(eid)
921 # uuid = self.keytable.pop(eid)
921 # uuid = self.keytable.pop(eid)
922 #
922 #
923 # ec = self.engines.pop(eid)
923 # ec = self.engines.pop(eid)
924 # self.hearts.pop(ec.heartbeat)
924 # self.hearts.pop(ec.heartbeat)
925 # self.by_ident.pop(ec.queue)
925 # self.by_ident.pop(ec.queue)
926 # self.completed.pop(eid)
926 # self.completed.pop(eid)
927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 dc.start()
929 dc.start()
930 ############## TODO: HANDLE IT ################
930 ############## TODO: HANDLE IT ################
931
931
932 if self.notifier:
932 if self.notifier:
933 self.session.send(self.notifier, "unregistration_notification", content=content)
933 self.session.send(self.notifier, "unregistration_notification", content=content)
934
934
935 def _handle_stranded_msgs(self, eid, uuid):
935 def _handle_stranded_msgs(self, eid, uuid):
936 """Handle messages known to be on an engine when the engine unregisters.
936 """Handle messages known to be on an engine when the engine unregisters.
937
937
938 It is possible that this will fire prematurely - that is, an engine will
938 It is possible that this will fire prematurely - that is, an engine will
939 go down after completing a result, and the client will be notified
939 go down after completing a result, and the client will be notified
940 that the result failed and later receive the actual result.
940 that the result failed and later receive the actual result.
941 """
941 """
942
942
943 outstanding = self.queues[eid]
943 outstanding = self.queues[eid]
944
944
945 for msg_id in outstanding:
945 for msg_id in outstanding:
946 self.pending.remove(msg_id)
946 self.pending.remove(msg_id)
947 self.all_completed.add(msg_id)
947 self.all_completed.add(msg_id)
948 try:
948 try:
949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 except:
950 except:
951 content = error.wrap_exception()
951 content = error.wrap_exception()
952 # build a fake header:
952 # build a fake header:
953 header = {}
953 header = {}
954 header['engine'] = uuid
954 header['engine'] = uuid
955 header['date'] = datetime.now()
955 header['date'] = datetime.now()
956 rec = dict(result_content=content, result_header=header, result_buffers=[])
956 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 rec['completed'] = header['date']
957 rec['completed'] = header['date']
958 rec['engine_uuid'] = uuid
958 rec['engine_uuid'] = uuid
959 try:
959 try:
960 self.db.update_record(msg_id, rec)
960 self.db.update_record(msg_id, rec)
961 except Exception:
961 except Exception:
962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963
963
964
964
965 def finish_registration(self, heart):
965 def finish_registration(self, heart):
966 """Second half of engine registration, called after our HeartMonitor
966 """Second half of engine registration, called after our HeartMonitor
967 has received a beat from the Engine's Heart."""
967 has received a beat from the Engine's Heart."""
968 try:
968 try:
969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 except KeyError:
970 except KeyError:
971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 return
972 return
973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 if purge is not None:
974 if purge is not None:
975 purge.stop()
975 purge.stop()
976 control = queue
976 control = queue
977 self.ids.add(eid)
977 self.ids.add(eid)
978 self.keytable[eid] = queue
978 self.keytable[eid] = queue
979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 control=control, heartbeat=heart)
980 control=control, heartbeat=heart)
981 self.by_ident[queue] = eid
981 self.by_ident[queue] = eid
982 self.queues[eid] = list()
982 self.queues[eid] = list()
983 self.tasks[eid] = list()
983 self.tasks[eid] = list()
984 self.completed[eid] = list()
984 self.completed[eid] = list()
985 self.hearts[heart] = eid
985 self.hearts[heart] = eid
986 content = dict(id=eid, queue=self.engines[eid].queue.decode())
986 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 if self.notifier:
987 if self.notifier:
988 self.session.send(self.notifier, "registration_notification", content=content)
988 self.session.send(self.notifier, "registration_notification", content=content)
989 self.log.info("engine::Engine Connected: %i"%eid)
989 self.log.info("engine::Engine Connected: %i"%eid)
990
990
991 def _purge_stalled_registration(self, heart):
991 def _purge_stalled_registration(self, heart):
992 if heart in self.incoming_registrations:
992 if heart in self.incoming_registrations:
993 eid = self.incoming_registrations.pop(heart)[0]
993 eid = self.incoming_registrations.pop(heart)[0]
994 self.log.info("registration::purging stalled registration: %i"%eid)
994 self.log.info("registration::purging stalled registration: %i"%eid)
995 else:
995 else:
996 pass
996 pass
997
997
998 #-------------------------------------------------------------------------
998 #-------------------------------------------------------------------------
999 # Client Requests
999 # Client Requests
1000 #-------------------------------------------------------------------------
1000 #-------------------------------------------------------------------------
1001
1001
1002 def shutdown_request(self, client_id, msg):
1002 def shutdown_request(self, client_id, msg):
1003 """handle shutdown request."""
1003 """handle shutdown request."""
1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 # also notify other clients of shutdown
1005 # also notify other clients of shutdown
1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 dc.start()
1008 dc.start()
1009
1009
1010 def _shutdown(self):
1010 def _shutdown(self):
1011 self.log.info("hub::hub shutting down.")
1011 self.log.info("hub::hub shutting down.")
1012 time.sleep(0.1)
1012 time.sleep(0.1)
1013 sys.exit(0)
1013 sys.exit(0)
1014
1014
1015
1015
1016 def check_load(self, client_id, msg):
1016 def check_load(self, client_id, msg):
1017 content = msg['content']
1017 content = msg['content']
1018 try:
1018 try:
1019 targets = content['targets']
1019 targets = content['targets']
1020 targets = self._validate_targets(targets)
1020 targets = self._validate_targets(targets)
1021 except:
1021 except:
1022 content = error.wrap_exception()
1022 content = error.wrap_exception()
1023 self.session.send(self.query, "hub_error",
1023 self.session.send(self.query, "hub_error",
1024 content=content, ident=client_id)
1024 content=content, ident=client_id)
1025 return
1025 return
1026
1026
1027 content = dict(status='ok')
1027 content = dict(status='ok')
1028 # loads = {}
1028 # loads = {}
1029 for t in targets:
1029 for t in targets:
1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032
1032
1033
1033
1034 def queue_status(self, client_id, msg):
1034 def queue_status(self, client_id, msg):
1035 """Return the Queue status of one or more targets.
1035 """Return the Queue status of one or more targets.
1036 if verbose: return the msg_ids
1036 if verbose: return the msg_ids
1037 else: return len of each type.
1037 else: return len of each type.
1038 keys: queue (pending MUX jobs)
1038 keys: queue (pending MUX jobs)
1039 tasks (pending Task jobs)
1039 tasks (pending Task jobs)
1040 completed (finished jobs from both queues)"""
1040 completed (finished jobs from both queues)"""
1041 content = msg['content']
1041 content = msg['content']
1042 targets = content['targets']
1042 targets = content['targets']
1043 try:
1043 try:
1044 targets = self._validate_targets(targets)
1044 targets = self._validate_targets(targets)
1045 except:
1045 except:
1046 content = error.wrap_exception()
1046 content = error.wrap_exception()
1047 self.session.send(self.query, "hub_error",
1047 self.session.send(self.query, "hub_error",
1048 content=content, ident=client_id)
1048 content=content, ident=client_id)
1049 return
1049 return
1050 verbose = content.get('verbose', False)
1050 verbose = content.get('verbose', False)
1051 content = dict(status='ok')
1051 content = dict(status='ok')
1052 for t in targets:
1052 for t in targets:
1053 queue = self.queues[t]
1053 queue = self.queues[t]
1054 completed = self.completed[t]
1054 completed = self.completed[t]
1055 tasks = self.tasks[t]
1055 tasks = self.tasks[t]
1056 if not verbose:
1056 if not verbose:
1057 queue = len(queue)
1057 queue = len(queue)
1058 completed = len(completed)
1058 completed = len(completed)
1059 tasks = len(tasks)
1059 tasks = len(tasks)
1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 # print (content)
1062 # print (content)
1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064
1064
1065 def purge_results(self, client_id, msg):
1065 def purge_results(self, client_id, msg):
1066 """Purge results from memory. This method is more valuable before we move
1066 """Purge results from memory. This method is more valuable before we move
1067 to a DB based message storage mechanism."""
1067 to a DB based message storage mechanism."""
1068 content = msg['content']
1068 content = msg['content']
1069 self.log.info("Dropping records with %s", content)
1069 self.log.info("Dropping records with %s", content)
1070 msg_ids = content.get('msg_ids', [])
1070 msg_ids = content.get('msg_ids', [])
1071 reply = dict(status='ok')
1071 reply = dict(status='ok')
1072 if msg_ids == 'all':
1072 if msg_ids == 'all':
1073 try:
1073 try:
1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 except Exception:
1075 except Exception:
1076 reply = error.wrap_exception()
1076 reply = error.wrap_exception()
1077 else:
1077 else:
1078 pending = filter(lambda m: m in self.pending, msg_ids)
1078 pending = filter(lambda m: m in self.pending, msg_ids)
1079 if pending:
1079 if pending:
1080 try:
1080 try:
1081 raise IndexError("msg pending: %r"%pending[0])
1081 raise IndexError("msg pending: %r"%pending[0])
1082 except:
1082 except:
1083 reply = error.wrap_exception()
1083 reply = error.wrap_exception()
1084 else:
1084 else:
1085 try:
1085 try:
1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 except Exception:
1087 except Exception:
1088 reply = error.wrap_exception()
1088 reply = error.wrap_exception()
1089
1089
1090 if reply['status'] == 'ok':
1090 if reply['status'] == 'ok':
1091 eids = content.get('engine_ids', [])
1091 eids = content.get('engine_ids', [])
1092 for eid in eids:
1092 for eid in eids:
1093 if eid not in self.engines:
1093 if eid not in self.engines:
1094 try:
1094 try:
1095 raise IndexError("No such engine: %i"%eid)
1095 raise IndexError("No such engine: %i"%eid)
1096 except:
1096 except:
1097 reply = error.wrap_exception()
1097 reply = error.wrap_exception()
1098 break
1098 break
1099 uid = self.engines[eid].queue
1099 uid = self.engines[eid].queue
1100 try:
1100 try:
1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 except Exception:
1102 except Exception:
1103 reply = error.wrap_exception()
1103 reply = error.wrap_exception()
1104 break
1104 break
1105
1105
1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107
1107
1108 def resubmit_task(self, client_id, msg):
1108 def resubmit_task(self, client_id, msg):
1109 """Resubmit one or more tasks."""
1109 """Resubmit one or more tasks."""
1110 def finish(reply):
1110 def finish(reply):
1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112
1112
1113 content = msg['content']
1113 content = msg['content']
1114 msg_ids = content['msg_ids']
1114 msg_ids = content['msg_ids']
1115 reply = dict(status='ok')
1115 reply = dict(status='ok')
1116 try:
1116 try:
1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 'header', 'content', 'buffers'])
1118 'header', 'content', 'buffers'])
1119 except Exception:
1119 except Exception:
1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 return finish(error.wrap_exception())
1121 return finish(error.wrap_exception())
1122
1122
1123 # validate msg_ids
1123 # validate msg_ids
1124 found_ids = [ rec['msg_id'] for rec in records ]
1124 found_ids = [ rec['msg_id'] for rec in records ]
1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 if len(records) > len(msg_ids):
1126 if len(records) > len(msg_ids):
1127 try:
1127 try:
1128 raise RuntimeError("DB appears to be in an inconsistent state."
1128 raise RuntimeError("DB appears to be in an inconsistent state."
1129 "More matching records were found than should exist")
1129 "More matching records were found than should exist")
1130 except Exception:
1130 except Exception:
1131 return finish(error.wrap_exception())
1131 return finish(error.wrap_exception())
1132 elif len(records) < len(msg_ids):
1132 elif len(records) < len(msg_ids):
1133 missing = [ m for m in msg_ids if m not in found_ids ]
1133 missing = [ m for m in msg_ids if m not in found_ids ]
1134 try:
1134 try:
1135 raise KeyError("No such msg(s): %r"%missing)
1135 raise KeyError("No such msg(s): %r"%missing)
1136 except KeyError:
1136 except KeyError:
1137 return finish(error.wrap_exception())
1137 return finish(error.wrap_exception())
1138 elif invalid_ids:
1138 elif invalid_ids:
1139 msg_id = invalid_ids[0]
1139 msg_id = invalid_ids[0]
1140 try:
1140 try:
1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 except Exception:
1142 except Exception:
1143 return finish(error.wrap_exception())
1143 return finish(error.wrap_exception())
1144
1144
1145 # clear the existing records
1145 # clear the existing records
1146 now = datetime.now()
1146 now = datetime.now()
1147 rec = empty_record()
1147 rec = empty_record()
1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 rec['resubmitted'] = now
1149 rec['resubmitted'] = now
1150 rec['queue'] = 'task'
1150 rec['queue'] = 'task'
1151 rec['client_uuid'] = client_id[0]
1151 rec['client_uuid'] = client_id[0]
1152 try:
1152 try:
1153 for msg_id in msg_ids:
1153 for msg_id in msg_ids:
1154 self.all_completed.discard(msg_id)
1154 self.all_completed.discard(msg_id)
1155 self.db.update_record(msg_id, rec)
1155 self.db.update_record(msg_id, rec)
1156 except Exception:
1156 except Exception:
1157 self.log.error('db::db error upating record', exc_info=True)
1157 self.log.error('db::db error upating record', exc_info=True)
1158 reply = error.wrap_exception()
1158 reply = error.wrap_exception()
1159 else:
1159 else:
1160 # send the messages
1160 # send the messages
1161 for rec in records:
1161 for rec in records:
1162 header = rec['header']
1162 header = rec['header']
1163 # include resubmitted in header to prevent digest collision
1163 # include resubmitted in header to prevent digest collision
1164 header['resubmitted'] = now
1164 header['resubmitted'] = now
1165 msg = self.session.msg(header['msg_type'])
1165 msg = self.session.msg(header['msg_type'])
1166 msg['content'] = rec['content']
1166 msg['content'] = rec['content']
1167 msg['header'] = header
1167 msg['header'] = header
1168 msg['msg_id'] = rec['msg_id']
1168 msg['msg_id'] = rec['msg_id']
1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170
1170
1171 finish(dict(status='ok'))
1171 finish(dict(status='ok'))
1172
1172
1173
1173
1174 def _extract_record(self, rec):
1174 def _extract_record(self, rec):
1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 io_dict = {}
1176 io_dict = {}
1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 io_dict[key] = rec[key]
1178 io_dict[key] = rec[key]
1179 content = { 'result_content': rec['result_content'],
1179 content = { 'result_content': rec['result_content'],
1180 'header': rec['header'],
1180 'header': rec['header'],
1181 'result_header' : rec['result_header'],
1181 'result_header' : rec['result_header'],
1182 'io' : io_dict,
1182 'io' : io_dict,
1183 }
1183 }
1184 if rec['result_buffers']:
1184 if rec['result_buffers']:
1185 buffers = map(bytes, rec['result_buffers'])
1185 buffers = map(bytes, rec['result_buffers'])
1186 else:
1186 else:
1187 buffers = []
1187 buffers = []
1188
1188
1189 return content, buffers
1189 return content, buffers
1190
1190
1191 def get_results(self, client_id, msg):
1191 def get_results(self, client_id, msg):
1192 """Get the result of 1 or more messages."""
1192 """Get the result of 1 or more messages."""
1193 content = msg['content']
1193 content = msg['content']
1194 msg_ids = sorted(set(content['msg_ids']))
1194 msg_ids = sorted(set(content['msg_ids']))
1195 statusonly = content.get('status_only', False)
1195 statusonly = content.get('status_only', False)
1196 pending = []
1196 pending = []
1197 completed = []
1197 completed = []
1198 content = dict(status='ok')
1198 content = dict(status='ok')
1199 content['pending'] = pending
1199 content['pending'] = pending
1200 content['completed'] = completed
1200 content['completed'] = completed
1201 buffers = []
1201 buffers = []
1202 if not statusonly:
1202 if not statusonly:
1203 try:
1203 try:
1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 # turn match list into dict, for faster lookup
1205 # turn match list into dict, for faster lookup
1206 records = {}
1206 records = {}
1207 for rec in matches:
1207 for rec in matches:
1208 records[rec['msg_id']] = rec
1208 records[rec['msg_id']] = rec
1209 except Exception:
1209 except Exception:
1210 content = error.wrap_exception()
1210 content = error.wrap_exception()
1211 self.session.send(self.query, "result_reply", content=content,
1211 self.session.send(self.query, "result_reply", content=content,
1212 parent=msg, ident=client_id)
1212 parent=msg, ident=client_id)
1213 return
1213 return
1214 else:
1214 else:
1215 records = {}
1215 records = {}
1216 for msg_id in msg_ids:
1216 for msg_id in msg_ids:
1217 if msg_id in self.pending:
1217 if msg_id in self.pending:
1218 pending.append(msg_id)
1218 pending.append(msg_id)
1219 elif msg_id in self.all_completed:
1219 elif msg_id in self.all_completed:
1220 completed.append(msg_id)
1220 completed.append(msg_id)
1221 if not statusonly:
1221 if not statusonly:
1222 c,bufs = self._extract_record(records[msg_id])
1222 c,bufs = self._extract_record(records[msg_id])
1223 content[msg_id] = c
1223 content[msg_id] = c
1224 buffers.extend(bufs)
1224 buffers.extend(bufs)
1225 elif msg_id in records:
1225 elif msg_id in records:
1226 if rec['completed']:
1226 if rec['completed']:
1227 completed.append(msg_id)
1227 completed.append(msg_id)
1228 c,bufs = self._extract_record(records[msg_id])
1228 c,bufs = self._extract_record(records[msg_id])
1229 content[msg_id] = c
1229 content[msg_id] = c
1230 buffers.extend(bufs)
1230 buffers.extend(bufs)
1231 else:
1231 else:
1232 pending.append(msg_id)
1232 pending.append(msg_id)
1233 else:
1233 else:
1234 try:
1234 try:
1235 raise KeyError('No such message: '+msg_id)
1235 raise KeyError('No such message: '+msg_id)
1236 except:
1236 except:
1237 content = error.wrap_exception()
1237 content = error.wrap_exception()
1238 break
1238 break
1239 self.session.send(self.query, "result_reply", content=content,
1239 self.session.send(self.query, "result_reply", content=content,
1240 parent=msg, ident=client_id,
1240 parent=msg, ident=client_id,
1241 buffers=buffers)
1241 buffers=buffers)
1242
1242
1243 def get_history(self, client_id, msg):
1243 def get_history(self, client_id, msg):
1244 """Get a list of all msg_ids in our DB records"""
1244 """Get a list of all msg_ids in our DB records"""
1245 try:
1245 try:
1246 msg_ids = self.db.get_history()
1246 msg_ids = self.db.get_history()
1247 except Exception as e:
1247 except Exception as e:
1248 content = error.wrap_exception()
1248 content = error.wrap_exception()
1249 else:
1249 else:
1250 content = dict(status='ok', history=msg_ids)
1250 content = dict(status='ok', history=msg_ids)
1251
1251
1252 self.session.send(self.query, "history_reply", content=content,
1252 self.session.send(self.query, "history_reply", content=content,
1253 parent=msg, ident=client_id)
1253 parent=msg, ident=client_id)
1254
1254
1255 def db_query(self, client_id, msg):
1255 def db_query(self, client_id, msg):
1256 """Perform a raw query on the task record database."""
1256 """Perform a raw query on the task record database."""
1257 content = msg['content']
1257 content = msg['content']
1258 query = content.get('query', {})
1258 query = content.get('query', {})
1259 keys = content.get('keys', None)
1259 keys = content.get('keys', None)
1260 buffers = []
1260 buffers = []
1261 empty = list()
1261 empty = list()
1262 try:
1262 try:
1263 records = self.db.find_records(query, keys)
1263 records = self.db.find_records(query, keys)
1264 except Exception as e:
1264 except Exception as e:
1265 content = error.wrap_exception()
1265 content = error.wrap_exception()
1266 else:
1266 else:
1267 # extract buffers from reply content:
1267 # extract buffers from reply content:
1268 if keys is not None:
1268 if keys is not None:
1269 buffer_lens = [] if 'buffers' in keys else None
1269 buffer_lens = [] if 'buffers' in keys else None
1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 else:
1271 else:
1272 buffer_lens = []
1272 buffer_lens = []
1273 result_buffer_lens = []
1273 result_buffer_lens = []
1274
1274
1275 for rec in records:
1275 for rec in records:
1276 # buffers may be None, so double check
1276 # buffers may be None, so double check
1277 if buffer_lens is not None:
1277 if buffer_lens is not None:
1278 b = rec.pop('buffers', empty) or empty
1278 b = rec.pop('buffers', empty) or empty
1279 buffer_lens.append(len(b))
1279 buffer_lens.append(len(b))
1280 buffers.extend(b)
1280 buffers.extend(b)
1281 if result_buffer_lens is not None:
1281 if result_buffer_lens is not None:
1282 rb = rec.pop('result_buffers', empty) or empty
1282 rb = rec.pop('result_buffers', empty) or empty
1283 result_buffer_lens.append(len(rb))
1283 result_buffer_lens.append(len(rb))
1284 buffers.extend(rb)
1284 buffers.extend(rb)
1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 result_buffer_lens=result_buffer_lens)
1286 result_buffer_lens=result_buffer_lens)
1287 # self.log.debug (content)
1287 # self.log.debug (content)
1288 self.session.send(self.query, "db_reply", content=content,
1288 self.session.send(self.query, "db_reply", content=content,
1289 parent=msg, ident=client_id,
1289 parent=msg, ident=client_id,
1290 buffers=buffers)
1290 buffers=buffers)
1291
1291
@@ -1,714 +1,714 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 from __future__ import print_function
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26
26
27 from datetime import datetime, timedelta
27 from datetime import datetime, timedelta
28 from random import randint, random
28 from random import randint, random
29 from types import FunctionType
29 from types import FunctionType
30
30
31 try:
31 try:
32 import numpy
32 import numpy
33 except ImportError:
33 except ImportError:
34 numpy = None
34 numpy = None
35
35
36 import zmq
36 import zmq
37 from zmq.eventloop import ioloop, zmqstream
37 from zmq.eventloop import ioloop, zmqstream
38
38
39 # local imports
39 # local imports
40 from IPython.external.decorator import decorator
40 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
41 from IPython.config.application import Application
42 from IPython.config.loader import Config
42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
44
44
45 from IPython.parallel import error
45 from IPython.parallel import error
46 from IPython.parallel.factory import SessionFactory
46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger, ensure_bytes
47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48
48
49 from .dependency import Dependency
49 from .dependency import Dependency
50
50
51 @decorator
51 @decorator
52 def logged(f,self,*args,**kwargs):
52 def logged(f,self,*args,**kwargs):
53 # print ("#--------------------")
53 # print ("#--------------------")
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 # print ("#--")
55 # print ("#--")
56 return f(self,*args, **kwargs)
56 return f(self,*args, **kwargs)
57
57
58 #----------------------------------------------------------------------
58 #----------------------------------------------------------------------
59 # Chooser functions
59 # Chooser functions
60 #----------------------------------------------------------------------
60 #----------------------------------------------------------------------
61
61
62 def plainrandom(loads):
62 def plainrandom(loads):
63 """Plain random pick."""
63 """Plain random pick."""
64 n = len(loads)
64 n = len(loads)
65 return randint(0,n-1)
65 return randint(0,n-1)
66
66
67 def lru(loads):
67 def lru(loads):
68 """Always pick the front of the line.
68 """Always pick the front of the line.
69
69
70 The content of `loads` is ignored.
70 The content of `loads` is ignored.
71
71
72 Assumes LRU ordering of loads, with oldest first.
72 Assumes LRU ordering of loads, with oldest first.
73 """
73 """
74 return 0
74 return 0
75
75
76 def twobin(loads):
76 def twobin(loads):
77 """Pick two at random, use the LRU of the two.
77 """Pick two at random, use the LRU of the two.
78
78
79 The content of loads is ignored.
79 The content of loads is ignored.
80
80
81 Assumes LRU ordering of loads, with oldest first.
81 Assumes LRU ordering of loads, with oldest first.
82 """
82 """
83 n = len(loads)
83 n = len(loads)
84 a = randint(0,n-1)
84 a = randint(0,n-1)
85 b = randint(0,n-1)
85 b = randint(0,n-1)
86 return min(a,b)
86 return min(a,b)
87
87
88 def weighted(loads):
88 def weighted(loads):
89 """Pick two at random using inverse load as weight.
89 """Pick two at random using inverse load as weight.
90
90
91 Return the less loaded of the two.
91 Return the less loaded of the two.
92 """
92 """
93 # weight 0 a million times more than 1:
93 # weight 0 a million times more than 1:
94 weights = 1./(1e-6+numpy.array(loads))
94 weights = 1./(1e-6+numpy.array(loads))
95 sums = weights.cumsum()
95 sums = weights.cumsum()
96 t = sums[-1]
96 t = sums[-1]
97 x = random()*t
97 x = random()*t
98 y = random()*t
98 y = random()*t
99 idx = 0
99 idx = 0
100 idy = 0
100 idy = 0
101 while sums[idx] < x:
101 while sums[idx] < x:
102 idx += 1
102 idx += 1
103 while sums[idy] < y:
103 while sums[idy] < y:
104 idy += 1
104 idy += 1
105 if weights[idy] > weights[idx]:
105 if weights[idy] > weights[idx]:
106 return idy
106 return idy
107 else:
107 else:
108 return idx
108 return idx
109
109
110 def leastload(loads):
110 def leastload(loads):
111 """Always choose the lowest load.
111 """Always choose the lowest load.
112
112
113 If the lowest load occurs more than once, the first
113 If the lowest load occurs more than once, the first
114 occurance will be used. If loads has LRU ordering, this means
114 occurance will be used. If loads has LRU ordering, this means
115 the LRU of those with the lowest load is chosen.
115 the LRU of those with the lowest load is chosen.
116 """
116 """
117 return loads.index(min(loads))
117 return loads.index(min(loads))
118
118
119 #---------------------------------------------------------------------
119 #---------------------------------------------------------------------
120 # Classes
120 # Classes
121 #---------------------------------------------------------------------
121 #---------------------------------------------------------------------
122 # store empty default dependency:
122 # store empty default dependency:
123 MET = Dependency([])
123 MET = Dependency([])
124
124
125 class TaskScheduler(SessionFactory):
125 class TaskScheduler(SessionFactory):
126 """Python TaskScheduler object.
126 """Python TaskScheduler object.
127
127
128 This is the simplest object that supports msg_id based
128 This is the simplest object that supports msg_id based
129 DAG dependencies. *Only* task msg_ids are checked, not
129 DAG dependencies. *Only* task msg_ids are checked, not
130 msg_ids of jobs submitted via the MUX queue.
130 msg_ids of jobs submitted via the MUX queue.
131
131
132 """
132 """
133
133
134 hwm = Int(0, config=True, shortname='hwm',
134 hwm = Int(0, config=True, shortname='hwm',
135 help="""specify the High Water Mark (HWM) for the downstream
135 help="""specify the High Water Mark (HWM) for the downstream
136 socket in the Task scheduler. This is the maximum number
136 socket in the Task scheduler. This is the maximum number
137 of allowed outstanding tasks on each engine."""
137 of allowed outstanding tasks on each engine."""
138 )
138 )
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 'leastload', config=True, shortname='scheme', allow_none=False,
140 'leastload', config=True, shortname='scheme', allow_none=False,
141 help="""select the task scheduler scheme [default: Python LRU]
141 help="""select the task scheduler scheme [default: Python LRU]
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 )
143 )
144 def _scheme_name_changed(self, old, new):
144 def _scheme_name_changed(self, old, new):
145 self.log.debug("Using scheme %r"%new)
145 self.log.debug("Using scheme %r"%new)
146 self.scheme = globals()[new]
146 self.scheme = globals()[new]
147
147
148 # input arguments:
148 # input arguments:
149 scheme = Instance(FunctionType) # function for determining the destination
149 scheme = Instance(FunctionType) # function for determining the destination
150 def _scheme_default(self):
150 def _scheme_default(self):
151 return leastload
151 return leastload
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156
156
157 # internals:
157 # internals:
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 pending = Dict() # dict by engine_uuid of submitted tasks
162 pending = Dict() # dict by engine_uuid of submitted tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
163 completed = Dict() # dict by engine_uuid of completed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
164 failed = Dict() # dict by engine_uuid of failed tasks
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 clients = Dict() # dict by msg_id for who submitted the task
166 clients = Dict() # dict by msg_id for who submitted the task
167 targets = List() # list of target IDENTs
167 targets = List() # list of target IDENTs
168 loads = List() # list of engine loads
168 loads = List() # list of engine loads
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 all_completed = Set() # set of all completed tasks
170 all_completed = Set() # set of all completed tasks
171 all_failed = Set() # set of all failed tasks
171 all_failed = Set() # set of all failed tasks
172 all_done = Set() # set of all finished tasks=union(completed,failed)
172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 all_ids = Set() # set of all submitted task IDs
173 all_ids = Set() # set of all submitted task IDs
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176
176
177 ident = CBytes() # ZMQ identity. This should just be self.session.session
177 ident = CBytes() # ZMQ identity. This should just be self.session.session
178 # but ensure Bytes
178 # but ensure Bytes
179 def _ident_default(self):
179 def _ident_default(self):
180 return ensure_bytes(self.session.session)
180 return asbytes(self.session.session)
181
181
182 def start(self):
182 def start(self):
183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
184 self._notification_handlers = dict(
184 self._notification_handlers = dict(
185 registration_notification = self._register_engine,
185 registration_notification = self._register_engine,
186 unregistration_notification = self._unregister_engine
186 unregistration_notification = self._unregister_engine
187 )
187 )
188 self.notifier_stream.on_recv(self.dispatch_notification)
188 self.notifier_stream.on_recv(self.dispatch_notification)
189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
190 self.auditor.start()
190 self.auditor.start()
191 self.log.info("Scheduler started [%s]"%self.scheme_name)
191 self.log.info("Scheduler started [%s]"%self.scheme_name)
192
192
193 def resume_receiving(self):
193 def resume_receiving(self):
194 """Resume accepting jobs."""
194 """Resume accepting jobs."""
195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
196
196
197 def stop_receiving(self):
197 def stop_receiving(self):
198 """Stop accepting jobs while there are no engines.
198 """Stop accepting jobs while there are no engines.
199 Leave them in the ZMQ queue."""
199 Leave them in the ZMQ queue."""
200 self.client_stream.on_recv(None)
200 self.client_stream.on_recv(None)
201
201
202 #-----------------------------------------------------------------------
202 #-----------------------------------------------------------------------
203 # [Un]Registration Handling
203 # [Un]Registration Handling
204 #-----------------------------------------------------------------------
204 #-----------------------------------------------------------------------
205
205
206 def dispatch_notification(self, msg):
206 def dispatch_notification(self, msg):
207 """dispatch register/unregister events."""
207 """dispatch register/unregister events."""
208 try:
208 try:
209 idents,msg = self.session.feed_identities(msg)
209 idents,msg = self.session.feed_identities(msg)
210 except ValueError:
210 except ValueError:
211 self.log.warn("task::Invalid Message: %r",msg)
211 self.log.warn("task::Invalid Message: %r",msg)
212 return
212 return
213 try:
213 try:
214 msg = self.session.unpack_message(msg)
214 msg = self.session.unpack_message(msg)
215 except ValueError:
215 except ValueError:
216 self.log.warn("task::Unauthorized message from: %r"%idents)
216 self.log.warn("task::Unauthorized message from: %r"%idents)
217 return
217 return
218
218
219 msg_type = msg['msg_type']
219 msg_type = msg['msg_type']
220
220
221 handler = self._notification_handlers.get(msg_type, None)
221 handler = self._notification_handlers.get(msg_type, None)
222 if handler is None:
222 if handler is None:
223 self.log.error("Unhandled message type: %r"%msg_type)
223 self.log.error("Unhandled message type: %r"%msg_type)
224 else:
224 else:
225 try:
225 try:
226 handler(ensure_bytes(msg['content']['queue']))
226 handler(asbytes(msg['content']['queue']))
227 except Exception:
227 except Exception:
228 self.log.error("task::Invalid notification msg: %r",msg)
228 self.log.error("task::Invalid notification msg: %r",msg)
229
229
230 def _register_engine(self, uid):
230 def _register_engine(self, uid):
231 """New engine with ident `uid` became available."""
231 """New engine with ident `uid` became available."""
232 # head of the line:
232 # head of the line:
233 self.targets.insert(0,uid)
233 self.targets.insert(0,uid)
234 self.loads.insert(0,0)
234 self.loads.insert(0,0)
235
235
236 # initialize sets
236 # initialize sets
237 self.completed[uid] = set()
237 self.completed[uid] = set()
238 self.failed[uid] = set()
238 self.failed[uid] = set()
239 self.pending[uid] = {}
239 self.pending[uid] = {}
240 if len(self.targets) == 1:
240 if len(self.targets) == 1:
241 self.resume_receiving()
241 self.resume_receiving()
242 # rescan the graph:
242 # rescan the graph:
243 self.update_graph(None)
243 self.update_graph(None)
244
244
245 def _unregister_engine(self, uid):
245 def _unregister_engine(self, uid):
246 """Existing engine with ident `uid` became unavailable."""
246 """Existing engine with ident `uid` became unavailable."""
247 if len(self.targets) == 1:
247 if len(self.targets) == 1:
248 # this was our only engine
248 # this was our only engine
249 self.stop_receiving()
249 self.stop_receiving()
250
250
251 # handle any potentially finished tasks:
251 # handle any potentially finished tasks:
252 self.engine_stream.flush()
252 self.engine_stream.flush()
253
253
254 # don't pop destinations, because they might be used later
254 # don't pop destinations, because they might be used later
255 # map(self.destinations.pop, self.completed.pop(uid))
255 # map(self.destinations.pop, self.completed.pop(uid))
256 # map(self.destinations.pop, self.failed.pop(uid))
256 # map(self.destinations.pop, self.failed.pop(uid))
257
257
258 # prevent this engine from receiving work
258 # prevent this engine from receiving work
259 idx = self.targets.index(uid)
259 idx = self.targets.index(uid)
260 self.targets.pop(idx)
260 self.targets.pop(idx)
261 self.loads.pop(idx)
261 self.loads.pop(idx)
262
262
263 # wait 5 seconds before cleaning up pending jobs, since the results might
263 # wait 5 seconds before cleaning up pending jobs, since the results might
264 # still be incoming
264 # still be incoming
265 if self.pending[uid]:
265 if self.pending[uid]:
266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
267 dc.start()
267 dc.start()
268 else:
268 else:
269 self.completed.pop(uid)
269 self.completed.pop(uid)
270 self.failed.pop(uid)
270 self.failed.pop(uid)
271
271
272
272
273 def handle_stranded_tasks(self, engine):
273 def handle_stranded_tasks(self, engine):
274 """Deal with jobs resident in an engine that died."""
274 """Deal with jobs resident in an engine that died."""
275 lost = self.pending[engine]
275 lost = self.pending[engine]
276 for msg_id in lost.keys():
276 for msg_id in lost.keys():
277 if msg_id not in self.pending[engine]:
277 if msg_id not in self.pending[engine]:
278 # prevent double-handling of messages
278 # prevent double-handling of messages
279 continue
279 continue
280
280
281 raw_msg = lost[msg_id][0]
281 raw_msg = lost[msg_id][0]
282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
283 parent = self.session.unpack(msg[1].bytes)
283 parent = self.session.unpack(msg[1].bytes)
284 idents = [engine, idents[0]]
284 idents = [engine, idents[0]]
285
285
286 # build fake error reply
286 # build fake error reply
287 try:
287 try:
288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
289 except:
289 except:
290 content = error.wrap_exception()
290 content = error.wrap_exception()
291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
293 # and dispatch it
293 # and dispatch it
294 self.dispatch_result(raw_reply)
294 self.dispatch_result(raw_reply)
295
295
296 # finally scrub completed/failed lists
296 # finally scrub completed/failed lists
297 self.completed.pop(engine)
297 self.completed.pop(engine)
298 self.failed.pop(engine)
298 self.failed.pop(engine)
299
299
300
300
301 #-----------------------------------------------------------------------
301 #-----------------------------------------------------------------------
302 # Job Submission
302 # Job Submission
303 #-----------------------------------------------------------------------
303 #-----------------------------------------------------------------------
304 def dispatch_submission(self, raw_msg):
304 def dispatch_submission(self, raw_msg):
305 """Dispatch job submission to appropriate handlers."""
305 """Dispatch job submission to appropriate handlers."""
306 # ensure targets up to date:
306 # ensure targets up to date:
307 self.notifier_stream.flush()
307 self.notifier_stream.flush()
308 try:
308 try:
309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
310 msg = self.session.unpack_message(msg, content=False, copy=False)
310 msg = self.session.unpack_message(msg, content=False, copy=False)
311 except Exception:
311 except Exception:
312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
313 return
313 return
314
314
315
315
316 # send to monitor
316 # send to monitor
317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
318
318
319 header = msg['header']
319 header = msg['header']
320 msg_id = header['msg_id']
320 msg_id = header['msg_id']
321 self.all_ids.add(msg_id)
321 self.all_ids.add(msg_id)
322
322
323 # get targets as a set of bytes objects
323 # get targets as a set of bytes objects
324 # from a list of unicode objects
324 # from a list of unicode objects
325 targets = header.get('targets', [])
325 targets = header.get('targets', [])
326 targets = map(ensure_bytes, targets)
326 targets = map(asbytes, targets)
327 targets = set(targets)
327 targets = set(targets)
328
328
329 retries = header.get('retries', 0)
329 retries = header.get('retries', 0)
330 self.retries[msg_id] = retries
330 self.retries[msg_id] = retries
331
331
332 # time dependencies
332 # time dependencies
333 after = header.get('after', None)
333 after = header.get('after', None)
334 if after:
334 if after:
335 after = Dependency(after)
335 after = Dependency(after)
336 if after.all:
336 if after.all:
337 if after.success:
337 if after.success:
338 after = Dependency(after.difference(self.all_completed),
338 after = Dependency(after.difference(self.all_completed),
339 success=after.success,
339 success=after.success,
340 failure=after.failure,
340 failure=after.failure,
341 all=after.all,
341 all=after.all,
342 )
342 )
343 if after.failure:
343 if after.failure:
344 after = Dependency(after.difference(self.all_failed),
344 after = Dependency(after.difference(self.all_failed),
345 success=after.success,
345 success=after.success,
346 failure=after.failure,
346 failure=after.failure,
347 all=after.all,
347 all=after.all,
348 )
348 )
349 if after.check(self.all_completed, self.all_failed):
349 if after.check(self.all_completed, self.all_failed):
350 # recast as empty set, if `after` already met,
350 # recast as empty set, if `after` already met,
351 # to prevent unnecessary set comparisons
351 # to prevent unnecessary set comparisons
352 after = MET
352 after = MET
353 else:
353 else:
354 after = MET
354 after = MET
355
355
356 # location dependencies
356 # location dependencies
357 follow = Dependency(header.get('follow', []))
357 follow = Dependency(header.get('follow', []))
358
358
359 # turn timeouts into datetime objects:
359 # turn timeouts into datetime objects:
360 timeout = header.get('timeout', None)
360 timeout = header.get('timeout', None)
361 if timeout:
361 if timeout:
362 timeout = datetime.now() + timedelta(0,timeout,0)
362 timeout = datetime.now() + timedelta(0,timeout,0)
363
363
364 args = [raw_msg, targets, after, follow, timeout]
364 args = [raw_msg, targets, after, follow, timeout]
365
365
366 # validate and reduce dependencies:
366 # validate and reduce dependencies:
367 for dep in after,follow:
367 for dep in after,follow:
368 if not dep: # empty dependency
368 if not dep: # empty dependency
369 continue
369 continue
370 # check valid:
370 # check valid:
371 if msg_id in dep or dep.difference(self.all_ids):
371 if msg_id in dep or dep.difference(self.all_ids):
372 self.depending[msg_id] = args
372 self.depending[msg_id] = args
373 return self.fail_unreachable(msg_id, error.InvalidDependency)
373 return self.fail_unreachable(msg_id, error.InvalidDependency)
374 # check if unreachable:
374 # check if unreachable:
375 if dep.unreachable(self.all_completed, self.all_failed):
375 if dep.unreachable(self.all_completed, self.all_failed):
376 self.depending[msg_id] = args
376 self.depending[msg_id] = args
377 return self.fail_unreachable(msg_id)
377 return self.fail_unreachable(msg_id)
378
378
379 if after.check(self.all_completed, self.all_failed):
379 if after.check(self.all_completed, self.all_failed):
380 # time deps already met, try to run
380 # time deps already met, try to run
381 if not self.maybe_run(msg_id, *args):
381 if not self.maybe_run(msg_id, *args):
382 # can't run yet
382 # can't run yet
383 if msg_id not in self.all_failed:
383 if msg_id not in self.all_failed:
384 # could have failed as unreachable
384 # could have failed as unreachable
385 self.save_unmet(msg_id, *args)
385 self.save_unmet(msg_id, *args)
386 else:
386 else:
387 self.save_unmet(msg_id, *args)
387 self.save_unmet(msg_id, *args)
388
388
389 def audit_timeouts(self):
389 def audit_timeouts(self):
390 """Audit all waiting tasks for expired timeouts."""
390 """Audit all waiting tasks for expired timeouts."""
391 now = datetime.now()
391 now = datetime.now()
392 for msg_id in self.depending.keys():
392 for msg_id in self.depending.keys():
393 # must recheck, in case one failure cascaded to another:
393 # must recheck, in case one failure cascaded to another:
394 if msg_id in self.depending:
394 if msg_id in self.depending:
395 raw,after,targets,follow,timeout = self.depending[msg_id]
395 raw,after,targets,follow,timeout = self.depending[msg_id]
396 if timeout and timeout < now:
396 if timeout and timeout < now:
397 self.fail_unreachable(msg_id, error.TaskTimeout)
397 self.fail_unreachable(msg_id, error.TaskTimeout)
398
398
399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
400 """a task has become unreachable, send a reply with an ImpossibleDependency
400 """a task has become unreachable, send a reply with an ImpossibleDependency
401 error."""
401 error."""
402 if msg_id not in self.depending:
402 if msg_id not in self.depending:
403 self.log.error("msg %r already failed!", msg_id)
403 self.log.error("msg %r already failed!", msg_id)
404 return
404 return
405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
406 for mid in follow.union(after):
406 for mid in follow.union(after):
407 if mid in self.graph:
407 if mid in self.graph:
408 self.graph[mid].remove(msg_id)
408 self.graph[mid].remove(msg_id)
409
409
410 # FIXME: unpacking a message I've already unpacked, but didn't save:
410 # FIXME: unpacking a message I've already unpacked, but didn't save:
411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
412 header = self.session.unpack(msg[1].bytes)
412 header = self.session.unpack(msg[1].bytes)
413
413
414 try:
414 try:
415 raise why()
415 raise why()
416 except:
416 except:
417 content = error.wrap_exception()
417 content = error.wrap_exception()
418
418
419 self.all_done.add(msg_id)
419 self.all_done.add(msg_id)
420 self.all_failed.add(msg_id)
420 self.all_failed.add(msg_id)
421
421
422 msg = self.session.send(self.client_stream, 'apply_reply', content,
422 msg = self.session.send(self.client_stream, 'apply_reply', content,
423 parent=header, ident=idents)
423 parent=header, ident=idents)
424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
425
425
426 self.update_graph(msg_id, success=False)
426 self.update_graph(msg_id, success=False)
427
427
428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
429 """check location dependencies, and run if they are met."""
429 """check location dependencies, and run if they are met."""
430 blacklist = self.blacklist.setdefault(msg_id, set())
430 blacklist = self.blacklist.setdefault(msg_id, set())
431 if follow or targets or blacklist or self.hwm:
431 if follow or targets or blacklist or self.hwm:
432 # we need a can_run filter
432 # we need a can_run filter
433 def can_run(idx):
433 def can_run(idx):
434 # check hwm
434 # check hwm
435 if self.hwm and self.loads[idx] == self.hwm:
435 if self.hwm and self.loads[idx] == self.hwm:
436 return False
436 return False
437 target = self.targets[idx]
437 target = self.targets[idx]
438 # check blacklist
438 # check blacklist
439 if target in blacklist:
439 if target in blacklist:
440 return False
440 return False
441 # check targets
441 # check targets
442 if targets and target not in targets:
442 if targets and target not in targets:
443 return False
443 return False
444 # check follow
444 # check follow
445 return follow.check(self.completed[target], self.failed[target])
445 return follow.check(self.completed[target], self.failed[target])
446
446
447 indices = filter(can_run, range(len(self.targets)))
447 indices = filter(can_run, range(len(self.targets)))
448
448
449 if not indices:
449 if not indices:
450 # couldn't run
450 # couldn't run
451 if follow.all:
451 if follow.all:
452 # check follow for impossibility
452 # check follow for impossibility
453 dests = set()
453 dests = set()
454 relevant = set()
454 relevant = set()
455 if follow.success:
455 if follow.success:
456 relevant = self.all_completed
456 relevant = self.all_completed
457 if follow.failure:
457 if follow.failure:
458 relevant = relevant.union(self.all_failed)
458 relevant = relevant.union(self.all_failed)
459 for m in follow.intersection(relevant):
459 for m in follow.intersection(relevant):
460 dests.add(self.destinations[m])
460 dests.add(self.destinations[m])
461 if len(dests) > 1:
461 if len(dests) > 1:
462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
463 self.fail_unreachable(msg_id)
463 self.fail_unreachable(msg_id)
464 return False
464 return False
465 if targets:
465 if targets:
466 # check blacklist+targets for impossibility
466 # check blacklist+targets for impossibility
467 targets.difference_update(blacklist)
467 targets.difference_update(blacklist)
468 if not targets or not targets.intersection(self.targets):
468 if not targets or not targets.intersection(self.targets):
469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
470 self.fail_unreachable(msg_id)
470 self.fail_unreachable(msg_id)
471 return False
471 return False
472 return False
472 return False
473 else:
473 else:
474 indices = None
474 indices = None
475
475
476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
477 return True
477 return True
478
478
479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
480 """Save a message for later submission when its dependencies are met."""
480 """Save a message for later submission when its dependencies are met."""
481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
482 # track the ids in follow or after, but not those already finished
482 # track the ids in follow or after, but not those already finished
483 for dep_id in after.union(follow).difference(self.all_done):
483 for dep_id in after.union(follow).difference(self.all_done):
484 if dep_id not in self.graph:
484 if dep_id not in self.graph:
485 self.graph[dep_id] = set()
485 self.graph[dep_id] = set()
486 self.graph[dep_id].add(msg_id)
486 self.graph[dep_id].add(msg_id)
487
487
488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
489 """Submit a task to any of a subset of our targets."""
489 """Submit a task to any of a subset of our targets."""
490 if indices:
490 if indices:
491 loads = [self.loads[i] for i in indices]
491 loads = [self.loads[i] for i in indices]
492 else:
492 else:
493 loads = self.loads
493 loads = self.loads
494 idx = self.scheme(loads)
494 idx = self.scheme(loads)
495 if indices:
495 if indices:
496 idx = indices[idx]
496 idx = indices[idx]
497 target = self.targets[idx]
497 target = self.targets[idx]
498 # print (target, map(str, msg[:3]))
498 # print (target, map(str, msg[:3]))
499 # send job to the engine
499 # send job to the engine
500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
501 self.engine_stream.send_multipart(raw_msg, copy=False)
501 self.engine_stream.send_multipart(raw_msg, copy=False)
502 # update load
502 # update load
503 self.add_job(idx)
503 self.add_job(idx)
504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
505 # notify Hub
505 # notify Hub
506 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
506 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
507 self.session.send(self.mon_stream, 'task_destination', content=content,
507 self.session.send(self.mon_stream, 'task_destination', content=content,
508 ident=[b'tracktask',self.ident])
508 ident=[b'tracktask',self.ident])
509
509
510
510
511 #-----------------------------------------------------------------------
511 #-----------------------------------------------------------------------
512 # Result Handling
512 # Result Handling
513 #-----------------------------------------------------------------------
513 #-----------------------------------------------------------------------
514 def dispatch_result(self, raw_msg):
514 def dispatch_result(self, raw_msg):
515 """dispatch method for result replies"""
515 """dispatch method for result replies"""
516 try:
516 try:
517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
518 msg = self.session.unpack_message(msg, content=False, copy=False)
518 msg = self.session.unpack_message(msg, content=False, copy=False)
519 engine = idents[0]
519 engine = idents[0]
520 try:
520 try:
521 idx = self.targets.index(engine)
521 idx = self.targets.index(engine)
522 except ValueError:
522 except ValueError:
523 pass # skip load-update for dead engines
523 pass # skip load-update for dead engines
524 else:
524 else:
525 self.finish_job(idx)
525 self.finish_job(idx)
526 except Exception:
526 except Exception:
527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
528 return
528 return
529
529
530 header = msg['header']
530 header = msg['header']
531 parent = msg['parent_header']
531 parent = msg['parent_header']
532 if header.get('dependencies_met', True):
532 if header.get('dependencies_met', True):
533 success = (header['status'] == 'ok')
533 success = (header['status'] == 'ok')
534 msg_id = parent['msg_id']
534 msg_id = parent['msg_id']
535 retries = self.retries[msg_id]
535 retries = self.retries[msg_id]
536 if not success and retries > 0:
536 if not success and retries > 0:
537 # failed
537 # failed
538 self.retries[msg_id] = retries - 1
538 self.retries[msg_id] = retries - 1
539 self.handle_unmet_dependency(idents, parent)
539 self.handle_unmet_dependency(idents, parent)
540 else:
540 else:
541 del self.retries[msg_id]
541 del self.retries[msg_id]
542 # relay to client and update graph
542 # relay to client and update graph
543 self.handle_result(idents, parent, raw_msg, success)
543 self.handle_result(idents, parent, raw_msg, success)
544 # send to Hub monitor
544 # send to Hub monitor
545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
546 else:
546 else:
547 self.handle_unmet_dependency(idents, parent)
547 self.handle_unmet_dependency(idents, parent)
548
548
549 def handle_result(self, idents, parent, raw_msg, success=True):
549 def handle_result(self, idents, parent, raw_msg, success=True):
550 """handle a real task result, either success or failure"""
550 """handle a real task result, either success or failure"""
551 # first, relay result to client
551 # first, relay result to client
552 engine = idents[0]
552 engine = idents[0]
553 client = idents[1]
553 client = idents[1]
554 # swap_ids for XREP-XREP mirror
554 # swap_ids for XREP-XREP mirror
555 raw_msg[:2] = [client,engine]
555 raw_msg[:2] = [client,engine]
556 # print (map(str, raw_msg[:4]))
556 # print (map(str, raw_msg[:4]))
557 self.client_stream.send_multipart(raw_msg, copy=False)
557 self.client_stream.send_multipart(raw_msg, copy=False)
558 # now, update our data structures
558 # now, update our data structures
559 msg_id = parent['msg_id']
559 msg_id = parent['msg_id']
560 self.blacklist.pop(msg_id, None)
560 self.blacklist.pop(msg_id, None)
561 self.pending[engine].pop(msg_id)
561 self.pending[engine].pop(msg_id)
562 if success:
562 if success:
563 self.completed[engine].add(msg_id)
563 self.completed[engine].add(msg_id)
564 self.all_completed.add(msg_id)
564 self.all_completed.add(msg_id)
565 else:
565 else:
566 self.failed[engine].add(msg_id)
566 self.failed[engine].add(msg_id)
567 self.all_failed.add(msg_id)
567 self.all_failed.add(msg_id)
568 self.all_done.add(msg_id)
568 self.all_done.add(msg_id)
569 self.destinations[msg_id] = engine
569 self.destinations[msg_id] = engine
570
570
571 self.update_graph(msg_id, success)
571 self.update_graph(msg_id, success)
572
572
573 def handle_unmet_dependency(self, idents, parent):
573 def handle_unmet_dependency(self, idents, parent):
574 """handle an unmet dependency"""
574 """handle an unmet dependency"""
575 engine = idents[0]
575 engine = idents[0]
576 msg_id = parent['msg_id']
576 msg_id = parent['msg_id']
577
577
578 if msg_id not in self.blacklist:
578 if msg_id not in self.blacklist:
579 self.blacklist[msg_id] = set()
579 self.blacklist[msg_id] = set()
580 self.blacklist[msg_id].add(engine)
580 self.blacklist[msg_id].add(engine)
581
581
582 args = self.pending[engine].pop(msg_id)
582 args = self.pending[engine].pop(msg_id)
583 raw,targets,after,follow,timeout = args
583 raw,targets,after,follow,timeout = args
584
584
585 if self.blacklist[msg_id] == targets:
585 if self.blacklist[msg_id] == targets:
586 self.depending[msg_id] = args
586 self.depending[msg_id] = args
587 self.fail_unreachable(msg_id)
587 self.fail_unreachable(msg_id)
588 elif not self.maybe_run(msg_id, *args):
588 elif not self.maybe_run(msg_id, *args):
589 # resubmit failed
589 # resubmit failed
590 if msg_id not in self.all_failed:
590 if msg_id not in self.all_failed:
591 # put it back in our dependency tree
591 # put it back in our dependency tree
592 self.save_unmet(msg_id, *args)
592 self.save_unmet(msg_id, *args)
593
593
594 if self.hwm:
594 if self.hwm:
595 try:
595 try:
596 idx = self.targets.index(engine)
596 idx = self.targets.index(engine)
597 except ValueError:
597 except ValueError:
598 pass # skip load-update for dead engines
598 pass # skip load-update for dead engines
599 else:
599 else:
600 if self.loads[idx] == self.hwm-1:
600 if self.loads[idx] == self.hwm-1:
601 self.update_graph(None)
601 self.update_graph(None)
602
602
603
603
604
604
605 def update_graph(self, dep_id=None, success=True):
605 def update_graph(self, dep_id=None, success=True):
606 """dep_id just finished. Update our dependency
606 """dep_id just finished. Update our dependency
607 graph and submit any jobs that just became runable.
607 graph and submit any jobs that just became runable.
608
608
609 Called with dep_id=None to update entire graph for hwm, but without finishing
609 Called with dep_id=None to update entire graph for hwm, but without finishing
610 a task.
610 a task.
611 """
611 """
612 # print ("\n\n***********")
612 # print ("\n\n***********")
613 # pprint (dep_id)
613 # pprint (dep_id)
614 # pprint (self.graph)
614 # pprint (self.graph)
615 # pprint (self.depending)
615 # pprint (self.depending)
616 # pprint (self.all_completed)
616 # pprint (self.all_completed)
617 # pprint (self.all_failed)
617 # pprint (self.all_failed)
618 # print ("\n\n***********\n\n")
618 # print ("\n\n***********\n\n")
619 # update any jobs that depended on the dependency
619 # update any jobs that depended on the dependency
620 jobs = self.graph.pop(dep_id, [])
620 jobs = self.graph.pop(dep_id, [])
621
621
622 # recheck *all* jobs if
622 # recheck *all* jobs if
623 # a) we have HWM and an engine just become no longer full
623 # a) we have HWM and an engine just become no longer full
624 # or b) dep_id was given as None
624 # or b) dep_id was given as None
625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
626 jobs = self.depending.keys()
626 jobs = self.depending.keys()
627
627
628 for msg_id in jobs:
628 for msg_id in jobs:
629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
630
630
631 if after.unreachable(self.all_completed, self.all_failed)\
631 if after.unreachable(self.all_completed, self.all_failed)\
632 or follow.unreachable(self.all_completed, self.all_failed):
632 or follow.unreachable(self.all_completed, self.all_failed):
633 self.fail_unreachable(msg_id)
633 self.fail_unreachable(msg_id)
634
634
635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
637
637
638 self.depending.pop(msg_id)
638 self.depending.pop(msg_id)
639 for mid in follow.union(after):
639 for mid in follow.union(after):
640 if mid in self.graph:
640 if mid in self.graph:
641 self.graph[mid].remove(msg_id)
641 self.graph[mid].remove(msg_id)
642
642
643 #----------------------------------------------------------------------
643 #----------------------------------------------------------------------
644 # methods to be overridden by subclasses
644 # methods to be overridden by subclasses
645 #----------------------------------------------------------------------
645 #----------------------------------------------------------------------
646
646
647 def add_job(self, idx):
647 def add_job(self, idx):
648 """Called after self.targets[idx] just got the job with header.
648 """Called after self.targets[idx] just got the job with header.
649 Override with subclasses. The default ordering is simple LRU.
649 Override with subclasses. The default ordering is simple LRU.
650 The default loads are the number of outstanding jobs."""
650 The default loads are the number of outstanding jobs."""
651 self.loads[idx] += 1
651 self.loads[idx] += 1
652 for lis in (self.targets, self.loads):
652 for lis in (self.targets, self.loads):
653 lis.append(lis.pop(idx))
653 lis.append(lis.pop(idx))
654
654
655
655
656 def finish_job(self, idx):
656 def finish_job(self, idx):
657 """Called after self.targets[idx] just finished a job.
657 """Called after self.targets[idx] just finished a job.
658 Override with subclasses."""
658 Override with subclasses."""
659 self.loads[idx] -= 1
659 self.loads[idx] -= 1
660
660
661
661
662
662
663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
664 logname='root', log_url=None, loglevel=logging.DEBUG,
664 logname='root', log_url=None, loglevel=logging.DEBUG,
665 identity=b'task', in_thread=False):
665 identity=b'task', in_thread=False):
666
666
667 ZMQStream = zmqstream.ZMQStream
667 ZMQStream = zmqstream.ZMQStream
668
668
669 if config:
669 if config:
670 # unwrap dict back into Config
670 # unwrap dict back into Config
671 config = Config(config)
671 config = Config(config)
672
672
673 if in_thread:
673 if in_thread:
674 # use instance() to get the same Context/Loop as our parent
674 # use instance() to get the same Context/Loop as our parent
675 ctx = zmq.Context.instance()
675 ctx = zmq.Context.instance()
676 loop = ioloop.IOLoop.instance()
676 loop = ioloop.IOLoop.instance()
677 else:
677 else:
678 # in a process, don't use instance()
678 # in a process, don't use instance()
679 # for safety with multiprocessing
679 # for safety with multiprocessing
680 ctx = zmq.Context()
680 ctx = zmq.Context()
681 loop = ioloop.IOLoop()
681 loop = ioloop.IOLoop()
682 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
682 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
683 ins.setsockopt(zmq.IDENTITY, identity)
683 ins.setsockopt(zmq.IDENTITY, identity)
684 ins.bind(in_addr)
684 ins.bind(in_addr)
685
685
686 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
686 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
687 outs.setsockopt(zmq.IDENTITY, identity)
687 outs.setsockopt(zmq.IDENTITY, identity)
688 outs.bind(out_addr)
688 outs.bind(out_addr)
689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
690 mons.connect(mon_addr)
690 mons.connect(mon_addr)
691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
692 nots.setsockopt(zmq.SUBSCRIBE, b'')
692 nots.setsockopt(zmq.SUBSCRIBE, b'')
693 nots.connect(not_addr)
693 nots.connect(not_addr)
694
694
695 # setup logging.
695 # setup logging.
696 if in_thread:
696 if in_thread:
697 log = Application.instance().log
697 log = Application.instance().log
698 else:
698 else:
699 if log_url:
699 if log_url:
700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
701 else:
701 else:
702 log = local_logger(logname, loglevel)
702 log = local_logger(logname, loglevel)
703
703
704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
705 mon_stream=mons, notifier_stream=nots,
705 mon_stream=mons, notifier_stream=nots,
706 loop=loop, log=log,
706 loop=loop, log=log,
707 config=config)
707 config=config)
708 scheduler.start()
708 scheduler.start()
709 if not in_thread:
709 if not in_thread:
710 try:
710 try:
711 loop.start()
711 loop.start()
712 except KeyboardInterrupt:
712 except KeyboardInterrupt:
713 print ("interrupted, exiting...", file=sys.__stderr__)
713 print ("interrupted, exiting...", file=sys.__stderr__)
714
714
@@ -1,401 +1,400 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import json
14 import json
15 import os
15 import os
16 import cPickle as pickle
16 import cPickle as pickle
17 from datetime import datetime
17 from datetime import datetime
18
18
19 import sqlite3
19 import sqlite3
20
20
21 from zmq.eventloop import ioloop
21 from zmq.eventloop import ioloop
22
22
23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
24 from .dictdb import BaseDB
24 from .dictdb import BaseDB
25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # SQLite operators, adapters, and converters
28 # SQLite operators, adapters, and converters
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30
30
31 try:
31 try:
32 buffer
32 buffer
33 except NameError:
33 except NameError:
34 # py3k
34 # py3k
35 buffer = memoryview
35 buffer = memoryview
36
36
37 operators = {
37 operators = {
38 '$lt' : "<",
38 '$lt' : "<",
39 '$gt' : ">",
39 '$gt' : ">",
40 # null is handled weird with ==,!=
40 # null is handled weird with ==,!=
41 '$eq' : "=",
41 '$eq' : "=",
42 '$ne' : "!=",
42 '$ne' : "!=",
43 '$lte': "<=",
43 '$lte': "<=",
44 '$gte': ">=",
44 '$gte': ">=",
45 '$in' : ('=', ' OR '),
45 '$in' : ('=', ' OR '),
46 '$nin': ('!=', ' AND '),
46 '$nin': ('!=', ' AND '),
47 # '$all': None,
47 # '$all': None,
48 # '$mod': None,
48 # '$mod': None,
49 # '$exists' : None
49 # '$exists' : None
50 }
50 }
51 null_operators = {
51 null_operators = {
52 '=' : "IS NULL",
52 '=' : "IS NULL",
53 '!=' : "IS NOT NULL",
53 '!=' : "IS NOT NULL",
54 }
54 }
55
55
56 def _adapt_dict(d):
56 def _adapt_dict(d):
57 return json.dumps(d, default=date_default)
57 return json.dumps(d, default=date_default)
58
58
59 def _convert_dict(ds):
59 def _convert_dict(ds):
60 if ds is None:
60 if ds is None:
61 return ds
61 return ds
62 else:
62 else:
63 if isinstance(ds, bytes):
63 if isinstance(ds, bytes):
64 # If I understand the sqlite doc correctly, this will always be utf8
64 # If I understand the sqlite doc correctly, this will always be utf8
65 ds = ds.decode('utf8')
65 ds = ds.decode('utf8')
66 d = json.loads(ds)
66 return extract_dates(json.loads(ds))
67 return extract_dates(d)
68
67
69 def _adapt_bufs(bufs):
68 def _adapt_bufs(bufs):
70 # this is *horrible*
69 # this is *horrible*
71 # copy buffers into single list and pickle it:
70 # copy buffers into single list and pickle it:
72 if bufs and isinstance(bufs[0], (bytes, buffer)):
71 if bufs and isinstance(bufs[0], (bytes, buffer)):
73 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
72 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
74 elif bufs:
73 elif bufs:
75 return bufs
74 return bufs
76 else:
75 else:
77 return None
76 return None
78
77
79 def _convert_bufs(bs):
78 def _convert_bufs(bs):
80 if bs is None:
79 if bs is None:
81 return []
80 return []
82 else:
81 else:
83 return pickle.loads(bytes(bs))
82 return pickle.loads(bytes(bs))
84
83
85 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
86 # SQLiteDB class
85 # SQLiteDB class
87 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
88
87
89 class SQLiteDB(BaseDB):
88 class SQLiteDB(BaseDB):
90 """SQLite3 TaskRecord backend."""
89 """SQLite3 TaskRecord backend."""
91
90
92 filename = Unicode('tasks.db', config=True,
91 filename = Unicode('tasks.db', config=True,
93 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
92 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
94 location = Unicode('', config=True,
93 location = Unicode('', config=True,
95 help="""The directory containing the sqlite task database. The default
94 help="""The directory containing the sqlite task database. The default
96 is to use the cluster_dir location.""")
95 is to use the cluster_dir location.""")
97 table = Unicode("", config=True,
96 table = Unicode("", config=True,
98 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
97 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
99 a new table will be created with the Hub's IDENT. Specifying the table will result
98 a new table will be created with the Hub's IDENT. Specifying the table will result
100 in tasks from previous sessions being available via Clients' db_query and
99 in tasks from previous sessions being available via Clients' db_query and
101 get_result methods.""")
100 get_result methods.""")
102
101
103 _db = Instance('sqlite3.Connection')
102 _db = Instance('sqlite3.Connection')
104 # the ordered list of column names
103 # the ordered list of column names
105 _keys = List(['msg_id' ,
104 _keys = List(['msg_id' ,
106 'header' ,
105 'header' ,
107 'content',
106 'content',
108 'buffers',
107 'buffers',
109 'submitted',
108 'submitted',
110 'client_uuid' ,
109 'client_uuid' ,
111 'engine_uuid' ,
110 'engine_uuid' ,
112 'started',
111 'started',
113 'completed',
112 'completed',
114 'resubmitted',
113 'resubmitted',
115 'result_header' ,
114 'result_header' ,
116 'result_content' ,
115 'result_content' ,
117 'result_buffers' ,
116 'result_buffers' ,
118 'queue' ,
117 'queue' ,
119 'pyin' ,
118 'pyin' ,
120 'pyout',
119 'pyout',
121 'pyerr',
120 'pyerr',
122 'stdout',
121 'stdout',
123 'stderr',
122 'stderr',
124 ])
123 ])
125 # sqlite datatypes for checking that db is current format
124 # sqlite datatypes for checking that db is current format
126 _types = Dict({'msg_id' : 'text' ,
125 _types = Dict({'msg_id' : 'text' ,
127 'header' : 'dict text',
126 'header' : 'dict text',
128 'content' : 'dict text',
127 'content' : 'dict text',
129 'buffers' : 'bufs blob',
128 'buffers' : 'bufs blob',
130 'submitted' : 'timestamp',
129 'submitted' : 'timestamp',
131 'client_uuid' : 'text',
130 'client_uuid' : 'text',
132 'engine_uuid' : 'text',
131 'engine_uuid' : 'text',
133 'started' : 'timestamp',
132 'started' : 'timestamp',
134 'completed' : 'timestamp',
133 'completed' : 'timestamp',
135 'resubmitted' : 'timestamp',
134 'resubmitted' : 'timestamp',
136 'result_header' : 'dict text',
135 'result_header' : 'dict text',
137 'result_content' : 'dict text',
136 'result_content' : 'dict text',
138 'result_buffers' : 'bufs blob',
137 'result_buffers' : 'bufs blob',
139 'queue' : 'text',
138 'queue' : 'text',
140 'pyin' : 'text',
139 'pyin' : 'text',
141 'pyout' : 'text',
140 'pyout' : 'text',
142 'pyerr' : 'text',
141 'pyerr' : 'text',
143 'stdout' : 'text',
142 'stdout' : 'text',
144 'stderr' : 'text',
143 'stderr' : 'text',
145 })
144 })
146
145
147 def __init__(self, **kwargs):
146 def __init__(self, **kwargs):
148 super(SQLiteDB, self).__init__(**kwargs)
147 super(SQLiteDB, self).__init__(**kwargs)
149 if not self.table:
148 if not self.table:
150 # use session, and prefix _, since starting with # is illegal
149 # use session, and prefix _, since starting with # is illegal
151 self.table = '_'+self.session.replace('-','_')
150 self.table = '_'+self.session.replace('-','_')
152 if not self.location:
151 if not self.location:
153 # get current profile
152 # get current profile
154 from IPython.core.application import BaseIPythonApplication
153 from IPython.core.application import BaseIPythonApplication
155 if BaseIPythonApplication.initialized():
154 if BaseIPythonApplication.initialized():
156 app = BaseIPythonApplication.instance()
155 app = BaseIPythonApplication.instance()
157 if app.profile_dir is not None:
156 if app.profile_dir is not None:
158 self.location = app.profile_dir.location
157 self.location = app.profile_dir.location
159 else:
158 else:
160 self.location = u'.'
159 self.location = u'.'
161 else:
160 else:
162 self.location = u'.'
161 self.location = u'.'
163 self._init_db()
162 self._init_db()
164
163
165 # register db commit as 2s periodic callback
164 # register db commit as 2s periodic callback
166 # to prevent clogging pipes
165 # to prevent clogging pipes
167 # assumes we are being run in a zmq ioloop app
166 # assumes we are being run in a zmq ioloop app
168 loop = ioloop.IOLoop.instance()
167 loop = ioloop.IOLoop.instance()
169 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
168 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
170 pc.start()
169 pc.start()
171
170
172 def _defaults(self, keys=None):
171 def _defaults(self, keys=None):
173 """create an empty record"""
172 """create an empty record"""
174 d = {}
173 d = {}
175 keys = self._keys if keys is None else keys
174 keys = self._keys if keys is None else keys
176 for key in keys:
175 for key in keys:
177 d[key] = None
176 d[key] = None
178 return d
177 return d
179
178
180 def _check_table(self):
179 def _check_table(self):
181 """Ensure that an incorrect table doesn't exist
180 """Ensure that an incorrect table doesn't exist
182
181
183 If a bad (old) table does exist, return False
182 If a bad (old) table does exist, return False
184 """
183 """
185 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
184 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
186 lines = cursor.fetchall()
185 lines = cursor.fetchall()
187 if not lines:
186 if not lines:
188 # table does not exist
187 # table does not exist
189 return True
188 return True
190 types = {}
189 types = {}
191 keys = []
190 keys = []
192 for line in lines:
191 for line in lines:
193 keys.append(line[1])
192 keys.append(line[1])
194 types[line[1]] = line[2]
193 types[line[1]] = line[2]
195 if self._keys != keys:
194 if self._keys != keys:
196 # key mismatch
195 # key mismatch
197 self.log.warn('keys mismatch')
196 self.log.warn('keys mismatch')
198 return False
197 return False
199 for key in self._keys:
198 for key in self._keys:
200 if types[key] != self._types[key]:
199 if types[key] != self._types[key]:
201 self.log.warn(
200 self.log.warn(
202 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
201 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
203 )
202 )
204 return False
203 return False
205 return True
204 return True
206
205
207 def _init_db(self):
206 def _init_db(self):
208 """Connect to the database and get new session number."""
207 """Connect to the database and get new session number."""
209 # register adapters
208 # register adapters
210 sqlite3.register_adapter(dict, _adapt_dict)
209 sqlite3.register_adapter(dict, _adapt_dict)
211 sqlite3.register_converter('dict', _convert_dict)
210 sqlite3.register_converter('dict', _convert_dict)
212 sqlite3.register_adapter(list, _adapt_bufs)
211 sqlite3.register_adapter(list, _adapt_bufs)
213 sqlite3.register_converter('bufs', _convert_bufs)
212 sqlite3.register_converter('bufs', _convert_bufs)
214 # connect to the db
213 # connect to the db
215 dbfile = os.path.join(self.location, self.filename)
214 dbfile = os.path.join(self.location, self.filename)
216 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
215 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
217 # isolation_level = None)#,
216 # isolation_level = None)#,
218 cached_statements=64)
217 cached_statements=64)
219 # print dir(self._db)
218 # print dir(self._db)
220 first_table = self.table
219 first_table = self.table
221 i=0
220 i=0
222 while not self._check_table():
221 while not self._check_table():
223 i+=1
222 i+=1
224 self.table = first_table+'_%i'%i
223 self.table = first_table+'_%i'%i
225 self.log.warn(
224 self.log.warn(
226 "Table %s exists and doesn't match db format, trying %s"%
225 "Table %s exists and doesn't match db format, trying %s"%
227 (first_table,self.table)
226 (first_table,self.table)
228 )
227 )
229
228
230 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
229 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
231 (msg_id text PRIMARY KEY,
230 (msg_id text PRIMARY KEY,
232 header dict text,
231 header dict text,
233 content dict text,
232 content dict text,
234 buffers bufs blob,
233 buffers bufs blob,
235 submitted timestamp,
234 submitted timestamp,
236 client_uuid text,
235 client_uuid text,
237 engine_uuid text,
236 engine_uuid text,
238 started timestamp,
237 started timestamp,
239 completed timestamp,
238 completed timestamp,
240 resubmitted timestamp,
239 resubmitted timestamp,
241 result_header dict text,
240 result_header dict text,
242 result_content dict text,
241 result_content dict text,
243 result_buffers bufs blob,
242 result_buffers bufs blob,
244 queue text,
243 queue text,
245 pyin text,
244 pyin text,
246 pyout text,
245 pyout text,
247 pyerr text,
246 pyerr text,
248 stdout text,
247 stdout text,
249 stderr text)
248 stderr text)
250 """%self.table)
249 """%self.table)
251 self._db.commit()
250 self._db.commit()
252
251
253 def _dict_to_list(self, d):
252 def _dict_to_list(self, d):
254 """turn a mongodb-style record dict into a list."""
253 """turn a mongodb-style record dict into a list."""
255
254
256 return [ d[key] for key in self._keys ]
255 return [ d[key] for key in self._keys ]
257
256
258 def _list_to_dict(self, line, keys=None):
257 def _list_to_dict(self, line, keys=None):
259 """Inverse of dict_to_list"""
258 """Inverse of dict_to_list"""
260 keys = self._keys if keys is None else keys
259 keys = self._keys if keys is None else keys
261 d = self._defaults(keys)
260 d = self._defaults(keys)
262 for key,value in zip(keys, line):
261 for key,value in zip(keys, line):
263 d[key] = value
262 d[key] = value
264
263
265 return d
264 return d
266
265
267 def _render_expression(self, check):
266 def _render_expression(self, check):
268 """Turn a mongodb-style search dict into an SQL query."""
267 """Turn a mongodb-style search dict into an SQL query."""
269 expressions = []
268 expressions = []
270 args = []
269 args = []
271
270
272 skeys = set(check.keys())
271 skeys = set(check.keys())
273 skeys.difference_update(set(self._keys))
272 skeys.difference_update(set(self._keys))
274 skeys.difference_update(set(['buffers', 'result_buffers']))
273 skeys.difference_update(set(['buffers', 'result_buffers']))
275 if skeys:
274 if skeys:
276 raise KeyError("Illegal testing key(s): %s"%skeys)
275 raise KeyError("Illegal testing key(s): %s"%skeys)
277
276
278 for name,sub_check in check.iteritems():
277 for name,sub_check in check.iteritems():
279 if isinstance(sub_check, dict):
278 if isinstance(sub_check, dict):
280 for test,value in sub_check.iteritems():
279 for test,value in sub_check.iteritems():
281 try:
280 try:
282 op = operators[test]
281 op = operators[test]
283 except KeyError:
282 except KeyError:
284 raise KeyError("Unsupported operator: %r"%test)
283 raise KeyError("Unsupported operator: %r"%test)
285 if isinstance(op, tuple):
284 if isinstance(op, tuple):
286 op, join = op
285 op, join = op
287
286
288 if value is None and op in null_operators:
287 if value is None and op in null_operators:
289 expr = "%s %s"%null_operators[op]
288 expr = "%s %s"%null_operators[op]
290 else:
289 else:
291 expr = "%s %s ?"%(name, op)
290 expr = "%s %s ?"%(name, op)
292 if isinstance(value, (tuple,list)):
291 if isinstance(value, (tuple,list)):
293 if op in null_operators and any([v is None for v in value]):
292 if op in null_operators and any([v is None for v in value]):
294 # equality tests don't work with NULL
293 # equality tests don't work with NULL
295 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
294 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
296 expr = '( %s )'%( join.join([expr]*len(value)) )
295 expr = '( %s )'%( join.join([expr]*len(value)) )
297 args.extend(value)
296 args.extend(value)
298 else:
297 else:
299 args.append(value)
298 args.append(value)
300 expressions.append(expr)
299 expressions.append(expr)
301 else:
300 else:
302 # it's an equality check
301 # it's an equality check
303 if sub_check is None:
302 if sub_check is None:
304 expressions.append("%s IS NULL")
303 expressions.append("%s IS NULL")
305 else:
304 else:
306 expressions.append("%s = ?"%name)
305 expressions.append("%s = ?"%name)
307 args.append(sub_check)
306 args.append(sub_check)
308
307
309 expr = " AND ".join(expressions)
308 expr = " AND ".join(expressions)
310 return expr, args
309 return expr, args
311
310
312 def add_record(self, msg_id, rec):
311 def add_record(self, msg_id, rec):
313 """Add a new Task Record, by msg_id."""
312 """Add a new Task Record, by msg_id."""
314 d = self._defaults()
313 d = self._defaults()
315 d.update(rec)
314 d.update(rec)
316 d['msg_id'] = msg_id
315 d['msg_id'] = msg_id
317 line = self._dict_to_list(d)
316 line = self._dict_to_list(d)
318 tups = '(%s)'%(','.join(['?']*len(line)))
317 tups = '(%s)'%(','.join(['?']*len(line)))
319 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
318 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
320 # self._db.commit()
319 # self._db.commit()
321
320
322 def get_record(self, msg_id):
321 def get_record(self, msg_id):
323 """Get a specific Task Record, by msg_id."""
322 """Get a specific Task Record, by msg_id."""
324 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
323 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
325 line = cursor.fetchone()
324 line = cursor.fetchone()
326 if line is None:
325 if line is None:
327 raise KeyError("No such msg: %r"%msg_id)
326 raise KeyError("No such msg: %r"%msg_id)
328 return self._list_to_dict(line)
327 return self._list_to_dict(line)
329
328
330 def update_record(self, msg_id, rec):
329 def update_record(self, msg_id, rec):
331 """Update the data in an existing record."""
330 """Update the data in an existing record."""
332 query = "UPDATE %s SET "%self.table
331 query = "UPDATE %s SET "%self.table
333 sets = []
332 sets = []
334 keys = sorted(rec.keys())
333 keys = sorted(rec.keys())
335 values = []
334 values = []
336 for key in keys:
335 for key in keys:
337 sets.append('%s = ?'%key)
336 sets.append('%s = ?'%key)
338 values.append(rec[key])
337 values.append(rec[key])
339 query += ', '.join(sets)
338 query += ', '.join(sets)
340 query += ' WHERE msg_id == ?'
339 query += ' WHERE msg_id == ?'
341 values.append(msg_id)
340 values.append(msg_id)
342 self._db.execute(query, values)
341 self._db.execute(query, values)
343 # self._db.commit()
342 # self._db.commit()
344
343
345 def drop_record(self, msg_id):
344 def drop_record(self, msg_id):
346 """Remove a record from the DB."""
345 """Remove a record from the DB."""
347 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
346 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
348 # self._db.commit()
347 # self._db.commit()
349
348
350 def drop_matching_records(self, check):
349 def drop_matching_records(self, check):
351 """Remove a record from the DB."""
350 """Remove a record from the DB."""
352 expr,args = self._render_expression(check)
351 expr,args = self._render_expression(check)
353 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
352 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
354 self._db.execute(query,args)
353 self._db.execute(query,args)
355 # self._db.commit()
354 # self._db.commit()
356
355
357 def find_records(self, check, keys=None):
356 def find_records(self, check, keys=None):
358 """Find records matching a query dict, optionally extracting subset of keys.
357 """Find records matching a query dict, optionally extracting subset of keys.
359
358
360 Returns list of matching records.
359 Returns list of matching records.
361
360
362 Parameters
361 Parameters
363 ----------
362 ----------
364
363
365 check: dict
364 check: dict
366 mongodb-style query argument
365 mongodb-style query argument
367 keys: list of strs [optional]
366 keys: list of strs [optional]
368 if specified, the subset of keys to extract. msg_id will *always* be
367 if specified, the subset of keys to extract. msg_id will *always* be
369 included.
368 included.
370 """
369 """
371 if keys:
370 if keys:
372 bad_keys = [ key for key in keys if key not in self._keys ]
371 bad_keys = [ key for key in keys if key not in self._keys ]
373 if bad_keys:
372 if bad_keys:
374 raise KeyError("Bad record key(s): %s"%bad_keys)
373 raise KeyError("Bad record key(s): %s"%bad_keys)
375
374
376 if keys:
375 if keys:
377 # ensure msg_id is present and first:
376 # ensure msg_id is present and first:
378 if 'msg_id' in keys:
377 if 'msg_id' in keys:
379 keys.remove('msg_id')
378 keys.remove('msg_id')
380 keys.insert(0, 'msg_id')
379 keys.insert(0, 'msg_id')
381 req = ', '.join(keys)
380 req = ', '.join(keys)
382 else:
381 else:
383 req = '*'
382 req = '*'
384 expr,args = self._render_expression(check)
383 expr,args = self._render_expression(check)
385 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
384 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
386 cursor = self._db.execute(query, args)
385 cursor = self._db.execute(query, args)
387 matches = cursor.fetchall()
386 matches = cursor.fetchall()
388 records = []
387 records = []
389 for line in matches:
388 for line in matches:
390 rec = self._list_to_dict(line, keys)
389 rec = self._list_to_dict(line, keys)
391 records.append(rec)
390 records.append(rec)
392 return records
391 return records
393
392
394 def get_history(self):
393 def get_history(self):
395 """get all msg_ids, ordered by time submitted."""
394 """get all msg_ids, ordered by time submitted."""
396 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
395 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
397 cursor = self._db.execute(query)
396 cursor = self._db.execute(query)
398 # will be a list of length 1 tuples
397 # will be a list of length 1 tuples
399 return [ tup[0] for tup in cursor.fetchall()]
398 return [ tup[0] for tup in cursor.fetchall()]
400
399
401 __all__ = ['SQLiteDB'] No newline at end of file
400 __all__ = ['SQLiteDB']
@@ -1,174 +1,174 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple engine that talks to a controller over 0MQ.
2 """A simple engine that talks to a controller over 0MQ.
3 it handles registration, etc. and launches a kernel
3 it handles registration, etc. and launches a kernel
4 connected to the Controller's Schedulers.
4 connected to the Controller's Schedulers.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 from __future__ import print_function
17 from __future__ import print_function
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 # internal
25 # internal
26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 # from IPython.utils.localinterfaces import LOCALHOST
27 # from IPython.utils.localinterfaces import LOCALHOST
28
28
29 from IPython.parallel.controller.heartmonitor import Heart
29 from IPython.parallel.controller.heartmonitor import Heart
30 from IPython.parallel.factory import RegistrationFactory
30 from IPython.parallel.factory import RegistrationFactory
31 from IPython.parallel.util import disambiguate_url, ensure_bytes
31 from IPython.parallel.util import disambiguate_url, asbytes
32
32
33 from IPython.zmq.session import Message
33 from IPython.zmq.session import Message
34
34
35 from .streamkernel import Kernel
35 from .streamkernel import Kernel
36
36
37 class EngineFactory(RegistrationFactory):
37 class EngineFactory(RegistrationFactory):
38 """IPython engine"""
38 """IPython engine"""
39
39
40 # configurables:
40 # configurables:
41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 help="""The OutStream for handling stdout/err.
42 help="""The OutStream for handling stdout/err.
43 Typically 'IPython.zmq.iostream.OutStream'""")
43 Typically 'IPython.zmq.iostream.OutStream'""")
44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 help="""The class for handling displayhook.
45 help="""The class for handling displayhook.
46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 location=Unicode(config=True,
47 location=Unicode(config=True,
48 help="""The location (an IP address) of the controller. This is
48 help="""The location (an IP address) of the controller. This is
49 used for disambiguating URLs, to determine whether
49 used for disambiguating URLs, to determine whether
50 loopback should be used to connect or the public address.""")
50 loopback should be used to connect or the public address.""")
51 timeout=CFloat(2,config=True,
51 timeout=CFloat(2,config=True,
52 help="""The time (in seconds) to wait for the Controller to respond
52 help="""The time (in seconds) to wait for the Controller to respond
53 to registration requests before giving up.""")
53 to registration requests before giving up.""")
54
54
55 # not configurable:
55 # not configurable:
56 user_ns=Dict()
56 user_ns=Dict()
57 id=Int(allow_none=True)
57 id=Int(allow_none=True)
58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 kernel=Instance(Kernel)
59 kernel=Instance(Kernel)
60
60
61 bident = CBytes()
61 bident = CBytes()
62 ident = Unicode()
62 ident = Unicode()
63 def _ident_changed(self, name, old, new):
63 def _ident_changed(self, name, old, new):
64 self.bident = ensure_bytes(new)
64 self.bident = asbytes(new)
65
65
66
66
67 def __init__(self, **kwargs):
67 def __init__(self, **kwargs):
68 super(EngineFactory, self).__init__(**kwargs)
68 super(EngineFactory, self).__init__(**kwargs)
69 self.ident = self.session.session
69 self.ident = self.session.session
70 ctx = self.context
70 ctx = self.context
71
71
72 reg = ctx.socket(zmq.XREQ)
72 reg = ctx.socket(zmq.XREQ)
73 reg.setsockopt(zmq.IDENTITY, self.bident)
73 reg.setsockopt(zmq.IDENTITY, self.bident)
74 reg.connect(self.url)
74 reg.connect(self.url)
75 self.registrar = zmqstream.ZMQStream(reg, self.loop)
75 self.registrar = zmqstream.ZMQStream(reg, self.loop)
76
76
77 def register(self):
77 def register(self):
78 """send the registration_request"""
78 """send the registration_request"""
79
79
80 self.log.info("Registering with controller at %s"%self.url)
80 self.log.info("Registering with controller at %s"%self.url)
81 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
81 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
82 self.registrar.on_recv(self.complete_registration)
82 self.registrar.on_recv(self.complete_registration)
83 # print (self.session.key)
83 # print (self.session.key)
84 self.session.send(self.registrar, "registration_request",content=content)
84 self.session.send(self.registrar, "registration_request",content=content)
85
85
86 def complete_registration(self, msg):
86 def complete_registration(self, msg):
87 # print msg
87 # print msg
88 self._abort_dc.stop()
88 self._abort_dc.stop()
89 ctx = self.context
89 ctx = self.context
90 loop = self.loop
90 loop = self.loop
91 identity = self.bident
91 identity = self.bident
92 idents,msg = self.session.feed_identities(msg)
92 idents,msg = self.session.feed_identities(msg)
93 msg = Message(self.session.unpack_message(msg))
93 msg = Message(self.session.unpack_message(msg))
94
94
95 if msg.content.status == 'ok':
95 if msg.content.status == 'ok':
96 self.id = int(msg.content.id)
96 self.id = int(msg.content.id)
97
97
98 # create Shell Streams (MUX, Task, etc.):
98 # create Shell Streams (MUX, Task, etc.):
99 queue_addr = msg.content.mux
99 queue_addr = msg.content.mux
100 shell_addrs = [ str(queue_addr) ]
100 shell_addrs = [ str(queue_addr) ]
101 task_addr = msg.content.task
101 task_addr = msg.content.task
102 if task_addr:
102 if task_addr:
103 shell_addrs.append(str(task_addr))
103 shell_addrs.append(str(task_addr))
104
104
105 # Uncomment this to go back to two-socket model
105 # Uncomment this to go back to two-socket model
106 # shell_streams = []
106 # shell_streams = []
107 # for addr in shell_addrs:
107 # for addr in shell_addrs:
108 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
108 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
109 # stream.setsockopt(zmq.IDENTITY, identity)
109 # stream.setsockopt(zmq.IDENTITY, identity)
110 # stream.connect(disambiguate_url(addr, self.location))
110 # stream.connect(disambiguate_url(addr, self.location))
111 # shell_streams.append(stream)
111 # shell_streams.append(stream)
112
112
113 # Now use only one shell stream for mux and tasks
113 # Now use only one shell stream for mux and tasks
114 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
114 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 stream.setsockopt(zmq.IDENTITY, identity)
115 stream.setsockopt(zmq.IDENTITY, identity)
116 shell_streams = [stream]
116 shell_streams = [stream]
117 for addr in shell_addrs:
117 for addr in shell_addrs:
118 stream.connect(disambiguate_url(addr, self.location))
118 stream.connect(disambiguate_url(addr, self.location))
119 # end single stream-socket
119 # end single stream-socket
120
120
121 # control stream:
121 # control stream:
122 control_addr = str(msg.content.control)
122 control_addr = str(msg.content.control)
123 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
123 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
124 control_stream.setsockopt(zmq.IDENTITY, identity)
124 control_stream.setsockopt(zmq.IDENTITY, identity)
125 control_stream.connect(disambiguate_url(control_addr, self.location))
125 control_stream.connect(disambiguate_url(control_addr, self.location))
126
126
127 # create iopub stream:
127 # create iopub stream:
128 iopub_addr = msg.content.iopub
128 iopub_addr = msg.content.iopub
129 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
129 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
130 iopub_stream.setsockopt(zmq.IDENTITY, identity)
130 iopub_stream.setsockopt(zmq.IDENTITY, identity)
131 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
131 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
132
132
133 # launch heartbeat
133 # launch heartbeat
134 hb_addrs = msg.content.heartbeat
134 hb_addrs = msg.content.heartbeat
135 # print (hb_addrs)
135 # print (hb_addrs)
136
136
137 # # Redirect input streams and set a display hook.
137 # # Redirect input streams and set a display hook.
138 if self.out_stream_factory:
138 if self.out_stream_factory:
139 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
139 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
140 sys.stdout.topic = 'engine.%i.stdout'%self.id
140 sys.stdout.topic = 'engine.%i.stdout'%self.id
141 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
141 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
142 sys.stderr.topic = 'engine.%i.stderr'%self.id
142 sys.stderr.topic = 'engine.%i.stderr'%self.id
143 if self.display_hook_factory:
143 if self.display_hook_factory:
144 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
144 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
145 sys.displayhook.topic = 'engine.%i.pyout'%self.id
145 sys.displayhook.topic = 'engine.%i.pyout'%self.id
146
146
147 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
147 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
148 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
148 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
149 loop=loop, user_ns = self.user_ns, log=self.log)
149 loop=loop, user_ns = self.user_ns, log=self.log)
150 self.kernel.start()
150 self.kernel.start()
151 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
151 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
152 heart = Heart(*map(str, hb_addrs), heart_id=identity)
152 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 heart.start()
153 heart.start()
154
154
155
155
156 else:
156 else:
157 self.log.fatal("Registration Failed: %s"%msg)
157 self.log.fatal("Registration Failed: %s"%msg)
158 raise Exception("Registration Failed: %s"%msg)
158 raise Exception("Registration Failed: %s"%msg)
159
159
160 self.log.info("Completed registration with id %i"%self.id)
160 self.log.info("Completed registration with id %i"%self.id)
161
161
162
162
163 def abort(self):
163 def abort(self):
164 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
164 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
165 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
165 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
166 time.sleep(1)
166 time.sleep(1)
167 sys.exit(255)
167 sys.exit(255)
168
168
169 def start(self):
169 def start(self):
170 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
170 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
171 dc.start()
171 dc.start()
172 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
172 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
173 self._abort_dc.start()
173 self._abort_dc.start()
174
174
@@ -1,438 +1,438 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 Kernel adapted from kernel.py to use ZMQ Streams
3 Kernel adapted from kernel.py to use ZMQ Streams
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 * Brian Granger
8 * Brian Granger
9 * Fernando Perez
9 * Fernando Perez
10 * Evan Patterson
10 * Evan Patterson
11 """
11 """
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2010-2011 The IPython Development Team
13 # Copyright (C) 2010-2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 # Standard library imports.
23 # Standard library imports.
24 from __future__ import print_function
24 from __future__ import print_function
25
25
26 import sys
26 import sys
27 import time
27 import time
28
28
29 from code import CommandCompiler
29 from code import CommandCompiler
30 from datetime import datetime
30 from datetime import datetime
31 from pprint import pprint
31 from pprint import pprint
32
32
33 # System library imports.
33 # System library imports.
34 import zmq
34 import zmq
35 from zmq.eventloop import ioloop, zmqstream
35 from zmq.eventloop import ioloop, zmqstream
36
36
37 # Local imports.
37 # Local imports.
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 from IPython.zmq.completer import KernelCompleter
39 from IPython.zmq.completer import KernelCompleter
40
40
41 from IPython.parallel.error import wrap_exception
41 from IPython.parallel.error import wrap_exception
42 from IPython.parallel.factory import SessionFactory
42 from IPython.parallel.factory import SessionFactory
43 from IPython.parallel.util import serialize_object, unpack_apply_message, ensure_bytes
43 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44
44
45 def printer(*args):
45 def printer(*args):
46 pprint(args, stream=sys.__stdout__)
46 pprint(args, stream=sys.__stdout__)
47
47
48
48
49 class _Passer(zmqstream.ZMQStream):
49 class _Passer(zmqstream.ZMQStream):
50 """Empty class that implements `send()` that does nothing.
50 """Empty class that implements `send()` that does nothing.
51
51
52 Subclass ZMQStream for Session typechecking
52 Subclass ZMQStream for Session typechecking
53
53
54 """
54 """
55 def __init__(self, *args, **kwargs):
55 def __init__(self, *args, **kwargs):
56 pass
56 pass
57
57
58 def send(self, *args, **kwargs):
58 def send(self, *args, **kwargs):
59 pass
59 pass
60 send_multipart = send
60 send_multipart = send
61
61
62
62
63 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
64 # Main kernel class
64 # Main kernel class
65 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
66
66
67 class Kernel(SessionFactory):
67 class Kernel(SessionFactory):
68
68
69 #---------------------------------------------------------------------------
69 #---------------------------------------------------------------------------
70 # Kernel interface
70 # Kernel interface
71 #---------------------------------------------------------------------------
71 #---------------------------------------------------------------------------
72
72
73 # kwargs:
73 # kwargs:
74 exec_lines = List(Unicode, config=True,
74 exec_lines = List(Unicode, config=True,
75 help="List of lines to execute")
75 help="List of lines to execute")
76
76
77 # identities:
77 # identities:
78 int_id = Int(-1)
78 int_id = Int(-1)
79 bident = CBytes()
79 bident = CBytes()
80 ident = Unicode()
80 ident = Unicode()
81 def _ident_changed(self, name, old, new):
81 def _ident_changed(self, name, old, new):
82 self.bident = ensure_bytes(new)
82 self.bident = asbytes(new)
83
83
84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85
85
86 control_stream = Instance(zmqstream.ZMQStream)
86 control_stream = Instance(zmqstream.ZMQStream)
87 task_stream = Instance(zmqstream.ZMQStream)
87 task_stream = Instance(zmqstream.ZMQStream)
88 iopub_stream = Instance(zmqstream.ZMQStream)
88 iopub_stream = Instance(zmqstream.ZMQStream)
89 client = Instance('IPython.parallel.Client')
89 client = Instance('IPython.parallel.Client')
90
90
91 # internals
91 # internals
92 shell_streams = List()
92 shell_streams = List()
93 compiler = Instance(CommandCompiler, (), {})
93 compiler = Instance(CommandCompiler, (), {})
94 completer = Instance(KernelCompleter)
94 completer = Instance(KernelCompleter)
95
95
96 aborted = Set()
96 aborted = Set()
97 shell_handlers = Dict()
97 shell_handlers = Dict()
98 control_handlers = Dict()
98 control_handlers = Dict()
99
99
100 def _set_prefix(self):
100 def _set_prefix(self):
101 self.prefix = "engine.%s"%self.int_id
101 self.prefix = "engine.%s"%self.int_id
102
102
103 def _connect_completer(self):
103 def _connect_completer(self):
104 self.completer = KernelCompleter(self.user_ns)
104 self.completer = KernelCompleter(self.user_ns)
105
105
106 def __init__(self, **kwargs):
106 def __init__(self, **kwargs):
107 super(Kernel, self).__init__(**kwargs)
107 super(Kernel, self).__init__(**kwargs)
108 self._set_prefix()
108 self._set_prefix()
109 self._connect_completer()
109 self._connect_completer()
110
110
111 self.on_trait_change(self._set_prefix, 'id')
111 self.on_trait_change(self._set_prefix, 'id')
112 self.on_trait_change(self._connect_completer, 'user_ns')
112 self.on_trait_change(self._connect_completer, 'user_ns')
113
113
114 # Build dict of handlers for message types
114 # Build dict of handlers for message types
115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 'clear_request']:
116 'clear_request']:
117 self.shell_handlers[msg_type] = getattr(self, msg_type)
117 self.shell_handlers[msg_type] = getattr(self, msg_type)
118
118
119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 self.control_handlers[msg_type] = getattr(self, msg_type)
120 self.control_handlers[msg_type] = getattr(self, msg_type)
121
121
122 self._initial_exec_lines()
122 self._initial_exec_lines()
123
123
124 def _wrap_exception(self, method=None):
124 def _wrap_exception(self, method=None):
125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 content=wrap_exception(e_info)
126 content=wrap_exception(e_info)
127 return content
127 return content
128
128
129 def _initial_exec_lines(self):
129 def _initial_exec_lines(self):
130 s = _Passer()
130 s = _Passer()
131 content = dict(silent=True, user_variable=[],user_expressions=[])
131 content = dict(silent=True, user_variable=[],user_expressions=[])
132 for line in self.exec_lines:
132 for line in self.exec_lines:
133 self.log.debug("executing initialization: %s"%line)
133 self.log.debug("executing initialization: %s"%line)
134 content.update({'code':line})
134 content.update({'code':line})
135 msg = self.session.msg('execute_request', content)
135 msg = self.session.msg('execute_request', content)
136 self.execute_request(s, [], msg)
136 self.execute_request(s, [], msg)
137
137
138
138
139 #-------------------- control handlers -----------------------------
139 #-------------------- control handlers -----------------------------
140 def abort_queues(self):
140 def abort_queues(self):
141 for stream in self.shell_streams:
141 for stream in self.shell_streams:
142 if stream:
142 if stream:
143 self.abort_queue(stream)
143 self.abort_queue(stream)
144
144
145 def abort_queue(self, stream):
145 def abort_queue(self, stream):
146 while True:
146 while True:
147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 if msg is None:
148 if msg is None:
149 return
149 return
150
150
151 self.log.info("Aborting:")
151 self.log.info("Aborting:")
152 self.log.info(str(msg))
152 self.log.info(str(msg))
153 msg_type = msg['msg_type']
153 msg_type = msg['msg_type']
154 reply_type = msg_type.split('_')[0] + '_reply'
154 reply_type = msg_type.split('_')[0] + '_reply'
155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 # self.reply_socket.send(ident,zmq.SNDMORE)
156 # self.reply_socket.send(ident,zmq.SNDMORE)
157 # self.reply_socket.send_json(reply_msg)
157 # self.reply_socket.send_json(reply_msg)
158 reply_msg = self.session.send(stream, reply_type,
158 reply_msg = self.session.send(stream, reply_type,
159 content={'status' : 'aborted'}, parent=msg, ident=idents)
159 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 self.log.debug(str(reply_msg))
160 self.log.debug(str(reply_msg))
161 # We need to wait a bit for requests to come in. This can probably
161 # We need to wait a bit for requests to come in. This can probably
162 # be set shorter for true asynchronous clients.
162 # be set shorter for true asynchronous clients.
163 time.sleep(0.05)
163 time.sleep(0.05)
164
164
165 def abort_request(self, stream, ident, parent):
165 def abort_request(self, stream, ident, parent):
166 """abort a specifig msg by id"""
166 """abort a specifig msg by id"""
167 msg_ids = parent['content'].get('msg_ids', None)
167 msg_ids = parent['content'].get('msg_ids', None)
168 if isinstance(msg_ids, basestring):
168 if isinstance(msg_ids, basestring):
169 msg_ids = [msg_ids]
169 msg_ids = [msg_ids]
170 if not msg_ids:
170 if not msg_ids:
171 self.abort_queues()
171 self.abort_queues()
172 for mid in msg_ids:
172 for mid in msg_ids:
173 self.aborted.add(str(mid))
173 self.aborted.add(str(mid))
174
174
175 content = dict(status='ok')
175 content = dict(status='ok')
176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 parent=parent, ident=ident)
177 parent=parent, ident=ident)
178 self.log.debug(str(reply_msg))
178 self.log.debug(str(reply_msg))
179
179
180 def shutdown_request(self, stream, ident, parent):
180 def shutdown_request(self, stream, ident, parent):
181 """kill ourself. This should really be handled in an external process"""
181 """kill ourself. This should really be handled in an external process"""
182 try:
182 try:
183 self.abort_queues()
183 self.abort_queues()
184 except:
184 except:
185 content = self._wrap_exception('shutdown')
185 content = self._wrap_exception('shutdown')
186 else:
186 else:
187 content = dict(parent['content'])
187 content = dict(parent['content'])
188 content['status'] = 'ok'
188 content['status'] = 'ok'
189 msg = self.session.send(stream, 'shutdown_reply',
189 msg = self.session.send(stream, 'shutdown_reply',
190 content=content, parent=parent, ident=ident)
190 content=content, parent=parent, ident=ident)
191 self.log.debug(str(msg))
191 self.log.debug(str(msg))
192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 dc.start()
193 dc.start()
194
194
195 def dispatch_control(self, msg):
195 def dispatch_control(self, msg):
196 idents,msg = self.session.feed_identities(msg, copy=False)
196 idents,msg = self.session.feed_identities(msg, copy=False)
197 try:
197 try:
198 msg = self.session.unpack_message(msg, content=True, copy=False)
198 msg = self.session.unpack_message(msg, content=True, copy=False)
199 except:
199 except:
200 self.log.error("Invalid Message", exc_info=True)
200 self.log.error("Invalid Message", exc_info=True)
201 return
201 return
202 else:
202 else:
203 self.log.debug("Control received, %s", msg)
203 self.log.debug("Control received, %s", msg)
204
204
205 header = msg['header']
205 header = msg['header']
206 msg_id = header['msg_id']
206 msg_id = header['msg_id']
207
207
208 handler = self.control_handlers.get(msg['msg_type'], None)
208 handler = self.control_handlers.get(msg['msg_type'], None)
209 if handler is None:
209 if handler is None:
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 else:
211 else:
212 handler(self.control_stream, idents, msg)
212 handler(self.control_stream, idents, msg)
213
213
214
214
215 #-------------------- queue helpers ------------------------------
215 #-------------------- queue helpers ------------------------------
216
216
217 def check_dependencies(self, dependencies):
217 def check_dependencies(self, dependencies):
218 if not dependencies:
218 if not dependencies:
219 return True
219 return True
220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 anyorall = dependencies[0]
221 anyorall = dependencies[0]
222 dependencies = dependencies[1]
222 dependencies = dependencies[1]
223 else:
223 else:
224 anyorall = 'all'
224 anyorall = 'all'
225 results = self.client.get_results(dependencies,status_only=True)
225 results = self.client.get_results(dependencies,status_only=True)
226 if results['status'] != 'ok':
226 if results['status'] != 'ok':
227 return False
227 return False
228
228
229 if anyorall == 'any':
229 if anyorall == 'any':
230 if not results['completed']:
230 if not results['completed']:
231 return False
231 return False
232 else:
232 else:
233 if results['pending']:
233 if results['pending']:
234 return False
234 return False
235
235
236 return True
236 return True
237
237
238 def check_aborted(self, msg_id):
238 def check_aborted(self, msg_id):
239 return msg_id in self.aborted
239 return msg_id in self.aborted
240
240
241 #-------------------- queue handlers -----------------------------
241 #-------------------- queue handlers -----------------------------
242
242
243 def clear_request(self, stream, idents, parent):
243 def clear_request(self, stream, idents, parent):
244 """Clear our namespace."""
244 """Clear our namespace."""
245 self.user_ns = {}
245 self.user_ns = {}
246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 content = dict(status='ok'))
247 content = dict(status='ok'))
248 self._initial_exec_lines()
248 self._initial_exec_lines()
249
249
250 def execute_request(self, stream, ident, parent):
250 def execute_request(self, stream, ident, parent):
251 self.log.debug('execute request %s'%parent)
251 self.log.debug('execute request %s'%parent)
252 try:
252 try:
253 code = parent[u'content'][u'code']
253 code = parent[u'content'][u'code']
254 except:
254 except:
255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 return
256 return
257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 ident=ensure_bytes('%s.pyin'%self.prefix))
258 ident=asbytes('%s.pyin'%self.prefix))
259 started = datetime.now()
259 started = datetime.now()
260 try:
260 try:
261 comp_code = self.compiler(code, '<zmq-kernel>')
261 comp_code = self.compiler(code, '<zmq-kernel>')
262 # allow for not overriding displayhook
262 # allow for not overriding displayhook
263 if hasattr(sys.displayhook, 'set_parent'):
263 if hasattr(sys.displayhook, 'set_parent'):
264 sys.displayhook.set_parent(parent)
264 sys.displayhook.set_parent(parent)
265 sys.stdout.set_parent(parent)
265 sys.stdout.set_parent(parent)
266 sys.stderr.set_parent(parent)
266 sys.stderr.set_parent(parent)
267 exec comp_code in self.user_ns, self.user_ns
267 exec comp_code in self.user_ns, self.user_ns
268 except:
268 except:
269 exc_content = self._wrap_exception('execute')
269 exc_content = self._wrap_exception('execute')
270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 ident=ensure_bytes('%s.pyerr'%self.prefix))
272 ident=asbytes('%s.pyerr'%self.prefix))
273 reply_content = exc_content
273 reply_content = exc_content
274 else:
274 else:
275 reply_content = {'status' : 'ok'}
275 reply_content = {'status' : 'ok'}
276
276
277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 ident=ident, subheader = dict(started=started))
278 ident=ident, subheader = dict(started=started))
279 self.log.debug(str(reply_msg))
279 self.log.debug(str(reply_msg))
280 if reply_msg['content']['status'] == u'error':
280 if reply_msg['content']['status'] == u'error':
281 self.abort_queues()
281 self.abort_queues()
282
282
283 def complete_request(self, stream, ident, parent):
283 def complete_request(self, stream, ident, parent):
284 matches = {'matches' : self.complete(parent),
284 matches = {'matches' : self.complete(parent),
285 'status' : 'ok'}
285 'status' : 'ok'}
286 completion_msg = self.session.send(stream, 'complete_reply',
286 completion_msg = self.session.send(stream, 'complete_reply',
287 matches, parent, ident)
287 matches, parent, ident)
288 # print >> sys.__stdout__, completion_msg
288 # print >> sys.__stdout__, completion_msg
289
289
290 def complete(self, msg):
290 def complete(self, msg):
291 return self.completer.complete(msg.content.line, msg.content.text)
291 return self.completer.complete(msg.content.line, msg.content.text)
292
292
293 def apply_request(self, stream, ident, parent):
293 def apply_request(self, stream, ident, parent):
294 # flush previous reply, so this request won't block it
294 # flush previous reply, so this request won't block it
295 stream.flush(zmq.POLLOUT)
295 stream.flush(zmq.POLLOUT)
296 try:
296 try:
297 content = parent[u'content']
297 content = parent[u'content']
298 bufs = parent[u'buffers']
298 bufs = parent[u'buffers']
299 msg_id = parent['header']['msg_id']
299 msg_id = parent['header']['msg_id']
300 # bound = parent['header'].get('bound', False)
300 # bound = parent['header'].get('bound', False)
301 except:
301 except:
302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 return
303 return
304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 # self.iopub_stream.send(pyin_msg)
305 # self.iopub_stream.send(pyin_msg)
306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 sub = {'dependencies_met' : True, 'engine' : self.ident,
307 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 'started': datetime.now()}
308 'started': datetime.now()}
309 try:
309 try:
310 # allow for not overriding displayhook
310 # allow for not overriding displayhook
311 if hasattr(sys.displayhook, 'set_parent'):
311 if hasattr(sys.displayhook, 'set_parent'):
312 sys.displayhook.set_parent(parent)
312 sys.displayhook.set_parent(parent)
313 sys.stdout.set_parent(parent)
313 sys.stdout.set_parent(parent)
314 sys.stderr.set_parent(parent)
314 sys.stderr.set_parent(parent)
315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 working = self.user_ns
316 working = self.user_ns
317 # suffix =
317 # suffix =
318 prefix = "_"+str(msg_id).replace("-","")+"_"
318 prefix = "_"+str(msg_id).replace("-","")+"_"
319
319
320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 # if bound:
321 # if bound:
322 # bound_ns = Namespace(working)
322 # bound_ns = Namespace(working)
323 # args = [bound_ns]+list(args)
323 # args = [bound_ns]+list(args)
324
324
325 fname = getattr(f, '__name__', 'f')
325 fname = getattr(f, '__name__', 'f')
326
326
327 fname = prefix+"f"
327 fname = prefix+"f"
328 argname = prefix+"args"
328 argname = prefix+"args"
329 kwargname = prefix+"kwargs"
329 kwargname = prefix+"kwargs"
330 resultname = prefix+"result"
330 resultname = prefix+"result"
331
331
332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 # print ns
333 # print ns
334 working.update(ns)
334 working.update(ns)
335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 try:
336 try:
337 exec code in working,working
337 exec code in working,working
338 result = working.get(resultname)
338 result = working.get(resultname)
339 finally:
339 finally:
340 for key in ns.iterkeys():
340 for key in ns.iterkeys():
341 working.pop(key)
341 working.pop(key)
342 # if bound:
342 # if bound:
343 # working.update(bound_ns)
343 # working.update(bound_ns)
344
344
345 packed_result,buf = serialize_object(result)
345 packed_result,buf = serialize_object(result)
346 result_buf = [packed_result]+buf
346 result_buf = [packed_result]+buf
347 except:
347 except:
348 exc_content = self._wrap_exception('apply')
348 exc_content = self._wrap_exception('apply')
349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 ident=ensure_bytes('%s.pyerr'%self.prefix))
351 ident=asbytes('%s.pyerr'%self.prefix))
352 reply_content = exc_content
352 reply_content = exc_content
353 result_buf = []
353 result_buf = []
354
354
355 if exc_content['ename'] == 'UnmetDependency':
355 if exc_content['ename'] == 'UnmetDependency':
356 sub['dependencies_met'] = False
356 sub['dependencies_met'] = False
357 else:
357 else:
358 reply_content = {'status' : 'ok'}
358 reply_content = {'status' : 'ok'}
359
359
360 # put 'ok'/'error' status in header, for scheduler introspection:
360 # put 'ok'/'error' status in header, for scheduler introspection:
361 sub['status'] = reply_content['status']
361 sub['status'] = reply_content['status']
362
362
363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365
365
366 # flush i/o
366 # flush i/o
367 # should this be before reply_msg is sent, like in the single-kernel code,
367 # should this be before reply_msg is sent, like in the single-kernel code,
368 # or should nothing get in the way of real results?
368 # or should nothing get in the way of real results?
369 sys.stdout.flush()
369 sys.stdout.flush()
370 sys.stderr.flush()
370 sys.stderr.flush()
371
371
372 def dispatch_queue(self, stream, msg):
372 def dispatch_queue(self, stream, msg):
373 self.control_stream.flush()
373 self.control_stream.flush()
374 idents,msg = self.session.feed_identities(msg, copy=False)
374 idents,msg = self.session.feed_identities(msg, copy=False)
375 try:
375 try:
376 msg = self.session.unpack_message(msg, content=True, copy=False)
376 msg = self.session.unpack_message(msg, content=True, copy=False)
377 except:
377 except:
378 self.log.error("Invalid Message", exc_info=True)
378 self.log.error("Invalid Message", exc_info=True)
379 return
379 return
380 else:
380 else:
381 self.log.debug("Message received, %s", msg)
381 self.log.debug("Message received, %s", msg)
382
382
383
383
384 header = msg['header']
384 header = msg['header']
385 msg_id = header['msg_id']
385 msg_id = header['msg_id']
386 if self.check_aborted(msg_id):
386 if self.check_aborted(msg_id):
387 self.aborted.remove(msg_id)
387 self.aborted.remove(msg_id)
388 # is it safe to assume a msg_id will not be resubmitted?
388 # is it safe to assume a msg_id will not be resubmitted?
389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
390 status = {'status' : 'aborted'}
390 status = {'status' : 'aborted'}
391 reply_msg = self.session.send(stream, reply_type, subheader=status,
391 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 content=status, parent=msg, ident=idents)
392 content=status, parent=msg, ident=idents)
393 return
393 return
394 handler = self.shell_handlers.get(msg['msg_type'], None)
394 handler = self.shell_handlers.get(msg['msg_type'], None)
395 if handler is None:
395 if handler is None:
396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
397 else:
397 else:
398 handler(stream, idents, msg)
398 handler(stream, idents, msg)
399
399
400 def start(self):
400 def start(self):
401 #### stream mode:
401 #### stream mode:
402 if self.control_stream:
402 if self.control_stream:
403 self.control_stream.on_recv(self.dispatch_control, copy=False)
403 self.control_stream.on_recv(self.dispatch_control, copy=False)
404 self.control_stream.on_err(printer)
404 self.control_stream.on_err(printer)
405
405
406 def make_dispatcher(stream):
406 def make_dispatcher(stream):
407 def dispatcher(msg):
407 def dispatcher(msg):
408 return self.dispatch_queue(stream, msg)
408 return self.dispatch_queue(stream, msg)
409 return dispatcher
409 return dispatcher
410
410
411 for s in self.shell_streams:
411 for s in self.shell_streams:
412 s.on_recv(make_dispatcher(s), copy=False)
412 s.on_recv(make_dispatcher(s), copy=False)
413 s.on_err(printer)
413 s.on_err(printer)
414
414
415 if self.iopub_stream:
415 if self.iopub_stream:
416 self.iopub_stream.on_err(printer)
416 self.iopub_stream.on_err(printer)
417
417
418 #### while True mode:
418 #### while True mode:
419 # while True:
419 # while True:
420 # idle = True
420 # idle = True
421 # try:
421 # try:
422 # msg = self.shell_stream.socket.recv_multipart(
422 # msg = self.shell_stream.socket.recv_multipart(
423 # zmq.NOBLOCK, copy=False)
423 # zmq.NOBLOCK, copy=False)
424 # except zmq.ZMQError, e:
424 # except zmq.ZMQError, e:
425 # if e.errno != zmq.EAGAIN:
425 # if e.errno != zmq.EAGAIN:
426 # raise e
426 # raise e
427 # else:
427 # else:
428 # idle=False
428 # idle=False
429 # self.dispatch_queue(self.shell_stream, msg)
429 # self.dispatch_queue(self.shell_stream, msg)
430 #
430 #
431 # if not self.task_stream.empty():
431 # if not self.task_stream.empty():
432 # idle=False
432 # idle=False
433 # msg = self.task_stream.recv_multipart()
433 # msg = self.task_stream.recv_multipart()
434 # self.dispatch_queue(self.task_stream, msg)
434 # self.dispatch_queue(self.task_stream, msg)
435 # if idle:
435 # if idle:
436 # # don't busywait
436 # # don't busywait
437 # time.sleep(1e-3)
437 # time.sleep(1e-3)
438
438
@@ -1,136 +1,137 b''
1 """base class for parallel client tests
1 """base class for parallel client tests
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 import sys
15 import sys
16 import tempfile
16 import tempfile
17 import time
17 import time
18
18
19 from nose import SkipTest
19 from nose import SkipTest
20
20
21 import zmq
21 import zmq
22 from zmq.tests import BaseZMQTestCase
22 from zmq.tests import BaseZMQTestCase
23
23
24 from IPython.external.decorator import decorator
24 from IPython.external.decorator import decorator
25
25
26 from IPython.parallel import error
26 from IPython.parallel import error
27 from IPython.parallel import Client
27 from IPython.parallel import Client
28
28
29 from IPython.parallel.tests import launchers, add_engines
29 from IPython.parallel.tests import launchers, add_engines
30
30
31 # simple tasks for use in apply tests
31 # simple tasks for use in apply tests
32
32
33 def segfault():
33 def segfault():
34 """this will segfault"""
34 """this will segfault"""
35 import ctypes
35 import ctypes
36 ctypes.memset(-1,0,1)
36 ctypes.memset(-1,0,1)
37
37
38 def crash():
38 def crash():
39 """from stdlib crashers in the test suite"""
39 """from stdlib crashers in the test suite"""
40 import types
40 import types
41 if sys.platform.startswith('win'):
41 if sys.platform.startswith('win'):
42 import ctypes
42 import ctypes
43 ctypes.windll.kernel32.SetErrorMode(0x0002);
43 ctypes.windll.kernel32.SetErrorMode(0x0002);
44 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
44 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 if sys.version_info[0] >= 3:
45 if sys.version_info[0] >= 3:
46 # Python3 adds 'kwonlyargcount' as the second argument to Code
46 args.insert(1, 0)
47 args.insert(1, 0)
47
48
48 co = types.CodeType(*args)
49 co = types.CodeType(*args)
49 exec(co)
50 exec(co)
50
51
51 def wait(n):
52 def wait(n):
52 """sleep for a time"""
53 """sleep for a time"""
53 import time
54 import time
54 time.sleep(n)
55 time.sleep(n)
55 return n
56 return n
56
57
57 def raiser(eclass):
58 def raiser(eclass):
58 """raise an exception"""
59 """raise an exception"""
59 raise eclass()
60 raise eclass()
60
61
61 # test decorator for skipping tests when libraries are unavailable
62 # test decorator for skipping tests when libraries are unavailable
62 def skip_without(*names):
63 def skip_without(*names):
63 """skip a test if some names are not importable"""
64 """skip a test if some names are not importable"""
64 @decorator
65 @decorator
65 def skip_without_names(f, *args, **kwargs):
66 def skip_without_names(f, *args, **kwargs):
66 """decorator to skip tests in the absence of numpy."""
67 """decorator to skip tests in the absence of numpy."""
67 for name in names:
68 for name in names:
68 try:
69 try:
69 __import__(name)
70 __import__(name)
70 except ImportError:
71 except ImportError:
71 raise SkipTest
72 raise SkipTest
72 return f(*args, **kwargs)
73 return f(*args, **kwargs)
73 return skip_without_names
74 return skip_without_names
74
75
75 class ClusterTestCase(BaseZMQTestCase):
76 class ClusterTestCase(BaseZMQTestCase):
76
77
77 def add_engines(self, n=1, block=True):
78 def add_engines(self, n=1, block=True):
78 """add multiple engines to our cluster"""
79 """add multiple engines to our cluster"""
79 self.engines.extend(add_engines(n))
80 self.engines.extend(add_engines(n))
80 if block:
81 if block:
81 self.wait_on_engines()
82 self.wait_on_engines()
82
83
83 def wait_on_engines(self, timeout=5):
84 def wait_on_engines(self, timeout=5):
84 """wait for our engines to connect."""
85 """wait for our engines to connect."""
85 n = len(self.engines)+self.base_engine_count
86 n = len(self.engines)+self.base_engine_count
86 tic = time.time()
87 tic = time.time()
87 while time.time()-tic < timeout and len(self.client.ids) < n:
88 while time.time()-tic < timeout and len(self.client.ids) < n:
88 time.sleep(0.1)
89 time.sleep(0.1)
89
90
90 assert not len(self.client.ids) < n, "waiting for engines timed out"
91 assert not len(self.client.ids) < n, "waiting for engines timed out"
91
92
92 def connect_client(self):
93 def connect_client(self):
93 """connect a client with my Context, and track its sockets for cleanup"""
94 """connect a client with my Context, and track its sockets for cleanup"""
94 c = Client(profile='iptest', context=self.context)
95 c = Client(profile='iptest', context=self.context)
95 for name in filter(lambda n:n.endswith('socket'), dir(c)):
96 for name in filter(lambda n:n.endswith('socket'), dir(c)):
96 s = getattr(c, name)
97 s = getattr(c, name)
97 s.setsockopt(zmq.LINGER, 0)
98 s.setsockopt(zmq.LINGER, 0)
98 self.sockets.append(s)
99 self.sockets.append(s)
99 return c
100 return c
100
101
101 def assertRaisesRemote(self, etype, f, *args, **kwargs):
102 def assertRaisesRemote(self, etype, f, *args, **kwargs):
102 try:
103 try:
103 try:
104 try:
104 f(*args, **kwargs)
105 f(*args, **kwargs)
105 except error.CompositeError as e:
106 except error.CompositeError as e:
106 e.raise_exception()
107 e.raise_exception()
107 except error.RemoteError as e:
108 except error.RemoteError as e:
108 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
109 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
109 else:
110 else:
110 self.fail("should have raised a RemoteError")
111 self.fail("should have raised a RemoteError")
111
112
112 def setUp(self):
113 def setUp(self):
113 BaseZMQTestCase.setUp(self)
114 BaseZMQTestCase.setUp(self)
114 self.client = self.connect_client()
115 self.client = self.connect_client()
115 # start every test with clean engine namespaces:
116 # start every test with clean engine namespaces:
116 self.client.clear(block=True)
117 self.client.clear(block=True)
117 self.base_engine_count=len(self.client.ids)
118 self.base_engine_count=len(self.client.ids)
118 self.engines=[]
119 self.engines=[]
119
120
120 def tearDown(self):
121 def tearDown(self):
121 # self.client.clear(block=True)
122 # self.client.clear(block=True)
122 # close fds:
123 # close fds:
123 for e in filter(lambda e: e.poll() is not None, launchers):
124 for e in filter(lambda e: e.poll() is not None, launchers):
124 launchers.remove(e)
125 launchers.remove(e)
125
126
126 # allow flushing of incoming messages to prevent crash on socket close
127 # allow flushing of incoming messages to prevent crash on socket close
127 self.client.wait(timeout=2)
128 self.client.wait(timeout=2)
128 # time.sleep(2)
129 # time.sleep(2)
129 self.client.spin()
130 self.client.spin()
130 self.client.close()
131 self.client.close()
131 BaseZMQTestCase.tearDown(self)
132 BaseZMQTestCase.tearDown(self)
132 # this will be redundant when pyzmq merges PR #88
133 # this will be redundant when pyzmq merges PR #88
133 # self.context.term()
134 # self.context.term()
134 # print tempfile.TemporaryFile().fileno(),
135 # print tempfile.TemporaryFile().fileno(),
135 # sys.stdout.flush()
136 # sys.stdout.flush()
136 No newline at end of file
137
@@ -1,456 +1,456 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 # IPython imports
42 # IPython imports
43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 from IPython.utils.newserialized import serialize, unserialize
44 from IPython.utils.newserialized import serialize, unserialize
45 from IPython.zmq.log import EnginePUBHandler
45 from IPython.zmq.log import EnginePUBHandler
46
46
47 #-----------------------------------------------------------------------------
47 #-----------------------------------------------------------------------------
48 # Classes
48 # Classes
49 #-----------------------------------------------------------------------------
49 #-----------------------------------------------------------------------------
50
50
51 class Namespace(dict):
51 class Namespace(dict):
52 """Subclass of dict for attribute access to keys."""
52 """Subclass of dict for attribute access to keys."""
53
53
54 def __getattr__(self, key):
54 def __getattr__(self, key):
55 """getattr aliased to getitem"""
55 """getattr aliased to getitem"""
56 if key in self.iterkeys():
56 if key in self.iterkeys():
57 return self[key]
57 return self[key]
58 else:
58 else:
59 raise NameError(key)
59 raise NameError(key)
60
60
61 def __setattr__(self, key, value):
61 def __setattr__(self, key, value):
62 """setattr aliased to setitem, with strict"""
62 """setattr aliased to setitem, with strict"""
63 if hasattr(dict, key):
63 if hasattr(dict, key):
64 raise KeyError("Cannot override dict keys %r"%key)
64 raise KeyError("Cannot override dict keys %r"%key)
65 self[key] = value
65 self[key] = value
66
66
67
67
68 class ReverseDict(dict):
68 class ReverseDict(dict):
69 """simple double-keyed subset of dict methods."""
69 """simple double-keyed subset of dict methods."""
70
70
71 def __init__(self, *args, **kwargs):
71 def __init__(self, *args, **kwargs):
72 dict.__init__(self, *args, **kwargs)
72 dict.__init__(self, *args, **kwargs)
73 self._reverse = dict()
73 self._reverse = dict()
74 for key, value in self.iteritems():
74 for key, value in self.iteritems():
75 self._reverse[value] = key
75 self._reverse[value] = key
76
76
77 def __getitem__(self, key):
77 def __getitem__(self, key):
78 try:
78 try:
79 return dict.__getitem__(self, key)
79 return dict.__getitem__(self, key)
80 except KeyError:
80 except KeyError:
81 return self._reverse[key]
81 return self._reverse[key]
82
82
83 def __setitem__(self, key, value):
83 def __setitem__(self, key, value):
84 if key in self._reverse:
84 if key in self._reverse:
85 raise KeyError("Can't have key %r on both sides!"%key)
85 raise KeyError("Can't have key %r on both sides!"%key)
86 dict.__setitem__(self, key, value)
86 dict.__setitem__(self, key, value)
87 self._reverse[value] = key
87 self._reverse[value] = key
88
88
89 def pop(self, key):
89 def pop(self, key):
90 value = dict.pop(self, key)
90 value = dict.pop(self, key)
91 self._reverse.pop(value)
91 self._reverse.pop(value)
92 return value
92 return value
93
93
94 def get(self, key, default=None):
94 def get(self, key, default=None):
95 try:
95 try:
96 return self[key]
96 return self[key]
97 except KeyError:
97 except KeyError:
98 return default
98 return default
99
99
100 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
101 # Functions
101 # Functions
102 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
103
103
104 def ensure_bytes(s):
104 def asbytes(s):
105 """ensure that an object is ascii bytes"""
105 """ensure that an object is ascii bytes"""
106 if isinstance(s, unicode):
106 if isinstance(s, unicode):
107 s = s.encode('ascii')
107 s = s.encode('ascii')
108 return s
108 return s
109
109
110 def validate_url(url):
110 def validate_url(url):
111 """validate a url for zeromq"""
111 """validate a url for zeromq"""
112 if not isinstance(url, basestring):
112 if not isinstance(url, basestring):
113 raise TypeError("url must be a string, not %r"%type(url))
113 raise TypeError("url must be a string, not %r"%type(url))
114 url = url.lower()
114 url = url.lower()
115
115
116 proto_addr = url.split('://')
116 proto_addr = url.split('://')
117 assert len(proto_addr) == 2, 'Invalid url: %r'%url
117 assert len(proto_addr) == 2, 'Invalid url: %r'%url
118 proto, addr = proto_addr
118 proto, addr = proto_addr
119 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
119 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
120
120
121 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
121 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
122 # author: Remi Sabourin
122 # author: Remi Sabourin
123 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
123 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
124
124
125 if proto == 'tcp':
125 if proto == 'tcp':
126 lis = addr.split(':')
126 lis = addr.split(':')
127 assert len(lis) == 2, 'Invalid url: %r'%url
127 assert len(lis) == 2, 'Invalid url: %r'%url
128 addr,s_port = lis
128 addr,s_port = lis
129 try:
129 try:
130 port = int(s_port)
130 port = int(s_port)
131 except ValueError:
131 except ValueError:
132 raise AssertionError("Invalid port %r in url: %r"%(port, url))
132 raise AssertionError("Invalid port %r in url: %r"%(port, url))
133
133
134 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
134 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
135
135
136 else:
136 else:
137 # only validate tcp urls currently
137 # only validate tcp urls currently
138 pass
138 pass
139
139
140 return True
140 return True
141
141
142
142
143 def validate_url_container(container):
143 def validate_url_container(container):
144 """validate a potentially nested collection of urls."""
144 """validate a potentially nested collection of urls."""
145 if isinstance(container, basestring):
145 if isinstance(container, basestring):
146 url = container
146 url = container
147 return validate_url(url)
147 return validate_url(url)
148 elif isinstance(container, dict):
148 elif isinstance(container, dict):
149 container = container.itervalues()
149 container = container.itervalues()
150
150
151 for element in container:
151 for element in container:
152 validate_url_container(element)
152 validate_url_container(element)
153
153
154
154
155 def split_url(url):
155 def split_url(url):
156 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
156 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
157 proto_addr = url.split('://')
157 proto_addr = url.split('://')
158 assert len(proto_addr) == 2, 'Invalid url: %r'%url
158 assert len(proto_addr) == 2, 'Invalid url: %r'%url
159 proto, addr = proto_addr
159 proto, addr = proto_addr
160 lis = addr.split(':')
160 lis = addr.split(':')
161 assert len(lis) == 2, 'Invalid url: %r'%url
161 assert len(lis) == 2, 'Invalid url: %r'%url
162 addr,s_port = lis
162 addr,s_port = lis
163 return proto,addr,s_port
163 return proto,addr,s_port
164
164
165 def disambiguate_ip_address(ip, location=None):
165 def disambiguate_ip_address(ip, location=None):
166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
167 ones, based on the location (default interpretation of location is localhost)."""
167 ones, based on the location (default interpretation of location is localhost)."""
168 if ip in ('0.0.0.0', '*'):
168 if ip in ('0.0.0.0', '*'):
169 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
169 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
170 if location is None or location in external_ips:
170 if location is None or location in external_ips:
171 ip='127.0.0.1'
171 ip='127.0.0.1'
172 elif location:
172 elif location:
173 return location
173 return location
174 return ip
174 return ip
175
175
176 def disambiguate_url(url, location=None):
176 def disambiguate_url(url, location=None):
177 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
177 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
178 ones, based on the location (default interpretation is localhost).
178 ones, based on the location (default interpretation is localhost).
179
179
180 This is for zeromq urls, such as tcp://*:10101."""
180 This is for zeromq urls, such as tcp://*:10101."""
181 try:
181 try:
182 proto,ip,port = split_url(url)
182 proto,ip,port = split_url(url)
183 except AssertionError:
183 except AssertionError:
184 # probably not tcp url; could be ipc, etc.
184 # probably not tcp url; could be ipc, etc.
185 return url
185 return url
186
186
187 ip = disambiguate_ip_address(ip,location)
187 ip = disambiguate_ip_address(ip,location)
188
188
189 return "%s://%s:%s"%(proto,ip,port)
189 return "%s://%s:%s"%(proto,ip,port)
190
190
191 def serialize_object(obj, threshold=64e-6):
191 def serialize_object(obj, threshold=64e-6):
192 """Serialize an object into a list of sendable buffers.
192 """Serialize an object into a list of sendable buffers.
193
193
194 Parameters
194 Parameters
195 ----------
195 ----------
196
196
197 obj : object
197 obj : object
198 The object to be serialized
198 The object to be serialized
199 threshold : float
199 threshold : float
200 The threshold for not double-pickling the content.
200 The threshold for not double-pickling the content.
201
201
202
202
203 Returns
203 Returns
204 -------
204 -------
205 ('pmd', [bufs]) :
205 ('pmd', [bufs]) :
206 where pmd is the pickled metadata wrapper,
206 where pmd is the pickled metadata wrapper,
207 bufs is a list of data buffers
207 bufs is a list of data buffers
208 """
208 """
209 databuffers = []
209 databuffers = []
210 if isinstance(obj, (list, tuple)):
210 if isinstance(obj, (list, tuple)):
211 clist = canSequence(obj)
211 clist = canSequence(obj)
212 slist = map(serialize, clist)
212 slist = map(serialize, clist)
213 for s in slist:
213 for s in slist:
214 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
214 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
215 databuffers.append(s.getData())
215 databuffers.append(s.getData())
216 s.data = None
216 s.data = None
217 return pickle.dumps(slist,-1), databuffers
217 return pickle.dumps(slist,-1), databuffers
218 elif isinstance(obj, dict):
218 elif isinstance(obj, dict):
219 sobj = {}
219 sobj = {}
220 for k in sorted(obj.iterkeys()):
220 for k in sorted(obj.iterkeys()):
221 s = serialize(can(obj[k]))
221 s = serialize(can(obj[k]))
222 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
222 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
223 databuffers.append(s.getData())
223 databuffers.append(s.getData())
224 s.data = None
224 s.data = None
225 sobj[k] = s
225 sobj[k] = s
226 return pickle.dumps(sobj,-1),databuffers
226 return pickle.dumps(sobj,-1),databuffers
227 else:
227 else:
228 s = serialize(can(obj))
228 s = serialize(can(obj))
229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 databuffers.append(s.getData())
230 databuffers.append(s.getData())
231 s.data = None
231 s.data = None
232 return pickle.dumps(s,-1),databuffers
232 return pickle.dumps(s,-1),databuffers
233
233
234
234
235 def unserialize_object(bufs):
235 def unserialize_object(bufs):
236 """reconstruct an object serialized by serialize_object from data buffers."""
236 """reconstruct an object serialized by serialize_object from data buffers."""
237 bufs = list(bufs)
237 bufs = list(bufs)
238 sobj = pickle.loads(bufs.pop(0))
238 sobj = pickle.loads(bufs.pop(0))
239 if isinstance(sobj, (list, tuple)):
239 if isinstance(sobj, (list, tuple)):
240 for s in sobj:
240 for s in sobj:
241 if s.data is None:
241 if s.data is None:
242 s.data = bufs.pop(0)
242 s.data = bufs.pop(0)
243 return uncanSequence(map(unserialize, sobj)), bufs
243 return uncanSequence(map(unserialize, sobj)), bufs
244 elif isinstance(sobj, dict):
244 elif isinstance(sobj, dict):
245 newobj = {}
245 newobj = {}
246 for k in sorted(sobj.iterkeys()):
246 for k in sorted(sobj.iterkeys()):
247 s = sobj[k]
247 s = sobj[k]
248 if s.data is None:
248 if s.data is None:
249 s.data = bufs.pop(0)
249 s.data = bufs.pop(0)
250 newobj[k] = uncan(unserialize(s))
250 newobj[k] = uncan(unserialize(s))
251 return newobj, bufs
251 return newobj, bufs
252 else:
252 else:
253 if sobj.data is None:
253 if sobj.data is None:
254 sobj.data = bufs.pop(0)
254 sobj.data = bufs.pop(0)
255 return uncan(unserialize(sobj)), bufs
255 return uncan(unserialize(sobj)), bufs
256
256
257 def pack_apply_message(f, args, kwargs, threshold=64e-6):
257 def pack_apply_message(f, args, kwargs, threshold=64e-6):
258 """pack up a function, args, and kwargs to be sent over the wire
258 """pack up a function, args, and kwargs to be sent over the wire
259 as a series of buffers. Any object whose data is larger than `threshold`
259 as a series of buffers. Any object whose data is larger than `threshold`
260 will not have their data copied (currently only numpy arrays support zero-copy)"""
260 will not have their data copied (currently only numpy arrays support zero-copy)"""
261 msg = [pickle.dumps(can(f),-1)]
261 msg = [pickle.dumps(can(f),-1)]
262 databuffers = [] # for large objects
262 databuffers = [] # for large objects
263 sargs, bufs = serialize_object(args,threshold)
263 sargs, bufs = serialize_object(args,threshold)
264 msg.append(sargs)
264 msg.append(sargs)
265 databuffers.extend(bufs)
265 databuffers.extend(bufs)
266 skwargs, bufs = serialize_object(kwargs,threshold)
266 skwargs, bufs = serialize_object(kwargs,threshold)
267 msg.append(skwargs)
267 msg.append(skwargs)
268 databuffers.extend(bufs)
268 databuffers.extend(bufs)
269 msg.extend(databuffers)
269 msg.extend(databuffers)
270 return msg
270 return msg
271
271
272 def unpack_apply_message(bufs, g=None, copy=True):
272 def unpack_apply_message(bufs, g=None, copy=True):
273 """unpack f,args,kwargs from buffers packed by pack_apply_message()
273 """unpack f,args,kwargs from buffers packed by pack_apply_message()
274 Returns: original f,args,kwargs"""
274 Returns: original f,args,kwargs"""
275 bufs = list(bufs) # allow us to pop
275 bufs = list(bufs) # allow us to pop
276 assert len(bufs) >= 3, "not enough buffers!"
276 assert len(bufs) >= 3, "not enough buffers!"
277 if not copy:
277 if not copy:
278 for i in range(3):
278 for i in range(3):
279 bufs[i] = bufs[i].bytes
279 bufs[i] = bufs[i].bytes
280 cf = pickle.loads(bufs.pop(0))
280 cf = pickle.loads(bufs.pop(0))
281 sargs = list(pickle.loads(bufs.pop(0)))
281 sargs = list(pickle.loads(bufs.pop(0)))
282 skwargs = dict(pickle.loads(bufs.pop(0)))
282 skwargs = dict(pickle.loads(bufs.pop(0)))
283 # print sargs, skwargs
283 # print sargs, skwargs
284 f = uncan(cf, g)
284 f = uncan(cf, g)
285 for sa in sargs:
285 for sa in sargs:
286 if sa.data is None:
286 if sa.data is None:
287 m = bufs.pop(0)
287 m = bufs.pop(0)
288 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
288 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
289 # always use a buffer, until memoryviews get sorted out
289 # always use a buffer, until memoryviews get sorted out
290 sa.data = buffer(m)
290 sa.data = buffer(m)
291 # disable memoryview support
291 # disable memoryview support
292 # if copy:
292 # if copy:
293 # sa.data = buffer(m)
293 # sa.data = buffer(m)
294 # else:
294 # else:
295 # sa.data = m.buffer
295 # sa.data = m.buffer
296 else:
296 else:
297 if copy:
297 if copy:
298 sa.data = m
298 sa.data = m
299 else:
299 else:
300 sa.data = m.bytes
300 sa.data = m.bytes
301
301
302 args = uncanSequence(map(unserialize, sargs), g)
302 args = uncanSequence(map(unserialize, sargs), g)
303 kwargs = {}
303 kwargs = {}
304 for k in sorted(skwargs.iterkeys()):
304 for k in sorted(skwargs.iterkeys()):
305 sa = skwargs[k]
305 sa = skwargs[k]
306 if sa.data is None:
306 if sa.data is None:
307 m = bufs.pop(0)
307 m = bufs.pop(0)
308 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
308 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
309 # always use a buffer, until memoryviews get sorted out
309 # always use a buffer, until memoryviews get sorted out
310 sa.data = buffer(m)
310 sa.data = buffer(m)
311 # disable memoryview support
311 # disable memoryview support
312 # if copy:
312 # if copy:
313 # sa.data = buffer(m)
313 # sa.data = buffer(m)
314 # else:
314 # else:
315 # sa.data = m.buffer
315 # sa.data = m.buffer
316 else:
316 else:
317 if copy:
317 if copy:
318 sa.data = m
318 sa.data = m
319 else:
319 else:
320 sa.data = m.bytes
320 sa.data = m.bytes
321
321
322 kwargs[k] = uncan(unserialize(sa), g)
322 kwargs[k] = uncan(unserialize(sa), g)
323
323
324 return f,args,kwargs
324 return f,args,kwargs
325
325
326 #--------------------------------------------------------------------------
326 #--------------------------------------------------------------------------
327 # helpers for implementing old MEC API via view.apply
327 # helpers for implementing old MEC API via view.apply
328 #--------------------------------------------------------------------------
328 #--------------------------------------------------------------------------
329
329
330 def interactive(f):
330 def interactive(f):
331 """decorator for making functions appear as interactively defined.
331 """decorator for making functions appear as interactively defined.
332 This results in the function being linked to the user_ns as globals()
332 This results in the function being linked to the user_ns as globals()
333 instead of the module globals().
333 instead of the module globals().
334 """
334 """
335 f.__module__ = '__main__'
335 f.__module__ = '__main__'
336 return f
336 return f
337
337
338 @interactive
338 @interactive
339 def _push(ns):
339 def _push(ns):
340 """helper method for implementing `client.push` via `client.apply`"""
340 """helper method for implementing `client.push` via `client.apply`"""
341 globals().update(ns)
341 globals().update(ns)
342
342
343 @interactive
343 @interactive
344 def _pull(keys):
344 def _pull(keys):
345 """helper method for implementing `client.pull` via `client.apply`"""
345 """helper method for implementing `client.pull` via `client.apply`"""
346 user_ns = globals()
346 user_ns = globals()
347 if isinstance(keys, (list,tuple, set)):
347 if isinstance(keys, (list,tuple, set)):
348 for key in keys:
348 for key in keys:
349 if not user_ns.has_key(key):
349 if not user_ns.has_key(key):
350 raise NameError("name '%s' is not defined"%key)
350 raise NameError("name '%s' is not defined"%key)
351 return map(user_ns.get, keys)
351 return map(user_ns.get, keys)
352 else:
352 else:
353 if not user_ns.has_key(keys):
353 if not user_ns.has_key(keys):
354 raise NameError("name '%s' is not defined"%keys)
354 raise NameError("name '%s' is not defined"%keys)
355 return user_ns.get(keys)
355 return user_ns.get(keys)
356
356
357 @interactive
357 @interactive
358 def _execute(code):
358 def _execute(code):
359 """helper method for implementing `client.execute` via `client.apply`"""
359 """helper method for implementing `client.execute` via `client.apply`"""
360 exec code in globals()
360 exec code in globals()
361
361
362 #--------------------------------------------------------------------------
362 #--------------------------------------------------------------------------
363 # extra process management utilities
363 # extra process management utilities
364 #--------------------------------------------------------------------------
364 #--------------------------------------------------------------------------
365
365
366 _random_ports = set()
366 _random_ports = set()
367
367
368 def select_random_ports(n):
368 def select_random_ports(n):
369 """Selects and return n random ports that are available."""
369 """Selects and return n random ports that are available."""
370 ports = []
370 ports = []
371 for i in xrange(n):
371 for i in xrange(n):
372 sock = socket.socket()
372 sock = socket.socket()
373 sock.bind(('', 0))
373 sock.bind(('', 0))
374 while sock.getsockname()[1] in _random_ports:
374 while sock.getsockname()[1] in _random_ports:
375 sock.close()
375 sock.close()
376 sock = socket.socket()
376 sock = socket.socket()
377 sock.bind(('', 0))
377 sock.bind(('', 0))
378 ports.append(sock)
378 ports.append(sock)
379 for i, sock in enumerate(ports):
379 for i, sock in enumerate(ports):
380 port = sock.getsockname()[1]
380 port = sock.getsockname()[1]
381 sock.close()
381 sock.close()
382 ports[i] = port
382 ports[i] = port
383 _random_ports.add(port)
383 _random_ports.add(port)
384 return ports
384 return ports
385
385
386 def signal_children(children):
386 def signal_children(children):
387 """Relay interupt/term signals to children, for more solid process cleanup."""
387 """Relay interupt/term signals to children, for more solid process cleanup."""
388 def terminate_children(sig, frame):
388 def terminate_children(sig, frame):
389 logging.critical("Got signal %i, terminating children..."%sig)
389 logging.critical("Got signal %i, terminating children..."%sig)
390 for child in children:
390 for child in children:
391 child.terminate()
391 child.terminate()
392
392
393 sys.exit(sig != SIGINT)
393 sys.exit(sig != SIGINT)
394 # sys.exit(sig)
394 # sys.exit(sig)
395 for sig in (SIGINT, SIGABRT, SIGTERM):
395 for sig in (SIGINT, SIGABRT, SIGTERM):
396 signal(sig, terminate_children)
396 signal(sig, terminate_children)
397
397
398 def generate_exec_key(keyfile):
398 def generate_exec_key(keyfile):
399 import uuid
399 import uuid
400 newkey = str(uuid.uuid4())
400 newkey = str(uuid.uuid4())
401 with open(keyfile, 'w') as f:
401 with open(keyfile, 'w') as f:
402 # f.write('ipython-key ')
402 # f.write('ipython-key ')
403 f.write(newkey+'\n')
403 f.write(newkey+'\n')
404 # set user-only RW permissions (0600)
404 # set user-only RW permissions (0600)
405 # this will have no effect on Windows
405 # this will have no effect on Windows
406 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
406 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
407
407
408
408
409 def integer_loglevel(loglevel):
409 def integer_loglevel(loglevel):
410 try:
410 try:
411 loglevel = int(loglevel)
411 loglevel = int(loglevel)
412 except ValueError:
412 except ValueError:
413 if isinstance(loglevel, str):
413 if isinstance(loglevel, str):
414 loglevel = getattr(logging, loglevel)
414 loglevel = getattr(logging, loglevel)
415 return loglevel
415 return loglevel
416
416
417 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
417 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
418 logger = logging.getLogger(logname)
418 logger = logging.getLogger(logname)
419 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
419 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
420 # don't add a second PUBHandler
420 # don't add a second PUBHandler
421 return
421 return
422 loglevel = integer_loglevel(loglevel)
422 loglevel = integer_loglevel(loglevel)
423 lsock = context.socket(zmq.PUB)
423 lsock = context.socket(zmq.PUB)
424 lsock.connect(iface)
424 lsock.connect(iface)
425 handler = handlers.PUBHandler(lsock)
425 handler = handlers.PUBHandler(lsock)
426 handler.setLevel(loglevel)
426 handler.setLevel(loglevel)
427 handler.root_topic = root
427 handler.root_topic = root
428 logger.addHandler(handler)
428 logger.addHandler(handler)
429 logger.setLevel(loglevel)
429 logger.setLevel(loglevel)
430
430
431 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
431 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
432 logger = logging.getLogger()
432 logger = logging.getLogger()
433 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
433 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
434 # don't add a second PUBHandler
434 # don't add a second PUBHandler
435 return
435 return
436 loglevel = integer_loglevel(loglevel)
436 loglevel = integer_loglevel(loglevel)
437 lsock = context.socket(zmq.PUB)
437 lsock = context.socket(zmq.PUB)
438 lsock.connect(iface)
438 lsock.connect(iface)
439 handler = EnginePUBHandler(engine, lsock)
439 handler = EnginePUBHandler(engine, lsock)
440 handler.setLevel(loglevel)
440 handler.setLevel(loglevel)
441 logger.addHandler(handler)
441 logger.addHandler(handler)
442 logger.setLevel(loglevel)
442 logger.setLevel(loglevel)
443 return logger
443 return logger
444
444
445 def local_logger(logname, loglevel=logging.DEBUG):
445 def local_logger(logname, loglevel=logging.DEBUG):
446 loglevel = integer_loglevel(loglevel)
446 loglevel = integer_loglevel(loglevel)
447 logger = logging.getLogger(logname)
447 logger = logging.getLogger(logname)
448 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
448 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
449 # don't add a second StreamHandler
449 # don't add a second StreamHandler
450 return
450 return
451 handler = logging.StreamHandler()
451 handler = logging.StreamHandler()
452 handler.setLevel(loglevel)
452 handler.setLevel(loglevel)
453 logger.addHandler(handler)
453 logger.addHandler(handler)
454 logger.setLevel(loglevel)
454 logger.setLevel(loglevel)
455 return logger
455 return logger
456
456
@@ -1,43 +1,43 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Utilities to enable code objects to be pickled.
3 """Utilities to enable code objects to be pickled.
4
4
5 Any process that import this module will be able to pickle code objects. This
5 Any process that import this module will be able to pickle code objects. This
6 includes the func_code attribute of any function. Once unpickled, new
6 includes the func_code attribute of any function. Once unpickled, new
7 functions can be built using new.function(code, globals()). Eventually
7 functions can be built using new.function(code, globals()). Eventually
8 we need to automate all of this so that functions themselves can be pickled.
8 we need to automate all of this so that functions themselves can be pickled.
9
9
10 Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305
10 Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305
11 """
11 """
12
12
13 __docformat__ = "restructuredtext en"
13 __docformat__ = "restructuredtext en"
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Copyright (C) 2008 The IPython Development Team
16 # Copyright (C) 2008 The IPython Development Team
17 #
17 #
18 # Distributed under the terms of the BSD License. The full license is in
18 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
19 # the file COPYING, distributed as part of this software.
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21
21
22 #-------------------------------------------------------------------------------
22 #-------------------------------------------------------------------------------
23 # Imports
23 # Imports
24 #-------------------------------------------------------------------------------
24 #-------------------------------------------------------------------------------
25
25
26 import sys
26 import sys
27 import new, types, copy_reg
27 import types, copy_reg
28
28
29 def code_ctor(*args):
29 def code_ctor(*args):
30 return new.code(*args)
30 return types.CodeType(*args)
31
31
32 def reduce_code(co):
32 def reduce_code(co):
33 if co.co_freevars or co.co_cellvars:
33 if co.co_freevars or co.co_cellvars:
34 raise ValueError("Sorry, cannot pickle code objects with closures")
34 raise ValueError("Sorry, cannot pickle code objects with closures")
35 args = [co.co_argcount, co.co_nlocals, co.co_stacksize,
35 args = [co.co_argcount, co.co_nlocals, co.co_stacksize,
36 co.co_flags, co.co_code, co.co_consts, co.co_names,
36 co.co_flags, co.co_code, co.co_consts, co.co_names,
37 co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno,
37 co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno,
38 co.co_lnotab]
38 co.co_lnotab]
39 if sys.version_info[0] >= 3:
39 if sys.version_info[0] >= 3:
40 args.insert(1, co.co_kwonlyargcount)
40 args.insert(1, co.co_kwonlyargcount)
41 return code_ctor, tuple(args)
41 return code_ctor, tuple(args)
42
42
43 copy_reg.pickle(types.CodeType, reduce_code) No newline at end of file
43 copy_reg.pickle(types.CodeType, reduce_code)
General Comments 0
You need to be logged in to leave comments. Login now