##// END OF EJS Templates
merge IPython.parallel.streamsession into IPython.zmq.session...
MinRK -
Show More
@@ -1,402 +1,402 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5 """
5 """
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
8 # Copyright (C) 2008-2009 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 from __future__ import with_statement
18 from __future__ import with_statement
19
19
20 import os
20 import os
21 import socket
21 import socket
22 import stat
22 import stat
23 import sys
23 import sys
24 import uuid
24 import uuid
25
25
26 from multiprocessing import Process
26 from multiprocessing import Process
27
27
28 import zmq
28 import zmq
29 from zmq.devices import ProcessMonitoredQueue
29 from zmq.devices import ProcessMonitoredQueue
30 from zmq.log.handlers import PUBHandler
30 from zmq.log.handlers import PUBHandler
31 from zmq.utils import jsonapi as json
31 from zmq.utils import jsonapi as json
32
32
33 from IPython.config.application import boolean_flag
33 from IPython.config.application import boolean_flag
34 from IPython.core.newapplication import ProfileDir
34 from IPython.core.newapplication import ProfileDir
35
35
36 from IPython.parallel.apps.baseapp import (
36 from IPython.parallel.apps.baseapp import (
37 BaseParallelApplication,
37 BaseParallelApplication,
38 base_flags
38 base_flags
39 )
39 )
40 from IPython.utils.importstring import import_item
40 from IPython.utils.importstring import import_item
41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
42
42
43 # from IPython.parallel.controller.controller import ControllerFactory
43 # from IPython.parallel.controller.controller import ControllerFactory
44 from IPython.parallel.streamsession import StreamSession
44 from IPython.zmq.session import Session
45 from IPython.parallel.controller.heartmonitor import HeartMonitor
45 from IPython.parallel.controller.heartmonitor import HeartMonitor
46 from IPython.parallel.controller.hub import HubFactory
46 from IPython.parallel.controller.hub import HubFactory
47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
48 from IPython.parallel.controller.sqlitedb import SQLiteDB
48 from IPython.parallel.controller.sqlitedb import SQLiteDB
49
49
50 from IPython.parallel.util import signal_children, split_url
50 from IPython.parallel.util import signal_children, split_url
51
51
52 # conditional import of MongoDB backend class
52 # conditional import of MongoDB backend class
53
53
54 try:
54 try:
55 from IPython.parallel.controller.mongodb import MongoDB
55 from IPython.parallel.controller.mongodb import MongoDB
56 except ImportError:
56 except ImportError:
57 maybe_mongo = []
57 maybe_mongo = []
58 else:
58 else:
59 maybe_mongo = [MongoDB]
59 maybe_mongo = [MongoDB]
60
60
61
61
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63 # Module level variables
63 # Module level variables
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65
65
66
66
67 #: The default config file name for this application
67 #: The default config file name for this application
68 default_config_file_name = u'ipcontroller_config.py'
68 default_config_file_name = u'ipcontroller_config.py'
69
69
70
70
71 _description = """Start the IPython controller for parallel computing.
71 _description = """Start the IPython controller for parallel computing.
72
72
73 The IPython controller provides a gateway between the IPython engines and
73 The IPython controller provides a gateway between the IPython engines and
74 clients. The controller needs to be started before the engines and can be
74 clients. The controller needs to be started before the engines and can be
75 configured using command line options or using a cluster directory. Cluster
75 configured using command line options or using a cluster directory. Cluster
76 directories contain config, log and security files and are usually located in
76 directories contain config, log and security files and are usually located in
77 your ipython directory and named as "cluster_<profile>". See the `profile`
77 your ipython directory and named as "cluster_<profile>". See the `profile`
78 and `profile_dir` options for details.
78 and `profile_dir` options for details.
79 """
79 """
80
80
81
81
82
82
83
83
84 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
85 # The main application
85 # The main application
86 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
87 flags = {}
87 flags = {}
88 flags.update(base_flags)
88 flags.update(base_flags)
89 flags.update({
89 flags.update({
90 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
90 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
91 'Use threads instead of processes for the schedulers'),
91 'Use threads instead of processes for the schedulers'),
92 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
92 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
93 'use the SQLiteDB backend'),
93 'use the SQLiteDB backend'),
94 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
94 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
95 'use the MongoDB backend'),
95 'use the MongoDB backend'),
96 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
96 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
97 'use the in-memory DictDB backend'),
97 'use the in-memory DictDB backend'),
98 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
98 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
99 'reuse existing json connection files')
99 'reuse existing json connection files')
100 })
100 })
101
101
102 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
102 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
103 "Use HMAC digests for authentication of messages.",
103 "Use HMAC digests for authentication of messages.",
104 "Don't authenticate messages."
104 "Don't authenticate messages."
105 ))
105 ))
106
106
107 class IPControllerApp(BaseParallelApplication):
107 class IPControllerApp(BaseParallelApplication):
108
108
109 name = u'ipcontroller'
109 name = u'ipcontroller'
110 description = _description
110 description = _description
111 config_file_name = Unicode(default_config_file_name)
111 config_file_name = Unicode(default_config_file_name)
112 classes = [ProfileDir, StreamSession, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
112 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
113
113
114 # change default to True
114 # change default to True
115 auto_create = Bool(True, config=True,
115 auto_create = Bool(True, config=True,
116 help="""Whether to create profile dir if it doesn't exist.""")
116 help="""Whether to create profile dir if it doesn't exist.""")
117
117
118 reuse_files = Bool(False, config=True,
118 reuse_files = Bool(False, config=True,
119 help='Whether to reuse existing json connection files.'
119 help='Whether to reuse existing json connection files.'
120 )
120 )
121 secure = Bool(True, config=True,
121 secure = Bool(True, config=True,
122 help='Whether to use HMAC digests for extra message authentication.'
122 help='Whether to use HMAC digests for extra message authentication.'
123 )
123 )
124 ssh_server = Unicode(u'', config=True,
124 ssh_server = Unicode(u'', config=True,
125 help="""ssh url for clients to use when connecting to the Controller
125 help="""ssh url for clients to use when connecting to the Controller
126 processes. It should be of the form: [user@]server[:port]. The
126 processes. It should be of the form: [user@]server[:port]. The
127 Controller's listening addresses must be accessible from the ssh server""",
127 Controller's listening addresses must be accessible from the ssh server""",
128 )
128 )
129 location = Unicode(u'', config=True,
129 location = Unicode(u'', config=True,
130 help="""The external IP or domain name of the Controller, used for disambiguating
130 help="""The external IP or domain name of the Controller, used for disambiguating
131 engine and client connections.""",
131 engine and client connections.""",
132 )
132 )
133 import_statements = List([], config=True,
133 import_statements = List([], config=True,
134 help="import statements to be run at startup. Necessary in some environments"
134 help="import statements to be run at startup. Necessary in some environments"
135 )
135 )
136
136
137 use_threads = Bool(False, config=True,
137 use_threads = Bool(False, config=True,
138 help='Use threads instead of processes for the schedulers',
138 help='Use threads instead of processes for the schedulers',
139 )
139 )
140
140
141 # internal
141 # internal
142 children = List()
142 children = List()
143 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
143 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
144
144
145 def _use_threads_changed(self, name, old, new):
145 def _use_threads_changed(self, name, old, new):
146 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
146 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
147
147
148 aliases = Dict(dict(
148 aliases = Dict(dict(
149 log_level = 'IPControllerApp.log_level',
149 log_level = 'IPControllerApp.log_level',
150 log_url = 'IPControllerApp.log_url',
150 log_url = 'IPControllerApp.log_url',
151 reuse_files = 'IPControllerApp.reuse_files',
151 reuse_files = 'IPControllerApp.reuse_files',
152 secure = 'IPControllerApp.secure',
152 secure = 'IPControllerApp.secure',
153 ssh = 'IPControllerApp.ssh_server',
153 ssh = 'IPControllerApp.ssh_server',
154 use_threads = 'IPControllerApp.use_threads',
154 use_threads = 'IPControllerApp.use_threads',
155 import_statements = 'IPControllerApp.import_statements',
155 import_statements = 'IPControllerApp.import_statements',
156 location = 'IPControllerApp.location',
156 location = 'IPControllerApp.location',
157
157
158 ident = 'StreamSession.session',
158 ident = 'Session.session',
159 user = 'StreamSession.username',
159 user = 'Session.username',
160 exec_key = 'StreamSession.keyfile',
160 exec_key = 'Session.keyfile',
161
161
162 url = 'HubFactory.url',
162 url = 'HubFactory.url',
163 ip = 'HubFactory.ip',
163 ip = 'HubFactory.ip',
164 transport = 'HubFactory.transport',
164 transport = 'HubFactory.transport',
165 port = 'HubFactory.regport',
165 port = 'HubFactory.regport',
166
166
167 ping = 'HeartMonitor.period',
167 ping = 'HeartMonitor.period',
168
168
169 scheme = 'TaskScheduler.scheme_name',
169 scheme = 'TaskScheduler.scheme_name',
170 hwm = 'TaskScheduler.hwm',
170 hwm = 'TaskScheduler.hwm',
171
171
172
172
173 profile = "BaseIPythonApplication.profile",
173 profile = "BaseIPythonApplication.profile",
174 profile_dir = 'ProfileDir.location',
174 profile_dir = 'ProfileDir.location',
175
175
176 ))
176 ))
177 flags = Dict(flags)
177 flags = Dict(flags)
178
178
179
179
180 def save_connection_dict(self, fname, cdict):
180 def save_connection_dict(self, fname, cdict):
181 """save a connection dict to json file."""
181 """save a connection dict to json file."""
182 c = self.config
182 c = self.config
183 url = cdict['url']
183 url = cdict['url']
184 location = cdict['location']
184 location = cdict['location']
185 if not location:
185 if not location:
186 try:
186 try:
187 proto,ip,port = split_url(url)
187 proto,ip,port = split_url(url)
188 except AssertionError:
188 except AssertionError:
189 pass
189 pass
190 else:
190 else:
191 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
191 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
192 cdict['location'] = location
192 cdict['location'] = location
193 fname = os.path.join(self.profile_dir.security_dir, fname)
193 fname = os.path.join(self.profile_dir.security_dir, fname)
194 with open(fname, 'w') as f:
194 with open(fname, 'w') as f:
195 f.write(json.dumps(cdict, indent=2))
195 f.write(json.dumps(cdict, indent=2))
196 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
196 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
197
197
198 def load_config_from_json(self):
198 def load_config_from_json(self):
199 """load config from existing json connector files."""
199 """load config from existing json connector files."""
200 c = self.config
200 c = self.config
201 # load from engine config
201 # load from engine config
202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 cfg = json.loads(f.read())
203 cfg = json.loads(f.read())
204 key = c.StreamSession.key = cfg['exec_key']
204 key = c.Session.key = cfg['exec_key']
205 xport,addr = cfg['url'].split('://')
205 xport,addr = cfg['url'].split('://')
206 c.HubFactory.engine_transport = xport
206 c.HubFactory.engine_transport = xport
207 ip,ports = addr.split(':')
207 ip,ports = addr.split(':')
208 c.HubFactory.engine_ip = ip
208 c.HubFactory.engine_ip = ip
209 c.HubFactory.regport = int(ports)
209 c.HubFactory.regport = int(ports)
210 self.location = cfg['location']
210 self.location = cfg['location']
211
211
212 # load client config
212 # load client config
213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 cfg = json.loads(f.read())
214 cfg = json.loads(f.read())
215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 xport,addr = cfg['url'].split('://')
216 xport,addr = cfg['url'].split('://')
217 c.HubFactory.client_transport = xport
217 c.HubFactory.client_transport = xport
218 ip,ports = addr.split(':')
218 ip,ports = addr.split(':')
219 c.HubFactory.client_ip = ip
219 c.HubFactory.client_ip = ip
220 self.ssh_server = cfg['ssh']
220 self.ssh_server = cfg['ssh']
221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222
222
223 def init_hub(self):
223 def init_hub(self):
224 c = self.config
224 c = self.config
225
225
226 self.do_import_statements()
226 self.do_import_statements()
227 reusing = self.reuse_files
227 reusing = self.reuse_files
228 if reusing:
228 if reusing:
229 try:
229 try:
230 self.load_config_from_json()
230 self.load_config_from_json()
231 except (AssertionError,IOError):
231 except (AssertionError,IOError):
232 reusing=False
232 reusing=False
233 # check again, because reusing may have failed:
233 # check again, because reusing may have failed:
234 if reusing:
234 if reusing:
235 pass
235 pass
236 elif self.secure:
236 elif self.secure:
237 key = str(uuid.uuid4())
237 key = str(uuid.uuid4())
238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 # with open(keyfile, 'w') as f:
239 # with open(keyfile, 'w') as f:
240 # f.write(key)
240 # f.write(key)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.StreamSession.key = key
242 c.Session.key = key
243 else:
243 else:
244 key = c.StreamSession.key = ''
244 key = c.Session.key = ''
245
245
246 try:
246 try:
247 self.factory = HubFactory(config=c, log=self.log)
247 self.factory = HubFactory(config=c, log=self.log)
248 # self.start_logging()
248 # self.start_logging()
249 self.factory.init_hub()
249 self.factory.init_hub()
250 except:
250 except:
251 self.log.error("Couldn't construct the Controller", exc_info=True)
251 self.log.error("Couldn't construct the Controller", exc_info=True)
252 self.exit(1)
252 self.exit(1)
253
253
254 if not reusing:
254 if not reusing:
255 # save to new json config files
255 # save to new json config files
256 f = self.factory
256 f = self.factory
257 cdict = {'exec_key' : key,
257 cdict = {'exec_key' : key,
258 'ssh' : self.ssh_server,
258 'ssh' : self.ssh_server,
259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 'location' : self.location
260 'location' : self.location
261 }
261 }
262 self.save_connection_dict('ipcontroller-client.json', cdict)
262 self.save_connection_dict('ipcontroller-client.json', cdict)
263 edict = cdict
263 edict = cdict
264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 self.save_connection_dict('ipcontroller-engine.json', edict)
265 self.save_connection_dict('ipcontroller-engine.json', edict)
266
266
267 #
267 #
268 def init_schedulers(self):
268 def init_schedulers(self):
269 children = self.children
269 children = self.children
270 mq = import_item(str(self.mq_class))
270 mq = import_item(str(self.mq_class))
271
271
272 hub = self.factory
272 hub = self.factory
273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 # IOPub relay (in a Process)
274 # IOPub relay (in a Process)
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
276 q.bind_in(hub.client_info['iopub'])
276 q.bind_in(hub.client_info['iopub'])
277 q.bind_out(hub.engine_info['iopub'])
277 q.bind_out(hub.engine_info['iopub'])
278 q.setsockopt_out(zmq.SUBSCRIBE, '')
278 q.setsockopt_out(zmq.SUBSCRIBE, '')
279 q.connect_mon(hub.monitor_url)
279 q.connect_mon(hub.monitor_url)
280 q.daemon=True
280 q.daemon=True
281 children.append(q)
281 children.append(q)
282
282
283 # Multiplexer Queue (in a Process)
283 # Multiplexer Queue (in a Process)
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
285 q.bind_in(hub.client_info['mux'])
285 q.bind_in(hub.client_info['mux'])
286 q.setsockopt_in(zmq.IDENTITY, 'mux')
286 q.setsockopt_in(zmq.IDENTITY, 'mux')
287 q.bind_out(hub.engine_info['mux'])
287 q.bind_out(hub.engine_info['mux'])
288 q.connect_mon(hub.monitor_url)
288 q.connect_mon(hub.monitor_url)
289 q.daemon=True
289 q.daemon=True
290 children.append(q)
290 children.append(q)
291
291
292 # Control Queue (in a Process)
292 # Control Queue (in a Process)
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
294 q.bind_in(hub.client_info['control'])
294 q.bind_in(hub.client_info['control'])
295 q.setsockopt_in(zmq.IDENTITY, 'control')
295 q.setsockopt_in(zmq.IDENTITY, 'control')
296 q.bind_out(hub.engine_info['control'])
296 q.bind_out(hub.engine_info['control'])
297 q.connect_mon(hub.monitor_url)
297 q.connect_mon(hub.monitor_url)
298 q.daemon=True
298 q.daemon=True
299 children.append(q)
299 children.append(q)
300 try:
300 try:
301 scheme = self.config.TaskScheduler.scheme_name
301 scheme = self.config.TaskScheduler.scheme_name
302 except AttributeError:
302 except AttributeError:
303 scheme = TaskScheduler.scheme_name.get_default_value()
303 scheme = TaskScheduler.scheme_name.get_default_value()
304 # Task Queue (in a Process)
304 # Task Queue (in a Process)
305 if scheme == 'pure':
305 if scheme == 'pure':
306 self.log.warn("task::using pure XREQ Task scheduler")
306 self.log.warn("task::using pure XREQ Task scheduler")
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
308 # q.setsockopt_out(zmq.HWM, hub.hwm)
308 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 q.bind_in(hub.client_info['task'][1])
309 q.bind_in(hub.client_info['task'][1])
310 q.setsockopt_in(zmq.IDENTITY, 'task')
310 q.setsockopt_in(zmq.IDENTITY, 'task')
311 q.bind_out(hub.engine_info['task'])
311 q.bind_out(hub.engine_info['task'])
312 q.connect_mon(hub.monitor_url)
312 q.connect_mon(hub.monitor_url)
313 q.daemon=True
313 q.daemon=True
314 children.append(q)
314 children.append(q)
315 elif scheme == 'none':
315 elif scheme == 'none':
316 self.log.warn("task::using no Task scheduler")
316 self.log.warn("task::using no Task scheduler")
317
317
318 else:
318 else:
319 self.log.info("task::using Python %s Task scheduler"%scheme)
319 self.log.info("task::using Python %s Task scheduler"%scheme)
320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 hub.monitor_url, hub.client_info['notification'])
321 hub.monitor_url, hub.client_info['notification'])
322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 log_url = self.log_url, config=dict(self.config))
323 log_url = self.log_url, config=dict(self.config))
324 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
324 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
325 q.daemon=True
325 q.daemon=True
326 children.append(q)
326 children.append(q)
327
327
328
328
329 def save_urls(self):
329 def save_urls(self):
330 """save the registration urls to files."""
330 """save the registration urls to files."""
331 c = self.config
331 c = self.config
332
332
333 sec_dir = self.profile_dir.security_dir
333 sec_dir = self.profile_dir.security_dir
334 cf = self.factory
334 cf = self.factory
335
335
336 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
336 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
337 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
337 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
338
338
339 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
339 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
340 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
340 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
341
341
342
342
343 def do_import_statements(self):
343 def do_import_statements(self):
344 statements = self.import_statements
344 statements = self.import_statements
345 for s in statements:
345 for s in statements:
346 try:
346 try:
347 self.log.msg("Executing statement: '%s'" % s)
347 self.log.msg("Executing statement: '%s'" % s)
348 exec s in globals(), locals()
348 exec s in globals(), locals()
349 except:
349 except:
350 self.log.msg("Error running statement: %s" % s)
350 self.log.msg("Error running statement: %s" % s)
351
351
352 def forward_logging(self):
352 def forward_logging(self):
353 if self.log_url:
353 if self.log_url:
354 self.log.info("Forwarding logging to %s"%self.log_url)
354 self.log.info("Forwarding logging to %s"%self.log_url)
355 context = zmq.Context.instance()
355 context = zmq.Context.instance()
356 lsock = context.socket(zmq.PUB)
356 lsock = context.socket(zmq.PUB)
357 lsock.connect(self.log_url)
357 lsock.connect(self.log_url)
358 handler = PUBHandler(lsock)
358 handler = PUBHandler(lsock)
359 self.log.removeHandler(self._log_handler)
359 self.log.removeHandler(self._log_handler)
360 handler.root_topic = 'controller'
360 handler.root_topic = 'controller'
361 handler.setLevel(self.log_level)
361 handler.setLevel(self.log_level)
362 self.log.addHandler(handler)
362 self.log.addHandler(handler)
363 self._log_handler = handler
363 self._log_handler = handler
364 # #
364 # #
365
365
366 def initialize(self, argv=None):
366 def initialize(self, argv=None):
367 super(IPControllerApp, self).initialize(argv)
367 super(IPControllerApp, self).initialize(argv)
368 self.forward_logging()
368 self.forward_logging()
369 self.init_hub()
369 self.init_hub()
370 self.init_schedulers()
370 self.init_schedulers()
371
371
372 def start(self):
372 def start(self):
373 # Start the subprocesses:
373 # Start the subprocesses:
374 self.factory.start()
374 self.factory.start()
375 child_procs = []
375 child_procs = []
376 for child in self.children:
376 for child in self.children:
377 child.start()
377 child.start()
378 if isinstance(child, ProcessMonitoredQueue):
378 if isinstance(child, ProcessMonitoredQueue):
379 child_procs.append(child.launcher)
379 child_procs.append(child.launcher)
380 elif isinstance(child, Process):
380 elif isinstance(child, Process):
381 child_procs.append(child)
381 child_procs.append(child)
382 if child_procs:
382 if child_procs:
383 signal_children(child_procs)
383 signal_children(child_procs)
384
384
385 self.write_pid_file(overwrite=True)
385 self.write_pid_file(overwrite=True)
386
386
387 try:
387 try:
388 self.factory.loop.start()
388 self.factory.loop.start()
389 except KeyboardInterrupt:
389 except KeyboardInterrupt:
390 self.log.critical("Interrupted, Exiting...\n")
390 self.log.critical("Interrupted, Exiting...\n")
391
391
392
392
393
393
394 def launch_new_instance():
394 def launch_new_instance():
395 """Create and run the IPython controller"""
395 """Create and run the IPython controller"""
396 app = IPControllerApp.instance()
396 app = IPControllerApp.instance()
397 app.initialize()
397 app.initialize()
398 app.start()
398 app.start()
399
399
400
400
401 if __name__ == '__main__':
401 if __name__ == '__main__':
402 launch_new_instance()
402 launch_new_instance()
@@ -1,270 +1,270 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython engine application
4 The IPython engine application
5 """
5 """
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
8 # Copyright (C) 2008-2009 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import json
18 import json
19 import os
19 import os
20 import sys
20 import sys
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24
24
25 from IPython.core.newapplication import ProfileDir
25 from IPython.core.newapplication import ProfileDir
26 from IPython.parallel.apps.baseapp import BaseParallelApplication
26 from IPython.parallel.apps.baseapp import BaseParallelApplication
27 from IPython.zmq.log import EnginePUBHandler
27 from IPython.zmq.log import EnginePUBHandler
28
28
29 from IPython.config.configurable import Configurable
29 from IPython.config.configurable import Configurable
30 from IPython.parallel.streamsession import StreamSession
30 from IPython.zmq.session import Session
31 from IPython.parallel.engine.engine import EngineFactory
31 from IPython.parallel.engine.engine import EngineFactory
32 from IPython.parallel.engine.streamkernel import Kernel
32 from IPython.parallel.engine.streamkernel import Kernel
33 from IPython.parallel.util import disambiguate_url
33 from IPython.parallel.util import disambiguate_url
34
34
35 from IPython.utils.importstring import import_item
35 from IPython.utils.importstring import import_item
36 from IPython.utils.traitlets import Bool, Unicode, Dict, List
36 from IPython.utils.traitlets import Bool, Unicode, Dict, List
37
37
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Module level variables
40 # Module level variables
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 #: The default config file name for this application
43 #: The default config file name for this application
44 default_config_file_name = u'ipengine_config.py'
44 default_config_file_name = u'ipengine_config.py'
45
45
46 _description = """Start an IPython engine for parallel computing.
46 _description = """Start an IPython engine for parallel computing.
47
47
48 IPython engines run in parallel and perform computations on behalf of a client
48 IPython engines run in parallel and perform computations on behalf of a client
49 and controller. A controller needs to be started before the engines. The
49 and controller. A controller needs to be started before the engines. The
50 engine can be configured using command line options or using a cluster
50 engine can be configured using command line options or using a cluster
51 directory. Cluster directories contain config, log and security files and are
51 directory. Cluster directories contain config, log and security files and are
52 usually located in your ipython directory and named as "cluster_<profile>".
52 usually located in your ipython directory and named as "cluster_<profile>".
53 See the `profile` and `profile_dir` options for details.
53 See the `profile` and `profile_dir` options for details.
54 """
54 """
55
55
56
56
57 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
58 # MPI configuration
58 # MPI configuration
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60
60
61 mpi4py_init = """from mpi4py import MPI as mpi
61 mpi4py_init = """from mpi4py import MPI as mpi
62 mpi.size = mpi.COMM_WORLD.Get_size()
62 mpi.size = mpi.COMM_WORLD.Get_size()
63 mpi.rank = mpi.COMM_WORLD.Get_rank()
63 mpi.rank = mpi.COMM_WORLD.Get_rank()
64 """
64 """
65
65
66
66
67 pytrilinos_init = """from PyTrilinos import Epetra
67 pytrilinos_init = """from PyTrilinos import Epetra
68 class SimpleStruct:
68 class SimpleStruct:
69 pass
69 pass
70 mpi = SimpleStruct()
70 mpi = SimpleStruct()
71 mpi.rank = 0
71 mpi.rank = 0
72 mpi.size = 0
72 mpi.size = 0
73 """
73 """
74
74
75 class MPI(Configurable):
75 class MPI(Configurable):
76 """Configurable for MPI initialization"""
76 """Configurable for MPI initialization"""
77 use = Unicode('', config=True,
77 use = Unicode('', config=True,
78 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
78 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
79 )
79 )
80
80
81 def _on_use_changed(self, old, new):
81 def _on_use_changed(self, old, new):
82 # load default init script if it's not set
82 # load default init script if it's not set
83 if not self.init_script:
83 if not self.init_script:
84 self.init_script = self.default_inits.get(new, '')
84 self.init_script = self.default_inits.get(new, '')
85
85
86 init_script = Unicode('', config=True,
86 init_script = Unicode('', config=True,
87 help="Initialization code for MPI")
87 help="Initialization code for MPI")
88
88
89 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
89 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
90 config=True)
90 config=True)
91
91
92
92
93 #-----------------------------------------------------------------------------
93 #-----------------------------------------------------------------------------
94 # Main application
94 # Main application
95 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
96
96
97
97
98 class IPEngineApp(BaseParallelApplication):
98 class IPEngineApp(BaseParallelApplication):
99
99
100 app_name = Unicode(u'ipengine')
100 app_name = Unicode(u'ipengine')
101 description = Unicode(_description)
101 description = Unicode(_description)
102 config_file_name = Unicode(default_config_file_name)
102 config_file_name = Unicode(default_config_file_name)
103 classes = List([ProfileDir, StreamSession, EngineFactory, Kernel, MPI])
103 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
104
104
105 startup_script = Unicode(u'', config=True,
105 startup_script = Unicode(u'', config=True,
106 help='specify a script to be run at startup')
106 help='specify a script to be run at startup')
107 startup_command = Unicode('', config=True,
107 startup_command = Unicode('', config=True,
108 help='specify a command to be run at startup')
108 help='specify a command to be run at startup')
109
109
110 url_file = Unicode(u'', config=True,
110 url_file = Unicode(u'', config=True,
111 help="""The full location of the file containing the connection information for
111 help="""The full location of the file containing the connection information for
112 the controller. If this is not given, the file must be in the
112 the controller. If this is not given, the file must be in the
113 security directory of the cluster directory. This location is
113 security directory of the cluster directory. This location is
114 resolved using the `profile` or `profile_dir` options.""",
114 resolved using the `profile` or `profile_dir` options.""",
115 )
115 )
116
116
117 url_file_name = Unicode(u'ipcontroller-engine.json')
117 url_file_name = Unicode(u'ipcontroller-engine.json')
118 log_url = Unicode('', config=True,
118 log_url = Unicode('', config=True,
119 help="""The URL for the iploggerapp instance, for forwarding
119 help="""The URL for the iploggerapp instance, for forwarding
120 logging to a central location.""")
120 logging to a central location.""")
121
121
122 aliases = Dict(dict(
122 aliases = Dict(dict(
123 file = 'IPEngineApp.url_file',
123 file = 'IPEngineApp.url_file',
124 c = 'IPEngineApp.startup_command',
124 c = 'IPEngineApp.startup_command',
125 s = 'IPEngineApp.startup_script',
125 s = 'IPEngineApp.startup_script',
126
126
127 ident = 'StreamSession.session',
127 ident = 'Session.session',
128 user = 'StreamSession.username',
128 user = 'Session.username',
129 exec_key = 'StreamSession.keyfile',
129 exec_key = 'Session.keyfile',
130
130
131 url = 'EngineFactory.url',
131 url = 'EngineFactory.url',
132 ip = 'EngineFactory.ip',
132 ip = 'EngineFactory.ip',
133 transport = 'EngineFactory.transport',
133 transport = 'EngineFactory.transport',
134 port = 'EngineFactory.regport',
134 port = 'EngineFactory.regport',
135 location = 'EngineFactory.location',
135 location = 'EngineFactory.location',
136
136
137 timeout = 'EngineFactory.timeout',
137 timeout = 'EngineFactory.timeout',
138
138
139 profile = "IPEngineApp.profile",
139 profile = "IPEngineApp.profile",
140 profile_dir = 'ProfileDir.location',
140 profile_dir = 'ProfileDir.location',
141
141
142 mpi = 'MPI.use',
142 mpi = 'MPI.use',
143
143
144 log_level = 'IPEngineApp.log_level',
144 log_level = 'IPEngineApp.log_level',
145 log_url = 'IPEngineApp.log_url'
145 log_url = 'IPEngineApp.log_url'
146 ))
146 ))
147
147
148 # def find_key_file(self):
148 # def find_key_file(self):
149 # """Set the key file.
149 # """Set the key file.
150 #
150 #
151 # Here we don't try to actually see if it exists for is valid as that
151 # Here we don't try to actually see if it exists for is valid as that
152 # is hadled by the connection logic.
152 # is hadled by the connection logic.
153 # """
153 # """
154 # config = self.master_config
154 # config = self.master_config
155 # # Find the actual controller key file
155 # # Find the actual controller key file
156 # if not config.Global.key_file:
156 # if not config.Global.key_file:
157 # try_this = os.path.join(
157 # try_this = os.path.join(
158 # config.Global.profile_dir,
158 # config.Global.profile_dir,
159 # config.Global.security_dir,
159 # config.Global.security_dir,
160 # config.Global.key_file_name
160 # config.Global.key_file_name
161 # )
161 # )
162 # config.Global.key_file = try_this
162 # config.Global.key_file = try_this
163
163
164 def find_url_file(self):
164 def find_url_file(self):
165 """Set the key file.
165 """Set the key file.
166
166
167 Here we don't try to actually see if it exists for is valid as that
167 Here we don't try to actually see if it exists for is valid as that
168 is hadled by the connection logic.
168 is hadled by the connection logic.
169 """
169 """
170 config = self.config
170 config = self.config
171 # Find the actual controller key file
171 # Find the actual controller key file
172 if not self.url_file:
172 if not self.url_file:
173 self.url_file = os.path.join(
173 self.url_file = os.path.join(
174 self.profile_dir.security_dir,
174 self.profile_dir.security_dir,
175 self.url_file_name
175 self.url_file_name
176 )
176 )
177 def init_engine(self):
177 def init_engine(self):
178 # This is the working dir by now.
178 # This is the working dir by now.
179 sys.path.insert(0, '')
179 sys.path.insert(0, '')
180 config = self.config
180 config = self.config
181 # print config
181 # print config
182 self.find_url_file()
182 self.find_url_file()
183
183
184 # if os.path.exists(config.Global.key_file) and config.Global.secure:
184 # if os.path.exists(config.Global.key_file) and config.Global.secure:
185 # config.SessionFactory.exec_key = config.Global.key_file
185 # config.SessionFactory.exec_key = config.Global.key_file
186 if os.path.exists(self.url_file):
186 if os.path.exists(self.url_file):
187 with open(self.url_file) as f:
187 with open(self.url_file) as f:
188 d = json.loads(f.read())
188 d = json.loads(f.read())
189 for k,v in d.iteritems():
189 for k,v in d.iteritems():
190 if isinstance(v, unicode):
190 if isinstance(v, unicode):
191 d[k] = v.encode()
191 d[k] = v.encode()
192 if d['exec_key']:
192 if d['exec_key']:
193 config.StreamSession.key = d['exec_key']
193 config.Session.key = d['exec_key']
194 d['url'] = disambiguate_url(d['url'], d['location'])
194 d['url'] = disambiguate_url(d['url'], d['location'])
195 config.EngineFactory.url = d['url']
195 config.EngineFactory.url = d['url']
196 config.EngineFactory.location = d['location']
196 config.EngineFactory.location = d['location']
197
197
198 try:
198 try:
199 exec_lines = config.Kernel.exec_lines
199 exec_lines = config.Kernel.exec_lines
200 except AttributeError:
200 except AttributeError:
201 config.Kernel.exec_lines = []
201 config.Kernel.exec_lines = []
202 exec_lines = config.Kernel.exec_lines
202 exec_lines = config.Kernel.exec_lines
203
203
204 if self.startup_script:
204 if self.startup_script:
205 enc = sys.getfilesystemencoding() or 'utf8'
205 enc = sys.getfilesystemencoding() or 'utf8'
206 cmd="execfile(%r)"%self.startup_script.encode(enc)
206 cmd="execfile(%r)"%self.startup_script.encode(enc)
207 exec_lines.append(cmd)
207 exec_lines.append(cmd)
208 if self.startup_command:
208 if self.startup_command:
209 exec_lines.append(self.startup_command)
209 exec_lines.append(self.startup_command)
210
210
211 # Create the underlying shell class and Engine
211 # Create the underlying shell class and Engine
212 # shell_class = import_item(self.master_config.Global.shell_class)
212 # shell_class = import_item(self.master_config.Global.shell_class)
213 # print self.config
213 # print self.config
214 try:
214 try:
215 self.engine = EngineFactory(config=config, log=self.log)
215 self.engine = EngineFactory(config=config, log=self.log)
216 except:
216 except:
217 self.log.error("Couldn't start the Engine", exc_info=True)
217 self.log.error("Couldn't start the Engine", exc_info=True)
218 self.exit(1)
218 self.exit(1)
219
219
220 def forward_logging(self):
220 def forward_logging(self):
221 if self.log_url:
221 if self.log_url:
222 self.log.info("Forwarding logging to %s"%self.log_url)
222 self.log.info("Forwarding logging to %s"%self.log_url)
223 context = self.engine.context
223 context = self.engine.context
224 lsock = context.socket(zmq.PUB)
224 lsock = context.socket(zmq.PUB)
225 lsock.connect(self.log_url)
225 lsock.connect(self.log_url)
226 self.log.removeHandler(self._log_handler)
226 self.log.removeHandler(self._log_handler)
227 handler = EnginePUBHandler(self.engine, lsock)
227 handler = EnginePUBHandler(self.engine, lsock)
228 handler.setLevel(self.log_level)
228 handler.setLevel(self.log_level)
229 self.log.addHandler(handler)
229 self.log.addHandler(handler)
230 self._log_handler = handler
230 self._log_handler = handler
231 #
231 #
232 def init_mpi(self):
232 def init_mpi(self):
233 global mpi
233 global mpi
234 self.mpi = MPI(config=self.config)
234 self.mpi = MPI(config=self.config)
235
235
236 mpi_import_statement = self.mpi.init_script
236 mpi_import_statement = self.mpi.init_script
237 if mpi_import_statement:
237 if mpi_import_statement:
238 try:
238 try:
239 self.log.info("Initializing MPI:")
239 self.log.info("Initializing MPI:")
240 self.log.info(mpi_import_statement)
240 self.log.info(mpi_import_statement)
241 exec mpi_import_statement in globals()
241 exec mpi_import_statement in globals()
242 except:
242 except:
243 mpi = None
243 mpi = None
244 else:
244 else:
245 mpi = None
245 mpi = None
246
246
247 def initialize(self, argv=None):
247 def initialize(self, argv=None):
248 super(IPEngineApp, self).initialize(argv)
248 super(IPEngineApp, self).initialize(argv)
249 self.init_mpi()
249 self.init_mpi()
250 self.init_engine()
250 self.init_engine()
251 self.forward_logging()
251 self.forward_logging()
252
252
253 def start(self):
253 def start(self):
254 self.engine.start()
254 self.engine.start()
255 try:
255 try:
256 self.engine.loop.start()
256 self.engine.loop.start()
257 except KeyboardInterrupt:
257 except KeyboardInterrupt:
258 self.log.critical("Engine Interrupted, shutting down...\n")
258 self.log.critical("Engine Interrupted, shutting down...\n")
259
259
260
260
261 def launch_new_instance():
261 def launch_new_instance():
262 """Create and run the IPython engine"""
262 """Create and run the IPython engine"""
263 app = IPEngineApp.instance()
263 app = IPEngineApp.instance()
264 app.initialize()
264 app.initialize()
265 app.start()
265 app.start()
266
266
267
267
268 if __name__ == '__main__':
268 if __name__ == '__main__':
269 launch_new_instance()
269 launch_new_instance()
270
270
@@ -1,1356 +1,1353 b''
1 """A semi-synchronous Client for the ZMQ cluster"""
1 """A semi-synchronous Client for the ZMQ cluster"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import json
14 import json
15 import time
15 import time
16 import warnings
16 import warnings
17 from datetime import datetime
17 from datetime import datetime
18 from getpass import getpass
18 from getpass import getpass
19 from pprint import pprint
19 from pprint import pprint
20
20
21 pjoin = os.path.join
21 pjoin = os.path.join
22
22
23 import zmq
23 import zmq
24 # from zmq.eventloop import ioloop, zmqstream
24 # from zmq.eventloop import ioloop, zmqstream
25
25
26 from IPython.utils.jsonutil import extract_dates
26 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.path import get_ipython_dir
27 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 Dict, List, Bool, Set)
29 Dict, List, Bool, Set)
29 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
30 from IPython.external.ssh import tunnel
31 from IPython.external.ssh import tunnel
31
32
32 from IPython.parallel import error
33 from IPython.parallel import error
33 from IPython.parallel import streamsession as ss
34 from IPython.parallel import util
34 from IPython.parallel import util
35
35
36 from IPython.zmq.session import Session, Message
37
36 from .asyncresult import AsyncResult, AsyncHubResult
38 from .asyncresult import AsyncResult, AsyncHubResult
37 from IPython.core.newapplication import ProfileDir, ProfileDirError
39 from IPython.core.newapplication import ProfileDir, ProfileDirError
38 from .view import DirectView, LoadBalancedView
40 from .view import DirectView, LoadBalancedView
39
41
40 #--------------------------------------------------------------------------
42 #--------------------------------------------------------------------------
41 # Decorators for Client methods
43 # Decorators for Client methods
42 #--------------------------------------------------------------------------
44 #--------------------------------------------------------------------------
43
45
44 @decorator
46 @decorator
45 def spin_first(f, self, *args, **kwargs):
47 def spin_first(f, self, *args, **kwargs):
46 """Call spin() to sync state prior to calling the method."""
48 """Call spin() to sync state prior to calling the method."""
47 self.spin()
49 self.spin()
48 return f(self, *args, **kwargs)
50 return f(self, *args, **kwargs)
49
51
50
52
51 #--------------------------------------------------------------------------
53 #--------------------------------------------------------------------------
52 # Classes
54 # Classes
53 #--------------------------------------------------------------------------
55 #--------------------------------------------------------------------------
54
56
55 class Metadata(dict):
57 class Metadata(dict):
56 """Subclass of dict for initializing metadata values.
58 """Subclass of dict for initializing metadata values.
57
59
58 Attribute access works on keys.
60 Attribute access works on keys.
59
61
60 These objects have a strict set of keys - errors will raise if you try
62 These objects have a strict set of keys - errors will raise if you try
61 to add new keys.
63 to add new keys.
62 """
64 """
63 def __init__(self, *args, **kwargs):
65 def __init__(self, *args, **kwargs):
64 dict.__init__(self)
66 dict.__init__(self)
65 md = {'msg_id' : None,
67 md = {'msg_id' : None,
66 'submitted' : None,
68 'submitted' : None,
67 'started' : None,
69 'started' : None,
68 'completed' : None,
70 'completed' : None,
69 'received' : None,
71 'received' : None,
70 'engine_uuid' : None,
72 'engine_uuid' : None,
71 'engine_id' : None,
73 'engine_id' : None,
72 'follow' : None,
74 'follow' : None,
73 'after' : None,
75 'after' : None,
74 'status' : None,
76 'status' : None,
75
77
76 'pyin' : None,
78 'pyin' : None,
77 'pyout' : None,
79 'pyout' : None,
78 'pyerr' : None,
80 'pyerr' : None,
79 'stdout' : '',
81 'stdout' : '',
80 'stderr' : '',
82 'stderr' : '',
81 }
83 }
82 self.update(md)
84 self.update(md)
83 self.update(dict(*args, **kwargs))
85 self.update(dict(*args, **kwargs))
84
86
85 def __getattr__(self, key):
87 def __getattr__(self, key):
86 """getattr aliased to getitem"""
88 """getattr aliased to getitem"""
87 if key in self.iterkeys():
89 if key in self.iterkeys():
88 return self[key]
90 return self[key]
89 else:
91 else:
90 raise AttributeError(key)
92 raise AttributeError(key)
91
93
92 def __setattr__(self, key, value):
94 def __setattr__(self, key, value):
93 """setattr aliased to setitem, with strict"""
95 """setattr aliased to setitem, with strict"""
94 if key in self.iterkeys():
96 if key in self.iterkeys():
95 self[key] = value
97 self[key] = value
96 else:
98 else:
97 raise AttributeError(key)
99 raise AttributeError(key)
98
100
99 def __setitem__(self, key, value):
101 def __setitem__(self, key, value):
100 """strict static key enforcement"""
102 """strict static key enforcement"""
101 if key in self.iterkeys():
103 if key in self.iterkeys():
102 dict.__setitem__(self, key, value)
104 dict.__setitem__(self, key, value)
103 else:
105 else:
104 raise KeyError(key)
106 raise KeyError(key)
105
107
106
108
107 class Client(HasTraits):
109 class Client(HasTraits):
108 """A semi-synchronous client to the IPython ZMQ cluster
110 """A semi-synchronous client to the IPython ZMQ cluster
109
111
110 Parameters
112 Parameters
111 ----------
113 ----------
112
114
113 url_or_file : bytes; zmq url or path to ipcontroller-client.json
115 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 Connection information for the Hub's registration. If a json connector
116 Connection information for the Hub's registration. If a json connector
115 file is given, then likely no further configuration is necessary.
117 file is given, then likely no further configuration is necessary.
116 [Default: use profile]
118 [Default: use profile]
117 profile : bytes
119 profile : bytes
118 The name of the Cluster profile to be used to find connector information.
120 The name of the Cluster profile to be used to find connector information.
119 [Default: 'default']
121 [Default: 'default']
120 context : zmq.Context
122 context : zmq.Context
121 Pass an existing zmq.Context instance, otherwise the client will create its own.
123 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 username : bytes
124 username : bytes
123 set username to be passed to the Session object
125 set username to be passed to the Session object
124 debug : bool
126 debug : bool
125 flag for lots of message printing for debug purposes
127 flag for lots of message printing for debug purposes
126
128
127 #-------------- ssh related args ----------------
129 #-------------- ssh related args ----------------
128 # These are args for configuring the ssh tunnel to be used
130 # These are args for configuring the ssh tunnel to be used
129 # credentials are used to forward connections over ssh to the Controller
131 # credentials are used to forward connections over ssh to the Controller
130 # Note that the ip given in `addr` needs to be relative to sshserver
132 # Note that the ip given in `addr` needs to be relative to sshserver
131 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
133 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 # and set sshserver as the same machine the Controller is on. However,
134 # and set sshserver as the same machine the Controller is on. However,
133 # the only requirement is that sshserver is able to see the Controller
135 # the only requirement is that sshserver is able to see the Controller
134 # (i.e. is within the same trusted network).
136 # (i.e. is within the same trusted network).
135
137
136 sshserver : str
138 sshserver : str
137 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
139 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 If keyfile or password is specified, and this is not, it will default to
140 If keyfile or password is specified, and this is not, it will default to
139 the ip given in addr.
141 the ip given in addr.
140 sshkey : str; path to public ssh key file
142 sshkey : str; path to public ssh key file
141 This specifies a key to be used in ssh login, default None.
143 This specifies a key to be used in ssh login, default None.
142 Regular default ssh keys will be used without specifying this argument.
144 Regular default ssh keys will be used without specifying this argument.
143 password : str
145 password : str
144 Your ssh password to sshserver. Note that if this is left None,
146 Your ssh password to sshserver. Note that if this is left None,
145 you will be prompted for it if passwordless key based login is unavailable.
147 you will be prompted for it if passwordless key based login is unavailable.
146 paramiko : bool
148 paramiko : bool
147 flag for whether to use paramiko instead of shell ssh for tunneling.
149 flag for whether to use paramiko instead of shell ssh for tunneling.
148 [default: True on win32, False else]
150 [default: True on win32, False else]
149
151
150 ------- exec authentication args -------
152 ------- exec authentication args -------
151 If even localhost is untrusted, you can have some protection against
153 If even localhost is untrusted, you can have some protection against
152 unauthorized execution by using a key. Messages are still sent
154 unauthorized execution by using a key. Messages are still sent
153 as cleartext, so if someone can snoop your loopback traffic this will
155 as cleartext, so if someone can snoop your loopback traffic this will
154 not help against malicious attacks.
156 not help against malicious attacks.
155
157
156 exec_key : str
158 exec_key : str
157 an authentication key or file containing a key
159 an authentication key or file containing a key
158 default: None
160 default: None
159
161
160
162
161 Attributes
163 Attributes
162 ----------
164 ----------
163
165
164 ids : list of int engine IDs
166 ids : list of int engine IDs
165 requesting the ids attribute always synchronizes
167 requesting the ids attribute always synchronizes
166 the registration state. To request ids without synchronization,
168 the registration state. To request ids without synchronization,
167 use semi-private _ids attributes.
169 use semi-private _ids attributes.
168
170
169 history : list of msg_ids
171 history : list of msg_ids
170 a list of msg_ids, keeping track of all the execution
172 a list of msg_ids, keeping track of all the execution
171 messages you have submitted in order.
173 messages you have submitted in order.
172
174
173 outstanding : set of msg_ids
175 outstanding : set of msg_ids
174 a set of msg_ids that have been submitted, but whose
176 a set of msg_ids that have been submitted, but whose
175 results have not yet been received.
177 results have not yet been received.
176
178
177 results : dict
179 results : dict
178 a dict of all our results, keyed by msg_id
180 a dict of all our results, keyed by msg_id
179
181
180 block : bool
182 block : bool
181 determines default behavior when block not specified
183 determines default behavior when block not specified
182 in execution methods
184 in execution methods
183
185
184 Methods
186 Methods
185 -------
187 -------
186
188
187 spin
189 spin
188 flushes incoming results and registration state changes
190 flushes incoming results and registration state changes
189 control methods spin, and requesting `ids` also ensures up to date
191 control methods spin, and requesting `ids` also ensures up to date
190
192
191 wait
193 wait
192 wait on one or more msg_ids
194 wait on one or more msg_ids
193
195
194 execution methods
196 execution methods
195 apply
197 apply
196 legacy: execute, run
198 legacy: execute, run
197
199
198 data movement
200 data movement
199 push, pull, scatter, gather
201 push, pull, scatter, gather
200
202
201 query methods
203 query methods
202 queue_status, get_result, purge, result_status
204 queue_status, get_result, purge, result_status
203
205
204 control methods
206 control methods
205 abort, shutdown
207 abort, shutdown
206
208
207 """
209 """
208
210
209
211
210 block = Bool(False)
212 block = Bool(False)
211 outstanding = Set()
213 outstanding = Set()
212 results = Instance('collections.defaultdict', (dict,))
214 results = Instance('collections.defaultdict', (dict,))
213 metadata = Instance('collections.defaultdict', (Metadata,))
215 metadata = Instance('collections.defaultdict', (Metadata,))
214 history = List()
216 history = List()
215 debug = Bool(False)
217 debug = Bool(False)
216 profile=Unicode('default')
218 profile=Unicode('default')
217
219
218 _outstanding_dict = Instance('collections.defaultdict', (set,))
220 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 _ids = List()
221 _ids = List()
220 _connected=Bool(False)
222 _connected=Bool(False)
221 _ssh=Bool(False)
223 _ssh=Bool(False)
222 _context = Instance('zmq.Context')
224 _context = Instance('zmq.Context')
223 _config = Dict()
225 _config = Dict()
224 _engines=Instance(util.ReverseDict, (), {})
226 _engines=Instance(util.ReverseDict, (), {})
225 # _hub_socket=Instance('zmq.Socket')
227 # _hub_socket=Instance('zmq.Socket')
226 _query_socket=Instance('zmq.Socket')
228 _query_socket=Instance('zmq.Socket')
227 _control_socket=Instance('zmq.Socket')
229 _control_socket=Instance('zmq.Socket')
228 _iopub_socket=Instance('zmq.Socket')
230 _iopub_socket=Instance('zmq.Socket')
229 _notification_socket=Instance('zmq.Socket')
231 _notification_socket=Instance('zmq.Socket')
230 _mux_socket=Instance('zmq.Socket')
232 _mux_socket=Instance('zmq.Socket')
231 _task_socket=Instance('zmq.Socket')
233 _task_socket=Instance('zmq.Socket')
232 _task_scheme=Unicode()
234 _task_scheme=Unicode()
233 _closed = False
235 _closed = False
234 _ignored_control_replies=Int(0)
236 _ignored_control_replies=Int(0)
235 _ignored_hub_replies=Int(0)
237 _ignored_hub_replies=Int(0)
236
238
237 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
239 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
238 context=None, username=None, debug=False, exec_key=None,
240 context=None, username=None, debug=False, exec_key=None,
239 sshserver=None, sshkey=None, password=None, paramiko=None,
241 sshserver=None, sshkey=None, password=None, paramiko=None,
240 timeout=10
242 timeout=10
241 ):
243 ):
242 super(Client, self).__init__(debug=debug, profile=profile)
244 super(Client, self).__init__(debug=debug, profile=profile)
243 if context is None:
245 if context is None:
244 context = zmq.Context.instance()
246 context = zmq.Context.instance()
245 self._context = context
247 self._context = context
246
248
247
249
248 self._setup_profile_dir(profile, profile_dir, ipython_dir)
250 self._setup_profile_dir(profile, profile_dir, ipython_dir)
249 if self._cd is not None:
251 if self._cd is not None:
250 if url_or_file is None:
252 if url_or_file is None:
251 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
253 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
254 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 " Please specify at least one of url_or_file or profile."
255 " Please specify at least one of url_or_file or profile."
254
256
255 try:
257 try:
256 util.validate_url(url_or_file)
258 util.validate_url(url_or_file)
257 except AssertionError:
259 except AssertionError:
258 if not os.path.exists(url_or_file):
260 if not os.path.exists(url_or_file):
259 if self._cd:
261 if self._cd:
260 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
262 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
263 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 with open(url_or_file) as f:
264 with open(url_or_file) as f:
263 cfg = json.loads(f.read())
265 cfg = json.loads(f.read())
264 else:
266 else:
265 cfg = {'url':url_or_file}
267 cfg = {'url':url_or_file}
266
268
267 # sync defaults from args, json:
269 # sync defaults from args, json:
268 if sshserver:
270 if sshserver:
269 cfg['ssh'] = sshserver
271 cfg['ssh'] = sshserver
270 if exec_key:
272 if exec_key:
271 cfg['exec_key'] = exec_key
273 cfg['exec_key'] = exec_key
272 exec_key = cfg['exec_key']
274 exec_key = cfg['exec_key']
273 sshserver=cfg['ssh']
275 sshserver=cfg['ssh']
274 url = cfg['url']
276 url = cfg['url']
275 location = cfg.setdefault('location', None)
277 location = cfg.setdefault('location', None)
276 cfg['url'] = util.disambiguate_url(cfg['url'], location)
278 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 url = cfg['url']
279 url = cfg['url']
278
280
279 self._config = cfg
281 self._config = cfg
280
282
281 self._ssh = bool(sshserver or sshkey or password)
283 self._ssh = bool(sshserver or sshkey or password)
282 if self._ssh and sshserver is None:
284 if self._ssh and sshserver is None:
283 # default to ssh via localhost
285 # default to ssh via localhost
284 sshserver = url.split('://')[1].split(':')[0]
286 sshserver = url.split('://')[1].split(':')[0]
285 if self._ssh and password is None:
287 if self._ssh and password is None:
286 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
288 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 password=False
289 password=False
288 else:
290 else:
289 password = getpass("SSH Password for %s: "%sshserver)
291 password = getpass("SSH Password for %s: "%sshserver)
290 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
292 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 if exec_key is not None and os.path.isfile(exec_key):
293 if exec_key is not None and os.path.isfile(exec_key):
292 arg = 'keyfile'
294 arg = 'keyfile'
293 else:
295 else:
294 arg = 'key'
296 arg = 'key'
295 key_arg = {arg:exec_key}
297 key_arg = {arg:exec_key}
296 if username is None:
298 if username is None:
297 self.session = ss.StreamSession(**key_arg)
299 self.session = Session(**key_arg)
298 else:
300 else:
299 self.session = ss.StreamSession(username=username, **key_arg)
301 self.session = Session(username=username, **key_arg)
300 self._query_socket = self._context.socket(zmq.XREQ)
302 self._query_socket = self._context.socket(zmq.XREQ)
301 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
303 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 if self._ssh:
304 if self._ssh:
303 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
305 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 else:
306 else:
305 self._query_socket.connect(url)
307 self._query_socket.connect(url)
306
308
307 self.session.debug = self.debug
309 self.session.debug = self.debug
308
310
309 self._notification_handlers = {'registration_notification' : self._register_engine,
311 self._notification_handlers = {'registration_notification' : self._register_engine,
310 'unregistration_notification' : self._unregister_engine,
312 'unregistration_notification' : self._unregister_engine,
311 'shutdown_notification' : lambda msg: self.close(),
313 'shutdown_notification' : lambda msg: self.close(),
312 }
314 }
313 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
315 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 'apply_reply' : self._handle_apply_reply}
316 'apply_reply' : self._handle_apply_reply}
315 self._connect(sshserver, ssh_kwargs, timeout)
317 self._connect(sshserver, ssh_kwargs, timeout)
316
318
317 def __del__(self):
319 def __del__(self):
318 """cleanup sockets, but _not_ context."""
320 """cleanup sockets, but _not_ context."""
319 self.close()
321 self.close()
320
322
321 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
323 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
322 if ipython_dir is None:
324 if ipython_dir is None:
323 ipython_dir = get_ipython_dir()
325 ipython_dir = get_ipython_dir()
324 if profile_dir is not None:
326 if profile_dir is not None:
325 try:
327 try:
326 self._cd = ProfileDir.find_profile_dir(profile_dir)
328 self._cd = ProfileDir.find_profile_dir(profile_dir)
327 return
329 return
328 except ProfileDirError:
330 except ProfileDirError:
329 pass
331 pass
330 elif profile is not None:
332 elif profile is not None:
331 try:
333 try:
332 self._cd = ProfileDir.find_profile_dir_by_name(
334 self._cd = ProfileDir.find_profile_dir_by_name(
333 ipython_dir, profile)
335 ipython_dir, profile)
334 return
336 return
335 except ProfileDirError:
337 except ProfileDirError:
336 pass
338 pass
337 self._cd = None
339 self._cd = None
338
340
339 def _update_engines(self, engines):
341 def _update_engines(self, engines):
340 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
342 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 for k,v in engines.iteritems():
343 for k,v in engines.iteritems():
342 eid = int(k)
344 eid = int(k)
343 self._engines[eid] = bytes(v) # force not unicode
345 self._engines[eid] = bytes(v) # force not unicode
344 self._ids.append(eid)
346 self._ids.append(eid)
345 self._ids = sorted(self._ids)
347 self._ids = sorted(self._ids)
346 if sorted(self._engines.keys()) != range(len(self._engines)) and \
348 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 self._task_scheme == 'pure' and self._task_socket:
349 self._task_scheme == 'pure' and self._task_socket:
348 self._stop_scheduling_tasks()
350 self._stop_scheduling_tasks()
349
351
350 def _stop_scheduling_tasks(self):
352 def _stop_scheduling_tasks(self):
351 """Stop scheduling tasks because an engine has been unregistered
353 """Stop scheduling tasks because an engine has been unregistered
352 from a pure ZMQ scheduler.
354 from a pure ZMQ scheduler.
353 """
355 """
354 self._task_socket.close()
356 self._task_socket.close()
355 self._task_socket = None
357 self._task_socket = None
356 msg = "An engine has been unregistered, and we are using pure " +\
358 msg = "An engine has been unregistered, and we are using pure " +\
357 "ZMQ task scheduling. Task farming will be disabled."
359 "ZMQ task scheduling. Task farming will be disabled."
358 if self.outstanding:
360 if self.outstanding:
359 msg += " If you were running tasks when this happened, " +\
361 msg += " If you were running tasks when this happened, " +\
360 "some `outstanding` msg_ids may never resolve."
362 "some `outstanding` msg_ids may never resolve."
361 warnings.warn(msg, RuntimeWarning)
363 warnings.warn(msg, RuntimeWarning)
362
364
363 def _build_targets(self, targets):
365 def _build_targets(self, targets):
364 """Turn valid target IDs or 'all' into two lists:
366 """Turn valid target IDs or 'all' into two lists:
365 (int_ids, uuids).
367 (int_ids, uuids).
366 """
368 """
367 if not self._ids:
369 if not self._ids:
368 # flush notification socket if no engines yet, just in case
370 # flush notification socket if no engines yet, just in case
369 if not self.ids:
371 if not self.ids:
370 raise error.NoEnginesRegistered("Can't build targets without any engines")
372 raise error.NoEnginesRegistered("Can't build targets without any engines")
371
373
372 if targets is None:
374 if targets is None:
373 targets = self._ids
375 targets = self._ids
374 elif isinstance(targets, str):
376 elif isinstance(targets, str):
375 if targets.lower() == 'all':
377 if targets.lower() == 'all':
376 targets = self._ids
378 targets = self._ids
377 else:
379 else:
378 raise TypeError("%r not valid str target, must be 'all'"%(targets))
380 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 elif isinstance(targets, int):
381 elif isinstance(targets, int):
380 if targets < 0:
382 if targets < 0:
381 targets = self.ids[targets]
383 targets = self.ids[targets]
382 if targets not in self._ids:
384 if targets not in self._ids:
383 raise IndexError("No such engine: %i"%targets)
385 raise IndexError("No such engine: %i"%targets)
384 targets = [targets]
386 targets = [targets]
385
387
386 if isinstance(targets, slice):
388 if isinstance(targets, slice):
387 indices = range(len(self._ids))[targets]
389 indices = range(len(self._ids))[targets]
388 ids = self.ids
390 ids = self.ids
389 targets = [ ids[i] for i in indices ]
391 targets = [ ids[i] for i in indices ]
390
392
391 if not isinstance(targets, (tuple, list, xrange)):
393 if not isinstance(targets, (tuple, list, xrange)):
392 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
394 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393
395
394 return [self._engines[t] for t in targets], list(targets)
396 return [self._engines[t] for t in targets], list(targets)
395
397
396 def _connect(self, sshserver, ssh_kwargs, timeout):
398 def _connect(self, sshserver, ssh_kwargs, timeout):
397 """setup all our socket connections to the cluster. This is called from
399 """setup all our socket connections to the cluster. This is called from
398 __init__."""
400 __init__."""
399
401
400 # Maybe allow reconnecting?
402 # Maybe allow reconnecting?
401 if self._connected:
403 if self._connected:
402 return
404 return
403 self._connected=True
405 self._connected=True
404
406
405 def connect_socket(s, url):
407 def connect_socket(s, url):
406 url = util.disambiguate_url(url, self._config['location'])
408 url = util.disambiguate_url(url, self._config['location'])
407 if self._ssh:
409 if self._ssh:
408 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
410 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 else:
411 else:
410 return s.connect(url)
412 return s.connect(url)
411
413
412 self.session.send(self._query_socket, 'connection_request')
414 self.session.send(self._query_socket, 'connection_request')
413 r,w,x = zmq.select([self._query_socket],[],[], timeout)
415 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 if not r:
416 if not r:
415 raise error.TimeoutError("Hub connection request timed out")
417 raise error.TimeoutError("Hub connection request timed out")
416 idents,msg = self.session.recv(self._query_socket,mode=0)
418 idents,msg = self.session.recv(self._query_socket,mode=0)
417 if self.debug:
419 if self.debug:
418 pprint(msg)
420 pprint(msg)
419 msg = ss.Message(msg)
421 msg = Message(msg)
420 content = msg.content
422 content = msg.content
421 self._config['registration'] = dict(content)
423 self._config['registration'] = dict(content)
422 if content.status == 'ok':
424 if content.status == 'ok':
423 if content.mux:
425 if content.mux:
424 self._mux_socket = self._context.socket(zmq.XREQ)
426 self._mux_socket = self._context.socket(zmq.XREQ)
425 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
427 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 connect_socket(self._mux_socket, content.mux)
428 connect_socket(self._mux_socket, content.mux)
427 if content.task:
429 if content.task:
428 self._task_scheme, task_addr = content.task
430 self._task_scheme, task_addr = content.task
429 self._task_socket = self._context.socket(zmq.XREQ)
431 self._task_socket = self._context.socket(zmq.XREQ)
430 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
432 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 connect_socket(self._task_socket, task_addr)
433 connect_socket(self._task_socket, task_addr)
432 if content.notification:
434 if content.notification:
433 self._notification_socket = self._context.socket(zmq.SUB)
435 self._notification_socket = self._context.socket(zmq.SUB)
434 connect_socket(self._notification_socket, content.notification)
436 connect_socket(self._notification_socket, content.notification)
435 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
437 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 # if content.query:
438 # if content.query:
437 # self._query_socket = self._context.socket(zmq.XREQ)
439 # self._query_socket = self._context.socket(zmq.XREQ)
438 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
440 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 # connect_socket(self._query_socket, content.query)
441 # connect_socket(self._query_socket, content.query)
440 if content.control:
442 if content.control:
441 self._control_socket = self._context.socket(zmq.XREQ)
443 self._control_socket = self._context.socket(zmq.XREQ)
442 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
444 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 connect_socket(self._control_socket, content.control)
445 connect_socket(self._control_socket, content.control)
444 if content.iopub:
446 if content.iopub:
445 self._iopub_socket = self._context.socket(zmq.SUB)
447 self._iopub_socket = self._context.socket(zmq.SUB)
446 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
448 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 connect_socket(self._iopub_socket, content.iopub)
450 connect_socket(self._iopub_socket, content.iopub)
449 self._update_engines(dict(content.engines))
451 self._update_engines(dict(content.engines))
450 else:
452 else:
451 self._connected = False
453 self._connected = False
452 raise Exception("Failed to connect!")
454 raise Exception("Failed to connect!")
453
455
454 #--------------------------------------------------------------------------
456 #--------------------------------------------------------------------------
455 # handlers and callbacks for incoming messages
457 # handlers and callbacks for incoming messages
456 #--------------------------------------------------------------------------
458 #--------------------------------------------------------------------------
457
459
458 def _unwrap_exception(self, content):
460 def _unwrap_exception(self, content):
459 """unwrap exception, and remap engine_id to int."""
461 """unwrap exception, and remap engine_id to int."""
460 e = error.unwrap_exception(content)
462 e = error.unwrap_exception(content)
461 # print e.traceback
463 # print e.traceback
462 if e.engine_info:
464 if e.engine_info:
463 e_uuid = e.engine_info['engine_uuid']
465 e_uuid = e.engine_info['engine_uuid']
464 eid = self._engines[e_uuid]
466 eid = self._engines[e_uuid]
465 e.engine_info['engine_id'] = eid
467 e.engine_info['engine_id'] = eid
466 return e
468 return e
467
469
468 def _extract_metadata(self, header, parent, content):
470 def _extract_metadata(self, header, parent, content):
469 md = {'msg_id' : parent['msg_id'],
471 md = {'msg_id' : parent['msg_id'],
470 'received' : datetime.now(),
472 'received' : datetime.now(),
471 'engine_uuid' : header.get('engine', None),
473 'engine_uuid' : header.get('engine', None),
472 'follow' : parent.get('follow', []),
474 'follow' : parent.get('follow', []),
473 'after' : parent.get('after', []),
475 'after' : parent.get('after', []),
474 'status' : content['status'],
476 'status' : content['status'],
475 }
477 }
476
478
477 if md['engine_uuid'] is not None:
479 if md['engine_uuid'] is not None:
478 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
480 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479
481
480 if 'date' in parent:
482 if 'date' in parent:
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
483 md['submitted'] = parent['date']
482 if 'started' in header:
484 if 'started' in header:
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
485 md['started'] = header['started']
484 if 'date' in header:
486 if 'date' in header:
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
487 md['completed'] = header['date']
486 return md
488 return md
487
489
488 def _register_engine(self, msg):
490 def _register_engine(self, msg):
489 """Register a new engine, and update our connection info."""
491 """Register a new engine, and update our connection info."""
490 content = msg['content']
492 content = msg['content']
491 eid = content['id']
493 eid = content['id']
492 d = {eid : content['queue']}
494 d = {eid : content['queue']}
493 self._update_engines(d)
495 self._update_engines(d)
494
496
495 def _unregister_engine(self, msg):
497 def _unregister_engine(self, msg):
496 """Unregister an engine that has died."""
498 """Unregister an engine that has died."""
497 content = msg['content']
499 content = msg['content']
498 eid = int(content['id'])
500 eid = int(content['id'])
499 if eid in self._ids:
501 if eid in self._ids:
500 self._ids.remove(eid)
502 self._ids.remove(eid)
501 uuid = self._engines.pop(eid)
503 uuid = self._engines.pop(eid)
502
504
503 self._handle_stranded_msgs(eid, uuid)
505 self._handle_stranded_msgs(eid, uuid)
504
506
505 if self._task_socket and self._task_scheme == 'pure':
507 if self._task_socket and self._task_scheme == 'pure':
506 self._stop_scheduling_tasks()
508 self._stop_scheduling_tasks()
507
509
508 def _handle_stranded_msgs(self, eid, uuid):
510 def _handle_stranded_msgs(self, eid, uuid):
509 """Handle messages known to be on an engine when the engine unregisters.
511 """Handle messages known to be on an engine when the engine unregisters.
510
512
511 It is possible that this will fire prematurely - that is, an engine will
513 It is possible that this will fire prematurely - that is, an engine will
512 go down after completing a result, and the client will be notified
514 go down after completing a result, and the client will be notified
513 of the unregistration and later receive the successful result.
515 of the unregistration and later receive the successful result.
514 """
516 """
515
517
516 outstanding = self._outstanding_dict[uuid]
518 outstanding = self._outstanding_dict[uuid]
517
519
518 for msg_id in list(outstanding):
520 for msg_id in list(outstanding):
519 if msg_id in self.results:
521 if msg_id in self.results:
520 # we already
522 # we already
521 continue
523 continue
522 try:
524 try:
523 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
525 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 except:
526 except:
525 content = error.wrap_exception()
527 content = error.wrap_exception()
526 # build a fake message:
528 # build a fake message:
527 parent = {}
529 parent = {}
528 header = {}
530 header = {}
529 parent['msg_id'] = msg_id
531 parent['msg_id'] = msg_id
530 header['engine'] = uuid
532 header['engine'] = uuid
531 header['date'] = datetime.now().strftime(util.ISO8601)
533 header['date'] = datetime.now()
532 msg = dict(parent_header=parent, header=header, content=content)
534 msg = dict(parent_header=parent, header=header, content=content)
533 self._handle_apply_reply(msg)
535 self._handle_apply_reply(msg)
534
536
535 def _handle_execute_reply(self, msg):
537 def _handle_execute_reply(self, msg):
536 """Save the reply to an execute_request into our results.
538 """Save the reply to an execute_request into our results.
537
539
538 execute messages are never actually used. apply is used instead.
540 execute messages are never actually used. apply is used instead.
539 """
541 """
540
542
541 parent = msg['parent_header']
543 parent = msg['parent_header']
542 msg_id = parent['msg_id']
544 msg_id = parent['msg_id']
543 if msg_id not in self.outstanding:
545 if msg_id not in self.outstanding:
544 if msg_id in self.history:
546 if msg_id in self.history:
545 print ("got stale result: %s"%msg_id)
547 print ("got stale result: %s"%msg_id)
546 else:
548 else:
547 print ("got unknown result: %s"%msg_id)
549 print ("got unknown result: %s"%msg_id)
548 else:
550 else:
549 self.outstanding.remove(msg_id)
551 self.outstanding.remove(msg_id)
550 self.results[msg_id] = self._unwrap_exception(msg['content'])
552 self.results[msg_id] = self._unwrap_exception(msg['content'])
551
553
552 def _handle_apply_reply(self, msg):
554 def _handle_apply_reply(self, msg):
553 """Save the reply to an apply_request into our results."""
555 """Save the reply to an apply_request into our results."""
554 parent = msg['parent_header']
556 parent = extract_dates(msg['parent_header'])
555 msg_id = parent['msg_id']
557 msg_id = parent['msg_id']
556 if msg_id not in self.outstanding:
558 if msg_id not in self.outstanding:
557 if msg_id in self.history:
559 if msg_id in self.history:
558 print ("got stale result: %s"%msg_id)
560 print ("got stale result: %s"%msg_id)
559 print self.results[msg_id]
561 print self.results[msg_id]
560 print msg
562 print msg
561 else:
563 else:
562 print ("got unknown result: %s"%msg_id)
564 print ("got unknown result: %s"%msg_id)
563 else:
565 else:
564 self.outstanding.remove(msg_id)
566 self.outstanding.remove(msg_id)
565 content = msg['content']
567 content = msg['content']
566 header = msg['header']
568 header = extract_dates(msg['header'])
567
569
568 # construct metadata:
570 # construct metadata:
569 md = self.metadata[msg_id]
571 md = self.metadata[msg_id]
570 md.update(self._extract_metadata(header, parent, content))
572 md.update(self._extract_metadata(header, parent, content))
571 # is this redundant?
573 # is this redundant?
572 self.metadata[msg_id] = md
574 self.metadata[msg_id] = md
573
575
574 e_outstanding = self._outstanding_dict[md['engine_uuid']]
576 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 if msg_id in e_outstanding:
577 if msg_id in e_outstanding:
576 e_outstanding.remove(msg_id)
578 e_outstanding.remove(msg_id)
577
579
578 # construct result:
580 # construct result:
579 if content['status'] == 'ok':
581 if content['status'] == 'ok':
580 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
582 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 elif content['status'] == 'aborted':
583 elif content['status'] == 'aborted':
582 self.results[msg_id] = error.TaskAborted(msg_id)
584 self.results[msg_id] = error.TaskAborted(msg_id)
583 elif content['status'] == 'resubmitted':
585 elif content['status'] == 'resubmitted':
584 # TODO: handle resubmission
586 # TODO: handle resubmission
585 pass
587 pass
586 else:
588 else:
587 self.results[msg_id] = self._unwrap_exception(content)
589 self.results[msg_id] = self._unwrap_exception(content)
588
590
589 def _flush_notifications(self):
591 def _flush_notifications(self):
590 """Flush notifications of engine registrations waiting
592 """Flush notifications of engine registrations waiting
591 in ZMQ queue."""
593 in ZMQ queue."""
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
594 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 while msg is not None:
595 while msg is not None:
594 if self.debug:
596 if self.debug:
595 pprint(msg)
597 pprint(msg)
596 msg = msg[-1]
597 msg_type = msg['msg_type']
598 msg_type = msg['msg_type']
598 handler = self._notification_handlers.get(msg_type, None)
599 handler = self._notification_handlers.get(msg_type, None)
599 if handler is None:
600 if handler is None:
600 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 else:
602 else:
602 handler(msg)
603 handler(msg)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604
605
605 def _flush_results(self, sock):
606 def _flush_results(self, sock):
606 """Flush task or queue results waiting in ZMQ queue."""
607 """Flush task or queue results waiting in ZMQ queue."""
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 while msg is not None:
609 while msg is not None:
609 if self.debug:
610 if self.debug:
610 pprint(msg)
611 pprint(msg)
611 msg = msg[-1]
612 msg_type = msg['msg_type']
612 msg_type = msg['msg_type']
613 handler = self._queue_handlers.get(msg_type, None)
613 handler = self._queue_handlers.get(msg_type, None)
614 if handler is None:
614 if handler is None:
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 else:
616 else:
617 handler(msg)
617 handler(msg)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619
619
620 def _flush_control(self, sock):
620 def _flush_control(self, sock):
621 """Flush replies from the control channel waiting
621 """Flush replies from the control channel waiting
622 in the ZMQ queue.
622 in the ZMQ queue.
623
623
624 Currently: ignore them."""
624 Currently: ignore them."""
625 if self._ignored_control_replies <= 0:
625 if self._ignored_control_replies <= 0:
626 return
626 return
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 while msg is not None:
628 while msg is not None:
629 self._ignored_control_replies -= 1
629 self._ignored_control_replies -= 1
630 if self.debug:
630 if self.debug:
631 pprint(msg)
631 pprint(msg)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633
633
634 def _flush_ignored_control(self):
634 def _flush_ignored_control(self):
635 """flush ignored control replies"""
635 """flush ignored control replies"""
636 while self._ignored_control_replies > 0:
636 while self._ignored_control_replies > 0:
637 self.session.recv(self._control_socket)
637 self.session.recv(self._control_socket)
638 self._ignored_control_replies -= 1
638 self._ignored_control_replies -= 1
639
639
640 def _flush_ignored_hub_replies(self):
640 def _flush_ignored_hub_replies(self):
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 while msg is not None:
642 while msg is not None:
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644
644
645 def _flush_iopub(self, sock):
645 def _flush_iopub(self, sock):
646 """Flush replies from the iopub channel waiting
646 """Flush replies from the iopub channel waiting
647 in the ZMQ queue.
647 in the ZMQ queue.
648 """
648 """
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 while msg is not None:
650 while msg is not None:
651 if self.debug:
651 if self.debug:
652 pprint(msg)
652 pprint(msg)
653 msg = msg[-1]
654 parent = msg['parent_header']
653 parent = msg['parent_header']
655 msg_id = parent['msg_id']
654 msg_id = parent['msg_id']
656 content = msg['content']
655 content = msg['content']
657 header = msg['header']
656 header = msg['header']
658 msg_type = msg['msg_type']
657 msg_type = msg['msg_type']
659
658
660 # init metadata:
659 # init metadata:
661 md = self.metadata[msg_id]
660 md = self.metadata[msg_id]
662
661
663 if msg_type == 'stream':
662 if msg_type == 'stream':
664 name = content['name']
663 name = content['name']
665 s = md[name] or ''
664 s = md[name] or ''
666 md[name] = s + content['data']
665 md[name] = s + content['data']
667 elif msg_type == 'pyerr':
666 elif msg_type == 'pyerr':
668 md.update({'pyerr' : self._unwrap_exception(content)})
667 md.update({'pyerr' : self._unwrap_exception(content)})
669 elif msg_type == 'pyin':
668 elif msg_type == 'pyin':
670 md.update({'pyin' : content['code']})
669 md.update({'pyin' : content['code']})
671 else:
670 else:
672 md.update({msg_type : content.get('data', '')})
671 md.update({msg_type : content.get('data', '')})
673
672
674 # reduntant?
673 # reduntant?
675 self.metadata[msg_id] = md
674 self.metadata[msg_id] = md
676
675
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
676 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678
677
679 #--------------------------------------------------------------------------
678 #--------------------------------------------------------------------------
680 # len, getitem
679 # len, getitem
681 #--------------------------------------------------------------------------
680 #--------------------------------------------------------------------------
682
681
683 def __len__(self):
682 def __len__(self):
684 """len(client) returns # of engines."""
683 """len(client) returns # of engines."""
685 return len(self.ids)
684 return len(self.ids)
686
685
687 def __getitem__(self, key):
686 def __getitem__(self, key):
688 """index access returns DirectView multiplexer objects
687 """index access returns DirectView multiplexer objects
689
688
690 Must be int, slice, or list/tuple/xrange of ints"""
689 Must be int, slice, or list/tuple/xrange of ints"""
691 if not isinstance(key, (int, slice, tuple, list, xrange)):
690 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
691 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 else:
692 else:
694 return self.direct_view(key)
693 return self.direct_view(key)
695
694
696 #--------------------------------------------------------------------------
695 #--------------------------------------------------------------------------
697 # Begin public methods
696 # Begin public methods
698 #--------------------------------------------------------------------------
697 #--------------------------------------------------------------------------
699
698
700 @property
699 @property
701 def ids(self):
700 def ids(self):
702 """Always up-to-date ids property."""
701 """Always up-to-date ids property."""
703 self._flush_notifications()
702 self._flush_notifications()
704 # always copy:
703 # always copy:
705 return list(self._ids)
704 return list(self._ids)
706
705
707 def close(self):
706 def close(self):
708 if self._closed:
707 if self._closed:
709 return
708 return
710 snames = filter(lambda n: n.endswith('socket'), dir(self))
709 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 for socket in map(lambda name: getattr(self, name), snames):
710 for socket in map(lambda name: getattr(self, name), snames):
712 if isinstance(socket, zmq.Socket) and not socket.closed:
711 if isinstance(socket, zmq.Socket) and not socket.closed:
713 socket.close()
712 socket.close()
714 self._closed = True
713 self._closed = True
715
714
716 def spin(self):
715 def spin(self):
717 """Flush any registration notifications and execution results
716 """Flush any registration notifications and execution results
718 waiting in the ZMQ queue.
717 waiting in the ZMQ queue.
719 """
718 """
720 if self._notification_socket:
719 if self._notification_socket:
721 self._flush_notifications()
720 self._flush_notifications()
722 if self._mux_socket:
721 if self._mux_socket:
723 self._flush_results(self._mux_socket)
722 self._flush_results(self._mux_socket)
724 if self._task_socket:
723 if self._task_socket:
725 self._flush_results(self._task_socket)
724 self._flush_results(self._task_socket)
726 if self._control_socket:
725 if self._control_socket:
727 self._flush_control(self._control_socket)
726 self._flush_control(self._control_socket)
728 if self._iopub_socket:
727 if self._iopub_socket:
729 self._flush_iopub(self._iopub_socket)
728 self._flush_iopub(self._iopub_socket)
730 if self._query_socket:
729 if self._query_socket:
731 self._flush_ignored_hub_replies()
730 self._flush_ignored_hub_replies()
732
731
733 def wait(self, jobs=None, timeout=-1):
732 def wait(self, jobs=None, timeout=-1):
734 """waits on one or more `jobs`, for up to `timeout` seconds.
733 """waits on one or more `jobs`, for up to `timeout` seconds.
735
734
736 Parameters
735 Parameters
737 ----------
736 ----------
738
737
739 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
738 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 ints are indices to self.history
739 ints are indices to self.history
741 strs are msg_ids
740 strs are msg_ids
742 default: wait on all outstanding messages
741 default: wait on all outstanding messages
743 timeout : float
742 timeout : float
744 a time in seconds, after which to give up.
743 a time in seconds, after which to give up.
745 default is -1, which means no timeout
744 default is -1, which means no timeout
746
745
747 Returns
746 Returns
748 -------
747 -------
749
748
750 True : when all msg_ids are done
749 True : when all msg_ids are done
751 False : timeout reached, some msg_ids still outstanding
750 False : timeout reached, some msg_ids still outstanding
752 """
751 """
753 tic = time.time()
752 tic = time.time()
754 if jobs is None:
753 if jobs is None:
755 theids = self.outstanding
754 theids = self.outstanding
756 else:
755 else:
757 if isinstance(jobs, (int, str, AsyncResult)):
756 if isinstance(jobs, (int, str, AsyncResult)):
758 jobs = [jobs]
757 jobs = [jobs]
759 theids = set()
758 theids = set()
760 for job in jobs:
759 for job in jobs:
761 if isinstance(job, int):
760 if isinstance(job, int):
762 # index access
761 # index access
763 job = self.history[job]
762 job = self.history[job]
764 elif isinstance(job, AsyncResult):
763 elif isinstance(job, AsyncResult):
765 map(theids.add, job.msg_ids)
764 map(theids.add, job.msg_ids)
766 continue
765 continue
767 theids.add(job)
766 theids.add(job)
768 if not theids.intersection(self.outstanding):
767 if not theids.intersection(self.outstanding):
769 return True
768 return True
770 self.spin()
769 self.spin()
771 while theids.intersection(self.outstanding):
770 while theids.intersection(self.outstanding):
772 if timeout >= 0 and ( time.time()-tic ) > timeout:
771 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 break
772 break
774 time.sleep(1e-3)
773 time.sleep(1e-3)
775 self.spin()
774 self.spin()
776 return len(theids.intersection(self.outstanding)) == 0
775 return len(theids.intersection(self.outstanding)) == 0
777
776
778 #--------------------------------------------------------------------------
777 #--------------------------------------------------------------------------
779 # Control methods
778 # Control methods
780 #--------------------------------------------------------------------------
779 #--------------------------------------------------------------------------
781
780
782 @spin_first
781 @spin_first
783 def clear(self, targets=None, block=None):
782 def clear(self, targets=None, block=None):
784 """Clear the namespace in target(s)."""
783 """Clear the namespace in target(s)."""
785 block = self.block if block is None else block
784 block = self.block if block is None else block
786 targets = self._build_targets(targets)[0]
785 targets = self._build_targets(targets)[0]
787 for t in targets:
786 for t in targets:
788 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
787 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 error = False
788 error = False
790 if block:
789 if block:
791 self._flush_ignored_control()
790 self._flush_ignored_control()
792 for i in range(len(targets)):
791 for i in range(len(targets)):
793 idents,msg = self.session.recv(self._control_socket,0)
792 idents,msg = self.session.recv(self._control_socket,0)
794 if self.debug:
793 if self.debug:
795 pprint(msg)
794 pprint(msg)
796 if msg['content']['status'] != 'ok':
795 if msg['content']['status'] != 'ok':
797 error = self._unwrap_exception(msg['content'])
796 error = self._unwrap_exception(msg['content'])
798 else:
797 else:
799 self._ignored_control_replies += len(targets)
798 self._ignored_control_replies += len(targets)
800 if error:
799 if error:
801 raise error
800 raise error
802
801
803
802
804 @spin_first
803 @spin_first
805 def abort(self, jobs=None, targets=None, block=None):
804 def abort(self, jobs=None, targets=None, block=None):
806 """Abort specific jobs from the execution queues of target(s).
805 """Abort specific jobs from the execution queues of target(s).
807
806
808 This is a mechanism to prevent jobs that have already been submitted
807 This is a mechanism to prevent jobs that have already been submitted
809 from executing.
808 from executing.
810
809
811 Parameters
810 Parameters
812 ----------
811 ----------
813
812
814 jobs : msg_id, list of msg_ids, or AsyncResult
813 jobs : msg_id, list of msg_ids, or AsyncResult
815 The jobs to be aborted
814 The jobs to be aborted
816
815
817
816
818 """
817 """
819 block = self.block if block is None else block
818 block = self.block if block is None else block
820 targets = self._build_targets(targets)[0]
819 targets = self._build_targets(targets)[0]
821 msg_ids = []
820 msg_ids = []
822 if isinstance(jobs, (basestring,AsyncResult)):
821 if isinstance(jobs, (basestring,AsyncResult)):
823 jobs = [jobs]
822 jobs = [jobs]
824 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
823 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 if bad_ids:
824 if bad_ids:
826 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
825 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 for j in jobs:
826 for j in jobs:
828 if isinstance(j, AsyncResult):
827 if isinstance(j, AsyncResult):
829 msg_ids.extend(j.msg_ids)
828 msg_ids.extend(j.msg_ids)
830 else:
829 else:
831 msg_ids.append(j)
830 msg_ids.append(j)
832 content = dict(msg_ids=msg_ids)
831 content = dict(msg_ids=msg_ids)
833 for t in targets:
832 for t in targets:
834 self.session.send(self._control_socket, 'abort_request',
833 self.session.send(self._control_socket, 'abort_request',
835 content=content, ident=t)
834 content=content, ident=t)
836 error = False
835 error = False
837 if block:
836 if block:
838 self._flush_ignored_control()
837 self._flush_ignored_control()
839 for i in range(len(targets)):
838 for i in range(len(targets)):
840 idents,msg = self.session.recv(self._control_socket,0)
839 idents,msg = self.session.recv(self._control_socket,0)
841 if self.debug:
840 if self.debug:
842 pprint(msg)
841 pprint(msg)
843 if msg['content']['status'] != 'ok':
842 if msg['content']['status'] != 'ok':
844 error = self._unwrap_exception(msg['content'])
843 error = self._unwrap_exception(msg['content'])
845 else:
844 else:
846 self._ignored_control_replies += len(targets)
845 self._ignored_control_replies += len(targets)
847 if error:
846 if error:
848 raise error
847 raise error
849
848
850 @spin_first
849 @spin_first
851 def shutdown(self, targets=None, restart=False, hub=False, block=None):
850 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 """Terminates one or more engine processes, optionally including the hub."""
851 """Terminates one or more engine processes, optionally including the hub."""
853 block = self.block if block is None else block
852 block = self.block if block is None else block
854 if hub:
853 if hub:
855 targets = 'all'
854 targets = 'all'
856 targets = self._build_targets(targets)[0]
855 targets = self._build_targets(targets)[0]
857 for t in targets:
856 for t in targets:
858 self.session.send(self._control_socket, 'shutdown_request',
857 self.session.send(self._control_socket, 'shutdown_request',
859 content={'restart':restart},ident=t)
858 content={'restart':restart},ident=t)
860 error = False
859 error = False
861 if block or hub:
860 if block or hub:
862 self._flush_ignored_control()
861 self._flush_ignored_control()
863 for i in range(len(targets)):
862 for i in range(len(targets)):
864 idents,msg = self.session.recv(self._control_socket, 0)
863 idents,msg = self.session.recv(self._control_socket, 0)
865 if self.debug:
864 if self.debug:
866 pprint(msg)
865 pprint(msg)
867 if msg['content']['status'] != 'ok':
866 if msg['content']['status'] != 'ok':
868 error = self._unwrap_exception(msg['content'])
867 error = self._unwrap_exception(msg['content'])
869 else:
868 else:
870 self._ignored_control_replies += len(targets)
869 self._ignored_control_replies += len(targets)
871
870
872 if hub:
871 if hub:
873 time.sleep(0.25)
872 time.sleep(0.25)
874 self.session.send(self._query_socket, 'shutdown_request')
873 self.session.send(self._query_socket, 'shutdown_request')
875 idents,msg = self.session.recv(self._query_socket, 0)
874 idents,msg = self.session.recv(self._query_socket, 0)
876 if self.debug:
875 if self.debug:
877 pprint(msg)
876 pprint(msg)
878 if msg['content']['status'] != 'ok':
877 if msg['content']['status'] != 'ok':
879 error = self._unwrap_exception(msg['content'])
878 error = self._unwrap_exception(msg['content'])
880
879
881 if error:
880 if error:
882 raise error
881 raise error
883
882
884 #--------------------------------------------------------------------------
883 #--------------------------------------------------------------------------
885 # Execution related methods
884 # Execution related methods
886 #--------------------------------------------------------------------------
885 #--------------------------------------------------------------------------
887
886
888 def _maybe_raise(self, result):
887 def _maybe_raise(self, result):
889 """wrapper for maybe raising an exception if apply failed."""
888 """wrapper for maybe raising an exception if apply failed."""
890 if isinstance(result, error.RemoteError):
889 if isinstance(result, error.RemoteError):
891 raise result
890 raise result
892
891
893 return result
892 return result
894
893
895 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
894 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 ident=None):
895 ident=None):
897 """construct and send an apply message via a socket.
896 """construct and send an apply message via a socket.
898
897
899 This is the principal method with which all engine execution is performed by views.
898 This is the principal method with which all engine execution is performed by views.
900 """
899 """
901
900
902 assert not self._closed, "cannot use me anymore, I'm closed!"
901 assert not self._closed, "cannot use me anymore, I'm closed!"
903 # defaults:
902 # defaults:
904 args = args if args is not None else []
903 args = args if args is not None else []
905 kwargs = kwargs if kwargs is not None else {}
904 kwargs = kwargs if kwargs is not None else {}
906 subheader = subheader if subheader is not None else {}
905 subheader = subheader if subheader is not None else {}
907
906
908 # validate arguments
907 # validate arguments
909 if not callable(f):
908 if not callable(f):
910 raise TypeError("f must be callable, not %s"%type(f))
909 raise TypeError("f must be callable, not %s"%type(f))
911 if not isinstance(args, (tuple, list)):
910 if not isinstance(args, (tuple, list)):
912 raise TypeError("args must be tuple or list, not %s"%type(args))
911 raise TypeError("args must be tuple or list, not %s"%type(args))
913 if not isinstance(kwargs, dict):
912 if not isinstance(kwargs, dict):
914 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
913 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 if not isinstance(subheader, dict):
914 if not isinstance(subheader, dict):
916 raise TypeError("subheader must be dict, not %s"%type(subheader))
915 raise TypeError("subheader must be dict, not %s"%type(subheader))
917
916
918 bufs = util.pack_apply_message(f,args,kwargs)
917 bufs = util.pack_apply_message(f,args,kwargs)
919
918
920 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
919 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 subheader=subheader, track=track)
920 subheader=subheader, track=track)
922
921
923 msg_id = msg['msg_id']
922 msg_id = msg['msg_id']
924 self.outstanding.add(msg_id)
923 self.outstanding.add(msg_id)
925 if ident:
924 if ident:
926 # possibly routed to a specific engine
925 # possibly routed to a specific engine
927 if isinstance(ident, list):
926 if isinstance(ident, list):
928 ident = ident[-1]
927 ident = ident[-1]
929 if ident in self._engines.values():
928 if ident in self._engines.values():
930 # save for later, in case of engine death
929 # save for later, in case of engine death
931 self._outstanding_dict[ident].add(msg_id)
930 self._outstanding_dict[ident].add(msg_id)
932 self.history.append(msg_id)
931 self.history.append(msg_id)
933 self.metadata[msg_id]['submitted'] = datetime.now()
932 self.metadata[msg_id]['submitted'] = datetime.now()
934
933
935 return msg
934 return msg
936
935
937 #--------------------------------------------------------------------------
936 #--------------------------------------------------------------------------
938 # construct a View object
937 # construct a View object
939 #--------------------------------------------------------------------------
938 #--------------------------------------------------------------------------
940
939
941 def load_balanced_view(self, targets=None):
940 def load_balanced_view(self, targets=None):
942 """construct a DirectView object.
941 """construct a DirectView object.
943
942
944 If no arguments are specified, create a LoadBalancedView
943 If no arguments are specified, create a LoadBalancedView
945 using all engines.
944 using all engines.
946
945
947 Parameters
946 Parameters
948 ----------
947 ----------
949
948
950 targets: list,slice,int,etc. [default: use all engines]
949 targets: list,slice,int,etc. [default: use all engines]
951 The subset of engines across which to load-balance
950 The subset of engines across which to load-balance
952 """
951 """
953 if targets is not None:
952 if targets is not None:
954 targets = self._build_targets(targets)[1]
953 targets = self._build_targets(targets)[1]
955 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
954 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956
955
957 def direct_view(self, targets='all'):
956 def direct_view(self, targets='all'):
958 """construct a DirectView object.
957 """construct a DirectView object.
959
958
960 If no targets are specified, create a DirectView
959 If no targets are specified, create a DirectView
961 using all engines.
960 using all engines.
962
961
963 Parameters
962 Parameters
964 ----------
963 ----------
965
964
966 targets: list,slice,int,etc. [default: use all engines]
965 targets: list,slice,int,etc. [default: use all engines]
967 The engines to use for the View
966 The engines to use for the View
968 """
967 """
969 single = isinstance(targets, int)
968 single = isinstance(targets, int)
970 targets = self._build_targets(targets)[1]
969 targets = self._build_targets(targets)[1]
971 if single:
970 if single:
972 targets = targets[0]
971 targets = targets[0]
973 return DirectView(client=self, socket=self._mux_socket, targets=targets)
972 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974
973
975 #--------------------------------------------------------------------------
974 #--------------------------------------------------------------------------
976 # Query methods
975 # Query methods
977 #--------------------------------------------------------------------------
976 #--------------------------------------------------------------------------
978
977
979 @spin_first
978 @spin_first
980 def get_result(self, indices_or_msg_ids=None, block=None):
979 def get_result(self, indices_or_msg_ids=None, block=None):
981 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
980 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982
981
983 If the client already has the results, no request to the Hub will be made.
982 If the client already has the results, no request to the Hub will be made.
984
983
985 This is a convenient way to construct AsyncResult objects, which are wrappers
984 This is a convenient way to construct AsyncResult objects, which are wrappers
986 that include metadata about execution, and allow for awaiting results that
985 that include metadata about execution, and allow for awaiting results that
987 were not submitted by this Client.
986 were not submitted by this Client.
988
987
989 It can also be a convenient way to retrieve the metadata associated with
988 It can also be a convenient way to retrieve the metadata associated with
990 blocking execution, since it always retrieves
989 blocking execution, since it always retrieves
991
990
992 Examples
991 Examples
993 --------
992 --------
994 ::
993 ::
995
994
996 In [10]: r = client.apply()
995 In [10]: r = client.apply()
997
996
998 Parameters
997 Parameters
999 ----------
998 ----------
1000
999
1001 indices_or_msg_ids : integer history index, str msg_id, or list of either
1000 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 The indices or msg_ids of indices to be retrieved
1001 The indices or msg_ids of indices to be retrieved
1003
1002
1004 block : bool
1003 block : bool
1005 Whether to wait for the result to be done
1004 Whether to wait for the result to be done
1006
1005
1007 Returns
1006 Returns
1008 -------
1007 -------
1009
1008
1010 AsyncResult
1009 AsyncResult
1011 A single AsyncResult object will always be returned.
1010 A single AsyncResult object will always be returned.
1012
1011
1013 AsyncHubResult
1012 AsyncHubResult
1014 A subclass of AsyncResult that retrieves results from the Hub
1013 A subclass of AsyncResult that retrieves results from the Hub
1015
1014
1016 """
1015 """
1017 block = self.block if block is None else block
1016 block = self.block if block is None else block
1018 if indices_or_msg_ids is None:
1017 if indices_or_msg_ids is None:
1019 indices_or_msg_ids = -1
1018 indices_or_msg_ids = -1
1020
1019
1021 if not isinstance(indices_or_msg_ids, (list,tuple)):
1020 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 indices_or_msg_ids = [indices_or_msg_ids]
1021 indices_or_msg_ids = [indices_or_msg_ids]
1023
1022
1024 theids = []
1023 theids = []
1025 for id in indices_or_msg_ids:
1024 for id in indices_or_msg_ids:
1026 if isinstance(id, int):
1025 if isinstance(id, int):
1027 id = self.history[id]
1026 id = self.history[id]
1028 if not isinstance(id, str):
1027 if not isinstance(id, str):
1029 raise TypeError("indices must be str or int, not %r"%id)
1028 raise TypeError("indices must be str or int, not %r"%id)
1030 theids.append(id)
1029 theids.append(id)
1031
1030
1032 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1031 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1032 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034
1033
1035 if remote_ids:
1034 if remote_ids:
1036 ar = AsyncHubResult(self, msg_ids=theids)
1035 ar = AsyncHubResult(self, msg_ids=theids)
1037 else:
1036 else:
1038 ar = AsyncResult(self, msg_ids=theids)
1037 ar = AsyncResult(self, msg_ids=theids)
1039
1038
1040 if block:
1039 if block:
1041 ar.wait()
1040 ar.wait()
1042
1041
1043 return ar
1042 return ar
1044
1043
1045 @spin_first
1044 @spin_first
1046 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1045 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 """Resubmit one or more tasks.
1046 """Resubmit one or more tasks.
1048
1047
1049 in-flight tasks may not be resubmitted.
1048 in-flight tasks may not be resubmitted.
1050
1049
1051 Parameters
1050 Parameters
1052 ----------
1051 ----------
1053
1052
1054 indices_or_msg_ids : integer history index, str msg_id, or list of either
1053 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 The indices or msg_ids of indices to be retrieved
1054 The indices or msg_ids of indices to be retrieved
1056
1055
1057 block : bool
1056 block : bool
1058 Whether to wait for the result to be done
1057 Whether to wait for the result to be done
1059
1058
1060 Returns
1059 Returns
1061 -------
1060 -------
1062
1061
1063 AsyncHubResult
1062 AsyncHubResult
1064 A subclass of AsyncResult that retrieves results from the Hub
1063 A subclass of AsyncResult that retrieves results from the Hub
1065
1064
1066 """
1065 """
1067 block = self.block if block is None else block
1066 block = self.block if block is None else block
1068 if indices_or_msg_ids is None:
1067 if indices_or_msg_ids is None:
1069 indices_or_msg_ids = -1
1068 indices_or_msg_ids = -1
1070
1069
1071 if not isinstance(indices_or_msg_ids, (list,tuple)):
1070 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 indices_or_msg_ids = [indices_or_msg_ids]
1071 indices_or_msg_ids = [indices_or_msg_ids]
1073
1072
1074 theids = []
1073 theids = []
1075 for id in indices_or_msg_ids:
1074 for id in indices_or_msg_ids:
1076 if isinstance(id, int):
1075 if isinstance(id, int):
1077 id = self.history[id]
1076 id = self.history[id]
1078 if not isinstance(id, str):
1077 if not isinstance(id, str):
1079 raise TypeError("indices must be str or int, not %r"%id)
1078 raise TypeError("indices must be str or int, not %r"%id)
1080 theids.append(id)
1079 theids.append(id)
1081
1080
1082 for msg_id in theids:
1081 for msg_id in theids:
1083 self.outstanding.discard(msg_id)
1082 self.outstanding.discard(msg_id)
1084 if msg_id in self.history:
1083 if msg_id in self.history:
1085 self.history.remove(msg_id)
1084 self.history.remove(msg_id)
1086 self.results.pop(msg_id, None)
1085 self.results.pop(msg_id, None)
1087 self.metadata.pop(msg_id, None)
1086 self.metadata.pop(msg_id, None)
1088 content = dict(msg_ids = theids)
1087 content = dict(msg_ids = theids)
1089
1088
1090 self.session.send(self._query_socket, 'resubmit_request', content)
1089 self.session.send(self._query_socket, 'resubmit_request', content)
1091
1090
1092 zmq.select([self._query_socket], [], [])
1091 zmq.select([self._query_socket], [], [])
1093 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1092 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 if self.debug:
1093 if self.debug:
1095 pprint(msg)
1094 pprint(msg)
1096 content = msg['content']
1095 content = msg['content']
1097 if content['status'] != 'ok':
1096 if content['status'] != 'ok':
1098 raise self._unwrap_exception(content)
1097 raise self._unwrap_exception(content)
1099
1098
1100 ar = AsyncHubResult(self, msg_ids=theids)
1099 ar = AsyncHubResult(self, msg_ids=theids)
1101
1100
1102 if block:
1101 if block:
1103 ar.wait()
1102 ar.wait()
1104
1103
1105 return ar
1104 return ar
1106
1105
1107 @spin_first
1106 @spin_first
1108 def result_status(self, msg_ids, status_only=True):
1107 def result_status(self, msg_ids, status_only=True):
1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1108 """Check on the status of the result(s) of the apply request with `msg_ids`.
1110
1109
1111 If status_only is False, then the actual results will be retrieved, else
1110 If status_only is False, then the actual results will be retrieved, else
1112 only the status of the results will be checked.
1111 only the status of the results will be checked.
1113
1112
1114 Parameters
1113 Parameters
1115 ----------
1114 ----------
1116
1115
1117 msg_ids : list of msg_ids
1116 msg_ids : list of msg_ids
1118 if int:
1117 if int:
1119 Passed as index to self.history for convenience.
1118 Passed as index to self.history for convenience.
1120 status_only : bool (default: True)
1119 status_only : bool (default: True)
1121 if False:
1120 if False:
1122 Retrieve the actual results of completed tasks.
1121 Retrieve the actual results of completed tasks.
1123
1122
1124 Returns
1123 Returns
1125 -------
1124 -------
1126
1125
1127 results : dict
1126 results : dict
1128 There will always be the keys 'pending' and 'completed', which will
1127 There will always be the keys 'pending' and 'completed', which will
1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1128 be lists of msg_ids that are incomplete or complete. If `status_only`
1130 is False, then completed results will be keyed by their `msg_id`.
1129 is False, then completed results will be keyed by their `msg_id`.
1131 """
1130 """
1132 if not isinstance(msg_ids, (list,tuple)):
1131 if not isinstance(msg_ids, (list,tuple)):
1133 msg_ids = [msg_ids]
1132 msg_ids = [msg_ids]
1134
1133
1135 theids = []
1134 theids = []
1136 for msg_id in msg_ids:
1135 for msg_id in msg_ids:
1137 if isinstance(msg_id, int):
1136 if isinstance(msg_id, int):
1138 msg_id = self.history[msg_id]
1137 msg_id = self.history[msg_id]
1139 if not isinstance(msg_id, basestring):
1138 if not isinstance(msg_id, basestring):
1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1139 raise TypeError("msg_ids must be str, not %r"%msg_id)
1141 theids.append(msg_id)
1140 theids.append(msg_id)
1142
1141
1143 completed = []
1142 completed = []
1144 local_results = {}
1143 local_results = {}
1145
1144
1146 # comment this block out to temporarily disable local shortcut:
1145 # comment this block out to temporarily disable local shortcut:
1147 for msg_id in theids:
1146 for msg_id in theids:
1148 if msg_id in self.results:
1147 if msg_id in self.results:
1149 completed.append(msg_id)
1148 completed.append(msg_id)
1150 local_results[msg_id] = self.results[msg_id]
1149 local_results[msg_id] = self.results[msg_id]
1151 theids.remove(msg_id)
1150 theids.remove(msg_id)
1152
1151
1153 if theids: # some not locally cached
1152 if theids: # some not locally cached
1154 content = dict(msg_ids=theids, status_only=status_only)
1153 content = dict(msg_ids=theids, status_only=status_only)
1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1154 msg = self.session.send(self._query_socket, "result_request", content=content)
1156 zmq.select([self._query_socket], [], [])
1155 zmq.select([self._query_socket], [], [])
1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1156 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1158 if self.debug:
1157 if self.debug:
1159 pprint(msg)
1158 pprint(msg)
1160 content = msg['content']
1159 content = msg['content']
1161 if content['status'] != 'ok':
1160 if content['status'] != 'ok':
1162 raise self._unwrap_exception(content)
1161 raise self._unwrap_exception(content)
1163 buffers = msg['buffers']
1162 buffers = msg['buffers']
1164 else:
1163 else:
1165 content = dict(completed=[],pending=[])
1164 content = dict(completed=[],pending=[])
1166
1165
1167 content['completed'].extend(completed)
1166 content['completed'].extend(completed)
1168
1167
1169 if status_only:
1168 if status_only:
1170 return content
1169 return content
1171
1170
1172 failures = []
1171 failures = []
1173 # load cached results into result:
1172 # load cached results into result:
1174 content.update(local_results)
1173 content.update(local_results)
1174 content = extract_dates(content)
1175 # update cache with results:
1175 # update cache with results:
1176 for msg_id in sorted(theids):
1176 for msg_id in sorted(theids):
1177 if msg_id in content['completed']:
1177 if msg_id in content['completed']:
1178 rec = content[msg_id]
1178 rec = content[msg_id]
1179 parent = rec['header']
1179 parent = rec['header']
1180 header = rec['result_header']
1180 header = rec['result_header']
1181 rcontent = rec['result_content']
1181 rcontent = rec['result_content']
1182 iodict = rec['io']
1182 iodict = rec['io']
1183 if isinstance(rcontent, str):
1183 if isinstance(rcontent, str):
1184 rcontent = self.session.unpack(rcontent)
1184 rcontent = self.session.unpack(rcontent)
1185
1185
1186 md = self.metadata[msg_id]
1186 md = self.metadata[msg_id]
1187 md.update(self._extract_metadata(header, parent, rcontent))
1187 md.update(self._extract_metadata(header, parent, rcontent))
1188 md.update(iodict)
1188 md.update(iodict)
1189
1189
1190 if rcontent['status'] == 'ok':
1190 if rcontent['status'] == 'ok':
1191 res,buffers = util.unserialize_object(buffers)
1191 res,buffers = util.unserialize_object(buffers)
1192 else:
1192 else:
1193 print rcontent
1193 print rcontent
1194 res = self._unwrap_exception(rcontent)
1194 res = self._unwrap_exception(rcontent)
1195 failures.append(res)
1195 failures.append(res)
1196
1196
1197 self.results[msg_id] = res
1197 self.results[msg_id] = res
1198 content[msg_id] = res
1198 content[msg_id] = res
1199
1199
1200 if len(theids) == 1 and failures:
1200 if len(theids) == 1 and failures:
1201 raise failures[0]
1201 raise failures[0]
1202
1202
1203 error.collect_exceptions(failures, "result_status")
1203 error.collect_exceptions(failures, "result_status")
1204 return content
1204 return content
1205
1205
1206 @spin_first
1206 @spin_first
1207 def queue_status(self, targets='all', verbose=False):
1207 def queue_status(self, targets='all', verbose=False):
1208 """Fetch the status of engine queues.
1208 """Fetch the status of engine queues.
1209
1209
1210 Parameters
1210 Parameters
1211 ----------
1211 ----------
1212
1212
1213 targets : int/str/list of ints/strs
1213 targets : int/str/list of ints/strs
1214 the engines whose states are to be queried.
1214 the engines whose states are to be queried.
1215 default : all
1215 default : all
1216 verbose : bool
1216 verbose : bool
1217 Whether to return lengths only, or lists of ids for each element
1217 Whether to return lengths only, or lists of ids for each element
1218 """
1218 """
1219 engine_ids = self._build_targets(targets)[1]
1219 engine_ids = self._build_targets(targets)[1]
1220 content = dict(targets=engine_ids, verbose=verbose)
1220 content = dict(targets=engine_ids, verbose=verbose)
1221 self.session.send(self._query_socket, "queue_request", content=content)
1221 self.session.send(self._query_socket, "queue_request", content=content)
1222 idents,msg = self.session.recv(self._query_socket, 0)
1222 idents,msg = self.session.recv(self._query_socket, 0)
1223 if self.debug:
1223 if self.debug:
1224 pprint(msg)
1224 pprint(msg)
1225 content = msg['content']
1225 content = msg['content']
1226 status = content.pop('status')
1226 status = content.pop('status')
1227 if status != 'ok':
1227 if status != 'ok':
1228 raise self._unwrap_exception(content)
1228 raise self._unwrap_exception(content)
1229 content = util.rekey(content)
1229 content = util.rekey(content)
1230 if isinstance(targets, int):
1230 if isinstance(targets, int):
1231 return content[targets]
1231 return content[targets]
1232 else:
1232 else:
1233 return content
1233 return content
1234
1234
1235 @spin_first
1235 @spin_first
1236 def purge_results(self, jobs=[], targets=[]):
1236 def purge_results(self, jobs=[], targets=[]):
1237 """Tell the Hub to forget results.
1237 """Tell the Hub to forget results.
1238
1238
1239 Individual results can be purged by msg_id, or the entire
1239 Individual results can be purged by msg_id, or the entire
1240 history of specific targets can be purged.
1240 history of specific targets can be purged.
1241
1241
1242 Parameters
1242 Parameters
1243 ----------
1243 ----------
1244
1244
1245 jobs : str or list of str or AsyncResult objects
1245 jobs : str or list of str or AsyncResult objects
1246 the msg_ids whose results should be forgotten.
1246 the msg_ids whose results should be forgotten.
1247 targets : int/str/list of ints/strs
1247 targets : int/str/list of ints/strs
1248 The targets, by uuid or int_id, whose entire history is to be purged.
1248 The targets, by uuid or int_id, whose entire history is to be purged.
1249 Use `targets='all'` to scrub everything from the Hub's memory.
1249 Use `targets='all'` to scrub everything from the Hub's memory.
1250
1250
1251 default : None
1251 default : None
1252 """
1252 """
1253 if not targets and not jobs:
1253 if not targets and not jobs:
1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1255 if targets:
1255 if targets:
1256 targets = self._build_targets(targets)[1]
1256 targets = self._build_targets(targets)[1]
1257
1257
1258 # construct msg_ids from jobs
1258 # construct msg_ids from jobs
1259 msg_ids = []
1259 msg_ids = []
1260 if isinstance(jobs, (basestring,AsyncResult)):
1260 if isinstance(jobs, (basestring,AsyncResult)):
1261 jobs = [jobs]
1261 jobs = [jobs]
1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1263 if bad_ids:
1263 if bad_ids:
1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1265 for j in jobs:
1265 for j in jobs:
1266 if isinstance(j, AsyncResult):
1266 if isinstance(j, AsyncResult):
1267 msg_ids.extend(j.msg_ids)
1267 msg_ids.extend(j.msg_ids)
1268 else:
1268 else:
1269 msg_ids.append(j)
1269 msg_ids.append(j)
1270
1270
1271 content = dict(targets=targets, msg_ids=msg_ids)
1271 content = dict(targets=targets, msg_ids=msg_ids)
1272 self.session.send(self._query_socket, "purge_request", content=content)
1272 self.session.send(self._query_socket, "purge_request", content=content)
1273 idents, msg = self.session.recv(self._query_socket, 0)
1273 idents, msg = self.session.recv(self._query_socket, 0)
1274 if self.debug:
1274 if self.debug:
1275 pprint(msg)
1275 pprint(msg)
1276 content = msg['content']
1276 content = msg['content']
1277 if content['status'] != 'ok':
1277 if content['status'] != 'ok':
1278 raise self._unwrap_exception(content)
1278 raise self._unwrap_exception(content)
1279
1279
1280 @spin_first
1280 @spin_first
1281 def hub_history(self):
1281 def hub_history(self):
1282 """Get the Hub's history
1282 """Get the Hub's history
1283
1283
1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1285 This will contain the history of all clients, and, depending on configuration,
1285 This will contain the history of all clients, and, depending on configuration,
1286 may contain history across multiple cluster sessions.
1286 may contain history across multiple cluster sessions.
1287
1287
1288 Any msg_id returned here is a valid argument to `get_result`.
1288 Any msg_id returned here is a valid argument to `get_result`.
1289
1289
1290 Returns
1290 Returns
1291 -------
1291 -------
1292
1292
1293 msg_ids : list of strs
1293 msg_ids : list of strs
1294 list of all msg_ids, ordered by task submission time.
1294 list of all msg_ids, ordered by task submission time.
1295 """
1295 """
1296
1296
1297 self.session.send(self._query_socket, "history_request", content={})
1297 self.session.send(self._query_socket, "history_request", content={})
1298 idents, msg = self.session.recv(self._query_socket, 0)
1298 idents, msg = self.session.recv(self._query_socket, 0)
1299
1299
1300 if self.debug:
1300 if self.debug:
1301 pprint(msg)
1301 pprint(msg)
1302 content = msg['content']
1302 content = msg['content']
1303 if content['status'] != 'ok':
1303 if content['status'] != 'ok':
1304 raise self._unwrap_exception(content)
1304 raise self._unwrap_exception(content)
1305 else:
1305 else:
1306 return content['history']
1306 return content['history']
1307
1307
1308 @spin_first
1308 @spin_first
1309 def db_query(self, query, keys=None):
1309 def db_query(self, query, keys=None):
1310 """Query the Hub's TaskRecord database
1310 """Query the Hub's TaskRecord database
1311
1311
1312 This will return a list of task record dicts that match `query`
1312 This will return a list of task record dicts that match `query`
1313
1313
1314 Parameters
1314 Parameters
1315 ----------
1315 ----------
1316
1316
1317 query : mongodb query dict
1317 query : mongodb query dict
1318 The search dict. See mongodb query docs for details.
1318 The search dict. See mongodb query docs for details.
1319 keys : list of strs [optional]
1319 keys : list of strs [optional]
1320 The subset of keys to be returned. The default is to fetch everything but buffers.
1320 The subset of keys to be returned. The default is to fetch everything but buffers.
1321 'msg_id' will *always* be included.
1321 'msg_id' will *always* be included.
1322 """
1322 """
1323 if isinstance(keys, basestring):
1323 if isinstance(keys, basestring):
1324 keys = [keys]
1324 keys = [keys]
1325 content = dict(query=query, keys=keys)
1325 content = dict(query=query, keys=keys)
1326 self.session.send(self._query_socket, "db_request", content=content)
1326 self.session.send(self._query_socket, "db_request", content=content)
1327 idents, msg = self.session.recv(self._query_socket, 0)
1327 idents, msg = self.session.recv(self._query_socket, 0)
1328 if self.debug:
1328 if self.debug:
1329 pprint(msg)
1329 pprint(msg)
1330 content = msg['content']
1330 content = msg['content']
1331 if content['status'] != 'ok':
1331 if content['status'] != 'ok':
1332 raise self._unwrap_exception(content)
1332 raise self._unwrap_exception(content)
1333
1333
1334 records = content['records']
1334 records = content['records']
1335 buffer_lens = content['buffer_lens']
1335 buffer_lens = content['buffer_lens']
1336 result_buffer_lens = content['result_buffer_lens']
1336 result_buffer_lens = content['result_buffer_lens']
1337 buffers = msg['buffers']
1337 buffers = msg['buffers']
1338 has_bufs = buffer_lens is not None
1338 has_bufs = buffer_lens is not None
1339 has_rbufs = result_buffer_lens is not None
1339 has_rbufs = result_buffer_lens is not None
1340 for i,rec in enumerate(records):
1340 for i,rec in enumerate(records):
1341 # unpack timestamps
1342 rec = extract_dates(rec)
1341 # relink buffers
1343 # relink buffers
1342 if has_bufs:
1344 if has_bufs:
1343 blen = buffer_lens[i]
1345 blen = buffer_lens[i]
1344 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1346 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1345 if has_rbufs:
1347 if has_rbufs:
1346 blen = result_buffer_lens[i]
1348 blen = result_buffer_lens[i]
1347 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1349 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1348 # turn timestamps back into times
1349 for key in 'submitted started completed resubmitted'.split():
1350 maybedate = rec.get(key, None)
1351 if maybedate and util.ISO8601_RE.match(maybedate):
1352 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1353
1350
1354 return records
1351 return records
1355
1352
1356 __all__ = [ 'Client' ]
1353 __all__ = [ 'Client' ]
@@ -1,1277 +1,1274 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
2 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
3 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
4 and monitors traffic through the various queues.
5 """
5 """
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010 The IPython Development Team
7 # Copyright (C) 2010 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from datetime import datetime
20 from datetime import datetime
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop
23 from zmq.eventloop import ioloop
24 from zmq.eventloop.zmqstream import ZMQStream
24 from zmq.eventloop.zmqstream import ZMQStream
25
25
26 # internal:
26 # internal:
27 from IPython.utils.importstring import import_item
27 from IPython.utils.importstring import import_item
28 from IPython.utils.traitlets import (
28 from IPython.utils.traitlets import (
29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 )
30 )
31 from IPython.utils.jsonutil import ISO8601, extract_dates
31
32
32 from IPython.parallel import error, util
33 from IPython.parallel import error, util
33 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
34 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
34
35
35 from .heartmonitor import HeartMonitor
36 from .heartmonitor import HeartMonitor
36
37
37 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
38 # Code
39 # Code
39 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
40
41
41 def _passer(*args, **kwargs):
42 def _passer(*args, **kwargs):
42 return
43 return
43
44
44 def _printer(*args, **kwargs):
45 def _printer(*args, **kwargs):
45 print (args)
46 print (args)
46 print (kwargs)
47 print (kwargs)
47
48
48 def empty_record():
49 def empty_record():
49 """Return an empty dict with all record keys."""
50 """Return an empty dict with all record keys."""
50 return {
51 return {
51 'msg_id' : None,
52 'msg_id' : None,
52 'header' : None,
53 'header' : None,
53 'content': None,
54 'content': None,
54 'buffers': None,
55 'buffers': None,
55 'submitted': None,
56 'submitted': None,
56 'client_uuid' : None,
57 'client_uuid' : None,
57 'engine_uuid' : None,
58 'engine_uuid' : None,
58 'started': None,
59 'started': None,
59 'completed': None,
60 'completed': None,
60 'resubmitted': None,
61 'resubmitted': None,
61 'result_header' : None,
62 'result_header' : None,
62 'result_content' : None,
63 'result_content' : None,
63 'result_buffers' : None,
64 'result_buffers' : None,
64 'queue' : None,
65 'queue' : None,
65 'pyin' : None,
66 'pyin' : None,
66 'pyout': None,
67 'pyout': None,
67 'pyerr': None,
68 'pyerr': None,
68 'stdout': '',
69 'stdout': '',
69 'stderr': '',
70 'stderr': '',
70 }
71 }
71
72
72 def init_record(msg):
73 def init_record(msg):
73 """Initialize a TaskRecord based on a request."""
74 """Initialize a TaskRecord based on a request."""
74 header = msg['header']
75 header = extract_dates(msg['header'])
75 return {
76 return {
76 'msg_id' : header['msg_id'],
77 'msg_id' : header['msg_id'],
77 'header' : header,
78 'header' : header,
78 'content': msg['content'],
79 'content': msg['content'],
79 'buffers': msg['buffers'],
80 'buffers': msg['buffers'],
80 'submitted': datetime.strptime(header['date'], util.ISO8601),
81 'submitted': header['date'],
81 'client_uuid' : None,
82 'client_uuid' : None,
82 'engine_uuid' : None,
83 'engine_uuid' : None,
83 'started': None,
84 'started': None,
84 'completed': None,
85 'completed': None,
85 'resubmitted': None,
86 'resubmitted': None,
86 'result_header' : None,
87 'result_header' : None,
87 'result_content' : None,
88 'result_content' : None,
88 'result_buffers' : None,
89 'result_buffers' : None,
89 'queue' : None,
90 'queue' : None,
90 'pyin' : None,
91 'pyin' : None,
91 'pyout': None,
92 'pyout': None,
92 'pyerr': None,
93 'pyerr': None,
93 'stdout': '',
94 'stdout': '',
94 'stderr': '',
95 'stderr': '',
95 }
96 }
96
97
97
98
98 class EngineConnector(HasTraits):
99 class EngineConnector(HasTraits):
99 """A simple object for accessing the various zmq connections of an object.
100 """A simple object for accessing the various zmq connections of an object.
100 Attributes are:
101 Attributes are:
101 id (int): engine ID
102 id (int): engine ID
102 uuid (str): uuid (unused?)
103 uuid (str): uuid (unused?)
103 queue (str): identity of queue's XREQ socket
104 queue (str): identity of queue's XREQ socket
104 registration (str): identity of registration XREQ socket
105 registration (str): identity of registration XREQ socket
105 heartbeat (str): identity of heartbeat XREQ socket
106 heartbeat (str): identity of heartbeat XREQ socket
106 """
107 """
107 id=Int(0)
108 id=Int(0)
108 queue=CStr()
109 queue=CStr()
109 control=CStr()
110 control=CStr()
110 registration=CStr()
111 registration=CStr()
111 heartbeat=CStr()
112 heartbeat=CStr()
112 pending=Set()
113 pending=Set()
113
114
114 class HubFactory(RegistrationFactory):
115 class HubFactory(RegistrationFactory):
115 """The Configurable for setting up a Hub."""
116 """The Configurable for setting up a Hub."""
116
117
117 # port-pairs for monitoredqueues:
118 # port-pairs for monitoredqueues:
118 hb = Tuple(Int,Int,config=True,
119 hb = Tuple(Int,Int,config=True,
119 help="""XREQ/SUB Port pair for Engine heartbeats""")
120 help="""XREQ/SUB Port pair for Engine heartbeats""")
120 def _hb_default(self):
121 def _hb_default(self):
121 return tuple(util.select_random_ports(2))
122 return tuple(util.select_random_ports(2))
122
123
123 mux = Tuple(Int,Int,config=True,
124 mux = Tuple(Int,Int,config=True,
124 help="""Engine/Client Port pair for MUX queue""")
125 help="""Engine/Client Port pair for MUX queue""")
125
126
126 def _mux_default(self):
127 def _mux_default(self):
127 return tuple(util.select_random_ports(2))
128 return tuple(util.select_random_ports(2))
128
129
129 task = Tuple(Int,Int,config=True,
130 task = Tuple(Int,Int,config=True,
130 help="""Engine/Client Port pair for Task queue""")
131 help="""Engine/Client Port pair for Task queue""")
131 def _task_default(self):
132 def _task_default(self):
132 return tuple(util.select_random_ports(2))
133 return tuple(util.select_random_ports(2))
133
134
134 control = Tuple(Int,Int,config=True,
135 control = Tuple(Int,Int,config=True,
135 help="""Engine/Client Port pair for Control queue""")
136 help="""Engine/Client Port pair for Control queue""")
136
137
137 def _control_default(self):
138 def _control_default(self):
138 return tuple(util.select_random_ports(2))
139 return tuple(util.select_random_ports(2))
139
140
140 iopub = Tuple(Int,Int,config=True,
141 iopub = Tuple(Int,Int,config=True,
141 help="""Engine/Client Port pair for IOPub relay""")
142 help="""Engine/Client Port pair for IOPub relay""")
142
143
143 def _iopub_default(self):
144 def _iopub_default(self):
144 return tuple(util.select_random_ports(2))
145 return tuple(util.select_random_ports(2))
145
146
146 # single ports:
147 # single ports:
147 mon_port = Int(config=True,
148 mon_port = Int(config=True,
148 help="""Monitor (SUB) port for queue traffic""")
149 help="""Monitor (SUB) port for queue traffic""")
149
150
150 def _mon_port_default(self):
151 def _mon_port_default(self):
151 return util.select_random_ports(1)[0]
152 return util.select_random_ports(1)[0]
152
153
153 notifier_port = Int(config=True,
154 notifier_port = Int(config=True,
154 help="""PUB port for sending engine status notifications""")
155 help="""PUB port for sending engine status notifications""")
155
156
156 def _notifier_port_default(self):
157 def _notifier_port_default(self):
157 return util.select_random_ports(1)[0]
158 return util.select_random_ports(1)[0]
158
159
159 engine_ip = Unicode('127.0.0.1', config=True,
160 engine_ip = Unicode('127.0.0.1', config=True,
160 help="IP on which to listen for engine connections. [default: loopback]")
161 help="IP on which to listen for engine connections. [default: loopback]")
161 engine_transport = Unicode('tcp', config=True,
162 engine_transport = Unicode('tcp', config=True,
162 help="0MQ transport for engine connections. [default: tcp]")
163 help="0MQ transport for engine connections. [default: tcp]")
163
164
164 client_ip = Unicode('127.0.0.1', config=True,
165 client_ip = Unicode('127.0.0.1', config=True,
165 help="IP on which to listen for client connections. [default: loopback]")
166 help="IP on which to listen for client connections. [default: loopback]")
166 client_transport = Unicode('tcp', config=True,
167 client_transport = Unicode('tcp', config=True,
167 help="0MQ transport for client connections. [default : tcp]")
168 help="0MQ transport for client connections. [default : tcp]")
168
169
169 monitor_ip = Unicode('127.0.0.1', config=True,
170 monitor_ip = Unicode('127.0.0.1', config=True,
170 help="IP on which to listen for monitor messages. [default: loopback]")
171 help="IP on which to listen for monitor messages. [default: loopback]")
171 monitor_transport = Unicode('tcp', config=True,
172 monitor_transport = Unicode('tcp', config=True,
172 help="0MQ transport for monitor messages. [default : tcp]")
173 help="0MQ transport for monitor messages. [default : tcp]")
173
174
174 monitor_url = Unicode('')
175 monitor_url = Unicode('')
175
176
176 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
177 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
177 help="""The class to use for the DB backend""")
178 help="""The class to use for the DB backend""")
178
179
179 # not configurable
180 # not configurable
180 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
181 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
181 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
182 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
182
183
183 def _ip_changed(self, name, old, new):
184 def _ip_changed(self, name, old, new):
184 self.engine_ip = new
185 self.engine_ip = new
185 self.client_ip = new
186 self.client_ip = new
186 self.monitor_ip = new
187 self.monitor_ip = new
187 self._update_monitor_url()
188 self._update_monitor_url()
188
189
189 def _update_monitor_url(self):
190 def _update_monitor_url(self):
190 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
191 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
191
192
192 def _transport_changed(self, name, old, new):
193 def _transport_changed(self, name, old, new):
193 self.engine_transport = new
194 self.engine_transport = new
194 self.client_transport = new
195 self.client_transport = new
195 self.monitor_transport = new
196 self.monitor_transport = new
196 self._update_monitor_url()
197 self._update_monitor_url()
197
198
198 def __init__(self, **kwargs):
199 def __init__(self, **kwargs):
199 super(HubFactory, self).__init__(**kwargs)
200 super(HubFactory, self).__init__(**kwargs)
200 self._update_monitor_url()
201 self._update_monitor_url()
201
202
202
203
203 def construct(self):
204 def construct(self):
204 self.init_hub()
205 self.init_hub()
205
206
206 def start(self):
207 def start(self):
207 self.heartmonitor.start()
208 self.heartmonitor.start()
208 self.log.info("Heartmonitor started")
209 self.log.info("Heartmonitor started")
209
210
210 def init_hub(self):
211 def init_hub(self):
211 """construct"""
212 """construct"""
212 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
213 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
213 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
214 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
214
215
215 ctx = self.context
216 ctx = self.context
216 loop = self.loop
217 loop = self.loop
217
218
218 # Registrar socket
219 # Registrar socket
219 q = ZMQStream(ctx.socket(zmq.XREP), loop)
220 q = ZMQStream(ctx.socket(zmq.XREP), loop)
220 q.bind(client_iface % self.regport)
221 q.bind(client_iface % self.regport)
221 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
222 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
222 if self.client_ip != self.engine_ip:
223 if self.client_ip != self.engine_ip:
223 q.bind(engine_iface % self.regport)
224 q.bind(engine_iface % self.regport)
224 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
225 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
225
226
226 ### Engine connections ###
227 ### Engine connections ###
227
228
228 # heartbeat
229 # heartbeat
229 hpub = ctx.socket(zmq.PUB)
230 hpub = ctx.socket(zmq.PUB)
230 hpub.bind(engine_iface % self.hb[0])
231 hpub.bind(engine_iface % self.hb[0])
231 hrep = ctx.socket(zmq.XREP)
232 hrep = ctx.socket(zmq.XREP)
232 hrep.bind(engine_iface % self.hb[1])
233 hrep.bind(engine_iface % self.hb[1])
233 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
234 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
234 config=self.config)
235 config=self.config)
235
236
236 ### Client connections ###
237 ### Client connections ###
237 # Notifier socket
238 # Notifier socket
238 n = ZMQStream(ctx.socket(zmq.PUB), loop)
239 n = ZMQStream(ctx.socket(zmq.PUB), loop)
239 n.bind(client_iface%self.notifier_port)
240 n.bind(client_iface%self.notifier_port)
240
241
241 ### build and launch the queues ###
242 ### build and launch the queues ###
242
243
243 # monitor socket
244 # monitor socket
244 sub = ctx.socket(zmq.SUB)
245 sub = ctx.socket(zmq.SUB)
245 sub.setsockopt(zmq.SUBSCRIBE, "")
246 sub.setsockopt(zmq.SUBSCRIBE, "")
246 sub.bind(self.monitor_url)
247 sub.bind(self.monitor_url)
247 sub.bind('inproc://monitor')
248 sub.bind('inproc://monitor')
248 sub = ZMQStream(sub, loop)
249 sub = ZMQStream(sub, loop)
249
250
250 # connect the db
251 # connect the db
251 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
252 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
252 # cdir = self.config.Global.cluster_dir
253 # cdir = self.config.Global.cluster_dir
253 self.db = import_item(str(self.db_class))(session=self.session.session, config=self.config)
254 self.db = import_item(str(self.db_class))(session=self.session.session, config=self.config)
254 time.sleep(.25)
255 time.sleep(.25)
255 try:
256 try:
256 scheme = self.config.TaskScheduler.scheme_name
257 scheme = self.config.TaskScheduler.scheme_name
257 except AttributeError:
258 except AttributeError:
258 from .scheduler import TaskScheduler
259 from .scheduler import TaskScheduler
259 scheme = TaskScheduler.scheme_name.get_default_value()
260 scheme = TaskScheduler.scheme_name.get_default_value()
260 # build connection dicts
261 # build connection dicts
261 self.engine_info = {
262 self.engine_info = {
262 'control' : engine_iface%self.control[1],
263 'control' : engine_iface%self.control[1],
263 'mux': engine_iface%self.mux[1],
264 'mux': engine_iface%self.mux[1],
264 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
265 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
265 'task' : engine_iface%self.task[1],
266 'task' : engine_iface%self.task[1],
266 'iopub' : engine_iface%self.iopub[1],
267 'iopub' : engine_iface%self.iopub[1],
267 # 'monitor' : engine_iface%self.mon_port,
268 # 'monitor' : engine_iface%self.mon_port,
268 }
269 }
269
270
270 self.client_info = {
271 self.client_info = {
271 'control' : client_iface%self.control[0],
272 'control' : client_iface%self.control[0],
272 'mux': client_iface%self.mux[0],
273 'mux': client_iface%self.mux[0],
273 'task' : (scheme, client_iface%self.task[0]),
274 'task' : (scheme, client_iface%self.task[0]),
274 'iopub' : client_iface%self.iopub[0],
275 'iopub' : client_iface%self.iopub[0],
275 'notification': client_iface%self.notifier_port
276 'notification': client_iface%self.notifier_port
276 }
277 }
277 self.log.debug("Hub engine addrs: %s"%self.engine_info)
278 self.log.debug("Hub engine addrs: %s"%self.engine_info)
278 self.log.debug("Hub client addrs: %s"%self.client_info)
279 self.log.debug("Hub client addrs: %s"%self.client_info)
279
280
280 # resubmit stream
281 # resubmit stream
281 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
282 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
282 url = util.disambiguate_url(self.client_info['task'][-1])
283 url = util.disambiguate_url(self.client_info['task'][-1])
283 r.setsockopt(zmq.IDENTITY, self.session.session)
284 r.setsockopt(zmq.IDENTITY, self.session.session)
284 r.connect(url)
285 r.connect(url)
285
286
286 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
287 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
287 query=q, notifier=n, resubmit=r, db=self.db,
288 query=q, notifier=n, resubmit=r, db=self.db,
288 engine_info=self.engine_info, client_info=self.client_info,
289 engine_info=self.engine_info, client_info=self.client_info,
289 logname=self.log.name)
290 logname=self.log.name)
290
291
291
292
292 class Hub(LoggingFactory):
293 class Hub(LoggingFactory):
293 """The IPython Controller Hub with 0MQ connections
294 """The IPython Controller Hub with 0MQ connections
294
295
295 Parameters
296 Parameters
296 ==========
297 ==========
297 loop: zmq IOLoop instance
298 loop: zmq IOLoop instance
298 session: StreamSession object
299 session: Session object
299 <removed> context: zmq context for creating new connections (?)
300 <removed> context: zmq context for creating new connections (?)
300 queue: ZMQStream for monitoring the command queue (SUB)
301 queue: ZMQStream for monitoring the command queue (SUB)
301 query: ZMQStream for engine registration and client queries requests (XREP)
302 query: ZMQStream for engine registration and client queries requests (XREP)
302 heartbeat: HeartMonitor object checking the pulse of the engines
303 heartbeat: HeartMonitor object checking the pulse of the engines
303 notifier: ZMQStream for broadcasting engine registration changes (PUB)
304 notifier: ZMQStream for broadcasting engine registration changes (PUB)
304 db: connection to db for out of memory logging of commands
305 db: connection to db for out of memory logging of commands
305 NotImplemented
306 NotImplemented
306 engine_info: dict of zmq connection information for engines to connect
307 engine_info: dict of zmq connection information for engines to connect
307 to the queues.
308 to the queues.
308 client_info: dict of zmq connection information for engines to connect
309 client_info: dict of zmq connection information for engines to connect
309 to the queues.
310 to the queues.
310 """
311 """
311 # internal data structures:
312 # internal data structures:
312 ids=Set() # engine IDs
313 ids=Set() # engine IDs
313 keytable=Dict()
314 keytable=Dict()
314 by_ident=Dict()
315 by_ident=Dict()
315 engines=Dict()
316 engines=Dict()
316 clients=Dict()
317 clients=Dict()
317 hearts=Dict()
318 hearts=Dict()
318 pending=Set()
319 pending=Set()
319 queues=Dict() # pending msg_ids keyed by engine_id
320 queues=Dict() # pending msg_ids keyed by engine_id
320 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
321 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
321 completed=Dict() # completed msg_ids keyed by engine_id
322 completed=Dict() # completed msg_ids keyed by engine_id
322 all_completed=Set() # completed msg_ids keyed by engine_id
323 all_completed=Set() # completed msg_ids keyed by engine_id
323 dead_engines=Set() # completed msg_ids keyed by engine_id
324 dead_engines=Set() # completed msg_ids keyed by engine_id
324 unassigned=Set() # set of task msg_ds not yet assigned a destination
325 unassigned=Set() # set of task msg_ds not yet assigned a destination
325 incoming_registrations=Dict()
326 incoming_registrations=Dict()
326 registration_timeout=Int()
327 registration_timeout=Int()
327 _idcounter=Int(0)
328 _idcounter=Int(0)
328
329
329 # objects from constructor:
330 # objects from constructor:
330 loop=Instance(ioloop.IOLoop)
331 loop=Instance(ioloop.IOLoop)
331 query=Instance(ZMQStream)
332 query=Instance(ZMQStream)
332 monitor=Instance(ZMQStream)
333 monitor=Instance(ZMQStream)
333 notifier=Instance(ZMQStream)
334 notifier=Instance(ZMQStream)
334 resubmit=Instance(ZMQStream)
335 resubmit=Instance(ZMQStream)
335 heartmonitor=Instance(HeartMonitor)
336 heartmonitor=Instance(HeartMonitor)
336 db=Instance(object)
337 db=Instance(object)
337 client_info=Dict()
338 client_info=Dict()
338 engine_info=Dict()
339 engine_info=Dict()
339
340
340
341
341 def __init__(self, **kwargs):
342 def __init__(self, **kwargs):
342 """
343 """
343 # universal:
344 # universal:
344 loop: IOLoop for creating future connections
345 loop: IOLoop for creating future connections
345 session: streamsession for sending serialized data
346 session: streamsession for sending serialized data
346 # engine:
347 # engine:
347 queue: ZMQStream for monitoring queue messages
348 queue: ZMQStream for monitoring queue messages
348 query: ZMQStream for engine+client registration and client requests
349 query: ZMQStream for engine+client registration and client requests
349 heartbeat: HeartMonitor object for tracking engines
350 heartbeat: HeartMonitor object for tracking engines
350 # extra:
351 # extra:
351 db: ZMQStream for db connection (NotImplemented)
352 db: ZMQStream for db connection (NotImplemented)
352 engine_info: zmq address/protocol dict for engine connections
353 engine_info: zmq address/protocol dict for engine connections
353 client_info: zmq address/protocol dict for client connections
354 client_info: zmq address/protocol dict for client connections
354 """
355 """
355
356
356 super(Hub, self).__init__(**kwargs)
357 super(Hub, self).__init__(**kwargs)
357 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
358 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
358
359
359 # validate connection dicts:
360 # validate connection dicts:
360 for k,v in self.client_info.iteritems():
361 for k,v in self.client_info.iteritems():
361 if k == 'task':
362 if k == 'task':
362 util.validate_url_container(v[1])
363 util.validate_url_container(v[1])
363 else:
364 else:
364 util.validate_url_container(v)
365 util.validate_url_container(v)
365 # util.validate_url_container(self.client_info)
366 # util.validate_url_container(self.client_info)
366 util.validate_url_container(self.engine_info)
367 util.validate_url_container(self.engine_info)
367
368
368 # register our callbacks
369 # register our callbacks
369 self.query.on_recv(self.dispatch_query)
370 self.query.on_recv(self.dispatch_query)
370 self.monitor.on_recv(self.dispatch_monitor_traffic)
371 self.monitor.on_recv(self.dispatch_monitor_traffic)
371
372
372 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
373 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
373 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
374 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
374
375
375 self.monitor_handlers = { 'in' : self.save_queue_request,
376 self.monitor_handlers = { 'in' : self.save_queue_request,
376 'out': self.save_queue_result,
377 'out': self.save_queue_result,
377 'intask': self.save_task_request,
378 'intask': self.save_task_request,
378 'outtask': self.save_task_result,
379 'outtask': self.save_task_result,
379 'tracktask': self.save_task_destination,
380 'tracktask': self.save_task_destination,
380 'incontrol': _passer,
381 'incontrol': _passer,
381 'outcontrol': _passer,
382 'outcontrol': _passer,
382 'iopub': self.save_iopub_message,
383 'iopub': self.save_iopub_message,
383 }
384 }
384
385
385 self.query_handlers = {'queue_request': self.queue_status,
386 self.query_handlers = {'queue_request': self.queue_status,
386 'result_request': self.get_results,
387 'result_request': self.get_results,
387 'history_request': self.get_history,
388 'history_request': self.get_history,
388 'db_request': self.db_query,
389 'db_request': self.db_query,
389 'purge_request': self.purge_results,
390 'purge_request': self.purge_results,
390 'load_request': self.check_load,
391 'load_request': self.check_load,
391 'resubmit_request': self.resubmit_task,
392 'resubmit_request': self.resubmit_task,
392 'shutdown_request': self.shutdown_request,
393 'shutdown_request': self.shutdown_request,
393 'registration_request' : self.register_engine,
394 'registration_request' : self.register_engine,
394 'unregistration_request' : self.unregister_engine,
395 'unregistration_request' : self.unregister_engine,
395 'connection_request': self.connection_request,
396 'connection_request': self.connection_request,
396 }
397 }
397
398
398 # ignore resubmit replies
399 # ignore resubmit replies
399 self.resubmit.on_recv(lambda msg: None, copy=False)
400 self.resubmit.on_recv(lambda msg: None, copy=False)
400
401
401 self.log.info("hub::created hub")
402 self.log.info("hub::created hub")
402
403
403 @property
404 @property
404 def _next_id(self):
405 def _next_id(self):
405 """gemerate a new ID.
406 """gemerate a new ID.
406
407
407 No longer reuse old ids, just count from 0."""
408 No longer reuse old ids, just count from 0."""
408 newid = self._idcounter
409 newid = self._idcounter
409 self._idcounter += 1
410 self._idcounter += 1
410 return newid
411 return newid
411 # newid = 0
412 # newid = 0
412 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
413 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
413 # # print newid, self.ids, self.incoming_registrations
414 # # print newid, self.ids, self.incoming_registrations
414 # while newid in self.ids or newid in incoming:
415 # while newid in self.ids or newid in incoming:
415 # newid += 1
416 # newid += 1
416 # return newid
417 # return newid
417
418
418 #-----------------------------------------------------------------------------
419 #-----------------------------------------------------------------------------
419 # message validation
420 # message validation
420 #-----------------------------------------------------------------------------
421 #-----------------------------------------------------------------------------
421
422
422 def _validate_targets(self, targets):
423 def _validate_targets(self, targets):
423 """turn any valid targets argument into a list of integer ids"""
424 """turn any valid targets argument into a list of integer ids"""
424 if targets is None:
425 if targets is None:
425 # default to all
426 # default to all
426 targets = self.ids
427 targets = self.ids
427
428
428 if isinstance(targets, (int,str,unicode)):
429 if isinstance(targets, (int,str,unicode)):
429 # only one target specified
430 # only one target specified
430 targets = [targets]
431 targets = [targets]
431 _targets = []
432 _targets = []
432 for t in targets:
433 for t in targets:
433 # map raw identities to ids
434 # map raw identities to ids
434 if isinstance(t, (str,unicode)):
435 if isinstance(t, (str,unicode)):
435 t = self.by_ident.get(t, t)
436 t = self.by_ident.get(t, t)
436 _targets.append(t)
437 _targets.append(t)
437 targets = _targets
438 targets = _targets
438 bad_targets = [ t for t in targets if t not in self.ids ]
439 bad_targets = [ t for t in targets if t not in self.ids ]
439 if bad_targets:
440 if bad_targets:
440 raise IndexError("No Such Engine: %r"%bad_targets)
441 raise IndexError("No Such Engine: %r"%bad_targets)
441 if not targets:
442 if not targets:
442 raise IndexError("No Engines Registered")
443 raise IndexError("No Engines Registered")
443 return targets
444 return targets
444
445
445 #-----------------------------------------------------------------------------
446 #-----------------------------------------------------------------------------
446 # dispatch methods (1 per stream)
447 # dispatch methods (1 per stream)
447 #-----------------------------------------------------------------------------
448 #-----------------------------------------------------------------------------
448
449
449
450
450 def dispatch_monitor_traffic(self, msg):
451 def dispatch_monitor_traffic(self, msg):
451 """all ME and Task queue messages come through here, as well as
452 """all ME and Task queue messages come through here, as well as
452 IOPub traffic."""
453 IOPub traffic."""
453 self.log.debug("monitor traffic: %r"%msg[:2])
454 self.log.debug("monitor traffic: %r"%msg[:2])
454 switch = msg[0]
455 switch = msg[0]
455 try:
456 try:
456 idents, msg = self.session.feed_identities(msg[1:])
457 idents, msg = self.session.feed_identities(msg[1:])
457 except ValueError:
458 except ValueError:
458 idents=[]
459 idents=[]
459 if not idents:
460 if not idents:
460 self.log.error("Bad Monitor Message: %r"%msg)
461 self.log.error("Bad Monitor Message: %r"%msg)
461 return
462 return
462 handler = self.monitor_handlers.get(switch, None)
463 handler = self.monitor_handlers.get(switch, None)
463 if handler is not None:
464 if handler is not None:
464 handler(idents, msg)
465 handler(idents, msg)
465 else:
466 else:
466 self.log.error("Invalid monitor topic: %r"%switch)
467 self.log.error("Invalid monitor topic: %r"%switch)
467
468
468
469
469 def dispatch_query(self, msg):
470 def dispatch_query(self, msg):
470 """Route registration requests and queries from clients."""
471 """Route registration requests and queries from clients."""
471 try:
472 try:
472 idents, msg = self.session.feed_identities(msg)
473 idents, msg = self.session.feed_identities(msg)
473 except ValueError:
474 except ValueError:
474 idents = []
475 idents = []
475 if not idents:
476 if not idents:
476 self.log.error("Bad Query Message: %r"%msg)
477 self.log.error("Bad Query Message: %r"%msg)
477 return
478 return
478 client_id = idents[0]
479 client_id = idents[0]
479 try:
480 try:
480 msg = self.session.unpack_message(msg, content=True)
481 msg = self.session.unpack_message(msg, content=True)
481 except Exception:
482 except Exception:
482 content = error.wrap_exception()
483 content = error.wrap_exception()
483 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
484 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
484 self.session.send(self.query, "hub_error", ident=client_id,
485 self.session.send(self.query, "hub_error", ident=client_id,
485 content=content)
486 content=content)
486 return
487 return
487 print( idents, msg)
488 print( idents, msg)
488 # print client_id, header, parent, content
489 # print client_id, header, parent, content
489 #switch on message type:
490 #switch on message type:
490 msg_type = msg['msg_type']
491 msg_type = msg['msg_type']
491 self.log.info("client::client %r requested %r"%(client_id, msg_type))
492 self.log.info("client::client %r requested %r"%(client_id, msg_type))
492 handler = self.query_handlers.get(msg_type, None)
493 handler = self.query_handlers.get(msg_type, None)
493 try:
494 try:
494 assert handler is not None, "Bad Message Type: %r"%msg_type
495 assert handler is not None, "Bad Message Type: %r"%msg_type
495 except:
496 except:
496 content = error.wrap_exception()
497 content = error.wrap_exception()
497 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
498 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
498 self.session.send(self.query, "hub_error", ident=client_id,
499 self.session.send(self.query, "hub_error", ident=client_id,
499 content=content)
500 content=content)
500 return
501 return
501
502
502 else:
503 else:
503 handler(idents, msg)
504 handler(idents, msg)
504
505
505 def dispatch_db(self, msg):
506 def dispatch_db(self, msg):
506 """"""
507 """"""
507 raise NotImplementedError
508 raise NotImplementedError
508
509
509 #---------------------------------------------------------------------------
510 #---------------------------------------------------------------------------
510 # handler methods (1 per event)
511 # handler methods (1 per event)
511 #---------------------------------------------------------------------------
512 #---------------------------------------------------------------------------
512
513
513 #----------------------- Heartbeat --------------------------------------
514 #----------------------- Heartbeat --------------------------------------
514
515
515 def handle_new_heart(self, heart):
516 def handle_new_heart(self, heart):
516 """handler to attach to heartbeater.
517 """handler to attach to heartbeater.
517 Called when a new heart starts to beat.
518 Called when a new heart starts to beat.
518 Triggers completion of registration."""
519 Triggers completion of registration."""
519 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
520 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
520 if heart not in self.incoming_registrations:
521 if heart not in self.incoming_registrations:
521 self.log.info("heartbeat::ignoring new heart: %r"%heart)
522 self.log.info("heartbeat::ignoring new heart: %r"%heart)
522 else:
523 else:
523 self.finish_registration(heart)
524 self.finish_registration(heart)
524
525
525
526
526 def handle_heart_failure(self, heart):
527 def handle_heart_failure(self, heart):
527 """handler to attach to heartbeater.
528 """handler to attach to heartbeater.
528 called when a previously registered heart fails to respond to beat request.
529 called when a previously registered heart fails to respond to beat request.
529 triggers unregistration"""
530 triggers unregistration"""
530 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
531 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
531 eid = self.hearts.get(heart, None)
532 eid = self.hearts.get(heart, None)
532 queue = self.engines[eid].queue
533 queue = self.engines[eid].queue
533 if eid is None:
534 if eid is None:
534 self.log.info("heartbeat::ignoring heart failure %r"%heart)
535 self.log.info("heartbeat::ignoring heart failure %r"%heart)
535 else:
536 else:
536 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
537 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
537
538
538 #----------------------- MUX Queue Traffic ------------------------------
539 #----------------------- MUX Queue Traffic ------------------------------
539
540
540 def save_queue_request(self, idents, msg):
541 def save_queue_request(self, idents, msg):
541 if len(idents) < 2:
542 if len(idents) < 2:
542 self.log.error("invalid identity prefix: %r"%idents)
543 self.log.error("invalid identity prefix: %r"%idents)
543 return
544 return
544 queue_id, client_id = idents[:2]
545 queue_id, client_id = idents[:2]
545 try:
546 try:
546 msg = self.session.unpack_message(msg, content=False)
547 msg = self.session.unpack_message(msg, content=False)
547 except Exception:
548 except Exception:
548 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
549 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
549 return
550 return
550
551
551 eid = self.by_ident.get(queue_id, None)
552 eid = self.by_ident.get(queue_id, None)
552 if eid is None:
553 if eid is None:
553 self.log.error("queue::target %r not registered"%queue_id)
554 self.log.error("queue::target %r not registered"%queue_id)
554 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
555 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
555 return
556 return
556
557
557 header = msg['header']
558 header = msg['header']
558 msg_id = header['msg_id']
559 msg_id = header['msg_id']
559 record = init_record(msg)
560 record = init_record(msg)
560 record['engine_uuid'] = queue_id
561 record['engine_uuid'] = queue_id
561 record['client_uuid'] = client_id
562 record['client_uuid'] = client_id
562 record['queue'] = 'mux'
563 record['queue'] = 'mux'
563
564
564 try:
565 try:
565 # it's posible iopub arrived first:
566 # it's posible iopub arrived first:
566 existing = self.db.get_record(msg_id)
567 existing = self.db.get_record(msg_id)
567 for key,evalue in existing.iteritems():
568 for key,evalue in existing.iteritems():
568 rvalue = record.get(key, None)
569 rvalue = record.get(key, None)
569 if evalue and rvalue and evalue != rvalue:
570 if evalue and rvalue and evalue != rvalue:
570 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
571 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
571 elif evalue and not rvalue:
572 elif evalue and not rvalue:
572 record[key] = evalue
573 record[key] = evalue
573 self.db.update_record(msg_id, record)
574 self.db.update_record(msg_id, record)
574 except KeyError:
575 except KeyError:
575 self.db.add_record(msg_id, record)
576 self.db.add_record(msg_id, record)
576
577
577 self.pending.add(msg_id)
578 self.pending.add(msg_id)
578 self.queues[eid].append(msg_id)
579 self.queues[eid].append(msg_id)
579
580
580 def save_queue_result(self, idents, msg):
581 def save_queue_result(self, idents, msg):
581 if len(idents) < 2:
582 if len(idents) < 2:
582 self.log.error("invalid identity prefix: %r"%idents)
583 self.log.error("invalid identity prefix: %r"%idents)
583 return
584 return
584
585
585 client_id, queue_id = idents[:2]
586 client_id, queue_id = idents[:2]
586 try:
587 try:
587 msg = self.session.unpack_message(msg, content=False)
588 msg = self.session.unpack_message(msg, content=False)
588 except Exception:
589 except Exception:
589 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
590 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
590 queue_id,client_id, msg), exc_info=True)
591 queue_id,client_id, msg), exc_info=True)
591 return
592 return
592
593
593 eid = self.by_ident.get(queue_id, None)
594 eid = self.by_ident.get(queue_id, None)
594 if eid is None:
595 if eid is None:
595 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
596 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
596 return
597 return
597
598
598 parent = msg['parent_header']
599 parent = msg['parent_header']
599 if not parent:
600 if not parent:
600 return
601 return
601 msg_id = parent['msg_id']
602 msg_id = parent['msg_id']
602 if msg_id in self.pending:
603 if msg_id in self.pending:
603 self.pending.remove(msg_id)
604 self.pending.remove(msg_id)
604 self.all_completed.add(msg_id)
605 self.all_completed.add(msg_id)
605 self.queues[eid].remove(msg_id)
606 self.queues[eid].remove(msg_id)
606 self.completed[eid].append(msg_id)
607 self.completed[eid].append(msg_id)
607 elif msg_id not in self.all_completed:
608 elif msg_id not in self.all_completed:
608 # it could be a result from a dead engine that died before delivering the
609 # it could be a result from a dead engine that died before delivering the
609 # result
610 # result
610 self.log.warn("queue:: unknown msg finished %r"%msg_id)
611 self.log.warn("queue:: unknown msg finished %r"%msg_id)
611 return
612 return
612 # update record anyway, because the unregistration could have been premature
613 # update record anyway, because the unregistration could have been premature
613 rheader = msg['header']
614 rheader = extract_dates(msg['header'])
614 completed = datetime.strptime(rheader['date'], util.ISO8601)
615 completed = rheader['date']
615 started = rheader.get('started', None)
616 started = rheader.get('started', None)
616 if started is not None:
617 started = datetime.strptime(started, util.ISO8601)
618 result = {
617 result = {
619 'result_header' : rheader,
618 'result_header' : rheader,
620 'result_content': msg['content'],
619 'result_content': msg['content'],
621 'started' : started,
620 'started' : started,
622 'completed' : completed
621 'completed' : completed
623 }
622 }
624
623
625 result['result_buffers'] = msg['buffers']
624 result['result_buffers'] = msg['buffers']
626 try:
625 try:
627 self.db.update_record(msg_id, result)
626 self.db.update_record(msg_id, result)
628 except Exception:
627 except Exception:
629 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
628 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
630
629
631
630
632 #--------------------- Task Queue Traffic ------------------------------
631 #--------------------- Task Queue Traffic ------------------------------
633
632
634 def save_task_request(self, idents, msg):
633 def save_task_request(self, idents, msg):
635 """Save the submission of a task."""
634 """Save the submission of a task."""
636 client_id = idents[0]
635 client_id = idents[0]
637
636
638 try:
637 try:
639 msg = self.session.unpack_message(msg, content=False)
638 msg = self.session.unpack_message(msg, content=False)
640 except Exception:
639 except Exception:
641 self.log.error("task::client %r sent invalid task message: %r"%(
640 self.log.error("task::client %r sent invalid task message: %r"%(
642 client_id, msg), exc_info=True)
641 client_id, msg), exc_info=True)
643 return
642 return
644 record = init_record(msg)
643 record = init_record(msg)
645
644
646 record['client_uuid'] = client_id
645 record['client_uuid'] = client_id
647 record['queue'] = 'task'
646 record['queue'] = 'task'
648 header = msg['header']
647 header = msg['header']
649 msg_id = header['msg_id']
648 msg_id = header['msg_id']
650 self.pending.add(msg_id)
649 self.pending.add(msg_id)
651 self.unassigned.add(msg_id)
650 self.unassigned.add(msg_id)
652 try:
651 try:
653 # it's posible iopub arrived first:
652 # it's posible iopub arrived first:
654 existing = self.db.get_record(msg_id)
653 existing = self.db.get_record(msg_id)
655 if existing['resubmitted']:
654 if existing['resubmitted']:
656 for key in ('submitted', 'client_uuid', 'buffers'):
655 for key in ('submitted', 'client_uuid', 'buffers'):
657 # don't clobber these keys on resubmit
656 # don't clobber these keys on resubmit
658 # submitted and client_uuid should be different
657 # submitted and client_uuid should be different
659 # and buffers might be big, and shouldn't have changed
658 # and buffers might be big, and shouldn't have changed
660 record.pop(key)
659 record.pop(key)
661 # still check content,header which should not change
660 # still check content,header which should not change
662 # but are not expensive to compare as buffers
661 # but are not expensive to compare as buffers
663
662
664 for key,evalue in existing.iteritems():
663 for key,evalue in existing.iteritems():
665 if key.endswith('buffers'):
664 if key.endswith('buffers'):
666 # don't compare buffers
665 # don't compare buffers
667 continue
666 continue
668 rvalue = record.get(key, None)
667 rvalue = record.get(key, None)
669 if evalue and rvalue and evalue != rvalue:
668 if evalue and rvalue and evalue != rvalue:
670 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
669 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
671 elif evalue and not rvalue:
670 elif evalue and not rvalue:
672 record[key] = evalue
671 record[key] = evalue
673 self.db.update_record(msg_id, record)
672 self.db.update_record(msg_id, record)
674 except KeyError:
673 except KeyError:
675 self.db.add_record(msg_id, record)
674 self.db.add_record(msg_id, record)
676 except Exception:
675 except Exception:
677 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
676 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
678
677
679 def save_task_result(self, idents, msg):
678 def save_task_result(self, idents, msg):
680 """save the result of a completed task."""
679 """save the result of a completed task."""
681 client_id = idents[0]
680 client_id = idents[0]
682 try:
681 try:
683 msg = self.session.unpack_message(msg, content=False)
682 msg = self.session.unpack_message(msg, content=False)
684 except Exception:
683 except Exception:
685 self.log.error("task::invalid task result message send to %r: %r"%(
684 self.log.error("task::invalid task result message send to %r: %r"%(
686 client_id, msg), exc_info=True)
685 client_id, msg), exc_info=True)
687 return
686 return
688
687
689 parent = msg['parent_header']
688 parent = msg['parent_header']
690 if not parent:
689 if not parent:
691 # print msg
690 # print msg
692 self.log.warn("Task %r had no parent!"%msg)
691 self.log.warn("Task %r had no parent!"%msg)
693 return
692 return
694 msg_id = parent['msg_id']
693 msg_id = parent['msg_id']
695 if msg_id in self.unassigned:
694 if msg_id in self.unassigned:
696 self.unassigned.remove(msg_id)
695 self.unassigned.remove(msg_id)
697
696
698 header = msg['header']
697 header = extract_dates(msg['header'])
699 engine_uuid = header.get('engine', None)
698 engine_uuid = header.get('engine', None)
700 eid = self.by_ident.get(engine_uuid, None)
699 eid = self.by_ident.get(engine_uuid, None)
701
700
702 if msg_id in self.pending:
701 if msg_id in self.pending:
703 self.pending.remove(msg_id)
702 self.pending.remove(msg_id)
704 self.all_completed.add(msg_id)
703 self.all_completed.add(msg_id)
705 if eid is not None:
704 if eid is not None:
706 self.completed[eid].append(msg_id)
705 self.completed[eid].append(msg_id)
707 if msg_id in self.tasks[eid]:
706 if msg_id in self.tasks[eid]:
708 self.tasks[eid].remove(msg_id)
707 self.tasks[eid].remove(msg_id)
709 completed = datetime.strptime(header['date'], util.ISO8601)
708 completed = header['date']
710 started = header.get('started', None)
709 started = header.get('started', None)
711 if started is not None:
712 started = datetime.strptime(started, util.ISO8601)
713 result = {
710 result = {
714 'result_header' : header,
711 'result_header' : header,
715 'result_content': msg['content'],
712 'result_content': msg['content'],
716 'started' : started,
713 'started' : started,
717 'completed' : completed,
714 'completed' : completed,
718 'engine_uuid': engine_uuid
715 'engine_uuid': engine_uuid
719 }
716 }
720
717
721 result['result_buffers'] = msg['buffers']
718 result['result_buffers'] = msg['buffers']
722 try:
719 try:
723 self.db.update_record(msg_id, result)
720 self.db.update_record(msg_id, result)
724 except Exception:
721 except Exception:
725 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
722 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
726
723
727 else:
724 else:
728 self.log.debug("task::unknown task %r finished"%msg_id)
725 self.log.debug("task::unknown task %r finished"%msg_id)
729
726
730 def save_task_destination(self, idents, msg):
727 def save_task_destination(self, idents, msg):
731 try:
728 try:
732 msg = self.session.unpack_message(msg, content=True)
729 msg = self.session.unpack_message(msg, content=True)
733 except Exception:
730 except Exception:
734 self.log.error("task::invalid task tracking message", exc_info=True)
731 self.log.error("task::invalid task tracking message", exc_info=True)
735 return
732 return
736 content = msg['content']
733 content = msg['content']
737 # print (content)
734 # print (content)
738 msg_id = content['msg_id']
735 msg_id = content['msg_id']
739 engine_uuid = content['engine_id']
736 engine_uuid = content['engine_id']
740 eid = self.by_ident[engine_uuid]
737 eid = self.by_ident[engine_uuid]
741
738
742 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
739 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
743 if msg_id in self.unassigned:
740 if msg_id in self.unassigned:
744 self.unassigned.remove(msg_id)
741 self.unassigned.remove(msg_id)
745 # else:
742 # else:
746 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
743 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
747
744
748 self.tasks[eid].append(msg_id)
745 self.tasks[eid].append(msg_id)
749 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
746 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
750 try:
747 try:
751 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
748 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
752 except Exception:
749 except Exception:
753 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
750 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
754
751
755
752
756 def mia_task_request(self, idents, msg):
753 def mia_task_request(self, idents, msg):
757 raise NotImplementedError
754 raise NotImplementedError
758 client_id = idents[0]
755 client_id = idents[0]
759 # content = dict(mia=self.mia,status='ok')
756 # content = dict(mia=self.mia,status='ok')
760 # self.session.send('mia_reply', content=content, idents=client_id)
757 # self.session.send('mia_reply', content=content, idents=client_id)
761
758
762
759
763 #--------------------- IOPub Traffic ------------------------------
760 #--------------------- IOPub Traffic ------------------------------
764
761
765 def save_iopub_message(self, topics, msg):
762 def save_iopub_message(self, topics, msg):
766 """save an iopub message into the db"""
763 """save an iopub message into the db"""
767 # print (topics)
764 # print (topics)
768 try:
765 try:
769 msg = self.session.unpack_message(msg, content=True)
766 msg = self.session.unpack_message(msg, content=True)
770 except Exception:
767 except Exception:
771 self.log.error("iopub::invalid IOPub message", exc_info=True)
768 self.log.error("iopub::invalid IOPub message", exc_info=True)
772 return
769 return
773
770
774 parent = msg['parent_header']
771 parent = msg['parent_header']
775 if not parent:
772 if not parent:
776 self.log.error("iopub::invalid IOPub message: %r"%msg)
773 self.log.error("iopub::invalid IOPub message: %r"%msg)
777 return
774 return
778 msg_id = parent['msg_id']
775 msg_id = parent['msg_id']
779 msg_type = msg['msg_type']
776 msg_type = msg['msg_type']
780 content = msg['content']
777 content = msg['content']
781
778
782 # ensure msg_id is in db
779 # ensure msg_id is in db
783 try:
780 try:
784 rec = self.db.get_record(msg_id)
781 rec = self.db.get_record(msg_id)
785 except KeyError:
782 except KeyError:
786 rec = empty_record()
783 rec = empty_record()
787 rec['msg_id'] = msg_id
784 rec['msg_id'] = msg_id
788 self.db.add_record(msg_id, rec)
785 self.db.add_record(msg_id, rec)
789 # stream
786 # stream
790 d = {}
787 d = {}
791 if msg_type == 'stream':
788 if msg_type == 'stream':
792 name = content['name']
789 name = content['name']
793 s = rec[name] or ''
790 s = rec[name] or ''
794 d[name] = s + content['data']
791 d[name] = s + content['data']
795
792
796 elif msg_type == 'pyerr':
793 elif msg_type == 'pyerr':
797 d['pyerr'] = content
794 d['pyerr'] = content
798 elif msg_type == 'pyin':
795 elif msg_type == 'pyin':
799 d['pyin'] = content['code']
796 d['pyin'] = content['code']
800 else:
797 else:
801 d[msg_type] = content.get('data', '')
798 d[msg_type] = content.get('data', '')
802
799
803 try:
800 try:
804 self.db.update_record(msg_id, d)
801 self.db.update_record(msg_id, d)
805 except Exception:
802 except Exception:
806 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
803 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
807
804
808
805
809
806
810 #-------------------------------------------------------------------------
807 #-------------------------------------------------------------------------
811 # Registration requests
808 # Registration requests
812 #-------------------------------------------------------------------------
809 #-------------------------------------------------------------------------
813
810
814 def connection_request(self, client_id, msg):
811 def connection_request(self, client_id, msg):
815 """Reply with connection addresses for clients."""
812 """Reply with connection addresses for clients."""
816 self.log.info("client::client %r connected"%client_id)
813 self.log.info("client::client %r connected"%client_id)
817 content = dict(status='ok')
814 content = dict(status='ok')
818 content.update(self.client_info)
815 content.update(self.client_info)
819 jsonable = {}
816 jsonable = {}
820 for k,v in self.keytable.iteritems():
817 for k,v in self.keytable.iteritems():
821 if v not in self.dead_engines:
818 if v not in self.dead_engines:
822 jsonable[str(k)] = v
819 jsonable[str(k)] = v
823 content['engines'] = jsonable
820 content['engines'] = jsonable
824 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
821 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
825
822
826 def register_engine(self, reg, msg):
823 def register_engine(self, reg, msg):
827 """Register a new engine."""
824 """Register a new engine."""
828 content = msg['content']
825 content = msg['content']
829 try:
826 try:
830 queue = content['queue']
827 queue = content['queue']
831 except KeyError:
828 except KeyError:
832 self.log.error("registration::queue not specified", exc_info=True)
829 self.log.error("registration::queue not specified", exc_info=True)
833 return
830 return
834 heart = content.get('heartbeat', None)
831 heart = content.get('heartbeat', None)
835 """register a new engine, and create the socket(s) necessary"""
832 """register a new engine, and create the socket(s) necessary"""
836 eid = self._next_id
833 eid = self._next_id
837 # print (eid, queue, reg, heart)
834 # print (eid, queue, reg, heart)
838
835
839 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
836 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
840
837
841 content = dict(id=eid,status='ok')
838 content = dict(id=eid,status='ok')
842 content.update(self.engine_info)
839 content.update(self.engine_info)
843 # check if requesting available IDs:
840 # check if requesting available IDs:
844 if queue in self.by_ident:
841 if queue in self.by_ident:
845 try:
842 try:
846 raise KeyError("queue_id %r in use"%queue)
843 raise KeyError("queue_id %r in use"%queue)
847 except:
844 except:
848 content = error.wrap_exception()
845 content = error.wrap_exception()
849 self.log.error("queue_id %r in use"%queue, exc_info=True)
846 self.log.error("queue_id %r in use"%queue, exc_info=True)
850 elif heart in self.hearts: # need to check unique hearts?
847 elif heart in self.hearts: # need to check unique hearts?
851 try:
848 try:
852 raise KeyError("heart_id %r in use"%heart)
849 raise KeyError("heart_id %r in use"%heart)
853 except:
850 except:
854 self.log.error("heart_id %r in use"%heart, exc_info=True)
851 self.log.error("heart_id %r in use"%heart, exc_info=True)
855 content = error.wrap_exception()
852 content = error.wrap_exception()
856 else:
853 else:
857 for h, pack in self.incoming_registrations.iteritems():
854 for h, pack in self.incoming_registrations.iteritems():
858 if heart == h:
855 if heart == h:
859 try:
856 try:
860 raise KeyError("heart_id %r in use"%heart)
857 raise KeyError("heart_id %r in use"%heart)
861 except:
858 except:
862 self.log.error("heart_id %r in use"%heart, exc_info=True)
859 self.log.error("heart_id %r in use"%heart, exc_info=True)
863 content = error.wrap_exception()
860 content = error.wrap_exception()
864 break
861 break
865 elif queue == pack[1]:
862 elif queue == pack[1]:
866 try:
863 try:
867 raise KeyError("queue_id %r in use"%queue)
864 raise KeyError("queue_id %r in use"%queue)
868 except:
865 except:
869 self.log.error("queue_id %r in use"%queue, exc_info=True)
866 self.log.error("queue_id %r in use"%queue, exc_info=True)
870 content = error.wrap_exception()
867 content = error.wrap_exception()
871 break
868 break
872
869
873 msg = self.session.send(self.query, "registration_reply",
870 msg = self.session.send(self.query, "registration_reply",
874 content=content,
871 content=content,
875 ident=reg)
872 ident=reg)
876
873
877 if content['status'] == 'ok':
874 if content['status'] == 'ok':
878 if heart in self.heartmonitor.hearts:
875 if heart in self.heartmonitor.hearts:
879 # already beating
876 # already beating
880 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
877 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
881 self.finish_registration(heart)
878 self.finish_registration(heart)
882 else:
879 else:
883 purge = lambda : self._purge_stalled_registration(heart)
880 purge = lambda : self._purge_stalled_registration(heart)
884 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
881 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
885 dc.start()
882 dc.start()
886 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
883 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
887 else:
884 else:
888 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
885 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
889 return eid
886 return eid
890
887
891 def unregister_engine(self, ident, msg):
888 def unregister_engine(self, ident, msg):
892 """Unregister an engine that explicitly requested to leave."""
889 """Unregister an engine that explicitly requested to leave."""
893 try:
890 try:
894 eid = msg['content']['id']
891 eid = msg['content']['id']
895 except:
892 except:
896 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
893 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
897 return
894 return
898 self.log.info("registration::unregister_engine(%r)"%eid)
895 self.log.info("registration::unregister_engine(%r)"%eid)
899 # print (eid)
896 # print (eid)
900 uuid = self.keytable[eid]
897 uuid = self.keytable[eid]
901 content=dict(id=eid, queue=uuid)
898 content=dict(id=eid, queue=uuid)
902 self.dead_engines.add(uuid)
899 self.dead_engines.add(uuid)
903 # self.ids.remove(eid)
900 # self.ids.remove(eid)
904 # uuid = self.keytable.pop(eid)
901 # uuid = self.keytable.pop(eid)
905 #
902 #
906 # ec = self.engines.pop(eid)
903 # ec = self.engines.pop(eid)
907 # self.hearts.pop(ec.heartbeat)
904 # self.hearts.pop(ec.heartbeat)
908 # self.by_ident.pop(ec.queue)
905 # self.by_ident.pop(ec.queue)
909 # self.completed.pop(eid)
906 # self.completed.pop(eid)
910 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
907 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
911 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
908 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
912 dc.start()
909 dc.start()
913 ############## TODO: HANDLE IT ################
910 ############## TODO: HANDLE IT ################
914
911
915 if self.notifier:
912 if self.notifier:
916 self.session.send(self.notifier, "unregistration_notification", content=content)
913 self.session.send(self.notifier, "unregistration_notification", content=content)
917
914
918 def _handle_stranded_msgs(self, eid, uuid):
915 def _handle_stranded_msgs(self, eid, uuid):
919 """Handle messages known to be on an engine when the engine unregisters.
916 """Handle messages known to be on an engine when the engine unregisters.
920
917
921 It is possible that this will fire prematurely - that is, an engine will
918 It is possible that this will fire prematurely - that is, an engine will
922 go down after completing a result, and the client will be notified
919 go down after completing a result, and the client will be notified
923 that the result failed and later receive the actual result.
920 that the result failed and later receive the actual result.
924 """
921 """
925
922
926 outstanding = self.queues[eid]
923 outstanding = self.queues[eid]
927
924
928 for msg_id in outstanding:
925 for msg_id in outstanding:
929 self.pending.remove(msg_id)
926 self.pending.remove(msg_id)
930 self.all_completed.add(msg_id)
927 self.all_completed.add(msg_id)
931 try:
928 try:
932 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
929 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
933 except:
930 except:
934 content = error.wrap_exception()
931 content = error.wrap_exception()
935 # build a fake header:
932 # build a fake header:
936 header = {}
933 header = {}
937 header['engine'] = uuid
934 header['engine'] = uuid
938 header['date'] = datetime.now()
935 header['date'] = datetime.now()
939 rec = dict(result_content=content, result_header=header, result_buffers=[])
936 rec = dict(result_content=content, result_header=header, result_buffers=[])
940 rec['completed'] = header['date']
937 rec['completed'] = header['date']
941 rec['engine_uuid'] = uuid
938 rec['engine_uuid'] = uuid
942 try:
939 try:
943 self.db.update_record(msg_id, rec)
940 self.db.update_record(msg_id, rec)
944 except Exception:
941 except Exception:
945 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
942 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
946
943
947
944
948 def finish_registration(self, heart):
945 def finish_registration(self, heart):
949 """Second half of engine registration, called after our HeartMonitor
946 """Second half of engine registration, called after our HeartMonitor
950 has received a beat from the Engine's Heart."""
947 has received a beat from the Engine's Heart."""
951 try:
948 try:
952 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
949 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
953 except KeyError:
950 except KeyError:
954 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
951 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
955 return
952 return
956 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
953 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
957 if purge is not None:
954 if purge is not None:
958 purge.stop()
955 purge.stop()
959 control = queue
956 control = queue
960 self.ids.add(eid)
957 self.ids.add(eid)
961 self.keytable[eid] = queue
958 self.keytable[eid] = queue
962 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
959 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
963 control=control, heartbeat=heart)
960 control=control, heartbeat=heart)
964 self.by_ident[queue] = eid
961 self.by_ident[queue] = eid
965 self.queues[eid] = list()
962 self.queues[eid] = list()
966 self.tasks[eid] = list()
963 self.tasks[eid] = list()
967 self.completed[eid] = list()
964 self.completed[eid] = list()
968 self.hearts[heart] = eid
965 self.hearts[heart] = eid
969 content = dict(id=eid, queue=self.engines[eid].queue)
966 content = dict(id=eid, queue=self.engines[eid].queue)
970 if self.notifier:
967 if self.notifier:
971 self.session.send(self.notifier, "registration_notification", content=content)
968 self.session.send(self.notifier, "registration_notification", content=content)
972 self.log.info("engine::Engine Connected: %i"%eid)
969 self.log.info("engine::Engine Connected: %i"%eid)
973
970
974 def _purge_stalled_registration(self, heart):
971 def _purge_stalled_registration(self, heart):
975 if heart in self.incoming_registrations:
972 if heart in self.incoming_registrations:
976 eid = self.incoming_registrations.pop(heart)[0]
973 eid = self.incoming_registrations.pop(heart)[0]
977 self.log.info("registration::purging stalled registration: %i"%eid)
974 self.log.info("registration::purging stalled registration: %i"%eid)
978 else:
975 else:
979 pass
976 pass
980
977
981 #-------------------------------------------------------------------------
978 #-------------------------------------------------------------------------
982 # Client Requests
979 # Client Requests
983 #-------------------------------------------------------------------------
980 #-------------------------------------------------------------------------
984
981
985 def shutdown_request(self, client_id, msg):
982 def shutdown_request(self, client_id, msg):
986 """handle shutdown request."""
983 """handle shutdown request."""
987 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
984 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
988 # also notify other clients of shutdown
985 # also notify other clients of shutdown
989 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
986 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
990 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
987 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
991 dc.start()
988 dc.start()
992
989
993 def _shutdown(self):
990 def _shutdown(self):
994 self.log.info("hub::hub shutting down.")
991 self.log.info("hub::hub shutting down.")
995 time.sleep(0.1)
992 time.sleep(0.1)
996 sys.exit(0)
993 sys.exit(0)
997
994
998
995
999 def check_load(self, client_id, msg):
996 def check_load(self, client_id, msg):
1000 content = msg['content']
997 content = msg['content']
1001 try:
998 try:
1002 targets = content['targets']
999 targets = content['targets']
1003 targets = self._validate_targets(targets)
1000 targets = self._validate_targets(targets)
1004 except:
1001 except:
1005 content = error.wrap_exception()
1002 content = error.wrap_exception()
1006 self.session.send(self.query, "hub_error",
1003 self.session.send(self.query, "hub_error",
1007 content=content, ident=client_id)
1004 content=content, ident=client_id)
1008 return
1005 return
1009
1006
1010 content = dict(status='ok')
1007 content = dict(status='ok')
1011 # loads = {}
1008 # loads = {}
1012 for t in targets:
1009 for t in targets:
1013 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1010 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1014 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1011 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1015
1012
1016
1013
1017 def queue_status(self, client_id, msg):
1014 def queue_status(self, client_id, msg):
1018 """Return the Queue status of one or more targets.
1015 """Return the Queue status of one or more targets.
1019 if verbose: return the msg_ids
1016 if verbose: return the msg_ids
1020 else: return len of each type.
1017 else: return len of each type.
1021 keys: queue (pending MUX jobs)
1018 keys: queue (pending MUX jobs)
1022 tasks (pending Task jobs)
1019 tasks (pending Task jobs)
1023 completed (finished jobs from both queues)"""
1020 completed (finished jobs from both queues)"""
1024 content = msg['content']
1021 content = msg['content']
1025 targets = content['targets']
1022 targets = content['targets']
1026 try:
1023 try:
1027 targets = self._validate_targets(targets)
1024 targets = self._validate_targets(targets)
1028 except:
1025 except:
1029 content = error.wrap_exception()
1026 content = error.wrap_exception()
1030 self.session.send(self.query, "hub_error",
1027 self.session.send(self.query, "hub_error",
1031 content=content, ident=client_id)
1028 content=content, ident=client_id)
1032 return
1029 return
1033 verbose = content.get('verbose', False)
1030 verbose = content.get('verbose', False)
1034 content = dict(status='ok')
1031 content = dict(status='ok')
1035 for t in targets:
1032 for t in targets:
1036 queue = self.queues[t]
1033 queue = self.queues[t]
1037 completed = self.completed[t]
1034 completed = self.completed[t]
1038 tasks = self.tasks[t]
1035 tasks = self.tasks[t]
1039 if not verbose:
1036 if not verbose:
1040 queue = len(queue)
1037 queue = len(queue)
1041 completed = len(completed)
1038 completed = len(completed)
1042 tasks = len(tasks)
1039 tasks = len(tasks)
1043 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1040 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1044 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1041 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1045
1042
1046 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1043 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1047
1044
1048 def purge_results(self, client_id, msg):
1045 def purge_results(self, client_id, msg):
1049 """Purge results from memory. This method is more valuable before we move
1046 """Purge results from memory. This method is more valuable before we move
1050 to a DB based message storage mechanism."""
1047 to a DB based message storage mechanism."""
1051 content = msg['content']
1048 content = msg['content']
1052 msg_ids = content.get('msg_ids', [])
1049 msg_ids = content.get('msg_ids', [])
1053 reply = dict(status='ok')
1050 reply = dict(status='ok')
1054 if msg_ids == 'all':
1051 if msg_ids == 'all':
1055 try:
1052 try:
1056 self.db.drop_matching_records(dict(completed={'$ne':None}))
1053 self.db.drop_matching_records(dict(completed={'$ne':None}))
1057 except Exception:
1054 except Exception:
1058 reply = error.wrap_exception()
1055 reply = error.wrap_exception()
1059 else:
1056 else:
1060 pending = filter(lambda m: m in self.pending, msg_ids)
1057 pending = filter(lambda m: m in self.pending, msg_ids)
1061 if pending:
1058 if pending:
1062 try:
1059 try:
1063 raise IndexError("msg pending: %r"%pending[0])
1060 raise IndexError("msg pending: %r"%pending[0])
1064 except:
1061 except:
1065 reply = error.wrap_exception()
1062 reply = error.wrap_exception()
1066 else:
1063 else:
1067 try:
1064 try:
1068 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1065 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1069 except Exception:
1066 except Exception:
1070 reply = error.wrap_exception()
1067 reply = error.wrap_exception()
1071
1068
1072 if reply['status'] == 'ok':
1069 if reply['status'] == 'ok':
1073 eids = content.get('engine_ids', [])
1070 eids = content.get('engine_ids', [])
1074 for eid in eids:
1071 for eid in eids:
1075 if eid not in self.engines:
1072 if eid not in self.engines:
1076 try:
1073 try:
1077 raise IndexError("No such engine: %i"%eid)
1074 raise IndexError("No such engine: %i"%eid)
1078 except:
1075 except:
1079 reply = error.wrap_exception()
1076 reply = error.wrap_exception()
1080 break
1077 break
1081 msg_ids = self.completed.pop(eid)
1078 msg_ids = self.completed.pop(eid)
1082 uid = self.engines[eid].queue
1079 uid = self.engines[eid].queue
1083 try:
1080 try:
1084 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1081 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1085 except Exception:
1082 except Exception:
1086 reply = error.wrap_exception()
1083 reply = error.wrap_exception()
1087 break
1084 break
1088
1085
1089 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1086 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1090
1087
1091 def resubmit_task(self, client_id, msg):
1088 def resubmit_task(self, client_id, msg):
1092 """Resubmit one or more tasks."""
1089 """Resubmit one or more tasks."""
1093 def finish(reply):
1090 def finish(reply):
1094 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1091 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1095
1092
1096 content = msg['content']
1093 content = msg['content']
1097 msg_ids = content['msg_ids']
1094 msg_ids = content['msg_ids']
1098 reply = dict(status='ok')
1095 reply = dict(status='ok')
1099 try:
1096 try:
1100 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1097 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1101 'header', 'content', 'buffers'])
1098 'header', 'content', 'buffers'])
1102 except Exception:
1099 except Exception:
1103 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1100 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1104 return finish(error.wrap_exception())
1101 return finish(error.wrap_exception())
1105
1102
1106 # validate msg_ids
1103 # validate msg_ids
1107 found_ids = [ rec['msg_id'] for rec in records ]
1104 found_ids = [ rec['msg_id'] for rec in records ]
1108 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1105 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1109 if len(records) > len(msg_ids):
1106 if len(records) > len(msg_ids):
1110 try:
1107 try:
1111 raise RuntimeError("DB appears to be in an inconsistent state."
1108 raise RuntimeError("DB appears to be in an inconsistent state."
1112 "More matching records were found than should exist")
1109 "More matching records were found than should exist")
1113 except Exception:
1110 except Exception:
1114 return finish(error.wrap_exception())
1111 return finish(error.wrap_exception())
1115 elif len(records) < len(msg_ids):
1112 elif len(records) < len(msg_ids):
1116 missing = [ m for m in msg_ids if m not in found_ids ]
1113 missing = [ m for m in msg_ids if m not in found_ids ]
1117 try:
1114 try:
1118 raise KeyError("No such msg(s): %r"%missing)
1115 raise KeyError("No such msg(s): %r"%missing)
1119 except KeyError:
1116 except KeyError:
1120 return finish(error.wrap_exception())
1117 return finish(error.wrap_exception())
1121 elif invalid_ids:
1118 elif invalid_ids:
1122 msg_id = invalid_ids[0]
1119 msg_id = invalid_ids[0]
1123 try:
1120 try:
1124 raise ValueError("Task %r appears to be inflight"%(msg_id))
1121 raise ValueError("Task %r appears to be inflight"%(msg_id))
1125 except Exception:
1122 except Exception:
1126 return finish(error.wrap_exception())
1123 return finish(error.wrap_exception())
1127
1124
1128 # clear the existing records
1125 # clear the existing records
1129 now = datetime.now()
1126 now = datetime.now()
1130 rec = empty_record()
1127 rec = empty_record()
1131 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1128 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1132 rec['resubmitted'] = now
1129 rec['resubmitted'] = now
1133 rec['queue'] = 'task'
1130 rec['queue'] = 'task'
1134 rec['client_uuid'] = client_id[0]
1131 rec['client_uuid'] = client_id[0]
1135 try:
1132 try:
1136 for msg_id in msg_ids:
1133 for msg_id in msg_ids:
1137 self.all_completed.discard(msg_id)
1134 self.all_completed.discard(msg_id)
1138 self.db.update_record(msg_id, rec)
1135 self.db.update_record(msg_id, rec)
1139 except Exception:
1136 except Exception:
1140 self.log.error('db::db error upating record', exc_info=True)
1137 self.log.error('db::db error upating record', exc_info=True)
1141 reply = error.wrap_exception()
1138 reply = error.wrap_exception()
1142 else:
1139 else:
1143 # send the messages
1140 # send the messages
1144 now_s = now.strftime(util.ISO8601)
1141 now_s = now.strftime(ISO8601)
1145 for rec in records:
1142 for rec in records:
1146 header = rec['header']
1143 header = rec['header']
1147 # include resubmitted in header to prevent digest collision
1144 # include resubmitted in header to prevent digest collision
1148 header['resubmitted'] = now_s
1145 header['resubmitted'] = now_s
1149 msg = self.session.msg(header['msg_type'])
1146 msg = self.session.msg(header['msg_type'])
1150 msg['content'] = rec['content']
1147 msg['content'] = rec['content']
1151 msg['header'] = header
1148 msg['header'] = header
1152 msg['msg_id'] = rec['msg_id']
1149 msg['msg_id'] = rec['msg_id']
1153 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1150 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1154
1151
1155 finish(dict(status='ok'))
1152 finish(dict(status='ok'))
1156
1153
1157
1154
1158 def _extract_record(self, rec):
1155 def _extract_record(self, rec):
1159 """decompose a TaskRecord dict into subsection of reply for get_result"""
1156 """decompose a TaskRecord dict into subsection of reply for get_result"""
1160 io_dict = {}
1157 io_dict = {}
1161 for key in 'pyin pyout pyerr stdout stderr'.split():
1158 for key in 'pyin pyout pyerr stdout stderr'.split():
1162 io_dict[key] = rec[key]
1159 io_dict[key] = rec[key]
1163 content = { 'result_content': rec['result_content'],
1160 content = { 'result_content': rec['result_content'],
1164 'header': rec['header'],
1161 'header': rec['header'],
1165 'result_header' : rec['result_header'],
1162 'result_header' : rec['result_header'],
1166 'io' : io_dict,
1163 'io' : io_dict,
1167 }
1164 }
1168 if rec['result_buffers']:
1165 if rec['result_buffers']:
1169 buffers = map(str, rec['result_buffers'])
1166 buffers = map(str, rec['result_buffers'])
1170 else:
1167 else:
1171 buffers = []
1168 buffers = []
1172
1169
1173 return content, buffers
1170 return content, buffers
1174
1171
1175 def get_results(self, client_id, msg):
1172 def get_results(self, client_id, msg):
1176 """Get the result of 1 or more messages."""
1173 """Get the result of 1 or more messages."""
1177 content = msg['content']
1174 content = msg['content']
1178 msg_ids = sorted(set(content['msg_ids']))
1175 msg_ids = sorted(set(content['msg_ids']))
1179 statusonly = content.get('status_only', False)
1176 statusonly = content.get('status_only', False)
1180 pending = []
1177 pending = []
1181 completed = []
1178 completed = []
1182 content = dict(status='ok')
1179 content = dict(status='ok')
1183 content['pending'] = pending
1180 content['pending'] = pending
1184 content['completed'] = completed
1181 content['completed'] = completed
1185 buffers = []
1182 buffers = []
1186 if not statusonly:
1183 if not statusonly:
1187 try:
1184 try:
1188 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1185 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1189 # turn match list into dict, for faster lookup
1186 # turn match list into dict, for faster lookup
1190 records = {}
1187 records = {}
1191 for rec in matches:
1188 for rec in matches:
1192 records[rec['msg_id']] = rec
1189 records[rec['msg_id']] = rec
1193 except Exception:
1190 except Exception:
1194 content = error.wrap_exception()
1191 content = error.wrap_exception()
1195 self.session.send(self.query, "result_reply", content=content,
1192 self.session.send(self.query, "result_reply", content=content,
1196 parent=msg, ident=client_id)
1193 parent=msg, ident=client_id)
1197 return
1194 return
1198 else:
1195 else:
1199 records = {}
1196 records = {}
1200 for msg_id in msg_ids:
1197 for msg_id in msg_ids:
1201 if msg_id in self.pending:
1198 if msg_id in self.pending:
1202 pending.append(msg_id)
1199 pending.append(msg_id)
1203 elif msg_id in self.all_completed:
1200 elif msg_id in self.all_completed:
1204 completed.append(msg_id)
1201 completed.append(msg_id)
1205 if not statusonly:
1202 if not statusonly:
1206 c,bufs = self._extract_record(records[msg_id])
1203 c,bufs = self._extract_record(records[msg_id])
1207 content[msg_id] = c
1204 content[msg_id] = c
1208 buffers.extend(bufs)
1205 buffers.extend(bufs)
1209 elif msg_id in records:
1206 elif msg_id in records:
1210 if rec['completed']:
1207 if rec['completed']:
1211 completed.append(msg_id)
1208 completed.append(msg_id)
1212 c,bufs = self._extract_record(records[msg_id])
1209 c,bufs = self._extract_record(records[msg_id])
1213 content[msg_id] = c
1210 content[msg_id] = c
1214 buffers.extend(bufs)
1211 buffers.extend(bufs)
1215 else:
1212 else:
1216 pending.append(msg_id)
1213 pending.append(msg_id)
1217 else:
1214 else:
1218 try:
1215 try:
1219 raise KeyError('No such message: '+msg_id)
1216 raise KeyError('No such message: '+msg_id)
1220 except:
1217 except:
1221 content = error.wrap_exception()
1218 content = error.wrap_exception()
1222 break
1219 break
1223 self.session.send(self.query, "result_reply", content=content,
1220 self.session.send(self.query, "result_reply", content=content,
1224 parent=msg, ident=client_id,
1221 parent=msg, ident=client_id,
1225 buffers=buffers)
1222 buffers=buffers)
1226
1223
1227 def get_history(self, client_id, msg):
1224 def get_history(self, client_id, msg):
1228 """Get a list of all msg_ids in our DB records"""
1225 """Get a list of all msg_ids in our DB records"""
1229 try:
1226 try:
1230 msg_ids = self.db.get_history()
1227 msg_ids = self.db.get_history()
1231 except Exception as e:
1228 except Exception as e:
1232 content = error.wrap_exception()
1229 content = error.wrap_exception()
1233 else:
1230 else:
1234 content = dict(status='ok', history=msg_ids)
1231 content = dict(status='ok', history=msg_ids)
1235
1232
1236 self.session.send(self.query, "history_reply", content=content,
1233 self.session.send(self.query, "history_reply", content=content,
1237 parent=msg, ident=client_id)
1234 parent=msg, ident=client_id)
1238
1235
1239 def db_query(self, client_id, msg):
1236 def db_query(self, client_id, msg):
1240 """Perform a raw query on the task record database."""
1237 """Perform a raw query on the task record database."""
1241 content = msg['content']
1238 content = msg['content']
1242 query = content.get('query', {})
1239 query = content.get('query', {})
1243 keys = content.get('keys', None)
1240 keys = content.get('keys', None)
1244 query = util.extract_dates(query)
1241 query = util.extract_dates(query)
1245 buffers = []
1242 buffers = []
1246 empty = list()
1243 empty = list()
1247
1244
1248 try:
1245 try:
1249 records = self.db.find_records(query, keys)
1246 records = self.db.find_records(query, keys)
1250 except Exception as e:
1247 except Exception as e:
1251 content = error.wrap_exception()
1248 content = error.wrap_exception()
1252 else:
1249 else:
1253 # extract buffers from reply content:
1250 # extract buffers from reply content:
1254 if keys is not None:
1251 if keys is not None:
1255 buffer_lens = [] if 'buffers' in keys else None
1252 buffer_lens = [] if 'buffers' in keys else None
1256 result_buffer_lens = [] if 'result_buffers' in keys else None
1253 result_buffer_lens = [] if 'result_buffers' in keys else None
1257 else:
1254 else:
1258 buffer_lens = []
1255 buffer_lens = []
1259 result_buffer_lens = []
1256 result_buffer_lens = []
1260
1257
1261 for rec in records:
1258 for rec in records:
1262 # buffers may be None, so double check
1259 # buffers may be None, so double check
1263 if buffer_lens is not None:
1260 if buffer_lens is not None:
1264 b = rec.pop('buffers', empty) or empty
1261 b = rec.pop('buffers', empty) or empty
1265 buffer_lens.append(len(b))
1262 buffer_lens.append(len(b))
1266 buffers.extend(b)
1263 buffers.extend(b)
1267 if result_buffer_lens is not None:
1264 if result_buffer_lens is not None:
1268 rb = rec.pop('result_buffers', empty) or empty
1265 rb = rec.pop('result_buffers', empty) or empty
1269 result_buffer_lens.append(len(rb))
1266 result_buffer_lens.append(len(rb))
1270 buffers.extend(rb)
1267 buffers.extend(rb)
1271 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1268 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1272 result_buffer_lens=result_buffer_lens)
1269 result_buffer_lens=result_buffer_lens)
1273
1270
1274 self.session.send(self.query, "db_reply", content=content,
1271 self.session.send(self.query, "db_reply", content=content,
1275 parent=msg, ident=client_id,
1272 parent=msg, ident=client_id,
1276 buffers=buffers)
1273 buffers=buffers)
1277
1274
@@ -1,339 +1,339 b''
1 """A TaskRecord backend using sqlite3"""
1 """A TaskRecord backend using sqlite3"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
3 # Copyright (C) 2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 import json
9 import json
10 import os
10 import os
11 import cPickle as pickle
11 import cPickle as pickle
12 from datetime import datetime
12 from datetime import datetime
13
13
14 import sqlite3
14 import sqlite3
15
15
16 from zmq.eventloop import ioloop
16 from zmq.eventloop import ioloop
17
17
18 from IPython.utils.traitlets import Unicode, Instance, List
18 from IPython.utils.traitlets import Unicode, Instance, List
19 from .dictdb import BaseDB
19 from .dictdb import BaseDB
20 from IPython.parallel.util import ISO8601
20 from IPython.utils.jsonutil import date_default, extract_dates
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # SQLite operators, adapters, and converters
23 # SQLite operators, adapters, and converters
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 operators = {
26 operators = {
27 '$lt' : "<",
27 '$lt' : "<",
28 '$gt' : ">",
28 '$gt' : ">",
29 # null is handled weird with ==,!=
29 # null is handled weird with ==,!=
30 '$eq' : "=",
30 '$eq' : "=",
31 '$ne' : "!=",
31 '$ne' : "!=",
32 '$lte': "<=",
32 '$lte': "<=",
33 '$gte': ">=",
33 '$gte': ">=",
34 '$in' : ('=', ' OR '),
34 '$in' : ('=', ' OR '),
35 '$nin': ('!=', ' AND '),
35 '$nin': ('!=', ' AND '),
36 # '$all': None,
36 # '$all': None,
37 # '$mod': None,
37 # '$mod': None,
38 # '$exists' : None
38 # '$exists' : None
39 }
39 }
40 null_operators = {
40 null_operators = {
41 '=' : "IS NULL",
41 '=' : "IS NULL",
42 '!=' : "IS NOT NULL",
42 '!=' : "IS NOT NULL",
43 }
43 }
44
44
45 def _adapt_datetime(dt):
45 def _adapt_datetime(dt):
46 return dt.strftime(ISO8601)
46 return dt.strftime(ISO8601)
47
47
48 def _convert_datetime(ds):
48 def _convert_datetime(ds):
49 if ds is None:
49 if ds is None:
50 return ds
50 return ds
51 else:
51 else:
52 return datetime.strptime(ds, ISO8601)
52 return datetime.strptime(ds, ISO8601)
53
53
54 def _adapt_dict(d):
54 def _adapt_dict(d):
55 return json.dumps(d)
55 return json.dumps(d, default=date_default)
56
56
57 def _convert_dict(ds):
57 def _convert_dict(ds):
58 if ds is None:
58 if ds is None:
59 return ds
59 return ds
60 else:
60 else:
61 return json.loads(ds)
61 return extract_dates(json.loads(ds))
62
62
63 def _adapt_bufs(bufs):
63 def _adapt_bufs(bufs):
64 # this is *horrible*
64 # this is *horrible*
65 # copy buffers into single list and pickle it:
65 # copy buffers into single list and pickle it:
66 if bufs and isinstance(bufs[0], (bytes, buffer)):
66 if bufs and isinstance(bufs[0], (bytes, buffer)):
67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
68 elif bufs:
68 elif bufs:
69 return bufs
69 return bufs
70 else:
70 else:
71 return None
71 return None
72
72
73 def _convert_bufs(bs):
73 def _convert_bufs(bs):
74 if bs is None:
74 if bs is None:
75 return []
75 return []
76 else:
76 else:
77 return pickle.loads(bytes(bs))
77 return pickle.loads(bytes(bs))
78
78
79 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
80 # SQLiteDB class
80 # SQLiteDB class
81 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
82
82
83 class SQLiteDB(BaseDB):
83 class SQLiteDB(BaseDB):
84 """SQLite3 TaskRecord backend."""
84 """SQLite3 TaskRecord backend."""
85
85
86 filename = Unicode('tasks.db', config=True,
86 filename = Unicode('tasks.db', config=True,
87 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
87 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
88 location = Unicode('', config=True,
88 location = Unicode('', config=True,
89 help="""The directory containing the sqlite task database. The default
89 help="""The directory containing the sqlite task database. The default
90 is to use the cluster_dir location.""")
90 is to use the cluster_dir location.""")
91 table = Unicode("", config=True,
91 table = Unicode("", config=True,
92 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
92 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
93 a new table will be created with the Hub's IDENT. Specifying the table will result
93 a new table will be created with the Hub's IDENT. Specifying the table will result
94 in tasks from previous sessions being available via Clients' db_query and
94 in tasks from previous sessions being available via Clients' db_query and
95 get_result methods.""")
95 get_result methods.""")
96
96
97 _db = Instance('sqlite3.Connection')
97 _db = Instance('sqlite3.Connection')
98 _keys = List(['msg_id' ,
98 _keys = List(['msg_id' ,
99 'header' ,
99 'header' ,
100 'content',
100 'content',
101 'buffers',
101 'buffers',
102 'submitted',
102 'submitted',
103 'client_uuid' ,
103 'client_uuid' ,
104 'engine_uuid' ,
104 'engine_uuid' ,
105 'started',
105 'started',
106 'completed',
106 'completed',
107 'resubmitted',
107 'resubmitted',
108 'result_header' ,
108 'result_header' ,
109 'result_content' ,
109 'result_content' ,
110 'result_buffers' ,
110 'result_buffers' ,
111 'queue' ,
111 'queue' ,
112 'pyin' ,
112 'pyin' ,
113 'pyout',
113 'pyout',
114 'pyerr',
114 'pyerr',
115 'stdout',
115 'stdout',
116 'stderr',
116 'stderr',
117 ])
117 ])
118
118
119 def __init__(self, **kwargs):
119 def __init__(self, **kwargs):
120 super(SQLiteDB, self).__init__(**kwargs)
120 super(SQLiteDB, self).__init__(**kwargs)
121 if not self.table:
121 if not self.table:
122 # use session, and prefix _, since starting with # is illegal
122 # use session, and prefix _, since starting with # is illegal
123 self.table = '_'+self.session.replace('-','_')
123 self.table = '_'+self.session.replace('-','_')
124 if not self.location:
124 if not self.location:
125 # get current profile
125 # get current profile
126 from IPython.core.newapplication import BaseIPythonApplication
126 from IPython.core.newapplication import BaseIPythonApplication
127 if BaseIPythonApplication.initialized():
127 if BaseIPythonApplication.initialized():
128 app = BaseIPythonApplication.instance()
128 app = BaseIPythonApplication.instance()
129 if app.profile_dir is not None:
129 if app.profile_dir is not None:
130 self.location = app.profile_dir.location
130 self.location = app.profile_dir.location
131 else:
131 else:
132 self.location = u'.'
132 self.location = u'.'
133 else:
133 else:
134 self.location = u'.'
134 self.location = u'.'
135 self._init_db()
135 self._init_db()
136
136
137 # register db commit as 2s periodic callback
137 # register db commit as 2s periodic callback
138 # to prevent clogging pipes
138 # to prevent clogging pipes
139 # assumes we are being run in a zmq ioloop app
139 # assumes we are being run in a zmq ioloop app
140 loop = ioloop.IOLoop.instance()
140 loop = ioloop.IOLoop.instance()
141 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
141 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
142 pc.start()
142 pc.start()
143
143
144 def _defaults(self, keys=None):
144 def _defaults(self, keys=None):
145 """create an empty record"""
145 """create an empty record"""
146 d = {}
146 d = {}
147 keys = self._keys if keys is None else keys
147 keys = self._keys if keys is None else keys
148 for key in keys:
148 for key in keys:
149 d[key] = None
149 d[key] = None
150 return d
150 return d
151
151
152 def _init_db(self):
152 def _init_db(self):
153 """Connect to the database and get new session number."""
153 """Connect to the database and get new session number."""
154 # register adapters
154 # register adapters
155 sqlite3.register_adapter(datetime, _adapt_datetime)
155 sqlite3.register_adapter(datetime, _adapt_datetime)
156 sqlite3.register_converter('datetime', _convert_datetime)
156 sqlite3.register_converter('datetime', _convert_datetime)
157 sqlite3.register_adapter(dict, _adapt_dict)
157 sqlite3.register_adapter(dict, _adapt_dict)
158 sqlite3.register_converter('dict', _convert_dict)
158 sqlite3.register_converter('dict', _convert_dict)
159 sqlite3.register_adapter(list, _adapt_bufs)
159 sqlite3.register_adapter(list, _adapt_bufs)
160 sqlite3.register_converter('bufs', _convert_bufs)
160 sqlite3.register_converter('bufs', _convert_bufs)
161 # connect to the db
161 # connect to the db
162 dbfile = os.path.join(self.location, self.filename)
162 dbfile = os.path.join(self.location, self.filename)
163 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
163 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
164 # isolation_level = None)#,
164 # isolation_level = None)#,
165 cached_statements=64)
165 cached_statements=64)
166 # print dir(self._db)
166 # print dir(self._db)
167
167
168 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
168 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
169 (msg_id text PRIMARY KEY,
169 (msg_id text PRIMARY KEY,
170 header dict text,
170 header dict text,
171 content dict text,
171 content dict text,
172 buffers bufs blob,
172 buffers bufs blob,
173 submitted datetime text,
173 submitted datetime text,
174 client_uuid text,
174 client_uuid text,
175 engine_uuid text,
175 engine_uuid text,
176 started datetime text,
176 started datetime text,
177 completed datetime text,
177 completed datetime text,
178 resubmitted datetime text,
178 resubmitted datetime text,
179 result_header dict text,
179 result_header dict text,
180 result_content dict text,
180 result_content dict text,
181 result_buffers bufs blob,
181 result_buffers bufs blob,
182 queue text,
182 queue text,
183 pyin text,
183 pyin text,
184 pyout text,
184 pyout text,
185 pyerr text,
185 pyerr text,
186 stdout text,
186 stdout text,
187 stderr text)
187 stderr text)
188 """%self.table)
188 """%self.table)
189 self._db.commit()
189 self._db.commit()
190
190
191 def _dict_to_list(self, d):
191 def _dict_to_list(self, d):
192 """turn a mongodb-style record dict into a list."""
192 """turn a mongodb-style record dict into a list."""
193
193
194 return [ d[key] for key in self._keys ]
194 return [ d[key] for key in self._keys ]
195
195
196 def _list_to_dict(self, line, keys=None):
196 def _list_to_dict(self, line, keys=None):
197 """Inverse of dict_to_list"""
197 """Inverse of dict_to_list"""
198 keys = self._keys if keys is None else keys
198 keys = self._keys if keys is None else keys
199 d = self._defaults(keys)
199 d = self._defaults(keys)
200 for key,value in zip(keys, line):
200 for key,value in zip(keys, line):
201 d[key] = value
201 d[key] = value
202
202
203 return d
203 return d
204
204
205 def _render_expression(self, check):
205 def _render_expression(self, check):
206 """Turn a mongodb-style search dict into an SQL query."""
206 """Turn a mongodb-style search dict into an SQL query."""
207 expressions = []
207 expressions = []
208 args = []
208 args = []
209
209
210 skeys = set(check.keys())
210 skeys = set(check.keys())
211 skeys.difference_update(set(self._keys))
211 skeys.difference_update(set(self._keys))
212 skeys.difference_update(set(['buffers', 'result_buffers']))
212 skeys.difference_update(set(['buffers', 'result_buffers']))
213 if skeys:
213 if skeys:
214 raise KeyError("Illegal testing key(s): %s"%skeys)
214 raise KeyError("Illegal testing key(s): %s"%skeys)
215
215
216 for name,sub_check in check.iteritems():
216 for name,sub_check in check.iteritems():
217 if isinstance(sub_check, dict):
217 if isinstance(sub_check, dict):
218 for test,value in sub_check.iteritems():
218 for test,value in sub_check.iteritems():
219 try:
219 try:
220 op = operators[test]
220 op = operators[test]
221 except KeyError:
221 except KeyError:
222 raise KeyError("Unsupported operator: %r"%test)
222 raise KeyError("Unsupported operator: %r"%test)
223 if isinstance(op, tuple):
223 if isinstance(op, tuple):
224 op, join = op
224 op, join = op
225
225
226 if value is None and op in null_operators:
226 if value is None and op in null_operators:
227 expr = "%s %s"%null_operators[op]
227 expr = "%s %s"%null_operators[op]
228 else:
228 else:
229 expr = "%s %s ?"%(name, op)
229 expr = "%s %s ?"%(name, op)
230 if isinstance(value, (tuple,list)):
230 if isinstance(value, (tuple,list)):
231 if op in null_operators and any([v is None for v in value]):
231 if op in null_operators and any([v is None for v in value]):
232 # equality tests don't work with NULL
232 # equality tests don't work with NULL
233 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
233 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
234 expr = '( %s )'%( join.join([expr]*len(value)) )
234 expr = '( %s )'%( join.join([expr]*len(value)) )
235 args.extend(value)
235 args.extend(value)
236 else:
236 else:
237 args.append(value)
237 args.append(value)
238 expressions.append(expr)
238 expressions.append(expr)
239 else:
239 else:
240 # it's an equality check
240 # it's an equality check
241 if sub_check is None:
241 if sub_check is None:
242 expressions.append("%s IS NULL")
242 expressions.append("%s IS NULL")
243 else:
243 else:
244 expressions.append("%s = ?"%name)
244 expressions.append("%s = ?"%name)
245 args.append(sub_check)
245 args.append(sub_check)
246
246
247 expr = " AND ".join(expressions)
247 expr = " AND ".join(expressions)
248 return expr, args
248 return expr, args
249
249
250 def add_record(self, msg_id, rec):
250 def add_record(self, msg_id, rec):
251 """Add a new Task Record, by msg_id."""
251 """Add a new Task Record, by msg_id."""
252 d = self._defaults()
252 d = self._defaults()
253 d.update(rec)
253 d.update(rec)
254 d['msg_id'] = msg_id
254 d['msg_id'] = msg_id
255 line = self._dict_to_list(d)
255 line = self._dict_to_list(d)
256 tups = '(%s)'%(','.join(['?']*len(line)))
256 tups = '(%s)'%(','.join(['?']*len(line)))
257 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
257 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
258 # self._db.commit()
258 # self._db.commit()
259
259
260 def get_record(self, msg_id):
260 def get_record(self, msg_id):
261 """Get a specific Task Record, by msg_id."""
261 """Get a specific Task Record, by msg_id."""
262 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
262 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
263 line = cursor.fetchone()
263 line = cursor.fetchone()
264 if line is None:
264 if line is None:
265 raise KeyError("No such msg: %r"%msg_id)
265 raise KeyError("No such msg: %r"%msg_id)
266 return self._list_to_dict(line)
266 return self._list_to_dict(line)
267
267
268 def update_record(self, msg_id, rec):
268 def update_record(self, msg_id, rec):
269 """Update the data in an existing record."""
269 """Update the data in an existing record."""
270 query = "UPDATE %s SET "%self.table
270 query = "UPDATE %s SET "%self.table
271 sets = []
271 sets = []
272 keys = sorted(rec.keys())
272 keys = sorted(rec.keys())
273 values = []
273 values = []
274 for key in keys:
274 for key in keys:
275 sets.append('%s = ?'%key)
275 sets.append('%s = ?'%key)
276 values.append(rec[key])
276 values.append(rec[key])
277 query += ', '.join(sets)
277 query += ', '.join(sets)
278 query += ' WHERE msg_id == ?'
278 query += ' WHERE msg_id == ?'
279 values.append(msg_id)
279 values.append(msg_id)
280 self._db.execute(query, values)
280 self._db.execute(query, values)
281 # self._db.commit()
281 # self._db.commit()
282
282
283 def drop_record(self, msg_id):
283 def drop_record(self, msg_id):
284 """Remove a record from the DB."""
284 """Remove a record from the DB."""
285 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
285 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
286 # self._db.commit()
286 # self._db.commit()
287
287
288 def drop_matching_records(self, check):
288 def drop_matching_records(self, check):
289 """Remove a record from the DB."""
289 """Remove a record from the DB."""
290 expr,args = self._render_expression(check)
290 expr,args = self._render_expression(check)
291 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
291 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
292 self._db.execute(query,args)
292 self._db.execute(query,args)
293 # self._db.commit()
293 # self._db.commit()
294
294
295 def find_records(self, check, keys=None):
295 def find_records(self, check, keys=None):
296 """Find records matching a query dict, optionally extracting subset of keys.
296 """Find records matching a query dict, optionally extracting subset of keys.
297
297
298 Returns list of matching records.
298 Returns list of matching records.
299
299
300 Parameters
300 Parameters
301 ----------
301 ----------
302
302
303 check: dict
303 check: dict
304 mongodb-style query argument
304 mongodb-style query argument
305 keys: list of strs [optional]
305 keys: list of strs [optional]
306 if specified, the subset of keys to extract. msg_id will *always* be
306 if specified, the subset of keys to extract. msg_id will *always* be
307 included.
307 included.
308 """
308 """
309 if keys:
309 if keys:
310 bad_keys = [ key for key in keys if key not in self._keys ]
310 bad_keys = [ key for key in keys if key not in self._keys ]
311 if bad_keys:
311 if bad_keys:
312 raise KeyError("Bad record key(s): %s"%bad_keys)
312 raise KeyError("Bad record key(s): %s"%bad_keys)
313
313
314 if keys:
314 if keys:
315 # ensure msg_id is present and first:
315 # ensure msg_id is present and first:
316 if 'msg_id' in keys:
316 if 'msg_id' in keys:
317 keys.remove('msg_id')
317 keys.remove('msg_id')
318 keys.insert(0, 'msg_id')
318 keys.insert(0, 'msg_id')
319 req = ', '.join(keys)
319 req = ', '.join(keys)
320 else:
320 else:
321 req = '*'
321 req = '*'
322 expr,args = self._render_expression(check)
322 expr,args = self._render_expression(check)
323 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
323 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
324 cursor = self._db.execute(query, args)
324 cursor = self._db.execute(query, args)
325 matches = cursor.fetchall()
325 matches = cursor.fetchall()
326 records = []
326 records = []
327 for line in matches:
327 for line in matches:
328 rec = self._list_to_dict(line, keys)
328 rec = self._list_to_dict(line, keys)
329 records.append(rec)
329 records.append(rec)
330 return records
330 return records
331
331
332 def get_history(self):
332 def get_history(self):
333 """get all msg_ids, ordered by time submitted."""
333 """get all msg_ids, ordered by time submitted."""
334 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
334 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
335 cursor = self._db.execute(query)
335 cursor = self._db.execute(query)
336 # will be a list of length 1 tuples
336 # will be a list of length 1 tuples
337 return [ tup[0] for tup in cursor.fetchall()]
337 return [ tup[0] for tup in cursor.fetchall()]
338
338
339 __all__ = ['SQLiteDB'] No newline at end of file
339 __all__ = ['SQLiteDB']
@@ -1,165 +1,166 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple engine that talks to a controller over 0MQ.
2 """A simple engine that talks to a controller over 0MQ.
3 it handles registration, etc. and launches a kernel
3 it handles registration, etc. and launches a kernel
4 connected to the Controller's Schedulers.
4 connected to the Controller's Schedulers.
5 """
5 """
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2010-2011 The IPython Development Team
7 # Copyright (C) 2010-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 import sys
15 import sys
16 import time
16 import time
17
17
18 import zmq
18 import zmq
19 from zmq.eventloop import ioloop, zmqstream
19 from zmq.eventloop import ioloop, zmqstream
20
20
21 # internal
21 # internal
22 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
22 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
23 # from IPython.utils.localinterfaces import LOCALHOST
23 # from IPython.utils.localinterfaces import LOCALHOST
24
24
25 from IPython.parallel.controller.heartmonitor import Heart
25 from IPython.parallel.controller.heartmonitor import Heart
26 from IPython.parallel.factory import RegistrationFactory
26 from IPython.parallel.factory import RegistrationFactory
27 from IPython.parallel.streamsession import Message
28 from IPython.parallel.util import disambiguate_url
27 from IPython.parallel.util import disambiguate_url
29
28
29 from IPython.zmq.session import Message
30
30 from .streamkernel import Kernel
31 from .streamkernel import Kernel
31
32
32 class EngineFactory(RegistrationFactory):
33 class EngineFactory(RegistrationFactory):
33 """IPython engine"""
34 """IPython engine"""
34
35
35 # configurables:
36 # configurables:
36 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
37 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
37 help="""The OutStream for handling stdout/err.
38 help="""The OutStream for handling stdout/err.
38 Typically 'IPython.zmq.iostream.OutStream'""")
39 Typically 'IPython.zmq.iostream.OutStream'""")
39 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
40 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
40 help="""The class for handling displayhook.
41 help="""The class for handling displayhook.
41 Typically 'IPython.zmq.displayhook.DisplayHook'""")
42 Typically 'IPython.zmq.displayhook.DisplayHook'""")
42 location=Unicode(config=True,
43 location=Unicode(config=True,
43 help="""The location (an IP address) of the controller. This is
44 help="""The location (an IP address) of the controller. This is
44 used for disambiguating URLs, to determine whether
45 used for disambiguating URLs, to determine whether
45 loopback should be used to connect or the public address.""")
46 loopback should be used to connect or the public address.""")
46 timeout=CFloat(2,config=True,
47 timeout=CFloat(2,config=True,
47 help="""The time (in seconds) to wait for the Controller to respond
48 help="""The time (in seconds) to wait for the Controller to respond
48 to registration requests before giving up.""")
49 to registration requests before giving up.""")
49
50
50 # not configurable:
51 # not configurable:
51 user_ns=Dict()
52 user_ns=Dict()
52 id=Int(allow_none=True)
53 id=Int(allow_none=True)
53 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
54 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
54 kernel=Instance(Kernel)
55 kernel=Instance(Kernel)
55
56
56
57
57 def __init__(self, **kwargs):
58 def __init__(self, **kwargs):
58 super(EngineFactory, self).__init__(**kwargs)
59 super(EngineFactory, self).__init__(**kwargs)
59 self.ident = self.session.session
60 self.ident = self.session.session
60 ctx = self.context
61 ctx = self.context
61
62
62 reg = ctx.socket(zmq.XREQ)
63 reg = ctx.socket(zmq.XREQ)
63 reg.setsockopt(zmq.IDENTITY, self.ident)
64 reg.setsockopt(zmq.IDENTITY, self.ident)
64 reg.connect(self.url)
65 reg.connect(self.url)
65 self.registrar = zmqstream.ZMQStream(reg, self.loop)
66 self.registrar = zmqstream.ZMQStream(reg, self.loop)
66
67
67 def register(self):
68 def register(self):
68 """send the registration_request"""
69 """send the registration_request"""
69
70
70 self.log.info("registering")
71 self.log.info("registering")
71 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
72 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
72 self.registrar.on_recv(self.complete_registration)
73 self.registrar.on_recv(self.complete_registration)
73 # print (self.session.key)
74 # print (self.session.key)
74 self.session.send(self.registrar, "registration_request",content=content)
75 self.session.send(self.registrar, "registration_request",content=content)
75
76
76 def complete_registration(self, msg):
77 def complete_registration(self, msg):
77 # print msg
78 # print msg
78 self._abort_dc.stop()
79 self._abort_dc.stop()
79 ctx = self.context
80 ctx = self.context
80 loop = self.loop
81 loop = self.loop
81 identity = self.ident
82 identity = self.ident
82
83
83 idents,msg = self.session.feed_identities(msg)
84 idents,msg = self.session.feed_identities(msg)
84 msg = Message(self.session.unpack_message(msg))
85 msg = Message(self.session.unpack_message(msg))
85
86
86 if msg.content.status == 'ok':
87 if msg.content.status == 'ok':
87 self.id = int(msg.content.id)
88 self.id = int(msg.content.id)
88
89
89 # create Shell Streams (MUX, Task, etc.):
90 # create Shell Streams (MUX, Task, etc.):
90 queue_addr = msg.content.mux
91 queue_addr = msg.content.mux
91 shell_addrs = [ str(queue_addr) ]
92 shell_addrs = [ str(queue_addr) ]
92 task_addr = msg.content.task
93 task_addr = msg.content.task
93 if task_addr:
94 if task_addr:
94 shell_addrs.append(str(task_addr))
95 shell_addrs.append(str(task_addr))
95
96
96 # Uncomment this to go back to two-socket model
97 # Uncomment this to go back to two-socket model
97 # shell_streams = []
98 # shell_streams = []
98 # for addr in shell_addrs:
99 # for addr in shell_addrs:
99 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
100 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
100 # stream.setsockopt(zmq.IDENTITY, identity)
101 # stream.setsockopt(zmq.IDENTITY, identity)
101 # stream.connect(disambiguate_url(addr, self.location))
102 # stream.connect(disambiguate_url(addr, self.location))
102 # shell_streams.append(stream)
103 # shell_streams.append(stream)
103
104
104 # Now use only one shell stream for mux and tasks
105 # Now use only one shell stream for mux and tasks
105 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
106 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
106 stream.setsockopt(zmq.IDENTITY, identity)
107 stream.setsockopt(zmq.IDENTITY, identity)
107 shell_streams = [stream]
108 shell_streams = [stream]
108 for addr in shell_addrs:
109 for addr in shell_addrs:
109 stream.connect(disambiguate_url(addr, self.location))
110 stream.connect(disambiguate_url(addr, self.location))
110 # end single stream-socket
111 # end single stream-socket
111
112
112 # control stream:
113 # control stream:
113 control_addr = str(msg.content.control)
114 control_addr = str(msg.content.control)
114 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 control_stream.setsockopt(zmq.IDENTITY, identity)
116 control_stream.setsockopt(zmq.IDENTITY, identity)
116 control_stream.connect(disambiguate_url(control_addr, self.location))
117 control_stream.connect(disambiguate_url(control_addr, self.location))
117
118
118 # create iopub stream:
119 # create iopub stream:
119 iopub_addr = msg.content.iopub
120 iopub_addr = msg.content.iopub
120 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
121 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
121 iopub_stream.setsockopt(zmq.IDENTITY, identity)
122 iopub_stream.setsockopt(zmq.IDENTITY, identity)
122 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
123 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
123
124
124 # launch heartbeat
125 # launch heartbeat
125 hb_addrs = msg.content.heartbeat
126 hb_addrs = msg.content.heartbeat
126 # print (hb_addrs)
127 # print (hb_addrs)
127
128
128 # # Redirect input streams and set a display hook.
129 # # Redirect input streams and set a display hook.
129 if self.out_stream_factory:
130 if self.out_stream_factory:
130 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
131 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
131 sys.stdout.topic = 'engine.%i.stdout'%self.id
132 sys.stdout.topic = 'engine.%i.stdout'%self.id
132 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
133 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
133 sys.stderr.topic = 'engine.%i.stderr'%self.id
134 sys.stderr.topic = 'engine.%i.stderr'%self.id
134 if self.display_hook_factory:
135 if self.display_hook_factory:
135 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
136 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
136 sys.displayhook.topic = 'engine.%i.pyout'%self.id
137 sys.displayhook.topic = 'engine.%i.pyout'%self.id
137
138
138 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
139 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
139 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
140 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
140 loop=loop, user_ns = self.user_ns, log=self.log)
141 loop=loop, user_ns = self.user_ns, log=self.log)
141 self.kernel.start()
142 self.kernel.start()
142 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
143 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
143 heart = Heart(*map(str, hb_addrs), heart_id=identity)
144 heart = Heart(*map(str, hb_addrs), heart_id=identity)
144 heart.start()
145 heart.start()
145
146
146
147
147 else:
148 else:
148 self.log.fatal("Registration Failed: %s"%msg)
149 self.log.fatal("Registration Failed: %s"%msg)
149 raise Exception("Registration Failed: %s"%msg)
150 raise Exception("Registration Failed: %s"%msg)
150
151
151 self.log.info("Completed registration with id %i"%self.id)
152 self.log.info("Completed registration with id %i"%self.id)
152
153
153
154
154 def abort(self):
155 def abort(self):
155 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
156 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
156 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
157 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
157 time.sleep(1)
158 time.sleep(1)
158 sys.exit(255)
159 sys.exit(255)
159
160
160 def start(self):
161 def start(self):
161 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
162 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
162 dc.start()
163 dc.start()
163 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
164 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
164 self._abort_dc.start()
165 self._abort_dc.start()
165
166
@@ -1,225 +1,225 b''
1 """KernelStarter class that intercepts Control Queue messages, and handles process management."""
1 """KernelStarter class that intercepts Control Queue messages, and handles process management."""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
3 # Copyright (C) 2010-2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 from zmq.eventloop import ioloop
9 from zmq.eventloop import ioloop
10
10
11 from IPython.parallel.streamsession import StreamSession
11 from IPython.zmq.session import Session
12
12
13 class KernelStarter(object):
13 class KernelStarter(object):
14 """Object for resetting/killing the Kernel."""
14 """Object for resetting/killing the Kernel."""
15
15
16
16
17 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
17 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
18 self.session = session
18 self.session = session
19 self.upstream = upstream
19 self.upstream = upstream
20 self.downstream = downstream
20 self.downstream = downstream
21 self.kernel_args = kernel_args
21 self.kernel_args = kernel_args
22 self.kernel_kwargs = kernel_kwargs
22 self.kernel_kwargs = kernel_kwargs
23 self.handlers = {}
23 self.handlers = {}
24 for method in 'shutdown_request shutdown_reply'.split():
24 for method in 'shutdown_request shutdown_reply'.split():
25 self.handlers[method] = getattr(self, method)
25 self.handlers[method] = getattr(self, method)
26
26
27 def start(self):
27 def start(self):
28 self.upstream.on_recv(self.dispatch_request)
28 self.upstream.on_recv(self.dispatch_request)
29 self.downstream.on_recv(self.dispatch_reply)
29 self.downstream.on_recv(self.dispatch_reply)
30
30
31 #--------------------------------------------------------------------------
31 #--------------------------------------------------------------------------
32 # Dispatch methods
32 # Dispatch methods
33 #--------------------------------------------------------------------------
33 #--------------------------------------------------------------------------
34
34
35 def dispatch_request(self, raw_msg):
35 def dispatch_request(self, raw_msg):
36 idents, msg = self.session.feed_identities()
36 idents, msg = self.session.feed_identities()
37 try:
37 try:
38 msg = self.session.unpack_message(msg, content=False)
38 msg = self.session.unpack_message(msg, content=False)
39 except:
39 except:
40 print ("bad msg: %s"%msg)
40 print ("bad msg: %s"%msg)
41
41
42 msgtype = msg['msg_type']
42 msgtype = msg['msg_type']
43 handler = self.handlers.get(msgtype, None)
43 handler = self.handlers.get(msgtype, None)
44 if handler is None:
44 if handler is None:
45 self.downstream.send_multipart(raw_msg, copy=False)
45 self.downstream.send_multipart(raw_msg, copy=False)
46 else:
46 else:
47 handler(msg)
47 handler(msg)
48
48
49 def dispatch_reply(self, raw_msg):
49 def dispatch_reply(self, raw_msg):
50 idents, msg = self.session.feed_identities()
50 idents, msg = self.session.feed_identities()
51 try:
51 try:
52 msg = self.session.unpack_message(msg, content=False)
52 msg = self.session.unpack_message(msg, content=False)
53 except:
53 except:
54 print ("bad msg: %s"%msg)
54 print ("bad msg: %s"%msg)
55
55
56 msgtype = msg['msg_type']
56 msgtype = msg['msg_type']
57 handler = self.handlers.get(msgtype, None)
57 handler = self.handlers.get(msgtype, None)
58 if handler is None:
58 if handler is None:
59 self.upstream.send_multipart(raw_msg, copy=False)
59 self.upstream.send_multipart(raw_msg, copy=False)
60 else:
60 else:
61 handler(msg)
61 handler(msg)
62
62
63 #--------------------------------------------------------------------------
63 #--------------------------------------------------------------------------
64 # Handlers
64 # Handlers
65 #--------------------------------------------------------------------------
65 #--------------------------------------------------------------------------
66
66
67 def shutdown_request(self, msg):
67 def shutdown_request(self, msg):
68 """"""
68 """"""
69 self.downstream.send_multipart(msg)
69 self.downstream.send_multipart(msg)
70
70
71 #--------------------------------------------------------------------------
71 #--------------------------------------------------------------------------
72 # Kernel process management methods, from KernelManager:
72 # Kernel process management methods, from KernelManager:
73 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
74
74
75 def _check_local(addr):
75 def _check_local(addr):
76 if isinstance(addr, tuple):
76 if isinstance(addr, tuple):
77 addr = addr[0]
77 addr = addr[0]
78 return addr in LOCAL_IPS
78 return addr in LOCAL_IPS
79
79
80 def start_kernel(self, **kw):
80 def start_kernel(self, **kw):
81 """Starts a kernel process and configures the manager to use it.
81 """Starts a kernel process and configures the manager to use it.
82
82
83 If random ports (port=0) are being used, this method must be called
83 If random ports (port=0) are being used, this method must be called
84 before the channels are created.
84 before the channels are created.
85
85
86 Parameters:
86 Parameters:
87 -----------
87 -----------
88 ipython : bool, optional (default True)
88 ipython : bool, optional (default True)
89 Whether to use an IPython kernel instead of a plain Python kernel.
89 Whether to use an IPython kernel instead of a plain Python kernel.
90 """
90 """
91 self.kernel = Process(target=make_kernel, args=self.kernel_args,
91 self.kernel = Process(target=make_kernel, args=self.kernel_args,
92 kwargs=self.kernel_kwargs)
92 kwargs=self.kernel_kwargs)
93
93
94 def shutdown_kernel(self, restart=False):
94 def shutdown_kernel(self, restart=False):
95 """ Attempts to the stop the kernel process cleanly. If the kernel
95 """ Attempts to the stop the kernel process cleanly. If the kernel
96 cannot be stopped, it is killed, if possible.
96 cannot be stopped, it is killed, if possible.
97 """
97 """
98 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
98 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
99 if sys.platform == 'win32':
99 if sys.platform == 'win32':
100 self.kill_kernel()
100 self.kill_kernel()
101 return
101 return
102
102
103 # Don't send any additional kernel kill messages immediately, to give
103 # Don't send any additional kernel kill messages immediately, to give
104 # the kernel a chance to properly execute shutdown actions. Wait for at
104 # the kernel a chance to properly execute shutdown actions. Wait for at
105 # most 1s, checking every 0.1s.
105 # most 1s, checking every 0.1s.
106 self.xreq_channel.shutdown(restart=restart)
106 self.xreq_channel.shutdown(restart=restart)
107 for i in range(10):
107 for i in range(10):
108 if self.is_alive:
108 if self.is_alive:
109 time.sleep(0.1)
109 time.sleep(0.1)
110 else:
110 else:
111 break
111 break
112 else:
112 else:
113 # OK, we've waited long enough.
113 # OK, we've waited long enough.
114 if self.has_kernel:
114 if self.has_kernel:
115 self.kill_kernel()
115 self.kill_kernel()
116
116
117 def restart_kernel(self, now=False):
117 def restart_kernel(self, now=False):
118 """Restarts a kernel with the same arguments that were used to launch
118 """Restarts a kernel with the same arguments that were used to launch
119 it. If the old kernel was launched with random ports, the same ports
119 it. If the old kernel was launched with random ports, the same ports
120 will be used for the new kernel.
120 will be used for the new kernel.
121
121
122 Parameters
122 Parameters
123 ----------
123 ----------
124 now : bool, optional
124 now : bool, optional
125 If True, the kernel is forcefully restarted *immediately*, without
125 If True, the kernel is forcefully restarted *immediately*, without
126 having a chance to do any cleanup action. Otherwise the kernel is
126 having a chance to do any cleanup action. Otherwise the kernel is
127 given 1s to clean up before a forceful restart is issued.
127 given 1s to clean up before a forceful restart is issued.
128
128
129 In all cases the kernel is restarted, the only difference is whether
129 In all cases the kernel is restarted, the only difference is whether
130 it is given a chance to perform a clean shutdown or not.
130 it is given a chance to perform a clean shutdown or not.
131 """
131 """
132 if self._launch_args is None:
132 if self._launch_args is None:
133 raise RuntimeError("Cannot restart the kernel. "
133 raise RuntimeError("Cannot restart the kernel. "
134 "No previous call to 'start_kernel'.")
134 "No previous call to 'start_kernel'.")
135 else:
135 else:
136 if self.has_kernel:
136 if self.has_kernel:
137 if now:
137 if now:
138 self.kill_kernel()
138 self.kill_kernel()
139 else:
139 else:
140 self.shutdown_kernel(restart=True)
140 self.shutdown_kernel(restart=True)
141 self.start_kernel(**self._launch_args)
141 self.start_kernel(**self._launch_args)
142
142
143 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
143 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
144 # unless there is some delay here.
144 # unless there is some delay here.
145 if sys.platform == 'win32':
145 if sys.platform == 'win32':
146 time.sleep(0.2)
146 time.sleep(0.2)
147
147
148 @property
148 @property
149 def has_kernel(self):
149 def has_kernel(self):
150 """Returns whether a kernel process has been specified for the kernel
150 """Returns whether a kernel process has been specified for the kernel
151 manager.
151 manager.
152 """
152 """
153 return self.kernel is not None
153 return self.kernel is not None
154
154
155 def kill_kernel(self):
155 def kill_kernel(self):
156 """ Kill the running kernel. """
156 """ Kill the running kernel. """
157 if self.has_kernel:
157 if self.has_kernel:
158 # Pause the heart beat channel if it exists.
158 # Pause the heart beat channel if it exists.
159 if self._hb_channel is not None:
159 if self._hb_channel is not None:
160 self._hb_channel.pause()
160 self._hb_channel.pause()
161
161
162 # Attempt to kill the kernel.
162 # Attempt to kill the kernel.
163 try:
163 try:
164 self.kernel.kill()
164 self.kernel.kill()
165 except OSError, e:
165 except OSError, e:
166 # In Windows, we will get an Access Denied error if the process
166 # In Windows, we will get an Access Denied error if the process
167 # has already terminated. Ignore it.
167 # has already terminated. Ignore it.
168 if not (sys.platform == 'win32' and e.winerror == 5):
168 if not (sys.platform == 'win32' and e.winerror == 5):
169 raise
169 raise
170 self.kernel = None
170 self.kernel = None
171 else:
171 else:
172 raise RuntimeError("Cannot kill kernel. No kernel is running!")
172 raise RuntimeError("Cannot kill kernel. No kernel is running!")
173
173
174 def interrupt_kernel(self):
174 def interrupt_kernel(self):
175 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
175 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
176 well supported on all platforms.
176 well supported on all platforms.
177 """
177 """
178 if self.has_kernel:
178 if self.has_kernel:
179 if sys.platform == 'win32':
179 if sys.platform == 'win32':
180 from parentpoller import ParentPollerWindows as Poller
180 from parentpoller import ParentPollerWindows as Poller
181 Poller.send_interrupt(self.kernel.win32_interrupt_event)
181 Poller.send_interrupt(self.kernel.win32_interrupt_event)
182 else:
182 else:
183 self.kernel.send_signal(signal.SIGINT)
183 self.kernel.send_signal(signal.SIGINT)
184 else:
184 else:
185 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
185 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
186
186
187 def signal_kernel(self, signum):
187 def signal_kernel(self, signum):
188 """ Sends a signal to the kernel. Note that since only SIGTERM is
188 """ Sends a signal to the kernel. Note that since only SIGTERM is
189 supported on Windows, this function is only useful on Unix systems.
189 supported on Windows, this function is only useful on Unix systems.
190 """
190 """
191 if self.has_kernel:
191 if self.has_kernel:
192 self.kernel.send_signal(signum)
192 self.kernel.send_signal(signum)
193 else:
193 else:
194 raise RuntimeError("Cannot signal kernel. No kernel is running!")
194 raise RuntimeError("Cannot signal kernel. No kernel is running!")
195
195
196 @property
196 @property
197 def is_alive(self):
197 def is_alive(self):
198 """Is the kernel process still running?"""
198 """Is the kernel process still running?"""
199 # FIXME: not using a heartbeat means this method is broken for any
199 # FIXME: not using a heartbeat means this method is broken for any
200 # remote kernel, it's only capable of handling local kernels.
200 # remote kernel, it's only capable of handling local kernels.
201 if self.has_kernel:
201 if self.has_kernel:
202 if self.kernel.poll() is None:
202 if self.kernel.poll() is None:
203 return True
203 return True
204 else:
204 else:
205 return False
205 return False
206 else:
206 else:
207 # We didn't start the kernel with this KernelManager so we don't
207 # We didn't start the kernel with this KernelManager so we don't
208 # know if it is running. We should use a heartbeat for this case.
208 # know if it is running. We should use a heartbeat for this case.
209 return True
209 return True
210
210
211
211
212 def make_starter(up_addr, down_addr, *args, **kwargs):
212 def make_starter(up_addr, down_addr, *args, **kwargs):
213 """entry point function for launching a kernelstarter in a subprocess"""
213 """entry point function for launching a kernelstarter in a subprocess"""
214 loop = ioloop.IOLoop.instance()
214 loop = ioloop.IOLoop.instance()
215 ctx = zmq.Context()
215 ctx = zmq.Context()
216 session = StreamSession()
216 session = Session()
217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
218 upstream.connect(up_addr)
218 upstream.connect(up_addr)
219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
220 downstream.connect(down_addr)
220 downstream.connect(down_addr)
221
221
222 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
222 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
223 starter.start()
223 starter.start()
224 loop.start()
224 loop.start()
225 No newline at end of file
225
@@ -1,433 +1,434 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 Kernel adapted from kernel.py to use ZMQ Streams
3 Kernel adapted from kernel.py to use ZMQ Streams
4 """
4 """
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (C) 2010-2011 The IPython Development Team
6 # Copyright (C) 2010-2011 The IPython Development Team
7 #
7 #
8 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Imports
13 # Imports
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 # Standard library imports.
16 # Standard library imports.
17 from __future__ import print_function
17 from __future__ import print_function
18
18
19 import sys
19 import sys
20 import time
20 import time
21
21
22 from code import CommandCompiler
22 from code import CommandCompiler
23 from datetime import datetime
23 from datetime import datetime
24 from pprint import pprint
24 from pprint import pprint
25
25
26 # System library imports.
26 # System library imports.
27 import zmq
27 import zmq
28 from zmq.eventloop import ioloop, zmqstream
28 from zmq.eventloop import ioloop, zmqstream
29
29
30 # Local imports.
30 # Local imports.
31 from IPython.utils.jsonutil import ISO8601
31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 from IPython.zmq.completer import KernelCompleter
33 from IPython.zmq.completer import KernelCompleter
33
34
34 from IPython.parallel.error import wrap_exception
35 from IPython.parallel.error import wrap_exception
35 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601
37 from IPython.parallel.util import serialize_object, unpack_apply_message
37
38
38 def printer(*args):
39 def printer(*args):
39 pprint(args, stream=sys.__stdout__)
40 pprint(args, stream=sys.__stdout__)
40
41
41
42
42 class _Passer(zmqstream.ZMQStream):
43 class _Passer(zmqstream.ZMQStream):
43 """Empty class that implements `send()` that does nothing.
44 """Empty class that implements `send()` that does nothing.
44
45
45 Subclass ZMQStream for StreamSession typechecking
46 Subclass ZMQStream for Session typechecking
46
47
47 """
48 """
48 def __init__(self, *args, **kwargs):
49 def __init__(self, *args, **kwargs):
49 pass
50 pass
50
51
51 def send(self, *args, **kwargs):
52 def send(self, *args, **kwargs):
52 pass
53 pass
53 send_multipart = send
54 send_multipart = send
54
55
55
56
56 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
57 # Main kernel class
58 # Main kernel class
58 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
59
60
60 class Kernel(SessionFactory):
61 class Kernel(SessionFactory):
61
62
62 #---------------------------------------------------------------------------
63 #---------------------------------------------------------------------------
63 # Kernel interface
64 # Kernel interface
64 #---------------------------------------------------------------------------
65 #---------------------------------------------------------------------------
65
66
66 # kwargs:
67 # kwargs:
67 exec_lines = List(Unicode, config=True,
68 exec_lines = List(Unicode, config=True,
68 help="List of lines to execute")
69 help="List of lines to execute")
69
70
70 int_id = Int(-1)
71 int_id = Int(-1)
71 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
72 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
72
73
73 control_stream = Instance(zmqstream.ZMQStream)
74 control_stream = Instance(zmqstream.ZMQStream)
74 task_stream = Instance(zmqstream.ZMQStream)
75 task_stream = Instance(zmqstream.ZMQStream)
75 iopub_stream = Instance(zmqstream.ZMQStream)
76 iopub_stream = Instance(zmqstream.ZMQStream)
76 client = Instance('IPython.parallel.Client')
77 client = Instance('IPython.parallel.Client')
77
78
78 # internals
79 # internals
79 shell_streams = List()
80 shell_streams = List()
80 compiler = Instance(CommandCompiler, (), {})
81 compiler = Instance(CommandCompiler, (), {})
81 completer = Instance(KernelCompleter)
82 completer = Instance(KernelCompleter)
82
83
83 aborted = Set()
84 aborted = Set()
84 shell_handlers = Dict()
85 shell_handlers = Dict()
85 control_handlers = Dict()
86 control_handlers = Dict()
86
87
87 def _set_prefix(self):
88 def _set_prefix(self):
88 self.prefix = "engine.%s"%self.int_id
89 self.prefix = "engine.%s"%self.int_id
89
90
90 def _connect_completer(self):
91 def _connect_completer(self):
91 self.completer = KernelCompleter(self.user_ns)
92 self.completer = KernelCompleter(self.user_ns)
92
93
93 def __init__(self, **kwargs):
94 def __init__(self, **kwargs):
94 super(Kernel, self).__init__(**kwargs)
95 super(Kernel, self).__init__(**kwargs)
95 self._set_prefix()
96 self._set_prefix()
96 self._connect_completer()
97 self._connect_completer()
97
98
98 self.on_trait_change(self._set_prefix, 'id')
99 self.on_trait_change(self._set_prefix, 'id')
99 self.on_trait_change(self._connect_completer, 'user_ns')
100 self.on_trait_change(self._connect_completer, 'user_ns')
100
101
101 # Build dict of handlers for message types
102 # Build dict of handlers for message types
102 for msg_type in ['execute_request', 'complete_request', 'apply_request',
103 for msg_type in ['execute_request', 'complete_request', 'apply_request',
103 'clear_request']:
104 'clear_request']:
104 self.shell_handlers[msg_type] = getattr(self, msg_type)
105 self.shell_handlers[msg_type] = getattr(self, msg_type)
105
106
106 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
107 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
107 self.control_handlers[msg_type] = getattr(self, msg_type)
108 self.control_handlers[msg_type] = getattr(self, msg_type)
108
109
109 self._initial_exec_lines()
110 self._initial_exec_lines()
110
111
111 def _wrap_exception(self, method=None):
112 def _wrap_exception(self, method=None):
112 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
113 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
113 content=wrap_exception(e_info)
114 content=wrap_exception(e_info)
114 return content
115 return content
115
116
116 def _initial_exec_lines(self):
117 def _initial_exec_lines(self):
117 s = _Passer()
118 s = _Passer()
118 content = dict(silent=True, user_variable=[],user_expressions=[])
119 content = dict(silent=True, user_variable=[],user_expressions=[])
119 for line in self.exec_lines:
120 for line in self.exec_lines:
120 self.log.debug("executing initialization: %s"%line)
121 self.log.debug("executing initialization: %s"%line)
121 content.update({'code':line})
122 content.update({'code':line})
122 msg = self.session.msg('execute_request', content)
123 msg = self.session.msg('execute_request', content)
123 self.execute_request(s, [], msg)
124 self.execute_request(s, [], msg)
124
125
125
126
126 #-------------------- control handlers -----------------------------
127 #-------------------- control handlers -----------------------------
127 def abort_queues(self):
128 def abort_queues(self):
128 for stream in self.shell_streams:
129 for stream in self.shell_streams:
129 if stream:
130 if stream:
130 self.abort_queue(stream)
131 self.abort_queue(stream)
131
132
132 def abort_queue(self, stream):
133 def abort_queue(self, stream):
133 while True:
134 while True:
134 try:
135 try:
135 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
136 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
136 except zmq.ZMQError as e:
137 except zmq.ZMQError as e:
137 if e.errno == zmq.EAGAIN:
138 if e.errno == zmq.EAGAIN:
138 break
139 break
139 else:
140 else:
140 return
141 return
141 else:
142 else:
142 if msg is None:
143 if msg is None:
143 return
144 return
144 else:
145 else:
145 idents,msg = msg
146 idents,msg = msg
146
147
147 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
148 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
148 # msg = self.reply_socket.recv_json()
149 # msg = self.reply_socket.recv_json()
149 self.log.info("Aborting:")
150 self.log.info("Aborting:")
150 self.log.info(str(msg))
151 self.log.info(str(msg))
151 msg_type = msg['msg_type']
152 msg_type = msg['msg_type']
152 reply_type = msg_type.split('_')[0] + '_reply'
153 reply_type = msg_type.split('_')[0] + '_reply'
153 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
154 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
154 # self.reply_socket.send(ident,zmq.SNDMORE)
155 # self.reply_socket.send(ident,zmq.SNDMORE)
155 # self.reply_socket.send_json(reply_msg)
156 # self.reply_socket.send_json(reply_msg)
156 reply_msg = self.session.send(stream, reply_type,
157 reply_msg = self.session.send(stream, reply_type,
157 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
158 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
158 self.log.debug(str(reply_msg))
159 self.log.debug(str(reply_msg))
159 # We need to wait a bit for requests to come in. This can probably
160 # We need to wait a bit for requests to come in. This can probably
160 # be set shorter for true asynchronous clients.
161 # be set shorter for true asynchronous clients.
161 time.sleep(0.05)
162 time.sleep(0.05)
162
163
163 def abort_request(self, stream, ident, parent):
164 def abort_request(self, stream, ident, parent):
164 """abort a specifig msg by id"""
165 """abort a specifig msg by id"""
165 msg_ids = parent['content'].get('msg_ids', None)
166 msg_ids = parent['content'].get('msg_ids', None)
166 if isinstance(msg_ids, basestring):
167 if isinstance(msg_ids, basestring):
167 msg_ids = [msg_ids]
168 msg_ids = [msg_ids]
168 if not msg_ids:
169 if not msg_ids:
169 self.abort_queues()
170 self.abort_queues()
170 for mid in msg_ids:
171 for mid in msg_ids:
171 self.aborted.add(str(mid))
172 self.aborted.add(str(mid))
172
173
173 content = dict(status='ok')
174 content = dict(status='ok')
174 reply_msg = self.session.send(stream, 'abort_reply', content=content,
175 reply_msg = self.session.send(stream, 'abort_reply', content=content,
175 parent=parent, ident=ident)
176 parent=parent, ident=ident)
176 self.log.debug(str(reply_msg))
177 self.log.debug(str(reply_msg))
177
178
178 def shutdown_request(self, stream, ident, parent):
179 def shutdown_request(self, stream, ident, parent):
179 """kill ourself. This should really be handled in an external process"""
180 """kill ourself. This should really be handled in an external process"""
180 try:
181 try:
181 self.abort_queues()
182 self.abort_queues()
182 except:
183 except:
183 content = self._wrap_exception('shutdown')
184 content = self._wrap_exception('shutdown')
184 else:
185 else:
185 content = dict(parent['content'])
186 content = dict(parent['content'])
186 content['status'] = 'ok'
187 content['status'] = 'ok'
187 msg = self.session.send(stream, 'shutdown_reply',
188 msg = self.session.send(stream, 'shutdown_reply',
188 content=content, parent=parent, ident=ident)
189 content=content, parent=parent, ident=ident)
189 self.log.debug(str(msg))
190 self.log.debug(str(msg))
190 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
191 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
191 dc.start()
192 dc.start()
192
193
193 def dispatch_control(self, msg):
194 def dispatch_control(self, msg):
194 idents,msg = self.session.feed_identities(msg, copy=False)
195 idents,msg = self.session.feed_identities(msg, copy=False)
195 try:
196 try:
196 msg = self.session.unpack_message(msg, content=True, copy=False)
197 msg = self.session.unpack_message(msg, content=True, copy=False)
197 except:
198 except:
198 self.log.error("Invalid Message", exc_info=True)
199 self.log.error("Invalid Message", exc_info=True)
199 return
200 return
200
201
201 header = msg['header']
202 header = msg['header']
202 msg_id = header['msg_id']
203 msg_id = header['msg_id']
203
204
204 handler = self.control_handlers.get(msg['msg_type'], None)
205 handler = self.control_handlers.get(msg['msg_type'], None)
205 if handler is None:
206 if handler is None:
206 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
207 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
207 else:
208 else:
208 handler(self.control_stream, idents, msg)
209 handler(self.control_stream, idents, msg)
209
210
210
211
211 #-------------------- queue helpers ------------------------------
212 #-------------------- queue helpers ------------------------------
212
213
213 def check_dependencies(self, dependencies):
214 def check_dependencies(self, dependencies):
214 if not dependencies:
215 if not dependencies:
215 return True
216 return True
216 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
217 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
217 anyorall = dependencies[0]
218 anyorall = dependencies[0]
218 dependencies = dependencies[1]
219 dependencies = dependencies[1]
219 else:
220 else:
220 anyorall = 'all'
221 anyorall = 'all'
221 results = self.client.get_results(dependencies,status_only=True)
222 results = self.client.get_results(dependencies,status_only=True)
222 if results['status'] != 'ok':
223 if results['status'] != 'ok':
223 return False
224 return False
224
225
225 if anyorall == 'any':
226 if anyorall == 'any':
226 if not results['completed']:
227 if not results['completed']:
227 return False
228 return False
228 else:
229 else:
229 if results['pending']:
230 if results['pending']:
230 return False
231 return False
231
232
232 return True
233 return True
233
234
234 def check_aborted(self, msg_id):
235 def check_aborted(self, msg_id):
235 return msg_id in self.aborted
236 return msg_id in self.aborted
236
237
237 #-------------------- queue handlers -----------------------------
238 #-------------------- queue handlers -----------------------------
238
239
239 def clear_request(self, stream, idents, parent):
240 def clear_request(self, stream, idents, parent):
240 """Clear our namespace."""
241 """Clear our namespace."""
241 self.user_ns = {}
242 self.user_ns = {}
242 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
243 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
243 content = dict(status='ok'))
244 content = dict(status='ok'))
244 self._initial_exec_lines()
245 self._initial_exec_lines()
245
246
246 def execute_request(self, stream, ident, parent):
247 def execute_request(self, stream, ident, parent):
247 self.log.debug('execute request %s'%parent)
248 self.log.debug('execute request %s'%parent)
248 try:
249 try:
249 code = parent[u'content'][u'code']
250 code = parent[u'content'][u'code']
250 except:
251 except:
251 self.log.error("Got bad msg: %s"%parent, exc_info=True)
252 self.log.error("Got bad msg: %s"%parent, exc_info=True)
252 return
253 return
253 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
254 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
254 ident='%s.pyin'%self.prefix)
255 ident='%s.pyin'%self.prefix)
255 started = datetime.now().strftime(ISO8601)
256 started = datetime.now().strftime(ISO8601)
256 try:
257 try:
257 comp_code = self.compiler(code, '<zmq-kernel>')
258 comp_code = self.compiler(code, '<zmq-kernel>')
258 # allow for not overriding displayhook
259 # allow for not overriding displayhook
259 if hasattr(sys.displayhook, 'set_parent'):
260 if hasattr(sys.displayhook, 'set_parent'):
260 sys.displayhook.set_parent(parent)
261 sys.displayhook.set_parent(parent)
261 sys.stdout.set_parent(parent)
262 sys.stdout.set_parent(parent)
262 sys.stderr.set_parent(parent)
263 sys.stderr.set_parent(parent)
263 exec comp_code in self.user_ns, self.user_ns
264 exec comp_code in self.user_ns, self.user_ns
264 except:
265 except:
265 exc_content = self._wrap_exception('execute')
266 exc_content = self._wrap_exception('execute')
266 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
267 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
267 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
268 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
268 ident='%s.pyerr'%self.prefix)
269 ident='%s.pyerr'%self.prefix)
269 reply_content = exc_content
270 reply_content = exc_content
270 else:
271 else:
271 reply_content = {'status' : 'ok'}
272 reply_content = {'status' : 'ok'}
272
273
273 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
274 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
274 ident=ident, subheader = dict(started=started))
275 ident=ident, subheader = dict(started=started))
275 self.log.debug(str(reply_msg))
276 self.log.debug(str(reply_msg))
276 if reply_msg['content']['status'] == u'error':
277 if reply_msg['content']['status'] == u'error':
277 self.abort_queues()
278 self.abort_queues()
278
279
279 def complete_request(self, stream, ident, parent):
280 def complete_request(self, stream, ident, parent):
280 matches = {'matches' : self.complete(parent),
281 matches = {'matches' : self.complete(parent),
281 'status' : 'ok'}
282 'status' : 'ok'}
282 completion_msg = self.session.send(stream, 'complete_reply',
283 completion_msg = self.session.send(stream, 'complete_reply',
283 matches, parent, ident)
284 matches, parent, ident)
284 # print >> sys.__stdout__, completion_msg
285 # print >> sys.__stdout__, completion_msg
285
286
286 def complete(self, msg):
287 def complete(self, msg):
287 return self.completer.complete(msg.content.line, msg.content.text)
288 return self.completer.complete(msg.content.line, msg.content.text)
288
289
289 def apply_request(self, stream, ident, parent):
290 def apply_request(self, stream, ident, parent):
290 # flush previous reply, so this request won't block it
291 # flush previous reply, so this request won't block it
291 stream.flush(zmq.POLLOUT)
292 stream.flush(zmq.POLLOUT)
292
293
293 try:
294 try:
294 content = parent[u'content']
295 content = parent[u'content']
295 bufs = parent[u'buffers']
296 bufs = parent[u'buffers']
296 msg_id = parent['header']['msg_id']
297 msg_id = parent['header']['msg_id']
297 # bound = parent['header'].get('bound', False)
298 # bound = parent['header'].get('bound', False)
298 except:
299 except:
299 self.log.error("Got bad msg: %s"%parent, exc_info=True)
300 self.log.error("Got bad msg: %s"%parent, exc_info=True)
300 return
301 return
301 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
302 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
302 # self.iopub_stream.send(pyin_msg)
303 # self.iopub_stream.send(pyin_msg)
303 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
304 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
304 sub = {'dependencies_met' : True, 'engine' : self.ident,
305 sub = {'dependencies_met' : True, 'engine' : self.ident,
305 'started': datetime.now().strftime(ISO8601)}
306 'started': datetime.now().strftime(ISO8601)}
306 try:
307 try:
307 # allow for not overriding displayhook
308 # allow for not overriding displayhook
308 if hasattr(sys.displayhook, 'set_parent'):
309 if hasattr(sys.displayhook, 'set_parent'):
309 sys.displayhook.set_parent(parent)
310 sys.displayhook.set_parent(parent)
310 sys.stdout.set_parent(parent)
311 sys.stdout.set_parent(parent)
311 sys.stderr.set_parent(parent)
312 sys.stderr.set_parent(parent)
312 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
313 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
313 working = self.user_ns
314 working = self.user_ns
314 # suffix =
315 # suffix =
315 prefix = "_"+str(msg_id).replace("-","")+"_"
316 prefix = "_"+str(msg_id).replace("-","")+"_"
316
317
317 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
318 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
318 # if bound:
319 # if bound:
319 # bound_ns = Namespace(working)
320 # bound_ns = Namespace(working)
320 # args = [bound_ns]+list(args)
321 # args = [bound_ns]+list(args)
321
322
322 fname = getattr(f, '__name__', 'f')
323 fname = getattr(f, '__name__', 'f')
323
324
324 fname = prefix+"f"
325 fname = prefix+"f"
325 argname = prefix+"args"
326 argname = prefix+"args"
326 kwargname = prefix+"kwargs"
327 kwargname = prefix+"kwargs"
327 resultname = prefix+"result"
328 resultname = prefix+"result"
328
329
329 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
330 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
330 # print ns
331 # print ns
331 working.update(ns)
332 working.update(ns)
332 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
333 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
333 try:
334 try:
334 exec code in working,working
335 exec code in working,working
335 result = working.get(resultname)
336 result = working.get(resultname)
336 finally:
337 finally:
337 for key in ns.iterkeys():
338 for key in ns.iterkeys():
338 working.pop(key)
339 working.pop(key)
339 # if bound:
340 # if bound:
340 # working.update(bound_ns)
341 # working.update(bound_ns)
341
342
342 packed_result,buf = serialize_object(result)
343 packed_result,buf = serialize_object(result)
343 result_buf = [packed_result]+buf
344 result_buf = [packed_result]+buf
344 except:
345 except:
345 exc_content = self._wrap_exception('apply')
346 exc_content = self._wrap_exception('apply')
346 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
347 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
347 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
348 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
348 ident='%s.pyerr'%self.prefix)
349 ident='%s.pyerr'%self.prefix)
349 reply_content = exc_content
350 reply_content = exc_content
350 result_buf = []
351 result_buf = []
351
352
352 if exc_content['ename'] == 'UnmetDependency':
353 if exc_content['ename'] == 'UnmetDependency':
353 sub['dependencies_met'] = False
354 sub['dependencies_met'] = False
354 else:
355 else:
355 reply_content = {'status' : 'ok'}
356 reply_content = {'status' : 'ok'}
356
357
357 # put 'ok'/'error' status in header, for scheduler introspection:
358 # put 'ok'/'error' status in header, for scheduler introspection:
358 sub['status'] = reply_content['status']
359 sub['status'] = reply_content['status']
359
360
360 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
361 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
361 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
362 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
362
363
363 # flush i/o
364 # flush i/o
364 # should this be before reply_msg is sent, like in the single-kernel code,
365 # should this be before reply_msg is sent, like in the single-kernel code,
365 # or should nothing get in the way of real results?
366 # or should nothing get in the way of real results?
366 sys.stdout.flush()
367 sys.stdout.flush()
367 sys.stderr.flush()
368 sys.stderr.flush()
368
369
369 def dispatch_queue(self, stream, msg):
370 def dispatch_queue(self, stream, msg):
370 self.control_stream.flush()
371 self.control_stream.flush()
371 idents,msg = self.session.feed_identities(msg, copy=False)
372 idents,msg = self.session.feed_identities(msg, copy=False)
372 try:
373 try:
373 msg = self.session.unpack_message(msg, content=True, copy=False)
374 msg = self.session.unpack_message(msg, content=True, copy=False)
374 except:
375 except:
375 self.log.error("Invalid Message", exc_info=True)
376 self.log.error("Invalid Message", exc_info=True)
376 return
377 return
377
378
378
379
379 header = msg['header']
380 header = msg['header']
380 msg_id = header['msg_id']
381 msg_id = header['msg_id']
381 if self.check_aborted(msg_id):
382 if self.check_aborted(msg_id):
382 self.aborted.remove(msg_id)
383 self.aborted.remove(msg_id)
383 # is it safe to assume a msg_id will not be resubmitted?
384 # is it safe to assume a msg_id will not be resubmitted?
384 reply_type = msg['msg_type'].split('_')[0] + '_reply'
385 reply_type = msg['msg_type'].split('_')[0] + '_reply'
385 status = {'status' : 'aborted'}
386 status = {'status' : 'aborted'}
386 reply_msg = self.session.send(stream, reply_type, subheader=status,
387 reply_msg = self.session.send(stream, reply_type, subheader=status,
387 content=status, parent=msg, ident=idents)
388 content=status, parent=msg, ident=idents)
388 return
389 return
389 handler = self.shell_handlers.get(msg['msg_type'], None)
390 handler = self.shell_handlers.get(msg['msg_type'], None)
390 if handler is None:
391 if handler is None:
391 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
392 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
392 else:
393 else:
393 handler(stream, idents, msg)
394 handler(stream, idents, msg)
394
395
395 def start(self):
396 def start(self):
396 #### stream mode:
397 #### stream mode:
397 if self.control_stream:
398 if self.control_stream:
398 self.control_stream.on_recv(self.dispatch_control, copy=False)
399 self.control_stream.on_recv(self.dispatch_control, copy=False)
399 self.control_stream.on_err(printer)
400 self.control_stream.on_err(printer)
400
401
401 def make_dispatcher(stream):
402 def make_dispatcher(stream):
402 def dispatcher(msg):
403 def dispatcher(msg):
403 return self.dispatch_queue(stream, msg)
404 return self.dispatch_queue(stream, msg)
404 return dispatcher
405 return dispatcher
405
406
406 for s in self.shell_streams:
407 for s in self.shell_streams:
407 s.on_recv(make_dispatcher(s), copy=False)
408 s.on_recv(make_dispatcher(s), copy=False)
408 s.on_err(printer)
409 s.on_err(printer)
409
410
410 if self.iopub_stream:
411 if self.iopub_stream:
411 self.iopub_stream.on_err(printer)
412 self.iopub_stream.on_err(printer)
412
413
413 #### while True mode:
414 #### while True mode:
414 # while True:
415 # while True:
415 # idle = True
416 # idle = True
416 # try:
417 # try:
417 # msg = self.shell_stream.socket.recv_multipart(
418 # msg = self.shell_stream.socket.recv_multipart(
418 # zmq.NOBLOCK, copy=False)
419 # zmq.NOBLOCK, copy=False)
419 # except zmq.ZMQError, e:
420 # except zmq.ZMQError, e:
420 # if e.errno != zmq.EAGAIN:
421 # if e.errno != zmq.EAGAIN:
421 # raise e
422 # raise e
422 # else:
423 # else:
423 # idle=False
424 # idle=False
424 # self.dispatch_queue(self.shell_stream, msg)
425 # self.dispatch_queue(self.shell_stream, msg)
425 #
426 #
426 # if not self.task_stream.empty():
427 # if not self.task_stream.empty():
427 # idle=False
428 # idle=False
428 # msg = self.task_stream.recv_multipart()
429 # msg = self.task_stream.recv_multipart()
429 # self.dispatch_queue(self.task_stream, msg)
430 # self.dispatch_queue(self.task_stream, msg)
430 # if idle:
431 # if idle:
431 # # don't busywait
432 # # don't busywait
432 # time.sleep(1e-3)
433 # time.sleep(1e-3)
433
434
@@ -1,99 +1,99 b''
1 """Base config factories."""
1 """Base config factories."""
2
2
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2008-2009 The IPython Development Team
4 # Copyright (C) 2008-2009 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14
14
15 import logging
15 import logging
16 import os
16 import os
17
17
18 import zmq
18 import zmq
19 from zmq.eventloop.ioloop import IOLoop
19 from zmq.eventloop.ioloop import IOLoop
20
20
21 from IPython.config.configurable import Configurable
21 from IPython.config.configurable import Configurable
22 from IPython.utils.traitlets import Int, Instance, Unicode
22 from IPython.utils.traitlets import Int, Instance, Unicode
23
23
24 import IPython.parallel.streamsession as ss
25 from IPython.parallel.util import select_random_ports
24 from IPython.parallel.util import select_random_ports
25 from IPython.zmq.session import Session
26
26
27 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
28 # Classes
28 # Classes
29 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
30 class LoggingFactory(Configurable):
30 class LoggingFactory(Configurable):
31 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
31 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
32 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
32 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
33 logname = Unicode('ZMQ')
33 logname = Unicode('ZMQ')
34 def _logname_changed(self, name, old, new):
34 def _logname_changed(self, name, old, new):
35 self.log = logging.getLogger(new)
35 self.log = logging.getLogger(new)
36
36
37
37
38 class SessionFactory(LoggingFactory):
38 class SessionFactory(LoggingFactory):
39 """The Base factory from which every factory in IPython.parallel inherits"""
39 """The Base factory from which every factory in IPython.parallel inherits"""
40
40
41 # not configurable:
41 # not configurable:
42 context = Instance('zmq.Context')
42 context = Instance('zmq.Context')
43 def _context_default(self):
43 def _context_default(self):
44 return zmq.Context.instance()
44 return zmq.Context.instance()
45
45
46 session = Instance('IPython.parallel.streamsession.StreamSession')
46 session = Instance('IPython.zmq.session.Session')
47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
48 def _loop_default(self):
48 def _loop_default(self):
49 return IOLoop.instance()
49 return IOLoop.instance()
50
50
51
51
52 def __init__(self, **kwargs):
52 def __init__(self, **kwargs):
53 super(SessionFactory, self).__init__(**kwargs)
53 super(SessionFactory, self).__init__(**kwargs)
54
54
55 # construct the session
55 # construct the session
56 self.session = ss.StreamSession(**kwargs)
56 self.session = Session(**kwargs)
57
57
58
58
59 class RegistrationFactory(SessionFactory):
59 class RegistrationFactory(SessionFactory):
60 """The Base Configurable for objects that involve registration."""
60 """The Base Configurable for objects that involve registration."""
61
61
62 url = Unicode('', config=True,
62 url = Unicode('', config=True,
63 help="""The 0MQ url used for registration. This sets transport, ip, and port
63 help="""The 0MQ url used for registration. This sets transport, ip, and port
64 in one variable. For example: url='tcp://127.0.0.1:12345' or
64 in one variable. For example: url='tcp://127.0.0.1:12345' or
65 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
65 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
66 transport = Unicode('tcp', config=True,
66 transport = Unicode('tcp', config=True,
67 help="""The 0MQ transport for communications. This will likely be
67 help="""The 0MQ transport for communications. This will likely be
68 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
68 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
69 ip = Unicode('127.0.0.1', config=True,
69 ip = Unicode('127.0.0.1', config=True,
70 help="""The IP address for registration. This is generally either
70 help="""The IP address for registration. This is generally either
71 '127.0.0.1' for loopback only or '*' for all interfaces.
71 '127.0.0.1' for loopback only or '*' for all interfaces.
72 [default: '127.0.0.1']""")
72 [default: '127.0.0.1']""")
73 regport = Int(config=True,
73 regport = Int(config=True,
74 help="""The port on which the Hub listens for registration.""")
74 help="""The port on which the Hub listens for registration.""")
75 def _regport_default(self):
75 def _regport_default(self):
76 return select_random_ports(1)[0]
76 return select_random_ports(1)[0]
77
77
78 def __init__(self, **kwargs):
78 def __init__(self, **kwargs):
79 super(RegistrationFactory, self).__init__(**kwargs)
79 super(RegistrationFactory, self).__init__(**kwargs)
80 self._propagate_url()
80 self._propagate_url()
81 self._rebuild_url()
81 self._rebuild_url()
82 self.on_trait_change(self._propagate_url, 'url')
82 self.on_trait_change(self._propagate_url, 'url')
83 self.on_trait_change(self._rebuild_url, 'ip')
83 self.on_trait_change(self._rebuild_url, 'ip')
84 self.on_trait_change(self._rebuild_url, 'transport')
84 self.on_trait_change(self._rebuild_url, 'transport')
85 self.on_trait_change(self._rebuild_url, 'regport')
85 self.on_trait_change(self._rebuild_url, 'regport')
86
86
87 def _rebuild_url(self):
87 def _rebuild_url(self):
88 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
88 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
89
89
90 def _propagate_url(self):
90 def _propagate_url(self):
91 """Ensure self.url contains full transport://interface:port"""
91 """Ensure self.url contains full transport://interface:port"""
92 if self.url:
92 if self.url:
93 iface = self.url.split('://',1)
93 iface = self.url.split('://',1)
94 if len(iface) == 2:
94 if len(iface) == 2:
95 self.transport,iface = iface
95 self.transport,iface = iface
96 iface = iface.split(':')
96 iface = iface.split(':')
97 self.ip = iface[0]
97 self.ip = iface[0]
98 if iface[1]:
98 if iface[1]:
99 self.regport = int(iface[1])
99 self.regport = int(iface[1])
@@ -1,170 +1,173 b''
1 """Tests for db backends"""
1 """Tests for db backends"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14
14
15 import tempfile
15 import tempfile
16 import time
16 import time
17
17
18 from datetime import datetime, timedelta
18 from datetime import datetime, timedelta
19 from unittest import TestCase
19 from unittest import TestCase
20
20
21 from nose import SkipTest
21 from nose import SkipTest
22
22
23 from IPython.parallel import error, streamsession as ss
23 from IPython.parallel import error
24 from IPython.parallel.controller.dictdb import DictDB
24 from IPython.parallel.controller.dictdb import DictDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
25 from IPython.parallel.controller.sqlitedb import SQLiteDB
26 from IPython.parallel.controller.hub import init_record, empty_record
26 from IPython.parallel.controller.hub import init_record, empty_record
27
27
28 from IPython.zmq.session import Session
29
30
28 #-------------------------------------------------------------------------------
31 #-------------------------------------------------------------------------------
29 # TestCases
32 # TestCases
30 #-------------------------------------------------------------------------------
33 #-------------------------------------------------------------------------------
31
34
32 class TestDictBackend(TestCase):
35 class TestDictBackend(TestCase):
33 def setUp(self):
36 def setUp(self):
34 self.session = ss.StreamSession()
37 self.session = Session()
35 self.db = self.create_db()
38 self.db = self.create_db()
36 self.load_records(16)
39 self.load_records(16)
37
40
38 def create_db(self):
41 def create_db(self):
39 return DictDB()
42 return DictDB()
40
43
41 def load_records(self, n=1):
44 def load_records(self, n=1):
42 """load n records for testing"""
45 """load n records for testing"""
43 #sleep 1/10 s, to ensure timestamp is different to previous calls
46 #sleep 1/10 s, to ensure timestamp is different to previous calls
44 time.sleep(0.1)
47 time.sleep(0.1)
45 msg_ids = []
48 msg_ids = []
46 for i in range(n):
49 for i in range(n):
47 msg = self.session.msg('apply_request', content=dict(a=5))
50 msg = self.session.msg('apply_request', content=dict(a=5))
48 msg['buffers'] = []
51 msg['buffers'] = []
49 rec = init_record(msg)
52 rec = init_record(msg)
50 msg_ids.append(msg['msg_id'])
53 msg_ids.append(msg['msg_id'])
51 self.db.add_record(msg['msg_id'], rec)
54 self.db.add_record(msg['msg_id'], rec)
52 return msg_ids
55 return msg_ids
53
56
54 def test_add_record(self):
57 def test_add_record(self):
55 before = self.db.get_history()
58 before = self.db.get_history()
56 self.load_records(5)
59 self.load_records(5)
57 after = self.db.get_history()
60 after = self.db.get_history()
58 self.assertEquals(len(after), len(before)+5)
61 self.assertEquals(len(after), len(before)+5)
59 self.assertEquals(after[:-5],before)
62 self.assertEquals(after[:-5],before)
60
63
61 def test_drop_record(self):
64 def test_drop_record(self):
62 msg_id = self.load_records()[-1]
65 msg_id = self.load_records()[-1]
63 rec = self.db.get_record(msg_id)
66 rec = self.db.get_record(msg_id)
64 self.db.drop_record(msg_id)
67 self.db.drop_record(msg_id)
65 self.assertRaises(KeyError,self.db.get_record, msg_id)
68 self.assertRaises(KeyError,self.db.get_record, msg_id)
66
69
67 def _round_to_millisecond(self, dt):
70 def _round_to_millisecond(self, dt):
68 """necessary because mongodb rounds microseconds"""
71 """necessary because mongodb rounds microseconds"""
69 micro = dt.microsecond
72 micro = dt.microsecond
70 extra = int(str(micro)[-3:])
73 extra = int(str(micro)[-3:])
71 return dt - timedelta(microseconds=extra)
74 return dt - timedelta(microseconds=extra)
72
75
73 def test_update_record(self):
76 def test_update_record(self):
74 now = self._round_to_millisecond(datetime.now())
77 now = self._round_to_millisecond(datetime.now())
75 #
78 #
76 msg_id = self.db.get_history()[-1]
79 msg_id = self.db.get_history()[-1]
77 rec1 = self.db.get_record(msg_id)
80 rec1 = self.db.get_record(msg_id)
78 data = {'stdout': 'hello there', 'completed' : now}
81 data = {'stdout': 'hello there', 'completed' : now}
79 self.db.update_record(msg_id, data)
82 self.db.update_record(msg_id, data)
80 rec2 = self.db.get_record(msg_id)
83 rec2 = self.db.get_record(msg_id)
81 self.assertEquals(rec2['stdout'], 'hello there')
84 self.assertEquals(rec2['stdout'], 'hello there')
82 self.assertEquals(rec2['completed'], now)
85 self.assertEquals(rec2['completed'], now)
83 rec1.update(data)
86 rec1.update(data)
84 self.assertEquals(rec1, rec2)
87 self.assertEquals(rec1, rec2)
85
88
86 # def test_update_record_bad(self):
89 # def test_update_record_bad(self):
87 # """test updating nonexistant records"""
90 # """test updating nonexistant records"""
88 # msg_id = str(uuid.uuid4())
91 # msg_id = str(uuid.uuid4())
89 # data = {'stdout': 'hello there'}
92 # data = {'stdout': 'hello there'}
90 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
91
94
92 def test_find_records_dt(self):
95 def test_find_records_dt(self):
93 """test finding records by date"""
96 """test finding records by date"""
94 hist = self.db.get_history()
97 hist = self.db.get_history()
95 middle = self.db.get_record(hist[len(hist)/2])
98 middle = self.db.get_record(hist[len(hist)/2])
96 tic = middle['submitted']
99 tic = middle['submitted']
97 before = self.db.find_records({'submitted' : {'$lt' : tic}})
100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
98 after = self.db.find_records({'submitted' : {'$gte' : tic}})
101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
99 self.assertEquals(len(before)+len(after),len(hist))
102 self.assertEquals(len(before)+len(after),len(hist))
100 for b in before:
103 for b in before:
101 self.assertTrue(b['submitted'] < tic)
104 self.assertTrue(b['submitted'] < tic)
102 for a in after:
105 for a in after:
103 self.assertTrue(a['submitted'] >= tic)
106 self.assertTrue(a['submitted'] >= tic)
104 same = self.db.find_records({'submitted' : tic})
107 same = self.db.find_records({'submitted' : tic})
105 for s in same:
108 for s in same:
106 self.assertTrue(s['submitted'] == tic)
109 self.assertTrue(s['submitted'] == tic)
107
110
108 def test_find_records_keys(self):
111 def test_find_records_keys(self):
109 """test extracting subset of record keys"""
112 """test extracting subset of record keys"""
110 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
111 for rec in found:
114 for rec in found:
112 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
113
116
114 def test_find_records_msg_id(self):
117 def test_find_records_msg_id(self):
115 """ensure msg_id is always in found records"""
118 """ensure msg_id is always in found records"""
116 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
117 for rec in found:
120 for rec in found:
118 self.assertTrue('msg_id' in rec.keys())
121 self.assertTrue('msg_id' in rec.keys())
119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
120 for rec in found:
123 for rec in found:
121 self.assertTrue('msg_id' in rec.keys())
124 self.assertTrue('msg_id' in rec.keys())
122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
123 for rec in found:
126 for rec in found:
124 self.assertTrue('msg_id' in rec.keys())
127 self.assertTrue('msg_id' in rec.keys())
125
128
126 def test_find_records_in(self):
129 def test_find_records_in(self):
127 """test finding records with '$in','$nin' operators"""
130 """test finding records with '$in','$nin' operators"""
128 hist = self.db.get_history()
131 hist = self.db.get_history()
129 even = hist[::2]
132 even = hist[::2]
130 odd = hist[1::2]
133 odd = hist[1::2]
131 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
132 found = [ r['msg_id'] for r in recs ]
135 found = [ r['msg_id'] for r in recs ]
133 self.assertEquals(set(even), set(found))
136 self.assertEquals(set(even), set(found))
134 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
135 found = [ r['msg_id'] for r in recs ]
138 found = [ r['msg_id'] for r in recs ]
136 self.assertEquals(set(odd), set(found))
139 self.assertEquals(set(odd), set(found))
137
140
138 def test_get_history(self):
141 def test_get_history(self):
139 msg_ids = self.db.get_history()
142 msg_ids = self.db.get_history()
140 latest = datetime(1984,1,1)
143 latest = datetime(1984,1,1)
141 for msg_id in msg_ids:
144 for msg_id in msg_ids:
142 rec = self.db.get_record(msg_id)
145 rec = self.db.get_record(msg_id)
143 newt = rec['submitted']
146 newt = rec['submitted']
144 self.assertTrue(newt >= latest)
147 self.assertTrue(newt >= latest)
145 latest = newt
148 latest = newt
146 msg_id = self.load_records(1)[-1]
149 msg_id = self.load_records(1)[-1]
147 self.assertEquals(self.db.get_history()[-1],msg_id)
150 self.assertEquals(self.db.get_history()[-1],msg_id)
148
151
149 def test_datetime(self):
152 def test_datetime(self):
150 """get/set timestamps with datetime objects"""
153 """get/set timestamps with datetime objects"""
151 msg_id = self.db.get_history()[-1]
154 msg_id = self.db.get_history()[-1]
152 rec = self.db.get_record(msg_id)
155 rec = self.db.get_record(msg_id)
153 self.assertTrue(isinstance(rec['submitted'], datetime))
156 self.assertTrue(isinstance(rec['submitted'], datetime))
154 self.db.update_record(msg_id, dict(completed=datetime.now()))
157 self.db.update_record(msg_id, dict(completed=datetime.now()))
155 rec = self.db.get_record(msg_id)
158 rec = self.db.get_record(msg_id)
156 self.assertTrue(isinstance(rec['completed'], datetime))
159 self.assertTrue(isinstance(rec['completed'], datetime))
157
160
158 def test_drop_matching(self):
161 def test_drop_matching(self):
159 msg_ids = self.load_records(10)
162 msg_ids = self.load_records(10)
160 query = {'msg_id' : {'$in':msg_ids}}
163 query = {'msg_id' : {'$in':msg_ids}}
161 self.db.drop_matching_records(query)
164 self.db.drop_matching_records(query)
162 recs = self.db.find_records(query)
165 recs = self.db.find_records(query)
163 self.assertTrue(len(recs)==0)
166 self.assertTrue(len(recs)==0)
164
167
165 class TestSQLiteBackend(TestDictBackend):
168 class TestSQLiteBackend(TestDictBackend):
166 def create_db(self):
169 def create_db(self):
167 return SQLiteDB(location=tempfile.gettempdir())
170 return SQLiteDB(location=tempfile.gettempdir())
168
171
169 def tearDown(self):
172 def tearDown(self):
170 self.db._db.close()
173 self.db._db.close()
@@ -1,483 +1,466 b''
1 """some generic utilities for dealing with classes, urls, and serialization"""
1 """some generic utilities for dealing with classes, urls, and serialization"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team
3 # Copyright (C) 2010-2011 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 # Standard library imports.
13 # Standard library imports.
14 import logging
14 import logging
15 import os
15 import os
16 import re
16 import re
17 import stat
17 import stat
18 import socket
18 import socket
19 import sys
19 import sys
20 from datetime import datetime
21 from signal import signal, SIGINT, SIGABRT, SIGTERM
20 from signal import signal, SIGINT, SIGABRT, SIGTERM
22 try:
21 try:
23 from signal import SIGKILL
22 from signal import SIGKILL
24 except ImportError:
23 except ImportError:
25 SIGKILL=None
24 SIGKILL=None
26
25
27 try:
26 try:
28 import cPickle
27 import cPickle
29 pickle = cPickle
28 pickle = cPickle
30 except:
29 except:
31 cPickle = None
30 cPickle = None
32 import pickle
31 import pickle
33
32
34 # System library imports
33 # System library imports
35 import zmq
34 import zmq
36 from zmq.log import handlers
35 from zmq.log import handlers
37
36
38 # IPython imports
37 # IPython imports
39 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
38 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
40 from IPython.utils.newserialized import serialize, unserialize
39 from IPython.utils.newserialized import serialize, unserialize
41 from IPython.zmq.log import EnginePUBHandler
40 from IPython.zmq.log import EnginePUBHandler
42
41
43 # globals
44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
46
47 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
48 # Classes
43 # Classes
49 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
50
45
51 class Namespace(dict):
46 class Namespace(dict):
52 """Subclass of dict for attribute access to keys."""
47 """Subclass of dict for attribute access to keys."""
53
48
54 def __getattr__(self, key):
49 def __getattr__(self, key):
55 """getattr aliased to getitem"""
50 """getattr aliased to getitem"""
56 if key in self.iterkeys():
51 if key in self.iterkeys():
57 return self[key]
52 return self[key]
58 else:
53 else:
59 raise NameError(key)
54 raise NameError(key)
60
55
61 def __setattr__(self, key, value):
56 def __setattr__(self, key, value):
62 """setattr aliased to setitem, with strict"""
57 """setattr aliased to setitem, with strict"""
63 if hasattr(dict, key):
58 if hasattr(dict, key):
64 raise KeyError("Cannot override dict keys %r"%key)
59 raise KeyError("Cannot override dict keys %r"%key)
65 self[key] = value
60 self[key] = value
66
61
67
62
68 class ReverseDict(dict):
63 class ReverseDict(dict):
69 """simple double-keyed subset of dict methods."""
64 """simple double-keyed subset of dict methods."""
70
65
71 def __init__(self, *args, **kwargs):
66 def __init__(self, *args, **kwargs):
72 dict.__init__(self, *args, **kwargs)
67 dict.__init__(self, *args, **kwargs)
73 self._reverse = dict()
68 self._reverse = dict()
74 for key, value in self.iteritems():
69 for key, value in self.iteritems():
75 self._reverse[value] = key
70 self._reverse[value] = key
76
71
77 def __getitem__(self, key):
72 def __getitem__(self, key):
78 try:
73 try:
79 return dict.__getitem__(self, key)
74 return dict.__getitem__(self, key)
80 except KeyError:
75 except KeyError:
81 return self._reverse[key]
76 return self._reverse[key]
82
77
83 def __setitem__(self, key, value):
78 def __setitem__(self, key, value):
84 if key in self._reverse:
79 if key in self._reverse:
85 raise KeyError("Can't have key %r on both sides!"%key)
80 raise KeyError("Can't have key %r on both sides!"%key)
86 dict.__setitem__(self, key, value)
81 dict.__setitem__(self, key, value)
87 self._reverse[value] = key
82 self._reverse[value] = key
88
83
89 def pop(self, key):
84 def pop(self, key):
90 value = dict.pop(self, key)
85 value = dict.pop(self, key)
91 self._reverse.pop(value)
86 self._reverse.pop(value)
92 return value
87 return value
93
88
94 def get(self, key, default=None):
89 def get(self, key, default=None):
95 try:
90 try:
96 return self[key]
91 return self[key]
97 except KeyError:
92 except KeyError:
98 return default
93 return default
99
94
100 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
101 # Functions
96 # Functions
102 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
103
98
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
116 def validate_url(url):
99 def validate_url(url):
117 """validate a url for zeromq"""
100 """validate a url for zeromq"""
118 if not isinstance(url, basestring):
101 if not isinstance(url, basestring):
119 raise TypeError("url must be a string, not %r"%type(url))
102 raise TypeError("url must be a string, not %r"%type(url))
120 url = url.lower()
103 url = url.lower()
121
104
122 proto_addr = url.split('://')
105 proto_addr = url.split('://')
123 assert len(proto_addr) == 2, 'Invalid url: %r'%url
106 assert len(proto_addr) == 2, 'Invalid url: %r'%url
124 proto, addr = proto_addr
107 proto, addr = proto_addr
125 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
108 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
126
109
127 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
110 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
128 # author: Remi Sabourin
111 # author: Remi Sabourin
129 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
112 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
130
113
131 if proto == 'tcp':
114 if proto == 'tcp':
132 lis = addr.split(':')
115 lis = addr.split(':')
133 assert len(lis) == 2, 'Invalid url: %r'%url
116 assert len(lis) == 2, 'Invalid url: %r'%url
134 addr,s_port = lis
117 addr,s_port = lis
135 try:
118 try:
136 port = int(s_port)
119 port = int(s_port)
137 except ValueError:
120 except ValueError:
138 raise AssertionError("Invalid port %r in url: %r"%(port, url))
121 raise AssertionError("Invalid port %r in url: %r"%(port, url))
139
122
140 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
123 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
141
124
142 else:
125 else:
143 # only validate tcp urls currently
126 # only validate tcp urls currently
144 pass
127 pass
145
128
146 return True
129 return True
147
130
148
131
149 def validate_url_container(container):
132 def validate_url_container(container):
150 """validate a potentially nested collection of urls."""
133 """validate a potentially nested collection of urls."""
151 if isinstance(container, basestring):
134 if isinstance(container, basestring):
152 url = container
135 url = container
153 return validate_url(url)
136 return validate_url(url)
154 elif isinstance(container, dict):
137 elif isinstance(container, dict):
155 container = container.itervalues()
138 container = container.itervalues()
156
139
157 for element in container:
140 for element in container:
158 validate_url_container(element)
141 validate_url_container(element)
159
142
160
143
161 def split_url(url):
144 def split_url(url):
162 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
145 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
163 proto_addr = url.split('://')
146 proto_addr = url.split('://')
164 assert len(proto_addr) == 2, 'Invalid url: %r'%url
147 assert len(proto_addr) == 2, 'Invalid url: %r'%url
165 proto, addr = proto_addr
148 proto, addr = proto_addr
166 lis = addr.split(':')
149 lis = addr.split(':')
167 assert len(lis) == 2, 'Invalid url: %r'%url
150 assert len(lis) == 2, 'Invalid url: %r'%url
168 addr,s_port = lis
151 addr,s_port = lis
169 return proto,addr,s_port
152 return proto,addr,s_port
170
153
171 def disambiguate_ip_address(ip, location=None):
154 def disambiguate_ip_address(ip, location=None):
172 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
155 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
173 ones, based on the location (default interpretation of location is localhost)."""
156 ones, based on the location (default interpretation of location is localhost)."""
174 if ip in ('0.0.0.0', '*'):
157 if ip in ('0.0.0.0', '*'):
175 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
158 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
176 if location is None or location in external_ips:
159 if location is None or location in external_ips:
177 ip='127.0.0.1'
160 ip='127.0.0.1'
178 elif location:
161 elif location:
179 return location
162 return location
180 return ip
163 return ip
181
164
182 def disambiguate_url(url, location=None):
165 def disambiguate_url(url, location=None):
183 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
184 ones, based on the location (default interpretation is localhost).
167 ones, based on the location (default interpretation is localhost).
185
168
186 This is for zeromq urls, such as tcp://*:10101."""
169 This is for zeromq urls, such as tcp://*:10101."""
187 try:
170 try:
188 proto,ip,port = split_url(url)
171 proto,ip,port = split_url(url)
189 except AssertionError:
172 except AssertionError:
190 # probably not tcp url; could be ipc, etc.
173 # probably not tcp url; could be ipc, etc.
191 return url
174 return url
192
175
193 ip = disambiguate_ip_address(ip,location)
176 ip = disambiguate_ip_address(ip,location)
194
177
195 return "%s://%s:%s"%(proto,ip,port)
178 return "%s://%s:%s"%(proto,ip,port)
196
179
197
180
198 def rekey(dikt):
181 def rekey(dikt):
199 """Rekey a dict that has been forced to use str keys where there should be
182 """Rekey a dict that has been forced to use str keys where there should be
200 ints by json. This belongs in the jsonutil added by fperez."""
183 ints by json. This belongs in the jsonutil added by fperez."""
201 for k in dikt.iterkeys():
184 for k in dikt.iterkeys():
202 if isinstance(k, str):
185 if isinstance(k, str):
203 ik=fk=None
186 ik=fk=None
204 try:
187 try:
205 ik = int(k)
188 ik = int(k)
206 except ValueError:
189 except ValueError:
207 try:
190 try:
208 fk = float(k)
191 fk = float(k)
209 except ValueError:
192 except ValueError:
210 continue
193 continue
211 if ik is not None:
194 if ik is not None:
212 nk = ik
195 nk = ik
213 else:
196 else:
214 nk = fk
197 nk = fk
215 if nk in dikt:
198 if nk in dikt:
216 raise KeyError("already have key %r"%nk)
199 raise KeyError("already have key %r"%nk)
217 dikt[nk] = dikt.pop(k)
200 dikt[nk] = dikt.pop(k)
218 return dikt
201 return dikt
219
202
220 def serialize_object(obj, threshold=64e-6):
203 def serialize_object(obj, threshold=64e-6):
221 """Serialize an object into a list of sendable buffers.
204 """Serialize an object into a list of sendable buffers.
222
205
223 Parameters
206 Parameters
224 ----------
207 ----------
225
208
226 obj : object
209 obj : object
227 The object to be serialized
210 The object to be serialized
228 threshold : float
211 threshold : float
229 The threshold for not double-pickling the content.
212 The threshold for not double-pickling the content.
230
213
231
214
232 Returns
215 Returns
233 -------
216 -------
234 ('pmd', [bufs]) :
217 ('pmd', [bufs]) :
235 where pmd is the pickled metadata wrapper,
218 where pmd is the pickled metadata wrapper,
236 bufs is a list of data buffers
219 bufs is a list of data buffers
237 """
220 """
238 databuffers = []
221 databuffers = []
239 if isinstance(obj, (list, tuple)):
222 if isinstance(obj, (list, tuple)):
240 clist = canSequence(obj)
223 clist = canSequence(obj)
241 slist = map(serialize, clist)
224 slist = map(serialize, clist)
242 for s in slist:
225 for s in slist:
243 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
226 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
244 databuffers.append(s.getData())
227 databuffers.append(s.getData())
245 s.data = None
228 s.data = None
246 return pickle.dumps(slist,-1), databuffers
229 return pickle.dumps(slist,-1), databuffers
247 elif isinstance(obj, dict):
230 elif isinstance(obj, dict):
248 sobj = {}
231 sobj = {}
249 for k in sorted(obj.iterkeys()):
232 for k in sorted(obj.iterkeys()):
250 s = serialize(can(obj[k]))
233 s = serialize(can(obj[k]))
251 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
234 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
252 databuffers.append(s.getData())
235 databuffers.append(s.getData())
253 s.data = None
236 s.data = None
254 sobj[k] = s
237 sobj[k] = s
255 return pickle.dumps(sobj,-1),databuffers
238 return pickle.dumps(sobj,-1),databuffers
256 else:
239 else:
257 s = serialize(can(obj))
240 s = serialize(can(obj))
258 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
259 databuffers.append(s.getData())
242 databuffers.append(s.getData())
260 s.data = None
243 s.data = None
261 return pickle.dumps(s,-1),databuffers
244 return pickle.dumps(s,-1),databuffers
262
245
263
246
264 def unserialize_object(bufs):
247 def unserialize_object(bufs):
265 """reconstruct an object serialized by serialize_object from data buffers."""
248 """reconstruct an object serialized by serialize_object from data buffers."""
266 bufs = list(bufs)
249 bufs = list(bufs)
267 sobj = pickle.loads(bufs.pop(0))
250 sobj = pickle.loads(bufs.pop(0))
268 if isinstance(sobj, (list, tuple)):
251 if isinstance(sobj, (list, tuple)):
269 for s in sobj:
252 for s in sobj:
270 if s.data is None:
253 if s.data is None:
271 s.data = bufs.pop(0)
254 s.data = bufs.pop(0)
272 return uncanSequence(map(unserialize, sobj)), bufs
255 return uncanSequence(map(unserialize, sobj)), bufs
273 elif isinstance(sobj, dict):
256 elif isinstance(sobj, dict):
274 newobj = {}
257 newobj = {}
275 for k in sorted(sobj.iterkeys()):
258 for k in sorted(sobj.iterkeys()):
276 s = sobj[k]
259 s = sobj[k]
277 if s.data is None:
260 if s.data is None:
278 s.data = bufs.pop(0)
261 s.data = bufs.pop(0)
279 newobj[k] = uncan(unserialize(s))
262 newobj[k] = uncan(unserialize(s))
280 return newobj, bufs
263 return newobj, bufs
281 else:
264 else:
282 if sobj.data is None:
265 if sobj.data is None:
283 sobj.data = bufs.pop(0)
266 sobj.data = bufs.pop(0)
284 return uncan(unserialize(sobj)), bufs
267 return uncan(unserialize(sobj)), bufs
285
268
286 def pack_apply_message(f, args, kwargs, threshold=64e-6):
269 def pack_apply_message(f, args, kwargs, threshold=64e-6):
287 """pack up a function, args, and kwargs to be sent over the wire
270 """pack up a function, args, and kwargs to be sent over the wire
288 as a series of buffers. Any object whose data is larger than `threshold`
271 as a series of buffers. Any object whose data is larger than `threshold`
289 will not have their data copied (currently only numpy arrays support zero-copy)"""
272 will not have their data copied (currently only numpy arrays support zero-copy)"""
290 msg = [pickle.dumps(can(f),-1)]
273 msg = [pickle.dumps(can(f),-1)]
291 databuffers = [] # for large objects
274 databuffers = [] # for large objects
292 sargs, bufs = serialize_object(args,threshold)
275 sargs, bufs = serialize_object(args,threshold)
293 msg.append(sargs)
276 msg.append(sargs)
294 databuffers.extend(bufs)
277 databuffers.extend(bufs)
295 skwargs, bufs = serialize_object(kwargs,threshold)
278 skwargs, bufs = serialize_object(kwargs,threshold)
296 msg.append(skwargs)
279 msg.append(skwargs)
297 databuffers.extend(bufs)
280 databuffers.extend(bufs)
298 msg.extend(databuffers)
281 msg.extend(databuffers)
299 return msg
282 return msg
300
283
301 def unpack_apply_message(bufs, g=None, copy=True):
284 def unpack_apply_message(bufs, g=None, copy=True):
302 """unpack f,args,kwargs from buffers packed by pack_apply_message()
285 """unpack f,args,kwargs from buffers packed by pack_apply_message()
303 Returns: original f,args,kwargs"""
286 Returns: original f,args,kwargs"""
304 bufs = list(bufs) # allow us to pop
287 bufs = list(bufs) # allow us to pop
305 assert len(bufs) >= 3, "not enough buffers!"
288 assert len(bufs) >= 3, "not enough buffers!"
306 if not copy:
289 if not copy:
307 for i in range(3):
290 for i in range(3):
308 bufs[i] = bufs[i].bytes
291 bufs[i] = bufs[i].bytes
309 cf = pickle.loads(bufs.pop(0))
292 cf = pickle.loads(bufs.pop(0))
310 sargs = list(pickle.loads(bufs.pop(0)))
293 sargs = list(pickle.loads(bufs.pop(0)))
311 skwargs = dict(pickle.loads(bufs.pop(0)))
294 skwargs = dict(pickle.loads(bufs.pop(0)))
312 # print sargs, skwargs
295 # print sargs, skwargs
313 f = uncan(cf, g)
296 f = uncan(cf, g)
314 for sa in sargs:
297 for sa in sargs:
315 if sa.data is None:
298 if sa.data is None:
316 m = bufs.pop(0)
299 m = bufs.pop(0)
317 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
300 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
318 # always use a buffer, until memoryviews get sorted out
301 # always use a buffer, until memoryviews get sorted out
319 sa.data = buffer(m)
302 sa.data = buffer(m)
320 # disable memoryview support
303 # disable memoryview support
321 # if copy:
304 # if copy:
322 # sa.data = buffer(m)
305 # sa.data = buffer(m)
323 # else:
306 # else:
324 # sa.data = m.buffer
307 # sa.data = m.buffer
325 else:
308 else:
326 if copy:
309 if copy:
327 sa.data = m
310 sa.data = m
328 else:
311 else:
329 sa.data = m.bytes
312 sa.data = m.bytes
330
313
331 args = uncanSequence(map(unserialize, sargs), g)
314 args = uncanSequence(map(unserialize, sargs), g)
332 kwargs = {}
315 kwargs = {}
333 for k in sorted(skwargs.iterkeys()):
316 for k in sorted(skwargs.iterkeys()):
334 sa = skwargs[k]
317 sa = skwargs[k]
335 if sa.data is None:
318 if sa.data is None:
336 m = bufs.pop(0)
319 m = bufs.pop(0)
337 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
320 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
338 # always use a buffer, until memoryviews get sorted out
321 # always use a buffer, until memoryviews get sorted out
339 sa.data = buffer(m)
322 sa.data = buffer(m)
340 # disable memoryview support
323 # disable memoryview support
341 # if copy:
324 # if copy:
342 # sa.data = buffer(m)
325 # sa.data = buffer(m)
343 # else:
326 # else:
344 # sa.data = m.buffer
327 # sa.data = m.buffer
345 else:
328 else:
346 if copy:
329 if copy:
347 sa.data = m
330 sa.data = m
348 else:
331 else:
349 sa.data = m.bytes
332 sa.data = m.bytes
350
333
351 kwargs[k] = uncan(unserialize(sa), g)
334 kwargs[k] = uncan(unserialize(sa), g)
352
335
353 return f,args,kwargs
336 return f,args,kwargs
354
337
355 #--------------------------------------------------------------------------
338 #--------------------------------------------------------------------------
356 # helpers for implementing old MEC API via view.apply
339 # helpers for implementing old MEC API via view.apply
357 #--------------------------------------------------------------------------
340 #--------------------------------------------------------------------------
358
341
359 def interactive(f):
342 def interactive(f):
360 """decorator for making functions appear as interactively defined.
343 """decorator for making functions appear as interactively defined.
361 This results in the function being linked to the user_ns as globals()
344 This results in the function being linked to the user_ns as globals()
362 instead of the module globals().
345 instead of the module globals().
363 """
346 """
364 f.__module__ = '__main__'
347 f.__module__ = '__main__'
365 return f
348 return f
366
349
367 @interactive
350 @interactive
368 def _push(ns):
351 def _push(ns):
369 """helper method for implementing `client.push` via `client.apply`"""
352 """helper method for implementing `client.push` via `client.apply`"""
370 globals().update(ns)
353 globals().update(ns)
371
354
372 @interactive
355 @interactive
373 def _pull(keys):
356 def _pull(keys):
374 """helper method for implementing `client.pull` via `client.apply`"""
357 """helper method for implementing `client.pull` via `client.apply`"""
375 user_ns = globals()
358 user_ns = globals()
376 if isinstance(keys, (list,tuple, set)):
359 if isinstance(keys, (list,tuple, set)):
377 for key in keys:
360 for key in keys:
378 if not user_ns.has_key(key):
361 if not user_ns.has_key(key):
379 raise NameError("name '%s' is not defined"%key)
362 raise NameError("name '%s' is not defined"%key)
380 return map(user_ns.get, keys)
363 return map(user_ns.get, keys)
381 else:
364 else:
382 if not user_ns.has_key(keys):
365 if not user_ns.has_key(keys):
383 raise NameError("name '%s' is not defined"%keys)
366 raise NameError("name '%s' is not defined"%keys)
384 return user_ns.get(keys)
367 return user_ns.get(keys)
385
368
386 @interactive
369 @interactive
387 def _execute(code):
370 def _execute(code):
388 """helper method for implementing `client.execute` via `client.apply`"""
371 """helper method for implementing `client.execute` via `client.apply`"""
389 exec code in globals()
372 exec code in globals()
390
373
391 #--------------------------------------------------------------------------
374 #--------------------------------------------------------------------------
392 # extra process management utilities
375 # extra process management utilities
393 #--------------------------------------------------------------------------
376 #--------------------------------------------------------------------------
394
377
395 _random_ports = set()
378 _random_ports = set()
396
379
397 def select_random_ports(n):
380 def select_random_ports(n):
398 """Selects and return n random ports that are available."""
381 """Selects and return n random ports that are available."""
399 ports = []
382 ports = []
400 for i in xrange(n):
383 for i in xrange(n):
401 sock = socket.socket()
384 sock = socket.socket()
402 sock.bind(('', 0))
385 sock.bind(('', 0))
403 while sock.getsockname()[1] in _random_ports:
386 while sock.getsockname()[1] in _random_ports:
404 sock.close()
387 sock.close()
405 sock = socket.socket()
388 sock = socket.socket()
406 sock.bind(('', 0))
389 sock.bind(('', 0))
407 ports.append(sock)
390 ports.append(sock)
408 for i, sock in enumerate(ports):
391 for i, sock in enumerate(ports):
409 port = sock.getsockname()[1]
392 port = sock.getsockname()[1]
410 sock.close()
393 sock.close()
411 ports[i] = port
394 ports[i] = port
412 _random_ports.add(port)
395 _random_ports.add(port)
413 return ports
396 return ports
414
397
415 def signal_children(children):
398 def signal_children(children):
416 """Relay interupt/term signals to children, for more solid process cleanup."""
399 """Relay interupt/term signals to children, for more solid process cleanup."""
417 def terminate_children(sig, frame):
400 def terminate_children(sig, frame):
418 logging.critical("Got signal %i, terminating children..."%sig)
401 logging.critical("Got signal %i, terminating children..."%sig)
419 for child in children:
402 for child in children:
420 child.terminate()
403 child.terminate()
421
404
422 sys.exit(sig != SIGINT)
405 sys.exit(sig != SIGINT)
423 # sys.exit(sig)
406 # sys.exit(sig)
424 for sig in (SIGINT, SIGABRT, SIGTERM):
407 for sig in (SIGINT, SIGABRT, SIGTERM):
425 signal(sig, terminate_children)
408 signal(sig, terminate_children)
426
409
427 def generate_exec_key(keyfile):
410 def generate_exec_key(keyfile):
428 import uuid
411 import uuid
429 newkey = str(uuid.uuid4())
412 newkey = str(uuid.uuid4())
430 with open(keyfile, 'w') as f:
413 with open(keyfile, 'w') as f:
431 # f.write('ipython-key ')
414 # f.write('ipython-key ')
432 f.write(newkey+'\n')
415 f.write(newkey+'\n')
433 # set user-only RW permissions (0600)
416 # set user-only RW permissions (0600)
434 # this will have no effect on Windows
417 # this will have no effect on Windows
435 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
418 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
436
419
437
420
438 def integer_loglevel(loglevel):
421 def integer_loglevel(loglevel):
439 try:
422 try:
440 loglevel = int(loglevel)
423 loglevel = int(loglevel)
441 except ValueError:
424 except ValueError:
442 if isinstance(loglevel, str):
425 if isinstance(loglevel, str):
443 loglevel = getattr(logging, loglevel)
426 loglevel = getattr(logging, loglevel)
444 return loglevel
427 return loglevel
445
428
446 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
429 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
447 logger = logging.getLogger(logname)
430 logger = logging.getLogger(logname)
448 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
431 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
449 # don't add a second PUBHandler
432 # don't add a second PUBHandler
450 return
433 return
451 loglevel = integer_loglevel(loglevel)
434 loglevel = integer_loglevel(loglevel)
452 lsock = context.socket(zmq.PUB)
435 lsock = context.socket(zmq.PUB)
453 lsock.connect(iface)
436 lsock.connect(iface)
454 handler = handlers.PUBHandler(lsock)
437 handler = handlers.PUBHandler(lsock)
455 handler.setLevel(loglevel)
438 handler.setLevel(loglevel)
456 handler.root_topic = root
439 handler.root_topic = root
457 logger.addHandler(handler)
440 logger.addHandler(handler)
458 logger.setLevel(loglevel)
441 logger.setLevel(loglevel)
459
442
460 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
443 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
461 logger = logging.getLogger()
444 logger = logging.getLogger()
462 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
445 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
463 # don't add a second PUBHandler
446 # don't add a second PUBHandler
464 return
447 return
465 loglevel = integer_loglevel(loglevel)
448 loglevel = integer_loglevel(loglevel)
466 lsock = context.socket(zmq.PUB)
449 lsock = context.socket(zmq.PUB)
467 lsock.connect(iface)
450 lsock.connect(iface)
468 handler = EnginePUBHandler(engine, lsock)
451 handler = EnginePUBHandler(engine, lsock)
469 handler.setLevel(loglevel)
452 handler.setLevel(loglevel)
470 logger.addHandler(handler)
453 logger.addHandler(handler)
471 logger.setLevel(loglevel)
454 logger.setLevel(loglevel)
472
455
473 def local_logger(logname, loglevel=logging.DEBUG):
456 def local_logger(logname, loglevel=logging.DEBUG):
474 loglevel = integer_loglevel(loglevel)
457 loglevel = integer_loglevel(loglevel)
475 logger = logging.getLogger(logname)
458 logger = logging.getLogger(logname)
476 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
459 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
477 # don't add a second StreamHandler
460 # don't add a second StreamHandler
478 return
461 return
479 handler = logging.StreamHandler()
462 handler = logging.StreamHandler()
480 handler.setLevel(loglevel)
463 handler.setLevel(loglevel)
481 logger.addHandler(handler)
464 logger.addHandler(handler)
482 logger.setLevel(loglevel)
465 logger.setLevel(loglevel)
483
466
@@ -1,90 +1,121 b''
1 """Utilities to manipulate JSON objects.
1 """Utilities to manipulate JSON objects.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (C) 2010 The IPython Development Team
4 # Copyright (C) 2010 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING.txt, distributed as part of this software.
7 # the file COPYING.txt, distributed as part of this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # stdlib
13 # stdlib
14 import re
14 import types
15 import types
16 from datetime import datetime
17
18 #-----------------------------------------------------------------------------
19 # Globals and constants
20 #-----------------------------------------------------------------------------
21
22 # timestamp formats
23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
15
25
16 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
17 # Classes and functions
27 # Classes and functions
18 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
19
29
30 def extract_dates(obj):
31 """extract ISO8601 dates from unpacked JSON"""
32 if isinstance(obj, dict):
33 for k,v in obj.iteritems():
34 obj[k] = extract_dates(v)
35 elif isinstance(obj, list):
36 obj = [ extract_dates(o) for o in obj ]
37 elif isinstance(obj, basestring):
38 if ISO8601_PAT.match(obj):
39 obj = datetime.strptime(obj, ISO8601)
40 return obj
41
42 def date_default(obj):
43 """default function for packing datetime objects"""
44 if isinstance(obj, datetime):
45 return obj.strftime(ISO8601)
46 else:
47 raise TypeError("%r is not JSON serializable"%obj)
48
49
50
20 def json_clean(obj):
51 def json_clean(obj):
21 """Clean an object to ensure it's safe to encode in JSON.
52 """Clean an object to ensure it's safe to encode in JSON.
22
53
23 Atomic, immutable objects are returned unmodified. Sets and tuples are
54 Atomic, immutable objects are returned unmodified. Sets and tuples are
24 converted to lists, lists are copied and dicts are also copied.
55 converted to lists, lists are copied and dicts are also copied.
25
56
26 Note: dicts whose keys could cause collisions upon encoding (such as a dict
57 Note: dicts whose keys could cause collisions upon encoding (such as a dict
27 with both the number 1 and the string '1' as keys) will cause a ValueError
58 with both the number 1 and the string '1' as keys) will cause a ValueError
28 to be raised.
59 to be raised.
29
60
30 Parameters
61 Parameters
31 ----------
62 ----------
32 obj : any python object
63 obj : any python object
33
64
34 Returns
65 Returns
35 -------
66 -------
36 out : object
67 out : object
37
68
38 A version of the input which will not cause an encoding error when
69 A version of the input which will not cause an encoding error when
39 encoded as JSON. Note that this function does not *encode* its inputs,
70 encoded as JSON. Note that this function does not *encode* its inputs,
40 it simply sanitizes it so that there will be no encoding errors later.
71 it simply sanitizes it so that there will be no encoding errors later.
41
72
42 Examples
73 Examples
43 --------
74 --------
44 >>> json_clean(4)
75 >>> json_clean(4)
45 4
76 4
46 >>> json_clean(range(10))
77 >>> json_clean(range(10))
47 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
78 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
48 >>> json_clean(dict(x=1, y=2))
79 >>> json_clean(dict(x=1, y=2))
49 {'y': 2, 'x': 1}
80 {'y': 2, 'x': 1}
50 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
81 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
51 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
82 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
52 >>> json_clean(True)
83 >>> json_clean(True)
53 True
84 True
54 """
85 """
55 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
86 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
56 # listed explicitly because bools pass as int instances
87 # listed explicitly because bools pass as int instances
57 atomic_ok = (basestring, int, float, types.NoneType)
88 atomic_ok = (basestring, int, float, types.NoneType)
58
89
59 # containers that we need to convert into lists
90 # containers that we need to convert into lists
60 container_to_list = (tuple, set, types.GeneratorType)
91 container_to_list = (tuple, set, types.GeneratorType)
61
92
62 if isinstance(obj, atomic_ok):
93 if isinstance(obj, atomic_ok):
63 return obj
94 return obj
64
95
65 if isinstance(obj, container_to_list) or (
96 if isinstance(obj, container_to_list) or (
66 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
97 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
67 obj = list(obj)
98 obj = list(obj)
68
99
69 if isinstance(obj, list):
100 if isinstance(obj, list):
70 return [json_clean(x) for x in obj]
101 return [json_clean(x) for x in obj]
71
102
72 if isinstance(obj, dict):
103 if isinstance(obj, dict):
73 # First, validate that the dict won't lose data in conversion due to
104 # First, validate that the dict won't lose data in conversion due to
74 # key collisions after stringification. This can happen with keys like
105 # key collisions after stringification. This can happen with keys like
75 # True and 'true' or 1 and '1', which collide in JSON.
106 # True and 'true' or 1 and '1', which collide in JSON.
76 nkeys = len(obj)
107 nkeys = len(obj)
77 nkeys_collapsed = len(set(map(str, obj)))
108 nkeys_collapsed = len(set(map(str, obj)))
78 if nkeys != nkeys_collapsed:
109 if nkeys != nkeys_collapsed:
79 raise ValueError('dict can not be safely converted to JSON: '
110 raise ValueError('dict can not be safely converted to JSON: '
80 'key collision would lead to dropped values')
111 'key collision would lead to dropped values')
81 # If all OK, proceed by making the new dict that will be json-safe
112 # If all OK, proceed by making the new dict that will be json-safe
82 out = {}
113 out = {}
83 for k,v in obj.iteritems():
114 for k,v in obj.iteritems():
84 out[str(k)] = json_clean(v)
115 out[str(k)] = json_clean(v)
85 return out
116 return out
86
117
87 # If we get here, we don't know how to handle the object, so we just get
118 # If we get here, we don't know how to handle the object, so we just get
88 # its repr and return that. This will catch lambdas, open sockets, class
119 # its repr and return that. This will catch lambdas, open sockets, class
89 # objects, and any other complicated contraption that json can't encode
120 # objects, and any other complicated contraption that json can't encode
90 return repr(obj)
121 return repr(obj)
@@ -1,184 +1,479 b''
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
3 """
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
6 #
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
10
11 #-----------------------------------------------------------------------------
12 # Imports
13 #-----------------------------------------------------------------------------
14
15 import hmac
1 import os
16 import os
2 import uuid
3 import pprint
17 import pprint
18 import uuid
19 from datetime import datetime
20
21 try:
22 import cPickle
23 pickle = cPickle
24 except:
25 cPickle = None
26 import pickle
4
27
5 import zmq
28 import zmq
29 from zmq.utils import jsonapi
30 from zmq.eventloop.zmqstream import ZMQStream
31
32 from IPython.config.configurable import Configurable
33 from IPython.utils.importstring import import_item
34 from IPython.utils.jsonutil import date_default
35 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
36
37 #-----------------------------------------------------------------------------
38 # utility functions
39 #-----------------------------------------------------------------------------
40
41 def squash_unicode(obj):
42 """coerce unicode back to bytestrings."""
43 if isinstance(obj,dict):
44 for key in obj.keys():
45 obj[key] = squash_unicode(obj[key])
46 if isinstance(key, unicode):
47 obj[squash_unicode(key)] = obj.pop(key)
48 elif isinstance(obj, list):
49 for i,v in enumerate(obj):
50 obj[i] = squash_unicode(v)
51 elif isinstance(obj, unicode):
52 obj = obj.encode('utf8')
53 return obj
54
55 #-----------------------------------------------------------------------------
56 # globals and defaults
57 #-----------------------------------------------------------------------------
58
59 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
60 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:date_default})
61 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
6
62
7 from zmq.utils import jsonapi as json
63 pickle_packer = lambda o: pickle.dumps(o,-1)
64 pickle_unpacker = pickle.loads
65
66 default_packer = json_packer
67 default_unpacker = json_unpacker
68
69
70 DELIM="<IDS|MSG>"
71
72 #-----------------------------------------------------------------------------
73 # Classes
74 #-----------------------------------------------------------------------------
8
75
9 class Message(object):
76 class Message(object):
10 """A simple message object that maps dict keys to attributes.
77 """A simple message object that maps dict keys to attributes.
11
78
12 A Message can be created from a dict and a dict from a Message instance
79 A Message can be created from a dict and a dict from a Message instance
13 simply by calling dict(msg_obj)."""
80 simply by calling dict(msg_obj)."""
14
81
15 def __init__(self, msg_dict):
82 def __init__(self, msg_dict):
16 dct = self.__dict__
83 dct = self.__dict__
17 for k, v in msg_dict.iteritems():
84 for k, v in dict(msg_dict).iteritems():
18 if isinstance(v, dict):
85 if isinstance(v, dict):
19 v = Message(v)
86 v = Message(v)
20 dct[k] = v
87 dct[k] = v
21
88
22 # Having this iterator lets dict(msg_obj) work out of the box.
89 # Having this iterator lets dict(msg_obj) work out of the box.
23 def __iter__(self):
90 def __iter__(self):
24 return iter(self.__dict__.iteritems())
91 return iter(self.__dict__.iteritems())
25
92
26 def __repr__(self):
93 def __repr__(self):
27 return repr(self.__dict__)
94 return repr(self.__dict__)
28
95
29 def __str__(self):
96 def __str__(self):
30 return pprint.pformat(self.__dict__)
97 return pprint.pformat(self.__dict__)
31
98
32 def __contains__(self, k):
99 def __contains__(self, k):
33 return k in self.__dict__
100 return k in self.__dict__
34
101
35 def __getitem__(self, k):
102 def __getitem__(self, k):
36 return self.__dict__[k]
103 return self.__dict__[k]
37
104
38
105
39 def msg_header(msg_id, username, session):
106 def msg_header(msg_id, msg_type, username, session):
40 return {
107 date=datetime.now()
41 'msg_id' : msg_id,
108 return locals()
42 'username' : username,
43 'session' : session
44 }
45
46
109
47 def extract_header(msg_or_header):
110 def extract_header(msg_or_header):
48 """Given a message or header, return the header."""
111 """Given a message or header, return the header."""
49 if not msg_or_header:
112 if not msg_or_header:
50 return {}
113 return {}
51 try:
114 try:
52 # See if msg_or_header is the entire message.
115 # See if msg_or_header is the entire message.
53 h = msg_or_header['header']
116 h = msg_or_header['header']
54 except KeyError:
117 except KeyError:
55 try:
118 try:
56 # See if msg_or_header is just the header
119 # See if msg_or_header is just the header
57 h = msg_or_header['msg_id']
120 h = msg_or_header['msg_id']
58 except KeyError:
121 except KeyError:
59 raise
122 raise
60 else:
123 else:
61 h = msg_or_header
124 h = msg_or_header
62 if not isinstance(h, dict):
125 if not isinstance(h, dict):
63 h = dict(h)
126 h = dict(h)
64 return h
127 return h
65
128
66
129 class Session(Configurable):
67 class Session(object):
130 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
68
131 debug=Bool(False, config=True, help="""Debug output in the Session""")
69 def __init__(self, username=os.environ.get('USER','username'), session=None):
132 packer = Unicode('json',config=True,
70 self.username = username
133 help="""The name of the packer for serializing messages.
71 if session is None:
134 Should be one of 'json', 'pickle', or an import name
72 self.session = str(uuid.uuid4())
135 for a custom serializer.""")
136 def _packer_changed(self, name, old, new):
137 if new.lower() == 'json':
138 self.pack = json_packer
139 self.unpack = json_unpacker
140 elif new.lower() == 'pickle':
141 self.pack = pickle_packer
142 self.unpack = pickle_unpacker
73 else:
143 else:
74 self.session = session
144 self.pack = import_item(new)
75 self.msg_id = 0
76
145
77 def msg_header(self):
146 unpacker = Unicode('json',config=True,
78 h = msg_header(self.msg_id, self.username, self.session)
147 help="""The name of the unpacker for unserializing messages.
79 self.msg_id += 1
148 Only used with custom functions for `packer`.""")
80 return h
149 def _unpacker_changed(self, name, old, new):
150 if new.lower() == 'json':
151 self.pack = json_packer
152 self.unpack = json_unpacker
153 elif new.lower() == 'pickle':
154 self.pack = pickle_packer
155 self.unpack = pickle_unpacker
156 else:
157 self.unpack = import_item(new)
158
159 session = CStr('',config=True,
160 help="""The UUID identifying this session.""")
161 def _session_default(self):
162 return bytes(uuid.uuid4())
163 username = Unicode(os.environ.get('USER','username'), config=True,
164 help="""Username for the Session. Default is your system username.""")
165
166 # message signature related traits:
167 key = CStr('', config=True,
168 help="""execution key, for extra authentication.""")
169 def _key_changed(self, name, old, new):
170 if new:
171 self.auth = hmac.HMAC(new)
172 else:
173 self.auth = None
174 auth = Instance(hmac.HMAC)
175 counters = Instance('collections.defaultdict', (int,))
176 digest_history = Set()
177
178 keyfile = Unicode('', config=True,
179 help="""path to file containing execution key.""")
180 def _keyfile_changed(self, name, old, new):
181 with open(new, 'rb') as f:
182 self.key = f.read().strip()
81
183
82 def msg(self, msg_type, content=None, parent=None):
184 pack = Any(default_packer) # the actual packer function
83 """Construct a standard-form message, with a given type, content, and parent.
185 def _pack_changed(self, name, old, new):
186 if not callable(new):
187 raise TypeError("packer must be callable, not %s"%type(new))
84
188
85 NOT to be called directly.
189 unpack = Any(default_unpacker) # the actual packer function
86 """
190 def _unpack_changed(self, name, old, new):
191 if not callable(new):
192 raise TypeError("packer must be callable, not %s"%type(new))
193
194 def __init__(self, **kwargs):
195 super(Session, self).__init__(**kwargs)
196 self.none = self.pack({})
197
198 @property
199 def msg_id(self):
200 """always return new uuid"""
201 return str(uuid.uuid4())
202
203 def msg_header(self, msg_type):
204 return msg_header(self.msg_id, msg_type, self.username, self.session)
205
206 def msg(self, msg_type, content=None, parent=None, subheader=None):
87 msg = {}
207 msg = {}
88 msg['header'] = self.msg_header()
208 msg['header'] = self.msg_header(msg_type)
209 msg['msg_id'] = msg['header']['msg_id']
89 msg['parent_header'] = {} if parent is None else extract_header(parent)
210 msg['parent_header'] = {} if parent is None else extract_header(parent)
90 msg['msg_type'] = msg_type
211 msg['msg_type'] = msg_type
91 msg['content'] = {} if content is None else content
212 msg['content'] = {} if content is None else content
213 sub = {} if subheader is None else subheader
214 msg['header'].update(sub)
92 return msg
215 return msg
93
216
94 def send(self, socket, msg_or_type, content=None, parent=None, ident=None):
217 def check_key(self, msg_or_header):
95 """send a message via a socket, using a uniform message pattern.
218 """Check that a message's header has the right key"""
219 if not self.key:
220 return True
221 header = extract_header(msg_or_header)
222 return header.get('key', '') == self.key
223
224 def sign(self, msg):
225 """Sign a message with HMAC digest. If no auth, return b''."""
226 if self.auth is None:
227 return b''
228 h = self.auth.copy()
229 for m in msg:
230 h.update(m)
231 return h.hexdigest()
232
233 def serialize(self, msg, ident=None):
234 content = msg.get('content', {})
235 if content is None:
236 content = self.none
237 elif isinstance(content, dict):
238 content = self.pack(content)
239 elif isinstance(content, bytes):
240 # content is already packed, as in a relayed message
241 pass
242 elif isinstance(content, unicode):
243 # should be bytes, but JSON often spits out unicode
244 content = content.encode('utf8')
245 else:
246 raise TypeError("Content incorrect type: %s"%type(content))
247
248 real_message = [self.pack(msg['header']),
249 self.pack(msg['parent_header']),
250 content
251 ]
252
253 to_send = []
254
255 if isinstance(ident, list):
256 # accept list of idents
257 to_send.extend(ident)
258 elif ident is not None:
259 to_send.append(ident)
260 to_send.append(DELIM)
261
262 signature = self.sign(real_message)
263 to_send.append(signature)
264
265 to_send.extend(real_message)
266
267 return to_send
268
269 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
270 buffers=None, subheader=None, track=False):
271 """Build and send a message via stream or socket.
96
272
97 Parameters
273 Parameters
98 ----------
274 ----------
99 socket : zmq.Socket
275
100 The socket on which to send.
276 stream : zmq.Socket or ZMQStream
101 msg_or_type : Message/dict or str
277 the socket-like object used to send the data
102 if str : then a new message will be constructed from content,parent
278 msg_or_type : str or Message/dict
103 if Message/dict : then content and parent are ignored, and the message
279 Normally, msg_or_type will be a msg_type unless a message is being sent more
104 is sent. This is only for use when sending a Message for a second time.
280 than once.
105 content : dict, optional
281
106 The contents of the message
282 content : dict or None
107 parent : dict, optional
283 the content of the message (ignored if msg_or_type is a message)
108 The parent header, or parent message, of this message
284 parent : Message or dict or None
109 ident : bytes, optional
285 the parent or parent header describing the parent of this message
110 The zmq.IDENTITY prefix of the destination.
286 ident : bytes or list of bytes
111 Only for use on certain socket types.
287 the zmq.IDENTITY routing path
288 subheader : dict or None
289 extra header keys for this message's header
290 buffers : list or None
291 the already-serialized buffers to be appended to the message
292 track : bool
293 whether to track. Only for use with Sockets,
294 because ZMQStream objects cannot track messages.
112
295
113 Returns
296 Returns
114 -------
297 -------
115 msg : dict
298 msg : message dict
116 The message, as constructed by self.msg(msg_type,content,parent)
299 the constructed message
300 (msg,tracker) : (message dict, MessageTracker)
301 if track=True, then a 2-tuple will be returned,
302 the first element being the constructed
303 message, and the second being the MessageTracker
304
117 """
305 """
306
307 if not isinstance(stream, (zmq.Socket, ZMQStream)):
308 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
309 elif track and isinstance(stream, ZMQStream):
310 raise TypeError("ZMQStream cannot track messages")
311
118 if isinstance(msg_or_type, (Message, dict)):
312 if isinstance(msg_or_type, (Message, dict)):
119 msg = dict(msg_or_type)
313 # we got a Message, not a msg_type
314 # don't build a new Message
315 msg = msg_or_type
120 else:
316 else:
121 msg = self.msg(msg_or_type, content, parent)
317 msg = self.msg(msg_or_type, content, parent, subheader)
122 if ident is not None:
123 socket.send(ident, zmq.SNDMORE)
124 socket.send_json(msg)
125 return msg
126
127 def recv(self, socket, mode=zmq.NOBLOCK):
128 """recv a message on a socket.
129
318
130 Receive an optionally identity-prefixed message, as sent via session.send().
319 buffers = [] if buffers is None else buffers
320 to_send = self.serialize(msg, ident)
321 flag = 0
322 if buffers:
323 flag = zmq.SNDMORE
324 _track = False
325 else:
326 _track=track
327 if track:
328 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
329 else:
330 tracker = stream.send_multipart(to_send, flag, copy=False)
331 for b in buffers[:-1]:
332 stream.send(b, flag, copy=False)
333 if buffers:
334 if track:
335 tracker = stream.send(buffers[-1], copy=False, track=track)
336 else:
337 tracker = stream.send(buffers[-1], copy=False)
338
339 # omsg = Message(msg)
340 if self.debug:
341 pprint.pprint(msg)
342 pprint.pprint(to_send)
343 pprint.pprint(buffers)
131
344
132 Parameters
345 msg['tracker'] = tracker
133 ----------
134
346
135 socket : zmq.Socket
347 return msg
136 The socket on which to recv a message.
348
137 mode : int, optional
349 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
138 the mode flag passed to socket.recv
350 """Send a raw message via ident path.
139 default: zmq.NOBLOCK
140
351
141 Returns
352 Parameters
142 -------
353 ----------
143 (ident,msg) : tuple
354 msg : list of sendable buffers"""
144 always length 2. If no message received, then return is (None,None)
355 to_send = []
145 ident : bytes or None
356 if isinstance(ident, bytes):
146 the identity prefix is there was one, None otherwise.
357 ident = [ident]
147 msg : dict or None
358 if ident is not None:
148 The actual message. If mode==zmq.NOBLOCK and no message was waiting,
359 to_send.extend(ident)
149 it will be None.
360
150 """
361 to_send.append(DELIM)
362 to_send.append(self.sign(msg))
363 to_send.extend(msg)
364 stream.send_multipart(msg, flags, copy=copy)
365
366 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
367 """receives and unpacks a message
368 returns [idents], msg"""
369 if isinstance(socket, ZMQStream):
370 socket = socket.socket
151 try:
371 try:
152 msg = socket.recv_multipart(mode)
372 msg = socket.recv_multipart(mode)
153 except zmq.ZMQError, e:
373 except zmq.ZMQError as e:
154 if e.errno == zmq.EAGAIN:
374 if e.errno == zmq.EAGAIN:
155 # We can convert EAGAIN to None as we know in this case
375 # We can convert EAGAIN to None as we know in this case
156 # recv_json won't return None.
376 # recv_multipart won't return None.
157 return None,None
377 return None,None
158 else:
378 else:
159 raise
379 raise
160 if len(msg) == 1:
380 # return an actual Message object
161 ident=None
381 # determine the number of idents by trying to unpack them.
162 msg = msg[0]
382 # this is terrible:
163 elif len(msg) == 2:
383 idents, msg = self.feed_identities(msg, copy)
164 ident, msg = msg
384 try:
385 return idents, self.unpack_message(msg, content=content, copy=copy)
386 except Exception as e:
387 print (idents, msg)
388 # TODO: handle it
389 raise e
390
391 def feed_identities(self, msg, copy=True):
392 """feed until DELIM is reached, then return the prefix as idents and remainder as
393 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
394
395 Parameters
396 ----------
397 msg : a list of Message or bytes objects
398 the message to be split
399 copy : bool
400 flag determining whether the arguments are bytes or Messages
401
402 Returns
403 -------
404 (idents,msg) : two lists
405 idents will always be a list of bytes - the indentity prefix
406 msg will be a list of bytes or Messages, unchanged from input
407 msg should be unpackable via self.unpack_message at this point.
408 """
409 if copy:
410 idx = msg.index(DELIM)
411 return msg[:idx], msg[idx+1:]
412 else:
413 failed = True
414 for idx,m in enumerate(msg):
415 if m.bytes == DELIM:
416 failed = False
417 break
418 if failed:
419 raise ValueError("DELIM not in msg")
420 idents, msg = msg[:idx], msg[idx+1:]
421 return [m.bytes for m in idents], msg
422
423 def unpack_message(self, msg, content=True, copy=True):
424 """Return a message object from the format
425 sent by self.send.
426
427 Parameters:
428 -----------
429
430 content : bool (True)
431 whether to unpack the content dict (True),
432 or leave it serialized (False)
433
434 copy : bool (True)
435 whether to return the bytes (True),
436 or the non-copying Message object in each place (False)
437
438 """
439 minlen = 4
440 message = {}
441 if not copy:
442 for i in range(minlen):
443 msg[i] = msg[i].bytes
444 if self.auth is not None:
445 signature = msg[0]
446 if signature in self.digest_history:
447 raise ValueError("Duplicate Signature: %r"%signature)
448 self.digest_history.add(signature)
449 check = self.sign(msg[1:4])
450 if not signature == check:
451 raise ValueError("Invalid Signature: %r"%signature)
452 if not len(msg) >= minlen:
453 raise TypeError("malformed message, must have at least %i elements"%minlen)
454 message['header'] = self.unpack(msg[1])
455 message['msg_type'] = message['header']['msg_type']
456 message['parent_header'] = self.unpack(msg[2])
457 if content:
458 message['content'] = self.unpack(msg[3])
165 else:
459 else:
166 raise ValueError("Got message with length > 2, which is invalid")
460 message['content'] = msg[3]
167
461
168 return ident, json.loads(msg)
462 message['buffers'] = msg[4:]
463 return message
169
464
170 def test_msg2obj():
465 def test_msg2obj():
171 am = dict(x=1)
466 am = dict(x=1)
172 ao = Message(am)
467 ao = Message(am)
173 assert ao.x == am['x']
468 assert ao.x == am['x']
174
469
175 am['y'] = dict(z=1)
470 am['y'] = dict(z=1)
176 ao = Message(am)
471 ao = Message(am)
177 assert ao.y.z == am['y']['z']
472 assert ao.y.z == am['y']['z']
178
473
179 k1, k2 = 'y', 'z'
474 k1, k2 = 'y', 'z'
180 assert ao[k1][k2] == am[k1][k2]
475 assert ao[k1][k2] == am[k1][k2]
181
476
182 am2 = dict(ao)
477 am2 = dict(ao)
183 assert am['x'] == am2['x']
478 assert am['x'] == am2['x']
184 assert am['y']['z'] == am2['y']['z']
479 assert am['y']['z'] == am2['y']['z']
@@ -1,111 +1,111 b''
1 """test building messages with streamsession"""
1 """test building messages with streamsession"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import uuid
15 import uuid
16 import zmq
16 import zmq
17
17
18 from zmq.tests import BaseZMQTestCase
18 from zmq.tests import BaseZMQTestCase
19 from zmq.eventloop.zmqstream import ZMQStream
19 from zmq.eventloop.zmqstream import ZMQStream
20 # from IPython.zmq.tests import SessionTestCase
20
21 from IPython.parallel import streamsession as ss
21 from IPython.zmq import session as ss
22
22
23 class SessionTestCase(BaseZMQTestCase):
23 class SessionTestCase(BaseZMQTestCase):
24
24
25 def setUp(self):
25 def setUp(self):
26 BaseZMQTestCase.setUp(self)
26 BaseZMQTestCase.setUp(self)
27 self.session = ss.StreamSession()
27 self.session = ss.Session()
28
28
29 class TestSession(SessionTestCase):
29 class TestSession(SessionTestCase):
30
30
31 def test_msg(self):
31 def test_msg(self):
32 """message format"""
32 """message format"""
33 msg = self.session.msg('execute')
33 msg = self.session.msg('execute')
34 thekeys = set('header msg_id parent_header msg_type content'.split())
34 thekeys = set('header msg_id parent_header msg_type content'.split())
35 s = set(msg.keys())
35 s = set(msg.keys())
36 self.assertEquals(s, thekeys)
36 self.assertEquals(s, thekeys)
37 self.assertTrue(isinstance(msg['content'],dict))
37 self.assertTrue(isinstance(msg['content'],dict))
38 self.assertTrue(isinstance(msg['header'],dict))
38 self.assertTrue(isinstance(msg['header'],dict))
39 self.assertTrue(isinstance(msg['parent_header'],dict))
39 self.assertTrue(isinstance(msg['parent_header'],dict))
40 self.assertEquals(msg['msg_type'], 'execute')
40 self.assertEquals(msg['msg_type'], 'execute')
41
41
42
42
43
43
44 def test_args(self):
44 def test_args(self):
45 """initialization arguments for StreamSession"""
45 """initialization arguments for Session"""
46 s = self.session
46 s = self.session
47 self.assertTrue(s.pack is ss.default_packer)
47 self.assertTrue(s.pack is ss.default_packer)
48 self.assertTrue(s.unpack is ss.default_unpacker)
48 self.assertTrue(s.unpack is ss.default_unpacker)
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
50
50
51 s = ss.StreamSession()
51 s = ss.Session()
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
53
53
54 self.assertRaises(TypeError, ss.StreamSession, pack='hi')
54 self.assertRaises(TypeError, ss.Session, pack='hi')
55 self.assertRaises(TypeError, ss.StreamSession, unpack='hi')
55 self.assertRaises(TypeError, ss.Session, unpack='hi')
56 u = str(uuid.uuid4())
56 u = str(uuid.uuid4())
57 s = ss.StreamSession(username='carrot', session=u)
57 s = ss.Session(username='carrot', session=u)
58 self.assertEquals(s.session, u)
58 self.assertEquals(s.session, u)
59 self.assertEquals(s.username, 'carrot')
59 self.assertEquals(s.username, 'carrot')
60
60
61 def test_tracking(self):
61 def test_tracking(self):
62 """test tracking messages"""
62 """test tracking messages"""
63 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
63 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
64 s = self.session
64 s = self.session
65 stream = ZMQStream(a)
65 stream = ZMQStream(a)
66 msg = s.send(a, 'hello', track=False)
66 msg = s.send(a, 'hello', track=False)
67 self.assertTrue(msg['tracker'] is None)
67 self.assertTrue(msg['tracker'] is None)
68 msg = s.send(a, 'hello', track=True)
68 msg = s.send(a, 'hello', track=True)
69 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
69 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
70 M = zmq.Message(b'hi there', track=True)
70 M = zmq.Message(b'hi there', track=True)
71 msg = s.send(a, 'hello', buffers=[M], track=True)
71 msg = s.send(a, 'hello', buffers=[M], track=True)
72 t = msg['tracker']
72 t = msg['tracker']
73 self.assertTrue(isinstance(t, zmq.MessageTracker))
73 self.assertTrue(isinstance(t, zmq.MessageTracker))
74 self.assertRaises(zmq.NotDone, t.wait, .1)
74 self.assertRaises(zmq.NotDone, t.wait, .1)
75 del M
75 del M
76 t.wait(1) # this will raise
76 t.wait(1) # this will raise
77
77
78
78
79 # def test_rekey(self):
79 # def test_rekey(self):
80 # """rekeying dict around json str keys"""
80 # """rekeying dict around json str keys"""
81 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
81 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
82 # self.assertRaises(KeyError, ss.rekey, d)
82 # self.assertRaises(KeyError, ss.rekey, d)
83 #
83 #
84 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
84 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
85 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
85 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
86 # rd = ss.rekey(d)
86 # rd = ss.rekey(d)
87 # self.assertEquals(d2,rd)
87 # self.assertEquals(d2,rd)
88 #
88 #
89 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
89 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
90 # d2 = {1.5:d['1.5'],1:d['1']}
90 # d2 = {1.5:d['1.5'],1:d['1']}
91 # rd = ss.rekey(d)
91 # rd = ss.rekey(d)
92 # self.assertEquals(d2,rd)
92 # self.assertEquals(d2,rd)
93 #
93 #
94 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
94 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
95 # self.assertRaises(KeyError, ss.rekey, d)
95 # self.assertRaises(KeyError, ss.rekey, d)
96 #
96 #
97 def test_unique_msg_ids(self):
97 def test_unique_msg_ids(self):
98 """test that messages receive unique ids"""
98 """test that messages receive unique ids"""
99 ids = set()
99 ids = set()
100 for i in range(2**12):
100 for i in range(2**12):
101 h = self.session.msg_header('test')
101 h = self.session.msg_header('test')
102 msg_id = h['msg_id']
102 msg_id = h['msg_id']
103 self.assertTrue(msg_id not in ids)
103 self.assertTrue(msg_id not in ids)
104 ids.add(msg_id)
104 ids.add(msg_id)
105
105
106 def test_feed_identities(self):
106 def test_feed_identities(self):
107 """scrub the front for zmq IDENTITIES"""
107 """scrub the front for zmq IDENTITIES"""
108 theids = "engine client other".split()
108 theids = "engine client other".split()
109 content = dict(code='whoda',stuff=object())
109 content = dict(code='whoda',stuff=object())
110 themsg = self.session.msg('execute',content=content)
110 themsg = self.session.msg('execute',content=content)
111 pmsg = theids
111 pmsg = theids
1 NO CONTENT: file was removed
NO CONTENT: file was removed
1 NO CONTENT: file was removed
NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now