##// END OF EJS Templates
simplify IPython.parallel connections...
MinRK -
Show More
@@ -1,491 +1,497 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import json
27 27 import os
28 28 import socket
29 29 import stat
30 30 import sys
31 31
32 32 from multiprocessing import Process
33 33 from signal import signal, SIGINT, SIGABRT, SIGTERM
34 34
35 35 import zmq
36 36 from zmq.devices import ProcessMonitoredQueue
37 37 from zmq.log.handlers import PUBHandler
38 38
39 39 from IPython.core.profiledir import ProfileDir
40 40
41 41 from IPython.parallel.apps.baseapp import (
42 42 BaseParallelApplication,
43 43 base_aliases,
44 44 base_flags,
45 45 catch_config_error,
46 46 )
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
49 49
50 50 from IPython.zmq.session import (
51 51 Session, session_aliases, session_flags, default_secure
52 52 )
53 53
54 54 from IPython.parallel.controller.heartmonitor import HeartMonitor
55 55 from IPython.parallel.controller.hub import HubFactory
56 56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
57 57 from IPython.parallel.controller.sqlitedb import SQLiteDB
58 58
59 59 from IPython.parallel.util import split_url, disambiguate_url
60 60
61 61 # conditional import of MongoDB backend class
62 62
63 63 try:
64 64 from IPython.parallel.controller.mongodb import MongoDB
65 65 except ImportError:
66 66 maybe_mongo = []
67 67 else:
68 68 maybe_mongo = [MongoDB]
69 69
70 70
71 71 #-----------------------------------------------------------------------------
72 72 # Module level variables
73 73 #-----------------------------------------------------------------------------
74 74
75 75
76 76 #: The default config file name for this application
77 77 default_config_file_name = u'ipcontroller_config.py'
78 78
79 79
80 80 _description = """Start the IPython controller for parallel computing.
81 81
82 82 The IPython controller provides a gateway between the IPython engines and
83 83 clients. The controller needs to be started before the engines and can be
84 84 configured using command line options or using a cluster directory. Cluster
85 85 directories contain config, log and security files and are usually located in
86 86 your ipython directory and named as "profile_name". See the `profile`
87 87 and `profile-dir` options for details.
88 88 """
89 89
90 90 _examples = """
91 91 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
92 92 ipcontroller --scheme=pure # use the pure zeromq scheduler
93 93 """
94 94
95 95
96 96 #-----------------------------------------------------------------------------
97 97 # The main application
98 98 #-----------------------------------------------------------------------------
99 99 flags = {}
100 100 flags.update(base_flags)
101 101 flags.update({
102 102 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
103 103 'Use threads instead of processes for the schedulers'),
104 104 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
105 105 'use the SQLiteDB backend'),
106 106 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
107 107 'use the MongoDB backend'),
108 108 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
109 109 'use the in-memory DictDB backend'),
110 110 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
111 111 """use dummy DB backend, which doesn't store any information.
112 112
113 113 This is the default as of IPython 0.13.
114 114
115 115 To enable delayed or repeated retrieval of results from the Hub,
116 116 select one of the true db backends.
117 117 """),
118 118 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
119 119 'reuse existing json connection files')
120 120 })
121 121
122 122 flags.update(session_flags)
123 123
124 124 aliases = dict(
125 125 ssh = 'IPControllerApp.ssh_server',
126 126 enginessh = 'IPControllerApp.engine_ssh_server',
127 127 location = 'IPControllerApp.location',
128 128
129 129 url = 'HubFactory.url',
130 130 ip = 'HubFactory.ip',
131 131 transport = 'HubFactory.transport',
132 132 port = 'HubFactory.regport',
133 133
134 134 ping = 'HeartMonitor.period',
135 135
136 136 scheme = 'TaskScheduler.scheme_name',
137 137 hwm = 'TaskScheduler.hwm',
138 138 )
139 139 aliases.update(base_aliases)
140 140 aliases.update(session_aliases)
141 141
142 142 class IPControllerApp(BaseParallelApplication):
143 143
144 144 name = u'ipcontroller'
145 145 description = _description
146 146 examples = _examples
147 147 config_file_name = Unicode(default_config_file_name)
148 148 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
149 149
150 150 # change default to True
151 151 auto_create = Bool(True, config=True,
152 152 help="""Whether to create profile dir if it doesn't exist.""")
153 153
154 154 reuse_files = Bool(False, config=True,
155 155 help="""Whether to reuse existing json connection files.
156 156 If False, connection files will be removed on a clean exit.
157 157 """
158 158 )
159 159 ssh_server = Unicode(u'', config=True,
160 160 help="""ssh url for clients to use when connecting to the Controller
161 161 processes. It should be of the form: [user@]server[:port]. The
162 162 Controller's listening addresses must be accessible from the ssh server""",
163 163 )
164 164 engine_ssh_server = Unicode(u'', config=True,
165 165 help="""ssh url for engines to use when connecting to the Controller
166 166 processes. It should be of the form: [user@]server[:port]. The
167 167 Controller's listening addresses must be accessible from the ssh server""",
168 168 )
169 169 location = Unicode(u'', config=True,
170 170 help="""The external IP or domain name of the Controller, used for disambiguating
171 171 engine and client connections.""",
172 172 )
173 173 import_statements = List([], config=True,
174 174 help="import statements to be run at startup. Necessary in some environments"
175 175 )
176 176
177 177 use_threads = Bool(False, config=True,
178 178 help='Use threads instead of processes for the schedulers',
179 179 )
180 180
181 181 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
182 182 help="JSON filename where engine connection info will be stored.")
183 183 client_json_file = Unicode('ipcontroller-client.json', config=True,
184 184 help="JSON filename where client connection info will be stored.")
185 185
186 186 def _cluster_id_changed(self, name, old, new):
187 187 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
188 188 self.engine_json_file = "%s-engine.json" % self.name
189 189 self.client_json_file = "%s-client.json" % self.name
190 190
191 191
192 192 # internal
193 193 children = List()
194 194 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
195 195
196 196 def _use_threads_changed(self, name, old, new):
197 197 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
198 198
199 199 write_connection_files = Bool(True,
200 200 help="""Whether to write connection files to disk.
201 201 True in all cases other than runs with `reuse_files=True` *after the first*
202 202 """
203 203 )
204 204
205 205 aliases = Dict(aliases)
206 206 flags = Dict(flags)
207 207
208 208
209 209 def save_connection_dict(self, fname, cdict):
210 210 """save a connection dict to json file."""
211 211 c = self.config
212 url = cdict['url']
212 url = cdict['registration']
213 213 location = cdict['location']
214 214 if not location:
215 215 try:
216 216 proto,ip,port = split_url(url)
217 217 except AssertionError:
218 218 pass
219 219 else:
220 220 try:
221 221 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
222 222 except (socket.gaierror, IndexError):
223 223 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
224 224 " You may need to specify '--location=<external_ip_address>' to help"
225 225 " IPython decide when to connect via loopback.")
226 226 location = '127.0.0.1'
227 227 cdict['location'] = location
228 228 fname = os.path.join(self.profile_dir.security_dir, fname)
229 229 self.log.info("writing connection info to %s", fname)
230 230 with open(fname, 'w') as f:
231 231 f.write(json.dumps(cdict, indent=2))
232 232 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
233 233
234 234 def load_config_from_json(self):
235 235 """load config from existing json connector files."""
236 236 c = self.config
237 237 self.log.debug("loading config from JSON")
238 238 # load from engine config
239 239 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
240 240 self.log.info("loading connection info from %s", fname)
241 241 with open(fname) as f:
242 242 cfg = json.loads(f.read())
243 243 key = cfg['exec_key']
244 244 # json gives unicode, Session.key wants bytes
245 245 c.Session.key = key.encode('ascii')
246 246 xport,addr = cfg['url'].split('://')
247 247 c.HubFactory.engine_transport = xport
248 248 ip,ports = addr.split(':')
249 249 c.HubFactory.engine_ip = ip
250 250 c.HubFactory.regport = int(ports)
251 251 self.location = cfg['location']
252 252 if not self.engine_ssh_server:
253 253 self.engine_ssh_server = cfg['ssh']
254 254 # load client config
255 255 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
256 256 self.log.info("loading connection info from %s", fname)
257 257 with open(fname) as f:
258 258 cfg = json.loads(f.read())
259 259 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
260 260 xport,addr = cfg['url'].split('://')
261 261 c.HubFactory.client_transport = xport
262 262 ip,ports = addr.split(':')
263 263 c.HubFactory.client_ip = ip
264 264 if not self.ssh_server:
265 265 self.ssh_server = cfg['ssh']
266 266 assert int(ports) == c.HubFactory.regport, "regport mismatch"
267 267
268 268 def cleanup_connection_files(self):
269 269 if self.reuse_files:
270 270 self.log.debug("leaving JSON connection files for reuse")
271 271 return
272 272 self.log.debug("cleaning up JSON connection files")
273 273 for f in (self.client_json_file, self.engine_json_file):
274 274 f = os.path.join(self.profile_dir.security_dir, f)
275 275 try:
276 276 os.remove(f)
277 277 except Exception as e:
278 278 self.log.error("Failed to cleanup connection file: %s", e)
279 279 else:
280 280 self.log.debug(u"removed %s", f)
281 281
282 282 def load_secondary_config(self):
283 283 """secondary config, loading from JSON and setting defaults"""
284 284 if self.reuse_files:
285 285 try:
286 286 self.load_config_from_json()
287 287 except (AssertionError,IOError) as e:
288 288 self.log.error("Could not load config from JSON: %s" % e)
289 289 else:
290 290 # successfully loaded config from JSON, and reuse=True
291 291 # no need to wite back the same file
292 292 self.write_connection_files = False
293 293
294 294 # switch Session.key default to secure
295 295 default_secure(self.config)
296 296 self.log.debug("Config changed")
297 297 self.log.debug(repr(self.config))
298 298
299 299 def init_hub(self):
300 300 c = self.config
301 301
302 302 self.do_import_statements()
303 303
304 304 try:
305 305 self.factory = HubFactory(config=c, log=self.log)
306 306 # self.start_logging()
307 307 self.factory.init_hub()
308 308 except TraitError:
309 309 raise
310 310 except Exception:
311 311 self.log.error("Couldn't construct the Controller", exc_info=True)
312 312 self.exit(1)
313 313
314 314 if self.write_connection_files:
315 315 # save to new json config files
316 316 f = self.factory
317 cdict = {'exec_key' : f.session.key.decode('ascii'),
318 'ssh' : self.ssh_server,
319 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
320 'location' : self.location
317 base = {
318 'exec_key' : f.session.key.decode('ascii'),
319 'location' : self.location,
320 'pack' : f.session.packer,
321 'unpack' : f.session.unpacker,
321 322 }
323
324 cdict = {'ssh' : self.ssh_server}
325 cdict.update(f.client_info)
326 cdict.update(base)
322 327 self.save_connection_dict(self.client_json_file, cdict)
323 edict = cdict
324 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
325 edict['ssh'] = self.engine_ssh_server
328
329 edict = {'ssh' : self.engine_ssh_server}
330 edict.update(f.engine_info)
331 edict.update(base)
326 332 self.save_connection_dict(self.engine_json_file, edict)
327 333
328 334 def init_schedulers(self):
329 335 children = self.children
330 336 mq = import_item(str(self.mq_class))
331 337
332 338 hub = self.factory
333 339 # disambiguate url, in case of *
334 340 monitor_url = disambiguate_url(hub.monitor_url)
335 341 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
336 342 # IOPub relay (in a Process)
337 343 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
338 344 q.bind_in(hub.client_info['iopub'])
339 345 q.bind_out(hub.engine_info['iopub'])
340 346 q.setsockopt_out(zmq.SUBSCRIBE, b'')
341 347 q.connect_mon(monitor_url)
342 348 q.daemon=True
343 349 children.append(q)
344 350
345 351 # Multiplexer Queue (in a Process)
346 352 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
347 353 q.bind_in(hub.client_info['mux'])
348 354 q.setsockopt_in(zmq.IDENTITY, b'mux')
349 355 q.bind_out(hub.engine_info['mux'])
350 356 q.connect_mon(monitor_url)
351 357 q.daemon=True
352 358 children.append(q)
353 359
354 360 # Control Queue (in a Process)
355 361 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
356 362 q.bind_in(hub.client_info['control'])
357 363 q.setsockopt_in(zmq.IDENTITY, b'control')
358 364 q.bind_out(hub.engine_info['control'])
359 365 q.connect_mon(monitor_url)
360 366 q.daemon=True
361 367 children.append(q)
362 368 try:
363 369 scheme = self.config.TaskScheduler.scheme_name
364 370 except AttributeError:
365 371 scheme = TaskScheduler.scheme_name.get_default_value()
366 372 # Task Queue (in a Process)
367 373 if scheme == 'pure':
368 374 self.log.warn("task::using pure DEALER Task scheduler")
369 375 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
370 376 # q.setsockopt_out(zmq.HWM, hub.hwm)
371 377 q.bind_in(hub.client_info['task'][1])
372 378 q.setsockopt_in(zmq.IDENTITY, b'task')
373 379 q.bind_out(hub.engine_info['task'])
374 380 q.connect_mon(monitor_url)
375 381 q.daemon=True
376 382 children.append(q)
377 383 elif scheme == 'none':
378 384 self.log.warn("task::using no Task scheduler")
379 385
380 386 else:
381 387 self.log.info("task::using Python %s Task scheduler"%scheme)
382 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
388 sargs = (hub.client_info['task'], hub.engine_info['task'],
383 389 monitor_url, disambiguate_url(hub.client_info['notification']))
384 390 kwargs = dict(logname='scheduler', loglevel=self.log_level,
385 391 log_url = self.log_url, config=dict(self.config))
386 392 if 'Process' in self.mq_class:
387 393 # run the Python scheduler in a Process
388 394 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
389 395 q.daemon=True
390 396 children.append(q)
391 397 else:
392 398 # single-threaded Controller
393 399 kwargs['in_thread'] = True
394 400 launch_scheduler(*sargs, **kwargs)
395 401
396 402 def terminate_children(self):
397 403 child_procs = []
398 404 for child in self.children:
399 405 if isinstance(child, ProcessMonitoredQueue):
400 406 child_procs.append(child.launcher)
401 407 elif isinstance(child, Process):
402 408 child_procs.append(child)
403 409 if child_procs:
404 410 self.log.critical("terminating children...")
405 411 for child in child_procs:
406 412 try:
407 413 child.terminate()
408 414 except OSError:
409 415 # already dead
410 416 pass
411 417
412 418 def handle_signal(self, sig, frame):
413 419 self.log.critical("Received signal %i, shutting down", sig)
414 420 self.terminate_children()
415 421 self.loop.stop()
416 422
417 423 def init_signal(self):
418 424 for sig in (SIGINT, SIGABRT, SIGTERM):
419 425 signal(sig, self.handle_signal)
420 426
421 427 def do_import_statements(self):
422 428 statements = self.import_statements
423 429 for s in statements:
424 430 try:
425 431 self.log.msg("Executing statement: '%s'" % s)
426 432 exec s in globals(), locals()
427 433 except:
428 434 self.log.msg("Error running statement: %s" % s)
429 435
430 436 def forward_logging(self):
431 437 if self.log_url:
432 438 self.log.info("Forwarding logging to %s"%self.log_url)
433 439 context = zmq.Context.instance()
434 440 lsock = context.socket(zmq.PUB)
435 441 lsock.connect(self.log_url)
436 442 handler = PUBHandler(lsock)
437 443 handler.root_topic = 'controller'
438 444 handler.setLevel(self.log_level)
439 445 self.log.addHandler(handler)
440 446
441 447 @catch_config_error
442 448 def initialize(self, argv=None):
443 449 super(IPControllerApp, self).initialize(argv)
444 450 self.forward_logging()
445 451 self.load_secondary_config()
446 452 self.init_hub()
447 453 self.init_schedulers()
448 454
449 455 def start(self):
450 456 # Start the subprocesses:
451 457 self.factory.start()
452 458 # children must be started before signals are setup,
453 459 # otherwise signal-handling will fire multiple times
454 460 for child in self.children:
455 461 child.start()
456 462 self.init_signal()
457 463
458 464 self.write_pid_file(overwrite=True)
459 465
460 466 try:
461 467 self.factory.loop.start()
462 468 except KeyboardInterrupt:
463 469 self.log.critical("Interrupted, Exiting...\n")
464 470 finally:
465 471 self.cleanup_connection_files()
466 472
467 473
468 474
469 475 def launch_new_instance():
470 476 """Create and run the IPython controller"""
471 477 if sys.platform == 'win32':
472 478 # make sure we don't get called from a multiprocessing subprocess
473 479 # this can result in infinite Controllers being started on Windows
474 480 # which doesn't have a proper fork, so multiprocessing is wonky
475 481
476 482 # this only comes up when IPython has been installed using vanilla
477 483 # setuptools, and *not* distribute.
478 484 import multiprocessing
479 485 p = multiprocessing.current_process()
480 486 # the main process has name 'MainProcess'
481 487 # subprocesses will have names like 'Process-1'
482 488 if p.name != 'MainProcess':
483 489 # we are a subprocess, don't start another Controller!
484 490 return
485 491 app = IPControllerApp.instance()
486 492 app.initialize()
487 493 app.start()
488 494
489 495
490 496 if __name__ == '__main__':
491 497 launch_new_instance()
@@ -1,377 +1,390 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import json
25 25 import os
26 26 import sys
27 27 import time
28 28
29 29 import zmq
30 30 from zmq.eventloop import ioloop
31 31
32 32 from IPython.core.profiledir import ProfileDir
33 33 from IPython.parallel.apps.baseapp import (
34 34 BaseParallelApplication,
35 35 base_aliases,
36 36 base_flags,
37 37 catch_config_error,
38 38 )
39 39 from IPython.zmq.log import EnginePUBHandler
40 40 from IPython.zmq.ipkernel import Kernel, IPKernelApp
41 41 from IPython.zmq.session import (
42 42 Session, session_aliases, session_flags
43 43 )
44 44
45 45 from IPython.config.configurable import Configurable
46 46
47 47 from IPython.parallel.engine.engine import EngineFactory
48 48 from IPython.parallel.util import disambiguate_url
49 49
50 50 from IPython.utils.importstring import import_item
51 51 from IPython.utils.py3compat import cast_bytes
52 52 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float, Instance
53 53
54 54
55 55 #-----------------------------------------------------------------------------
56 56 # Module level variables
57 57 #-----------------------------------------------------------------------------
58 58
59 59 #: The default config file name for this application
60 60 default_config_file_name = u'ipengine_config.py'
61 61
62 62 _description = """Start an IPython engine for parallel computing.
63 63
64 64 IPython engines run in parallel and perform computations on behalf of a client
65 65 and controller. A controller needs to be started before the engines. The
66 66 engine can be configured using command line options or using a cluster
67 67 directory. Cluster directories contain config, log and security files and are
68 68 usually located in your ipython directory and named as "profile_name".
69 69 See the `profile` and `profile-dir` options for details.
70 70 """
71 71
72 72 _examples = """
73 73 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
74 74 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
75 75 """
76 76
77 77 #-----------------------------------------------------------------------------
78 78 # MPI configuration
79 79 #-----------------------------------------------------------------------------
80 80
81 81 mpi4py_init = """from mpi4py import MPI as mpi
82 82 mpi.size = mpi.COMM_WORLD.Get_size()
83 83 mpi.rank = mpi.COMM_WORLD.Get_rank()
84 84 """
85 85
86 86
87 87 pytrilinos_init = """from PyTrilinos import Epetra
88 88 class SimpleStruct:
89 89 pass
90 90 mpi = SimpleStruct()
91 91 mpi.rank = 0
92 92 mpi.size = 0
93 93 """
94 94
95 95 class MPI(Configurable):
96 96 """Configurable for MPI initialization"""
97 97 use = Unicode('', config=True,
98 98 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
99 99 )
100 100
101 101 def _use_changed(self, name, old, new):
102 102 # load default init script if it's not set
103 103 if not self.init_script:
104 104 self.init_script = self.default_inits.get(new, '')
105 105
106 106 init_script = Unicode('', config=True,
107 107 help="Initialization code for MPI")
108 108
109 109 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
110 110 config=True)
111 111
112 112
113 113 #-----------------------------------------------------------------------------
114 114 # Main application
115 115 #-----------------------------------------------------------------------------
116 116 aliases = dict(
117 117 file = 'IPEngineApp.url_file',
118 118 c = 'IPEngineApp.startup_command',
119 119 s = 'IPEngineApp.startup_script',
120 120
121 121 url = 'EngineFactory.url',
122 122 ssh = 'EngineFactory.sshserver',
123 123 sshkey = 'EngineFactory.sshkey',
124 124 ip = 'EngineFactory.ip',
125 125 transport = 'EngineFactory.transport',
126 126 port = 'EngineFactory.regport',
127 127 location = 'EngineFactory.location',
128 128
129 129 timeout = 'EngineFactory.timeout',
130 130
131 131 mpi = 'MPI.use',
132 132
133 133 )
134 134 aliases.update(base_aliases)
135 135 aliases.update(session_aliases)
136 136 flags = {}
137 137 flags.update(base_flags)
138 138 flags.update(session_flags)
139 139
140 140 class IPEngineApp(BaseParallelApplication):
141 141
142 142 name = 'ipengine'
143 143 description = _description
144 144 examples = _examples
145 145 config_file_name = Unicode(default_config_file_name)
146 146 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
147 147
148 148 startup_script = Unicode(u'', config=True,
149 149 help='specify a script to be run at startup')
150 150 startup_command = Unicode('', config=True,
151 151 help='specify a command to be run at startup')
152 152
153 153 url_file = Unicode(u'', config=True,
154 154 help="""The full location of the file containing the connection information for
155 155 the controller. If this is not given, the file must be in the
156 156 security directory of the cluster directory. This location is
157 157 resolved using the `profile` or `profile_dir` options.""",
158 158 )
159 159 wait_for_url_file = Float(5, config=True,
160 160 help="""The maximum number of seconds to wait for url_file to exist.
161 161 This is useful for batch-systems and shared-filesystems where the
162 162 controller and engine are started at the same time and it
163 163 may take a moment for the controller to write the connector files.""")
164 164
165 165 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
166 166
167 167 def _cluster_id_changed(self, name, old, new):
168 168 if new:
169 169 base = 'ipcontroller-%s' % new
170 170 else:
171 171 base = 'ipcontroller'
172 172 self.url_file_name = "%s-engine.json" % base
173 173
174 174 log_url = Unicode('', config=True,
175 175 help="""The URL for the iploggerapp instance, for forwarding
176 176 logging to a central location.""")
177 177
178 178 # an IPKernelApp instance, used to setup listening for shell frontends
179 179 kernel_app = Instance(IPKernelApp)
180 180
181 181 aliases = Dict(aliases)
182 182 flags = Dict(flags)
183 183
184 184 @property
185 185 def kernel(self):
186 186 """allow access to the Kernel object, so I look like IPKernelApp"""
187 187 return self.engine.kernel
188 188
189 189 def find_url_file(self):
190 190 """Set the url file.
191 191
192 192 Here we don't try to actually see if it exists for is valid as that
193 193 is hadled by the connection logic.
194 194 """
195 195 config = self.config
196 196 # Find the actual controller key file
197 197 if not self.url_file:
198 198 self.url_file = os.path.join(
199 199 self.profile_dir.security_dir,
200 200 self.url_file_name
201 201 )
202 202
203 203 def load_connector_file(self):
204 204 """load config from a JSON connector file,
205 205 at a *lower* priority than command-line/config files.
206 206 """
207 207
208 208 self.log.info("Loading url_file %r", self.url_file)
209 209 config = self.config
210 210
211 211 with open(self.url_file) as f:
212 212 d = json.loads(f.read())
213 213
214 if 'exec_key' in d:
215 config.Session.key = cast_bytes(d['exec_key'])
216
214 # allow hand-override of location for disambiguation
215 # and ssh-server
217 216 try:
218 217 config.EngineFactory.location
219 218 except AttributeError:
220 219 config.EngineFactory.location = d['location']
221 220
222 d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
223 try:
224 config.EngineFactory.url
225 except AttributeError:
226 config.EngineFactory.url = d['url']
227
228 221 try:
229 222 config.EngineFactory.sshserver
230 223 except AttributeError:
231 config.EngineFactory.sshserver = d['ssh']
224 config.EngineFactory.sshserver = d.get('ssh')
225
226 location = config.EngineFactory.location
227
228 for key in ('registration', 'hb_ping', 'hb_pong', 'mux', 'task', 'control'):
229 d[key] = disambiguate_url(d[key], location)
230
231 # DO NOT allow override of basic URLs, serialization, or exec_key
232 # JSON file takes top priority there
233 config.Session.key = asbytes(d['exec_key'])
234
235 config.EngineFactory.url = d['registration']
236
237 config.Session.packer = d['pack']
238 config.Session.unpacker = d['unpack']
239
240 self.log.debug("Config changed:")
241 self.log.debug("%r", config)
242 self.connection_info = d
232 243
233 244 def bind_kernel(self, **kwargs):
234 245 """Promote engine to listening kernel, accessible to frontends."""
235 246 if self.kernel_app is not None:
236 247 return
237 248
238 249 self.log.info("Opening ports for direct connections as an IPython kernel")
239 250
240 251 kernel = self.kernel
241 252
242 253 kwargs.setdefault('config', self.config)
243 254 kwargs.setdefault('log', self.log)
244 255 kwargs.setdefault('profile_dir', self.profile_dir)
245 256 kwargs.setdefault('session', self.engine.session)
246 257
247 258 app = self.kernel_app = IPKernelApp(**kwargs)
248 259
249 260 # allow IPKernelApp.instance():
250 261 IPKernelApp._instance = app
251 262
252 263 app.init_connection_file()
253 264 # relevant contents of init_sockets:
254 265
255 266 app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
256 267 app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)
257 268
258 269 app.iopub_port = app._bind_socket(kernel.iopub_socket, app.iopub_port)
259 270 app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)
260 271
261 272 kernel.stdin_socket = self.engine.context.socket(zmq.ROUTER)
262 273 app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
263 274 app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)
264 275
265 276 # start the heartbeat, and log connection info:
266 277
267 278 app.init_heartbeat()
268 279
269 280 app.log_connection_info()
270 281 app.write_connection_file()
271 282
272 283
273 284 def init_engine(self):
274 285 # This is the working dir by now.
275 286 sys.path.insert(0, '')
276 287 config = self.config
277 288 # print config
278 289 self.find_url_file()
279 290
280 291 # was the url manually specified?
281 292 keys = set(self.config.EngineFactory.keys())
282 293 keys = keys.union(set(self.config.RegistrationFactory.keys()))
283 294
284 295 if keys.intersection(set(['ip', 'url', 'port'])):
285 296 # Connection info was specified, don't wait for the file
286 297 url_specified = True
287 298 self.wait_for_url_file = 0
288 299 else:
289 300 url_specified = False
290 301
291 302 if self.wait_for_url_file and not os.path.exists(self.url_file):
292 303 self.log.warn("url_file %r not found", self.url_file)
293 304 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
294 305 tic = time.time()
295 306 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
296 307 # wait for url_file to exist, or until time limit
297 308 time.sleep(0.1)
298 309
299 310 if os.path.exists(self.url_file):
300 311 self.load_connector_file()
301 312 elif not url_specified:
302 313 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
303 314 self.exit(1)
304 315
305 316
306 317 try:
307 318 exec_lines = config.Kernel.exec_lines
308 319 except AttributeError:
309 320 config.Kernel.exec_lines = []
310 321 exec_lines = config.Kernel.exec_lines
311 322
312 323 if self.startup_script:
313 324 enc = sys.getfilesystemencoding() or 'utf8'
314 325 cmd="execfile(%r)" % self.startup_script.encode(enc)
315 326 exec_lines.append(cmd)
316 327 if self.startup_command:
317 328 exec_lines.append(self.startup_command)
318 329
319 330 # Create the underlying shell class and Engine
320 331 # shell_class = import_item(self.master_config.Global.shell_class)
321 332 # print self.config
322 333 try:
323 self.engine = EngineFactory(config=config, log=self.log)
334 self.engine = EngineFactory(config=config, log=self.log,
335 connection_info=self.connection_info,
336 )
324 337 except:
325 338 self.log.error("Couldn't start the Engine", exc_info=True)
326 339 self.exit(1)
327 340
328 341 def forward_logging(self):
329 342 if self.log_url:
330 343 self.log.info("Forwarding logging to %s", self.log_url)
331 344 context = self.engine.context
332 345 lsock = context.socket(zmq.PUB)
333 346 lsock.connect(self.log_url)
334 347 handler = EnginePUBHandler(self.engine, lsock)
335 348 handler.setLevel(self.log_level)
336 349 self.log.addHandler(handler)
337 350
338 351 def init_mpi(self):
339 352 global mpi
340 353 self.mpi = MPI(config=self.config)
341 354
342 355 mpi_import_statement = self.mpi.init_script
343 356 if mpi_import_statement:
344 357 try:
345 358 self.log.info("Initializing MPI:")
346 359 self.log.info(mpi_import_statement)
347 360 exec mpi_import_statement in globals()
348 361 except:
349 362 mpi = None
350 363 else:
351 364 mpi = None
352 365
353 366 @catch_config_error
354 367 def initialize(self, argv=None):
355 368 super(IPEngineApp, self).initialize(argv)
356 369 self.init_mpi()
357 370 self.init_engine()
358 371 self.forward_logging()
359 372
360 373 def start(self):
361 374 self.engine.start()
362 375 try:
363 376 self.engine.loop.start()
364 377 except KeyboardInterrupt:
365 378 self.log.critical("Engine Interrupted, shutting down...\n")
366 379
367 380
368 381 def launch_new_instance():
369 382 """Create and run the IPython engine"""
370 383 app = IPEngineApp.instance()
371 384 app.initialize()
372 385 app.start()
373 386
374 387
375 388 if __name__ == '__main__':
376 389 launch_new_instance()
377 390
@@ -1,1713 +1,1692 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 from threading import Thread, Event
22 22 import time
23 23 import warnings
24 24 from datetime import datetime
25 25 from getpass import getpass
26 26 from pprint import pprint
27 27
28 28 pjoin = os.path.join
29 29
30 30 import zmq
31 31 # from zmq.eventloop import ioloop, zmqstream
32 32
33 33 from IPython.config.configurable import MultipleInstanceError
34 34 from IPython.core.application import BaseIPythonApplication
35 35 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 36
37 37 from IPython.utils.coloransi import TermColors
38 38 from IPython.utils.jsonutil import rekey
39 39 from IPython.utils.localinterfaces import LOCAL_IPS
40 40 from IPython.utils.path import get_ipython_dir
41 41 from IPython.utils.py3compat import cast_bytes
42 42 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
43 43 Dict, List, Bool, Set, Any)
44 44 from IPython.external.decorator import decorator
45 45 from IPython.external.ssh import tunnel
46 46
47 47 from IPython.parallel import Reference
48 48 from IPython.parallel import error
49 49 from IPython.parallel import util
50 50
51 51 from IPython.zmq.session import Session, Message
52 52
53 53 from .asyncresult import AsyncResult, AsyncHubResult
54 54 from .view import DirectView, LoadBalancedView
55 55
56 56 if sys.version_info[0] >= 3:
57 57 # xrange is used in a couple 'isinstance' tests in py2
58 58 # should be just 'range' in 3k
59 59 xrange = range
60 60
61 61 #--------------------------------------------------------------------------
62 62 # Decorators for Client methods
63 63 #--------------------------------------------------------------------------
64 64
65 65 @decorator
66 66 def spin_first(f, self, *args, **kwargs):
67 67 """Call spin() to sync state prior to calling the method."""
68 68 self.spin()
69 69 return f(self, *args, **kwargs)
70 70
71 71
72 72 #--------------------------------------------------------------------------
73 73 # Classes
74 74 #--------------------------------------------------------------------------
75 75
76 76
77 77 class ExecuteReply(object):
78 78 """wrapper for finished Execute results"""
79 79 def __init__(self, msg_id, content, metadata):
80 80 self.msg_id = msg_id
81 81 self._content = content
82 82 self.execution_count = content['execution_count']
83 83 self.metadata = metadata
84 84
85 85 def __getitem__(self, key):
86 86 return self.metadata[key]
87 87
88 88 def __getattr__(self, key):
89 89 if key not in self.metadata:
90 90 raise AttributeError(key)
91 91 return self.metadata[key]
92 92
93 93 def __repr__(self):
94 94 pyout = self.metadata['pyout'] or {'data':{}}
95 95 text_out = pyout['data'].get('text/plain', '')
96 96 if len(text_out) > 32:
97 97 text_out = text_out[:29] + '...'
98 98
99 99 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
100 100
101 101 def _repr_pretty_(self, p, cycle):
102 102 pyout = self.metadata['pyout'] or {'data':{}}
103 103 text_out = pyout['data'].get('text/plain', '')
104 104
105 105 if not text_out:
106 106 return
107 107
108 108 try:
109 109 ip = get_ipython()
110 110 except NameError:
111 111 colors = "NoColor"
112 112 else:
113 113 colors = ip.colors
114 114
115 115 if colors == "NoColor":
116 116 out = normal = ""
117 117 else:
118 118 out = TermColors.Red
119 119 normal = TermColors.Normal
120 120
121 121 if '\n' in text_out and not text_out.startswith('\n'):
122 122 # add newline for multiline reprs
123 123 text_out = '\n' + text_out
124 124
125 125 p.text(
126 126 out + u'Out[%i:%i]: ' % (
127 127 self.metadata['engine_id'], self.execution_count
128 128 ) + normal + text_out
129 129 )
130 130
131 131 def _repr_html_(self):
132 132 pyout = self.metadata['pyout'] or {'data':{}}
133 133 return pyout['data'].get("text/html")
134 134
135 135 def _repr_latex_(self):
136 136 pyout = self.metadata['pyout'] or {'data':{}}
137 137 return pyout['data'].get("text/latex")
138 138
139 139 def _repr_json_(self):
140 140 pyout = self.metadata['pyout'] or {'data':{}}
141 141 return pyout['data'].get("application/json")
142 142
143 143 def _repr_javascript_(self):
144 144 pyout = self.metadata['pyout'] or {'data':{}}
145 145 return pyout['data'].get("application/javascript")
146 146
147 147 def _repr_png_(self):
148 148 pyout = self.metadata['pyout'] or {'data':{}}
149 149 return pyout['data'].get("image/png")
150 150
151 151 def _repr_jpeg_(self):
152 152 pyout = self.metadata['pyout'] or {'data':{}}
153 153 return pyout['data'].get("image/jpeg")
154 154
155 155 def _repr_svg_(self):
156 156 pyout = self.metadata['pyout'] or {'data':{}}
157 157 return pyout['data'].get("image/svg+xml")
158 158
159 159
160 160 class Metadata(dict):
161 161 """Subclass of dict for initializing metadata values.
162 162
163 163 Attribute access works on keys.
164 164
165 165 These objects have a strict set of keys - errors will raise if you try
166 166 to add new keys.
167 167 """
168 168 def __init__(self, *args, **kwargs):
169 169 dict.__init__(self)
170 170 md = {'msg_id' : None,
171 171 'submitted' : None,
172 172 'started' : None,
173 173 'completed' : None,
174 174 'received' : None,
175 175 'engine_uuid' : None,
176 176 'engine_id' : None,
177 177 'follow' : None,
178 178 'after' : None,
179 179 'status' : None,
180 180
181 181 'pyin' : None,
182 182 'pyout' : None,
183 183 'pyerr' : None,
184 184 'stdout' : '',
185 185 'stderr' : '',
186 186 'outputs' : [],
187 187 'outputs_ready' : False,
188 188 }
189 189 self.update(md)
190 190 self.update(dict(*args, **kwargs))
191 191
192 192 def __getattr__(self, key):
193 193 """getattr aliased to getitem"""
194 194 if key in self.iterkeys():
195 195 return self[key]
196 196 else:
197 197 raise AttributeError(key)
198 198
199 199 def __setattr__(self, key, value):
200 200 """setattr aliased to setitem, with strict"""
201 201 if key in self.iterkeys():
202 202 self[key] = value
203 203 else:
204 204 raise AttributeError(key)
205 205
206 206 def __setitem__(self, key, value):
207 207 """strict static key enforcement"""
208 208 if key in self.iterkeys():
209 209 dict.__setitem__(self, key, value)
210 210 else:
211 211 raise KeyError(key)
212 212
213 213
214 214 class Client(HasTraits):
215 215 """A semi-synchronous client to the IPython ZMQ cluster
216 216
217 217 Parameters
218 218 ----------
219 219
220 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
220 url_file : str/unicode; path to ipcontroller-client.json
221 This JSON file should contain all the information needed to connect to a cluster,
222 and is likely the only argument needed.
221 223 Connection information for the Hub's registration. If a json connector
222 224 file is given, then likely no further configuration is necessary.
223 225 [Default: use profile]
224 226 profile : bytes
225 227 The name of the Cluster profile to be used to find connector information.
226 228 If run from an IPython application, the default profile will be the same
227 229 as the running application, otherwise it will be 'default'.
228 230 context : zmq.Context
229 231 Pass an existing zmq.Context instance, otherwise the client will create its own.
230 232 debug : bool
231 233 flag for lots of message printing for debug purposes
232 234 timeout : int/float
233 235 time (in seconds) to wait for connection replies from the Hub
234 236 [Default: 10]
235 237
236 238 #-------------- session related args ----------------
237 239
238 240 config : Config object
239 241 If specified, this will be relayed to the Session for configuration
240 242 username : str
241 243 set username for the session object
242 packer : str (import_string) or callable
243 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
244 function to serialize messages. Must support same input as
245 JSON, and output must be bytes.
246 You can pass a callable directly as `pack`
247 unpacker : str (import_string) or callable
248 The inverse of packer. Only necessary if packer is specified as *not* one
249 of 'json' or 'pickle'.
250 244
251 245 #-------------- ssh related args ----------------
252 246 # These are args for configuring the ssh tunnel to be used
253 247 # credentials are used to forward connections over ssh to the Controller
254 248 # Note that the ip given in `addr` needs to be relative to sshserver
255 249 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
256 250 # and set sshserver as the same machine the Controller is on. However,
257 251 # the only requirement is that sshserver is able to see the Controller
258 252 # (i.e. is within the same trusted network).
259 253
260 254 sshserver : str
261 255 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
262 256 If keyfile or password is specified, and this is not, it will default to
263 257 the ip given in addr.
264 258 sshkey : str; path to ssh private key file
265 259 This specifies a key to be used in ssh login, default None.
266 260 Regular default ssh keys will be used without specifying this argument.
267 261 password : str
268 262 Your ssh password to sshserver. Note that if this is left None,
269 263 you will be prompted for it if passwordless key based login is unavailable.
270 264 paramiko : bool
271 265 flag for whether to use paramiko instead of shell ssh for tunneling.
272 266 [default: True on win32, False else]
273 267
274 ------- exec authentication args -------
275 If even localhost is untrusted, you can have some protection against
276 unauthorized execution by signing messages with HMAC digests.
277 Messages are still sent as cleartext, so if someone can snoop your
278 loopback traffic this will not protect your privacy, but will prevent
279 unauthorized execution.
280
281 exec_key : str
282 an authentication key or file containing a key
283 default: None
284
285 268
286 269 Attributes
287 270 ----------
288 271
289 272 ids : list of int engine IDs
290 273 requesting the ids attribute always synchronizes
291 274 the registration state. To request ids without synchronization,
292 275 use semi-private _ids attributes.
293 276
294 277 history : list of msg_ids
295 278 a list of msg_ids, keeping track of all the execution
296 279 messages you have submitted in order.
297 280
298 281 outstanding : set of msg_ids
299 282 a set of msg_ids that have been submitted, but whose
300 283 results have not yet been received.
301 284
302 285 results : dict
303 286 a dict of all our results, keyed by msg_id
304 287
305 288 block : bool
306 289 determines default behavior when block not specified
307 290 in execution methods
308 291
309 292 Methods
310 293 -------
311 294
312 295 spin
313 296 flushes incoming results and registration state changes
314 297 control methods spin, and requesting `ids` also ensures up to date
315 298
316 299 wait
317 300 wait on one or more msg_ids
318 301
319 302 execution methods
320 303 apply
321 304 legacy: execute, run
322 305
323 306 data movement
324 307 push, pull, scatter, gather
325 308
326 309 query methods
327 310 queue_status, get_result, purge, result_status
328 311
329 312 control methods
330 313 abort, shutdown
331 314
332 315 """
333 316
334 317
335 318 block = Bool(False)
336 319 outstanding = Set()
337 320 results = Instance('collections.defaultdict', (dict,))
338 321 metadata = Instance('collections.defaultdict', (Metadata,))
339 322 history = List()
340 323 debug = Bool(False)
341 324 _spin_thread = Any()
342 325 _stop_spinning = Any()
343 326
344 327 profile=Unicode()
345 328 def _profile_default(self):
346 329 if BaseIPythonApplication.initialized():
347 330 # an IPython app *might* be running, try to get its profile
348 331 try:
349 332 return BaseIPythonApplication.instance().profile
350 333 except (AttributeError, MultipleInstanceError):
351 334 # could be a *different* subclass of config.Application,
352 335 # which would raise one of these two errors.
353 336 return u'default'
354 337 else:
355 338 return u'default'
356 339
357 340
358 341 _outstanding_dict = Instance('collections.defaultdict', (set,))
359 342 _ids = List()
360 343 _connected=Bool(False)
361 344 _ssh=Bool(False)
362 345 _context = Instance('zmq.Context')
363 346 _config = Dict()
364 347 _engines=Instance(util.ReverseDict, (), {})
365 348 # _hub_socket=Instance('zmq.Socket')
366 349 _query_socket=Instance('zmq.Socket')
367 350 _control_socket=Instance('zmq.Socket')
368 351 _iopub_socket=Instance('zmq.Socket')
369 352 _notification_socket=Instance('zmq.Socket')
370 353 _mux_socket=Instance('zmq.Socket')
371 354 _task_socket=Instance('zmq.Socket')
372 355 _task_scheme=Unicode()
373 356 _closed = False
374 357 _ignored_control_replies=Integer(0)
375 358 _ignored_hub_replies=Integer(0)
376 359
377 360 def __new__(self, *args, **kw):
378 361 # don't raise on positional args
379 362 return HasTraits.__new__(self, **kw)
380 363
381 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
382 context=None, debug=False, exec_key=None,
364 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
365 context=None, debug=False,
383 366 sshserver=None, sshkey=None, password=None, paramiko=None,
384 367 timeout=10, **extra_args
385 368 ):
386 369 if profile:
387 370 super(Client, self).__init__(debug=debug, profile=profile)
388 371 else:
389 372 super(Client, self).__init__(debug=debug)
390 373 if context is None:
391 374 context = zmq.Context.instance()
392 375 self._context = context
393 376 self._stop_spinning = Event()
394 377
378 if 'url_or_file' in extra_args:
379 url_file = extra_args['url_or_file']
380 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
381
382 if url_file and util.is_url(url_file):
383 raise ValueError("single urls cannot be specified, url-files must be used.")
384
395 385 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
386
396 387 if self._cd is not None:
397 if url_or_file is None:
398 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
399 if url_or_file is None:
388 if url_file is None:
389 url_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
390 if url_file is None:
400 391 raise ValueError(
401 392 "I can't find enough information to connect to a hub!"
402 " Please specify at least one of url_or_file or profile."
393 " Please specify at least one of url_file or profile."
403 394 )
404 395
405 if not util.is_url(url_or_file):
406 # it's not a url, try for a file
407 if not os.path.exists(url_or_file):
408 if self._cd:
409 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
410 if not os.path.exists(url_or_file):
411 raise IOError("Connection file not found: %r" % url_or_file)
412 with open(url_or_file) as f:
413 cfg = json.loads(f.read())
414 else:
415 cfg = {'url':url_or_file}
396 with open(url_file) as f:
397 cfg = json.load(f)
398
399 self._task_scheme = cfg['task_scheme']
416 400
417 401 # sync defaults from args, json:
418 402 if sshserver:
419 403 cfg['ssh'] = sshserver
420 if exec_key:
421 cfg['exec_key'] = exec_key
422 exec_key = cfg['exec_key']
404
423 405 location = cfg.setdefault('location', None)
424 cfg['url'] = util.disambiguate_url(cfg['url'], location)
425 url = cfg['url']
406 for key in ('control', 'task', 'mux', 'notification', 'registration'):
407 cfg[key] = util.disambiguate_url(cfg[key], location)
408 url = cfg['registration']
426 409 proto,addr,port = util.split_url(url)
427 410 if location is not None and addr == '127.0.0.1':
428 411 # location specified, and connection is expected to be local
429 412 if location not in LOCAL_IPS and not sshserver:
430 413 # load ssh from JSON *only* if the controller is not on
431 414 # this machine
432 415 sshserver=cfg['ssh']
433 416 if location not in LOCAL_IPS and not sshserver:
434 417 # warn if no ssh specified, but SSH is probably needed
435 418 # This is only a warning, because the most likely cause
436 419 # is a local Controller on a laptop whose IP is dynamic
437 420 warnings.warn("""
438 421 Controller appears to be listening on localhost, but not on this machine.
439 422 If this is true, you should specify Client(...,sshserver='you@%s')
440 423 or instruct your controller to listen on an external IP."""%location,
441 424 RuntimeWarning)
442 425 elif not sshserver:
443 426 # otherwise sync with cfg
444 427 sshserver = cfg['ssh']
445 428
446 429 self._config = cfg
447 430
448 431 self._ssh = bool(sshserver or sshkey or password)
449 432 if self._ssh and sshserver is None:
450 433 # default to ssh via localhost
451 434 sshserver = url.split('://')[1].split(':')[0]
452 435 if self._ssh and password is None:
453 436 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
454 437 password=False
455 438 else:
456 439 password = getpass("SSH Password for %s: "%sshserver)
457 440 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
458 441
459 442 # configure and construct the session
460 if exec_key is not None:
461 if os.path.isfile(exec_key):
462 extra_args['keyfile'] = exec_key
463 else:
464 exec_key = cast_bytes(exec_key)
465 extra_args['key'] = exec_key
443 extra_args['packer'] = cfg['pack']
444 extra_args['unpacker'] = cfg['unpack']
445 extra_args['key'] = cfg['exec_key']
446
466 447 self.session = Session(**extra_args)
467 448
468 449 self._query_socket = self._context.socket(zmq.DEALER)
469 450
470 451 if self._ssh:
471 452 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
472 453 else:
473 454 self._query_socket.connect(url)
474 455
475 456 self.session.debug = self.debug
476 457
477 458 self._notification_handlers = {'registration_notification' : self._register_engine,
478 459 'unregistration_notification' : self._unregister_engine,
479 460 'shutdown_notification' : lambda msg: self.close(),
480 461 }
481 462 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
482 463 'apply_reply' : self._handle_apply_reply}
483 464 self._connect(sshserver, ssh_kwargs, timeout)
484 465
485 466 # last step: setup magics, if we are in IPython:
486 467
487 468 try:
488 469 ip = get_ipython()
489 470 except NameError:
490 471 return
491 472 else:
492 473 if 'px' not in ip.magics_manager.magics:
493 474 # in IPython but we are the first Client.
494 475 # activate a default view for parallel magics.
495 476 self.activate()
496 477
497 478 def __del__(self):
498 479 """cleanup sockets, but _not_ context."""
499 480 self.close()
500 481
501 482 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
502 483 if ipython_dir is None:
503 484 ipython_dir = get_ipython_dir()
504 485 if profile_dir is not None:
505 486 try:
506 487 self._cd = ProfileDir.find_profile_dir(profile_dir)
507 488 return
508 489 except ProfileDirError:
509 490 pass
510 491 elif profile is not None:
511 492 try:
512 493 self._cd = ProfileDir.find_profile_dir_by_name(
513 494 ipython_dir, profile)
514 495 return
515 496 except ProfileDirError:
516 497 pass
517 498 self._cd = None
518 499
519 500 def _update_engines(self, engines):
520 501 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
521 502 for k,v in engines.iteritems():
522 503 eid = int(k)
523 504 self._engines[eid] = v
524 505 self._ids.append(eid)
525 506 self._ids = sorted(self._ids)
526 507 if sorted(self._engines.keys()) != range(len(self._engines)) and \
527 508 self._task_scheme == 'pure' and self._task_socket:
528 509 self._stop_scheduling_tasks()
529 510
530 511 def _stop_scheduling_tasks(self):
531 512 """Stop scheduling tasks because an engine has been unregistered
532 513 from a pure ZMQ scheduler.
533 514 """
534 515 self._task_socket.close()
535 516 self._task_socket = None
536 517 msg = "An engine has been unregistered, and we are using pure " +\
537 518 "ZMQ task scheduling. Task farming will be disabled."
538 519 if self.outstanding:
539 520 msg += " If you were running tasks when this happened, " +\
540 521 "some `outstanding` msg_ids may never resolve."
541 522 warnings.warn(msg, RuntimeWarning)
542 523
543 524 def _build_targets(self, targets):
544 525 """Turn valid target IDs or 'all' into two lists:
545 526 (int_ids, uuids).
546 527 """
547 528 if not self._ids:
548 529 # flush notification socket if no engines yet, just in case
549 530 if not self.ids:
550 531 raise error.NoEnginesRegistered("Can't build targets without any engines")
551 532
552 533 if targets is None:
553 534 targets = self._ids
554 535 elif isinstance(targets, basestring):
555 536 if targets.lower() == 'all':
556 537 targets = self._ids
557 538 else:
558 539 raise TypeError("%r not valid str target, must be 'all'"%(targets))
559 540 elif isinstance(targets, int):
560 541 if targets < 0:
561 542 targets = self.ids[targets]
562 543 if targets not in self._ids:
563 544 raise IndexError("No such engine: %i"%targets)
564 545 targets = [targets]
565 546
566 547 if isinstance(targets, slice):
567 548 indices = range(len(self._ids))[targets]
568 549 ids = self.ids
569 550 targets = [ ids[i] for i in indices ]
570 551
571 552 if not isinstance(targets, (tuple, list, xrange)):
572 553 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
573 554
574 555 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
575 556
576 557 def _connect(self, sshserver, ssh_kwargs, timeout):
577 558 """setup all our socket connections to the cluster. This is called from
578 559 __init__."""
579 560
580 561 # Maybe allow reconnecting?
581 562 if self._connected:
582 563 return
583 564 self._connected=True
584 565
585 566 def connect_socket(s, url):
586 url = util.disambiguate_url(url, self._config['location'])
567 # url = util.disambiguate_url(url, self._config['location'])
587 568 if self._ssh:
588 569 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
589 570 else:
590 571 return s.connect(url)
591 572
592 573 self.session.send(self._query_socket, 'connection_request')
593 574 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
594 575 poller = zmq.Poller()
595 576 poller.register(self._query_socket, zmq.POLLIN)
596 577 # poll expects milliseconds, timeout is seconds
597 578 evts = poller.poll(timeout*1000)
598 579 if not evts:
599 580 raise error.TimeoutError("Hub connection request timed out")
600 581 idents,msg = self.session.recv(self._query_socket,mode=0)
601 582 if self.debug:
602 583 pprint(msg)
603 msg = Message(msg)
604 content = msg.content
605 self._config['registration'] = dict(content)
606 if content.status == 'ok':
607 ident = self.session.bsession
608 if content.mux:
584 content = msg['content']
585 # self._config['registration'] = dict(content)
586 cfg = self._config
587 if content['status'] == 'ok':
609 588 self._mux_socket = self._context.socket(zmq.DEALER)
610 connect_socket(self._mux_socket, content.mux)
611 if content.task:
612 self._task_scheme, task_addr = content.task
589 connect_socket(self._mux_socket, cfg['mux'])
590
613 591 self._task_socket = self._context.socket(zmq.DEALER)
614 connect_socket(self._task_socket, task_addr)
615 if content.notification:
592 connect_socket(self._task_socket, cfg['task'])
593
616 594 self._notification_socket = self._context.socket(zmq.SUB)
617 connect_socket(self._notification_socket, content.notification)
618 595 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
619 if content.control:
596 connect_socket(self._notification_socket, cfg['notification'])
597
620 598 self._control_socket = self._context.socket(zmq.DEALER)
621 connect_socket(self._control_socket, content.control)
622 if content.iopub:
599 connect_socket(self._control_socket, cfg['control'])
600
623 601 self._iopub_socket = self._context.socket(zmq.SUB)
624 602 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
625 connect_socket(self._iopub_socket, content.iopub)
626 self._update_engines(dict(content.engines))
603 connect_socket(self._iopub_socket, cfg['iopub'])
604
605 self._update_engines(dict(content['engines']))
627 606 else:
628 607 self._connected = False
629 608 raise Exception("Failed to connect!")
630 609
631 610 #--------------------------------------------------------------------------
632 611 # handlers and callbacks for incoming messages
633 612 #--------------------------------------------------------------------------
634 613
635 614 def _unwrap_exception(self, content):
636 615 """unwrap exception, and remap engine_id to int."""
637 616 e = error.unwrap_exception(content)
638 617 # print e.traceback
639 618 if e.engine_info:
640 619 e_uuid = e.engine_info['engine_uuid']
641 620 eid = self._engines[e_uuid]
642 621 e.engine_info['engine_id'] = eid
643 622 return e
644 623
645 624 def _extract_metadata(self, header, parent, content):
646 625 md = {'msg_id' : parent['msg_id'],
647 626 'received' : datetime.now(),
648 627 'engine_uuid' : header.get('engine', None),
649 628 'follow' : parent.get('follow', []),
650 629 'after' : parent.get('after', []),
651 630 'status' : content['status'],
652 631 }
653 632
654 633 if md['engine_uuid'] is not None:
655 634 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
656 635
657 636 if 'date' in parent:
658 637 md['submitted'] = parent['date']
659 638 if 'started' in header:
660 639 md['started'] = header['started']
661 640 if 'date' in header:
662 641 md['completed'] = header['date']
663 642 return md
664 643
665 644 def _register_engine(self, msg):
666 645 """Register a new engine, and update our connection info."""
667 646 content = msg['content']
668 647 eid = content['id']
669 648 d = {eid : content['queue']}
670 649 self._update_engines(d)
671 650
672 651 def _unregister_engine(self, msg):
673 652 """Unregister an engine that has died."""
674 653 content = msg['content']
675 654 eid = int(content['id'])
676 655 if eid in self._ids:
677 656 self._ids.remove(eid)
678 657 uuid = self._engines.pop(eid)
679 658
680 659 self._handle_stranded_msgs(eid, uuid)
681 660
682 661 if self._task_socket and self._task_scheme == 'pure':
683 662 self._stop_scheduling_tasks()
684 663
685 664 def _handle_stranded_msgs(self, eid, uuid):
686 665 """Handle messages known to be on an engine when the engine unregisters.
687 666
688 667 It is possible that this will fire prematurely - that is, an engine will
689 668 go down after completing a result, and the client will be notified
690 669 of the unregistration and later receive the successful result.
691 670 """
692 671
693 672 outstanding = self._outstanding_dict[uuid]
694 673
695 674 for msg_id in list(outstanding):
696 675 if msg_id in self.results:
697 676 # we already
698 677 continue
699 678 try:
700 679 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
701 680 except:
702 681 content = error.wrap_exception()
703 682 # build a fake message:
704 683 parent = {}
705 684 header = {}
706 685 parent['msg_id'] = msg_id
707 686 header['engine'] = uuid
708 687 header['date'] = datetime.now()
709 688 msg = dict(parent_header=parent, header=header, content=content)
710 689 self._handle_apply_reply(msg)
711 690
712 691 def _handle_execute_reply(self, msg):
713 692 """Save the reply to an execute_request into our results.
714 693
715 694 execute messages are never actually used. apply is used instead.
716 695 """
717 696
718 697 parent = msg['parent_header']
719 698 msg_id = parent['msg_id']
720 699 if msg_id not in self.outstanding:
721 700 if msg_id in self.history:
722 701 print ("got stale result: %s"%msg_id)
723 702 else:
724 703 print ("got unknown result: %s"%msg_id)
725 704 else:
726 705 self.outstanding.remove(msg_id)
727 706
728 707 content = msg['content']
729 708 header = msg['header']
730 709
731 710 # construct metadata:
732 711 md = self.metadata[msg_id]
733 712 md.update(self._extract_metadata(header, parent, content))
734 713 # is this redundant?
735 714 self.metadata[msg_id] = md
736 715
737 716 e_outstanding = self._outstanding_dict[md['engine_uuid']]
738 717 if msg_id in e_outstanding:
739 718 e_outstanding.remove(msg_id)
740 719
741 720 # construct result:
742 721 if content['status'] == 'ok':
743 722 self.results[msg_id] = ExecuteReply(msg_id, content, md)
744 723 elif content['status'] == 'aborted':
745 724 self.results[msg_id] = error.TaskAborted(msg_id)
746 725 elif content['status'] == 'resubmitted':
747 726 # TODO: handle resubmission
748 727 pass
749 728 else:
750 729 self.results[msg_id] = self._unwrap_exception(content)
751 730
752 731 def _handle_apply_reply(self, msg):
753 732 """Save the reply to an apply_request into our results."""
754 733 parent = msg['parent_header']
755 734 msg_id = parent['msg_id']
756 735 if msg_id not in self.outstanding:
757 736 if msg_id in self.history:
758 737 print ("got stale result: %s"%msg_id)
759 738 print self.results[msg_id]
760 739 print msg
761 740 else:
762 741 print ("got unknown result: %s"%msg_id)
763 742 else:
764 743 self.outstanding.remove(msg_id)
765 744 content = msg['content']
766 745 header = msg['header']
767 746
768 747 # construct metadata:
769 748 md = self.metadata[msg_id]
770 749 md.update(self._extract_metadata(header, parent, content))
771 750 # is this redundant?
772 751 self.metadata[msg_id] = md
773 752
774 753 e_outstanding = self._outstanding_dict[md['engine_uuid']]
775 754 if msg_id in e_outstanding:
776 755 e_outstanding.remove(msg_id)
777 756
778 757 # construct result:
779 758 if content['status'] == 'ok':
780 759 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
781 760 elif content['status'] == 'aborted':
782 761 self.results[msg_id] = error.TaskAborted(msg_id)
783 762 elif content['status'] == 'resubmitted':
784 763 # TODO: handle resubmission
785 764 pass
786 765 else:
787 766 self.results[msg_id] = self._unwrap_exception(content)
788 767
789 768 def _flush_notifications(self):
790 769 """Flush notifications of engine registrations waiting
791 770 in ZMQ queue."""
792 771 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
793 772 while msg is not None:
794 773 if self.debug:
795 774 pprint(msg)
796 775 msg_type = msg['header']['msg_type']
797 776 handler = self._notification_handlers.get(msg_type, None)
798 777 if handler is None:
799 778 raise Exception("Unhandled message type: %s"%msg.msg_type)
800 779 else:
801 780 handler(msg)
802 781 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
803 782
804 783 def _flush_results(self, sock):
805 784 """Flush task or queue results waiting in ZMQ queue."""
806 785 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
807 786 while msg is not None:
808 787 if self.debug:
809 788 pprint(msg)
810 789 msg_type = msg['header']['msg_type']
811 790 handler = self._queue_handlers.get(msg_type, None)
812 791 if handler is None:
813 792 raise Exception("Unhandled message type: %s"%msg.msg_type)
814 793 else:
815 794 handler(msg)
816 795 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
817 796
818 797 def _flush_control(self, sock):
819 798 """Flush replies from the control channel waiting
820 799 in the ZMQ queue.
821 800
822 801 Currently: ignore them."""
823 802 if self._ignored_control_replies <= 0:
824 803 return
825 804 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
826 805 while msg is not None:
827 806 self._ignored_control_replies -= 1
828 807 if self.debug:
829 808 pprint(msg)
830 809 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 810
832 811 def _flush_ignored_control(self):
833 812 """flush ignored control replies"""
834 813 while self._ignored_control_replies > 0:
835 814 self.session.recv(self._control_socket)
836 815 self._ignored_control_replies -= 1
837 816
838 817 def _flush_ignored_hub_replies(self):
839 818 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
840 819 while msg is not None:
841 820 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
842 821
843 822 def _flush_iopub(self, sock):
844 823 """Flush replies from the iopub channel waiting
845 824 in the ZMQ queue.
846 825 """
847 826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
848 827 while msg is not None:
849 828 if self.debug:
850 829 pprint(msg)
851 830 parent = msg['parent_header']
852 831 # ignore IOPub messages with no parent.
853 832 # Caused by print statements or warnings from before the first execution.
854 833 if not parent:
855 834 continue
856 835 msg_id = parent['msg_id']
857 836 content = msg['content']
858 837 header = msg['header']
859 838 msg_type = msg['header']['msg_type']
860 839
861 840 # init metadata:
862 841 md = self.metadata[msg_id]
863 842
864 843 if msg_type == 'stream':
865 844 name = content['name']
866 845 s = md[name] or ''
867 846 md[name] = s + content['data']
868 847 elif msg_type == 'pyerr':
869 848 md.update({'pyerr' : self._unwrap_exception(content)})
870 849 elif msg_type == 'pyin':
871 850 md.update({'pyin' : content['code']})
872 851 elif msg_type == 'display_data':
873 852 md['outputs'].append(content)
874 853 elif msg_type == 'pyout':
875 854 md['pyout'] = content
876 855 elif msg_type == 'status':
877 856 # idle message comes after all outputs
878 857 if content['execution_state'] == 'idle':
879 858 md['outputs_ready'] = True
880 859 else:
881 860 # unhandled msg_type (status, etc.)
882 861 pass
883 862
884 863 # reduntant?
885 864 self.metadata[msg_id] = md
886 865
887 866 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
888 867
889 868 #--------------------------------------------------------------------------
890 869 # len, getitem
891 870 #--------------------------------------------------------------------------
892 871
893 872 def __len__(self):
894 873 """len(client) returns # of engines."""
895 874 return len(self.ids)
896 875
897 876 def __getitem__(self, key):
898 877 """index access returns DirectView multiplexer objects
899 878
900 879 Must be int, slice, or list/tuple/xrange of ints"""
901 880 if not isinstance(key, (int, slice, tuple, list, xrange)):
902 881 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
903 882 else:
904 883 return self.direct_view(key)
905 884
906 885 #--------------------------------------------------------------------------
907 886 # Begin public methods
908 887 #--------------------------------------------------------------------------
909 888
910 889 @property
911 890 def ids(self):
912 891 """Always up-to-date ids property."""
913 892 self._flush_notifications()
914 893 # always copy:
915 894 return list(self._ids)
916 895
917 896 def activate(self, targets='all', suffix=''):
918 897 """Create a DirectView and register it with IPython magics
919 898
920 899 Defines the magics `%px, %autopx, %pxresult, %%px`
921 900
922 901 Parameters
923 902 ----------
924 903
925 904 targets: int, list of ints, or 'all'
926 905 The engines on which the view's magics will run
927 906 suffix: str [default: '']
928 907 The suffix, if any, for the magics. This allows you to have
929 908 multiple views associated with parallel magics at the same time.
930 909
931 910 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
932 911 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
933 912 on engine 0.
934 913 """
935 914 view = self.direct_view(targets)
936 915 view.block = True
937 916 view.activate(suffix)
938 917 return view
939 918
940 919 def close(self):
941 920 if self._closed:
942 921 return
943 922 self.stop_spin_thread()
944 923 snames = filter(lambda n: n.endswith('socket'), dir(self))
945 924 for socket in map(lambda name: getattr(self, name), snames):
946 925 if isinstance(socket, zmq.Socket) and not socket.closed:
947 926 socket.close()
948 927 self._closed = True
949 928
950 929 def _spin_every(self, interval=1):
951 930 """target func for use in spin_thread"""
952 931 while True:
953 932 if self._stop_spinning.is_set():
954 933 return
955 934 time.sleep(interval)
956 935 self.spin()
957 936
958 937 def spin_thread(self, interval=1):
959 938 """call Client.spin() in a background thread on some regular interval
960 939
961 940 This helps ensure that messages don't pile up too much in the zmq queue
962 941 while you are working on other things, or just leaving an idle terminal.
963 942
964 943 It also helps limit potential padding of the `received` timestamp
965 944 on AsyncResult objects, used for timings.
966 945
967 946 Parameters
968 947 ----------
969 948
970 949 interval : float, optional
971 950 The interval on which to spin the client in the background thread
972 951 (simply passed to time.sleep).
973 952
974 953 Notes
975 954 -----
976 955
977 956 For precision timing, you may want to use this method to put a bound
978 957 on the jitter (in seconds) in `received` timestamps used
979 958 in AsyncResult.wall_time.
980 959
981 960 """
982 961 if self._spin_thread is not None:
983 962 self.stop_spin_thread()
984 963 self._stop_spinning.clear()
985 964 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
986 965 self._spin_thread.daemon = True
987 966 self._spin_thread.start()
988 967
989 968 def stop_spin_thread(self):
990 969 """stop background spin_thread, if any"""
991 970 if self._spin_thread is not None:
992 971 self._stop_spinning.set()
993 972 self._spin_thread.join()
994 973 self._spin_thread = None
995 974
996 975 def spin(self):
997 976 """Flush any registration notifications and execution results
998 977 waiting in the ZMQ queue.
999 978 """
1000 979 if self._notification_socket:
1001 980 self._flush_notifications()
1002 981 if self._iopub_socket:
1003 982 self._flush_iopub(self._iopub_socket)
1004 983 if self._mux_socket:
1005 984 self._flush_results(self._mux_socket)
1006 985 if self._task_socket:
1007 986 self._flush_results(self._task_socket)
1008 987 if self._control_socket:
1009 988 self._flush_control(self._control_socket)
1010 989 if self._query_socket:
1011 990 self._flush_ignored_hub_replies()
1012 991
1013 992 def wait(self, jobs=None, timeout=-1):
1014 993 """waits on one or more `jobs`, for up to `timeout` seconds.
1015 994
1016 995 Parameters
1017 996 ----------
1018 997
1019 998 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1020 999 ints are indices to self.history
1021 1000 strs are msg_ids
1022 1001 default: wait on all outstanding messages
1023 1002 timeout : float
1024 1003 a time in seconds, after which to give up.
1025 1004 default is -1, which means no timeout
1026 1005
1027 1006 Returns
1028 1007 -------
1029 1008
1030 1009 True : when all msg_ids are done
1031 1010 False : timeout reached, some msg_ids still outstanding
1032 1011 """
1033 1012 tic = time.time()
1034 1013 if jobs is None:
1035 1014 theids = self.outstanding
1036 1015 else:
1037 1016 if isinstance(jobs, (int, basestring, AsyncResult)):
1038 1017 jobs = [jobs]
1039 1018 theids = set()
1040 1019 for job in jobs:
1041 1020 if isinstance(job, int):
1042 1021 # index access
1043 1022 job = self.history[job]
1044 1023 elif isinstance(job, AsyncResult):
1045 1024 map(theids.add, job.msg_ids)
1046 1025 continue
1047 1026 theids.add(job)
1048 1027 if not theids.intersection(self.outstanding):
1049 1028 return True
1050 1029 self.spin()
1051 1030 while theids.intersection(self.outstanding):
1052 1031 if timeout >= 0 and ( time.time()-tic ) > timeout:
1053 1032 break
1054 1033 time.sleep(1e-3)
1055 1034 self.spin()
1056 1035 return len(theids.intersection(self.outstanding)) == 0
1057 1036
1058 1037 #--------------------------------------------------------------------------
1059 1038 # Control methods
1060 1039 #--------------------------------------------------------------------------
1061 1040
1062 1041 @spin_first
1063 1042 def clear(self, targets=None, block=None):
1064 1043 """Clear the namespace in target(s)."""
1065 1044 block = self.block if block is None else block
1066 1045 targets = self._build_targets(targets)[0]
1067 1046 for t in targets:
1068 1047 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1069 1048 error = False
1070 1049 if block:
1071 1050 self._flush_ignored_control()
1072 1051 for i in range(len(targets)):
1073 1052 idents,msg = self.session.recv(self._control_socket,0)
1074 1053 if self.debug:
1075 1054 pprint(msg)
1076 1055 if msg['content']['status'] != 'ok':
1077 1056 error = self._unwrap_exception(msg['content'])
1078 1057 else:
1079 1058 self._ignored_control_replies += len(targets)
1080 1059 if error:
1081 1060 raise error
1082 1061
1083 1062
1084 1063 @spin_first
1085 1064 def abort(self, jobs=None, targets=None, block=None):
1086 1065 """Abort specific jobs from the execution queues of target(s).
1087 1066
1088 1067 This is a mechanism to prevent jobs that have already been submitted
1089 1068 from executing.
1090 1069
1091 1070 Parameters
1092 1071 ----------
1093 1072
1094 1073 jobs : msg_id, list of msg_ids, or AsyncResult
1095 1074 The jobs to be aborted
1096 1075
1097 1076 If unspecified/None: abort all outstanding jobs.
1098 1077
1099 1078 """
1100 1079 block = self.block if block is None else block
1101 1080 jobs = jobs if jobs is not None else list(self.outstanding)
1102 1081 targets = self._build_targets(targets)[0]
1103 1082
1104 1083 msg_ids = []
1105 1084 if isinstance(jobs, (basestring,AsyncResult)):
1106 1085 jobs = [jobs]
1107 1086 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1108 1087 if bad_ids:
1109 1088 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1110 1089 for j in jobs:
1111 1090 if isinstance(j, AsyncResult):
1112 1091 msg_ids.extend(j.msg_ids)
1113 1092 else:
1114 1093 msg_ids.append(j)
1115 1094 content = dict(msg_ids=msg_ids)
1116 1095 for t in targets:
1117 1096 self.session.send(self._control_socket, 'abort_request',
1118 1097 content=content, ident=t)
1119 1098 error = False
1120 1099 if block:
1121 1100 self._flush_ignored_control()
1122 1101 for i in range(len(targets)):
1123 1102 idents,msg = self.session.recv(self._control_socket,0)
1124 1103 if self.debug:
1125 1104 pprint(msg)
1126 1105 if msg['content']['status'] != 'ok':
1127 1106 error = self._unwrap_exception(msg['content'])
1128 1107 else:
1129 1108 self._ignored_control_replies += len(targets)
1130 1109 if error:
1131 1110 raise error
1132 1111
1133 1112 @spin_first
1134 1113 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1135 1114 """Terminates one or more engine processes, optionally including the hub.
1136 1115
1137 1116 Parameters
1138 1117 ----------
1139 1118
1140 1119 targets: list of ints or 'all' [default: all]
1141 1120 Which engines to shutdown.
1142 1121 hub: bool [default: False]
1143 1122 Whether to include the Hub. hub=True implies targets='all'.
1144 1123 block: bool [default: self.block]
1145 1124 Whether to wait for clean shutdown replies or not.
1146 1125 restart: bool [default: False]
1147 1126 NOT IMPLEMENTED
1148 1127 whether to restart engines after shutting them down.
1149 1128 """
1150 1129
1151 1130 if restart:
1152 1131 raise NotImplementedError("Engine restart is not yet implemented")
1153 1132
1154 1133 block = self.block if block is None else block
1155 1134 if hub:
1156 1135 targets = 'all'
1157 1136 targets = self._build_targets(targets)[0]
1158 1137 for t in targets:
1159 1138 self.session.send(self._control_socket, 'shutdown_request',
1160 1139 content={'restart':restart},ident=t)
1161 1140 error = False
1162 1141 if block or hub:
1163 1142 self._flush_ignored_control()
1164 1143 for i in range(len(targets)):
1165 1144 idents,msg = self.session.recv(self._control_socket, 0)
1166 1145 if self.debug:
1167 1146 pprint(msg)
1168 1147 if msg['content']['status'] != 'ok':
1169 1148 error = self._unwrap_exception(msg['content'])
1170 1149 else:
1171 1150 self._ignored_control_replies += len(targets)
1172 1151
1173 1152 if hub:
1174 1153 time.sleep(0.25)
1175 1154 self.session.send(self._query_socket, 'shutdown_request')
1176 1155 idents,msg = self.session.recv(self._query_socket, 0)
1177 1156 if self.debug:
1178 1157 pprint(msg)
1179 1158 if msg['content']['status'] != 'ok':
1180 1159 error = self._unwrap_exception(msg['content'])
1181 1160
1182 1161 if error:
1183 1162 raise error
1184 1163
1185 1164 #--------------------------------------------------------------------------
1186 1165 # Execution related methods
1187 1166 #--------------------------------------------------------------------------
1188 1167
1189 1168 def _maybe_raise(self, result):
1190 1169 """wrapper for maybe raising an exception if apply failed."""
1191 1170 if isinstance(result, error.RemoteError):
1192 1171 raise result
1193 1172
1194 1173 return result
1195 1174
1196 1175 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1197 1176 ident=None):
1198 1177 """construct and send an apply message via a socket.
1199 1178
1200 1179 This is the principal method with which all engine execution is performed by views.
1201 1180 """
1202 1181
1203 1182 if self._closed:
1204 1183 raise RuntimeError("Client cannot be used after its sockets have been closed")
1205 1184
1206 1185 # defaults:
1207 1186 args = args if args is not None else []
1208 1187 kwargs = kwargs if kwargs is not None else {}
1209 1188 subheader = subheader if subheader is not None else {}
1210 1189
1211 1190 # validate arguments
1212 1191 if not callable(f) and not isinstance(f, Reference):
1213 1192 raise TypeError("f must be callable, not %s"%type(f))
1214 1193 if not isinstance(args, (tuple, list)):
1215 1194 raise TypeError("args must be tuple or list, not %s"%type(args))
1216 1195 if not isinstance(kwargs, dict):
1217 1196 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1218 1197 if not isinstance(subheader, dict):
1219 1198 raise TypeError("subheader must be dict, not %s"%type(subheader))
1220 1199
1221 1200 bufs = util.pack_apply_message(f,args,kwargs)
1222 1201
1223 1202 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1224 1203 subheader=subheader, track=track)
1225 1204
1226 1205 msg_id = msg['header']['msg_id']
1227 1206 self.outstanding.add(msg_id)
1228 1207 if ident:
1229 1208 # possibly routed to a specific engine
1230 1209 if isinstance(ident, list):
1231 1210 ident = ident[-1]
1232 1211 if ident in self._engines.values():
1233 1212 # save for later, in case of engine death
1234 1213 self._outstanding_dict[ident].add(msg_id)
1235 1214 self.history.append(msg_id)
1236 1215 self.metadata[msg_id]['submitted'] = datetime.now()
1237 1216
1238 1217 return msg
1239 1218
1240 1219 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1241 1220 """construct and send an execute request via a socket.
1242 1221
1243 1222 """
1244 1223
1245 1224 if self._closed:
1246 1225 raise RuntimeError("Client cannot be used after its sockets have been closed")
1247 1226
1248 1227 # defaults:
1249 1228 subheader = subheader if subheader is not None else {}
1250 1229
1251 1230 # validate arguments
1252 1231 if not isinstance(code, basestring):
1253 1232 raise TypeError("code must be text, not %s" % type(code))
1254 1233 if not isinstance(subheader, dict):
1255 1234 raise TypeError("subheader must be dict, not %s" % type(subheader))
1256 1235
1257 1236 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1258 1237
1259 1238
1260 1239 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1261 1240 subheader=subheader)
1262 1241
1263 1242 msg_id = msg['header']['msg_id']
1264 1243 self.outstanding.add(msg_id)
1265 1244 if ident:
1266 1245 # possibly routed to a specific engine
1267 1246 if isinstance(ident, list):
1268 1247 ident = ident[-1]
1269 1248 if ident in self._engines.values():
1270 1249 # save for later, in case of engine death
1271 1250 self._outstanding_dict[ident].add(msg_id)
1272 1251 self.history.append(msg_id)
1273 1252 self.metadata[msg_id]['submitted'] = datetime.now()
1274 1253
1275 1254 return msg
1276 1255
1277 1256 #--------------------------------------------------------------------------
1278 1257 # construct a View object
1279 1258 #--------------------------------------------------------------------------
1280 1259
1281 1260 def load_balanced_view(self, targets=None):
1282 1261 """construct a DirectView object.
1283 1262
1284 1263 If no arguments are specified, create a LoadBalancedView
1285 1264 using all engines.
1286 1265
1287 1266 Parameters
1288 1267 ----------
1289 1268
1290 1269 targets: list,slice,int,etc. [default: use all engines]
1291 1270 The subset of engines across which to load-balance
1292 1271 """
1293 1272 if targets == 'all':
1294 1273 targets = None
1295 1274 if targets is not None:
1296 1275 targets = self._build_targets(targets)[1]
1297 1276 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1298 1277
1299 1278 def direct_view(self, targets='all'):
1300 1279 """construct a DirectView object.
1301 1280
1302 1281 If no targets are specified, create a DirectView using all engines.
1303 1282
1304 1283 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1305 1284 evaluate the target engines at each execution, whereas rc[:] will connect to
1306 1285 all *current* engines, and that list will not change.
1307 1286
1308 1287 That is, 'all' will always use all engines, whereas rc[:] will not use
1309 1288 engines added after the DirectView is constructed.
1310 1289
1311 1290 Parameters
1312 1291 ----------
1313 1292
1314 1293 targets: list,slice,int,etc. [default: use all engines]
1315 1294 The engines to use for the View
1316 1295 """
1317 1296 single = isinstance(targets, int)
1318 1297 # allow 'all' to be lazily evaluated at each execution
1319 1298 if targets != 'all':
1320 1299 targets = self._build_targets(targets)[1]
1321 1300 if single:
1322 1301 targets = targets[0]
1323 1302 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1324 1303
1325 1304 #--------------------------------------------------------------------------
1326 1305 # Query methods
1327 1306 #--------------------------------------------------------------------------
1328 1307
1329 1308 @spin_first
1330 1309 def get_result(self, indices_or_msg_ids=None, block=None):
1331 1310 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1332 1311
1333 1312 If the client already has the results, no request to the Hub will be made.
1334 1313
1335 1314 This is a convenient way to construct AsyncResult objects, which are wrappers
1336 1315 that include metadata about execution, and allow for awaiting results that
1337 1316 were not submitted by this Client.
1338 1317
1339 1318 It can also be a convenient way to retrieve the metadata associated with
1340 1319 blocking execution, since it always retrieves
1341 1320
1342 1321 Examples
1343 1322 --------
1344 1323 ::
1345 1324
1346 1325 In [10]: r = client.apply()
1347 1326
1348 1327 Parameters
1349 1328 ----------
1350 1329
1351 1330 indices_or_msg_ids : integer history index, str msg_id, or list of either
1352 1331 The indices or msg_ids of indices to be retrieved
1353 1332
1354 1333 block : bool
1355 1334 Whether to wait for the result to be done
1356 1335
1357 1336 Returns
1358 1337 -------
1359 1338
1360 1339 AsyncResult
1361 1340 A single AsyncResult object will always be returned.
1362 1341
1363 1342 AsyncHubResult
1364 1343 A subclass of AsyncResult that retrieves results from the Hub
1365 1344
1366 1345 """
1367 1346 block = self.block if block is None else block
1368 1347 if indices_or_msg_ids is None:
1369 1348 indices_or_msg_ids = -1
1370 1349
1371 1350 if not isinstance(indices_or_msg_ids, (list,tuple)):
1372 1351 indices_or_msg_ids = [indices_or_msg_ids]
1373 1352
1374 1353 theids = []
1375 1354 for id in indices_or_msg_ids:
1376 1355 if isinstance(id, int):
1377 1356 id = self.history[id]
1378 1357 if not isinstance(id, basestring):
1379 1358 raise TypeError("indices must be str or int, not %r"%id)
1380 1359 theids.append(id)
1381 1360
1382 1361 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1383 1362 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1384 1363
1385 1364 if remote_ids:
1386 1365 ar = AsyncHubResult(self, msg_ids=theids)
1387 1366 else:
1388 1367 ar = AsyncResult(self, msg_ids=theids)
1389 1368
1390 1369 if block:
1391 1370 ar.wait()
1392 1371
1393 1372 return ar
1394 1373
1395 1374 @spin_first
1396 1375 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1397 1376 """Resubmit one or more tasks.
1398 1377
1399 1378 in-flight tasks may not be resubmitted.
1400 1379
1401 1380 Parameters
1402 1381 ----------
1403 1382
1404 1383 indices_or_msg_ids : integer history index, str msg_id, or list of either
1405 1384 The indices or msg_ids of indices to be retrieved
1406 1385
1407 1386 block : bool
1408 1387 Whether to wait for the result to be done
1409 1388
1410 1389 Returns
1411 1390 -------
1412 1391
1413 1392 AsyncHubResult
1414 1393 A subclass of AsyncResult that retrieves results from the Hub
1415 1394
1416 1395 """
1417 1396 block = self.block if block is None else block
1418 1397 if indices_or_msg_ids is None:
1419 1398 indices_or_msg_ids = -1
1420 1399
1421 1400 if not isinstance(indices_or_msg_ids, (list,tuple)):
1422 1401 indices_or_msg_ids = [indices_or_msg_ids]
1423 1402
1424 1403 theids = []
1425 1404 for id in indices_or_msg_ids:
1426 1405 if isinstance(id, int):
1427 1406 id = self.history[id]
1428 1407 if not isinstance(id, basestring):
1429 1408 raise TypeError("indices must be str or int, not %r"%id)
1430 1409 theids.append(id)
1431 1410
1432 1411 content = dict(msg_ids = theids)
1433 1412
1434 1413 self.session.send(self._query_socket, 'resubmit_request', content)
1435 1414
1436 1415 zmq.select([self._query_socket], [], [])
1437 1416 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1438 1417 if self.debug:
1439 1418 pprint(msg)
1440 1419 content = msg['content']
1441 1420 if content['status'] != 'ok':
1442 1421 raise self._unwrap_exception(content)
1443 1422 mapping = content['resubmitted']
1444 1423 new_ids = [ mapping[msg_id] for msg_id in theids ]
1445 1424
1446 1425 ar = AsyncHubResult(self, msg_ids=new_ids)
1447 1426
1448 1427 if block:
1449 1428 ar.wait()
1450 1429
1451 1430 return ar
1452 1431
1453 1432 @spin_first
1454 1433 def result_status(self, msg_ids, status_only=True):
1455 1434 """Check on the status of the result(s) of the apply request with `msg_ids`.
1456 1435
1457 1436 If status_only is False, then the actual results will be retrieved, else
1458 1437 only the status of the results will be checked.
1459 1438
1460 1439 Parameters
1461 1440 ----------
1462 1441
1463 1442 msg_ids : list of msg_ids
1464 1443 if int:
1465 1444 Passed as index to self.history for convenience.
1466 1445 status_only : bool (default: True)
1467 1446 if False:
1468 1447 Retrieve the actual results of completed tasks.
1469 1448
1470 1449 Returns
1471 1450 -------
1472 1451
1473 1452 results : dict
1474 1453 There will always be the keys 'pending' and 'completed', which will
1475 1454 be lists of msg_ids that are incomplete or complete. If `status_only`
1476 1455 is False, then completed results will be keyed by their `msg_id`.
1477 1456 """
1478 1457 if not isinstance(msg_ids, (list,tuple)):
1479 1458 msg_ids = [msg_ids]
1480 1459
1481 1460 theids = []
1482 1461 for msg_id in msg_ids:
1483 1462 if isinstance(msg_id, int):
1484 1463 msg_id = self.history[msg_id]
1485 1464 if not isinstance(msg_id, basestring):
1486 1465 raise TypeError("msg_ids must be str, not %r"%msg_id)
1487 1466 theids.append(msg_id)
1488 1467
1489 1468 completed = []
1490 1469 local_results = {}
1491 1470
1492 1471 # comment this block out to temporarily disable local shortcut:
1493 1472 for msg_id in theids:
1494 1473 if msg_id in self.results:
1495 1474 completed.append(msg_id)
1496 1475 local_results[msg_id] = self.results[msg_id]
1497 1476 theids.remove(msg_id)
1498 1477
1499 1478 if theids: # some not locally cached
1500 1479 content = dict(msg_ids=theids, status_only=status_only)
1501 1480 msg = self.session.send(self._query_socket, "result_request", content=content)
1502 1481 zmq.select([self._query_socket], [], [])
1503 1482 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1504 1483 if self.debug:
1505 1484 pprint(msg)
1506 1485 content = msg['content']
1507 1486 if content['status'] != 'ok':
1508 1487 raise self._unwrap_exception(content)
1509 1488 buffers = msg['buffers']
1510 1489 else:
1511 1490 content = dict(completed=[],pending=[])
1512 1491
1513 1492 content['completed'].extend(completed)
1514 1493
1515 1494 if status_only:
1516 1495 return content
1517 1496
1518 1497 failures = []
1519 1498 # load cached results into result:
1520 1499 content.update(local_results)
1521 1500
1522 1501 # update cache with results:
1523 1502 for msg_id in sorted(theids):
1524 1503 if msg_id in content['completed']:
1525 1504 rec = content[msg_id]
1526 1505 parent = rec['header']
1527 1506 header = rec['result_header']
1528 1507 rcontent = rec['result_content']
1529 1508 iodict = rec['io']
1530 1509 if isinstance(rcontent, str):
1531 1510 rcontent = self.session.unpack(rcontent)
1532 1511
1533 1512 md = self.metadata[msg_id]
1534 1513 md.update(self._extract_metadata(header, parent, rcontent))
1535 1514 if rec.get('received'):
1536 1515 md['received'] = rec['received']
1537 1516 md.update(iodict)
1538 1517
1539 1518 if rcontent['status'] == 'ok':
1540 1519 if header['msg_type'] == 'apply_reply':
1541 1520 res,buffers = util.unserialize_object(buffers)
1542 1521 elif header['msg_type'] == 'execute_reply':
1543 1522 res = ExecuteReply(msg_id, rcontent, md)
1544 1523 else:
1545 1524 raise KeyError("unhandled msg type: %r" % header[msg_type])
1546 1525 else:
1547 1526 res = self._unwrap_exception(rcontent)
1548 1527 failures.append(res)
1549 1528
1550 1529 self.results[msg_id] = res
1551 1530 content[msg_id] = res
1552 1531
1553 1532 if len(theids) == 1 and failures:
1554 1533 raise failures[0]
1555 1534
1556 1535 error.collect_exceptions(failures, "result_status")
1557 1536 return content
1558 1537
1559 1538 @spin_first
1560 1539 def queue_status(self, targets='all', verbose=False):
1561 1540 """Fetch the status of engine queues.
1562 1541
1563 1542 Parameters
1564 1543 ----------
1565 1544
1566 1545 targets : int/str/list of ints/strs
1567 1546 the engines whose states are to be queried.
1568 1547 default : all
1569 1548 verbose : bool
1570 1549 Whether to return lengths only, or lists of ids for each element
1571 1550 """
1572 1551 if targets == 'all':
1573 1552 # allow 'all' to be evaluated on the engine
1574 1553 engine_ids = None
1575 1554 else:
1576 1555 engine_ids = self._build_targets(targets)[1]
1577 1556 content = dict(targets=engine_ids, verbose=verbose)
1578 1557 self.session.send(self._query_socket, "queue_request", content=content)
1579 1558 idents,msg = self.session.recv(self._query_socket, 0)
1580 1559 if self.debug:
1581 1560 pprint(msg)
1582 1561 content = msg['content']
1583 1562 status = content.pop('status')
1584 1563 if status != 'ok':
1585 1564 raise self._unwrap_exception(content)
1586 1565 content = rekey(content)
1587 1566 if isinstance(targets, int):
1588 1567 return content[targets]
1589 1568 else:
1590 1569 return content
1591 1570
1592 1571 @spin_first
1593 1572 def purge_results(self, jobs=[], targets=[]):
1594 1573 """Tell the Hub to forget results.
1595 1574
1596 1575 Individual results can be purged by msg_id, or the entire
1597 1576 history of specific targets can be purged.
1598 1577
1599 1578 Use `purge_results('all')` to scrub everything from the Hub's db.
1600 1579
1601 1580 Parameters
1602 1581 ----------
1603 1582
1604 1583 jobs : str or list of str or AsyncResult objects
1605 1584 the msg_ids whose results should be forgotten.
1606 1585 targets : int/str/list of ints/strs
1607 1586 The targets, by int_id, whose entire history is to be purged.
1608 1587
1609 1588 default : None
1610 1589 """
1611 1590 if not targets and not jobs:
1612 1591 raise ValueError("Must specify at least one of `targets` and `jobs`")
1613 1592 if targets:
1614 1593 targets = self._build_targets(targets)[1]
1615 1594
1616 1595 # construct msg_ids from jobs
1617 1596 if jobs == 'all':
1618 1597 msg_ids = jobs
1619 1598 else:
1620 1599 msg_ids = []
1621 1600 if isinstance(jobs, (basestring,AsyncResult)):
1622 1601 jobs = [jobs]
1623 1602 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1624 1603 if bad_ids:
1625 1604 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1626 1605 for j in jobs:
1627 1606 if isinstance(j, AsyncResult):
1628 1607 msg_ids.extend(j.msg_ids)
1629 1608 else:
1630 1609 msg_ids.append(j)
1631 1610
1632 1611 content = dict(engine_ids=targets, msg_ids=msg_ids)
1633 1612 self.session.send(self._query_socket, "purge_request", content=content)
1634 1613 idents, msg = self.session.recv(self._query_socket, 0)
1635 1614 if self.debug:
1636 1615 pprint(msg)
1637 1616 content = msg['content']
1638 1617 if content['status'] != 'ok':
1639 1618 raise self._unwrap_exception(content)
1640 1619
1641 1620 @spin_first
1642 1621 def hub_history(self):
1643 1622 """Get the Hub's history
1644 1623
1645 1624 Just like the Client, the Hub has a history, which is a list of msg_ids.
1646 1625 This will contain the history of all clients, and, depending on configuration,
1647 1626 may contain history across multiple cluster sessions.
1648 1627
1649 1628 Any msg_id returned here is a valid argument to `get_result`.
1650 1629
1651 1630 Returns
1652 1631 -------
1653 1632
1654 1633 msg_ids : list of strs
1655 1634 list of all msg_ids, ordered by task submission time.
1656 1635 """
1657 1636
1658 1637 self.session.send(self._query_socket, "history_request", content={})
1659 1638 idents, msg = self.session.recv(self._query_socket, 0)
1660 1639
1661 1640 if self.debug:
1662 1641 pprint(msg)
1663 1642 content = msg['content']
1664 1643 if content['status'] != 'ok':
1665 1644 raise self._unwrap_exception(content)
1666 1645 else:
1667 1646 return content['history']
1668 1647
1669 1648 @spin_first
1670 1649 def db_query(self, query, keys=None):
1671 1650 """Query the Hub's TaskRecord database
1672 1651
1673 1652 This will return a list of task record dicts that match `query`
1674 1653
1675 1654 Parameters
1676 1655 ----------
1677 1656
1678 1657 query : mongodb query dict
1679 1658 The search dict. See mongodb query docs for details.
1680 1659 keys : list of strs [optional]
1681 1660 The subset of keys to be returned. The default is to fetch everything but buffers.
1682 1661 'msg_id' will *always* be included.
1683 1662 """
1684 1663 if isinstance(keys, basestring):
1685 1664 keys = [keys]
1686 1665 content = dict(query=query, keys=keys)
1687 1666 self.session.send(self._query_socket, "db_request", content=content)
1688 1667 idents, msg = self.session.recv(self._query_socket, 0)
1689 1668 if self.debug:
1690 1669 pprint(msg)
1691 1670 content = msg['content']
1692 1671 if content['status'] != 'ok':
1693 1672 raise self._unwrap_exception(content)
1694 1673
1695 1674 records = content['records']
1696 1675
1697 1676 buffer_lens = content['buffer_lens']
1698 1677 result_buffer_lens = content['result_buffer_lens']
1699 1678 buffers = msg['buffers']
1700 1679 has_bufs = buffer_lens is not None
1701 1680 has_rbufs = result_buffer_lens is not None
1702 1681 for i,rec in enumerate(records):
1703 1682 # relink buffers
1704 1683 if has_bufs:
1705 1684 blen = buffer_lens[i]
1706 1685 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1707 1686 if has_rbufs:
1708 1687 blen = result_buffer_lens[i]
1709 1688 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1710 1689
1711 1690 return records
1712 1691
1713 1692 __all__ = [ 'Client' ]
@@ -1,1338 +1,1342 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.py3compat import cast_bytes
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'received': None,
68 68 'result_header' : None,
69 69 'result_content' : None,
70 70 'result_buffers' : None,
71 71 'queue' : None,
72 72 'pyin' : None,
73 73 'pyout': None,
74 74 'pyerr': None,
75 75 'stdout': '',
76 76 'stderr': '',
77 77 }
78 78
79 79 def init_record(msg):
80 80 """Initialize a TaskRecord based on a request."""
81 81 header = msg['header']
82 82 return {
83 83 'msg_id' : header['msg_id'],
84 84 'header' : header,
85 85 'content': msg['content'],
86 86 'buffers': msg['buffers'],
87 87 'submitted': header['date'],
88 88 'client_uuid' : None,
89 89 'engine_uuid' : None,
90 90 'started': None,
91 91 'completed': None,
92 92 'resubmitted': None,
93 93 'received': None,
94 94 'result_header' : None,
95 95 'result_content' : None,
96 96 'result_buffers' : None,
97 97 'queue' : None,
98 98 'pyin' : None,
99 99 'pyout': None,
100 100 'pyerr': None,
101 101 'stdout': '',
102 102 'stderr': '',
103 103 }
104 104
105 105
106 106 class EngineConnector(HasTraits):
107 107 """A simple object for accessing the various zmq connections of an object.
108 108 Attributes are:
109 109 id (int): engine ID
110 110 uuid (str): uuid (unused?)
111 111 queue (str): identity of queue's DEALER socket
112 112 registration (str): identity of registration DEALER socket
113 113 heartbeat (str): identity of heartbeat DEALER socket
114 114 """
115 115 id=Integer(0)
116 116 queue=CBytes()
117 117 control=CBytes()
118 118 registration=CBytes()
119 119 heartbeat=CBytes()
120 120 pending=Set()
121 121
122 122 _db_shortcuts = {
123 123 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
124 124 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
125 125 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
126 126 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
127 127 }
128 128
129 129 class HubFactory(RegistrationFactory):
130 130 """The Configurable for setting up a Hub."""
131 131
132 132 # port-pairs for monitoredqueues:
133 133 hb = Tuple(Integer,Integer,config=True,
134 134 help="""DEALER/SUB Port pair for Engine heartbeats""")
135 135 def _hb_default(self):
136 136 return tuple(util.select_random_ports(2))
137 137
138 138 mux = Tuple(Integer,Integer,config=True,
139 139 help="""Engine/Client Port pair for MUX queue""")
140 140
141 141 def _mux_default(self):
142 142 return tuple(util.select_random_ports(2))
143 143
144 144 task = Tuple(Integer,Integer,config=True,
145 145 help="""Engine/Client Port pair for Task queue""")
146 146 def _task_default(self):
147 147 return tuple(util.select_random_ports(2))
148 148
149 149 control = Tuple(Integer,Integer,config=True,
150 150 help="""Engine/Client Port pair for Control queue""")
151 151
152 152 def _control_default(self):
153 153 return tuple(util.select_random_ports(2))
154 154
155 155 iopub = Tuple(Integer,Integer,config=True,
156 156 help="""Engine/Client Port pair for IOPub relay""")
157 157
158 158 def _iopub_default(self):
159 159 return tuple(util.select_random_ports(2))
160 160
161 161 # single ports:
162 162 mon_port = Integer(config=True,
163 163 help="""Monitor (SUB) port for queue traffic""")
164 164
165 165 def _mon_port_default(self):
166 166 return util.select_random_ports(1)[0]
167 167
168 168 notifier_port = Integer(config=True,
169 169 help="""PUB port for sending engine status notifications""")
170 170
171 171 def _notifier_port_default(self):
172 172 return util.select_random_ports(1)[0]
173 173
174 174 engine_ip = Unicode('127.0.0.1', config=True,
175 175 help="IP on which to listen for engine connections. [default: loopback]")
176 176 engine_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for engine connections. [default: tcp]")
178 178
179 179 client_ip = Unicode('127.0.0.1', config=True,
180 180 help="IP on which to listen for client connections. [default: loopback]")
181 181 client_transport = Unicode('tcp', config=True,
182 182 help="0MQ transport for client connections. [default : tcp]")
183 183
184 184 monitor_ip = Unicode('127.0.0.1', config=True,
185 185 help="IP on which to listen for monitor messages. [default: loopback]")
186 186 monitor_transport = Unicode('tcp', config=True,
187 187 help="0MQ transport for monitor messages. [default : tcp]")
188 188
189 189 monitor_url = Unicode('')
190 190
191 191 db_class = DottedObjectName('NoDB',
192 192 config=True, help="""The class to use for the DB backend
193 193
194 194 Options include:
195 195
196 196 SQLiteDB: SQLite
197 197 MongoDB : use MongoDB
198 198 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
199 199 NoDB : disable database altogether (default)
200 200
201 201 """)
202 202
203 203 # not configurable
204 204 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
205 205 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
206 206
207 207 def _ip_changed(self, name, old, new):
208 208 self.engine_ip = new
209 209 self.client_ip = new
210 210 self.monitor_ip = new
211 211 self._update_monitor_url()
212 212
213 213 def _update_monitor_url(self):
214 214 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
215 215
216 216 def _transport_changed(self, name, old, new):
217 217 self.engine_transport = new
218 218 self.client_transport = new
219 219 self.monitor_transport = new
220 220 self._update_monitor_url()
221 221
222 222 def __init__(self, **kwargs):
223 223 super(HubFactory, self).__init__(**kwargs)
224 224 self._update_monitor_url()
225 225
226 226
227 227 def construct(self):
228 228 self.init_hub()
229 229
230 230 def start(self):
231 231 self.heartmonitor.start()
232 232 self.log.info("Heartmonitor started")
233 233
234 234 def init_hub(self):
235 235 """construct"""
236 236 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
237 237 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
238 238
239 239 ctx = self.context
240 240 loop = self.loop
241 241
242 try:
243 scheme = self.config.TaskScheduler.scheme_name
244 except AttributeError:
245 from .scheduler import TaskScheduler
246 scheme = TaskScheduler.scheme_name.get_default_value()
247
248 # build connection dicts
249 engine = self.engine_info = {
250 'registration' : engine_iface % self.regport,
251 'control' : engine_iface % self.control[1],
252 'mux' : engine_iface % self.mux[1],
253 'hb_ping' : engine_iface % self.hb[0],
254 'hb_pong' : engine_iface % self.hb[1],
255 'task' : engine_iface % self.task[1],
256 'iopub' : engine_iface % self.iopub[1],
257 }
258
259 client = self.client_info = {
260 'registration' : client_iface % self.regport,
261 'control' : client_iface % self.control[0],
262 'mux' : client_iface % self.mux[0],
263 'task' : client_iface % self.task[0],
264 'task_scheme' : scheme,
265 'iopub' : client_iface % self.iopub[0],
266 'notification' : client_iface % self.notifier_port,
267 }
268
269 self.log.debug("Hub engine addrs: %s", self.engine_info)
270 self.log.debug("Hub client addrs: %s", self.client_info)
271
242 272 # Registrar socket
243 273 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
244 q.bind(client_iface % self.regport)
274 q.bind(client['registration'])
245 275 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
246 276 if self.client_ip != self.engine_ip:
247 q.bind(engine_iface % self.regport)
277 q.bind(engine['registration'])
248 278 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
249 279
250 280 ### Engine connections ###
251 281
252 282 # heartbeat
253 283 hpub = ctx.socket(zmq.PUB)
254 hpub.bind(engine_iface % self.hb[0])
284 hpub.bind(engine['hb_ping'])
255 285 hrep = ctx.socket(zmq.ROUTER)
256 hrep.bind(engine_iface % self.hb[1])
286 hrep.bind(engine['hb_pong'])
257 287 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
258 288 pingstream=ZMQStream(hpub,loop),
259 289 pongstream=ZMQStream(hrep,loop)
260 290 )
261 291
262 292 ### Client connections ###
293
263 294 # Notifier socket
264 295 n = ZMQStream(ctx.socket(zmq.PUB), loop)
265 n.bind(client_iface%self.notifier_port)
296 n.bind(client['notification'])
266 297
267 298 ### build and launch the queues ###
268 299
269 300 # monitor socket
270 301 sub = ctx.socket(zmq.SUB)
271 302 sub.setsockopt(zmq.SUBSCRIBE, b"")
272 303 sub.bind(self.monitor_url)
273 304 sub.bind('inproc://monitor')
274 305 sub = ZMQStream(sub, loop)
275 306
276 307 # connect the db
277 308 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
278 309 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
279 310 self.db = import_item(str(db_class))(session=self.session.session,
280 311 config=self.config, log=self.log)
281 312 time.sleep(.25)
282 try:
283 scheme = self.config.TaskScheduler.scheme_name
284 except AttributeError:
285 from .scheduler import TaskScheduler
286 scheme = TaskScheduler.scheme_name.get_default_value()
287 # build connection dicts
288 self.engine_info = {
289 'control' : engine_iface%self.control[1],
290 'mux': engine_iface%self.mux[1],
291 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
292 'task' : engine_iface%self.task[1],
293 'iopub' : engine_iface%self.iopub[1],
294 # 'monitor' : engine_iface%self.mon_port,
295 }
296
297 self.client_info = {
298 'control' : client_iface%self.control[0],
299 'mux': client_iface%self.mux[0],
300 'task' : (scheme, client_iface%self.task[0]),
301 'iopub' : client_iface%self.iopub[0],
302 'notification': client_iface%self.notifier_port
303 }
304 self.log.debug("Hub engine addrs: %s", self.engine_info)
305 self.log.debug("Hub client addrs: %s", self.client_info)
306 313
307 314 # resubmit stream
308 315 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
309 url = util.disambiguate_url(self.client_info['task'][-1])
310 r.setsockopt(zmq.IDENTITY, self.session.bsession)
316 url = util.disambiguate_url(self.client_info['task'])
311 317 r.connect(url)
312 318
313 319 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
314 320 query=q, notifier=n, resubmit=r, db=self.db,
315 321 engine_info=self.engine_info, client_info=self.client_info,
316 322 log=self.log)
317 323
318 324
319 325 class Hub(SessionFactory):
320 326 """The IPython Controller Hub with 0MQ connections
321 327
322 328 Parameters
323 329 ==========
324 330 loop: zmq IOLoop instance
325 331 session: Session object
326 332 <removed> context: zmq context for creating new connections (?)
327 333 queue: ZMQStream for monitoring the command queue (SUB)
328 334 query: ZMQStream for engine registration and client queries requests (ROUTER)
329 335 heartbeat: HeartMonitor object checking the pulse of the engines
330 336 notifier: ZMQStream for broadcasting engine registration changes (PUB)
331 337 db: connection to db for out of memory logging of commands
332 338 NotImplemented
333 339 engine_info: dict of zmq connection information for engines to connect
334 340 to the queues.
335 341 client_info: dict of zmq connection information for engines to connect
336 342 to the queues.
337 343 """
338 344 # internal data structures:
339 345 ids=Set() # engine IDs
340 346 keytable=Dict()
341 347 by_ident=Dict()
342 348 engines=Dict()
343 349 clients=Dict()
344 350 hearts=Dict()
345 351 pending=Set()
346 352 queues=Dict() # pending msg_ids keyed by engine_id
347 353 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
348 354 completed=Dict() # completed msg_ids keyed by engine_id
349 355 all_completed=Set() # completed msg_ids keyed by engine_id
350 356 dead_engines=Set() # completed msg_ids keyed by engine_id
351 357 unassigned=Set() # set of task msg_ds not yet assigned a destination
352 358 incoming_registrations=Dict()
353 359 registration_timeout=Integer()
354 360 _idcounter=Integer(0)
355 361
356 362 # objects from constructor:
357 363 query=Instance(ZMQStream)
358 364 monitor=Instance(ZMQStream)
359 365 notifier=Instance(ZMQStream)
360 366 resubmit=Instance(ZMQStream)
361 367 heartmonitor=Instance(HeartMonitor)
362 368 db=Instance(object)
363 369 client_info=Dict()
364 370 engine_info=Dict()
365 371
366 372
367 373 def __init__(self, **kwargs):
368 374 """
369 375 # universal:
370 376 loop: IOLoop for creating future connections
371 377 session: streamsession for sending serialized data
372 378 # engine:
373 379 queue: ZMQStream for monitoring queue messages
374 380 query: ZMQStream for engine+client registration and client requests
375 381 heartbeat: HeartMonitor object for tracking engines
376 382 # extra:
377 383 db: ZMQStream for db connection (NotImplemented)
378 384 engine_info: zmq address/protocol dict for engine connections
379 385 client_info: zmq address/protocol dict for client connections
380 386 """
381 387
382 388 super(Hub, self).__init__(**kwargs)
383 389 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
384 390
385 391 # validate connection dicts:
386 392 for k,v in self.client_info.iteritems():
387 if k == 'task':
388 util.validate_url_container(v[1])
393 if k == 'task_scheme':
394 continue
389 395 else:
390 396 util.validate_url_container(v)
391 397 # util.validate_url_container(self.client_info)
392 398 util.validate_url_container(self.engine_info)
393 399
394 400 # register our callbacks
395 401 self.query.on_recv(self.dispatch_query)
396 402 self.monitor.on_recv(self.dispatch_monitor_traffic)
397 403
398 404 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
399 405 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
400 406
401 407 self.monitor_handlers = {b'in' : self.save_queue_request,
402 408 b'out': self.save_queue_result,
403 409 b'intask': self.save_task_request,
404 410 b'outtask': self.save_task_result,
405 411 b'tracktask': self.save_task_destination,
406 412 b'incontrol': _passer,
407 413 b'outcontrol': _passer,
408 414 b'iopub': self.save_iopub_message,
409 415 }
410 416
411 417 self.query_handlers = {'queue_request': self.queue_status,
412 418 'result_request': self.get_results,
413 419 'history_request': self.get_history,
414 420 'db_request': self.db_query,
415 421 'purge_request': self.purge_results,
416 422 'load_request': self.check_load,
417 423 'resubmit_request': self.resubmit_task,
418 424 'shutdown_request': self.shutdown_request,
419 425 'registration_request' : self.register_engine,
420 426 'unregistration_request' : self.unregister_engine,
421 427 'connection_request': self.connection_request,
422 428 }
423 429
424 430 # ignore resubmit replies
425 431 self.resubmit.on_recv(lambda msg: None, copy=False)
426 432
427 433 self.log.info("hub::created hub")
428 434
429 435 @property
430 436 def _next_id(self):
431 437 """gemerate a new ID.
432 438
433 439 No longer reuse old ids, just count from 0."""
434 440 newid = self._idcounter
435 441 self._idcounter += 1
436 442 return newid
437 443 # newid = 0
438 444 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
439 445 # # print newid, self.ids, self.incoming_registrations
440 446 # while newid in self.ids or newid in incoming:
441 447 # newid += 1
442 448 # return newid
443 449
444 450 #-----------------------------------------------------------------------------
445 451 # message validation
446 452 #-----------------------------------------------------------------------------
447 453
448 454 def _validate_targets(self, targets):
449 455 """turn any valid targets argument into a list of integer ids"""
450 456 if targets is None:
451 457 # default to all
452 458 return self.ids
453 459
454 460 if isinstance(targets, (int,str,unicode)):
455 461 # only one target specified
456 462 targets = [targets]
457 463 _targets = []
458 464 for t in targets:
459 465 # map raw identities to ids
460 466 if isinstance(t, (str,unicode)):
461 467 t = self.by_ident.get(cast_bytes(t), t)
462 468 _targets.append(t)
463 469 targets = _targets
464 470 bad_targets = [ t for t in targets if t not in self.ids ]
465 471 if bad_targets:
466 472 raise IndexError("No Such Engine: %r" % bad_targets)
467 473 if not targets:
468 474 raise IndexError("No Engines Registered")
469 475 return targets
470 476
471 477 #-----------------------------------------------------------------------------
472 478 # dispatch methods (1 per stream)
473 479 #-----------------------------------------------------------------------------
474 480
475 481
476 482 @util.log_errors
477 483 def dispatch_monitor_traffic(self, msg):
478 484 """all ME and Task queue messages come through here, as well as
479 485 IOPub traffic."""
480 486 self.log.debug("monitor traffic: %r", msg[0])
481 487 switch = msg[0]
482 488 try:
483 489 idents, msg = self.session.feed_identities(msg[1:])
484 490 except ValueError:
485 491 idents=[]
486 492 if not idents:
487 493 self.log.error("Monitor message without topic: %r", msg)
488 494 return
489 495 handler = self.monitor_handlers.get(switch, None)
490 496 if handler is not None:
491 497 handler(idents, msg)
492 498 else:
493 499 self.log.error("Unrecognized monitor topic: %r", switch)
494 500
495 501
496 502 @util.log_errors
497 503 def dispatch_query(self, msg):
498 504 """Route registration requests and queries from clients."""
499 505 try:
500 506 idents, msg = self.session.feed_identities(msg)
501 507 except ValueError:
502 508 idents = []
503 509 if not idents:
504 510 self.log.error("Bad Query Message: %r", msg)
505 511 return
506 512 client_id = idents[0]
507 513 try:
508 514 msg = self.session.unserialize(msg, content=True)
509 515 except Exception:
510 516 content = error.wrap_exception()
511 517 self.log.error("Bad Query Message: %r", msg, exc_info=True)
512 518 self.session.send(self.query, "hub_error", ident=client_id,
513 519 content=content)
514 520 return
515 521 # print client_id, header, parent, content
516 522 #switch on message type:
517 523 msg_type = msg['header']['msg_type']
518 524 self.log.info("client::client %r requested %r", client_id, msg_type)
519 525 handler = self.query_handlers.get(msg_type, None)
520 526 try:
521 527 assert handler is not None, "Bad Message Type: %r" % msg_type
522 528 except:
523 529 content = error.wrap_exception()
524 530 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
525 531 self.session.send(self.query, "hub_error", ident=client_id,
526 532 content=content)
527 533 return
528 534
529 535 else:
530 536 handler(idents, msg)
531 537
532 538 def dispatch_db(self, msg):
533 539 """"""
534 540 raise NotImplementedError
535 541
536 542 #---------------------------------------------------------------------------
537 543 # handler methods (1 per event)
538 544 #---------------------------------------------------------------------------
539 545
540 546 #----------------------- Heartbeat --------------------------------------
541 547
542 548 def handle_new_heart(self, heart):
543 549 """handler to attach to heartbeater.
544 550 Called when a new heart starts to beat.
545 551 Triggers completion of registration."""
546 552 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
547 553 if heart not in self.incoming_registrations:
548 554 self.log.info("heartbeat::ignoring new heart: %r", heart)
549 555 else:
550 556 self.finish_registration(heart)
551 557
552 558
553 559 def handle_heart_failure(self, heart):
554 560 """handler to attach to heartbeater.
555 561 called when a previously registered heart fails to respond to beat request.
556 562 triggers unregistration"""
557 563 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
558 564 eid = self.hearts.get(heart, None)
559 565 queue = self.engines[eid].queue
560 566 if eid is None or self.keytable[eid] in self.dead_engines:
561 567 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
562 568 else:
563 569 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
564 570
565 571 #----------------------- MUX Queue Traffic ------------------------------
566 572
567 573 def save_queue_request(self, idents, msg):
568 574 if len(idents) < 2:
569 575 self.log.error("invalid identity prefix: %r", idents)
570 576 return
571 577 queue_id, client_id = idents[:2]
572 578 try:
573 579 msg = self.session.unserialize(msg)
574 580 except Exception:
575 581 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
576 582 return
577 583
578 584 eid = self.by_ident.get(queue_id, None)
579 585 if eid is None:
580 586 self.log.error("queue::target %r not registered", queue_id)
581 587 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
582 588 return
583 589 record = init_record(msg)
584 590 msg_id = record['msg_id']
585 591 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
586 592 # Unicode in records
587 593 record['engine_uuid'] = queue_id.decode('ascii')
588 594 record['client_uuid'] = msg['header']['session']
589 595 record['queue'] = 'mux'
590 596
591 597 try:
592 598 # it's posible iopub arrived first:
593 599 existing = self.db.get_record(msg_id)
594 600 for key,evalue in existing.iteritems():
595 601 rvalue = record.get(key, None)
596 602 if evalue and rvalue and evalue != rvalue:
597 603 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
598 604 elif evalue and not rvalue:
599 605 record[key] = evalue
600 606 try:
601 607 self.db.update_record(msg_id, record)
602 608 except Exception:
603 609 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
604 610 except KeyError:
605 611 try:
606 612 self.db.add_record(msg_id, record)
607 613 except Exception:
608 614 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
609 615
610 616
611 617 self.pending.add(msg_id)
612 618 self.queues[eid].append(msg_id)
613 619
614 620 def save_queue_result(self, idents, msg):
615 621 if len(idents) < 2:
616 622 self.log.error("invalid identity prefix: %r", idents)
617 623 return
618 624
619 625 client_id, queue_id = idents[:2]
620 626 try:
621 627 msg = self.session.unserialize(msg)
622 628 except Exception:
623 629 self.log.error("queue::engine %r sent invalid message to %r: %r",
624 630 queue_id, client_id, msg, exc_info=True)
625 631 return
626 632
627 633 eid = self.by_ident.get(queue_id, None)
628 634 if eid is None:
629 635 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
630 636 return
631 637
632 638 parent = msg['parent_header']
633 639 if not parent:
634 640 return
635 641 msg_id = parent['msg_id']
636 642 if msg_id in self.pending:
637 643 self.pending.remove(msg_id)
638 644 self.all_completed.add(msg_id)
639 645 self.queues[eid].remove(msg_id)
640 646 self.completed[eid].append(msg_id)
641 647 self.log.info("queue::request %r completed on %s", msg_id, eid)
642 648 elif msg_id not in self.all_completed:
643 649 # it could be a result from a dead engine that died before delivering the
644 650 # result
645 651 self.log.warn("queue:: unknown msg finished %r", msg_id)
646 652 return
647 653 # update record anyway, because the unregistration could have been premature
648 654 rheader = msg['header']
649 655 completed = rheader['date']
650 656 started = rheader.get('started', None)
651 657 result = {
652 658 'result_header' : rheader,
653 659 'result_content': msg['content'],
654 660 'received': datetime.now(),
655 661 'started' : started,
656 662 'completed' : completed
657 663 }
658 664
659 665 result['result_buffers'] = msg['buffers']
660 666 try:
661 667 self.db.update_record(msg_id, result)
662 668 except Exception:
663 669 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
664 670
665 671
666 672 #--------------------- Task Queue Traffic ------------------------------
667 673
668 674 def save_task_request(self, idents, msg):
669 675 """Save the submission of a task."""
670 676 client_id = idents[0]
671 677
672 678 try:
673 679 msg = self.session.unserialize(msg)
674 680 except Exception:
675 681 self.log.error("task::client %r sent invalid task message: %r",
676 682 client_id, msg, exc_info=True)
677 683 return
678 684 record = init_record(msg)
679 685
680 686 record['client_uuid'] = msg['header']['session']
681 687 record['queue'] = 'task'
682 688 header = msg['header']
683 689 msg_id = header['msg_id']
684 690 self.pending.add(msg_id)
685 691 self.unassigned.add(msg_id)
686 692 try:
687 693 # it's posible iopub arrived first:
688 694 existing = self.db.get_record(msg_id)
689 695 if existing['resubmitted']:
690 696 for key in ('submitted', 'client_uuid', 'buffers'):
691 697 # don't clobber these keys on resubmit
692 698 # submitted and client_uuid should be different
693 699 # and buffers might be big, and shouldn't have changed
694 700 record.pop(key)
695 701 # still check content,header which should not change
696 702 # but are not expensive to compare as buffers
697 703
698 704 for key,evalue in existing.iteritems():
699 705 if key.endswith('buffers'):
700 706 # don't compare buffers
701 707 continue
702 708 rvalue = record.get(key, None)
703 709 if evalue and rvalue and evalue != rvalue:
704 710 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
705 711 elif evalue and not rvalue:
706 712 record[key] = evalue
707 713 try:
708 714 self.db.update_record(msg_id, record)
709 715 except Exception:
710 716 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
711 717 except KeyError:
712 718 try:
713 719 self.db.add_record(msg_id, record)
714 720 except Exception:
715 721 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
716 722 except Exception:
717 723 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
718 724
719 725 def save_task_result(self, idents, msg):
720 726 """save the result of a completed task."""
721 727 client_id = idents[0]
722 728 try:
723 729 msg = self.session.unserialize(msg)
724 730 except Exception:
725 731 self.log.error("task::invalid task result message send to %r: %r",
726 732 client_id, msg, exc_info=True)
727 733 return
728 734
729 735 parent = msg['parent_header']
730 736 if not parent:
731 737 # print msg
732 738 self.log.warn("Task %r had no parent!", msg)
733 739 return
734 740 msg_id = parent['msg_id']
735 741 if msg_id in self.unassigned:
736 742 self.unassigned.remove(msg_id)
737 743
738 744 header = msg['header']
739 745 engine_uuid = header.get('engine', u'')
740 746 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
741 747
742 748 status = header.get('status', None)
743 749
744 750 if msg_id in self.pending:
745 751 self.log.info("task::task %r finished on %s", msg_id, eid)
746 752 self.pending.remove(msg_id)
747 753 self.all_completed.add(msg_id)
748 754 if eid is not None:
749 755 if status != 'aborted':
750 756 self.completed[eid].append(msg_id)
751 757 if msg_id in self.tasks[eid]:
752 758 self.tasks[eid].remove(msg_id)
753 759 completed = header['date']
754 760 started = header.get('started', None)
755 761 result = {
756 762 'result_header' : header,
757 763 'result_content': msg['content'],
758 764 'started' : started,
759 765 'completed' : completed,
760 766 'received' : datetime.now(),
761 767 'engine_uuid': engine_uuid,
762 768 }
763 769
764 770 result['result_buffers'] = msg['buffers']
765 771 try:
766 772 self.db.update_record(msg_id, result)
767 773 except Exception:
768 774 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
769 775
770 776 else:
771 777 self.log.debug("task::unknown task %r finished", msg_id)
772 778
773 779 def save_task_destination(self, idents, msg):
774 780 try:
775 781 msg = self.session.unserialize(msg, content=True)
776 782 except Exception:
777 783 self.log.error("task::invalid task tracking message", exc_info=True)
778 784 return
779 785 content = msg['content']
780 786 # print (content)
781 787 msg_id = content['msg_id']
782 788 engine_uuid = content['engine_id']
783 789 eid = self.by_ident[cast_bytes(engine_uuid)]
784 790
785 791 self.log.info("task::task %r arrived on %r", msg_id, eid)
786 792 if msg_id in self.unassigned:
787 793 self.unassigned.remove(msg_id)
788 794 # else:
789 795 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
790 796
791 797 self.tasks[eid].append(msg_id)
792 798 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
793 799 try:
794 800 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
795 801 except Exception:
796 802 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
797 803
798 804
799 805 def mia_task_request(self, idents, msg):
800 806 raise NotImplementedError
801 807 client_id = idents[0]
802 808 # content = dict(mia=self.mia,status='ok')
803 809 # self.session.send('mia_reply', content=content, idents=client_id)
804 810
805 811
806 812 #--------------------- IOPub Traffic ------------------------------
807 813
808 814 def save_iopub_message(self, topics, msg):
809 815 """save an iopub message into the db"""
810 816 # print (topics)
811 817 try:
812 818 msg = self.session.unserialize(msg, content=True)
813 819 except Exception:
814 820 self.log.error("iopub::invalid IOPub message", exc_info=True)
815 821 return
816 822
817 823 parent = msg['parent_header']
818 824 if not parent:
819 825 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
820 826 return
821 827 msg_id = parent['msg_id']
822 828 msg_type = msg['header']['msg_type']
823 829 content = msg['content']
824 830
825 831 # ensure msg_id is in db
826 832 try:
827 833 rec = self.db.get_record(msg_id)
828 834 except KeyError:
829 835 rec = empty_record()
830 836 rec['msg_id'] = msg_id
831 837 self.db.add_record(msg_id, rec)
832 838 # stream
833 839 d = {}
834 840 if msg_type == 'stream':
835 841 name = content['name']
836 842 s = rec[name] or ''
837 843 d[name] = s + content['data']
838 844
839 845 elif msg_type == 'pyerr':
840 846 d['pyerr'] = content
841 847 elif msg_type == 'pyin':
842 848 d['pyin'] = content['code']
843 849 elif msg_type in ('display_data', 'pyout'):
844 850 d[msg_type] = content
845 851 elif msg_type == 'status':
846 852 pass
847 853 else:
848 854 self.log.warn("unhandled iopub msg_type: %r", msg_type)
849 855
850 856 if not d:
851 857 return
852 858
853 859 try:
854 860 self.db.update_record(msg_id, d)
855 861 except Exception:
856 862 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
857 863
858 864
859 865
860 866 #-------------------------------------------------------------------------
861 867 # Registration requests
862 868 #-------------------------------------------------------------------------
863 869
864 870 def connection_request(self, client_id, msg):
865 871 """Reply with connection addresses for clients."""
866 872 self.log.info("client::client %r connected", client_id)
867 873 content = dict(status='ok')
868 content.update(self.client_info)
869 874 jsonable = {}
870 875 for k,v in self.keytable.iteritems():
871 876 if v not in self.dead_engines:
872 877 jsonable[str(k)] = v.decode('ascii')
873 878 content['engines'] = jsonable
874 879 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
875 880
876 881 def register_engine(self, reg, msg):
877 882 """Register a new engine."""
878 883 content = msg['content']
879 884 try:
880 885 queue = cast_bytes(content['queue'])
881 886 except KeyError:
882 887 self.log.error("registration::queue not specified", exc_info=True)
883 888 return
884 889 heart = content.get('heartbeat', None)
885 890 if heart:
886 891 heart = cast_bytes(heart)
887 892 """register a new engine, and create the socket(s) necessary"""
888 893 eid = self._next_id
889 894 # print (eid, queue, reg, heart)
890 895
891 896 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
892 897
893 898 content = dict(id=eid,status='ok')
894 content.update(self.engine_info)
895 899 # check if requesting available IDs:
896 900 if queue in self.by_ident:
897 901 try:
898 902 raise KeyError("queue_id %r in use" % queue)
899 903 except:
900 904 content = error.wrap_exception()
901 905 self.log.error("queue_id %r in use", queue, exc_info=True)
902 906 elif heart in self.hearts: # need to check unique hearts?
903 907 try:
904 908 raise KeyError("heart_id %r in use" % heart)
905 909 except:
906 910 self.log.error("heart_id %r in use", heart, exc_info=True)
907 911 content = error.wrap_exception()
908 912 else:
909 913 for h, pack in self.incoming_registrations.iteritems():
910 914 if heart == h:
911 915 try:
912 916 raise KeyError("heart_id %r in use" % heart)
913 917 except:
914 918 self.log.error("heart_id %r in use", heart, exc_info=True)
915 919 content = error.wrap_exception()
916 920 break
917 921 elif queue == pack[1]:
918 922 try:
919 923 raise KeyError("queue_id %r in use" % queue)
920 924 except:
921 925 self.log.error("queue_id %r in use", queue, exc_info=True)
922 926 content = error.wrap_exception()
923 927 break
924 928
925 929 msg = self.session.send(self.query, "registration_reply",
926 930 content=content,
927 931 ident=reg)
928 932
929 933 if content['status'] == 'ok':
930 934 if heart in self.heartmonitor.hearts:
931 935 # already beating
932 936 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
933 937 self.finish_registration(heart)
934 938 else:
935 939 purge = lambda : self._purge_stalled_registration(heart)
936 940 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
937 941 dc.start()
938 942 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
939 943 else:
940 944 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
941 945 return eid
942 946
943 947 def unregister_engine(self, ident, msg):
944 948 """Unregister an engine that explicitly requested to leave."""
945 949 try:
946 950 eid = msg['content']['id']
947 951 except:
948 952 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
949 953 return
950 954 self.log.info("registration::unregister_engine(%r)", eid)
951 955 # print (eid)
952 956 uuid = self.keytable[eid]
953 957 content=dict(id=eid, queue=uuid.decode('ascii'))
954 958 self.dead_engines.add(uuid)
955 959 # self.ids.remove(eid)
956 960 # uuid = self.keytable.pop(eid)
957 961 #
958 962 # ec = self.engines.pop(eid)
959 963 # self.hearts.pop(ec.heartbeat)
960 964 # self.by_ident.pop(ec.queue)
961 965 # self.completed.pop(eid)
962 966 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
963 967 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
964 968 dc.start()
965 969 ############## TODO: HANDLE IT ################
966 970
967 971 if self.notifier:
968 972 self.session.send(self.notifier, "unregistration_notification", content=content)
969 973
970 974 def _handle_stranded_msgs(self, eid, uuid):
971 975 """Handle messages known to be on an engine when the engine unregisters.
972 976
973 977 It is possible that this will fire prematurely - that is, an engine will
974 978 go down after completing a result, and the client will be notified
975 979 that the result failed and later receive the actual result.
976 980 """
977 981
978 982 outstanding = self.queues[eid]
979 983
980 984 for msg_id in outstanding:
981 985 self.pending.remove(msg_id)
982 986 self.all_completed.add(msg_id)
983 987 try:
984 988 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
985 989 except:
986 990 content = error.wrap_exception()
987 991 # build a fake header:
988 992 header = {}
989 993 header['engine'] = uuid
990 994 header['date'] = datetime.now()
991 995 rec = dict(result_content=content, result_header=header, result_buffers=[])
992 996 rec['completed'] = header['date']
993 997 rec['engine_uuid'] = uuid
994 998 try:
995 999 self.db.update_record(msg_id, rec)
996 1000 except Exception:
997 1001 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
998 1002
999 1003
1000 1004 def finish_registration(self, heart):
1001 1005 """Second half of engine registration, called after our HeartMonitor
1002 1006 has received a beat from the Engine's Heart."""
1003 1007 try:
1004 1008 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
1005 1009 except KeyError:
1006 1010 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1007 1011 return
1008 1012 self.log.info("registration::finished registering engine %i:%r", eid, queue)
1009 1013 if purge is not None:
1010 1014 purge.stop()
1011 1015 control = queue
1012 1016 self.ids.add(eid)
1013 1017 self.keytable[eid] = queue
1014 1018 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
1015 1019 control=control, heartbeat=heart)
1016 1020 self.by_ident[queue] = eid
1017 1021 self.queues[eid] = list()
1018 1022 self.tasks[eid] = list()
1019 1023 self.completed[eid] = list()
1020 1024 self.hearts[heart] = eid
1021 1025 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
1022 1026 if self.notifier:
1023 1027 self.session.send(self.notifier, "registration_notification", content=content)
1024 1028 self.log.info("engine::Engine Connected: %i", eid)
1025 1029
1026 1030 def _purge_stalled_registration(self, heart):
1027 1031 if heart in self.incoming_registrations:
1028 1032 eid = self.incoming_registrations.pop(heart)[0]
1029 1033 self.log.info("registration::purging stalled registration: %i", eid)
1030 1034 else:
1031 1035 pass
1032 1036
1033 1037 #-------------------------------------------------------------------------
1034 1038 # Client Requests
1035 1039 #-------------------------------------------------------------------------
1036 1040
1037 1041 def shutdown_request(self, client_id, msg):
1038 1042 """handle shutdown request."""
1039 1043 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1040 1044 # also notify other clients of shutdown
1041 1045 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1042 1046 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1043 1047 dc.start()
1044 1048
1045 1049 def _shutdown(self):
1046 1050 self.log.info("hub::hub shutting down.")
1047 1051 time.sleep(0.1)
1048 1052 sys.exit(0)
1049 1053
1050 1054
1051 1055 def check_load(self, client_id, msg):
1052 1056 content = msg['content']
1053 1057 try:
1054 1058 targets = content['targets']
1055 1059 targets = self._validate_targets(targets)
1056 1060 except:
1057 1061 content = error.wrap_exception()
1058 1062 self.session.send(self.query, "hub_error",
1059 1063 content=content, ident=client_id)
1060 1064 return
1061 1065
1062 1066 content = dict(status='ok')
1063 1067 # loads = {}
1064 1068 for t in targets:
1065 1069 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1066 1070 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1067 1071
1068 1072
1069 1073 def queue_status(self, client_id, msg):
1070 1074 """Return the Queue status of one or more targets.
1071 1075 if verbose: return the msg_ids
1072 1076 else: return len of each type.
1073 1077 keys: queue (pending MUX jobs)
1074 1078 tasks (pending Task jobs)
1075 1079 completed (finished jobs from both queues)"""
1076 1080 content = msg['content']
1077 1081 targets = content['targets']
1078 1082 try:
1079 1083 targets = self._validate_targets(targets)
1080 1084 except:
1081 1085 content = error.wrap_exception()
1082 1086 self.session.send(self.query, "hub_error",
1083 1087 content=content, ident=client_id)
1084 1088 return
1085 1089 verbose = content.get('verbose', False)
1086 1090 content = dict(status='ok')
1087 1091 for t in targets:
1088 1092 queue = self.queues[t]
1089 1093 completed = self.completed[t]
1090 1094 tasks = self.tasks[t]
1091 1095 if not verbose:
1092 1096 queue = len(queue)
1093 1097 completed = len(completed)
1094 1098 tasks = len(tasks)
1095 1099 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1096 1100 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1097 1101 # print (content)
1098 1102 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1099 1103
1100 1104 def purge_results(self, client_id, msg):
1101 1105 """Purge results from memory. This method is more valuable before we move
1102 1106 to a DB based message storage mechanism."""
1103 1107 content = msg['content']
1104 1108 self.log.info("Dropping records with %s", content)
1105 1109 msg_ids = content.get('msg_ids', [])
1106 1110 reply = dict(status='ok')
1107 1111 if msg_ids == 'all':
1108 1112 try:
1109 1113 self.db.drop_matching_records(dict(completed={'$ne':None}))
1110 1114 except Exception:
1111 1115 reply = error.wrap_exception()
1112 1116 else:
1113 1117 pending = filter(lambda m: m in self.pending, msg_ids)
1114 1118 if pending:
1115 1119 try:
1116 1120 raise IndexError("msg pending: %r" % pending[0])
1117 1121 except:
1118 1122 reply = error.wrap_exception()
1119 1123 else:
1120 1124 try:
1121 1125 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1122 1126 except Exception:
1123 1127 reply = error.wrap_exception()
1124 1128
1125 1129 if reply['status'] == 'ok':
1126 1130 eids = content.get('engine_ids', [])
1127 1131 for eid in eids:
1128 1132 if eid not in self.engines:
1129 1133 try:
1130 1134 raise IndexError("No such engine: %i" % eid)
1131 1135 except:
1132 1136 reply = error.wrap_exception()
1133 1137 break
1134 1138 uid = self.engines[eid].queue
1135 1139 try:
1136 1140 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1137 1141 except Exception:
1138 1142 reply = error.wrap_exception()
1139 1143 break
1140 1144
1141 1145 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1142 1146
1143 1147 def resubmit_task(self, client_id, msg):
1144 1148 """Resubmit one or more tasks."""
1145 1149 def finish(reply):
1146 1150 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1147 1151
1148 1152 content = msg['content']
1149 1153 msg_ids = content['msg_ids']
1150 1154 reply = dict(status='ok')
1151 1155 try:
1152 1156 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1153 1157 'header', 'content', 'buffers'])
1154 1158 except Exception:
1155 1159 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1156 1160 return finish(error.wrap_exception())
1157 1161
1158 1162 # validate msg_ids
1159 1163 found_ids = [ rec['msg_id'] for rec in records ]
1160 1164 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1161 1165 if len(records) > len(msg_ids):
1162 1166 try:
1163 1167 raise RuntimeError("DB appears to be in an inconsistent state."
1164 1168 "More matching records were found than should exist")
1165 1169 except Exception:
1166 1170 return finish(error.wrap_exception())
1167 1171 elif len(records) < len(msg_ids):
1168 1172 missing = [ m for m in msg_ids if m not in found_ids ]
1169 1173 try:
1170 1174 raise KeyError("No such msg(s): %r" % missing)
1171 1175 except KeyError:
1172 1176 return finish(error.wrap_exception())
1173 1177 elif pending_ids:
1174 1178 pass
1175 1179 # no need to raise on resubmit of pending task, now that we
1176 1180 # resubmit under new ID, but do we want to raise anyway?
1177 1181 # msg_id = invalid_ids[0]
1178 1182 # try:
1179 1183 # raise ValueError("Task(s) %r appears to be inflight" % )
1180 1184 # except Exception:
1181 1185 # return finish(error.wrap_exception())
1182 1186
1183 1187 # mapping of original IDs to resubmitted IDs
1184 1188 resubmitted = {}
1185 1189
1186 1190 # send the messages
1187 1191 for rec in records:
1188 1192 header = rec['header']
1189 1193 msg = self.session.msg(header['msg_type'], parent=header)
1190 1194 msg_id = msg['msg_id']
1191 1195 msg['content'] = rec['content']
1192 1196
1193 1197 # use the old header, but update msg_id and timestamp
1194 1198 fresh = msg['header']
1195 1199 header['msg_id'] = fresh['msg_id']
1196 1200 header['date'] = fresh['date']
1197 1201 msg['header'] = header
1198 1202
1199 1203 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1200 1204
1201 1205 resubmitted[rec['msg_id']] = msg_id
1202 1206 self.pending.add(msg_id)
1203 1207 msg['buffers'] = rec['buffers']
1204 1208 try:
1205 1209 self.db.add_record(msg_id, init_record(msg))
1206 1210 except Exception:
1207 1211 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1208 1212 return finish(error.wrap_exception())
1209 1213
1210 1214 finish(dict(status='ok', resubmitted=resubmitted))
1211 1215
1212 1216 # store the new IDs in the Task DB
1213 1217 for msg_id, resubmit_id in resubmitted.iteritems():
1214 1218 try:
1215 1219 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1216 1220 except Exception:
1217 1221 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1218 1222
1219 1223
1220 1224 def _extract_record(self, rec):
1221 1225 """decompose a TaskRecord dict into subsection of reply for get_result"""
1222 1226 io_dict = {}
1223 1227 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1224 1228 io_dict[key] = rec[key]
1225 1229 content = { 'result_content': rec['result_content'],
1226 1230 'header': rec['header'],
1227 1231 'result_header' : rec['result_header'],
1228 1232 'received' : rec['received'],
1229 1233 'io' : io_dict,
1230 1234 }
1231 1235 if rec['result_buffers']:
1232 1236 buffers = map(bytes, rec['result_buffers'])
1233 1237 else:
1234 1238 buffers = []
1235 1239
1236 1240 return content, buffers
1237 1241
1238 1242 def get_results(self, client_id, msg):
1239 1243 """Get the result of 1 or more messages."""
1240 1244 content = msg['content']
1241 1245 msg_ids = sorted(set(content['msg_ids']))
1242 1246 statusonly = content.get('status_only', False)
1243 1247 pending = []
1244 1248 completed = []
1245 1249 content = dict(status='ok')
1246 1250 content['pending'] = pending
1247 1251 content['completed'] = completed
1248 1252 buffers = []
1249 1253 if not statusonly:
1250 1254 try:
1251 1255 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1252 1256 # turn match list into dict, for faster lookup
1253 1257 records = {}
1254 1258 for rec in matches:
1255 1259 records[rec['msg_id']] = rec
1256 1260 except Exception:
1257 1261 content = error.wrap_exception()
1258 1262 self.session.send(self.query, "result_reply", content=content,
1259 1263 parent=msg, ident=client_id)
1260 1264 return
1261 1265 else:
1262 1266 records = {}
1263 1267 for msg_id in msg_ids:
1264 1268 if msg_id in self.pending:
1265 1269 pending.append(msg_id)
1266 1270 elif msg_id in self.all_completed:
1267 1271 completed.append(msg_id)
1268 1272 if not statusonly:
1269 1273 c,bufs = self._extract_record(records[msg_id])
1270 1274 content[msg_id] = c
1271 1275 buffers.extend(bufs)
1272 1276 elif msg_id in records:
1273 1277 if rec['completed']:
1274 1278 completed.append(msg_id)
1275 1279 c,bufs = self._extract_record(records[msg_id])
1276 1280 content[msg_id] = c
1277 1281 buffers.extend(bufs)
1278 1282 else:
1279 1283 pending.append(msg_id)
1280 1284 else:
1281 1285 try:
1282 1286 raise KeyError('No such message: '+msg_id)
1283 1287 except:
1284 1288 content = error.wrap_exception()
1285 1289 break
1286 1290 self.session.send(self.query, "result_reply", content=content,
1287 1291 parent=msg, ident=client_id,
1288 1292 buffers=buffers)
1289 1293
1290 1294 def get_history(self, client_id, msg):
1291 1295 """Get a list of all msg_ids in our DB records"""
1292 1296 try:
1293 1297 msg_ids = self.db.get_history()
1294 1298 except Exception as e:
1295 1299 content = error.wrap_exception()
1296 1300 else:
1297 1301 content = dict(status='ok', history=msg_ids)
1298 1302
1299 1303 self.session.send(self.query, "history_reply", content=content,
1300 1304 parent=msg, ident=client_id)
1301 1305
1302 1306 def db_query(self, client_id, msg):
1303 1307 """Perform a raw query on the task record database."""
1304 1308 content = msg['content']
1305 1309 query = content.get('query', {})
1306 1310 keys = content.get('keys', None)
1307 1311 buffers = []
1308 1312 empty = list()
1309 1313 try:
1310 1314 records = self.db.find_records(query, keys)
1311 1315 except Exception as e:
1312 1316 content = error.wrap_exception()
1313 1317 else:
1314 1318 # extract buffers from reply content:
1315 1319 if keys is not None:
1316 1320 buffer_lens = [] if 'buffers' in keys else None
1317 1321 result_buffer_lens = [] if 'result_buffers' in keys else None
1318 1322 else:
1319 1323 buffer_lens = None
1320 1324 result_buffer_lens = None
1321 1325
1322 1326 for rec in records:
1323 1327 # buffers may be None, so double check
1324 1328 b = rec.pop('buffers', empty) or empty
1325 1329 if buffer_lens is not None:
1326 1330 buffer_lens.append(len(b))
1327 1331 buffers.extend(b)
1328 1332 rb = rec.pop('result_buffers', empty) or empty
1329 1333 if result_buffer_lens is not None:
1330 1334 result_buffer_lens.append(len(rb))
1331 1335 buffers.extend(rb)
1332 1336 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1333 1337 result_buffer_lens=result_buffer_lens)
1334 1338 # self.log.debug (content)
1335 1339 self.session.send(self.query, "db_reply", content=content,
1336 1340 parent=msg, ident=client_id,
1337 1341 buffers=buffers)
1338 1342
@@ -1,237 +1,227 b''
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from getpass import getpass
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.external.ssh import tunnel
26 26 # internal
27 27 from IPython.utils.traitlets import (
28 28 Instance, Dict, Integer, Type, CFloat, Unicode, CBytes, Bool
29 29 )
30 30 from IPython.utils.py3compat import cast_bytes
31 31
32 32 from IPython.parallel.controller.heartmonitor import Heart
33 33 from IPython.parallel.factory import RegistrationFactory
34 34 from IPython.parallel.util import disambiguate_url
35 35
36 36 from IPython.zmq.session import Message
37 37 from IPython.zmq.ipkernel import Kernel
38 38
39 39 class EngineFactory(RegistrationFactory):
40 40 """IPython engine"""
41 41
42 42 # configurables:
43 43 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
44 44 help="""The OutStream for handling stdout/err.
45 45 Typically 'IPython.zmq.iostream.OutStream'""")
46 46 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
47 47 help="""The class for handling displayhook.
48 48 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
49 49 location=Unicode(config=True,
50 50 help="""The location (an IP address) of the controller. This is
51 51 used for disambiguating URLs, to determine whether
52 52 loopback should be used to connect or the public address.""")
53 timeout=CFloat(2,config=True,
53 timeout=CFloat(5, config=True,
54 54 help="""The time (in seconds) to wait for the Controller to respond
55 55 to registration requests before giving up.""")
56 56 sshserver=Unicode(config=True,
57 57 help="""The SSH server to use for tunneling connections to the Controller.""")
58 58 sshkey=Unicode(config=True,
59 59 help="""The SSH private key file to use when tunneling connections to the Controller.""")
60 60 paramiko=Bool(sys.platform == 'win32', config=True,
61 61 help="""Whether to use paramiko instead of openssh for tunnels.""")
62 62
63 63 # not configurable:
64 connection_info = Dict()
64 65 user_ns=Dict()
65 66 id=Integer(allow_none=True)
66 67 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
67 68 kernel=Instance(Kernel)
68 69
69 70 bident = CBytes()
70 71 ident = Unicode()
71 72 def _ident_changed(self, name, old, new):
72 73 self.bident = cast_bytes(new)
73 74 using_ssh=Bool(False)
74 75
75 76
76 77 def __init__(self, **kwargs):
77 78 super(EngineFactory, self).__init__(**kwargs)
78 79 self.ident = self.session.session
79 80
80 81 def init_connector(self):
81 82 """construct connection function, which handles tunnels."""
82 83 self.using_ssh = bool(self.sshkey or self.sshserver)
83 84
84 85 if self.sshkey and not self.sshserver:
85 86 # We are using ssh directly to the controller, tunneling localhost to localhost
86 87 self.sshserver = self.url.split('://')[1].split(':')[0]
87 88
88 89 if self.using_ssh:
89 90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
90 91 password=False
91 92 else:
92 93 password = getpass("SSH Password for %s: "%self.sshserver)
93 94 else:
94 95 password = False
95 96
96 97 def connect(s, url):
97 98 url = disambiguate_url(url, self.location)
98 99 if self.using_ssh:
99 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
100 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
100 101 return tunnel.tunnel_connection(s, url, self.sshserver,
101 102 keyfile=self.sshkey, paramiko=self.paramiko,
102 103 password=password,
103 104 )
104 105 else:
105 106 return s.connect(url)
106 107
107 108 def maybe_tunnel(url):
108 109 """like connect, but don't complete the connection (for use by heartbeat)"""
109 110 url = disambiguate_url(url, self.location)
110 111 if self.using_ssh:
111 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
112 self.log.debug("Tunneling connection to %s via %s", url, self.sshserver)
112 113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
113 114 keyfile=self.sshkey, paramiko=self.paramiko,
114 115 password=password,
115 116 )
116 return url
117 return str(url)
117 118 return connect, maybe_tunnel
118 119
119 120 def register(self):
120 121 """send the registration_request"""
121 122
122 123 self.log.info("Registering with controller at %s"%self.url)
123 124 ctx = self.context
124 125 connect,maybe_tunnel = self.init_connector()
125 126 reg = ctx.socket(zmq.DEALER)
126 127 reg.setsockopt(zmq.IDENTITY, self.bident)
127 128 connect(reg, self.url)
128 129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
129 130
130 131
131 132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
132 133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
133 134 # print (self.session.key)
134 135 self.session.send(self.registrar, "registration_request",content=content)
135 136
136 137 def complete_registration(self, msg, connect, maybe_tunnel):
137 138 # print msg
138 139 self._abort_dc.stop()
139 140 ctx = self.context
140 141 loop = self.loop
141 142 identity = self.bident
142 143 idents,msg = self.session.feed_identities(msg)
143 msg = Message(self.session.unserialize(msg))
144 msg = self.session.unserialize(msg)
145 content = msg['content']
146 info = self.connection_info
144 147
145 if msg.content.status == 'ok':
146 self.id = int(msg.content.id)
148 if content['status'] == 'ok':
149 self.id = int(content['id'])
147 150
148 151 # launch heartbeat
149 hb_addrs = msg.content.heartbeat
150
151 152 # possibly forward hb ports with tunnels
152 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
153 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 hb_ping = maybe_tunnel(info['hb_ping'])
154 hb_pong = maybe_tunnel(info['hb_pong'])
155
156 heart = Heart(hb_ping, hb_pong, heart_id=identity)
154 157 heart.start()
155 158
156 # create Shell Streams (MUX, Task, etc.):
157 queue_addr = msg.content.mux
158 shell_addrs = [ str(queue_addr) ]
159 task_addr = msg.content.task
160 if task_addr:
161 shell_addrs.append(str(task_addr))
162
163 # Uncomment this to go back to two-socket model
164 # shell_streams = []
165 # for addr in shell_addrs:
166 # stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
167 # stream.setsockopt(zmq.IDENTITY, identity)
168 # stream.connect(disambiguate_url(addr, self.location))
169 # shell_streams.append(stream)
170
171 # Now use only one shell stream for mux and tasks
159 # create Shell Connections (MUX, Task, etc.):
160 shell_addrs = map(str, [info['mux'], info['task']])
161
162 # Use only one shell stream for mux and tasks
172 163 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
173 164 stream.setsockopt(zmq.IDENTITY, identity)
174 165 shell_streams = [stream]
175 166 for addr in shell_addrs:
176 167 connect(stream, addr)
177 # end single stream-socket
178 168
179 169 # control stream:
180 control_addr = str(msg.content.control)
170 control_addr = str(info['control'])
181 171 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
182 172 control_stream.setsockopt(zmq.IDENTITY, identity)
183 173 connect(control_stream, control_addr)
184 174
185 175 # create iopub stream:
186 iopub_addr = msg.content.iopub
176 iopub_addr = info['iopub']
187 177 iopub_socket = ctx.socket(zmq.PUB)
188 178 iopub_socket.setsockopt(zmq.IDENTITY, identity)
189 179 connect(iopub_socket, iopub_addr)
190 180
191 181 # disable history:
192 182 self.config.HistoryManager.hist_file = ':memory:'
193 183
194 184 # Redirect input streams and set a display hook.
195 185 if self.out_stream_factory:
196 186 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
197 187 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
198 188 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
199 189 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
200 190 if self.display_hook_factory:
201 191 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
202 192 sys.displayhook.topic = cast_bytes('engine.%i.pyout' % self.id)
203 193
204 194 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
205 195 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
206 196 loop=loop, user_ns=self.user_ns, log=self.log)
207 197 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
208 198 self.kernel.start()
209 199
210 200
211 201 else:
212 202 self.log.fatal("Registration Failed: %s"%msg)
213 203 raise Exception("Registration Failed: %s"%msg)
214 204
215 205 self.log.info("Completed registration with id %i"%self.id)
216 206
217 207
218 208 def abort(self):
219 209 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
220 210 if self.url.startswith('127.'):
221 211 self.log.fatal("""
222 212 If the controller and engines are not on the same machine,
223 213 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
224 214 c.HubFactory.ip='*' # for all interfaces, internal and external
225 215 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
226 216 or tunnel connections via ssh.
227 217 """)
228 218 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
229 219 time.sleep(1)
230 220 sys.exit(255)
231 221
232 222 def start(self):
233 223 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
234 224 dc.start()
235 225 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
236 226 self._abort_dc.start()
237 227
@@ -1,756 +1,760 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 27 import hmac
28 28 import logging
29 29 import os
30 30 import pprint
31 31 import uuid
32 32 from datetime import datetime
33 33
34 34 try:
35 35 import cPickle
36 36 pickle = cPickle
37 37 except:
38 38 cPickle = None
39 39 import pickle
40 40
41 41 import zmq
42 42 from zmq.utils import jsonapi
43 43 from zmq.eventloop.ioloop import IOLoop
44 44 from zmq.eventloop.zmqstream import ZMQStream
45 45
46 46 from IPython.config.application import Application, boolean_flag
47 47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 48 from IPython.utils.importstring import import_item
49 49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 50 from IPython.utils.py3compat import str_to_bytes
51 51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 52 DottedObjectName, CUnicode)
53 53
54 54 #-----------------------------------------------------------------------------
55 55 # utility functions
56 56 #-----------------------------------------------------------------------------
57 57
58 58 def squash_unicode(obj):
59 59 """coerce unicode back to bytestrings."""
60 60 if isinstance(obj,dict):
61 61 for key in obj.keys():
62 62 obj[key] = squash_unicode(obj[key])
63 63 if isinstance(key, unicode):
64 64 obj[squash_unicode(key)] = obj.pop(key)
65 65 elif isinstance(obj, list):
66 66 for i,v in enumerate(obj):
67 67 obj[i] = squash_unicode(v)
68 68 elif isinstance(obj, unicode):
69 69 obj = obj.encode('utf8')
70 70 return obj
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # globals and defaults
74 74 #-----------------------------------------------------------------------------
75 75
76 76
77 77 # ISO8601-ify datetime objects
78 78 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
79 79 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
80 80
81 81 pickle_packer = lambda o: pickle.dumps(o,-1)
82 82 pickle_unpacker = pickle.loads
83 83
84 84 default_packer = json_packer
85 85 default_unpacker = json_unpacker
86 86
87 87 DELIM=b"<IDS|MSG>"
88 88
89 89
90 90 #-----------------------------------------------------------------------------
91 91 # Mixin tools for apps that use Sessions
92 92 #-----------------------------------------------------------------------------
93 93
94 94 session_aliases = dict(
95 95 ident = 'Session.session',
96 96 user = 'Session.username',
97 97 keyfile = 'Session.keyfile',
98 98 )
99 99
100 100 session_flags = {
101 101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
102 102 'keyfile' : '' }},
103 103 """Use HMAC digests for authentication of messages.
104 104 Setting this flag will generate a new UUID to use as the HMAC key.
105 105 """),
106 106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
107 107 """Don't authenticate messages."""),
108 108 }
109 109
110 110 def default_secure(cfg):
111 111 """Set the default behavior for a config environment to be secure.
112 112
113 113 If Session.key/keyfile have not been set, set Session.key to
114 114 a new random UUID.
115 115 """
116 116
117 117 if 'Session' in cfg:
118 118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
119 119 return
120 120 # key/keyfile not specified, generate new UUID:
121 121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
122 122
123 123
124 124 #-----------------------------------------------------------------------------
125 125 # Classes
126 126 #-----------------------------------------------------------------------------
127 127
128 128 class SessionFactory(LoggingConfigurable):
129 129 """The Base class for configurables that have a Session, Context, logger,
130 130 and IOLoop.
131 131 """
132 132
133 133 logname = Unicode('')
134 134 def _logname_changed(self, name, old, new):
135 135 self.log = logging.getLogger(new)
136 136
137 137 # not configurable:
138 138 context = Instance('zmq.Context')
139 139 def _context_default(self):
140 140 return zmq.Context.instance()
141 141
142 142 session = Instance('IPython.zmq.session.Session')
143 143
144 144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
145 145 def _loop_default(self):
146 146 return IOLoop.instance()
147 147
148 148 def __init__(self, **kwargs):
149 149 super(SessionFactory, self).__init__(**kwargs)
150 150
151 151 if self.session is None:
152 152 # construct the session
153 153 self.session = Session(**kwargs)
154 154
155 155
156 156 class Message(object):
157 157 """A simple message object that maps dict keys to attributes.
158 158
159 159 A Message can be created from a dict and a dict from a Message instance
160 160 simply by calling dict(msg_obj)."""
161 161
162 162 def __init__(self, msg_dict):
163 163 dct = self.__dict__
164 164 for k, v in dict(msg_dict).iteritems():
165 165 if isinstance(v, dict):
166 166 v = Message(v)
167 167 dct[k] = v
168 168
169 169 # Having this iterator lets dict(msg_obj) work out of the box.
170 170 def __iter__(self):
171 171 return iter(self.__dict__.iteritems())
172 172
173 173 def __repr__(self):
174 174 return repr(self.__dict__)
175 175
176 176 def __str__(self):
177 177 return pprint.pformat(self.__dict__)
178 178
179 179 def __contains__(self, k):
180 180 return k in self.__dict__
181 181
182 182 def __getitem__(self, k):
183 183 return self.__dict__[k]
184 184
185 185
186 186 def msg_header(msg_id, msg_type, username, session):
187 187 date = datetime.now()
188 188 return locals()
189 189
190 190 def extract_header(msg_or_header):
191 191 """Given a message or header, return the header."""
192 192 if not msg_or_header:
193 193 return {}
194 194 try:
195 195 # See if msg_or_header is the entire message.
196 196 h = msg_or_header['header']
197 197 except KeyError:
198 198 try:
199 199 # See if msg_or_header is just the header
200 200 h = msg_or_header['msg_id']
201 201 except KeyError:
202 202 raise
203 203 else:
204 204 h = msg_or_header
205 205 if not isinstance(h, dict):
206 206 h = dict(h)
207 207 return h
208 208
209 209 class Session(Configurable):
210 210 """Object for handling serialization and sending of messages.
211 211
212 212 The Session object handles building messages and sending them
213 213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 214 other over the network via Session objects, and only need to work with the
215 215 dict-based IPython message spec. The Session will handle
216 216 serialization/deserialization, security, and metadata.
217 217
218 218 Sessions support configurable serialiization via packer/unpacker traits,
219 219 and signing with HMAC digests via the key/keyfile traits.
220 220
221 221 Parameters
222 222 ----------
223 223
224 224 debug : bool
225 225 whether to trigger extra debugging statements
226 226 packer/unpacker : str : 'json', 'pickle' or import_string
227 227 importstrings for methods to serialize message parts. If just
228 228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 229 Otherwise, the entire importstring must be used.
230 230
231 231 The functions must accept at least valid JSON input, and output *bytes*.
232 232
233 233 For example, to use msgpack:
234 234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 235 pack/unpack : callables
236 236 You can also set the pack/unpack callables for serialization directly.
237 237 session : bytes
238 238 the ID of this Session object. The default is to generate a new UUID.
239 239 username : unicode
240 240 username added to message headers. The default is to ask the OS.
241 241 key : bytes
242 242 The key used to initialize an HMAC signature. If unset, messages
243 243 will not be signed or checked.
244 244 keyfile : filepath
245 245 The file containing a key. If this is set, `key` will be initialized
246 246 to the contents of the file.
247 247
248 248 """
249 249
250 250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251 251
252 252 packer = DottedObjectName('json',config=True,
253 253 help="""The name of the packer for serializing messages.
254 254 Should be one of 'json', 'pickle', or an import name
255 255 for a custom callable serializer.""")
256 256 def _packer_changed(self, name, old, new):
257 257 if new.lower() == 'json':
258 258 self.pack = json_packer
259 259 self.unpack = json_unpacker
260 self.unpacker = new
260 261 elif new.lower() == 'pickle':
261 262 self.pack = pickle_packer
262 263 self.unpack = pickle_unpacker
264 self.unpacker = new
263 265 else:
264 266 self.pack = import_item(str(new))
265 267
266 268 unpacker = DottedObjectName('json', config=True,
267 269 help="""The name of the unpacker for unserializing messages.
268 270 Only used with custom functions for `packer`.""")
269 271 def _unpacker_changed(self, name, old, new):
270 272 if new.lower() == 'json':
271 273 self.pack = json_packer
272 274 self.unpack = json_unpacker
275 self.packer = new
273 276 elif new.lower() == 'pickle':
274 277 self.pack = pickle_packer
275 278 self.unpack = pickle_unpacker
279 self.packer = new
276 280 else:
277 281 self.unpack = import_item(str(new))
278 282
279 283 session = CUnicode(u'', config=True,
280 284 help="""The UUID identifying this session.""")
281 285 def _session_default(self):
282 286 u = unicode(uuid.uuid4())
283 287 self.bsession = u.encode('ascii')
284 288 return u
285 289
286 290 def _session_changed(self, name, old, new):
287 291 self.bsession = self.session.encode('ascii')
288 292
289 293 # bsession is the session as bytes
290 294 bsession = CBytes(b'')
291 295
292 296 username = Unicode(os.environ.get('USER',u'username'), config=True,
293 297 help="""Username for the Session. Default is your system username.""")
294 298
295 299 # message signature related traits:
296 300
297 301 key = CBytes(b'', config=True,
298 302 help="""execution key, for extra authentication.""")
299 303 def _key_changed(self, name, old, new):
300 304 if new:
301 305 self.auth = hmac.HMAC(new)
302 306 else:
303 307 self.auth = None
304 308 auth = Instance(hmac.HMAC)
305 309 digest_history = Set()
306 310
307 311 keyfile = Unicode('', config=True,
308 312 help="""path to file containing execution key.""")
309 313 def _keyfile_changed(self, name, old, new):
310 314 with open(new, 'rb') as f:
311 315 self.key = f.read().strip()
312 316
313 317 # serialization traits:
314 318
315 319 pack = Any(default_packer) # the actual packer function
316 320 def _pack_changed(self, name, old, new):
317 321 if not callable(new):
318 322 raise TypeError("packer must be callable, not %s"%type(new))
319 323
320 324 unpack = Any(default_unpacker) # the actual packer function
321 325 def _unpack_changed(self, name, old, new):
322 326 # unpacker is not checked - it is assumed to be
323 327 if not callable(new):
324 328 raise TypeError("unpacker must be callable, not %s"%type(new))
325 329
326 330 def __init__(self, **kwargs):
327 331 """create a Session object
328 332
329 333 Parameters
330 334 ----------
331 335
332 336 debug : bool
333 337 whether to trigger extra debugging statements
334 338 packer/unpacker : str : 'json', 'pickle' or import_string
335 339 importstrings for methods to serialize message parts. If just
336 340 'json' or 'pickle', predefined JSON and pickle packers will be used.
337 341 Otherwise, the entire importstring must be used.
338 342
339 343 The functions must accept at least valid JSON input, and output
340 344 *bytes*.
341 345
342 346 For example, to use msgpack:
343 347 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
344 348 pack/unpack : callables
345 349 You can also set the pack/unpack callables for serialization
346 350 directly.
347 351 session : unicode (must be ascii)
348 352 the ID of this Session object. The default is to generate a new
349 353 UUID.
350 354 bsession : bytes
351 355 The session as bytes
352 356 username : unicode
353 357 username added to message headers. The default is to ask the OS.
354 358 key : bytes
355 359 The key used to initialize an HMAC signature. If unset, messages
356 360 will not be signed or checked.
357 361 keyfile : filepath
358 362 The file containing a key. If this is set, `key` will be
359 363 initialized to the contents of the file.
360 364 """
361 365 super(Session, self).__init__(**kwargs)
362 366 self._check_packers()
363 367 self.none = self.pack({})
364 368 # ensure self._session_default() if necessary, so bsession is defined:
365 369 self.session
366 370
367 371 @property
368 372 def msg_id(self):
369 373 """always return new uuid"""
370 374 return str(uuid.uuid4())
371 375
372 376 def _check_packers(self):
373 377 """check packers for binary data and datetime support."""
374 378 pack = self.pack
375 379 unpack = self.unpack
376 380
377 381 # check simple serialization
378 382 msg = dict(a=[1,'hi'])
379 383 try:
380 384 packed = pack(msg)
381 385 except Exception:
382 386 raise ValueError("packer could not serialize a simple message")
383 387
384 388 # ensure packed message is bytes
385 389 if not isinstance(packed, bytes):
386 390 raise ValueError("message packed to %r, but bytes are required"%type(packed))
387 391
388 392 # check that unpack is pack's inverse
389 393 try:
390 394 unpacked = unpack(packed)
391 395 except Exception:
392 396 raise ValueError("unpacker could not handle the packer's output")
393 397
394 398 # check datetime support
395 399 msg = dict(t=datetime.now())
396 400 try:
397 401 unpacked = unpack(pack(msg))
398 402 except Exception:
399 403 self.pack = lambda o: pack(squash_dates(o))
400 404 self.unpack = lambda s: extract_dates(unpack(s))
401 405
402 406 def msg_header(self, msg_type):
403 407 return msg_header(self.msg_id, msg_type, self.username, self.session)
404 408
405 409 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
406 410 """Return the nested message dict.
407 411
408 412 This format is different from what is sent over the wire. The
409 413 serialize/unserialize methods converts this nested message dict to the wire
410 414 format, which is a list of message parts.
411 415 """
412 416 msg = {}
413 417 header = self.msg_header(msg_type) if header is None else header
414 418 msg['header'] = header
415 419 msg['msg_id'] = header['msg_id']
416 420 msg['msg_type'] = header['msg_type']
417 421 msg['parent_header'] = {} if parent is None else extract_header(parent)
418 422 msg['content'] = {} if content is None else content
419 423 sub = {} if subheader is None else subheader
420 424 msg['header'].update(sub)
421 425 return msg
422 426
423 427 def sign(self, msg_list):
424 428 """Sign a message with HMAC digest. If no auth, return b''.
425 429
426 430 Parameters
427 431 ----------
428 432 msg_list : list
429 433 The [p_header,p_parent,p_content] part of the message list.
430 434 """
431 435 if self.auth is None:
432 436 return b''
433 437 h = self.auth.copy()
434 438 for m in msg_list:
435 439 h.update(m)
436 440 return str_to_bytes(h.hexdigest())
437 441
438 442 def serialize(self, msg, ident=None):
439 443 """Serialize the message components to bytes.
440 444
441 445 This is roughly the inverse of unserialize. The serialize/unserialize
442 446 methods work with full message lists, whereas pack/unpack work with
443 447 the individual message parts in the message list.
444 448
445 449 Parameters
446 450 ----------
447 451 msg : dict or Message
448 452 The nexted message dict as returned by the self.msg method.
449 453
450 454 Returns
451 455 -------
452 456 msg_list : list
453 457 The list of bytes objects to be sent with the format:
454 458 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
455 459 buffer1,buffer2,...]. In this list, the p_* entities are
456 460 the packed or serialized versions, so if JSON is used, these
457 461 are utf8 encoded JSON strings.
458 462 """
459 463 content = msg.get('content', {})
460 464 if content is None:
461 465 content = self.none
462 466 elif isinstance(content, dict):
463 467 content = self.pack(content)
464 468 elif isinstance(content, bytes):
465 469 # content is already packed, as in a relayed message
466 470 pass
467 471 elif isinstance(content, unicode):
468 472 # should be bytes, but JSON often spits out unicode
469 473 content = content.encode('utf8')
470 474 else:
471 475 raise TypeError("Content incorrect type: %s"%type(content))
472 476
473 477 real_message = [self.pack(msg['header']),
474 478 self.pack(msg['parent_header']),
475 479 content
476 480 ]
477 481
478 482 to_send = []
479 483
480 484 if isinstance(ident, list):
481 485 # accept list of idents
482 486 to_send.extend(ident)
483 487 elif ident is not None:
484 488 to_send.append(ident)
485 489 to_send.append(DELIM)
486 490
487 491 signature = self.sign(real_message)
488 492 to_send.append(signature)
489 493
490 494 to_send.extend(real_message)
491 495
492 496 return to_send
493 497
494 498 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
495 499 buffers=None, subheader=None, track=False, header=None):
496 500 """Build and send a message via stream or socket.
497 501
498 502 The message format used by this function internally is as follows:
499 503
500 504 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
501 505 buffer1,buffer2,...]
502 506
503 507 The serialize/unserialize methods convert the nested message dict into this
504 508 format.
505 509
506 510 Parameters
507 511 ----------
508 512
509 513 stream : zmq.Socket or ZMQStream
510 514 The socket-like object used to send the data.
511 515 msg_or_type : str or Message/dict
512 516 Normally, msg_or_type will be a msg_type unless a message is being
513 517 sent more than once. If a header is supplied, this can be set to
514 518 None and the msg_type will be pulled from the header.
515 519
516 520 content : dict or None
517 521 The content of the message (ignored if msg_or_type is a message).
518 522 header : dict or None
519 523 The header dict for the message (ignores if msg_to_type is a message).
520 524 parent : Message or dict or None
521 525 The parent or parent header describing the parent of this message
522 526 (ignored if msg_or_type is a message).
523 527 ident : bytes or list of bytes
524 528 The zmq.IDENTITY routing path.
525 529 subheader : dict or None
526 530 Extra header keys for this message's header (ignored if msg_or_type
527 531 is a message).
528 532 buffers : list or None
529 533 The already-serialized buffers to be appended to the message.
530 534 track : bool
531 535 Whether to track. Only for use with Sockets, because ZMQStream
532 536 objects cannot track messages.
533 537
534 538 Returns
535 539 -------
536 540 msg : dict
537 541 The constructed message.
538 542 (msg,tracker) : (dict, MessageTracker)
539 543 if track=True, then a 2-tuple will be returned,
540 544 the first element being the constructed
541 545 message, and the second being the MessageTracker
542 546
543 547 """
544 548
545 549 if not isinstance(stream, (zmq.Socket, ZMQStream)):
546 550 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
547 551 elif track and isinstance(stream, ZMQStream):
548 552 raise TypeError("ZMQStream cannot track messages")
549 553
550 554 if isinstance(msg_or_type, (Message, dict)):
551 555 # We got a Message or message dict, not a msg_type so don't
552 556 # build a new Message.
553 557 msg = msg_or_type
554 558 else:
555 559 msg = self.msg(msg_or_type, content=content, parent=parent,
556 560 subheader=subheader, header=header)
557 561
558 562 buffers = [] if buffers is None else buffers
559 563 to_send = self.serialize(msg, ident)
560 564 flag = 0
561 565 if buffers:
562 566 flag = zmq.SNDMORE
563 567 _track = False
564 568 else:
565 569 _track=track
566 570 if track:
567 571 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
568 572 else:
569 573 tracker = stream.send_multipart(to_send, flag, copy=False)
570 574 for b in buffers[:-1]:
571 575 stream.send(b, flag, copy=False)
572 576 if buffers:
573 577 if track:
574 578 tracker = stream.send(buffers[-1], copy=False, track=track)
575 579 else:
576 580 tracker = stream.send(buffers[-1], copy=False)
577 581
578 582 # omsg = Message(msg)
579 583 if self.debug:
580 584 pprint.pprint(msg)
581 585 pprint.pprint(to_send)
582 586 pprint.pprint(buffers)
583 587
584 588 msg['tracker'] = tracker
585 589
586 590 return msg
587 591
588 592 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
589 593 """Send a raw message via ident path.
590 594
591 595 This method is used to send a already serialized message.
592 596
593 597 Parameters
594 598 ----------
595 599 stream : ZMQStream or Socket
596 600 The ZMQ stream or socket to use for sending the message.
597 601 msg_list : list
598 602 The serialized list of messages to send. This only includes the
599 603 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
600 604 the message.
601 605 ident : ident or list
602 606 A single ident or a list of idents to use in sending.
603 607 """
604 608 to_send = []
605 609 if isinstance(ident, bytes):
606 610 ident = [ident]
607 611 if ident is not None:
608 612 to_send.extend(ident)
609 613
610 614 to_send.append(DELIM)
611 615 to_send.append(self.sign(msg_list))
612 616 to_send.extend(msg_list)
613 617 stream.send_multipart(msg_list, flags, copy=copy)
614 618
615 619 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
616 620 """Receive and unpack a message.
617 621
618 622 Parameters
619 623 ----------
620 624 socket : ZMQStream or Socket
621 625 The socket or stream to use in receiving.
622 626
623 627 Returns
624 628 -------
625 629 [idents], msg
626 630 [idents] is a list of idents and msg is a nested message dict of
627 631 same format as self.msg returns.
628 632 """
629 633 if isinstance(socket, ZMQStream):
630 634 socket = socket.socket
631 635 try:
632 636 msg_list = socket.recv_multipart(mode, copy=copy)
633 637 except zmq.ZMQError as e:
634 638 if e.errno == zmq.EAGAIN:
635 639 # We can convert EAGAIN to None as we know in this case
636 640 # recv_multipart won't return None.
637 641 return None,None
638 642 else:
639 643 raise
640 644 # split multipart message into identity list and message dict
641 645 # invalid large messages can cause very expensive string comparisons
642 646 idents, msg_list = self.feed_identities(msg_list, copy)
643 647 try:
644 648 return idents, self.unserialize(msg_list, content=content, copy=copy)
645 649 except Exception as e:
646 650 # TODO: handle it
647 651 raise e
648 652
649 653 def feed_identities(self, msg_list, copy=True):
650 654 """Split the identities from the rest of the message.
651 655
652 656 Feed until DELIM is reached, then return the prefix as idents and
653 657 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
654 658 but that would be silly.
655 659
656 660 Parameters
657 661 ----------
658 662 msg_list : a list of Message or bytes objects
659 663 The message to be split.
660 664 copy : bool
661 665 flag determining whether the arguments are bytes or Messages
662 666
663 667 Returns
664 668 -------
665 669 (idents, msg_list) : two lists
666 670 idents will always be a list of bytes, each of which is a ZMQ
667 671 identity. msg_list will be a list of bytes or zmq.Messages of the
668 672 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
669 673 should be unpackable/unserializable via self.unserialize at this
670 674 point.
671 675 """
672 676 if copy:
673 677 idx = msg_list.index(DELIM)
674 678 return msg_list[:idx], msg_list[idx+1:]
675 679 else:
676 680 failed = True
677 681 for idx,m in enumerate(msg_list):
678 682 if m.bytes == DELIM:
679 683 failed = False
680 684 break
681 685 if failed:
682 686 raise ValueError("DELIM not in msg_list")
683 687 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
684 688 return [m.bytes for m in idents], msg_list
685 689
686 690 def unserialize(self, msg_list, content=True, copy=True):
687 691 """Unserialize a msg_list to a nested message dict.
688 692
689 693 This is roughly the inverse of serialize. The serialize/unserialize
690 694 methods work with full message lists, whereas pack/unpack work with
691 695 the individual message parts in the message list.
692 696
693 697 Parameters:
694 698 -----------
695 699 msg_list : list of bytes or Message objects
696 700 The list of message parts of the form [HMAC,p_header,p_parent,
697 701 p_content,buffer1,buffer2,...].
698 702 content : bool (True)
699 703 Whether to unpack the content dict (True), or leave it packed
700 704 (False).
701 705 copy : bool (True)
702 706 Whether to return the bytes (True), or the non-copying Message
703 707 object in each place (False).
704 708
705 709 Returns
706 710 -------
707 711 msg : dict
708 712 The nested message dict with top-level keys [header, parent_header,
709 713 content, buffers].
710 714 """
711 715 minlen = 4
712 716 message = {}
713 717 if not copy:
714 718 for i in range(minlen):
715 719 msg_list[i] = msg_list[i].bytes
716 720 if self.auth is not None:
717 721 signature = msg_list[0]
718 722 if not signature:
719 723 raise ValueError("Unsigned Message")
720 724 if signature in self.digest_history:
721 725 raise ValueError("Duplicate Signature: %r"%signature)
722 726 self.digest_history.add(signature)
723 727 check = self.sign(msg_list[1:4])
724 728 if not signature == check:
725 729 raise ValueError("Invalid Signature: %r"%signature)
726 730 if not len(msg_list) >= minlen:
727 731 raise TypeError("malformed message, must have at least %i elements"%minlen)
728 732 header = self.unpack(msg_list[1])
729 733 message['header'] = header
730 734 message['msg_id'] = header['msg_id']
731 735 message['msg_type'] = header['msg_type']
732 736 message['parent_header'] = self.unpack(msg_list[2])
733 737 if content:
734 738 message['content'] = self.unpack(msg_list[3])
735 739 else:
736 740 message['content'] = msg_list[3]
737 741
738 742 message['buffers'] = msg_list[4:]
739 743 return message
740 744
741 745 def test_msg2obj():
742 746 am = dict(x=1)
743 747 ao = Message(am)
744 748 assert ao.x == am['x']
745 749
746 750 am['y'] = dict(z=1)
747 751 ao = Message(am)
748 752 assert ao.y.z == am['y']['z']
749 753
750 754 k1, k2 = 'y', 'z'
751 755 assert ao[k1][k2] == am[k1][k2]
752 756
753 757 am2 = dict(ao)
754 758 assert am['x'] == am2['x']
755 759 assert am['y']['z'] == am2['y']['z']
756 760
General Comments 0
You need to be logged in to leave comments. Login now