##// END OF EJS Templates
cleanup per review...
MinRK -
Show More
@@ -1,422 +1,422 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import os
27 27 import socket
28 28 import stat
29 29 import sys
30 30 import uuid
31 31
32 32 from multiprocessing import Process
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37 from zmq.utils import jsonapi as json
38 38
39 39 from IPython.config.application import boolean_flag
40 40 from IPython.core.profiledir import ProfileDir
41 41
42 42 from IPython.parallel.apps.baseapp import (
43 43 BaseParallelApplication,
44 44 base_aliases,
45 45 base_flags,
46 46 )
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
49 49
50 50 # from IPython.parallel.controller.controller import ControllerFactory
51 51 from IPython.zmq.session import Session
52 52 from IPython.parallel.controller.heartmonitor import HeartMonitor
53 53 from IPython.parallel.controller.hub import HubFactory
54 54 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
55 55 from IPython.parallel.controller.sqlitedb import SQLiteDB
56 56
57 from IPython.parallel.util import signal_children, split_url, ensure_bytes
57 from IPython.parallel.util import signal_children, split_url, asbytes
58 58
59 59 # conditional import of MongoDB backend class
60 60
61 61 try:
62 62 from IPython.parallel.controller.mongodb import MongoDB
63 63 except ImportError:
64 64 maybe_mongo = []
65 65 else:
66 66 maybe_mongo = [MongoDB]
67 67
68 68
69 69 #-----------------------------------------------------------------------------
70 70 # Module level variables
71 71 #-----------------------------------------------------------------------------
72 72
73 73
74 74 #: The default config file name for this application
75 75 default_config_file_name = u'ipcontroller_config.py'
76 76
77 77
78 78 _description = """Start the IPython controller for parallel computing.
79 79
80 80 The IPython controller provides a gateway between the IPython engines and
81 81 clients. The controller needs to be started before the engines and can be
82 82 configured using command line options or using a cluster directory. Cluster
83 83 directories contain config, log and security files and are usually located in
84 84 your ipython directory and named as "profile_name". See the `profile`
85 85 and `profile_dir` options for details.
86 86 """
87 87
88 88
89 89
90 90
91 91 #-----------------------------------------------------------------------------
92 92 # The main application
93 93 #-----------------------------------------------------------------------------
94 94 flags = {}
95 95 flags.update(base_flags)
96 96 flags.update({
97 97 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
98 98 'Use threads instead of processes for the schedulers'),
99 99 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
100 100 'use the SQLiteDB backend'),
101 101 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
102 102 'use the MongoDB backend'),
103 103 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
104 104 'use the in-memory DictDB backend'),
105 105 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
106 106 'reuse existing json connection files')
107 107 })
108 108
109 109 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
110 110 "Use HMAC digests for authentication of messages.",
111 111 "Don't authenticate messages."
112 112 ))
113 113 aliases = dict(
114 114 reuse_files = 'IPControllerApp.reuse_files',
115 115 secure = 'IPControllerApp.secure',
116 116 ssh = 'IPControllerApp.ssh_server',
117 117 use_threads = 'IPControllerApp.use_threads',
118 118 location = 'IPControllerApp.location',
119 119
120 120 ident = 'Session.session',
121 121 user = 'Session.username',
122 122 exec_key = 'Session.keyfile',
123 123
124 124 url = 'HubFactory.url',
125 125 ip = 'HubFactory.ip',
126 126 transport = 'HubFactory.transport',
127 127 port = 'HubFactory.regport',
128 128
129 129 ping = 'HeartMonitor.period',
130 130
131 131 scheme = 'TaskScheduler.scheme_name',
132 132 hwm = 'TaskScheduler.hwm',
133 133 )
134 134 aliases.update(base_aliases)
135 135
136 136 class IPControllerApp(BaseParallelApplication):
137 137
138 138 name = u'ipcontroller'
139 139 description = _description
140 140 config_file_name = Unicode(default_config_file_name)
141 141 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
142 142
143 143 # change default to True
144 144 auto_create = Bool(True, config=True,
145 145 help="""Whether to create profile dir if it doesn't exist.""")
146 146
147 147 reuse_files = Bool(False, config=True,
148 148 help='Whether to reuse existing json connection files.'
149 149 )
150 150 secure = Bool(True, config=True,
151 151 help='Whether to use HMAC digests for extra message authentication.'
152 152 )
153 153 ssh_server = Unicode(u'', config=True,
154 154 help="""ssh url for clients to use when connecting to the Controller
155 155 processes. It should be of the form: [user@]server[:port]. The
156 156 Controller's listening addresses must be accessible from the ssh server""",
157 157 )
158 158 location = Unicode(u'', config=True,
159 159 help="""The external IP or domain name of the Controller, used for disambiguating
160 160 engine and client connections.""",
161 161 )
162 162 import_statements = List([], config=True,
163 163 help="import statements to be run at startup. Necessary in some environments"
164 164 )
165 165
166 166 use_threads = Bool(False, config=True,
167 167 help='Use threads instead of processes for the schedulers',
168 168 )
169 169
170 170 # internal
171 171 children = List()
172 172 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
173 173
174 174 def _use_threads_changed(self, name, old, new):
175 175 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
176 176
177 177 aliases = Dict(aliases)
178 178 flags = Dict(flags)
179 179
180 180
181 181 def save_connection_dict(self, fname, cdict):
182 182 """save a connection dict to json file."""
183 183 c = self.config
184 184 url = cdict['url']
185 185 location = cdict['location']
186 186 if not location:
187 187 try:
188 188 proto,ip,port = split_url(url)
189 189 except AssertionError:
190 190 pass
191 191 else:
192 192 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
193 193 cdict['location'] = location
194 194 fname = os.path.join(self.profile_dir.security_dir, fname)
195 195 with open(fname, 'wb') as f:
196 196 f.write(json.dumps(cdict, indent=2))
197 197 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
198 198
199 199 def load_config_from_json(self):
200 200 """load config from existing json connector files."""
201 201 c = self.config
202 202 # load from engine config
203 203 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
204 204 cfg = json.loads(f.read())
205 key = c.Session.key = ensure_bytes(cfg['exec_key'])
205 key = c.Session.key = asbytes(cfg['exec_key'])
206 206 xport,addr = cfg['url'].split('://')
207 207 c.HubFactory.engine_transport = xport
208 208 ip,ports = addr.split(':')
209 209 c.HubFactory.engine_ip = ip
210 210 c.HubFactory.regport = int(ports)
211 211 self.location = cfg['location']
212 212 # load client config
213 213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
214 214 cfg = json.loads(f.read())
215 215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
216 216 xport,addr = cfg['url'].split('://')
217 217 c.HubFactory.client_transport = xport
218 218 ip,ports = addr.split(':')
219 219 c.HubFactory.client_ip = ip
220 220 self.ssh_server = cfg['ssh']
221 221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
222 222
223 223 def init_hub(self):
224 224 c = self.config
225 225
226 226 self.do_import_statements()
227 227 reusing = self.reuse_files
228 228 if reusing:
229 229 try:
230 230 self.load_config_from_json()
231 231 except (AssertionError,IOError):
232 232 reusing=False
233 233 # check again, because reusing may have failed:
234 234 if reusing:
235 235 pass
236 236 elif self.secure:
237 237 key = str(uuid.uuid4())
238 238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
239 239 # with open(keyfile, 'w') as f:
240 240 # f.write(key)
241 241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
242 c.Session.key = ensure_bytes(key)
242 c.Session.key = asbytes(key)
243 243 else:
244 244 key = c.Session.key = b''
245 245
246 246 try:
247 247 self.factory = HubFactory(config=c, log=self.log)
248 248 # self.start_logging()
249 249 self.factory.init_hub()
250 250 except:
251 251 self.log.error("Couldn't construct the Controller", exc_info=True)
252 252 self.exit(1)
253 253
254 254 if not reusing:
255 255 # save to new json config files
256 256 f = self.factory
257 257 cdict = {'exec_key' : key,
258 258 'ssh' : self.ssh_server,
259 259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
260 260 'location' : self.location
261 261 }
262 262 self.save_connection_dict('ipcontroller-client.json', cdict)
263 263 edict = cdict
264 264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
265 265 self.save_connection_dict('ipcontroller-engine.json', edict)
266 266
267 267 #
268 268 def init_schedulers(self):
269 269 children = self.children
270 270 mq = import_item(str(self.mq_class))
271 271
272 272 hub = self.factory
273 273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
274 274 # IOPub relay (in a Process)
275 275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
276 276 q.bind_in(hub.client_info['iopub'])
277 277 q.bind_out(hub.engine_info['iopub'])
278 278 q.setsockopt_out(zmq.SUBSCRIBE, b'')
279 279 q.connect_mon(hub.monitor_url)
280 280 q.daemon=True
281 281 children.append(q)
282 282
283 283 # Multiplexer Queue (in a Process)
284 284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'in', b'out')
285 285 q.bind_in(hub.client_info['mux'])
286 286 q.setsockopt_in(zmq.IDENTITY, b'mux')
287 287 q.bind_out(hub.engine_info['mux'])
288 288 q.connect_mon(hub.monitor_url)
289 289 q.daemon=True
290 290 children.append(q)
291 291
292 292 # Control Queue (in a Process)
293 293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'incontrol', b'outcontrol')
294 294 q.bind_in(hub.client_info['control'])
295 295 q.setsockopt_in(zmq.IDENTITY, b'control')
296 296 q.bind_out(hub.engine_info['control'])
297 297 q.connect_mon(hub.monitor_url)
298 298 q.daemon=True
299 299 children.append(q)
300 300 try:
301 301 scheme = self.config.TaskScheduler.scheme_name
302 302 except AttributeError:
303 303 scheme = TaskScheduler.scheme_name.get_default_value()
304 304 # Task Queue (in a Process)
305 305 if scheme == 'pure':
306 306 self.log.warn("task::using pure XREQ Task scheduler")
307 307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, b'intask', b'outtask')
308 308 # q.setsockopt_out(zmq.HWM, hub.hwm)
309 309 q.bind_in(hub.client_info['task'][1])
310 310 q.setsockopt_in(zmq.IDENTITY, b'task')
311 311 q.bind_out(hub.engine_info['task'])
312 312 q.connect_mon(hub.monitor_url)
313 313 q.daemon=True
314 314 children.append(q)
315 315 elif scheme == 'none':
316 316 self.log.warn("task::using no Task scheduler")
317 317
318 318 else:
319 319 self.log.info("task::using Python %s Task scheduler"%scheme)
320 320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
321 321 hub.monitor_url, hub.client_info['notification'])
322 322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
323 323 log_url = self.log_url, config=dict(self.config))
324 324 if 'Process' in self.mq_class:
325 325 # run the Python scheduler in a Process
326 326 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
327 327 q.daemon=True
328 328 children.append(q)
329 329 else:
330 330 # single-threaded Controller
331 331 kwargs['in_thread'] = True
332 332 launch_scheduler(*sargs, **kwargs)
333 333
334 334
335 335 def save_urls(self):
336 336 """save the registration urls to files."""
337 337 c = self.config
338 338
339 339 sec_dir = self.profile_dir.security_dir
340 340 cf = self.factory
341 341
342 342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
343 343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
344 344
345 345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
346 346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
347 347
348 348
349 349 def do_import_statements(self):
350 350 statements = self.import_statements
351 351 for s in statements:
352 352 try:
353 353 self.log.msg("Executing statement: '%s'" % s)
354 354 exec s in globals(), locals()
355 355 except:
356 356 self.log.msg("Error running statement: %s" % s)
357 357
358 358 def forward_logging(self):
359 359 if self.log_url:
360 360 self.log.info("Forwarding logging to %s"%self.log_url)
361 361 context = zmq.Context.instance()
362 362 lsock = context.socket(zmq.PUB)
363 363 lsock.connect(self.log_url)
364 364 handler = PUBHandler(lsock)
365 365 self.log.removeHandler(self._log_handler)
366 366 handler.root_topic = 'controller'
367 367 handler.setLevel(self.log_level)
368 368 self.log.addHandler(handler)
369 369 self._log_handler = handler
370 370 # #
371 371
372 372 def initialize(self, argv=None):
373 373 super(IPControllerApp, self).initialize(argv)
374 374 self.forward_logging()
375 375 self.init_hub()
376 376 self.init_schedulers()
377 377
378 378 def start(self):
379 379 # Start the subprocesses:
380 380 self.factory.start()
381 381 child_procs = []
382 382 for child in self.children:
383 383 child.start()
384 384 if isinstance(child, ProcessMonitoredQueue):
385 385 child_procs.append(child.launcher)
386 386 elif isinstance(child, Process):
387 387 child_procs.append(child)
388 388 if child_procs:
389 389 signal_children(child_procs)
390 390
391 391 self.write_pid_file(overwrite=True)
392 392
393 393 try:
394 394 self.factory.loop.start()
395 395 except KeyboardInterrupt:
396 396 self.log.critical("Interrupted, Exiting...\n")
397 397
398 398
399 399
400 400 def launch_new_instance():
401 401 """Create and run the IPython controller"""
402 402 if sys.platform == 'win32':
403 403 # make sure we don't get called from a multiprocessing subprocess
404 404 # this can result in infinite Controllers being started on Windows
405 405 # which doesn't have a proper fork, so multiprocessing is wonky
406 406
407 407 # this only comes up when IPython has been installed using vanilla
408 408 # setuptools, and *not* distribute.
409 409 import multiprocessing
410 410 p = multiprocessing.current_process()
411 411 # the main process has name 'MainProcess'
412 412 # subprocesses will have names like 'Process-1'
413 413 if p.name != 'MainProcess':
414 414 # we are a subprocess, don't start another Controller!
415 415 return
416 416 app = IPControllerApp.instance()
417 417 app.initialize()
418 418 app.start()
419 419
420 420
421 421 if __name__ == '__main__':
422 422 launch_new_instance()
@@ -1,301 +1,301 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import json
25 25 import os
26 26 import sys
27 27 import time
28 28
29 29 import zmq
30 30 from zmq.eventloop import ioloop
31 31
32 32 from IPython.core.profiledir import ProfileDir
33 33 from IPython.parallel.apps.baseapp import (
34 34 BaseParallelApplication,
35 35 base_aliases,
36 36 base_flags,
37 37 )
38 38 from IPython.zmq.log import EnginePUBHandler
39 39
40 40 from IPython.config.configurable import Configurable
41 41 from IPython.zmq.session import Session
42 42 from IPython.parallel.engine.engine import EngineFactory
43 43 from IPython.parallel.engine.streamkernel import Kernel
44 from IPython.parallel.util import disambiguate_url, ensure_bytes
44 from IPython.parallel.util import disambiguate_url, asbytes
45 45
46 46 from IPython.utils.importstring import import_item
47 47 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float
48 48
49 49
50 50 #-----------------------------------------------------------------------------
51 51 # Module level variables
52 52 #-----------------------------------------------------------------------------
53 53
54 54 #: The default config file name for this application
55 55 default_config_file_name = u'ipengine_config.py'
56 56
57 57 _description = """Start an IPython engine for parallel computing.
58 58
59 59 IPython engines run in parallel and perform computations on behalf of a client
60 60 and controller. A controller needs to be started before the engines. The
61 61 engine can be configured using command line options or using a cluster
62 62 directory. Cluster directories contain config, log and security files and are
63 63 usually located in your ipython directory and named as "profile_name".
64 64 See the `profile` and `profile_dir` options for details.
65 65 """
66 66
67 67
68 68 #-----------------------------------------------------------------------------
69 69 # MPI configuration
70 70 #-----------------------------------------------------------------------------
71 71
72 72 mpi4py_init = """from mpi4py import MPI as mpi
73 73 mpi.size = mpi.COMM_WORLD.Get_size()
74 74 mpi.rank = mpi.COMM_WORLD.Get_rank()
75 75 """
76 76
77 77
78 78 pytrilinos_init = """from PyTrilinos import Epetra
79 79 class SimpleStruct:
80 80 pass
81 81 mpi = SimpleStruct()
82 82 mpi.rank = 0
83 83 mpi.size = 0
84 84 """
85 85
86 86 class MPI(Configurable):
87 87 """Configurable for MPI initialization"""
88 88 use = Unicode('', config=True,
89 89 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
90 90 )
91 91
92 92 def _on_use_changed(self, old, new):
93 93 # load default init script if it's not set
94 94 if not self.init_script:
95 95 self.init_script = self.default_inits.get(new, '')
96 96
97 97 init_script = Unicode('', config=True,
98 98 help="Initialization code for MPI")
99 99
100 100 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
101 101 config=True)
102 102
103 103
104 104 #-----------------------------------------------------------------------------
105 105 # Main application
106 106 #-----------------------------------------------------------------------------
107 107 aliases = dict(
108 108 file = 'IPEngineApp.url_file',
109 109 c = 'IPEngineApp.startup_command',
110 110 s = 'IPEngineApp.startup_script',
111 111
112 112 ident = 'Session.session',
113 113 user = 'Session.username',
114 114 exec_key = 'Session.keyfile',
115 115
116 116 url = 'EngineFactory.url',
117 117 ip = 'EngineFactory.ip',
118 118 transport = 'EngineFactory.transport',
119 119 port = 'EngineFactory.regport',
120 120 location = 'EngineFactory.location',
121 121
122 122 timeout = 'EngineFactory.timeout',
123 123
124 124 mpi = 'MPI.use',
125 125
126 126 )
127 127 aliases.update(base_aliases)
128 128
129 129 class IPEngineApp(BaseParallelApplication):
130 130
131 131 name = Unicode(u'ipengine')
132 132 description = Unicode(_description)
133 133 config_file_name = Unicode(default_config_file_name)
134 134 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
135 135
136 136 startup_script = Unicode(u'', config=True,
137 137 help='specify a script to be run at startup')
138 138 startup_command = Unicode('', config=True,
139 139 help='specify a command to be run at startup')
140 140
141 141 url_file = Unicode(u'', config=True,
142 142 help="""The full location of the file containing the connection information for
143 143 the controller. If this is not given, the file must be in the
144 144 security directory of the cluster directory. This location is
145 145 resolved using the `profile` or `profile_dir` options.""",
146 146 )
147 147 wait_for_url_file = Float(5, config=True,
148 148 help="""The maximum number of seconds to wait for url_file to exist.
149 149 This is useful for batch-systems and shared-filesystems where the
150 150 controller and engine are started at the same time and it
151 151 may take a moment for the controller to write the connector files.""")
152 152
153 153 url_file_name = Unicode(u'ipcontroller-engine.json')
154 154 log_url = Unicode('', config=True,
155 155 help="""The URL for the iploggerapp instance, for forwarding
156 156 logging to a central location.""")
157 157
158 158 aliases = Dict(aliases)
159 159
160 160 # def find_key_file(self):
161 161 # """Set the key file.
162 162 #
163 163 # Here we don't try to actually see if it exists for is valid as that
164 164 # is hadled by the connection logic.
165 165 # """
166 166 # config = self.master_config
167 167 # # Find the actual controller key file
168 168 # if not config.Global.key_file:
169 169 # try_this = os.path.join(
170 170 # config.Global.profile_dir,
171 171 # config.Global.security_dir,
172 172 # config.Global.key_file_name
173 173 # )
174 174 # config.Global.key_file = try_this
175 175
176 176 def find_url_file(self):
177 177 """Set the url file.
178 178
179 179 Here we don't try to actually see if it exists for is valid as that
180 180 is hadled by the connection logic.
181 181 """
182 182 config = self.config
183 183 # Find the actual controller key file
184 184 if not self.url_file:
185 185 self.url_file = os.path.join(
186 186 self.profile_dir.security_dir,
187 187 self.url_file_name
188 188 )
189 189 def init_engine(self):
190 190 # This is the working dir by now.
191 191 sys.path.insert(0, '')
192 192 config = self.config
193 193 # print config
194 194 self.find_url_file()
195 195
196 196 # was the url manually specified?
197 197 keys = set(self.config.EngineFactory.keys())
198 198 keys = keys.union(set(self.config.RegistrationFactory.keys()))
199 199
200 200 if keys.intersection(set(['ip', 'url', 'port'])):
201 201 # Connection info was specified, don't wait for the file
202 202 url_specified = True
203 203 self.wait_for_url_file = 0
204 204 else:
205 205 url_specified = False
206 206
207 207 if self.wait_for_url_file and not os.path.exists(self.url_file):
208 208 self.log.warn("url_file %r not found"%self.url_file)
209 209 self.log.warn("Waiting up to %.1f seconds for it to arrive."%self.wait_for_url_file)
210 210 tic = time.time()
211 211 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
212 212 # wait for url_file to exist, for up to 10 seconds
213 213 time.sleep(0.1)
214 214
215 215 if os.path.exists(self.url_file):
216 216 self.log.info("Loading url_file %r"%self.url_file)
217 217 with open(self.url_file) as f:
218 218 d = json.loads(f.read())
219 219 if d['exec_key']:
220 config.Session.key = ensure_bytes(d['exec_key'])
220 config.Session.key = asbytes(d['exec_key'])
221 221 d['url'] = disambiguate_url(d['url'], d['location'])
222 222 config.EngineFactory.url = d['url']
223 223 config.EngineFactory.location = d['location']
224 224 elif not url_specified:
225 225 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
226 226 self.exit(1)
227 227
228 228
229 229 try:
230 230 exec_lines = config.Kernel.exec_lines
231 231 except AttributeError:
232 232 config.Kernel.exec_lines = []
233 233 exec_lines = config.Kernel.exec_lines
234 234
235 235 if self.startup_script:
236 236 enc = sys.getfilesystemencoding() or 'utf8'
237 237 cmd="execfile(%r)"%self.startup_script.encode(enc)
238 238 exec_lines.append(cmd)
239 239 if self.startup_command:
240 240 exec_lines.append(self.startup_command)
241 241
242 242 # Create the underlying shell class and Engine
243 243 # shell_class = import_item(self.master_config.Global.shell_class)
244 244 # print self.config
245 245 try:
246 246 self.engine = EngineFactory(config=config, log=self.log)
247 247 except:
248 248 self.log.error("Couldn't start the Engine", exc_info=True)
249 249 self.exit(1)
250 250
251 251 def forward_logging(self):
252 252 if self.log_url:
253 253 self.log.info("Forwarding logging to %s"%self.log_url)
254 254 context = self.engine.context
255 255 lsock = context.socket(zmq.PUB)
256 256 lsock.connect(self.log_url)
257 257 self.log.removeHandler(self._log_handler)
258 258 handler = EnginePUBHandler(self.engine, lsock)
259 259 handler.setLevel(self.log_level)
260 260 self.log.addHandler(handler)
261 261 self._log_handler = handler
262 262 #
263 263 def init_mpi(self):
264 264 global mpi
265 265 self.mpi = MPI(config=self.config)
266 266
267 267 mpi_import_statement = self.mpi.init_script
268 268 if mpi_import_statement:
269 269 try:
270 270 self.log.info("Initializing MPI:")
271 271 self.log.info(mpi_import_statement)
272 272 exec mpi_import_statement in globals()
273 273 except:
274 274 mpi = None
275 275 else:
276 276 mpi = None
277 277
278 278 def initialize(self, argv=None):
279 279 super(IPEngineApp, self).initialize(argv)
280 280 self.init_mpi()
281 281 self.init_engine()
282 282 self.forward_logging()
283 283
284 284 def start(self):
285 285 self.engine.start()
286 286 try:
287 287 self.engine.loop.start()
288 288 except KeyboardInterrupt:
289 289 self.log.critical("Engine Interrupted, shutting down...\n")
290 290
291 291
292 292 def launch_new_instance():
293 293 """Create and run the IPython engine"""
294 294 app = IPEngineApp.instance()
295 295 app.initialize()
296 296 app.start()
297 297
298 298
299 299 if __name__ == '__main__':
300 300 launch_new_instance()
301 301
@@ -1,1428 +1,1428 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import sys
21 21 import time
22 22 import warnings
23 23 from datetime import datetime
24 24 from getpass import getpass
25 25 from pprint import pprint
26 26
27 27 pjoin = os.path.join
28 28
29 29 import zmq
30 30 # from zmq.eventloop import ioloop, zmqstream
31 31
32 32 from IPython.config.configurable import MultipleInstanceError
33 33 from IPython.core.application import BaseIPythonApplication
34 34
35 35 from IPython.utils.jsonutil import rekey
36 36 from IPython.utils.localinterfaces import LOCAL_IPS
37 37 from IPython.utils.path import get_ipython_dir
38 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
39 39 Dict, List, Bool, Set)
40 40 from IPython.external.decorator import decorator
41 41 from IPython.external.ssh import tunnel
42 42
43 43 from IPython.parallel import error
44 44 from IPython.parallel import util
45 45
46 46 from IPython.zmq.session import Session, Message
47 47
48 48 from .asyncresult import AsyncResult, AsyncHubResult
49 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
50 50 from .view import DirectView, LoadBalancedView
51 51
52 52 if sys.version_info[0] >= 3:
53 # xrange is used in a coupe 'isinstance' tests in py2
53 # xrange is used in a couple 'isinstance' tests in py2
54 54 # should be just 'range' in 3k
55 55 xrange = range
56 56
57 57 #--------------------------------------------------------------------------
58 58 # Decorators for Client methods
59 59 #--------------------------------------------------------------------------
60 60
61 61 @decorator
62 62 def spin_first(f, self, *args, **kwargs):
63 63 """Call spin() to sync state prior to calling the method."""
64 64 self.spin()
65 65 return f(self, *args, **kwargs)
66 66
67 67
68 68 #--------------------------------------------------------------------------
69 69 # Classes
70 70 #--------------------------------------------------------------------------
71 71
72 72 class Metadata(dict):
73 73 """Subclass of dict for initializing metadata values.
74 74
75 75 Attribute access works on keys.
76 76
77 77 These objects have a strict set of keys - errors will raise if you try
78 78 to add new keys.
79 79 """
80 80 def __init__(self, *args, **kwargs):
81 81 dict.__init__(self)
82 82 md = {'msg_id' : None,
83 83 'submitted' : None,
84 84 'started' : None,
85 85 'completed' : None,
86 86 'received' : None,
87 87 'engine_uuid' : None,
88 88 'engine_id' : None,
89 89 'follow' : None,
90 90 'after' : None,
91 91 'status' : None,
92 92
93 93 'pyin' : None,
94 94 'pyout' : None,
95 95 'pyerr' : None,
96 96 'stdout' : '',
97 97 'stderr' : '',
98 98 }
99 99 self.update(md)
100 100 self.update(dict(*args, **kwargs))
101 101
102 102 def __getattr__(self, key):
103 103 """getattr aliased to getitem"""
104 104 if key in self.iterkeys():
105 105 return self[key]
106 106 else:
107 107 raise AttributeError(key)
108 108
109 109 def __setattr__(self, key, value):
110 110 """setattr aliased to setitem, with strict"""
111 111 if key in self.iterkeys():
112 112 self[key] = value
113 113 else:
114 114 raise AttributeError(key)
115 115
116 116 def __setitem__(self, key, value):
117 117 """strict static key enforcement"""
118 118 if key in self.iterkeys():
119 119 dict.__setitem__(self, key, value)
120 120 else:
121 121 raise KeyError(key)
122 122
123 123
124 124 class Client(HasTraits):
125 125 """A semi-synchronous client to the IPython ZMQ cluster
126 126
127 127 Parameters
128 128 ----------
129 129
130 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
131 131 Connection information for the Hub's registration. If a json connector
132 132 file is given, then likely no further configuration is necessary.
133 133 [Default: use profile]
134 134 profile : bytes
135 135 The name of the Cluster profile to be used to find connector information.
136 136 If run from an IPython application, the default profile will be the same
137 137 as the running application, otherwise it will be 'default'.
138 138 context : zmq.Context
139 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
140 140 debug : bool
141 141 flag for lots of message printing for debug purposes
142 142 timeout : int/float
143 143 time (in seconds) to wait for connection replies from the Hub
144 144 [Default: 10]
145 145
146 146 #-------------- session related args ----------------
147 147
148 148 config : Config object
149 149 If specified, this will be relayed to the Session for configuration
150 150 username : str
151 151 set username for the session object
152 152 packer : str (import_string) or callable
153 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
154 154 function to serialize messages. Must support same input as
155 155 JSON, and output must be bytes.
156 156 You can pass a callable directly as `pack`
157 157 unpacker : str (import_string) or callable
158 158 The inverse of packer. Only necessary if packer is specified as *not* one
159 159 of 'json' or 'pickle'.
160 160
161 161 #-------------- ssh related args ----------------
162 162 # These are args for configuring the ssh tunnel to be used
163 163 # credentials are used to forward connections over ssh to the Controller
164 164 # Note that the ip given in `addr` needs to be relative to sshserver
165 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
166 166 # and set sshserver as the same machine the Controller is on. However,
167 167 # the only requirement is that sshserver is able to see the Controller
168 168 # (i.e. is within the same trusted network).
169 169
170 170 sshserver : str
171 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
172 172 If keyfile or password is specified, and this is not, it will default to
173 173 the ip given in addr.
174 174 sshkey : str; path to public ssh key file
175 175 This specifies a key to be used in ssh login, default None.
176 176 Regular default ssh keys will be used without specifying this argument.
177 177 password : str
178 178 Your ssh password to sshserver. Note that if this is left None,
179 179 you will be prompted for it if passwordless key based login is unavailable.
180 180 paramiko : bool
181 181 flag for whether to use paramiko instead of shell ssh for tunneling.
182 182 [default: True on win32, False else]
183 183
184 184 ------- exec authentication args -------
185 185 If even localhost is untrusted, you can have some protection against
186 186 unauthorized execution by signing messages with HMAC digests.
187 187 Messages are still sent as cleartext, so if someone can snoop your
188 188 loopback traffic this will not protect your privacy, but will prevent
189 189 unauthorized execution.
190 190
191 191 exec_key : str
192 192 an authentication key or file containing a key
193 193 default: None
194 194
195 195
196 196 Attributes
197 197 ----------
198 198
199 199 ids : list of int engine IDs
200 200 requesting the ids attribute always synchronizes
201 201 the registration state. To request ids without synchronization,
202 202 use semi-private _ids attributes.
203 203
204 204 history : list of msg_ids
205 205 a list of msg_ids, keeping track of all the execution
206 206 messages you have submitted in order.
207 207
208 208 outstanding : set of msg_ids
209 209 a set of msg_ids that have been submitted, but whose
210 210 results have not yet been received.
211 211
212 212 results : dict
213 213 a dict of all our results, keyed by msg_id
214 214
215 215 block : bool
216 216 determines default behavior when block not specified
217 217 in execution methods
218 218
219 219 Methods
220 220 -------
221 221
222 222 spin
223 223 flushes incoming results and registration state changes
224 224 control methods spin, and requesting `ids` also ensures up to date
225 225
226 226 wait
227 227 wait on one or more msg_ids
228 228
229 229 execution methods
230 230 apply
231 231 legacy: execute, run
232 232
233 233 data movement
234 234 push, pull, scatter, gather
235 235
236 236 query methods
237 237 queue_status, get_result, purge, result_status
238 238
239 239 control methods
240 240 abort, shutdown
241 241
242 242 """
243 243
244 244
245 245 block = Bool(False)
246 246 outstanding = Set()
247 247 results = Instance('collections.defaultdict', (dict,))
248 248 metadata = Instance('collections.defaultdict', (Metadata,))
249 249 history = List()
250 250 debug = Bool(False)
251 251
252 252 profile=Unicode()
253 253 def _profile_default(self):
254 254 if BaseIPythonApplication.initialized():
255 255 # an IPython app *might* be running, try to get its profile
256 256 try:
257 257 return BaseIPythonApplication.instance().profile
258 258 except (AttributeError, MultipleInstanceError):
259 259 # could be a *different* subclass of config.Application,
260 260 # which would raise one of these two errors.
261 261 return u'default'
262 262 else:
263 263 return u'default'
264 264
265 265
266 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
267 267 _ids = List()
268 268 _connected=Bool(False)
269 269 _ssh=Bool(False)
270 270 _context = Instance('zmq.Context')
271 271 _config = Dict()
272 272 _engines=Instance(util.ReverseDict, (), {})
273 273 # _hub_socket=Instance('zmq.Socket')
274 274 _query_socket=Instance('zmq.Socket')
275 275 _control_socket=Instance('zmq.Socket')
276 276 _iopub_socket=Instance('zmq.Socket')
277 277 _notification_socket=Instance('zmq.Socket')
278 278 _mux_socket=Instance('zmq.Socket')
279 279 _task_socket=Instance('zmq.Socket')
280 280 _task_scheme=Unicode()
281 281 _closed = False
282 282 _ignored_control_replies=Int(0)
283 283 _ignored_hub_replies=Int(0)
284 284
285 285 def __new__(self, *args, **kw):
286 286 # don't raise on positional args
287 287 return HasTraits.__new__(self, **kw)
288 288
289 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
290 290 context=None, debug=False, exec_key=None,
291 291 sshserver=None, sshkey=None, password=None, paramiko=None,
292 292 timeout=10, **extra_args
293 293 ):
294 294 if profile:
295 295 super(Client, self).__init__(debug=debug, profile=profile)
296 296 else:
297 297 super(Client, self).__init__(debug=debug)
298 298 if context is None:
299 299 context = zmq.Context.instance()
300 300 self._context = context
301 301
302 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
303 303 if self._cd is not None:
304 304 if url_or_file is None:
305 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
306 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
307 307 " Please specify at least one of url_or_file or profile."
308 308
309 309 try:
310 310 util.validate_url(url_or_file)
311 311 except AssertionError:
312 312 if not os.path.exists(url_or_file):
313 313 if self._cd:
314 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
315 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
316 316 with open(url_or_file) as f:
317 317 cfg = json.loads(f.read())
318 318 else:
319 319 cfg = {'url':url_or_file}
320 320
321 321 # sync defaults from args, json:
322 322 if sshserver:
323 323 cfg['ssh'] = sshserver
324 324 if exec_key:
325 325 cfg['exec_key'] = exec_key
326 326 exec_key = cfg['exec_key']
327 327 location = cfg.setdefault('location', None)
328 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
329 329 url = cfg['url']
330 330 proto,addr,port = util.split_url(url)
331 331 if location is not None and addr == '127.0.0.1':
332 332 # location specified, and connection is expected to be local
333 333 if location not in LOCAL_IPS and not sshserver:
334 334 # load ssh from JSON *only* if the controller is not on
335 335 # this machine
336 336 sshserver=cfg['ssh']
337 337 if location not in LOCAL_IPS and not sshserver:
338 338 # warn if no ssh specified, but SSH is probably needed
339 339 # This is only a warning, because the most likely cause
340 340 # is a local Controller on a laptop whose IP is dynamic
341 341 warnings.warn("""
342 342 Controller appears to be listening on localhost, but not on this machine.
343 343 If this is true, you should specify Client(...,sshserver='you@%s')
344 344 or instruct your controller to listen on an external IP."""%location,
345 345 RuntimeWarning)
346 346
347 347 self._config = cfg
348 348
349 349 self._ssh = bool(sshserver or sshkey or password)
350 350 if self._ssh and sshserver is None:
351 351 # default to ssh via localhost
352 352 sshserver = url.split('://')[1].split(':')[0]
353 353 if self._ssh and password is None:
354 354 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
355 355 password=False
356 356 else:
357 357 password = getpass("SSH Password for %s: "%sshserver)
358 358 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
359 359
360 360 # configure and construct the session
361 361 if exec_key is not None:
362 362 if os.path.isfile(exec_key):
363 363 extra_args['keyfile'] = exec_key
364 364 else:
365 exec_key = util.ensure_bytes(exec_key)
365 exec_key = util.asbytes(exec_key)
366 366 extra_args['key'] = exec_key
367 367 self.session = Session(**extra_args)
368 368
369 369 self._query_socket = self._context.socket(zmq.XREQ)
370 self._query_socket.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
370 self._query_socket.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
371 371 if self._ssh:
372 372 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
373 373 else:
374 374 self._query_socket.connect(url)
375 375
376 376 self.session.debug = self.debug
377 377
378 378 self._notification_handlers = {'registration_notification' : self._register_engine,
379 379 'unregistration_notification' : self._unregister_engine,
380 380 'shutdown_notification' : lambda msg: self.close(),
381 381 }
382 382 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
383 383 'apply_reply' : self._handle_apply_reply}
384 384 self._connect(sshserver, ssh_kwargs, timeout)
385 385
386 386 def __del__(self):
387 387 """cleanup sockets, but _not_ context."""
388 388 self.close()
389 389
390 390 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
391 391 if ipython_dir is None:
392 392 ipython_dir = get_ipython_dir()
393 393 if profile_dir is not None:
394 394 try:
395 395 self._cd = ProfileDir.find_profile_dir(profile_dir)
396 396 return
397 397 except ProfileDirError:
398 398 pass
399 399 elif profile is not None:
400 400 try:
401 401 self._cd = ProfileDir.find_profile_dir_by_name(
402 402 ipython_dir, profile)
403 403 return
404 404 except ProfileDirError:
405 405 pass
406 406 self._cd = None
407 407
408 408 def _update_engines(self, engines):
409 409 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
410 410 for k,v in engines.iteritems():
411 411 eid = int(k)
412 412 self._engines[eid] = v
413 413 self._ids.append(eid)
414 414 self._ids = sorted(self._ids)
415 415 if sorted(self._engines.keys()) != range(len(self._engines)) and \
416 416 self._task_scheme == 'pure' and self._task_socket:
417 417 self._stop_scheduling_tasks()
418 418
419 419 def _stop_scheduling_tasks(self):
420 420 """Stop scheduling tasks because an engine has been unregistered
421 421 from a pure ZMQ scheduler.
422 422 """
423 423 self._task_socket.close()
424 424 self._task_socket = None
425 425 msg = "An engine has been unregistered, and we are using pure " +\
426 426 "ZMQ task scheduling. Task farming will be disabled."
427 427 if self.outstanding:
428 428 msg += " If you were running tasks when this happened, " +\
429 429 "some `outstanding` msg_ids may never resolve."
430 430 warnings.warn(msg, RuntimeWarning)
431 431
432 432 def _build_targets(self, targets):
433 433 """Turn valid target IDs or 'all' into two lists:
434 434 (int_ids, uuids).
435 435 """
436 436 if not self._ids:
437 437 # flush notification socket if no engines yet, just in case
438 438 if not self.ids:
439 439 raise error.NoEnginesRegistered("Can't build targets without any engines")
440 440
441 441 if targets is None:
442 442 targets = self._ids
443 443 elif isinstance(targets, str):
444 444 if targets.lower() == 'all':
445 445 targets = self._ids
446 446 else:
447 447 raise TypeError("%r not valid str target, must be 'all'"%(targets))
448 448 elif isinstance(targets, int):
449 449 if targets < 0:
450 450 targets = self.ids[targets]
451 451 if targets not in self._ids:
452 452 raise IndexError("No such engine: %i"%targets)
453 453 targets = [targets]
454 454
455 455 if isinstance(targets, slice):
456 456 indices = range(len(self._ids))[targets]
457 457 ids = self.ids
458 458 targets = [ ids[i] for i in indices ]
459 459
460 460 if not isinstance(targets, (tuple, list, xrange)):
461 461 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
462 462
463 return [util.ensure_bytes(self._engines[t]) for t in targets], list(targets)
463 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
464 464
465 465 def _connect(self, sshserver, ssh_kwargs, timeout):
466 466 """setup all our socket connections to the cluster. This is called from
467 467 __init__."""
468 468
469 469 # Maybe allow reconnecting?
470 470 if self._connected:
471 471 return
472 472 self._connected=True
473 473
474 474 def connect_socket(s, url):
475 475 url = util.disambiguate_url(url, self._config['location'])
476 476 if self._ssh:
477 477 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
478 478 else:
479 479 return s.connect(url)
480 480
481 481 self.session.send(self._query_socket, 'connection_request')
482 482 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
483 483 poller = zmq.Poller()
484 484 poller.register(self._query_socket, zmq.POLLIN)
485 485 # poll expects milliseconds, timeout is seconds
486 486 evts = poller.poll(timeout*1000)
487 487 if not evts:
488 488 raise error.TimeoutError("Hub connection request timed out")
489 489 idents,msg = self.session.recv(self._query_socket,mode=0)
490 490 if self.debug:
491 491 pprint(msg)
492 492 msg = Message(msg)
493 493 content = msg.content
494 494 self._config['registration'] = dict(content)
495 495 if content.status == 'ok':
496 ident = util.ensure_bytes(self.session.session)
496 ident = util.asbytes(self.session.session)
497 497 if content.mux:
498 498 self._mux_socket = self._context.socket(zmq.XREQ)
499 499 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
500 500 connect_socket(self._mux_socket, content.mux)
501 501 if content.task:
502 502 self._task_scheme, task_addr = content.task
503 503 self._task_socket = self._context.socket(zmq.XREQ)
504 504 self._task_socket.setsockopt(zmq.IDENTITY, ident)
505 505 connect_socket(self._task_socket, task_addr)
506 506 if content.notification:
507 507 self._notification_socket = self._context.socket(zmq.SUB)
508 508 connect_socket(self._notification_socket, content.notification)
509 509 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
510 510 # if content.query:
511 511 # self._query_socket = self._context.socket(zmq.XREQ)
512 512 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
513 513 # connect_socket(self._query_socket, content.query)
514 514 if content.control:
515 515 self._control_socket = self._context.socket(zmq.XREQ)
516 516 self._control_socket.setsockopt(zmq.IDENTITY, ident)
517 517 connect_socket(self._control_socket, content.control)
518 518 if content.iopub:
519 519 self._iopub_socket = self._context.socket(zmq.SUB)
520 520 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
521 521 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
522 522 connect_socket(self._iopub_socket, content.iopub)
523 523 self._update_engines(dict(content.engines))
524 524 else:
525 525 self._connected = False
526 526 raise Exception("Failed to connect!")
527 527
528 528 #--------------------------------------------------------------------------
529 529 # handlers and callbacks for incoming messages
530 530 #--------------------------------------------------------------------------
531 531
532 532 def _unwrap_exception(self, content):
533 533 """unwrap exception, and remap engine_id to int."""
534 534 e = error.unwrap_exception(content)
535 535 # print e.traceback
536 536 if e.engine_info:
537 537 e_uuid = e.engine_info['engine_uuid']
538 538 eid = self._engines[e_uuid]
539 539 e.engine_info['engine_id'] = eid
540 540 return e
541 541
542 542 def _extract_metadata(self, header, parent, content):
543 543 md = {'msg_id' : parent['msg_id'],
544 544 'received' : datetime.now(),
545 545 'engine_uuid' : header.get('engine', None),
546 546 'follow' : parent.get('follow', []),
547 547 'after' : parent.get('after', []),
548 548 'status' : content['status'],
549 549 }
550 550
551 551 if md['engine_uuid'] is not None:
552 552 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
553 553
554 554 if 'date' in parent:
555 555 md['submitted'] = parent['date']
556 556 if 'started' in header:
557 557 md['started'] = header['started']
558 558 if 'date' in header:
559 559 md['completed'] = header['date']
560 560 return md
561 561
562 562 def _register_engine(self, msg):
563 563 """Register a new engine, and update our connection info."""
564 564 content = msg['content']
565 565 eid = content['id']
566 566 d = {eid : content['queue']}
567 567 self._update_engines(d)
568 568
569 569 def _unregister_engine(self, msg):
570 570 """Unregister an engine that has died."""
571 571 content = msg['content']
572 572 eid = int(content['id'])
573 573 if eid in self._ids:
574 574 self._ids.remove(eid)
575 575 uuid = self._engines.pop(eid)
576 576
577 577 self._handle_stranded_msgs(eid, uuid)
578 578
579 579 if self._task_socket and self._task_scheme == 'pure':
580 580 self._stop_scheduling_tasks()
581 581
582 582 def _handle_stranded_msgs(self, eid, uuid):
583 583 """Handle messages known to be on an engine when the engine unregisters.
584 584
585 585 It is possible that this will fire prematurely - that is, an engine will
586 586 go down after completing a result, and the client will be notified
587 587 of the unregistration and later receive the successful result.
588 588 """
589 589
590 590 outstanding = self._outstanding_dict[uuid]
591 591
592 592 for msg_id in list(outstanding):
593 593 if msg_id in self.results:
594 594 # we already
595 595 continue
596 596 try:
597 597 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
598 598 except:
599 599 content = error.wrap_exception()
600 600 # build a fake message:
601 601 parent = {}
602 602 header = {}
603 603 parent['msg_id'] = msg_id
604 604 header['engine'] = uuid
605 605 header['date'] = datetime.now()
606 606 msg = dict(parent_header=parent, header=header, content=content)
607 607 self._handle_apply_reply(msg)
608 608
609 609 def _handle_execute_reply(self, msg):
610 610 """Save the reply to an execute_request into our results.
611 611
612 612 execute messages are never actually used. apply is used instead.
613 613 """
614 614
615 615 parent = msg['parent_header']
616 616 msg_id = parent['msg_id']
617 617 if msg_id not in self.outstanding:
618 618 if msg_id in self.history:
619 619 print ("got stale result: %s"%msg_id)
620 620 else:
621 621 print ("got unknown result: %s"%msg_id)
622 622 else:
623 623 self.outstanding.remove(msg_id)
624 624 self.results[msg_id] = self._unwrap_exception(msg['content'])
625 625
626 626 def _handle_apply_reply(self, msg):
627 627 """Save the reply to an apply_request into our results."""
628 628 parent = msg['parent_header']
629 629 msg_id = parent['msg_id']
630 630 if msg_id not in self.outstanding:
631 631 if msg_id in self.history:
632 632 print ("got stale result: %s"%msg_id)
633 633 print self.results[msg_id]
634 634 print msg
635 635 else:
636 636 print ("got unknown result: %s"%msg_id)
637 637 else:
638 638 self.outstanding.remove(msg_id)
639 639 content = msg['content']
640 640 header = msg['header']
641 641
642 642 # construct metadata:
643 643 md = self.metadata[msg_id]
644 644 md.update(self._extract_metadata(header, parent, content))
645 645 # is this redundant?
646 646 self.metadata[msg_id] = md
647 647
648 648 e_outstanding = self._outstanding_dict[md['engine_uuid']]
649 649 if msg_id in e_outstanding:
650 650 e_outstanding.remove(msg_id)
651 651
652 652 # construct result:
653 653 if content['status'] == 'ok':
654 654 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
655 655 elif content['status'] == 'aborted':
656 656 self.results[msg_id] = error.TaskAborted(msg_id)
657 657 elif content['status'] == 'resubmitted':
658 658 # TODO: handle resubmission
659 659 pass
660 660 else:
661 661 self.results[msg_id] = self._unwrap_exception(content)
662 662
663 663 def _flush_notifications(self):
664 664 """Flush notifications of engine registrations waiting
665 665 in ZMQ queue."""
666 666 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
667 667 while msg is not None:
668 668 if self.debug:
669 669 pprint(msg)
670 670 msg_type = msg['msg_type']
671 671 handler = self._notification_handlers.get(msg_type, None)
672 672 if handler is None:
673 673 raise Exception("Unhandled message type: %s"%msg.msg_type)
674 674 else:
675 675 handler(msg)
676 676 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
677 677
678 678 def _flush_results(self, sock):
679 679 """Flush task or queue results waiting in ZMQ queue."""
680 680 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
681 681 while msg is not None:
682 682 if self.debug:
683 683 pprint(msg)
684 684 msg_type = msg['msg_type']
685 685 handler = self._queue_handlers.get(msg_type, None)
686 686 if handler is None:
687 687 raise Exception("Unhandled message type: %s"%msg.msg_type)
688 688 else:
689 689 handler(msg)
690 690 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
691 691
692 692 def _flush_control(self, sock):
693 693 """Flush replies from the control channel waiting
694 694 in the ZMQ queue.
695 695
696 696 Currently: ignore them."""
697 697 if self._ignored_control_replies <= 0:
698 698 return
699 699 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
700 700 while msg is not None:
701 701 self._ignored_control_replies -= 1
702 702 if self.debug:
703 703 pprint(msg)
704 704 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
705 705
706 706 def _flush_ignored_control(self):
707 707 """flush ignored control replies"""
708 708 while self._ignored_control_replies > 0:
709 709 self.session.recv(self._control_socket)
710 710 self._ignored_control_replies -= 1
711 711
712 712 def _flush_ignored_hub_replies(self):
713 713 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
714 714 while msg is not None:
715 715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
716 716
717 717 def _flush_iopub(self, sock):
718 718 """Flush replies from the iopub channel waiting
719 719 in the ZMQ queue.
720 720 """
721 721 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
722 722 while msg is not None:
723 723 if self.debug:
724 724 pprint(msg)
725 725 parent = msg['parent_header']
726 726 msg_id = parent['msg_id']
727 727 content = msg['content']
728 728 header = msg['header']
729 729 msg_type = msg['msg_type']
730 730
731 731 # init metadata:
732 732 md = self.metadata[msg_id]
733 733
734 734 if msg_type == 'stream':
735 735 name = content['name']
736 736 s = md[name] or ''
737 737 md[name] = s + content['data']
738 738 elif msg_type == 'pyerr':
739 739 md.update({'pyerr' : self._unwrap_exception(content)})
740 740 elif msg_type == 'pyin':
741 741 md.update({'pyin' : content['code']})
742 742 else:
743 743 md.update({msg_type : content.get('data', '')})
744 744
745 745 # reduntant?
746 746 self.metadata[msg_id] = md
747 747
748 748 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
749 749
750 750 #--------------------------------------------------------------------------
751 751 # len, getitem
752 752 #--------------------------------------------------------------------------
753 753
754 754 def __len__(self):
755 755 """len(client) returns # of engines."""
756 756 return len(self.ids)
757 757
758 758 def __getitem__(self, key):
759 759 """index access returns DirectView multiplexer objects
760 760
761 761 Must be int, slice, or list/tuple/xrange of ints"""
762 762 if not isinstance(key, (int, slice, tuple, list, xrange)):
763 763 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
764 764 else:
765 765 return self.direct_view(key)
766 766
767 767 #--------------------------------------------------------------------------
768 768 # Begin public methods
769 769 #--------------------------------------------------------------------------
770 770
771 771 @property
772 772 def ids(self):
773 773 """Always up-to-date ids property."""
774 774 self._flush_notifications()
775 775 # always copy:
776 776 return list(self._ids)
777 777
778 778 def close(self):
779 779 if self._closed:
780 780 return
781 781 snames = filter(lambda n: n.endswith('socket'), dir(self))
782 782 for socket in map(lambda name: getattr(self, name), snames):
783 783 if isinstance(socket, zmq.Socket) and not socket.closed:
784 784 socket.close()
785 785 self._closed = True
786 786
787 787 def spin(self):
788 788 """Flush any registration notifications and execution results
789 789 waiting in the ZMQ queue.
790 790 """
791 791 if self._notification_socket:
792 792 self._flush_notifications()
793 793 if self._mux_socket:
794 794 self._flush_results(self._mux_socket)
795 795 if self._task_socket:
796 796 self._flush_results(self._task_socket)
797 797 if self._control_socket:
798 798 self._flush_control(self._control_socket)
799 799 if self._iopub_socket:
800 800 self._flush_iopub(self._iopub_socket)
801 801 if self._query_socket:
802 802 self._flush_ignored_hub_replies()
803 803
804 804 def wait(self, jobs=None, timeout=-1):
805 805 """waits on one or more `jobs`, for up to `timeout` seconds.
806 806
807 807 Parameters
808 808 ----------
809 809
810 810 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
811 811 ints are indices to self.history
812 812 strs are msg_ids
813 813 default: wait on all outstanding messages
814 814 timeout : float
815 815 a time in seconds, after which to give up.
816 816 default is -1, which means no timeout
817 817
818 818 Returns
819 819 -------
820 820
821 821 True : when all msg_ids are done
822 822 False : timeout reached, some msg_ids still outstanding
823 823 """
824 824 tic = time.time()
825 825 if jobs is None:
826 826 theids = self.outstanding
827 827 else:
828 828 if isinstance(jobs, (int, str, AsyncResult)):
829 829 jobs = [jobs]
830 830 theids = set()
831 831 for job in jobs:
832 832 if isinstance(job, int):
833 833 # index access
834 834 job = self.history[job]
835 835 elif isinstance(job, AsyncResult):
836 836 map(theids.add, job.msg_ids)
837 837 continue
838 838 theids.add(job)
839 839 if not theids.intersection(self.outstanding):
840 840 return True
841 841 self.spin()
842 842 while theids.intersection(self.outstanding):
843 843 if timeout >= 0 and ( time.time()-tic ) > timeout:
844 844 break
845 845 time.sleep(1e-3)
846 846 self.spin()
847 847 return len(theids.intersection(self.outstanding)) == 0
848 848
849 849 #--------------------------------------------------------------------------
850 850 # Control methods
851 851 #--------------------------------------------------------------------------
852 852
853 853 @spin_first
854 854 def clear(self, targets=None, block=None):
855 855 """Clear the namespace in target(s)."""
856 856 block = self.block if block is None else block
857 857 targets = self._build_targets(targets)[0]
858 858 for t in targets:
859 859 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
860 860 error = False
861 861 if block:
862 862 self._flush_ignored_control()
863 863 for i in range(len(targets)):
864 864 idents,msg = self.session.recv(self._control_socket,0)
865 865 if self.debug:
866 866 pprint(msg)
867 867 if msg['content']['status'] != 'ok':
868 868 error = self._unwrap_exception(msg['content'])
869 869 else:
870 870 self._ignored_control_replies += len(targets)
871 871 if error:
872 872 raise error
873 873
874 874
875 875 @spin_first
876 876 def abort(self, jobs=None, targets=None, block=None):
877 877 """Abort specific jobs from the execution queues of target(s).
878 878
879 879 This is a mechanism to prevent jobs that have already been submitted
880 880 from executing.
881 881
882 882 Parameters
883 883 ----------
884 884
885 885 jobs : msg_id, list of msg_ids, or AsyncResult
886 886 The jobs to be aborted
887 887
888 888
889 889 """
890 890 block = self.block if block is None else block
891 891 targets = self._build_targets(targets)[0]
892 892 msg_ids = []
893 893 if isinstance(jobs, (basestring,AsyncResult)):
894 894 jobs = [jobs]
895 895 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
896 896 if bad_ids:
897 897 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
898 898 for j in jobs:
899 899 if isinstance(j, AsyncResult):
900 900 msg_ids.extend(j.msg_ids)
901 901 else:
902 902 msg_ids.append(j)
903 903 content = dict(msg_ids=msg_ids)
904 904 for t in targets:
905 905 self.session.send(self._control_socket, 'abort_request',
906 906 content=content, ident=t)
907 907 error = False
908 908 if block:
909 909 self._flush_ignored_control()
910 910 for i in range(len(targets)):
911 911 idents,msg = self.session.recv(self._control_socket,0)
912 912 if self.debug:
913 913 pprint(msg)
914 914 if msg['content']['status'] != 'ok':
915 915 error = self._unwrap_exception(msg['content'])
916 916 else:
917 917 self._ignored_control_replies += len(targets)
918 918 if error:
919 919 raise error
920 920
921 921 @spin_first
922 922 def shutdown(self, targets=None, restart=False, hub=False, block=None):
923 923 """Terminates one or more engine processes, optionally including the hub."""
924 924 block = self.block if block is None else block
925 925 if hub:
926 926 targets = 'all'
927 927 targets = self._build_targets(targets)[0]
928 928 for t in targets:
929 929 self.session.send(self._control_socket, 'shutdown_request',
930 930 content={'restart':restart},ident=t)
931 931 error = False
932 932 if block or hub:
933 933 self._flush_ignored_control()
934 934 for i in range(len(targets)):
935 935 idents,msg = self.session.recv(self._control_socket, 0)
936 936 if self.debug:
937 937 pprint(msg)
938 938 if msg['content']['status'] != 'ok':
939 939 error = self._unwrap_exception(msg['content'])
940 940 else:
941 941 self._ignored_control_replies += len(targets)
942 942
943 943 if hub:
944 944 time.sleep(0.25)
945 945 self.session.send(self._query_socket, 'shutdown_request')
946 946 idents,msg = self.session.recv(self._query_socket, 0)
947 947 if self.debug:
948 948 pprint(msg)
949 949 if msg['content']['status'] != 'ok':
950 950 error = self._unwrap_exception(msg['content'])
951 951
952 952 if error:
953 953 raise error
954 954
955 955 #--------------------------------------------------------------------------
956 956 # Execution related methods
957 957 #--------------------------------------------------------------------------
958 958
959 959 def _maybe_raise(self, result):
960 960 """wrapper for maybe raising an exception if apply failed."""
961 961 if isinstance(result, error.RemoteError):
962 962 raise result
963 963
964 964 return result
965 965
966 966 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
967 967 ident=None):
968 968 """construct and send an apply message via a socket.
969 969
970 970 This is the principal method with which all engine execution is performed by views.
971 971 """
972 972
973 973 assert not self._closed, "cannot use me anymore, I'm closed!"
974 974 # defaults:
975 975 args = args if args is not None else []
976 976 kwargs = kwargs if kwargs is not None else {}
977 977 subheader = subheader if subheader is not None else {}
978 978
979 979 # validate arguments
980 980 if not callable(f):
981 981 raise TypeError("f must be callable, not %s"%type(f))
982 982 if not isinstance(args, (tuple, list)):
983 983 raise TypeError("args must be tuple or list, not %s"%type(args))
984 984 if not isinstance(kwargs, dict):
985 985 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
986 986 if not isinstance(subheader, dict):
987 987 raise TypeError("subheader must be dict, not %s"%type(subheader))
988 988
989 989 bufs = util.pack_apply_message(f,args,kwargs)
990 990
991 991 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
992 992 subheader=subheader, track=track)
993 993
994 994 msg_id = msg['msg_id']
995 995 self.outstanding.add(msg_id)
996 996 if ident:
997 997 # possibly routed to a specific engine
998 998 if isinstance(ident, list):
999 999 ident = ident[-1]
1000 1000 if ident in self._engines.values():
1001 1001 # save for later, in case of engine death
1002 1002 self._outstanding_dict[ident].add(msg_id)
1003 1003 self.history.append(msg_id)
1004 1004 self.metadata[msg_id]['submitted'] = datetime.now()
1005 1005
1006 1006 return msg
1007 1007
1008 1008 #--------------------------------------------------------------------------
1009 1009 # construct a View object
1010 1010 #--------------------------------------------------------------------------
1011 1011
1012 1012 def load_balanced_view(self, targets=None):
1013 1013 """construct a DirectView object.
1014 1014
1015 1015 If no arguments are specified, create a LoadBalancedView
1016 1016 using all engines.
1017 1017
1018 1018 Parameters
1019 1019 ----------
1020 1020
1021 1021 targets: list,slice,int,etc. [default: use all engines]
1022 1022 The subset of engines across which to load-balance
1023 1023 """
1024 1024 if targets is not None:
1025 1025 targets = self._build_targets(targets)[1]
1026 1026 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1027 1027
1028 1028 def direct_view(self, targets='all'):
1029 1029 """construct a DirectView object.
1030 1030
1031 1031 If no targets are specified, create a DirectView
1032 1032 using all engines.
1033 1033
1034 1034 Parameters
1035 1035 ----------
1036 1036
1037 1037 targets: list,slice,int,etc. [default: use all engines]
1038 1038 The engines to use for the View
1039 1039 """
1040 1040 single = isinstance(targets, int)
1041 1041 targets = self._build_targets(targets)[1]
1042 1042 if single:
1043 1043 targets = targets[0]
1044 1044 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1045 1045
1046 1046 #--------------------------------------------------------------------------
1047 1047 # Query methods
1048 1048 #--------------------------------------------------------------------------
1049 1049
1050 1050 @spin_first
1051 1051 def get_result(self, indices_or_msg_ids=None, block=None):
1052 1052 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1053 1053
1054 1054 If the client already has the results, no request to the Hub will be made.
1055 1055
1056 1056 This is a convenient way to construct AsyncResult objects, which are wrappers
1057 1057 that include metadata about execution, and allow for awaiting results that
1058 1058 were not submitted by this Client.
1059 1059
1060 1060 It can also be a convenient way to retrieve the metadata associated with
1061 1061 blocking execution, since it always retrieves
1062 1062
1063 1063 Examples
1064 1064 --------
1065 1065 ::
1066 1066
1067 1067 In [10]: r = client.apply()
1068 1068
1069 1069 Parameters
1070 1070 ----------
1071 1071
1072 1072 indices_or_msg_ids : integer history index, str msg_id, or list of either
1073 1073 The indices or msg_ids of indices to be retrieved
1074 1074
1075 1075 block : bool
1076 1076 Whether to wait for the result to be done
1077 1077
1078 1078 Returns
1079 1079 -------
1080 1080
1081 1081 AsyncResult
1082 1082 A single AsyncResult object will always be returned.
1083 1083
1084 1084 AsyncHubResult
1085 1085 A subclass of AsyncResult that retrieves results from the Hub
1086 1086
1087 1087 """
1088 1088 block = self.block if block is None else block
1089 1089 if indices_or_msg_ids is None:
1090 1090 indices_or_msg_ids = -1
1091 1091
1092 1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1093 1093 indices_or_msg_ids = [indices_or_msg_ids]
1094 1094
1095 1095 theids = []
1096 1096 for id in indices_or_msg_ids:
1097 1097 if isinstance(id, int):
1098 1098 id = self.history[id]
1099 1099 if not isinstance(id, str):
1100 1100 raise TypeError("indices must be str or int, not %r"%id)
1101 1101 theids.append(id)
1102 1102
1103 1103 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1104 1104 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1105 1105
1106 1106 if remote_ids:
1107 1107 ar = AsyncHubResult(self, msg_ids=theids)
1108 1108 else:
1109 1109 ar = AsyncResult(self, msg_ids=theids)
1110 1110
1111 1111 if block:
1112 1112 ar.wait()
1113 1113
1114 1114 return ar
1115 1115
1116 1116 @spin_first
1117 1117 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1118 1118 """Resubmit one or more tasks.
1119 1119
1120 1120 in-flight tasks may not be resubmitted.
1121 1121
1122 1122 Parameters
1123 1123 ----------
1124 1124
1125 1125 indices_or_msg_ids : integer history index, str msg_id, or list of either
1126 1126 The indices or msg_ids of indices to be retrieved
1127 1127
1128 1128 block : bool
1129 1129 Whether to wait for the result to be done
1130 1130
1131 1131 Returns
1132 1132 -------
1133 1133
1134 1134 AsyncHubResult
1135 1135 A subclass of AsyncResult that retrieves results from the Hub
1136 1136
1137 1137 """
1138 1138 block = self.block if block is None else block
1139 1139 if indices_or_msg_ids is None:
1140 1140 indices_or_msg_ids = -1
1141 1141
1142 1142 if not isinstance(indices_or_msg_ids, (list,tuple)):
1143 1143 indices_or_msg_ids = [indices_or_msg_ids]
1144 1144
1145 1145 theids = []
1146 1146 for id in indices_or_msg_ids:
1147 1147 if isinstance(id, int):
1148 1148 id = self.history[id]
1149 1149 if not isinstance(id, str):
1150 1150 raise TypeError("indices must be str or int, not %r"%id)
1151 1151 theids.append(id)
1152 1152
1153 1153 for msg_id in theids:
1154 1154 self.outstanding.discard(msg_id)
1155 1155 if msg_id in self.history:
1156 1156 self.history.remove(msg_id)
1157 1157 self.results.pop(msg_id, None)
1158 1158 self.metadata.pop(msg_id, None)
1159 1159 content = dict(msg_ids = theids)
1160 1160
1161 1161 self.session.send(self._query_socket, 'resubmit_request', content)
1162 1162
1163 1163 zmq.select([self._query_socket], [], [])
1164 1164 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1165 1165 if self.debug:
1166 1166 pprint(msg)
1167 1167 content = msg['content']
1168 1168 if content['status'] != 'ok':
1169 1169 raise self._unwrap_exception(content)
1170 1170
1171 1171 ar = AsyncHubResult(self, msg_ids=theids)
1172 1172
1173 1173 if block:
1174 1174 ar.wait()
1175 1175
1176 1176 return ar
1177 1177
1178 1178 @spin_first
1179 1179 def result_status(self, msg_ids, status_only=True):
1180 1180 """Check on the status of the result(s) of the apply request with `msg_ids`.
1181 1181
1182 1182 If status_only is False, then the actual results will be retrieved, else
1183 1183 only the status of the results will be checked.
1184 1184
1185 1185 Parameters
1186 1186 ----------
1187 1187
1188 1188 msg_ids : list of msg_ids
1189 1189 if int:
1190 1190 Passed as index to self.history for convenience.
1191 1191 status_only : bool (default: True)
1192 1192 if False:
1193 1193 Retrieve the actual results of completed tasks.
1194 1194
1195 1195 Returns
1196 1196 -------
1197 1197
1198 1198 results : dict
1199 1199 There will always be the keys 'pending' and 'completed', which will
1200 1200 be lists of msg_ids that are incomplete or complete. If `status_only`
1201 1201 is False, then completed results will be keyed by their `msg_id`.
1202 1202 """
1203 1203 if not isinstance(msg_ids, (list,tuple)):
1204 1204 msg_ids = [msg_ids]
1205 1205
1206 1206 theids = []
1207 1207 for msg_id in msg_ids:
1208 1208 if isinstance(msg_id, int):
1209 1209 msg_id = self.history[msg_id]
1210 1210 if not isinstance(msg_id, basestring):
1211 1211 raise TypeError("msg_ids must be str, not %r"%msg_id)
1212 1212 theids.append(msg_id)
1213 1213
1214 1214 completed = []
1215 1215 local_results = {}
1216 1216
1217 1217 # comment this block out to temporarily disable local shortcut:
1218 1218 for msg_id in theids:
1219 1219 if msg_id in self.results:
1220 1220 completed.append(msg_id)
1221 1221 local_results[msg_id] = self.results[msg_id]
1222 1222 theids.remove(msg_id)
1223 1223
1224 1224 if theids: # some not locally cached
1225 1225 content = dict(msg_ids=theids, status_only=status_only)
1226 1226 msg = self.session.send(self._query_socket, "result_request", content=content)
1227 1227 zmq.select([self._query_socket], [], [])
1228 1228 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1229 1229 if self.debug:
1230 1230 pprint(msg)
1231 1231 content = msg['content']
1232 1232 if content['status'] != 'ok':
1233 1233 raise self._unwrap_exception(content)
1234 1234 buffers = msg['buffers']
1235 1235 else:
1236 1236 content = dict(completed=[],pending=[])
1237 1237
1238 1238 content['completed'].extend(completed)
1239 1239
1240 1240 if status_only:
1241 1241 return content
1242 1242
1243 1243 failures = []
1244 1244 # load cached results into result:
1245 1245 content.update(local_results)
1246 1246
1247 1247 # update cache with results:
1248 1248 for msg_id in sorted(theids):
1249 1249 if msg_id in content['completed']:
1250 1250 rec = content[msg_id]
1251 1251 parent = rec['header']
1252 1252 header = rec['result_header']
1253 1253 rcontent = rec['result_content']
1254 1254 iodict = rec['io']
1255 1255 if isinstance(rcontent, str):
1256 1256 rcontent = self.session.unpack(rcontent)
1257 1257
1258 1258 md = self.metadata[msg_id]
1259 1259 md.update(self._extract_metadata(header, parent, rcontent))
1260 1260 md.update(iodict)
1261 1261
1262 1262 if rcontent['status'] == 'ok':
1263 1263 res,buffers = util.unserialize_object(buffers)
1264 1264 else:
1265 1265 print rcontent
1266 1266 res = self._unwrap_exception(rcontent)
1267 1267 failures.append(res)
1268 1268
1269 1269 self.results[msg_id] = res
1270 1270 content[msg_id] = res
1271 1271
1272 1272 if len(theids) == 1 and failures:
1273 1273 raise failures[0]
1274 1274
1275 1275 error.collect_exceptions(failures, "result_status")
1276 1276 return content
1277 1277
1278 1278 @spin_first
1279 1279 def queue_status(self, targets='all', verbose=False):
1280 1280 """Fetch the status of engine queues.
1281 1281
1282 1282 Parameters
1283 1283 ----------
1284 1284
1285 1285 targets : int/str/list of ints/strs
1286 1286 the engines whose states are to be queried.
1287 1287 default : all
1288 1288 verbose : bool
1289 1289 Whether to return lengths only, or lists of ids for each element
1290 1290 """
1291 1291 engine_ids = self._build_targets(targets)[1]
1292 1292 content = dict(targets=engine_ids, verbose=verbose)
1293 1293 self.session.send(self._query_socket, "queue_request", content=content)
1294 1294 idents,msg = self.session.recv(self._query_socket, 0)
1295 1295 if self.debug:
1296 1296 pprint(msg)
1297 1297 content = msg['content']
1298 1298 status = content.pop('status')
1299 1299 if status != 'ok':
1300 1300 raise self._unwrap_exception(content)
1301 1301 content = rekey(content)
1302 1302 if isinstance(targets, int):
1303 1303 return content[targets]
1304 1304 else:
1305 1305 return content
1306 1306
1307 1307 @spin_first
1308 1308 def purge_results(self, jobs=[], targets=[]):
1309 1309 """Tell the Hub to forget results.
1310 1310
1311 1311 Individual results can be purged by msg_id, or the entire
1312 1312 history of specific targets can be purged.
1313 1313
1314 1314 Use `purge_results('all')` to scrub everything from the Hub's db.
1315 1315
1316 1316 Parameters
1317 1317 ----------
1318 1318
1319 1319 jobs : str or list of str or AsyncResult objects
1320 1320 the msg_ids whose results should be forgotten.
1321 1321 targets : int/str/list of ints/strs
1322 1322 The targets, by int_id, whose entire history is to be purged.
1323 1323
1324 1324 default : None
1325 1325 """
1326 1326 if not targets and not jobs:
1327 1327 raise ValueError("Must specify at least one of `targets` and `jobs`")
1328 1328 if targets:
1329 1329 targets = self._build_targets(targets)[1]
1330 1330
1331 1331 # construct msg_ids from jobs
1332 1332 if jobs == 'all':
1333 1333 msg_ids = jobs
1334 1334 else:
1335 1335 msg_ids = []
1336 1336 if isinstance(jobs, (basestring,AsyncResult)):
1337 1337 jobs = [jobs]
1338 1338 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1339 1339 if bad_ids:
1340 1340 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1341 1341 for j in jobs:
1342 1342 if isinstance(j, AsyncResult):
1343 1343 msg_ids.extend(j.msg_ids)
1344 1344 else:
1345 1345 msg_ids.append(j)
1346 1346
1347 1347 content = dict(engine_ids=targets, msg_ids=msg_ids)
1348 1348 self.session.send(self._query_socket, "purge_request", content=content)
1349 1349 idents, msg = self.session.recv(self._query_socket, 0)
1350 1350 if self.debug:
1351 1351 pprint(msg)
1352 1352 content = msg['content']
1353 1353 if content['status'] != 'ok':
1354 1354 raise self._unwrap_exception(content)
1355 1355
1356 1356 @spin_first
1357 1357 def hub_history(self):
1358 1358 """Get the Hub's history
1359 1359
1360 1360 Just like the Client, the Hub has a history, which is a list of msg_ids.
1361 1361 This will contain the history of all clients, and, depending on configuration,
1362 1362 may contain history across multiple cluster sessions.
1363 1363
1364 1364 Any msg_id returned here is a valid argument to `get_result`.
1365 1365
1366 1366 Returns
1367 1367 -------
1368 1368
1369 1369 msg_ids : list of strs
1370 1370 list of all msg_ids, ordered by task submission time.
1371 1371 """
1372 1372
1373 1373 self.session.send(self._query_socket, "history_request", content={})
1374 1374 idents, msg = self.session.recv(self._query_socket, 0)
1375 1375
1376 1376 if self.debug:
1377 1377 pprint(msg)
1378 1378 content = msg['content']
1379 1379 if content['status'] != 'ok':
1380 1380 raise self._unwrap_exception(content)
1381 1381 else:
1382 1382 return content['history']
1383 1383
1384 1384 @spin_first
1385 1385 def db_query(self, query, keys=None):
1386 1386 """Query the Hub's TaskRecord database
1387 1387
1388 1388 This will return a list of task record dicts that match `query`
1389 1389
1390 1390 Parameters
1391 1391 ----------
1392 1392
1393 1393 query : mongodb query dict
1394 1394 The search dict. See mongodb query docs for details.
1395 1395 keys : list of strs [optional]
1396 1396 The subset of keys to be returned. The default is to fetch everything but buffers.
1397 1397 'msg_id' will *always* be included.
1398 1398 """
1399 1399 if isinstance(keys, basestring):
1400 1400 keys = [keys]
1401 1401 content = dict(query=query, keys=keys)
1402 1402 self.session.send(self._query_socket, "db_request", content=content)
1403 1403 idents, msg = self.session.recv(self._query_socket, 0)
1404 1404 if self.debug:
1405 1405 pprint(msg)
1406 1406 content = msg['content']
1407 1407 if content['status'] != 'ok':
1408 1408 raise self._unwrap_exception(content)
1409 1409
1410 1410 records = content['records']
1411 1411
1412 1412 buffer_lens = content['buffer_lens']
1413 1413 result_buffer_lens = content['result_buffer_lens']
1414 1414 buffers = msg['buffers']
1415 1415 has_bufs = buffer_lens is not None
1416 1416 has_rbufs = result_buffer_lens is not None
1417 1417 for i,rec in enumerate(records):
1418 1418 # relink buffers
1419 1419 if has_bufs:
1420 1420 blen = buffer_lens[i]
1421 1421 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1422 1422 if has_rbufs:
1423 1423 blen = result_buffer_lens[i]
1424 1424 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1425 1425
1426 1426 return records
1427 1427
1428 1428 __all__ = [ 'Client' ]
@@ -1,173 +1,173 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18 import time
19 19 import uuid
20 20
21 21 import zmq
22 22 from zmq.devices import ThreadDevice
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.config.configurable import LoggingConfigurable
26 26 from IPython.utils.traitlets import Set, Instance, CFloat
27 27
28 from IPython.parallel.util import ensure_bytes
28 from IPython.parallel.util import asbytes
29 29
30 30 class Heart(object):
31 31 """A basic heart object for responding to a HeartMonitor.
32 32 This is a simple wrapper with defaults for the most common
33 33 Device model for responding to heartbeats.
34 34
35 35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 36 SUB/XREQ for in/out.
37 37
38 38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
39 39 device=None
40 40 id=None
41 41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
42 42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
43 43 self.device.daemon=True
44 44 self.device.connect_in(in_addr)
45 45 self.device.connect_out(out_addr)
46 46 if in_type == zmq.SUB:
47 47 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
48 48 if heart_id is None:
49 49 heart_id = uuid.uuid4().bytes
50 50 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
51 51 self.id = heart_id
52 52
53 53 def start(self):
54 54 return self.device.start()
55 55
56 56 class HeartMonitor(LoggingConfigurable):
57 57 """A basic HeartMonitor class
58 58 pingstream: a PUB stream
59 59 pongstream: an XREP stream
60 60 period: the period of the heartbeat in milliseconds"""
61 61
62 62 period=CFloat(1000, config=True,
63 63 help='The frequency at which the Hub pings the engines for heartbeats '
64 64 ' (in ms) [default: 100]',
65 65 )
66 66
67 67 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 68 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
69 69 loop = Instance('zmq.eventloop.ioloop.IOLoop')
70 70 def _loop_default(self):
71 71 return ioloop.IOLoop.instance()
72 72
73 73 # not settable:
74 74 hearts=Set()
75 75 responses=Set()
76 76 on_probation=Set()
77 77 last_ping=CFloat(0)
78 78 _new_handlers = Set()
79 79 _failure_handlers = Set()
80 80 lifetime = CFloat(0)
81 81 tic = CFloat(0)
82 82
83 83 def __init__(self, **kwargs):
84 84 super(HeartMonitor, self).__init__(**kwargs)
85 85
86 86 self.pongstream.on_recv(self.handle_pong)
87 87
88 88 def start(self):
89 89 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
90 90 self.caller.start()
91 91
92 92 def add_new_heart_handler(self, handler):
93 93 """add a new handler for new hearts"""
94 94 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
95 95 self._new_handlers.add(handler)
96 96
97 97 def add_heart_failure_handler(self, handler):
98 98 """add a new handler for heart failure"""
99 99 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
100 100 self._failure_handlers.add(handler)
101 101
102 102 def beat(self):
103 103 self.pongstream.flush()
104 104 self.last_ping = self.lifetime
105 105
106 106 toc = time.time()
107 107 self.lifetime += toc-self.tic
108 108 self.tic = toc
109 109 # self.log.debug("heartbeat::%s"%self.lifetime)
110 110 goodhearts = self.hearts.intersection(self.responses)
111 111 missed_beats = self.hearts.difference(goodhearts)
112 112 heartfailures = self.on_probation.intersection(missed_beats)
113 113 newhearts = self.responses.difference(goodhearts)
114 114 map(self.handle_new_heart, newhearts)
115 115 map(self.handle_heart_failure, heartfailures)
116 116 self.on_probation = missed_beats.intersection(self.hearts)
117 117 self.responses = set()
118 118 # print self.on_probation, self.hearts
119 119 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
120 self.pingstream.send(ensure_bytes(str(self.lifetime)))
120 self.pingstream.send(asbytes(str(self.lifetime)))
121 121
122 122 def handle_new_heart(self, heart):
123 123 if self._new_handlers:
124 124 for handler in self._new_handlers:
125 125 handler(heart)
126 126 else:
127 127 self.log.info("heartbeat::yay, got new heart %s!"%heart)
128 128 self.hearts.add(heart)
129 129
130 130 def handle_heart_failure(self, heart):
131 131 if self._failure_handlers:
132 132 for handler in self._failure_handlers:
133 133 try:
134 134 handler(heart)
135 135 except Exception as e:
136 136 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
137 137 pass
138 138 else:
139 139 self.log.info("heartbeat::Heart %s failed :("%heart)
140 140 self.hearts.remove(heart)
141 141
142 142
143 143 def handle_pong(self, msg):
144 144 "a heart just beat"
145 current = ensure_bytes(str(self.lifetime))
146 last = ensure_bytes(str(self.last_ping))
145 current = asbytes(str(self.lifetime))
146 last = asbytes(str(self.last_ping))
147 147 if msg[1] == current:
148 148 delta = time.time()-self.tic
149 149 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
150 150 self.responses.add(msg[0])
151 151 elif msg[1] == last:
152 152 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
153 153 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
154 154 self.responses.add(msg[0])
155 155 else:
156 156 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
157 157 (msg[1],self.lifetime))
158 158
159 159
160 160 if __name__ == '__main__':
161 161 loop = ioloop.IOLoop.instance()
162 162 context = zmq.Context()
163 163 pub = context.socket(zmq.PUB)
164 164 pub.bind('tcp://127.0.0.1:5555')
165 165 xrep = context.socket(zmq.XREP)
166 166 xrep.bind('tcp://127.0.0.1:5556')
167 167
168 168 outstream = zmqstream.ZMQStream(pub, loop)
169 169 instream = zmqstream.ZMQStream(xrep, loop)
170 170
171 171 hb = HeartMonitor(loop, outstream, instream)
172 172
173 173 loop.start()
@@ -1,1291 +1,1291 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20 from __future__ import print_function
21 21
22 22 import sys
23 23 import time
24 24 from datetime import datetime
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop
28 28 from zmq.eventloop.zmqstream import ZMQStream
29 29
30 30 # internal:
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'result_header' : None,
68 68 'result_content' : None,
69 69 'result_buffers' : None,
70 70 'queue' : None,
71 71 'pyin' : None,
72 72 'pyout': None,
73 73 'pyerr': None,
74 74 'stdout': '',
75 75 'stderr': '',
76 76 }
77 77
78 78 def init_record(msg):
79 79 """Initialize a TaskRecord based on a request."""
80 80 header = msg['header']
81 81 return {
82 82 'msg_id' : header['msg_id'],
83 83 'header' : header,
84 84 'content': msg['content'],
85 85 'buffers': msg['buffers'],
86 86 'submitted': header['date'],
87 87 'client_uuid' : None,
88 88 'engine_uuid' : None,
89 89 'started': None,
90 90 'completed': None,
91 91 'resubmitted': None,
92 92 'result_header' : None,
93 93 'result_content' : None,
94 94 'result_buffers' : None,
95 95 'queue' : None,
96 96 'pyin' : None,
97 97 'pyout': None,
98 98 'pyerr': None,
99 99 'stdout': '',
100 100 'stderr': '',
101 101 }
102 102
103 103
104 104 class EngineConnector(HasTraits):
105 105 """A simple object for accessing the various zmq connections of an object.
106 106 Attributes are:
107 107 id (int): engine ID
108 108 uuid (str): uuid (unused?)
109 109 queue (str): identity of queue's XREQ socket
110 110 registration (str): identity of registration XREQ socket
111 111 heartbeat (str): identity of heartbeat XREQ socket
112 112 """
113 113 id=Int(0)
114 114 queue=CBytes()
115 115 control=CBytes()
116 116 registration=CBytes()
117 117 heartbeat=CBytes()
118 118 pending=Set()
119 119
120 120 class HubFactory(RegistrationFactory):
121 121 """The Configurable for setting up a Hub."""
122 122
123 123 # port-pairs for monitoredqueues:
124 124 hb = Tuple(Int,Int,config=True,
125 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 126 def _hb_default(self):
127 127 return tuple(util.select_random_ports(2))
128 128
129 129 mux = Tuple(Int,Int,config=True,
130 130 help="""Engine/Client Port pair for MUX queue""")
131 131
132 132 def _mux_default(self):
133 133 return tuple(util.select_random_ports(2))
134 134
135 135 task = Tuple(Int,Int,config=True,
136 136 help="""Engine/Client Port pair for Task queue""")
137 137 def _task_default(self):
138 138 return tuple(util.select_random_ports(2))
139 139
140 140 control = Tuple(Int,Int,config=True,
141 141 help="""Engine/Client Port pair for Control queue""")
142 142
143 143 def _control_default(self):
144 144 return tuple(util.select_random_ports(2))
145 145
146 146 iopub = Tuple(Int,Int,config=True,
147 147 help="""Engine/Client Port pair for IOPub relay""")
148 148
149 149 def _iopub_default(self):
150 150 return tuple(util.select_random_ports(2))
151 151
152 152 # single ports:
153 153 mon_port = Int(config=True,
154 154 help="""Monitor (SUB) port for queue traffic""")
155 155
156 156 def _mon_port_default(self):
157 157 return util.select_random_ports(1)[0]
158 158
159 159 notifier_port = Int(config=True,
160 160 help="""PUB port for sending engine status notifications""")
161 161
162 162 def _notifier_port_default(self):
163 163 return util.select_random_ports(1)[0]
164 164
165 165 engine_ip = Unicode('127.0.0.1', config=True,
166 166 help="IP on which to listen for engine connections. [default: loopback]")
167 167 engine_transport = Unicode('tcp', config=True,
168 168 help="0MQ transport for engine connections. [default: tcp]")
169 169
170 170 client_ip = Unicode('127.0.0.1', config=True,
171 171 help="IP on which to listen for client connections. [default: loopback]")
172 172 client_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for client connections. [default : tcp]")
174 174
175 175 monitor_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for monitor messages. [default: loopback]")
177 177 monitor_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for monitor messages. [default : tcp]")
179 179
180 180 monitor_url = Unicode('')
181 181
182 182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 183 config=True, help="""The class to use for the DB backend""")
184 184
185 185 # not configurable
186 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 188
189 189 def _ip_changed(self, name, old, new):
190 190 self.engine_ip = new
191 191 self.client_ip = new
192 192 self.monitor_ip = new
193 193 self._update_monitor_url()
194 194
195 195 def _update_monitor_url(self):
196 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 197
198 198 def _transport_changed(self, name, old, new):
199 199 self.engine_transport = new
200 200 self.client_transport = new
201 201 self.monitor_transport = new
202 202 self._update_monitor_url()
203 203
204 204 def __init__(self, **kwargs):
205 205 super(HubFactory, self).__init__(**kwargs)
206 206 self._update_monitor_url()
207 207
208 208
209 209 def construct(self):
210 210 self.init_hub()
211 211
212 212 def start(self):
213 213 self.heartmonitor.start()
214 214 self.log.info("Heartmonitor started")
215 215
216 216 def init_hub(self):
217 217 """construct"""
218 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 220
221 221 ctx = self.context
222 222 loop = self.loop
223 223
224 224 # Registrar socket
225 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 q.bind(client_iface % self.regport)
227 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 228 if self.client_ip != self.engine_ip:
229 229 q.bind(engine_iface % self.regport)
230 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 231
232 232 ### Engine connections ###
233 233
234 234 # heartbeat
235 235 hpub = ctx.socket(zmq.PUB)
236 236 hpub.bind(engine_iface % self.hb[0])
237 237 hrep = ctx.socket(zmq.XREP)
238 238 hrep.bind(engine_iface % self.hb[1])
239 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 240 pingstream=ZMQStream(hpub,loop),
241 241 pongstream=ZMQStream(hrep,loop)
242 242 )
243 243
244 244 ### Client connections ###
245 245 # Notifier socket
246 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 247 n.bind(client_iface%self.notifier_port)
248 248
249 249 ### build and launch the queues ###
250 250
251 251 # monitor socket
252 252 sub = ctx.socket(zmq.SUB)
253 253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 254 sub.bind(self.monitor_url)
255 255 sub.bind('inproc://monitor')
256 256 sub = ZMQStream(sub, loop)
257 257
258 258 # connect the db
259 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 260 # cdir = self.config.Global.cluster_dir
261 261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 262 config=self.config, log=self.log)
263 263 time.sleep(.25)
264 264 try:
265 265 scheme = self.config.TaskScheduler.scheme_name
266 266 except AttributeError:
267 267 from .scheduler import TaskScheduler
268 268 scheme = TaskScheduler.scheme_name.get_default_value()
269 269 # build connection dicts
270 270 self.engine_info = {
271 271 'control' : engine_iface%self.control[1],
272 272 'mux': engine_iface%self.mux[1],
273 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 274 'task' : engine_iface%self.task[1],
275 275 'iopub' : engine_iface%self.iopub[1],
276 276 # 'monitor' : engine_iface%self.mon_port,
277 277 }
278 278
279 279 self.client_info = {
280 280 'control' : client_iface%self.control[0],
281 281 'mux': client_iface%self.mux[0],
282 282 'task' : (scheme, client_iface%self.task[0]),
283 283 'iopub' : client_iface%self.iopub[0],
284 284 'notification': client_iface%self.notifier_port
285 285 }
286 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 287 self.log.debug("Hub client addrs: %s"%self.client_info)
288 288
289 289 # resubmit stream
290 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 291 url = util.disambiguate_url(self.client_info['task'][-1])
292 r.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
292 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 293 r.connect(url)
294 294
295 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 296 query=q, notifier=n, resubmit=r, db=self.db,
297 297 engine_info=self.engine_info, client_info=self.client_info,
298 298 log=self.log)
299 299
300 300
301 301 class Hub(SessionFactory):
302 302 """The IPython Controller Hub with 0MQ connections
303 303
304 304 Parameters
305 305 ==========
306 306 loop: zmq IOLoop instance
307 307 session: Session object
308 308 <removed> context: zmq context for creating new connections (?)
309 309 queue: ZMQStream for monitoring the command queue (SUB)
310 310 query: ZMQStream for engine registration and client queries requests (XREP)
311 311 heartbeat: HeartMonitor object checking the pulse of the engines
312 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 313 db: connection to db for out of memory logging of commands
314 314 NotImplemented
315 315 engine_info: dict of zmq connection information for engines to connect
316 316 to the queues.
317 317 client_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 """
320 320 # internal data structures:
321 321 ids=Set() # engine IDs
322 322 keytable=Dict()
323 323 by_ident=Dict()
324 324 engines=Dict()
325 325 clients=Dict()
326 326 hearts=Dict()
327 327 pending=Set()
328 328 queues=Dict() # pending msg_ids keyed by engine_id
329 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 330 completed=Dict() # completed msg_ids keyed by engine_id
331 331 all_completed=Set() # completed msg_ids keyed by engine_id
332 332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 334 incoming_registrations=Dict()
335 335 registration_timeout=Int()
336 336 _idcounter=Int(0)
337 337
338 338 # objects from constructor:
339 339 query=Instance(ZMQStream)
340 340 monitor=Instance(ZMQStream)
341 341 notifier=Instance(ZMQStream)
342 342 resubmit=Instance(ZMQStream)
343 343 heartmonitor=Instance(HeartMonitor)
344 344 db=Instance(object)
345 345 client_info=Dict()
346 346 engine_info=Dict()
347 347
348 348
349 349 def __init__(self, **kwargs):
350 350 """
351 351 # universal:
352 352 loop: IOLoop for creating future connections
353 353 session: streamsession for sending serialized data
354 354 # engine:
355 355 queue: ZMQStream for monitoring queue messages
356 356 query: ZMQStream for engine+client registration and client requests
357 357 heartbeat: HeartMonitor object for tracking engines
358 358 # extra:
359 359 db: ZMQStream for db connection (NotImplemented)
360 360 engine_info: zmq address/protocol dict for engine connections
361 361 client_info: zmq address/protocol dict for client connections
362 362 """
363 363
364 364 super(Hub, self).__init__(**kwargs)
365 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 366
367 367 # validate connection dicts:
368 368 for k,v in self.client_info.iteritems():
369 369 if k == 'task':
370 370 util.validate_url_container(v[1])
371 371 else:
372 372 util.validate_url_container(v)
373 373 # util.validate_url_container(self.client_info)
374 374 util.validate_url_container(self.engine_info)
375 375
376 376 # register our callbacks
377 377 self.query.on_recv(self.dispatch_query)
378 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 379
380 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 382
383 383 self.monitor_handlers = {b'in' : self.save_queue_request,
384 384 b'out': self.save_queue_result,
385 385 b'intask': self.save_task_request,
386 386 b'outtask': self.save_task_result,
387 387 b'tracktask': self.save_task_destination,
388 388 b'incontrol': _passer,
389 389 b'outcontrol': _passer,
390 390 b'iopub': self.save_iopub_message,
391 391 }
392 392
393 393 self.query_handlers = {'queue_request': self.queue_status,
394 394 'result_request': self.get_results,
395 395 'history_request': self.get_history,
396 396 'db_request': self.db_query,
397 397 'purge_request': self.purge_results,
398 398 'load_request': self.check_load,
399 399 'resubmit_request': self.resubmit_task,
400 400 'shutdown_request': self.shutdown_request,
401 401 'registration_request' : self.register_engine,
402 402 'unregistration_request' : self.unregister_engine,
403 403 'connection_request': self.connection_request,
404 404 }
405 405
406 406 # ignore resubmit replies
407 407 self.resubmit.on_recv(lambda msg: None, copy=False)
408 408
409 409 self.log.info("hub::created hub")
410 410
411 411 @property
412 412 def _next_id(self):
413 413 """gemerate a new ID.
414 414
415 415 No longer reuse old ids, just count from 0."""
416 416 newid = self._idcounter
417 417 self._idcounter += 1
418 418 return newid
419 419 # newid = 0
420 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 421 # # print newid, self.ids, self.incoming_registrations
422 422 # while newid in self.ids or newid in incoming:
423 423 # newid += 1
424 424 # return newid
425 425
426 426 #-----------------------------------------------------------------------------
427 427 # message validation
428 428 #-----------------------------------------------------------------------------
429 429
430 430 def _validate_targets(self, targets):
431 431 """turn any valid targets argument into a list of integer ids"""
432 432 if targets is None:
433 433 # default to all
434 434 targets = self.ids
435 435
436 436 if isinstance(targets, (int,str,unicode)):
437 437 # only one target specified
438 438 targets = [targets]
439 439 _targets = []
440 440 for t in targets:
441 441 # map raw identities to ids
442 442 if isinstance(t, (str,unicode)):
443 443 t = self.by_ident.get(t, t)
444 444 _targets.append(t)
445 445 targets = _targets
446 446 bad_targets = [ t for t in targets if t not in self.ids ]
447 447 if bad_targets:
448 448 raise IndexError("No Such Engine: %r"%bad_targets)
449 449 if not targets:
450 450 raise IndexError("No Engines Registered")
451 451 return targets
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # dispatch methods (1 per stream)
455 455 #-----------------------------------------------------------------------------
456 456
457 457
458 458 def dispatch_monitor_traffic(self, msg):
459 459 """all ME and Task queue messages come through here, as well as
460 460 IOPub traffic."""
461 461 self.log.debug("monitor traffic: %r"%msg[:2])
462 462 switch = msg[0]
463 463 try:
464 464 idents, msg = self.session.feed_identities(msg[1:])
465 465 except ValueError:
466 466 idents=[]
467 467 if not idents:
468 468 self.log.error("Bad Monitor Message: %r"%msg)
469 469 return
470 470 handler = self.monitor_handlers.get(switch, None)
471 471 if handler is not None:
472 472 handler(idents, msg)
473 473 else:
474 474 self.log.error("Invalid monitor topic: %r"%switch)
475 475
476 476
477 477 def dispatch_query(self, msg):
478 478 """Route registration requests and queries from clients."""
479 479 try:
480 480 idents, msg = self.session.feed_identities(msg)
481 481 except ValueError:
482 482 idents = []
483 483 if not idents:
484 484 self.log.error("Bad Query Message: %r"%msg)
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 488 msg = self.session.unpack_message(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 492 self.session.send(self.query, "hub_error", ident=client_id,
493 493 content=content)
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 497 msg_type = msg['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %r"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 553 msg = self.session.unpack_message(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 562 return
563 563 record = init_record(msg)
564 564 msg_id = record['msg_id']
565 565 # Unicode in records
566 566 record['engine_uuid'] = queue_id.decode('ascii')
567 567 record['client_uuid'] = client_id.decode('ascii')
568 568 record['queue'] = 'mux'
569 569
570 570 try:
571 571 # it's posible iopub arrived first:
572 572 existing = self.db.get_record(msg_id)
573 573 for key,evalue in existing.iteritems():
574 574 rvalue = record.get(key, None)
575 575 if evalue and rvalue and evalue != rvalue:
576 576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 577 elif evalue and not rvalue:
578 578 record[key] = evalue
579 579 try:
580 580 self.db.update_record(msg_id, record)
581 581 except Exception:
582 582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 583 except KeyError:
584 584 try:
585 585 self.db.add_record(msg_id, record)
586 586 except Exception:
587 587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588 588
589 589
590 590 self.pending.add(msg_id)
591 591 self.queues[eid].append(msg_id)
592 592
593 593 def save_queue_result(self, idents, msg):
594 594 if len(idents) < 2:
595 595 self.log.error("invalid identity prefix: %r"%idents)
596 596 return
597 597
598 598 client_id, queue_id = idents[:2]
599 599 try:
600 600 msg = self.session.unpack_message(msg)
601 601 except Exception:
602 602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 603 queue_id,client_id, msg), exc_info=True)
604 604 return
605 605
606 606 eid = self.by_ident.get(queue_id, None)
607 607 if eid is None:
608 608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 609 return
610 610
611 611 parent = msg['parent_header']
612 612 if not parent:
613 613 return
614 614 msg_id = parent['msg_id']
615 615 if msg_id in self.pending:
616 616 self.pending.remove(msg_id)
617 617 self.all_completed.add(msg_id)
618 618 self.queues[eid].remove(msg_id)
619 619 self.completed[eid].append(msg_id)
620 620 elif msg_id not in self.all_completed:
621 621 # it could be a result from a dead engine that died before delivering the
622 622 # result
623 623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 624 return
625 625 # update record anyway, because the unregistration could have been premature
626 626 rheader = msg['header']
627 627 completed = rheader['date']
628 628 started = rheader.get('started', None)
629 629 result = {
630 630 'result_header' : rheader,
631 631 'result_content': msg['content'],
632 632 'started' : started,
633 633 'completed' : completed
634 634 }
635 635
636 636 result['result_buffers'] = msg['buffers']
637 637 try:
638 638 self.db.update_record(msg_id, result)
639 639 except Exception:
640 640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641 641
642 642
643 643 #--------------------- Task Queue Traffic ------------------------------
644 644
645 645 def save_task_request(self, idents, msg):
646 646 """Save the submission of a task."""
647 647 client_id = idents[0]
648 648
649 649 try:
650 650 msg = self.session.unpack_message(msg)
651 651 except Exception:
652 652 self.log.error("task::client %r sent invalid task message: %r"%(
653 653 client_id, msg), exc_info=True)
654 654 return
655 655 record = init_record(msg)
656 656
657 657 record['client_uuid'] = client_id
658 658 record['queue'] = 'task'
659 659 header = msg['header']
660 660 msg_id = header['msg_id']
661 661 self.pending.add(msg_id)
662 662 self.unassigned.add(msg_id)
663 663 try:
664 664 # it's posible iopub arrived first:
665 665 existing = self.db.get_record(msg_id)
666 666 if existing['resubmitted']:
667 667 for key in ('submitted', 'client_uuid', 'buffers'):
668 668 # don't clobber these keys on resubmit
669 669 # submitted and client_uuid should be different
670 670 # and buffers might be big, and shouldn't have changed
671 671 record.pop(key)
672 672 # still check content,header which should not change
673 673 # but are not expensive to compare as buffers
674 674
675 675 for key,evalue in existing.iteritems():
676 676 if key.endswith('buffers'):
677 677 # don't compare buffers
678 678 continue
679 679 rvalue = record.get(key, None)
680 680 if evalue and rvalue and evalue != rvalue:
681 681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 682 elif evalue and not rvalue:
683 683 record[key] = evalue
684 684 try:
685 685 self.db.update_record(msg_id, record)
686 686 except Exception:
687 687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 688 except KeyError:
689 689 try:
690 690 self.db.add_record(msg_id, record)
691 691 except Exception:
692 692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 693 except Exception:
694 694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695 695
696 696 def save_task_result(self, idents, msg):
697 697 """save the result of a completed task."""
698 698 client_id = idents[0]
699 699 try:
700 700 msg = self.session.unpack_message(msg)
701 701 except Exception:
702 702 self.log.error("task::invalid task result message send to %r: %r"%(
703 703 client_id, msg), exc_info=True)
704 704 return
705 705
706 706 parent = msg['parent_header']
707 707 if not parent:
708 708 # print msg
709 709 self.log.warn("Task %r had no parent!"%msg)
710 710 return
711 711 msg_id = parent['msg_id']
712 712 if msg_id in self.unassigned:
713 713 self.unassigned.remove(msg_id)
714 714
715 715 header = msg['header']
716 716 engine_uuid = header.get('engine', None)
717 717 eid = self.by_ident.get(engine_uuid, None)
718 718
719 719 if msg_id in self.pending:
720 720 self.pending.remove(msg_id)
721 721 self.all_completed.add(msg_id)
722 722 if eid is not None:
723 723 self.completed[eid].append(msg_id)
724 724 if msg_id in self.tasks[eid]:
725 725 self.tasks[eid].remove(msg_id)
726 726 completed = header['date']
727 727 started = header.get('started', None)
728 728 result = {
729 729 'result_header' : header,
730 730 'result_content': msg['content'],
731 731 'started' : started,
732 732 'completed' : completed,
733 733 'engine_uuid': engine_uuid
734 734 }
735 735
736 736 result['result_buffers'] = msg['buffers']
737 737 try:
738 738 self.db.update_record(msg_id, result)
739 739 except Exception:
740 740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741 741
742 742 else:
743 743 self.log.debug("task::unknown task %r finished"%msg_id)
744 744
745 745 def save_task_destination(self, idents, msg):
746 746 try:
747 747 msg = self.session.unpack_message(msg, content=True)
748 748 except Exception:
749 749 self.log.error("task::invalid task tracking message", exc_info=True)
750 750 return
751 751 content = msg['content']
752 752 # print (content)
753 753 msg_id = content['msg_id']
754 754 engine_uuid = content['engine_id']
755 eid = self.by_ident[util.ensure_bytes(engine_uuid)]
755 eid = self.by_ident[util.asbytes(engine_uuid)]
756 756
757 757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 758 if msg_id in self.unassigned:
759 759 self.unassigned.remove(msg_id)
760 760 # else:
761 761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762 762
763 763 self.tasks[eid].append(msg_id)
764 764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 765 try:
766 766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 767 except Exception:
768 768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769 769
770 770
771 771 def mia_task_request(self, idents, msg):
772 772 raise NotImplementedError
773 773 client_id = idents[0]
774 774 # content = dict(mia=self.mia,status='ok')
775 775 # self.session.send('mia_reply', content=content, idents=client_id)
776 776
777 777
778 778 #--------------------- IOPub Traffic ------------------------------
779 779
780 780 def save_iopub_message(self, topics, msg):
781 781 """save an iopub message into the db"""
782 782 # print (topics)
783 783 try:
784 784 msg = self.session.unpack_message(msg, content=True)
785 785 except Exception:
786 786 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 787 return
788 788
789 789 parent = msg['parent_header']
790 790 if not parent:
791 791 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 792 return
793 793 msg_id = parent['msg_id']
794 794 msg_type = msg['msg_type']
795 795 content = msg['content']
796 796
797 797 # ensure msg_id is in db
798 798 try:
799 799 rec = self.db.get_record(msg_id)
800 800 except KeyError:
801 801 rec = empty_record()
802 802 rec['msg_id'] = msg_id
803 803 self.db.add_record(msg_id, rec)
804 804 # stream
805 805 d = {}
806 806 if msg_type == 'stream':
807 807 name = content['name']
808 808 s = rec[name] or ''
809 809 d[name] = s + content['data']
810 810
811 811 elif msg_type == 'pyerr':
812 812 d['pyerr'] = content
813 813 elif msg_type == 'pyin':
814 814 d['pyin'] = content['code']
815 815 else:
816 816 d[msg_type] = content.get('data', '')
817 817
818 818 try:
819 819 self.db.update_record(msg_id, d)
820 820 except Exception:
821 821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822 822
823 823
824 824
825 825 #-------------------------------------------------------------------------
826 826 # Registration requests
827 827 #-------------------------------------------------------------------------
828 828
829 829 def connection_request(self, client_id, msg):
830 830 """Reply with connection addresses for clients."""
831 831 self.log.info("client::client %r connected"%client_id)
832 832 content = dict(status='ok')
833 833 content.update(self.client_info)
834 834 jsonable = {}
835 835 for k,v in self.keytable.iteritems():
836 836 if v not in self.dead_engines:
837 837 jsonable[str(k)] = v.decode('ascii')
838 838 content['engines'] = jsonable
839 839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840 840
841 841 def register_engine(self, reg, msg):
842 842 """Register a new engine."""
843 843 content = msg['content']
844 844 try:
845 queue = util.ensure_bytes(content['queue'])
845 queue = util.asbytes(content['queue'])
846 846 except KeyError:
847 847 self.log.error("registration::queue not specified", exc_info=True)
848 848 return
849 849 heart = content.get('heartbeat', None)
850 850 if heart:
851 heart = util.ensure_bytes(heart)
851 heart = util.asbytes(heart)
852 852 """register a new engine, and create the socket(s) necessary"""
853 853 eid = self._next_id
854 854 # print (eid, queue, reg, heart)
855 855
856 856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857 857
858 858 content = dict(id=eid,status='ok')
859 859 content.update(self.engine_info)
860 860 # check if requesting available IDs:
861 861 if queue in self.by_ident:
862 862 try:
863 863 raise KeyError("queue_id %r in use"%queue)
864 864 except:
865 865 content = error.wrap_exception()
866 866 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 867 elif heart in self.hearts: # need to check unique hearts?
868 868 try:
869 869 raise KeyError("heart_id %r in use"%heart)
870 870 except:
871 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 872 content = error.wrap_exception()
873 873 else:
874 874 for h, pack in self.incoming_registrations.iteritems():
875 875 if heart == h:
876 876 try:
877 877 raise KeyError("heart_id %r in use"%heart)
878 878 except:
879 879 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 880 content = error.wrap_exception()
881 881 break
882 882 elif queue == pack[1]:
883 883 try:
884 884 raise KeyError("queue_id %r in use"%queue)
885 885 except:
886 886 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 887 content = error.wrap_exception()
888 888 break
889 889
890 890 msg = self.session.send(self.query, "registration_reply",
891 891 content=content,
892 892 ident=reg)
893 893
894 894 if content['status'] == 'ok':
895 895 if heart in self.heartmonitor.hearts:
896 896 # already beating
897 897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 898 self.finish_registration(heart)
899 899 else:
900 900 purge = lambda : self._purge_stalled_registration(heart)
901 901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 902 dc.start()
903 903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 904 else:
905 905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 906 return eid
907 907
908 908 def unregister_engine(self, ident, msg):
909 909 """Unregister an engine that explicitly requested to leave."""
910 910 try:
911 911 eid = msg['content']['id']
912 912 except:
913 913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 914 return
915 915 self.log.info("registration::unregister_engine(%r)"%eid)
916 916 # print (eid)
917 917 uuid = self.keytable[eid]
918 content=dict(id=eid, queue=uuid.decode())
918 content=dict(id=eid, queue=uuid.decode('ascii'))
919 919 self.dead_engines.add(uuid)
920 920 # self.ids.remove(eid)
921 921 # uuid = self.keytable.pop(eid)
922 922 #
923 923 # ec = self.engines.pop(eid)
924 924 # self.hearts.pop(ec.heartbeat)
925 925 # self.by_ident.pop(ec.queue)
926 926 # self.completed.pop(eid)
927 927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 929 dc.start()
930 930 ############## TODO: HANDLE IT ################
931 931
932 932 if self.notifier:
933 933 self.session.send(self.notifier, "unregistration_notification", content=content)
934 934
935 935 def _handle_stranded_msgs(self, eid, uuid):
936 936 """Handle messages known to be on an engine when the engine unregisters.
937 937
938 938 It is possible that this will fire prematurely - that is, an engine will
939 939 go down after completing a result, and the client will be notified
940 940 that the result failed and later receive the actual result.
941 941 """
942 942
943 943 outstanding = self.queues[eid]
944 944
945 945 for msg_id in outstanding:
946 946 self.pending.remove(msg_id)
947 947 self.all_completed.add(msg_id)
948 948 try:
949 949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 950 except:
951 951 content = error.wrap_exception()
952 952 # build a fake header:
953 953 header = {}
954 954 header['engine'] = uuid
955 955 header['date'] = datetime.now()
956 956 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 957 rec['completed'] = header['date']
958 958 rec['engine_uuid'] = uuid
959 959 try:
960 960 self.db.update_record(msg_id, rec)
961 961 except Exception:
962 962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963 963
964 964
965 965 def finish_registration(self, heart):
966 966 """Second half of engine registration, called after our HeartMonitor
967 967 has received a beat from the Engine's Heart."""
968 968 try:
969 969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 970 except KeyError:
971 971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 972 return
973 973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 974 if purge is not None:
975 975 purge.stop()
976 976 control = queue
977 977 self.ids.add(eid)
978 978 self.keytable[eid] = queue
979 979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 980 control=control, heartbeat=heart)
981 981 self.by_ident[queue] = eid
982 982 self.queues[eid] = list()
983 983 self.tasks[eid] = list()
984 984 self.completed[eid] = list()
985 985 self.hearts[heart] = eid
986 content = dict(id=eid, queue=self.engines[eid].queue.decode())
986 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 987 if self.notifier:
988 988 self.session.send(self.notifier, "registration_notification", content=content)
989 989 self.log.info("engine::Engine Connected: %i"%eid)
990 990
991 991 def _purge_stalled_registration(self, heart):
992 992 if heart in self.incoming_registrations:
993 993 eid = self.incoming_registrations.pop(heart)[0]
994 994 self.log.info("registration::purging stalled registration: %i"%eid)
995 995 else:
996 996 pass
997 997
998 998 #-------------------------------------------------------------------------
999 999 # Client Requests
1000 1000 #-------------------------------------------------------------------------
1001 1001
1002 1002 def shutdown_request(self, client_id, msg):
1003 1003 """handle shutdown request."""
1004 1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 1005 # also notify other clients of shutdown
1006 1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 1008 dc.start()
1009 1009
1010 1010 def _shutdown(self):
1011 1011 self.log.info("hub::hub shutting down.")
1012 1012 time.sleep(0.1)
1013 1013 sys.exit(0)
1014 1014
1015 1015
1016 1016 def check_load(self, client_id, msg):
1017 1017 content = msg['content']
1018 1018 try:
1019 1019 targets = content['targets']
1020 1020 targets = self._validate_targets(targets)
1021 1021 except:
1022 1022 content = error.wrap_exception()
1023 1023 self.session.send(self.query, "hub_error",
1024 1024 content=content, ident=client_id)
1025 1025 return
1026 1026
1027 1027 content = dict(status='ok')
1028 1028 # loads = {}
1029 1029 for t in targets:
1030 1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032 1032
1033 1033
1034 1034 def queue_status(self, client_id, msg):
1035 1035 """Return the Queue status of one or more targets.
1036 1036 if verbose: return the msg_ids
1037 1037 else: return len of each type.
1038 1038 keys: queue (pending MUX jobs)
1039 1039 tasks (pending Task jobs)
1040 1040 completed (finished jobs from both queues)"""
1041 1041 content = msg['content']
1042 1042 targets = content['targets']
1043 1043 try:
1044 1044 targets = self._validate_targets(targets)
1045 1045 except:
1046 1046 content = error.wrap_exception()
1047 1047 self.session.send(self.query, "hub_error",
1048 1048 content=content, ident=client_id)
1049 1049 return
1050 1050 verbose = content.get('verbose', False)
1051 1051 content = dict(status='ok')
1052 1052 for t in targets:
1053 1053 queue = self.queues[t]
1054 1054 completed = self.completed[t]
1055 1055 tasks = self.tasks[t]
1056 1056 if not verbose:
1057 1057 queue = len(queue)
1058 1058 completed = len(completed)
1059 1059 tasks = len(tasks)
1060 1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 1062 # print (content)
1063 1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064 1064
1065 1065 def purge_results(self, client_id, msg):
1066 1066 """Purge results from memory. This method is more valuable before we move
1067 1067 to a DB based message storage mechanism."""
1068 1068 content = msg['content']
1069 1069 self.log.info("Dropping records with %s", content)
1070 1070 msg_ids = content.get('msg_ids', [])
1071 1071 reply = dict(status='ok')
1072 1072 if msg_ids == 'all':
1073 1073 try:
1074 1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 1075 except Exception:
1076 1076 reply = error.wrap_exception()
1077 1077 else:
1078 1078 pending = filter(lambda m: m in self.pending, msg_ids)
1079 1079 if pending:
1080 1080 try:
1081 1081 raise IndexError("msg pending: %r"%pending[0])
1082 1082 except:
1083 1083 reply = error.wrap_exception()
1084 1084 else:
1085 1085 try:
1086 1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 1087 except Exception:
1088 1088 reply = error.wrap_exception()
1089 1089
1090 1090 if reply['status'] == 'ok':
1091 1091 eids = content.get('engine_ids', [])
1092 1092 for eid in eids:
1093 1093 if eid not in self.engines:
1094 1094 try:
1095 1095 raise IndexError("No such engine: %i"%eid)
1096 1096 except:
1097 1097 reply = error.wrap_exception()
1098 1098 break
1099 1099 uid = self.engines[eid].queue
1100 1100 try:
1101 1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 1102 except Exception:
1103 1103 reply = error.wrap_exception()
1104 1104 break
1105 1105
1106 1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107 1107
1108 1108 def resubmit_task(self, client_id, msg):
1109 1109 """Resubmit one or more tasks."""
1110 1110 def finish(reply):
1111 1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112 1112
1113 1113 content = msg['content']
1114 1114 msg_ids = content['msg_ids']
1115 1115 reply = dict(status='ok')
1116 1116 try:
1117 1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 1118 'header', 'content', 'buffers'])
1119 1119 except Exception:
1120 1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 1121 return finish(error.wrap_exception())
1122 1122
1123 1123 # validate msg_ids
1124 1124 found_ids = [ rec['msg_id'] for rec in records ]
1125 1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 1126 if len(records) > len(msg_ids):
1127 1127 try:
1128 1128 raise RuntimeError("DB appears to be in an inconsistent state."
1129 1129 "More matching records were found than should exist")
1130 1130 except Exception:
1131 1131 return finish(error.wrap_exception())
1132 1132 elif len(records) < len(msg_ids):
1133 1133 missing = [ m for m in msg_ids if m not in found_ids ]
1134 1134 try:
1135 1135 raise KeyError("No such msg(s): %r"%missing)
1136 1136 except KeyError:
1137 1137 return finish(error.wrap_exception())
1138 1138 elif invalid_ids:
1139 1139 msg_id = invalid_ids[0]
1140 1140 try:
1141 1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 1142 except Exception:
1143 1143 return finish(error.wrap_exception())
1144 1144
1145 1145 # clear the existing records
1146 1146 now = datetime.now()
1147 1147 rec = empty_record()
1148 1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 1149 rec['resubmitted'] = now
1150 1150 rec['queue'] = 'task'
1151 1151 rec['client_uuid'] = client_id[0]
1152 1152 try:
1153 1153 for msg_id in msg_ids:
1154 1154 self.all_completed.discard(msg_id)
1155 1155 self.db.update_record(msg_id, rec)
1156 1156 except Exception:
1157 1157 self.log.error('db::db error upating record', exc_info=True)
1158 1158 reply = error.wrap_exception()
1159 1159 else:
1160 1160 # send the messages
1161 1161 for rec in records:
1162 1162 header = rec['header']
1163 1163 # include resubmitted in header to prevent digest collision
1164 1164 header['resubmitted'] = now
1165 1165 msg = self.session.msg(header['msg_type'])
1166 1166 msg['content'] = rec['content']
1167 1167 msg['header'] = header
1168 1168 msg['msg_id'] = rec['msg_id']
1169 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170 1170
1171 1171 finish(dict(status='ok'))
1172 1172
1173 1173
1174 1174 def _extract_record(self, rec):
1175 1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 1176 io_dict = {}
1177 1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 1178 io_dict[key] = rec[key]
1179 1179 content = { 'result_content': rec['result_content'],
1180 1180 'header': rec['header'],
1181 1181 'result_header' : rec['result_header'],
1182 1182 'io' : io_dict,
1183 1183 }
1184 1184 if rec['result_buffers']:
1185 1185 buffers = map(bytes, rec['result_buffers'])
1186 1186 else:
1187 1187 buffers = []
1188 1188
1189 1189 return content, buffers
1190 1190
1191 1191 def get_results(self, client_id, msg):
1192 1192 """Get the result of 1 or more messages."""
1193 1193 content = msg['content']
1194 1194 msg_ids = sorted(set(content['msg_ids']))
1195 1195 statusonly = content.get('status_only', False)
1196 1196 pending = []
1197 1197 completed = []
1198 1198 content = dict(status='ok')
1199 1199 content['pending'] = pending
1200 1200 content['completed'] = completed
1201 1201 buffers = []
1202 1202 if not statusonly:
1203 1203 try:
1204 1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 1205 # turn match list into dict, for faster lookup
1206 1206 records = {}
1207 1207 for rec in matches:
1208 1208 records[rec['msg_id']] = rec
1209 1209 except Exception:
1210 1210 content = error.wrap_exception()
1211 1211 self.session.send(self.query, "result_reply", content=content,
1212 1212 parent=msg, ident=client_id)
1213 1213 return
1214 1214 else:
1215 1215 records = {}
1216 1216 for msg_id in msg_ids:
1217 1217 if msg_id in self.pending:
1218 1218 pending.append(msg_id)
1219 1219 elif msg_id in self.all_completed:
1220 1220 completed.append(msg_id)
1221 1221 if not statusonly:
1222 1222 c,bufs = self._extract_record(records[msg_id])
1223 1223 content[msg_id] = c
1224 1224 buffers.extend(bufs)
1225 1225 elif msg_id in records:
1226 1226 if rec['completed']:
1227 1227 completed.append(msg_id)
1228 1228 c,bufs = self._extract_record(records[msg_id])
1229 1229 content[msg_id] = c
1230 1230 buffers.extend(bufs)
1231 1231 else:
1232 1232 pending.append(msg_id)
1233 1233 else:
1234 1234 try:
1235 1235 raise KeyError('No such message: '+msg_id)
1236 1236 except:
1237 1237 content = error.wrap_exception()
1238 1238 break
1239 1239 self.session.send(self.query, "result_reply", content=content,
1240 1240 parent=msg, ident=client_id,
1241 1241 buffers=buffers)
1242 1242
1243 1243 def get_history(self, client_id, msg):
1244 1244 """Get a list of all msg_ids in our DB records"""
1245 1245 try:
1246 1246 msg_ids = self.db.get_history()
1247 1247 except Exception as e:
1248 1248 content = error.wrap_exception()
1249 1249 else:
1250 1250 content = dict(status='ok', history=msg_ids)
1251 1251
1252 1252 self.session.send(self.query, "history_reply", content=content,
1253 1253 parent=msg, ident=client_id)
1254 1254
1255 1255 def db_query(self, client_id, msg):
1256 1256 """Perform a raw query on the task record database."""
1257 1257 content = msg['content']
1258 1258 query = content.get('query', {})
1259 1259 keys = content.get('keys', None)
1260 1260 buffers = []
1261 1261 empty = list()
1262 1262 try:
1263 1263 records = self.db.find_records(query, keys)
1264 1264 except Exception as e:
1265 1265 content = error.wrap_exception()
1266 1266 else:
1267 1267 # extract buffers from reply content:
1268 1268 if keys is not None:
1269 1269 buffer_lens = [] if 'buffers' in keys else None
1270 1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 1271 else:
1272 1272 buffer_lens = []
1273 1273 result_buffer_lens = []
1274 1274
1275 1275 for rec in records:
1276 1276 # buffers may be None, so double check
1277 1277 if buffer_lens is not None:
1278 1278 b = rec.pop('buffers', empty) or empty
1279 1279 buffer_lens.append(len(b))
1280 1280 buffers.extend(b)
1281 1281 if result_buffer_lens is not None:
1282 1282 rb = rec.pop('result_buffers', empty) or empty
1283 1283 result_buffer_lens.append(len(rb))
1284 1284 buffers.extend(rb)
1285 1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 1286 result_buffer_lens=result_buffer_lens)
1287 1287 # self.log.debug (content)
1288 1288 self.session.send(self.query, "db_reply", content=content,
1289 1289 parent=msg, ident=client_id,
1290 1290 buffers=buffers)
1291 1291
@@ -1,714 +1,714 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger, ensure_bytes
47 from IPython.parallel.util import connect_logger, local_logger, asbytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Int(0, config=True, shortname='hwm',
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine."""
138 138 )
139 139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 140 'leastload', config=True, shortname='scheme', allow_none=False,
141 141 help="""select the task scheduler scheme [default: Python LRU]
142 142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 143 )
144 144 def _scheme_name_changed(self, old, new):
145 145 self.log.debug("Using scheme %r"%new)
146 146 self.scheme = globals()[new]
147 147
148 148 # input arguments:
149 149 scheme = Instance(FunctionType) # function for determining the destination
150 150 def _scheme_default(self):
151 151 return leastload
152 152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156 156
157 157 # internals:
158 158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 162 pending = Dict() # dict by engine_uuid of submitted tasks
163 163 completed = Dict() # dict by engine_uuid of completed tasks
164 164 failed = Dict() # dict by engine_uuid of failed tasks
165 165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 166 clients = Dict() # dict by msg_id for who submitted the task
167 167 targets = List() # list of target IDENTs
168 168 loads = List() # list of engine loads
169 169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 170 all_completed = Set() # set of all completed tasks
171 171 all_failed = Set() # set of all failed tasks
172 172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 173 all_ids = Set() # set of all submitted task IDs
174 174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176 176
177 177 ident = CBytes() # ZMQ identity. This should just be self.session.session
178 178 # but ensure Bytes
179 179 def _ident_default(self):
180 return ensure_bytes(self.session.session)
180 return asbytes(self.session.session)
181 181
182 182 def start(self):
183 183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
184 184 self._notification_handlers = dict(
185 185 registration_notification = self._register_engine,
186 186 unregistration_notification = self._unregister_engine
187 187 )
188 188 self.notifier_stream.on_recv(self.dispatch_notification)
189 189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
190 190 self.auditor.start()
191 191 self.log.info("Scheduler started [%s]"%self.scheme_name)
192 192
193 193 def resume_receiving(self):
194 194 """Resume accepting jobs."""
195 195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
196 196
197 197 def stop_receiving(self):
198 198 """Stop accepting jobs while there are no engines.
199 199 Leave them in the ZMQ queue."""
200 200 self.client_stream.on_recv(None)
201 201
202 202 #-----------------------------------------------------------------------
203 203 # [Un]Registration Handling
204 204 #-----------------------------------------------------------------------
205 205
206 206 def dispatch_notification(self, msg):
207 207 """dispatch register/unregister events."""
208 208 try:
209 209 idents,msg = self.session.feed_identities(msg)
210 210 except ValueError:
211 211 self.log.warn("task::Invalid Message: %r",msg)
212 212 return
213 213 try:
214 214 msg = self.session.unpack_message(msg)
215 215 except ValueError:
216 216 self.log.warn("task::Unauthorized message from: %r"%idents)
217 217 return
218 218
219 219 msg_type = msg['msg_type']
220 220
221 221 handler = self._notification_handlers.get(msg_type, None)
222 222 if handler is None:
223 223 self.log.error("Unhandled message type: %r"%msg_type)
224 224 else:
225 225 try:
226 handler(ensure_bytes(msg['content']['queue']))
226 handler(asbytes(msg['content']['queue']))
227 227 except Exception:
228 228 self.log.error("task::Invalid notification msg: %r",msg)
229 229
230 230 def _register_engine(self, uid):
231 231 """New engine with ident `uid` became available."""
232 232 # head of the line:
233 233 self.targets.insert(0,uid)
234 234 self.loads.insert(0,0)
235 235
236 236 # initialize sets
237 237 self.completed[uid] = set()
238 238 self.failed[uid] = set()
239 239 self.pending[uid] = {}
240 240 if len(self.targets) == 1:
241 241 self.resume_receiving()
242 242 # rescan the graph:
243 243 self.update_graph(None)
244 244
245 245 def _unregister_engine(self, uid):
246 246 """Existing engine with ident `uid` became unavailable."""
247 247 if len(self.targets) == 1:
248 248 # this was our only engine
249 249 self.stop_receiving()
250 250
251 251 # handle any potentially finished tasks:
252 252 self.engine_stream.flush()
253 253
254 254 # don't pop destinations, because they might be used later
255 255 # map(self.destinations.pop, self.completed.pop(uid))
256 256 # map(self.destinations.pop, self.failed.pop(uid))
257 257
258 258 # prevent this engine from receiving work
259 259 idx = self.targets.index(uid)
260 260 self.targets.pop(idx)
261 261 self.loads.pop(idx)
262 262
263 263 # wait 5 seconds before cleaning up pending jobs, since the results might
264 264 # still be incoming
265 265 if self.pending[uid]:
266 266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
267 267 dc.start()
268 268 else:
269 269 self.completed.pop(uid)
270 270 self.failed.pop(uid)
271 271
272 272
273 273 def handle_stranded_tasks(self, engine):
274 274 """Deal with jobs resident in an engine that died."""
275 275 lost = self.pending[engine]
276 276 for msg_id in lost.keys():
277 277 if msg_id not in self.pending[engine]:
278 278 # prevent double-handling of messages
279 279 continue
280 280
281 281 raw_msg = lost[msg_id][0]
282 282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
283 283 parent = self.session.unpack(msg[1].bytes)
284 284 idents = [engine, idents[0]]
285 285
286 286 # build fake error reply
287 287 try:
288 288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
289 289 except:
290 290 content = error.wrap_exception()
291 291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
292 292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
293 293 # and dispatch it
294 294 self.dispatch_result(raw_reply)
295 295
296 296 # finally scrub completed/failed lists
297 297 self.completed.pop(engine)
298 298 self.failed.pop(engine)
299 299
300 300
301 301 #-----------------------------------------------------------------------
302 302 # Job Submission
303 303 #-----------------------------------------------------------------------
304 304 def dispatch_submission(self, raw_msg):
305 305 """Dispatch job submission to appropriate handlers."""
306 306 # ensure targets up to date:
307 307 self.notifier_stream.flush()
308 308 try:
309 309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
310 310 msg = self.session.unpack_message(msg, content=False, copy=False)
311 311 except Exception:
312 312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
313 313 return
314 314
315 315
316 316 # send to monitor
317 317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
318 318
319 319 header = msg['header']
320 320 msg_id = header['msg_id']
321 321 self.all_ids.add(msg_id)
322 322
323 323 # get targets as a set of bytes objects
324 324 # from a list of unicode objects
325 325 targets = header.get('targets', [])
326 targets = map(ensure_bytes, targets)
326 targets = map(asbytes, targets)
327 327 targets = set(targets)
328 328
329 329 retries = header.get('retries', 0)
330 330 self.retries[msg_id] = retries
331 331
332 332 # time dependencies
333 333 after = header.get('after', None)
334 334 if after:
335 335 after = Dependency(after)
336 336 if after.all:
337 337 if after.success:
338 338 after = Dependency(after.difference(self.all_completed),
339 339 success=after.success,
340 340 failure=after.failure,
341 341 all=after.all,
342 342 )
343 343 if after.failure:
344 344 after = Dependency(after.difference(self.all_failed),
345 345 success=after.success,
346 346 failure=after.failure,
347 347 all=after.all,
348 348 )
349 349 if after.check(self.all_completed, self.all_failed):
350 350 # recast as empty set, if `after` already met,
351 351 # to prevent unnecessary set comparisons
352 352 after = MET
353 353 else:
354 354 after = MET
355 355
356 356 # location dependencies
357 357 follow = Dependency(header.get('follow', []))
358 358
359 359 # turn timeouts into datetime objects:
360 360 timeout = header.get('timeout', None)
361 361 if timeout:
362 362 timeout = datetime.now() + timedelta(0,timeout,0)
363 363
364 364 args = [raw_msg, targets, after, follow, timeout]
365 365
366 366 # validate and reduce dependencies:
367 367 for dep in after,follow:
368 368 if not dep: # empty dependency
369 369 continue
370 370 # check valid:
371 371 if msg_id in dep or dep.difference(self.all_ids):
372 372 self.depending[msg_id] = args
373 373 return self.fail_unreachable(msg_id, error.InvalidDependency)
374 374 # check if unreachable:
375 375 if dep.unreachable(self.all_completed, self.all_failed):
376 376 self.depending[msg_id] = args
377 377 return self.fail_unreachable(msg_id)
378 378
379 379 if after.check(self.all_completed, self.all_failed):
380 380 # time deps already met, try to run
381 381 if not self.maybe_run(msg_id, *args):
382 382 # can't run yet
383 383 if msg_id not in self.all_failed:
384 384 # could have failed as unreachable
385 385 self.save_unmet(msg_id, *args)
386 386 else:
387 387 self.save_unmet(msg_id, *args)
388 388
389 389 def audit_timeouts(self):
390 390 """Audit all waiting tasks for expired timeouts."""
391 391 now = datetime.now()
392 392 for msg_id in self.depending.keys():
393 393 # must recheck, in case one failure cascaded to another:
394 394 if msg_id in self.depending:
395 395 raw,after,targets,follow,timeout = self.depending[msg_id]
396 396 if timeout and timeout < now:
397 397 self.fail_unreachable(msg_id, error.TaskTimeout)
398 398
399 399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
400 400 """a task has become unreachable, send a reply with an ImpossibleDependency
401 401 error."""
402 402 if msg_id not in self.depending:
403 403 self.log.error("msg %r already failed!", msg_id)
404 404 return
405 405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
406 406 for mid in follow.union(after):
407 407 if mid in self.graph:
408 408 self.graph[mid].remove(msg_id)
409 409
410 410 # FIXME: unpacking a message I've already unpacked, but didn't save:
411 411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
412 412 header = self.session.unpack(msg[1].bytes)
413 413
414 414 try:
415 415 raise why()
416 416 except:
417 417 content = error.wrap_exception()
418 418
419 419 self.all_done.add(msg_id)
420 420 self.all_failed.add(msg_id)
421 421
422 422 msg = self.session.send(self.client_stream, 'apply_reply', content,
423 423 parent=header, ident=idents)
424 424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
425 425
426 426 self.update_graph(msg_id, success=False)
427 427
428 428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
429 429 """check location dependencies, and run if they are met."""
430 430 blacklist = self.blacklist.setdefault(msg_id, set())
431 431 if follow or targets or blacklist or self.hwm:
432 432 # we need a can_run filter
433 433 def can_run(idx):
434 434 # check hwm
435 435 if self.hwm and self.loads[idx] == self.hwm:
436 436 return False
437 437 target = self.targets[idx]
438 438 # check blacklist
439 439 if target in blacklist:
440 440 return False
441 441 # check targets
442 442 if targets and target not in targets:
443 443 return False
444 444 # check follow
445 445 return follow.check(self.completed[target], self.failed[target])
446 446
447 447 indices = filter(can_run, range(len(self.targets)))
448 448
449 449 if not indices:
450 450 # couldn't run
451 451 if follow.all:
452 452 # check follow for impossibility
453 453 dests = set()
454 454 relevant = set()
455 455 if follow.success:
456 456 relevant = self.all_completed
457 457 if follow.failure:
458 458 relevant = relevant.union(self.all_failed)
459 459 for m in follow.intersection(relevant):
460 460 dests.add(self.destinations[m])
461 461 if len(dests) > 1:
462 462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
463 463 self.fail_unreachable(msg_id)
464 464 return False
465 465 if targets:
466 466 # check blacklist+targets for impossibility
467 467 targets.difference_update(blacklist)
468 468 if not targets or not targets.intersection(self.targets):
469 469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
470 470 self.fail_unreachable(msg_id)
471 471 return False
472 472 return False
473 473 else:
474 474 indices = None
475 475
476 476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
477 477 return True
478 478
479 479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
480 480 """Save a message for later submission when its dependencies are met."""
481 481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
482 482 # track the ids in follow or after, but not those already finished
483 483 for dep_id in after.union(follow).difference(self.all_done):
484 484 if dep_id not in self.graph:
485 485 self.graph[dep_id] = set()
486 486 self.graph[dep_id].add(msg_id)
487 487
488 488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
489 489 """Submit a task to any of a subset of our targets."""
490 490 if indices:
491 491 loads = [self.loads[i] for i in indices]
492 492 else:
493 493 loads = self.loads
494 494 idx = self.scheme(loads)
495 495 if indices:
496 496 idx = indices[idx]
497 497 target = self.targets[idx]
498 498 # print (target, map(str, msg[:3]))
499 499 # send job to the engine
500 500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
501 501 self.engine_stream.send_multipart(raw_msg, copy=False)
502 502 # update load
503 503 self.add_job(idx)
504 504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
505 505 # notify Hub
506 506 content = dict(msg_id=msg_id, engine_id=target.decode('ascii'))
507 507 self.session.send(self.mon_stream, 'task_destination', content=content,
508 508 ident=[b'tracktask',self.ident])
509 509
510 510
511 511 #-----------------------------------------------------------------------
512 512 # Result Handling
513 513 #-----------------------------------------------------------------------
514 514 def dispatch_result(self, raw_msg):
515 515 """dispatch method for result replies"""
516 516 try:
517 517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
518 518 msg = self.session.unpack_message(msg, content=False, copy=False)
519 519 engine = idents[0]
520 520 try:
521 521 idx = self.targets.index(engine)
522 522 except ValueError:
523 523 pass # skip load-update for dead engines
524 524 else:
525 525 self.finish_job(idx)
526 526 except Exception:
527 527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
528 528 return
529 529
530 530 header = msg['header']
531 531 parent = msg['parent_header']
532 532 if header.get('dependencies_met', True):
533 533 success = (header['status'] == 'ok')
534 534 msg_id = parent['msg_id']
535 535 retries = self.retries[msg_id]
536 536 if not success and retries > 0:
537 537 # failed
538 538 self.retries[msg_id] = retries - 1
539 539 self.handle_unmet_dependency(idents, parent)
540 540 else:
541 541 del self.retries[msg_id]
542 542 # relay to client and update graph
543 543 self.handle_result(idents, parent, raw_msg, success)
544 544 # send to Hub monitor
545 545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
546 546 else:
547 547 self.handle_unmet_dependency(idents, parent)
548 548
549 549 def handle_result(self, idents, parent, raw_msg, success=True):
550 550 """handle a real task result, either success or failure"""
551 551 # first, relay result to client
552 552 engine = idents[0]
553 553 client = idents[1]
554 554 # swap_ids for XREP-XREP mirror
555 555 raw_msg[:2] = [client,engine]
556 556 # print (map(str, raw_msg[:4]))
557 557 self.client_stream.send_multipart(raw_msg, copy=False)
558 558 # now, update our data structures
559 559 msg_id = parent['msg_id']
560 560 self.blacklist.pop(msg_id, None)
561 561 self.pending[engine].pop(msg_id)
562 562 if success:
563 563 self.completed[engine].add(msg_id)
564 564 self.all_completed.add(msg_id)
565 565 else:
566 566 self.failed[engine].add(msg_id)
567 567 self.all_failed.add(msg_id)
568 568 self.all_done.add(msg_id)
569 569 self.destinations[msg_id] = engine
570 570
571 571 self.update_graph(msg_id, success)
572 572
573 573 def handle_unmet_dependency(self, idents, parent):
574 574 """handle an unmet dependency"""
575 575 engine = idents[0]
576 576 msg_id = parent['msg_id']
577 577
578 578 if msg_id not in self.blacklist:
579 579 self.blacklist[msg_id] = set()
580 580 self.blacklist[msg_id].add(engine)
581 581
582 582 args = self.pending[engine].pop(msg_id)
583 583 raw,targets,after,follow,timeout = args
584 584
585 585 if self.blacklist[msg_id] == targets:
586 586 self.depending[msg_id] = args
587 587 self.fail_unreachable(msg_id)
588 588 elif not self.maybe_run(msg_id, *args):
589 589 # resubmit failed
590 590 if msg_id not in self.all_failed:
591 591 # put it back in our dependency tree
592 592 self.save_unmet(msg_id, *args)
593 593
594 594 if self.hwm:
595 595 try:
596 596 idx = self.targets.index(engine)
597 597 except ValueError:
598 598 pass # skip load-update for dead engines
599 599 else:
600 600 if self.loads[idx] == self.hwm-1:
601 601 self.update_graph(None)
602 602
603 603
604 604
605 605 def update_graph(self, dep_id=None, success=True):
606 606 """dep_id just finished. Update our dependency
607 607 graph and submit any jobs that just became runable.
608 608
609 609 Called with dep_id=None to update entire graph for hwm, but without finishing
610 610 a task.
611 611 """
612 612 # print ("\n\n***********")
613 613 # pprint (dep_id)
614 614 # pprint (self.graph)
615 615 # pprint (self.depending)
616 616 # pprint (self.all_completed)
617 617 # pprint (self.all_failed)
618 618 # print ("\n\n***********\n\n")
619 619 # update any jobs that depended on the dependency
620 620 jobs = self.graph.pop(dep_id, [])
621 621
622 622 # recheck *all* jobs if
623 623 # a) we have HWM and an engine just become no longer full
624 624 # or b) dep_id was given as None
625 625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
626 626 jobs = self.depending.keys()
627 627
628 628 for msg_id in jobs:
629 629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
630 630
631 631 if after.unreachable(self.all_completed, self.all_failed)\
632 632 or follow.unreachable(self.all_completed, self.all_failed):
633 633 self.fail_unreachable(msg_id)
634 634
635 635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
636 636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
637 637
638 638 self.depending.pop(msg_id)
639 639 for mid in follow.union(after):
640 640 if mid in self.graph:
641 641 self.graph[mid].remove(msg_id)
642 642
643 643 #----------------------------------------------------------------------
644 644 # methods to be overridden by subclasses
645 645 #----------------------------------------------------------------------
646 646
647 647 def add_job(self, idx):
648 648 """Called after self.targets[idx] just got the job with header.
649 649 Override with subclasses. The default ordering is simple LRU.
650 650 The default loads are the number of outstanding jobs."""
651 651 self.loads[idx] += 1
652 652 for lis in (self.targets, self.loads):
653 653 lis.append(lis.pop(idx))
654 654
655 655
656 656 def finish_job(self, idx):
657 657 """Called after self.targets[idx] just finished a job.
658 658 Override with subclasses."""
659 659 self.loads[idx] -= 1
660 660
661 661
662 662
663 663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
664 664 logname='root', log_url=None, loglevel=logging.DEBUG,
665 665 identity=b'task', in_thread=False):
666 666
667 667 ZMQStream = zmqstream.ZMQStream
668 668
669 669 if config:
670 670 # unwrap dict back into Config
671 671 config = Config(config)
672 672
673 673 if in_thread:
674 674 # use instance() to get the same Context/Loop as our parent
675 675 ctx = zmq.Context.instance()
676 676 loop = ioloop.IOLoop.instance()
677 677 else:
678 678 # in a process, don't use instance()
679 679 # for safety with multiprocessing
680 680 ctx = zmq.Context()
681 681 loop = ioloop.IOLoop()
682 682 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
683 683 ins.setsockopt(zmq.IDENTITY, identity)
684 684 ins.bind(in_addr)
685 685
686 686 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
687 687 outs.setsockopt(zmq.IDENTITY, identity)
688 688 outs.bind(out_addr)
689 689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
690 690 mons.connect(mon_addr)
691 691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
692 692 nots.setsockopt(zmq.SUBSCRIBE, b'')
693 693 nots.connect(not_addr)
694 694
695 695 # setup logging.
696 696 if in_thread:
697 697 log = Application.instance().log
698 698 else:
699 699 if log_url:
700 700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
701 701 else:
702 702 log = local_logger(logname, loglevel)
703 703
704 704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
705 705 mon_stream=mons, notifier_stream=nots,
706 706 loop=loop, log=log,
707 707 config=config)
708 708 scheduler.start()
709 709 if not in_thread:
710 710 try:
711 711 loop.start()
712 712 except KeyboardInterrupt:
713 713 print ("interrupted, exiting...", file=sys.__stderr__)
714 714
@@ -1,401 +1,400 b''
1 1 """A TaskRecord backend using sqlite3
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import json
15 15 import os
16 16 import cPickle as pickle
17 17 from datetime import datetime
18 18
19 19 import sqlite3
20 20
21 21 from zmq.eventloop import ioloop
22 22
23 23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
24 24 from .dictdb import BaseDB
25 25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # SQLite operators, adapters, and converters
29 29 #-----------------------------------------------------------------------------
30 30
31 31 try:
32 32 buffer
33 33 except NameError:
34 34 # py3k
35 35 buffer = memoryview
36 36
37 37 operators = {
38 38 '$lt' : "<",
39 39 '$gt' : ">",
40 40 # null is handled weird with ==,!=
41 41 '$eq' : "=",
42 42 '$ne' : "!=",
43 43 '$lte': "<=",
44 44 '$gte': ">=",
45 45 '$in' : ('=', ' OR '),
46 46 '$nin': ('!=', ' AND '),
47 47 # '$all': None,
48 48 # '$mod': None,
49 49 # '$exists' : None
50 50 }
51 51 null_operators = {
52 52 '=' : "IS NULL",
53 53 '!=' : "IS NOT NULL",
54 54 }
55 55
56 56 def _adapt_dict(d):
57 57 return json.dumps(d, default=date_default)
58 58
59 59 def _convert_dict(ds):
60 60 if ds is None:
61 61 return ds
62 62 else:
63 63 if isinstance(ds, bytes):
64 64 # If I understand the sqlite doc correctly, this will always be utf8
65 65 ds = ds.decode('utf8')
66 d = json.loads(ds)
67 return extract_dates(d)
66 return extract_dates(json.loads(ds))
68 67
69 68 def _adapt_bufs(bufs):
70 69 # this is *horrible*
71 70 # copy buffers into single list and pickle it:
72 71 if bufs and isinstance(bufs[0], (bytes, buffer)):
73 72 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
74 73 elif bufs:
75 74 return bufs
76 75 else:
77 76 return None
78 77
79 78 def _convert_bufs(bs):
80 79 if bs is None:
81 80 return []
82 81 else:
83 82 return pickle.loads(bytes(bs))
84 83
85 84 #-----------------------------------------------------------------------------
86 85 # SQLiteDB class
87 86 #-----------------------------------------------------------------------------
88 87
89 88 class SQLiteDB(BaseDB):
90 89 """SQLite3 TaskRecord backend."""
91 90
92 91 filename = Unicode('tasks.db', config=True,
93 92 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
94 93 location = Unicode('', config=True,
95 94 help="""The directory containing the sqlite task database. The default
96 95 is to use the cluster_dir location.""")
97 96 table = Unicode("", config=True,
98 97 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
99 98 a new table will be created with the Hub's IDENT. Specifying the table will result
100 99 in tasks from previous sessions being available via Clients' db_query and
101 100 get_result methods.""")
102 101
103 102 _db = Instance('sqlite3.Connection')
104 103 # the ordered list of column names
105 104 _keys = List(['msg_id' ,
106 105 'header' ,
107 106 'content',
108 107 'buffers',
109 108 'submitted',
110 109 'client_uuid' ,
111 110 'engine_uuid' ,
112 111 'started',
113 112 'completed',
114 113 'resubmitted',
115 114 'result_header' ,
116 115 'result_content' ,
117 116 'result_buffers' ,
118 117 'queue' ,
119 118 'pyin' ,
120 119 'pyout',
121 120 'pyerr',
122 121 'stdout',
123 122 'stderr',
124 123 ])
125 124 # sqlite datatypes for checking that db is current format
126 125 _types = Dict({'msg_id' : 'text' ,
127 126 'header' : 'dict text',
128 127 'content' : 'dict text',
129 128 'buffers' : 'bufs blob',
130 129 'submitted' : 'timestamp',
131 130 'client_uuid' : 'text',
132 131 'engine_uuid' : 'text',
133 132 'started' : 'timestamp',
134 133 'completed' : 'timestamp',
135 134 'resubmitted' : 'timestamp',
136 135 'result_header' : 'dict text',
137 136 'result_content' : 'dict text',
138 137 'result_buffers' : 'bufs blob',
139 138 'queue' : 'text',
140 139 'pyin' : 'text',
141 140 'pyout' : 'text',
142 141 'pyerr' : 'text',
143 142 'stdout' : 'text',
144 143 'stderr' : 'text',
145 144 })
146 145
147 146 def __init__(self, **kwargs):
148 147 super(SQLiteDB, self).__init__(**kwargs)
149 148 if not self.table:
150 149 # use session, and prefix _, since starting with # is illegal
151 150 self.table = '_'+self.session.replace('-','_')
152 151 if not self.location:
153 152 # get current profile
154 153 from IPython.core.application import BaseIPythonApplication
155 154 if BaseIPythonApplication.initialized():
156 155 app = BaseIPythonApplication.instance()
157 156 if app.profile_dir is not None:
158 157 self.location = app.profile_dir.location
159 158 else:
160 159 self.location = u'.'
161 160 else:
162 161 self.location = u'.'
163 162 self._init_db()
164 163
165 164 # register db commit as 2s periodic callback
166 165 # to prevent clogging pipes
167 166 # assumes we are being run in a zmq ioloop app
168 167 loop = ioloop.IOLoop.instance()
169 168 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
170 169 pc.start()
171 170
172 171 def _defaults(self, keys=None):
173 172 """create an empty record"""
174 173 d = {}
175 174 keys = self._keys if keys is None else keys
176 175 for key in keys:
177 176 d[key] = None
178 177 return d
179 178
180 179 def _check_table(self):
181 180 """Ensure that an incorrect table doesn't exist
182 181
183 182 If a bad (old) table does exist, return False
184 183 """
185 184 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
186 185 lines = cursor.fetchall()
187 186 if not lines:
188 187 # table does not exist
189 188 return True
190 189 types = {}
191 190 keys = []
192 191 for line in lines:
193 192 keys.append(line[1])
194 193 types[line[1]] = line[2]
195 194 if self._keys != keys:
196 195 # key mismatch
197 196 self.log.warn('keys mismatch')
198 197 return False
199 198 for key in self._keys:
200 199 if types[key] != self._types[key]:
201 200 self.log.warn(
202 201 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
203 202 )
204 203 return False
205 204 return True
206 205
207 206 def _init_db(self):
208 207 """Connect to the database and get new session number."""
209 208 # register adapters
210 209 sqlite3.register_adapter(dict, _adapt_dict)
211 210 sqlite3.register_converter('dict', _convert_dict)
212 211 sqlite3.register_adapter(list, _adapt_bufs)
213 212 sqlite3.register_converter('bufs', _convert_bufs)
214 213 # connect to the db
215 214 dbfile = os.path.join(self.location, self.filename)
216 215 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
217 216 # isolation_level = None)#,
218 217 cached_statements=64)
219 218 # print dir(self._db)
220 219 first_table = self.table
221 220 i=0
222 221 while not self._check_table():
223 222 i+=1
224 223 self.table = first_table+'_%i'%i
225 224 self.log.warn(
226 225 "Table %s exists and doesn't match db format, trying %s"%
227 226 (first_table,self.table)
228 227 )
229 228
230 229 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
231 230 (msg_id text PRIMARY KEY,
232 231 header dict text,
233 232 content dict text,
234 233 buffers bufs blob,
235 234 submitted timestamp,
236 235 client_uuid text,
237 236 engine_uuid text,
238 237 started timestamp,
239 238 completed timestamp,
240 239 resubmitted timestamp,
241 240 result_header dict text,
242 241 result_content dict text,
243 242 result_buffers bufs blob,
244 243 queue text,
245 244 pyin text,
246 245 pyout text,
247 246 pyerr text,
248 247 stdout text,
249 248 stderr text)
250 249 """%self.table)
251 250 self._db.commit()
252 251
253 252 def _dict_to_list(self, d):
254 253 """turn a mongodb-style record dict into a list."""
255 254
256 255 return [ d[key] for key in self._keys ]
257 256
258 257 def _list_to_dict(self, line, keys=None):
259 258 """Inverse of dict_to_list"""
260 259 keys = self._keys if keys is None else keys
261 260 d = self._defaults(keys)
262 261 for key,value in zip(keys, line):
263 262 d[key] = value
264 263
265 264 return d
266 265
267 266 def _render_expression(self, check):
268 267 """Turn a mongodb-style search dict into an SQL query."""
269 268 expressions = []
270 269 args = []
271 270
272 271 skeys = set(check.keys())
273 272 skeys.difference_update(set(self._keys))
274 273 skeys.difference_update(set(['buffers', 'result_buffers']))
275 274 if skeys:
276 275 raise KeyError("Illegal testing key(s): %s"%skeys)
277 276
278 277 for name,sub_check in check.iteritems():
279 278 if isinstance(sub_check, dict):
280 279 for test,value in sub_check.iteritems():
281 280 try:
282 281 op = operators[test]
283 282 except KeyError:
284 283 raise KeyError("Unsupported operator: %r"%test)
285 284 if isinstance(op, tuple):
286 285 op, join = op
287 286
288 287 if value is None and op in null_operators:
289 288 expr = "%s %s"%null_operators[op]
290 289 else:
291 290 expr = "%s %s ?"%(name, op)
292 291 if isinstance(value, (tuple,list)):
293 292 if op in null_operators and any([v is None for v in value]):
294 293 # equality tests don't work with NULL
295 294 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
296 295 expr = '( %s )'%( join.join([expr]*len(value)) )
297 296 args.extend(value)
298 297 else:
299 298 args.append(value)
300 299 expressions.append(expr)
301 300 else:
302 301 # it's an equality check
303 302 if sub_check is None:
304 303 expressions.append("%s IS NULL")
305 304 else:
306 305 expressions.append("%s = ?"%name)
307 306 args.append(sub_check)
308 307
309 308 expr = " AND ".join(expressions)
310 309 return expr, args
311 310
312 311 def add_record(self, msg_id, rec):
313 312 """Add a new Task Record, by msg_id."""
314 313 d = self._defaults()
315 314 d.update(rec)
316 315 d['msg_id'] = msg_id
317 316 line = self._dict_to_list(d)
318 317 tups = '(%s)'%(','.join(['?']*len(line)))
319 318 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
320 319 # self._db.commit()
321 320
322 321 def get_record(self, msg_id):
323 322 """Get a specific Task Record, by msg_id."""
324 323 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
325 324 line = cursor.fetchone()
326 325 if line is None:
327 326 raise KeyError("No such msg: %r"%msg_id)
328 327 return self._list_to_dict(line)
329 328
330 329 def update_record(self, msg_id, rec):
331 330 """Update the data in an existing record."""
332 331 query = "UPDATE %s SET "%self.table
333 332 sets = []
334 333 keys = sorted(rec.keys())
335 334 values = []
336 335 for key in keys:
337 336 sets.append('%s = ?'%key)
338 337 values.append(rec[key])
339 338 query += ', '.join(sets)
340 339 query += ' WHERE msg_id == ?'
341 340 values.append(msg_id)
342 341 self._db.execute(query, values)
343 342 # self._db.commit()
344 343
345 344 def drop_record(self, msg_id):
346 345 """Remove a record from the DB."""
347 346 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
348 347 # self._db.commit()
349 348
350 349 def drop_matching_records(self, check):
351 350 """Remove a record from the DB."""
352 351 expr,args = self._render_expression(check)
353 352 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
354 353 self._db.execute(query,args)
355 354 # self._db.commit()
356 355
357 356 def find_records(self, check, keys=None):
358 357 """Find records matching a query dict, optionally extracting subset of keys.
359 358
360 359 Returns list of matching records.
361 360
362 361 Parameters
363 362 ----------
364 363
365 364 check: dict
366 365 mongodb-style query argument
367 366 keys: list of strs [optional]
368 367 if specified, the subset of keys to extract. msg_id will *always* be
369 368 included.
370 369 """
371 370 if keys:
372 371 bad_keys = [ key for key in keys if key not in self._keys ]
373 372 if bad_keys:
374 373 raise KeyError("Bad record key(s): %s"%bad_keys)
375 374
376 375 if keys:
377 376 # ensure msg_id is present and first:
378 377 if 'msg_id' in keys:
379 378 keys.remove('msg_id')
380 379 keys.insert(0, 'msg_id')
381 380 req = ', '.join(keys)
382 381 else:
383 382 req = '*'
384 383 expr,args = self._render_expression(check)
385 384 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
386 385 cursor = self._db.execute(query, args)
387 386 matches = cursor.fetchall()
388 387 records = []
389 388 for line in matches:
390 389 rec = self._list_to_dict(line, keys)
391 390 records.append(rec)
392 391 return records
393 392
394 393 def get_history(self):
395 394 """get all msg_ids, ordered by time submitted."""
396 395 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
397 396 cursor = self._db.execute(query)
398 397 # will be a list of length 1 tuples
399 398 return [ tup[0] for tup in cursor.fetchall()]
400 399
401 400 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,174 +1,174 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 # internal
26 26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 27 # from IPython.utils.localinterfaces import LOCALHOST
28 28
29 29 from IPython.parallel.controller.heartmonitor import Heart
30 30 from IPython.parallel.factory import RegistrationFactory
31 from IPython.parallel.util import disambiguate_url, ensure_bytes
31 from IPython.parallel.util import disambiguate_url, asbytes
32 32
33 33 from IPython.zmq.session import Message
34 34
35 35 from .streamkernel import Kernel
36 36
37 37 class EngineFactory(RegistrationFactory):
38 38 """IPython engine"""
39 39
40 40 # configurables:
41 41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 42 help="""The OutStream for handling stdout/err.
43 43 Typically 'IPython.zmq.iostream.OutStream'""")
44 44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 45 help="""The class for handling displayhook.
46 46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 47 location=Unicode(config=True,
48 48 help="""The location (an IP address) of the controller. This is
49 49 used for disambiguating URLs, to determine whether
50 50 loopback should be used to connect or the public address.""")
51 51 timeout=CFloat(2,config=True,
52 52 help="""The time (in seconds) to wait for the Controller to respond
53 53 to registration requests before giving up.""")
54 54
55 55 # not configurable:
56 56 user_ns=Dict()
57 57 id=Int(allow_none=True)
58 58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 59 kernel=Instance(Kernel)
60 60
61 61 bident = CBytes()
62 62 ident = Unicode()
63 63 def _ident_changed(self, name, old, new):
64 self.bident = ensure_bytes(new)
64 self.bident = asbytes(new)
65 65
66 66
67 67 def __init__(self, **kwargs):
68 68 super(EngineFactory, self).__init__(**kwargs)
69 69 self.ident = self.session.session
70 70 ctx = self.context
71 71
72 72 reg = ctx.socket(zmq.XREQ)
73 73 reg.setsockopt(zmq.IDENTITY, self.bident)
74 74 reg.connect(self.url)
75 75 self.registrar = zmqstream.ZMQStream(reg, self.loop)
76 76
77 77 def register(self):
78 78 """send the registration_request"""
79 79
80 80 self.log.info("Registering with controller at %s"%self.url)
81 81 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
82 82 self.registrar.on_recv(self.complete_registration)
83 83 # print (self.session.key)
84 84 self.session.send(self.registrar, "registration_request",content=content)
85 85
86 86 def complete_registration(self, msg):
87 87 # print msg
88 88 self._abort_dc.stop()
89 89 ctx = self.context
90 90 loop = self.loop
91 91 identity = self.bident
92 92 idents,msg = self.session.feed_identities(msg)
93 93 msg = Message(self.session.unpack_message(msg))
94 94
95 95 if msg.content.status == 'ok':
96 96 self.id = int(msg.content.id)
97 97
98 98 # create Shell Streams (MUX, Task, etc.):
99 99 queue_addr = msg.content.mux
100 100 shell_addrs = [ str(queue_addr) ]
101 101 task_addr = msg.content.task
102 102 if task_addr:
103 103 shell_addrs.append(str(task_addr))
104 104
105 105 # Uncomment this to go back to two-socket model
106 106 # shell_streams = []
107 107 # for addr in shell_addrs:
108 108 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
109 109 # stream.setsockopt(zmq.IDENTITY, identity)
110 110 # stream.connect(disambiguate_url(addr, self.location))
111 111 # shell_streams.append(stream)
112 112
113 113 # Now use only one shell stream for mux and tasks
114 114 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 115 stream.setsockopt(zmq.IDENTITY, identity)
116 116 shell_streams = [stream]
117 117 for addr in shell_addrs:
118 118 stream.connect(disambiguate_url(addr, self.location))
119 119 # end single stream-socket
120 120
121 121 # control stream:
122 122 control_addr = str(msg.content.control)
123 123 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
124 124 control_stream.setsockopt(zmq.IDENTITY, identity)
125 125 control_stream.connect(disambiguate_url(control_addr, self.location))
126 126
127 127 # create iopub stream:
128 128 iopub_addr = msg.content.iopub
129 129 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
130 130 iopub_stream.setsockopt(zmq.IDENTITY, identity)
131 131 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
132 132
133 133 # launch heartbeat
134 134 hb_addrs = msg.content.heartbeat
135 135 # print (hb_addrs)
136 136
137 137 # # Redirect input streams and set a display hook.
138 138 if self.out_stream_factory:
139 139 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
140 140 sys.stdout.topic = 'engine.%i.stdout'%self.id
141 141 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
142 142 sys.stderr.topic = 'engine.%i.stderr'%self.id
143 143 if self.display_hook_factory:
144 144 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
145 145 sys.displayhook.topic = 'engine.%i.pyout'%self.id
146 146
147 147 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
148 148 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
149 149 loop=loop, user_ns = self.user_ns, log=self.log)
150 150 self.kernel.start()
151 151 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
152 152 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 153 heart.start()
154 154
155 155
156 156 else:
157 157 self.log.fatal("Registration Failed: %s"%msg)
158 158 raise Exception("Registration Failed: %s"%msg)
159 159
160 160 self.log.info("Completed registration with id %i"%self.id)
161 161
162 162
163 163 def abort(self):
164 164 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
165 165 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
166 166 time.sleep(1)
167 167 sys.exit(255)
168 168
169 169 def start(self):
170 170 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
171 171 dc.start()
172 172 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
173 173 self._abort_dc.start()
174 174
@@ -1,438 +1,438 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 * Brian Granger
9 9 * Fernando Perez
10 10 * Evan Patterson
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2010-2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 # Standard library imports.
24 24 from __future__ import print_function
25 25
26 26 import sys
27 27 import time
28 28
29 29 from code import CommandCompiler
30 30 from datetime import datetime
31 31 from pprint import pprint
32 32
33 33 # System library imports.
34 34 import zmq
35 35 from zmq.eventloop import ioloop, zmqstream
36 36
37 37 # Local imports.
38 38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 39 from IPython.zmq.completer import KernelCompleter
40 40
41 41 from IPython.parallel.error import wrap_exception
42 42 from IPython.parallel.factory import SessionFactory
43 from IPython.parallel.util import serialize_object, unpack_apply_message, ensure_bytes
43 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44 44
45 45 def printer(*args):
46 46 pprint(args, stream=sys.__stdout__)
47 47
48 48
49 49 class _Passer(zmqstream.ZMQStream):
50 50 """Empty class that implements `send()` that does nothing.
51 51
52 52 Subclass ZMQStream for Session typechecking
53 53
54 54 """
55 55 def __init__(self, *args, **kwargs):
56 56 pass
57 57
58 58 def send(self, *args, **kwargs):
59 59 pass
60 60 send_multipart = send
61 61
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Main kernel class
65 65 #-----------------------------------------------------------------------------
66 66
67 67 class Kernel(SessionFactory):
68 68
69 69 #---------------------------------------------------------------------------
70 70 # Kernel interface
71 71 #---------------------------------------------------------------------------
72 72
73 73 # kwargs:
74 74 exec_lines = List(Unicode, config=True,
75 75 help="List of lines to execute")
76 76
77 77 # identities:
78 78 int_id = Int(-1)
79 79 bident = CBytes()
80 80 ident = Unicode()
81 81 def _ident_changed(self, name, old, new):
82 self.bident = ensure_bytes(new)
82 self.bident = asbytes(new)
83 83
84 84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85 85
86 86 control_stream = Instance(zmqstream.ZMQStream)
87 87 task_stream = Instance(zmqstream.ZMQStream)
88 88 iopub_stream = Instance(zmqstream.ZMQStream)
89 89 client = Instance('IPython.parallel.Client')
90 90
91 91 # internals
92 92 shell_streams = List()
93 93 compiler = Instance(CommandCompiler, (), {})
94 94 completer = Instance(KernelCompleter)
95 95
96 96 aborted = Set()
97 97 shell_handlers = Dict()
98 98 control_handlers = Dict()
99 99
100 100 def _set_prefix(self):
101 101 self.prefix = "engine.%s"%self.int_id
102 102
103 103 def _connect_completer(self):
104 104 self.completer = KernelCompleter(self.user_ns)
105 105
106 106 def __init__(self, **kwargs):
107 107 super(Kernel, self).__init__(**kwargs)
108 108 self._set_prefix()
109 109 self._connect_completer()
110 110
111 111 self.on_trait_change(self._set_prefix, 'id')
112 112 self.on_trait_change(self._connect_completer, 'user_ns')
113 113
114 114 # Build dict of handlers for message types
115 115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 116 'clear_request']:
117 117 self.shell_handlers[msg_type] = getattr(self, msg_type)
118 118
119 119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 120 self.control_handlers[msg_type] = getattr(self, msg_type)
121 121
122 122 self._initial_exec_lines()
123 123
124 124 def _wrap_exception(self, method=None):
125 125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 126 content=wrap_exception(e_info)
127 127 return content
128 128
129 129 def _initial_exec_lines(self):
130 130 s = _Passer()
131 131 content = dict(silent=True, user_variable=[],user_expressions=[])
132 132 for line in self.exec_lines:
133 133 self.log.debug("executing initialization: %s"%line)
134 134 content.update({'code':line})
135 135 msg = self.session.msg('execute_request', content)
136 136 self.execute_request(s, [], msg)
137 137
138 138
139 139 #-------------------- control handlers -----------------------------
140 140 def abort_queues(self):
141 141 for stream in self.shell_streams:
142 142 if stream:
143 143 self.abort_queue(stream)
144 144
145 145 def abort_queue(self, stream):
146 146 while True:
147 147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 148 if msg is None:
149 149 return
150 150
151 151 self.log.info("Aborting:")
152 152 self.log.info(str(msg))
153 153 msg_type = msg['msg_type']
154 154 reply_type = msg_type.split('_')[0] + '_reply'
155 155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 156 # self.reply_socket.send(ident,zmq.SNDMORE)
157 157 # self.reply_socket.send_json(reply_msg)
158 158 reply_msg = self.session.send(stream, reply_type,
159 159 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 160 self.log.debug(str(reply_msg))
161 161 # We need to wait a bit for requests to come in. This can probably
162 162 # be set shorter for true asynchronous clients.
163 163 time.sleep(0.05)
164 164
165 165 def abort_request(self, stream, ident, parent):
166 166 """abort a specifig msg by id"""
167 167 msg_ids = parent['content'].get('msg_ids', None)
168 168 if isinstance(msg_ids, basestring):
169 169 msg_ids = [msg_ids]
170 170 if not msg_ids:
171 171 self.abort_queues()
172 172 for mid in msg_ids:
173 173 self.aborted.add(str(mid))
174 174
175 175 content = dict(status='ok')
176 176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 177 parent=parent, ident=ident)
178 178 self.log.debug(str(reply_msg))
179 179
180 180 def shutdown_request(self, stream, ident, parent):
181 181 """kill ourself. This should really be handled in an external process"""
182 182 try:
183 183 self.abort_queues()
184 184 except:
185 185 content = self._wrap_exception('shutdown')
186 186 else:
187 187 content = dict(parent['content'])
188 188 content['status'] = 'ok'
189 189 msg = self.session.send(stream, 'shutdown_reply',
190 190 content=content, parent=parent, ident=ident)
191 191 self.log.debug(str(msg))
192 192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 193 dc.start()
194 194
195 195 def dispatch_control(self, msg):
196 196 idents,msg = self.session.feed_identities(msg, copy=False)
197 197 try:
198 198 msg = self.session.unpack_message(msg, content=True, copy=False)
199 199 except:
200 200 self.log.error("Invalid Message", exc_info=True)
201 201 return
202 202 else:
203 203 self.log.debug("Control received, %s", msg)
204 204
205 205 header = msg['header']
206 206 msg_id = header['msg_id']
207 207
208 208 handler = self.control_handlers.get(msg['msg_type'], None)
209 209 if handler is None:
210 210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 211 else:
212 212 handler(self.control_stream, idents, msg)
213 213
214 214
215 215 #-------------------- queue helpers ------------------------------
216 216
217 217 def check_dependencies(self, dependencies):
218 218 if not dependencies:
219 219 return True
220 220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 221 anyorall = dependencies[0]
222 222 dependencies = dependencies[1]
223 223 else:
224 224 anyorall = 'all'
225 225 results = self.client.get_results(dependencies,status_only=True)
226 226 if results['status'] != 'ok':
227 227 return False
228 228
229 229 if anyorall == 'any':
230 230 if not results['completed']:
231 231 return False
232 232 else:
233 233 if results['pending']:
234 234 return False
235 235
236 236 return True
237 237
238 238 def check_aborted(self, msg_id):
239 239 return msg_id in self.aborted
240 240
241 241 #-------------------- queue handlers -----------------------------
242 242
243 243 def clear_request(self, stream, idents, parent):
244 244 """Clear our namespace."""
245 245 self.user_ns = {}
246 246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 247 content = dict(status='ok'))
248 248 self._initial_exec_lines()
249 249
250 250 def execute_request(self, stream, ident, parent):
251 251 self.log.debug('execute request %s'%parent)
252 252 try:
253 253 code = parent[u'content'][u'code']
254 254 except:
255 255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 256 return
257 257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 ident=ensure_bytes('%s.pyin'%self.prefix))
258 ident=asbytes('%s.pyin'%self.prefix))
259 259 started = datetime.now()
260 260 try:
261 261 comp_code = self.compiler(code, '<zmq-kernel>')
262 262 # allow for not overriding displayhook
263 263 if hasattr(sys.displayhook, 'set_parent'):
264 264 sys.displayhook.set_parent(parent)
265 265 sys.stdout.set_parent(parent)
266 266 sys.stderr.set_parent(parent)
267 267 exec comp_code in self.user_ns, self.user_ns
268 268 except:
269 269 exc_content = self._wrap_exception('execute')
270 270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 ident=ensure_bytes('%s.pyerr'%self.prefix))
272 ident=asbytes('%s.pyerr'%self.prefix))
273 273 reply_content = exc_content
274 274 else:
275 275 reply_content = {'status' : 'ok'}
276 276
277 277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 278 ident=ident, subheader = dict(started=started))
279 279 self.log.debug(str(reply_msg))
280 280 if reply_msg['content']['status'] == u'error':
281 281 self.abort_queues()
282 282
283 283 def complete_request(self, stream, ident, parent):
284 284 matches = {'matches' : self.complete(parent),
285 285 'status' : 'ok'}
286 286 completion_msg = self.session.send(stream, 'complete_reply',
287 287 matches, parent, ident)
288 288 # print >> sys.__stdout__, completion_msg
289 289
290 290 def complete(self, msg):
291 291 return self.completer.complete(msg.content.line, msg.content.text)
292 292
293 293 def apply_request(self, stream, ident, parent):
294 294 # flush previous reply, so this request won't block it
295 295 stream.flush(zmq.POLLOUT)
296 296 try:
297 297 content = parent[u'content']
298 298 bufs = parent[u'buffers']
299 299 msg_id = parent['header']['msg_id']
300 300 # bound = parent['header'].get('bound', False)
301 301 except:
302 302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 303 return
304 304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 305 # self.iopub_stream.send(pyin_msg)
306 306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 307 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 308 'started': datetime.now()}
309 309 try:
310 310 # allow for not overriding displayhook
311 311 if hasattr(sys.displayhook, 'set_parent'):
312 312 sys.displayhook.set_parent(parent)
313 313 sys.stdout.set_parent(parent)
314 314 sys.stderr.set_parent(parent)
315 315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 316 working = self.user_ns
317 317 # suffix =
318 318 prefix = "_"+str(msg_id).replace("-","")+"_"
319 319
320 320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 321 # if bound:
322 322 # bound_ns = Namespace(working)
323 323 # args = [bound_ns]+list(args)
324 324
325 325 fname = getattr(f, '__name__', 'f')
326 326
327 327 fname = prefix+"f"
328 328 argname = prefix+"args"
329 329 kwargname = prefix+"kwargs"
330 330 resultname = prefix+"result"
331 331
332 332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 333 # print ns
334 334 working.update(ns)
335 335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 336 try:
337 337 exec code in working,working
338 338 result = working.get(resultname)
339 339 finally:
340 340 for key in ns.iterkeys():
341 341 working.pop(key)
342 342 # if bound:
343 343 # working.update(bound_ns)
344 344
345 345 packed_result,buf = serialize_object(result)
346 346 result_buf = [packed_result]+buf
347 347 except:
348 348 exc_content = self._wrap_exception('apply')
349 349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 ident=ensure_bytes('%s.pyerr'%self.prefix))
351 ident=asbytes('%s.pyerr'%self.prefix))
352 352 reply_content = exc_content
353 353 result_buf = []
354 354
355 355 if exc_content['ename'] == 'UnmetDependency':
356 356 sub['dependencies_met'] = False
357 357 else:
358 358 reply_content = {'status' : 'ok'}
359 359
360 360 # put 'ok'/'error' status in header, for scheduler introspection:
361 361 sub['status'] = reply_content['status']
362 362
363 363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365 365
366 366 # flush i/o
367 367 # should this be before reply_msg is sent, like in the single-kernel code,
368 368 # or should nothing get in the way of real results?
369 369 sys.stdout.flush()
370 370 sys.stderr.flush()
371 371
372 372 def dispatch_queue(self, stream, msg):
373 373 self.control_stream.flush()
374 374 idents,msg = self.session.feed_identities(msg, copy=False)
375 375 try:
376 376 msg = self.session.unpack_message(msg, content=True, copy=False)
377 377 except:
378 378 self.log.error("Invalid Message", exc_info=True)
379 379 return
380 380 else:
381 381 self.log.debug("Message received, %s", msg)
382 382
383 383
384 384 header = msg['header']
385 385 msg_id = header['msg_id']
386 386 if self.check_aborted(msg_id):
387 387 self.aborted.remove(msg_id)
388 388 # is it safe to assume a msg_id will not be resubmitted?
389 389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
390 390 status = {'status' : 'aborted'}
391 391 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 392 content=status, parent=msg, ident=idents)
393 393 return
394 394 handler = self.shell_handlers.get(msg['msg_type'], None)
395 395 if handler is None:
396 396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
397 397 else:
398 398 handler(stream, idents, msg)
399 399
400 400 def start(self):
401 401 #### stream mode:
402 402 if self.control_stream:
403 403 self.control_stream.on_recv(self.dispatch_control, copy=False)
404 404 self.control_stream.on_err(printer)
405 405
406 406 def make_dispatcher(stream):
407 407 def dispatcher(msg):
408 408 return self.dispatch_queue(stream, msg)
409 409 return dispatcher
410 410
411 411 for s in self.shell_streams:
412 412 s.on_recv(make_dispatcher(s), copy=False)
413 413 s.on_err(printer)
414 414
415 415 if self.iopub_stream:
416 416 self.iopub_stream.on_err(printer)
417 417
418 418 #### while True mode:
419 419 # while True:
420 420 # idle = True
421 421 # try:
422 422 # msg = self.shell_stream.socket.recv_multipart(
423 423 # zmq.NOBLOCK, copy=False)
424 424 # except zmq.ZMQError, e:
425 425 # if e.errno != zmq.EAGAIN:
426 426 # raise e
427 427 # else:
428 428 # idle=False
429 429 # self.dispatch_queue(self.shell_stream, msg)
430 430 #
431 431 # if not self.task_stream.empty():
432 432 # idle=False
433 433 # msg = self.task_stream.recv_multipart()
434 434 # self.dispatch_queue(self.task_stream, msg)
435 435 # if idle:
436 436 # # don't busywait
437 437 # time.sleep(1e-3)
438 438
@@ -1,136 +1,137 b''
1 1 """base class for parallel client tests
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 import sys
16 16 import tempfile
17 17 import time
18 18
19 19 from nose import SkipTest
20 20
21 21 import zmq
22 22 from zmq.tests import BaseZMQTestCase
23 23
24 24 from IPython.external.decorator import decorator
25 25
26 26 from IPython.parallel import error
27 27 from IPython.parallel import Client
28 28
29 29 from IPython.parallel.tests import launchers, add_engines
30 30
31 31 # simple tasks for use in apply tests
32 32
33 33 def segfault():
34 34 """this will segfault"""
35 35 import ctypes
36 36 ctypes.memset(-1,0,1)
37 37
38 38 def crash():
39 39 """from stdlib crashers in the test suite"""
40 40 import types
41 41 if sys.platform.startswith('win'):
42 42 import ctypes
43 43 ctypes.windll.kernel32.SetErrorMode(0x0002);
44 44 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 45 if sys.version_info[0] >= 3:
46 # Python3 adds 'kwonlyargcount' as the second argument to Code
46 47 args.insert(1, 0)
47 48
48 49 co = types.CodeType(*args)
49 50 exec(co)
50 51
51 52 def wait(n):
52 53 """sleep for a time"""
53 54 import time
54 55 time.sleep(n)
55 56 return n
56 57
57 58 def raiser(eclass):
58 59 """raise an exception"""
59 60 raise eclass()
60 61
61 62 # test decorator for skipping tests when libraries are unavailable
62 63 def skip_without(*names):
63 64 """skip a test if some names are not importable"""
64 65 @decorator
65 66 def skip_without_names(f, *args, **kwargs):
66 67 """decorator to skip tests in the absence of numpy."""
67 68 for name in names:
68 69 try:
69 70 __import__(name)
70 71 except ImportError:
71 72 raise SkipTest
72 73 return f(*args, **kwargs)
73 74 return skip_without_names
74 75
75 76 class ClusterTestCase(BaseZMQTestCase):
76 77
77 78 def add_engines(self, n=1, block=True):
78 79 """add multiple engines to our cluster"""
79 80 self.engines.extend(add_engines(n))
80 81 if block:
81 82 self.wait_on_engines()
82 83
83 84 def wait_on_engines(self, timeout=5):
84 85 """wait for our engines to connect."""
85 86 n = len(self.engines)+self.base_engine_count
86 87 tic = time.time()
87 88 while time.time()-tic < timeout and len(self.client.ids) < n:
88 89 time.sleep(0.1)
89 90
90 91 assert not len(self.client.ids) < n, "waiting for engines timed out"
91 92
92 93 def connect_client(self):
93 94 """connect a client with my Context, and track its sockets for cleanup"""
94 95 c = Client(profile='iptest', context=self.context)
95 96 for name in filter(lambda n:n.endswith('socket'), dir(c)):
96 97 s = getattr(c, name)
97 98 s.setsockopt(zmq.LINGER, 0)
98 99 self.sockets.append(s)
99 100 return c
100 101
101 102 def assertRaisesRemote(self, etype, f, *args, **kwargs):
102 103 try:
103 104 try:
104 105 f(*args, **kwargs)
105 106 except error.CompositeError as e:
106 107 e.raise_exception()
107 108 except error.RemoteError as e:
108 109 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
109 110 else:
110 111 self.fail("should have raised a RemoteError")
111 112
112 113 def setUp(self):
113 114 BaseZMQTestCase.setUp(self)
114 115 self.client = self.connect_client()
115 116 # start every test with clean engine namespaces:
116 117 self.client.clear(block=True)
117 118 self.base_engine_count=len(self.client.ids)
118 119 self.engines=[]
119 120
120 121 def tearDown(self):
121 122 # self.client.clear(block=True)
122 123 # close fds:
123 124 for e in filter(lambda e: e.poll() is not None, launchers):
124 125 launchers.remove(e)
125 126
126 127 # allow flushing of incoming messages to prevent crash on socket close
127 128 self.client.wait(timeout=2)
128 129 # time.sleep(2)
129 130 self.client.spin()
130 131 self.client.close()
131 132 BaseZMQTestCase.tearDown(self)
132 133 # this will be redundant when pyzmq merges PR #88
133 134 # self.context.term()
134 135 # print tempfile.TemporaryFile().fileno(),
135 136 # sys.stdout.flush()
136 137 No newline at end of file
@@ -1,456 +1,456 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 44 from IPython.utils.newserialized import serialize, unserialize
45 45 from IPython.zmq.log import EnginePUBHandler
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Classes
49 49 #-----------------------------------------------------------------------------
50 50
51 51 class Namespace(dict):
52 52 """Subclass of dict for attribute access to keys."""
53 53
54 54 def __getattr__(self, key):
55 55 """getattr aliased to getitem"""
56 56 if key in self.iterkeys():
57 57 return self[key]
58 58 else:
59 59 raise NameError(key)
60 60
61 61 def __setattr__(self, key, value):
62 62 """setattr aliased to setitem, with strict"""
63 63 if hasattr(dict, key):
64 64 raise KeyError("Cannot override dict keys %r"%key)
65 65 self[key] = value
66 66
67 67
68 68 class ReverseDict(dict):
69 69 """simple double-keyed subset of dict methods."""
70 70
71 71 def __init__(self, *args, **kwargs):
72 72 dict.__init__(self, *args, **kwargs)
73 73 self._reverse = dict()
74 74 for key, value in self.iteritems():
75 75 self._reverse[value] = key
76 76
77 77 def __getitem__(self, key):
78 78 try:
79 79 return dict.__getitem__(self, key)
80 80 except KeyError:
81 81 return self._reverse[key]
82 82
83 83 def __setitem__(self, key, value):
84 84 if key in self._reverse:
85 85 raise KeyError("Can't have key %r on both sides!"%key)
86 86 dict.__setitem__(self, key, value)
87 87 self._reverse[value] = key
88 88
89 89 def pop(self, key):
90 90 value = dict.pop(self, key)
91 91 self._reverse.pop(value)
92 92 return value
93 93
94 94 def get(self, key, default=None):
95 95 try:
96 96 return self[key]
97 97 except KeyError:
98 98 return default
99 99
100 100 #-----------------------------------------------------------------------------
101 101 # Functions
102 102 #-----------------------------------------------------------------------------
103 103
104 def ensure_bytes(s):
104 def asbytes(s):
105 105 """ensure that an object is ascii bytes"""
106 106 if isinstance(s, unicode):
107 107 s = s.encode('ascii')
108 108 return s
109 109
110 110 def validate_url(url):
111 111 """validate a url for zeromq"""
112 112 if not isinstance(url, basestring):
113 113 raise TypeError("url must be a string, not %r"%type(url))
114 114 url = url.lower()
115 115
116 116 proto_addr = url.split('://')
117 117 assert len(proto_addr) == 2, 'Invalid url: %r'%url
118 118 proto, addr = proto_addr
119 119 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
120 120
121 121 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
122 122 # author: Remi Sabourin
123 123 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
124 124
125 125 if proto == 'tcp':
126 126 lis = addr.split(':')
127 127 assert len(lis) == 2, 'Invalid url: %r'%url
128 128 addr,s_port = lis
129 129 try:
130 130 port = int(s_port)
131 131 except ValueError:
132 132 raise AssertionError("Invalid port %r in url: %r"%(port, url))
133 133
134 134 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
135 135
136 136 else:
137 137 # only validate tcp urls currently
138 138 pass
139 139
140 140 return True
141 141
142 142
143 143 def validate_url_container(container):
144 144 """validate a potentially nested collection of urls."""
145 145 if isinstance(container, basestring):
146 146 url = container
147 147 return validate_url(url)
148 148 elif isinstance(container, dict):
149 149 container = container.itervalues()
150 150
151 151 for element in container:
152 152 validate_url_container(element)
153 153
154 154
155 155 def split_url(url):
156 156 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
157 157 proto_addr = url.split('://')
158 158 assert len(proto_addr) == 2, 'Invalid url: %r'%url
159 159 proto, addr = proto_addr
160 160 lis = addr.split(':')
161 161 assert len(lis) == 2, 'Invalid url: %r'%url
162 162 addr,s_port = lis
163 163 return proto,addr,s_port
164 164
165 165 def disambiguate_ip_address(ip, location=None):
166 166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
167 167 ones, based on the location (default interpretation of location is localhost)."""
168 168 if ip in ('0.0.0.0', '*'):
169 169 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
170 170 if location is None or location in external_ips:
171 171 ip='127.0.0.1'
172 172 elif location:
173 173 return location
174 174 return ip
175 175
176 176 def disambiguate_url(url, location=None):
177 177 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
178 178 ones, based on the location (default interpretation is localhost).
179 179
180 180 This is for zeromq urls, such as tcp://*:10101."""
181 181 try:
182 182 proto,ip,port = split_url(url)
183 183 except AssertionError:
184 184 # probably not tcp url; could be ipc, etc.
185 185 return url
186 186
187 187 ip = disambiguate_ip_address(ip,location)
188 188
189 189 return "%s://%s:%s"%(proto,ip,port)
190 190
191 191 def serialize_object(obj, threshold=64e-6):
192 192 """Serialize an object into a list of sendable buffers.
193 193
194 194 Parameters
195 195 ----------
196 196
197 197 obj : object
198 198 The object to be serialized
199 199 threshold : float
200 200 The threshold for not double-pickling the content.
201 201
202 202
203 203 Returns
204 204 -------
205 205 ('pmd', [bufs]) :
206 206 where pmd is the pickled metadata wrapper,
207 207 bufs is a list of data buffers
208 208 """
209 209 databuffers = []
210 210 if isinstance(obj, (list, tuple)):
211 211 clist = canSequence(obj)
212 212 slist = map(serialize, clist)
213 213 for s in slist:
214 214 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
215 215 databuffers.append(s.getData())
216 216 s.data = None
217 217 return pickle.dumps(slist,-1), databuffers
218 218 elif isinstance(obj, dict):
219 219 sobj = {}
220 220 for k in sorted(obj.iterkeys()):
221 221 s = serialize(can(obj[k]))
222 222 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
223 223 databuffers.append(s.getData())
224 224 s.data = None
225 225 sobj[k] = s
226 226 return pickle.dumps(sobj,-1),databuffers
227 227 else:
228 228 s = serialize(can(obj))
229 229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 230 databuffers.append(s.getData())
231 231 s.data = None
232 232 return pickle.dumps(s,-1),databuffers
233 233
234 234
235 235 def unserialize_object(bufs):
236 236 """reconstruct an object serialized by serialize_object from data buffers."""
237 237 bufs = list(bufs)
238 238 sobj = pickle.loads(bufs.pop(0))
239 239 if isinstance(sobj, (list, tuple)):
240 240 for s in sobj:
241 241 if s.data is None:
242 242 s.data = bufs.pop(0)
243 243 return uncanSequence(map(unserialize, sobj)), bufs
244 244 elif isinstance(sobj, dict):
245 245 newobj = {}
246 246 for k in sorted(sobj.iterkeys()):
247 247 s = sobj[k]
248 248 if s.data is None:
249 249 s.data = bufs.pop(0)
250 250 newobj[k] = uncan(unserialize(s))
251 251 return newobj, bufs
252 252 else:
253 253 if sobj.data is None:
254 254 sobj.data = bufs.pop(0)
255 255 return uncan(unserialize(sobj)), bufs
256 256
257 257 def pack_apply_message(f, args, kwargs, threshold=64e-6):
258 258 """pack up a function, args, and kwargs to be sent over the wire
259 259 as a series of buffers. Any object whose data is larger than `threshold`
260 260 will not have their data copied (currently only numpy arrays support zero-copy)"""
261 261 msg = [pickle.dumps(can(f),-1)]
262 262 databuffers = [] # for large objects
263 263 sargs, bufs = serialize_object(args,threshold)
264 264 msg.append(sargs)
265 265 databuffers.extend(bufs)
266 266 skwargs, bufs = serialize_object(kwargs,threshold)
267 267 msg.append(skwargs)
268 268 databuffers.extend(bufs)
269 269 msg.extend(databuffers)
270 270 return msg
271 271
272 272 def unpack_apply_message(bufs, g=None, copy=True):
273 273 """unpack f,args,kwargs from buffers packed by pack_apply_message()
274 274 Returns: original f,args,kwargs"""
275 275 bufs = list(bufs) # allow us to pop
276 276 assert len(bufs) >= 3, "not enough buffers!"
277 277 if not copy:
278 278 for i in range(3):
279 279 bufs[i] = bufs[i].bytes
280 280 cf = pickle.loads(bufs.pop(0))
281 281 sargs = list(pickle.loads(bufs.pop(0)))
282 282 skwargs = dict(pickle.loads(bufs.pop(0)))
283 283 # print sargs, skwargs
284 284 f = uncan(cf, g)
285 285 for sa in sargs:
286 286 if sa.data is None:
287 287 m = bufs.pop(0)
288 288 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
289 289 # always use a buffer, until memoryviews get sorted out
290 290 sa.data = buffer(m)
291 291 # disable memoryview support
292 292 # if copy:
293 293 # sa.data = buffer(m)
294 294 # else:
295 295 # sa.data = m.buffer
296 296 else:
297 297 if copy:
298 298 sa.data = m
299 299 else:
300 300 sa.data = m.bytes
301 301
302 302 args = uncanSequence(map(unserialize, sargs), g)
303 303 kwargs = {}
304 304 for k in sorted(skwargs.iterkeys()):
305 305 sa = skwargs[k]
306 306 if sa.data is None:
307 307 m = bufs.pop(0)
308 308 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
309 309 # always use a buffer, until memoryviews get sorted out
310 310 sa.data = buffer(m)
311 311 # disable memoryview support
312 312 # if copy:
313 313 # sa.data = buffer(m)
314 314 # else:
315 315 # sa.data = m.buffer
316 316 else:
317 317 if copy:
318 318 sa.data = m
319 319 else:
320 320 sa.data = m.bytes
321 321
322 322 kwargs[k] = uncan(unserialize(sa), g)
323 323
324 324 return f,args,kwargs
325 325
326 326 #--------------------------------------------------------------------------
327 327 # helpers for implementing old MEC API via view.apply
328 328 #--------------------------------------------------------------------------
329 329
330 330 def interactive(f):
331 331 """decorator for making functions appear as interactively defined.
332 332 This results in the function being linked to the user_ns as globals()
333 333 instead of the module globals().
334 334 """
335 335 f.__module__ = '__main__'
336 336 return f
337 337
338 338 @interactive
339 339 def _push(ns):
340 340 """helper method for implementing `client.push` via `client.apply`"""
341 341 globals().update(ns)
342 342
343 343 @interactive
344 344 def _pull(keys):
345 345 """helper method for implementing `client.pull` via `client.apply`"""
346 346 user_ns = globals()
347 347 if isinstance(keys, (list,tuple, set)):
348 348 for key in keys:
349 349 if not user_ns.has_key(key):
350 350 raise NameError("name '%s' is not defined"%key)
351 351 return map(user_ns.get, keys)
352 352 else:
353 353 if not user_ns.has_key(keys):
354 354 raise NameError("name '%s' is not defined"%keys)
355 355 return user_ns.get(keys)
356 356
357 357 @interactive
358 358 def _execute(code):
359 359 """helper method for implementing `client.execute` via `client.apply`"""
360 360 exec code in globals()
361 361
362 362 #--------------------------------------------------------------------------
363 363 # extra process management utilities
364 364 #--------------------------------------------------------------------------
365 365
366 366 _random_ports = set()
367 367
368 368 def select_random_ports(n):
369 369 """Selects and return n random ports that are available."""
370 370 ports = []
371 371 for i in xrange(n):
372 372 sock = socket.socket()
373 373 sock.bind(('', 0))
374 374 while sock.getsockname()[1] in _random_ports:
375 375 sock.close()
376 376 sock = socket.socket()
377 377 sock.bind(('', 0))
378 378 ports.append(sock)
379 379 for i, sock in enumerate(ports):
380 380 port = sock.getsockname()[1]
381 381 sock.close()
382 382 ports[i] = port
383 383 _random_ports.add(port)
384 384 return ports
385 385
386 386 def signal_children(children):
387 387 """Relay interupt/term signals to children, for more solid process cleanup."""
388 388 def terminate_children(sig, frame):
389 389 logging.critical("Got signal %i, terminating children..."%sig)
390 390 for child in children:
391 391 child.terminate()
392 392
393 393 sys.exit(sig != SIGINT)
394 394 # sys.exit(sig)
395 395 for sig in (SIGINT, SIGABRT, SIGTERM):
396 396 signal(sig, terminate_children)
397 397
398 398 def generate_exec_key(keyfile):
399 399 import uuid
400 400 newkey = str(uuid.uuid4())
401 401 with open(keyfile, 'w') as f:
402 402 # f.write('ipython-key ')
403 403 f.write(newkey+'\n')
404 404 # set user-only RW permissions (0600)
405 405 # this will have no effect on Windows
406 406 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
407 407
408 408
409 409 def integer_loglevel(loglevel):
410 410 try:
411 411 loglevel = int(loglevel)
412 412 except ValueError:
413 413 if isinstance(loglevel, str):
414 414 loglevel = getattr(logging, loglevel)
415 415 return loglevel
416 416
417 417 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
418 418 logger = logging.getLogger(logname)
419 419 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
420 420 # don't add a second PUBHandler
421 421 return
422 422 loglevel = integer_loglevel(loglevel)
423 423 lsock = context.socket(zmq.PUB)
424 424 lsock.connect(iface)
425 425 handler = handlers.PUBHandler(lsock)
426 426 handler.setLevel(loglevel)
427 427 handler.root_topic = root
428 428 logger.addHandler(handler)
429 429 logger.setLevel(loglevel)
430 430
431 431 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
432 432 logger = logging.getLogger()
433 433 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
434 434 # don't add a second PUBHandler
435 435 return
436 436 loglevel = integer_loglevel(loglevel)
437 437 lsock = context.socket(zmq.PUB)
438 438 lsock.connect(iface)
439 439 handler = EnginePUBHandler(engine, lsock)
440 440 handler.setLevel(loglevel)
441 441 logger.addHandler(handler)
442 442 logger.setLevel(loglevel)
443 443 return logger
444 444
445 445 def local_logger(logname, loglevel=logging.DEBUG):
446 446 loglevel = integer_loglevel(loglevel)
447 447 logger = logging.getLogger(logname)
448 448 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
449 449 # don't add a second StreamHandler
450 450 return
451 451 handler = logging.StreamHandler()
452 452 handler.setLevel(loglevel)
453 453 logger.addHandler(handler)
454 454 logger.setLevel(loglevel)
455 455 return logger
456 456
@@ -1,43 +1,43 b''
1 1 # encoding: utf-8
2 2
3 3 """Utilities to enable code objects to be pickled.
4 4
5 5 Any process that import this module will be able to pickle code objects. This
6 6 includes the func_code attribute of any function. Once unpickled, new
7 7 functions can be built using new.function(code, globals()). Eventually
8 8 we need to automate all of this so that functions themselves can be pickled.
9 9
10 10 Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305
11 11 """
12 12
13 13 __docformat__ = "restructuredtext en"
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Copyright (C) 2008 The IPython Development Team
17 17 #
18 18 # Distributed under the terms of the BSD License. The full license is in
19 19 # the file COPYING, distributed as part of this software.
20 20 #-------------------------------------------------------------------------------
21 21
22 22 #-------------------------------------------------------------------------------
23 23 # Imports
24 24 #-------------------------------------------------------------------------------
25 25
26 26 import sys
27 import new, types, copy_reg
27 import types, copy_reg
28 28
29 29 def code_ctor(*args):
30 return new.code(*args)
30 return types.CodeType(*args)
31 31
32 32 def reduce_code(co):
33 33 if co.co_freevars or co.co_cellvars:
34 34 raise ValueError("Sorry, cannot pickle code objects with closures")
35 35 args = [co.co_argcount, co.co_nlocals, co.co_stacksize,
36 36 co.co_flags, co.co_code, co.co_consts, co.co_names,
37 37 co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno,
38 38 co.co_lnotab]
39 39 if sys.version_info[0] >= 3:
40 40 args.insert(1, co.co_kwonlyargcount)
41 41 return code_ctor, tuple(args)
42 42
43 43 copy_reg.pickle(types.CodeType, reduce_code) No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now