##// END OF EJS Templates
merge IPython.parallel.streamsession into IPython.zmq.session...
MinRK -
Show More
@@ -1,402 +1,402 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import with_statement
19 19
20 20 import os
21 21 import socket
22 22 import stat
23 23 import sys
24 24 import uuid
25 25
26 26 from multiprocessing import Process
27 27
28 28 import zmq
29 29 from zmq.devices import ProcessMonitoredQueue
30 30 from zmq.log.handlers import PUBHandler
31 31 from zmq.utils import jsonapi as json
32 32
33 33 from IPython.config.application import boolean_flag
34 34 from IPython.core.newapplication import ProfileDir
35 35
36 36 from IPython.parallel.apps.baseapp import (
37 37 BaseParallelApplication,
38 38 base_flags
39 39 )
40 40 from IPython.utils.importstring import import_item
41 41 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
42 42
43 43 # from IPython.parallel.controller.controller import ControllerFactory
44 from IPython.parallel.streamsession import StreamSession
44 from IPython.zmq.session import Session
45 45 from IPython.parallel.controller.heartmonitor import HeartMonitor
46 46 from IPython.parallel.controller.hub import HubFactory
47 47 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
48 48 from IPython.parallel.controller.sqlitedb import SQLiteDB
49 49
50 50 from IPython.parallel.util import signal_children, split_url
51 51
52 52 # conditional import of MongoDB backend class
53 53
54 54 try:
55 55 from IPython.parallel.controller.mongodb import MongoDB
56 56 except ImportError:
57 57 maybe_mongo = []
58 58 else:
59 59 maybe_mongo = [MongoDB]
60 60
61 61
62 62 #-----------------------------------------------------------------------------
63 63 # Module level variables
64 64 #-----------------------------------------------------------------------------
65 65
66 66
67 67 #: The default config file name for this application
68 68 default_config_file_name = u'ipcontroller_config.py'
69 69
70 70
71 71 _description = """Start the IPython controller for parallel computing.
72 72
73 73 The IPython controller provides a gateway between the IPython engines and
74 74 clients. The controller needs to be started before the engines and can be
75 75 configured using command line options or using a cluster directory. Cluster
76 76 directories contain config, log and security files and are usually located in
77 77 your ipython directory and named as "cluster_<profile>". See the `profile`
78 78 and `profile_dir` options for details.
79 79 """
80 80
81 81
82 82
83 83
84 84 #-----------------------------------------------------------------------------
85 85 # The main application
86 86 #-----------------------------------------------------------------------------
87 87 flags = {}
88 88 flags.update(base_flags)
89 89 flags.update({
90 90 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
91 91 'Use threads instead of processes for the schedulers'),
92 92 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
93 93 'use the SQLiteDB backend'),
94 94 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
95 95 'use the MongoDB backend'),
96 96 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
97 97 'use the in-memory DictDB backend'),
98 98 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
99 99 'reuse existing json connection files')
100 100 })
101 101
102 102 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
103 103 "Use HMAC digests for authentication of messages.",
104 104 "Don't authenticate messages."
105 105 ))
106 106
107 107 class IPControllerApp(BaseParallelApplication):
108 108
109 109 name = u'ipcontroller'
110 110 description = _description
111 111 config_file_name = Unicode(default_config_file_name)
112 classes = [ProfileDir, StreamSession, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
112 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
113 113
114 114 # change default to True
115 115 auto_create = Bool(True, config=True,
116 116 help="""Whether to create profile dir if it doesn't exist.""")
117 117
118 118 reuse_files = Bool(False, config=True,
119 119 help='Whether to reuse existing json connection files.'
120 120 )
121 121 secure = Bool(True, config=True,
122 122 help='Whether to use HMAC digests for extra message authentication.'
123 123 )
124 124 ssh_server = Unicode(u'', config=True,
125 125 help="""ssh url for clients to use when connecting to the Controller
126 126 processes. It should be of the form: [user@]server[:port]. The
127 127 Controller's listening addresses must be accessible from the ssh server""",
128 128 )
129 129 location = Unicode(u'', config=True,
130 130 help="""The external IP or domain name of the Controller, used for disambiguating
131 131 engine and client connections.""",
132 132 )
133 133 import_statements = List([], config=True,
134 134 help="import statements to be run at startup. Necessary in some environments"
135 135 )
136 136
137 137 use_threads = Bool(False, config=True,
138 138 help='Use threads instead of processes for the schedulers',
139 139 )
140 140
141 141 # internal
142 142 children = List()
143 143 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
144 144
145 145 def _use_threads_changed(self, name, old, new):
146 146 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
147 147
148 148 aliases = Dict(dict(
149 149 log_level = 'IPControllerApp.log_level',
150 150 log_url = 'IPControllerApp.log_url',
151 151 reuse_files = 'IPControllerApp.reuse_files',
152 152 secure = 'IPControllerApp.secure',
153 153 ssh = 'IPControllerApp.ssh_server',
154 154 use_threads = 'IPControllerApp.use_threads',
155 155 import_statements = 'IPControllerApp.import_statements',
156 156 location = 'IPControllerApp.location',
157 157
158 ident = 'StreamSession.session',
159 user = 'StreamSession.username',
160 exec_key = 'StreamSession.keyfile',
158 ident = 'Session.session',
159 user = 'Session.username',
160 exec_key = 'Session.keyfile',
161 161
162 162 url = 'HubFactory.url',
163 163 ip = 'HubFactory.ip',
164 164 transport = 'HubFactory.transport',
165 165 port = 'HubFactory.regport',
166 166
167 167 ping = 'HeartMonitor.period',
168 168
169 169 scheme = 'TaskScheduler.scheme_name',
170 170 hwm = 'TaskScheduler.hwm',
171 171
172 172
173 173 profile = "BaseIPythonApplication.profile",
174 174 profile_dir = 'ProfileDir.location',
175 175
176 176 ))
177 177 flags = Dict(flags)
178 178
179 179
180 180 def save_connection_dict(self, fname, cdict):
181 181 """save a connection dict to json file."""
182 182 c = self.config
183 183 url = cdict['url']
184 184 location = cdict['location']
185 185 if not location:
186 186 try:
187 187 proto,ip,port = split_url(url)
188 188 except AssertionError:
189 189 pass
190 190 else:
191 191 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
192 192 cdict['location'] = location
193 193 fname = os.path.join(self.profile_dir.security_dir, fname)
194 194 with open(fname, 'w') as f:
195 195 f.write(json.dumps(cdict, indent=2))
196 196 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
197 197
198 198 def load_config_from_json(self):
199 199 """load config from existing json connector files."""
200 200 c = self.config
201 201 # load from engine config
202 202 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
203 203 cfg = json.loads(f.read())
204 key = c.StreamSession.key = cfg['exec_key']
204 key = c.Session.key = cfg['exec_key']
205 205 xport,addr = cfg['url'].split('://')
206 206 c.HubFactory.engine_transport = xport
207 207 ip,ports = addr.split(':')
208 208 c.HubFactory.engine_ip = ip
209 209 c.HubFactory.regport = int(ports)
210 210 self.location = cfg['location']
211 211
212 212 # load client config
213 213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 214 cfg = json.loads(f.read())
215 215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 216 xport,addr = cfg['url'].split('://')
217 217 c.HubFactory.client_transport = xport
218 218 ip,ports = addr.split(':')
219 219 c.HubFactory.client_ip = ip
220 220 self.ssh_server = cfg['ssh']
221 221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222 222
223 223 def init_hub(self):
224 224 c = self.config
225 225
226 226 self.do_import_statements()
227 227 reusing = self.reuse_files
228 228 if reusing:
229 229 try:
230 230 self.load_config_from_json()
231 231 except (AssertionError,IOError):
232 232 reusing=False
233 233 # check again, because reusing may have failed:
234 234 if reusing:
235 235 pass
236 236 elif self.secure:
237 237 key = str(uuid.uuid4())
238 238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 239 # with open(keyfile, 'w') as f:
240 240 # f.write(key)
241 241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.StreamSession.key = key
242 c.Session.key = key
243 243 else:
244 key = c.StreamSession.key = ''
244 key = c.Session.key = ''
245 245
246 246 try:
247 247 self.factory = HubFactory(config=c, log=self.log)
248 248 # self.start_logging()
249 249 self.factory.init_hub()
250 250 except:
251 251 self.log.error("Couldn't construct the Controller", exc_info=True)
252 252 self.exit(1)
253 253
254 254 if not reusing:
255 255 # save to new json config files
256 256 f = self.factory
257 257 cdict = {'exec_key' : key,
258 258 'ssh' : self.ssh_server,
259 259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 260 'location' : self.location
261 261 }
262 262 self.save_connection_dict('ipcontroller-client.json', cdict)
263 263 edict = cdict
264 264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 265 self.save_connection_dict('ipcontroller-engine.json', edict)
266 266
267 267 #
268 268 def init_schedulers(self):
269 269 children = self.children
270 270 mq = import_item(str(self.mq_class))
271 271
272 272 hub = self.factory
273 273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 274 # IOPub relay (in a Process)
275 275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
276 276 q.bind_in(hub.client_info['iopub'])
277 277 q.bind_out(hub.engine_info['iopub'])
278 278 q.setsockopt_out(zmq.SUBSCRIBE, '')
279 279 q.connect_mon(hub.monitor_url)
280 280 q.daemon=True
281 281 children.append(q)
282 282
283 283 # Multiplexer Queue (in a Process)
284 284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
285 285 q.bind_in(hub.client_info['mux'])
286 286 q.setsockopt_in(zmq.IDENTITY, 'mux')
287 287 q.bind_out(hub.engine_info['mux'])
288 288 q.connect_mon(hub.monitor_url)
289 289 q.daemon=True
290 290 children.append(q)
291 291
292 292 # Control Queue (in a Process)
293 293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
294 294 q.bind_in(hub.client_info['control'])
295 295 q.setsockopt_in(zmq.IDENTITY, 'control')
296 296 q.bind_out(hub.engine_info['control'])
297 297 q.connect_mon(hub.monitor_url)
298 298 q.daemon=True
299 299 children.append(q)
300 300 try:
301 301 scheme = self.config.TaskScheduler.scheme_name
302 302 except AttributeError:
303 303 scheme = TaskScheduler.scheme_name.get_default_value()
304 304 # Task Queue (in a Process)
305 305 if scheme == 'pure':
306 306 self.log.warn("task::using pure XREQ Task scheduler")
307 307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
308 308 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 309 q.bind_in(hub.client_info['task'][1])
310 310 q.setsockopt_in(zmq.IDENTITY, 'task')
311 311 q.bind_out(hub.engine_info['task'])
312 312 q.connect_mon(hub.monitor_url)
313 313 q.daemon=True
314 314 children.append(q)
315 315 elif scheme == 'none':
316 316 self.log.warn("task::using no Task scheduler")
317 317
318 318 else:
319 319 self.log.info("task::using Python %s Task scheduler"%scheme)
320 320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 321 hub.monitor_url, hub.client_info['notification'])
322 322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 323 log_url = self.log_url, config=dict(self.config))
324 324 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
325 325 q.daemon=True
326 326 children.append(q)
327 327
328 328
329 329 def save_urls(self):
330 330 """save the registration urls to files."""
331 331 c = self.config
332 332
333 333 sec_dir = self.profile_dir.security_dir
334 334 cf = self.factory
335 335
336 336 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
337 337 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
338 338
339 339 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
340 340 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
341 341
342 342
343 343 def do_import_statements(self):
344 344 statements = self.import_statements
345 345 for s in statements:
346 346 try:
347 347 self.log.msg("Executing statement: '%s'" % s)
348 348 exec s in globals(), locals()
349 349 except:
350 350 self.log.msg("Error running statement: %s" % s)
351 351
352 352 def forward_logging(self):
353 353 if self.log_url:
354 354 self.log.info("Forwarding logging to %s"%self.log_url)
355 355 context = zmq.Context.instance()
356 356 lsock = context.socket(zmq.PUB)
357 357 lsock.connect(self.log_url)
358 358 handler = PUBHandler(lsock)
359 359 self.log.removeHandler(self._log_handler)
360 360 handler.root_topic = 'controller'
361 361 handler.setLevel(self.log_level)
362 362 self.log.addHandler(handler)
363 363 self._log_handler = handler
364 364 # #
365 365
366 366 def initialize(self, argv=None):
367 367 super(IPControllerApp, self).initialize(argv)
368 368 self.forward_logging()
369 369 self.init_hub()
370 370 self.init_schedulers()
371 371
372 372 def start(self):
373 373 # Start the subprocesses:
374 374 self.factory.start()
375 375 child_procs = []
376 376 for child in self.children:
377 377 child.start()
378 378 if isinstance(child, ProcessMonitoredQueue):
379 379 child_procs.append(child.launcher)
380 380 elif isinstance(child, Process):
381 381 child_procs.append(child)
382 382 if child_procs:
383 383 signal_children(child_procs)
384 384
385 385 self.write_pid_file(overwrite=True)
386 386
387 387 try:
388 388 self.factory.loop.start()
389 389 except KeyboardInterrupt:
390 390 self.log.critical("Interrupted, Exiting...\n")
391 391
392 392
393 393
394 394 def launch_new_instance():
395 395 """Create and run the IPython controller"""
396 396 app = IPControllerApp.instance()
397 397 app.initialize()
398 398 app.start()
399 399
400 400
401 401 if __name__ == '__main__':
402 402 launch_new_instance()
@@ -1,270 +1,270 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import json
19 19 import os
20 20 import sys
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24
25 25 from IPython.core.newapplication import ProfileDir
26 26 from IPython.parallel.apps.baseapp import BaseParallelApplication
27 27 from IPython.zmq.log import EnginePUBHandler
28 28
29 29 from IPython.config.configurable import Configurable
30 from IPython.parallel.streamsession import StreamSession
30 from IPython.zmq.session import Session
31 31 from IPython.parallel.engine.engine import EngineFactory
32 32 from IPython.parallel.engine.streamkernel import Kernel
33 33 from IPython.parallel.util import disambiguate_url
34 34
35 35 from IPython.utils.importstring import import_item
36 36 from IPython.utils.traitlets import Bool, Unicode, Dict, List
37 37
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Module level variables
41 41 #-----------------------------------------------------------------------------
42 42
43 43 #: The default config file name for this application
44 44 default_config_file_name = u'ipengine_config.py'
45 45
46 46 _description = """Start an IPython engine for parallel computing.
47 47
48 48 IPython engines run in parallel and perform computations on behalf of a client
49 49 and controller. A controller needs to be started before the engines. The
50 50 engine can be configured using command line options or using a cluster
51 51 directory. Cluster directories contain config, log and security files and are
52 52 usually located in your ipython directory and named as "cluster_<profile>".
53 53 See the `profile` and `profile_dir` options for details.
54 54 """
55 55
56 56
57 57 #-----------------------------------------------------------------------------
58 58 # MPI configuration
59 59 #-----------------------------------------------------------------------------
60 60
61 61 mpi4py_init = """from mpi4py import MPI as mpi
62 62 mpi.size = mpi.COMM_WORLD.Get_size()
63 63 mpi.rank = mpi.COMM_WORLD.Get_rank()
64 64 """
65 65
66 66
67 67 pytrilinos_init = """from PyTrilinos import Epetra
68 68 class SimpleStruct:
69 69 pass
70 70 mpi = SimpleStruct()
71 71 mpi.rank = 0
72 72 mpi.size = 0
73 73 """
74 74
75 75 class MPI(Configurable):
76 76 """Configurable for MPI initialization"""
77 77 use = Unicode('', config=True,
78 78 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
79 79 )
80 80
81 81 def _on_use_changed(self, old, new):
82 82 # load default init script if it's not set
83 83 if not self.init_script:
84 84 self.init_script = self.default_inits.get(new, '')
85 85
86 86 init_script = Unicode('', config=True,
87 87 help="Initialization code for MPI")
88 88
89 89 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
90 90 config=True)
91 91
92 92
93 93 #-----------------------------------------------------------------------------
94 94 # Main application
95 95 #-----------------------------------------------------------------------------
96 96
97 97
98 98 class IPEngineApp(BaseParallelApplication):
99 99
100 100 app_name = Unicode(u'ipengine')
101 101 description = Unicode(_description)
102 102 config_file_name = Unicode(default_config_file_name)
103 classes = List([ProfileDir, StreamSession, EngineFactory, Kernel, MPI])
103 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
104 104
105 105 startup_script = Unicode(u'', config=True,
106 106 help='specify a script to be run at startup')
107 107 startup_command = Unicode('', config=True,
108 108 help='specify a command to be run at startup')
109 109
110 110 url_file = Unicode(u'', config=True,
111 111 help="""The full location of the file containing the connection information for
112 112 the controller. If this is not given, the file must be in the
113 113 security directory of the cluster directory. This location is
114 114 resolved using the `profile` or `profile_dir` options.""",
115 115 )
116 116
117 117 url_file_name = Unicode(u'ipcontroller-engine.json')
118 118 log_url = Unicode('', config=True,
119 119 help="""The URL for the iploggerapp instance, for forwarding
120 120 logging to a central location.""")
121 121
122 122 aliases = Dict(dict(
123 123 file = 'IPEngineApp.url_file',
124 124 c = 'IPEngineApp.startup_command',
125 125 s = 'IPEngineApp.startup_script',
126 126
127 ident = 'StreamSession.session',
128 user = 'StreamSession.username',
129 exec_key = 'StreamSession.keyfile',
127 ident = 'Session.session',
128 user = 'Session.username',
129 exec_key = 'Session.keyfile',
130 130
131 131 url = 'EngineFactory.url',
132 132 ip = 'EngineFactory.ip',
133 133 transport = 'EngineFactory.transport',
134 134 port = 'EngineFactory.regport',
135 135 location = 'EngineFactory.location',
136 136
137 137 timeout = 'EngineFactory.timeout',
138 138
139 139 profile = "IPEngineApp.profile",
140 140 profile_dir = 'ProfileDir.location',
141 141
142 142 mpi = 'MPI.use',
143 143
144 144 log_level = 'IPEngineApp.log_level',
145 145 log_url = 'IPEngineApp.log_url'
146 146 ))
147 147
148 148 # def find_key_file(self):
149 149 # """Set the key file.
150 150 #
151 151 # Here we don't try to actually see if it exists for is valid as that
152 152 # is hadled by the connection logic.
153 153 # """
154 154 # config = self.master_config
155 155 # # Find the actual controller key file
156 156 # if not config.Global.key_file:
157 157 # try_this = os.path.join(
158 158 # config.Global.profile_dir,
159 159 # config.Global.security_dir,
160 160 # config.Global.key_file_name
161 161 # )
162 162 # config.Global.key_file = try_this
163 163
164 164 def find_url_file(self):
165 165 """Set the key file.
166 166
167 167 Here we don't try to actually see if it exists for is valid as that
168 168 is hadled by the connection logic.
169 169 """
170 170 config = self.config
171 171 # Find the actual controller key file
172 172 if not self.url_file:
173 173 self.url_file = os.path.join(
174 174 self.profile_dir.security_dir,
175 175 self.url_file_name
176 176 )
177 177 def init_engine(self):
178 178 # This is the working dir by now.
179 179 sys.path.insert(0, '')
180 180 config = self.config
181 181 # print config
182 182 self.find_url_file()
183 183
184 184 # if os.path.exists(config.Global.key_file) and config.Global.secure:
185 185 # config.SessionFactory.exec_key = config.Global.key_file
186 186 if os.path.exists(self.url_file):
187 187 with open(self.url_file) as f:
188 188 d = json.loads(f.read())
189 189 for k,v in d.iteritems():
190 190 if isinstance(v, unicode):
191 191 d[k] = v.encode()
192 192 if d['exec_key']:
193 config.StreamSession.key = d['exec_key']
193 config.Session.key = d['exec_key']
194 194 d['url'] = disambiguate_url(d['url'], d['location'])
195 195 config.EngineFactory.url = d['url']
196 196 config.EngineFactory.location = d['location']
197 197
198 198 try:
199 199 exec_lines = config.Kernel.exec_lines
200 200 except AttributeError:
201 201 config.Kernel.exec_lines = []
202 202 exec_lines = config.Kernel.exec_lines
203 203
204 204 if self.startup_script:
205 205 enc = sys.getfilesystemencoding() or 'utf8'
206 206 cmd="execfile(%r)"%self.startup_script.encode(enc)
207 207 exec_lines.append(cmd)
208 208 if self.startup_command:
209 209 exec_lines.append(self.startup_command)
210 210
211 211 # Create the underlying shell class and Engine
212 212 # shell_class = import_item(self.master_config.Global.shell_class)
213 213 # print self.config
214 214 try:
215 215 self.engine = EngineFactory(config=config, log=self.log)
216 216 except:
217 217 self.log.error("Couldn't start the Engine", exc_info=True)
218 218 self.exit(1)
219 219
220 220 def forward_logging(self):
221 221 if self.log_url:
222 222 self.log.info("Forwarding logging to %s"%self.log_url)
223 223 context = self.engine.context
224 224 lsock = context.socket(zmq.PUB)
225 225 lsock.connect(self.log_url)
226 226 self.log.removeHandler(self._log_handler)
227 227 handler = EnginePUBHandler(self.engine, lsock)
228 228 handler.setLevel(self.log_level)
229 229 self.log.addHandler(handler)
230 230 self._log_handler = handler
231 231 #
232 232 def init_mpi(self):
233 233 global mpi
234 234 self.mpi = MPI(config=self.config)
235 235
236 236 mpi_import_statement = self.mpi.init_script
237 237 if mpi_import_statement:
238 238 try:
239 239 self.log.info("Initializing MPI:")
240 240 self.log.info(mpi_import_statement)
241 241 exec mpi_import_statement in globals()
242 242 except:
243 243 mpi = None
244 244 else:
245 245 mpi = None
246 246
247 247 def initialize(self, argv=None):
248 248 super(IPEngineApp, self).initialize(argv)
249 249 self.init_mpi()
250 250 self.init_engine()
251 251 self.forward_logging()
252 252
253 253 def start(self):
254 254 self.engine.start()
255 255 try:
256 256 self.engine.loop.start()
257 257 except KeyboardInterrupt:
258 258 self.log.critical("Engine Interrupted, shutting down...\n")
259 259
260 260
261 261 def launch_new_instance():
262 262 """Create and run the IPython engine"""
263 263 app = IPEngineApp.instance()
264 264 app.initialize()
265 265 app.start()
266 266
267 267
268 268 if __name__ == '__main__':
269 269 launch_new_instance()
270 270
@@ -1,1356 +1,1353 b''
1 1 """A semi-synchronous Client for the ZMQ cluster"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 from IPython.utils.jsonutil import extract_dates
26 27 from IPython.utils.path import get_ipython_dir
27 28 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
28 29 Dict, List, Bool, Set)
29 30 from IPython.external.decorator import decorator
30 31 from IPython.external.ssh import tunnel
31 32
32 33 from IPython.parallel import error
33 from IPython.parallel import streamsession as ss
34 34 from IPython.parallel import util
35 35
36 from IPython.zmq.session import Session, Message
37
36 38 from .asyncresult import AsyncResult, AsyncHubResult
37 39 from IPython.core.newapplication import ProfileDir, ProfileDirError
38 40 from .view import DirectView, LoadBalancedView
39 41
40 42 #--------------------------------------------------------------------------
41 43 # Decorators for Client methods
42 44 #--------------------------------------------------------------------------
43 45
44 46 @decorator
45 47 def spin_first(f, self, *args, **kwargs):
46 48 """Call spin() to sync state prior to calling the method."""
47 49 self.spin()
48 50 return f(self, *args, **kwargs)
49 51
50 52
51 53 #--------------------------------------------------------------------------
52 54 # Classes
53 55 #--------------------------------------------------------------------------
54 56
55 57 class Metadata(dict):
56 58 """Subclass of dict for initializing metadata values.
57 59
58 60 Attribute access works on keys.
59 61
60 62 These objects have a strict set of keys - errors will raise if you try
61 63 to add new keys.
62 64 """
63 65 def __init__(self, *args, **kwargs):
64 66 dict.__init__(self)
65 67 md = {'msg_id' : None,
66 68 'submitted' : None,
67 69 'started' : None,
68 70 'completed' : None,
69 71 'received' : None,
70 72 'engine_uuid' : None,
71 73 'engine_id' : None,
72 74 'follow' : None,
73 75 'after' : None,
74 76 'status' : None,
75 77
76 78 'pyin' : None,
77 79 'pyout' : None,
78 80 'pyerr' : None,
79 81 'stdout' : '',
80 82 'stderr' : '',
81 83 }
82 84 self.update(md)
83 85 self.update(dict(*args, **kwargs))
84 86
85 87 def __getattr__(self, key):
86 88 """getattr aliased to getitem"""
87 89 if key in self.iterkeys():
88 90 return self[key]
89 91 else:
90 92 raise AttributeError(key)
91 93
92 94 def __setattr__(self, key, value):
93 95 """setattr aliased to setitem, with strict"""
94 96 if key in self.iterkeys():
95 97 self[key] = value
96 98 else:
97 99 raise AttributeError(key)
98 100
99 101 def __setitem__(self, key, value):
100 102 """strict static key enforcement"""
101 103 if key in self.iterkeys():
102 104 dict.__setitem__(self, key, value)
103 105 else:
104 106 raise KeyError(key)
105 107
106 108
107 109 class Client(HasTraits):
108 110 """A semi-synchronous client to the IPython ZMQ cluster
109 111
110 112 Parameters
111 113 ----------
112 114
113 115 url_or_file : bytes; zmq url or path to ipcontroller-client.json
114 116 Connection information for the Hub's registration. If a json connector
115 117 file is given, then likely no further configuration is necessary.
116 118 [Default: use profile]
117 119 profile : bytes
118 120 The name of the Cluster profile to be used to find connector information.
119 121 [Default: 'default']
120 122 context : zmq.Context
121 123 Pass an existing zmq.Context instance, otherwise the client will create its own.
122 124 username : bytes
123 125 set username to be passed to the Session object
124 126 debug : bool
125 127 flag for lots of message printing for debug purposes
126 128
127 129 #-------------- ssh related args ----------------
128 130 # These are args for configuring the ssh tunnel to be used
129 131 # credentials are used to forward connections over ssh to the Controller
130 132 # Note that the ip given in `addr` needs to be relative to sshserver
131 133 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
132 134 # and set sshserver as the same machine the Controller is on. However,
133 135 # the only requirement is that sshserver is able to see the Controller
134 136 # (i.e. is within the same trusted network).
135 137
136 138 sshserver : str
137 139 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
138 140 If keyfile or password is specified, and this is not, it will default to
139 141 the ip given in addr.
140 142 sshkey : str; path to public ssh key file
141 143 This specifies a key to be used in ssh login, default None.
142 144 Regular default ssh keys will be used without specifying this argument.
143 145 password : str
144 146 Your ssh password to sshserver. Note that if this is left None,
145 147 you will be prompted for it if passwordless key based login is unavailable.
146 148 paramiko : bool
147 149 flag for whether to use paramiko instead of shell ssh for tunneling.
148 150 [default: True on win32, False else]
149 151
150 152 ------- exec authentication args -------
151 153 If even localhost is untrusted, you can have some protection against
152 154 unauthorized execution by using a key. Messages are still sent
153 155 as cleartext, so if someone can snoop your loopback traffic this will
154 156 not help against malicious attacks.
155 157
156 158 exec_key : str
157 159 an authentication key or file containing a key
158 160 default: None
159 161
160 162
161 163 Attributes
162 164 ----------
163 165
164 166 ids : list of int engine IDs
165 167 requesting the ids attribute always synchronizes
166 168 the registration state. To request ids without synchronization,
167 169 use semi-private _ids attributes.
168 170
169 171 history : list of msg_ids
170 172 a list of msg_ids, keeping track of all the execution
171 173 messages you have submitted in order.
172 174
173 175 outstanding : set of msg_ids
174 176 a set of msg_ids that have been submitted, but whose
175 177 results have not yet been received.
176 178
177 179 results : dict
178 180 a dict of all our results, keyed by msg_id
179 181
180 182 block : bool
181 183 determines default behavior when block not specified
182 184 in execution methods
183 185
184 186 Methods
185 187 -------
186 188
187 189 spin
188 190 flushes incoming results and registration state changes
189 191 control methods spin, and requesting `ids` also ensures up to date
190 192
191 193 wait
192 194 wait on one or more msg_ids
193 195
194 196 execution methods
195 197 apply
196 198 legacy: execute, run
197 199
198 200 data movement
199 201 push, pull, scatter, gather
200 202
201 203 query methods
202 204 queue_status, get_result, purge, result_status
203 205
204 206 control methods
205 207 abort, shutdown
206 208
207 209 """
208 210
209 211
210 212 block = Bool(False)
211 213 outstanding = Set()
212 214 results = Instance('collections.defaultdict', (dict,))
213 215 metadata = Instance('collections.defaultdict', (Metadata,))
214 216 history = List()
215 217 debug = Bool(False)
216 218 profile=Unicode('default')
217 219
218 220 _outstanding_dict = Instance('collections.defaultdict', (set,))
219 221 _ids = List()
220 222 _connected=Bool(False)
221 223 _ssh=Bool(False)
222 224 _context = Instance('zmq.Context')
223 225 _config = Dict()
224 226 _engines=Instance(util.ReverseDict, (), {})
225 227 # _hub_socket=Instance('zmq.Socket')
226 228 _query_socket=Instance('zmq.Socket')
227 229 _control_socket=Instance('zmq.Socket')
228 230 _iopub_socket=Instance('zmq.Socket')
229 231 _notification_socket=Instance('zmq.Socket')
230 232 _mux_socket=Instance('zmq.Socket')
231 233 _task_socket=Instance('zmq.Socket')
232 234 _task_scheme=Unicode()
233 235 _closed = False
234 236 _ignored_control_replies=Int(0)
235 237 _ignored_hub_replies=Int(0)
236 238
237 239 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
238 240 context=None, username=None, debug=False, exec_key=None,
239 241 sshserver=None, sshkey=None, password=None, paramiko=None,
240 242 timeout=10
241 243 ):
242 244 super(Client, self).__init__(debug=debug, profile=profile)
243 245 if context is None:
244 246 context = zmq.Context.instance()
245 247 self._context = context
246 248
247 249
248 250 self._setup_profile_dir(profile, profile_dir, ipython_dir)
249 251 if self._cd is not None:
250 252 if url_or_file is None:
251 253 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
252 254 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
253 255 " Please specify at least one of url_or_file or profile."
254 256
255 257 try:
256 258 util.validate_url(url_or_file)
257 259 except AssertionError:
258 260 if not os.path.exists(url_or_file):
259 261 if self._cd:
260 262 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
261 263 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
262 264 with open(url_or_file) as f:
263 265 cfg = json.loads(f.read())
264 266 else:
265 267 cfg = {'url':url_or_file}
266 268
267 269 # sync defaults from args, json:
268 270 if sshserver:
269 271 cfg['ssh'] = sshserver
270 272 if exec_key:
271 273 cfg['exec_key'] = exec_key
272 274 exec_key = cfg['exec_key']
273 275 sshserver=cfg['ssh']
274 276 url = cfg['url']
275 277 location = cfg.setdefault('location', None)
276 278 cfg['url'] = util.disambiguate_url(cfg['url'], location)
277 279 url = cfg['url']
278 280
279 281 self._config = cfg
280 282
281 283 self._ssh = bool(sshserver or sshkey or password)
282 284 if self._ssh and sshserver is None:
283 285 # default to ssh via localhost
284 286 sshserver = url.split('://')[1].split(':')[0]
285 287 if self._ssh and password is None:
286 288 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
287 289 password=False
288 290 else:
289 291 password = getpass("SSH Password for %s: "%sshserver)
290 292 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
291 293 if exec_key is not None and os.path.isfile(exec_key):
292 294 arg = 'keyfile'
293 295 else:
294 296 arg = 'key'
295 297 key_arg = {arg:exec_key}
296 298 if username is None:
297 self.session = ss.StreamSession(**key_arg)
299 self.session = Session(**key_arg)
298 300 else:
299 self.session = ss.StreamSession(username=username, **key_arg)
301 self.session = Session(username=username, **key_arg)
300 302 self._query_socket = self._context.socket(zmq.XREQ)
301 303 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
302 304 if self._ssh:
303 305 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
304 306 else:
305 307 self._query_socket.connect(url)
306 308
307 309 self.session.debug = self.debug
308 310
309 311 self._notification_handlers = {'registration_notification' : self._register_engine,
310 312 'unregistration_notification' : self._unregister_engine,
311 313 'shutdown_notification' : lambda msg: self.close(),
312 314 }
313 315 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
314 316 'apply_reply' : self._handle_apply_reply}
315 317 self._connect(sshserver, ssh_kwargs, timeout)
316 318
317 319 def __del__(self):
318 320 """cleanup sockets, but _not_ context."""
319 321 self.close()
320 322
321 323 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
322 324 if ipython_dir is None:
323 325 ipython_dir = get_ipython_dir()
324 326 if profile_dir is not None:
325 327 try:
326 328 self._cd = ProfileDir.find_profile_dir(profile_dir)
327 329 return
328 330 except ProfileDirError:
329 331 pass
330 332 elif profile is not None:
331 333 try:
332 334 self._cd = ProfileDir.find_profile_dir_by_name(
333 335 ipython_dir, profile)
334 336 return
335 337 except ProfileDirError:
336 338 pass
337 339 self._cd = None
338 340
339 341 def _update_engines(self, engines):
340 342 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
341 343 for k,v in engines.iteritems():
342 344 eid = int(k)
343 345 self._engines[eid] = bytes(v) # force not unicode
344 346 self._ids.append(eid)
345 347 self._ids = sorted(self._ids)
346 348 if sorted(self._engines.keys()) != range(len(self._engines)) and \
347 349 self._task_scheme == 'pure' and self._task_socket:
348 350 self._stop_scheduling_tasks()
349 351
350 352 def _stop_scheduling_tasks(self):
351 353 """Stop scheduling tasks because an engine has been unregistered
352 354 from a pure ZMQ scheduler.
353 355 """
354 356 self._task_socket.close()
355 357 self._task_socket = None
356 358 msg = "An engine has been unregistered, and we are using pure " +\
357 359 "ZMQ task scheduling. Task farming will be disabled."
358 360 if self.outstanding:
359 361 msg += " If you were running tasks when this happened, " +\
360 362 "some `outstanding` msg_ids may never resolve."
361 363 warnings.warn(msg, RuntimeWarning)
362 364
363 365 def _build_targets(self, targets):
364 366 """Turn valid target IDs or 'all' into two lists:
365 367 (int_ids, uuids).
366 368 """
367 369 if not self._ids:
368 370 # flush notification socket if no engines yet, just in case
369 371 if not self.ids:
370 372 raise error.NoEnginesRegistered("Can't build targets without any engines")
371 373
372 374 if targets is None:
373 375 targets = self._ids
374 376 elif isinstance(targets, str):
375 377 if targets.lower() == 'all':
376 378 targets = self._ids
377 379 else:
378 380 raise TypeError("%r not valid str target, must be 'all'"%(targets))
379 381 elif isinstance(targets, int):
380 382 if targets < 0:
381 383 targets = self.ids[targets]
382 384 if targets not in self._ids:
383 385 raise IndexError("No such engine: %i"%targets)
384 386 targets = [targets]
385 387
386 388 if isinstance(targets, slice):
387 389 indices = range(len(self._ids))[targets]
388 390 ids = self.ids
389 391 targets = [ ids[i] for i in indices ]
390 392
391 393 if not isinstance(targets, (tuple, list, xrange)):
392 394 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
393 395
394 396 return [self._engines[t] for t in targets], list(targets)
395 397
396 398 def _connect(self, sshserver, ssh_kwargs, timeout):
397 399 """setup all our socket connections to the cluster. This is called from
398 400 __init__."""
399 401
400 402 # Maybe allow reconnecting?
401 403 if self._connected:
402 404 return
403 405 self._connected=True
404 406
405 407 def connect_socket(s, url):
406 408 url = util.disambiguate_url(url, self._config['location'])
407 409 if self._ssh:
408 410 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
409 411 else:
410 412 return s.connect(url)
411 413
412 414 self.session.send(self._query_socket, 'connection_request')
413 415 r,w,x = zmq.select([self._query_socket],[],[], timeout)
414 416 if not r:
415 417 raise error.TimeoutError("Hub connection request timed out")
416 418 idents,msg = self.session.recv(self._query_socket,mode=0)
417 419 if self.debug:
418 420 pprint(msg)
419 msg = ss.Message(msg)
421 msg = Message(msg)
420 422 content = msg.content
421 423 self._config['registration'] = dict(content)
422 424 if content.status == 'ok':
423 425 if content.mux:
424 426 self._mux_socket = self._context.socket(zmq.XREQ)
425 427 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
426 428 connect_socket(self._mux_socket, content.mux)
427 429 if content.task:
428 430 self._task_scheme, task_addr = content.task
429 431 self._task_socket = self._context.socket(zmq.XREQ)
430 432 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
431 433 connect_socket(self._task_socket, task_addr)
432 434 if content.notification:
433 435 self._notification_socket = self._context.socket(zmq.SUB)
434 436 connect_socket(self._notification_socket, content.notification)
435 437 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
436 438 # if content.query:
437 439 # self._query_socket = self._context.socket(zmq.XREQ)
438 440 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 441 # connect_socket(self._query_socket, content.query)
440 442 if content.control:
441 443 self._control_socket = self._context.socket(zmq.XREQ)
442 444 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
443 445 connect_socket(self._control_socket, content.control)
444 446 if content.iopub:
445 447 self._iopub_socket = self._context.socket(zmq.SUB)
446 448 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
447 449 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 450 connect_socket(self._iopub_socket, content.iopub)
449 451 self._update_engines(dict(content.engines))
450 452 else:
451 453 self._connected = False
452 454 raise Exception("Failed to connect!")
453 455
454 456 #--------------------------------------------------------------------------
455 457 # handlers and callbacks for incoming messages
456 458 #--------------------------------------------------------------------------
457 459
458 460 def _unwrap_exception(self, content):
459 461 """unwrap exception, and remap engine_id to int."""
460 462 e = error.unwrap_exception(content)
461 463 # print e.traceback
462 464 if e.engine_info:
463 465 e_uuid = e.engine_info['engine_uuid']
464 466 eid = self._engines[e_uuid]
465 467 e.engine_info['engine_id'] = eid
466 468 return e
467 469
468 470 def _extract_metadata(self, header, parent, content):
469 471 md = {'msg_id' : parent['msg_id'],
470 472 'received' : datetime.now(),
471 473 'engine_uuid' : header.get('engine', None),
472 474 'follow' : parent.get('follow', []),
473 475 'after' : parent.get('after', []),
474 476 'status' : content['status'],
475 477 }
476 478
477 479 if md['engine_uuid'] is not None:
478 480 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
479 481
480 482 if 'date' in parent:
481 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
483 md['submitted'] = parent['date']
482 484 if 'started' in header:
483 md['started'] = datetime.strptime(header['started'], util.ISO8601)
485 md['started'] = header['started']
484 486 if 'date' in header:
485 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
487 md['completed'] = header['date']
486 488 return md
487 489
488 490 def _register_engine(self, msg):
489 491 """Register a new engine, and update our connection info."""
490 492 content = msg['content']
491 493 eid = content['id']
492 494 d = {eid : content['queue']}
493 495 self._update_engines(d)
494 496
495 497 def _unregister_engine(self, msg):
496 498 """Unregister an engine that has died."""
497 499 content = msg['content']
498 500 eid = int(content['id'])
499 501 if eid in self._ids:
500 502 self._ids.remove(eid)
501 503 uuid = self._engines.pop(eid)
502 504
503 505 self._handle_stranded_msgs(eid, uuid)
504 506
505 507 if self._task_socket and self._task_scheme == 'pure':
506 508 self._stop_scheduling_tasks()
507 509
508 510 def _handle_stranded_msgs(self, eid, uuid):
509 511 """Handle messages known to be on an engine when the engine unregisters.
510 512
511 513 It is possible that this will fire prematurely - that is, an engine will
512 514 go down after completing a result, and the client will be notified
513 515 of the unregistration and later receive the successful result.
514 516 """
515 517
516 518 outstanding = self._outstanding_dict[uuid]
517 519
518 520 for msg_id in list(outstanding):
519 521 if msg_id in self.results:
520 522 # we already
521 523 continue
522 524 try:
523 525 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
524 526 except:
525 527 content = error.wrap_exception()
526 528 # build a fake message:
527 529 parent = {}
528 530 header = {}
529 531 parent['msg_id'] = msg_id
530 532 header['engine'] = uuid
531 header['date'] = datetime.now().strftime(util.ISO8601)
533 header['date'] = datetime.now()
532 534 msg = dict(parent_header=parent, header=header, content=content)
533 535 self._handle_apply_reply(msg)
534 536
535 537 def _handle_execute_reply(self, msg):
536 538 """Save the reply to an execute_request into our results.
537 539
538 540 execute messages are never actually used. apply is used instead.
539 541 """
540 542
541 543 parent = msg['parent_header']
542 544 msg_id = parent['msg_id']
543 545 if msg_id not in self.outstanding:
544 546 if msg_id in self.history:
545 547 print ("got stale result: %s"%msg_id)
546 548 else:
547 549 print ("got unknown result: %s"%msg_id)
548 550 else:
549 551 self.outstanding.remove(msg_id)
550 552 self.results[msg_id] = self._unwrap_exception(msg['content'])
551 553
552 554 def _handle_apply_reply(self, msg):
553 555 """Save the reply to an apply_request into our results."""
554 parent = msg['parent_header']
556 parent = extract_dates(msg['parent_header'])
555 557 msg_id = parent['msg_id']
556 558 if msg_id not in self.outstanding:
557 559 if msg_id in self.history:
558 560 print ("got stale result: %s"%msg_id)
559 561 print self.results[msg_id]
560 562 print msg
561 563 else:
562 564 print ("got unknown result: %s"%msg_id)
563 565 else:
564 566 self.outstanding.remove(msg_id)
565 567 content = msg['content']
566 header = msg['header']
568 header = extract_dates(msg['header'])
567 569
568 570 # construct metadata:
569 571 md = self.metadata[msg_id]
570 572 md.update(self._extract_metadata(header, parent, content))
571 573 # is this redundant?
572 574 self.metadata[msg_id] = md
573 575
574 576 e_outstanding = self._outstanding_dict[md['engine_uuid']]
575 577 if msg_id in e_outstanding:
576 578 e_outstanding.remove(msg_id)
577 579
578 580 # construct result:
579 581 if content['status'] == 'ok':
580 582 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
581 583 elif content['status'] == 'aborted':
582 584 self.results[msg_id] = error.TaskAborted(msg_id)
583 585 elif content['status'] == 'resubmitted':
584 586 # TODO: handle resubmission
585 587 pass
586 588 else:
587 589 self.results[msg_id] = self._unwrap_exception(content)
588 590
589 591 def _flush_notifications(self):
590 592 """Flush notifications of engine registrations waiting
591 593 in ZMQ queue."""
592 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
594 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
593 595 while msg is not None:
594 596 if self.debug:
595 597 pprint(msg)
596 msg = msg[-1]
597 598 msg_type = msg['msg_type']
598 599 handler = self._notification_handlers.get(msg_type, None)
599 600 if handler is None:
600 601 raise Exception("Unhandled message type: %s"%msg.msg_type)
601 602 else:
602 603 handler(msg)
603 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
604 605
605 606 def _flush_results(self, sock):
606 607 """Flush task or queue results waiting in ZMQ queue."""
607 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
608 609 while msg is not None:
609 610 if self.debug:
610 611 pprint(msg)
611 msg = msg[-1]
612 612 msg_type = msg['msg_type']
613 613 handler = self._queue_handlers.get(msg_type, None)
614 614 if handler is None:
615 615 raise Exception("Unhandled message type: %s"%msg.msg_type)
616 616 else:
617 617 handler(msg)
618 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
618 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
619 619
620 620 def _flush_control(self, sock):
621 621 """Flush replies from the control channel waiting
622 622 in the ZMQ queue.
623 623
624 624 Currently: ignore them."""
625 625 if self._ignored_control_replies <= 0:
626 626 return
627 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
628 628 while msg is not None:
629 629 self._ignored_control_replies -= 1
630 630 if self.debug:
631 631 pprint(msg)
632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
633 633
634 634 def _flush_ignored_control(self):
635 635 """flush ignored control replies"""
636 636 while self._ignored_control_replies > 0:
637 637 self.session.recv(self._control_socket)
638 638 self._ignored_control_replies -= 1
639 639
640 640 def _flush_ignored_hub_replies(self):
641 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
641 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
642 642 while msg is not None:
643 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
643 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
644 644
645 645 def _flush_iopub(self, sock):
646 646 """Flush replies from the iopub channel waiting
647 647 in the ZMQ queue.
648 648 """
649 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
650 650 while msg is not None:
651 651 if self.debug:
652 652 pprint(msg)
653 msg = msg[-1]
654 653 parent = msg['parent_header']
655 654 msg_id = parent['msg_id']
656 655 content = msg['content']
657 656 header = msg['header']
658 657 msg_type = msg['msg_type']
659 658
660 659 # init metadata:
661 660 md = self.metadata[msg_id]
662 661
663 662 if msg_type == 'stream':
664 663 name = content['name']
665 664 s = md[name] or ''
666 665 md[name] = s + content['data']
667 666 elif msg_type == 'pyerr':
668 667 md.update({'pyerr' : self._unwrap_exception(content)})
669 668 elif msg_type == 'pyin':
670 669 md.update({'pyin' : content['code']})
671 670 else:
672 671 md.update({msg_type : content.get('data', '')})
673 672
674 673 # reduntant?
675 674 self.metadata[msg_id] = md
676 675
677 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
676 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
678 677
679 678 #--------------------------------------------------------------------------
680 679 # len, getitem
681 680 #--------------------------------------------------------------------------
682 681
683 682 def __len__(self):
684 683 """len(client) returns # of engines."""
685 684 return len(self.ids)
686 685
687 686 def __getitem__(self, key):
688 687 """index access returns DirectView multiplexer objects
689 688
690 689 Must be int, slice, or list/tuple/xrange of ints"""
691 690 if not isinstance(key, (int, slice, tuple, list, xrange)):
692 691 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
693 692 else:
694 693 return self.direct_view(key)
695 694
696 695 #--------------------------------------------------------------------------
697 696 # Begin public methods
698 697 #--------------------------------------------------------------------------
699 698
700 699 @property
701 700 def ids(self):
702 701 """Always up-to-date ids property."""
703 702 self._flush_notifications()
704 703 # always copy:
705 704 return list(self._ids)
706 705
707 706 def close(self):
708 707 if self._closed:
709 708 return
710 709 snames = filter(lambda n: n.endswith('socket'), dir(self))
711 710 for socket in map(lambda name: getattr(self, name), snames):
712 711 if isinstance(socket, zmq.Socket) and not socket.closed:
713 712 socket.close()
714 713 self._closed = True
715 714
716 715 def spin(self):
717 716 """Flush any registration notifications and execution results
718 717 waiting in the ZMQ queue.
719 718 """
720 719 if self._notification_socket:
721 720 self._flush_notifications()
722 721 if self._mux_socket:
723 722 self._flush_results(self._mux_socket)
724 723 if self._task_socket:
725 724 self._flush_results(self._task_socket)
726 725 if self._control_socket:
727 726 self._flush_control(self._control_socket)
728 727 if self._iopub_socket:
729 728 self._flush_iopub(self._iopub_socket)
730 729 if self._query_socket:
731 730 self._flush_ignored_hub_replies()
732 731
733 732 def wait(self, jobs=None, timeout=-1):
734 733 """waits on one or more `jobs`, for up to `timeout` seconds.
735 734
736 735 Parameters
737 736 ----------
738 737
739 738 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
740 739 ints are indices to self.history
741 740 strs are msg_ids
742 741 default: wait on all outstanding messages
743 742 timeout : float
744 743 a time in seconds, after which to give up.
745 744 default is -1, which means no timeout
746 745
747 746 Returns
748 747 -------
749 748
750 749 True : when all msg_ids are done
751 750 False : timeout reached, some msg_ids still outstanding
752 751 """
753 752 tic = time.time()
754 753 if jobs is None:
755 754 theids = self.outstanding
756 755 else:
757 756 if isinstance(jobs, (int, str, AsyncResult)):
758 757 jobs = [jobs]
759 758 theids = set()
760 759 for job in jobs:
761 760 if isinstance(job, int):
762 761 # index access
763 762 job = self.history[job]
764 763 elif isinstance(job, AsyncResult):
765 764 map(theids.add, job.msg_ids)
766 765 continue
767 766 theids.add(job)
768 767 if not theids.intersection(self.outstanding):
769 768 return True
770 769 self.spin()
771 770 while theids.intersection(self.outstanding):
772 771 if timeout >= 0 and ( time.time()-tic ) > timeout:
773 772 break
774 773 time.sleep(1e-3)
775 774 self.spin()
776 775 return len(theids.intersection(self.outstanding)) == 0
777 776
778 777 #--------------------------------------------------------------------------
779 778 # Control methods
780 779 #--------------------------------------------------------------------------
781 780
782 781 @spin_first
783 782 def clear(self, targets=None, block=None):
784 783 """Clear the namespace in target(s)."""
785 784 block = self.block if block is None else block
786 785 targets = self._build_targets(targets)[0]
787 786 for t in targets:
788 787 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
789 788 error = False
790 789 if block:
791 790 self._flush_ignored_control()
792 791 for i in range(len(targets)):
793 792 idents,msg = self.session.recv(self._control_socket,0)
794 793 if self.debug:
795 794 pprint(msg)
796 795 if msg['content']['status'] != 'ok':
797 796 error = self._unwrap_exception(msg['content'])
798 797 else:
799 798 self._ignored_control_replies += len(targets)
800 799 if error:
801 800 raise error
802 801
803 802
804 803 @spin_first
805 804 def abort(self, jobs=None, targets=None, block=None):
806 805 """Abort specific jobs from the execution queues of target(s).
807 806
808 807 This is a mechanism to prevent jobs that have already been submitted
809 808 from executing.
810 809
811 810 Parameters
812 811 ----------
813 812
814 813 jobs : msg_id, list of msg_ids, or AsyncResult
815 814 The jobs to be aborted
816 815
817 816
818 817 """
819 818 block = self.block if block is None else block
820 819 targets = self._build_targets(targets)[0]
821 820 msg_ids = []
822 821 if isinstance(jobs, (basestring,AsyncResult)):
823 822 jobs = [jobs]
824 823 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
825 824 if bad_ids:
826 825 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
827 826 for j in jobs:
828 827 if isinstance(j, AsyncResult):
829 828 msg_ids.extend(j.msg_ids)
830 829 else:
831 830 msg_ids.append(j)
832 831 content = dict(msg_ids=msg_ids)
833 832 for t in targets:
834 833 self.session.send(self._control_socket, 'abort_request',
835 834 content=content, ident=t)
836 835 error = False
837 836 if block:
838 837 self._flush_ignored_control()
839 838 for i in range(len(targets)):
840 839 idents,msg = self.session.recv(self._control_socket,0)
841 840 if self.debug:
842 841 pprint(msg)
843 842 if msg['content']['status'] != 'ok':
844 843 error = self._unwrap_exception(msg['content'])
845 844 else:
846 845 self._ignored_control_replies += len(targets)
847 846 if error:
848 847 raise error
849 848
850 849 @spin_first
851 850 def shutdown(self, targets=None, restart=False, hub=False, block=None):
852 851 """Terminates one or more engine processes, optionally including the hub."""
853 852 block = self.block if block is None else block
854 853 if hub:
855 854 targets = 'all'
856 855 targets = self._build_targets(targets)[0]
857 856 for t in targets:
858 857 self.session.send(self._control_socket, 'shutdown_request',
859 858 content={'restart':restart},ident=t)
860 859 error = False
861 860 if block or hub:
862 861 self._flush_ignored_control()
863 862 for i in range(len(targets)):
864 863 idents,msg = self.session.recv(self._control_socket, 0)
865 864 if self.debug:
866 865 pprint(msg)
867 866 if msg['content']['status'] != 'ok':
868 867 error = self._unwrap_exception(msg['content'])
869 868 else:
870 869 self._ignored_control_replies += len(targets)
871 870
872 871 if hub:
873 872 time.sleep(0.25)
874 873 self.session.send(self._query_socket, 'shutdown_request')
875 874 idents,msg = self.session.recv(self._query_socket, 0)
876 875 if self.debug:
877 876 pprint(msg)
878 877 if msg['content']['status'] != 'ok':
879 878 error = self._unwrap_exception(msg['content'])
880 879
881 880 if error:
882 881 raise error
883 882
884 883 #--------------------------------------------------------------------------
885 884 # Execution related methods
886 885 #--------------------------------------------------------------------------
887 886
888 887 def _maybe_raise(self, result):
889 888 """wrapper for maybe raising an exception if apply failed."""
890 889 if isinstance(result, error.RemoteError):
891 890 raise result
892 891
893 892 return result
894 893
895 894 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
896 895 ident=None):
897 896 """construct and send an apply message via a socket.
898 897
899 898 This is the principal method with which all engine execution is performed by views.
900 899 """
901 900
902 901 assert not self._closed, "cannot use me anymore, I'm closed!"
903 902 # defaults:
904 903 args = args if args is not None else []
905 904 kwargs = kwargs if kwargs is not None else {}
906 905 subheader = subheader if subheader is not None else {}
907 906
908 907 # validate arguments
909 908 if not callable(f):
910 909 raise TypeError("f must be callable, not %s"%type(f))
911 910 if not isinstance(args, (tuple, list)):
912 911 raise TypeError("args must be tuple or list, not %s"%type(args))
913 912 if not isinstance(kwargs, dict):
914 913 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
915 914 if not isinstance(subheader, dict):
916 915 raise TypeError("subheader must be dict, not %s"%type(subheader))
917 916
918 917 bufs = util.pack_apply_message(f,args,kwargs)
919 918
920 919 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
921 920 subheader=subheader, track=track)
922 921
923 922 msg_id = msg['msg_id']
924 923 self.outstanding.add(msg_id)
925 924 if ident:
926 925 # possibly routed to a specific engine
927 926 if isinstance(ident, list):
928 927 ident = ident[-1]
929 928 if ident in self._engines.values():
930 929 # save for later, in case of engine death
931 930 self._outstanding_dict[ident].add(msg_id)
932 931 self.history.append(msg_id)
933 932 self.metadata[msg_id]['submitted'] = datetime.now()
934 933
935 934 return msg
936 935
937 936 #--------------------------------------------------------------------------
938 937 # construct a View object
939 938 #--------------------------------------------------------------------------
940 939
941 940 def load_balanced_view(self, targets=None):
942 941 """construct a DirectView object.
943 942
944 943 If no arguments are specified, create a LoadBalancedView
945 944 using all engines.
946 945
947 946 Parameters
948 947 ----------
949 948
950 949 targets: list,slice,int,etc. [default: use all engines]
951 950 The subset of engines across which to load-balance
952 951 """
953 952 if targets is not None:
954 953 targets = self._build_targets(targets)[1]
955 954 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
956 955
957 956 def direct_view(self, targets='all'):
958 957 """construct a DirectView object.
959 958
960 959 If no targets are specified, create a DirectView
961 960 using all engines.
962 961
963 962 Parameters
964 963 ----------
965 964
966 965 targets: list,slice,int,etc. [default: use all engines]
967 966 The engines to use for the View
968 967 """
969 968 single = isinstance(targets, int)
970 969 targets = self._build_targets(targets)[1]
971 970 if single:
972 971 targets = targets[0]
973 972 return DirectView(client=self, socket=self._mux_socket, targets=targets)
974 973
975 974 #--------------------------------------------------------------------------
976 975 # Query methods
977 976 #--------------------------------------------------------------------------
978 977
979 978 @spin_first
980 979 def get_result(self, indices_or_msg_ids=None, block=None):
981 980 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
982 981
983 982 If the client already has the results, no request to the Hub will be made.
984 983
985 984 This is a convenient way to construct AsyncResult objects, which are wrappers
986 985 that include metadata about execution, and allow for awaiting results that
987 986 were not submitted by this Client.
988 987
989 988 It can also be a convenient way to retrieve the metadata associated with
990 989 blocking execution, since it always retrieves
991 990
992 991 Examples
993 992 --------
994 993 ::
995 994
996 995 In [10]: r = client.apply()
997 996
998 997 Parameters
999 998 ----------
1000 999
1001 1000 indices_or_msg_ids : integer history index, str msg_id, or list of either
1002 1001 The indices or msg_ids of indices to be retrieved
1003 1002
1004 1003 block : bool
1005 1004 Whether to wait for the result to be done
1006 1005
1007 1006 Returns
1008 1007 -------
1009 1008
1010 1009 AsyncResult
1011 1010 A single AsyncResult object will always be returned.
1012 1011
1013 1012 AsyncHubResult
1014 1013 A subclass of AsyncResult that retrieves results from the Hub
1015 1014
1016 1015 """
1017 1016 block = self.block if block is None else block
1018 1017 if indices_or_msg_ids is None:
1019 1018 indices_or_msg_ids = -1
1020 1019
1021 1020 if not isinstance(indices_or_msg_ids, (list,tuple)):
1022 1021 indices_or_msg_ids = [indices_or_msg_ids]
1023 1022
1024 1023 theids = []
1025 1024 for id in indices_or_msg_ids:
1026 1025 if isinstance(id, int):
1027 1026 id = self.history[id]
1028 1027 if not isinstance(id, str):
1029 1028 raise TypeError("indices must be str or int, not %r"%id)
1030 1029 theids.append(id)
1031 1030
1032 1031 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1033 1032 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1034 1033
1035 1034 if remote_ids:
1036 1035 ar = AsyncHubResult(self, msg_ids=theids)
1037 1036 else:
1038 1037 ar = AsyncResult(self, msg_ids=theids)
1039 1038
1040 1039 if block:
1041 1040 ar.wait()
1042 1041
1043 1042 return ar
1044 1043
1045 1044 @spin_first
1046 1045 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1047 1046 """Resubmit one or more tasks.
1048 1047
1049 1048 in-flight tasks may not be resubmitted.
1050 1049
1051 1050 Parameters
1052 1051 ----------
1053 1052
1054 1053 indices_or_msg_ids : integer history index, str msg_id, or list of either
1055 1054 The indices or msg_ids of indices to be retrieved
1056 1055
1057 1056 block : bool
1058 1057 Whether to wait for the result to be done
1059 1058
1060 1059 Returns
1061 1060 -------
1062 1061
1063 1062 AsyncHubResult
1064 1063 A subclass of AsyncResult that retrieves results from the Hub
1065 1064
1066 1065 """
1067 1066 block = self.block if block is None else block
1068 1067 if indices_or_msg_ids is None:
1069 1068 indices_or_msg_ids = -1
1070 1069
1071 1070 if not isinstance(indices_or_msg_ids, (list,tuple)):
1072 1071 indices_or_msg_ids = [indices_or_msg_ids]
1073 1072
1074 1073 theids = []
1075 1074 for id in indices_or_msg_ids:
1076 1075 if isinstance(id, int):
1077 1076 id = self.history[id]
1078 1077 if not isinstance(id, str):
1079 1078 raise TypeError("indices must be str or int, not %r"%id)
1080 1079 theids.append(id)
1081 1080
1082 1081 for msg_id in theids:
1083 1082 self.outstanding.discard(msg_id)
1084 1083 if msg_id in self.history:
1085 1084 self.history.remove(msg_id)
1086 1085 self.results.pop(msg_id, None)
1087 1086 self.metadata.pop(msg_id, None)
1088 1087 content = dict(msg_ids = theids)
1089 1088
1090 1089 self.session.send(self._query_socket, 'resubmit_request', content)
1091 1090
1092 1091 zmq.select([self._query_socket], [], [])
1093 1092 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1094 1093 if self.debug:
1095 1094 pprint(msg)
1096 1095 content = msg['content']
1097 1096 if content['status'] != 'ok':
1098 1097 raise self._unwrap_exception(content)
1099 1098
1100 1099 ar = AsyncHubResult(self, msg_ids=theids)
1101 1100
1102 1101 if block:
1103 1102 ar.wait()
1104 1103
1105 1104 return ar
1106 1105
1107 1106 @spin_first
1108 1107 def result_status(self, msg_ids, status_only=True):
1109 1108 """Check on the status of the result(s) of the apply request with `msg_ids`.
1110 1109
1111 1110 If status_only is False, then the actual results will be retrieved, else
1112 1111 only the status of the results will be checked.
1113 1112
1114 1113 Parameters
1115 1114 ----------
1116 1115
1117 1116 msg_ids : list of msg_ids
1118 1117 if int:
1119 1118 Passed as index to self.history for convenience.
1120 1119 status_only : bool (default: True)
1121 1120 if False:
1122 1121 Retrieve the actual results of completed tasks.
1123 1122
1124 1123 Returns
1125 1124 -------
1126 1125
1127 1126 results : dict
1128 1127 There will always be the keys 'pending' and 'completed', which will
1129 1128 be lists of msg_ids that are incomplete or complete. If `status_only`
1130 1129 is False, then completed results will be keyed by their `msg_id`.
1131 1130 """
1132 1131 if not isinstance(msg_ids, (list,tuple)):
1133 1132 msg_ids = [msg_ids]
1134 1133
1135 1134 theids = []
1136 1135 for msg_id in msg_ids:
1137 1136 if isinstance(msg_id, int):
1138 1137 msg_id = self.history[msg_id]
1139 1138 if not isinstance(msg_id, basestring):
1140 1139 raise TypeError("msg_ids must be str, not %r"%msg_id)
1141 1140 theids.append(msg_id)
1142 1141
1143 1142 completed = []
1144 1143 local_results = {}
1145 1144
1146 1145 # comment this block out to temporarily disable local shortcut:
1147 1146 for msg_id in theids:
1148 1147 if msg_id in self.results:
1149 1148 completed.append(msg_id)
1150 1149 local_results[msg_id] = self.results[msg_id]
1151 1150 theids.remove(msg_id)
1152 1151
1153 1152 if theids: # some not locally cached
1154 1153 content = dict(msg_ids=theids, status_only=status_only)
1155 1154 msg = self.session.send(self._query_socket, "result_request", content=content)
1156 1155 zmq.select([self._query_socket], [], [])
1157 1156 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1158 1157 if self.debug:
1159 1158 pprint(msg)
1160 1159 content = msg['content']
1161 1160 if content['status'] != 'ok':
1162 1161 raise self._unwrap_exception(content)
1163 1162 buffers = msg['buffers']
1164 1163 else:
1165 1164 content = dict(completed=[],pending=[])
1166 1165
1167 1166 content['completed'].extend(completed)
1168 1167
1169 1168 if status_only:
1170 1169 return content
1171 1170
1172 1171 failures = []
1173 1172 # load cached results into result:
1174 1173 content.update(local_results)
1174 content = extract_dates(content)
1175 1175 # update cache with results:
1176 1176 for msg_id in sorted(theids):
1177 1177 if msg_id in content['completed']:
1178 1178 rec = content[msg_id]
1179 1179 parent = rec['header']
1180 1180 header = rec['result_header']
1181 1181 rcontent = rec['result_content']
1182 1182 iodict = rec['io']
1183 1183 if isinstance(rcontent, str):
1184 1184 rcontent = self.session.unpack(rcontent)
1185 1185
1186 1186 md = self.metadata[msg_id]
1187 1187 md.update(self._extract_metadata(header, parent, rcontent))
1188 1188 md.update(iodict)
1189 1189
1190 1190 if rcontent['status'] == 'ok':
1191 1191 res,buffers = util.unserialize_object(buffers)
1192 1192 else:
1193 1193 print rcontent
1194 1194 res = self._unwrap_exception(rcontent)
1195 1195 failures.append(res)
1196 1196
1197 1197 self.results[msg_id] = res
1198 1198 content[msg_id] = res
1199 1199
1200 1200 if len(theids) == 1 and failures:
1201 1201 raise failures[0]
1202 1202
1203 1203 error.collect_exceptions(failures, "result_status")
1204 1204 return content
1205 1205
1206 1206 @spin_first
1207 1207 def queue_status(self, targets='all', verbose=False):
1208 1208 """Fetch the status of engine queues.
1209 1209
1210 1210 Parameters
1211 1211 ----------
1212 1212
1213 1213 targets : int/str/list of ints/strs
1214 1214 the engines whose states are to be queried.
1215 1215 default : all
1216 1216 verbose : bool
1217 1217 Whether to return lengths only, or lists of ids for each element
1218 1218 """
1219 1219 engine_ids = self._build_targets(targets)[1]
1220 1220 content = dict(targets=engine_ids, verbose=verbose)
1221 1221 self.session.send(self._query_socket, "queue_request", content=content)
1222 1222 idents,msg = self.session.recv(self._query_socket, 0)
1223 1223 if self.debug:
1224 1224 pprint(msg)
1225 1225 content = msg['content']
1226 1226 status = content.pop('status')
1227 1227 if status != 'ok':
1228 1228 raise self._unwrap_exception(content)
1229 1229 content = util.rekey(content)
1230 1230 if isinstance(targets, int):
1231 1231 return content[targets]
1232 1232 else:
1233 1233 return content
1234 1234
1235 1235 @spin_first
1236 1236 def purge_results(self, jobs=[], targets=[]):
1237 1237 """Tell the Hub to forget results.
1238 1238
1239 1239 Individual results can be purged by msg_id, or the entire
1240 1240 history of specific targets can be purged.
1241 1241
1242 1242 Parameters
1243 1243 ----------
1244 1244
1245 1245 jobs : str or list of str or AsyncResult objects
1246 1246 the msg_ids whose results should be forgotten.
1247 1247 targets : int/str/list of ints/strs
1248 1248 The targets, by uuid or int_id, whose entire history is to be purged.
1249 1249 Use `targets='all'` to scrub everything from the Hub's memory.
1250 1250
1251 1251 default : None
1252 1252 """
1253 1253 if not targets and not jobs:
1254 1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1255 1255 if targets:
1256 1256 targets = self._build_targets(targets)[1]
1257 1257
1258 1258 # construct msg_ids from jobs
1259 1259 msg_ids = []
1260 1260 if isinstance(jobs, (basestring,AsyncResult)):
1261 1261 jobs = [jobs]
1262 1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1263 1263 if bad_ids:
1264 1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1265 1265 for j in jobs:
1266 1266 if isinstance(j, AsyncResult):
1267 1267 msg_ids.extend(j.msg_ids)
1268 1268 else:
1269 1269 msg_ids.append(j)
1270 1270
1271 1271 content = dict(targets=targets, msg_ids=msg_ids)
1272 1272 self.session.send(self._query_socket, "purge_request", content=content)
1273 1273 idents, msg = self.session.recv(self._query_socket, 0)
1274 1274 if self.debug:
1275 1275 pprint(msg)
1276 1276 content = msg['content']
1277 1277 if content['status'] != 'ok':
1278 1278 raise self._unwrap_exception(content)
1279 1279
1280 1280 @spin_first
1281 1281 def hub_history(self):
1282 1282 """Get the Hub's history
1283 1283
1284 1284 Just like the Client, the Hub has a history, which is a list of msg_ids.
1285 1285 This will contain the history of all clients, and, depending on configuration,
1286 1286 may contain history across multiple cluster sessions.
1287 1287
1288 1288 Any msg_id returned here is a valid argument to `get_result`.
1289 1289
1290 1290 Returns
1291 1291 -------
1292 1292
1293 1293 msg_ids : list of strs
1294 1294 list of all msg_ids, ordered by task submission time.
1295 1295 """
1296 1296
1297 1297 self.session.send(self._query_socket, "history_request", content={})
1298 1298 idents, msg = self.session.recv(self._query_socket, 0)
1299 1299
1300 1300 if self.debug:
1301 1301 pprint(msg)
1302 1302 content = msg['content']
1303 1303 if content['status'] != 'ok':
1304 1304 raise self._unwrap_exception(content)
1305 1305 else:
1306 1306 return content['history']
1307 1307
1308 1308 @spin_first
1309 1309 def db_query(self, query, keys=None):
1310 1310 """Query the Hub's TaskRecord database
1311 1311
1312 1312 This will return a list of task record dicts that match `query`
1313 1313
1314 1314 Parameters
1315 1315 ----------
1316 1316
1317 1317 query : mongodb query dict
1318 1318 The search dict. See mongodb query docs for details.
1319 1319 keys : list of strs [optional]
1320 1320 The subset of keys to be returned. The default is to fetch everything but buffers.
1321 1321 'msg_id' will *always* be included.
1322 1322 """
1323 1323 if isinstance(keys, basestring):
1324 1324 keys = [keys]
1325 1325 content = dict(query=query, keys=keys)
1326 1326 self.session.send(self._query_socket, "db_request", content=content)
1327 1327 idents, msg = self.session.recv(self._query_socket, 0)
1328 1328 if self.debug:
1329 1329 pprint(msg)
1330 1330 content = msg['content']
1331 1331 if content['status'] != 'ok':
1332 1332 raise self._unwrap_exception(content)
1333 1333
1334 1334 records = content['records']
1335 1335 buffer_lens = content['buffer_lens']
1336 1336 result_buffer_lens = content['result_buffer_lens']
1337 1337 buffers = msg['buffers']
1338 1338 has_bufs = buffer_lens is not None
1339 1339 has_rbufs = result_buffer_lens is not None
1340 1340 for i,rec in enumerate(records):
1341 # unpack timestamps
1342 rec = extract_dates(rec)
1341 1343 # relink buffers
1342 1344 if has_bufs:
1343 1345 blen = buffer_lens[i]
1344 1346 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1345 1347 if has_rbufs:
1346 1348 blen = result_buffer_lens[i]
1347 1349 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1348 # turn timestamps back into times
1349 for key in 'submitted started completed resubmitted'.split():
1350 maybedate = rec.get(key, None)
1351 if maybedate and util.ISO8601_RE.match(maybedate):
1352 rec[key] = datetime.strptime(maybedate, util.ISO8601)
1353 1350
1354 1351 return records
1355 1352
1356 1353 __all__ = [ 'Client' ]
@@ -1,1277 +1,1274 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import (
29 29 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CStr
30 30 )
31 from IPython.utils.jsonutil import ISO8601, extract_dates
31 32
32 33 from IPython.parallel import error, util
33 34 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
34 35
35 36 from .heartmonitor import HeartMonitor
36 37
37 38 #-----------------------------------------------------------------------------
38 39 # Code
39 40 #-----------------------------------------------------------------------------
40 41
41 42 def _passer(*args, **kwargs):
42 43 return
43 44
44 45 def _printer(*args, **kwargs):
45 46 print (args)
46 47 print (kwargs)
47 48
48 49 def empty_record():
49 50 """Return an empty dict with all record keys."""
50 51 return {
51 52 'msg_id' : None,
52 53 'header' : None,
53 54 'content': None,
54 55 'buffers': None,
55 56 'submitted': None,
56 57 'client_uuid' : None,
57 58 'engine_uuid' : None,
58 59 'started': None,
59 60 'completed': None,
60 61 'resubmitted': None,
61 62 'result_header' : None,
62 63 'result_content' : None,
63 64 'result_buffers' : None,
64 65 'queue' : None,
65 66 'pyin' : None,
66 67 'pyout': None,
67 68 'pyerr': None,
68 69 'stdout': '',
69 70 'stderr': '',
70 71 }
71 72
72 73 def init_record(msg):
73 74 """Initialize a TaskRecord based on a request."""
74 header = msg['header']
75 header = extract_dates(msg['header'])
75 76 return {
76 77 'msg_id' : header['msg_id'],
77 78 'header' : header,
78 79 'content': msg['content'],
79 80 'buffers': msg['buffers'],
80 'submitted': datetime.strptime(header['date'], util.ISO8601),
81 'submitted': header['date'],
81 82 'client_uuid' : None,
82 83 'engine_uuid' : None,
83 84 'started': None,
84 85 'completed': None,
85 86 'resubmitted': None,
86 87 'result_header' : None,
87 88 'result_content' : None,
88 89 'result_buffers' : None,
89 90 'queue' : None,
90 91 'pyin' : None,
91 92 'pyout': None,
92 93 'pyerr': None,
93 94 'stdout': '',
94 95 'stderr': '',
95 96 }
96 97
97 98
98 99 class EngineConnector(HasTraits):
99 100 """A simple object for accessing the various zmq connections of an object.
100 101 Attributes are:
101 102 id (int): engine ID
102 103 uuid (str): uuid (unused?)
103 104 queue (str): identity of queue's XREQ socket
104 105 registration (str): identity of registration XREQ socket
105 106 heartbeat (str): identity of heartbeat XREQ socket
106 107 """
107 108 id=Int(0)
108 109 queue=CStr()
109 110 control=CStr()
110 111 registration=CStr()
111 112 heartbeat=CStr()
112 113 pending=Set()
113 114
114 115 class HubFactory(RegistrationFactory):
115 116 """The Configurable for setting up a Hub."""
116 117
117 118 # port-pairs for monitoredqueues:
118 119 hb = Tuple(Int,Int,config=True,
119 120 help="""XREQ/SUB Port pair for Engine heartbeats""")
120 121 def _hb_default(self):
121 122 return tuple(util.select_random_ports(2))
122 123
123 124 mux = Tuple(Int,Int,config=True,
124 125 help="""Engine/Client Port pair for MUX queue""")
125 126
126 127 def _mux_default(self):
127 128 return tuple(util.select_random_ports(2))
128 129
129 130 task = Tuple(Int,Int,config=True,
130 131 help="""Engine/Client Port pair for Task queue""")
131 132 def _task_default(self):
132 133 return tuple(util.select_random_ports(2))
133 134
134 135 control = Tuple(Int,Int,config=True,
135 136 help="""Engine/Client Port pair for Control queue""")
136 137
137 138 def _control_default(self):
138 139 return tuple(util.select_random_ports(2))
139 140
140 141 iopub = Tuple(Int,Int,config=True,
141 142 help="""Engine/Client Port pair for IOPub relay""")
142 143
143 144 def _iopub_default(self):
144 145 return tuple(util.select_random_ports(2))
145 146
146 147 # single ports:
147 148 mon_port = Int(config=True,
148 149 help="""Monitor (SUB) port for queue traffic""")
149 150
150 151 def _mon_port_default(self):
151 152 return util.select_random_ports(1)[0]
152 153
153 154 notifier_port = Int(config=True,
154 155 help="""PUB port for sending engine status notifications""")
155 156
156 157 def _notifier_port_default(self):
157 158 return util.select_random_ports(1)[0]
158 159
159 160 engine_ip = Unicode('127.0.0.1', config=True,
160 161 help="IP on which to listen for engine connections. [default: loopback]")
161 162 engine_transport = Unicode('tcp', config=True,
162 163 help="0MQ transport for engine connections. [default: tcp]")
163 164
164 165 client_ip = Unicode('127.0.0.1', config=True,
165 166 help="IP on which to listen for client connections. [default: loopback]")
166 167 client_transport = Unicode('tcp', config=True,
167 168 help="0MQ transport for client connections. [default : tcp]")
168 169
169 170 monitor_ip = Unicode('127.0.0.1', config=True,
170 171 help="IP on which to listen for monitor messages. [default: loopback]")
171 172 monitor_transport = Unicode('tcp', config=True,
172 173 help="0MQ transport for monitor messages. [default : tcp]")
173 174
174 175 monitor_url = Unicode('')
175 176
176 177 db_class = Unicode('IPython.parallel.controller.dictdb.DictDB', config=True,
177 178 help="""The class to use for the DB backend""")
178 179
179 180 # not configurable
180 181 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
181 182 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
182 183
183 184 def _ip_changed(self, name, old, new):
184 185 self.engine_ip = new
185 186 self.client_ip = new
186 187 self.monitor_ip = new
187 188 self._update_monitor_url()
188 189
189 190 def _update_monitor_url(self):
190 191 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
191 192
192 193 def _transport_changed(self, name, old, new):
193 194 self.engine_transport = new
194 195 self.client_transport = new
195 196 self.monitor_transport = new
196 197 self._update_monitor_url()
197 198
198 199 def __init__(self, **kwargs):
199 200 super(HubFactory, self).__init__(**kwargs)
200 201 self._update_monitor_url()
201 202
202 203
203 204 def construct(self):
204 205 self.init_hub()
205 206
206 207 def start(self):
207 208 self.heartmonitor.start()
208 209 self.log.info("Heartmonitor started")
209 210
210 211 def init_hub(self):
211 212 """construct"""
212 213 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
213 214 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
214 215
215 216 ctx = self.context
216 217 loop = self.loop
217 218
218 219 # Registrar socket
219 220 q = ZMQStream(ctx.socket(zmq.XREP), loop)
220 221 q.bind(client_iface % self.regport)
221 222 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
222 223 if self.client_ip != self.engine_ip:
223 224 q.bind(engine_iface % self.regport)
224 225 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
225 226
226 227 ### Engine connections ###
227 228
228 229 # heartbeat
229 230 hpub = ctx.socket(zmq.PUB)
230 231 hpub.bind(engine_iface % self.hb[0])
231 232 hrep = ctx.socket(zmq.XREP)
232 233 hrep.bind(engine_iface % self.hb[1])
233 234 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
234 235 config=self.config)
235 236
236 237 ### Client connections ###
237 238 # Notifier socket
238 239 n = ZMQStream(ctx.socket(zmq.PUB), loop)
239 240 n.bind(client_iface%self.notifier_port)
240 241
241 242 ### build and launch the queues ###
242 243
243 244 # monitor socket
244 245 sub = ctx.socket(zmq.SUB)
245 246 sub.setsockopt(zmq.SUBSCRIBE, "")
246 247 sub.bind(self.monitor_url)
247 248 sub.bind('inproc://monitor')
248 249 sub = ZMQStream(sub, loop)
249 250
250 251 # connect the db
251 252 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
252 253 # cdir = self.config.Global.cluster_dir
253 254 self.db = import_item(str(self.db_class))(session=self.session.session, config=self.config)
254 255 time.sleep(.25)
255 256 try:
256 257 scheme = self.config.TaskScheduler.scheme_name
257 258 except AttributeError:
258 259 from .scheduler import TaskScheduler
259 260 scheme = TaskScheduler.scheme_name.get_default_value()
260 261 # build connection dicts
261 262 self.engine_info = {
262 263 'control' : engine_iface%self.control[1],
263 264 'mux': engine_iface%self.mux[1],
264 265 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
265 266 'task' : engine_iface%self.task[1],
266 267 'iopub' : engine_iface%self.iopub[1],
267 268 # 'monitor' : engine_iface%self.mon_port,
268 269 }
269 270
270 271 self.client_info = {
271 272 'control' : client_iface%self.control[0],
272 273 'mux': client_iface%self.mux[0],
273 274 'task' : (scheme, client_iface%self.task[0]),
274 275 'iopub' : client_iface%self.iopub[0],
275 276 'notification': client_iface%self.notifier_port
276 277 }
277 278 self.log.debug("Hub engine addrs: %s"%self.engine_info)
278 279 self.log.debug("Hub client addrs: %s"%self.client_info)
279 280
280 281 # resubmit stream
281 282 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
282 283 url = util.disambiguate_url(self.client_info['task'][-1])
283 284 r.setsockopt(zmq.IDENTITY, self.session.session)
284 285 r.connect(url)
285 286
286 287 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
287 288 query=q, notifier=n, resubmit=r, db=self.db,
288 289 engine_info=self.engine_info, client_info=self.client_info,
289 290 logname=self.log.name)
290 291
291 292
292 293 class Hub(LoggingFactory):
293 294 """The IPython Controller Hub with 0MQ connections
294 295
295 296 Parameters
296 297 ==========
297 298 loop: zmq IOLoop instance
298 session: StreamSession object
299 session: Session object
299 300 <removed> context: zmq context for creating new connections (?)
300 301 queue: ZMQStream for monitoring the command queue (SUB)
301 302 query: ZMQStream for engine registration and client queries requests (XREP)
302 303 heartbeat: HeartMonitor object checking the pulse of the engines
303 304 notifier: ZMQStream for broadcasting engine registration changes (PUB)
304 305 db: connection to db for out of memory logging of commands
305 306 NotImplemented
306 307 engine_info: dict of zmq connection information for engines to connect
307 308 to the queues.
308 309 client_info: dict of zmq connection information for engines to connect
309 310 to the queues.
310 311 """
311 312 # internal data structures:
312 313 ids=Set() # engine IDs
313 314 keytable=Dict()
314 315 by_ident=Dict()
315 316 engines=Dict()
316 317 clients=Dict()
317 318 hearts=Dict()
318 319 pending=Set()
319 320 queues=Dict() # pending msg_ids keyed by engine_id
320 321 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
321 322 completed=Dict() # completed msg_ids keyed by engine_id
322 323 all_completed=Set() # completed msg_ids keyed by engine_id
323 324 dead_engines=Set() # completed msg_ids keyed by engine_id
324 325 unassigned=Set() # set of task msg_ds not yet assigned a destination
325 326 incoming_registrations=Dict()
326 327 registration_timeout=Int()
327 328 _idcounter=Int(0)
328 329
329 330 # objects from constructor:
330 331 loop=Instance(ioloop.IOLoop)
331 332 query=Instance(ZMQStream)
332 333 monitor=Instance(ZMQStream)
333 334 notifier=Instance(ZMQStream)
334 335 resubmit=Instance(ZMQStream)
335 336 heartmonitor=Instance(HeartMonitor)
336 337 db=Instance(object)
337 338 client_info=Dict()
338 339 engine_info=Dict()
339 340
340 341
341 342 def __init__(self, **kwargs):
342 343 """
343 344 # universal:
344 345 loop: IOLoop for creating future connections
345 346 session: streamsession for sending serialized data
346 347 # engine:
347 348 queue: ZMQStream for monitoring queue messages
348 349 query: ZMQStream for engine+client registration and client requests
349 350 heartbeat: HeartMonitor object for tracking engines
350 351 # extra:
351 352 db: ZMQStream for db connection (NotImplemented)
352 353 engine_info: zmq address/protocol dict for engine connections
353 354 client_info: zmq address/protocol dict for client connections
354 355 """
355 356
356 357 super(Hub, self).__init__(**kwargs)
357 358 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
358 359
359 360 # validate connection dicts:
360 361 for k,v in self.client_info.iteritems():
361 362 if k == 'task':
362 363 util.validate_url_container(v[1])
363 364 else:
364 365 util.validate_url_container(v)
365 366 # util.validate_url_container(self.client_info)
366 367 util.validate_url_container(self.engine_info)
367 368
368 369 # register our callbacks
369 370 self.query.on_recv(self.dispatch_query)
370 371 self.monitor.on_recv(self.dispatch_monitor_traffic)
371 372
372 373 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
373 374 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
374 375
375 376 self.monitor_handlers = { 'in' : self.save_queue_request,
376 377 'out': self.save_queue_result,
377 378 'intask': self.save_task_request,
378 379 'outtask': self.save_task_result,
379 380 'tracktask': self.save_task_destination,
380 381 'incontrol': _passer,
381 382 'outcontrol': _passer,
382 383 'iopub': self.save_iopub_message,
383 384 }
384 385
385 386 self.query_handlers = {'queue_request': self.queue_status,
386 387 'result_request': self.get_results,
387 388 'history_request': self.get_history,
388 389 'db_request': self.db_query,
389 390 'purge_request': self.purge_results,
390 391 'load_request': self.check_load,
391 392 'resubmit_request': self.resubmit_task,
392 393 'shutdown_request': self.shutdown_request,
393 394 'registration_request' : self.register_engine,
394 395 'unregistration_request' : self.unregister_engine,
395 396 'connection_request': self.connection_request,
396 397 }
397 398
398 399 # ignore resubmit replies
399 400 self.resubmit.on_recv(lambda msg: None, copy=False)
400 401
401 402 self.log.info("hub::created hub")
402 403
403 404 @property
404 405 def _next_id(self):
405 406 """gemerate a new ID.
406 407
407 408 No longer reuse old ids, just count from 0."""
408 409 newid = self._idcounter
409 410 self._idcounter += 1
410 411 return newid
411 412 # newid = 0
412 413 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
413 414 # # print newid, self.ids, self.incoming_registrations
414 415 # while newid in self.ids or newid in incoming:
415 416 # newid += 1
416 417 # return newid
417 418
418 419 #-----------------------------------------------------------------------------
419 420 # message validation
420 421 #-----------------------------------------------------------------------------
421 422
422 423 def _validate_targets(self, targets):
423 424 """turn any valid targets argument into a list of integer ids"""
424 425 if targets is None:
425 426 # default to all
426 427 targets = self.ids
427 428
428 429 if isinstance(targets, (int,str,unicode)):
429 430 # only one target specified
430 431 targets = [targets]
431 432 _targets = []
432 433 for t in targets:
433 434 # map raw identities to ids
434 435 if isinstance(t, (str,unicode)):
435 436 t = self.by_ident.get(t, t)
436 437 _targets.append(t)
437 438 targets = _targets
438 439 bad_targets = [ t for t in targets if t not in self.ids ]
439 440 if bad_targets:
440 441 raise IndexError("No Such Engine: %r"%bad_targets)
441 442 if not targets:
442 443 raise IndexError("No Engines Registered")
443 444 return targets
444 445
445 446 #-----------------------------------------------------------------------------
446 447 # dispatch methods (1 per stream)
447 448 #-----------------------------------------------------------------------------
448 449
449 450
450 451 def dispatch_monitor_traffic(self, msg):
451 452 """all ME and Task queue messages come through here, as well as
452 453 IOPub traffic."""
453 454 self.log.debug("monitor traffic: %r"%msg[:2])
454 455 switch = msg[0]
455 456 try:
456 457 idents, msg = self.session.feed_identities(msg[1:])
457 458 except ValueError:
458 459 idents=[]
459 460 if not idents:
460 461 self.log.error("Bad Monitor Message: %r"%msg)
461 462 return
462 463 handler = self.monitor_handlers.get(switch, None)
463 464 if handler is not None:
464 465 handler(idents, msg)
465 466 else:
466 467 self.log.error("Invalid monitor topic: %r"%switch)
467 468
468 469
469 470 def dispatch_query(self, msg):
470 471 """Route registration requests and queries from clients."""
471 472 try:
472 473 idents, msg = self.session.feed_identities(msg)
473 474 except ValueError:
474 475 idents = []
475 476 if not idents:
476 477 self.log.error("Bad Query Message: %r"%msg)
477 478 return
478 479 client_id = idents[0]
479 480 try:
480 481 msg = self.session.unpack_message(msg, content=True)
481 482 except Exception:
482 483 content = error.wrap_exception()
483 484 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
484 485 self.session.send(self.query, "hub_error", ident=client_id,
485 486 content=content)
486 487 return
487 488 print( idents, msg)
488 489 # print client_id, header, parent, content
489 490 #switch on message type:
490 491 msg_type = msg['msg_type']
491 492 self.log.info("client::client %r requested %r"%(client_id, msg_type))
492 493 handler = self.query_handlers.get(msg_type, None)
493 494 try:
494 495 assert handler is not None, "Bad Message Type: %r"%msg_type
495 496 except:
496 497 content = error.wrap_exception()
497 498 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
498 499 self.session.send(self.query, "hub_error", ident=client_id,
499 500 content=content)
500 501 return
501 502
502 503 else:
503 504 handler(idents, msg)
504 505
505 506 def dispatch_db(self, msg):
506 507 """"""
507 508 raise NotImplementedError
508 509
509 510 #---------------------------------------------------------------------------
510 511 # handler methods (1 per event)
511 512 #---------------------------------------------------------------------------
512 513
513 514 #----------------------- Heartbeat --------------------------------------
514 515
515 516 def handle_new_heart(self, heart):
516 517 """handler to attach to heartbeater.
517 518 Called when a new heart starts to beat.
518 519 Triggers completion of registration."""
519 520 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
520 521 if heart not in self.incoming_registrations:
521 522 self.log.info("heartbeat::ignoring new heart: %r"%heart)
522 523 else:
523 524 self.finish_registration(heart)
524 525
525 526
526 527 def handle_heart_failure(self, heart):
527 528 """handler to attach to heartbeater.
528 529 called when a previously registered heart fails to respond to beat request.
529 530 triggers unregistration"""
530 531 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
531 532 eid = self.hearts.get(heart, None)
532 533 queue = self.engines[eid].queue
533 534 if eid is None:
534 535 self.log.info("heartbeat::ignoring heart failure %r"%heart)
535 536 else:
536 537 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
537 538
538 539 #----------------------- MUX Queue Traffic ------------------------------
539 540
540 541 def save_queue_request(self, idents, msg):
541 542 if len(idents) < 2:
542 543 self.log.error("invalid identity prefix: %r"%idents)
543 544 return
544 545 queue_id, client_id = idents[:2]
545 546 try:
546 547 msg = self.session.unpack_message(msg, content=False)
547 548 except Exception:
548 549 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
549 550 return
550 551
551 552 eid = self.by_ident.get(queue_id, None)
552 553 if eid is None:
553 554 self.log.error("queue::target %r not registered"%queue_id)
554 555 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
555 556 return
556 557
557 558 header = msg['header']
558 559 msg_id = header['msg_id']
559 560 record = init_record(msg)
560 561 record['engine_uuid'] = queue_id
561 562 record['client_uuid'] = client_id
562 563 record['queue'] = 'mux'
563 564
564 565 try:
565 566 # it's posible iopub arrived first:
566 567 existing = self.db.get_record(msg_id)
567 568 for key,evalue in existing.iteritems():
568 569 rvalue = record.get(key, None)
569 570 if evalue and rvalue and evalue != rvalue:
570 571 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
571 572 elif evalue and not rvalue:
572 573 record[key] = evalue
573 574 self.db.update_record(msg_id, record)
574 575 except KeyError:
575 576 self.db.add_record(msg_id, record)
576 577
577 578 self.pending.add(msg_id)
578 579 self.queues[eid].append(msg_id)
579 580
580 581 def save_queue_result(self, idents, msg):
581 582 if len(idents) < 2:
582 583 self.log.error("invalid identity prefix: %r"%idents)
583 584 return
584 585
585 586 client_id, queue_id = idents[:2]
586 587 try:
587 588 msg = self.session.unpack_message(msg, content=False)
588 589 except Exception:
589 590 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
590 591 queue_id,client_id, msg), exc_info=True)
591 592 return
592 593
593 594 eid = self.by_ident.get(queue_id, None)
594 595 if eid is None:
595 596 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
596 597 return
597 598
598 599 parent = msg['parent_header']
599 600 if not parent:
600 601 return
601 602 msg_id = parent['msg_id']
602 603 if msg_id in self.pending:
603 604 self.pending.remove(msg_id)
604 605 self.all_completed.add(msg_id)
605 606 self.queues[eid].remove(msg_id)
606 607 self.completed[eid].append(msg_id)
607 608 elif msg_id not in self.all_completed:
608 609 # it could be a result from a dead engine that died before delivering the
609 610 # result
610 611 self.log.warn("queue:: unknown msg finished %r"%msg_id)
611 612 return
612 613 # update record anyway, because the unregistration could have been premature
613 rheader = msg['header']
614 completed = datetime.strptime(rheader['date'], util.ISO8601)
614 rheader = extract_dates(msg['header'])
615 completed = rheader['date']
615 616 started = rheader.get('started', None)
616 if started is not None:
617 started = datetime.strptime(started, util.ISO8601)
618 617 result = {
619 618 'result_header' : rheader,
620 619 'result_content': msg['content'],
621 620 'started' : started,
622 621 'completed' : completed
623 622 }
624 623
625 624 result['result_buffers'] = msg['buffers']
626 625 try:
627 626 self.db.update_record(msg_id, result)
628 627 except Exception:
629 628 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
630 629
631 630
632 631 #--------------------- Task Queue Traffic ------------------------------
633 632
634 633 def save_task_request(self, idents, msg):
635 634 """Save the submission of a task."""
636 635 client_id = idents[0]
637 636
638 637 try:
639 638 msg = self.session.unpack_message(msg, content=False)
640 639 except Exception:
641 640 self.log.error("task::client %r sent invalid task message: %r"%(
642 641 client_id, msg), exc_info=True)
643 642 return
644 643 record = init_record(msg)
645 644
646 645 record['client_uuid'] = client_id
647 646 record['queue'] = 'task'
648 647 header = msg['header']
649 648 msg_id = header['msg_id']
650 649 self.pending.add(msg_id)
651 650 self.unassigned.add(msg_id)
652 651 try:
653 652 # it's posible iopub arrived first:
654 653 existing = self.db.get_record(msg_id)
655 654 if existing['resubmitted']:
656 655 for key in ('submitted', 'client_uuid', 'buffers'):
657 656 # don't clobber these keys on resubmit
658 657 # submitted and client_uuid should be different
659 658 # and buffers might be big, and shouldn't have changed
660 659 record.pop(key)
661 660 # still check content,header which should not change
662 661 # but are not expensive to compare as buffers
663 662
664 663 for key,evalue in existing.iteritems():
665 664 if key.endswith('buffers'):
666 665 # don't compare buffers
667 666 continue
668 667 rvalue = record.get(key, None)
669 668 if evalue and rvalue and evalue != rvalue:
670 669 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
671 670 elif evalue and not rvalue:
672 671 record[key] = evalue
673 672 self.db.update_record(msg_id, record)
674 673 except KeyError:
675 674 self.db.add_record(msg_id, record)
676 675 except Exception:
677 676 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
678 677
679 678 def save_task_result(self, idents, msg):
680 679 """save the result of a completed task."""
681 680 client_id = idents[0]
682 681 try:
683 682 msg = self.session.unpack_message(msg, content=False)
684 683 except Exception:
685 684 self.log.error("task::invalid task result message send to %r: %r"%(
686 685 client_id, msg), exc_info=True)
687 686 return
688 687
689 688 parent = msg['parent_header']
690 689 if not parent:
691 690 # print msg
692 691 self.log.warn("Task %r had no parent!"%msg)
693 692 return
694 693 msg_id = parent['msg_id']
695 694 if msg_id in self.unassigned:
696 695 self.unassigned.remove(msg_id)
697 696
698 header = msg['header']
697 header = extract_dates(msg['header'])
699 698 engine_uuid = header.get('engine', None)
700 699 eid = self.by_ident.get(engine_uuid, None)
701 700
702 701 if msg_id in self.pending:
703 702 self.pending.remove(msg_id)
704 703 self.all_completed.add(msg_id)
705 704 if eid is not None:
706 705 self.completed[eid].append(msg_id)
707 706 if msg_id in self.tasks[eid]:
708 707 self.tasks[eid].remove(msg_id)
709 completed = datetime.strptime(header['date'], util.ISO8601)
708 completed = header['date']
710 709 started = header.get('started', None)
711 if started is not None:
712 started = datetime.strptime(started, util.ISO8601)
713 710 result = {
714 711 'result_header' : header,
715 712 'result_content': msg['content'],
716 713 'started' : started,
717 714 'completed' : completed,
718 715 'engine_uuid': engine_uuid
719 716 }
720 717
721 718 result['result_buffers'] = msg['buffers']
722 719 try:
723 720 self.db.update_record(msg_id, result)
724 721 except Exception:
725 722 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
726 723
727 724 else:
728 725 self.log.debug("task::unknown task %r finished"%msg_id)
729 726
730 727 def save_task_destination(self, idents, msg):
731 728 try:
732 729 msg = self.session.unpack_message(msg, content=True)
733 730 except Exception:
734 731 self.log.error("task::invalid task tracking message", exc_info=True)
735 732 return
736 733 content = msg['content']
737 734 # print (content)
738 735 msg_id = content['msg_id']
739 736 engine_uuid = content['engine_id']
740 737 eid = self.by_ident[engine_uuid]
741 738
742 739 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
743 740 if msg_id in self.unassigned:
744 741 self.unassigned.remove(msg_id)
745 742 # else:
746 743 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
747 744
748 745 self.tasks[eid].append(msg_id)
749 746 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
750 747 try:
751 748 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
752 749 except Exception:
753 750 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
754 751
755 752
756 753 def mia_task_request(self, idents, msg):
757 754 raise NotImplementedError
758 755 client_id = idents[0]
759 756 # content = dict(mia=self.mia,status='ok')
760 757 # self.session.send('mia_reply', content=content, idents=client_id)
761 758
762 759
763 760 #--------------------- IOPub Traffic ------------------------------
764 761
765 762 def save_iopub_message(self, topics, msg):
766 763 """save an iopub message into the db"""
767 764 # print (topics)
768 765 try:
769 766 msg = self.session.unpack_message(msg, content=True)
770 767 except Exception:
771 768 self.log.error("iopub::invalid IOPub message", exc_info=True)
772 769 return
773 770
774 771 parent = msg['parent_header']
775 772 if not parent:
776 773 self.log.error("iopub::invalid IOPub message: %r"%msg)
777 774 return
778 775 msg_id = parent['msg_id']
779 776 msg_type = msg['msg_type']
780 777 content = msg['content']
781 778
782 779 # ensure msg_id is in db
783 780 try:
784 781 rec = self.db.get_record(msg_id)
785 782 except KeyError:
786 783 rec = empty_record()
787 784 rec['msg_id'] = msg_id
788 785 self.db.add_record(msg_id, rec)
789 786 # stream
790 787 d = {}
791 788 if msg_type == 'stream':
792 789 name = content['name']
793 790 s = rec[name] or ''
794 791 d[name] = s + content['data']
795 792
796 793 elif msg_type == 'pyerr':
797 794 d['pyerr'] = content
798 795 elif msg_type == 'pyin':
799 796 d['pyin'] = content['code']
800 797 else:
801 798 d[msg_type] = content.get('data', '')
802 799
803 800 try:
804 801 self.db.update_record(msg_id, d)
805 802 except Exception:
806 803 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
807 804
808 805
809 806
810 807 #-------------------------------------------------------------------------
811 808 # Registration requests
812 809 #-------------------------------------------------------------------------
813 810
814 811 def connection_request(self, client_id, msg):
815 812 """Reply with connection addresses for clients."""
816 813 self.log.info("client::client %r connected"%client_id)
817 814 content = dict(status='ok')
818 815 content.update(self.client_info)
819 816 jsonable = {}
820 817 for k,v in self.keytable.iteritems():
821 818 if v not in self.dead_engines:
822 819 jsonable[str(k)] = v
823 820 content['engines'] = jsonable
824 821 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
825 822
826 823 def register_engine(self, reg, msg):
827 824 """Register a new engine."""
828 825 content = msg['content']
829 826 try:
830 827 queue = content['queue']
831 828 except KeyError:
832 829 self.log.error("registration::queue not specified", exc_info=True)
833 830 return
834 831 heart = content.get('heartbeat', None)
835 832 """register a new engine, and create the socket(s) necessary"""
836 833 eid = self._next_id
837 834 # print (eid, queue, reg, heart)
838 835
839 836 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
840 837
841 838 content = dict(id=eid,status='ok')
842 839 content.update(self.engine_info)
843 840 # check if requesting available IDs:
844 841 if queue in self.by_ident:
845 842 try:
846 843 raise KeyError("queue_id %r in use"%queue)
847 844 except:
848 845 content = error.wrap_exception()
849 846 self.log.error("queue_id %r in use"%queue, exc_info=True)
850 847 elif heart in self.hearts: # need to check unique hearts?
851 848 try:
852 849 raise KeyError("heart_id %r in use"%heart)
853 850 except:
854 851 self.log.error("heart_id %r in use"%heart, exc_info=True)
855 852 content = error.wrap_exception()
856 853 else:
857 854 for h, pack in self.incoming_registrations.iteritems():
858 855 if heart == h:
859 856 try:
860 857 raise KeyError("heart_id %r in use"%heart)
861 858 except:
862 859 self.log.error("heart_id %r in use"%heart, exc_info=True)
863 860 content = error.wrap_exception()
864 861 break
865 862 elif queue == pack[1]:
866 863 try:
867 864 raise KeyError("queue_id %r in use"%queue)
868 865 except:
869 866 self.log.error("queue_id %r in use"%queue, exc_info=True)
870 867 content = error.wrap_exception()
871 868 break
872 869
873 870 msg = self.session.send(self.query, "registration_reply",
874 871 content=content,
875 872 ident=reg)
876 873
877 874 if content['status'] == 'ok':
878 875 if heart in self.heartmonitor.hearts:
879 876 # already beating
880 877 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
881 878 self.finish_registration(heart)
882 879 else:
883 880 purge = lambda : self._purge_stalled_registration(heart)
884 881 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
885 882 dc.start()
886 883 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
887 884 else:
888 885 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
889 886 return eid
890 887
891 888 def unregister_engine(self, ident, msg):
892 889 """Unregister an engine that explicitly requested to leave."""
893 890 try:
894 891 eid = msg['content']['id']
895 892 except:
896 893 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
897 894 return
898 895 self.log.info("registration::unregister_engine(%r)"%eid)
899 896 # print (eid)
900 897 uuid = self.keytable[eid]
901 898 content=dict(id=eid, queue=uuid)
902 899 self.dead_engines.add(uuid)
903 900 # self.ids.remove(eid)
904 901 # uuid = self.keytable.pop(eid)
905 902 #
906 903 # ec = self.engines.pop(eid)
907 904 # self.hearts.pop(ec.heartbeat)
908 905 # self.by_ident.pop(ec.queue)
909 906 # self.completed.pop(eid)
910 907 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
911 908 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
912 909 dc.start()
913 910 ############## TODO: HANDLE IT ################
914 911
915 912 if self.notifier:
916 913 self.session.send(self.notifier, "unregistration_notification", content=content)
917 914
918 915 def _handle_stranded_msgs(self, eid, uuid):
919 916 """Handle messages known to be on an engine when the engine unregisters.
920 917
921 918 It is possible that this will fire prematurely - that is, an engine will
922 919 go down after completing a result, and the client will be notified
923 920 that the result failed and later receive the actual result.
924 921 """
925 922
926 923 outstanding = self.queues[eid]
927 924
928 925 for msg_id in outstanding:
929 926 self.pending.remove(msg_id)
930 927 self.all_completed.add(msg_id)
931 928 try:
932 929 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
933 930 except:
934 931 content = error.wrap_exception()
935 932 # build a fake header:
936 933 header = {}
937 934 header['engine'] = uuid
938 935 header['date'] = datetime.now()
939 936 rec = dict(result_content=content, result_header=header, result_buffers=[])
940 937 rec['completed'] = header['date']
941 938 rec['engine_uuid'] = uuid
942 939 try:
943 940 self.db.update_record(msg_id, rec)
944 941 except Exception:
945 942 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
946 943
947 944
948 945 def finish_registration(self, heart):
949 946 """Second half of engine registration, called after our HeartMonitor
950 947 has received a beat from the Engine's Heart."""
951 948 try:
952 949 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
953 950 except KeyError:
954 951 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
955 952 return
956 953 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
957 954 if purge is not None:
958 955 purge.stop()
959 956 control = queue
960 957 self.ids.add(eid)
961 958 self.keytable[eid] = queue
962 959 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
963 960 control=control, heartbeat=heart)
964 961 self.by_ident[queue] = eid
965 962 self.queues[eid] = list()
966 963 self.tasks[eid] = list()
967 964 self.completed[eid] = list()
968 965 self.hearts[heart] = eid
969 966 content = dict(id=eid, queue=self.engines[eid].queue)
970 967 if self.notifier:
971 968 self.session.send(self.notifier, "registration_notification", content=content)
972 969 self.log.info("engine::Engine Connected: %i"%eid)
973 970
974 971 def _purge_stalled_registration(self, heart):
975 972 if heart in self.incoming_registrations:
976 973 eid = self.incoming_registrations.pop(heart)[0]
977 974 self.log.info("registration::purging stalled registration: %i"%eid)
978 975 else:
979 976 pass
980 977
981 978 #-------------------------------------------------------------------------
982 979 # Client Requests
983 980 #-------------------------------------------------------------------------
984 981
985 982 def shutdown_request(self, client_id, msg):
986 983 """handle shutdown request."""
987 984 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
988 985 # also notify other clients of shutdown
989 986 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
990 987 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
991 988 dc.start()
992 989
993 990 def _shutdown(self):
994 991 self.log.info("hub::hub shutting down.")
995 992 time.sleep(0.1)
996 993 sys.exit(0)
997 994
998 995
999 996 def check_load(self, client_id, msg):
1000 997 content = msg['content']
1001 998 try:
1002 999 targets = content['targets']
1003 1000 targets = self._validate_targets(targets)
1004 1001 except:
1005 1002 content = error.wrap_exception()
1006 1003 self.session.send(self.query, "hub_error",
1007 1004 content=content, ident=client_id)
1008 1005 return
1009 1006
1010 1007 content = dict(status='ok')
1011 1008 # loads = {}
1012 1009 for t in targets:
1013 1010 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1014 1011 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1015 1012
1016 1013
1017 1014 def queue_status(self, client_id, msg):
1018 1015 """Return the Queue status of one or more targets.
1019 1016 if verbose: return the msg_ids
1020 1017 else: return len of each type.
1021 1018 keys: queue (pending MUX jobs)
1022 1019 tasks (pending Task jobs)
1023 1020 completed (finished jobs from both queues)"""
1024 1021 content = msg['content']
1025 1022 targets = content['targets']
1026 1023 try:
1027 1024 targets = self._validate_targets(targets)
1028 1025 except:
1029 1026 content = error.wrap_exception()
1030 1027 self.session.send(self.query, "hub_error",
1031 1028 content=content, ident=client_id)
1032 1029 return
1033 1030 verbose = content.get('verbose', False)
1034 1031 content = dict(status='ok')
1035 1032 for t in targets:
1036 1033 queue = self.queues[t]
1037 1034 completed = self.completed[t]
1038 1035 tasks = self.tasks[t]
1039 1036 if not verbose:
1040 1037 queue = len(queue)
1041 1038 completed = len(completed)
1042 1039 tasks = len(tasks)
1043 1040 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1044 1041 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1045 1042
1046 1043 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1047 1044
1048 1045 def purge_results(self, client_id, msg):
1049 1046 """Purge results from memory. This method is more valuable before we move
1050 1047 to a DB based message storage mechanism."""
1051 1048 content = msg['content']
1052 1049 msg_ids = content.get('msg_ids', [])
1053 1050 reply = dict(status='ok')
1054 1051 if msg_ids == 'all':
1055 1052 try:
1056 1053 self.db.drop_matching_records(dict(completed={'$ne':None}))
1057 1054 except Exception:
1058 1055 reply = error.wrap_exception()
1059 1056 else:
1060 1057 pending = filter(lambda m: m in self.pending, msg_ids)
1061 1058 if pending:
1062 1059 try:
1063 1060 raise IndexError("msg pending: %r"%pending[0])
1064 1061 except:
1065 1062 reply = error.wrap_exception()
1066 1063 else:
1067 1064 try:
1068 1065 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1069 1066 except Exception:
1070 1067 reply = error.wrap_exception()
1071 1068
1072 1069 if reply['status'] == 'ok':
1073 1070 eids = content.get('engine_ids', [])
1074 1071 for eid in eids:
1075 1072 if eid not in self.engines:
1076 1073 try:
1077 1074 raise IndexError("No such engine: %i"%eid)
1078 1075 except:
1079 1076 reply = error.wrap_exception()
1080 1077 break
1081 1078 msg_ids = self.completed.pop(eid)
1082 1079 uid = self.engines[eid].queue
1083 1080 try:
1084 1081 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1085 1082 except Exception:
1086 1083 reply = error.wrap_exception()
1087 1084 break
1088 1085
1089 1086 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1090 1087
1091 1088 def resubmit_task(self, client_id, msg):
1092 1089 """Resubmit one or more tasks."""
1093 1090 def finish(reply):
1094 1091 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1095 1092
1096 1093 content = msg['content']
1097 1094 msg_ids = content['msg_ids']
1098 1095 reply = dict(status='ok')
1099 1096 try:
1100 1097 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1101 1098 'header', 'content', 'buffers'])
1102 1099 except Exception:
1103 1100 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1104 1101 return finish(error.wrap_exception())
1105 1102
1106 1103 # validate msg_ids
1107 1104 found_ids = [ rec['msg_id'] for rec in records ]
1108 1105 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1109 1106 if len(records) > len(msg_ids):
1110 1107 try:
1111 1108 raise RuntimeError("DB appears to be in an inconsistent state."
1112 1109 "More matching records were found than should exist")
1113 1110 except Exception:
1114 1111 return finish(error.wrap_exception())
1115 1112 elif len(records) < len(msg_ids):
1116 1113 missing = [ m for m in msg_ids if m not in found_ids ]
1117 1114 try:
1118 1115 raise KeyError("No such msg(s): %r"%missing)
1119 1116 except KeyError:
1120 1117 return finish(error.wrap_exception())
1121 1118 elif invalid_ids:
1122 1119 msg_id = invalid_ids[0]
1123 1120 try:
1124 1121 raise ValueError("Task %r appears to be inflight"%(msg_id))
1125 1122 except Exception:
1126 1123 return finish(error.wrap_exception())
1127 1124
1128 1125 # clear the existing records
1129 1126 now = datetime.now()
1130 1127 rec = empty_record()
1131 1128 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1132 1129 rec['resubmitted'] = now
1133 1130 rec['queue'] = 'task'
1134 1131 rec['client_uuid'] = client_id[0]
1135 1132 try:
1136 1133 for msg_id in msg_ids:
1137 1134 self.all_completed.discard(msg_id)
1138 1135 self.db.update_record(msg_id, rec)
1139 1136 except Exception:
1140 1137 self.log.error('db::db error upating record', exc_info=True)
1141 1138 reply = error.wrap_exception()
1142 1139 else:
1143 1140 # send the messages
1144 now_s = now.strftime(util.ISO8601)
1141 now_s = now.strftime(ISO8601)
1145 1142 for rec in records:
1146 1143 header = rec['header']
1147 1144 # include resubmitted in header to prevent digest collision
1148 1145 header['resubmitted'] = now_s
1149 1146 msg = self.session.msg(header['msg_type'])
1150 1147 msg['content'] = rec['content']
1151 1148 msg['header'] = header
1152 1149 msg['msg_id'] = rec['msg_id']
1153 1150 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1154 1151
1155 1152 finish(dict(status='ok'))
1156 1153
1157 1154
1158 1155 def _extract_record(self, rec):
1159 1156 """decompose a TaskRecord dict into subsection of reply for get_result"""
1160 1157 io_dict = {}
1161 1158 for key in 'pyin pyout pyerr stdout stderr'.split():
1162 1159 io_dict[key] = rec[key]
1163 1160 content = { 'result_content': rec['result_content'],
1164 1161 'header': rec['header'],
1165 1162 'result_header' : rec['result_header'],
1166 1163 'io' : io_dict,
1167 1164 }
1168 1165 if rec['result_buffers']:
1169 1166 buffers = map(str, rec['result_buffers'])
1170 1167 else:
1171 1168 buffers = []
1172 1169
1173 1170 return content, buffers
1174 1171
1175 1172 def get_results(self, client_id, msg):
1176 1173 """Get the result of 1 or more messages."""
1177 1174 content = msg['content']
1178 1175 msg_ids = sorted(set(content['msg_ids']))
1179 1176 statusonly = content.get('status_only', False)
1180 1177 pending = []
1181 1178 completed = []
1182 1179 content = dict(status='ok')
1183 1180 content['pending'] = pending
1184 1181 content['completed'] = completed
1185 1182 buffers = []
1186 1183 if not statusonly:
1187 1184 try:
1188 1185 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1189 1186 # turn match list into dict, for faster lookup
1190 1187 records = {}
1191 1188 for rec in matches:
1192 1189 records[rec['msg_id']] = rec
1193 1190 except Exception:
1194 1191 content = error.wrap_exception()
1195 1192 self.session.send(self.query, "result_reply", content=content,
1196 1193 parent=msg, ident=client_id)
1197 1194 return
1198 1195 else:
1199 1196 records = {}
1200 1197 for msg_id in msg_ids:
1201 1198 if msg_id in self.pending:
1202 1199 pending.append(msg_id)
1203 1200 elif msg_id in self.all_completed:
1204 1201 completed.append(msg_id)
1205 1202 if not statusonly:
1206 1203 c,bufs = self._extract_record(records[msg_id])
1207 1204 content[msg_id] = c
1208 1205 buffers.extend(bufs)
1209 1206 elif msg_id in records:
1210 1207 if rec['completed']:
1211 1208 completed.append(msg_id)
1212 1209 c,bufs = self._extract_record(records[msg_id])
1213 1210 content[msg_id] = c
1214 1211 buffers.extend(bufs)
1215 1212 else:
1216 1213 pending.append(msg_id)
1217 1214 else:
1218 1215 try:
1219 1216 raise KeyError('No such message: '+msg_id)
1220 1217 except:
1221 1218 content = error.wrap_exception()
1222 1219 break
1223 1220 self.session.send(self.query, "result_reply", content=content,
1224 1221 parent=msg, ident=client_id,
1225 1222 buffers=buffers)
1226 1223
1227 1224 def get_history(self, client_id, msg):
1228 1225 """Get a list of all msg_ids in our DB records"""
1229 1226 try:
1230 1227 msg_ids = self.db.get_history()
1231 1228 except Exception as e:
1232 1229 content = error.wrap_exception()
1233 1230 else:
1234 1231 content = dict(status='ok', history=msg_ids)
1235 1232
1236 1233 self.session.send(self.query, "history_reply", content=content,
1237 1234 parent=msg, ident=client_id)
1238 1235
1239 1236 def db_query(self, client_id, msg):
1240 1237 """Perform a raw query on the task record database."""
1241 1238 content = msg['content']
1242 1239 query = content.get('query', {})
1243 1240 keys = content.get('keys', None)
1244 1241 query = util.extract_dates(query)
1245 1242 buffers = []
1246 1243 empty = list()
1247 1244
1248 1245 try:
1249 1246 records = self.db.find_records(query, keys)
1250 1247 except Exception as e:
1251 1248 content = error.wrap_exception()
1252 1249 else:
1253 1250 # extract buffers from reply content:
1254 1251 if keys is not None:
1255 1252 buffer_lens = [] if 'buffers' in keys else None
1256 1253 result_buffer_lens = [] if 'result_buffers' in keys else None
1257 1254 else:
1258 1255 buffer_lens = []
1259 1256 result_buffer_lens = []
1260 1257
1261 1258 for rec in records:
1262 1259 # buffers may be None, so double check
1263 1260 if buffer_lens is not None:
1264 1261 b = rec.pop('buffers', empty) or empty
1265 1262 buffer_lens.append(len(b))
1266 1263 buffers.extend(b)
1267 1264 if result_buffer_lens is not None:
1268 1265 rb = rec.pop('result_buffers', empty) or empty
1269 1266 result_buffer_lens.append(len(rb))
1270 1267 buffers.extend(rb)
1271 1268 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1272 1269 result_buffer_lens=result_buffer_lens)
1273 1270
1274 1271 self.session.send(self.query, "db_reply", content=content,
1275 1272 parent=msg, ident=client_id,
1276 1273 buffers=buffers)
1277 1274
@@ -1,339 +1,339 b''
1 1 """A TaskRecord backend using sqlite3"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 import json
10 10 import os
11 11 import cPickle as pickle
12 12 from datetime import datetime
13 13
14 14 import sqlite3
15 15
16 16 from zmq.eventloop import ioloop
17 17
18 18 from IPython.utils.traitlets import Unicode, Instance, List
19 19 from .dictdb import BaseDB
20 from IPython.parallel.util import ISO8601
20 from IPython.utils.jsonutil import date_default, extract_dates
21 21
22 22 #-----------------------------------------------------------------------------
23 23 # SQLite operators, adapters, and converters
24 24 #-----------------------------------------------------------------------------
25 25
26 26 operators = {
27 27 '$lt' : "<",
28 28 '$gt' : ">",
29 29 # null is handled weird with ==,!=
30 30 '$eq' : "=",
31 31 '$ne' : "!=",
32 32 '$lte': "<=",
33 33 '$gte': ">=",
34 34 '$in' : ('=', ' OR '),
35 35 '$nin': ('!=', ' AND '),
36 36 # '$all': None,
37 37 # '$mod': None,
38 38 # '$exists' : None
39 39 }
40 40 null_operators = {
41 41 '=' : "IS NULL",
42 42 '!=' : "IS NOT NULL",
43 43 }
44 44
45 45 def _adapt_datetime(dt):
46 46 return dt.strftime(ISO8601)
47 47
48 48 def _convert_datetime(ds):
49 49 if ds is None:
50 50 return ds
51 51 else:
52 52 return datetime.strptime(ds, ISO8601)
53 53
54 54 def _adapt_dict(d):
55 return json.dumps(d)
55 return json.dumps(d, default=date_default)
56 56
57 57 def _convert_dict(ds):
58 58 if ds is None:
59 59 return ds
60 60 else:
61 return json.loads(ds)
61 return extract_dates(json.loads(ds))
62 62
63 63 def _adapt_bufs(bufs):
64 64 # this is *horrible*
65 65 # copy buffers into single list and pickle it:
66 66 if bufs and isinstance(bufs[0], (bytes, buffer)):
67 67 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
68 68 elif bufs:
69 69 return bufs
70 70 else:
71 71 return None
72 72
73 73 def _convert_bufs(bs):
74 74 if bs is None:
75 75 return []
76 76 else:
77 77 return pickle.loads(bytes(bs))
78 78
79 79 #-----------------------------------------------------------------------------
80 80 # SQLiteDB class
81 81 #-----------------------------------------------------------------------------
82 82
83 83 class SQLiteDB(BaseDB):
84 84 """SQLite3 TaskRecord backend."""
85 85
86 86 filename = Unicode('tasks.db', config=True,
87 87 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
88 88 location = Unicode('', config=True,
89 89 help="""The directory containing the sqlite task database. The default
90 90 is to use the cluster_dir location.""")
91 91 table = Unicode("", config=True,
92 92 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
93 93 a new table will be created with the Hub's IDENT. Specifying the table will result
94 94 in tasks from previous sessions being available via Clients' db_query and
95 95 get_result methods.""")
96 96
97 97 _db = Instance('sqlite3.Connection')
98 98 _keys = List(['msg_id' ,
99 99 'header' ,
100 100 'content',
101 101 'buffers',
102 102 'submitted',
103 103 'client_uuid' ,
104 104 'engine_uuid' ,
105 105 'started',
106 106 'completed',
107 107 'resubmitted',
108 108 'result_header' ,
109 109 'result_content' ,
110 110 'result_buffers' ,
111 111 'queue' ,
112 112 'pyin' ,
113 113 'pyout',
114 114 'pyerr',
115 115 'stdout',
116 116 'stderr',
117 117 ])
118 118
119 119 def __init__(self, **kwargs):
120 120 super(SQLiteDB, self).__init__(**kwargs)
121 121 if not self.table:
122 122 # use session, and prefix _, since starting with # is illegal
123 123 self.table = '_'+self.session.replace('-','_')
124 124 if not self.location:
125 125 # get current profile
126 126 from IPython.core.newapplication import BaseIPythonApplication
127 127 if BaseIPythonApplication.initialized():
128 128 app = BaseIPythonApplication.instance()
129 129 if app.profile_dir is not None:
130 130 self.location = app.profile_dir.location
131 131 else:
132 132 self.location = u'.'
133 133 else:
134 134 self.location = u'.'
135 135 self._init_db()
136 136
137 137 # register db commit as 2s periodic callback
138 138 # to prevent clogging pipes
139 139 # assumes we are being run in a zmq ioloop app
140 140 loop = ioloop.IOLoop.instance()
141 141 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
142 142 pc.start()
143 143
144 144 def _defaults(self, keys=None):
145 145 """create an empty record"""
146 146 d = {}
147 147 keys = self._keys if keys is None else keys
148 148 for key in keys:
149 149 d[key] = None
150 150 return d
151 151
152 152 def _init_db(self):
153 153 """Connect to the database and get new session number."""
154 154 # register adapters
155 155 sqlite3.register_adapter(datetime, _adapt_datetime)
156 156 sqlite3.register_converter('datetime', _convert_datetime)
157 157 sqlite3.register_adapter(dict, _adapt_dict)
158 158 sqlite3.register_converter('dict', _convert_dict)
159 159 sqlite3.register_adapter(list, _adapt_bufs)
160 160 sqlite3.register_converter('bufs', _convert_bufs)
161 161 # connect to the db
162 162 dbfile = os.path.join(self.location, self.filename)
163 163 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
164 164 # isolation_level = None)#,
165 165 cached_statements=64)
166 166 # print dir(self._db)
167 167
168 168 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
169 169 (msg_id text PRIMARY KEY,
170 170 header dict text,
171 171 content dict text,
172 172 buffers bufs blob,
173 173 submitted datetime text,
174 174 client_uuid text,
175 175 engine_uuid text,
176 176 started datetime text,
177 177 completed datetime text,
178 178 resubmitted datetime text,
179 179 result_header dict text,
180 180 result_content dict text,
181 181 result_buffers bufs blob,
182 182 queue text,
183 183 pyin text,
184 184 pyout text,
185 185 pyerr text,
186 186 stdout text,
187 187 stderr text)
188 188 """%self.table)
189 189 self._db.commit()
190 190
191 191 def _dict_to_list(self, d):
192 192 """turn a mongodb-style record dict into a list."""
193 193
194 194 return [ d[key] for key in self._keys ]
195 195
196 196 def _list_to_dict(self, line, keys=None):
197 197 """Inverse of dict_to_list"""
198 198 keys = self._keys if keys is None else keys
199 199 d = self._defaults(keys)
200 200 for key,value in zip(keys, line):
201 201 d[key] = value
202 202
203 203 return d
204 204
205 205 def _render_expression(self, check):
206 206 """Turn a mongodb-style search dict into an SQL query."""
207 207 expressions = []
208 208 args = []
209 209
210 210 skeys = set(check.keys())
211 211 skeys.difference_update(set(self._keys))
212 212 skeys.difference_update(set(['buffers', 'result_buffers']))
213 213 if skeys:
214 214 raise KeyError("Illegal testing key(s): %s"%skeys)
215 215
216 216 for name,sub_check in check.iteritems():
217 217 if isinstance(sub_check, dict):
218 218 for test,value in sub_check.iteritems():
219 219 try:
220 220 op = operators[test]
221 221 except KeyError:
222 222 raise KeyError("Unsupported operator: %r"%test)
223 223 if isinstance(op, tuple):
224 224 op, join = op
225 225
226 226 if value is None and op in null_operators:
227 227 expr = "%s %s"%null_operators[op]
228 228 else:
229 229 expr = "%s %s ?"%(name, op)
230 230 if isinstance(value, (tuple,list)):
231 231 if op in null_operators and any([v is None for v in value]):
232 232 # equality tests don't work with NULL
233 233 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
234 234 expr = '( %s )'%( join.join([expr]*len(value)) )
235 235 args.extend(value)
236 236 else:
237 237 args.append(value)
238 238 expressions.append(expr)
239 239 else:
240 240 # it's an equality check
241 241 if sub_check is None:
242 242 expressions.append("%s IS NULL")
243 243 else:
244 244 expressions.append("%s = ?"%name)
245 245 args.append(sub_check)
246 246
247 247 expr = " AND ".join(expressions)
248 248 return expr, args
249 249
250 250 def add_record(self, msg_id, rec):
251 251 """Add a new Task Record, by msg_id."""
252 252 d = self._defaults()
253 253 d.update(rec)
254 254 d['msg_id'] = msg_id
255 255 line = self._dict_to_list(d)
256 256 tups = '(%s)'%(','.join(['?']*len(line)))
257 257 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
258 258 # self._db.commit()
259 259
260 260 def get_record(self, msg_id):
261 261 """Get a specific Task Record, by msg_id."""
262 262 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
263 263 line = cursor.fetchone()
264 264 if line is None:
265 265 raise KeyError("No such msg: %r"%msg_id)
266 266 return self._list_to_dict(line)
267 267
268 268 def update_record(self, msg_id, rec):
269 269 """Update the data in an existing record."""
270 270 query = "UPDATE %s SET "%self.table
271 271 sets = []
272 272 keys = sorted(rec.keys())
273 273 values = []
274 274 for key in keys:
275 275 sets.append('%s = ?'%key)
276 276 values.append(rec[key])
277 277 query += ', '.join(sets)
278 278 query += ' WHERE msg_id == ?'
279 279 values.append(msg_id)
280 280 self._db.execute(query, values)
281 281 # self._db.commit()
282 282
283 283 def drop_record(self, msg_id):
284 284 """Remove a record from the DB."""
285 285 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
286 286 # self._db.commit()
287 287
288 288 def drop_matching_records(self, check):
289 289 """Remove a record from the DB."""
290 290 expr,args = self._render_expression(check)
291 291 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
292 292 self._db.execute(query,args)
293 293 # self._db.commit()
294 294
295 295 def find_records(self, check, keys=None):
296 296 """Find records matching a query dict, optionally extracting subset of keys.
297 297
298 298 Returns list of matching records.
299 299
300 300 Parameters
301 301 ----------
302 302
303 303 check: dict
304 304 mongodb-style query argument
305 305 keys: list of strs [optional]
306 306 if specified, the subset of keys to extract. msg_id will *always* be
307 307 included.
308 308 """
309 309 if keys:
310 310 bad_keys = [ key for key in keys if key not in self._keys ]
311 311 if bad_keys:
312 312 raise KeyError("Bad record key(s): %s"%bad_keys)
313 313
314 314 if keys:
315 315 # ensure msg_id is present and first:
316 316 if 'msg_id' in keys:
317 317 keys.remove('msg_id')
318 318 keys.insert(0, 'msg_id')
319 319 req = ', '.join(keys)
320 320 else:
321 321 req = '*'
322 322 expr,args = self._render_expression(check)
323 323 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
324 324 cursor = self._db.execute(query, args)
325 325 matches = cursor.fetchall()
326 326 records = []
327 327 for line in matches:
328 328 rec = self._list_to_dict(line, keys)
329 329 records.append(rec)
330 330 return records
331 331
332 332 def get_history(self):
333 333 """get all msg_ids, ordered by time submitted."""
334 334 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
335 335 cursor = self._db.execute(query)
336 336 # will be a list of length 1 tuples
337 337 return [ tup[0] for tup in cursor.fetchall()]
338 338
339 339 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,165 +1,166 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010-2011 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from __future__ import print_function
14 14
15 15 import sys
16 16 import time
17 17
18 18 import zmq
19 19 from zmq.eventloop import ioloop, zmqstream
20 20
21 21 # internal
22 22 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
23 23 # from IPython.utils.localinterfaces import LOCALHOST
24 24
25 25 from IPython.parallel.controller.heartmonitor import Heart
26 26 from IPython.parallel.factory import RegistrationFactory
27 from IPython.parallel.streamsession import Message
28 27 from IPython.parallel.util import disambiguate_url
29 28
29 from IPython.zmq.session import Message
30
30 31 from .streamkernel import Kernel
31 32
32 33 class EngineFactory(RegistrationFactory):
33 34 """IPython engine"""
34 35
35 36 # configurables:
36 37 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
37 38 help="""The OutStream for handling stdout/err.
38 39 Typically 'IPython.zmq.iostream.OutStream'""")
39 40 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True,
40 41 help="""The class for handling displayhook.
41 42 Typically 'IPython.zmq.displayhook.DisplayHook'""")
42 43 location=Unicode(config=True,
43 44 help="""The location (an IP address) of the controller. This is
44 45 used for disambiguating URLs, to determine whether
45 46 loopback should be used to connect or the public address.""")
46 47 timeout=CFloat(2,config=True,
47 48 help="""The time (in seconds) to wait for the Controller to respond
48 49 to registration requests before giving up.""")
49 50
50 51 # not configurable:
51 52 user_ns=Dict()
52 53 id=Int(allow_none=True)
53 54 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
54 55 kernel=Instance(Kernel)
55 56
56 57
57 58 def __init__(self, **kwargs):
58 59 super(EngineFactory, self).__init__(**kwargs)
59 60 self.ident = self.session.session
60 61 ctx = self.context
61 62
62 63 reg = ctx.socket(zmq.XREQ)
63 64 reg.setsockopt(zmq.IDENTITY, self.ident)
64 65 reg.connect(self.url)
65 66 self.registrar = zmqstream.ZMQStream(reg, self.loop)
66 67
67 68 def register(self):
68 69 """send the registration_request"""
69 70
70 71 self.log.info("registering")
71 72 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
72 73 self.registrar.on_recv(self.complete_registration)
73 74 # print (self.session.key)
74 75 self.session.send(self.registrar, "registration_request",content=content)
75 76
76 77 def complete_registration(self, msg):
77 78 # print msg
78 79 self._abort_dc.stop()
79 80 ctx = self.context
80 81 loop = self.loop
81 82 identity = self.ident
82 83
83 84 idents,msg = self.session.feed_identities(msg)
84 85 msg = Message(self.session.unpack_message(msg))
85 86
86 87 if msg.content.status == 'ok':
87 88 self.id = int(msg.content.id)
88 89
89 90 # create Shell Streams (MUX, Task, etc.):
90 91 queue_addr = msg.content.mux
91 92 shell_addrs = [ str(queue_addr) ]
92 93 task_addr = msg.content.task
93 94 if task_addr:
94 95 shell_addrs.append(str(task_addr))
95 96
96 97 # Uncomment this to go back to two-socket model
97 98 # shell_streams = []
98 99 # for addr in shell_addrs:
99 100 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
100 101 # stream.setsockopt(zmq.IDENTITY, identity)
101 102 # stream.connect(disambiguate_url(addr, self.location))
102 103 # shell_streams.append(stream)
103 104
104 105 # Now use only one shell stream for mux and tasks
105 106 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
106 107 stream.setsockopt(zmq.IDENTITY, identity)
107 108 shell_streams = [stream]
108 109 for addr in shell_addrs:
109 110 stream.connect(disambiguate_url(addr, self.location))
110 111 # end single stream-socket
111 112
112 113 # control stream:
113 114 control_addr = str(msg.content.control)
114 115 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 116 control_stream.setsockopt(zmq.IDENTITY, identity)
116 117 control_stream.connect(disambiguate_url(control_addr, self.location))
117 118
118 119 # create iopub stream:
119 120 iopub_addr = msg.content.iopub
120 121 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
121 122 iopub_stream.setsockopt(zmq.IDENTITY, identity)
122 123 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
123 124
124 125 # launch heartbeat
125 126 hb_addrs = msg.content.heartbeat
126 127 # print (hb_addrs)
127 128
128 129 # # Redirect input streams and set a display hook.
129 130 if self.out_stream_factory:
130 131 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
131 132 sys.stdout.topic = 'engine.%i.stdout'%self.id
132 133 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
133 134 sys.stderr.topic = 'engine.%i.stderr'%self.id
134 135 if self.display_hook_factory:
135 136 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
136 137 sys.displayhook.topic = 'engine.%i.pyout'%self.id
137 138
138 139 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
139 140 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
140 141 loop=loop, user_ns = self.user_ns, log=self.log)
141 142 self.kernel.start()
142 143 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
143 144 heart = Heart(*map(str, hb_addrs), heart_id=identity)
144 145 heart.start()
145 146
146 147
147 148 else:
148 149 self.log.fatal("Registration Failed: %s"%msg)
149 150 raise Exception("Registration Failed: %s"%msg)
150 151
151 152 self.log.info("Completed registration with id %i"%self.id)
152 153
153 154
154 155 def abort(self):
155 156 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
156 157 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
157 158 time.sleep(1)
158 159 sys.exit(255)
159 160
160 161 def start(self):
161 162 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
162 163 dc.start()
163 164 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
164 165 self._abort_dc.start()
165 166
@@ -1,225 +1,225 b''
1 1 """KernelStarter class that intercepts Control Queue messages, and handles process management."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010-2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 from zmq.eventloop import ioloop
10 10
11 from IPython.parallel.streamsession import StreamSession
11 from IPython.zmq.session import Session
12 12
13 13 class KernelStarter(object):
14 14 """Object for resetting/killing the Kernel."""
15 15
16 16
17 17 def __init__(self, session, upstream, downstream, *kernel_args, **kernel_kwargs):
18 18 self.session = session
19 19 self.upstream = upstream
20 20 self.downstream = downstream
21 21 self.kernel_args = kernel_args
22 22 self.kernel_kwargs = kernel_kwargs
23 23 self.handlers = {}
24 24 for method in 'shutdown_request shutdown_reply'.split():
25 25 self.handlers[method] = getattr(self, method)
26 26
27 27 def start(self):
28 28 self.upstream.on_recv(self.dispatch_request)
29 29 self.downstream.on_recv(self.dispatch_reply)
30 30
31 31 #--------------------------------------------------------------------------
32 32 # Dispatch methods
33 33 #--------------------------------------------------------------------------
34 34
35 35 def dispatch_request(self, raw_msg):
36 36 idents, msg = self.session.feed_identities()
37 37 try:
38 38 msg = self.session.unpack_message(msg, content=False)
39 39 except:
40 40 print ("bad msg: %s"%msg)
41 41
42 42 msgtype = msg['msg_type']
43 43 handler = self.handlers.get(msgtype, None)
44 44 if handler is None:
45 45 self.downstream.send_multipart(raw_msg, copy=False)
46 46 else:
47 47 handler(msg)
48 48
49 49 def dispatch_reply(self, raw_msg):
50 50 idents, msg = self.session.feed_identities()
51 51 try:
52 52 msg = self.session.unpack_message(msg, content=False)
53 53 except:
54 54 print ("bad msg: %s"%msg)
55 55
56 56 msgtype = msg['msg_type']
57 57 handler = self.handlers.get(msgtype, None)
58 58 if handler is None:
59 59 self.upstream.send_multipart(raw_msg, copy=False)
60 60 else:
61 61 handler(msg)
62 62
63 63 #--------------------------------------------------------------------------
64 64 # Handlers
65 65 #--------------------------------------------------------------------------
66 66
67 67 def shutdown_request(self, msg):
68 68 """"""
69 69 self.downstream.send_multipart(msg)
70 70
71 71 #--------------------------------------------------------------------------
72 72 # Kernel process management methods, from KernelManager:
73 73 #--------------------------------------------------------------------------
74 74
75 75 def _check_local(addr):
76 76 if isinstance(addr, tuple):
77 77 addr = addr[0]
78 78 return addr in LOCAL_IPS
79 79
80 80 def start_kernel(self, **kw):
81 81 """Starts a kernel process and configures the manager to use it.
82 82
83 83 If random ports (port=0) are being used, this method must be called
84 84 before the channels are created.
85 85
86 86 Parameters:
87 87 -----------
88 88 ipython : bool, optional (default True)
89 89 Whether to use an IPython kernel instead of a plain Python kernel.
90 90 """
91 91 self.kernel = Process(target=make_kernel, args=self.kernel_args,
92 92 kwargs=self.kernel_kwargs)
93 93
94 94 def shutdown_kernel(self, restart=False):
95 95 """ Attempts to the stop the kernel process cleanly. If the kernel
96 96 cannot be stopped, it is killed, if possible.
97 97 """
98 98 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
99 99 if sys.platform == 'win32':
100 100 self.kill_kernel()
101 101 return
102 102
103 103 # Don't send any additional kernel kill messages immediately, to give
104 104 # the kernel a chance to properly execute shutdown actions. Wait for at
105 105 # most 1s, checking every 0.1s.
106 106 self.xreq_channel.shutdown(restart=restart)
107 107 for i in range(10):
108 108 if self.is_alive:
109 109 time.sleep(0.1)
110 110 else:
111 111 break
112 112 else:
113 113 # OK, we've waited long enough.
114 114 if self.has_kernel:
115 115 self.kill_kernel()
116 116
117 117 def restart_kernel(self, now=False):
118 118 """Restarts a kernel with the same arguments that were used to launch
119 119 it. If the old kernel was launched with random ports, the same ports
120 120 will be used for the new kernel.
121 121
122 122 Parameters
123 123 ----------
124 124 now : bool, optional
125 125 If True, the kernel is forcefully restarted *immediately*, without
126 126 having a chance to do any cleanup action. Otherwise the kernel is
127 127 given 1s to clean up before a forceful restart is issued.
128 128
129 129 In all cases the kernel is restarted, the only difference is whether
130 130 it is given a chance to perform a clean shutdown or not.
131 131 """
132 132 if self._launch_args is None:
133 133 raise RuntimeError("Cannot restart the kernel. "
134 134 "No previous call to 'start_kernel'.")
135 135 else:
136 136 if self.has_kernel:
137 137 if now:
138 138 self.kill_kernel()
139 139 else:
140 140 self.shutdown_kernel(restart=True)
141 141 self.start_kernel(**self._launch_args)
142 142
143 143 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
144 144 # unless there is some delay here.
145 145 if sys.platform == 'win32':
146 146 time.sleep(0.2)
147 147
148 148 @property
149 149 def has_kernel(self):
150 150 """Returns whether a kernel process has been specified for the kernel
151 151 manager.
152 152 """
153 153 return self.kernel is not None
154 154
155 155 def kill_kernel(self):
156 156 """ Kill the running kernel. """
157 157 if self.has_kernel:
158 158 # Pause the heart beat channel if it exists.
159 159 if self._hb_channel is not None:
160 160 self._hb_channel.pause()
161 161
162 162 # Attempt to kill the kernel.
163 163 try:
164 164 self.kernel.kill()
165 165 except OSError, e:
166 166 # In Windows, we will get an Access Denied error if the process
167 167 # has already terminated. Ignore it.
168 168 if not (sys.platform == 'win32' and e.winerror == 5):
169 169 raise
170 170 self.kernel = None
171 171 else:
172 172 raise RuntimeError("Cannot kill kernel. No kernel is running!")
173 173
174 174 def interrupt_kernel(self):
175 175 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
176 176 well supported on all platforms.
177 177 """
178 178 if self.has_kernel:
179 179 if sys.platform == 'win32':
180 180 from parentpoller import ParentPollerWindows as Poller
181 181 Poller.send_interrupt(self.kernel.win32_interrupt_event)
182 182 else:
183 183 self.kernel.send_signal(signal.SIGINT)
184 184 else:
185 185 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
186 186
187 187 def signal_kernel(self, signum):
188 188 """ Sends a signal to the kernel. Note that since only SIGTERM is
189 189 supported on Windows, this function is only useful on Unix systems.
190 190 """
191 191 if self.has_kernel:
192 192 self.kernel.send_signal(signum)
193 193 else:
194 194 raise RuntimeError("Cannot signal kernel. No kernel is running!")
195 195
196 196 @property
197 197 def is_alive(self):
198 198 """Is the kernel process still running?"""
199 199 # FIXME: not using a heartbeat means this method is broken for any
200 200 # remote kernel, it's only capable of handling local kernels.
201 201 if self.has_kernel:
202 202 if self.kernel.poll() is None:
203 203 return True
204 204 else:
205 205 return False
206 206 else:
207 207 # We didn't start the kernel with this KernelManager so we don't
208 208 # know if it is running. We should use a heartbeat for this case.
209 209 return True
210 210
211 211
212 212 def make_starter(up_addr, down_addr, *args, **kwargs):
213 213 """entry point function for launching a kernelstarter in a subprocess"""
214 214 loop = ioloop.IOLoop.instance()
215 215 ctx = zmq.Context()
216 session = StreamSession()
216 session = Session()
217 217 upstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
218 218 upstream.connect(up_addr)
219 219 downstream = zmqstream.ZMQStream(ctx.socket(zmq.XREQ),loop)
220 220 downstream.connect(down_addr)
221 221
222 222 starter = KernelStarter(session, upstream, downstream, *args, **kwargs)
223 223 starter.start()
224 224 loop.start()
225 225 No newline at end of file
@@ -1,433 +1,434 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010-2011 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15
16 16 # Standard library imports.
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 from code import CommandCompiler
23 23 from datetime import datetime
24 24 from pprint import pprint
25 25
26 26 # System library imports.
27 27 import zmq
28 28 from zmq.eventloop import ioloop, zmqstream
29 29
30 30 # Local imports.
31 from IPython.utils.jsonutil import ISO8601
31 32 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
32 33 from IPython.zmq.completer import KernelCompleter
33 34
34 35 from IPython.parallel.error import wrap_exception
35 36 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601
37 from IPython.parallel.util import serialize_object, unpack_apply_message
37 38
38 39 def printer(*args):
39 40 pprint(args, stream=sys.__stdout__)
40 41
41 42
42 43 class _Passer(zmqstream.ZMQStream):
43 44 """Empty class that implements `send()` that does nothing.
44 45
45 Subclass ZMQStream for StreamSession typechecking
46 Subclass ZMQStream for Session typechecking
46 47
47 48 """
48 49 def __init__(self, *args, **kwargs):
49 50 pass
50 51
51 52 def send(self, *args, **kwargs):
52 53 pass
53 54 send_multipart = send
54 55
55 56
56 57 #-----------------------------------------------------------------------------
57 58 # Main kernel class
58 59 #-----------------------------------------------------------------------------
59 60
60 61 class Kernel(SessionFactory):
61 62
62 63 #---------------------------------------------------------------------------
63 64 # Kernel interface
64 65 #---------------------------------------------------------------------------
65 66
66 67 # kwargs:
67 68 exec_lines = List(Unicode, config=True,
68 69 help="List of lines to execute")
69 70
70 71 int_id = Int(-1)
71 72 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
72 73
73 74 control_stream = Instance(zmqstream.ZMQStream)
74 75 task_stream = Instance(zmqstream.ZMQStream)
75 76 iopub_stream = Instance(zmqstream.ZMQStream)
76 77 client = Instance('IPython.parallel.Client')
77 78
78 79 # internals
79 80 shell_streams = List()
80 81 compiler = Instance(CommandCompiler, (), {})
81 82 completer = Instance(KernelCompleter)
82 83
83 84 aborted = Set()
84 85 shell_handlers = Dict()
85 86 control_handlers = Dict()
86 87
87 88 def _set_prefix(self):
88 89 self.prefix = "engine.%s"%self.int_id
89 90
90 91 def _connect_completer(self):
91 92 self.completer = KernelCompleter(self.user_ns)
92 93
93 94 def __init__(self, **kwargs):
94 95 super(Kernel, self).__init__(**kwargs)
95 96 self._set_prefix()
96 97 self._connect_completer()
97 98
98 99 self.on_trait_change(self._set_prefix, 'id')
99 100 self.on_trait_change(self._connect_completer, 'user_ns')
100 101
101 102 # Build dict of handlers for message types
102 103 for msg_type in ['execute_request', 'complete_request', 'apply_request',
103 104 'clear_request']:
104 105 self.shell_handlers[msg_type] = getattr(self, msg_type)
105 106
106 107 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
107 108 self.control_handlers[msg_type] = getattr(self, msg_type)
108 109
109 110 self._initial_exec_lines()
110 111
111 112 def _wrap_exception(self, method=None):
112 113 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
113 114 content=wrap_exception(e_info)
114 115 return content
115 116
116 117 def _initial_exec_lines(self):
117 118 s = _Passer()
118 119 content = dict(silent=True, user_variable=[],user_expressions=[])
119 120 for line in self.exec_lines:
120 121 self.log.debug("executing initialization: %s"%line)
121 122 content.update({'code':line})
122 123 msg = self.session.msg('execute_request', content)
123 124 self.execute_request(s, [], msg)
124 125
125 126
126 127 #-------------------- control handlers -----------------------------
127 128 def abort_queues(self):
128 129 for stream in self.shell_streams:
129 130 if stream:
130 131 self.abort_queue(stream)
131 132
132 133 def abort_queue(self, stream):
133 134 while True:
134 135 try:
135 136 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
136 137 except zmq.ZMQError as e:
137 138 if e.errno == zmq.EAGAIN:
138 139 break
139 140 else:
140 141 return
141 142 else:
142 143 if msg is None:
143 144 return
144 145 else:
145 146 idents,msg = msg
146 147
147 148 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
148 149 # msg = self.reply_socket.recv_json()
149 150 self.log.info("Aborting:")
150 151 self.log.info(str(msg))
151 152 msg_type = msg['msg_type']
152 153 reply_type = msg_type.split('_')[0] + '_reply'
153 154 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
154 155 # self.reply_socket.send(ident,zmq.SNDMORE)
155 156 # self.reply_socket.send_json(reply_msg)
156 157 reply_msg = self.session.send(stream, reply_type,
157 158 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
158 159 self.log.debug(str(reply_msg))
159 160 # We need to wait a bit for requests to come in. This can probably
160 161 # be set shorter for true asynchronous clients.
161 162 time.sleep(0.05)
162 163
163 164 def abort_request(self, stream, ident, parent):
164 165 """abort a specifig msg by id"""
165 166 msg_ids = parent['content'].get('msg_ids', None)
166 167 if isinstance(msg_ids, basestring):
167 168 msg_ids = [msg_ids]
168 169 if not msg_ids:
169 170 self.abort_queues()
170 171 for mid in msg_ids:
171 172 self.aborted.add(str(mid))
172 173
173 174 content = dict(status='ok')
174 175 reply_msg = self.session.send(stream, 'abort_reply', content=content,
175 176 parent=parent, ident=ident)
176 177 self.log.debug(str(reply_msg))
177 178
178 179 def shutdown_request(self, stream, ident, parent):
179 180 """kill ourself. This should really be handled in an external process"""
180 181 try:
181 182 self.abort_queues()
182 183 except:
183 184 content = self._wrap_exception('shutdown')
184 185 else:
185 186 content = dict(parent['content'])
186 187 content['status'] = 'ok'
187 188 msg = self.session.send(stream, 'shutdown_reply',
188 189 content=content, parent=parent, ident=ident)
189 190 self.log.debug(str(msg))
190 191 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
191 192 dc.start()
192 193
193 194 def dispatch_control(self, msg):
194 195 idents,msg = self.session.feed_identities(msg, copy=False)
195 196 try:
196 197 msg = self.session.unpack_message(msg, content=True, copy=False)
197 198 except:
198 199 self.log.error("Invalid Message", exc_info=True)
199 200 return
200 201
201 202 header = msg['header']
202 203 msg_id = header['msg_id']
203 204
204 205 handler = self.control_handlers.get(msg['msg_type'], None)
205 206 if handler is None:
206 207 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
207 208 else:
208 209 handler(self.control_stream, idents, msg)
209 210
210 211
211 212 #-------------------- queue helpers ------------------------------
212 213
213 214 def check_dependencies(self, dependencies):
214 215 if not dependencies:
215 216 return True
216 217 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
217 218 anyorall = dependencies[0]
218 219 dependencies = dependencies[1]
219 220 else:
220 221 anyorall = 'all'
221 222 results = self.client.get_results(dependencies,status_only=True)
222 223 if results['status'] != 'ok':
223 224 return False
224 225
225 226 if anyorall == 'any':
226 227 if not results['completed']:
227 228 return False
228 229 else:
229 230 if results['pending']:
230 231 return False
231 232
232 233 return True
233 234
234 235 def check_aborted(self, msg_id):
235 236 return msg_id in self.aborted
236 237
237 238 #-------------------- queue handlers -----------------------------
238 239
239 240 def clear_request(self, stream, idents, parent):
240 241 """Clear our namespace."""
241 242 self.user_ns = {}
242 243 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
243 244 content = dict(status='ok'))
244 245 self._initial_exec_lines()
245 246
246 247 def execute_request(self, stream, ident, parent):
247 248 self.log.debug('execute request %s'%parent)
248 249 try:
249 250 code = parent[u'content'][u'code']
250 251 except:
251 252 self.log.error("Got bad msg: %s"%parent, exc_info=True)
252 253 return
253 254 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
254 255 ident='%s.pyin'%self.prefix)
255 256 started = datetime.now().strftime(ISO8601)
256 257 try:
257 258 comp_code = self.compiler(code, '<zmq-kernel>')
258 259 # allow for not overriding displayhook
259 260 if hasattr(sys.displayhook, 'set_parent'):
260 261 sys.displayhook.set_parent(parent)
261 262 sys.stdout.set_parent(parent)
262 263 sys.stderr.set_parent(parent)
263 264 exec comp_code in self.user_ns, self.user_ns
264 265 except:
265 266 exc_content = self._wrap_exception('execute')
266 267 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
267 268 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
268 269 ident='%s.pyerr'%self.prefix)
269 270 reply_content = exc_content
270 271 else:
271 272 reply_content = {'status' : 'ok'}
272 273
273 274 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
274 275 ident=ident, subheader = dict(started=started))
275 276 self.log.debug(str(reply_msg))
276 277 if reply_msg['content']['status'] == u'error':
277 278 self.abort_queues()
278 279
279 280 def complete_request(self, stream, ident, parent):
280 281 matches = {'matches' : self.complete(parent),
281 282 'status' : 'ok'}
282 283 completion_msg = self.session.send(stream, 'complete_reply',
283 284 matches, parent, ident)
284 285 # print >> sys.__stdout__, completion_msg
285 286
286 287 def complete(self, msg):
287 288 return self.completer.complete(msg.content.line, msg.content.text)
288 289
289 290 def apply_request(self, stream, ident, parent):
290 291 # flush previous reply, so this request won't block it
291 292 stream.flush(zmq.POLLOUT)
292 293
293 294 try:
294 295 content = parent[u'content']
295 296 bufs = parent[u'buffers']
296 297 msg_id = parent['header']['msg_id']
297 298 # bound = parent['header'].get('bound', False)
298 299 except:
299 300 self.log.error("Got bad msg: %s"%parent, exc_info=True)
300 301 return
301 302 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
302 303 # self.iopub_stream.send(pyin_msg)
303 304 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
304 305 sub = {'dependencies_met' : True, 'engine' : self.ident,
305 306 'started': datetime.now().strftime(ISO8601)}
306 307 try:
307 308 # allow for not overriding displayhook
308 309 if hasattr(sys.displayhook, 'set_parent'):
309 310 sys.displayhook.set_parent(parent)
310 311 sys.stdout.set_parent(parent)
311 312 sys.stderr.set_parent(parent)
312 313 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
313 314 working = self.user_ns
314 315 # suffix =
315 316 prefix = "_"+str(msg_id).replace("-","")+"_"
316 317
317 318 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
318 319 # if bound:
319 320 # bound_ns = Namespace(working)
320 321 # args = [bound_ns]+list(args)
321 322
322 323 fname = getattr(f, '__name__', 'f')
323 324
324 325 fname = prefix+"f"
325 326 argname = prefix+"args"
326 327 kwargname = prefix+"kwargs"
327 328 resultname = prefix+"result"
328 329
329 330 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
330 331 # print ns
331 332 working.update(ns)
332 333 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
333 334 try:
334 335 exec code in working,working
335 336 result = working.get(resultname)
336 337 finally:
337 338 for key in ns.iterkeys():
338 339 working.pop(key)
339 340 # if bound:
340 341 # working.update(bound_ns)
341 342
342 343 packed_result,buf = serialize_object(result)
343 344 result_buf = [packed_result]+buf
344 345 except:
345 346 exc_content = self._wrap_exception('apply')
346 347 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
347 348 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
348 349 ident='%s.pyerr'%self.prefix)
349 350 reply_content = exc_content
350 351 result_buf = []
351 352
352 353 if exc_content['ename'] == 'UnmetDependency':
353 354 sub['dependencies_met'] = False
354 355 else:
355 356 reply_content = {'status' : 'ok'}
356 357
357 358 # put 'ok'/'error' status in header, for scheduler introspection:
358 359 sub['status'] = reply_content['status']
359 360
360 361 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
361 362 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
362 363
363 364 # flush i/o
364 365 # should this be before reply_msg is sent, like in the single-kernel code,
365 366 # or should nothing get in the way of real results?
366 367 sys.stdout.flush()
367 368 sys.stderr.flush()
368 369
369 370 def dispatch_queue(self, stream, msg):
370 371 self.control_stream.flush()
371 372 idents,msg = self.session.feed_identities(msg, copy=False)
372 373 try:
373 374 msg = self.session.unpack_message(msg, content=True, copy=False)
374 375 except:
375 376 self.log.error("Invalid Message", exc_info=True)
376 377 return
377 378
378 379
379 380 header = msg['header']
380 381 msg_id = header['msg_id']
381 382 if self.check_aborted(msg_id):
382 383 self.aborted.remove(msg_id)
383 384 # is it safe to assume a msg_id will not be resubmitted?
384 385 reply_type = msg['msg_type'].split('_')[0] + '_reply'
385 386 status = {'status' : 'aborted'}
386 387 reply_msg = self.session.send(stream, reply_type, subheader=status,
387 388 content=status, parent=msg, ident=idents)
388 389 return
389 390 handler = self.shell_handlers.get(msg['msg_type'], None)
390 391 if handler is None:
391 392 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
392 393 else:
393 394 handler(stream, idents, msg)
394 395
395 396 def start(self):
396 397 #### stream mode:
397 398 if self.control_stream:
398 399 self.control_stream.on_recv(self.dispatch_control, copy=False)
399 400 self.control_stream.on_err(printer)
400 401
401 402 def make_dispatcher(stream):
402 403 def dispatcher(msg):
403 404 return self.dispatch_queue(stream, msg)
404 405 return dispatcher
405 406
406 407 for s in self.shell_streams:
407 408 s.on_recv(make_dispatcher(s), copy=False)
408 409 s.on_err(printer)
409 410
410 411 if self.iopub_stream:
411 412 self.iopub_stream.on_err(printer)
412 413
413 414 #### while True mode:
414 415 # while True:
415 416 # idle = True
416 417 # try:
417 418 # msg = self.shell_stream.socket.recv_multipart(
418 419 # zmq.NOBLOCK, copy=False)
419 420 # except zmq.ZMQError, e:
420 421 # if e.errno != zmq.EAGAIN:
421 422 # raise e
422 423 # else:
423 424 # idle=False
424 425 # self.dispatch_queue(self.shell_stream, msg)
425 426 #
426 427 # if not self.task_stream.empty():
427 428 # idle=False
428 429 # msg = self.task_stream.recv_multipart()
429 430 # self.dispatch_queue(self.task_stream, msg)
430 431 # if idle:
431 432 # # don't busywait
432 433 # time.sleep(1e-3)
433 434
@@ -1,99 +1,99 b''
1 1 """Base config factories."""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2008-2009 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13
14 14
15 15 import logging
16 16 import os
17 17
18 18 import zmq
19 19 from zmq.eventloop.ioloop import IOLoop
20 20
21 21 from IPython.config.configurable import Configurable
22 22 from IPython.utils.traitlets import Int, Instance, Unicode
23 23
24 import IPython.parallel.streamsession as ss
25 24 from IPython.parallel.util import select_random_ports
25 from IPython.zmq.session import Session
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # Classes
29 29 #-----------------------------------------------------------------------------
30 30 class LoggingFactory(Configurable):
31 31 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
32 32 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
33 33 logname = Unicode('ZMQ')
34 34 def _logname_changed(self, name, old, new):
35 35 self.log = logging.getLogger(new)
36 36
37 37
38 38 class SessionFactory(LoggingFactory):
39 39 """The Base factory from which every factory in IPython.parallel inherits"""
40 40
41 41 # not configurable:
42 42 context = Instance('zmq.Context')
43 43 def _context_default(self):
44 44 return zmq.Context.instance()
45 45
46 session = Instance('IPython.parallel.streamsession.StreamSession')
46 session = Instance('IPython.zmq.session.Session')
47 47 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
48 48 def _loop_default(self):
49 49 return IOLoop.instance()
50 50
51 51
52 52 def __init__(self, **kwargs):
53 53 super(SessionFactory, self).__init__(**kwargs)
54 54
55 55 # construct the session
56 self.session = ss.StreamSession(**kwargs)
56 self.session = Session(**kwargs)
57 57
58 58
59 59 class RegistrationFactory(SessionFactory):
60 60 """The Base Configurable for objects that involve registration."""
61 61
62 62 url = Unicode('', config=True,
63 63 help="""The 0MQ url used for registration. This sets transport, ip, and port
64 64 in one variable. For example: url='tcp://127.0.0.1:12345' or
65 65 url='epgm://*:90210'""") # url takes precedence over ip,regport,transport
66 66 transport = Unicode('tcp', config=True,
67 67 help="""The 0MQ transport for communications. This will likely be
68 68 the default of 'tcp', but other values include 'ipc', 'epgm', 'inproc'.""")
69 69 ip = Unicode('127.0.0.1', config=True,
70 70 help="""The IP address for registration. This is generally either
71 71 '127.0.0.1' for loopback only or '*' for all interfaces.
72 72 [default: '127.0.0.1']""")
73 73 regport = Int(config=True,
74 74 help="""The port on which the Hub listens for registration.""")
75 75 def _regport_default(self):
76 76 return select_random_ports(1)[0]
77 77
78 78 def __init__(self, **kwargs):
79 79 super(RegistrationFactory, self).__init__(**kwargs)
80 80 self._propagate_url()
81 81 self._rebuild_url()
82 82 self.on_trait_change(self._propagate_url, 'url')
83 83 self.on_trait_change(self._rebuild_url, 'ip')
84 84 self.on_trait_change(self._rebuild_url, 'transport')
85 85 self.on_trait_change(self._rebuild_url, 'regport')
86 86
87 87 def _rebuild_url(self):
88 88 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
89 89
90 90 def _propagate_url(self):
91 91 """Ensure self.url contains full transport://interface:port"""
92 92 if self.url:
93 93 iface = self.url.split('://',1)
94 94 if len(iface) == 2:
95 95 self.transport,iface = iface
96 96 iface = iface.split(':')
97 97 self.ip = iface[0]
98 98 if iface[1]:
99 99 self.regport = int(iface[1])
@@ -1,170 +1,173 b''
1 1 """Tests for db backends"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14
15 15 import tempfile
16 16 import time
17 17
18 18 from datetime import datetime, timedelta
19 19 from unittest import TestCase
20 20
21 21 from nose import SkipTest
22 22
23 from IPython.parallel import error, streamsession as ss
23 from IPython.parallel import error
24 24 from IPython.parallel.controller.dictdb import DictDB
25 25 from IPython.parallel.controller.sqlitedb import SQLiteDB
26 26 from IPython.parallel.controller.hub import init_record, empty_record
27 27
28 from IPython.zmq.session import Session
29
30
28 31 #-------------------------------------------------------------------------------
29 32 # TestCases
30 33 #-------------------------------------------------------------------------------
31 34
32 35 class TestDictBackend(TestCase):
33 36 def setUp(self):
34 self.session = ss.StreamSession()
37 self.session = Session()
35 38 self.db = self.create_db()
36 39 self.load_records(16)
37 40
38 41 def create_db(self):
39 42 return DictDB()
40 43
41 44 def load_records(self, n=1):
42 45 """load n records for testing"""
43 46 #sleep 1/10 s, to ensure timestamp is different to previous calls
44 47 time.sleep(0.1)
45 48 msg_ids = []
46 49 for i in range(n):
47 50 msg = self.session.msg('apply_request', content=dict(a=5))
48 51 msg['buffers'] = []
49 52 rec = init_record(msg)
50 53 msg_ids.append(msg['msg_id'])
51 54 self.db.add_record(msg['msg_id'], rec)
52 55 return msg_ids
53 56
54 57 def test_add_record(self):
55 58 before = self.db.get_history()
56 59 self.load_records(5)
57 60 after = self.db.get_history()
58 61 self.assertEquals(len(after), len(before)+5)
59 62 self.assertEquals(after[:-5],before)
60 63
61 64 def test_drop_record(self):
62 65 msg_id = self.load_records()[-1]
63 66 rec = self.db.get_record(msg_id)
64 67 self.db.drop_record(msg_id)
65 68 self.assertRaises(KeyError,self.db.get_record, msg_id)
66 69
67 70 def _round_to_millisecond(self, dt):
68 71 """necessary because mongodb rounds microseconds"""
69 72 micro = dt.microsecond
70 73 extra = int(str(micro)[-3:])
71 74 return dt - timedelta(microseconds=extra)
72 75
73 76 def test_update_record(self):
74 77 now = self._round_to_millisecond(datetime.now())
75 78 #
76 79 msg_id = self.db.get_history()[-1]
77 80 rec1 = self.db.get_record(msg_id)
78 81 data = {'stdout': 'hello there', 'completed' : now}
79 82 self.db.update_record(msg_id, data)
80 83 rec2 = self.db.get_record(msg_id)
81 84 self.assertEquals(rec2['stdout'], 'hello there')
82 85 self.assertEquals(rec2['completed'], now)
83 86 rec1.update(data)
84 87 self.assertEquals(rec1, rec2)
85 88
86 89 # def test_update_record_bad(self):
87 90 # """test updating nonexistant records"""
88 91 # msg_id = str(uuid.uuid4())
89 92 # data = {'stdout': 'hello there'}
90 93 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
91 94
92 95 def test_find_records_dt(self):
93 96 """test finding records by date"""
94 97 hist = self.db.get_history()
95 98 middle = self.db.get_record(hist[len(hist)/2])
96 99 tic = middle['submitted']
97 100 before = self.db.find_records({'submitted' : {'$lt' : tic}})
98 101 after = self.db.find_records({'submitted' : {'$gte' : tic}})
99 102 self.assertEquals(len(before)+len(after),len(hist))
100 103 for b in before:
101 104 self.assertTrue(b['submitted'] < tic)
102 105 for a in after:
103 106 self.assertTrue(a['submitted'] >= tic)
104 107 same = self.db.find_records({'submitted' : tic})
105 108 for s in same:
106 109 self.assertTrue(s['submitted'] == tic)
107 110
108 111 def test_find_records_keys(self):
109 112 """test extracting subset of record keys"""
110 113 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
111 114 for rec in found:
112 115 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
113 116
114 117 def test_find_records_msg_id(self):
115 118 """ensure msg_id is always in found records"""
116 119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
117 120 for rec in found:
118 121 self.assertTrue('msg_id' in rec.keys())
119 122 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
120 123 for rec in found:
121 124 self.assertTrue('msg_id' in rec.keys())
122 125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
123 126 for rec in found:
124 127 self.assertTrue('msg_id' in rec.keys())
125 128
126 129 def test_find_records_in(self):
127 130 """test finding records with '$in','$nin' operators"""
128 131 hist = self.db.get_history()
129 132 even = hist[::2]
130 133 odd = hist[1::2]
131 134 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
132 135 found = [ r['msg_id'] for r in recs ]
133 136 self.assertEquals(set(even), set(found))
134 137 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
135 138 found = [ r['msg_id'] for r in recs ]
136 139 self.assertEquals(set(odd), set(found))
137 140
138 141 def test_get_history(self):
139 142 msg_ids = self.db.get_history()
140 143 latest = datetime(1984,1,1)
141 144 for msg_id in msg_ids:
142 145 rec = self.db.get_record(msg_id)
143 146 newt = rec['submitted']
144 147 self.assertTrue(newt >= latest)
145 148 latest = newt
146 149 msg_id = self.load_records(1)[-1]
147 150 self.assertEquals(self.db.get_history()[-1],msg_id)
148 151
149 152 def test_datetime(self):
150 153 """get/set timestamps with datetime objects"""
151 154 msg_id = self.db.get_history()[-1]
152 155 rec = self.db.get_record(msg_id)
153 156 self.assertTrue(isinstance(rec['submitted'], datetime))
154 157 self.db.update_record(msg_id, dict(completed=datetime.now()))
155 158 rec = self.db.get_record(msg_id)
156 159 self.assertTrue(isinstance(rec['completed'], datetime))
157 160
158 161 def test_drop_matching(self):
159 162 msg_ids = self.load_records(10)
160 163 query = {'msg_id' : {'$in':msg_ids}}
161 164 self.db.drop_matching_records(query)
162 165 recs = self.db.find_records(query)
163 166 self.assertTrue(len(recs)==0)
164 167
165 168 class TestSQLiteBackend(TestDictBackend):
166 169 def create_db(self):
167 170 return SQLiteDB(location=tempfile.gettempdir())
168 171
169 172 def tearDown(self):
170 173 self.db._db.close()
@@ -1,483 +1,466 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010-2011 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 # Standard library imports.
14 14 import logging
15 15 import os
16 16 import re
17 17 import stat
18 18 import socket
19 19 import sys
20 from datetime import datetime
21 20 from signal import signal, SIGINT, SIGABRT, SIGTERM
22 21 try:
23 22 from signal import SIGKILL
24 23 except ImportError:
25 24 SIGKILL=None
26 25
27 26 try:
28 27 import cPickle
29 28 pickle = cPickle
30 29 except:
31 30 cPickle = None
32 31 import pickle
33 32
34 33 # System library imports
35 34 import zmq
36 35 from zmq.log import handlers
37 36
38 37 # IPython imports
39 38 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
40 39 from IPython.utils.newserialized import serialize, unserialize
41 40 from IPython.zmq.log import EnginePUBHandler
42 41
43 # globals
44 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
45 ISO8601_RE=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
46
47 42 #-----------------------------------------------------------------------------
48 43 # Classes
49 44 #-----------------------------------------------------------------------------
50 45
51 46 class Namespace(dict):
52 47 """Subclass of dict for attribute access to keys."""
53 48
54 49 def __getattr__(self, key):
55 50 """getattr aliased to getitem"""
56 51 if key in self.iterkeys():
57 52 return self[key]
58 53 else:
59 54 raise NameError(key)
60 55
61 56 def __setattr__(self, key, value):
62 57 """setattr aliased to setitem, with strict"""
63 58 if hasattr(dict, key):
64 59 raise KeyError("Cannot override dict keys %r"%key)
65 60 self[key] = value
66 61
67 62
68 63 class ReverseDict(dict):
69 64 """simple double-keyed subset of dict methods."""
70 65
71 66 def __init__(self, *args, **kwargs):
72 67 dict.__init__(self, *args, **kwargs)
73 68 self._reverse = dict()
74 69 for key, value in self.iteritems():
75 70 self._reverse[value] = key
76 71
77 72 def __getitem__(self, key):
78 73 try:
79 74 return dict.__getitem__(self, key)
80 75 except KeyError:
81 76 return self._reverse[key]
82 77
83 78 def __setitem__(self, key, value):
84 79 if key in self._reverse:
85 80 raise KeyError("Can't have key %r on both sides!"%key)
86 81 dict.__setitem__(self, key, value)
87 82 self._reverse[value] = key
88 83
89 84 def pop(self, key):
90 85 value = dict.pop(self, key)
91 86 self._reverse.pop(value)
92 87 return value
93 88
94 89 def get(self, key, default=None):
95 90 try:
96 91 return self[key]
97 92 except KeyError:
98 93 return default
99 94
100 95 #-----------------------------------------------------------------------------
101 96 # Functions
102 97 #-----------------------------------------------------------------------------
103 98
104 def extract_dates(obj):
105 """extract ISO8601 dates from unpacked JSON"""
106 if isinstance(obj, dict):
107 for k,v in obj.iteritems():
108 obj[k] = extract_dates(v)
109 elif isinstance(obj, list):
110 obj = [ extract_dates(o) for o in obj ]
111 elif isinstance(obj, basestring):
112 if ISO8601_RE.match(obj):
113 obj = datetime.strptime(obj, ISO8601)
114 return obj
115
116 99 def validate_url(url):
117 100 """validate a url for zeromq"""
118 101 if not isinstance(url, basestring):
119 102 raise TypeError("url must be a string, not %r"%type(url))
120 103 url = url.lower()
121 104
122 105 proto_addr = url.split('://')
123 106 assert len(proto_addr) == 2, 'Invalid url: %r'%url
124 107 proto, addr = proto_addr
125 108 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
126 109
127 110 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
128 111 # author: Remi Sabourin
129 112 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
130 113
131 114 if proto == 'tcp':
132 115 lis = addr.split(':')
133 116 assert len(lis) == 2, 'Invalid url: %r'%url
134 117 addr,s_port = lis
135 118 try:
136 119 port = int(s_port)
137 120 except ValueError:
138 121 raise AssertionError("Invalid port %r in url: %r"%(port, url))
139 122
140 123 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
141 124
142 125 else:
143 126 # only validate tcp urls currently
144 127 pass
145 128
146 129 return True
147 130
148 131
149 132 def validate_url_container(container):
150 133 """validate a potentially nested collection of urls."""
151 134 if isinstance(container, basestring):
152 135 url = container
153 136 return validate_url(url)
154 137 elif isinstance(container, dict):
155 138 container = container.itervalues()
156 139
157 140 for element in container:
158 141 validate_url_container(element)
159 142
160 143
161 144 def split_url(url):
162 145 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
163 146 proto_addr = url.split('://')
164 147 assert len(proto_addr) == 2, 'Invalid url: %r'%url
165 148 proto, addr = proto_addr
166 149 lis = addr.split(':')
167 150 assert len(lis) == 2, 'Invalid url: %r'%url
168 151 addr,s_port = lis
169 152 return proto,addr,s_port
170 153
171 154 def disambiguate_ip_address(ip, location=None):
172 155 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
173 156 ones, based on the location (default interpretation of location is localhost)."""
174 157 if ip in ('0.0.0.0', '*'):
175 158 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
176 159 if location is None or location in external_ips:
177 160 ip='127.0.0.1'
178 161 elif location:
179 162 return location
180 163 return ip
181 164
182 165 def disambiguate_url(url, location=None):
183 166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
184 167 ones, based on the location (default interpretation is localhost).
185 168
186 169 This is for zeromq urls, such as tcp://*:10101."""
187 170 try:
188 171 proto,ip,port = split_url(url)
189 172 except AssertionError:
190 173 # probably not tcp url; could be ipc, etc.
191 174 return url
192 175
193 176 ip = disambiguate_ip_address(ip,location)
194 177
195 178 return "%s://%s:%s"%(proto,ip,port)
196 179
197 180
198 181 def rekey(dikt):
199 182 """Rekey a dict that has been forced to use str keys where there should be
200 183 ints by json. This belongs in the jsonutil added by fperez."""
201 184 for k in dikt.iterkeys():
202 185 if isinstance(k, str):
203 186 ik=fk=None
204 187 try:
205 188 ik = int(k)
206 189 except ValueError:
207 190 try:
208 191 fk = float(k)
209 192 except ValueError:
210 193 continue
211 194 if ik is not None:
212 195 nk = ik
213 196 else:
214 197 nk = fk
215 198 if nk in dikt:
216 199 raise KeyError("already have key %r"%nk)
217 200 dikt[nk] = dikt.pop(k)
218 201 return dikt
219 202
220 203 def serialize_object(obj, threshold=64e-6):
221 204 """Serialize an object into a list of sendable buffers.
222 205
223 206 Parameters
224 207 ----------
225 208
226 209 obj : object
227 210 The object to be serialized
228 211 threshold : float
229 212 The threshold for not double-pickling the content.
230 213
231 214
232 215 Returns
233 216 -------
234 217 ('pmd', [bufs]) :
235 218 where pmd is the pickled metadata wrapper,
236 219 bufs is a list of data buffers
237 220 """
238 221 databuffers = []
239 222 if isinstance(obj, (list, tuple)):
240 223 clist = canSequence(obj)
241 224 slist = map(serialize, clist)
242 225 for s in slist:
243 226 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
244 227 databuffers.append(s.getData())
245 228 s.data = None
246 229 return pickle.dumps(slist,-1), databuffers
247 230 elif isinstance(obj, dict):
248 231 sobj = {}
249 232 for k in sorted(obj.iterkeys()):
250 233 s = serialize(can(obj[k]))
251 234 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
252 235 databuffers.append(s.getData())
253 236 s.data = None
254 237 sobj[k] = s
255 238 return pickle.dumps(sobj,-1),databuffers
256 239 else:
257 240 s = serialize(can(obj))
258 241 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
259 242 databuffers.append(s.getData())
260 243 s.data = None
261 244 return pickle.dumps(s,-1),databuffers
262 245
263 246
264 247 def unserialize_object(bufs):
265 248 """reconstruct an object serialized by serialize_object from data buffers."""
266 249 bufs = list(bufs)
267 250 sobj = pickle.loads(bufs.pop(0))
268 251 if isinstance(sobj, (list, tuple)):
269 252 for s in sobj:
270 253 if s.data is None:
271 254 s.data = bufs.pop(0)
272 255 return uncanSequence(map(unserialize, sobj)), bufs
273 256 elif isinstance(sobj, dict):
274 257 newobj = {}
275 258 for k in sorted(sobj.iterkeys()):
276 259 s = sobj[k]
277 260 if s.data is None:
278 261 s.data = bufs.pop(0)
279 262 newobj[k] = uncan(unserialize(s))
280 263 return newobj, bufs
281 264 else:
282 265 if sobj.data is None:
283 266 sobj.data = bufs.pop(0)
284 267 return uncan(unserialize(sobj)), bufs
285 268
286 269 def pack_apply_message(f, args, kwargs, threshold=64e-6):
287 270 """pack up a function, args, and kwargs to be sent over the wire
288 271 as a series of buffers. Any object whose data is larger than `threshold`
289 272 will not have their data copied (currently only numpy arrays support zero-copy)"""
290 273 msg = [pickle.dumps(can(f),-1)]
291 274 databuffers = [] # for large objects
292 275 sargs, bufs = serialize_object(args,threshold)
293 276 msg.append(sargs)
294 277 databuffers.extend(bufs)
295 278 skwargs, bufs = serialize_object(kwargs,threshold)
296 279 msg.append(skwargs)
297 280 databuffers.extend(bufs)
298 281 msg.extend(databuffers)
299 282 return msg
300 283
301 284 def unpack_apply_message(bufs, g=None, copy=True):
302 285 """unpack f,args,kwargs from buffers packed by pack_apply_message()
303 286 Returns: original f,args,kwargs"""
304 287 bufs = list(bufs) # allow us to pop
305 288 assert len(bufs) >= 3, "not enough buffers!"
306 289 if not copy:
307 290 for i in range(3):
308 291 bufs[i] = bufs[i].bytes
309 292 cf = pickle.loads(bufs.pop(0))
310 293 sargs = list(pickle.loads(bufs.pop(0)))
311 294 skwargs = dict(pickle.loads(bufs.pop(0)))
312 295 # print sargs, skwargs
313 296 f = uncan(cf, g)
314 297 for sa in sargs:
315 298 if sa.data is None:
316 299 m = bufs.pop(0)
317 300 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
318 301 # always use a buffer, until memoryviews get sorted out
319 302 sa.data = buffer(m)
320 303 # disable memoryview support
321 304 # if copy:
322 305 # sa.data = buffer(m)
323 306 # else:
324 307 # sa.data = m.buffer
325 308 else:
326 309 if copy:
327 310 sa.data = m
328 311 else:
329 312 sa.data = m.bytes
330 313
331 314 args = uncanSequence(map(unserialize, sargs), g)
332 315 kwargs = {}
333 316 for k in sorted(skwargs.iterkeys()):
334 317 sa = skwargs[k]
335 318 if sa.data is None:
336 319 m = bufs.pop(0)
337 320 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
338 321 # always use a buffer, until memoryviews get sorted out
339 322 sa.data = buffer(m)
340 323 # disable memoryview support
341 324 # if copy:
342 325 # sa.data = buffer(m)
343 326 # else:
344 327 # sa.data = m.buffer
345 328 else:
346 329 if copy:
347 330 sa.data = m
348 331 else:
349 332 sa.data = m.bytes
350 333
351 334 kwargs[k] = uncan(unserialize(sa), g)
352 335
353 336 return f,args,kwargs
354 337
355 338 #--------------------------------------------------------------------------
356 339 # helpers for implementing old MEC API via view.apply
357 340 #--------------------------------------------------------------------------
358 341
359 342 def interactive(f):
360 343 """decorator for making functions appear as interactively defined.
361 344 This results in the function being linked to the user_ns as globals()
362 345 instead of the module globals().
363 346 """
364 347 f.__module__ = '__main__'
365 348 return f
366 349
367 350 @interactive
368 351 def _push(ns):
369 352 """helper method for implementing `client.push` via `client.apply`"""
370 353 globals().update(ns)
371 354
372 355 @interactive
373 356 def _pull(keys):
374 357 """helper method for implementing `client.pull` via `client.apply`"""
375 358 user_ns = globals()
376 359 if isinstance(keys, (list,tuple, set)):
377 360 for key in keys:
378 361 if not user_ns.has_key(key):
379 362 raise NameError("name '%s' is not defined"%key)
380 363 return map(user_ns.get, keys)
381 364 else:
382 365 if not user_ns.has_key(keys):
383 366 raise NameError("name '%s' is not defined"%keys)
384 367 return user_ns.get(keys)
385 368
386 369 @interactive
387 370 def _execute(code):
388 371 """helper method for implementing `client.execute` via `client.apply`"""
389 372 exec code in globals()
390 373
391 374 #--------------------------------------------------------------------------
392 375 # extra process management utilities
393 376 #--------------------------------------------------------------------------
394 377
395 378 _random_ports = set()
396 379
397 380 def select_random_ports(n):
398 381 """Selects and return n random ports that are available."""
399 382 ports = []
400 383 for i in xrange(n):
401 384 sock = socket.socket()
402 385 sock.bind(('', 0))
403 386 while sock.getsockname()[1] in _random_ports:
404 387 sock.close()
405 388 sock = socket.socket()
406 389 sock.bind(('', 0))
407 390 ports.append(sock)
408 391 for i, sock in enumerate(ports):
409 392 port = sock.getsockname()[1]
410 393 sock.close()
411 394 ports[i] = port
412 395 _random_ports.add(port)
413 396 return ports
414 397
415 398 def signal_children(children):
416 399 """Relay interupt/term signals to children, for more solid process cleanup."""
417 400 def terminate_children(sig, frame):
418 401 logging.critical("Got signal %i, terminating children..."%sig)
419 402 for child in children:
420 403 child.terminate()
421 404
422 405 sys.exit(sig != SIGINT)
423 406 # sys.exit(sig)
424 407 for sig in (SIGINT, SIGABRT, SIGTERM):
425 408 signal(sig, terminate_children)
426 409
427 410 def generate_exec_key(keyfile):
428 411 import uuid
429 412 newkey = str(uuid.uuid4())
430 413 with open(keyfile, 'w') as f:
431 414 # f.write('ipython-key ')
432 415 f.write(newkey+'\n')
433 416 # set user-only RW permissions (0600)
434 417 # this will have no effect on Windows
435 418 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
436 419
437 420
438 421 def integer_loglevel(loglevel):
439 422 try:
440 423 loglevel = int(loglevel)
441 424 except ValueError:
442 425 if isinstance(loglevel, str):
443 426 loglevel = getattr(logging, loglevel)
444 427 return loglevel
445 428
446 429 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
447 430 logger = logging.getLogger(logname)
448 431 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
449 432 # don't add a second PUBHandler
450 433 return
451 434 loglevel = integer_loglevel(loglevel)
452 435 lsock = context.socket(zmq.PUB)
453 436 lsock.connect(iface)
454 437 handler = handlers.PUBHandler(lsock)
455 438 handler.setLevel(loglevel)
456 439 handler.root_topic = root
457 440 logger.addHandler(handler)
458 441 logger.setLevel(loglevel)
459 442
460 443 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
461 444 logger = logging.getLogger()
462 445 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
463 446 # don't add a second PUBHandler
464 447 return
465 448 loglevel = integer_loglevel(loglevel)
466 449 lsock = context.socket(zmq.PUB)
467 450 lsock.connect(iface)
468 451 handler = EnginePUBHandler(engine, lsock)
469 452 handler.setLevel(loglevel)
470 453 logger.addHandler(handler)
471 454 logger.setLevel(loglevel)
472 455
473 456 def local_logger(logname, loglevel=logging.DEBUG):
474 457 loglevel = integer_loglevel(loglevel)
475 458 logger = logging.getLogger(logname)
476 459 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
477 460 # don't add a second StreamHandler
478 461 return
479 462 handler = logging.StreamHandler()
480 463 handler.setLevel(loglevel)
481 464 logger.addHandler(handler)
482 465 logger.setLevel(loglevel)
483 466
@@ -1,90 +1,121 b''
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 import re
14 15 import types
16 from datetime import datetime
17
18 #-----------------------------------------------------------------------------
19 # Globals and constants
20 #-----------------------------------------------------------------------------
21
22 # timestamp formats
23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
15 25
16 26 #-----------------------------------------------------------------------------
17 27 # Classes and functions
18 28 #-----------------------------------------------------------------------------
19 29
30 def extract_dates(obj):
31 """extract ISO8601 dates from unpacked JSON"""
32 if isinstance(obj, dict):
33 for k,v in obj.iteritems():
34 obj[k] = extract_dates(v)
35 elif isinstance(obj, list):
36 obj = [ extract_dates(o) for o in obj ]
37 elif isinstance(obj, basestring):
38 if ISO8601_PAT.match(obj):
39 obj = datetime.strptime(obj, ISO8601)
40 return obj
41
42 def date_default(obj):
43 """default function for packing datetime objects"""
44 if isinstance(obj, datetime):
45 return obj.strftime(ISO8601)
46 else:
47 raise TypeError("%r is not JSON serializable"%obj)
48
49
50
20 51 def json_clean(obj):
21 52 """Clean an object to ensure it's safe to encode in JSON.
22 53
23 54 Atomic, immutable objects are returned unmodified. Sets and tuples are
24 55 converted to lists, lists are copied and dicts are also copied.
25 56
26 57 Note: dicts whose keys could cause collisions upon encoding (such as a dict
27 58 with both the number 1 and the string '1' as keys) will cause a ValueError
28 59 to be raised.
29 60
30 61 Parameters
31 62 ----------
32 63 obj : any python object
33 64
34 65 Returns
35 66 -------
36 67 out : object
37 68
38 69 A version of the input which will not cause an encoding error when
39 70 encoded as JSON. Note that this function does not *encode* its inputs,
40 71 it simply sanitizes it so that there will be no encoding errors later.
41 72
42 73 Examples
43 74 --------
44 75 >>> json_clean(4)
45 76 4
46 77 >>> json_clean(range(10))
47 78 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
48 79 >>> json_clean(dict(x=1, y=2))
49 80 {'y': 2, 'x': 1}
50 81 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
51 82 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
52 83 >>> json_clean(True)
53 84 True
54 85 """
55 86 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
56 87 # listed explicitly because bools pass as int instances
57 88 atomic_ok = (basestring, int, float, types.NoneType)
58 89
59 90 # containers that we need to convert into lists
60 91 container_to_list = (tuple, set, types.GeneratorType)
61 92
62 93 if isinstance(obj, atomic_ok):
63 94 return obj
64 95
65 96 if isinstance(obj, container_to_list) or (
66 97 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
67 98 obj = list(obj)
68 99
69 100 if isinstance(obj, list):
70 101 return [json_clean(x) for x in obj]
71 102
72 103 if isinstance(obj, dict):
73 104 # First, validate that the dict won't lose data in conversion due to
74 105 # key collisions after stringification. This can happen with keys like
75 106 # True and 'true' or 1 and '1', which collide in JSON.
76 107 nkeys = len(obj)
77 108 nkeys_collapsed = len(set(map(str, obj)))
78 109 if nkeys != nkeys_collapsed:
79 110 raise ValueError('dict can not be safely converted to JSON: '
80 111 'key collision would lead to dropped values')
81 112 # If all OK, proceed by making the new dict that will be json-safe
82 113 out = {}
83 114 for k,v in obj.iteritems():
84 115 out[str(k)] = json_clean(v)
85 116 return out
86 117
87 118 # If we get here, we don't know how to handle the object, so we just get
88 119 # its repr and return that. This will catch lambdas, open sockets, class
89 120 # objects, and any other complicated contraption that json can't encode
90 121 return repr(obj)
@@ -1,184 +1,479 b''
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
3 """
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2010-2011 The IPython Development Team
6 #
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
10
11 #-----------------------------------------------------------------------------
12 # Imports
13 #-----------------------------------------------------------------------------
14
15 import hmac
1 16 import os
2 import uuid
3 17 import pprint
18 import uuid
19 from datetime import datetime
20
21 try:
22 import cPickle
23 pickle = cPickle
24 except:
25 cPickle = None
26 import pickle
4 27
5 28 import zmq
29 from zmq.utils import jsonapi
30 from zmq.eventloop.zmqstream import ZMQStream
31
32 from IPython.config.configurable import Configurable
33 from IPython.utils.importstring import import_item
34 from IPython.utils.jsonutil import date_default
35 from IPython.utils.traitlets import CStr, Unicode, Bool, Any, Instance, Set
36
37 #-----------------------------------------------------------------------------
38 # utility functions
39 #-----------------------------------------------------------------------------
40
41 def squash_unicode(obj):
42 """coerce unicode back to bytestrings."""
43 if isinstance(obj,dict):
44 for key in obj.keys():
45 obj[key] = squash_unicode(obj[key])
46 if isinstance(key, unicode):
47 obj[squash_unicode(key)] = obj.pop(key)
48 elif isinstance(obj, list):
49 for i,v in enumerate(obj):
50 obj[i] = squash_unicode(v)
51 elif isinstance(obj, unicode):
52 obj = obj.encode('utf8')
53 return obj
54
55 #-----------------------------------------------------------------------------
56 # globals and defaults
57 #-----------------------------------------------------------------------------
58
59 _default_key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
60 json_packer = lambda obj: jsonapi.dumps(obj, **{_default_key:date_default})
61 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
6 62
7 from zmq.utils import jsonapi as json
63 pickle_packer = lambda o: pickle.dumps(o,-1)
64 pickle_unpacker = pickle.loads
65
66 default_packer = json_packer
67 default_unpacker = json_unpacker
68
69
70 DELIM="<IDS|MSG>"
71
72 #-----------------------------------------------------------------------------
73 # Classes
74 #-----------------------------------------------------------------------------
8 75
9 76 class Message(object):
10 77 """A simple message object that maps dict keys to attributes.
11 78
12 79 A Message can be created from a dict and a dict from a Message instance
13 80 simply by calling dict(msg_obj)."""
14 81
15 82 def __init__(self, msg_dict):
16 83 dct = self.__dict__
17 for k, v in msg_dict.iteritems():
84 for k, v in dict(msg_dict).iteritems():
18 85 if isinstance(v, dict):
19 86 v = Message(v)
20 87 dct[k] = v
21 88
22 89 # Having this iterator lets dict(msg_obj) work out of the box.
23 90 def __iter__(self):
24 91 return iter(self.__dict__.iteritems())
25 92
26 93 def __repr__(self):
27 94 return repr(self.__dict__)
28 95
29 96 def __str__(self):
30 97 return pprint.pformat(self.__dict__)
31 98
32 99 def __contains__(self, k):
33 100 return k in self.__dict__
34 101
35 102 def __getitem__(self, k):
36 103 return self.__dict__[k]
37 104
38 105
39 def msg_header(msg_id, username, session):
40 return {
41 'msg_id' : msg_id,
42 'username' : username,
43 'session' : session
44 }
45
106 def msg_header(msg_id, msg_type, username, session):
107 date=datetime.now()
108 return locals()
46 109
47 110 def extract_header(msg_or_header):
48 111 """Given a message or header, return the header."""
49 112 if not msg_or_header:
50 113 return {}
51 114 try:
52 115 # See if msg_or_header is the entire message.
53 116 h = msg_or_header['header']
54 117 except KeyError:
55 118 try:
56 119 # See if msg_or_header is just the header
57 120 h = msg_or_header['msg_id']
58 121 except KeyError:
59 122 raise
60 123 else:
61 124 h = msg_or_header
62 125 if not isinstance(h, dict):
63 126 h = dict(h)
64 127 return h
65 128
66
67 class Session(object):
68
69 def __init__(self, username=os.environ.get('USER','username'), session=None):
70 self.username = username
71 if session is None:
72 self.session = str(uuid.uuid4())
129 class Session(Configurable):
130 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
131 debug=Bool(False, config=True, help="""Debug output in the Session""")
132 packer = Unicode('json',config=True,
133 help="""The name of the packer for serializing messages.
134 Should be one of 'json', 'pickle', or an import name
135 for a custom serializer.""")
136 def _packer_changed(self, name, old, new):
137 if new.lower() == 'json':
138 self.pack = json_packer
139 self.unpack = json_unpacker
140 elif new.lower() == 'pickle':
141 self.pack = pickle_packer
142 self.unpack = pickle_unpacker
73 143 else:
74 self.session = session
75 self.msg_id = 0
144 self.pack = import_item(new)
76 145
77 def msg_header(self):
78 h = msg_header(self.msg_id, self.username, self.session)
79 self.msg_id += 1
80 return h
146 unpacker = Unicode('json',config=True,
147 help="""The name of the unpacker for unserializing messages.
148 Only used with custom functions for `packer`.""")
149 def _unpacker_changed(self, name, old, new):
150 if new.lower() == 'json':
151 self.pack = json_packer
152 self.unpack = json_unpacker
153 elif new.lower() == 'pickle':
154 self.pack = pickle_packer
155 self.unpack = pickle_unpacker
156 else:
157 self.unpack = import_item(new)
158
159 session = CStr('',config=True,
160 help="""The UUID identifying this session.""")
161 def _session_default(self):
162 return bytes(uuid.uuid4())
163 username = Unicode(os.environ.get('USER','username'), config=True,
164 help="""Username for the Session. Default is your system username.""")
165
166 # message signature related traits:
167 key = CStr('', config=True,
168 help="""execution key, for extra authentication.""")
169 def _key_changed(self, name, old, new):
170 if new:
171 self.auth = hmac.HMAC(new)
172 else:
173 self.auth = None
174 auth = Instance(hmac.HMAC)
175 counters = Instance('collections.defaultdict', (int,))
176 digest_history = Set()
177
178 keyfile = Unicode('', config=True,
179 help="""path to file containing execution key.""")
180 def _keyfile_changed(self, name, old, new):
181 with open(new, 'rb') as f:
182 self.key = f.read().strip()
81 183
82 def msg(self, msg_type, content=None, parent=None):
83 """Construct a standard-form message, with a given type, content, and parent.
184 pack = Any(default_packer) # the actual packer function
185 def _pack_changed(self, name, old, new):
186 if not callable(new):
187 raise TypeError("packer must be callable, not %s"%type(new))
84 188
85 NOT to be called directly.
86 """
189 unpack = Any(default_unpacker) # the actual packer function
190 def _unpack_changed(self, name, old, new):
191 if not callable(new):
192 raise TypeError("packer must be callable, not %s"%type(new))
193
194 def __init__(self, **kwargs):
195 super(Session, self).__init__(**kwargs)
196 self.none = self.pack({})
197
198 @property
199 def msg_id(self):
200 """always return new uuid"""
201 return str(uuid.uuid4())
202
203 def msg_header(self, msg_type):
204 return msg_header(self.msg_id, msg_type, self.username, self.session)
205
206 def msg(self, msg_type, content=None, parent=None, subheader=None):
87 207 msg = {}
88 msg['header'] = self.msg_header()
208 msg['header'] = self.msg_header(msg_type)
209 msg['msg_id'] = msg['header']['msg_id']
89 210 msg['parent_header'] = {} if parent is None else extract_header(parent)
90 211 msg['msg_type'] = msg_type
91 212 msg['content'] = {} if content is None else content
213 sub = {} if subheader is None else subheader
214 msg['header'].update(sub)
92 215 return msg
93 216
94 def send(self, socket, msg_or_type, content=None, parent=None, ident=None):
95 """send a message via a socket, using a uniform message pattern.
217 def check_key(self, msg_or_header):
218 """Check that a message's header has the right key"""
219 if not self.key:
220 return True
221 header = extract_header(msg_or_header)
222 return header.get('key', '') == self.key
223
224 def sign(self, msg):
225 """Sign a message with HMAC digest. If no auth, return b''."""
226 if self.auth is None:
227 return b''
228 h = self.auth.copy()
229 for m in msg:
230 h.update(m)
231 return h.hexdigest()
232
233 def serialize(self, msg, ident=None):
234 content = msg.get('content', {})
235 if content is None:
236 content = self.none
237 elif isinstance(content, dict):
238 content = self.pack(content)
239 elif isinstance(content, bytes):
240 # content is already packed, as in a relayed message
241 pass
242 elif isinstance(content, unicode):
243 # should be bytes, but JSON often spits out unicode
244 content = content.encode('utf8')
245 else:
246 raise TypeError("Content incorrect type: %s"%type(content))
247
248 real_message = [self.pack(msg['header']),
249 self.pack(msg['parent_header']),
250 content
251 ]
252
253 to_send = []
254
255 if isinstance(ident, list):
256 # accept list of idents
257 to_send.extend(ident)
258 elif ident is not None:
259 to_send.append(ident)
260 to_send.append(DELIM)
261
262 signature = self.sign(real_message)
263 to_send.append(signature)
264
265 to_send.extend(real_message)
266
267 return to_send
268
269 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
270 buffers=None, subheader=None, track=False):
271 """Build and send a message via stream or socket.
96 272
97 273 Parameters
98 274 ----------
99 socket : zmq.Socket
100 The socket on which to send.
101 msg_or_type : Message/dict or str
102 if str : then a new message will be constructed from content,parent
103 if Message/dict : then content and parent are ignored, and the message
104 is sent. This is only for use when sending a Message for a second time.
105 content : dict, optional
106 The contents of the message
107 parent : dict, optional
108 The parent header, or parent message, of this message
109 ident : bytes, optional
110 The zmq.IDENTITY prefix of the destination.
111 Only for use on certain socket types.
275
276 stream : zmq.Socket or ZMQStream
277 the socket-like object used to send the data
278 msg_or_type : str or Message/dict
279 Normally, msg_or_type will be a msg_type unless a message is being sent more
280 than once.
281
282 content : dict or None
283 the content of the message (ignored if msg_or_type is a message)
284 parent : Message or dict or None
285 the parent or parent header describing the parent of this message
286 ident : bytes or list of bytes
287 the zmq.IDENTITY routing path
288 subheader : dict or None
289 extra header keys for this message's header
290 buffers : list or None
291 the already-serialized buffers to be appended to the message
292 track : bool
293 whether to track. Only for use with Sockets,
294 because ZMQStream objects cannot track messages.
112 295
113 296 Returns
114 297 -------
115 msg : dict
116 The message, as constructed by self.msg(msg_type,content,parent)
298 msg : message dict
299 the constructed message
300 (msg,tracker) : (message dict, MessageTracker)
301 if track=True, then a 2-tuple will be returned,
302 the first element being the constructed
303 message, and the second being the MessageTracker
304
117 305 """
306
307 if not isinstance(stream, (zmq.Socket, ZMQStream)):
308 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
309 elif track and isinstance(stream, ZMQStream):
310 raise TypeError("ZMQStream cannot track messages")
311
118 312 if isinstance(msg_or_type, (Message, dict)):
119 msg = dict(msg_or_type)
313 # we got a Message, not a msg_type
314 # don't build a new Message
315 msg = msg_or_type
120 316 else:
121 msg = self.msg(msg_or_type, content, parent)
122 if ident is not None:
123 socket.send(ident, zmq.SNDMORE)
124 socket.send_json(msg)
125 return msg
126
127 def recv(self, socket, mode=zmq.NOBLOCK):
128 """recv a message on a socket.
317 msg = self.msg(msg_or_type, content, parent, subheader)
129 318
130 Receive an optionally identity-prefixed message, as sent via session.send().
319 buffers = [] if buffers is None else buffers
320 to_send = self.serialize(msg, ident)
321 flag = 0
322 if buffers:
323 flag = zmq.SNDMORE
324 _track = False
325 else:
326 _track=track
327 if track:
328 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
329 else:
330 tracker = stream.send_multipart(to_send, flag, copy=False)
331 for b in buffers[:-1]:
332 stream.send(b, flag, copy=False)
333 if buffers:
334 if track:
335 tracker = stream.send(buffers[-1], copy=False, track=track)
336 else:
337 tracker = stream.send(buffers[-1], copy=False)
338
339 # omsg = Message(msg)
340 if self.debug:
341 pprint.pprint(msg)
342 pprint.pprint(to_send)
343 pprint.pprint(buffers)
131 344
132 Parameters
133 ----------
345 msg['tracker'] = tracker
134 346
135 socket : zmq.Socket
136 The socket on which to recv a message.
137 mode : int, optional
138 the mode flag passed to socket.recv
139 default: zmq.NOBLOCK
347 return msg
348
349 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
350 """Send a raw message via ident path.
140 351
141 Returns
142 -------
143 (ident,msg) : tuple
144 always length 2. If no message received, then return is (None,None)
145 ident : bytes or None
146 the identity prefix is there was one, None otherwise.
147 msg : dict or None
148 The actual message. If mode==zmq.NOBLOCK and no message was waiting,
149 it will be None.
150 """
352 Parameters
353 ----------
354 msg : list of sendable buffers"""
355 to_send = []
356 if isinstance(ident, bytes):
357 ident = [ident]
358 if ident is not None:
359 to_send.extend(ident)
360
361 to_send.append(DELIM)
362 to_send.append(self.sign(msg))
363 to_send.extend(msg)
364 stream.send_multipart(msg, flags, copy=copy)
365
366 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
367 """receives and unpacks a message
368 returns [idents], msg"""
369 if isinstance(socket, ZMQStream):
370 socket = socket.socket
151 371 try:
152 372 msg = socket.recv_multipart(mode)
153 except zmq.ZMQError, e:
373 except zmq.ZMQError as e:
154 374 if e.errno == zmq.EAGAIN:
155 375 # We can convert EAGAIN to None as we know in this case
156 # recv_json won't return None.
376 # recv_multipart won't return None.
157 377 return None,None
158 378 else:
159 379 raise
160 if len(msg) == 1:
161 ident=None
162 msg = msg[0]
163 elif len(msg) == 2:
164 ident, msg = msg
380 # return an actual Message object
381 # determine the number of idents by trying to unpack them.
382 # this is terrible:
383 idents, msg = self.feed_identities(msg, copy)
384 try:
385 return idents, self.unpack_message(msg, content=content, copy=copy)
386 except Exception as e:
387 print (idents, msg)
388 # TODO: handle it
389 raise e
390
391 def feed_identities(self, msg, copy=True):
392 """feed until DELIM is reached, then return the prefix as idents and remainder as
393 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
394
395 Parameters
396 ----------
397 msg : a list of Message or bytes objects
398 the message to be split
399 copy : bool
400 flag determining whether the arguments are bytes or Messages
401
402 Returns
403 -------
404 (idents,msg) : two lists
405 idents will always be a list of bytes - the indentity prefix
406 msg will be a list of bytes or Messages, unchanged from input
407 msg should be unpackable via self.unpack_message at this point.
408 """
409 if copy:
410 idx = msg.index(DELIM)
411 return msg[:idx], msg[idx+1:]
412 else:
413 failed = True
414 for idx,m in enumerate(msg):
415 if m.bytes == DELIM:
416 failed = False
417 break
418 if failed:
419 raise ValueError("DELIM not in msg")
420 idents, msg = msg[:idx], msg[idx+1:]
421 return [m.bytes for m in idents], msg
422
423 def unpack_message(self, msg, content=True, copy=True):
424 """Return a message object from the format
425 sent by self.send.
426
427 Parameters:
428 -----------
429
430 content : bool (True)
431 whether to unpack the content dict (True),
432 or leave it serialized (False)
433
434 copy : bool (True)
435 whether to return the bytes (True),
436 or the non-copying Message object in each place (False)
437
438 """
439 minlen = 4
440 message = {}
441 if not copy:
442 for i in range(minlen):
443 msg[i] = msg[i].bytes
444 if self.auth is not None:
445 signature = msg[0]
446 if signature in self.digest_history:
447 raise ValueError("Duplicate Signature: %r"%signature)
448 self.digest_history.add(signature)
449 check = self.sign(msg[1:4])
450 if not signature == check:
451 raise ValueError("Invalid Signature: %r"%signature)
452 if not len(msg) >= minlen:
453 raise TypeError("malformed message, must have at least %i elements"%minlen)
454 message['header'] = self.unpack(msg[1])
455 message['msg_type'] = message['header']['msg_type']
456 message['parent_header'] = self.unpack(msg[2])
457 if content:
458 message['content'] = self.unpack(msg[3])
165 459 else:
166 raise ValueError("Got message with length > 2, which is invalid")
460 message['content'] = msg[3]
167 461
168 return ident, json.loads(msg)
462 message['buffers'] = msg[4:]
463 return message
169 464
170 465 def test_msg2obj():
171 466 am = dict(x=1)
172 467 ao = Message(am)
173 468 assert ao.x == am['x']
174 469
175 470 am['y'] = dict(z=1)
176 471 ao = Message(am)
177 472 assert ao.y.z == am['y']['z']
178 473
179 474 k1, k2 = 'y', 'z'
180 475 assert ao[k1][k2] == am[k1][k2]
181 476
182 477 am2 = dict(ao)
183 478 assert am['x'] == am2['x']
184 479 assert am['y']['z'] == am2['y']['z']
@@ -1,111 +1,111 b''
1 1 """test building messages with streamsession"""
2 2
3 3 #-------------------------------------------------------------------------------
4 4 # Copyright (C) 2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-------------------------------------------------------------------------------
9 9
10 10 #-------------------------------------------------------------------------------
11 11 # Imports
12 12 #-------------------------------------------------------------------------------
13 13
14 14 import os
15 15 import uuid
16 16 import zmq
17 17
18 18 from zmq.tests import BaseZMQTestCase
19 19 from zmq.eventloop.zmqstream import ZMQStream
20 # from IPython.zmq.tests import SessionTestCase
21 from IPython.parallel import streamsession as ss
20
21 from IPython.zmq import session as ss
22 22
23 23 class SessionTestCase(BaseZMQTestCase):
24 24
25 25 def setUp(self):
26 26 BaseZMQTestCase.setUp(self)
27 self.session = ss.StreamSession()
27 self.session = ss.Session()
28 28
29 29 class TestSession(SessionTestCase):
30 30
31 31 def test_msg(self):
32 32 """message format"""
33 33 msg = self.session.msg('execute')
34 34 thekeys = set('header msg_id parent_header msg_type content'.split())
35 35 s = set(msg.keys())
36 36 self.assertEquals(s, thekeys)
37 37 self.assertTrue(isinstance(msg['content'],dict))
38 38 self.assertTrue(isinstance(msg['header'],dict))
39 39 self.assertTrue(isinstance(msg['parent_header'],dict))
40 40 self.assertEquals(msg['msg_type'], 'execute')
41 41
42 42
43 43
44 44 def test_args(self):
45 """initialization arguments for StreamSession"""
45 """initialization arguments for Session"""
46 46 s = self.session
47 47 self.assertTrue(s.pack is ss.default_packer)
48 48 self.assertTrue(s.unpack is ss.default_unpacker)
49 49 self.assertEquals(s.username, os.environ.get('USER', 'username'))
50 50
51 s = ss.StreamSession()
51 s = ss.Session()
52 52 self.assertEquals(s.username, os.environ.get('USER', 'username'))
53 53
54 self.assertRaises(TypeError, ss.StreamSession, pack='hi')
55 self.assertRaises(TypeError, ss.StreamSession, unpack='hi')
54 self.assertRaises(TypeError, ss.Session, pack='hi')
55 self.assertRaises(TypeError, ss.Session, unpack='hi')
56 56 u = str(uuid.uuid4())
57 s = ss.StreamSession(username='carrot', session=u)
57 s = ss.Session(username='carrot', session=u)
58 58 self.assertEquals(s.session, u)
59 59 self.assertEquals(s.username, 'carrot')
60 60
61 61 def test_tracking(self):
62 62 """test tracking messages"""
63 63 a,b = self.create_bound_pair(zmq.PAIR, zmq.PAIR)
64 64 s = self.session
65 65 stream = ZMQStream(a)
66 66 msg = s.send(a, 'hello', track=False)
67 67 self.assertTrue(msg['tracker'] is None)
68 68 msg = s.send(a, 'hello', track=True)
69 69 self.assertTrue(isinstance(msg['tracker'], zmq.MessageTracker))
70 70 M = zmq.Message(b'hi there', track=True)
71 71 msg = s.send(a, 'hello', buffers=[M], track=True)
72 72 t = msg['tracker']
73 73 self.assertTrue(isinstance(t, zmq.MessageTracker))
74 74 self.assertRaises(zmq.NotDone, t.wait, .1)
75 75 del M
76 76 t.wait(1) # this will raise
77 77
78 78
79 79 # def test_rekey(self):
80 80 # """rekeying dict around json str keys"""
81 81 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
82 82 # self.assertRaises(KeyError, ss.rekey, d)
83 83 #
84 84 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
85 85 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
86 86 # rd = ss.rekey(d)
87 87 # self.assertEquals(d2,rd)
88 88 #
89 89 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
90 90 # d2 = {1.5:d['1.5'],1:d['1']}
91 91 # rd = ss.rekey(d)
92 92 # self.assertEquals(d2,rd)
93 93 #
94 94 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
95 95 # self.assertRaises(KeyError, ss.rekey, d)
96 96 #
97 97 def test_unique_msg_ids(self):
98 98 """test that messages receive unique ids"""
99 99 ids = set()
100 100 for i in range(2**12):
101 101 h = self.session.msg_header('test')
102 102 msg_id = h['msg_id']
103 103 self.assertTrue(msg_id not in ids)
104 104 ids.add(msg_id)
105 105
106 106 def test_feed_identities(self):
107 107 """scrub the front for zmq IDENTITIES"""
108 108 theids = "engine client other".split()
109 109 content = dict(code='whoda',stuff=object())
110 110 themsg = self.session.msg('execute',content=content)
111 111 pmsg = theids
1 NO CONTENT: file was removed
1 NO CONTENT: file was removed
General Comments 0
You need to be logged in to leave comments. Login now