##// END OF EJS Templates
support iterating through map results as they arrive
MinRK -
Show More
@@ -1,208 +1,232 b''
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from IPython.external.decorator import decorator
14 14 import error
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Classes
18 18 #-----------------------------------------------------------------------------
19 19
20 20 @decorator
21 21 def check_ready(f, self, *args, **kwargs):
22 22 """Call spin() to sync state prior to calling the method."""
23 23 self.wait(0)
24 24 if not self._ready:
25 25 raise error.TimeoutError("result not ready")
26 26 return f(self, *args, **kwargs)
27 27
28 28 class AsyncResult(object):
29 29 """Class for representing results of non-blocking calls.
30 30
31 31 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
32 32 """
33 33
34 34 msg_ids = None
35 35
36 36 def __init__(self, client, msg_ids, fname=''):
37 37 self._client = client
38 if isinstance(msg_ids, basestring):
39 msg_ids = [msg_ids]
38 40 self.msg_ids = msg_ids
39 41 self._fname=fname
40 42 self._ready = False
41 43 self._success = None
42 44 self._single_result = len(msg_ids) == 1
43 45
44 46 def __repr__(self):
45 47 if self._ready:
46 48 return "<%s: finished>"%(self.__class__.__name__)
47 49 else:
48 50 return "<%s: %s>"%(self.__class__.__name__,self._fname)
49 51
50 52
51 53 def _reconstruct_result(self, res):
52 54 """
53 55 Override me in subclasses for turning a list of results
54 56 into the expected form.
55 57 """
56 58 if self._single_result:
57 59 return res[0]
58 60 else:
59 61 return res
60 62
61 63 def get(self, timeout=-1):
62 64 """Return the result when it arrives.
63 65
64 66 If `timeout` is not ``None`` and the result does not arrive within
65 67 `timeout` seconds then ``TimeoutError`` is raised. If the
66 68 remote call raised an exception then that exception will be reraised
67 69 by get().
68 70 """
69 71 if not self.ready():
70 72 self.wait(timeout)
71 73
72 74 if self._ready:
73 75 if self._success:
74 76 return self._result
75 77 else:
76 78 raise self._exception
77 79 else:
78 80 raise error.TimeoutError("Result not ready.")
79 81
80 82 def ready(self):
81 83 """Return whether the call has completed."""
82 84 if not self._ready:
83 85 self.wait(0)
84 86 return self._ready
85 87
86 88 def wait(self, timeout=-1):
87 89 """Wait until the result is available or until `timeout` seconds pass.
88 90 """
89 91 if self._ready:
90 92 return
91 93 self._ready = self._client.barrier(self.msg_ids, timeout)
92 94 if self._ready:
93 95 try:
94 96 results = map(self._client.results.get, self.msg_ids)
95 97 self._result = results
96 98 if self._single_result:
97 99 r = results[0]
98 100 if isinstance(r, Exception):
99 101 raise r
100 102 else:
101 103 results = error.collect_exceptions(results, self._fname)
102 104 self._result = self._reconstruct_result(results)
103 105 except Exception, e:
104 106 self._exception = e
105 107 self._success = False
106 108 else:
107 109 self._success = True
108 110 finally:
109 111 self._metadata = map(self._client.metadata.get, self.msg_ids)
110 112
111 113
112 114 def successful(self):
113 115 """Return whether the call completed without raising an exception.
114 116
115 117 Will raise ``AssertionError`` if the result is not ready.
116 118 """
117 119 assert self._ready
118 120 return self._success
119 121
120 122 #----------------------------------------------------------------
121 123 # Extra methods not in mp.pool.AsyncResult
122 124 #----------------------------------------------------------------
123 125
124 126 def get_dict(self, timeout=-1):
125 127 """Get the results as a dict, keyed by engine_id."""
126 128 results = self.get(timeout)
127 129 engine_ids = [ md['engine_id'] for md in self._metadata ]
128 130 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
129 131 maxcount = bycount.count(bycount[-1])
130 132 if maxcount > 1:
131 133 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
132 134 maxcount, bycount[-1]))
133 135
134 136 return dict(zip(engine_ids,results))
135 137
136 138 @property
137 139 @check_ready
138 140 def result(self):
139 141 """result property."""
140 142 return self._result
141 143
142 144 # abbreviated alias:
143 145 r = result
144 146
145 147 @property
146 148 @check_ready
147 149 def metadata(self):
148 150 """metadata property."""
149 151 if self._single_result:
150 152 return self._metadata[0]
151 153 else:
152 154 return self._metadata
153 155
154 156 @property
155 157 def result_dict(self):
156 158 """result property as a dict."""
157 159 return self.get_dict(0)
158 160
159 161 def __dict__(self):
160 162 return self.get_dict(0)
161 163
162 164 #-------------------------------------
163 165 # dict-access
164 166 #-------------------------------------
165 167
166 168 @check_ready
167 169 def __getitem__(self, key):
168 170 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
169 171 """
170 172 if isinstance(key, int):
171 173 return error.collect_exceptions([self._result[key]], self._fname)[0]
172 174 elif isinstance(key, slice):
173 175 return error.collect_exceptions(self._result[key], self._fname)
174 176 elif isinstance(key, basestring):
175 177 values = [ md[key] for md in self._metadata ]
176 178 if self._single_result:
177 179 return values[0]
178 180 else:
179 181 return values
180 182 else:
181 183 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
182 184
183 185 @check_ready
184 186 def __getattr__(self, key):
185 187 """getattr maps to getitem for convenient access to metadata."""
186 188 if key not in self._metadata[0].keys():
187 189 raise AttributeError("%r object has no attribute %r"%(
188 190 self.__class__.__name__, key))
189 191 return self.__getitem__(key)
190 192
191 193
192 194 class AsyncMapResult(AsyncResult):
193 195 """Class for representing results of non-blocking gathers.
194 196
195 197 This will properly reconstruct the gather.
196 198 """
197 199
198 200 def __init__(self, client, msg_ids, mapObject, fname=''):
199 201 AsyncResult.__init__(self, client, msg_ids, fname=fname)
200 202 self._mapObject = mapObject
201 203 self._single_result = False
202 204
203 205 def _reconstruct_result(self, res):
204 206 """Perform the gather on the actual results."""
205 207 return self._mapObject.joinPartitions(res)
206 208
209 # asynchronous iterator:
210 def __iter__(self):
211 try:
212 rlist = self.get(0)
213 except error.TimeoutError:
214 # wait for each result individually
215 for msg_id in self.msg_ids:
216 ar = AsyncResult(self._client, msg_id, self._fname)
217 rlist = ar.get()
218 try:
219 for r in rlist:
220 yield r
221 except TypeError:
222 # flattened, not a list
223 # this could get broken by flattened data that returns iterables
224 # but most calls to map do not expose the `flatten` argument
225 yield rlist
226 else:
227 # already done
228 for r in rlist:
229 yield r
230
207 231
208 232 __all__ = ['AsyncResult', 'AsyncMapResult'] No newline at end of file
@@ -1,1400 +1,1472 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import time
15 15 from getpass import getpass
16 16 from pprint import pprint
17 17 from datetime import datetime
18 18 import warnings
19 19 import json
20 20 pjoin = os.path.join
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.utils.path import get_ipython_dir
26 26 from IPython.external.decorator import decorator
27 27 from IPython.external.ssh import tunnel
28 28
29 29 import streamsession as ss
30 30 from clusterdir import ClusterDir, ClusterDirError
31 31 # from remotenamespace import RemoteNamespace
32 32 from view import DirectView, LoadBalancedView
33 33 from dependency import Dependency, depend, require, dependent
34 34 import error
35 35 import map as Map
36 36 from asyncresult import AsyncResult, AsyncMapResult
37 37 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
38 38 from util import ReverseDict, disambiguate_url, validate_url
39 39
40 40 #--------------------------------------------------------------------------
41 41 # helpers for implementing old MEC API via client.apply
42 42 #--------------------------------------------------------------------------
43 43
44 44 def _push(ns):
45 45 """helper method for implementing `client.push` via `client.apply`"""
46 46 globals().update(ns)
47 47
48 48 def _pull(keys):
49 49 """helper method for implementing `client.pull` via `client.apply`"""
50 50 g = globals()
51 51 if isinstance(keys, (list,tuple, set)):
52 52 for key in keys:
53 53 if not g.has_key(key):
54 54 raise NameError("name '%s' is not defined"%key)
55 55 return map(g.get, keys)
56 56 else:
57 57 if not g.has_key(keys):
58 58 raise NameError("name '%s' is not defined"%keys)
59 59 return g.get(keys)
60 60
61 61 def _clear():
62 62 """helper method for implementing `client.clear` via `client.apply`"""
63 63 globals().clear()
64 64
65 65 def _execute(code):
66 66 """helper method for implementing `client.execute` via `client.apply`"""
67 67 exec code in globals()
68 68
69 69
70 70 #--------------------------------------------------------------------------
71 71 # Decorators for Client methods
72 72 #--------------------------------------------------------------------------
73 73
74 74 @decorator
75 75 def spinfirst(f, self, *args, **kwargs):
76 76 """Call spin() to sync state prior to calling the method."""
77 77 self.spin()
78 78 return f(self, *args, **kwargs)
79 79
80 80 @decorator
81 81 def defaultblock(f, self, *args, **kwargs):
82 82 """Default to self.block; preserve self.block."""
83 83 block = kwargs.get('block',None)
84 84 block = self.block if block is None else block
85 85 saveblock = self.block
86 86 self.block = block
87 87 try:
88 88 ret = f(self, *args, **kwargs)
89 89 finally:
90 90 self.block = saveblock
91 91 return ret
92 92
93 93
94 94 #--------------------------------------------------------------------------
95 95 # Classes
96 96 #--------------------------------------------------------------------------
97 97
98 98 class Metadata(dict):
99 99 """Subclass of dict for initializing metadata values.
100 100
101 101 Attribute access works on keys.
102 102
103 103 These objects have a strict set of keys - errors will raise if you try
104 104 to add new keys.
105 105 """
106 106 def __init__(self, *args, **kwargs):
107 107 dict.__init__(self)
108 108 md = {'msg_id' : None,
109 109 'submitted' : None,
110 110 'started' : None,
111 111 'completed' : None,
112 112 'received' : None,
113 113 'engine_uuid' : None,
114 114 'engine_id' : None,
115 115 'follow' : None,
116 116 'after' : None,
117 117 'status' : None,
118 118
119 119 'pyin' : None,
120 120 'pyout' : None,
121 121 'pyerr' : None,
122 122 'stdout' : '',
123 123 'stderr' : '',
124 124 }
125 125 self.update(md)
126 126 self.update(dict(*args, **kwargs))
127 127
128 128 def __getattr__(self, key):
129 129 """getattr aliased to getitem"""
130 130 if key in self.iterkeys():
131 131 return self[key]
132 132 else:
133 133 raise AttributeError(key)
134 134
135 135 def __setattr__(self, key, value):
136 136 """setattr aliased to setitem, with strict"""
137 137 if key in self.iterkeys():
138 138 self[key] = value
139 139 else:
140 140 raise AttributeError(key)
141 141
142 142 def __setitem__(self, key, value):
143 143 """strict static key enforcement"""
144 144 if key in self.iterkeys():
145 145 dict.__setitem__(self, key, value)
146 146 else:
147 147 raise KeyError(key)
148 148
149 149
150 150 class Client(object):
151 151 """A semi-synchronous client to the IPython ZMQ controller
152 152
153 153 Parameters
154 154 ----------
155 155
156 156 url_or_file : bytes; zmq url or path to ipcontroller-client.json
157 157 Connection information for the Hub's registration. If a json connector
158 158 file is given, then likely no further configuration is necessary.
159 159 [Default: use profile]
160 160 profile : bytes
161 161 The name of the Cluster profile to be used to find connector information.
162 162 [Default: 'default']
163 163 context : zmq.Context
164 164 Pass an existing zmq.Context instance, otherwise the client will create its own.
165 165 username : bytes
166 166 set username to be passed to the Session object
167 167 debug : bool
168 168 flag for lots of message printing for debug purposes
169 169
170 170 #-------------- ssh related args ----------------
171 171 # These are args for configuring the ssh tunnel to be used
172 172 # credentials are used to forward connections over ssh to the Controller
173 173 # Note that the ip given in `addr` needs to be relative to sshserver
174 174 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
175 175 # and set sshserver as the same machine the Controller is on. However,
176 176 # the only requirement is that sshserver is able to see the Controller
177 177 # (i.e. is within the same trusted network).
178 178
179 179 sshserver : str
180 180 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
181 181 If keyfile or password is specified, and this is not, it will default to
182 182 the ip given in addr.
183 183 sshkey : str; path to public ssh key file
184 184 This specifies a key to be used in ssh login, default None.
185 185 Regular default ssh keys will be used without specifying this argument.
186 186 password : str
187 187 Your ssh password to sshserver. Note that if this is left None,
188 188 you will be prompted for it if passwordless key based login is unavailable.
189 189 paramiko : bool
190 190 flag for whether to use paramiko instead of shell ssh for tunneling.
191 191 [default: True on win32, False else]
192 192
193 193 #------- exec authentication args -------
194 194 # If even localhost is untrusted, you can have some protection against
195 195 # unauthorized execution by using a key. Messages are still sent
196 196 # as cleartext, so if someone can snoop your loopback traffic this will
197 197 # not help against malicious attacks.
198 198
199 199 exec_key : str
200 200 an authentication key or file containing a key
201 201 default: None
202 202
203 203
204 204 Attributes
205 205 ----------
206 206 ids : set of int engine IDs
207 207 requesting the ids attribute always synchronizes
208 208 the registration state. To request ids without synchronization,
209 209 use semi-private _ids attributes.
210 210
211 211 history : list of msg_ids
212 212 a list of msg_ids, keeping track of all the execution
213 213 messages you have submitted in order.
214 214
215 215 outstanding : set of msg_ids
216 216 a set of msg_ids that have been submitted, but whose
217 217 results have not yet been received.
218 218
219 219 results : dict
220 220 a dict of all our results, keyed by msg_id
221 221
222 222 block : bool
223 223 determines default behavior when block not specified
224 224 in execution methods
225 225
226 226 Methods
227 227 -------
228 228 spin : flushes incoming results and registration state changes
229 229 control methods spin, and requesting `ids` also ensures up to date
230 230
231 231 barrier : wait on one or more msg_ids
232 232
233 233 execution methods: apply/apply_bound/apply_to/apply_bound
234 234 legacy: execute, run
235 235
236 236 query methods: queue_status, get_result, purge
237 237
238 238 control methods: abort, kill
239 239
240 240 """
241 241
242 242
243 243 _connected=False
244 244 _ssh=False
245 245 _engines=None
246 246 _registration_socket=None
247 247 _query_socket=None
248 248 _control_socket=None
249 249 _iopub_socket=None
250 250 _notification_socket=None
251 251 _mux_socket=None
252 252 _task_socket=None
253 253 _task_scheme=None
254 254 block = False
255 255 outstanding=None
256 256 results = None
257 257 history = None
258 258 debug = False
259 259 targets = None
260 260
261 261 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
262 262 context=None, username=None, debug=False, exec_key=None,
263 263 sshserver=None, sshkey=None, password=None, paramiko=None,
264 264 ):
265 265 if context is None:
266 266 context = zmq.Context()
267 267 self.context = context
268 268 self.targets = 'all'
269 269
270 270 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
271 271 if self._cd is not None:
272 272 if url_or_file is None:
273 273 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
274 274 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
275 275 " Please specify at least one of url_or_file or profile."
276 276
277 277 try:
278 278 validate_url(url_or_file)
279 279 except AssertionError:
280 280 if not os.path.exists(url_or_file):
281 281 if self._cd:
282 282 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
283 283 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
284 284 with open(url_or_file) as f:
285 285 cfg = json.loads(f.read())
286 286 else:
287 287 cfg = {'url':url_or_file}
288 288
289 289 # sync defaults from args, json:
290 290 if sshserver:
291 291 cfg['ssh'] = sshserver
292 292 if exec_key:
293 293 cfg['exec_key'] = exec_key
294 294 exec_key = cfg['exec_key']
295 295 sshserver=cfg['ssh']
296 296 url = cfg['url']
297 297 location = cfg.setdefault('location', None)
298 298 cfg['url'] = disambiguate_url(cfg['url'], location)
299 299 url = cfg['url']
300 300
301 301 self._config = cfg
302 302
303 303 self._ssh = bool(sshserver or sshkey or password)
304 304 if self._ssh and sshserver is None:
305 305 # default to ssh via localhost
306 306 sshserver = url.split('://')[1].split(':')[0]
307 307 if self._ssh and password is None:
308 308 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
309 309 password=False
310 310 else:
311 311 password = getpass("SSH Password for %s: "%sshserver)
312 312 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
313 313 if exec_key is not None and os.path.isfile(exec_key):
314 314 arg = 'keyfile'
315 315 else:
316 316 arg = 'key'
317 317 key_arg = {arg:exec_key}
318 318 if username is None:
319 319 self.session = ss.StreamSession(**key_arg)
320 320 else:
321 321 self.session = ss.StreamSession(username, **key_arg)
322 322 self._registration_socket = self.context.socket(zmq.XREQ)
323 323 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
324 324 if self._ssh:
325 325 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
326 326 else:
327 327 self._registration_socket.connect(url)
328 328 self._engines = ReverseDict()
329 329 self._ids = []
330 330 self.outstanding=set()
331 331 self.results = {}
332 332 self.metadata = {}
333 333 self.history = []
334 334 self.debug = debug
335 335 self.session.debug = debug
336 336
337 337 self._notification_handlers = {'registration_notification' : self._register_engine,
338 338 'unregistration_notification' : self._unregister_engine,
339 339 }
340 340 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
341 341 'apply_reply' : self._handle_apply_reply}
342 342 self._connect(sshserver, ssh_kwargs)
343 343
344 344
345 345 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
346 346 if ipython_dir is None:
347 347 ipython_dir = get_ipython_dir()
348 348 if cluster_dir is not None:
349 349 try:
350 350 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
351 351 except ClusterDirError:
352 352 pass
353 353 elif profile is not None:
354 354 try:
355 355 self._cd = ClusterDir.find_cluster_dir_by_profile(
356 356 ipython_dir, profile)
357 357 except ClusterDirError:
358 358 pass
359 359 else:
360 360 self._cd = None
361 361
362 362 @property
363 363 def ids(self):
364 364 """Always up-to-date ids property."""
365 365 self._flush_notifications()
366 366 return self._ids
367 367
368 368 def _update_engines(self, engines):
369 369 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
370 370 for k,v in engines.iteritems():
371 371 eid = int(k)
372 372 self._engines[eid] = bytes(v) # force not unicode
373 373 self._ids.append(eid)
374 374 self._ids = sorted(self._ids)
375 375 if sorted(self._engines.keys()) != range(len(self._engines)) and \
376 376 self._task_scheme == 'pure' and self._task_socket:
377 377 self._stop_scheduling_tasks()
378 378
379 379 def _stop_scheduling_tasks(self):
380 380 """Stop scheduling tasks because an engine has been unregistered
381 381 from a pure ZMQ scheduler.
382 382 """
383 383
384 384 self._task_socket.close()
385 385 self._task_socket = None
386 386 msg = "An engine has been unregistered, and we are using pure " +\
387 387 "ZMQ task scheduling. Task farming will be disabled."
388 388 if self.outstanding:
389 389 msg += " If you were running tasks when this happened, " +\
390 390 "some `outstanding` msg_ids may never resolve."
391 391 warnings.warn(msg, RuntimeWarning)
392 392
393 393 def _build_targets(self, targets):
394 394 """Turn valid target IDs or 'all' into two lists:
395 395 (int_ids, uuids).
396 396 """
397 397 if targets is None:
398 398 targets = self._ids
399 399 elif isinstance(targets, str):
400 400 if targets.lower() == 'all':
401 401 targets = self._ids
402 402 else:
403 403 raise TypeError("%r not valid str target, must be 'all'"%(targets))
404 404 elif isinstance(targets, int):
405 405 targets = [targets]
406 406 return [self._engines[t] for t in targets], list(targets)
407 407
408 408 def _connect(self, sshserver, ssh_kwargs):
409 409 """setup all our socket connections to the controller. This is called from
410 410 __init__."""
411 411
412 412 # Maybe allow reconnecting?
413 413 if self._connected:
414 414 return
415 415 self._connected=True
416 416
417 417 def connect_socket(s, url):
418 418 url = disambiguate_url(url, self._config['location'])
419 419 if self._ssh:
420 420 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
421 421 else:
422 422 return s.connect(url)
423 423
424 424 self.session.send(self._registration_socket, 'connection_request')
425 425 idents,msg = self.session.recv(self._registration_socket,mode=0)
426 426 if self.debug:
427 427 pprint(msg)
428 428 msg = ss.Message(msg)
429 429 content = msg.content
430 430 self._config['registration'] = dict(content)
431 431 if content.status == 'ok':
432 432 if content.mux:
433 433 self._mux_socket = self.context.socket(zmq.PAIR)
434 434 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
435 435 connect_socket(self._mux_socket, content.mux)
436 436 if content.task:
437 437 self._task_scheme, task_addr = content.task
438 438 self._task_socket = self.context.socket(zmq.PAIR)
439 439 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
440 440 connect_socket(self._task_socket, task_addr)
441 441 if content.notification:
442 442 self._notification_socket = self.context.socket(zmq.SUB)
443 443 connect_socket(self._notification_socket, content.notification)
444 444 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
445 445 if content.query:
446 446 self._query_socket = self.context.socket(zmq.PAIR)
447 447 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
448 448 connect_socket(self._query_socket, content.query)
449 449 if content.control:
450 450 self._control_socket = self.context.socket(zmq.PAIR)
451 451 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
452 452 connect_socket(self._control_socket, content.control)
453 453 if content.iopub:
454 454 self._iopub_socket = self.context.socket(zmq.SUB)
455 455 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
456 456 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
457 457 connect_socket(self._iopub_socket, content.iopub)
458 458 self._update_engines(dict(content.engines))
459 459
460 460 else:
461 461 self._connected = False
462 462 raise Exception("Failed to connect!")
463 463
464 464 #--------------------------------------------------------------------------
465 465 # handlers and callbacks for incoming messages
466 466 #--------------------------------------------------------------------------
467 467
468 468 def _register_engine(self, msg):
469 469 """Register a new engine, and update our connection info."""
470 470 content = msg['content']
471 471 eid = content['id']
472 472 d = {eid : content['queue']}
473 473 self._update_engines(d)
474 474
475 475 def _unregister_engine(self, msg):
476 476 """Unregister an engine that has died."""
477 477 content = msg['content']
478 478 eid = int(content['id'])
479 479 if eid in self._ids:
480 480 self._ids.remove(eid)
481 481 self._engines.pop(eid)
482 482 if self._task_socket and self._task_scheme == 'pure':
483 483 self._stop_scheduling_tasks()
484 484
485 485 def _extract_metadata(self, header, parent, content):
486 486 md = {'msg_id' : parent['msg_id'],
487 487 'received' : datetime.now(),
488 488 'engine_uuid' : header.get('engine', None),
489 489 'follow' : parent.get('follow', []),
490 490 'after' : parent.get('after', []),
491 491 'status' : content['status'],
492 492 }
493 493
494 494 if md['engine_uuid'] is not None:
495 495 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
496 496
497 497 if 'date' in parent:
498 498 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
499 499 if 'started' in header:
500 500 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
501 501 if 'date' in header:
502 502 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
503 503 return md
504 504
505 505 def _handle_execute_reply(self, msg):
506 506 """Save the reply to an execute_request into our results.
507 507
508 508 execute messages are never actually used. apply is used instead.
509 509 """
510 510
511 511 parent = msg['parent_header']
512 512 msg_id = parent['msg_id']
513 513 if msg_id not in self.outstanding:
514 514 if msg_id in self.history:
515 515 print ("got stale result: %s"%msg_id)
516 516 else:
517 517 print ("got unknown result: %s"%msg_id)
518 518 else:
519 519 self.outstanding.remove(msg_id)
520 520 self.results[msg_id] = ss.unwrap_exception(msg['content'])
521 521
522 522 def _handle_apply_reply(self, msg):
523 523 """Save the reply to an apply_request into our results."""
524 524 parent = msg['parent_header']
525 525 msg_id = parent['msg_id']
526 526 if msg_id not in self.outstanding:
527 527 if msg_id in self.history:
528 528 print ("got stale result: %s"%msg_id)
529 529 print self.results[msg_id]
530 530 print msg
531 531 else:
532 532 print ("got unknown result: %s"%msg_id)
533 533 else:
534 534 self.outstanding.remove(msg_id)
535 535 content = msg['content']
536 536 header = msg['header']
537 537
538 538 # construct metadata:
539 539 md = self.metadata.setdefault(msg_id, Metadata())
540 540 md.update(self._extract_metadata(header, parent, content))
541 541 self.metadata[msg_id] = md
542 542
543 543 # construct result:
544 544 if content['status'] == 'ok':
545 545 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
546 546 elif content['status'] == 'aborted':
547 547 self.results[msg_id] = error.AbortedTask(msg_id)
548 548 elif content['status'] == 'resubmitted':
549 549 # TODO: handle resubmission
550 550 pass
551 551 else:
552 552 e = ss.unwrap_exception(content)
553 553 if e.engine_info:
554 554 e_uuid = e.engine_info['engineid']
555 555 eid = self._engines[e_uuid]
556 556 e.engine_info['engineid'] = eid
557 557 self.results[msg_id] = e
558 558
559 559 def _flush_notifications(self):
560 560 """Flush notifications of engine registrations waiting
561 561 in ZMQ queue."""
562 562 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
563 563 while msg is not None:
564 564 if self.debug:
565 565 pprint(msg)
566 566 msg = msg[-1]
567 567 msg_type = msg['msg_type']
568 568 handler = self._notification_handlers.get(msg_type, None)
569 569 if handler is None:
570 570 raise Exception("Unhandled message type: %s"%msg.msg_type)
571 571 else:
572 572 handler(msg)
573 573 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
574 574
575 575 def _flush_results(self, sock):
576 576 """Flush task or queue results waiting in ZMQ queue."""
577 577 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
578 578 while msg is not None:
579 579 if self.debug:
580 580 pprint(msg)
581 581 msg = msg[-1]
582 582 msg_type = msg['msg_type']
583 583 handler = self._queue_handlers.get(msg_type, None)
584 584 if handler is None:
585 585 raise Exception("Unhandled message type: %s"%msg.msg_type)
586 586 else:
587 587 handler(msg)
588 588 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
589 589
590 590 def _flush_control(self, sock):
591 591 """Flush replies from the control channel waiting
592 592 in the ZMQ queue.
593 593
594 594 Currently: ignore them."""
595 595 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
596 596 while msg is not None:
597 597 if self.debug:
598 598 pprint(msg)
599 599 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
600 600
601 601 def _flush_iopub(self, sock):
602 602 """Flush replies from the iopub channel waiting
603 603 in the ZMQ queue.
604 604 """
605 605 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
606 606 while msg is not None:
607 607 if self.debug:
608 608 pprint(msg)
609 609 msg = msg[-1]
610 610 parent = msg['parent_header']
611 611 msg_id = parent['msg_id']
612 612 content = msg['content']
613 613 header = msg['header']
614 614 msg_type = msg['msg_type']
615 615
616 616 # init metadata:
617 617 md = self.metadata.setdefault(msg_id, Metadata())
618 618
619 619 if msg_type == 'stream':
620 620 name = content['name']
621 621 s = md[name] or ''
622 622 md[name] = s + content['data']
623 623 elif msg_type == 'pyerr':
624 624 md.update({'pyerr' : ss.unwrap_exception(content)})
625 625 else:
626 626 md.update({msg_type : content['data']})
627 627
628 628 self.metadata[msg_id] = md
629 629
630 630 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
631 631
632 632 #--------------------------------------------------------------------------
633 633 # getitem
634 634 #--------------------------------------------------------------------------
635 635
636 636 def __getitem__(self, key):
637 637 """Dict access returns DirectView multiplexer objects or,
638 638 if key is None, a LoadBalancedView."""
639 639 if key is None:
640 640 return LoadBalancedView(self)
641 641 if isinstance(key, int):
642 642 if key not in self.ids:
643 643 raise IndexError("No such engine: %i"%key)
644 644 return DirectView(self, key)
645 645
646 646 if isinstance(key, slice):
647 647 indices = range(len(self.ids))[key]
648 648 ids = sorted(self._ids)
649 649 key = [ ids[i] for i in indices ]
650 650 # newkeys = sorted(self._ids)[thekeys[k]]
651 651
652 652 if isinstance(key, (tuple, list, xrange)):
653 653 _,targets = self._build_targets(list(key))
654 654 return DirectView(self, targets)
655 655 else:
656 656 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
657 657
658 658 #--------------------------------------------------------------------------
659 659 # Begin public methods
660 660 #--------------------------------------------------------------------------
661 661
662 662 @property
663 663 def remote(self):
664 664 """property for convenient RemoteFunction generation.
665 665
666 666 >>> @client.remote
667 667 ... def getpid():
668 668 import os
669 669 return os.getpid()
670 670 """
671 671 return remote(self, block=self.block)
672 672
673 673 def spin(self):
674 674 """Flush any registration notifications and execution results
675 675 waiting in the ZMQ queue.
676 676 """
677 677 if self._notification_socket:
678 678 self._flush_notifications()
679 679 if self._mux_socket:
680 680 self._flush_results(self._mux_socket)
681 681 if self._task_socket:
682 682 self._flush_results(self._task_socket)
683 683 if self._control_socket:
684 684 self._flush_control(self._control_socket)
685 685 if self._iopub_socket:
686 686 self._flush_iopub(self._iopub_socket)
687 687
688 688 def barrier(self, msg_ids=None, timeout=-1):
689 689 """waits on one or more `msg_ids`, for up to `timeout` seconds.
690 690
691 691 Parameters
692 692 ----------
693 693 msg_ids : int, str, or list of ints and/or strs, or one or more AsyncResult objects
694 694 ints are indices to self.history
695 695 strs are msg_ids
696 696 default: wait on all outstanding messages
697 697 timeout : float
698 698 a time in seconds, after which to give up.
699 699 default is -1, which means no timeout
700 700
701 701 Returns
702 702 -------
703 703 True : when all msg_ids are done
704 704 False : timeout reached, some msg_ids still outstanding
705 705 """
706 706 tic = time.time()
707 707 if msg_ids is None:
708 708 theids = self.outstanding
709 709 else:
710 710 if isinstance(msg_ids, (int, str, AsyncResult)):
711 711 msg_ids = [msg_ids]
712 712 theids = set()
713 713 for msg_id in msg_ids:
714 714 if isinstance(msg_id, int):
715 715 msg_id = self.history[msg_id]
716 716 elif isinstance(msg_id, AsyncResult):
717 717 map(theids.add, msg_id.msg_ids)
718 718 continue
719 719 theids.add(msg_id)
720 720 if not theids.intersection(self.outstanding):
721 721 return True
722 722 self.spin()
723 723 while theids.intersection(self.outstanding):
724 724 if timeout >= 0 and ( time.time()-tic ) > timeout:
725 725 break
726 726 time.sleep(1e-3)
727 727 self.spin()
728 728 return len(theids.intersection(self.outstanding)) == 0
729 729
730 730 #--------------------------------------------------------------------------
731 731 # Control methods
732 732 #--------------------------------------------------------------------------
733 733
734 734 @spinfirst
735 735 @defaultblock
736 736 def clear(self, targets=None, block=None):
737 737 """Clear the namespace in target(s)."""
738 738 targets = self._build_targets(targets)[0]
739 739 for t in targets:
740 740 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
741 741 error = False
742 742 if self.block:
743 743 for i in range(len(targets)):
744 744 idents,msg = self.session.recv(self._control_socket,0)
745 745 if self.debug:
746 746 pprint(msg)
747 747 if msg['content']['status'] != 'ok':
748 748 error = ss.unwrap_exception(msg['content'])
749 749 if error:
750 750 return error
751 751
752 752
753 753 @spinfirst
754 754 @defaultblock
755 755 def abort(self, msg_ids = None, targets=None, block=None):
756 756 """Abort the execution queues of target(s)."""
757 757 targets = self._build_targets(targets)[0]
758 758 if isinstance(msg_ids, basestring):
759 759 msg_ids = [msg_ids]
760 760 content = dict(msg_ids=msg_ids)
761 761 for t in targets:
762 762 self.session.send(self._control_socket, 'abort_request',
763 763 content=content, ident=t)
764 764 error = False
765 765 if self.block:
766 766 for i in range(len(targets)):
767 767 idents,msg = self.session.recv(self._control_socket,0)
768 768 if self.debug:
769 769 pprint(msg)
770 770 if msg['content']['status'] != 'ok':
771 771 error = ss.unwrap_exception(msg['content'])
772 772 if error:
773 773 return error
774 774
775 775 @spinfirst
776 776 @defaultblock
777 777 def shutdown(self, targets=None, restart=False, controller=False, block=None):
778 778 """Terminates one or more engine processes, optionally including the controller."""
779 779 if controller:
780 780 targets = 'all'
781 781 targets = self._build_targets(targets)[0]
782 782 for t in targets:
783 783 self.session.send(self._control_socket, 'shutdown_request',
784 784 content={'restart':restart},ident=t)
785 785 error = False
786 786 if block or controller:
787 787 for i in range(len(targets)):
788 788 idents,msg = self.session.recv(self._control_socket,0)
789 789 if self.debug:
790 790 pprint(msg)
791 791 if msg['content']['status'] != 'ok':
792 792 error = ss.unwrap_exception(msg['content'])
793 793
794 794 if controller:
795 795 time.sleep(0.25)
796 796 self.session.send(self._query_socket, 'shutdown_request')
797 797 idents,msg = self.session.recv(self._query_socket, 0)
798 798 if self.debug:
799 799 pprint(msg)
800 800 if msg['content']['status'] != 'ok':
801 801 error = ss.unwrap_exception(msg['content'])
802 802
803 803 if error:
804 804 raise error
805 805
806 806 #--------------------------------------------------------------------------
807 807 # Execution methods
808 808 #--------------------------------------------------------------------------
809 809
810 810 @defaultblock
811 811 def execute(self, code, targets='all', block=None):
812 812 """Executes `code` on `targets` in blocking or nonblocking manner.
813 813
814 814 ``execute`` is always `bound` (affects engine namespace)
815 815
816 816 Parameters
817 817 ----------
818 818 code : str
819 819 the code string to be executed
820 820 targets : int/str/list of ints/strs
821 821 the engines on which to execute
822 822 default : all
823 823 block : bool
824 824 whether or not to wait until done to return
825 825 default: self.block
826 826 """
827 827 result = self.apply(_execute, (code,), targets=targets, block=self.block, bound=True)
828 828 return result
829 829
830 830 def run(self, filename, targets='all', block=None):
831 831 """Execute contents of `filename` on engine(s).
832 832
833 833 This simply reads the contents of the file and calls `execute`.
834 834
835 835 Parameters
836 836 ----------
837 837 filename : str
838 838 The path to the file
839 839 targets : int/str/list of ints/strs
840 840 the engines on which to execute
841 841 default : all
842 842 block : bool
843 843 whether or not to wait until done
844 844 default: self.block
845 845
846 846 """
847 847 with open(filename, 'rb') as f:
848 848 code = f.read()
849 849 return self.execute(code, targets=targets, block=block)
850 850
851 851 def _maybe_raise(self, result):
852 852 """wrapper for maybe raising an exception if apply failed."""
853 853 if isinstance(result, error.RemoteError):
854 854 raise result
855 855
856 856 return result
857 857
858 858 def _build_dependency(self, dep):
859 859 """helper for building jsonable dependencies from various input forms"""
860 860 if isinstance(dep, Dependency):
861 861 return dep.as_dict()
862 862 elif isinstance(dep, AsyncResult):
863 863 return dep.msg_ids
864 864 elif dep is None:
865 865 return []
866 866 else:
867 867 # pass to Dependency constructor
868 868 return list(Dependency(dep))
869 869
870 870 @defaultblock
871 871 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
872 872 after=None, follow=None, timeout=None):
873 873 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
874 874
875 875 This is the central execution command for the client.
876 876
877 877 Parameters
878 878 ----------
879 879
880 880 f : function
881 881 The fuction to be called remotely
882 882 args : tuple/list
883 883 The positional arguments passed to `f`
884 884 kwargs : dict
885 885 The keyword arguments passed to `f`
886 886 bound : bool (default: True)
887 887 Whether to execute in the Engine(s) namespace, or in a clean
888 888 namespace not affecting the engine.
889 889 block : bool (default: self.block)
890 890 Whether to wait for the result, or return immediately.
891 891 False:
892 892 returns AsyncResult
893 893 True:
894 894 returns actual result(s) of f(*args, **kwargs)
895 895 if multiple targets:
896 896 list of results, matching `targets`
897 897 targets : int,list of ints, 'all', None
898 898 Specify the destination of the job.
899 899 if None:
900 900 Submit via Task queue for load-balancing.
901 901 if 'all':
902 902 Run on all active engines
903 903 if list:
904 904 Run on each specified engine
905 905 if int:
906 906 Run on single engine
907 907
908 908 after,follow,timeout only used in `apply_balanced`. See that docstring
909 909 for details.
910 910
911 911 Returns
912 912 -------
913 913 if block is False:
914 914 return AsyncResult wrapping msg_ids
915 915 output of AsyncResult.get() is identical to that of `apply(...block=True)`
916 916 else:
917 917 if single target:
918 918 return result of `f(*args, **kwargs)`
919 919 else:
920 920 return list of results, matching `targets`
921 921 """
922 922
923 923 # defaults:
924 924 block = block if block is not None else self.block
925 925 args = args if args is not None else []
926 926 kwargs = kwargs if kwargs is not None else {}
927 927
928 928 # enforce types of f,args,kwrags
929 929 if not callable(f):
930 930 raise TypeError("f must be callable, not %s"%type(f))
931 931 if not isinstance(args, (tuple, list)):
932 932 raise TypeError("args must be tuple or list, not %s"%type(args))
933 933 if not isinstance(kwargs, dict):
934 934 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
935 935
936 936 options = dict(bound=bound, block=block, targets=targets)
937 937
938 938 if targets is None:
939 939 return self.apply_balanced(f, args, kwargs, timeout=timeout,
940 940 after=after, follow=follow, **options)
941 941 else:
942 942 if follow or after or timeout:
943 943 msg = "follow, after, and timeout args are only used for load-balanced"
944 944 msg += "execution."
945 945 raise ValueError(msg)
946 946 return self._apply_direct(f, args, kwargs, **options)
947 947
948 948 @defaultblock
949 949 def apply_balanced(self, f, args, kwargs, bound=True, block=None, targets=None,
950 950 after=None, follow=None, timeout=None):
951 951 """call f(*args, **kwargs) remotely in a load-balanced manner.
952 952
953 953 Parameters
954 954 ----------
955 955
956 956 f : function
957 957 The fuction to be called remotely
958 958 args : tuple/list
959 959 The positional arguments passed to `f`
960 960 kwargs : dict
961 961 The keyword arguments passed to `f`
962 962 bound : bool (default: True)
963 963 Whether to execute in the Engine(s) namespace, or in a clean
964 964 namespace not affecting the engine.
965 965 block : bool (default: self.block)
966 966 Whether to wait for the result, or return immediately.
967 967 False:
968 968 returns AsyncResult
969 969 True:
970 970 returns actual result(s) of f(*args, **kwargs)
971 971 if multiple targets:
972 972 list of results, matching `targets`
973 973 targets : int,list of ints, 'all', None
974 974 Specify the destination of the job.
975 975 if None:
976 976 Submit via Task queue for load-balancing.
977 977 if 'all':
978 978 Run on all active engines
979 979 if list:
980 980 Run on each specified engine
981 981 if int:
982 982 Run on single engine
983 983
984 984 after : Dependency or collection of msg_ids
985 985 Only for load-balanced execution (targets=None)
986 986 Specify a list of msg_ids as a time-based dependency.
987 987 This job will only be run *after* the dependencies
988 988 have been met.
989 989
990 990 follow : Dependency or collection of msg_ids
991 991 Only for load-balanced execution (targets=None)
992 992 Specify a list of msg_ids as a location-based dependency.
993 993 This job will only be run on an engine where this dependency
994 994 is met.
995 995
996 996 timeout : float/int or None
997 997 Only for load-balanced execution (targets=None)
998 998 Specify an amount of time (in seconds) for the scheduler to
999 999 wait for dependencies to be met before failing with a
1000 1000 DependencyTimeout.
1001 1001
1002 1002 Returns
1003 1003 -------
1004 1004 if block is False:
1005 1005 return AsyncResult wrapping msg_id
1006 1006 output of AsyncResult.get() is identical to that of `apply(...block=True)`
1007 1007 else:
1008 1008 wait for, and return actual result of `f(*args, **kwargs)`
1009 1009
1010 1010 """
1011 1011
1012 1012 if self._task_socket is None:
1013 1013 msg = "Task farming is disabled"
1014 1014 if self._task_scheme == 'pure':
1015 1015 msg += " because the pure ZMQ scheduler cannot handle"
1016 1016 msg += " disappearing engines."
1017 1017 raise RuntimeError(msg)
1018 1018
1019 1019 if self._task_scheme == 'pure':
1020 1020 # pure zmq scheme doesn't support dependencies
1021 1021 msg = "Pure ZMQ scheduler doesn't support dependencies"
1022 1022 if (follow or after):
1023 1023 # hard fail on DAG dependencies
1024 1024 raise RuntimeError(msg)
1025 1025 if isinstance(f, dependent):
1026 1026 # soft warn on functional dependencies
1027 1027 warnings.warn(msg, RuntimeWarning)
1028 1028
1029 1029
1030 1030 # defaults:
1031 1031 args = args if args is not None else []
1032 1032 kwargs = kwargs if kwargs is not None else {}
1033 1033
1034 1034 if targets:
1035 1035 idents,_ = self._build_targets(targets)
1036 1036 else:
1037 1037 idents = []
1038 1038
1039 1039 # enforce types of f,args,kwrags
1040 1040 if not callable(f):
1041 1041 raise TypeError("f must be callable, not %s"%type(f))
1042 1042 if not isinstance(args, (tuple, list)):
1043 1043 raise TypeError("args must be tuple or list, not %s"%type(args))
1044 1044 if not isinstance(kwargs, dict):
1045 1045 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1046 1046
1047 1047 after = self._build_dependency(after)
1048 1048 follow = self._build_dependency(follow)
1049 1049 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1050 1050 bufs = ss.pack_apply_message(f,args,kwargs)
1051 1051 content = dict(bound=bound)
1052 1052
1053 1053 msg = self.session.send(self._task_socket, "apply_request",
1054 1054 content=content, buffers=bufs, subheader=subheader)
1055 1055 msg_id = msg['msg_id']
1056 1056 self.outstanding.add(msg_id)
1057 1057 self.history.append(msg_id)
1058 1058 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1059 1059 if block:
1060 1060 try:
1061 1061 return ar.get()
1062 1062 except KeyboardInterrupt:
1063 1063 return ar
1064 1064 else:
1065 1065 return ar
1066 1066
1067 1067 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None):
1068 1068 """Then underlying method for applying functions to specific engines
1069 1069 via the MUX queue.
1070 1070
1071 1071 Not to be called directly!
1072 1072 """
1073 1073
1074 1074 idents,targets = self._build_targets(targets)
1075 1075
1076 1076 subheader = {}
1077 1077 content = dict(bound=bound)
1078 1078 bufs = ss.pack_apply_message(f,args,kwargs)
1079 1079
1080 1080 msg_ids = []
1081 1081 for ident in idents:
1082 1082 msg = self.session.send(self._mux_socket, "apply_request",
1083 1083 content=content, buffers=bufs, ident=ident, subheader=subheader)
1084 1084 msg_id = msg['msg_id']
1085 1085 self.outstanding.add(msg_id)
1086 1086 self.history.append(msg_id)
1087 1087 msg_ids.append(msg_id)
1088 1088 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1089 1089 if block:
1090 1090 try:
1091 1091 return ar.get()
1092 1092 except KeyboardInterrupt:
1093 1093 return ar
1094 1094 else:
1095 1095 return ar
1096 1096
1097 1097 #--------------------------------------------------------------------------
1098 1098 # Map and decorators
1099 1099 #--------------------------------------------------------------------------
1100 1100
1101 def map(self, f, *sequences):
1102 """Parallel version of builtin `map`, using all our engines."""
1101 def map(self, f, *sequences, **kwargs):
1102 """Parallel version of builtin `map`, using all our engines.
1103
1104 `block` and `targets` can be passed as keyword arguments only.
1105
1106 There will be one task per target, so work will be chunked
1107 if the sequences are longer than `targets`.
1108
1109 Results can be iterated as they are ready, but will become available in chunks.
1110
1111 Parameters
1112 ----------
1113
1114 f : callable
1115 function to be mapped
1116 *sequences: one or more sequences of matching length
1117 the sequences to be distributed and passed to `f`
1118 block : bool
1119 whether to wait for the result or not [default self.block]
1120 targets : valid targets
1121 targets to be used [default self.targets]
1122
1123 Returns
1124 -------
1125
1126 if block=False:
1127 AsyncMapResult
1128 An object like AsyncResult, but which reassembles the sequence of results
1129 into a single list. AsyncMapResults can be iterated through before all
1130 results are complete.
1131 else:
1132 the result of map(f,*sequences)
1133
1134 """
1135 block = kwargs.get('block', self.block)
1136 targets = kwargs.get('targets', self.targets)
1137 assert len(sequences) > 0, "must have some sequences to map onto!"
1138 pf = ParallelFunction(self, f, block=block,
1139 bound=True, targets=targets)
1140 return pf.map(*sequences)
1141
1142 def imap(self, f, *sequences, **kwargs):
1143 """Parallel version of builtin `itertools.imap`, load-balanced across all engines.
1144
1145 Each element will be a separate task, and will be load-balanced. This
1146 lets individual elements be ready for iteration as soon as they come.
1147
1148 Parameters
1149 ----------
1150
1151 f : callable
1152 function to be mapped
1153 *sequences: one or more sequences of matching length
1154 the sequences to be distributed and passed to `f`
1155 block : bool
1156 whether to wait for the result or not [default self.block]
1157
1158 Returns
1159 -------
1160
1161 if block=False:
1162 AsyncMapResult
1163 An object like AsyncResult, but which reassembles the sequence of results
1164 into a single list. AsyncMapResults can be iterated through before all
1165 results are complete.
1166 else:
1167 the result of map(f,*sequences)
1168
1169 """
1170
1171 block = kwargs.get('block', self.block)
1172
1173 assert len(sequences) > 0, "must have some sequences to map onto!"
1174
1103 1175 pf = ParallelFunction(self, f, block=self.block,
1104 bound=True, targets='all')
1176 bound=True, targets=None)
1105 1177 return pf.map(*sequences)
1106 1178
1107 1179 def parallel(self, bound=True, targets='all', block=True):
1108 1180 """Decorator for making a ParallelFunction."""
1109 1181 return parallel(self, bound=bound, targets=targets, block=block)
1110 1182
1111 1183 def remote(self, bound=True, targets='all', block=True):
1112 1184 """Decorator for making a RemoteFunction."""
1113 1185 return remote(self, bound=bound, targets=targets, block=block)
1114 1186
1115 1187 def view(self, targets=None, balanced=False):
1116 1188 """Method for constructing View objects"""
1117 1189 if not balanced:
1118 1190 if not targets:
1119 1191 targets = slice(None)
1120 1192 return self[targets]
1121 1193 else:
1122 1194 return LoadBalancedView(self, targets)
1123 1195
1124 1196 #--------------------------------------------------------------------------
1125 1197 # Data movement
1126 1198 #--------------------------------------------------------------------------
1127 1199
1128 1200 @defaultblock
1129 1201 def push(self, ns, targets='all', block=None):
1130 1202 """Push the contents of `ns` into the namespace on `target`"""
1131 1203 if not isinstance(ns, dict):
1132 1204 raise TypeError("Must be a dict, not %s"%type(ns))
1133 1205 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
1134 1206 return result
1135 1207
1136 1208 @defaultblock
1137 1209 def pull(self, keys, targets='all', block=None):
1138 1210 """Pull objects from `target`'s namespace by `keys`"""
1139 1211 if isinstance(keys, str):
1140 1212 pass
1141 1213 elif isinstance(keys, (list,tuple,set)):
1142 1214 for key in keys:
1143 1215 if not isinstance(key, str):
1144 1216 raise TypeError
1145 1217 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
1146 1218 return result
1147 1219
1148 1220 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1149 1221 """
1150 1222 Partition a Python sequence and send the partitions to a set of engines.
1151 1223 """
1152 1224 block = block if block is not None else self.block
1153 1225 targets = self._build_targets(targets)[-1]
1154 1226 mapObject = Map.dists[dist]()
1155 1227 nparts = len(targets)
1156 1228 msg_ids = []
1157 1229 for index, engineid in enumerate(targets):
1158 1230 partition = mapObject.getPartition(seq, index, nparts)
1159 1231 if flatten and len(partition) == 1:
1160 1232 r = self.push({key: partition[0]}, targets=engineid, block=False)
1161 1233 else:
1162 1234 r = self.push({key: partition}, targets=engineid, block=False)
1163 1235 msg_ids.extend(r.msg_ids)
1164 1236 r = AsyncResult(self, msg_ids, fname='scatter')
1165 1237 if block:
1166 1238 return r.get()
1167 1239 else:
1168 1240 return r
1169 1241
1170 1242 def gather(self, key, dist='b', targets='all', block=None):
1171 1243 """
1172 1244 Gather a partitioned sequence on a set of engines as a single local seq.
1173 1245 """
1174 1246 block = block if block is not None else self.block
1175 1247
1176 1248 targets = self._build_targets(targets)[-1]
1177 1249 mapObject = Map.dists[dist]()
1178 1250 msg_ids = []
1179 1251 for index, engineid in enumerate(targets):
1180 1252 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1181 1253
1182 1254 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1183 1255 if block:
1184 1256 return r.get()
1185 1257 else:
1186 1258 return r
1187 1259
1188 1260 #--------------------------------------------------------------------------
1189 1261 # Query methods
1190 1262 #--------------------------------------------------------------------------
1191 1263
1192 1264 @spinfirst
1193 1265 def get_results(self, msg_ids, status_only=False):
1194 1266 """Returns the result of the execute or task request with `msg_ids`.
1195 1267
1196 1268 Parameters
1197 1269 ----------
1198 1270 msg_ids : list of ints or msg_ids
1199 1271 if int:
1200 1272 Passed as index to self.history for convenience.
1201 1273 status_only : bool (default: False)
1202 1274 if False:
1203 1275 return the actual results
1204 1276
1205 1277 Returns
1206 1278 -------
1207 1279
1208 1280 results : dict
1209 1281 There will always be the keys 'pending' and 'completed', which will
1210 1282 be lists of msg_ids.
1211 1283 """
1212 1284 if not isinstance(msg_ids, (list,tuple)):
1213 1285 msg_ids = [msg_ids]
1214 1286 theids = []
1215 1287 for msg_id in msg_ids:
1216 1288 if isinstance(msg_id, int):
1217 1289 msg_id = self.history[msg_id]
1218 1290 if not isinstance(msg_id, str):
1219 1291 raise TypeError("msg_ids must be str, not %r"%msg_id)
1220 1292 theids.append(msg_id)
1221 1293
1222 1294 completed = []
1223 1295 local_results = {}
1224 1296
1225 1297 # comment this block out to temporarily disable local shortcut:
1226 1298 for msg_id in list(theids):
1227 1299 if msg_id in self.results:
1228 1300 completed.append(msg_id)
1229 1301 local_results[msg_id] = self.results[msg_id]
1230 1302 theids.remove(msg_id)
1231 1303
1232 1304 if theids: # some not locally cached
1233 1305 content = dict(msg_ids=theids, status_only=status_only)
1234 1306 msg = self.session.send(self._query_socket, "result_request", content=content)
1235 1307 zmq.select([self._query_socket], [], [])
1236 1308 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1237 1309 if self.debug:
1238 1310 pprint(msg)
1239 1311 content = msg['content']
1240 1312 if content['status'] != 'ok':
1241 1313 raise ss.unwrap_exception(content)
1242 1314 buffers = msg['buffers']
1243 1315 else:
1244 1316 content = dict(completed=[],pending=[])
1245 1317
1246 1318 content['completed'].extend(completed)
1247 1319
1248 1320 if status_only:
1249 1321 return content
1250 1322
1251 1323 failures = []
1252 1324 # load cached results into result:
1253 1325 content.update(local_results)
1254 1326 # update cache with results:
1255 1327 for msg_id in sorted(theids):
1256 1328 if msg_id in content['completed']:
1257 1329 rec = content[msg_id]
1258 1330 parent = rec['header']
1259 1331 header = rec['result_header']
1260 1332 rcontent = rec['result_content']
1261 1333 iodict = rec['io']
1262 1334 if isinstance(rcontent, str):
1263 1335 rcontent = self.session.unpack(rcontent)
1264 1336
1265 1337 md = self.metadata.setdefault(msg_id, Metadata())
1266 1338 md.update(self._extract_metadata(header, parent, rcontent))
1267 1339 md.update(iodict)
1268 1340
1269 1341 if rcontent['status'] == 'ok':
1270 1342 res,buffers = ss.unserialize_object(buffers)
1271 1343 else:
1272 1344 res = ss.unwrap_exception(rcontent)
1273 1345 failures.append(res)
1274 1346
1275 1347 self.results[msg_id] = res
1276 1348 content[msg_id] = res
1277 1349
1278 1350 error.collect_exceptions(failures, "get_results")
1279 1351 return content
1280 1352
1281 1353 @spinfirst
1282 1354 def queue_status(self, targets=None, verbose=False):
1283 1355 """Fetch the status of engine queues.
1284 1356
1285 1357 Parameters
1286 1358 ----------
1287 1359 targets : int/str/list of ints/strs
1288 1360 the engines on which to execute
1289 1361 default : all
1290 1362 verbose : bool
1291 1363 Whether to return lengths only, or lists of ids for each element
1292 1364 """
1293 1365 targets = self._build_targets(targets)[1]
1294 1366 content = dict(targets=targets, verbose=verbose)
1295 1367 self.session.send(self._query_socket, "queue_request", content=content)
1296 1368 idents,msg = self.session.recv(self._query_socket, 0)
1297 1369 if self.debug:
1298 1370 pprint(msg)
1299 1371 content = msg['content']
1300 1372 status = content.pop('status')
1301 1373 if status != 'ok':
1302 1374 raise ss.unwrap_exception(content)
1303 1375 return ss.rekey(content)
1304 1376
1305 1377 @spinfirst
1306 1378 def purge_results(self, msg_ids=[], targets=[]):
1307 1379 """Tell the controller to forget results.
1308 1380
1309 1381 Individual results can be purged by msg_id, or the entire
1310 1382 history of specific targets can be purged.
1311 1383
1312 1384 Parameters
1313 1385 ----------
1314 1386 msg_ids : str or list of strs
1315 1387 the msg_ids whose results should be forgotten.
1316 1388 targets : int/str/list of ints/strs
1317 1389 The targets, by uuid or int_id, whose entire history is to be purged.
1318 1390 Use `targets='all'` to scrub everything from the controller's memory.
1319 1391
1320 1392 default : None
1321 1393 """
1322 1394 if not targets and not msg_ids:
1323 1395 raise ValueError
1324 1396 if targets:
1325 1397 targets = self._build_targets(targets)[1]
1326 1398 content = dict(targets=targets, msg_ids=msg_ids)
1327 1399 self.session.send(self._query_socket, "purge_request", content=content)
1328 1400 idents, msg = self.session.recv(self._query_socket, 0)
1329 1401 if self.debug:
1330 1402 pprint(msg)
1331 1403 content = msg['content']
1332 1404 if content['status'] != 'ok':
1333 1405 raise ss.unwrap_exception(content)
1334 1406
1335 1407 #----------------------------------------
1336 1408 # activate for %px,%autopx magics
1337 1409 #----------------------------------------
1338 1410 def activate(self):
1339 1411 """Make this `View` active for parallel magic commands.
1340 1412
1341 1413 IPython has a magic command syntax to work with `MultiEngineClient` objects.
1342 1414 In a given IPython session there is a single active one. While
1343 1415 there can be many `Views` created and used by the user,
1344 1416 there is only one active one. The active `View` is used whenever
1345 1417 the magic commands %px and %autopx are used.
1346 1418
1347 1419 The activate() method is called on a given `View` to make it
1348 1420 active. Once this has been done, the magic commands can be used.
1349 1421 """
1350 1422
1351 1423 try:
1352 1424 # This is injected into __builtins__.
1353 1425 ip = get_ipython()
1354 1426 except NameError:
1355 1427 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
1356 1428 else:
1357 1429 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
1358 1430 if pmagic is not None:
1359 1431 pmagic.active_multiengine_client = self
1360 1432 else:
1361 1433 print "You must first load the parallelmagic extension " \
1362 1434 "by doing '%load_ext parallelmagic'"
1363 1435
1364 1436 class AsynClient(Client):
1365 1437 """An Asynchronous client, using the Tornado Event Loop.
1366 1438 !!!unfinished!!!"""
1367 1439 io_loop = None
1368 1440 _queue_stream = None
1369 1441 _notifier_stream = None
1370 1442 _task_stream = None
1371 1443 _control_stream = None
1372 1444
1373 1445 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
1374 1446 Client.__init__(self, addr, context, username, debug)
1375 1447 if io_loop is None:
1376 1448 io_loop = ioloop.IOLoop.instance()
1377 1449 self.io_loop = io_loop
1378 1450
1379 1451 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
1380 1452 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
1381 1453 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
1382 1454 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
1383 1455
1384 1456 def spin(self):
1385 1457 for stream in (self.queue_stream, self.notifier_stream,
1386 1458 self.task_stream, self.control_stream):
1387 1459 stream.flush()
1388 1460
1389 1461 __all__ = [ 'Client',
1390 1462 'depend',
1391 1463 'require',
1392 1464 'remote',
1393 1465 'parallel',
1394 1466 'RemoteFunction',
1395 1467 'ParallelFunction',
1396 1468 'DirectView',
1397 1469 'LoadBalancedView',
1398 1470 'AsyncResult',
1399 1471 'AsyncMapResult'
1400 1472 ]
@@ -1,483 +1,485 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 12 import __builtin__
13 13 from code import CommandCompiler
14 14 import os
15 15 import sys
16 16 import time
17 17 import traceback
18 18 import logging
19 19 from datetime import datetime
20 20 from signal import SIGTERM, SIGKILL
21 21 from pprint import pprint
22 22
23 23 # System library imports.
24 24 import zmq
25 25 from zmq.eventloop import ioloop, zmqstream
26 26
27 27 # Local imports.
28 28 from IPython.core import ultratb
29 29 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
30 30 from IPython.zmq.completer import KernelCompleter
31 31 from IPython.zmq.iostream import OutStream
32 32 from IPython.zmq.displayhook import DisplayHook
33 33
34 34 from factory import SessionFactory
35 35 from streamsession import StreamSession, Message, extract_header, serialize_object,\
36 36 unpack_apply_message, ISO8601, wrap_exception
37 37 import heartmonitor
38 38 from client import Client
39 39
40 40 def printer(*args):
41 41 pprint(args, stream=sys.__stdout__)
42 42
43 43
44 44 class _Passer:
45 45 """Empty class that implements `send()` that does nothing."""
46 46 def send(self, *args, **kwargs):
47 47 pass
48 48 send_multipart = send
49 49
50 50
51 51 #-----------------------------------------------------------------------------
52 52 # Main kernel class
53 53 #-----------------------------------------------------------------------------
54 54
55 55 class Kernel(SessionFactory):
56 56
57 57 #---------------------------------------------------------------------------
58 58 # Kernel interface
59 59 #---------------------------------------------------------------------------
60 60
61 61 # kwargs:
62 62 int_id = Int(-1, config=True)
63 63 user_ns = Dict(config=True)
64 64 exec_lines = List(config=True)
65 65
66 66 control_stream = Instance(zmqstream.ZMQStream)
67 67 task_stream = Instance(zmqstream.ZMQStream)
68 68 iopub_stream = Instance(zmqstream.ZMQStream)
69 69 client = Instance('IPython.zmq.parallel.client.Client')
70 70
71 71 # internals
72 72 shell_streams = List()
73 73 compiler = Instance(CommandCompiler, (), {})
74 74 completer = Instance(KernelCompleter)
75 75
76 76 aborted = Set()
77 77 shell_handlers = Dict()
78 78 control_handlers = Dict()
79 79
80 80 def _set_prefix(self):
81 81 self.prefix = "engine.%s"%self.int_id
82 82
83 83 def _connect_completer(self):
84 84 self.completer = KernelCompleter(self.user_ns)
85 85
86 86 def __init__(self, **kwargs):
87 87 super(Kernel, self).__init__(**kwargs)
88 88 self._set_prefix()
89 89 self._connect_completer()
90 90
91 91 self.on_trait_change(self._set_prefix, 'id')
92 92 self.on_trait_change(self._connect_completer, 'user_ns')
93 93
94 94 # Build dict of handlers for message types
95 95 for msg_type in ['execute_request', 'complete_request', 'apply_request',
96 96 'clear_request']:
97 97 self.shell_handlers[msg_type] = getattr(self, msg_type)
98 98
99 99 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
100 100 self.control_handlers[msg_type] = getattr(self, msg_type)
101 101
102 102 self._initial_exec_lines()
103 103
104 104 def _wrap_exception(self, method=None):
105 105 e_info = dict(engineid=self.ident, method=method)
106 106 content=wrap_exception(e_info)
107 107 return content
108 108
109 109 def _initial_exec_lines(self):
110 110 s = _Passer()
111 111 content = dict(silent=True, user_variable=[],user_expressions=[])
112 112 for line in self.exec_lines:
113 113 self.log.debug("executing initialization: %s"%line)
114 114 content.update({'code':line})
115 115 msg = self.session.msg('execute_request', content)
116 116 self.execute_request(s, [], msg)
117 117
118 118
119 119 #-------------------- control handlers -----------------------------
120 120 def abort_queues(self):
121 121 for stream in self.shell_streams:
122 122 if stream:
123 123 self.abort_queue(stream)
124 124
125 125 def abort_queue(self, stream):
126 126 while True:
127 127 try:
128 128 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
129 129 except zmq.ZMQError as e:
130 130 if e.errno == zmq.EAGAIN:
131 131 break
132 132 else:
133 133 return
134 134 else:
135 135 if msg is None:
136 136 return
137 137 else:
138 138 idents,msg = msg
139 139
140 140 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
141 141 # msg = self.reply_socket.recv_json()
142 142 self.log.info("Aborting:")
143 143 self.log.info(str(msg))
144 144 msg_type = msg['msg_type']
145 145 reply_type = msg_type.split('_')[0] + '_reply'
146 146 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
147 147 # self.reply_socket.send(ident,zmq.SNDMORE)
148 148 # self.reply_socket.send_json(reply_msg)
149 149 reply_msg = self.session.send(stream, reply_type,
150 150 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
151 151 self.log.debug(str(reply_msg))
152 152 # We need to wait a bit for requests to come in. This can probably
153 153 # be set shorter for true asynchronous clients.
154 154 time.sleep(0.05)
155 155
156 156 def abort_request(self, stream, ident, parent):
157 157 """abort a specifig msg by id"""
158 158 msg_ids = parent['content'].get('msg_ids', None)
159 159 if isinstance(msg_ids, basestring):
160 160 msg_ids = [msg_ids]
161 161 if not msg_ids:
162 162 self.abort_queues()
163 163 for mid in msg_ids:
164 164 self.aborted.add(str(mid))
165 165
166 166 content = dict(status='ok')
167 167 reply_msg = self.session.send(stream, 'abort_reply', content=content,
168 168 parent=parent, ident=ident)[0]
169 169 self.log.debug(str(reply_msg))
170 170
171 171 def shutdown_request(self, stream, ident, parent):
172 172 """kill ourself. This should really be handled in an external process"""
173 173 try:
174 174 self.abort_queues()
175 175 except:
176 176 content = self._wrap_exception('shutdown')
177 177 else:
178 178 content = dict(parent['content'])
179 179 content['status'] = 'ok'
180 180 msg = self.session.send(stream, 'shutdown_reply',
181 181 content=content, parent=parent, ident=ident)
182 182 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
183 183 # content, parent, ident)
184 184 # print >> sys.__stdout__, msg
185 185 # time.sleep(0.2)
186 186 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
187 187 dc.start()
188 188
189 189 def dispatch_control(self, msg):
190 190 idents,msg = self.session.feed_identities(msg, copy=False)
191 191 try:
192 192 msg = self.session.unpack_message(msg, content=True, copy=False)
193 193 except:
194 194 self.log.error("Invalid Message", exc_info=True)
195 195 return
196 196
197 197 header = msg['header']
198 198 msg_id = header['msg_id']
199 199
200 200 handler = self.control_handlers.get(msg['msg_type'], None)
201 201 if handler is None:
202 202 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
203 203 else:
204 204 handler(self.control_stream, idents, msg)
205 205
206 206
207 207 #-------------------- queue helpers ------------------------------
208 208
209 209 def check_dependencies(self, dependencies):
210 210 if not dependencies:
211 211 return True
212 212 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
213 213 anyorall = dependencies[0]
214 214 dependencies = dependencies[1]
215 215 else:
216 216 anyorall = 'all'
217 217 results = self.client.get_results(dependencies,status_only=True)
218 218 if results['status'] != 'ok':
219 219 return False
220 220
221 221 if anyorall == 'any':
222 222 if not results['completed']:
223 223 return False
224 224 else:
225 225 if results['pending']:
226 226 return False
227 227
228 228 return True
229 229
230 230 def check_aborted(self, msg_id):
231 231 return msg_id in self.aborted
232 232
233 233 #-------------------- queue handlers -----------------------------
234 234
235 235 def clear_request(self, stream, idents, parent):
236 236 """Clear our namespace."""
237 237 self.user_ns = {}
238 238 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
239 239 content = dict(status='ok'))
240 240 self._initial_exec_lines()
241 241
242 242 def execute_request(self, stream, ident, parent):
243 243 self.log.debug('execute request %s'%parent)
244 244 try:
245 245 code = parent[u'content'][u'code']
246 246 except:
247 247 self.log.error("Got bad msg: %s"%parent, exc_info=True)
248 248 return
249 249 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
250 250 ident='%s.pyin'%self.prefix)
251 251 started = datetime.now().strftime(ISO8601)
252 252 try:
253 253 comp_code = self.compiler(code, '<zmq-kernel>')
254 254 # allow for not overriding displayhook
255 255 if hasattr(sys.displayhook, 'set_parent'):
256 256 sys.displayhook.set_parent(parent)
257 257 sys.stdout.set_parent(parent)
258 258 sys.stderr.set_parent(parent)
259 259 exec comp_code in self.user_ns, self.user_ns
260 260 except:
261 261 exc_content = self._wrap_exception('execute')
262 262 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
263 263 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
264 264 ident='%s.pyerr'%self.prefix)
265 265 reply_content = exc_content
266 266 else:
267 267 reply_content = {'status' : 'ok'}
268 268
269 269 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
270 270 ident=ident, subheader = dict(started=started))
271 271 self.log.debug(str(reply_msg))
272 272 if reply_msg['content']['status'] == u'error':
273 273 self.abort_queues()
274 274
275 275 def complete_request(self, stream, ident, parent):
276 276 matches = {'matches' : self.complete(parent),
277 277 'status' : 'ok'}
278 278 completion_msg = self.session.send(stream, 'complete_reply',
279 279 matches, parent, ident)
280 280 # print >> sys.__stdout__, completion_msg
281 281
282 282 def complete(self, msg):
283 283 return self.completer.complete(msg.content.line, msg.content.text)
284 284
285 285 def apply_request(self, stream, ident, parent):
286 # print (parent)
286 # flush previous reply, so this request won't block it
287 stream.flush(zmq.POLLOUT)
288
287 289 try:
288 290 content = parent[u'content']
289 291 bufs = parent[u'buffers']
290 292 msg_id = parent['header']['msg_id']
291 293 bound = content.get('bound', False)
292 294 except:
293 295 self.log.error("Got bad msg: %s"%parent, exc_info=True)
294 296 return
295 297 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
296 298 # self.iopub_stream.send(pyin_msg)
297 299 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
298 300 sub = {'dependencies_met' : True, 'engine' : self.ident,
299 301 'started': datetime.now().strftime(ISO8601)}
300 302 try:
301 303 # allow for not overriding displayhook
302 304 if hasattr(sys.displayhook, 'set_parent'):
303 305 sys.displayhook.set_parent(parent)
304 306 sys.stdout.set_parent(parent)
305 307 sys.stderr.set_parent(parent)
306 308 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
307 309 if bound:
308 310 working = self.user_ns
309 311 suffix = str(msg_id).replace("-","")
310 312 prefix = "_"
311 313
312 314 else:
313 315 working = dict()
314 316 suffix = prefix = "_" # prevent keyword collisions with lambda
315 317 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
316 318 # if f.fun
317 319 fname = getattr(f, '__name__', 'f')
318 320
319 321 fname = prefix+fname.strip('<>')+suffix
320 322 argname = prefix+"args"+suffix
321 323 kwargname = prefix+"kwargs"+suffix
322 324 resultname = prefix+"result"+suffix
323 325
324 326 ns = { fname : f, argname : args, kwargname : kwargs }
325 327 # print ns
326 328 working.update(ns)
327 329 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
328 330 exec code in working, working
329 331 result = working.get(resultname)
330 332 # clear the namespace
331 333 if bound:
332 334 for key in ns.iterkeys():
333 335 self.user_ns.pop(key)
334 336 else:
335 337 del working
336 338
337 339 packed_result,buf = serialize_object(result)
338 340 result_buf = [packed_result]+buf
339 341 except:
340 342 exc_content = self._wrap_exception('apply')
341 343 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
342 344 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
343 345 ident='%s.pyerr'%self.prefix)
344 346 reply_content = exc_content
345 347 result_buf = []
346 348
347 349 if exc_content['ename'] == 'UnmetDependency':
348 350 sub['dependencies_met'] = False
349 351 else:
350 352 reply_content = {'status' : 'ok'}
351 353
352 354 # put 'ok'/'error' status in header, for scheduler introspection:
353 355 sub['status'] = reply_content['status']
354 356
355 357 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
356 358 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
357
359
358 360 # if reply_msg['content']['status'] == u'error':
359 361 # self.abort_queues()
360 362
361 363 def dispatch_queue(self, stream, msg):
362 364 self.control_stream.flush()
363 365 idents,msg = self.session.feed_identities(msg, copy=False)
364 366 try:
365 367 msg = self.session.unpack_message(msg, content=True, copy=False)
366 368 except:
367 369 self.log.error("Invalid Message", exc_info=True)
368 370 return
369 371
370 372
371 373 header = msg['header']
372 374 msg_id = header['msg_id']
373 375 if self.check_aborted(msg_id):
374 376 self.aborted.remove(msg_id)
375 377 # is it safe to assume a msg_id will not be resubmitted?
376 378 reply_type = msg['msg_type'].split('_')[0] + '_reply'
377 379 reply_msg = self.session.send(stream, reply_type,
378 380 content={'status' : 'aborted'}, parent=msg, ident=idents)
379 381 return
380 382 handler = self.shell_handlers.get(msg['msg_type'], None)
381 383 if handler is None:
382 384 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
383 385 else:
384 386 handler(stream, idents, msg)
385 387
386 388 def start(self):
387 389 #### stream mode:
388 390 if self.control_stream:
389 391 self.control_stream.on_recv(self.dispatch_control, copy=False)
390 392 self.control_stream.on_err(printer)
391 393
392 394 def make_dispatcher(stream):
393 395 def dispatcher(msg):
394 396 return self.dispatch_queue(stream, msg)
395 397 return dispatcher
396 398
397 399 for s in self.shell_streams:
398 400 s.on_recv(make_dispatcher(s), copy=False)
399 401 s.on_err(printer)
400 402
401 403 if self.iopub_stream:
402 404 self.iopub_stream.on_err(printer)
403 405
404 406 #### while True mode:
405 407 # while True:
406 408 # idle = True
407 409 # try:
408 410 # msg = self.shell_stream.socket.recv_multipart(
409 411 # zmq.NOBLOCK, copy=False)
410 412 # except zmq.ZMQError, e:
411 413 # if e.errno != zmq.EAGAIN:
412 414 # raise e
413 415 # else:
414 416 # idle=False
415 417 # self.dispatch_queue(self.shell_stream, msg)
416 418 #
417 419 # if not self.task_stream.empty():
418 420 # idle=False
419 421 # msg = self.task_stream.recv_multipart()
420 422 # self.dispatch_queue(self.task_stream, msg)
421 423 # if idle:
422 424 # # don't busywait
423 425 # time.sleep(1e-3)
424 426
425 427 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
426 428 client_addr=None, loop=None, context=None, key=None,
427 429 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
428 430 """NO LONGER IN USE"""
429 431 # create loop, context, and session:
430 432 if loop is None:
431 433 loop = ioloop.IOLoop.instance()
432 434 if context is None:
433 435 context = zmq.Context()
434 436 c = context
435 437 session = StreamSession(key=key)
436 438 # print (session.key)
437 439 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
438 440
439 441 # create Control Stream
440 442 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
441 443 control_stream.setsockopt(zmq.IDENTITY, identity)
442 444 control_stream.connect(control_addr)
443 445
444 446 # create Shell Streams (MUX, Task, etc.):
445 447 shell_streams = []
446 448 for addr in shell_addrs:
447 449 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
448 450 stream.setsockopt(zmq.IDENTITY, identity)
449 451 stream.connect(addr)
450 452 shell_streams.append(stream)
451 453
452 454 # create iopub stream:
453 455 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
454 456 iopub_stream.setsockopt(zmq.IDENTITY, identity)
455 457 iopub_stream.connect(iopub_addr)
456 458
457 459 # Redirect input streams and set a display hook.
458 460 if out_stream_factory:
459 461 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
460 462 sys.stdout.topic = 'engine.%i.stdout'%int_id
461 463 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
462 464 sys.stderr.topic = 'engine.%i.stderr'%int_id
463 465 if display_hook_factory:
464 466 sys.displayhook = display_hook_factory(session, iopub_stream)
465 467 sys.displayhook.topic = 'engine.%i.pyout'%int_id
466 468
467 469
468 470 # launch heartbeat
469 471 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
470 472 heart.start()
471 473
472 474 # create (optional) Client
473 475 if client_addr:
474 476 client = Client(client_addr, username=identity)
475 477 else:
476 478 client = None
477 479
478 480 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
479 481 shell_streams=shell_streams, iopub_stream=iopub_stream,
480 482 client=client, loop=loop)
481 483 kernel.start()
482 484 return loop, c, kernel
483 485
General Comments 0
You need to be logged in to leave comments. Login now