##// END OF EJS Templates
update parallel code for py3k...
MinRK -
Show More
@@ -1,423 +1,422 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import os
27 27 import socket
28 28 import stat
29 29 import sys
30 30 import uuid
31 31
32 32 from multiprocessing import Process
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37 from zmq.utils import jsonapi as json
38 38
39 39 from IPython.config.application import boolean_flag
40 40 from IPython.core.profiledir import ProfileDir
41 41
42 42 from IPython.parallel.apps.baseapp import (
43 43 BaseParallelApplication,
44 44 base_aliases,
45 45 base_flags,
46 46 )
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
49 49
50 50 # from IPython.parallel.controller.controller import ControllerFactory
51 51 from IPython.zmq.session import Session
52 52 from IPython.parallel.controller.heartmonitor import HeartMonitor
53 53 from IPython.parallel.controller.hub import HubFactory
54 54 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
55 55 from IPython.parallel.controller.sqlitedb import SQLiteDB
56 56
57 from IPython.parallel.util import signal_children, split_url
57 from IPython.parallel.util import signal_children, split_url, ensure_bytes
58 58
59 59 # conditional import of MongoDB backend class
60 60
61 61 try:
62 62 from IPython.parallel.controller.mongodb import MongoDB
63 63 except ImportError:
64 64 maybe_mongo = []
65 65 else:
66 66 maybe_mongo = [MongoDB]
67 67
68 68
69 69 #-----------------------------------------------------------------------------
70 70 # Module level variables
71 71 #-----------------------------------------------------------------------------
72 72
73 73
74 74 #: The default config file name for this application
75 75 default_config_file_name = u'ipcontroller_config.py'
76 76
77 77
78 78 _description = """Start the IPython controller for parallel computing.
79 79
80 80 The IPython controller provides a gateway between the IPython engines and
81 81 clients. The controller needs to be started before the engines and can be
82 82 configured using command line options or using a cluster directory. Cluster
83 83 directories contain config, log and security files and are usually located in
84 84 your ipython directory and named as "profile_name". See the `profile`
85 85 and `profile_dir` options for details.
86 86 """
87 87
88 88
89 89
90 90
91 91 #-----------------------------------------------------------------------------
92 92 # The main application
93 93 #-----------------------------------------------------------------------------
94 94 flags = {}
95 95 flags.update(base_flags)
96 96 flags.update({
97 97 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
98 98 'Use threads instead of processes for the schedulers'),
99 99 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
100 100 'use the SQLiteDB backend'),
101 101 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
102 102 'use the MongoDB backend'),
103 103 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
104 104 'use the in-memory DictDB backend'),
105 105 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
106 106 'reuse existing json connection files')
107 107 })
108 108
109 109 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
110 110 "Use HMAC digests for authentication of messages.",
111 111 "Don't authenticate messages."
112 112 ))
113 113 aliases = dict(
114 114 reuse_files = 'IPControllerApp.reuse_files',
115 115 secure = 'IPControllerApp.secure',
116 116 ssh = 'IPControllerApp.ssh_server',
117 117 use_threads = 'IPControllerApp.use_threads',
118 118 location = 'IPControllerApp.location',
119 119
120 120 ident = 'Session.session',
121 121 user = 'Session.username',
122 122 exec_key = 'Session.keyfile',
123 123
124 124 url = 'HubFactory.url',
125 125 ip = 'HubFactory.ip',
126 126 transport = 'HubFactory.transport',
127 127 port = 'HubFactory.regport',
128 128
129 129 ping = 'HeartMonitor.period',
130 130
131 131 scheme = 'TaskScheduler.scheme_name',
132 132 hwm = 'TaskScheduler.hwm',
133 133 )
134 134 aliases.update(base_aliases)
135 135
136 136 class IPControllerApp(BaseParallelApplication):
137 137
138 138 name = u'ipcontroller'
139 139 description = _description
140 140 config_file_name = Unicode(default_config_file_name)
141 141 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
142 142
143 143 # change default to True
144 144 auto_create = Bool(True, config=True,
145 145 help="""Whether to create profile dir if it doesn't exist.""")
146 146
147 147 reuse_files = Bool(False, config=True,
148 148 help='Whether to reuse existing json connection files.'
149 149 )
150 150 secure = Bool(True, config=True,
151 151 help='Whether to use HMAC digests for extra message authentication.'
152 152 )
153 153 ssh_server = Unicode(u'', config=True,
154 154 help="""ssh url for clients to use when connecting to the Controller
155 155 processes. It should be of the form: [user@]server[:port]. The
156 156 Controller's listening addresses must be accessible from the ssh server""",
157 157 )
158 158 location = Unicode(u'', config=True,
159 159 help="""The external IP or domain name of the Controller, used for disambiguating
160 160 engine and client connections.""",
161 161 )
162 162 import_statements = List([], config=True,
163 163 help="import statements to be run at startup. Necessary in some environments"
164 164 )
165 165
166 166 use_threads = Bool(False, config=True,
167 167 help='Use threads instead of processes for the schedulers',
168 168 )
169 169
170 170 # internal
171 171 children = List()
172 172 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
173 173
174 174 def _use_threads_changed(self, name, old, new):
175 175 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
176 176
177 177 aliases = Dict(aliases)
178 178 flags = Dict(flags)
179 179
180 180
181 181 def save_connection_dict(self, fname, cdict):
182 182 """save a connection dict to json file."""
183 183 c = self.config
184 184 url = cdict['url']
185 185 location = cdict['location']
186 186 if not location:
187 187 try:
188 188 proto,ip,port = split_url(url)
189 189 except AssertionError:
190 190 pass
191 191 else:
192 192 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
193 193 cdict['location'] = location
194 194 fname = os.path.join(self.profile_dir.security_dir, fname)
195 195 with open(fname, 'wb') as f:
196 196 f.write(json.dumps(cdict, indent=2))
197 197 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
198 198
199 199 def load_config_from_json(self):
200 200 """load config from existing json connector files."""
201 201 c = self.config
202 202 # load from engine config
203 203 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
204 204 cfg = json.loads(f.read())
205 key = c.Session.key = cfg['exec_key']
205 key = c.Session.key = ensure_bytes(cfg['exec_key'])
206 206 xport,addr = cfg['url'].split('://')
207 207 c.HubFactory.engine_transport = xport
208 208 ip,ports = addr.split(':')
209 209 c.HubFactory.engine_ip = ip
210 210 c.HubFactory.regport = int(ports)
211 211 self.location = cfg['location']
212
213 212 # load client config
214 213 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
215 214 cfg = json.loads(f.read())
216 215 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
217 216 xport,addr = cfg['url'].split('://')
218 217 c.HubFactory.client_transport = xport
219 218 ip,ports = addr.split(':')
220 219 c.HubFactory.client_ip = ip
221 220 self.ssh_server = cfg['ssh']
222 221 assert int(ports) == c.HubFactory.regport, "regport mismatch"
223 222
224 223 def init_hub(self):
225 224 c = self.config
226 225
227 226 self.do_import_statements()
228 227 reusing = self.reuse_files
229 228 if reusing:
230 229 try:
231 230 self.load_config_from_json()
232 231 except (AssertionError,IOError):
233 232 reusing=False
234 233 # check again, because reusing may have failed:
235 234 if reusing:
236 235 pass
237 236 elif self.secure:
238 237 key = str(uuid.uuid4())
239 238 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
240 239 # with open(keyfile, 'w') as f:
241 240 # f.write(key)
242 241 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
243 c.Session.key = key
242 c.Session.key = ensure_bytes(key)
244 243 else:
245 key = c.Session.key = ''
244 key = c.Session.key = b''
246 245
247 246 try:
248 247 self.factory = HubFactory(config=c, log=self.log)
249 248 # self.start_logging()
250 249 self.factory.init_hub()
251 250 except:
252 251 self.log.error("Couldn't construct the Controller", exc_info=True)
253 252 self.exit(1)
254 253
255 254 if not reusing:
256 255 # save to new json config files
257 256 f = self.factory
258 257 cdict = {'exec_key' : key,
259 258 'ssh' : self.ssh_server,
260 259 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
261 260 'location' : self.location
262 261 }
263 262 self.save_connection_dict('ipcontroller-client.json', cdict)
264 263 edict = cdict
265 264 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
266 265 self.save_connection_dict('ipcontroller-engine.json', edict)
267 266
268 267 #
269 268 def init_schedulers(self):
270 269 children = self.children
271 270 mq = import_item(str(self.mq_class))
272 271
273 272 hub = self.factory
274 273 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
275 274 # IOPub relay (in a Process)
276 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
275 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
277 276 q.bind_in(hub.client_info['iopub'])
278 277 q.bind_out(hub.engine_info['iopub'])
279 q.setsockopt_out(zmq.SUBSCRIBE, '')
278 q.setsockopt_out(zmq.SUBSCRIBE, b'')
280 279 q.connect_mon(hub.monitor_url)
281 280 q.daemon=True
282 281 children.append(q)
283 282
284 283 # Multiplexer Queue (in a Process)
285 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
284 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'in', b'out')
286 285 q.bind_in(hub.client_info['mux'])
287 q.setsockopt_in(zmq.IDENTITY, 'mux')
286 q.setsockopt_in(zmq.IDENTITY, b'mux')
288 287 q.bind_out(hub.engine_info['mux'])
289 288 q.connect_mon(hub.monitor_url)
290 289 q.daemon=True
291 290 children.append(q)
292 291
293 292 # Control Queue (in a Process)
294 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'incontrol', b'outcontrol')
295 294 q.bind_in(hub.client_info['control'])
296 q.setsockopt_in(zmq.IDENTITY, 'control')
295 q.setsockopt_in(zmq.IDENTITY, b'control')
297 296 q.bind_out(hub.engine_info['control'])
298 297 q.connect_mon(hub.monitor_url)
299 298 q.daemon=True
300 299 children.append(q)
301 300 try:
302 301 scheme = self.config.TaskScheduler.scheme_name
303 302 except AttributeError:
304 303 scheme = TaskScheduler.scheme_name.get_default_value()
305 304 # Task Queue (in a Process)
306 305 if scheme == 'pure':
307 306 self.log.warn("task::using pure XREQ Task scheduler")
308 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
307 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, b'intask', b'outtask')
309 308 # q.setsockopt_out(zmq.HWM, hub.hwm)
310 309 q.bind_in(hub.client_info['task'][1])
311 q.setsockopt_in(zmq.IDENTITY, 'task')
310 q.setsockopt_in(zmq.IDENTITY, b'task')
312 311 q.bind_out(hub.engine_info['task'])
313 312 q.connect_mon(hub.monitor_url)
314 313 q.daemon=True
315 314 children.append(q)
316 315 elif scheme == 'none':
317 316 self.log.warn("task::using no Task scheduler")
318 317
319 318 else:
320 319 self.log.info("task::using Python %s Task scheduler"%scheme)
321 320 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
322 321 hub.monitor_url, hub.client_info['notification'])
323 322 kwargs = dict(logname='scheduler', loglevel=self.log_level,
324 323 log_url = self.log_url, config=dict(self.config))
325 324 if 'Process' in self.mq_class:
326 325 # run the Python scheduler in a Process
327 326 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
328 327 q.daemon=True
329 328 children.append(q)
330 329 else:
331 330 # single-threaded Controller
332 331 kwargs['in_thread'] = True
333 332 launch_scheduler(*sargs, **kwargs)
334 333
335 334
336 335 def save_urls(self):
337 336 """save the registration urls to files."""
338 337 c = self.config
339 338
340 339 sec_dir = self.profile_dir.security_dir
341 340 cf = self.factory
342 341
343 342 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
344 343 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
345 344
346 345 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
347 346 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
348 347
349 348
350 349 def do_import_statements(self):
351 350 statements = self.import_statements
352 351 for s in statements:
353 352 try:
354 353 self.log.msg("Executing statement: '%s'" % s)
355 354 exec s in globals(), locals()
356 355 except:
357 356 self.log.msg("Error running statement: %s" % s)
358 357
359 358 def forward_logging(self):
360 359 if self.log_url:
361 360 self.log.info("Forwarding logging to %s"%self.log_url)
362 361 context = zmq.Context.instance()
363 362 lsock = context.socket(zmq.PUB)
364 363 lsock.connect(self.log_url)
365 364 handler = PUBHandler(lsock)
366 365 self.log.removeHandler(self._log_handler)
367 366 handler.root_topic = 'controller'
368 367 handler.setLevel(self.log_level)
369 368 self.log.addHandler(handler)
370 369 self._log_handler = handler
371 370 # #
372 371
373 372 def initialize(self, argv=None):
374 373 super(IPControllerApp, self).initialize(argv)
375 374 self.forward_logging()
376 375 self.init_hub()
377 376 self.init_schedulers()
378 377
379 378 def start(self):
380 379 # Start the subprocesses:
381 380 self.factory.start()
382 381 child_procs = []
383 382 for child in self.children:
384 383 child.start()
385 384 if isinstance(child, ProcessMonitoredQueue):
386 385 child_procs.append(child.launcher)
387 386 elif isinstance(child, Process):
388 387 child_procs.append(child)
389 388 if child_procs:
390 389 signal_children(child_procs)
391 390
392 391 self.write_pid_file(overwrite=True)
393 392
394 393 try:
395 394 self.factory.loop.start()
396 395 except KeyboardInterrupt:
397 396 self.log.critical("Interrupted, Exiting...\n")
398 397
399 398
400 399
401 400 def launch_new_instance():
402 401 """Create and run the IPython controller"""
403 402 if sys.platform == 'win32':
404 403 # make sure we don't get called from a multiprocessing subprocess
405 404 # this can result in infinite Controllers being started on Windows
406 405 # which doesn't have a proper fork, so multiprocessing is wonky
407 406
408 407 # this only comes up when IPython has been installed using vanilla
409 408 # setuptools, and *not* distribute.
410 409 import multiprocessing
411 410 p = multiprocessing.current_process()
412 411 # the main process has name 'MainProcess'
413 412 # subprocesses will have names like 'Process-1'
414 413 if p.name != 'MainProcess':
415 414 # we are a subprocess, don't start another Controller!
416 415 return
417 416 app = IPControllerApp.instance()
418 417 app.initialize()
419 418 app.start()
420 419
421 420
422 421 if __name__ == '__main__':
423 422 launch_new_instance()
@@ -1,304 +1,301 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import json
25 25 import os
26 26 import sys
27 27 import time
28 28
29 29 import zmq
30 30 from zmq.eventloop import ioloop
31 31
32 32 from IPython.core.profiledir import ProfileDir
33 33 from IPython.parallel.apps.baseapp import (
34 34 BaseParallelApplication,
35 35 base_aliases,
36 36 base_flags,
37 37 )
38 38 from IPython.zmq.log import EnginePUBHandler
39 39
40 40 from IPython.config.configurable import Configurable
41 41 from IPython.zmq.session import Session
42 42 from IPython.parallel.engine.engine import EngineFactory
43 43 from IPython.parallel.engine.streamkernel import Kernel
44 from IPython.parallel.util import disambiguate_url
44 from IPython.parallel.util import disambiguate_url, ensure_bytes
45 45
46 46 from IPython.utils.importstring import import_item
47 47 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float
48 48
49 49
50 50 #-----------------------------------------------------------------------------
51 51 # Module level variables
52 52 #-----------------------------------------------------------------------------
53 53
54 54 #: The default config file name for this application
55 55 default_config_file_name = u'ipengine_config.py'
56 56
57 57 _description = """Start an IPython engine for parallel computing.
58 58
59 59 IPython engines run in parallel and perform computations on behalf of a client
60 60 and controller. A controller needs to be started before the engines. The
61 61 engine can be configured using command line options or using a cluster
62 62 directory. Cluster directories contain config, log and security files and are
63 63 usually located in your ipython directory and named as "profile_name".
64 64 See the `profile` and `profile_dir` options for details.
65 65 """
66 66
67 67
68 68 #-----------------------------------------------------------------------------
69 69 # MPI configuration
70 70 #-----------------------------------------------------------------------------
71 71
72 72 mpi4py_init = """from mpi4py import MPI as mpi
73 73 mpi.size = mpi.COMM_WORLD.Get_size()
74 74 mpi.rank = mpi.COMM_WORLD.Get_rank()
75 75 """
76 76
77 77
78 78 pytrilinos_init = """from PyTrilinos import Epetra
79 79 class SimpleStruct:
80 80 pass
81 81 mpi = SimpleStruct()
82 82 mpi.rank = 0
83 83 mpi.size = 0
84 84 """
85 85
86 86 class MPI(Configurable):
87 87 """Configurable for MPI initialization"""
88 88 use = Unicode('', config=True,
89 89 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
90 90 )
91 91
92 92 def _on_use_changed(self, old, new):
93 93 # load default init script if it's not set
94 94 if not self.init_script:
95 95 self.init_script = self.default_inits.get(new, '')
96 96
97 97 init_script = Unicode('', config=True,
98 98 help="Initialization code for MPI")
99 99
100 100 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
101 101 config=True)
102 102
103 103
104 104 #-----------------------------------------------------------------------------
105 105 # Main application
106 106 #-----------------------------------------------------------------------------
107 107 aliases = dict(
108 108 file = 'IPEngineApp.url_file',
109 109 c = 'IPEngineApp.startup_command',
110 110 s = 'IPEngineApp.startup_script',
111 111
112 112 ident = 'Session.session',
113 113 user = 'Session.username',
114 114 exec_key = 'Session.keyfile',
115 115
116 116 url = 'EngineFactory.url',
117 117 ip = 'EngineFactory.ip',
118 118 transport = 'EngineFactory.transport',
119 119 port = 'EngineFactory.regport',
120 120 location = 'EngineFactory.location',
121 121
122 122 timeout = 'EngineFactory.timeout',
123 123
124 124 mpi = 'MPI.use',
125 125
126 126 )
127 127 aliases.update(base_aliases)
128 128
129 129 class IPEngineApp(BaseParallelApplication):
130 130
131 131 name = Unicode(u'ipengine')
132 132 description = Unicode(_description)
133 133 config_file_name = Unicode(default_config_file_name)
134 134 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
135 135
136 136 startup_script = Unicode(u'', config=True,
137 137 help='specify a script to be run at startup')
138 138 startup_command = Unicode('', config=True,
139 139 help='specify a command to be run at startup')
140 140
141 141 url_file = Unicode(u'', config=True,
142 142 help="""The full location of the file containing the connection information for
143 143 the controller. If this is not given, the file must be in the
144 144 security directory of the cluster directory. This location is
145 145 resolved using the `profile` or `profile_dir` options.""",
146 146 )
147 147 wait_for_url_file = Float(5, config=True,
148 148 help="""The maximum number of seconds to wait for url_file to exist.
149 149 This is useful for batch-systems and shared-filesystems where the
150 150 controller and engine are started at the same time and it
151 151 may take a moment for the controller to write the connector files.""")
152 152
153 153 url_file_name = Unicode(u'ipcontroller-engine.json')
154 154 log_url = Unicode('', config=True,
155 155 help="""The URL for the iploggerapp instance, for forwarding
156 156 logging to a central location.""")
157 157
158 158 aliases = Dict(aliases)
159 159
160 160 # def find_key_file(self):
161 161 # """Set the key file.
162 162 #
163 163 # Here we don't try to actually see if it exists for is valid as that
164 164 # is hadled by the connection logic.
165 165 # """
166 166 # config = self.master_config
167 167 # # Find the actual controller key file
168 168 # if not config.Global.key_file:
169 169 # try_this = os.path.join(
170 170 # config.Global.profile_dir,
171 171 # config.Global.security_dir,
172 172 # config.Global.key_file_name
173 173 # )
174 174 # config.Global.key_file = try_this
175 175
176 176 def find_url_file(self):
177 177 """Set the url file.
178 178
179 179 Here we don't try to actually see if it exists for is valid as that
180 180 is hadled by the connection logic.
181 181 """
182 182 config = self.config
183 183 # Find the actual controller key file
184 184 if not self.url_file:
185 185 self.url_file = os.path.join(
186 186 self.profile_dir.security_dir,
187 187 self.url_file_name
188 188 )
189 189 def init_engine(self):
190 190 # This is the working dir by now.
191 191 sys.path.insert(0, '')
192 192 config = self.config
193 193 # print config
194 194 self.find_url_file()
195 195
196 196 # was the url manually specified?
197 197 keys = set(self.config.EngineFactory.keys())
198 198 keys = keys.union(set(self.config.RegistrationFactory.keys()))
199 199
200 200 if keys.intersection(set(['ip', 'url', 'port'])):
201 201 # Connection info was specified, don't wait for the file
202 202 url_specified = True
203 203 self.wait_for_url_file = 0
204 204 else:
205 205 url_specified = False
206 206
207 207 if self.wait_for_url_file and not os.path.exists(self.url_file):
208 208 self.log.warn("url_file %r not found"%self.url_file)
209 209 self.log.warn("Waiting up to %.1f seconds for it to arrive."%self.wait_for_url_file)
210 210 tic = time.time()
211 211 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
212 212 # wait for url_file to exist, for up to 10 seconds
213 213 time.sleep(0.1)
214 214
215 215 if os.path.exists(self.url_file):
216 216 self.log.info("Loading url_file %r"%self.url_file)
217 217 with open(self.url_file) as f:
218 218 d = json.loads(f.read())
219 for k,v in d.iteritems():
220 if isinstance(v, unicode):
221 d[k] = v.encode()
222 219 if d['exec_key']:
223 config.Session.key = d['exec_key']
220 config.Session.key = ensure_bytes(d['exec_key'])
224 221 d['url'] = disambiguate_url(d['url'], d['location'])
225 222 config.EngineFactory.url = d['url']
226 223 config.EngineFactory.location = d['location']
227 224 elif not url_specified:
228 225 self.log.critical("Fatal: url file never arrived: %s"%self.url_file)
229 226 self.exit(1)
230 227
231 228
232 229 try:
233 230 exec_lines = config.Kernel.exec_lines
234 231 except AttributeError:
235 232 config.Kernel.exec_lines = []
236 233 exec_lines = config.Kernel.exec_lines
237 234
238 235 if self.startup_script:
239 236 enc = sys.getfilesystemencoding() or 'utf8'
240 237 cmd="execfile(%r)"%self.startup_script.encode(enc)
241 238 exec_lines.append(cmd)
242 239 if self.startup_command:
243 240 exec_lines.append(self.startup_command)
244 241
245 242 # Create the underlying shell class and Engine
246 243 # shell_class = import_item(self.master_config.Global.shell_class)
247 244 # print self.config
248 245 try:
249 246 self.engine = EngineFactory(config=config, log=self.log)
250 247 except:
251 248 self.log.error("Couldn't start the Engine", exc_info=True)
252 249 self.exit(1)
253 250
254 251 def forward_logging(self):
255 252 if self.log_url:
256 253 self.log.info("Forwarding logging to %s"%self.log_url)
257 254 context = self.engine.context
258 255 lsock = context.socket(zmq.PUB)
259 256 lsock.connect(self.log_url)
260 257 self.log.removeHandler(self._log_handler)
261 258 handler = EnginePUBHandler(self.engine, lsock)
262 259 handler.setLevel(self.log_level)
263 260 self.log.addHandler(handler)
264 261 self._log_handler = handler
265 262 #
266 263 def init_mpi(self):
267 264 global mpi
268 265 self.mpi = MPI(config=self.config)
269 266
270 267 mpi_import_statement = self.mpi.init_script
271 268 if mpi_import_statement:
272 269 try:
273 270 self.log.info("Initializing MPI:")
274 271 self.log.info(mpi_import_statement)
275 272 exec mpi_import_statement in globals()
276 273 except:
277 274 mpi = None
278 275 else:
279 276 mpi = None
280 277
281 278 def initialize(self, argv=None):
282 279 super(IPEngineApp, self).initialize(argv)
283 280 self.init_mpi()
284 281 self.init_engine()
285 282 self.forward_logging()
286 283
287 284 def start(self):
288 285 self.engine.start()
289 286 try:
290 287 self.engine.loop.start()
291 288 except KeyboardInterrupt:
292 289 self.log.critical("Engine Interrupted, shutting down...\n")
293 290
294 291
295 292 def launch_new_instance():
296 293 """Create and run the IPython engine"""
297 294 app = IPEngineApp.instance()
298 295 app.initialize()
299 296 app.start()
300 297
301 298
302 299 if __name__ == '__main__':
303 300 launch_new_instance()
304 301
@@ -1,1422 +1,1428 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 import sys
20 21 import time
21 22 import warnings
22 23 from datetime import datetime
23 24 from getpass import getpass
24 25 from pprint import pprint
25 26
26 27 pjoin = os.path.join
27 28
28 29 import zmq
29 30 # from zmq.eventloop import ioloop, zmqstream
30 31
31 32 from IPython.config.configurable import MultipleInstanceError
32 33 from IPython.core.application import BaseIPythonApplication
33 34
34 35 from IPython.utils.jsonutil import rekey
35 36 from IPython.utils.localinterfaces import LOCAL_IPS
36 37 from IPython.utils.path import get_ipython_dir
37 38 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
38 39 Dict, List, Bool, Set)
39 40 from IPython.external.decorator import decorator
40 41 from IPython.external.ssh import tunnel
41 42
42 43 from IPython.parallel import error
43 44 from IPython.parallel import util
44 45
45 46 from IPython.zmq.session import Session, Message
46 47
47 48 from .asyncresult import AsyncResult, AsyncHubResult
48 49 from IPython.core.profiledir import ProfileDir, ProfileDirError
49 50 from .view import DirectView, LoadBalancedView
50 51
52 if sys.version_info[0] >= 3:
53 # xrange is used in a coupe 'isinstance' tests in py2
54 # should be just 'range' in 3k
55 xrange = range
56
51 57 #--------------------------------------------------------------------------
52 58 # Decorators for Client methods
53 59 #--------------------------------------------------------------------------
54 60
55 61 @decorator
56 62 def spin_first(f, self, *args, **kwargs):
57 63 """Call spin() to sync state prior to calling the method."""
58 64 self.spin()
59 65 return f(self, *args, **kwargs)
60 66
61 67
62 68 #--------------------------------------------------------------------------
63 69 # Classes
64 70 #--------------------------------------------------------------------------
65 71
66 72 class Metadata(dict):
67 73 """Subclass of dict for initializing metadata values.
68 74
69 75 Attribute access works on keys.
70 76
71 77 These objects have a strict set of keys - errors will raise if you try
72 78 to add new keys.
73 79 """
74 80 def __init__(self, *args, **kwargs):
75 81 dict.__init__(self)
76 82 md = {'msg_id' : None,
77 83 'submitted' : None,
78 84 'started' : None,
79 85 'completed' : None,
80 86 'received' : None,
81 87 'engine_uuid' : None,
82 88 'engine_id' : None,
83 89 'follow' : None,
84 90 'after' : None,
85 91 'status' : None,
86 92
87 93 'pyin' : None,
88 94 'pyout' : None,
89 95 'pyerr' : None,
90 96 'stdout' : '',
91 97 'stderr' : '',
92 98 }
93 99 self.update(md)
94 100 self.update(dict(*args, **kwargs))
95 101
96 102 def __getattr__(self, key):
97 103 """getattr aliased to getitem"""
98 104 if key in self.iterkeys():
99 105 return self[key]
100 106 else:
101 107 raise AttributeError(key)
102 108
103 109 def __setattr__(self, key, value):
104 110 """setattr aliased to setitem, with strict"""
105 111 if key in self.iterkeys():
106 112 self[key] = value
107 113 else:
108 114 raise AttributeError(key)
109 115
110 116 def __setitem__(self, key, value):
111 117 """strict static key enforcement"""
112 118 if key in self.iterkeys():
113 119 dict.__setitem__(self, key, value)
114 120 else:
115 121 raise KeyError(key)
116 122
117 123
118 124 class Client(HasTraits):
119 125 """A semi-synchronous client to the IPython ZMQ cluster
120 126
121 127 Parameters
122 128 ----------
123 129
124 130 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
125 131 Connection information for the Hub's registration. If a json connector
126 132 file is given, then likely no further configuration is necessary.
127 133 [Default: use profile]
128 134 profile : bytes
129 135 The name of the Cluster profile to be used to find connector information.
130 136 If run from an IPython application, the default profile will be the same
131 137 as the running application, otherwise it will be 'default'.
132 138 context : zmq.Context
133 139 Pass an existing zmq.Context instance, otherwise the client will create its own.
134 140 debug : bool
135 141 flag for lots of message printing for debug purposes
136 142 timeout : int/float
137 143 time (in seconds) to wait for connection replies from the Hub
138 144 [Default: 10]
139 145
140 146 #-------------- session related args ----------------
141 147
142 148 config : Config object
143 149 If specified, this will be relayed to the Session for configuration
144 150 username : str
145 151 set username for the session object
146 152 packer : str (import_string) or callable
147 153 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
148 154 function to serialize messages. Must support same input as
149 155 JSON, and output must be bytes.
150 156 You can pass a callable directly as `pack`
151 157 unpacker : str (import_string) or callable
152 158 The inverse of packer. Only necessary if packer is specified as *not* one
153 159 of 'json' or 'pickle'.
154 160
155 161 #-------------- ssh related args ----------------
156 162 # These are args for configuring the ssh tunnel to be used
157 163 # credentials are used to forward connections over ssh to the Controller
158 164 # Note that the ip given in `addr` needs to be relative to sshserver
159 165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
160 166 # and set sshserver as the same machine the Controller is on. However,
161 167 # the only requirement is that sshserver is able to see the Controller
162 168 # (i.e. is within the same trusted network).
163 169
164 170 sshserver : str
165 171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
166 172 If keyfile or password is specified, and this is not, it will default to
167 173 the ip given in addr.
168 174 sshkey : str; path to public ssh key file
169 175 This specifies a key to be used in ssh login, default None.
170 176 Regular default ssh keys will be used without specifying this argument.
171 177 password : str
172 178 Your ssh password to sshserver. Note that if this is left None,
173 179 you will be prompted for it if passwordless key based login is unavailable.
174 180 paramiko : bool
175 181 flag for whether to use paramiko instead of shell ssh for tunneling.
176 182 [default: True on win32, False else]
177 183
178 184 ------- exec authentication args -------
179 185 If even localhost is untrusted, you can have some protection against
180 186 unauthorized execution by signing messages with HMAC digests.
181 187 Messages are still sent as cleartext, so if someone can snoop your
182 188 loopback traffic this will not protect your privacy, but will prevent
183 189 unauthorized execution.
184 190
185 191 exec_key : str
186 192 an authentication key or file containing a key
187 193 default: None
188 194
189 195
190 196 Attributes
191 197 ----------
192 198
193 199 ids : list of int engine IDs
194 200 requesting the ids attribute always synchronizes
195 201 the registration state. To request ids without synchronization,
196 202 use semi-private _ids attributes.
197 203
198 204 history : list of msg_ids
199 205 a list of msg_ids, keeping track of all the execution
200 206 messages you have submitted in order.
201 207
202 208 outstanding : set of msg_ids
203 209 a set of msg_ids that have been submitted, but whose
204 210 results have not yet been received.
205 211
206 212 results : dict
207 213 a dict of all our results, keyed by msg_id
208 214
209 215 block : bool
210 216 determines default behavior when block not specified
211 217 in execution methods
212 218
213 219 Methods
214 220 -------
215 221
216 222 spin
217 223 flushes incoming results and registration state changes
218 224 control methods spin, and requesting `ids` also ensures up to date
219 225
220 226 wait
221 227 wait on one or more msg_ids
222 228
223 229 execution methods
224 230 apply
225 231 legacy: execute, run
226 232
227 233 data movement
228 234 push, pull, scatter, gather
229 235
230 236 query methods
231 237 queue_status, get_result, purge, result_status
232 238
233 239 control methods
234 240 abort, shutdown
235 241
236 242 """
237 243
238 244
239 245 block = Bool(False)
240 246 outstanding = Set()
241 247 results = Instance('collections.defaultdict', (dict,))
242 248 metadata = Instance('collections.defaultdict', (Metadata,))
243 249 history = List()
244 250 debug = Bool(False)
245 251
246 252 profile=Unicode()
247 253 def _profile_default(self):
248 254 if BaseIPythonApplication.initialized():
249 255 # an IPython app *might* be running, try to get its profile
250 256 try:
251 257 return BaseIPythonApplication.instance().profile
252 258 except (AttributeError, MultipleInstanceError):
253 259 # could be a *different* subclass of config.Application,
254 260 # which would raise one of these two errors.
255 261 return u'default'
256 262 else:
257 263 return u'default'
258 264
259 265
260 266 _outstanding_dict = Instance('collections.defaultdict', (set,))
261 267 _ids = List()
262 268 _connected=Bool(False)
263 269 _ssh=Bool(False)
264 270 _context = Instance('zmq.Context')
265 271 _config = Dict()
266 272 _engines=Instance(util.ReverseDict, (), {})
267 273 # _hub_socket=Instance('zmq.Socket')
268 274 _query_socket=Instance('zmq.Socket')
269 275 _control_socket=Instance('zmq.Socket')
270 276 _iopub_socket=Instance('zmq.Socket')
271 277 _notification_socket=Instance('zmq.Socket')
272 278 _mux_socket=Instance('zmq.Socket')
273 279 _task_socket=Instance('zmq.Socket')
274 280 _task_scheme=Unicode()
275 281 _closed = False
276 282 _ignored_control_replies=Int(0)
277 283 _ignored_hub_replies=Int(0)
278 284
279 285 def __new__(self, *args, **kw):
280 286 # don't raise on positional args
281 287 return HasTraits.__new__(self, **kw)
282 288
283 289 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
284 290 context=None, debug=False, exec_key=None,
285 291 sshserver=None, sshkey=None, password=None, paramiko=None,
286 292 timeout=10, **extra_args
287 293 ):
288 294 if profile:
289 295 super(Client, self).__init__(debug=debug, profile=profile)
290 296 else:
291 297 super(Client, self).__init__(debug=debug)
292 298 if context is None:
293 299 context = zmq.Context.instance()
294 300 self._context = context
295 301
296 302 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
297 303 if self._cd is not None:
298 304 if url_or_file is None:
299 305 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
300 306 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
301 307 " Please specify at least one of url_or_file or profile."
302 308
303 309 try:
304 310 util.validate_url(url_or_file)
305 311 except AssertionError:
306 312 if not os.path.exists(url_or_file):
307 313 if self._cd:
308 314 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
309 315 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
310 316 with open(url_or_file) as f:
311 317 cfg = json.loads(f.read())
312 318 else:
313 319 cfg = {'url':url_or_file}
314 320
315 321 # sync defaults from args, json:
316 322 if sshserver:
317 323 cfg['ssh'] = sshserver
318 324 if exec_key:
319 325 cfg['exec_key'] = exec_key
320 326 exec_key = cfg['exec_key']
321 327 location = cfg.setdefault('location', None)
322 328 cfg['url'] = util.disambiguate_url(cfg['url'], location)
323 329 url = cfg['url']
324 330 proto,addr,port = util.split_url(url)
325 331 if location is not None and addr == '127.0.0.1':
326 332 # location specified, and connection is expected to be local
327 333 if location not in LOCAL_IPS and not sshserver:
328 334 # load ssh from JSON *only* if the controller is not on
329 335 # this machine
330 336 sshserver=cfg['ssh']
331 337 if location not in LOCAL_IPS and not sshserver:
332 338 # warn if no ssh specified, but SSH is probably needed
333 339 # This is only a warning, because the most likely cause
334 340 # is a local Controller on a laptop whose IP is dynamic
335 341 warnings.warn("""
336 342 Controller appears to be listening on localhost, but not on this machine.
337 343 If this is true, you should specify Client(...,sshserver='you@%s')
338 344 or instruct your controller to listen on an external IP."""%location,
339 345 RuntimeWarning)
340 346
341 347 self._config = cfg
342 348
343 349 self._ssh = bool(sshserver or sshkey or password)
344 350 if self._ssh and sshserver is None:
345 351 # default to ssh via localhost
346 352 sshserver = url.split('://')[1].split(':')[0]
347 353 if self._ssh and password is None:
348 354 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
349 355 password=False
350 356 else:
351 357 password = getpass("SSH Password for %s: "%sshserver)
352 358 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
353 359
354 360 # configure and construct the session
355 361 if exec_key is not None:
356 362 if os.path.isfile(exec_key):
357 363 extra_args['keyfile'] = exec_key
358 364 else:
359 if isinstance(exec_key, unicode):
360 exec_key = exec_key.encode('ascii')
365 exec_key = util.ensure_bytes(exec_key)
361 366 extra_args['key'] = exec_key
362 367 self.session = Session(**extra_args)
363 368
364 369 self._query_socket = self._context.socket(zmq.XREQ)
365 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
370 self._query_socket.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
366 371 if self._ssh:
367 372 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
368 373 else:
369 374 self._query_socket.connect(url)
370 375
371 376 self.session.debug = self.debug
372 377
373 378 self._notification_handlers = {'registration_notification' : self._register_engine,
374 379 'unregistration_notification' : self._unregister_engine,
375 380 'shutdown_notification' : lambda msg: self.close(),
376 381 }
377 382 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
378 383 'apply_reply' : self._handle_apply_reply}
379 384 self._connect(sshserver, ssh_kwargs, timeout)
380 385
381 386 def __del__(self):
382 387 """cleanup sockets, but _not_ context."""
383 388 self.close()
384 389
385 390 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
386 391 if ipython_dir is None:
387 392 ipython_dir = get_ipython_dir()
388 393 if profile_dir is not None:
389 394 try:
390 395 self._cd = ProfileDir.find_profile_dir(profile_dir)
391 396 return
392 397 except ProfileDirError:
393 398 pass
394 399 elif profile is not None:
395 400 try:
396 401 self._cd = ProfileDir.find_profile_dir_by_name(
397 402 ipython_dir, profile)
398 403 return
399 404 except ProfileDirError:
400 405 pass
401 406 self._cd = None
402 407
403 408 def _update_engines(self, engines):
404 409 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
405 410 for k,v in engines.iteritems():
406 411 eid = int(k)
407 self._engines[eid] = bytes(v) # force not unicode
412 self._engines[eid] = v
408 413 self._ids.append(eid)
409 414 self._ids = sorted(self._ids)
410 415 if sorted(self._engines.keys()) != range(len(self._engines)) and \
411 416 self._task_scheme == 'pure' and self._task_socket:
412 417 self._stop_scheduling_tasks()
413 418
414 419 def _stop_scheduling_tasks(self):
415 420 """Stop scheduling tasks because an engine has been unregistered
416 421 from a pure ZMQ scheduler.
417 422 """
418 423 self._task_socket.close()
419 424 self._task_socket = None
420 425 msg = "An engine has been unregistered, and we are using pure " +\
421 426 "ZMQ task scheduling. Task farming will be disabled."
422 427 if self.outstanding:
423 428 msg += " If you were running tasks when this happened, " +\
424 429 "some `outstanding` msg_ids may never resolve."
425 430 warnings.warn(msg, RuntimeWarning)
426 431
427 432 def _build_targets(self, targets):
428 433 """Turn valid target IDs or 'all' into two lists:
429 434 (int_ids, uuids).
430 435 """
431 436 if not self._ids:
432 437 # flush notification socket if no engines yet, just in case
433 438 if not self.ids:
434 439 raise error.NoEnginesRegistered("Can't build targets without any engines")
435 440
436 441 if targets is None:
437 442 targets = self._ids
438 443 elif isinstance(targets, str):
439 444 if targets.lower() == 'all':
440 445 targets = self._ids
441 446 else:
442 447 raise TypeError("%r not valid str target, must be 'all'"%(targets))
443 448 elif isinstance(targets, int):
444 449 if targets < 0:
445 450 targets = self.ids[targets]
446 451 if targets not in self._ids:
447 452 raise IndexError("No such engine: %i"%targets)
448 453 targets = [targets]
449 454
450 455 if isinstance(targets, slice):
451 456 indices = range(len(self._ids))[targets]
452 457 ids = self.ids
453 458 targets = [ ids[i] for i in indices ]
454 459
455 460 if not isinstance(targets, (tuple, list, xrange)):
456 461 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
457 462
458 return [self._engines[t] for t in targets], list(targets)
463 return [util.ensure_bytes(self._engines[t]) for t in targets], list(targets)
459 464
460 465 def _connect(self, sshserver, ssh_kwargs, timeout):
461 466 """setup all our socket connections to the cluster. This is called from
462 467 __init__."""
463 468
464 469 # Maybe allow reconnecting?
465 470 if self._connected:
466 471 return
467 472 self._connected=True
468 473
469 474 def connect_socket(s, url):
470 475 url = util.disambiguate_url(url, self._config['location'])
471 476 if self._ssh:
472 477 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
473 478 else:
474 479 return s.connect(url)
475 480
476 481 self.session.send(self._query_socket, 'connection_request')
477 482 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
478 483 poller = zmq.Poller()
479 484 poller.register(self._query_socket, zmq.POLLIN)
480 485 # poll expects milliseconds, timeout is seconds
481 486 evts = poller.poll(timeout*1000)
482 487 if not evts:
483 488 raise error.TimeoutError("Hub connection request timed out")
484 489 idents,msg = self.session.recv(self._query_socket,mode=0)
485 490 if self.debug:
486 491 pprint(msg)
487 492 msg = Message(msg)
488 493 content = msg.content
489 494 self._config['registration'] = dict(content)
490 495 if content.status == 'ok':
496 ident = util.ensure_bytes(self.session.session)
491 497 if content.mux:
492 498 self._mux_socket = self._context.socket(zmq.XREQ)
493 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
499 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
494 500 connect_socket(self._mux_socket, content.mux)
495 501 if content.task:
496 502 self._task_scheme, task_addr = content.task
497 503 self._task_socket = self._context.socket(zmq.XREQ)
498 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
504 self._task_socket.setsockopt(zmq.IDENTITY, ident)
499 505 connect_socket(self._task_socket, task_addr)
500 506 if content.notification:
501 507 self._notification_socket = self._context.socket(zmq.SUB)
502 508 connect_socket(self._notification_socket, content.notification)
503 509 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
504 510 # if content.query:
505 511 # self._query_socket = self._context.socket(zmq.XREQ)
506 512 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
507 513 # connect_socket(self._query_socket, content.query)
508 514 if content.control:
509 515 self._control_socket = self._context.socket(zmq.XREQ)
510 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
516 self._control_socket.setsockopt(zmq.IDENTITY, ident)
511 517 connect_socket(self._control_socket, content.control)
512 518 if content.iopub:
513 519 self._iopub_socket = self._context.socket(zmq.SUB)
514 520 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
515 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
521 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
516 522 connect_socket(self._iopub_socket, content.iopub)
517 523 self._update_engines(dict(content.engines))
518 524 else:
519 525 self._connected = False
520 526 raise Exception("Failed to connect!")
521 527
522 528 #--------------------------------------------------------------------------
523 529 # handlers and callbacks for incoming messages
524 530 #--------------------------------------------------------------------------
525 531
526 532 def _unwrap_exception(self, content):
527 533 """unwrap exception, and remap engine_id to int."""
528 534 e = error.unwrap_exception(content)
529 535 # print e.traceback
530 536 if e.engine_info:
531 537 e_uuid = e.engine_info['engine_uuid']
532 538 eid = self._engines[e_uuid]
533 539 e.engine_info['engine_id'] = eid
534 540 return e
535 541
536 542 def _extract_metadata(self, header, parent, content):
537 543 md = {'msg_id' : parent['msg_id'],
538 544 'received' : datetime.now(),
539 545 'engine_uuid' : header.get('engine', None),
540 546 'follow' : parent.get('follow', []),
541 547 'after' : parent.get('after', []),
542 548 'status' : content['status'],
543 549 }
544 550
545 551 if md['engine_uuid'] is not None:
546 552 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
547 553
548 554 if 'date' in parent:
549 555 md['submitted'] = parent['date']
550 556 if 'started' in header:
551 557 md['started'] = header['started']
552 558 if 'date' in header:
553 559 md['completed'] = header['date']
554 560 return md
555 561
556 562 def _register_engine(self, msg):
557 563 """Register a new engine, and update our connection info."""
558 564 content = msg['content']
559 565 eid = content['id']
560 566 d = {eid : content['queue']}
561 567 self._update_engines(d)
562 568
563 569 def _unregister_engine(self, msg):
564 570 """Unregister an engine that has died."""
565 571 content = msg['content']
566 572 eid = int(content['id'])
567 573 if eid in self._ids:
568 574 self._ids.remove(eid)
569 575 uuid = self._engines.pop(eid)
570 576
571 577 self._handle_stranded_msgs(eid, uuid)
572 578
573 579 if self._task_socket and self._task_scheme == 'pure':
574 580 self._stop_scheduling_tasks()
575 581
576 582 def _handle_stranded_msgs(self, eid, uuid):
577 583 """Handle messages known to be on an engine when the engine unregisters.
578 584
579 585 It is possible that this will fire prematurely - that is, an engine will
580 586 go down after completing a result, and the client will be notified
581 587 of the unregistration and later receive the successful result.
582 588 """
583 589
584 590 outstanding = self._outstanding_dict[uuid]
585 591
586 592 for msg_id in list(outstanding):
587 593 if msg_id in self.results:
588 594 # we already
589 595 continue
590 596 try:
591 597 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
592 598 except:
593 599 content = error.wrap_exception()
594 600 # build a fake message:
595 601 parent = {}
596 602 header = {}
597 603 parent['msg_id'] = msg_id
598 604 header['engine'] = uuid
599 605 header['date'] = datetime.now()
600 606 msg = dict(parent_header=parent, header=header, content=content)
601 607 self._handle_apply_reply(msg)
602 608
603 609 def _handle_execute_reply(self, msg):
604 610 """Save the reply to an execute_request into our results.
605 611
606 612 execute messages are never actually used. apply is used instead.
607 613 """
608 614
609 615 parent = msg['parent_header']
610 616 msg_id = parent['msg_id']
611 617 if msg_id not in self.outstanding:
612 618 if msg_id in self.history:
613 619 print ("got stale result: %s"%msg_id)
614 620 else:
615 621 print ("got unknown result: %s"%msg_id)
616 622 else:
617 623 self.outstanding.remove(msg_id)
618 624 self.results[msg_id] = self._unwrap_exception(msg['content'])
619 625
620 626 def _handle_apply_reply(self, msg):
621 627 """Save the reply to an apply_request into our results."""
622 628 parent = msg['parent_header']
623 629 msg_id = parent['msg_id']
624 630 if msg_id not in self.outstanding:
625 631 if msg_id in self.history:
626 632 print ("got stale result: %s"%msg_id)
627 633 print self.results[msg_id]
628 634 print msg
629 635 else:
630 636 print ("got unknown result: %s"%msg_id)
631 637 else:
632 638 self.outstanding.remove(msg_id)
633 639 content = msg['content']
634 640 header = msg['header']
635 641
636 642 # construct metadata:
637 643 md = self.metadata[msg_id]
638 644 md.update(self._extract_metadata(header, parent, content))
639 645 # is this redundant?
640 646 self.metadata[msg_id] = md
641 647
642 648 e_outstanding = self._outstanding_dict[md['engine_uuid']]
643 649 if msg_id in e_outstanding:
644 650 e_outstanding.remove(msg_id)
645 651
646 652 # construct result:
647 653 if content['status'] == 'ok':
648 654 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
649 655 elif content['status'] == 'aborted':
650 656 self.results[msg_id] = error.TaskAborted(msg_id)
651 657 elif content['status'] == 'resubmitted':
652 658 # TODO: handle resubmission
653 659 pass
654 660 else:
655 661 self.results[msg_id] = self._unwrap_exception(content)
656 662
657 663 def _flush_notifications(self):
658 664 """Flush notifications of engine registrations waiting
659 665 in ZMQ queue."""
660 666 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
661 667 while msg is not None:
662 668 if self.debug:
663 669 pprint(msg)
664 670 msg_type = msg['msg_type']
665 671 handler = self._notification_handlers.get(msg_type, None)
666 672 if handler is None:
667 673 raise Exception("Unhandled message type: %s"%msg.msg_type)
668 674 else:
669 675 handler(msg)
670 676 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
671 677
672 678 def _flush_results(self, sock):
673 679 """Flush task or queue results waiting in ZMQ queue."""
674 680 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
675 681 while msg is not None:
676 682 if self.debug:
677 683 pprint(msg)
678 684 msg_type = msg['msg_type']
679 685 handler = self._queue_handlers.get(msg_type, None)
680 686 if handler is None:
681 687 raise Exception("Unhandled message type: %s"%msg.msg_type)
682 688 else:
683 689 handler(msg)
684 690 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
685 691
686 692 def _flush_control(self, sock):
687 693 """Flush replies from the control channel waiting
688 694 in the ZMQ queue.
689 695
690 696 Currently: ignore them."""
691 697 if self._ignored_control_replies <= 0:
692 698 return
693 699 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
694 700 while msg is not None:
695 701 self._ignored_control_replies -= 1
696 702 if self.debug:
697 703 pprint(msg)
698 704 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
699 705
700 706 def _flush_ignored_control(self):
701 707 """flush ignored control replies"""
702 708 while self._ignored_control_replies > 0:
703 709 self.session.recv(self._control_socket)
704 710 self._ignored_control_replies -= 1
705 711
706 712 def _flush_ignored_hub_replies(self):
707 713 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
708 714 while msg is not None:
709 715 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
710 716
711 717 def _flush_iopub(self, sock):
712 718 """Flush replies from the iopub channel waiting
713 719 in the ZMQ queue.
714 720 """
715 721 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
716 722 while msg is not None:
717 723 if self.debug:
718 724 pprint(msg)
719 725 parent = msg['parent_header']
720 726 msg_id = parent['msg_id']
721 727 content = msg['content']
722 728 header = msg['header']
723 729 msg_type = msg['msg_type']
724 730
725 731 # init metadata:
726 732 md = self.metadata[msg_id]
727 733
728 734 if msg_type == 'stream':
729 735 name = content['name']
730 736 s = md[name] or ''
731 737 md[name] = s + content['data']
732 738 elif msg_type == 'pyerr':
733 739 md.update({'pyerr' : self._unwrap_exception(content)})
734 740 elif msg_type == 'pyin':
735 741 md.update({'pyin' : content['code']})
736 742 else:
737 743 md.update({msg_type : content.get('data', '')})
738 744
739 745 # reduntant?
740 746 self.metadata[msg_id] = md
741 747
742 748 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
743 749
744 750 #--------------------------------------------------------------------------
745 751 # len, getitem
746 752 #--------------------------------------------------------------------------
747 753
748 754 def __len__(self):
749 755 """len(client) returns # of engines."""
750 756 return len(self.ids)
751 757
752 758 def __getitem__(self, key):
753 759 """index access returns DirectView multiplexer objects
754 760
755 761 Must be int, slice, or list/tuple/xrange of ints"""
756 762 if not isinstance(key, (int, slice, tuple, list, xrange)):
757 763 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
758 764 else:
759 765 return self.direct_view(key)
760 766
761 767 #--------------------------------------------------------------------------
762 768 # Begin public methods
763 769 #--------------------------------------------------------------------------
764 770
765 771 @property
766 772 def ids(self):
767 773 """Always up-to-date ids property."""
768 774 self._flush_notifications()
769 775 # always copy:
770 776 return list(self._ids)
771 777
772 778 def close(self):
773 779 if self._closed:
774 780 return
775 781 snames = filter(lambda n: n.endswith('socket'), dir(self))
776 782 for socket in map(lambda name: getattr(self, name), snames):
777 783 if isinstance(socket, zmq.Socket) and not socket.closed:
778 784 socket.close()
779 785 self._closed = True
780 786
781 787 def spin(self):
782 788 """Flush any registration notifications and execution results
783 789 waiting in the ZMQ queue.
784 790 """
785 791 if self._notification_socket:
786 792 self._flush_notifications()
787 793 if self._mux_socket:
788 794 self._flush_results(self._mux_socket)
789 795 if self._task_socket:
790 796 self._flush_results(self._task_socket)
791 797 if self._control_socket:
792 798 self._flush_control(self._control_socket)
793 799 if self._iopub_socket:
794 800 self._flush_iopub(self._iopub_socket)
795 801 if self._query_socket:
796 802 self._flush_ignored_hub_replies()
797 803
798 804 def wait(self, jobs=None, timeout=-1):
799 805 """waits on one or more `jobs`, for up to `timeout` seconds.
800 806
801 807 Parameters
802 808 ----------
803 809
804 810 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
805 811 ints are indices to self.history
806 812 strs are msg_ids
807 813 default: wait on all outstanding messages
808 814 timeout : float
809 815 a time in seconds, after which to give up.
810 816 default is -1, which means no timeout
811 817
812 818 Returns
813 819 -------
814 820
815 821 True : when all msg_ids are done
816 822 False : timeout reached, some msg_ids still outstanding
817 823 """
818 824 tic = time.time()
819 825 if jobs is None:
820 826 theids = self.outstanding
821 827 else:
822 828 if isinstance(jobs, (int, str, AsyncResult)):
823 829 jobs = [jobs]
824 830 theids = set()
825 831 for job in jobs:
826 832 if isinstance(job, int):
827 833 # index access
828 834 job = self.history[job]
829 835 elif isinstance(job, AsyncResult):
830 836 map(theids.add, job.msg_ids)
831 837 continue
832 838 theids.add(job)
833 839 if not theids.intersection(self.outstanding):
834 840 return True
835 841 self.spin()
836 842 while theids.intersection(self.outstanding):
837 843 if timeout >= 0 and ( time.time()-tic ) > timeout:
838 844 break
839 845 time.sleep(1e-3)
840 846 self.spin()
841 847 return len(theids.intersection(self.outstanding)) == 0
842 848
843 849 #--------------------------------------------------------------------------
844 850 # Control methods
845 851 #--------------------------------------------------------------------------
846 852
847 853 @spin_first
848 854 def clear(self, targets=None, block=None):
849 855 """Clear the namespace in target(s)."""
850 856 block = self.block if block is None else block
851 857 targets = self._build_targets(targets)[0]
852 858 for t in targets:
853 859 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
854 860 error = False
855 861 if block:
856 862 self._flush_ignored_control()
857 863 for i in range(len(targets)):
858 864 idents,msg = self.session.recv(self._control_socket,0)
859 865 if self.debug:
860 866 pprint(msg)
861 867 if msg['content']['status'] != 'ok':
862 868 error = self._unwrap_exception(msg['content'])
863 869 else:
864 870 self._ignored_control_replies += len(targets)
865 871 if error:
866 872 raise error
867 873
868 874
869 875 @spin_first
870 876 def abort(self, jobs=None, targets=None, block=None):
871 877 """Abort specific jobs from the execution queues of target(s).
872 878
873 879 This is a mechanism to prevent jobs that have already been submitted
874 880 from executing.
875 881
876 882 Parameters
877 883 ----------
878 884
879 885 jobs : msg_id, list of msg_ids, or AsyncResult
880 886 The jobs to be aborted
881 887
882 888
883 889 """
884 890 block = self.block if block is None else block
885 891 targets = self._build_targets(targets)[0]
886 892 msg_ids = []
887 893 if isinstance(jobs, (basestring,AsyncResult)):
888 894 jobs = [jobs]
889 895 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
890 896 if bad_ids:
891 897 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
892 898 for j in jobs:
893 899 if isinstance(j, AsyncResult):
894 900 msg_ids.extend(j.msg_ids)
895 901 else:
896 902 msg_ids.append(j)
897 903 content = dict(msg_ids=msg_ids)
898 904 for t in targets:
899 905 self.session.send(self._control_socket, 'abort_request',
900 906 content=content, ident=t)
901 907 error = False
902 908 if block:
903 909 self._flush_ignored_control()
904 910 for i in range(len(targets)):
905 911 idents,msg = self.session.recv(self._control_socket,0)
906 912 if self.debug:
907 913 pprint(msg)
908 914 if msg['content']['status'] != 'ok':
909 915 error = self._unwrap_exception(msg['content'])
910 916 else:
911 917 self._ignored_control_replies += len(targets)
912 918 if error:
913 919 raise error
914 920
915 921 @spin_first
916 922 def shutdown(self, targets=None, restart=False, hub=False, block=None):
917 923 """Terminates one or more engine processes, optionally including the hub."""
918 924 block = self.block if block is None else block
919 925 if hub:
920 926 targets = 'all'
921 927 targets = self._build_targets(targets)[0]
922 928 for t in targets:
923 929 self.session.send(self._control_socket, 'shutdown_request',
924 930 content={'restart':restart},ident=t)
925 931 error = False
926 932 if block or hub:
927 933 self._flush_ignored_control()
928 934 for i in range(len(targets)):
929 935 idents,msg = self.session.recv(self._control_socket, 0)
930 936 if self.debug:
931 937 pprint(msg)
932 938 if msg['content']['status'] != 'ok':
933 939 error = self._unwrap_exception(msg['content'])
934 940 else:
935 941 self._ignored_control_replies += len(targets)
936 942
937 943 if hub:
938 944 time.sleep(0.25)
939 945 self.session.send(self._query_socket, 'shutdown_request')
940 946 idents,msg = self.session.recv(self._query_socket, 0)
941 947 if self.debug:
942 948 pprint(msg)
943 949 if msg['content']['status'] != 'ok':
944 950 error = self._unwrap_exception(msg['content'])
945 951
946 952 if error:
947 953 raise error
948 954
949 955 #--------------------------------------------------------------------------
950 956 # Execution related methods
951 957 #--------------------------------------------------------------------------
952 958
953 959 def _maybe_raise(self, result):
954 960 """wrapper for maybe raising an exception if apply failed."""
955 961 if isinstance(result, error.RemoteError):
956 962 raise result
957 963
958 964 return result
959 965
960 966 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
961 967 ident=None):
962 968 """construct and send an apply message via a socket.
963 969
964 970 This is the principal method with which all engine execution is performed by views.
965 971 """
966 972
967 973 assert not self._closed, "cannot use me anymore, I'm closed!"
968 974 # defaults:
969 975 args = args if args is not None else []
970 976 kwargs = kwargs if kwargs is not None else {}
971 977 subheader = subheader if subheader is not None else {}
972 978
973 979 # validate arguments
974 980 if not callable(f):
975 981 raise TypeError("f must be callable, not %s"%type(f))
976 982 if not isinstance(args, (tuple, list)):
977 983 raise TypeError("args must be tuple or list, not %s"%type(args))
978 984 if not isinstance(kwargs, dict):
979 985 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
980 986 if not isinstance(subheader, dict):
981 987 raise TypeError("subheader must be dict, not %s"%type(subheader))
982 988
983 989 bufs = util.pack_apply_message(f,args,kwargs)
984 990
985 991 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
986 992 subheader=subheader, track=track)
987 993
988 994 msg_id = msg['msg_id']
989 995 self.outstanding.add(msg_id)
990 996 if ident:
991 997 # possibly routed to a specific engine
992 998 if isinstance(ident, list):
993 999 ident = ident[-1]
994 1000 if ident in self._engines.values():
995 1001 # save for later, in case of engine death
996 1002 self._outstanding_dict[ident].add(msg_id)
997 1003 self.history.append(msg_id)
998 1004 self.metadata[msg_id]['submitted'] = datetime.now()
999 1005
1000 1006 return msg
1001 1007
1002 1008 #--------------------------------------------------------------------------
1003 1009 # construct a View object
1004 1010 #--------------------------------------------------------------------------
1005 1011
1006 1012 def load_balanced_view(self, targets=None):
1007 1013 """construct a DirectView object.
1008 1014
1009 1015 If no arguments are specified, create a LoadBalancedView
1010 1016 using all engines.
1011 1017
1012 1018 Parameters
1013 1019 ----------
1014 1020
1015 1021 targets: list,slice,int,etc. [default: use all engines]
1016 1022 The subset of engines across which to load-balance
1017 1023 """
1018 1024 if targets is not None:
1019 1025 targets = self._build_targets(targets)[1]
1020 1026 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1021 1027
1022 1028 def direct_view(self, targets='all'):
1023 1029 """construct a DirectView object.
1024 1030
1025 1031 If no targets are specified, create a DirectView
1026 1032 using all engines.
1027 1033
1028 1034 Parameters
1029 1035 ----------
1030 1036
1031 1037 targets: list,slice,int,etc. [default: use all engines]
1032 1038 The engines to use for the View
1033 1039 """
1034 1040 single = isinstance(targets, int)
1035 1041 targets = self._build_targets(targets)[1]
1036 1042 if single:
1037 1043 targets = targets[0]
1038 1044 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1039 1045
1040 1046 #--------------------------------------------------------------------------
1041 1047 # Query methods
1042 1048 #--------------------------------------------------------------------------
1043 1049
1044 1050 @spin_first
1045 1051 def get_result(self, indices_or_msg_ids=None, block=None):
1046 1052 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1047 1053
1048 1054 If the client already has the results, no request to the Hub will be made.
1049 1055
1050 1056 This is a convenient way to construct AsyncResult objects, which are wrappers
1051 1057 that include metadata about execution, and allow for awaiting results that
1052 1058 were not submitted by this Client.
1053 1059
1054 1060 It can also be a convenient way to retrieve the metadata associated with
1055 1061 blocking execution, since it always retrieves
1056 1062
1057 1063 Examples
1058 1064 --------
1059 1065 ::
1060 1066
1061 1067 In [10]: r = client.apply()
1062 1068
1063 1069 Parameters
1064 1070 ----------
1065 1071
1066 1072 indices_or_msg_ids : integer history index, str msg_id, or list of either
1067 1073 The indices or msg_ids of indices to be retrieved
1068 1074
1069 1075 block : bool
1070 1076 Whether to wait for the result to be done
1071 1077
1072 1078 Returns
1073 1079 -------
1074 1080
1075 1081 AsyncResult
1076 1082 A single AsyncResult object will always be returned.
1077 1083
1078 1084 AsyncHubResult
1079 1085 A subclass of AsyncResult that retrieves results from the Hub
1080 1086
1081 1087 """
1082 1088 block = self.block if block is None else block
1083 1089 if indices_or_msg_ids is None:
1084 1090 indices_or_msg_ids = -1
1085 1091
1086 1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1087 1093 indices_or_msg_ids = [indices_or_msg_ids]
1088 1094
1089 1095 theids = []
1090 1096 for id in indices_or_msg_ids:
1091 1097 if isinstance(id, int):
1092 1098 id = self.history[id]
1093 1099 if not isinstance(id, str):
1094 1100 raise TypeError("indices must be str or int, not %r"%id)
1095 1101 theids.append(id)
1096 1102
1097 1103 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1098 1104 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1099 1105
1100 1106 if remote_ids:
1101 1107 ar = AsyncHubResult(self, msg_ids=theids)
1102 1108 else:
1103 1109 ar = AsyncResult(self, msg_ids=theids)
1104 1110
1105 1111 if block:
1106 1112 ar.wait()
1107 1113
1108 1114 return ar
1109 1115
1110 1116 @spin_first
1111 1117 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1112 1118 """Resubmit one or more tasks.
1113 1119
1114 1120 in-flight tasks may not be resubmitted.
1115 1121
1116 1122 Parameters
1117 1123 ----------
1118 1124
1119 1125 indices_or_msg_ids : integer history index, str msg_id, or list of either
1120 1126 The indices or msg_ids of indices to be retrieved
1121 1127
1122 1128 block : bool
1123 1129 Whether to wait for the result to be done
1124 1130
1125 1131 Returns
1126 1132 -------
1127 1133
1128 1134 AsyncHubResult
1129 1135 A subclass of AsyncResult that retrieves results from the Hub
1130 1136
1131 1137 """
1132 1138 block = self.block if block is None else block
1133 1139 if indices_or_msg_ids is None:
1134 1140 indices_or_msg_ids = -1
1135 1141
1136 1142 if not isinstance(indices_or_msg_ids, (list,tuple)):
1137 1143 indices_or_msg_ids = [indices_or_msg_ids]
1138 1144
1139 1145 theids = []
1140 1146 for id in indices_or_msg_ids:
1141 1147 if isinstance(id, int):
1142 1148 id = self.history[id]
1143 1149 if not isinstance(id, str):
1144 1150 raise TypeError("indices must be str or int, not %r"%id)
1145 1151 theids.append(id)
1146 1152
1147 1153 for msg_id in theids:
1148 1154 self.outstanding.discard(msg_id)
1149 1155 if msg_id in self.history:
1150 1156 self.history.remove(msg_id)
1151 1157 self.results.pop(msg_id, None)
1152 1158 self.metadata.pop(msg_id, None)
1153 1159 content = dict(msg_ids = theids)
1154 1160
1155 1161 self.session.send(self._query_socket, 'resubmit_request', content)
1156 1162
1157 1163 zmq.select([self._query_socket], [], [])
1158 1164 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1159 1165 if self.debug:
1160 1166 pprint(msg)
1161 1167 content = msg['content']
1162 1168 if content['status'] != 'ok':
1163 1169 raise self._unwrap_exception(content)
1164 1170
1165 1171 ar = AsyncHubResult(self, msg_ids=theids)
1166 1172
1167 1173 if block:
1168 1174 ar.wait()
1169 1175
1170 1176 return ar
1171 1177
1172 1178 @spin_first
1173 1179 def result_status(self, msg_ids, status_only=True):
1174 1180 """Check on the status of the result(s) of the apply request with `msg_ids`.
1175 1181
1176 1182 If status_only is False, then the actual results will be retrieved, else
1177 1183 only the status of the results will be checked.
1178 1184
1179 1185 Parameters
1180 1186 ----------
1181 1187
1182 1188 msg_ids : list of msg_ids
1183 1189 if int:
1184 1190 Passed as index to self.history for convenience.
1185 1191 status_only : bool (default: True)
1186 1192 if False:
1187 1193 Retrieve the actual results of completed tasks.
1188 1194
1189 1195 Returns
1190 1196 -------
1191 1197
1192 1198 results : dict
1193 1199 There will always be the keys 'pending' and 'completed', which will
1194 1200 be lists of msg_ids that are incomplete or complete. If `status_only`
1195 1201 is False, then completed results will be keyed by their `msg_id`.
1196 1202 """
1197 1203 if not isinstance(msg_ids, (list,tuple)):
1198 1204 msg_ids = [msg_ids]
1199 1205
1200 1206 theids = []
1201 1207 for msg_id in msg_ids:
1202 1208 if isinstance(msg_id, int):
1203 1209 msg_id = self.history[msg_id]
1204 1210 if not isinstance(msg_id, basestring):
1205 1211 raise TypeError("msg_ids must be str, not %r"%msg_id)
1206 1212 theids.append(msg_id)
1207 1213
1208 1214 completed = []
1209 1215 local_results = {}
1210 1216
1211 1217 # comment this block out to temporarily disable local shortcut:
1212 1218 for msg_id in theids:
1213 1219 if msg_id in self.results:
1214 1220 completed.append(msg_id)
1215 1221 local_results[msg_id] = self.results[msg_id]
1216 1222 theids.remove(msg_id)
1217 1223
1218 1224 if theids: # some not locally cached
1219 1225 content = dict(msg_ids=theids, status_only=status_only)
1220 1226 msg = self.session.send(self._query_socket, "result_request", content=content)
1221 1227 zmq.select([self._query_socket], [], [])
1222 1228 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1223 1229 if self.debug:
1224 1230 pprint(msg)
1225 1231 content = msg['content']
1226 1232 if content['status'] != 'ok':
1227 1233 raise self._unwrap_exception(content)
1228 1234 buffers = msg['buffers']
1229 1235 else:
1230 1236 content = dict(completed=[],pending=[])
1231 1237
1232 1238 content['completed'].extend(completed)
1233 1239
1234 1240 if status_only:
1235 1241 return content
1236 1242
1237 1243 failures = []
1238 1244 # load cached results into result:
1239 1245 content.update(local_results)
1240 1246
1241 1247 # update cache with results:
1242 1248 for msg_id in sorted(theids):
1243 1249 if msg_id in content['completed']:
1244 1250 rec = content[msg_id]
1245 1251 parent = rec['header']
1246 1252 header = rec['result_header']
1247 1253 rcontent = rec['result_content']
1248 1254 iodict = rec['io']
1249 1255 if isinstance(rcontent, str):
1250 1256 rcontent = self.session.unpack(rcontent)
1251 1257
1252 1258 md = self.metadata[msg_id]
1253 1259 md.update(self._extract_metadata(header, parent, rcontent))
1254 1260 md.update(iodict)
1255 1261
1256 1262 if rcontent['status'] == 'ok':
1257 1263 res,buffers = util.unserialize_object(buffers)
1258 1264 else:
1259 1265 print rcontent
1260 1266 res = self._unwrap_exception(rcontent)
1261 1267 failures.append(res)
1262 1268
1263 1269 self.results[msg_id] = res
1264 1270 content[msg_id] = res
1265 1271
1266 1272 if len(theids) == 1 and failures:
1267 1273 raise failures[0]
1268 1274
1269 1275 error.collect_exceptions(failures, "result_status")
1270 1276 return content
1271 1277
1272 1278 @spin_first
1273 1279 def queue_status(self, targets='all', verbose=False):
1274 1280 """Fetch the status of engine queues.
1275 1281
1276 1282 Parameters
1277 1283 ----------
1278 1284
1279 1285 targets : int/str/list of ints/strs
1280 1286 the engines whose states are to be queried.
1281 1287 default : all
1282 1288 verbose : bool
1283 1289 Whether to return lengths only, or lists of ids for each element
1284 1290 """
1285 1291 engine_ids = self._build_targets(targets)[1]
1286 1292 content = dict(targets=engine_ids, verbose=verbose)
1287 1293 self.session.send(self._query_socket, "queue_request", content=content)
1288 1294 idents,msg = self.session.recv(self._query_socket, 0)
1289 1295 if self.debug:
1290 1296 pprint(msg)
1291 1297 content = msg['content']
1292 1298 status = content.pop('status')
1293 1299 if status != 'ok':
1294 1300 raise self._unwrap_exception(content)
1295 1301 content = rekey(content)
1296 1302 if isinstance(targets, int):
1297 1303 return content[targets]
1298 1304 else:
1299 1305 return content
1300 1306
1301 1307 @spin_first
1302 1308 def purge_results(self, jobs=[], targets=[]):
1303 1309 """Tell the Hub to forget results.
1304 1310
1305 1311 Individual results can be purged by msg_id, or the entire
1306 1312 history of specific targets can be purged.
1307 1313
1308 1314 Use `purge_results('all')` to scrub everything from the Hub's db.
1309 1315
1310 1316 Parameters
1311 1317 ----------
1312 1318
1313 1319 jobs : str or list of str or AsyncResult objects
1314 1320 the msg_ids whose results should be forgotten.
1315 1321 targets : int/str/list of ints/strs
1316 1322 The targets, by int_id, whose entire history is to be purged.
1317 1323
1318 1324 default : None
1319 1325 """
1320 1326 if not targets and not jobs:
1321 1327 raise ValueError("Must specify at least one of `targets` and `jobs`")
1322 1328 if targets:
1323 1329 targets = self._build_targets(targets)[1]
1324 1330
1325 1331 # construct msg_ids from jobs
1326 1332 if jobs == 'all':
1327 1333 msg_ids = jobs
1328 1334 else:
1329 1335 msg_ids = []
1330 1336 if isinstance(jobs, (basestring,AsyncResult)):
1331 1337 jobs = [jobs]
1332 1338 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1333 1339 if bad_ids:
1334 1340 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1335 1341 for j in jobs:
1336 1342 if isinstance(j, AsyncResult):
1337 1343 msg_ids.extend(j.msg_ids)
1338 1344 else:
1339 1345 msg_ids.append(j)
1340 1346
1341 1347 content = dict(engine_ids=targets, msg_ids=msg_ids)
1342 1348 self.session.send(self._query_socket, "purge_request", content=content)
1343 1349 idents, msg = self.session.recv(self._query_socket, 0)
1344 1350 if self.debug:
1345 1351 pprint(msg)
1346 1352 content = msg['content']
1347 1353 if content['status'] != 'ok':
1348 1354 raise self._unwrap_exception(content)
1349 1355
1350 1356 @spin_first
1351 1357 def hub_history(self):
1352 1358 """Get the Hub's history
1353 1359
1354 1360 Just like the Client, the Hub has a history, which is a list of msg_ids.
1355 1361 This will contain the history of all clients, and, depending on configuration,
1356 1362 may contain history across multiple cluster sessions.
1357 1363
1358 1364 Any msg_id returned here is a valid argument to `get_result`.
1359 1365
1360 1366 Returns
1361 1367 -------
1362 1368
1363 1369 msg_ids : list of strs
1364 1370 list of all msg_ids, ordered by task submission time.
1365 1371 """
1366 1372
1367 1373 self.session.send(self._query_socket, "history_request", content={})
1368 1374 idents, msg = self.session.recv(self._query_socket, 0)
1369 1375
1370 1376 if self.debug:
1371 1377 pprint(msg)
1372 1378 content = msg['content']
1373 1379 if content['status'] != 'ok':
1374 1380 raise self._unwrap_exception(content)
1375 1381 else:
1376 1382 return content['history']
1377 1383
1378 1384 @spin_first
1379 1385 def db_query(self, query, keys=None):
1380 1386 """Query the Hub's TaskRecord database
1381 1387
1382 1388 This will return a list of task record dicts that match `query`
1383 1389
1384 1390 Parameters
1385 1391 ----------
1386 1392
1387 1393 query : mongodb query dict
1388 1394 The search dict. See mongodb query docs for details.
1389 1395 keys : list of strs [optional]
1390 1396 The subset of keys to be returned. The default is to fetch everything but buffers.
1391 1397 'msg_id' will *always* be included.
1392 1398 """
1393 1399 if isinstance(keys, basestring):
1394 1400 keys = [keys]
1395 1401 content = dict(query=query, keys=keys)
1396 1402 self.session.send(self._query_socket, "db_request", content=content)
1397 1403 idents, msg = self.session.recv(self._query_socket, 0)
1398 1404 if self.debug:
1399 1405 pprint(msg)
1400 1406 content = msg['content']
1401 1407 if content['status'] != 'ok':
1402 1408 raise self._unwrap_exception(content)
1403 1409
1404 1410 records = content['records']
1405 1411
1406 1412 buffer_lens = content['buffer_lens']
1407 1413 result_buffer_lens = content['result_buffer_lens']
1408 1414 buffers = msg['buffers']
1409 1415 has_bufs = buffer_lens is not None
1410 1416 has_rbufs = result_buffer_lens is not None
1411 1417 for i,rec in enumerate(records):
1412 1418 # relink buffers
1413 1419 if has_bufs:
1414 1420 blen = buffer_lens[i]
1415 1421 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1416 1422 if has_rbufs:
1417 1423 blen = result_buffer_lens[i]
1418 1424 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1419 1425
1420 1426 return records
1421 1427
1422 1428 __all__ = [ 'Client' ]
@@ -1,165 +1,165 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes used in scattering and gathering sequences.
4 4
5 5 Scattering consists of partitioning a sequence and sending the various
6 6 pieces to individual nodes in a cluster.
7 7
8 8
9 9 Authors:
10 10
11 11 * Brian Granger
12 12 * MinRK
13 13
14 14 """
15 15
16 __docformat__ = "restructuredtext en"
17
18 16 #-------------------------------------------------------------------------------
19 17 # Copyright (C) 2008-2011 The IPython Development Team
20 18 #
21 19 # Distributed under the terms of the BSD License. The full license is in
22 20 # the file COPYING, distributed as part of this software.
23 21 #-------------------------------------------------------------------------------
24 22
25 23 #-------------------------------------------------------------------------------
26 24 # Imports
27 25 #-------------------------------------------------------------------------------
28 26
27 from __future__ import division
28
29 29 import types
30 30
31 31 from IPython.utils.data import flatten as utils_flatten
32 32
33 33 #-------------------------------------------------------------------------------
34 34 # Figure out which array packages are present and their array types
35 35 #-------------------------------------------------------------------------------
36 36
37 37 arrayModules = []
38 38 try:
39 39 import Numeric
40 40 except ImportError:
41 41 pass
42 42 else:
43 43 arrayModules.append({'module':Numeric, 'type':Numeric.arraytype})
44 44 try:
45 45 import numpy
46 46 except ImportError:
47 47 pass
48 48 else:
49 49 arrayModules.append({'module':numpy, 'type':numpy.ndarray})
50 50 try:
51 51 import numarray
52 52 except ImportError:
53 53 pass
54 54 else:
55 55 arrayModules.append({'module':numarray,
56 56 'type':numarray.numarraycore.NumArray})
57 57
58 58 class Map:
59 59 """A class for partitioning a sequence using a map."""
60 60
61 61 def getPartition(self, seq, p, q):
62 62 """Returns the pth partition of q partitions of seq."""
63 63
64 64 # Test for error conditions here
65 65 if p<0 or p>=q:
66 66 print "No partition exists."
67 67 return
68 68
69 69 remainder = len(seq)%q
70 basesize = len(seq)/q
70 basesize = len(seq)//q
71 71 hi = []
72 72 lo = []
73 73 for n in range(q):
74 74 if n < remainder:
75 75 lo.append(n * (basesize + 1))
76 76 hi.append(lo[-1] + basesize + 1)
77 77 else:
78 78 lo.append(n*basesize + remainder)
79 79 hi.append(lo[-1] + basesize)
80 80
81 81
82 82 result = seq[lo[p]:hi[p]]
83 83 return result
84 84
85 85 def joinPartitions(self, listOfPartitions):
86 86 return self.concatenate(listOfPartitions)
87 87
88 88 def concatenate(self, listOfPartitions):
89 89 testObject = listOfPartitions[0]
90 90 # First see if we have a known array type
91 91 for m in arrayModules:
92 92 #print m
93 93 if isinstance(testObject, m['type']):
94 94 return m['module'].concatenate(listOfPartitions)
95 95 # Next try for Python sequence types
96 96 if isinstance(testObject, (types.ListType, types.TupleType)):
97 97 return utils_flatten(listOfPartitions)
98 98 # If we have scalars, just return listOfPartitions
99 99 return listOfPartitions
100 100
101 101 class RoundRobinMap(Map):
102 102 """Partitions a sequence in a roun robin fashion.
103 103
104 104 This currently does not work!
105 105 """
106 106
107 107 def getPartition(self, seq, p, q):
108 108 # if not isinstance(seq,(list,tuple)):
109 109 # raise NotImplementedError("cannot RR partition type %s"%type(seq))
110 110 return seq[p:len(seq):q]
111 111 #result = []
112 112 #for i in range(p,len(seq),q):
113 113 # result.append(seq[i])
114 114 #return result
115 115
116 116 def joinPartitions(self, listOfPartitions):
117 117 testObject = listOfPartitions[0]
118 118 # First see if we have a known array type
119 119 for m in arrayModules:
120 120 #print m
121 121 if isinstance(testObject, m['type']):
122 122 return self.flatten_array(m['type'], listOfPartitions)
123 123 if isinstance(testObject, (types.ListType, types.TupleType)):
124 124 return self.flatten_list(listOfPartitions)
125 125 return listOfPartitions
126 126
127 127 def flatten_array(self, klass, listOfPartitions):
128 128 test = listOfPartitions[0]
129 129 shape = list(test.shape)
130 130 shape[0] = sum([ p.shape[0] for p in listOfPartitions])
131 131 A = klass(shape)
132 132 N = shape[0]
133 133 q = len(listOfPartitions)
134 134 for p,part in enumerate(listOfPartitions):
135 135 A[p:N:q] = part
136 136 return A
137 137
138 138 def flatten_list(self, listOfPartitions):
139 139 flat = []
140 140 for i in range(len(listOfPartitions[0])):
141 141 flat.extend([ part[i] for part in listOfPartitions if len(part) > i ])
142 142 return flat
143 143 #lengths = [len(x) for x in listOfPartitions]
144 144 #maxPartitionLength = len(listOfPartitions[0])
145 145 #numberOfPartitions = len(listOfPartitions)
146 146 #concat = self.concatenate(listOfPartitions)
147 147 #totalLength = len(concat)
148 148 #result = []
149 149 #for i in range(maxPartitionLength):
150 150 # result.append(concat[i:totalLength:maxPartitionLength])
151 151 # return self.concatenate(listOfPartitions)
152 152
153 153 def mappable(obj):
154 154 """return whether an object is mappable or not."""
155 155 if isinstance(obj, (tuple,list)):
156 156 return True
157 157 for m in arrayModules:
158 158 if isinstance(obj,m['type']):
159 159 return True
160 160 return False
161 161
162 162 dists = {'b':Map,'r':RoundRobinMap}
163 163
164 164
165 165
@@ -1,206 +1,211 b''
1 1 """Remote Functions and decorators for Views.
2 2
3 3 Authors:
4 4
5 5 * Brian Granger
6 6 * Min RK
7 7 """
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 from __future__ import division
20
19 21 import warnings
20 22
21 23 from IPython.testing.skipdoctest import skip_doctest
22 24
23 25 from . import map as Map
24 26 from .asyncresult import AsyncMapResult
25 27
26 28 #-----------------------------------------------------------------------------
27 29 # Decorators
28 30 #-----------------------------------------------------------------------------
29 31
30 32 @skip_doctest
31 33 def remote(view, block=None, **flags):
32 34 """Turn a function into a remote function.
33 35
34 36 This method can be used for map:
35 37
36 38 In [1]: @remote(view,block=True)
37 39 ...: def func(a):
38 40 ...: pass
39 41 """
40 42
41 43 def remote_function(f):
42 44 return RemoteFunction(view, f, block=block, **flags)
43 45 return remote_function
44 46
45 47 @skip_doctest
46 48 def parallel(view, dist='b', block=None, **flags):
47 49 """Turn a function into a parallel remote function.
48 50
49 51 This method can be used for map:
50 52
51 53 In [1]: @parallel(view, block=True)
52 54 ...: def func(a):
53 55 ...: pass
54 56 """
55 57
56 58 def parallel_function(f):
57 59 return ParallelFunction(view, f, dist=dist, block=block, **flags)
58 60 return parallel_function
59 61
60 62 #--------------------------------------------------------------------------
61 63 # Classes
62 64 #--------------------------------------------------------------------------
63 65
64 66 class RemoteFunction(object):
65 67 """Turn an existing function into a remote function.
66 68
67 69 Parameters
68 70 ----------
69 71
70 72 view : View instance
71 73 The view to be used for execution
72 74 f : callable
73 75 The function to be wrapped into a remote function
74 76 block : bool [default: None]
75 77 Whether to wait for results or not. The default behavior is
76 78 to use the current `block` attribute of `view`
77 79
78 80 **flags : remaining kwargs are passed to View.temp_flags
79 81 """
80 82
81 83 view = None # the remote connection
82 84 func = None # the wrapped function
83 85 block = None # whether to block
84 86 flags = None # dict of extra kwargs for temp_flags
85 87
86 88 def __init__(self, view, f, block=None, **flags):
87 89 self.view = view
88 90 self.func = f
89 91 self.block=block
90 92 self.flags=flags
91 93
92 94 def __call__(self, *args, **kwargs):
93 95 block = self.view.block if self.block is None else self.block
94 96 with self.view.temp_flags(block=block, **self.flags):
95 97 return self.view.apply(self.func, *args, **kwargs)
96 98
97 99
98 100 class ParallelFunction(RemoteFunction):
99 101 """Class for mapping a function to sequences.
100 102
101 103 This will distribute the sequences according the a mapper, and call
102 104 the function on each sub-sequence. If called via map, then the function
103 105 will be called once on each element, rather that each sub-sequence.
104 106
105 107 Parameters
106 108 ----------
107 109
108 110 view : View instance
109 111 The view to be used for execution
110 112 f : callable
111 113 The function to be wrapped into a remote function
112 114 dist : str [default: 'b']
113 115 The key for which mapObject to use to distribute sequences
114 116 options are:
115 117 * 'b' : use contiguous chunks in order
116 118 * 'r' : use round-robin striping
117 119 block : bool [default: None]
118 120 Whether to wait for results or not. The default behavior is
119 121 to use the current `block` attribute of `view`
120 122 chunksize : int or None
121 123 The size of chunk to use when breaking up sequences in a load-balanced manner
122 124 **flags : remaining kwargs are passed to View.temp_flags
123 125 """
124 126
125 127 chunksize=None
126 128 mapObject=None
127 129
128 130 def __init__(self, view, f, dist='b', block=None, chunksize=None, **flags):
129 131 super(ParallelFunction, self).__init__(view, f, block=block, **flags)
130 132 self.chunksize = chunksize
131 133
132 134 mapClass = Map.dists[dist]
133 135 self.mapObject = mapClass()
134 136
135 137 def __call__(self, *sequences):
136 138 # check that the length of sequences match
137 139 len_0 = len(sequences[0])
138 140 for s in sequences:
139 141 if len(s)!=len_0:
140 142 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
141 143 raise ValueError(msg)
142 144 balanced = 'Balanced' in self.view.__class__.__name__
143 145 if balanced:
144 146 if self.chunksize:
145 nparts = len_0/self.chunksize + int(len_0%self.chunksize > 0)
147 nparts = len_0//self.chunksize + int(len_0%self.chunksize > 0)
146 148 else:
147 149 nparts = len_0
148 150 targets = [None]*nparts
149 151 else:
150 152 if self.chunksize:
151 153 warnings.warn("`chunksize` is ignored unless load balancing", UserWarning)
152 154 # multiplexed:
153 155 targets = self.view.targets
154 156 nparts = len(targets)
155 157
156 158 msg_ids = []
157 159 # my_f = lambda *a: map(self.func, *a)
158 160 client = self.view.client
159 161 for index, t in enumerate(targets):
160 162 args = []
161 163 for seq in sequences:
162 164 part = self.mapObject.getPartition(seq, index, nparts)
163 165 if len(part) == 0:
164 166 continue
165 167 else:
166 168 args.append(part)
167 169 if not args:
168 170 continue
169 171
170 172 # print (args)
171 173 if hasattr(self, '_map'):
172 f = map
174 if sys.version_info[0] >= 3:
175 f = lambda f, *sequences: list(map(f, *sequences))
176 else:
177 f = map
173 178 args = [self.func]+args
174 179 else:
175 180 f=self.func
176 181
177 182 view = self.view if balanced else client[t]
178 183 with view.temp_flags(block=False, **self.flags):
179 184 ar = view.apply(f, *args)
180 185
181 186 msg_ids.append(ar.msg_ids[0])
182 187
183 188 r = AsyncMapResult(self.view.client, msg_ids, self.mapObject, fname=self.func.__name__)
184 189
185 190 if self.block:
186 191 try:
187 192 return r.get()
188 193 except KeyboardInterrupt:
189 194 return r
190 195 else:
191 196 return r
192 197
193 198 def map(self, *sequences):
194 199 """call a function on each element of a sequence remotely.
195 200 This should behave very much like the builtin map, but return an AsyncMapResult
196 201 if self.block is False.
197 202 """
198 203 # set _map as a flag for use inside self.__call__
199 204 self._map = True
200 205 try:
201 206 ret = self.__call__(*sequences)
202 207 finally:
203 208 del self._map
204 209 return ret
205 210
206 211 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction']
@@ -1,1046 +1,1048 b''
1 1 """Views of remote engines.
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import imp
19 19 import sys
20 20 import warnings
21 21 from contextlib import contextmanager
22 22 from types import ModuleType
23 23
24 24 import zmq
25 25
26 26 from IPython.testing.skipdoctest import skip_doctest
27 27 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance, CFloat, CInt
28 28 from IPython.external.decorator import decorator
29 29
30 30 from IPython.parallel import util
31 31 from IPython.parallel.controller.dependency import Dependency, dependent
32 32
33 33 from . import map as Map
34 34 from .asyncresult import AsyncResult, AsyncMapResult
35 35 from .remotefunction import ParallelFunction, parallel, remote
36 36
37 37 #-----------------------------------------------------------------------------
38 38 # Decorators
39 39 #-----------------------------------------------------------------------------
40 40
41 41 @decorator
42 42 def save_ids(f, self, *args, **kwargs):
43 43 """Keep our history and outstanding attributes up to date after a method call."""
44 44 n_previous = len(self.client.history)
45 45 try:
46 46 ret = f(self, *args, **kwargs)
47 47 finally:
48 48 nmsgs = len(self.client.history) - n_previous
49 49 msg_ids = self.client.history[-nmsgs:]
50 50 self.history.extend(msg_ids)
51 51 map(self.outstanding.add, msg_ids)
52 52 return ret
53 53
54 54 @decorator
55 55 def sync_results(f, self, *args, **kwargs):
56 56 """sync relevant results from self.client to our results attribute."""
57 57 ret = f(self, *args, **kwargs)
58 58 delta = self.outstanding.difference(self.client.outstanding)
59 59 completed = self.outstanding.intersection(delta)
60 60 self.outstanding = self.outstanding.difference(completed)
61 61 for msg_id in completed:
62 62 self.results[msg_id] = self.client.results[msg_id]
63 63 return ret
64 64
65 65 @decorator
66 66 def spin_after(f, self, *args, **kwargs):
67 67 """call spin after the method."""
68 68 ret = f(self, *args, **kwargs)
69 69 self.spin()
70 70 return ret
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # Classes
74 74 #-----------------------------------------------------------------------------
75 75
76 76 @skip_doctest
77 77 class View(HasTraits):
78 78 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
79 79
80 80 Don't use this class, use subclasses.
81 81
82 82 Methods
83 83 -------
84 84
85 85 spin
86 86 flushes incoming results and registration state changes
87 87 control methods spin, and requesting `ids` also ensures up to date
88 88
89 89 wait
90 90 wait on one or more msg_ids
91 91
92 92 execution methods
93 93 apply
94 94 legacy: execute, run
95 95
96 96 data movement
97 97 push, pull, scatter, gather
98 98
99 99 query methods
100 100 get_result, queue_status, purge_results, result_status
101 101
102 102 control methods
103 103 abort, shutdown
104 104
105 105 """
106 106 # flags
107 107 block=Bool(False)
108 108 track=Bool(True)
109 109 targets = Any()
110 110
111 111 history=List()
112 112 outstanding = Set()
113 113 results = Dict()
114 114 client = Instance('IPython.parallel.Client')
115 115
116 116 _socket = Instance('zmq.Socket')
117 117 _flag_names = List(['targets', 'block', 'track'])
118 118 _targets = Any()
119 119 _idents = Any()
120 120
121 121 def __init__(self, client=None, socket=None, **flags):
122 122 super(View, self).__init__(client=client, _socket=socket)
123 123 self.block = client.block
124 124
125 125 self.set_flags(**flags)
126 126
127 127 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
128 128
129 129
130 130 def __repr__(self):
131 131 strtargets = str(self.targets)
132 132 if len(strtargets) > 16:
133 133 strtargets = strtargets[:12]+'...]'
134 134 return "<%s %s>"%(self.__class__.__name__, strtargets)
135 135
136 136 def set_flags(self, **kwargs):
137 137 """set my attribute flags by keyword.
138 138
139 139 Views determine behavior with a few attributes (`block`, `track`, etc.).
140 140 These attributes can be set all at once by name with this method.
141 141
142 142 Parameters
143 143 ----------
144 144
145 145 block : bool
146 146 whether to wait for results
147 147 track : bool
148 148 whether to create a MessageTracker to allow the user to
149 149 safely edit after arrays and buffers during non-copying
150 150 sends.
151 151 """
152 152 for name, value in kwargs.iteritems():
153 153 if name not in self._flag_names:
154 154 raise KeyError("Invalid name: %r"%name)
155 155 else:
156 156 setattr(self, name, value)
157 157
158 158 @contextmanager
159 159 def temp_flags(self, **kwargs):
160 160 """temporarily set flags, for use in `with` statements.
161 161
162 162 See set_flags for permanent setting of flags
163 163
164 164 Examples
165 165 --------
166 166
167 167 >>> view.track=False
168 168 ...
169 169 >>> with view.temp_flags(track=True):
170 170 ... ar = view.apply(dostuff, my_big_array)
171 171 ... ar.tracker.wait() # wait for send to finish
172 172 >>> view.track
173 173 False
174 174
175 175 """
176 176 # preflight: save flags, and set temporaries
177 177 saved_flags = {}
178 178 for f in self._flag_names:
179 179 saved_flags[f] = getattr(self, f)
180 180 self.set_flags(**kwargs)
181 181 # yield to the with-statement block
182 182 try:
183 183 yield
184 184 finally:
185 185 # postflight: restore saved flags
186 186 self.set_flags(**saved_flags)
187 187
188 188
189 189 #----------------------------------------------------------------
190 190 # apply
191 191 #----------------------------------------------------------------
192 192
193 193 @sync_results
194 194 @save_ids
195 195 def _really_apply(self, f, args, kwargs, block=None, **options):
196 196 """wrapper for client.send_apply_message"""
197 197 raise NotImplementedError("Implement in subclasses")
198 198
199 199 def apply(self, f, *args, **kwargs):
200 200 """calls f(*args, **kwargs) on remote engines, returning the result.
201 201
202 202 This method sets all apply flags via this View's attributes.
203 203
204 204 if self.block is False:
205 205 returns AsyncResult
206 206 else:
207 207 returns actual result of f(*args, **kwargs)
208 208 """
209 209 return self._really_apply(f, args, kwargs)
210 210
211 211 def apply_async(self, f, *args, **kwargs):
212 212 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
213 213
214 214 returns AsyncResult
215 215 """
216 216 return self._really_apply(f, args, kwargs, block=False)
217 217
218 218 @spin_after
219 219 def apply_sync(self, f, *args, **kwargs):
220 220 """calls f(*args, **kwargs) on remote engines in a blocking manner,
221 221 returning the result.
222 222
223 223 returns: actual result of f(*args, **kwargs)
224 224 """
225 225 return self._really_apply(f, args, kwargs, block=True)
226 226
227 227 #----------------------------------------------------------------
228 228 # wrappers for client and control methods
229 229 #----------------------------------------------------------------
230 230 @sync_results
231 231 def spin(self):
232 232 """spin the client, and sync"""
233 233 self.client.spin()
234 234
235 235 @sync_results
236 236 def wait(self, jobs=None, timeout=-1):
237 237 """waits on one or more `jobs`, for up to `timeout` seconds.
238 238
239 239 Parameters
240 240 ----------
241 241
242 242 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
243 243 ints are indices to self.history
244 244 strs are msg_ids
245 245 default: wait on all outstanding messages
246 246 timeout : float
247 247 a time in seconds, after which to give up.
248 248 default is -1, which means no timeout
249 249
250 250 Returns
251 251 -------
252 252
253 253 True : when all msg_ids are done
254 254 False : timeout reached, some msg_ids still outstanding
255 255 """
256 256 if jobs is None:
257 257 jobs = self.history
258 258 return self.client.wait(jobs, timeout)
259 259
260 260 def abort(self, jobs=None, targets=None, block=None):
261 261 """Abort jobs on my engines.
262 262
263 263 Parameters
264 264 ----------
265 265
266 266 jobs : None, str, list of strs, optional
267 267 if None: abort all jobs.
268 268 else: abort specific msg_id(s).
269 269 """
270 270 block = block if block is not None else self.block
271 271 targets = targets if targets is not None else self.targets
272 272 return self.client.abort(jobs=jobs, targets=targets, block=block)
273 273
274 274 def queue_status(self, targets=None, verbose=False):
275 275 """Fetch the Queue status of my engines"""
276 276 targets = targets if targets is not None else self.targets
277 277 return self.client.queue_status(targets=targets, verbose=verbose)
278 278
279 279 def purge_results(self, jobs=[], targets=[]):
280 280 """Instruct the controller to forget specific results."""
281 281 if targets is None or targets == 'all':
282 282 targets = self.targets
283 283 return self.client.purge_results(jobs=jobs, targets=targets)
284 284
285 285 def shutdown(self, targets=None, restart=False, hub=False, block=None):
286 286 """Terminates one or more engine processes, optionally including the hub.
287 287 """
288 288 block = self.block if block is None else block
289 289 if targets is None or targets == 'all':
290 290 targets = self.targets
291 291 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
292 292
293 293 @spin_after
294 294 def get_result(self, indices_or_msg_ids=None):
295 295 """return one or more results, specified by history index or msg_id.
296 296
297 297 See client.get_result for details.
298 298
299 299 """
300 300
301 301 if indices_or_msg_ids is None:
302 302 indices_or_msg_ids = -1
303 303 if isinstance(indices_or_msg_ids, int):
304 304 indices_or_msg_ids = self.history[indices_or_msg_ids]
305 305 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
306 306 indices_or_msg_ids = list(indices_or_msg_ids)
307 307 for i,index in enumerate(indices_or_msg_ids):
308 308 if isinstance(index, int):
309 309 indices_or_msg_ids[i] = self.history[index]
310 310 return self.client.get_result(indices_or_msg_ids)
311 311
312 312 #-------------------------------------------------------------------
313 313 # Map
314 314 #-------------------------------------------------------------------
315 315
316 316 def map(self, f, *sequences, **kwargs):
317 317 """override in subclasses"""
318 318 raise NotImplementedError
319 319
320 320 def map_async(self, f, *sequences, **kwargs):
321 321 """Parallel version of builtin `map`, using this view's engines.
322 322
323 323 This is equivalent to map(...block=False)
324 324
325 325 See `self.map` for details.
326 326 """
327 327 if 'block' in kwargs:
328 328 raise TypeError("map_async doesn't take a `block` keyword argument.")
329 329 kwargs['block'] = False
330 330 return self.map(f,*sequences,**kwargs)
331 331
332 332 def map_sync(self, f, *sequences, **kwargs):
333 333 """Parallel version of builtin `map`, using this view's engines.
334 334
335 335 This is equivalent to map(...block=True)
336 336
337 337 See `self.map` for details.
338 338 """
339 339 if 'block' in kwargs:
340 340 raise TypeError("map_sync doesn't take a `block` keyword argument.")
341 341 kwargs['block'] = True
342 342 return self.map(f,*sequences,**kwargs)
343 343
344 344 def imap(self, f, *sequences, **kwargs):
345 345 """Parallel version of `itertools.imap`.
346 346
347 347 See `self.map` for details.
348 348
349 349 """
350 350
351 351 return iter(self.map_async(f,*sequences, **kwargs))
352 352
353 353 #-------------------------------------------------------------------
354 354 # Decorators
355 355 #-------------------------------------------------------------------
356 356
357 357 def remote(self, block=True, **flags):
358 358 """Decorator for making a RemoteFunction"""
359 359 block = self.block if block is None else block
360 360 return remote(self, block=block, **flags)
361 361
362 362 def parallel(self, dist='b', block=None, **flags):
363 363 """Decorator for making a ParallelFunction"""
364 364 block = self.block if block is None else block
365 365 return parallel(self, dist=dist, block=block, **flags)
366 366
367 367 @skip_doctest
368 368 class DirectView(View):
369 369 """Direct Multiplexer View of one or more engines.
370 370
371 371 These are created via indexed access to a client:
372 372
373 373 >>> dv_1 = client[1]
374 374 >>> dv_all = client[:]
375 375 >>> dv_even = client[::2]
376 376 >>> dv_some = client[1:3]
377 377
378 378 This object provides dictionary access to engine namespaces:
379 379
380 380 # push a=5:
381 381 >>> dv['a'] = 5
382 382 # pull 'foo':
383 383 >>> db['foo']
384 384
385 385 """
386 386
387 387 def __init__(self, client=None, socket=None, targets=None):
388 388 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
389 389
390 390 @property
391 391 def importer(self):
392 392 """sync_imports(local=True) as a property.
393 393
394 394 See sync_imports for details.
395 395
396 396 """
397 397 return self.sync_imports(True)
398 398
399 399 @contextmanager
400 400 def sync_imports(self, local=True):
401 401 """Context Manager for performing simultaneous local and remote imports.
402 402
403 403 'import x as y' will *not* work. The 'as y' part will simply be ignored.
404 404
405 405 >>> with view.sync_imports():
406 406 ... from numpy import recarray
407 407 importing recarray from numpy on engine(s)
408 408
409 409 """
410 410 import __builtin__
411 411 local_import = __builtin__.__import__
412 412 modules = set()
413 413 results = []
414 414 @util.interactive
415 415 def remote_import(name, fromlist, level):
416 416 """the function to be passed to apply, that actually performs the import
417 417 on the engine, and loads up the user namespace.
418 418 """
419 419 import sys
420 420 user_ns = globals()
421 421 mod = __import__(name, fromlist=fromlist, level=level)
422 422 if fromlist:
423 423 for key in fromlist:
424 424 user_ns[key] = getattr(mod, key)
425 425 else:
426 426 user_ns[name] = sys.modules[name]
427 427
428 428 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
429 429 """the drop-in replacement for __import__, that optionally imports
430 430 locally as well.
431 431 """
432 432 # don't override nested imports
433 433 save_import = __builtin__.__import__
434 434 __builtin__.__import__ = local_import
435 435
436 436 if imp.lock_held():
437 437 # this is a side-effect import, don't do it remotely, or even
438 438 # ignore the local effects
439 439 return local_import(name, globals, locals, fromlist, level)
440 440
441 441 imp.acquire_lock()
442 442 if local:
443 443 mod = local_import(name, globals, locals, fromlist, level)
444 444 else:
445 445 raise NotImplementedError("remote-only imports not yet implemented")
446 446 imp.release_lock()
447 447
448 448 key = name+':'+','.join(fromlist or [])
449 449 if level == -1 and key not in modules:
450 450 modules.add(key)
451 451 if fromlist:
452 452 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
453 453 else:
454 454 print "importing %s on engine(s)"%name
455 455 results.append(self.apply_async(remote_import, name, fromlist, level))
456 456 # restore override
457 457 __builtin__.__import__ = save_import
458 458
459 459 return mod
460 460
461 461 # override __import__
462 462 __builtin__.__import__ = view_import
463 463 try:
464 464 # enter the block
465 465 yield
466 466 except ImportError:
467 467 if not local:
468 468 # ignore import errors if not doing local imports
469 469 pass
470 470 finally:
471 471 # always restore __import__
472 472 __builtin__.__import__ = local_import
473 473
474 474 for r in results:
475 475 # raise possible remote ImportErrors here
476 476 r.get()
477 477
478 478
479 479 @sync_results
480 480 @save_ids
481 481 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
482 482 """calls f(*args, **kwargs) on remote engines, returning the result.
483 483
484 484 This method sets all of `apply`'s flags via this View's attributes.
485 485
486 486 Parameters
487 487 ----------
488 488
489 489 f : callable
490 490
491 491 args : list [default: empty]
492 492
493 493 kwargs : dict [default: empty]
494 494
495 495 targets : target list [default: self.targets]
496 496 where to run
497 497 block : bool [default: self.block]
498 498 whether to block
499 499 track : bool [default: self.track]
500 500 whether to ask zmq to track the message, for safe non-copying sends
501 501
502 502 Returns
503 503 -------
504 504
505 505 if self.block is False:
506 506 returns AsyncResult
507 507 else:
508 508 returns actual result of f(*args, **kwargs) on the engine(s)
509 509 This will be a list of self.targets is also a list (even length 1), or
510 510 the single result if self.targets is an integer engine id
511 511 """
512 512 args = [] if args is None else args
513 513 kwargs = {} if kwargs is None else kwargs
514 514 block = self.block if block is None else block
515 515 track = self.track if track is None else track
516 516 targets = self.targets if targets is None else targets
517 517
518 518 _idents = self.client._build_targets(targets)[0]
519 519 msg_ids = []
520 520 trackers = []
521 521 for ident in _idents:
522 522 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
523 523 ident=ident)
524 524 if track:
525 525 trackers.append(msg['tracker'])
526 526 msg_ids.append(msg['msg_id'])
527 527 tracker = None if track is False else zmq.MessageTracker(*trackers)
528 528 ar = AsyncResult(self.client, msg_ids, fname=f.__name__, targets=targets, tracker=tracker)
529 529 if block:
530 530 try:
531 531 return ar.get()
532 532 except KeyboardInterrupt:
533 533 pass
534 534 return ar
535 535
536 536 @spin_after
537 537 def map(self, f, *sequences, **kwargs):
538 538 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
539 539
540 540 Parallel version of builtin `map`, using this View's `targets`.
541 541
542 542 There will be one task per target, so work will be chunked
543 543 if the sequences are longer than `targets`.
544 544
545 545 Results can be iterated as they are ready, but will become available in chunks.
546 546
547 547 Parameters
548 548 ----------
549 549
550 550 f : callable
551 551 function to be mapped
552 552 *sequences: one or more sequences of matching length
553 553 the sequences to be distributed and passed to `f`
554 554 block : bool
555 555 whether to wait for the result or not [default self.block]
556 556
557 557 Returns
558 558 -------
559 559
560 560 if block=False:
561 561 AsyncMapResult
562 562 An object like AsyncResult, but which reassembles the sequence of results
563 563 into a single list. AsyncMapResults can be iterated through before all
564 564 results are complete.
565 565 else:
566 566 list
567 567 the result of map(f,*sequences)
568 568 """
569 569
570 570 block = kwargs.pop('block', self.block)
571 571 for k in kwargs.keys():
572 572 if k not in ['block', 'track']:
573 573 raise TypeError("invalid keyword arg, %r"%k)
574 574
575 575 assert len(sequences) > 0, "must have some sequences to map onto!"
576 576 pf = ParallelFunction(self, f, block=block, **kwargs)
577 577 return pf.map(*sequences)
578 578
579 579 def execute(self, code, targets=None, block=None):
580 580 """Executes `code` on `targets` in blocking or nonblocking manner.
581 581
582 582 ``execute`` is always `bound` (affects engine namespace)
583 583
584 584 Parameters
585 585 ----------
586 586
587 587 code : str
588 588 the code string to be executed
589 589 block : bool
590 590 whether or not to wait until done to return
591 591 default: self.block
592 592 """
593 593 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
594 594
595 595 def run(self, filename, targets=None, block=None):
596 596 """Execute contents of `filename` on my engine(s).
597 597
598 598 This simply reads the contents of the file and calls `execute`.
599 599
600 600 Parameters
601 601 ----------
602 602
603 603 filename : str
604 604 The path to the file
605 605 targets : int/str/list of ints/strs
606 606 the engines on which to execute
607 607 default : all
608 608 block : bool
609 609 whether or not to wait until done
610 610 default: self.block
611 611
612 612 """
613 613 with open(filename, 'r') as f:
614 614 # add newline in case of trailing indented whitespace
615 615 # which will cause SyntaxError
616 616 code = f.read()+'\n'
617 617 return self.execute(code, block=block, targets=targets)
618 618
619 619 def update(self, ns):
620 620 """update remote namespace with dict `ns`
621 621
622 622 See `push` for details.
623 623 """
624 624 return self.push(ns, block=self.block, track=self.track)
625 625
626 626 def push(self, ns, targets=None, block=None, track=None):
627 627 """update remote namespace with dict `ns`
628 628
629 629 Parameters
630 630 ----------
631 631
632 632 ns : dict
633 633 dict of keys with which to update engine namespace(s)
634 634 block : bool [default : self.block]
635 635 whether to wait to be notified of engine receipt
636 636
637 637 """
638 638
639 639 block = block if block is not None else self.block
640 640 track = track if track is not None else self.track
641 641 targets = targets if targets is not None else self.targets
642 642 # applier = self.apply_sync if block else self.apply_async
643 643 if not isinstance(ns, dict):
644 644 raise TypeError("Must be a dict, not %s"%type(ns))
645 645 return self._really_apply(util._push, (ns,), block=block, track=track, targets=targets)
646 646
647 647 def get(self, key_s):
648 648 """get object(s) by `key_s` from remote namespace
649 649
650 650 see `pull` for details.
651 651 """
652 652 # block = block if block is not None else self.block
653 653 return self.pull(key_s, block=True)
654 654
655 655 def pull(self, names, targets=None, block=None):
656 656 """get object(s) by `name` from remote namespace
657 657
658 658 will return one object if it is a key.
659 659 can also take a list of keys, in which case it will return a list of objects.
660 660 """
661 661 block = block if block is not None else self.block
662 662 targets = targets if targets is not None else self.targets
663 663 applier = self.apply_sync if block else self.apply_async
664 664 if isinstance(names, basestring):
665 665 pass
666 666 elif isinstance(names, (list,tuple,set)):
667 667 for key in names:
668 668 if not isinstance(key, basestring):
669 669 raise TypeError("keys must be str, not type %r"%type(key))
670 670 else:
671 671 raise TypeError("names must be strs, not %r"%names)
672 672 return self._really_apply(util._pull, (names,), block=block, targets=targets)
673 673
674 674 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
675 675 """
676 676 Partition a Python sequence and send the partitions to a set of engines.
677 677 """
678 678 block = block if block is not None else self.block
679 679 track = track if track is not None else self.track
680 680 targets = targets if targets is not None else self.targets
681 681
682 682 mapObject = Map.dists[dist]()
683 683 nparts = len(targets)
684 684 msg_ids = []
685 685 trackers = []
686 686 for index, engineid in enumerate(targets):
687 687 partition = mapObject.getPartition(seq, index, nparts)
688 688 if flatten and len(partition) == 1:
689 689 ns = {key: partition[0]}
690 690 else:
691 691 ns = {key: partition}
692 692 r = self.push(ns, block=False, track=track, targets=engineid)
693 693 msg_ids.extend(r.msg_ids)
694 694 if track:
695 695 trackers.append(r._tracker)
696 696
697 697 if track:
698 698 tracker = zmq.MessageTracker(*trackers)
699 699 else:
700 700 tracker = None
701 701
702 702 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
703 703 if block:
704 704 r.wait()
705 705 else:
706 706 return r
707 707
708 708 @sync_results
709 709 @save_ids
710 710 def gather(self, key, dist='b', targets=None, block=None):
711 711 """
712 712 Gather a partitioned sequence on a set of engines as a single local seq.
713 713 """
714 714 block = block if block is not None else self.block
715 715 targets = targets if targets is not None else self.targets
716 716 mapObject = Map.dists[dist]()
717 717 msg_ids = []
718 718
719 719 for index, engineid in enumerate(targets):
720 720 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
721 721
722 722 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
723 723
724 724 if block:
725 725 try:
726 726 return r.get()
727 727 except KeyboardInterrupt:
728 728 pass
729 729 return r
730 730
731 731 def __getitem__(self, key):
732 732 return self.get(key)
733 733
734 734 def __setitem__(self,key, value):
735 735 self.update({key:value})
736 736
737 737 def clear(self, targets=None, block=False):
738 738 """Clear the remote namespaces on my engines."""
739 739 block = block if block is not None else self.block
740 740 targets = targets if targets is not None else self.targets
741 741 return self.client.clear(targets=targets, block=block)
742 742
743 743 def kill(self, targets=None, block=True):
744 744 """Kill my engines."""
745 745 block = block if block is not None else self.block
746 746 targets = targets if targets is not None else self.targets
747 747 return self.client.kill(targets=targets, block=block)
748 748
749 749 #----------------------------------------
750 750 # activate for %px,%autopx magics
751 751 #----------------------------------------
752 752 def activate(self):
753 753 """Make this `View` active for parallel magic commands.
754 754
755 755 IPython has a magic command syntax to work with `MultiEngineClient` objects.
756 756 In a given IPython session there is a single active one. While
757 757 there can be many `Views` created and used by the user,
758 758 there is only one active one. The active `View` is used whenever
759 759 the magic commands %px and %autopx are used.
760 760
761 761 The activate() method is called on a given `View` to make it
762 762 active. Once this has been done, the magic commands can be used.
763 763 """
764 764
765 765 try:
766 766 # This is injected into __builtins__.
767 767 ip = get_ipython()
768 768 except NameError:
769 769 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
770 770 else:
771 771 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
772 772 if pmagic is None:
773 773 ip.magic_load_ext('parallelmagic')
774 774 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
775 775
776 776 pmagic.active_view = self
777 777
778 778
779 779 @skip_doctest
780 780 class LoadBalancedView(View):
781 781 """An load-balancing View that only executes via the Task scheduler.
782 782
783 783 Load-balanced views can be created with the client's `view` method:
784 784
785 785 >>> v = client.load_balanced_view()
786 786
787 787 or targets can be specified, to restrict the potential destinations:
788 788
789 789 >>> v = client.client.load_balanced_view(([1,3])
790 790
791 791 which would restrict loadbalancing to between engines 1 and 3.
792 792
793 793 """
794 794
795 795 follow=Any()
796 796 after=Any()
797 797 timeout=CFloat()
798 798 retries = CInt(0)
799 799
800 800 _task_scheme = Any()
801 801 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
802 802
803 803 def __init__(self, client=None, socket=None, **flags):
804 804 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
805 805 self._task_scheme=client._task_scheme
806 806
807 807 def _validate_dependency(self, dep):
808 808 """validate a dependency.
809 809
810 810 For use in `set_flags`.
811 811 """
812 812 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
813 813 return True
814 814 elif isinstance(dep, (list,set, tuple)):
815 815 for d in dep:
816 816 if not isinstance(d, (str, AsyncResult)):
817 817 return False
818 818 elif isinstance(dep, dict):
819 819 if set(dep.keys()) != set(Dependency().as_dict().keys()):
820 820 return False
821 821 if not isinstance(dep['msg_ids'], list):
822 822 return False
823 823 for d in dep['msg_ids']:
824 824 if not isinstance(d, str):
825 825 return False
826 826 else:
827 827 return False
828 828
829 829 return True
830 830
831 831 def _render_dependency(self, dep):
832 832 """helper for building jsonable dependencies from various input forms."""
833 833 if isinstance(dep, Dependency):
834 834 return dep.as_dict()
835 835 elif isinstance(dep, AsyncResult):
836 836 return dep.msg_ids
837 837 elif dep is None:
838 838 return []
839 839 else:
840 840 # pass to Dependency constructor
841 841 return list(Dependency(dep))
842 842
843 843 def set_flags(self, **kwargs):
844 844 """set my attribute flags by keyword.
845 845
846 846 A View is a wrapper for the Client's apply method, but with attributes
847 847 that specify keyword arguments, those attributes can be set by keyword
848 848 argument with this method.
849 849
850 850 Parameters
851 851 ----------
852 852
853 853 block : bool
854 854 whether to wait for results
855 855 track : bool
856 856 whether to create a MessageTracker to allow the user to
857 857 safely edit after arrays and buffers during non-copying
858 858 sends.
859 859
860 860 after : Dependency or collection of msg_ids
861 861 Only for load-balanced execution (targets=None)
862 862 Specify a list of msg_ids as a time-based dependency.
863 863 This job will only be run *after* the dependencies
864 864 have been met.
865 865
866 866 follow : Dependency or collection of msg_ids
867 867 Only for load-balanced execution (targets=None)
868 868 Specify a list of msg_ids as a location-based dependency.
869 869 This job will only be run on an engine where this dependency
870 870 is met.
871 871
872 872 timeout : float/int or None
873 873 Only for load-balanced execution (targets=None)
874 874 Specify an amount of time (in seconds) for the scheduler to
875 875 wait for dependencies to be met before failing with a
876 876 DependencyTimeout.
877 877
878 878 retries : int
879 879 Number of times a task will be retried on failure.
880 880 """
881 881
882 882 super(LoadBalancedView, self).set_flags(**kwargs)
883 883 for name in ('follow', 'after'):
884 884 if name in kwargs:
885 885 value = kwargs[name]
886 886 if self._validate_dependency(value):
887 887 setattr(self, name, value)
888 888 else:
889 889 raise ValueError("Invalid dependency: %r"%value)
890 890 if 'timeout' in kwargs:
891 891 t = kwargs['timeout']
892 892 if not isinstance(t, (int, long, float, type(None))):
893 893 raise TypeError("Invalid type for timeout: %r"%type(t))
894 894 if t is not None:
895 895 if t < 0:
896 896 raise ValueError("Invalid timeout: %s"%t)
897 897 self.timeout = t
898 898
899 899 @sync_results
900 900 @save_ids
901 901 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
902 902 after=None, follow=None, timeout=None,
903 903 targets=None, retries=None):
904 904 """calls f(*args, **kwargs) on a remote engine, returning the result.
905 905
906 906 This method temporarily sets all of `apply`'s flags for a single call.
907 907
908 908 Parameters
909 909 ----------
910 910
911 911 f : callable
912 912
913 913 args : list [default: empty]
914 914
915 915 kwargs : dict [default: empty]
916 916
917 917 block : bool [default: self.block]
918 918 whether to block
919 919 track : bool [default: self.track]
920 920 whether to ask zmq to track the message, for safe non-copying sends
921 921
922 922 !!!!!! TODO: THE REST HERE !!!!
923 923
924 924 Returns
925 925 -------
926 926
927 927 if self.block is False:
928 928 returns AsyncResult
929 929 else:
930 930 returns actual result of f(*args, **kwargs) on the engine(s)
931 931 This will be a list of self.targets is also a list (even length 1), or
932 932 the single result if self.targets is an integer engine id
933 933 """
934 934
935 935 # validate whether we can run
936 936 if self._socket.closed:
937 937 msg = "Task farming is disabled"
938 938 if self._task_scheme == 'pure':
939 939 msg += " because the pure ZMQ scheduler cannot handle"
940 940 msg += " disappearing engines."
941 941 raise RuntimeError(msg)
942 942
943 943 if self._task_scheme == 'pure':
944 944 # pure zmq scheme doesn't support extra features
945 945 msg = "Pure ZMQ scheduler doesn't support the following flags:"
946 946 "follow, after, retries, targets, timeout"
947 947 if (follow or after or retries or targets or timeout):
948 948 # hard fail on Scheduler flags
949 949 raise RuntimeError(msg)
950 950 if isinstance(f, dependent):
951 951 # soft warn on functional dependencies
952 952 warnings.warn(msg, RuntimeWarning)
953 953
954 954 # build args
955 955 args = [] if args is None else args
956 956 kwargs = {} if kwargs is None else kwargs
957 957 block = self.block if block is None else block
958 958 track = self.track if track is None else track
959 959 after = self.after if after is None else after
960 960 retries = self.retries if retries is None else retries
961 961 follow = self.follow if follow is None else follow
962 962 timeout = self.timeout if timeout is None else timeout
963 963 targets = self.targets if targets is None else targets
964 964
965 965 if not isinstance(retries, int):
966 966 raise TypeError('retries must be int, not %r'%type(retries))
967 967
968 968 if targets is None:
969 969 idents = []
970 970 else:
971 971 idents = self.client._build_targets(targets)[0]
972 # ensure *not* bytes
973 idents = [ ident.decode() for ident in idents ]
972 974
973 975 after = self._render_dependency(after)
974 976 follow = self._render_dependency(follow)
975 977 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
976 978
977 979 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
978 980 subheader=subheader)
979 981 tracker = None if track is False else msg['tracker']
980 982
981 983 ar = AsyncResult(self.client, msg['msg_id'], fname=f.__name__, targets=None, tracker=tracker)
982 984
983 985 if block:
984 986 try:
985 987 return ar.get()
986 988 except KeyboardInterrupt:
987 989 pass
988 990 return ar
989 991
990 992 @spin_after
991 993 @save_ids
992 994 def map(self, f, *sequences, **kwargs):
993 995 """view.map(f, *sequences, block=self.block, chunksize=1) => list|AsyncMapResult
994 996
995 997 Parallel version of builtin `map`, load-balanced by this View.
996 998
997 999 `block`, and `chunksize` can be specified by keyword only.
998 1000
999 1001 Each `chunksize` elements will be a separate task, and will be
1000 1002 load-balanced. This lets individual elements be available for iteration
1001 1003 as soon as they arrive.
1002 1004
1003 1005 Parameters
1004 1006 ----------
1005 1007
1006 1008 f : callable
1007 1009 function to be mapped
1008 1010 *sequences: one or more sequences of matching length
1009 1011 the sequences to be distributed and passed to `f`
1010 1012 block : bool
1011 1013 whether to wait for the result or not [default self.block]
1012 1014 track : bool
1013 1015 whether to create a MessageTracker to allow the user to
1014 1016 safely edit after arrays and buffers during non-copying
1015 1017 sends.
1016 1018 chunksize : int
1017 1019 how many elements should be in each task [default 1]
1018 1020
1019 1021 Returns
1020 1022 -------
1021 1023
1022 1024 if block=False:
1023 1025 AsyncMapResult
1024 1026 An object like AsyncResult, but which reassembles the sequence of results
1025 1027 into a single list. AsyncMapResults can be iterated through before all
1026 1028 results are complete.
1027 1029 else:
1028 1030 the result of map(f,*sequences)
1029 1031
1030 1032 """
1031 1033
1032 1034 # default
1033 1035 block = kwargs.get('block', self.block)
1034 1036 chunksize = kwargs.get('chunksize', 1)
1035 1037
1036 1038 keyset = set(kwargs.keys())
1037 1039 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1038 1040 if extra_keys:
1039 1041 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1040 1042
1041 1043 assert len(sequences) > 0, "must have some sequences to map onto!"
1042 1044
1043 1045 pf = ParallelFunction(self, f, block=block, chunksize=chunksize)
1044 1046 return pf.map(*sequences)
1045 1047
1046 1048 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,169 +1,173 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 4 and hearts are tracked based on their XREQ identities.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18 import time
19 19 import uuid
20 20
21 21 import zmq
22 22 from zmq.devices import ThreadDevice
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.config.configurable import LoggingConfigurable
26 26 from IPython.utils.traitlets import Set, Instance, CFloat
27 27
28 from IPython.parallel.util import ensure_bytes
29
28 30 class Heart(object):
29 31 """A basic heart object for responding to a HeartMonitor.
30 32 This is a simple wrapper with defaults for the most common
31 33 Device model for responding to heartbeats.
32 34
33 35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
34 36 SUB/XREQ for in/out.
35 37
36 38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
37 39 device=None
38 40 id=None
39 41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.XREQ, heart_id=None):
40 42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
41 43 self.device.daemon=True
42 44 self.device.connect_in(in_addr)
43 45 self.device.connect_out(out_addr)
44 46 if in_type == zmq.SUB:
45 self.device.setsockopt_in(zmq.SUBSCRIBE, "")
47 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
46 48 if heart_id is None:
47 heart_id = str(uuid.uuid4())
49 heart_id = ensure_bytes(uuid.uuid4())
48 50 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
49 51 self.id = heart_id
50 52
51 53 def start(self):
52 54 return self.device.start()
53 55
54 56 class HeartMonitor(LoggingConfigurable):
55 57 """A basic HeartMonitor class
56 58 pingstream: a PUB stream
57 59 pongstream: an XREP stream
58 60 period: the period of the heartbeat in milliseconds"""
59 61
60 62 period=CFloat(1000, config=True,
61 63 help='The frequency at which the Hub pings the engines for heartbeats '
62 64 ' (in ms) [default: 100]',
63 65 )
64 66
65 67 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
66 68 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
67 69 loop = Instance('zmq.eventloop.ioloop.IOLoop')
68 70 def _loop_default(self):
69 71 return ioloop.IOLoop.instance()
70 72
71 73 # not settable:
72 74 hearts=Set()
73 75 responses=Set()
74 76 on_probation=Set()
75 77 last_ping=CFloat(0)
76 78 _new_handlers = Set()
77 79 _failure_handlers = Set()
78 80 lifetime = CFloat(0)
79 81 tic = CFloat(0)
80 82
81 83 def __init__(self, **kwargs):
82 84 super(HeartMonitor, self).__init__(**kwargs)
83 85
84 86 self.pongstream.on_recv(self.handle_pong)
85 87
86 88 def start(self):
87 89 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
88 90 self.caller.start()
89 91
90 92 def add_new_heart_handler(self, handler):
91 93 """add a new handler for new hearts"""
92 94 self.log.debug("heartbeat::new_heart_handler: %s"%handler)
93 95 self._new_handlers.add(handler)
94 96
95 97 def add_heart_failure_handler(self, handler):
96 98 """add a new handler for heart failure"""
97 99 self.log.debug("heartbeat::new heart failure handler: %s"%handler)
98 100 self._failure_handlers.add(handler)
99 101
100 102 def beat(self):
101 103 self.pongstream.flush()
102 104 self.last_ping = self.lifetime
103 105
104 106 toc = time.time()
105 107 self.lifetime += toc-self.tic
106 108 self.tic = toc
107 109 # self.log.debug("heartbeat::%s"%self.lifetime)
108 110 goodhearts = self.hearts.intersection(self.responses)
109 111 missed_beats = self.hearts.difference(goodhearts)
110 112 heartfailures = self.on_probation.intersection(missed_beats)
111 113 newhearts = self.responses.difference(goodhearts)
112 114 map(self.handle_new_heart, newhearts)
113 115 map(self.handle_heart_failure, heartfailures)
114 116 self.on_probation = missed_beats.intersection(self.hearts)
115 117 self.responses = set()
116 118 # print self.on_probation, self.hearts
117 119 # self.log.debug("heartbeat::beat %.3f, %i beating hearts"%(self.lifetime, len(self.hearts)))
118 self.pingstream.send(str(self.lifetime))
120 self.pingstream.send(ensure_bytes(str(self.lifetime)))
119 121
120 122 def handle_new_heart(self, heart):
121 123 if self._new_handlers:
122 124 for handler in self._new_handlers:
123 125 handler(heart)
124 126 else:
125 127 self.log.info("heartbeat::yay, got new heart %s!"%heart)
126 128 self.hearts.add(heart)
127 129
128 130 def handle_heart_failure(self, heart):
129 131 if self._failure_handlers:
130 132 for handler in self._failure_handlers:
131 133 try:
132 134 handler(heart)
133 135 except Exception as e:
134 136 self.log.error("heartbeat::Bad Handler! %s"%handler, exc_info=True)
135 137 pass
136 138 else:
137 139 self.log.info("heartbeat::Heart %s failed :("%heart)
138 140 self.hearts.remove(heart)
139 141
140 142
141 143 def handle_pong(self, msg):
142 144 "a heart just beat"
143 if msg[1] == str(self.lifetime):
145 current = ensure_bytes(str(self.lifetime))
146 last = ensure_bytes(str(self.last_ping))
147 if msg[1] == current:
144 148 delta = time.time()-self.tic
145 149 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
146 150 self.responses.add(msg[0])
147 elif msg[1] == str(self.last_ping):
151 elif msg[1] == last:
148 152 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
149 153 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond"%(msg[0], 1000*delta))
150 154 self.responses.add(msg[0])
151 155 else:
152 156 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)"%
153 157 (msg[1],self.lifetime))
154 158
155 159
156 160 if __name__ == '__main__':
157 161 loop = ioloop.IOLoop.instance()
158 162 context = zmq.Context()
159 163 pub = context.socket(zmq.PUB)
160 164 pub.bind('tcp://127.0.0.1:5555')
161 165 xrep = context.socket(zmq.XREP)
162 166 xrep.bind('tcp://127.0.0.1:5556')
163 167
164 168 outstream = zmqstream.ZMQStream(pub, loop)
165 169 instream = zmqstream.ZMQStream(xrep, loop)
166 170
167 171 hb = HeartMonitor(loop, outstream, instream)
168 172
169 173 loop.start()
@@ -1,1288 +1,1291 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Imports
19 19 #-----------------------------------------------------------------------------
20 20 from __future__ import print_function
21 21
22 22 import sys
23 23 import time
24 24 from datetime import datetime
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop
28 28 from zmq.eventloop.zmqstream import ZMQStream
29 29
30 30 # internal:
31 31 from IPython.utils.importstring import import_item
32 32 from IPython.utils.traitlets import (
33 33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 34 )
35 35
36 36 from IPython.parallel import error, util
37 37 from IPython.parallel.factory import RegistrationFactory
38 38
39 39 from IPython.zmq.session import SessionFactory
40 40
41 41 from .heartmonitor import HeartMonitor
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 50 def _printer(*args, **kwargs):
51 51 print (args)
52 52 print (kwargs)
53 53
54 54 def empty_record():
55 55 """Return an empty dict with all record keys."""
56 56 return {
57 57 'msg_id' : None,
58 58 'header' : None,
59 59 'content': None,
60 60 'buffers': None,
61 61 'submitted': None,
62 62 'client_uuid' : None,
63 63 'engine_uuid' : None,
64 64 'started': None,
65 65 'completed': None,
66 66 'resubmitted': None,
67 67 'result_header' : None,
68 68 'result_content' : None,
69 69 'result_buffers' : None,
70 70 'queue' : None,
71 71 'pyin' : None,
72 72 'pyout': None,
73 73 'pyerr': None,
74 74 'stdout': '',
75 75 'stderr': '',
76 76 }
77 77
78 78 def init_record(msg):
79 79 """Initialize a TaskRecord based on a request."""
80 80 header = msg['header']
81 81 return {
82 82 'msg_id' : header['msg_id'],
83 83 'header' : header,
84 84 'content': msg['content'],
85 85 'buffers': msg['buffers'],
86 86 'submitted': header['date'],
87 87 'client_uuid' : None,
88 88 'engine_uuid' : None,
89 89 'started': None,
90 90 'completed': None,
91 91 'resubmitted': None,
92 92 'result_header' : None,
93 93 'result_content' : None,
94 94 'result_buffers' : None,
95 95 'queue' : None,
96 96 'pyin' : None,
97 97 'pyout': None,
98 98 'pyerr': None,
99 99 'stdout': '',
100 100 'stderr': '',
101 101 }
102 102
103 103
104 104 class EngineConnector(HasTraits):
105 105 """A simple object for accessing the various zmq connections of an object.
106 106 Attributes are:
107 107 id (int): engine ID
108 108 uuid (str): uuid (unused?)
109 109 queue (str): identity of queue's XREQ socket
110 110 registration (str): identity of registration XREQ socket
111 111 heartbeat (str): identity of heartbeat XREQ socket
112 112 """
113 113 id=Int(0)
114 114 queue=CBytes()
115 115 control=CBytes()
116 116 registration=CBytes()
117 117 heartbeat=CBytes()
118 118 pending=Set()
119 119
120 120 class HubFactory(RegistrationFactory):
121 121 """The Configurable for setting up a Hub."""
122 122
123 123 # port-pairs for monitoredqueues:
124 124 hb = Tuple(Int,Int,config=True,
125 125 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 126 def _hb_default(self):
127 127 return tuple(util.select_random_ports(2))
128 128
129 129 mux = Tuple(Int,Int,config=True,
130 130 help="""Engine/Client Port pair for MUX queue""")
131 131
132 132 def _mux_default(self):
133 133 return tuple(util.select_random_ports(2))
134 134
135 135 task = Tuple(Int,Int,config=True,
136 136 help="""Engine/Client Port pair for Task queue""")
137 137 def _task_default(self):
138 138 return tuple(util.select_random_ports(2))
139 139
140 140 control = Tuple(Int,Int,config=True,
141 141 help="""Engine/Client Port pair for Control queue""")
142 142
143 143 def _control_default(self):
144 144 return tuple(util.select_random_ports(2))
145 145
146 146 iopub = Tuple(Int,Int,config=True,
147 147 help="""Engine/Client Port pair for IOPub relay""")
148 148
149 149 def _iopub_default(self):
150 150 return tuple(util.select_random_ports(2))
151 151
152 152 # single ports:
153 153 mon_port = Int(config=True,
154 154 help="""Monitor (SUB) port for queue traffic""")
155 155
156 156 def _mon_port_default(self):
157 157 return util.select_random_ports(1)[0]
158 158
159 159 notifier_port = Int(config=True,
160 160 help="""PUB port for sending engine status notifications""")
161 161
162 162 def _notifier_port_default(self):
163 163 return util.select_random_ports(1)[0]
164 164
165 165 engine_ip = Unicode('127.0.0.1', config=True,
166 166 help="IP on which to listen for engine connections. [default: loopback]")
167 167 engine_transport = Unicode('tcp', config=True,
168 168 help="0MQ transport for engine connections. [default: tcp]")
169 169
170 170 client_ip = Unicode('127.0.0.1', config=True,
171 171 help="IP on which to listen for client connections. [default: loopback]")
172 172 client_transport = Unicode('tcp', config=True,
173 173 help="0MQ transport for client connections. [default : tcp]")
174 174
175 175 monitor_ip = Unicode('127.0.0.1', config=True,
176 176 help="IP on which to listen for monitor messages. [default: loopback]")
177 177 monitor_transport = Unicode('tcp', config=True,
178 178 help="0MQ transport for monitor messages. [default : tcp]")
179 179
180 180 monitor_url = Unicode('')
181 181
182 182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 183 config=True, help="""The class to use for the DB backend""")
184 184
185 185 # not configurable
186 186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 188
189 189 def _ip_changed(self, name, old, new):
190 190 self.engine_ip = new
191 191 self.client_ip = new
192 192 self.monitor_ip = new
193 193 self._update_monitor_url()
194 194
195 195 def _update_monitor_url(self):
196 196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 197
198 198 def _transport_changed(self, name, old, new):
199 199 self.engine_transport = new
200 200 self.client_transport = new
201 201 self.monitor_transport = new
202 202 self._update_monitor_url()
203 203
204 204 def __init__(self, **kwargs):
205 205 super(HubFactory, self).__init__(**kwargs)
206 206 self._update_monitor_url()
207 207
208 208
209 209 def construct(self):
210 210 self.init_hub()
211 211
212 212 def start(self):
213 213 self.heartmonitor.start()
214 214 self.log.info("Heartmonitor started")
215 215
216 216 def init_hub(self):
217 217 """construct"""
218 218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 220
221 221 ctx = self.context
222 222 loop = self.loop
223 223
224 224 # Registrar socket
225 225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 226 q.bind(client_iface % self.regport)
227 227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 228 if self.client_ip != self.engine_ip:
229 229 q.bind(engine_iface % self.regport)
230 230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 231
232 232 ### Engine connections ###
233 233
234 234 # heartbeat
235 235 hpub = ctx.socket(zmq.PUB)
236 236 hpub.bind(engine_iface % self.hb[0])
237 237 hrep = ctx.socket(zmq.XREP)
238 238 hrep.bind(engine_iface % self.hb[1])
239 239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 240 pingstream=ZMQStream(hpub,loop),
241 241 pongstream=ZMQStream(hrep,loop)
242 242 )
243 243
244 244 ### Client connections ###
245 245 # Notifier socket
246 246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 247 n.bind(client_iface%self.notifier_port)
248 248
249 249 ### build and launch the queues ###
250 250
251 251 # monitor socket
252 252 sub = ctx.socket(zmq.SUB)
253 253 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 254 sub.bind(self.monitor_url)
255 255 sub.bind('inproc://monitor')
256 256 sub = ZMQStream(sub, loop)
257 257
258 258 # connect the db
259 259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 260 # cdir = self.config.Global.cluster_dir
261 261 self.db = import_item(str(self.db_class))(session=self.session.session,
262 262 config=self.config, log=self.log)
263 263 time.sleep(.25)
264 264 try:
265 265 scheme = self.config.TaskScheduler.scheme_name
266 266 except AttributeError:
267 267 from .scheduler import TaskScheduler
268 268 scheme = TaskScheduler.scheme_name.get_default_value()
269 269 # build connection dicts
270 270 self.engine_info = {
271 271 'control' : engine_iface%self.control[1],
272 272 'mux': engine_iface%self.mux[1],
273 273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 274 'task' : engine_iface%self.task[1],
275 275 'iopub' : engine_iface%self.iopub[1],
276 276 # 'monitor' : engine_iface%self.mon_port,
277 277 }
278 278
279 279 self.client_info = {
280 280 'control' : client_iface%self.control[0],
281 281 'mux': client_iface%self.mux[0],
282 282 'task' : (scheme, client_iface%self.task[0]),
283 283 'iopub' : client_iface%self.iopub[0],
284 284 'notification': client_iface%self.notifier_port
285 285 }
286 286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 287 self.log.debug("Hub client addrs: %s"%self.client_info)
288 288
289 289 # resubmit stream
290 290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 291 url = util.disambiguate_url(self.client_info['task'][-1])
292 r.setsockopt(zmq.IDENTITY, self.session.session)
292 r.setsockopt(zmq.IDENTITY, util.ensure_bytes(self.session.session))
293 293 r.connect(url)
294 294
295 295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 296 query=q, notifier=n, resubmit=r, db=self.db,
297 297 engine_info=self.engine_info, client_info=self.client_info,
298 298 log=self.log)
299 299
300 300
301 301 class Hub(SessionFactory):
302 302 """The IPython Controller Hub with 0MQ connections
303 303
304 304 Parameters
305 305 ==========
306 306 loop: zmq IOLoop instance
307 307 session: Session object
308 308 <removed> context: zmq context for creating new connections (?)
309 309 queue: ZMQStream for monitoring the command queue (SUB)
310 310 query: ZMQStream for engine registration and client queries requests (XREP)
311 311 heartbeat: HeartMonitor object checking the pulse of the engines
312 312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 313 db: connection to db for out of memory logging of commands
314 314 NotImplemented
315 315 engine_info: dict of zmq connection information for engines to connect
316 316 to the queues.
317 317 client_info: dict of zmq connection information for engines to connect
318 318 to the queues.
319 319 """
320 320 # internal data structures:
321 321 ids=Set() # engine IDs
322 322 keytable=Dict()
323 323 by_ident=Dict()
324 324 engines=Dict()
325 325 clients=Dict()
326 326 hearts=Dict()
327 327 pending=Set()
328 328 queues=Dict() # pending msg_ids keyed by engine_id
329 329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 330 completed=Dict() # completed msg_ids keyed by engine_id
331 331 all_completed=Set() # completed msg_ids keyed by engine_id
332 332 dead_engines=Set() # completed msg_ids keyed by engine_id
333 333 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 334 incoming_registrations=Dict()
335 335 registration_timeout=Int()
336 336 _idcounter=Int(0)
337 337
338 338 # objects from constructor:
339 339 query=Instance(ZMQStream)
340 340 monitor=Instance(ZMQStream)
341 341 notifier=Instance(ZMQStream)
342 342 resubmit=Instance(ZMQStream)
343 343 heartmonitor=Instance(HeartMonitor)
344 344 db=Instance(object)
345 345 client_info=Dict()
346 346 engine_info=Dict()
347 347
348 348
349 349 def __init__(self, **kwargs):
350 350 """
351 351 # universal:
352 352 loop: IOLoop for creating future connections
353 353 session: streamsession for sending serialized data
354 354 # engine:
355 355 queue: ZMQStream for monitoring queue messages
356 356 query: ZMQStream for engine+client registration and client requests
357 357 heartbeat: HeartMonitor object for tracking engines
358 358 # extra:
359 359 db: ZMQStream for db connection (NotImplemented)
360 360 engine_info: zmq address/protocol dict for engine connections
361 361 client_info: zmq address/protocol dict for client connections
362 362 """
363 363
364 364 super(Hub, self).__init__(**kwargs)
365 365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 366
367 367 # validate connection dicts:
368 368 for k,v in self.client_info.iteritems():
369 369 if k == 'task':
370 370 util.validate_url_container(v[1])
371 371 else:
372 372 util.validate_url_container(v)
373 373 # util.validate_url_container(self.client_info)
374 374 util.validate_url_container(self.engine_info)
375 375
376 376 # register our callbacks
377 377 self.query.on_recv(self.dispatch_query)
378 378 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 379
380 380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 382
383 self.monitor_handlers = { 'in' : self.save_queue_request,
384 'out': self.save_queue_result,
385 'intask': self.save_task_request,
386 'outtask': self.save_task_result,
387 'tracktask': self.save_task_destination,
388 'incontrol': _passer,
389 'outcontrol': _passer,
390 'iopub': self.save_iopub_message,
383 self.monitor_handlers = {b'in' : self.save_queue_request,
384 b'out': self.save_queue_result,
385 b'intask': self.save_task_request,
386 b'outtask': self.save_task_result,
387 b'tracktask': self.save_task_destination,
388 b'incontrol': _passer,
389 b'outcontrol': _passer,
390 b'iopub': self.save_iopub_message,
391 391 }
392 392
393 393 self.query_handlers = {'queue_request': self.queue_status,
394 394 'result_request': self.get_results,
395 395 'history_request': self.get_history,
396 396 'db_request': self.db_query,
397 397 'purge_request': self.purge_results,
398 398 'load_request': self.check_load,
399 399 'resubmit_request': self.resubmit_task,
400 400 'shutdown_request': self.shutdown_request,
401 401 'registration_request' : self.register_engine,
402 402 'unregistration_request' : self.unregister_engine,
403 403 'connection_request': self.connection_request,
404 404 }
405 405
406 406 # ignore resubmit replies
407 407 self.resubmit.on_recv(lambda msg: None, copy=False)
408 408
409 409 self.log.info("hub::created hub")
410 410
411 411 @property
412 412 def _next_id(self):
413 413 """gemerate a new ID.
414 414
415 415 No longer reuse old ids, just count from 0."""
416 416 newid = self._idcounter
417 417 self._idcounter += 1
418 418 return newid
419 419 # newid = 0
420 420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 421 # # print newid, self.ids, self.incoming_registrations
422 422 # while newid in self.ids or newid in incoming:
423 423 # newid += 1
424 424 # return newid
425 425
426 426 #-----------------------------------------------------------------------------
427 427 # message validation
428 428 #-----------------------------------------------------------------------------
429 429
430 430 def _validate_targets(self, targets):
431 431 """turn any valid targets argument into a list of integer ids"""
432 432 if targets is None:
433 433 # default to all
434 434 targets = self.ids
435 435
436 436 if isinstance(targets, (int,str,unicode)):
437 437 # only one target specified
438 438 targets = [targets]
439 439 _targets = []
440 440 for t in targets:
441 441 # map raw identities to ids
442 442 if isinstance(t, (str,unicode)):
443 443 t = self.by_ident.get(t, t)
444 444 _targets.append(t)
445 445 targets = _targets
446 446 bad_targets = [ t for t in targets if t not in self.ids ]
447 447 if bad_targets:
448 448 raise IndexError("No Such Engine: %r"%bad_targets)
449 449 if not targets:
450 450 raise IndexError("No Engines Registered")
451 451 return targets
452 452
453 453 #-----------------------------------------------------------------------------
454 454 # dispatch methods (1 per stream)
455 455 #-----------------------------------------------------------------------------
456 456
457 457
458 458 def dispatch_monitor_traffic(self, msg):
459 459 """all ME and Task queue messages come through here, as well as
460 460 IOPub traffic."""
461 461 self.log.debug("monitor traffic: %r"%msg[:2])
462 462 switch = msg[0]
463 463 try:
464 464 idents, msg = self.session.feed_identities(msg[1:])
465 465 except ValueError:
466 466 idents=[]
467 467 if not idents:
468 468 self.log.error("Bad Monitor Message: %r"%msg)
469 469 return
470 470 handler = self.monitor_handlers.get(switch, None)
471 471 if handler is not None:
472 472 handler(idents, msg)
473 473 else:
474 474 self.log.error("Invalid monitor topic: %r"%switch)
475 475
476 476
477 477 def dispatch_query(self, msg):
478 478 """Route registration requests and queries from clients."""
479 479 try:
480 480 idents, msg = self.session.feed_identities(msg)
481 481 except ValueError:
482 482 idents = []
483 483 if not idents:
484 484 self.log.error("Bad Query Message: %r"%msg)
485 485 return
486 486 client_id = idents[0]
487 487 try:
488 488 msg = self.session.unpack_message(msg, content=True)
489 489 except Exception:
490 490 content = error.wrap_exception()
491 491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 492 self.session.send(self.query, "hub_error", ident=client_id,
493 493 content=content)
494 494 return
495 495 # print client_id, header, parent, content
496 496 #switch on message type:
497 497 msg_type = msg['msg_type']
498 498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 499 handler = self.query_handlers.get(msg_type, None)
500 500 try:
501 501 assert handler is not None, "Bad Message Type: %r"%msg_type
502 502 except:
503 503 content = error.wrap_exception()
504 504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 505 self.session.send(self.query, "hub_error", ident=client_id,
506 506 content=content)
507 507 return
508 508
509 509 else:
510 510 handler(idents, msg)
511 511
512 512 def dispatch_db(self, msg):
513 513 """"""
514 514 raise NotImplementedError
515 515
516 516 #---------------------------------------------------------------------------
517 517 # handler methods (1 per event)
518 518 #---------------------------------------------------------------------------
519 519
520 520 #----------------------- Heartbeat --------------------------------------
521 521
522 522 def handle_new_heart(self, heart):
523 523 """handler to attach to heartbeater.
524 524 Called when a new heart starts to beat.
525 525 Triggers completion of registration."""
526 526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 527 if heart not in self.incoming_registrations:
528 528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 529 else:
530 530 self.finish_registration(heart)
531 531
532 532
533 533 def handle_heart_failure(self, heart):
534 534 """handler to attach to heartbeater.
535 535 called when a previously registered heart fails to respond to beat request.
536 536 triggers unregistration"""
537 537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 538 eid = self.hearts.get(heart, None)
539 539 queue = self.engines[eid].queue
540 540 if eid is None:
541 541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 542 else:
543 543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 544
545 545 #----------------------- MUX Queue Traffic ------------------------------
546 546
547 547 def save_queue_request(self, idents, msg):
548 548 if len(idents) < 2:
549 549 self.log.error("invalid identity prefix: %r"%idents)
550 550 return
551 551 queue_id, client_id = idents[:2]
552 552 try:
553 553 msg = self.session.unpack_message(msg)
554 554 except Exception:
555 555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 556 return
557 557
558 558 eid = self.by_ident.get(queue_id, None)
559 559 if eid is None:
560 560 self.log.error("queue::target %r not registered"%queue_id)
561 561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 562 return
563 563 record = init_record(msg)
564 564 msg_id = record['msg_id']
565 record['engine_uuid'] = queue_id
566 record['client_uuid'] = client_id
565 # Unicode in records
566 record['engine_uuid'] = queue_id.decode('utf8', 'replace')
567 record['client_uuid'] = client_id.decode('utf8', 'replace')
567 568 record['queue'] = 'mux'
568 569
569 570 try:
570 571 # it's posible iopub arrived first:
571 572 existing = self.db.get_record(msg_id)
572 573 for key,evalue in existing.iteritems():
573 574 rvalue = record.get(key, None)
574 575 if evalue and rvalue and evalue != rvalue:
575 576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 577 elif evalue and not rvalue:
577 578 record[key] = evalue
578 579 try:
579 580 self.db.update_record(msg_id, record)
580 581 except Exception:
581 582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 583 except KeyError:
583 584 try:
584 585 self.db.add_record(msg_id, record)
585 586 except Exception:
586 587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587 588
588 589
589 590 self.pending.add(msg_id)
590 591 self.queues[eid].append(msg_id)
591 592
592 593 def save_queue_result(self, idents, msg):
593 594 if len(idents) < 2:
594 595 self.log.error("invalid identity prefix: %r"%idents)
595 596 return
596 597
597 598 client_id, queue_id = idents[:2]
598 599 try:
599 600 msg = self.session.unpack_message(msg)
600 601 except Exception:
601 602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 603 queue_id,client_id, msg), exc_info=True)
603 604 return
604 605
605 606 eid = self.by_ident.get(queue_id, None)
606 607 if eid is None:
607 608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 609 return
609 610
610 611 parent = msg['parent_header']
611 612 if not parent:
612 613 return
613 614 msg_id = parent['msg_id']
614 615 if msg_id in self.pending:
615 616 self.pending.remove(msg_id)
616 617 self.all_completed.add(msg_id)
617 618 self.queues[eid].remove(msg_id)
618 619 self.completed[eid].append(msg_id)
619 620 elif msg_id not in self.all_completed:
620 621 # it could be a result from a dead engine that died before delivering the
621 622 # result
622 623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 624 return
624 625 # update record anyway, because the unregistration could have been premature
625 626 rheader = msg['header']
626 627 completed = rheader['date']
627 628 started = rheader.get('started', None)
628 629 result = {
629 630 'result_header' : rheader,
630 631 'result_content': msg['content'],
631 632 'started' : started,
632 633 'completed' : completed
633 634 }
634 635
635 636 result['result_buffers'] = msg['buffers']
636 637 try:
637 638 self.db.update_record(msg_id, result)
638 639 except Exception:
639 640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640 641
641 642
642 643 #--------------------- Task Queue Traffic ------------------------------
643 644
644 645 def save_task_request(self, idents, msg):
645 646 """Save the submission of a task."""
646 647 client_id = idents[0]
647 648
648 649 try:
649 650 msg = self.session.unpack_message(msg)
650 651 except Exception:
651 652 self.log.error("task::client %r sent invalid task message: %r"%(
652 653 client_id, msg), exc_info=True)
653 654 return
654 655 record = init_record(msg)
655 656
656 657 record['client_uuid'] = client_id
657 658 record['queue'] = 'task'
658 659 header = msg['header']
659 660 msg_id = header['msg_id']
660 661 self.pending.add(msg_id)
661 662 self.unassigned.add(msg_id)
662 663 try:
663 664 # it's posible iopub arrived first:
664 665 existing = self.db.get_record(msg_id)
665 666 if existing['resubmitted']:
666 667 for key in ('submitted', 'client_uuid', 'buffers'):
667 668 # don't clobber these keys on resubmit
668 669 # submitted and client_uuid should be different
669 670 # and buffers might be big, and shouldn't have changed
670 671 record.pop(key)
671 672 # still check content,header which should not change
672 673 # but are not expensive to compare as buffers
673 674
674 675 for key,evalue in existing.iteritems():
675 676 if key.endswith('buffers'):
676 677 # don't compare buffers
677 678 continue
678 679 rvalue = record.get(key, None)
679 680 if evalue and rvalue and evalue != rvalue:
680 681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 682 elif evalue and not rvalue:
682 683 record[key] = evalue
683 684 try:
684 685 self.db.update_record(msg_id, record)
685 686 except Exception:
686 687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 688 except KeyError:
688 689 try:
689 690 self.db.add_record(msg_id, record)
690 691 except Exception:
691 692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 693 except Exception:
693 694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694 695
695 696 def save_task_result(self, idents, msg):
696 697 """save the result of a completed task."""
697 698 client_id = idents[0]
698 699 try:
699 700 msg = self.session.unpack_message(msg)
700 701 except Exception:
701 702 self.log.error("task::invalid task result message send to %r: %r"%(
702 703 client_id, msg), exc_info=True)
703 704 return
704 705
705 706 parent = msg['parent_header']
706 707 if not parent:
707 708 # print msg
708 709 self.log.warn("Task %r had no parent!"%msg)
709 710 return
710 711 msg_id = parent['msg_id']
711 712 if msg_id in self.unassigned:
712 713 self.unassigned.remove(msg_id)
713 714
714 715 header = msg['header']
715 716 engine_uuid = header.get('engine', None)
716 717 eid = self.by_ident.get(engine_uuid, None)
717 718
718 719 if msg_id in self.pending:
719 720 self.pending.remove(msg_id)
720 721 self.all_completed.add(msg_id)
721 722 if eid is not None:
722 723 self.completed[eid].append(msg_id)
723 724 if msg_id in self.tasks[eid]:
724 725 self.tasks[eid].remove(msg_id)
725 726 completed = header['date']
726 727 started = header.get('started', None)
727 728 result = {
728 729 'result_header' : header,
729 730 'result_content': msg['content'],
730 731 'started' : started,
731 732 'completed' : completed,
732 733 'engine_uuid': engine_uuid
733 734 }
734 735
735 736 result['result_buffers'] = msg['buffers']
736 737 try:
737 738 self.db.update_record(msg_id, result)
738 739 except Exception:
739 740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740 741
741 742 else:
742 743 self.log.debug("task::unknown task %r finished"%msg_id)
743 744
744 745 def save_task_destination(self, idents, msg):
745 746 try:
746 747 msg = self.session.unpack_message(msg, content=True)
747 748 except Exception:
748 749 self.log.error("task::invalid task tracking message", exc_info=True)
749 750 return
750 751 content = msg['content']
751 752 # print (content)
752 753 msg_id = content['msg_id']
753 754 engine_uuid = content['engine_id']
754 eid = self.by_ident[engine_uuid]
755 eid = self.by_ident[util.ensure_bytes(engine_uuid)]
755 756
756 757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 758 if msg_id in self.unassigned:
758 759 self.unassigned.remove(msg_id)
759 760 # else:
760 761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761 762
762 763 self.tasks[eid].append(msg_id)
763 764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 765 try:
765 766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 767 except Exception:
767 768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768 769
769 770
770 771 def mia_task_request(self, idents, msg):
771 772 raise NotImplementedError
772 773 client_id = idents[0]
773 774 # content = dict(mia=self.mia,status='ok')
774 775 # self.session.send('mia_reply', content=content, idents=client_id)
775 776
776 777
777 778 #--------------------- IOPub Traffic ------------------------------
778 779
779 780 def save_iopub_message(self, topics, msg):
780 781 """save an iopub message into the db"""
781 782 # print (topics)
782 783 try:
783 784 msg = self.session.unpack_message(msg, content=True)
784 785 except Exception:
785 786 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 787 return
787 788
788 789 parent = msg['parent_header']
789 790 if not parent:
790 791 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 792 return
792 793 msg_id = parent['msg_id']
793 794 msg_type = msg['msg_type']
794 795 content = msg['content']
795 796
796 797 # ensure msg_id is in db
797 798 try:
798 799 rec = self.db.get_record(msg_id)
799 800 except KeyError:
800 801 rec = empty_record()
801 802 rec['msg_id'] = msg_id
802 803 self.db.add_record(msg_id, rec)
803 804 # stream
804 805 d = {}
805 806 if msg_type == 'stream':
806 807 name = content['name']
807 808 s = rec[name] or ''
808 809 d[name] = s + content['data']
809 810
810 811 elif msg_type == 'pyerr':
811 812 d['pyerr'] = content
812 813 elif msg_type == 'pyin':
813 814 d['pyin'] = content['code']
814 815 else:
815 816 d[msg_type] = content.get('data', '')
816 817
817 818 try:
818 819 self.db.update_record(msg_id, d)
819 820 except Exception:
820 821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821 822
822 823
823 824
824 825 #-------------------------------------------------------------------------
825 826 # Registration requests
826 827 #-------------------------------------------------------------------------
827 828
828 829 def connection_request(self, client_id, msg):
829 830 """Reply with connection addresses for clients."""
830 831 self.log.info("client::client %r connected"%client_id)
831 832 content = dict(status='ok')
832 833 content.update(self.client_info)
833 834 jsonable = {}
834 835 for k,v in self.keytable.iteritems():
835 836 if v not in self.dead_engines:
836 jsonable[str(k)] = v
837 jsonable[str(k)] = v.decode()
837 838 content['engines'] = jsonable
838 839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839 840
840 841 def register_engine(self, reg, msg):
841 842 """Register a new engine."""
842 843 content = msg['content']
843 844 try:
844 queue = content['queue']
845 queue = util.ensure_bytes(content['queue'])
845 846 except KeyError:
846 847 self.log.error("registration::queue not specified", exc_info=True)
847 848 return
848 849 heart = content.get('heartbeat', None)
850 if heart:
851 heart = util.ensure_bytes(heart)
849 852 """register a new engine, and create the socket(s) necessary"""
850 853 eid = self._next_id
851 854 # print (eid, queue, reg, heart)
852 855
853 856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
854 857
855 858 content = dict(id=eid,status='ok')
856 859 content.update(self.engine_info)
857 860 # check if requesting available IDs:
858 861 if queue in self.by_ident:
859 862 try:
860 863 raise KeyError("queue_id %r in use"%queue)
861 864 except:
862 865 content = error.wrap_exception()
863 866 self.log.error("queue_id %r in use"%queue, exc_info=True)
864 867 elif heart in self.hearts: # need to check unique hearts?
865 868 try:
866 869 raise KeyError("heart_id %r in use"%heart)
867 870 except:
868 871 self.log.error("heart_id %r in use"%heart, exc_info=True)
869 872 content = error.wrap_exception()
870 873 else:
871 874 for h, pack in self.incoming_registrations.iteritems():
872 875 if heart == h:
873 876 try:
874 877 raise KeyError("heart_id %r in use"%heart)
875 878 except:
876 879 self.log.error("heart_id %r in use"%heart, exc_info=True)
877 880 content = error.wrap_exception()
878 881 break
879 882 elif queue == pack[1]:
880 883 try:
881 884 raise KeyError("queue_id %r in use"%queue)
882 885 except:
883 886 self.log.error("queue_id %r in use"%queue, exc_info=True)
884 887 content = error.wrap_exception()
885 888 break
886 889
887 890 msg = self.session.send(self.query, "registration_reply",
888 891 content=content,
889 892 ident=reg)
890 893
891 894 if content['status'] == 'ok':
892 895 if heart in self.heartmonitor.hearts:
893 896 # already beating
894 897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
895 898 self.finish_registration(heart)
896 899 else:
897 900 purge = lambda : self._purge_stalled_registration(heart)
898 901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
899 902 dc.start()
900 903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
901 904 else:
902 905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
903 906 return eid
904 907
905 908 def unregister_engine(self, ident, msg):
906 909 """Unregister an engine that explicitly requested to leave."""
907 910 try:
908 911 eid = msg['content']['id']
909 912 except:
910 913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
911 914 return
912 915 self.log.info("registration::unregister_engine(%r)"%eid)
913 916 # print (eid)
914 917 uuid = self.keytable[eid]
915 content=dict(id=eid, queue=uuid)
918 content=dict(id=eid, queue=uuid.decode())
916 919 self.dead_engines.add(uuid)
917 920 # self.ids.remove(eid)
918 921 # uuid = self.keytable.pop(eid)
919 922 #
920 923 # ec = self.engines.pop(eid)
921 924 # self.hearts.pop(ec.heartbeat)
922 925 # self.by_ident.pop(ec.queue)
923 926 # self.completed.pop(eid)
924 927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
925 928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
926 929 dc.start()
927 930 ############## TODO: HANDLE IT ################
928 931
929 932 if self.notifier:
930 933 self.session.send(self.notifier, "unregistration_notification", content=content)
931 934
932 935 def _handle_stranded_msgs(self, eid, uuid):
933 936 """Handle messages known to be on an engine when the engine unregisters.
934 937
935 938 It is possible that this will fire prematurely - that is, an engine will
936 939 go down after completing a result, and the client will be notified
937 940 that the result failed and later receive the actual result.
938 941 """
939 942
940 943 outstanding = self.queues[eid]
941 944
942 945 for msg_id in outstanding:
943 946 self.pending.remove(msg_id)
944 947 self.all_completed.add(msg_id)
945 948 try:
946 949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
947 950 except:
948 951 content = error.wrap_exception()
949 952 # build a fake header:
950 953 header = {}
951 954 header['engine'] = uuid
952 955 header['date'] = datetime.now()
953 956 rec = dict(result_content=content, result_header=header, result_buffers=[])
954 957 rec['completed'] = header['date']
955 958 rec['engine_uuid'] = uuid
956 959 try:
957 960 self.db.update_record(msg_id, rec)
958 961 except Exception:
959 962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
960 963
961 964
962 965 def finish_registration(self, heart):
963 966 """Second half of engine registration, called after our HeartMonitor
964 967 has received a beat from the Engine's Heart."""
965 968 try:
966 969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
967 970 except KeyError:
968 971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
969 972 return
970 973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
971 974 if purge is not None:
972 975 purge.stop()
973 976 control = queue
974 977 self.ids.add(eid)
975 978 self.keytable[eid] = queue
976 979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
977 980 control=control, heartbeat=heart)
978 981 self.by_ident[queue] = eid
979 982 self.queues[eid] = list()
980 983 self.tasks[eid] = list()
981 984 self.completed[eid] = list()
982 985 self.hearts[heart] = eid
983 content = dict(id=eid, queue=self.engines[eid].queue)
986 content = dict(id=eid, queue=self.engines[eid].queue.decode())
984 987 if self.notifier:
985 988 self.session.send(self.notifier, "registration_notification", content=content)
986 989 self.log.info("engine::Engine Connected: %i"%eid)
987 990
988 991 def _purge_stalled_registration(self, heart):
989 992 if heart in self.incoming_registrations:
990 993 eid = self.incoming_registrations.pop(heart)[0]
991 994 self.log.info("registration::purging stalled registration: %i"%eid)
992 995 else:
993 996 pass
994 997
995 998 #-------------------------------------------------------------------------
996 999 # Client Requests
997 1000 #-------------------------------------------------------------------------
998 1001
999 1002 def shutdown_request(self, client_id, msg):
1000 1003 """handle shutdown request."""
1001 1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1002 1005 # also notify other clients of shutdown
1003 1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1004 1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1005 1008 dc.start()
1006 1009
1007 1010 def _shutdown(self):
1008 1011 self.log.info("hub::hub shutting down.")
1009 1012 time.sleep(0.1)
1010 1013 sys.exit(0)
1011 1014
1012 1015
1013 1016 def check_load(self, client_id, msg):
1014 1017 content = msg['content']
1015 1018 try:
1016 1019 targets = content['targets']
1017 1020 targets = self._validate_targets(targets)
1018 1021 except:
1019 1022 content = error.wrap_exception()
1020 1023 self.session.send(self.query, "hub_error",
1021 1024 content=content, ident=client_id)
1022 1025 return
1023 1026
1024 1027 content = dict(status='ok')
1025 1028 # loads = {}
1026 1029 for t in targets:
1027 1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1028 1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1029 1032
1030 1033
1031 1034 def queue_status(self, client_id, msg):
1032 1035 """Return the Queue status of one or more targets.
1033 1036 if verbose: return the msg_ids
1034 1037 else: return len of each type.
1035 1038 keys: queue (pending MUX jobs)
1036 1039 tasks (pending Task jobs)
1037 1040 completed (finished jobs from both queues)"""
1038 1041 content = msg['content']
1039 1042 targets = content['targets']
1040 1043 try:
1041 1044 targets = self._validate_targets(targets)
1042 1045 except:
1043 1046 content = error.wrap_exception()
1044 1047 self.session.send(self.query, "hub_error",
1045 1048 content=content, ident=client_id)
1046 1049 return
1047 1050 verbose = content.get('verbose', False)
1048 1051 content = dict(status='ok')
1049 1052 for t in targets:
1050 1053 queue = self.queues[t]
1051 1054 completed = self.completed[t]
1052 1055 tasks = self.tasks[t]
1053 1056 if not verbose:
1054 1057 queue = len(queue)
1055 1058 completed = len(completed)
1056 1059 tasks = len(tasks)
1057 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1058 1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1059
1062 # print (content)
1060 1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1061 1064
1062 1065 def purge_results(self, client_id, msg):
1063 1066 """Purge results from memory. This method is more valuable before we move
1064 1067 to a DB based message storage mechanism."""
1065 1068 content = msg['content']
1066 1069 self.log.info("Dropping records with %s", content)
1067 1070 msg_ids = content.get('msg_ids', [])
1068 1071 reply = dict(status='ok')
1069 1072 if msg_ids == 'all':
1070 1073 try:
1071 1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1072 1075 except Exception:
1073 1076 reply = error.wrap_exception()
1074 1077 else:
1075 1078 pending = filter(lambda m: m in self.pending, msg_ids)
1076 1079 if pending:
1077 1080 try:
1078 1081 raise IndexError("msg pending: %r"%pending[0])
1079 1082 except:
1080 1083 reply = error.wrap_exception()
1081 1084 else:
1082 1085 try:
1083 1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1084 1087 except Exception:
1085 1088 reply = error.wrap_exception()
1086 1089
1087 1090 if reply['status'] == 'ok':
1088 1091 eids = content.get('engine_ids', [])
1089 1092 for eid in eids:
1090 1093 if eid not in self.engines:
1091 1094 try:
1092 1095 raise IndexError("No such engine: %i"%eid)
1093 1096 except:
1094 1097 reply = error.wrap_exception()
1095 1098 break
1096 1099 uid = self.engines[eid].queue
1097 1100 try:
1098 1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1099 1102 except Exception:
1100 1103 reply = error.wrap_exception()
1101 1104 break
1102 1105
1103 1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1104 1107
1105 1108 def resubmit_task(self, client_id, msg):
1106 1109 """Resubmit one or more tasks."""
1107 1110 def finish(reply):
1108 1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1109 1112
1110 1113 content = msg['content']
1111 1114 msg_ids = content['msg_ids']
1112 1115 reply = dict(status='ok')
1113 1116 try:
1114 1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1115 1118 'header', 'content', 'buffers'])
1116 1119 except Exception:
1117 1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1118 1121 return finish(error.wrap_exception())
1119 1122
1120 1123 # validate msg_ids
1121 1124 found_ids = [ rec['msg_id'] for rec in records ]
1122 1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1123 1126 if len(records) > len(msg_ids):
1124 1127 try:
1125 1128 raise RuntimeError("DB appears to be in an inconsistent state."
1126 1129 "More matching records were found than should exist")
1127 1130 except Exception:
1128 1131 return finish(error.wrap_exception())
1129 1132 elif len(records) < len(msg_ids):
1130 1133 missing = [ m for m in msg_ids if m not in found_ids ]
1131 1134 try:
1132 1135 raise KeyError("No such msg(s): %r"%missing)
1133 1136 except KeyError:
1134 1137 return finish(error.wrap_exception())
1135 1138 elif invalid_ids:
1136 1139 msg_id = invalid_ids[0]
1137 1140 try:
1138 1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1139 1142 except Exception:
1140 1143 return finish(error.wrap_exception())
1141 1144
1142 1145 # clear the existing records
1143 1146 now = datetime.now()
1144 1147 rec = empty_record()
1145 1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1146 1149 rec['resubmitted'] = now
1147 1150 rec['queue'] = 'task'
1148 1151 rec['client_uuid'] = client_id[0]
1149 1152 try:
1150 1153 for msg_id in msg_ids:
1151 1154 self.all_completed.discard(msg_id)
1152 1155 self.db.update_record(msg_id, rec)
1153 1156 except Exception:
1154 1157 self.log.error('db::db error upating record', exc_info=True)
1155 1158 reply = error.wrap_exception()
1156 1159 else:
1157 1160 # send the messages
1158 1161 for rec in records:
1159 1162 header = rec['header']
1160 1163 # include resubmitted in header to prevent digest collision
1161 1164 header['resubmitted'] = now
1162 1165 msg = self.session.msg(header['msg_type'])
1163 1166 msg['content'] = rec['content']
1164 1167 msg['header'] = header
1165 1168 msg['msg_id'] = rec['msg_id']
1166 1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1167 1170
1168 1171 finish(dict(status='ok'))
1169 1172
1170 1173
1171 1174 def _extract_record(self, rec):
1172 1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1173 1176 io_dict = {}
1174 1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1175 1178 io_dict[key] = rec[key]
1176 1179 content = { 'result_content': rec['result_content'],
1177 1180 'header': rec['header'],
1178 1181 'result_header' : rec['result_header'],
1179 1182 'io' : io_dict,
1180 1183 }
1181 1184 if rec['result_buffers']:
1182 buffers = map(str, rec['result_buffers'])
1185 buffers = map(bytes, rec['result_buffers'])
1183 1186 else:
1184 1187 buffers = []
1185 1188
1186 1189 return content, buffers
1187 1190
1188 1191 def get_results(self, client_id, msg):
1189 1192 """Get the result of 1 or more messages."""
1190 1193 content = msg['content']
1191 1194 msg_ids = sorted(set(content['msg_ids']))
1192 1195 statusonly = content.get('status_only', False)
1193 1196 pending = []
1194 1197 completed = []
1195 1198 content = dict(status='ok')
1196 1199 content['pending'] = pending
1197 1200 content['completed'] = completed
1198 1201 buffers = []
1199 1202 if not statusonly:
1200 1203 try:
1201 1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1202 1205 # turn match list into dict, for faster lookup
1203 1206 records = {}
1204 1207 for rec in matches:
1205 1208 records[rec['msg_id']] = rec
1206 1209 except Exception:
1207 1210 content = error.wrap_exception()
1208 1211 self.session.send(self.query, "result_reply", content=content,
1209 1212 parent=msg, ident=client_id)
1210 1213 return
1211 1214 else:
1212 1215 records = {}
1213 1216 for msg_id in msg_ids:
1214 1217 if msg_id in self.pending:
1215 1218 pending.append(msg_id)
1216 1219 elif msg_id in self.all_completed:
1217 1220 completed.append(msg_id)
1218 1221 if not statusonly:
1219 1222 c,bufs = self._extract_record(records[msg_id])
1220 1223 content[msg_id] = c
1221 1224 buffers.extend(bufs)
1222 1225 elif msg_id in records:
1223 1226 if rec['completed']:
1224 1227 completed.append(msg_id)
1225 1228 c,bufs = self._extract_record(records[msg_id])
1226 1229 content[msg_id] = c
1227 1230 buffers.extend(bufs)
1228 1231 else:
1229 1232 pending.append(msg_id)
1230 1233 else:
1231 1234 try:
1232 1235 raise KeyError('No such message: '+msg_id)
1233 1236 except:
1234 1237 content = error.wrap_exception()
1235 1238 break
1236 1239 self.session.send(self.query, "result_reply", content=content,
1237 1240 parent=msg, ident=client_id,
1238 1241 buffers=buffers)
1239 1242
1240 1243 def get_history(self, client_id, msg):
1241 1244 """Get a list of all msg_ids in our DB records"""
1242 1245 try:
1243 1246 msg_ids = self.db.get_history()
1244 1247 except Exception as e:
1245 1248 content = error.wrap_exception()
1246 1249 else:
1247 1250 content = dict(status='ok', history=msg_ids)
1248 1251
1249 1252 self.session.send(self.query, "history_reply", content=content,
1250 1253 parent=msg, ident=client_id)
1251 1254
1252 1255 def db_query(self, client_id, msg):
1253 1256 """Perform a raw query on the task record database."""
1254 1257 content = msg['content']
1255 1258 query = content.get('query', {})
1256 1259 keys = content.get('keys', None)
1257 1260 buffers = []
1258 1261 empty = list()
1259 1262 try:
1260 1263 records = self.db.find_records(query, keys)
1261 1264 except Exception as e:
1262 1265 content = error.wrap_exception()
1263 1266 else:
1264 1267 # extract buffers from reply content:
1265 1268 if keys is not None:
1266 1269 buffer_lens = [] if 'buffers' in keys else None
1267 1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1268 1271 else:
1269 1272 buffer_lens = []
1270 1273 result_buffer_lens = []
1271 1274
1272 1275 for rec in records:
1273 1276 # buffers may be None, so double check
1274 1277 if buffer_lens is not None:
1275 1278 b = rec.pop('buffers', empty) or empty
1276 1279 buffer_lens.append(len(b))
1277 1280 buffers.extend(b)
1278 1281 if result_buffer_lens is not None:
1279 1282 rb = rec.pop('result_buffers', empty) or empty
1280 1283 result_buffer_lens.append(len(rb))
1281 1284 buffers.extend(rb)
1282 1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1283 1286 result_buffer_lens=result_buffer_lens)
1284
1287 # self.log.debug (content)
1285 1288 self.session.send(self.query, "db_reply", content=content,
1286 1289 parent=msg, ident=client_id,
1287 1290 buffers=buffers)
1288 1291
@@ -1,705 +1,714 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6
7 7 Authors:
8 8
9 9 * Min RK
10 10 """
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2010-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #----------------------------------------------------------------------
19 19 # Imports
20 20 #----------------------------------------------------------------------
21 21
22 22 from __future__ import print_function
23 23
24 24 import logging
25 25 import sys
26 26
27 27 from datetime import datetime, timedelta
28 28 from random import randint, random
29 29 from types import FunctionType
30 30
31 31 try:
32 32 import numpy
33 33 except ImportError:
34 34 numpy = None
35 35
36 36 import zmq
37 37 from zmq.eventloop import ioloop, zmqstream
38 38
39 39 # local imports
40 40 from IPython.external.decorator import decorator
41 41 from IPython.config.application import Application
42 42 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Int, Enum, CBytes
44 44
45 45 from IPython.parallel import error
46 46 from IPython.parallel.factory import SessionFactory
47 from IPython.parallel.util import connect_logger, local_logger
47 from IPython.parallel.util import connect_logger, local_logger, ensure_bytes
48 48
49 49 from .dependency import Dependency
50 50
51 51 @decorator
52 52 def logged(f,self,*args,**kwargs):
53 53 # print ("#--------------------")
54 54 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
55 55 # print ("#--")
56 56 return f(self,*args, **kwargs)
57 57
58 58 #----------------------------------------------------------------------
59 59 # Chooser functions
60 60 #----------------------------------------------------------------------
61 61
62 62 def plainrandom(loads):
63 63 """Plain random pick."""
64 64 n = len(loads)
65 65 return randint(0,n-1)
66 66
67 67 def lru(loads):
68 68 """Always pick the front of the line.
69 69
70 70 The content of `loads` is ignored.
71 71
72 72 Assumes LRU ordering of loads, with oldest first.
73 73 """
74 74 return 0
75 75
76 76 def twobin(loads):
77 77 """Pick two at random, use the LRU of the two.
78 78
79 79 The content of loads is ignored.
80 80
81 81 Assumes LRU ordering of loads, with oldest first.
82 82 """
83 83 n = len(loads)
84 84 a = randint(0,n-1)
85 85 b = randint(0,n-1)
86 86 return min(a,b)
87 87
88 88 def weighted(loads):
89 89 """Pick two at random using inverse load as weight.
90 90
91 91 Return the less loaded of the two.
92 92 """
93 93 # weight 0 a million times more than 1:
94 94 weights = 1./(1e-6+numpy.array(loads))
95 95 sums = weights.cumsum()
96 96 t = sums[-1]
97 97 x = random()*t
98 98 y = random()*t
99 99 idx = 0
100 100 idy = 0
101 101 while sums[idx] < x:
102 102 idx += 1
103 103 while sums[idy] < y:
104 104 idy += 1
105 105 if weights[idy] > weights[idx]:
106 106 return idy
107 107 else:
108 108 return idx
109 109
110 110 def leastload(loads):
111 111 """Always choose the lowest load.
112 112
113 113 If the lowest load occurs more than once, the first
114 114 occurance will be used. If loads has LRU ordering, this means
115 115 the LRU of those with the lowest load is chosen.
116 116 """
117 117 return loads.index(min(loads))
118 118
119 119 #---------------------------------------------------------------------
120 120 # Classes
121 121 #---------------------------------------------------------------------
122 122 # store empty default dependency:
123 123 MET = Dependency([])
124 124
125 125 class TaskScheduler(SessionFactory):
126 126 """Python TaskScheduler object.
127 127
128 128 This is the simplest object that supports msg_id based
129 129 DAG dependencies. *Only* task msg_ids are checked, not
130 130 msg_ids of jobs submitted via the MUX queue.
131 131
132 132 """
133 133
134 134 hwm = Int(0, config=True, shortname='hwm',
135 135 help="""specify the High Water Mark (HWM) for the downstream
136 136 socket in the Task scheduler. This is the maximum number
137 137 of allowed outstanding tasks on each engine."""
138 138 )
139 139 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
140 140 'leastload', config=True, shortname='scheme', allow_none=False,
141 141 help="""select the task scheduler scheme [default: Python LRU]
142 142 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
143 143 )
144 144 def _scheme_name_changed(self, old, new):
145 145 self.log.debug("Using scheme %r"%new)
146 146 self.scheme = globals()[new]
147 147
148 148 # input arguments:
149 149 scheme = Instance(FunctionType) # function for determining the destination
150 150 def _scheme_default(self):
151 151 return leastload
152 152 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
153 153 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
154 154 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
155 155 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
156 156
157 157 # internals:
158 158 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
159 159 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
160 160 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
161 161 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
162 162 pending = Dict() # dict by engine_uuid of submitted tasks
163 163 completed = Dict() # dict by engine_uuid of completed tasks
164 164 failed = Dict() # dict by engine_uuid of failed tasks
165 165 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
166 166 clients = Dict() # dict by msg_id for who submitted the task
167 167 targets = List() # list of target IDENTs
168 168 loads = List() # list of engine loads
169 169 # full = Set() # set of IDENTs that have HWM outstanding tasks
170 170 all_completed = Set() # set of all completed tasks
171 171 all_failed = Set() # set of all failed tasks
172 172 all_done = Set() # set of all finished tasks=union(completed,failed)
173 173 all_ids = Set() # set of all submitted task IDs
174 174 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
175 175 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
176 176
177 ident = CBytes() # ZMQ identity. This should just be self.session.session
178 # but ensure Bytes
179 def _ident_default(self):
180 return ensure_bytes(self.session.session)
177 181
178 182 def start(self):
179 183 self.engine_stream.on_recv(self.dispatch_result, copy=False)
180 184 self._notification_handlers = dict(
181 185 registration_notification = self._register_engine,
182 186 unregistration_notification = self._unregister_engine
183 187 )
184 188 self.notifier_stream.on_recv(self.dispatch_notification)
185 189 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
186 190 self.auditor.start()
187 191 self.log.info("Scheduler started [%s]"%self.scheme_name)
188 192
189 193 def resume_receiving(self):
190 194 """Resume accepting jobs."""
191 195 self.client_stream.on_recv(self.dispatch_submission, copy=False)
192 196
193 197 def stop_receiving(self):
194 198 """Stop accepting jobs while there are no engines.
195 199 Leave them in the ZMQ queue."""
196 200 self.client_stream.on_recv(None)
197 201
198 202 #-----------------------------------------------------------------------
199 203 # [Un]Registration Handling
200 204 #-----------------------------------------------------------------------
201 205
202 206 def dispatch_notification(self, msg):
203 207 """dispatch register/unregister events."""
204 208 try:
205 209 idents,msg = self.session.feed_identities(msg)
206 210 except ValueError:
207 self.log.warn("task::Invalid Message: %r"%msg)
211 self.log.warn("task::Invalid Message: %r",msg)
208 212 return
209 213 try:
210 214 msg = self.session.unpack_message(msg)
211 215 except ValueError:
212 216 self.log.warn("task::Unauthorized message from: %r"%idents)
213 217 return
214 218
215 219 msg_type = msg['msg_type']
216 220
217 221 handler = self._notification_handlers.get(msg_type, None)
218 222 if handler is None:
219 223 self.log.error("Unhandled message type: %r"%msg_type)
220 224 else:
221 225 try:
222 handler(str(msg['content']['queue']))
223 except KeyError:
224 self.log.error("task::Invalid notification msg: %r"%msg)
226 handler(ensure_bytes(msg['content']['queue']))
227 except Exception:
228 self.log.error("task::Invalid notification msg: %r",msg)
225 229
226 230 def _register_engine(self, uid):
227 231 """New engine with ident `uid` became available."""
228 232 # head of the line:
229 233 self.targets.insert(0,uid)
230 234 self.loads.insert(0,0)
235
231 236 # initialize sets
232 237 self.completed[uid] = set()
233 238 self.failed[uid] = set()
234 239 self.pending[uid] = {}
235 240 if len(self.targets) == 1:
236 241 self.resume_receiving()
237 242 # rescan the graph:
238 243 self.update_graph(None)
239 244
240 245 def _unregister_engine(self, uid):
241 246 """Existing engine with ident `uid` became unavailable."""
242 247 if len(self.targets) == 1:
243 248 # this was our only engine
244 249 self.stop_receiving()
245 250
246 251 # handle any potentially finished tasks:
247 252 self.engine_stream.flush()
248 253
249 254 # don't pop destinations, because they might be used later
250 255 # map(self.destinations.pop, self.completed.pop(uid))
251 256 # map(self.destinations.pop, self.failed.pop(uid))
252 257
253 258 # prevent this engine from receiving work
254 259 idx = self.targets.index(uid)
255 260 self.targets.pop(idx)
256 261 self.loads.pop(idx)
257 262
258 263 # wait 5 seconds before cleaning up pending jobs, since the results might
259 264 # still be incoming
260 265 if self.pending[uid]:
261 266 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
262 267 dc.start()
263 268 else:
264 269 self.completed.pop(uid)
265 270 self.failed.pop(uid)
266 271
267 272
268 273 def handle_stranded_tasks(self, engine):
269 274 """Deal with jobs resident in an engine that died."""
270 275 lost = self.pending[engine]
271 276 for msg_id in lost.keys():
272 277 if msg_id not in self.pending[engine]:
273 278 # prevent double-handling of messages
274 279 continue
275 280
276 281 raw_msg = lost[msg_id][0]
277 282 idents,msg = self.session.feed_identities(raw_msg, copy=False)
278 283 parent = self.session.unpack(msg[1].bytes)
279 284 idents = [engine, idents[0]]
280 285
281 286 # build fake error reply
282 287 try:
283 288 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
284 289 except:
285 290 content = error.wrap_exception()
286 291 msg = self.session.msg('apply_reply', content, parent=parent, subheader={'status':'error'})
287 292 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
288 293 # and dispatch it
289 294 self.dispatch_result(raw_reply)
290 295
291 296 # finally scrub completed/failed lists
292 297 self.completed.pop(engine)
293 298 self.failed.pop(engine)
294 299
295 300
296 301 #-----------------------------------------------------------------------
297 302 # Job Submission
298 303 #-----------------------------------------------------------------------
299 304 def dispatch_submission(self, raw_msg):
300 305 """Dispatch job submission to appropriate handlers."""
301 306 # ensure targets up to date:
302 307 self.notifier_stream.flush()
303 308 try:
304 309 idents, msg = self.session.feed_identities(raw_msg, copy=False)
305 310 msg = self.session.unpack_message(msg, content=False, copy=False)
306 311 except Exception:
307 312 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
308 313 return
309 314
310 315
311 316 # send to monitor
312 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
317 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
313 318
314 319 header = msg['header']
315 320 msg_id = header['msg_id']
316 321 self.all_ids.add(msg_id)
317 322
318 # targets
319 targets = set(header.get('targets', []))
323 # get targets as a set of bytes objects
324 # from a list of unicode objects
325 targets = header.get('targets', [])
326 targets = map(ensure_bytes, targets)
327 targets = set(targets)
328
320 329 retries = header.get('retries', 0)
321 330 self.retries[msg_id] = retries
322 331
323 332 # time dependencies
324 333 after = header.get('after', None)
325 334 if after:
326 335 after = Dependency(after)
327 336 if after.all:
328 337 if after.success:
329 338 after = Dependency(after.difference(self.all_completed),
330 339 success=after.success,
331 340 failure=after.failure,
332 341 all=after.all,
333 342 )
334 343 if after.failure:
335 344 after = Dependency(after.difference(self.all_failed),
336 345 success=after.success,
337 346 failure=after.failure,
338 347 all=after.all,
339 348 )
340 349 if after.check(self.all_completed, self.all_failed):
341 350 # recast as empty set, if `after` already met,
342 351 # to prevent unnecessary set comparisons
343 352 after = MET
344 353 else:
345 354 after = MET
346 355
347 356 # location dependencies
348 357 follow = Dependency(header.get('follow', []))
349 358
350 359 # turn timeouts into datetime objects:
351 360 timeout = header.get('timeout', None)
352 361 if timeout:
353 362 timeout = datetime.now() + timedelta(0,timeout,0)
354 363
355 364 args = [raw_msg, targets, after, follow, timeout]
356 365
357 366 # validate and reduce dependencies:
358 367 for dep in after,follow:
359 368 if not dep: # empty dependency
360 369 continue
361 370 # check valid:
362 371 if msg_id in dep or dep.difference(self.all_ids):
363 372 self.depending[msg_id] = args
364 373 return self.fail_unreachable(msg_id, error.InvalidDependency)
365 374 # check if unreachable:
366 375 if dep.unreachable(self.all_completed, self.all_failed):
367 376 self.depending[msg_id] = args
368 377 return self.fail_unreachable(msg_id)
369 378
370 379 if after.check(self.all_completed, self.all_failed):
371 380 # time deps already met, try to run
372 381 if not self.maybe_run(msg_id, *args):
373 382 # can't run yet
374 383 if msg_id not in self.all_failed:
375 384 # could have failed as unreachable
376 385 self.save_unmet(msg_id, *args)
377 386 else:
378 387 self.save_unmet(msg_id, *args)
379 388
380 389 def audit_timeouts(self):
381 390 """Audit all waiting tasks for expired timeouts."""
382 391 now = datetime.now()
383 392 for msg_id in self.depending.keys():
384 393 # must recheck, in case one failure cascaded to another:
385 394 if msg_id in self.depending:
386 395 raw,after,targets,follow,timeout = self.depending[msg_id]
387 396 if timeout and timeout < now:
388 397 self.fail_unreachable(msg_id, error.TaskTimeout)
389 398
390 399 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
391 400 """a task has become unreachable, send a reply with an ImpossibleDependency
392 401 error."""
393 402 if msg_id not in self.depending:
394 403 self.log.error("msg %r already failed!", msg_id)
395 404 return
396 405 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
397 406 for mid in follow.union(after):
398 407 if mid in self.graph:
399 408 self.graph[mid].remove(msg_id)
400 409
401 410 # FIXME: unpacking a message I've already unpacked, but didn't save:
402 411 idents,msg = self.session.feed_identities(raw_msg, copy=False)
403 412 header = self.session.unpack(msg[1].bytes)
404 413
405 414 try:
406 415 raise why()
407 416 except:
408 417 content = error.wrap_exception()
409 418
410 419 self.all_done.add(msg_id)
411 420 self.all_failed.add(msg_id)
412 421
413 422 msg = self.session.send(self.client_stream, 'apply_reply', content,
414 423 parent=header, ident=idents)
415 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
424 self.session.send(self.mon_stream, msg, ident=[b'outtask']+idents)
416 425
417 426 self.update_graph(msg_id, success=False)
418 427
419 428 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
420 429 """check location dependencies, and run if they are met."""
421 430 blacklist = self.blacklist.setdefault(msg_id, set())
422 431 if follow or targets or blacklist or self.hwm:
423 432 # we need a can_run filter
424 433 def can_run(idx):
425 434 # check hwm
426 435 if self.hwm and self.loads[idx] == self.hwm:
427 436 return False
428 437 target = self.targets[idx]
429 438 # check blacklist
430 439 if target in blacklist:
431 440 return False
432 441 # check targets
433 442 if targets and target not in targets:
434 443 return False
435 444 # check follow
436 445 return follow.check(self.completed[target], self.failed[target])
437 446
438 447 indices = filter(can_run, range(len(self.targets)))
439 448
440 449 if not indices:
441 450 # couldn't run
442 451 if follow.all:
443 452 # check follow for impossibility
444 453 dests = set()
445 454 relevant = set()
446 455 if follow.success:
447 456 relevant = self.all_completed
448 457 if follow.failure:
449 458 relevant = relevant.union(self.all_failed)
450 459 for m in follow.intersection(relevant):
451 460 dests.add(self.destinations[m])
452 461 if len(dests) > 1:
453 462 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
454 463 self.fail_unreachable(msg_id)
455 464 return False
456 465 if targets:
457 466 # check blacklist+targets for impossibility
458 467 targets.difference_update(blacklist)
459 468 if not targets or not targets.intersection(self.targets):
460 469 self.depending[msg_id] = (raw_msg, targets, after, follow, timeout)
461 470 self.fail_unreachable(msg_id)
462 471 return False
463 472 return False
464 473 else:
465 474 indices = None
466 475
467 476 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
468 477 return True
469 478
470 479 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
471 480 """Save a message for later submission when its dependencies are met."""
472 481 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
473 482 # track the ids in follow or after, but not those already finished
474 483 for dep_id in after.union(follow).difference(self.all_done):
475 484 if dep_id not in self.graph:
476 485 self.graph[dep_id] = set()
477 486 self.graph[dep_id].add(msg_id)
478 487
479 488 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
480 489 """Submit a task to any of a subset of our targets."""
481 490 if indices:
482 491 loads = [self.loads[i] for i in indices]
483 492 else:
484 493 loads = self.loads
485 494 idx = self.scheme(loads)
486 495 if indices:
487 496 idx = indices[idx]
488 497 target = self.targets[idx]
489 498 # print (target, map(str, msg[:3]))
490 499 # send job to the engine
491 500 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
492 501 self.engine_stream.send_multipart(raw_msg, copy=False)
493 502 # update load
494 503 self.add_job(idx)
495 504 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
496 505 # notify Hub
497 content = dict(msg_id=msg_id, engine_id=target)
506 content = dict(msg_id=msg_id, engine_id=target.decode())
498 507 self.session.send(self.mon_stream, 'task_destination', content=content,
499 ident=['tracktask',self.session.session])
508 ident=[b'tracktask',self.ident])
500 509
501 510
502 511 #-----------------------------------------------------------------------
503 512 # Result Handling
504 513 #-----------------------------------------------------------------------
505 514 def dispatch_result(self, raw_msg):
506 515 """dispatch method for result replies"""
507 516 try:
508 517 idents,msg = self.session.feed_identities(raw_msg, copy=False)
509 518 msg = self.session.unpack_message(msg, content=False, copy=False)
510 519 engine = idents[0]
511 520 try:
512 521 idx = self.targets.index(engine)
513 522 except ValueError:
514 523 pass # skip load-update for dead engines
515 524 else:
516 525 self.finish_job(idx)
517 526 except Exception:
518 527 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
519 528 return
520 529
521 530 header = msg['header']
522 531 parent = msg['parent_header']
523 532 if header.get('dependencies_met', True):
524 533 success = (header['status'] == 'ok')
525 534 msg_id = parent['msg_id']
526 535 retries = self.retries[msg_id]
527 536 if not success and retries > 0:
528 537 # failed
529 538 self.retries[msg_id] = retries - 1
530 539 self.handle_unmet_dependency(idents, parent)
531 540 else:
532 541 del self.retries[msg_id]
533 542 # relay to client and update graph
534 543 self.handle_result(idents, parent, raw_msg, success)
535 544 # send to Hub monitor
536 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
545 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
537 546 else:
538 547 self.handle_unmet_dependency(idents, parent)
539 548
540 549 def handle_result(self, idents, parent, raw_msg, success=True):
541 550 """handle a real task result, either success or failure"""
542 551 # first, relay result to client
543 552 engine = idents[0]
544 553 client = idents[1]
545 554 # swap_ids for XREP-XREP mirror
546 555 raw_msg[:2] = [client,engine]
547 556 # print (map(str, raw_msg[:4]))
548 557 self.client_stream.send_multipart(raw_msg, copy=False)
549 558 # now, update our data structures
550 559 msg_id = parent['msg_id']
551 560 self.blacklist.pop(msg_id, None)
552 561 self.pending[engine].pop(msg_id)
553 562 if success:
554 563 self.completed[engine].add(msg_id)
555 564 self.all_completed.add(msg_id)
556 565 else:
557 566 self.failed[engine].add(msg_id)
558 567 self.all_failed.add(msg_id)
559 568 self.all_done.add(msg_id)
560 569 self.destinations[msg_id] = engine
561 570
562 571 self.update_graph(msg_id, success)
563 572
564 573 def handle_unmet_dependency(self, idents, parent):
565 574 """handle an unmet dependency"""
566 575 engine = idents[0]
567 576 msg_id = parent['msg_id']
568 577
569 578 if msg_id not in self.blacklist:
570 579 self.blacklist[msg_id] = set()
571 580 self.blacklist[msg_id].add(engine)
572 581
573 582 args = self.pending[engine].pop(msg_id)
574 583 raw,targets,after,follow,timeout = args
575 584
576 585 if self.blacklist[msg_id] == targets:
577 586 self.depending[msg_id] = args
578 587 self.fail_unreachable(msg_id)
579 588 elif not self.maybe_run(msg_id, *args):
580 589 # resubmit failed
581 590 if msg_id not in self.all_failed:
582 591 # put it back in our dependency tree
583 592 self.save_unmet(msg_id, *args)
584 593
585 594 if self.hwm:
586 595 try:
587 596 idx = self.targets.index(engine)
588 597 except ValueError:
589 598 pass # skip load-update for dead engines
590 599 else:
591 600 if self.loads[idx] == self.hwm-1:
592 601 self.update_graph(None)
593 602
594 603
595 604
596 605 def update_graph(self, dep_id=None, success=True):
597 606 """dep_id just finished. Update our dependency
598 607 graph and submit any jobs that just became runable.
599 608
600 609 Called with dep_id=None to update entire graph for hwm, but without finishing
601 610 a task.
602 611 """
603 612 # print ("\n\n***********")
604 613 # pprint (dep_id)
605 614 # pprint (self.graph)
606 615 # pprint (self.depending)
607 616 # pprint (self.all_completed)
608 617 # pprint (self.all_failed)
609 618 # print ("\n\n***********\n\n")
610 619 # update any jobs that depended on the dependency
611 620 jobs = self.graph.pop(dep_id, [])
612 621
613 622 # recheck *all* jobs if
614 623 # a) we have HWM and an engine just become no longer full
615 624 # or b) dep_id was given as None
616 625 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
617 626 jobs = self.depending.keys()
618 627
619 628 for msg_id in jobs:
620 629 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
621 630
622 631 if after.unreachable(self.all_completed, self.all_failed)\
623 632 or follow.unreachable(self.all_completed, self.all_failed):
624 633 self.fail_unreachable(msg_id)
625 634
626 635 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
627 636 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
628 637
629 638 self.depending.pop(msg_id)
630 639 for mid in follow.union(after):
631 640 if mid in self.graph:
632 641 self.graph[mid].remove(msg_id)
633 642
634 643 #----------------------------------------------------------------------
635 644 # methods to be overridden by subclasses
636 645 #----------------------------------------------------------------------
637 646
638 647 def add_job(self, idx):
639 648 """Called after self.targets[idx] just got the job with header.
640 649 Override with subclasses. The default ordering is simple LRU.
641 650 The default loads are the number of outstanding jobs."""
642 651 self.loads[idx] += 1
643 652 for lis in (self.targets, self.loads):
644 653 lis.append(lis.pop(idx))
645 654
646 655
647 656 def finish_job(self, idx):
648 657 """Called after self.targets[idx] just finished a job.
649 658 Override with subclasses."""
650 659 self.loads[idx] -= 1
651 660
652 661
653 662
654 663 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
655 664 logname='root', log_url=None, loglevel=logging.DEBUG,
656 665 identity=b'task', in_thread=False):
657 666
658 667 ZMQStream = zmqstream.ZMQStream
659 668
660 669 if config:
661 670 # unwrap dict back into Config
662 671 config = Config(config)
663 672
664 673 if in_thread:
665 674 # use instance() to get the same Context/Loop as our parent
666 675 ctx = zmq.Context.instance()
667 676 loop = ioloop.IOLoop.instance()
668 677 else:
669 678 # in a process, don't use instance()
670 679 # for safety with multiprocessing
671 680 ctx = zmq.Context()
672 681 loop = ioloop.IOLoop()
673 682 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
674 683 ins.setsockopt(zmq.IDENTITY, identity)
675 684 ins.bind(in_addr)
676 685
677 686 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
678 687 outs.setsockopt(zmq.IDENTITY, identity)
679 688 outs.bind(out_addr)
680 689 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
681 690 mons.connect(mon_addr)
682 691 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
683 692 nots.setsockopt(zmq.SUBSCRIBE, b'')
684 693 nots.connect(not_addr)
685 694
686 695 # setup logging.
687 696 if in_thread:
688 697 log = Application.instance().log
689 698 else:
690 699 if log_url:
691 700 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
692 701 else:
693 702 log = local_logger(logname, loglevel)
694 703
695 704 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
696 705 mon_stream=mons, notifier_stream=nots,
697 706 loop=loop, log=log,
698 707 config=config)
699 708 scheduler.start()
700 709 if not in_thread:
701 710 try:
702 711 loop.start()
703 712 except KeyboardInterrupt:
704 713 print ("interrupted, exiting...", file=sys.__stderr__)
705 714
@@ -1,391 +1,401 b''
1 1 """A TaskRecord backend using sqlite3
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 import json
15 15 import os
16 16 import cPickle as pickle
17 17 from datetime import datetime
18 18
19 19 import sqlite3
20 20
21 21 from zmq.eventloop import ioloop
22 22
23 23 from IPython.utils.traitlets import Unicode, Instance, List, Dict
24 24 from .dictdb import BaseDB
25 25 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
26 26
27 27 #-----------------------------------------------------------------------------
28 28 # SQLite operators, adapters, and converters
29 29 #-----------------------------------------------------------------------------
30 30
31 try:
32 buffer
33 except NameError:
34 # py3k
35 buffer = memoryview
36
31 37 operators = {
32 38 '$lt' : "<",
33 39 '$gt' : ">",
34 40 # null is handled weird with ==,!=
35 41 '$eq' : "=",
36 42 '$ne' : "!=",
37 43 '$lte': "<=",
38 44 '$gte': ">=",
39 45 '$in' : ('=', ' OR '),
40 46 '$nin': ('!=', ' AND '),
41 47 # '$all': None,
42 48 # '$mod': None,
43 49 # '$exists' : None
44 50 }
45 51 null_operators = {
46 52 '=' : "IS NULL",
47 53 '!=' : "IS NOT NULL",
48 54 }
49 55
50 56 def _adapt_dict(d):
51 57 return json.dumps(d, default=date_default)
52 58
53 59 def _convert_dict(ds):
54 60 if ds is None:
55 61 return ds
56 62 else:
57 return extract_dates(json.loads(ds))
63 if isinstance(ds, bytes):
64 # If I understand the sqlite doc correctly, this will always be utf8
65 ds = ds.decode('utf8')
66 d = json.loads(ds)
67 return extract_dates(d)
58 68
59 69 def _adapt_bufs(bufs):
60 70 # this is *horrible*
61 71 # copy buffers into single list and pickle it:
62 72 if bufs and isinstance(bufs[0], (bytes, buffer)):
63 73 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
64 74 elif bufs:
65 75 return bufs
66 76 else:
67 77 return None
68 78
69 79 def _convert_bufs(bs):
70 80 if bs is None:
71 81 return []
72 82 else:
73 83 return pickle.loads(bytes(bs))
74 84
75 85 #-----------------------------------------------------------------------------
76 86 # SQLiteDB class
77 87 #-----------------------------------------------------------------------------
78 88
79 89 class SQLiteDB(BaseDB):
80 90 """SQLite3 TaskRecord backend."""
81 91
82 92 filename = Unicode('tasks.db', config=True,
83 93 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
84 94 location = Unicode('', config=True,
85 95 help="""The directory containing the sqlite task database. The default
86 96 is to use the cluster_dir location.""")
87 97 table = Unicode("", config=True,
88 98 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
89 99 a new table will be created with the Hub's IDENT. Specifying the table will result
90 100 in tasks from previous sessions being available via Clients' db_query and
91 101 get_result methods.""")
92 102
93 103 _db = Instance('sqlite3.Connection')
94 104 # the ordered list of column names
95 105 _keys = List(['msg_id' ,
96 106 'header' ,
97 107 'content',
98 108 'buffers',
99 109 'submitted',
100 110 'client_uuid' ,
101 111 'engine_uuid' ,
102 112 'started',
103 113 'completed',
104 114 'resubmitted',
105 115 'result_header' ,
106 116 'result_content' ,
107 117 'result_buffers' ,
108 118 'queue' ,
109 119 'pyin' ,
110 120 'pyout',
111 121 'pyerr',
112 122 'stdout',
113 123 'stderr',
114 124 ])
115 125 # sqlite datatypes for checking that db is current format
116 126 _types = Dict({'msg_id' : 'text' ,
117 127 'header' : 'dict text',
118 128 'content' : 'dict text',
119 129 'buffers' : 'bufs blob',
120 130 'submitted' : 'timestamp',
121 131 'client_uuid' : 'text',
122 132 'engine_uuid' : 'text',
123 133 'started' : 'timestamp',
124 134 'completed' : 'timestamp',
125 135 'resubmitted' : 'timestamp',
126 136 'result_header' : 'dict text',
127 137 'result_content' : 'dict text',
128 138 'result_buffers' : 'bufs blob',
129 139 'queue' : 'text',
130 140 'pyin' : 'text',
131 141 'pyout' : 'text',
132 142 'pyerr' : 'text',
133 143 'stdout' : 'text',
134 144 'stderr' : 'text',
135 145 })
136 146
137 147 def __init__(self, **kwargs):
138 148 super(SQLiteDB, self).__init__(**kwargs)
139 149 if not self.table:
140 150 # use session, and prefix _, since starting with # is illegal
141 151 self.table = '_'+self.session.replace('-','_')
142 152 if not self.location:
143 153 # get current profile
144 154 from IPython.core.application import BaseIPythonApplication
145 155 if BaseIPythonApplication.initialized():
146 156 app = BaseIPythonApplication.instance()
147 157 if app.profile_dir is not None:
148 158 self.location = app.profile_dir.location
149 159 else:
150 160 self.location = u'.'
151 161 else:
152 162 self.location = u'.'
153 163 self._init_db()
154 164
155 165 # register db commit as 2s periodic callback
156 166 # to prevent clogging pipes
157 167 # assumes we are being run in a zmq ioloop app
158 168 loop = ioloop.IOLoop.instance()
159 169 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
160 170 pc.start()
161 171
162 172 def _defaults(self, keys=None):
163 173 """create an empty record"""
164 174 d = {}
165 175 keys = self._keys if keys is None else keys
166 176 for key in keys:
167 177 d[key] = None
168 178 return d
169 179
170 180 def _check_table(self):
171 181 """Ensure that an incorrect table doesn't exist
172 182
173 183 If a bad (old) table does exist, return False
174 184 """
175 185 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
176 186 lines = cursor.fetchall()
177 187 if not lines:
178 188 # table does not exist
179 189 return True
180 190 types = {}
181 191 keys = []
182 192 for line in lines:
183 193 keys.append(line[1])
184 194 types[line[1]] = line[2]
185 195 if self._keys != keys:
186 196 # key mismatch
187 197 self.log.warn('keys mismatch')
188 198 return False
189 199 for key in self._keys:
190 200 if types[key] != self._types[key]:
191 201 self.log.warn(
192 202 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
193 203 )
194 204 return False
195 205 return True
196 206
197 207 def _init_db(self):
198 208 """Connect to the database and get new session number."""
199 209 # register adapters
200 210 sqlite3.register_adapter(dict, _adapt_dict)
201 211 sqlite3.register_converter('dict', _convert_dict)
202 212 sqlite3.register_adapter(list, _adapt_bufs)
203 213 sqlite3.register_converter('bufs', _convert_bufs)
204 214 # connect to the db
205 215 dbfile = os.path.join(self.location, self.filename)
206 216 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
207 217 # isolation_level = None)#,
208 218 cached_statements=64)
209 219 # print dir(self._db)
210 220 first_table = self.table
211 221 i=0
212 222 while not self._check_table():
213 223 i+=1
214 224 self.table = first_table+'_%i'%i
215 225 self.log.warn(
216 226 "Table %s exists and doesn't match db format, trying %s"%
217 227 (first_table,self.table)
218 228 )
219 229
220 230 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
221 231 (msg_id text PRIMARY KEY,
222 232 header dict text,
223 233 content dict text,
224 234 buffers bufs blob,
225 235 submitted timestamp,
226 236 client_uuid text,
227 237 engine_uuid text,
228 238 started timestamp,
229 239 completed timestamp,
230 240 resubmitted timestamp,
231 241 result_header dict text,
232 242 result_content dict text,
233 243 result_buffers bufs blob,
234 244 queue text,
235 245 pyin text,
236 246 pyout text,
237 247 pyerr text,
238 248 stdout text,
239 249 stderr text)
240 250 """%self.table)
241 251 self._db.commit()
242 252
243 253 def _dict_to_list(self, d):
244 254 """turn a mongodb-style record dict into a list."""
245 255
246 256 return [ d[key] for key in self._keys ]
247 257
248 258 def _list_to_dict(self, line, keys=None):
249 259 """Inverse of dict_to_list"""
250 260 keys = self._keys if keys is None else keys
251 261 d = self._defaults(keys)
252 262 for key,value in zip(keys, line):
253 263 d[key] = value
254 264
255 265 return d
256 266
257 267 def _render_expression(self, check):
258 268 """Turn a mongodb-style search dict into an SQL query."""
259 269 expressions = []
260 270 args = []
261 271
262 272 skeys = set(check.keys())
263 273 skeys.difference_update(set(self._keys))
264 274 skeys.difference_update(set(['buffers', 'result_buffers']))
265 275 if skeys:
266 276 raise KeyError("Illegal testing key(s): %s"%skeys)
267 277
268 278 for name,sub_check in check.iteritems():
269 279 if isinstance(sub_check, dict):
270 280 for test,value in sub_check.iteritems():
271 281 try:
272 282 op = operators[test]
273 283 except KeyError:
274 284 raise KeyError("Unsupported operator: %r"%test)
275 285 if isinstance(op, tuple):
276 286 op, join = op
277 287
278 288 if value is None and op in null_operators:
279 289 expr = "%s %s"%null_operators[op]
280 290 else:
281 291 expr = "%s %s ?"%(name, op)
282 292 if isinstance(value, (tuple,list)):
283 293 if op in null_operators and any([v is None for v in value]):
284 294 # equality tests don't work with NULL
285 295 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
286 296 expr = '( %s )'%( join.join([expr]*len(value)) )
287 297 args.extend(value)
288 298 else:
289 299 args.append(value)
290 300 expressions.append(expr)
291 301 else:
292 302 # it's an equality check
293 303 if sub_check is None:
294 304 expressions.append("%s IS NULL")
295 305 else:
296 306 expressions.append("%s = ?"%name)
297 307 args.append(sub_check)
298 308
299 309 expr = " AND ".join(expressions)
300 310 return expr, args
301 311
302 312 def add_record(self, msg_id, rec):
303 313 """Add a new Task Record, by msg_id."""
304 314 d = self._defaults()
305 315 d.update(rec)
306 316 d['msg_id'] = msg_id
307 317 line = self._dict_to_list(d)
308 318 tups = '(%s)'%(','.join(['?']*len(line)))
309 319 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
310 320 # self._db.commit()
311 321
312 322 def get_record(self, msg_id):
313 323 """Get a specific Task Record, by msg_id."""
314 324 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
315 325 line = cursor.fetchone()
316 326 if line is None:
317 327 raise KeyError("No such msg: %r"%msg_id)
318 328 return self._list_to_dict(line)
319 329
320 330 def update_record(self, msg_id, rec):
321 331 """Update the data in an existing record."""
322 332 query = "UPDATE %s SET "%self.table
323 333 sets = []
324 334 keys = sorted(rec.keys())
325 335 values = []
326 336 for key in keys:
327 337 sets.append('%s = ?'%key)
328 338 values.append(rec[key])
329 339 query += ', '.join(sets)
330 340 query += ' WHERE msg_id == ?'
331 341 values.append(msg_id)
332 342 self._db.execute(query, values)
333 343 # self._db.commit()
334 344
335 345 def drop_record(self, msg_id):
336 346 """Remove a record from the DB."""
337 347 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
338 348 # self._db.commit()
339 349
340 350 def drop_matching_records(self, check):
341 351 """Remove a record from the DB."""
342 352 expr,args = self._render_expression(check)
343 353 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
344 354 self._db.execute(query,args)
345 355 # self._db.commit()
346 356
347 357 def find_records(self, check, keys=None):
348 358 """Find records matching a query dict, optionally extracting subset of keys.
349 359
350 360 Returns list of matching records.
351 361
352 362 Parameters
353 363 ----------
354 364
355 365 check: dict
356 366 mongodb-style query argument
357 367 keys: list of strs [optional]
358 368 if specified, the subset of keys to extract. msg_id will *always* be
359 369 included.
360 370 """
361 371 if keys:
362 372 bad_keys = [ key for key in keys if key not in self._keys ]
363 373 if bad_keys:
364 374 raise KeyError("Bad record key(s): %s"%bad_keys)
365 375
366 376 if keys:
367 377 # ensure msg_id is present and first:
368 378 if 'msg_id' in keys:
369 379 keys.remove('msg_id')
370 380 keys.insert(0, 'msg_id')
371 381 req = ', '.join(keys)
372 382 else:
373 383 req = '*'
374 384 expr,args = self._render_expression(check)
375 385 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
376 386 cursor = self._db.execute(query, args)
377 387 matches = cursor.fetchall()
378 388 records = []
379 389 for line in matches:
380 390 rec = self._list_to_dict(line, keys)
381 391 records.append(rec)
382 392 return records
383 393
384 394 def get_history(self):
385 395 """get all msg_ids, ordered by time submitted."""
386 396 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
387 397 cursor = self._db.execute(query)
388 398 # will be a list of length 1 tuples
389 399 return [ tup[0] for tup in cursor.fetchall()]
390 400
391 401 __all__ = ['SQLiteDB'] No newline at end of file
@@ -1,170 +1,170 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 4 connected to the Controller's Schedulers.
5 5
6 6 Authors:
7 7
8 8 * Min RK
9 9 """
10 10 #-----------------------------------------------------------------------------
11 11 # Copyright (C) 2010-2011 The IPython Development Team
12 12 #
13 13 # Distributed under the terms of the BSD License. The full license is in
14 14 # the file COPYING, distributed as part of this software.
15 15 #-----------------------------------------------------------------------------
16 16
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 # internal
26 26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode
27 27 # from IPython.utils.localinterfaces import LOCALHOST
28 28
29 29 from IPython.parallel.controller.heartmonitor import Heart
30 30 from IPython.parallel.factory import RegistrationFactory
31 from IPython.parallel.util import disambiguate_url
31 from IPython.parallel.util import disambiguate_url, ensure_bytes
32 32
33 33 from IPython.zmq.session import Message
34 34
35 35 from .streamkernel import Kernel
36 36
37 37 class EngineFactory(RegistrationFactory):
38 38 """IPython engine"""
39 39
40 40 # configurables:
41 41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 42 help="""The OutStream for handling stdout/err.
43 43 Typically 'IPython.zmq.iostream.OutStream'""")
44 44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 45 help="""The class for handling displayhook.
46 46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 47 location=Unicode(config=True,
48 48 help="""The location (an IP address) of the controller. This is
49 49 used for disambiguating URLs, to determine whether
50 50 loopback should be used to connect or the public address.""")
51 51 timeout=CFloat(2,config=True,
52 52 help="""The time (in seconds) to wait for the Controller to respond
53 53 to registration requests before giving up.""")
54 54
55 55 # not configurable:
56 56 user_ns=Dict()
57 57 id=Int(allow_none=True)
58 58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 59 kernel=Instance(Kernel)
60 60
61 61
62 62 def __init__(self, **kwargs):
63 63 super(EngineFactory, self).__init__(**kwargs)
64 64 self.ident = self.session.session
65 65 ctx = self.context
66 66
67 67 reg = ctx.socket(zmq.XREQ)
68 reg.setsockopt(zmq.IDENTITY, self.ident)
68 reg.setsockopt(zmq.IDENTITY, ensure_bytes(self.ident))
69 69 reg.connect(self.url)
70 70 self.registrar = zmqstream.ZMQStream(reg, self.loop)
71 71
72 72 def register(self):
73 73 """send the registration_request"""
74 74
75 75 self.log.info("Registering with controller at %s"%self.url)
76 76 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
77 77 self.registrar.on_recv(self.complete_registration)
78 78 # print (self.session.key)
79 79 self.session.send(self.registrar, "registration_request",content=content)
80 80
81 81 def complete_registration(self, msg):
82 82 # print msg
83 83 self._abort_dc.stop()
84 84 ctx = self.context
85 85 loop = self.loop
86 identity = self.ident
86 identity = ensure_bytes(self.ident)
87 87
88 88 idents,msg = self.session.feed_identities(msg)
89 89 msg = Message(self.session.unpack_message(msg))
90 90
91 91 if msg.content.status == 'ok':
92 92 self.id = int(msg.content.id)
93 93
94 94 # create Shell Streams (MUX, Task, etc.):
95 95 queue_addr = msg.content.mux
96 96 shell_addrs = [ str(queue_addr) ]
97 97 task_addr = msg.content.task
98 98 if task_addr:
99 99 shell_addrs.append(str(task_addr))
100 100
101 101 # Uncomment this to go back to two-socket model
102 102 # shell_streams = []
103 103 # for addr in shell_addrs:
104 104 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
105 105 # stream.setsockopt(zmq.IDENTITY, identity)
106 106 # stream.connect(disambiguate_url(addr, self.location))
107 107 # shell_streams.append(stream)
108 108
109 109 # Now use only one shell stream for mux and tasks
110 110 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
111 111 stream.setsockopt(zmq.IDENTITY, identity)
112 112 shell_streams = [stream]
113 113 for addr in shell_addrs:
114 114 stream.connect(disambiguate_url(addr, self.location))
115 115 # end single stream-socket
116 116
117 117 # control stream:
118 118 control_addr = str(msg.content.control)
119 119 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
120 120 control_stream.setsockopt(zmq.IDENTITY, identity)
121 121 control_stream.connect(disambiguate_url(control_addr, self.location))
122 122
123 123 # create iopub stream:
124 124 iopub_addr = msg.content.iopub
125 125 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
126 126 iopub_stream.setsockopt(zmq.IDENTITY, identity)
127 127 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
128 128
129 129 # launch heartbeat
130 130 hb_addrs = msg.content.heartbeat
131 131 # print (hb_addrs)
132 132
133 133 # # Redirect input streams and set a display hook.
134 134 if self.out_stream_factory:
135 135 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
136 136 sys.stdout.topic = 'engine.%i.stdout'%self.id
137 137 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
138 138 sys.stderr.topic = 'engine.%i.stderr'%self.id
139 139 if self.display_hook_factory:
140 140 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
141 141 sys.displayhook.topic = 'engine.%i.pyout'%self.id
142 142
143 143 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
144 144 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
145 145 loop=loop, user_ns = self.user_ns, log=self.log)
146 146 self.kernel.start()
147 147 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
148 148 heart = Heart(*map(str, hb_addrs), heart_id=identity)
149 149 heart.start()
150 150
151 151
152 152 else:
153 153 self.log.fatal("Registration Failed: %s"%msg)
154 154 raise Exception("Registration Failed: %s"%msg)
155 155
156 156 self.log.info("Completed registration with id %i"%self.id)
157 157
158 158
159 159 def abort(self):
160 160 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
161 161 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
162 162 time.sleep(1)
163 163 sys.exit(255)
164 164
165 165 def start(self):
166 166 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
167 167 dc.start()
168 168 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
169 169 self._abort_dc.start()
170 170
@@ -1,429 +1,438 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 * Brian Granger
9 9 * Fernando Perez
10 10 * Evan Patterson
11 11 """
12 12 #-----------------------------------------------------------------------------
13 13 # Copyright (C) 2010-2011 The IPython Development Team
14 14 #
15 15 # Distributed under the terms of the BSD License. The full license is in
16 16 # the file COPYING, distributed as part of this software.
17 17 #-----------------------------------------------------------------------------
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 # Standard library imports.
24 24 from __future__ import print_function
25 25
26 26 import sys
27 27 import time
28 28
29 29 from code import CommandCompiler
30 30 from datetime import datetime
31 31 from pprint import pprint
32 32
33 33 # System library imports.
34 34 import zmq
35 35 from zmq.eventloop import ioloop, zmqstream
36 36
37 37 # Local imports.
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 39 from IPython.zmq.completer import KernelCompleter
40 40
41 41 from IPython.parallel.error import wrap_exception
42 42 from IPython.parallel.factory import SessionFactory
43 from IPython.parallel.util import serialize_object, unpack_apply_message
43 from IPython.parallel.util import serialize_object, unpack_apply_message, ensure_bytes
44 44
45 45 def printer(*args):
46 46 pprint(args, stream=sys.__stdout__)
47 47
48 48
49 49 class _Passer(zmqstream.ZMQStream):
50 50 """Empty class that implements `send()` that does nothing.
51 51
52 52 Subclass ZMQStream for Session typechecking
53 53
54 54 """
55 55 def __init__(self, *args, **kwargs):
56 56 pass
57 57
58 58 def send(self, *args, **kwargs):
59 59 pass
60 60 send_multipart = send
61 61
62 62
63 63 #-----------------------------------------------------------------------------
64 64 # Main kernel class
65 65 #-----------------------------------------------------------------------------
66 66
67 67 class Kernel(SessionFactory):
68 68
69 69 #---------------------------------------------------------------------------
70 70 # Kernel interface
71 71 #---------------------------------------------------------------------------
72 72
73 73 # kwargs:
74 74 exec_lines = List(Unicode, config=True,
75 75 help="List of lines to execute")
76
76
77 # identities:
77 78 int_id = Int(-1)
79 bident = CBytes()
80 ident = Unicode()
81 def _ident_changed(self, name, old, new):
82 self.bident = ensure_bytes(new)
83
78 84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
79 85
80 86 control_stream = Instance(zmqstream.ZMQStream)
81 87 task_stream = Instance(zmqstream.ZMQStream)
82 88 iopub_stream = Instance(zmqstream.ZMQStream)
83 89 client = Instance('IPython.parallel.Client')
84 90
85 91 # internals
86 92 shell_streams = List()
87 93 compiler = Instance(CommandCompiler, (), {})
88 94 completer = Instance(KernelCompleter)
89 95
90 96 aborted = Set()
91 97 shell_handlers = Dict()
92 98 control_handlers = Dict()
93 99
94 100 def _set_prefix(self):
95 101 self.prefix = "engine.%s"%self.int_id
96 102
97 103 def _connect_completer(self):
98 104 self.completer = KernelCompleter(self.user_ns)
99 105
100 106 def __init__(self, **kwargs):
101 107 super(Kernel, self).__init__(**kwargs)
102 108 self._set_prefix()
103 109 self._connect_completer()
104 110
105 111 self.on_trait_change(self._set_prefix, 'id')
106 112 self.on_trait_change(self._connect_completer, 'user_ns')
107 113
108 114 # Build dict of handlers for message types
109 115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
110 116 'clear_request']:
111 117 self.shell_handlers[msg_type] = getattr(self, msg_type)
112 118
113 119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
114 120 self.control_handlers[msg_type] = getattr(self, msg_type)
115 121
116 122 self._initial_exec_lines()
117 123
118 124 def _wrap_exception(self, method=None):
119 125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
120 126 content=wrap_exception(e_info)
121 127 return content
122 128
123 129 def _initial_exec_lines(self):
124 130 s = _Passer()
125 131 content = dict(silent=True, user_variable=[],user_expressions=[])
126 132 for line in self.exec_lines:
127 133 self.log.debug("executing initialization: %s"%line)
128 134 content.update({'code':line})
129 135 msg = self.session.msg('execute_request', content)
130 136 self.execute_request(s, [], msg)
131 137
132 138
133 139 #-------------------- control handlers -----------------------------
134 140 def abort_queues(self):
135 141 for stream in self.shell_streams:
136 142 if stream:
137 143 self.abort_queue(stream)
138 144
139 145 def abort_queue(self, stream):
140 146 while True:
141 147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
142 148 if msg is None:
143 149 return
144 150
145 151 self.log.info("Aborting:")
146 152 self.log.info(str(msg))
147 153 msg_type = msg['msg_type']
148 154 reply_type = msg_type.split('_')[0] + '_reply'
149 155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
150 156 # self.reply_socket.send(ident,zmq.SNDMORE)
151 157 # self.reply_socket.send_json(reply_msg)
152 158 reply_msg = self.session.send(stream, reply_type,
153 159 content={'status' : 'aborted'}, parent=msg, ident=idents)
154 160 self.log.debug(str(reply_msg))
155 161 # We need to wait a bit for requests to come in. This can probably
156 162 # be set shorter for true asynchronous clients.
157 163 time.sleep(0.05)
158 164
159 165 def abort_request(self, stream, ident, parent):
160 166 """abort a specifig msg by id"""
161 167 msg_ids = parent['content'].get('msg_ids', None)
162 168 if isinstance(msg_ids, basestring):
163 169 msg_ids = [msg_ids]
164 170 if not msg_ids:
165 171 self.abort_queues()
166 172 for mid in msg_ids:
167 173 self.aborted.add(str(mid))
168 174
169 175 content = dict(status='ok')
170 176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
171 177 parent=parent, ident=ident)
172 178 self.log.debug(str(reply_msg))
173 179
174 180 def shutdown_request(self, stream, ident, parent):
175 181 """kill ourself. This should really be handled in an external process"""
176 182 try:
177 183 self.abort_queues()
178 184 except:
179 185 content = self._wrap_exception('shutdown')
180 186 else:
181 187 content = dict(parent['content'])
182 188 content['status'] = 'ok'
183 189 msg = self.session.send(stream, 'shutdown_reply',
184 190 content=content, parent=parent, ident=ident)
185 191 self.log.debug(str(msg))
186 192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
187 193 dc.start()
188 194
189 195 def dispatch_control(self, msg):
190 196 idents,msg = self.session.feed_identities(msg, copy=False)
191 197 try:
192 198 msg = self.session.unpack_message(msg, content=True, copy=False)
193 199 except:
194 200 self.log.error("Invalid Message", exc_info=True)
195 201 return
202 else:
203 self.log.debug("Control received, %s", msg)
196 204
197 205 header = msg['header']
198 206 msg_id = header['msg_id']
199 207
200 208 handler = self.control_handlers.get(msg['msg_type'], None)
201 209 if handler is None:
202 210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
203 211 else:
204 212 handler(self.control_stream, idents, msg)
205 213
206 214
207 215 #-------------------- queue helpers ------------------------------
208 216
209 217 def check_dependencies(self, dependencies):
210 218 if not dependencies:
211 219 return True
212 220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
213 221 anyorall = dependencies[0]
214 222 dependencies = dependencies[1]
215 223 else:
216 224 anyorall = 'all'
217 225 results = self.client.get_results(dependencies,status_only=True)
218 226 if results['status'] != 'ok':
219 227 return False
220 228
221 229 if anyorall == 'any':
222 230 if not results['completed']:
223 231 return False
224 232 else:
225 233 if results['pending']:
226 234 return False
227 235
228 236 return True
229 237
230 238 def check_aborted(self, msg_id):
231 239 return msg_id in self.aborted
232 240
233 241 #-------------------- queue handlers -----------------------------
234 242
235 243 def clear_request(self, stream, idents, parent):
236 244 """Clear our namespace."""
237 245 self.user_ns = {}
238 246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
239 247 content = dict(status='ok'))
240 248 self._initial_exec_lines()
241 249
242 250 def execute_request(self, stream, ident, parent):
243 251 self.log.debug('execute request %s'%parent)
244 252 try:
245 253 code = parent[u'content'][u'code']
246 254 except:
247 255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
248 256 return
249 257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
250 ident='%s.pyin'%self.prefix)
258 ident=ensure_bytes('%s.pyin'%self.prefix))
251 259 started = datetime.now()
252 260 try:
253 261 comp_code = self.compiler(code, '<zmq-kernel>')
254 262 # allow for not overriding displayhook
255 263 if hasattr(sys.displayhook, 'set_parent'):
256 264 sys.displayhook.set_parent(parent)
257 265 sys.stdout.set_parent(parent)
258 266 sys.stderr.set_parent(parent)
259 267 exec comp_code in self.user_ns, self.user_ns
260 268 except:
261 269 exc_content = self._wrap_exception('execute')
262 270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
263 271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
264 ident='%s.pyerr'%self.prefix)
272 ident=ensure_bytes('%s.pyerr'%self.prefix))
265 273 reply_content = exc_content
266 274 else:
267 275 reply_content = {'status' : 'ok'}
268 276
269 277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
270 278 ident=ident, subheader = dict(started=started))
271 279 self.log.debug(str(reply_msg))
272 280 if reply_msg['content']['status'] == u'error':
273 281 self.abort_queues()
274 282
275 283 def complete_request(self, stream, ident, parent):
276 284 matches = {'matches' : self.complete(parent),
277 285 'status' : 'ok'}
278 286 completion_msg = self.session.send(stream, 'complete_reply',
279 287 matches, parent, ident)
280 288 # print >> sys.__stdout__, completion_msg
281 289
282 290 def complete(self, msg):
283 291 return self.completer.complete(msg.content.line, msg.content.text)
284 292
285 293 def apply_request(self, stream, ident, parent):
286 294 # flush previous reply, so this request won't block it
287 295 stream.flush(zmq.POLLOUT)
288
289 296 try:
290 297 content = parent[u'content']
291 298 bufs = parent[u'buffers']
292 299 msg_id = parent['header']['msg_id']
293 300 # bound = parent['header'].get('bound', False)
294 301 except:
295 302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
296 303 return
297 304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
298 305 # self.iopub_stream.send(pyin_msg)
299 306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
300 307 sub = {'dependencies_met' : True, 'engine' : self.ident,
301 308 'started': datetime.now()}
302 309 try:
303 310 # allow for not overriding displayhook
304 311 if hasattr(sys.displayhook, 'set_parent'):
305 312 sys.displayhook.set_parent(parent)
306 313 sys.stdout.set_parent(parent)
307 314 sys.stderr.set_parent(parent)
308 315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
309 316 working = self.user_ns
310 317 # suffix =
311 318 prefix = "_"+str(msg_id).replace("-","")+"_"
312 319
313 320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
314 321 # if bound:
315 322 # bound_ns = Namespace(working)
316 323 # args = [bound_ns]+list(args)
317 324
318 325 fname = getattr(f, '__name__', 'f')
319 326
320 327 fname = prefix+"f"
321 328 argname = prefix+"args"
322 329 kwargname = prefix+"kwargs"
323 330 resultname = prefix+"result"
324 331
325 332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
326 333 # print ns
327 334 working.update(ns)
328 335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
329 336 try:
330 337 exec code in working,working
331 338 result = working.get(resultname)
332 339 finally:
333 340 for key in ns.iterkeys():
334 341 working.pop(key)
335 342 # if bound:
336 343 # working.update(bound_ns)
337 344
338 345 packed_result,buf = serialize_object(result)
339 346 result_buf = [packed_result]+buf
340 347 except:
341 348 exc_content = self._wrap_exception('apply')
342 349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
343 350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
344 ident='%s.pyerr'%self.prefix)
351 ident=ensure_bytes('%s.pyerr'%self.prefix))
345 352 reply_content = exc_content
346 353 result_buf = []
347 354
348 355 if exc_content['ename'] == 'UnmetDependency':
349 356 sub['dependencies_met'] = False
350 357 else:
351 358 reply_content = {'status' : 'ok'}
352 359
353 360 # put 'ok'/'error' status in header, for scheduler introspection:
354 361 sub['status'] = reply_content['status']
355 362
356 363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
357 364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
358 365
359 366 # flush i/o
360 367 # should this be before reply_msg is sent, like in the single-kernel code,
361 368 # or should nothing get in the way of real results?
362 369 sys.stdout.flush()
363 370 sys.stderr.flush()
364 371
365 372 def dispatch_queue(self, stream, msg):
366 373 self.control_stream.flush()
367 374 idents,msg = self.session.feed_identities(msg, copy=False)
368 375 try:
369 376 msg = self.session.unpack_message(msg, content=True, copy=False)
370 377 except:
371 378 self.log.error("Invalid Message", exc_info=True)
372 379 return
380 else:
381 self.log.debug("Message received, %s", msg)
373 382
374 383
375 384 header = msg['header']
376 385 msg_id = header['msg_id']
377 386 if self.check_aborted(msg_id):
378 387 self.aborted.remove(msg_id)
379 388 # is it safe to assume a msg_id will not be resubmitted?
380 389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
381 390 status = {'status' : 'aborted'}
382 391 reply_msg = self.session.send(stream, reply_type, subheader=status,
383 392 content=status, parent=msg, ident=idents)
384 393 return
385 394 handler = self.shell_handlers.get(msg['msg_type'], None)
386 395 if handler is None:
387 396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
388 397 else:
389 398 handler(stream, idents, msg)
390 399
391 400 def start(self):
392 401 #### stream mode:
393 402 if self.control_stream:
394 403 self.control_stream.on_recv(self.dispatch_control, copy=False)
395 404 self.control_stream.on_err(printer)
396 405
397 406 def make_dispatcher(stream):
398 407 def dispatcher(msg):
399 408 return self.dispatch_queue(stream, msg)
400 409 return dispatcher
401 410
402 411 for s in self.shell_streams:
403 412 s.on_recv(make_dispatcher(s), copy=False)
404 413 s.on_err(printer)
405 414
406 415 if self.iopub_stream:
407 416 self.iopub_stream.on_err(printer)
408 417
409 418 #### while True mode:
410 419 # while True:
411 420 # idle = True
412 421 # try:
413 422 # msg = self.shell_stream.socket.recv_multipart(
414 423 # zmq.NOBLOCK, copy=False)
415 424 # except zmq.ZMQError, e:
416 425 # if e.errno != zmq.EAGAIN:
417 426 # raise e
418 427 # else:
419 428 # idle=False
420 429 # self.dispatch_queue(self.shell_stream, msg)
421 430 #
422 431 # if not self.task_stream.empty():
423 432 # idle=False
424 433 # msg = self.task_stream.recv_multipart()
425 434 # self.dispatch_queue(self.task_stream, msg)
426 435 # if idle:
427 436 # # don't busywait
428 437 # time.sleep(1e-3)
429 438
@@ -1,134 +1,136 b''
1 1 """base class for parallel client tests
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 import sys
16 16 import tempfile
17 17 import time
18 18
19 19 from nose import SkipTest
20 20
21 21 import zmq
22 22 from zmq.tests import BaseZMQTestCase
23 23
24 24 from IPython.external.decorator import decorator
25 25
26 26 from IPython.parallel import error
27 27 from IPython.parallel import Client
28 28
29 29 from IPython.parallel.tests import launchers, add_engines
30 30
31 31 # simple tasks for use in apply tests
32 32
33 33 def segfault():
34 34 """this will segfault"""
35 35 import ctypes
36 36 ctypes.memset(-1,0,1)
37 37
38 38 def crash():
39 39 """from stdlib crashers in the test suite"""
40 40 import types
41 41 if sys.platform.startswith('win'):
42 42 import ctypes
43 43 ctypes.windll.kernel32.SetErrorMode(0x0002);
44
45 co = types.CodeType(0, 0, 0, 0, b'\x04\x71\x00\x00',
46 (), (), (), '', '', 1, b'')
44 args = [ 0, 0, 0, 0, b'\x04\x71\x00\x00', (), (), (), '', '', 1, b'']
45 if sys.version_info[0] >= 3:
46 args.insert(1, 0)
47
48 co = types.CodeType(*args)
47 49 exec(co)
48 50
49 51 def wait(n):
50 52 """sleep for a time"""
51 53 import time
52 54 time.sleep(n)
53 55 return n
54 56
55 57 def raiser(eclass):
56 58 """raise an exception"""
57 59 raise eclass()
58 60
59 61 # test decorator for skipping tests when libraries are unavailable
60 62 def skip_without(*names):
61 63 """skip a test if some names are not importable"""
62 64 @decorator
63 65 def skip_without_names(f, *args, **kwargs):
64 66 """decorator to skip tests in the absence of numpy."""
65 67 for name in names:
66 68 try:
67 69 __import__(name)
68 70 except ImportError:
69 71 raise SkipTest
70 72 return f(*args, **kwargs)
71 73 return skip_without_names
72 74
73 75 class ClusterTestCase(BaseZMQTestCase):
74 76
75 77 def add_engines(self, n=1, block=True):
76 78 """add multiple engines to our cluster"""
77 79 self.engines.extend(add_engines(n))
78 80 if block:
79 81 self.wait_on_engines()
80 82
81 83 def wait_on_engines(self, timeout=5):
82 84 """wait for our engines to connect."""
83 85 n = len(self.engines)+self.base_engine_count
84 86 tic = time.time()
85 87 while time.time()-tic < timeout and len(self.client.ids) < n:
86 88 time.sleep(0.1)
87 89
88 90 assert not len(self.client.ids) < n, "waiting for engines timed out"
89 91
90 92 def connect_client(self):
91 93 """connect a client with my Context, and track its sockets for cleanup"""
92 94 c = Client(profile='iptest', context=self.context)
93 95 for name in filter(lambda n:n.endswith('socket'), dir(c)):
94 96 s = getattr(c, name)
95 97 s.setsockopt(zmq.LINGER, 0)
96 98 self.sockets.append(s)
97 99 return c
98 100
99 101 def assertRaisesRemote(self, etype, f, *args, **kwargs):
100 102 try:
101 103 try:
102 104 f(*args, **kwargs)
103 105 except error.CompositeError as e:
104 106 e.raise_exception()
105 107 except error.RemoteError as e:
106 108 self.assertEquals(etype.__name__, e.ename, "Should have raised %r, but raised %r"%(etype.__name__, e.ename))
107 109 else:
108 110 self.fail("should have raised a RemoteError")
109 111
110 112 def setUp(self):
111 113 BaseZMQTestCase.setUp(self)
112 114 self.client = self.connect_client()
113 115 # start every test with clean engine namespaces:
114 116 self.client.clear(block=True)
115 117 self.base_engine_count=len(self.client.ids)
116 118 self.engines=[]
117 119
118 120 def tearDown(self):
119 121 # self.client.clear(block=True)
120 122 # close fds:
121 123 for e in filter(lambda e: e.poll() is not None, launchers):
122 124 launchers.remove(e)
123 125
124 126 # allow flushing of incoming messages to prevent crash on socket close
125 127 self.client.wait(timeout=2)
126 128 # time.sleep(2)
127 129 self.client.spin()
128 130 self.client.close()
129 131 BaseZMQTestCase.tearDown(self)
130 132 # this will be redundant when pyzmq merges PR #88
131 133 # self.context.term()
132 134 # print tempfile.TemporaryFile().fileno(),
133 135 # sys.stdout.flush()
134 136 No newline at end of file
@@ -1,73 +1,73 b''
1 1 """Tests for asyncresult.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19
20 20 from IPython.parallel.error import TimeoutError
21 21
22 22 from IPython.parallel.tests import add_engines
23 23 from .clienttest import ClusterTestCase
24 24
25 25 def setup():
26 26 add_engines(2)
27 27
28 28 def wait(n):
29 29 import time
30 30 time.sleep(n)
31 31 return n
32 32
33 33 class AsyncResultTest(ClusterTestCase):
34 34
35 35 def test_single_result(self):
36 36 eid = self.client.ids[-1]
37 37 ar = self.client[eid].apply_async(lambda : 42)
38 38 self.assertEquals(ar.get(), 42)
39 39 ar = self.client[[eid]].apply_async(lambda : 42)
40 40 self.assertEquals(ar.get(), [42])
41 41 ar = self.client[-1:].apply_async(lambda : 42)
42 42 self.assertEquals(ar.get(), [42])
43 43
44 44 def test_get_after_done(self):
45 45 ar = self.client[-1].apply_async(lambda : 42)
46 46 ar.wait()
47 47 self.assertTrue(ar.ready())
48 48 self.assertEquals(ar.get(), 42)
49 49 self.assertEquals(ar.get(), 42)
50 50
51 51 def test_get_before_done(self):
52 52 ar = self.client[-1].apply_async(wait, 0.1)
53 53 self.assertRaises(TimeoutError, ar.get, 0)
54 54 ar.wait(0)
55 55 self.assertFalse(ar.ready())
56 56 self.assertEquals(ar.get(), 0.1)
57 57
58 58 def test_get_after_error(self):
59 59 ar = self.client[-1].apply_async(lambda : 1/0)
60 ar.wait()
60 ar.wait(10)
61 61 self.assertRaisesRemote(ZeroDivisionError, ar.get)
62 62 self.assertRaisesRemote(ZeroDivisionError, ar.get)
63 63 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
64 64
65 65 def test_get_dict(self):
66 66 n = len(self.client)
67 67 ar = self.client[:].apply_async(lambda : 5)
68 68 self.assertEquals(ar.get(), [5]*n)
69 69 d = ar.get_dict()
70 70 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
71 71 for eid,r in d.iteritems():
72 72 self.assertEquals(r, 5)
73 73
@@ -1,257 +1,261 b''
1 1 """Tests for parallel client.py
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 from __future__ import division
20
19 21 import time
20 22 from datetime import datetime
21 23 from tempfile import mktemp
22 24
23 25 import zmq
24 26
25 27 from IPython.parallel.client import client as clientmod
26 28 from IPython.parallel import error
27 29 from IPython.parallel import AsyncResult, AsyncHubResult
28 30 from IPython.parallel import LoadBalancedView, DirectView
29 31
30 32 from clienttest import ClusterTestCase, segfault, wait, add_engines
31 33
32 34 def setup():
33 35 add_engines(4)
34 36
35 37 class TestClient(ClusterTestCase):
36 38
37 39 def test_ids(self):
38 40 n = len(self.client.ids)
39 41 self.add_engines(3)
40 42 self.assertEquals(len(self.client.ids), n+3)
41 43
42 44 def test_view_indexing(self):
43 45 """test index access for views"""
44 46 self.add_engines(2)
45 47 targets = self.client._build_targets('all')[-1]
46 48 v = self.client[:]
47 49 self.assertEquals(v.targets, targets)
48 50 t = self.client.ids[2]
49 51 v = self.client[t]
50 52 self.assert_(isinstance(v, DirectView))
51 53 self.assertEquals(v.targets, t)
52 54 t = self.client.ids[2:4]
53 55 v = self.client[t]
54 56 self.assert_(isinstance(v, DirectView))
55 57 self.assertEquals(v.targets, t)
56 58 v = self.client[::2]
57 59 self.assert_(isinstance(v, DirectView))
58 60 self.assertEquals(v.targets, targets[::2])
59 61 v = self.client[1::3]
60 62 self.assert_(isinstance(v, DirectView))
61 63 self.assertEquals(v.targets, targets[1::3])
62 64 v = self.client[:-3]
63 65 self.assert_(isinstance(v, DirectView))
64 66 self.assertEquals(v.targets, targets[:-3])
65 67 v = self.client[-1]
66 68 self.assert_(isinstance(v, DirectView))
67 69 self.assertEquals(v.targets, targets[-1])
68 70 self.assertRaises(TypeError, lambda : self.client[None])
69 71
70 72 def test_lbview_targets(self):
71 73 """test load_balanced_view targets"""
72 74 v = self.client.load_balanced_view()
73 75 self.assertEquals(v.targets, None)
74 76 v = self.client.load_balanced_view(-1)
75 77 self.assertEquals(v.targets, [self.client.ids[-1]])
76 78 v = self.client.load_balanced_view('all')
77 79 self.assertEquals(v.targets, self.client.ids)
78 80
79 81 def test_targets(self):
80 82 """test various valid targets arguments"""
81 83 build = self.client._build_targets
82 84 ids = self.client.ids
83 85 idents,targets = build(None)
84 86 self.assertEquals(ids, targets)
85 87
86 88 def test_clear(self):
87 89 """test clear behavior"""
88 90 # self.add_engines(2)
89 91 v = self.client[:]
90 92 v.block=True
91 93 v.push(dict(a=5))
92 94 v.pull('a')
93 95 id0 = self.client.ids[-1]
94 96 self.client.clear(targets=id0, block=True)
95 97 a = self.client[:-1].get('a')
96 98 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
97 99 self.client.clear(block=True)
98 100 for i in self.client.ids:
99 101 # print i
100 102 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
101 103
102 104 def test_get_result(self):
103 105 """test getting results from the Hub."""
104 106 c = clientmod.Client(profile='iptest')
105 107 # self.add_engines(1)
106 108 t = c.ids[-1]
107 109 ar = c[t].apply_async(wait, 1)
108 110 # give the monitor time to notice the message
109 111 time.sleep(.25)
110 112 ahr = self.client.get_result(ar.msg_ids)
111 113 self.assertTrue(isinstance(ahr, AsyncHubResult))
112 114 self.assertEquals(ahr.get(), ar.get())
113 115 ar2 = self.client.get_result(ar.msg_ids)
114 116 self.assertFalse(isinstance(ar2, AsyncHubResult))
115 117 c.close()
116 118
117 119 def test_ids_list(self):
118 120 """test client.ids"""
119 121 # self.add_engines(2)
120 122 ids = self.client.ids
121 123 self.assertEquals(ids, self.client._ids)
122 124 self.assertFalse(ids is self.client._ids)
123 125 ids.remove(ids[-1])
124 126 self.assertNotEquals(ids, self.client._ids)
125 127
126 128 def test_queue_status(self):
127 129 # self.addEngine(4)
128 130 ids = self.client.ids
129 131 id0 = ids[0]
130 132 qs = self.client.queue_status(targets=id0)
131 133 self.assertTrue(isinstance(qs, dict))
132 134 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
133 135 allqs = self.client.queue_status()
134 136 self.assertTrue(isinstance(allqs, dict))
135 self.assertEquals(sorted(allqs.keys()), sorted(self.client.ids + ['unassigned']))
137 intkeys = list(allqs.keys())
138 intkeys.remove('unassigned')
139 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
136 140 unassigned = allqs.pop('unassigned')
137 141 for eid,qs in allqs.items():
138 142 self.assertTrue(isinstance(qs, dict))
139 143 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
140 144
141 145 def test_shutdown(self):
142 146 # self.addEngine(4)
143 147 ids = self.client.ids
144 148 id0 = ids[0]
145 149 self.client.shutdown(id0, block=True)
146 150 while id0 in self.client.ids:
147 151 time.sleep(0.1)
148 152 self.client.spin()
149 153
150 154 self.assertRaises(IndexError, lambda : self.client[id0])
151 155
152 156 def test_result_status(self):
153 157 pass
154 158 # to be written
155 159
156 160 def test_db_query_dt(self):
157 161 """test db query by date"""
158 162 hist = self.client.hub_history()
159 middle = self.client.db_query({'msg_id' : hist[len(hist)/2]})[0]
163 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
160 164 tic = middle['submitted']
161 165 before = self.client.db_query({'submitted' : {'$lt' : tic}})
162 166 after = self.client.db_query({'submitted' : {'$gte' : tic}})
163 167 self.assertEquals(len(before)+len(after),len(hist))
164 168 for b in before:
165 169 self.assertTrue(b['submitted'] < tic)
166 170 for a in after:
167 171 self.assertTrue(a['submitted'] >= tic)
168 172 same = self.client.db_query({'submitted' : tic})
169 173 for s in same:
170 174 self.assertTrue(s['submitted'] == tic)
171 175
172 176 def test_db_query_keys(self):
173 177 """test extracting subset of record keys"""
174 178 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
175 179 for rec in found:
176 180 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
177 181
178 182 def test_db_query_msg_id(self):
179 183 """ensure msg_id is always in db queries"""
180 184 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
181 185 for rec in found:
182 186 self.assertTrue('msg_id' in rec.keys())
183 187 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
184 188 for rec in found:
185 189 self.assertTrue('msg_id' in rec.keys())
186 190 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
187 191 for rec in found:
188 192 self.assertTrue('msg_id' in rec.keys())
189 193
190 194 def test_db_query_in(self):
191 195 """test db query with '$in','$nin' operators"""
192 196 hist = self.client.hub_history()
193 197 even = hist[::2]
194 198 odd = hist[1::2]
195 199 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
196 200 found = [ r['msg_id'] for r in recs ]
197 201 self.assertEquals(set(even), set(found))
198 202 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
199 203 found = [ r['msg_id'] for r in recs ]
200 204 self.assertEquals(set(odd), set(found))
201 205
202 206 def test_hub_history(self):
203 207 hist = self.client.hub_history()
204 208 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
205 209 recdict = {}
206 210 for rec in recs:
207 211 recdict[rec['msg_id']] = rec
208 212
209 213 latest = datetime(1984,1,1)
210 214 for msg_id in hist:
211 215 rec = recdict[msg_id]
212 216 newt = rec['submitted']
213 217 self.assertTrue(newt >= latest)
214 218 latest = newt
215 219 ar = self.client[-1].apply_async(lambda : 1)
216 220 ar.get()
217 221 time.sleep(0.25)
218 222 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
219 223
220 224 def test_resubmit(self):
221 225 def f():
222 226 import random
223 227 return random.random()
224 228 v = self.client.load_balanced_view()
225 229 ar = v.apply_async(f)
226 230 r1 = ar.get(1)
227 231 ahr = self.client.resubmit(ar.msg_ids)
228 232 r2 = ahr.get(1)
229 233 self.assertFalse(r1 == r2)
230 234
231 235 def test_resubmit_inflight(self):
232 236 """ensure ValueError on resubmit of inflight task"""
233 237 v = self.client.load_balanced_view()
234 238 ar = v.apply_async(time.sleep,1)
235 239 # give the message a chance to arrive
236 240 time.sleep(0.2)
237 241 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
238 242 ar.get(2)
239 243
240 244 def test_resubmit_badkey(self):
241 245 """ensure KeyError on resubmit of nonexistant task"""
242 246 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
243 247
244 248 def test_purge_results(self):
245 249 # ensure there are some tasks
246 250 for i in range(5):
247 251 self.client[:].apply_sync(lambda : 1)
248 252 hist = self.client.hub_history()
249 253 self.client.purge_results(hist[-1])
250 254 newhist = self.client.hub_history()
251 255 self.assertEquals(len(newhist)+1,len(hist))
252 256
253 257 def test_purge_all_results(self):
254 258 self.client.purge_results('all')
255 259 hist = self.client.hub_history()
256 260 self.assertEquals(len(hist), 0)
257 261
@@ -1,178 +1,179 b''
1 1 """Tests for db backends
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 from __future__ import division
19 20
20 21 import tempfile
21 22 import time
22 23
23 24 from datetime import datetime, timedelta
24 25 from unittest import TestCase
25 26
26 27 from nose import SkipTest
27 28
28 29 from IPython.parallel import error
29 30 from IPython.parallel.controller.dictdb import DictDB
30 31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 32 from IPython.parallel.controller.hub import init_record, empty_record
32 33
33 34 from IPython.zmq.session import Session
34 35
35 36
36 37 #-------------------------------------------------------------------------------
37 38 # TestCases
38 39 #-------------------------------------------------------------------------------
39 40
40 41 class TestDictBackend(TestCase):
41 42 def setUp(self):
42 43 self.session = Session()
43 44 self.db = self.create_db()
44 45 self.load_records(16)
45 46
46 47 def create_db(self):
47 48 return DictDB()
48 49
49 50 def load_records(self, n=1):
50 51 """load n records for testing"""
51 52 #sleep 1/10 s, to ensure timestamp is different to previous calls
52 53 time.sleep(0.1)
53 54 msg_ids = []
54 55 for i in range(n):
55 56 msg = self.session.msg('apply_request', content=dict(a=5))
56 57 msg['buffers'] = []
57 58 rec = init_record(msg)
58 59 msg_ids.append(msg['msg_id'])
59 60 self.db.add_record(msg['msg_id'], rec)
60 61 return msg_ids
61 62
62 63 def test_add_record(self):
63 64 before = self.db.get_history()
64 65 self.load_records(5)
65 66 after = self.db.get_history()
66 67 self.assertEquals(len(after), len(before)+5)
67 68 self.assertEquals(after[:-5],before)
68 69
69 70 def test_drop_record(self):
70 71 msg_id = self.load_records()[-1]
71 72 rec = self.db.get_record(msg_id)
72 73 self.db.drop_record(msg_id)
73 74 self.assertRaises(KeyError,self.db.get_record, msg_id)
74 75
75 76 def _round_to_millisecond(self, dt):
76 77 """necessary because mongodb rounds microseconds"""
77 78 micro = dt.microsecond
78 79 extra = int(str(micro)[-3:])
79 80 return dt - timedelta(microseconds=extra)
80 81
81 82 def test_update_record(self):
82 83 now = self._round_to_millisecond(datetime.now())
83 84 #
84 85 msg_id = self.db.get_history()[-1]
85 86 rec1 = self.db.get_record(msg_id)
86 87 data = {'stdout': 'hello there', 'completed' : now}
87 88 self.db.update_record(msg_id, data)
88 89 rec2 = self.db.get_record(msg_id)
89 90 self.assertEquals(rec2['stdout'], 'hello there')
90 91 self.assertEquals(rec2['completed'], now)
91 92 rec1.update(data)
92 93 self.assertEquals(rec1, rec2)
93 94
94 95 # def test_update_record_bad(self):
95 96 # """test updating nonexistant records"""
96 97 # msg_id = str(uuid.uuid4())
97 98 # data = {'stdout': 'hello there'}
98 99 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
99 100
100 101 def test_find_records_dt(self):
101 102 """test finding records by date"""
102 103 hist = self.db.get_history()
103 middle = self.db.get_record(hist[len(hist)/2])
104 middle = self.db.get_record(hist[len(hist)//2])
104 105 tic = middle['submitted']
105 106 before = self.db.find_records({'submitted' : {'$lt' : tic}})
106 107 after = self.db.find_records({'submitted' : {'$gte' : tic}})
107 108 self.assertEquals(len(before)+len(after),len(hist))
108 109 for b in before:
109 110 self.assertTrue(b['submitted'] < tic)
110 111 for a in after:
111 112 self.assertTrue(a['submitted'] >= tic)
112 113 same = self.db.find_records({'submitted' : tic})
113 114 for s in same:
114 115 self.assertTrue(s['submitted'] == tic)
115 116
116 117 def test_find_records_keys(self):
117 118 """test extracting subset of record keys"""
118 119 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
119 120 for rec in found:
120 121 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
121 122
122 123 def test_find_records_msg_id(self):
123 124 """ensure msg_id is always in found records"""
124 125 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
125 126 for rec in found:
126 127 self.assertTrue('msg_id' in rec.keys())
127 128 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
128 129 for rec in found:
129 130 self.assertTrue('msg_id' in rec.keys())
130 131 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
131 132 for rec in found:
132 133 self.assertTrue('msg_id' in rec.keys())
133 134
134 135 def test_find_records_in(self):
135 136 """test finding records with '$in','$nin' operators"""
136 137 hist = self.db.get_history()
137 138 even = hist[::2]
138 139 odd = hist[1::2]
139 140 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
140 141 found = [ r['msg_id'] for r in recs ]
141 142 self.assertEquals(set(even), set(found))
142 143 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
143 144 found = [ r['msg_id'] for r in recs ]
144 145 self.assertEquals(set(odd), set(found))
145 146
146 147 def test_get_history(self):
147 148 msg_ids = self.db.get_history()
148 149 latest = datetime(1984,1,1)
149 150 for msg_id in msg_ids:
150 151 rec = self.db.get_record(msg_id)
151 152 newt = rec['submitted']
152 153 self.assertTrue(newt >= latest)
153 154 latest = newt
154 155 msg_id = self.load_records(1)[-1]
155 156 self.assertEquals(self.db.get_history()[-1],msg_id)
156 157
157 158 def test_datetime(self):
158 159 """get/set timestamps with datetime objects"""
159 160 msg_id = self.db.get_history()[-1]
160 161 rec = self.db.get_record(msg_id)
161 162 self.assertTrue(isinstance(rec['submitted'], datetime))
162 163 self.db.update_record(msg_id, dict(completed=datetime.now()))
163 164 rec = self.db.get_record(msg_id)
164 165 self.assertTrue(isinstance(rec['completed'], datetime))
165 166
166 167 def test_drop_matching(self):
167 168 msg_ids = self.load_records(10)
168 169 query = {'msg_id' : {'$in':msg_ids}}
169 170 self.db.drop_matching_records(query)
170 171 recs = self.db.find_records(query)
171 self.assertTrue(len(recs)==0)
172 self.assertEquals(len(recs), 0)
172 173
173 174 class TestSQLiteBackend(TestDictBackend):
174 175 def create_db(self):
175 176 return SQLiteDB(location=tempfile.gettempdir())
176 177
177 178 def tearDown(self):
178 179 self.db._db.close()
@@ -1,446 +1,449 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25
26 26 from IPython import parallel as pmod
27 27 from IPython.parallel import error
28 28 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
29 29 from IPython.parallel import DirectView
30 30 from IPython.parallel.util import interactive
31 31
32 32 from IPython.parallel.tests import add_engines
33 33
34 34 from .clienttest import ClusterTestCase, crash, wait, skip_without
35 35
36 36 def setup():
37 37 add_engines(3)
38 38
39 39 class TestView(ClusterTestCase):
40 40
41 41 def test_z_crash_mux(self):
42 42 """test graceful handling of engine death (direct)"""
43 43 # self.add_engines(1)
44 44 eid = self.client.ids[-1]
45 45 ar = self.client[eid].apply_async(crash)
46 self.assertRaisesRemote(error.EngineError, ar.get)
46 self.assertRaisesRemote(error.EngineError, ar.get, 10)
47 47 eid = ar.engine_id
48 48 tic = time.time()
49 49 while eid in self.client.ids and time.time()-tic < 5:
50 50 time.sleep(.01)
51 51 self.client.spin()
52 52 self.assertFalse(eid in self.client.ids, "Engine should have died")
53 53
54 54 def test_push_pull(self):
55 55 """test pushing and pulling"""
56 56 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
57 57 t = self.client.ids[-1]
58 58 v = self.client[t]
59 59 push = v.push
60 60 pull = v.pull
61 61 v.block=True
62 62 nengines = len(self.client)
63 63 push({'data':data})
64 64 d = pull('data')
65 65 self.assertEquals(d, data)
66 66 self.client[:].push({'data':data})
67 67 d = self.client[:].pull('data', block=True)
68 68 self.assertEquals(d, nengines*[data])
69 69 ar = push({'data':data}, block=False)
70 70 self.assertTrue(isinstance(ar, AsyncResult))
71 71 r = ar.get()
72 72 ar = self.client[:].pull('data', block=False)
73 73 self.assertTrue(isinstance(ar, AsyncResult))
74 74 r = ar.get()
75 75 self.assertEquals(r, nengines*[data])
76 76 self.client[:].push(dict(a=10,b=20))
77 77 r = self.client[:].pull(('a','b'), block=True)
78 78 self.assertEquals(r, nengines*[[10,20]])
79 79
80 80 def test_push_pull_function(self):
81 81 "test pushing and pulling functions"
82 82 def testf(x):
83 83 return 2.0*x
84 84
85 85 t = self.client.ids[-1]
86 86 v = self.client[t]
87 87 v.block=True
88 88 push = v.push
89 89 pull = v.pull
90 90 execute = v.execute
91 91 push({'testf':testf})
92 92 r = pull('testf')
93 93 self.assertEqual(r(1.0), testf(1.0))
94 94 execute('r = testf(10)')
95 95 r = pull('r')
96 96 self.assertEquals(r, testf(10))
97 97 ar = self.client[:].push({'testf':testf}, block=False)
98 98 ar.get()
99 99 ar = self.client[:].pull('testf', block=False)
100 100 rlist = ar.get()
101 101 for r in rlist:
102 102 self.assertEqual(r(1.0), testf(1.0))
103 103 execute("def g(x): return x*x")
104 104 r = pull(('testf','g'))
105 105 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
106 106
107 107 def test_push_function_globals(self):
108 108 """test that pushed functions have access to globals"""
109 109 @interactive
110 110 def geta():
111 111 return a
112 112 # self.add_engines(1)
113 113 v = self.client[-1]
114 114 v.block=True
115 115 v['f'] = geta
116 116 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
117 117 v.execute('a=5')
118 118 v.execute('b=f()')
119 119 self.assertEquals(v['b'], 5)
120 120
121 121 def test_push_function_defaults(self):
122 122 """test that pushed functions preserve default args"""
123 123 def echo(a=10):
124 124 return a
125 125 v = self.client[-1]
126 126 v.block=True
127 127 v['f'] = echo
128 128 v.execute('b=f()')
129 129 self.assertEquals(v['b'], 10)
130 130
131 131 def test_get_result(self):
132 132 """test getting results from the Hub."""
133 133 c = pmod.Client(profile='iptest')
134 134 # self.add_engines(1)
135 135 t = c.ids[-1]
136 136 v = c[t]
137 137 v2 = self.client[t]
138 138 ar = v.apply_async(wait, 1)
139 139 # give the monitor time to notice the message
140 140 time.sleep(.25)
141 141 ahr = v2.get_result(ar.msg_ids)
142 142 self.assertTrue(isinstance(ahr, AsyncHubResult))
143 143 self.assertEquals(ahr.get(), ar.get())
144 144 ar2 = v2.get_result(ar.msg_ids)
145 145 self.assertFalse(isinstance(ar2, AsyncHubResult))
146 146 c.spin()
147 147 c.close()
148 148
149 149 def test_run_newline(self):
150 150 """test that run appends newline to files"""
151 151 tmpfile = mktemp()
152 152 with open(tmpfile, 'w') as f:
153 153 f.write("""def g():
154 154 return 5
155 155 """)
156 156 v = self.client[-1]
157 157 v.run(tmpfile, block=True)
158 158 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
159 159
160 160 def test_apply_tracked(self):
161 161 """test tracking for apply"""
162 162 # self.add_engines(1)
163 163 t = self.client.ids[-1]
164 164 v = self.client[t]
165 165 v.block=False
166 166 def echo(n=1024*1024, **kwargs):
167 167 with v.temp_flags(**kwargs):
168 168 return v.apply(lambda x: x, 'x'*n)
169 169 ar = echo(1, track=False)
170 170 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
171 171 self.assertTrue(ar.sent)
172 172 ar = echo(track=True)
173 173 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
174 174 self.assertEquals(ar.sent, ar._tracker.done)
175 175 ar._tracker.wait()
176 176 self.assertTrue(ar.sent)
177 177
178 178 def test_push_tracked(self):
179 179 t = self.client.ids[-1]
180 180 ns = dict(x='x'*1024*1024)
181 181 v = self.client[t]
182 182 ar = v.push(ns, block=False, track=False)
183 183 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
184 184 self.assertTrue(ar.sent)
185 185
186 186 ar = v.push(ns, block=False, track=True)
187 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 188 self.assertEquals(ar.sent, ar._tracker.done)
189 189 ar._tracker.wait()
190 190 self.assertTrue(ar.sent)
191 191 ar.get()
192 192
193 193 def test_scatter_tracked(self):
194 194 t = self.client.ids
195 195 x='x'*1024*1024
196 196 ar = self.client[t].scatter('x', x, block=False, track=False)
197 197 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
198 198 self.assertTrue(ar.sent)
199 199
200 200 ar = self.client[t].scatter('x', x, block=False, track=True)
201 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 202 self.assertEquals(ar.sent, ar._tracker.done)
203 203 ar._tracker.wait()
204 204 self.assertTrue(ar.sent)
205 205 ar.get()
206 206
207 207 def test_remote_reference(self):
208 208 v = self.client[-1]
209 209 v['a'] = 123
210 210 ra = pmod.Reference('a')
211 211 b = v.apply_sync(lambda x: x, ra)
212 212 self.assertEquals(b, 123)
213 213
214 214
215 215 def test_scatter_gather(self):
216 216 view = self.client[:]
217 217 seq1 = range(16)
218 218 view.scatter('a', seq1)
219 219 seq2 = view.gather('a', block=True)
220 220 self.assertEquals(seq2, seq1)
221 221 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
222 222
223 223 @skip_without('numpy')
224 224 def test_scatter_gather_numpy(self):
225 225 import numpy
226 226 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
227 227 view = self.client[:]
228 228 a = numpy.arange(64)
229 229 view.scatter('a', a)
230 230 b = view.gather('a', block=True)
231 231 assert_array_equal(b, a)
232 232
233 233 def test_map(self):
234 234 view = self.client[:]
235 235 def f(x):
236 236 return x**2
237 237 data = range(16)
238 238 r = view.map_sync(f, data)
239 239 self.assertEquals(r, map(f, data))
240 240
241 241 def test_scatterGatherNonblocking(self):
242 242 data = range(16)
243 243 view = self.client[:]
244 244 view.scatter('a', data, block=False)
245 245 ar = view.gather('a', block=False)
246 246 self.assertEquals(ar.get(), data)
247 247
248 248 @skip_without('numpy')
249 249 def test_scatter_gather_numpy_nonblocking(self):
250 250 import numpy
251 251 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
252 252 a = numpy.arange(64)
253 253 view = self.client[:]
254 254 ar = view.scatter('a', a, block=False)
255 255 self.assertTrue(isinstance(ar, AsyncResult))
256 256 amr = view.gather('a', block=False)
257 257 self.assertTrue(isinstance(amr, AsyncMapResult))
258 258 assert_array_equal(amr.get(), a)
259 259
260 260 def test_execute(self):
261 261 view = self.client[:]
262 262 # self.client.debug=True
263 263 execute = view.execute
264 264 ar = execute('c=30', block=False)
265 265 self.assertTrue(isinstance(ar, AsyncResult))
266 266 ar = execute('d=[0,1,2]', block=False)
267 267 self.client.wait(ar, 1)
268 268 self.assertEquals(len(ar.get()), len(self.client))
269 269 for c in view['c']:
270 270 self.assertEquals(c, 30)
271 271
272 272 def test_abort(self):
273 273 view = self.client[-1]
274 274 ar = view.execute('import time; time.sleep(0.25)', block=False)
275 275 ar2 = view.apply_async(lambda : 2)
276 276 ar3 = view.apply_async(lambda : 3)
277 277 view.abort(ar2)
278 278 view.abort(ar3.msg_ids)
279 279 self.assertRaises(error.TaskAborted, ar2.get)
280 280 self.assertRaises(error.TaskAborted, ar3.get)
281 281
282 282 def test_temp_flags(self):
283 283 view = self.client[-1]
284 284 view.block=True
285 285 with view.temp_flags(block=False):
286 286 self.assertFalse(view.block)
287 287 self.assertTrue(view.block)
288 288
289 289 def test_importer(self):
290 290 view = self.client[-1]
291 291 view.clear(block=True)
292 292 with view.importer:
293 293 import re
294 294
295 295 @interactive
296 296 def findall(pat, s):
297 297 # this globals() step isn't necessary in real code
298 298 # only to prevent a closure in the test
299 299 re = globals()['re']
300 300 return re.findall(pat, s)
301 301
302 302 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
303 303
304 304 # parallel magic tests
305 305
306 306 def test_magic_px_blocking(self):
307 307 ip = get_ipython()
308 308 v = self.client[-1]
309 309 v.activate()
310 310 v.block=True
311 311
312 312 ip.magic_px('a=5')
313 313 self.assertEquals(v['a'], 5)
314 314 ip.magic_px('a=10')
315 315 self.assertEquals(v['a'], 10)
316 316 sio = StringIO()
317 317 savestdout = sys.stdout
318 318 sys.stdout = sio
319 319 ip.magic_px('print a')
320 320 sys.stdout = savestdout
321 321 sio.read()
322 322 self.assertTrue('[stdout:%i]'%v.targets in sio.buf)
323 323 self.assertTrue(sio.buf.rstrip().endswith('10'))
324 324 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
325 325
326 326 def test_magic_px_nonblocking(self):
327 327 ip = get_ipython()
328 328 v = self.client[-1]
329 329 v.activate()
330 330 v.block=False
331 331
332 332 ip.magic_px('a=5')
333 333 self.assertEquals(v['a'], 5)
334 334 ip.magic_px('a=10')
335 335 self.assertEquals(v['a'], 10)
336 336 sio = StringIO()
337 337 savestdout = sys.stdout
338 338 sys.stdout = sio
339 339 ip.magic_px('print a')
340 340 sys.stdout = savestdout
341 341 sio.read()
342 342 self.assertFalse('[stdout:%i]'%v.targets in sio.buf)
343 343 ip.magic_px('1/0')
344 344 ar = v.get_result(-1)
345 345 self.assertRaisesRemote(ZeroDivisionError, ar.get)
346 346
347 347 def test_magic_autopx_blocking(self):
348 348 ip = get_ipython()
349 349 v = self.client[-1]
350 350 v.activate()
351 351 v.block=True
352 352
353 353 sio = StringIO()
354 354 savestdout = sys.stdout
355 355 sys.stdout = sio
356 356 ip.magic_autopx()
357 357 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
358 358 ip.run_cell('print b')
359 359 ip.run_cell("b/c")
360 360 ip.run_code(compile('b*=2', '', 'single'))
361 361 ip.magic_autopx()
362 362 sys.stdout = savestdout
363 363 sio.read()
364 364 output = sio.buf.strip()
365 365 self.assertTrue(output.startswith('%autopx enabled'))
366 366 self.assertTrue(output.endswith('%autopx disabled'))
367 367 self.assertTrue('RemoteError: ZeroDivisionError' in output)
368 368 ar = v.get_result(-2)
369 369 self.assertEquals(v['a'], 5)
370 370 self.assertEquals(v['b'], 20)
371 371 self.assertRaisesRemote(ZeroDivisionError, ar.get)
372 372
373 373 def test_magic_autopx_nonblocking(self):
374 374 ip = get_ipython()
375 375 v = self.client[-1]
376 376 v.activate()
377 377 v.block=False
378 378
379 379 sio = StringIO()
380 380 savestdout = sys.stdout
381 381 sys.stdout = sio
382 382 ip.magic_autopx()
383 383 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
384 384 ip.run_cell('print b')
385 385 ip.run_cell("b/c")
386 386 ip.run_code(compile('b*=2', '', 'single'))
387 387 ip.magic_autopx()
388 388 sys.stdout = savestdout
389 389 sio.read()
390 390 output = sio.buf.strip()
391 391 self.assertTrue(output.startswith('%autopx enabled'))
392 392 self.assertTrue(output.endswith('%autopx disabled'))
393 393 self.assertFalse('ZeroDivisionError' in output)
394 394 ar = v.get_result(-2)
395 395 self.assertEquals(v['a'], 5)
396 396 self.assertEquals(v['b'], 20)
397 397 self.assertRaisesRemote(ZeroDivisionError, ar.get)
398 398
399 399 def test_magic_result(self):
400 400 ip = get_ipython()
401 401 v = self.client[-1]
402 402 v.activate()
403 403 v['a'] = 111
404 404 ra = v['a']
405 405
406 406 ar = ip.magic_result()
407 407 self.assertEquals(ar.msg_ids, [v.history[-1]])
408 408 self.assertEquals(ar.get(), 111)
409 409 ar = ip.magic_result('-2')
410 410 self.assertEquals(ar.msg_ids, [v.history[-2]])
411 411
412 412 def test_unicode_execute(self):
413 413 """test executing unicode strings"""
414 414 v = self.client[-1]
415 415 v.block=True
416 code=u"a=u'é'"
416 if sys.version_info[0] >= 3:
417 code="a='é'"
418 else:
419 code=u"a=u'é'"
417 420 v.execute(code)
418 421 self.assertEquals(v['a'], u'é')
419 422
420 423 def test_unicode_apply_result(self):
421 424 """test unicode apply results"""
422 425 v = self.client[-1]
423 426 r = v.apply_sync(lambda : u'é')
424 427 self.assertEquals(r, u'é')
425 428
426 429 def test_unicode_apply_arg(self):
427 430 """test passing unicode arguments to apply"""
428 431 v = self.client[-1]
429 432
430 433 @interactive
431 434 def check_unicode(a, check):
432 435 assert isinstance(a, unicode), "%r is not unicode"%a
433 436 assert isinstance(check, bytes), "%r is not bytes"%check
434 437 assert a.encode('utf8') == check, "%s != %s"%(a,check)
435 438
436 for s in [ u'é', u'ßø®∫','asdf'.decode() ]:
439 for s in [ u'é', u'ßø®∫',u'asdf' ]:
437 440 try:
438 441 v.apply_sync(check_unicode, s, s.encode('utf8'))
439 442 except error.RemoteError as e:
440 443 if e.ename == 'AssertionError':
441 444 self.fail(e.evalue)
442 445 else:
443 446 raise e
444 447
445 448
446 449
@@ -1,450 +1,456 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 44 from IPython.utils.newserialized import serialize, unserialize
45 45 from IPython.zmq.log import EnginePUBHandler
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Classes
49 49 #-----------------------------------------------------------------------------
50 50
51 51 class Namespace(dict):
52 52 """Subclass of dict for attribute access to keys."""
53 53
54 54 def __getattr__(self, key):
55 55 """getattr aliased to getitem"""
56 56 if key in self.iterkeys():
57 57 return self[key]
58 58 else:
59 59 raise NameError(key)
60 60
61 61 def __setattr__(self, key, value):
62 62 """setattr aliased to setitem, with strict"""
63 63 if hasattr(dict, key):
64 64 raise KeyError("Cannot override dict keys %r"%key)
65 65 self[key] = value
66 66
67 67
68 68 class ReverseDict(dict):
69 69 """simple double-keyed subset of dict methods."""
70 70
71 71 def __init__(self, *args, **kwargs):
72 72 dict.__init__(self, *args, **kwargs)
73 73 self._reverse = dict()
74 74 for key, value in self.iteritems():
75 75 self._reverse[value] = key
76 76
77 77 def __getitem__(self, key):
78 78 try:
79 79 return dict.__getitem__(self, key)
80 80 except KeyError:
81 81 return self._reverse[key]
82 82
83 83 def __setitem__(self, key, value):
84 84 if key in self._reverse:
85 85 raise KeyError("Can't have key %r on both sides!"%key)
86 86 dict.__setitem__(self, key, value)
87 87 self._reverse[value] = key
88 88
89 89 def pop(self, key):
90 90 value = dict.pop(self, key)
91 91 self._reverse.pop(value)
92 92 return value
93 93
94 94 def get(self, key, default=None):
95 95 try:
96 96 return self[key]
97 97 except KeyError:
98 98 return default
99 99
100 100 #-----------------------------------------------------------------------------
101 101 # Functions
102 102 #-----------------------------------------------------------------------------
103 103
104 def ensure_bytes(s):
105 """ensure that an object is bytes"""
106 if isinstance(s, unicode):
107 s = s.encode(sys.getdefaultencoding(), 'replace')
108 return s
109
104 110 def validate_url(url):
105 111 """validate a url for zeromq"""
106 112 if not isinstance(url, basestring):
107 113 raise TypeError("url must be a string, not %r"%type(url))
108 114 url = url.lower()
109 115
110 116 proto_addr = url.split('://')
111 117 assert len(proto_addr) == 2, 'Invalid url: %r'%url
112 118 proto, addr = proto_addr
113 119 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
114 120
115 121 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
116 122 # author: Remi Sabourin
117 123 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
118 124
119 125 if proto == 'tcp':
120 126 lis = addr.split(':')
121 127 assert len(lis) == 2, 'Invalid url: %r'%url
122 128 addr,s_port = lis
123 129 try:
124 130 port = int(s_port)
125 131 except ValueError:
126 132 raise AssertionError("Invalid port %r in url: %r"%(port, url))
127 133
128 134 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
129 135
130 136 else:
131 137 # only validate tcp urls currently
132 138 pass
133 139
134 140 return True
135 141
136 142
137 143 def validate_url_container(container):
138 144 """validate a potentially nested collection of urls."""
139 145 if isinstance(container, basestring):
140 146 url = container
141 147 return validate_url(url)
142 148 elif isinstance(container, dict):
143 149 container = container.itervalues()
144 150
145 151 for element in container:
146 152 validate_url_container(element)
147 153
148 154
149 155 def split_url(url):
150 156 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
151 157 proto_addr = url.split('://')
152 158 assert len(proto_addr) == 2, 'Invalid url: %r'%url
153 159 proto, addr = proto_addr
154 160 lis = addr.split(':')
155 161 assert len(lis) == 2, 'Invalid url: %r'%url
156 162 addr,s_port = lis
157 163 return proto,addr,s_port
158 164
159 165 def disambiguate_ip_address(ip, location=None):
160 166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
161 167 ones, based on the location (default interpretation of location is localhost)."""
162 168 if ip in ('0.0.0.0', '*'):
163 169 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
164 170 if location is None or location in external_ips:
165 171 ip='127.0.0.1'
166 172 elif location:
167 173 return location
168 174 return ip
169 175
170 176 def disambiguate_url(url, location=None):
171 177 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
172 178 ones, based on the location (default interpretation is localhost).
173 179
174 180 This is for zeromq urls, such as tcp://*:10101."""
175 181 try:
176 182 proto,ip,port = split_url(url)
177 183 except AssertionError:
178 184 # probably not tcp url; could be ipc, etc.
179 185 return url
180 186
181 187 ip = disambiguate_ip_address(ip,location)
182 188
183 189 return "%s://%s:%s"%(proto,ip,port)
184 190
185 191 def serialize_object(obj, threshold=64e-6):
186 192 """Serialize an object into a list of sendable buffers.
187 193
188 194 Parameters
189 195 ----------
190 196
191 197 obj : object
192 198 The object to be serialized
193 199 threshold : float
194 200 The threshold for not double-pickling the content.
195 201
196 202
197 203 Returns
198 204 -------
199 205 ('pmd', [bufs]) :
200 206 where pmd is the pickled metadata wrapper,
201 207 bufs is a list of data buffers
202 208 """
203 209 databuffers = []
204 210 if isinstance(obj, (list, tuple)):
205 211 clist = canSequence(obj)
206 212 slist = map(serialize, clist)
207 213 for s in slist:
208 214 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
209 215 databuffers.append(s.getData())
210 216 s.data = None
211 217 return pickle.dumps(slist,-1), databuffers
212 218 elif isinstance(obj, dict):
213 219 sobj = {}
214 220 for k in sorted(obj.iterkeys()):
215 221 s = serialize(can(obj[k]))
216 222 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
217 223 databuffers.append(s.getData())
218 224 s.data = None
219 225 sobj[k] = s
220 226 return pickle.dumps(sobj,-1),databuffers
221 227 else:
222 228 s = serialize(can(obj))
223 229 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
224 230 databuffers.append(s.getData())
225 231 s.data = None
226 232 return pickle.dumps(s,-1),databuffers
227 233
228 234
229 235 def unserialize_object(bufs):
230 236 """reconstruct an object serialized by serialize_object from data buffers."""
231 237 bufs = list(bufs)
232 238 sobj = pickle.loads(bufs.pop(0))
233 239 if isinstance(sobj, (list, tuple)):
234 240 for s in sobj:
235 241 if s.data is None:
236 242 s.data = bufs.pop(0)
237 243 return uncanSequence(map(unserialize, sobj)), bufs
238 244 elif isinstance(sobj, dict):
239 245 newobj = {}
240 246 for k in sorted(sobj.iterkeys()):
241 247 s = sobj[k]
242 248 if s.data is None:
243 249 s.data = bufs.pop(0)
244 250 newobj[k] = uncan(unserialize(s))
245 251 return newobj, bufs
246 252 else:
247 253 if sobj.data is None:
248 254 sobj.data = bufs.pop(0)
249 255 return uncan(unserialize(sobj)), bufs
250 256
251 257 def pack_apply_message(f, args, kwargs, threshold=64e-6):
252 258 """pack up a function, args, and kwargs to be sent over the wire
253 259 as a series of buffers. Any object whose data is larger than `threshold`
254 260 will not have their data copied (currently only numpy arrays support zero-copy)"""
255 261 msg = [pickle.dumps(can(f),-1)]
256 262 databuffers = [] # for large objects
257 263 sargs, bufs = serialize_object(args,threshold)
258 264 msg.append(sargs)
259 265 databuffers.extend(bufs)
260 266 skwargs, bufs = serialize_object(kwargs,threshold)
261 267 msg.append(skwargs)
262 268 databuffers.extend(bufs)
263 269 msg.extend(databuffers)
264 270 return msg
265 271
266 272 def unpack_apply_message(bufs, g=None, copy=True):
267 273 """unpack f,args,kwargs from buffers packed by pack_apply_message()
268 274 Returns: original f,args,kwargs"""
269 275 bufs = list(bufs) # allow us to pop
270 276 assert len(bufs) >= 3, "not enough buffers!"
271 277 if not copy:
272 278 for i in range(3):
273 279 bufs[i] = bufs[i].bytes
274 280 cf = pickle.loads(bufs.pop(0))
275 281 sargs = list(pickle.loads(bufs.pop(0)))
276 282 skwargs = dict(pickle.loads(bufs.pop(0)))
277 283 # print sargs, skwargs
278 284 f = uncan(cf, g)
279 285 for sa in sargs:
280 286 if sa.data is None:
281 287 m = bufs.pop(0)
282 288 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
283 289 # always use a buffer, until memoryviews get sorted out
284 290 sa.data = buffer(m)
285 291 # disable memoryview support
286 292 # if copy:
287 293 # sa.data = buffer(m)
288 294 # else:
289 295 # sa.data = m.buffer
290 296 else:
291 297 if copy:
292 298 sa.data = m
293 299 else:
294 300 sa.data = m.bytes
295 301
296 302 args = uncanSequence(map(unserialize, sargs), g)
297 303 kwargs = {}
298 304 for k in sorted(skwargs.iterkeys()):
299 305 sa = skwargs[k]
300 306 if sa.data is None:
301 307 m = bufs.pop(0)
302 308 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
303 309 # always use a buffer, until memoryviews get sorted out
304 310 sa.data = buffer(m)
305 311 # disable memoryview support
306 312 # if copy:
307 313 # sa.data = buffer(m)
308 314 # else:
309 315 # sa.data = m.buffer
310 316 else:
311 317 if copy:
312 318 sa.data = m
313 319 else:
314 320 sa.data = m.bytes
315 321
316 322 kwargs[k] = uncan(unserialize(sa), g)
317 323
318 324 return f,args,kwargs
319 325
320 326 #--------------------------------------------------------------------------
321 327 # helpers for implementing old MEC API via view.apply
322 328 #--------------------------------------------------------------------------
323 329
324 330 def interactive(f):
325 331 """decorator for making functions appear as interactively defined.
326 332 This results in the function being linked to the user_ns as globals()
327 333 instead of the module globals().
328 334 """
329 335 f.__module__ = '__main__'
330 336 return f
331 337
332 338 @interactive
333 339 def _push(ns):
334 340 """helper method for implementing `client.push` via `client.apply`"""
335 341 globals().update(ns)
336 342
337 343 @interactive
338 344 def _pull(keys):
339 345 """helper method for implementing `client.pull` via `client.apply`"""
340 346 user_ns = globals()
341 347 if isinstance(keys, (list,tuple, set)):
342 348 for key in keys:
343 349 if not user_ns.has_key(key):
344 350 raise NameError("name '%s' is not defined"%key)
345 351 return map(user_ns.get, keys)
346 352 else:
347 353 if not user_ns.has_key(keys):
348 354 raise NameError("name '%s' is not defined"%keys)
349 355 return user_ns.get(keys)
350 356
351 357 @interactive
352 358 def _execute(code):
353 359 """helper method for implementing `client.execute` via `client.apply`"""
354 360 exec code in globals()
355 361
356 362 #--------------------------------------------------------------------------
357 363 # extra process management utilities
358 364 #--------------------------------------------------------------------------
359 365
360 366 _random_ports = set()
361 367
362 368 def select_random_ports(n):
363 369 """Selects and return n random ports that are available."""
364 370 ports = []
365 371 for i in xrange(n):
366 372 sock = socket.socket()
367 373 sock.bind(('', 0))
368 374 while sock.getsockname()[1] in _random_ports:
369 375 sock.close()
370 376 sock = socket.socket()
371 377 sock.bind(('', 0))
372 378 ports.append(sock)
373 379 for i, sock in enumerate(ports):
374 380 port = sock.getsockname()[1]
375 381 sock.close()
376 382 ports[i] = port
377 383 _random_ports.add(port)
378 384 return ports
379 385
380 386 def signal_children(children):
381 387 """Relay interupt/term signals to children, for more solid process cleanup."""
382 388 def terminate_children(sig, frame):
383 389 logging.critical("Got signal %i, terminating children..."%sig)
384 390 for child in children:
385 391 child.terminate()
386 392
387 393 sys.exit(sig != SIGINT)
388 394 # sys.exit(sig)
389 395 for sig in (SIGINT, SIGABRT, SIGTERM):
390 396 signal(sig, terminate_children)
391 397
392 398 def generate_exec_key(keyfile):
393 399 import uuid
394 400 newkey = str(uuid.uuid4())
395 401 with open(keyfile, 'w') as f:
396 402 # f.write('ipython-key ')
397 403 f.write(newkey+'\n')
398 404 # set user-only RW permissions (0600)
399 405 # this will have no effect on Windows
400 406 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
401 407
402 408
403 409 def integer_loglevel(loglevel):
404 410 try:
405 411 loglevel = int(loglevel)
406 412 except ValueError:
407 413 if isinstance(loglevel, str):
408 414 loglevel = getattr(logging, loglevel)
409 415 return loglevel
410 416
411 417 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
412 418 logger = logging.getLogger(logname)
413 419 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
414 420 # don't add a second PUBHandler
415 421 return
416 422 loglevel = integer_loglevel(loglevel)
417 423 lsock = context.socket(zmq.PUB)
418 424 lsock.connect(iface)
419 425 handler = handlers.PUBHandler(lsock)
420 426 handler.setLevel(loglevel)
421 427 handler.root_topic = root
422 428 logger.addHandler(handler)
423 429 logger.setLevel(loglevel)
424 430
425 431 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
426 432 logger = logging.getLogger()
427 433 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
428 434 # don't add a second PUBHandler
429 435 return
430 436 loglevel = integer_loglevel(loglevel)
431 437 lsock = context.socket(zmq.PUB)
432 438 lsock.connect(iface)
433 439 handler = EnginePUBHandler(engine, lsock)
434 440 handler.setLevel(loglevel)
435 441 logger.addHandler(handler)
436 442 logger.setLevel(loglevel)
437 443 return logger
438 444
439 445 def local_logger(logname, loglevel=logging.DEBUG):
440 446 loglevel = integer_loglevel(loglevel)
441 447 logger = logging.getLogger(logname)
442 448 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
443 449 # don't add a second StreamHandler
444 450 return
445 451 handler = logging.StreamHandler()
446 452 handler.setLevel(loglevel)
447 453 logger.addHandler(handler)
448 454 logger.setLevel(loglevel)
449 455 return logger
450 456
@@ -1,39 +1,43 b''
1 1 # encoding: utf-8
2 2
3 3 """Utilities to enable code objects to be pickled.
4 4
5 5 Any process that import this module will be able to pickle code objects. This
6 6 includes the func_code attribute of any function. Once unpickled, new
7 7 functions can be built using new.function(code, globals()). Eventually
8 8 we need to automate all of this so that functions themselves can be pickled.
9 9
10 10 Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305
11 11 """
12 12
13 13 __docformat__ = "restructuredtext en"
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Copyright (C) 2008 The IPython Development Team
17 17 #
18 18 # Distributed under the terms of the BSD License. The full license is in
19 19 # the file COPYING, distributed as part of this software.
20 20 #-------------------------------------------------------------------------------
21 21
22 22 #-------------------------------------------------------------------------------
23 23 # Imports
24 24 #-------------------------------------------------------------------------------
25 25
26 import sys
26 27 import new, types, copy_reg
27 28
28 29 def code_ctor(*args):
29 30 return new.code(*args)
30 31
31 32 def reduce_code(co):
32 33 if co.co_freevars or co.co_cellvars:
33 34 raise ValueError("Sorry, cannot pickle code objects with closures")
34 return code_ctor, (co.co_argcount, co.co_nlocals, co.co_stacksize,
35 co.co_flags, co.co_code, co.co_consts, co.co_names,
36 co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno,
37 co.co_lnotab)
35 args = [co.co_argcount, co.co_nlocals, co.co_stacksize,
36 co.co_flags, co.co_code, co.co_consts, co.co_names,
37 co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno,
38 co.co_lnotab]
39 if sys.version_info[0] >= 3:
40 args.insert(1, co.co_kwonlyargcount)
41 return code_ctor, tuple(args)
38 42
39 43 copy_reg.pickle(types.CodeType, reduce_code) No newline at end of file
@@ -1,169 +1,179 b''
1 1 # encoding: utf-8
2 2 # -*- test-case-name: IPython.kernel.test.test_newserialized -*-
3 3
4 4 """Refactored serialization classes and interfaces."""
5 5
6 6 __docformat__ = "restructuredtext en"
7 7
8 8 # Tell nose to skip this module
9 9 __test__ = {}
10 10
11 11 #-------------------------------------------------------------------------------
12 12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-------------------------------------------------------------------------------
17 17
18 18 #-------------------------------------------------------------------------------
19 19 # Imports
20 20 #-------------------------------------------------------------------------------
21 21
22 import sys
22 23 import cPickle as pickle
23 24
24 25 try:
25 26 import numpy
26 27 except ImportError:
27 pass
28 numpy = None
28 29
29 30 class SerializationError(Exception):
30 31 pass
31 32
33 if sys.version_info[0] >= 3:
34 buffer = memoryview
35 py3k = True
36 else:
37 py3k = False
38
32 39 #-----------------------------------------------------------------------------
33 40 # Classes and functions
34 41 #-----------------------------------------------------------------------------
35 42
36 43 class ISerialized:
37 44
38 45 def getData():
39 46 """"""
40 47
41 48 def getDataSize(units=10.0**6):
42 49 """"""
43 50
44 51 def getTypeDescriptor():
45 52 """"""
46 53
47 54 def getMetadata():
48 55 """"""
49 56
50 57
51 58 class IUnSerialized:
52 59
53 60 def getObject():
54 61 """"""
55 62
56 63 class Serialized(object):
57 64
58 65 # implements(ISerialized)
59 66
60 67 def __init__(self, data, typeDescriptor, metadata={}):
61 68 self.data = data
62 69 self.typeDescriptor = typeDescriptor
63 70 self.metadata = metadata
64 71
65 72 def getData(self):
66 73 return self.data
67 74
68 75 def getDataSize(self, units=10.0**6):
69 76 return len(self.data)/units
70 77
71 78 def getTypeDescriptor(self):
72 79 return self.typeDescriptor
73 80
74 81 def getMetadata(self):
75 82 return self.metadata
76 83
77 84
78 85 class UnSerialized(object):
79 86
80 87 # implements(IUnSerialized)
81 88
82 89 def __init__(self, obj):
83 90 self.obj = obj
84 91
85 92 def getObject(self):
86 93 return self.obj
87 94
88 95
89 96 class SerializeIt(object):
90 97
91 98 # implements(ISerialized)
92 99
93 100 def __init__(self, unSerialized):
94 101 self.data = None
95 102 self.obj = unSerialized.getObject()
96 if globals().has_key('numpy') and isinstance(self.obj, numpy.ndarray):
97 if len(self.obj.shape) == 0: # length 0 arrays are just pickled
103 if numpy is not None and isinstance(self.obj, numpy.ndarray):
104 if py3k or len(self.obj.shape) == 0: # length 0 arrays are just pickled
105 # FIXME:
106 # also use pickle for numpy arrays on py3k, since
107 # pyzmq doesn't rebuild from memoryviews properly
98 108 self.typeDescriptor = 'pickle'
99 109 self.metadata = {}
100 110 else:
101 111 self.obj = numpy.ascontiguousarray(self.obj, dtype=None)
102 112 self.typeDescriptor = 'ndarray'
103 113 self.metadata = {'shape':self.obj.shape,
104 114 'dtype':self.obj.dtype.str}
105 elif isinstance(self.obj, str):
115 elif isinstance(self.obj, bytes):
106 116 self.typeDescriptor = 'bytes'
107 117 self.metadata = {}
108 118 elif isinstance(self.obj, buffer):
109 119 self.typeDescriptor = 'buffer'
110 120 self.metadata = {}
111 121 else:
112 122 self.typeDescriptor = 'pickle'
113 123 self.metadata = {}
114 124 self._generateData()
115 125
116 126 def _generateData(self):
117 127 if self.typeDescriptor == 'ndarray':
118 128 self.data = numpy.getbuffer(self.obj)
119 129 elif self.typeDescriptor in ('bytes', 'buffer'):
120 130 self.data = self.obj
121 131 elif self.typeDescriptor == 'pickle':
122 132 self.data = pickle.dumps(self.obj, -1)
123 133 else:
124 134 raise SerializationError("Really wierd serialization error.")
125 135 del self.obj
126 136
127 137 def getData(self):
128 138 return self.data
129 139
130 140 def getDataSize(self, units=10.0**6):
131 141 return 1.0*len(self.data)/units
132 142
133 143 def getTypeDescriptor(self):
134 144 return self.typeDescriptor
135 145
136 146 def getMetadata(self):
137 147 return self.metadata
138 148
139 149
140 150 class UnSerializeIt(UnSerialized):
141 151
142 152 # implements(IUnSerialized)
143 153
144 154 def __init__(self, serialized):
145 155 self.serialized = serialized
146 156
147 157 def getObject(self):
148 158 typeDescriptor = self.serialized.getTypeDescriptor()
149 if globals().has_key('numpy') and typeDescriptor == 'ndarray':
159 if numpy is not None and typeDescriptor == 'ndarray':
150 160 buf = self.serialized.getData()
151 if isinstance(buf, (str, buffer)):
161 if isinstance(buf, (bytes, buffer)):
152 162 result = numpy.frombuffer(buf, dtype = self.serialized.metadata['dtype'])
153 163 else:
154 164 # memoryview
155 165 result = numpy.array(buf, dtype = self.serialized.metadata['dtype'])
156 166 result.shape = self.serialized.metadata['shape']
157 167 elif typeDescriptor == 'pickle':
158 168 result = pickle.loads(self.serialized.getData())
159 169 elif typeDescriptor in ('bytes', 'buffer'):
160 170 result = self.serialized.getData()
161 171 else:
162 172 raise SerializationError("Really wierd serialization error.")
163 173 return result
164 174
165 175 def serialize(obj):
166 176 return SerializeIt(UnSerialized(obj))
167 177
168 178 def unserialize(serialized):
169 179 return UnSerializeIt(serialized).getObject()
General Comments 0
You need to be logged in to leave comments. Login now