##// END OF EJS Templates
improved client.get_results() behavior
MinRK -
Show More
@@ -0,0 +1,35
1 """some generic utilities"""
2
3 class ReverseDict(dict):
4 """simple double-keyed subset of dict methods."""
5
6 def __init__(self, *args, **kwargs):
7 dict.__init__(self, *args, **kwargs)
8 self._reverse = dict()
9 for key, value in self.iteritems():
10 self._reverse[value] = key
11
12 def __getitem__(self, key):
13 try:
14 return dict.__getitem__(self, key)
15 except KeyError:
16 return self._reverse[key]
17
18 def __setitem__(self, key, value):
19 if key in self._reverse:
20 raise KeyError("Can't have key %r on both sides!"%key)
21 dict.__setitem__(self, key, value)
22 self._reverse[value] = key
23
24 def pop(self, key):
25 value = dict.pop(self, key)
26 self.d1.pop(value)
27 return value
28
29 def get(self, key, default=None):
30 try:
31 return self[key]
32 except KeyError:
33 return default
34
35
@@ -1,113 +1,113
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import error
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Classes
17 17 #-----------------------------------------------------------------------------
18 18
19 19 class AsyncResult(object):
20 20 """Class for representing results of non-blocking calls.
21 21
22 22 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
23 23 """
24 24 def __init__(self, client, msg_ids, fname=''):
25 25 self._client = client
26 26 self.msg_ids = msg_ids
27 27 self._fname=fname
28 28 self._ready = False
29 29 self._success = None
30 30
31 31 def __repr__(self):
32 32 if self._ready:
33 33 return "<%s: finished>"%(self.__class__.__name__)
34 34 else:
35 35 return "<%s: %s>"%(self.__class__.__name__,self._fname)
36 36
37 37
38 38 def _reconstruct_result(self, res):
39 39 """
40 40 Override me in subclasses for turning a list of results
41 41 into the expected form.
42 42 """
43 if len(res) == 1:
43 if len(self.msg_ids) == 1:
44 44 return res[0]
45 45 else:
46 46 return res
47 47
48 48 def get(self, timeout=-1):
49 49 """Return the result when it arrives.
50 50
51 51 If `timeout` is not ``None`` and the result does not arrive within
52 52 `timeout` seconds then ``TimeoutError`` is raised. If the
53 53 remote call raised an exception then that exception will be reraised
54 54 by get().
55 55 """
56 56 if not self.ready():
57 57 self.wait(timeout)
58 58
59 59 if self._ready:
60 60 if self._success:
61 61 return self._result
62 62 else:
63 63 raise self._exception
64 64 else:
65 65 raise error.TimeoutError("Result not ready.")
66 66
67 67 def ready(self):
68 68 """Return whether the call has completed."""
69 69 if not self._ready:
70 70 self.wait(0)
71 71 return self._ready
72 72
73 73 def wait(self, timeout=-1):
74 74 """Wait until the result is available or until `timeout` seconds pass.
75 75 """
76 76 if self._ready:
77 77 return
78 78 self._ready = self._client.barrier(self.msg_ids, timeout)
79 79 if self._ready:
80 80 try:
81 81 results = map(self._client.results.get, self.msg_ids)
82 82 results = error.collect_exceptions(results, self._fname)
83 83 self._result = self._reconstruct_result(results)
84 84 except Exception, e:
85 85 self._exception = e
86 86 self._success = False
87 87 else:
88 88 self._success = True
89 89
90 90
91 91 def successful(self):
92 92 """Return whether the call completed without raising an exception.
93 93
94 94 Will raise ``AssertionError`` if the result is not ready.
95 95 """
96 96 assert self._ready
97 97 return self._success
98 98
99 99 class AsyncMapResult(AsyncResult):
100 100 """Class for representing results of non-blocking gathers.
101 101
102 102 This will properly reconstruct the gather.
103 103 """
104 104
105 105 def __init__(self, client, msg_ids, mapObject, fname=''):
106 106 self._mapObject = mapObject
107 107 AsyncResult.__init__(self, client, msg_ids, fname=fname)
108 108
109 109 def _reconstruct_result(self, res):
110 110 """Perform the gather on the actual results."""
111 111 return self._mapObject.joinPartitions(res)
112 112
113 113
@@ -1,1045 +1,1105
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import time
15 15 from getpass import getpass
16 16 from pprint import pprint
17 from datetime import datetime
17 18
18 19 import zmq
19 20 from zmq.eventloop import ioloop, zmqstream
20 21
21 22 from IPython.external.decorator import decorator
22 23 from IPython.zmq import tunnel
23 24
24 25 import streamsession as ss
25 26 # from remotenamespace import RemoteNamespace
26 27 from view import DirectView, LoadBalancedView
27 28 from dependency import Dependency, depend, require
28 29 import error
29 30 import map as Map
30 31 from asyncresult import AsyncResult, AsyncMapResult
31 32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
33 from util import ReverseDict
32 34
33 35 #--------------------------------------------------------------------------
34 36 # helpers for implementing old MEC API via client.apply
35 37 #--------------------------------------------------------------------------
36 38
37 39 def _push(ns):
38 40 """helper method for implementing `client.push` via `client.apply`"""
39 41 globals().update(ns)
40 42
41 43 def _pull(keys):
42 44 """helper method for implementing `client.pull` via `client.apply`"""
43 45 g = globals()
44 46 if isinstance(keys, (list,tuple, set)):
45 47 for key in keys:
46 48 if not g.has_key(key):
47 49 raise NameError("name '%s' is not defined"%key)
48 50 return map(g.get, keys)
49 51 else:
50 52 if not g.has_key(keys):
51 53 raise NameError("name '%s' is not defined"%keys)
52 54 return g.get(keys)
53 55
54 56 def _clear():
55 57 """helper method for implementing `client.clear` via `client.apply`"""
56 58 globals().clear()
57 59
58 60 def _execute(code):
59 61 """helper method for implementing `client.execute` via `client.apply`"""
60 62 exec code in globals()
61 63
62 64
63 65 #--------------------------------------------------------------------------
64 66 # Decorators for Client methods
65 67 #--------------------------------------------------------------------------
66 68
67 69 @decorator
68 70 def spinfirst(f, self, *args, **kwargs):
69 71 """Call spin() to sync state prior to calling the method."""
70 72 self.spin()
71 73 return f(self, *args, **kwargs)
72 74
73 75 @decorator
74 76 def defaultblock(f, self, *args, **kwargs):
75 77 """Default to self.block; preserve self.block."""
76 78 block = kwargs.get('block',None)
77 79 block = self.block if block is None else block
78 80 saveblock = self.block
79 81 self.block = block
80 82 try:
81 83 ret = f(self, *args, **kwargs)
82 84 finally:
83 85 self.block = saveblock
84 86 return ret
85 87
88
89 #--------------------------------------------------------------------------
90 # Classes
91 #--------------------------------------------------------------------------
92
86 93 class AbortedTask(object):
87 94 """A basic wrapper object describing an aborted task."""
88 95 def __init__(self, msg_id):
89 96 self.msg_id = msg_id
90 97
91 98 class ResultDict(dict):
92 99 """A subclass of dict that raises errors if it has them."""
93 100 def __getitem__(self, key):
94 101 res = dict.__getitem__(self, key)
95 102 if isinstance(res, error.KernelError):
96 103 raise res
97 104 return res
98 105
99 106 class Client(object):
100 107 """A semi-synchronous client to the IPython ZMQ controller
101 108
102 109 Parameters
103 110 ----------
104 111
105 112 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
106 113 The address of the controller's registration socket.
107 114 [Default: 'tcp://127.0.0.1:10101']
108 115 context : zmq.Context
109 116 Pass an existing zmq.Context instance, otherwise the client will create its own
110 117 username : bytes
111 118 set username to be passed to the Session object
112 119 debug : bool
113 120 flag for lots of message printing for debug purposes
114 121
115 122 #-------------- ssh related args ----------------
116 123 # These are args for configuring the ssh tunnel to be used
117 124 # credentials are used to forward connections over ssh to the Controller
118 125 # Note that the ip given in `addr` needs to be relative to sshserver
119 126 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
120 127 # and set sshserver as the same machine the Controller is on. However,
121 128 # the only requirement is that sshserver is able to see the Controller
122 129 # (i.e. is within the same trusted network).
123 130
124 131 sshserver : str
125 132 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
126 133 If keyfile or password is specified, and this is not, it will default to
127 134 the ip given in addr.
128 135 sshkey : str; path to public ssh key file
129 136 This specifies a key to be used in ssh login, default None.
130 137 Regular default ssh keys will be used without specifying this argument.
131 138 password : str;
132 139 Your ssh password to sshserver. Note that if this is left None,
133 140 you will be prompted for it if passwordless key based login is unavailable.
134 141
135 142 #------- exec authentication args -------
136 143 # If even localhost is untrusted, you can have some protection against
137 144 # unauthorized execution by using a key. Messages are still sent
138 145 # as cleartext, so if someone can snoop your loopback traffic this will
139 146 # not help anything.
140 147
141 148 exec_key : str
142 149 an authentication key or file containing a key
143 150 default: None
144 151
145 152
146 153 Attributes
147 154 ----------
148 155 ids : set of int engine IDs
149 156 requesting the ids attribute always synchronizes
150 157 the registration state. To request ids without synchronization,
151 158 use semi-private _ids attributes.
152 159
153 160 history : list of msg_ids
154 161 a list of msg_ids, keeping track of all the execution
155 162 messages you have submitted in order.
156 163
157 164 outstanding : set of msg_ids
158 165 a set of msg_ids that have been submitted, but whose
159 166 results have not yet been received.
160 167
161 168 results : dict
162 169 a dict of all our results, keyed by msg_id
163 170
164 171 block : bool
165 172 determines default behavior when block not specified
166 173 in execution methods
167 174
168 175 Methods
169 176 -------
170 177 spin : flushes incoming results and registration state changes
171 178 control methods spin, and requesting `ids` also ensures up to date
172 179
173 180 barrier : wait on one or more msg_ids
174 181
175 182 execution methods: apply/apply_bound/apply_to/apply_bound
176 183 legacy: execute, run
177 184
178 185 query methods: queue_status, get_result, purge
179 186
180 187 control methods: abort, kill
181 188
182 189 """
183 190
184 191
185 192 _connected=False
186 193 _ssh=False
187 194 _engines=None
188 195 _addr='tcp://127.0.0.1:10101'
189 196 _registration_socket=None
190 197 _query_socket=None
191 198 _control_socket=None
192 199 _notification_socket=None
193 200 _mux_socket=None
194 201 _task_socket=None
195 202 block = False
196 203 outstanding=None
197 204 results = None
198 205 history = None
199 206 debug = False
200 207 targets = None
201 208
202 209 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
203 210 sshserver=None, sshkey=None, password=None, paramiko=None,
204 211 exec_key=None,):
205 212 if context is None:
206 213 context = zmq.Context()
207 214 self.context = context
208 215 self.targets = 'all'
209 216 self._addr = addr
210 217 self._ssh = bool(sshserver or sshkey or password)
211 218 if self._ssh and sshserver is None:
212 219 # default to the same
213 220 sshserver = addr.split('://')[1].split(':')[0]
214 221 if self._ssh and password is None:
215 222 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
216 223 password=False
217 224 else:
218 225 password = getpass("SSH Password for %s: "%sshserver)
219 226 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
220 227
221 228 if exec_key is not None and os.path.isfile(exec_key):
222 229 arg = 'keyfile'
223 230 else:
224 231 arg = 'key'
225 232 key_arg = {arg:exec_key}
226 233 if username is None:
227 234 self.session = ss.StreamSession(**key_arg)
228 235 else:
229 236 self.session = ss.StreamSession(username, **key_arg)
230 237 self._registration_socket = self.context.socket(zmq.XREQ)
231 238 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
232 239 if self._ssh:
233 240 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
234 241 else:
235 242 self._registration_socket.connect(addr)
236 self._engines = {}
243 self._engines = ReverseDict()
237 244 self._ids = set()
238 245 self.outstanding=set()
239 246 self.results = {}
247 self.metadata = {}
240 248 self.history = []
241 249 self.debug = debug
242 250 self.session.debug = debug
243 251
244 252 self._notification_handlers = {'registration_notification' : self._register_engine,
245 253 'unregistration_notification' : self._unregister_engine,
246 254 }
247 255 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
248 256 'apply_reply' : self._handle_apply_reply}
249 257 self._connect(sshserver, ssh_kwargs)
250 258
251 259
252 260 @property
253 261 def ids(self):
254 262 """Always up to date ids property."""
255 263 self._flush_notifications()
256 264 return self._ids
257 265
258 266 def _update_engines(self, engines):
259 267 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
260 268 for k,v in engines.iteritems():
261 269 eid = int(k)
262 270 self._engines[eid] = bytes(v) # force not unicode
263 271 self._ids.add(eid)
264 272
265 273 def _build_targets(self, targets):
266 274 """Turn valid target IDs or 'all' into two lists:
267 275 (int_ids, uuids).
268 276 """
269 277 if targets is None:
270 278 targets = self._ids
271 279 elif isinstance(targets, str):
272 280 if targets.lower() == 'all':
273 281 targets = self._ids
274 282 else:
275 283 raise TypeError("%r not valid str target, must be 'all'"%(targets))
276 284 elif isinstance(targets, int):
277 285 targets = [targets]
278 286 return [self._engines[t] for t in targets], list(targets)
279 287
280 288 def _connect(self, sshserver, ssh_kwargs):
281 289 """setup all our socket connections to the controller. This is called from
282 290 __init__."""
283 291 if self._connected:
284 292 return
285 293 self._connected=True
286 294
287 295 def connect_socket(s, addr):
288 296 if self._ssh:
289 297 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
290 298 else:
291 299 return s.connect(addr)
292 300
293 301 self.session.send(self._registration_socket, 'connection_request')
294 302 idents,msg = self.session.recv(self._registration_socket,mode=0)
295 303 if self.debug:
296 304 pprint(msg)
297 305 msg = ss.Message(msg)
298 306 content = msg.content
299 307 if content.status == 'ok':
300 308 if content.queue:
301 309 self._mux_socket = self.context.socket(zmq.PAIR)
302 310 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
303 311 connect_socket(self._mux_socket, content.queue)
304 312 if content.task:
305 313 self._task_socket = self.context.socket(zmq.PAIR)
306 314 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
307 315 connect_socket(self._task_socket, content.task)
308 316 if content.notification:
309 317 self._notification_socket = self.context.socket(zmq.SUB)
310 318 connect_socket(self._notification_socket, content.notification)
311 319 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
312 320 if content.query:
313 321 self._query_socket = self.context.socket(zmq.PAIR)
314 322 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
315 323 connect_socket(self._query_socket, content.query)
316 324 if content.control:
317 325 self._control_socket = self.context.socket(zmq.PAIR)
318 326 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
319 327 connect_socket(self._control_socket, content.control)
320 328 self._update_engines(dict(content.engines))
321 329
322 330 else:
323 331 self._connected = False
324 332 raise Exception("Failed to connect!")
325 333
326 334 #--------------------------------------------------------------------------
327 335 # handlers and callbacks for incoming messages
328 336 #--------------------------------------------------------------------------
329 337
330 338 def _register_engine(self, msg):
331 339 """Register a new engine, and update our connection info."""
332 340 content = msg['content']
333 341 eid = content['id']
334 342 d = {eid : content['queue']}
335 343 self._update_engines(d)
336 344 self._ids.add(int(eid))
337 345
338 346 def _unregister_engine(self, msg):
339 347 """Unregister an engine that has died."""
340 348 content = msg['content']
341 349 eid = int(content['id'])
342 350 if eid in self._ids:
343 351 self._ids.remove(eid)
344 352 self._engines.pop(eid)
353 #
354 def _build_metadata(self, header, parent, content):
355 md = {'msg_id' : parent['msg_id'],
356 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
357 'started' : datetime.strptime(header['started'], ss.ISO8601),
358 'completed' : datetime.strptime(header['date'], ss.ISO8601),
359 'received' : datetime.now(),
360 'engine_uuid' : header['engine'],
361 'engine_id' : self._engines.get(header['engine'], None),
362 'follow' : parent['follow'],
363 'after' : parent['after'],
364 'status' : content['status']
365 }
366 return md
345 367
346 368 def _handle_execute_reply(self, msg):
347 """Save the reply to an execute_request into our results."""
369 """Save the reply to an execute_request into our results.
370
371 execute messages are never actually used. apply is used instead.
372 """
373
348 374 parent = msg['parent_header']
349 375 msg_id = parent['msg_id']
350 376 if msg_id not in self.outstanding:
351 377 print("got unknown result: %s"%msg_id)
352 378 else:
353 379 self.outstanding.remove(msg_id)
354 380 self.results[msg_id] = ss.unwrap_exception(msg['content'])
355 381
356 382 def _handle_apply_reply(self, msg):
357 383 """Save the reply to an apply_request into our results."""
358 384 parent = msg['parent_header']
359 385 msg_id = parent['msg_id']
360 386 if msg_id not in self.outstanding:
361 387 print ("got unknown result: %s"%msg_id)
362 388 else:
363 389 self.outstanding.remove(msg_id)
364 390 content = msg['content']
391 header = msg['header']
392
393 self.metadata[msg_id] = self._build_metadata(header, parent, content)
394
365 395 if content['status'] == 'ok':
366 self.results[msg_id] = ss.unserialize_object(msg['buffers'])
396 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
367 397 elif content['status'] == 'aborted':
368 398 self.results[msg_id] = error.AbortedTask(msg_id)
369 399 elif content['status'] == 'resubmitted':
370 400 # TODO: handle resubmission
371 401 pass
372 402 else:
373 403 e = ss.unwrap_exception(content)
374 404 e_uuid = e.engine_info['engineid']
375 for k,v in self._engines.iteritems():
376 if v == e_uuid:
377 e.engine_info['engineid'] = k
378 break
405 eid = self._engines[e_uuid]
406 e.engine_info['engineid'] = eid
379 407 self.results[msg_id] = e
380 408
381 409 def _flush_notifications(self):
382 410 """Flush notifications of engine registrations waiting
383 411 in ZMQ queue."""
384 412 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
385 413 while msg is not None:
386 414 if self.debug:
387 415 pprint(msg)
388 416 msg = msg[-1]
389 417 msg_type = msg['msg_type']
390 418 handler = self._notification_handlers.get(msg_type, None)
391 419 if handler is None:
392 420 raise Exception("Unhandled message type: %s"%msg.msg_type)
393 421 else:
394 422 handler(msg)
395 423 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
396 424
397 425 def _flush_results(self, sock):
398 426 """Flush task or queue results waiting in ZMQ queue."""
399 427 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
400 428 while msg is not None:
401 429 if self.debug:
402 430 pprint(msg)
403 431 msg = msg[-1]
404 432 msg_type = msg['msg_type']
405 433 handler = self._queue_handlers.get(msg_type, None)
406 434 if handler is None:
407 435 raise Exception("Unhandled message type: %s"%msg.msg_type)
408 436 else:
409 437 handler(msg)
410 438 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
411 439
412 440 def _flush_control(self, sock):
413 441 """Flush replies from the control channel waiting
414 442 in the ZMQ queue.
415 443
416 444 Currently: ignore them."""
417 445 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
418 446 while msg is not None:
419 447 if self.debug:
420 448 pprint(msg)
421 449 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
422 450
423 451 #--------------------------------------------------------------------------
424 452 # getitem
425 453 #--------------------------------------------------------------------------
426 454
427 455 def __getitem__(self, key):
428 456 """Dict access returns DirectView multiplexer objects or,
429 457 if key is None, a LoadBalancedView."""
430 458 if key is None:
431 459 return LoadBalancedView(self)
432 460 if isinstance(key, int):
433 461 if key not in self.ids:
434 462 raise IndexError("No such engine: %i"%key)
435 463 return DirectView(self, key)
436 464
437 465 if isinstance(key, slice):
438 466 indices = range(len(self.ids))[key]
439 467 ids = sorted(self._ids)
440 468 key = [ ids[i] for i in indices ]
441 469 # newkeys = sorted(self._ids)[thekeys[k]]
442 470
443 471 if isinstance(key, (tuple, list, xrange)):
444 472 _,targets = self._build_targets(list(key))
445 473 return DirectView(self, targets)
446 474 else:
447 475 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
448 476
449 477 #--------------------------------------------------------------------------
450 478 # Begin public methods
451 479 #--------------------------------------------------------------------------
452 480
453 481 @property
454 482 def remote(self):
455 483 """property for convenient RemoteFunction generation.
456 484
457 485 >>> @client.remote
458 486 ... def f():
459 487 import os
460 488 print (os.getpid())
461 489 """
462 490 return remote(self, block=self.block)
463 491
464 492 def spin(self):
465 493 """Flush any registration notifications and execution results
466 494 waiting in the ZMQ queue.
467 495 """
468 496 if self._notification_socket:
469 497 self._flush_notifications()
470 498 if self._mux_socket:
471 499 self._flush_results(self._mux_socket)
472 500 if self._task_socket:
473 501 self._flush_results(self._task_socket)
474 502 if self._control_socket:
475 503 self._flush_control(self._control_socket)
476 504
477 505 def barrier(self, msg_ids=None, timeout=-1):
478 506 """waits on one or more `msg_ids`, for up to `timeout` seconds.
479 507
480 508 Parameters
481 509 ----------
482 510 msg_ids : int, str, or list of ints and/or strs, or one or more AsyncResult objects
483 511 ints are indices to self.history
484 512 strs are msg_ids
485 513 default: wait on all outstanding messages
486 514 timeout : float
487 515 a time in seconds, after which to give up.
488 516 default is -1, which means no timeout
489 517
490 518 Returns
491 519 -------
492 520 True : when all msg_ids are done
493 521 False : timeout reached, some msg_ids still outstanding
494 522 """
495 523 tic = time.time()
496 524 if msg_ids is None:
497 525 theids = self.outstanding
498 526 else:
499 527 if isinstance(msg_ids, (int, str, AsyncResult)):
500 528 msg_ids = [msg_ids]
501 529 theids = set()
502 530 for msg_id in msg_ids:
503 531 if isinstance(msg_id, int):
504 532 msg_id = self.history[msg_id]
505 533 elif isinstance(msg_id, AsyncResult):
506 534 map(theids.add, msg_id.msg_ids)
507 535 continue
508 536 theids.add(msg_id)
509 537 if not theids.intersection(self.outstanding):
510 538 return True
511 539 self.spin()
512 540 while theids.intersection(self.outstanding):
513 541 if timeout >= 0 and ( time.time()-tic ) > timeout:
514 542 break
515 543 time.sleep(1e-3)
516 544 self.spin()
517 545 return len(theids.intersection(self.outstanding)) == 0
518 546
519 547 #--------------------------------------------------------------------------
520 548 # Control methods
521 549 #--------------------------------------------------------------------------
522 550
523 551 @spinfirst
524 552 @defaultblock
525 553 def clear(self, targets=None, block=None):
526 554 """Clear the namespace in target(s)."""
527 555 targets = self._build_targets(targets)[0]
528 556 for t in targets:
529 557 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
530 558 error = False
531 559 if self.block:
532 560 for i in range(len(targets)):
533 561 idents,msg = self.session.recv(self._control_socket,0)
534 562 if self.debug:
535 563 pprint(msg)
536 564 if msg['content']['status'] != 'ok':
537 565 error = ss.unwrap_exception(msg['content'])
538 566 if error:
539 567 return error
540 568
541 569
542 570 @spinfirst
543 571 @defaultblock
544 572 def abort(self, msg_ids = None, targets=None, block=None):
545 573 """Abort the execution queues of target(s)."""
546 574 targets = self._build_targets(targets)[0]
547 575 if isinstance(msg_ids, basestring):
548 576 msg_ids = [msg_ids]
549 577 content = dict(msg_ids=msg_ids)
550 578 for t in targets:
551 579 self.session.send(self._control_socket, 'abort_request',
552 580 content=content, ident=t)
553 581 error = False
554 582 if self.block:
555 583 for i in range(len(targets)):
556 584 idents,msg = self.session.recv(self._control_socket,0)
557 585 if self.debug:
558 586 pprint(msg)
559 587 if msg['content']['status'] != 'ok':
560 588 error = ss.unwrap_exception(msg['content'])
561 589 if error:
562 590 return error
563 591
564 592 @spinfirst
565 593 @defaultblock
566 594 def shutdown(self, targets=None, restart=False, controller=False, block=None):
567 595 """Terminates one or more engine processes, optionally including the controller."""
568 596 if controller:
569 597 targets = 'all'
570 598 targets = self._build_targets(targets)[0]
571 599 for t in targets:
572 600 self.session.send(self._control_socket, 'shutdown_request',
573 601 content={'restart':restart},ident=t)
574 602 error = False
575 603 if block or controller:
576 604 for i in range(len(targets)):
577 605 idents,msg = self.session.recv(self._control_socket,0)
578 606 if self.debug:
579 607 pprint(msg)
580 608 if msg['content']['status'] != 'ok':
581 609 error = ss.unwrap_exception(msg['content'])
582 610
583 611 if controller:
584 612 time.sleep(0.25)
585 613 self.session.send(self._query_socket, 'shutdown_request')
586 614 idents,msg = self.session.recv(self._query_socket, 0)
587 615 if self.debug:
588 616 pprint(msg)
589 617 if msg['content']['status'] != 'ok':
590 618 error = ss.unwrap_exception(msg['content'])
591 619
592 620 if error:
593 621 raise error
594 622
595 623 #--------------------------------------------------------------------------
596 624 # Execution methods
597 625 #--------------------------------------------------------------------------
598 626
599 627 @defaultblock
600 628 def execute(self, code, targets='all', block=None):
601 629 """Executes `code` on `targets` in blocking or nonblocking manner.
602 630
603 631 ``execute`` is always `bound` (affects engine namespace)
604 632
605 633 Parameters
606 634 ----------
607 635 code : str
608 636 the code string to be executed
609 637 targets : int/str/list of ints/strs
610 638 the engines on which to execute
611 639 default : all
612 640 block : bool
613 641 whether or not to wait until done to return
614 642 default: self.block
615 643 """
616 644 result = self.apply(_execute, (code,), targets=targets, block=self.block, bound=True)
617 645 return result
618 646
619 647 def run(self, code, block=None):
620 648 """Runs `code` on an engine.
621 649
622 650 Calls to this are load-balanced.
623 651
624 652 ``run`` is never `bound` (no effect on engine namespace)
625 653
626 654 Parameters
627 655 ----------
628 656 code : str
629 657 the code string to be executed
630 658 block : bool
631 659 whether or not to wait until done
632 660
633 661 """
634 662 result = self.apply(_execute, (code,), targets=None, block=block, bound=False)
635 663 return result
636 664
637 665 def _maybe_raise(self, result):
638 666 """wrapper for maybe raising an exception if apply failed."""
639 667 if isinstance(result, error.RemoteError):
640 668 raise result
641 669
642 670 return result
643 671
644 672 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
645 673 after=None, follow=None):
646 674 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
647 675
648 676 This is the central execution command for the client.
649 677
650 678 Parameters
651 679 ----------
652 680
653 681 f : function
654 682 The fuction to be called remotely
655 683 args : tuple/list
656 684 The positional arguments passed to `f`
657 685 kwargs : dict
658 686 The keyword arguments passed to `f`
659 687 bound : bool (default: True)
660 688 Whether to execute in the Engine(s) namespace, or in a clean
661 689 namespace not affecting the engine.
662 690 block : bool (default: self.block)
663 691 Whether to wait for the result, or return immediately.
664 692 False:
665 693 returns msg_id(s)
666 694 if multiple targets:
667 695 list of ids
668 696 True:
669 697 returns actual result(s) of f(*args, **kwargs)
670 698 if multiple targets:
671 699 dict of results, by engine ID
672 700 targets : int,list of ints, 'all', None
673 701 Specify the destination of the job.
674 702 if None:
675 703 Submit via Task queue for load-balancing.
676 704 if 'all':
677 705 Run on all active engines
678 706 if list:
679 707 Run on each specified engine
680 708 if int:
681 709 Run on single engine
682 710
683 711 after : Dependency or collection of msg_ids
684 712 Only for load-balanced execution (targets=None)
685 713 Specify a list of msg_ids as a time-based dependency.
686 714 This job will only be run *after* the dependencies
687 715 have been met.
688 716
689 717 follow : Dependency or collection of msg_ids
690 718 Only for load-balanced execution (targets=None)
691 719 Specify a list of msg_ids as a location-based dependency.
692 720 This job will only be run on an engine where this dependency
693 721 is met.
694 722
695 723 Returns
696 724 -------
697 725 if block is False:
698 726 if single target:
699 727 return msg_id
700 728 else:
701 729 return list of msg_ids
702 730 ? (should this be dict like block=True) ?
703 731 else:
704 732 if single target:
705 733 return result of f(*args, **kwargs)
706 734 else:
707 735 return dict of results, keyed by engine
708 736 """
709 737
710 738 # defaults:
711 739 block = block if block is not None else self.block
712 740 args = args if args is not None else []
713 741 kwargs = kwargs if kwargs is not None else {}
714 742
715 743 # enforce types of f,args,kwrags
716 744 if not callable(f):
717 745 raise TypeError("f must be callable, not %s"%type(f))
718 746 if not isinstance(args, (tuple, list)):
719 747 raise TypeError("args must be tuple or list, not %s"%type(args))
720 748 if not isinstance(kwargs, dict):
721 749 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
722 750
723 751 if isinstance(after, Dependency):
724 752 after = after.as_dict()
725 753 elif isinstance(after, AsyncResult):
726 754 after=after.msg_ids
727 755 elif after is None:
728 756 after = []
729 757 if isinstance(follow, Dependency):
730 758 follow = follow.as_dict()
731 759 elif isinstance(follow, AsyncResult):
732 760 follow=follow.msg_ids
733 761 elif follow is None:
734 762 follow = []
735 763 options = dict(bound=bound, block=block, after=after, follow=follow)
736 764
737 765 if targets is None:
738 766 return self._apply_balanced(f, args, kwargs, **options)
739 767 else:
740 768 return self._apply_direct(f, args, kwargs, targets=targets, **options)
741 769
742 770 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
743 771 after=None, follow=None):
744 772 """The underlying method for applying functions in a load balanced
745 773 manner, via the task queue."""
746 774
747 775 subheader = dict(after=after, follow=follow)
748 776 bufs = ss.pack_apply_message(f,args,kwargs)
749 777 content = dict(bound=bound)
750 778
751 779 msg = self.session.send(self._task_socket, "apply_request",
752 780 content=content, buffers=bufs, subheader=subheader)
753 781 msg_id = msg['msg_id']
754 782 self.outstanding.add(msg_id)
755 783 self.history.append(msg_id)
756 784 ar = AsyncResult(self, [msg_id], fname=f.__name__)
757 785 if block:
758 786 return ar.get()
759 787 else:
760 788 return ar
761 789
762 790 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
763 791 after=None, follow=None):
764 792 """Then underlying method for applying functions to specific engines
765 793 via the MUX queue."""
766 794
767 795 queues,targets = self._build_targets(targets)
768 796
769 797 subheader = dict(after=after, follow=follow)
770 798 content = dict(bound=bound)
771 799 bufs = ss.pack_apply_message(f,args,kwargs)
772 800
773 801 msg_ids = []
774 802 for queue in queues:
775 803 msg = self.session.send(self._mux_socket, "apply_request",
776 804 content=content, buffers=bufs,ident=queue, subheader=subheader)
777 805 msg_id = msg['msg_id']
778 806 self.outstanding.add(msg_id)
779 807 self.history.append(msg_id)
780 808 msg_ids.append(msg_id)
781 809 ar = AsyncResult(self, msg_ids, fname=f.__name__)
782 810 if block:
783 811 return ar.get()
784 812 else:
785 813 return ar
786 814
787 815 #--------------------------------------------------------------------------
788 816 # Map and decorators
789 817 #--------------------------------------------------------------------------
790 818
791 819 def map(self, f, *sequences):
792 820 """Parallel version of builtin `map`, using all our engines."""
793 821 pf = ParallelFunction(self, f, block=self.block,
794 822 bound=True, targets='all')
795 823 return pf.map(*sequences)
796 824
797 825 def parallel(self, bound=True, targets='all', block=True):
798 826 """Decorator for making a ParallelFunction."""
799 827 return parallel(self, bound=bound, targets=targets, block=block)
800 828
801 829 def remote(self, bound=True, targets='all', block=True):
802 830 """Decorator for making a RemoteFunction."""
803 831 return remote(self, bound=bound, targets=targets, block=block)
804 832
805 833 #--------------------------------------------------------------------------
806 834 # Data movement
807 835 #--------------------------------------------------------------------------
808 836
809 837 @defaultblock
810 838 def push(self, ns, targets='all', block=None):
811 839 """Push the contents of `ns` into the namespace on `target`"""
812 840 if not isinstance(ns, dict):
813 841 raise TypeError("Must be a dict, not %s"%type(ns))
814 842 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
815 843 return result
816 844
817 845 @defaultblock
818 846 def pull(self, keys, targets='all', block=None):
819 847 """Pull objects from `target`'s namespace by `keys`"""
820 848 if isinstance(keys, str):
821 849 pass
822 850 elif isinstance(keys, (list,tuple,set)):
823 851 for key in keys:
824 852 if not isinstance(key, str):
825 853 raise TypeError
826 854 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
827 855 return result
828 856
829 857 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
830 858 """
831 859 Partition a Python sequence and send the partitions to a set of engines.
832 860 """
833 861 block = block if block is not None else self.block
834 862 targets = self._build_targets(targets)[-1]
835 863 mapObject = Map.dists[dist]()
836 864 nparts = len(targets)
837 865 msg_ids = []
838 866 for index, engineid in enumerate(targets):
839 867 partition = mapObject.getPartition(seq, index, nparts)
840 868 if flatten and len(partition) == 1:
841 869 r = self.push({key: partition[0]}, targets=engineid, block=False)
842 870 else:
843 871 r = self.push({key: partition}, targets=engineid, block=False)
844 872 msg_ids.extend(r.msg_ids)
845 873 r = AsyncResult(self, msg_ids, fname='scatter')
846 874 if block:
847 875 return r.get()
848 876 else:
849 877 return r
850 878
851 879 def gather(self, key, dist='b', targets='all', block=None):
852 880 """
853 881 Gather a partitioned sequence on a set of engines as a single local seq.
854 882 """
855 883 block = block if block is not None else self.block
856 884
857 885 targets = self._build_targets(targets)[-1]
858 886 mapObject = Map.dists[dist]()
859 887 msg_ids = []
860 888 for index, engineid in enumerate(targets):
861 889 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
862 890
863 891 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
864 892 if block:
865 893 return r.get()
866 894 else:
867 895 return r
868 896
869 897 #--------------------------------------------------------------------------
870 898 # Query methods
871 899 #--------------------------------------------------------------------------
872 900
873 901 @spinfirst
874 902 def get_results(self, msg_ids, status_only=False):
875 903 """Returns the result of the execute or task request with `msg_ids`.
876 904
877 905 Parameters
878 906 ----------
879 907 msg_ids : list of ints or msg_ids
880 908 if int:
881 909 Passed as index to self.history for convenience.
882 910 status_only : bool (default: False)
883 911 if False:
884 912 return the actual results
913
914 Returns
915 -------
916
917 results : dict
918 There will always be the keys 'pending' and 'completed', which will
919 be lists of msg_ids.
885 920 """
886 921 if not isinstance(msg_ids, (list,tuple)):
887 922 msg_ids = [msg_ids]
888 923 theids = []
889 924 for msg_id in msg_ids:
890 925 if isinstance(msg_id, int):
891 926 msg_id = self.history[msg_id]
892 927 if not isinstance(msg_id, str):
893 928 raise TypeError("msg_ids must be str, not %r"%msg_id)
894 929 theids.append(msg_id)
895 930
896 931 completed = []
897 932 local_results = {}
898 for msg_id in list(theids):
899 if msg_id in self.results:
900 completed.append(msg_id)
901 local_results[msg_id] = self.results[msg_id]
902 theids.remove(msg_id)
933 # temporarily disable local shortcut
934 # for msg_id in list(theids):
935 # if msg_id in self.results:
936 # completed.append(msg_id)
937 # local_results[msg_id] = self.results[msg_id]
938 # theids.remove(msg_id)
903 939
904 940 if theids: # some not locally cached
905 941 content = dict(msg_ids=theids, status_only=status_only)
906 942 msg = self.session.send(self._query_socket, "result_request", content=content)
907 943 zmq.select([self._query_socket], [], [])
908 944 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
909 945 if self.debug:
910 946 pprint(msg)
911 947 content = msg['content']
912 948 if content['status'] != 'ok':
913 949 raise ss.unwrap_exception(content)
950 buffers = msg['buffers']
914 951 else:
915 952 content = dict(completed=[],pending=[])
916 if not status_only:
917 # load cached results into result:
953
918 954 content['completed'].extend(completed)
955
956 if status_only:
957 return content
958
959 failures = []
960 # load cached results into result:
919 961 content.update(local_results)
920 962 # update cache with results:
921 for msg_id in msg_ids:
963 for msg_id in sorted(theids):
922 964 if msg_id in content['completed']:
923 self.results[msg_id] = content[msg_id]
965 rec = content[msg_id]
966 parent = rec['header']
967 header = rec['result_header']
968 rcontent = rec['result_content']
969 if isinstance(rcontent, str):
970 rcontent = self.session.unpack(rcontent)
971
972 self.metadata[msg_id] = self._build_metadata(header, parent, rcontent)
973
974 if rcontent['status'] == 'ok':
975 res,buffers = ss.unserialize_object(buffers)
976 else:
977 res = ss.unwrap_exception(rcontent)
978 failures.append(res)
979
980 self.results[msg_id] = res
981 content[msg_id] = res
982
983 error.collect_exceptions(failures, "get_results")
924 984 return content
925 985
926 986 @spinfirst
927 987 def queue_status(self, targets=None, verbose=False):
928 988 """Fetch the status of engine queues.
929 989
930 990 Parameters
931 991 ----------
932 992 targets : int/str/list of ints/strs
933 993 the engines on which to execute
934 994 default : all
935 995 verbose : bool
936 996 Whether to return lengths only, or lists of ids for each element
937 997 """
938 998 targets = self._build_targets(targets)[1]
939 999 content = dict(targets=targets, verbose=verbose)
940 1000 self.session.send(self._query_socket, "queue_request", content=content)
941 1001 idents,msg = self.session.recv(self._query_socket, 0)
942 1002 if self.debug:
943 1003 pprint(msg)
944 1004 content = msg['content']
945 1005 status = content.pop('status')
946 1006 if status != 'ok':
947 1007 raise ss.unwrap_exception(content)
948 return content
1008 return ss.rekey(content)
949 1009
950 1010 @spinfirst
951 1011 def purge_results(self, msg_ids=[], targets=[]):
952 1012 """Tell the controller to forget results.
953 1013
954 1014 Individual results can be purged by msg_id, or the entire
955 1015 history of specific targets can be purged.
956 1016
957 1017 Parameters
958 1018 ----------
959 1019 msg_ids : str or list of strs
960 1020 the msg_ids whose results should be forgotten.
961 1021 targets : int/str/list of ints/strs
962 1022 The targets, by uuid or int_id, whose entire history is to be purged.
963 1023 Use `targets='all'` to scrub everything from the controller's memory.
964 1024
965 1025 default : None
966 1026 """
967 1027 if not targets and not msg_ids:
968 1028 raise ValueError
969 1029 if targets:
970 1030 targets = self._build_targets(targets)[1]
971 1031 content = dict(targets=targets, msg_ids=msg_ids)
972 1032 self.session.send(self._query_socket, "purge_request", content=content)
973 1033 idents, msg = self.session.recv(self._query_socket, 0)
974 1034 if self.debug:
975 1035 pprint(msg)
976 1036 content = msg['content']
977 1037 if content['status'] != 'ok':
978 1038 raise ss.unwrap_exception(content)
979 1039
980 1040 #----------------------------------------
981 1041 # activate for %px,%autopx magics
982 1042 #----------------------------------------
983 1043 def activate(self):
984 1044 """Make this `View` active for parallel magic commands.
985 1045
986 1046 IPython has a magic command syntax to work with `MultiEngineClient` objects.
987 1047 In a given IPython session there is a single active one. While
988 1048 there can be many `Views` created and used by the user,
989 1049 there is only one active one. The active `View` is used whenever
990 1050 the magic commands %px and %autopx are used.
991 1051
992 1052 The activate() method is called on a given `View` to make it
993 1053 active. Once this has been done, the magic commands can be used.
994 1054 """
995 1055
996 1056 try:
997 1057 # This is injected into __builtins__.
998 1058 ip = get_ipython()
999 1059 except NameError:
1000 1060 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
1001 1061 else:
1002 1062 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
1003 1063 if pmagic is not None:
1004 1064 pmagic.active_multiengine_client = self
1005 1065 else:
1006 1066 print "You must first load the parallelmagic extension " \
1007 1067 "by doing '%load_ext parallelmagic'"
1008 1068
1009 1069 class AsynClient(Client):
1010 1070 """An Asynchronous client, using the Tornado Event Loop.
1011 1071 !!!unfinished!!!"""
1012 1072 io_loop = None
1013 1073 _queue_stream = None
1014 1074 _notifier_stream = None
1015 1075 _task_stream = None
1016 1076 _control_stream = None
1017 1077
1018 1078 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
1019 1079 Client.__init__(self, addr, context, username, debug)
1020 1080 if io_loop is None:
1021 1081 io_loop = ioloop.IOLoop.instance()
1022 1082 self.io_loop = io_loop
1023 1083
1024 1084 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
1025 1085 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
1026 1086 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
1027 1087 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
1028 1088
1029 1089 def spin(self):
1030 1090 for stream in (self.queue_stream, self.notifier_stream,
1031 1091 self.task_stream, self.control_stream):
1032 1092 stream.flush()
1033 1093
1034 1094 __all__ = [ 'Client',
1035 1095 'depend',
1036 1096 'require',
1037 1097 'remote',
1038 1098 'parallel',
1039 1099 'RemoteFunction',
1040 1100 'ParallelFunction',
1041 1101 'DirectView',
1042 1102 'LoadBalancedView',
1043 1103 'AsyncResult',
1044 1104 'AsyncMapResult'
1045 1105 ]
@@ -1,1050 +1,1035
1 1 #!/usr/bin/env python
2 2 """The IPython Controller with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import os
20 20 from datetime import datetime
21 21 import logging
22 22 import time
23 23 import uuid
24 24
25 25 import zmq
26 26 from zmq.eventloop import zmqstream, ioloop
27 27
28 28 # internal:
29 29 from IPython.zmq.log import logger # a Logger object
30 30 from IPython.zmq.entry_point import bind_port
31 31
32 32 from streamsession import Message, wrap_exception, ISO8601
33 33 from entry_point import (make_base_argument_parser, select_random_ports, split_ports,
34 34 connect_logger, parse_url, signal_children, generate_exec_key)
35 35 from dictdb import DictDB
36 36 try:
37 37 from pymongo.binary import Binary
38 38 except ImportError:
39 39 MongoDB=None
40 40 else:
41 41 from mongodb import MongoDB
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Code
45 45 #-----------------------------------------------------------------------------
46 46
47 47 def _passer(*args, **kwargs):
48 48 return
49 49
50 class ReverseDict(dict):
51 """simple double-keyed subset of dict methods."""
52
53 def __init__(self, *args, **kwargs):
54 dict.__init__(self, *args, **kwargs)
55 self.reverse = dict()
56 for key, value in self.iteritems():
57 self.reverse[value] = key
58
59 def __getitem__(self, key):
60 try:
61 return dict.__getitem__(self, key)
62 except KeyError:
63 return self.reverse[key]
64
65 def __setitem__(self, key, value):
66 if key in self.reverse:
67 raise KeyError("Can't have key %r on both sides!"%key)
68 dict.__setitem__(self, key, value)
69 self.reverse[value] = key
70
71 def pop(self, key):
72 value = dict.pop(self, key)
73 self.d1.pop(value)
74 return value
75
76
77 50 def init_record(msg):
78 51 """return an empty TaskRecord dict, with all keys initialized with None."""
79 52 header = msg['header']
80 53 return {
81 54 'msg_id' : header['msg_id'],
82 55 'header' : header,
83 56 'content': msg['content'],
84 57 'buffers': msg['buffers'],
85 58 'submitted': datetime.strptime(header['date'], ISO8601),
86 59 'client_uuid' : None,
87 60 'engine_uuid' : None,
88 61 'started': None,
89 62 'completed': None,
90 63 'resubmitted': None,
91 64 'result_header' : None,
92 65 'result_content' : None,
93 66 'result_buffers' : None,
94 67 'queue' : None
95 68 }
96 69
97 70
98 71 class EngineConnector(object):
99 72 """A simple object for accessing the various zmq connections of an object.
100 73 Attributes are:
101 74 id (int): engine ID
102 75 uuid (str): uuid (unused?)
103 76 queue (str): identity of queue's XREQ socket
104 77 registration (str): identity of registration XREQ socket
105 78 heartbeat (str): identity of heartbeat XREQ socket
106 79 """
107 80 id=0
108 81 queue=None
109 82 control=None
110 83 registration=None
111 84 heartbeat=None
112 85 pending=None
113 86
114 87 def __init__(self, id, queue, registration, control, heartbeat=None):
115 88 logger.info("engine::Engine Connected: %i"%id)
116 89 self.id = id
117 90 self.queue = queue
118 91 self.registration = registration
119 92 self.control = control
120 93 self.heartbeat = heartbeat
121 94
122 95 class Controller(object):
123 96 """The IPython Controller with 0MQ connections
124 97
125 98 Parameters
126 99 ==========
127 100 loop: zmq IOLoop instance
128 101 session: StreamSession object
129 102 <removed> context: zmq context for creating new connections (?)
130 103 queue: ZMQStream for monitoring the command queue (SUB)
131 104 registrar: ZMQStream for engine registration requests (XREP)
132 105 heartbeat: HeartMonitor object checking the pulse of the engines
133 106 clientele: ZMQStream for client connections (XREP)
134 107 not used for jobs, only query/control commands
135 108 notifier: ZMQStream for broadcasting engine registration changes (PUB)
136 109 db: connection to db for out of memory logging of commands
137 110 NotImplemented
138 111 engine_addrs: dict of zmq connection information for engines to connect
139 112 to the queues.
140 113 client_addrs: dict of zmq connection information for engines to connect
141 114 to the queues.
142 115 """
143 116 # internal data structures:
144 117 ids=None # engine IDs
145 118 keytable=None
146 119 engines=None
147 120 clients=None
148 121 hearts=None
149 122 pending=None
150 123 results=None
151 124 tasks=None
152 125 completed=None
153 126 mia=None
154 127 incoming_registrations=None
155 128 registration_timeout=None
156 129
157 130 #objects from constructor:
158 131 loop=None
159 132 registrar=None
160 133 clientelle=None
161 134 queue=None
162 135 heartbeat=None
163 136 notifier=None
164 137 db=None
165 138 client_addr=None
166 139 engine_addrs=None
167 140
168 141
169 142 def __init__(self, loop, session, queue, registrar, heartbeat, clientele, notifier, db, engine_addrs, client_addrs):
170 143 """
171 144 # universal:
172 145 loop: IOLoop for creating future connections
173 146 session: streamsession for sending serialized data
174 147 # engine:
175 148 queue: ZMQStream for monitoring queue messages
176 149 registrar: ZMQStream for engine registration
177 150 heartbeat: HeartMonitor object for tracking engines
178 151 # client:
179 152 clientele: ZMQStream for client connections
180 153 # extra:
181 154 db: ZMQStream for db connection (NotImplemented)
182 155 engine_addrs: zmq address/protocol dict for engine connections
183 156 client_addrs: zmq address/protocol dict for client connections
184 157 """
185 158 self.ids = set()
186 159 self.keytable={}
187 160 self.incoming_registrations={}
188 161 self.engines = {}
189 162 self.by_ident = {}
190 163 self.clients = {}
191 164 self.hearts = {}
192 165 # self.mia = set()
193 166
194 167 # self.sockets = {}
195 168 self.loop = loop
196 169 self.session = session
197 170 self.registrar = registrar
198 171 self.clientele = clientele
199 172 self.queue = queue
200 173 self.heartbeat = heartbeat
201 174 self.notifier = notifier
202 175 self.db = db
203 176
204 177 # validate connection dicts:
205 178 self.client_addrs = client_addrs
206 179 assert isinstance(client_addrs['queue'], str)
207 180 assert isinstance(client_addrs['control'], str)
208 181 # self.hb_addrs = hb_addrs
209 182 self.engine_addrs = engine_addrs
210 183 assert isinstance(engine_addrs['queue'], str)
211 184 assert isinstance(client_addrs['control'], str)
212 185 assert len(engine_addrs['heartbeat']) == 2
213 186
214 187 # register our callbacks
215 188 self.registrar.on_recv(self.dispatch_register_request)
216 189 self.clientele.on_recv(self.dispatch_client_msg)
217 190 self.queue.on_recv(self.dispatch_queue_traffic)
218 191
219 192 if heartbeat is not None:
220 193 heartbeat.add_heart_failure_handler(self.handle_heart_failure)
221 194 heartbeat.add_new_heart_handler(self.handle_new_heart)
222 195
223 196 self.queue_handlers = { 'in' : self.save_queue_request,
224 197 'out': self.save_queue_result,
225 198 'intask': self.save_task_request,
226 199 'outtask': self.save_task_result,
227 200 'tracktask': self.save_task_destination,
228 201 'incontrol': _passer,
229 202 'outcontrol': _passer,
230 203 }
231 204
232 205 self.client_handlers = {'queue_request': self.queue_status,
233 206 'result_request': self.get_results,
234 207 'purge_request': self.purge_results,
235 208 'load_request': self.check_load,
236 209 'resubmit_request': self.resubmit_task,
237 210 'shutdown_request': self.shutdown_request,
238 211 }
239 212
240 213 self.registrar_handlers = {'registration_request' : self.register_engine,
241 214 'unregistration_request' : self.unregister_engine,
242 215 'connection_request': self.connection_request,
243 216 }
244 217 self.registration_timeout = max(5000, 2*self.heartbeat.period)
245 218 # this is the stuff that will move to DB:
246 219 # self.results = {} # completed results
247 220 self.pending = set() # pending messages, keyed by msg_id
248 221 self.queues = {} # pending msg_ids keyed by engine_id
249 222 self.tasks = {} # pending msg_ids submitted as tasks, keyed by client_id
250 223 self.completed = {} # completed msg_ids keyed by engine_id
251 224 self.all_completed = set()
252 225
253 226 logger.info("controller::created controller")
254 227
255 228 def _new_id(self):
256 229 """gemerate a new ID"""
257 230 newid = 0
258 231 incoming = [id[0] for id in self.incoming_registrations.itervalues()]
259 232 # print newid, self.ids, self.incoming_registrations
260 233 while newid in self.ids or newid in incoming:
261 234 newid += 1
262 235 return newid
263 236
264 237 #-----------------------------------------------------------------------------
265 238 # message validation
266 239 #-----------------------------------------------------------------------------
267 240
268 241 def _validate_targets(self, targets):
269 242 """turn any valid targets argument into a list of integer ids"""
270 243 if targets is None:
271 244 # default to all
272 245 targets = self.ids
273 246
274 247 if isinstance(targets, (int,str,unicode)):
275 248 # only one target specified
276 249 targets = [targets]
277 250 _targets = []
278 251 for t in targets:
279 252 # map raw identities to ids
280 253 if isinstance(t, (str,unicode)):
281 254 t = self.by_ident.get(t, t)
282 255 _targets.append(t)
283 256 targets = _targets
284 257 bad_targets = [ t for t in targets if t not in self.ids ]
285 258 if bad_targets:
286 259 raise IndexError("No Such Engine: %r"%bad_targets)
287 260 if not targets:
288 261 raise IndexError("No Engines Registered")
289 262 return targets
290 263
291 264 def _validate_client_msg(self, msg):
292 265 """validates and unpacks headers of a message. Returns False if invalid,
293 266 (ident, header, parent, content)"""
294 267 client_id = msg[0]
295 268 try:
296 269 msg = self.session.unpack_message(msg[1:], content=True)
297 270 except:
298 271 logger.error("client::Invalid Message %s"%msg)
299 272 return False
300 273
301 274 msg_type = msg.get('msg_type', None)
302 275 if msg_type is None:
303 276 return False
304 277 header = msg.get('header')
305 278 # session doesn't handle split content for now:
306 279 return client_id, msg
307 280
308 281
309 282 #-----------------------------------------------------------------------------
310 283 # dispatch methods (1 per stream)
311 284 #-----------------------------------------------------------------------------
312 285
313 286 def dispatch_register_request(self, msg):
314 287 """"""
315 288 logger.debug("registration::dispatch_register_request(%s)"%msg)
316 289 idents,msg = self.session.feed_identities(msg)
317 290 if not idents:
318 291 logger.error("Bad Queue Message: %s"%msg, exc_info=True)
319 292 return
320 293 try:
321 294 msg = self.session.unpack_message(msg,content=True)
322 295 except:
323 296 logger.error("registration::got bad registration message: %s"%msg, exc_info=True)
324 297 return
325 298
326 299 msg_type = msg['msg_type']
327 300 content = msg['content']
328 301
329 302 handler = self.registrar_handlers.get(msg_type, None)
330 303 if handler is None:
331 304 logger.error("registration::got bad registration message: %s"%msg)
332 305 else:
333 306 handler(idents, msg)
334 307
335 308 def dispatch_queue_traffic(self, msg):
336 309 """all ME and Task queue messages come through here"""
337 310 logger.debug("queue traffic: %s"%msg[:2])
338 311 switch = msg[0]
339 312 idents, msg = self.session.feed_identities(msg[1:])
340 313 if not idents:
341 314 logger.error("Bad Queue Message: %s"%msg)
342 315 return
343 316 handler = self.queue_handlers.get(switch, None)
344 317 if handler is not None:
345 318 handler(idents, msg)
346 319 else:
347 320 logger.error("Invalid message topic: %s"%switch)
348 321
349 322
350 323 def dispatch_client_msg(self, msg):
351 324 """Route messages from clients"""
352 325 idents, msg = self.session.feed_identities(msg)
353 326 if not idents:
354 327 logger.error("Bad Client Message: %s"%msg)
355 328 return
356 329 client_id = idents[0]
357 330 try:
358 331 msg = self.session.unpack_message(msg, content=True)
359 332 except:
360 333 content = wrap_exception()
361 334 logger.error("Bad Client Message: %s"%msg, exc_info=True)
362 335 self.session.send(self.clientele, "controller_error", ident=client_id,
363 336 content=content)
364 337 return
365 338
366 339 # print client_id, header, parent, content
367 340 #switch on message type:
368 341 msg_type = msg['msg_type']
369 342 logger.info("client:: client %s requested %s"%(client_id, msg_type))
370 343 handler = self.client_handlers.get(msg_type, None)
371 344 try:
372 345 assert handler is not None, "Bad Message Type: %s"%msg_type
373 346 except:
374 347 content = wrap_exception()
375 348 logger.error("Bad Message Type: %s"%msg_type, exc_info=True)
376 349 self.session.send(self.clientele, "controller_error", ident=client_id,
377 350 content=content)
378 351 return
379 352 else:
380 353 handler(client_id, msg)
381 354
382 355 def dispatch_db(self, msg):
383 356 """"""
384 357 raise NotImplementedError
385 358
386 359 #---------------------------------------------------------------------------
387 360 # handler methods (1 per event)
388 361 #---------------------------------------------------------------------------
389 362
390 363 #----------------------- Heartbeat --------------------------------------
391 364
392 365 def handle_new_heart(self, heart):
393 366 """handler to attach to heartbeater.
394 367 Called when a new heart starts to beat.
395 368 Triggers completion of registration."""
396 369 logger.debug("heartbeat::handle_new_heart(%r)"%heart)
397 370 if heart not in self.incoming_registrations:
398 371 logger.info("heartbeat::ignoring new heart: %r"%heart)
399 372 else:
400 373 self.finish_registration(heart)
401 374
402 375
403 376 def handle_heart_failure(self, heart):
404 377 """handler to attach to heartbeater.
405 378 called when a previously registered heart fails to respond to beat request.
406 379 triggers unregistration"""
407 380 logger.debug("heartbeat::handle_heart_failure(%r)"%heart)
408 381 eid = self.hearts.get(heart, None)
409 382 queue = self.engines[eid].queue
410 383 if eid is None:
411 384 logger.info("heartbeat::ignoring heart failure %r"%heart)
412 385 else:
413 386 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
414 387
415 388 #----------------------- MUX Queue Traffic ------------------------------
416 389
417 390 def save_queue_request(self, idents, msg):
418 391 if len(idents) < 2:
419 392 logger.error("invalid identity prefix: %s"%idents)
420 393 return
421 394 queue_id, client_id = idents[:2]
422 395 try:
423 396 msg = self.session.unpack_message(msg, content=False)
424 397 except:
425 398 logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
426 399 return
427 400
428 401 eid = self.by_ident.get(queue_id, None)
429 402 if eid is None:
430 403 logger.error("queue::target %r not registered"%queue_id)
431 404 logger.debug("queue:: valid are: %s"%(self.by_ident.keys()))
432 405 return
433 406
434 407 header = msg['header']
435 408 msg_id = header['msg_id']
436 409 record = init_record(msg)
437 410 record['engine_uuid'] = queue_id
438 411 record['client_uuid'] = client_id
439 412 record['queue'] = 'mux'
440 413 if MongoDB is not None and isinstance(self.db, MongoDB):
441 414 record['buffers'] = map(Binary, record['buffers'])
442 415 self.pending.add(msg_id)
443 416 self.queues[eid].append(msg_id)
444 417 self.db.add_record(msg_id, record)
445 418
446 419 def save_queue_result(self, idents, msg):
447 420 if len(idents) < 2:
448 421 logger.error("invalid identity prefix: %s"%idents)
449 422 return
450 423
451 424 client_id, queue_id = idents[:2]
452 425 try:
453 426 msg = self.session.unpack_message(msg, content=False)
454 427 except:
455 428 logger.error("queue::engine %r sent invalid message to %r: %s"%(
456 429 queue_id,client_id, msg), exc_info=True)
457 430 return
458 431
459 432 eid = self.by_ident.get(queue_id, None)
460 433 if eid is None:
461 434 logger.error("queue::unknown engine %r is sending a reply: "%queue_id)
462 435 logger.debug("queue:: %s"%msg[2:])
463 436 return
464 437
465 438 parent = msg['parent_header']
466 439 if not parent:
467 440 return
468 441 msg_id = parent['msg_id']
469 442 if msg_id in self.pending:
470 443 self.pending.remove(msg_id)
471 444 self.all_completed.add(msg_id)
472 445 self.queues[eid].remove(msg_id)
473 446 self.completed[eid].append(msg_id)
474 447 rheader = msg['header']
475 448 completed = datetime.strptime(rheader['date'], ISO8601)
476 449 started = rheader.get('started', None)
477 450 if started is not None:
478 451 started = datetime.strptime(started, ISO8601)
479 452 result = {
480 453 'result_header' : rheader,
481 454 'result_content': msg['content'],
482 455 'started' : started,
483 456 'completed' : completed
484 457 }
485 458 if MongoDB is not None and isinstance(self.db, MongoDB):
486 459 result['result_buffers'] = map(Binary, msg['buffers'])
460 else:
461 result['result_buffers'] = msg['buffers']
487 462 self.db.update_record(msg_id, result)
488 463 else:
489 464 logger.debug("queue:: unknown msg finished %s"%msg_id)
490 465
491 466 #--------------------- Task Queue Traffic ------------------------------
492 467
493 468 def save_task_request(self, idents, msg):
494 469 """Save the submission of a task."""
495 470 client_id = idents[0]
496 471
497 472 try:
498 473 msg = self.session.unpack_message(msg, content=False)
499 474 except:
500 475 logger.error("task::client %r sent invalid task message: %s"%(
501 476 client_id, msg), exc_info=True)
502 477 return
503 478 record = init_record(msg)
504 479 if MongoDB is not None and isinstance(self.db, MongoDB):
505 480 record['buffers'] = map(Binary, record['buffers'])
506 481 record['client_uuid'] = client_id
507 482 record['queue'] = 'task'
508 483 header = msg['header']
509 484 msg_id = header['msg_id']
510 485 self.pending.add(msg_id)
511 486 self.db.add_record(msg_id, record)
512 487
513 488 def save_task_result(self, idents, msg):
514 489 """save the result of a completed task."""
515 490 client_id = idents[0]
516 491 try:
517 492 msg = self.session.unpack_message(msg, content=False)
518 493 except:
519 494 logger.error("task::invalid task result message send to %r: %s"%(
520 495 client_id, msg))
521 496 raise
522 497 return
523 498
524 499 parent = msg['parent_header']
525 500 if not parent:
526 501 # print msg
527 502 logger.warn("Task %r had no parent!"%msg)
528 503 return
529 504 msg_id = parent['msg_id']
530 505
531 506 header = msg['header']
532 507 engine_uuid = header.get('engine', None)
533 508 eid = self.by_ident.get(engine_uuid, None)
534 509
535 510 if msg_id in self.pending:
536 511 self.pending.remove(msg_id)
537 512 self.all_completed.add(msg_id)
538 513 if eid is not None:
539 514 self.completed[eid].append(msg_id)
540 515 if msg_id in self.tasks[eid]:
541 516 self.tasks[eid].remove(msg_id)
542 517 completed = datetime.strptime(header['date'], ISO8601)
543 518 started = header.get('started', None)
544 519 if started is not None:
545 520 started = datetime.strptime(started, ISO8601)
546 521 result = {
547 522 'result_header' : header,
548 523 'result_content': msg['content'],
549 524 'started' : started,
550 525 'completed' : completed,
551 526 'engine_uuid': engine_uuid
552 527 }
553 528 if MongoDB is not None and isinstance(self.db, MongoDB):
554 529 result['result_buffers'] = map(Binary, msg['buffers'])
530 else:
531 result['result_buffers'] = msg['buffers']
555 532 self.db.update_record(msg_id, result)
556 533
557 534 else:
558 535 logger.debug("task::unknown task %s finished"%msg_id)
559 536
560 537 def save_task_destination(self, idents, msg):
561 538 try:
562 539 msg = self.session.unpack_message(msg, content=True)
563 540 except:
564 541 logger.error("task::invalid task tracking message")
565 542 return
566 543 content = msg['content']
567 544 print (content)
568 545 msg_id = content['msg_id']
569 546 engine_uuid = content['engine_id']
570 547 eid = self.by_ident[engine_uuid]
571 548
572 549 logger.info("task::task %s arrived on %s"%(msg_id, eid))
573 550 # if msg_id in self.mia:
574 551 # self.mia.remove(msg_id)
575 552 # else:
576 553 # logger.debug("task::task %s not listed as MIA?!"%(msg_id))
577 554
578 555 self.tasks[eid].append(msg_id)
579 556 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
580 557 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
581 558
582 559 def mia_task_request(self, idents, msg):
583 560 raise NotImplementedError
584 561 client_id = idents[0]
585 562 # content = dict(mia=self.mia,status='ok')
586 563 # self.session.send('mia_reply', content=content, idents=client_id)
587 564
588 565
589 566
590 567 #-------------------------------------------------------------------------
591 568 # Registration requests
592 569 #-------------------------------------------------------------------------
593 570
594 571 def connection_request(self, client_id, msg):
595 572 """Reply with connection addresses for clients."""
596 573 logger.info("client::client %s connected"%client_id)
597 574 content = dict(status='ok')
598 575 content.update(self.client_addrs)
599 576 jsonable = {}
600 577 for k,v in self.keytable.iteritems():
601 578 jsonable[str(k)] = v
602 579 content['engines'] = jsonable
603 580 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
604 581
605 582 def register_engine(self, reg, msg):
606 583 """Register a new engine."""
607 584 content = msg['content']
608 585 try:
609 586 queue = content['queue']
610 587 except KeyError:
611 588 logger.error("registration::queue not specified")
612 589 return
613 590 heart = content.get('heartbeat', None)
614 591 """register a new engine, and create the socket(s) necessary"""
615 592 eid = self._new_id()
616 593 # print (eid, queue, reg, heart)
617 594
618 595 logger.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
619 596
620 597 content = dict(id=eid,status='ok')
621 598 content.update(self.engine_addrs)
622 599 # check if requesting available IDs:
623 600 if queue in self.by_ident:
624 601 try:
625 602 raise KeyError("queue_id %r in use"%queue)
626 603 except:
627 604 content = wrap_exception()
628 605 elif heart in self.hearts: # need to check unique hearts?
629 606 try:
630 607 raise KeyError("heart_id %r in use"%heart)
631 608 except:
632 609 content = wrap_exception()
633 610 else:
634 611 for h, pack in self.incoming_registrations.iteritems():
635 612 if heart == h:
636 613 try:
637 614 raise KeyError("heart_id %r in use"%heart)
638 615 except:
639 616 content = wrap_exception()
640 617 break
641 618 elif queue == pack[1]:
642 619 try:
643 620 raise KeyError("queue_id %r in use"%queue)
644 621 except:
645 622 content = wrap_exception()
646 623 break
647 624
648 625 msg = self.session.send(self.registrar, "registration_reply",
649 626 content=content,
650 627 ident=reg)
651 628
652 629 if content['status'] == 'ok':
653 630 if heart in self.heartbeat.hearts:
654 631 # already beating
655 632 self.incoming_registrations[heart] = (eid,queue,reg,None)
656 633 self.finish_registration(heart)
657 634 else:
658 635 purge = lambda : self._purge_stalled_registration(heart)
659 636 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
660 637 dc.start()
661 638 self.incoming_registrations[heart] = (eid,queue,reg,dc)
662 639 else:
663 640 logger.error("registration::registration %i failed: %s"%(eid, content['evalue']))
664 641 return eid
665 642
666 643 def unregister_engine(self, ident, msg):
667 644 """Unregister an engine that explicitly requested to leave."""
668 645 try:
669 646 eid = msg['content']['id']
670 647 except:
671 648 logger.error("registration::bad engine id for unregistration: %s"%ident)
672 649 return
673 650 logger.info("registration::unregister_engine(%s)"%eid)
674 651 content=dict(id=eid, queue=self.engines[eid].queue)
675 652 self.ids.remove(eid)
676 653 self.keytable.pop(eid)
677 654 ec = self.engines.pop(eid)
678 655 self.hearts.pop(ec.heartbeat)
679 656 self.by_ident.pop(ec.queue)
680 657 self.completed.pop(eid)
681 658 for msg_id in self.queues.pop(eid):
682 659 msg = self.pending.remove(msg_id)
683 660 ############## TODO: HANDLE IT ################
684 661
685 662 if self.notifier:
686 663 self.session.send(self.notifier, "unregistration_notification", content=content)
687 664
688 665 def finish_registration(self, heart):
689 666 """Second half of engine registration, called after our HeartMonitor
690 667 has received a beat from the Engine's Heart."""
691 668 try:
692 669 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
693 670 except KeyError:
694 671 logger.error("registration::tried to finish nonexistant registration")
695 672 return
696 673 logger.info("registration::finished registering engine %i:%r"%(eid,queue))
697 674 if purge is not None:
698 675 purge.stop()
699 676 control = queue
700 677 self.ids.add(eid)
701 678 self.keytable[eid] = queue
702 679 self.engines[eid] = EngineConnector(eid, queue, reg, control, heart)
703 680 self.by_ident[queue] = eid
704 681 self.queues[eid] = list()
705 682 self.tasks[eid] = list()
706 683 self.completed[eid] = list()
707 684 self.hearts[heart] = eid
708 685 content = dict(id=eid, queue=self.engines[eid].queue)
709 686 if self.notifier:
710 687 self.session.send(self.notifier, "registration_notification", content=content)
711 688
712 689 def _purge_stalled_registration(self, heart):
713 690 if heart in self.incoming_registrations:
714 691 eid = self.incoming_registrations.pop(heart)[0]
715 692 logger.info("registration::purging stalled registration: %i"%eid)
716 693 else:
717 694 pass
718 695
719 696 #-------------------------------------------------------------------------
720 697 # Client Requests
721 698 #-------------------------------------------------------------------------
722 699
723 700 def shutdown_request(self, client_id, msg):
724 701 """handle shutdown request."""
725 702 # s = self.context.socket(zmq.XREQ)
726 703 # s.connect(self.client_connections['mux'])
727 704 # time.sleep(0.1)
728 705 # for eid,ec in self.engines.iteritems():
729 706 # self.session.send(s, 'shutdown_request', content=dict(restart=False), ident=ec.queue)
730 707 # time.sleep(1)
731 708 self.session.send(self.clientele, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
732 709 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
733 710 dc.start()
734 711
735 712 def _shutdown(self):
736 713 logger.info("controller::controller shutting down.")
737 714 time.sleep(0.1)
738 715 sys.exit(0)
739 716
740 717
741 718 def check_load(self, client_id, msg):
742 719 content = msg['content']
743 720 try:
744 721 targets = content['targets']
745 722 targets = self._validate_targets(targets)
746 723 except:
747 724 content = wrap_exception()
748 725 self.session.send(self.clientele, "controller_error",
749 726 content=content, ident=client_id)
750 727 return
751 728
752 729 content = dict(status='ok')
753 730 # loads = {}
754 731 for t in targets:
755 732 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
756 733 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
757 734
758 735
759 736 def queue_status(self, client_id, msg):
760 737 """Return the Queue status of one or more targets.
761 738 if verbose: return the msg_ids
762 739 else: return len of each type.
763 740 keys: queue (pending MUX jobs)
764 741 tasks (pending Task jobs)
765 742 completed (finished jobs from both queues)"""
766 743 content = msg['content']
767 744 targets = content['targets']
768 745 try:
769 746 targets = self._validate_targets(targets)
770 747 except:
771 748 content = wrap_exception()
772 749 self.session.send(self.clientele, "controller_error",
773 750 content=content, ident=client_id)
774 751 return
775 752 verbose = content.get('verbose', False)
776 753 content = dict(status='ok')
777 754 for t in targets:
778 755 queue = self.queues[t]
779 756 completed = self.completed[t]
780 757 tasks = self.tasks[t]
781 758 if not verbose:
782 759 queue = len(queue)
783 760 completed = len(completed)
784 761 tasks = len(tasks)
785 762 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
786 763 # pending
787 764 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
788 765
789 766 def purge_results(self, client_id, msg):
790 767 """Purge results from memory. This method is more valuable before we move
791 768 to a DB based message storage mechanism."""
792 769 content = msg['content']
793 770 msg_ids = content.get('msg_ids', [])
794 771 reply = dict(status='ok')
795 772 if msg_ids == 'all':
796 773 self.db.drop_matching_records(dict(completed={'$ne':None}))
797 774 else:
798 775 for msg_id in msg_ids:
799 776 if msg_id in self.all_completed:
800 777 self.db.drop_record(msg_id)
801 778 else:
802 779 if msg_id in self.pending:
803 780 try:
804 781 raise IndexError("msg pending: %r"%msg_id)
805 782 except:
806 783 reply = wrap_exception()
807 784 else:
808 785 try:
809 786 raise IndexError("No such msg: %r"%msg_id)
810 787 except:
811 788 reply = wrap_exception()
812 789 break
813 790 eids = content.get('engine_ids', [])
814 791 for eid in eids:
815 792 if eid not in self.engines:
816 793 try:
817 794 raise IndexError("No such engine: %i"%eid)
818 795 except:
819 796 reply = wrap_exception()
820 797 break
821 798 msg_ids = self.completed.pop(eid)
822 799 uid = self.engines[eid].queue
823 800 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
824 801
825 802 self.session.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
826 803
827 804 def resubmit_task(self, client_id, msg, buffers):
828 805 """Resubmit a task."""
829 806 raise NotImplementedError
830 807
831 808 def get_results(self, client_id, msg):
832 809 """Get the result of 1 or more messages."""
833 810 content = msg['content']
834 msg_ids = set(content['msg_ids'])
811 msg_ids = sorted(set(content['msg_ids']))
835 812 statusonly = content.get('status_only', False)
836 813 pending = []
837 814 completed = []
838 815 content = dict(status='ok')
839 816 content['pending'] = pending
840 817 content['completed'] = completed
818 buffers = []
841 819 if not statusonly:
820 content['results'] = {}
842 821 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
843 822 for msg_id in msg_ids:
844 823 if msg_id in self.pending:
845 824 pending.append(msg_id)
846 825 elif msg_id in self.all_completed:
847 826 completed.append(msg_id)
848 827 if not statusonly:
849 content[msg_id] = records[msg_id]['result_content']
828 rec = records[msg_id]
829 content[msg_id] = { 'result_content': rec['result_content'],
830 'header': rec['header'],
831 'result_header' : rec['result_header'],
832 }
833 buffers.extend(map(str, rec['result_buffers']))
850 834 else:
851 835 try:
852 836 raise KeyError('No such message: '+msg_id)
853 837 except:
854 838 content = wrap_exception()
855 839 break
856 840 self.session.send(self.clientele, "result_reply", content=content,
857 parent=msg, ident=client_id)
841 parent=msg, ident=client_id,
842 buffers=buffers)
858 843
859 844
860 845 #-------------------------------------------------------------------------
861 846 # Entry Point
862 847 #-------------------------------------------------------------------------
863 848
864 849 def make_argument_parser():
865 850 """Make an argument parser"""
866 851 parser = make_base_argument_parser()
867 852
868 853 parser.add_argument('--client', type=int, metavar='PORT', default=0,
869 854 help='set the XREP port for clients [default: random]')
870 855 parser.add_argument('--notice', type=int, metavar='PORT', default=0,
871 856 help='set the PUB socket for registration notification [default: random]')
872 857 parser.add_argument('--hb', type=str, metavar='PORTS',
873 858 help='set the 2 ports for heartbeats [default: random]')
874 859 parser.add_argument('--ping', type=int, default=3000,
875 860 help='set the heartbeat period in ms [default: 3000]')
876 861 parser.add_argument('--monitor', type=int, metavar='PORT', default=0,
877 862 help='set the SUB port for queue monitoring [default: random]')
878 863 parser.add_argument('--mux', type=str, metavar='PORTS',
879 864 help='set the XREP ports for the MUX queue [default: random]')
880 865 parser.add_argument('--task', type=str, metavar='PORTS',
881 866 help='set the XREP/XREQ ports for the task queue [default: random]')
882 867 parser.add_argument('--control', type=str, metavar='PORTS',
883 868 help='set the XREP ports for the control queue [default: random]')
884 869 parser.add_argument('--scheduler', type=str, default='pure',
885 870 choices = ['pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'],
886 871 help='select the task scheduler [default: pure ZMQ]')
887 872 parser.add_argument('--mongodb', action='store_true',
888 873 help='Use MongoDB task storage [default: in-memory]')
889 874
890 875 return parser
891 876
892 877 def main(argv=None):
893 878 import time
894 879 from multiprocessing import Process
895 880
896 881 from zmq.eventloop.zmqstream import ZMQStream
897 882 from zmq.devices import ProcessMonitoredQueue
898 883 from zmq.log import handlers
899 884
900 885 import streamsession as session
901 886 import heartmonitor
902 887 from scheduler import launch_scheduler
903 888
904 889 parser = make_argument_parser()
905 890
906 891 args = parser.parse_args(argv)
907 892 parse_url(args)
908 893
909 894 iface="%s://%s"%(args.transport,args.ip)+':%i'
910 895
911 896 random_ports = 0
912 897 if args.hb:
913 898 hb = split_ports(args.hb, 2)
914 899 else:
915 900 hb = select_random_ports(2)
916 901 if args.mux:
917 902 mux = split_ports(args.mux, 2)
918 903 else:
919 904 mux = None
920 905 random_ports += 2
921 906 if args.task:
922 907 task = split_ports(args.task, 2)
923 908 else:
924 909 task = None
925 910 random_ports += 2
926 911 if args.control:
927 912 control = split_ports(args.control, 2)
928 913 else:
929 914 control = None
930 915 random_ports += 2
931 916
932 917 ctx = zmq.Context()
933 918 loop = ioloop.IOLoop.instance()
934 919
935 920 # setup logging
936 921 connect_logger(ctx, iface%args.logport, root="controller", loglevel=args.loglevel)
937 922
938 923 # Registrar socket
939 924 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
940 925 regport = bind_port(reg, args.ip, args.regport)
941 926
942 927 ### Engine connections ###
943 928
944 929 # heartbeat
945 930 hpub = ctx.socket(zmq.PUB)
946 931 bind_port(hpub, args.ip, hb[0])
947 932 hrep = ctx.socket(zmq.XREP)
948 933 bind_port(hrep, args.ip, hb[1])
949 934
950 935 hmon = heartmonitor.HeartMonitor(loop, ZMQStream(hpub,loop), ZMQStream(hrep,loop),args.ping)
951 936 hmon.start()
952 937
953 938 ### Client connections ###
954 939 # Clientele socket
955 940 c = ZMQStream(ctx.socket(zmq.XREP), loop)
956 941 cport = bind_port(c, args.ip, args.client)
957 942 # Notifier socket
958 943 n = ZMQStream(ctx.socket(zmq.PUB), loop)
959 944 nport = bind_port(n, args.ip, args.notice)
960 945
961 946 ### Key File ###
962 947 if args.execkey and not os.path.isfile(args.execkey):
963 948 generate_exec_key(args.execkey)
964 949
965 950 thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey)
966 951
967 952 ### build and launch the queues ###
968 953
969 954 # monitor socket
970 955 sub = ctx.socket(zmq.SUB)
971 956 sub.setsockopt(zmq.SUBSCRIBE, "")
972 957 monport = bind_port(sub, args.ip, args.monitor)
973 958 sub = ZMQStream(sub, loop)
974 959
975 960 ports = select_random_ports(random_ports)
976 961 children = []
977 962 # Multiplexer Queue (in a Process)
978 963 if not mux:
979 964 mux = (ports.pop(),ports.pop())
980 965 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
981 966 q.bind_in(iface%mux[0])
982 967 q.bind_out(iface%mux[1])
983 968 q.connect_mon(iface%monport)
984 969 q.daemon=True
985 970 q.start()
986 971 children.append(q.launcher)
987 972
988 973 # Control Queue (in a Process)
989 974 if not control:
990 975 control = (ports.pop(),ports.pop())
991 976 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
992 977 q.bind_in(iface%control[0])
993 978 q.bind_out(iface%control[1])
994 979 q.connect_mon(iface%monport)
995 980 q.daemon=True
996 981 q.start()
997 982 children.append(q.launcher)
998 983 # Task Queue (in a Process)
999 984 if not task:
1000 985 task = (ports.pop(),ports.pop())
1001 986 if args.scheduler == 'pure':
1002 987 q = ProcessMonitoredQueue(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
1003 988 q.bind_in(iface%task[0])
1004 989 q.bind_out(iface%task[1])
1005 990 q.connect_mon(iface%monport)
1006 991 q.daemon=True
1007 992 q.start()
1008 993 children.append(q.launcher)
1009 994 else:
1010 995 sargs = (iface%task[0],iface%task[1],iface%monport,iface%nport,args.scheduler)
1011 996 print (sargs)
1012 997 q = Process(target=launch_scheduler, args=sargs)
1013 998 q.daemon=True
1014 999 q.start()
1015 1000 children.append(q)
1016 1001
1017 1002 if args.mongodb:
1018 1003 from mongodb import MongoDB
1019 1004 db = MongoDB(thesession.session)
1020 1005 else:
1021 1006 db = DictDB()
1022 1007 time.sleep(.25)
1023 1008
1024 1009 # build connection dicts
1025 1010 engine_addrs = {
1026 1011 'control' : iface%control[1],
1027 1012 'queue': iface%mux[1],
1028 1013 'heartbeat': (iface%hb[0], iface%hb[1]),
1029 1014 'task' : iface%task[1],
1030 1015 'monitor' : iface%monport,
1031 1016 }
1032 1017
1033 1018 client_addrs = {
1034 1019 'control' : iface%control[0],
1035 1020 'query': iface%cport,
1036 1021 'queue': iface%mux[0],
1037 1022 'task' : iface%task[0],
1038 1023 'notification': iface%nport
1039 1024 }
1040 1025 signal_children(children)
1041 1026 con = Controller(loop, thesession, sub, reg, hmon, c, n, db, engine_addrs, client_addrs)
1042 1027 dc = ioloop.DelayedCallback(lambda : print("Controller started..."), 100, loop)
1043 1028 dc.start()
1044 1029 loop.start()
1045 1030
1046 1031
1047 1032
1048 1033
1049 1034 if __name__ == '__main__':
1050 1035 main()
@@ -1,56 +1,60
1 1 """A TaskRecord backend using mongodb"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 from datetime import datetime
10 10
11 11 from pymongo import Connection
12 12
13 13 #----------------------
14 14 # MongoDB class
15 15 #----------------------
16 16 class MongoDB(object):
17 17 """MongoDB TaskRecord backend."""
18 18 def __init__(self, session_uuid, *args, **kwargs):
19 19 self._connection = Connection(*args, **kwargs)
20 20 self._db = self._connection[session_uuid]
21 21 self._records = self._db['task_records']
22 22 self._table = {}
23 23
24 24
25 25 def add_record(self, msg_id, rec):
26 26 """Add a new Task Record, by msg_id."""
27 27 # print rec
28 28 obj_id = self._records.insert(rec)
29 29 self._table[msg_id] = obj_id
30 30
31 31 def get_record(self, msg_id):
32 32 """Get a specific Task Record, by msg_id."""
33 33 return self._records.find_one(self._table[msg_id])
34 34
35 35 def update_record(self, msg_id, rec):
36 36 """Update the data in an existing record."""
37 37 obj_id = self._table[msg_id]
38 self._records.update({'_id':obj_id}, rec)
38 self._records.update({'_id':obj_id}, {'$set': rec})
39 39
40 40 def drop_matching_records(self, check):
41 41 """Remove a record from the DB."""
42 42 self._records.remove(check)
43 43
44 44 def drop_record(self, msg_id):
45 45 """Remove a record from the DB."""
46 46 obj_id = self._table.pop(msg_id)
47 47 self._records.remove(obj_id)
48 48
49 49 def find_records(self, check, id_only=False):
50 50 """Find records matching a query dict."""
51 51 matches = list(self._records.find(check))
52 52 if id_only:
53 matches = [ rec['msg_id'] for rec in matches ]
54 return matches
53 return [ rec['msg_id'] for rec in matches ]
54 else:
55 data = {}
56 for rec in matches:
57 data[rec['msg_id']] = rec
58 return data
55 59
56 60
@@ -1,145 +1,145
1 1 """Remote Functions and decorators for the client."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import map as Map
14 14 from asyncresult import AsyncMapResult
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Decorators
18 18 #-----------------------------------------------------------------------------
19 19
20 20 def remote(client, bound=False, block=None, targets=None):
21 21 """Turn a function into a remote function.
22 22
23 23 This method can be used for map:
24 24
25 25 >>> @remote(client,block=True)
26 26 def func(a)
27 27 """
28 28 def remote_function(f):
29 29 return RemoteFunction(client, f, bound, block, targets)
30 30 return remote_function
31 31
32 32 def parallel(client, dist='b', bound=False, block=None, targets='all'):
33 33 """Turn a function into a parallel remote function.
34 34
35 35 This method can be used for map:
36 36
37 37 >>> @parallel(client,block=True)
38 38 def func(a)
39 39 """
40 40 def parallel_function(f):
41 41 return ParallelFunction(client, f, dist, bound, block, targets)
42 42 return parallel_function
43 43
44 44 #--------------------------------------------------------------------------
45 45 # Classes
46 46 #--------------------------------------------------------------------------
47 47
48 48 class RemoteFunction(object):
49 49 """Turn an existing function into a remote function.
50 50
51 51 Parameters
52 52 ----------
53 53
54 54 client : Client instance
55 55 The client to be used to connect to engines
56 56 f : callable
57 57 The function to be wrapped into a remote function
58 58 bound : bool [default: False]
59 59 Whether the affect the remote namespace when called
60 60 block : bool [default: None]
61 61 Whether to wait for results or not. The default behavior is
62 62 to use the current `block` attribute of `client`
63 63 targets : valid target list [default: all]
64 64 The targets on which to execute.
65 65 """
66 66
67 67 client = None # the remote connection
68 68 func = None # the wrapped function
69 69 block = None # whether to block
70 70 bound = None # whether to affect the namespace
71 71 targets = None # where to execute
72 72
73 73 def __init__(self, client, f, bound=False, block=None, targets=None):
74 74 self.client = client
75 75 self.func = f
76 76 self.block=block
77 77 self.bound=bound
78 78 self.targets=targets
79 79
80 80 def __call__(self, *args, **kwargs):
81 81 return self.client.apply(self.func, args=args, kwargs=kwargs,
82 82 block=self.block, targets=self.targets, bound=self.bound)
83 83
84 84
85 85 class ParallelFunction(RemoteFunction):
86 86 """Class for mapping a function to sequences."""
87 87 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all'):
88 88 super(ParallelFunction, self).__init__(client,f,bound,block,targets)
89 89 mapClass = Map.dists[dist]
90 90 self.mapObject = mapClass()
91 91
92 92 def __call__(self, *sequences):
93 93 len_0 = len(sequences[0])
94 94 for s in sequences:
95 95 if len(s)!=len_0:
96 96 raise ValueError('all sequences must have equal length')
97 97
98 98 if self.targets is None:
99 99 # load-balanced:
100 100 engines = [None]*len_0
101 101 elif isinstance(self.targets, int):
102 102 engines = [None]*self.targets
103 103 else:
104 104 # multiplexed:
105 105 engines = self.client._build_targets(self.targets)[-1]
106 106
107 107 nparts = len(engines)
108 108 msg_ids = []
109 109 # my_f = lambda *a: map(self.func, *a)
110 110 for index, engineid in enumerate(engines):
111 111 args = []
112 112 for seq in sequences:
113 113 part = self.mapObject.getPartition(seq, index, nparts)
114 114 if not part:
115 115 continue
116 116 else:
117 117 args.append(part)
118 118 if not args:
119 119 continue
120 120
121 121 # print (args)
122 122 if hasattr(self, '_map'):
123 123 f = map
124 124 args = [self.func]+args
125 125 else:
126 126 f=self.func
127 127 mid = self.client.apply(f, args=args, block=False,
128 128 bound=self.bound,
129 targets=engineid)._msg_ids[0]
129 targets=engineid).msg_ids[0]
130 130 msg_ids.append(mid)
131 131
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject)
132 r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
133 133 if self.block:
134 134 r.wait()
135 135 return r.result
136 136 else:
137 137 return r
138 138
139 139 def map(self, *sequences):
140 140 """call a function on each element of a sequence remotely."""
141 141 self._map = True
142 142 ret = self.__call__(*sequences)
143 143 del self._map
144 144 return ret
145 145
@@ -1,544 +1,544
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import sys
8 8 import traceback
9 9 import pprint
10 10 import uuid
11 11 from datetime import datetime
12 12
13 13 import zmq
14 14 from zmq.utils import jsonapi
15 15 from zmq.eventloop.zmqstream import ZMQStream
16 16
17 17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 18 from IPython.utils.newserialized import serialize, unserialize
19 19
20 20 from IPython.zmq.parallel.error import RemoteError
21 21
22 22 try:
23 23 import cPickle
24 24 pickle = cPickle
25 25 except:
26 26 cPickle = None
27 27 import pickle
28 28
29 29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
31 31 if json_name in ('jsonlib', 'jsonlib2'):
32 32 use_json = True
33 33 elif json_name:
34 34 if cPickle is None:
35 35 use_json = True
36 36 else:
37 37 use_json = False
38 38 else:
39 39 use_json = False
40 40
41 41 def squash_unicode(obj):
42 42 if isinstance(obj,dict):
43 43 for key in obj.keys():
44 44 obj[key] = squash_unicode(obj[key])
45 45 if isinstance(key, unicode):
46 46 obj[squash_unicode(key)] = obj.pop(key)
47 47 elif isinstance(obj, list):
48 48 for i,v in enumerate(obj):
49 49 obj[i] = squash_unicode(v)
50 50 elif isinstance(obj, unicode):
51 51 obj = obj.encode('utf8')
52 52 return obj
53 53
54 54 if use_json:
55 55 default_packer = jsonapi.dumps
56 56 default_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
57 57 else:
58 58 default_packer = lambda o: pickle.dumps(o,-1)
59 59 default_unpacker = pickle.loads
60 60
61 61
62 62 DELIM="<IDS|MSG>"
63 63 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
64 64
65 65 def wrap_exception(engine_info={}):
66 66 etype, evalue, tb = sys.exc_info()
67 67 stb = traceback.format_exception(etype, evalue, tb)
68 68 exc_content = {
69 69 'status' : 'error',
70 70 'traceback' : stb,
71 71 'ename' : unicode(etype.__name__),
72 72 'evalue' : unicode(evalue),
73 73 'engine_info' : engine_info
74 74 }
75 75 return exc_content
76 76
77 77 def unwrap_exception(content):
78 78 err = RemoteError(content['ename'], content['evalue'],
79 79 ''.join(content['traceback']),
80 80 content.get('engine_info', {}))
81 81 return err
82 82
83 83
84 84 class Message(object):
85 85 """A simple message object that maps dict keys to attributes.
86 86
87 87 A Message can be created from a dict and a dict from a Message instance
88 88 simply by calling dict(msg_obj)."""
89 89
90 90 def __init__(self, msg_dict):
91 91 dct = self.__dict__
92 92 for k, v in dict(msg_dict).iteritems():
93 93 if isinstance(v, dict):
94 94 v = Message(v)
95 95 dct[k] = v
96 96
97 97 # Having this iterator lets dict(msg_obj) work out of the box.
98 98 def __iter__(self):
99 99 return iter(self.__dict__.iteritems())
100 100
101 101 def __repr__(self):
102 102 return repr(self.__dict__)
103 103
104 104 def __str__(self):
105 105 return pprint.pformat(self.__dict__)
106 106
107 107 def __contains__(self, k):
108 108 return k in self.__dict__
109 109
110 110 def __getitem__(self, k):
111 111 return self.__dict__[k]
112 112
113 113
114 114 def msg_header(msg_id, msg_type, username, session):
115 115 date=datetime.now().strftime(ISO8601)
116 116 return locals()
117 117
118 118 def extract_header(msg_or_header):
119 119 """Given a message or header, return the header."""
120 120 if not msg_or_header:
121 121 return {}
122 122 try:
123 123 # See if msg_or_header is the entire message.
124 124 h = msg_or_header['header']
125 125 except KeyError:
126 126 try:
127 127 # See if msg_or_header is just the header
128 128 h = msg_or_header['msg_id']
129 129 except KeyError:
130 130 raise
131 131 else:
132 132 h = msg_or_header
133 133 if not isinstance(h, dict):
134 134 h = dict(h)
135 135 return h
136 136
137 137 def rekey(dikt):
138 138 """Rekey a dict that has been forced to use str keys where there should be
139 139 ints by json. This belongs in the jsonutil added by fperez."""
140 140 for k in dikt.iterkeys():
141 141 if isinstance(k, str):
142 142 ik=fk=None
143 143 try:
144 144 ik = int(k)
145 145 except ValueError:
146 146 try:
147 147 fk = float(k)
148 148 except ValueError:
149 149 continue
150 150 if ik is not None:
151 151 nk = ik
152 152 else:
153 153 nk = fk
154 154 if nk in dikt:
155 155 raise KeyError("already have key %r"%nk)
156 156 dikt[nk] = dikt.pop(k)
157 157 return dikt
158 158
159 159 def serialize_object(obj, threshold=64e-6):
160 160 """Serialize an object into a list of sendable buffers.
161 161
162 162 Parameters
163 163 ----------
164 164
165 165 obj : object
166 166 The object to be serialized
167 167 threshold : float
168 168 The threshold for not double-pickling the content.
169 169
170 170
171 171 Returns
172 172 -------
173 173 ('pmd', [bufs]) :
174 174 where pmd is the pickled metadata wrapper,
175 175 bufs is a list of data buffers
176 176 """
177 177 databuffers = []
178 178 if isinstance(obj, (list, tuple)):
179 179 clist = canSequence(obj)
180 180 slist = map(serialize, clist)
181 181 for s in slist:
182 182 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
183 183 databuffers.append(s.getData())
184 184 s.data = None
185 185 return pickle.dumps(slist,-1), databuffers
186 186 elif isinstance(obj, dict):
187 187 sobj = {}
188 188 for k in sorted(obj.iterkeys()):
189 189 s = serialize(can(obj[k]))
190 190 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
191 191 databuffers.append(s.getData())
192 192 s.data = None
193 193 sobj[k] = s
194 194 return pickle.dumps(sobj,-1),databuffers
195 195 else:
196 196 s = serialize(can(obj))
197 197 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
198 198 databuffers.append(s.getData())
199 199 s.data = None
200 200 return pickle.dumps(s,-1),databuffers
201 201
202 202
203 203 def unserialize_object(bufs):
204 204 """reconstruct an object serialized by serialize_object from data buffers."""
205 205 bufs = list(bufs)
206 206 sobj = pickle.loads(bufs.pop(0))
207 207 if isinstance(sobj, (list, tuple)):
208 208 for s in sobj:
209 209 if s.data is None:
210 210 s.data = bufs.pop(0)
211 return uncanSequence(map(unserialize, sobj))
211 return uncanSequence(map(unserialize, sobj)), bufs
212 212 elif isinstance(sobj, dict):
213 213 newobj = {}
214 214 for k in sorted(sobj.iterkeys()):
215 215 s = sobj[k]
216 216 if s.data is None:
217 217 s.data = bufs.pop(0)
218 218 newobj[k] = uncan(unserialize(s))
219 return newobj
219 return newobj, bufs
220 220 else:
221 221 if sobj.data is None:
222 222 sobj.data = bufs.pop(0)
223 return uncan(unserialize(sobj))
223 return uncan(unserialize(sobj)), bufs
224 224
225 225 def pack_apply_message(f, args, kwargs, threshold=64e-6):
226 226 """pack up a function, args, and kwargs to be sent over the wire
227 227 as a series of buffers. Any object whose data is larger than `threshold`
228 228 will not have their data copied (currently only numpy arrays support zero-copy)"""
229 229 msg = [pickle.dumps(can(f),-1)]
230 230 databuffers = [] # for large objects
231 231 sargs, bufs = serialize_object(args,threshold)
232 232 msg.append(sargs)
233 233 databuffers.extend(bufs)
234 234 skwargs, bufs = serialize_object(kwargs,threshold)
235 235 msg.append(skwargs)
236 236 databuffers.extend(bufs)
237 237 msg.extend(databuffers)
238 238 return msg
239 239
240 240 def unpack_apply_message(bufs, g=None, copy=True):
241 241 """unpack f,args,kwargs from buffers packed by pack_apply_message()
242 242 Returns: original f,args,kwargs"""
243 243 bufs = list(bufs) # allow us to pop
244 244 assert len(bufs) >= 3, "not enough buffers!"
245 245 if not copy:
246 246 for i in range(3):
247 247 bufs[i] = bufs[i].bytes
248 248 cf = pickle.loads(bufs.pop(0))
249 249 sargs = list(pickle.loads(bufs.pop(0)))
250 250 skwargs = dict(pickle.loads(bufs.pop(0)))
251 251 # print sargs, skwargs
252 252 f = uncan(cf, g)
253 253 for sa in sargs:
254 254 if sa.data is None:
255 255 m = bufs.pop(0)
256 256 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
257 257 if copy:
258 258 sa.data = buffer(m)
259 259 else:
260 260 sa.data = m.buffer
261 261 else:
262 262 if copy:
263 263 sa.data = m
264 264 else:
265 265 sa.data = m.bytes
266 266
267 267 args = uncanSequence(map(unserialize, sargs), g)
268 268 kwargs = {}
269 269 for k in sorted(skwargs.iterkeys()):
270 270 sa = skwargs[k]
271 271 if sa.data is None:
272 272 sa.data = bufs.pop(0)
273 273 kwargs[k] = uncan(unserialize(sa), g)
274 274
275 275 return f,args,kwargs
276 276
277 277 class StreamSession(object):
278 278 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
279 279 debug=False
280 280 key=None
281 281
282 282 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
283 283 if username is None:
284 284 username = os.environ.get('USER','username')
285 285 self.username = username
286 286 if session is None:
287 287 self.session = str(uuid.uuid4())
288 288 else:
289 289 self.session = session
290 290 self.msg_id = str(uuid.uuid4())
291 291 if packer is None:
292 292 self.pack = default_packer
293 293 else:
294 294 if not callable(packer):
295 295 raise TypeError("packer must be callable, not %s"%type(packer))
296 296 self.pack = packer
297 297
298 298 if unpacker is None:
299 299 self.unpack = default_unpacker
300 300 else:
301 301 if not callable(unpacker):
302 302 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
303 303 self.unpack = unpacker
304 304
305 305 if key is not None and keyfile is not None:
306 306 raise TypeError("Must specify key OR keyfile, not both")
307 307 if keyfile is not None:
308 308 with open(keyfile) as f:
309 309 self.key = f.read().strip()
310 310 else:
311 311 self.key = key
312 312 # print key, keyfile, self.key
313 313 self.none = self.pack({})
314 314
315 315 def msg_header(self, msg_type):
316 316 h = msg_header(self.msg_id, msg_type, self.username, self.session)
317 317 self.msg_id = str(uuid.uuid4())
318 318 return h
319 319
320 320 def msg(self, msg_type, content=None, parent=None, subheader=None):
321 321 msg = {}
322 322 msg['header'] = self.msg_header(msg_type)
323 323 msg['msg_id'] = msg['header']['msg_id']
324 324 msg['parent_header'] = {} if parent is None else extract_header(parent)
325 325 msg['msg_type'] = msg_type
326 326 msg['content'] = {} if content is None else content
327 327 sub = {} if subheader is None else subheader
328 328 msg['header'].update(sub)
329 329 return msg
330 330
331 331 def check_key(self, msg_or_header):
332 332 """Check that a message's header has the right key"""
333 333 if self.key is None:
334 334 return True
335 335 header = extract_header(msg_or_header)
336 336 return header.get('key', None) == self.key
337 337
338 338
339 339 def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
340 340 """Build and send a message via stream or socket.
341 341
342 342 Parameters
343 343 ----------
344 344
345 345 stream : zmq.Socket or ZMQStream
346 346 the socket-like object used to send the data
347 347 msg_type : str or Message/dict
348 348 Normally, msg_type will be
349 349
350 350
351 351
352 352 Returns
353 353 -------
354 354 (msg,sent) : tuple
355 355 msg : Message
356 356 the nice wrapped dict-like object containing the headers
357 357
358 358 """
359 359 if isinstance(msg_type, (Message, dict)):
360 360 # we got a Message, not a msg_type
361 361 # don't build a new Message
362 362 msg = msg_type
363 363 content = msg['content']
364 364 else:
365 365 msg = self.msg(msg_type, content, parent, subheader)
366 366 buffers = [] if buffers is None else buffers
367 367 to_send = []
368 368 if isinstance(ident, list):
369 369 # accept list of idents
370 370 to_send.extend(ident)
371 371 elif ident is not None:
372 372 to_send.append(ident)
373 373 to_send.append(DELIM)
374 374 if self.key is not None:
375 375 to_send.append(self.key)
376 376 to_send.append(self.pack(msg['header']))
377 377 to_send.append(self.pack(msg['parent_header']))
378 378
379 379 if content is None:
380 380 content = self.none
381 381 elif isinstance(content, dict):
382 382 content = self.pack(content)
383 383 elif isinstance(content, str):
384 384 # content is already packed, as in a relayed message
385 385 pass
386 386 else:
387 387 raise TypeError("Content incorrect type: %s"%type(content))
388 388 to_send.append(content)
389 389 flag = 0
390 390 if buffers:
391 391 flag = zmq.SNDMORE
392 392 stream.send_multipart(to_send, flag, copy=False)
393 393 for b in buffers[:-1]:
394 394 stream.send(b, flag, copy=False)
395 395 if buffers:
396 396 stream.send(buffers[-1], copy=False)
397 397 omsg = Message(msg)
398 398 if self.debug:
399 399 pprint.pprint(omsg)
400 400 pprint.pprint(to_send)
401 401 pprint.pprint(buffers)
402 402 return omsg
403 403
404 404 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
405 405 """Send a raw message via ident path.
406 406
407 407 Parameters
408 408 ----------
409 409 msg : list of sendable buffers"""
410 410 to_send = []
411 411 if isinstance(ident, str):
412 412 ident = [ident]
413 413 if ident is not None:
414 414 to_send.extend(ident)
415 415 to_send.append(DELIM)
416 416 if self.key is not None:
417 417 to_send.append(self.key)
418 418 to_send.extend(msg)
419 419 stream.send_multipart(msg, flags, copy=copy)
420 420
421 421 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
422 422 """receives and unpacks a message
423 423 returns [idents], msg"""
424 424 if isinstance(socket, ZMQStream):
425 425 socket = socket.socket
426 426 try:
427 427 msg = socket.recv_multipart(mode)
428 428 except zmq.ZMQError as e:
429 429 if e.errno == zmq.EAGAIN:
430 430 # We can convert EAGAIN to None as we know in this case
431 431 # recv_json won't return None.
432 432 return None
433 433 else:
434 434 raise
435 435 # return an actual Message object
436 436 # determine the number of idents by trying to unpack them.
437 437 # this is terrible:
438 438 idents, msg = self.feed_identities(msg, copy)
439 439 try:
440 440 return idents, self.unpack_message(msg, content=content, copy=copy)
441 441 except Exception as e:
442 442 print (idents, msg)
443 443 # TODO: handle it
444 444 raise e
445 445
446 446 def feed_identities(self, msg, copy=True):
447 447 """feed until DELIM is reached, then return the prefix as idents and remainder as
448 448 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
449 449
450 450 Parameters
451 451 ----------
452 452 msg : a list of Message or bytes objects
453 453 the message to be split
454 454 copy : bool
455 455 flag determining whether the arguments are bytes or Messages
456 456
457 457 Returns
458 458 -------
459 459 (idents,msg) : two lists
460 460 idents will always be a list of bytes - the indentity prefix
461 461 msg will be a list of bytes or Messages, unchanged from input
462 462 msg should be unpackable via self.unpack_message at this point.
463 463 """
464 464 msg = list(msg)
465 465 idents = []
466 466 while len(msg) > 3:
467 467 if copy:
468 468 s = msg[0]
469 469 else:
470 470 s = msg[0].bytes
471 471 if s == DELIM:
472 472 msg.pop(0)
473 473 break
474 474 else:
475 475 idents.append(s)
476 476 msg.pop(0)
477 477
478 478 return idents, msg
479 479
480 480 def unpack_message(self, msg, content=True, copy=True):
481 481 """Return a message object from the format
482 482 sent by self.send.
483 483
484 484 Parameters:
485 485 -----------
486 486
487 487 content : bool (True)
488 488 whether to unpack the content dict (True),
489 489 or leave it serialized (False)
490 490
491 491 copy : bool (True)
492 492 whether to return the bytes (True),
493 493 or the non-copying Message object in each place (False)
494 494
495 495 """
496 496 ikey = int(self.key is not None)
497 497 minlen = 3 + ikey
498 498 if not len(msg) >= minlen:
499 499 raise TypeError("malformed message, must have at least %i elements"%minlen)
500 500 message = {}
501 501 if not copy:
502 502 for i in range(minlen):
503 503 msg[i] = msg[i].bytes
504 504 if ikey:
505 505 if not self.key == msg[0]:
506 506 raise KeyError("Invalid Session Key: %s"%msg[0])
507 507 message['header'] = self.unpack(msg[ikey+0])
508 508 message['msg_type'] = message['header']['msg_type']
509 509 message['parent_header'] = self.unpack(msg[ikey+1])
510 510 if content:
511 511 message['content'] = self.unpack(msg[ikey+2])
512 512 else:
513 513 message['content'] = msg[ikey+2]
514 514
515 515 # message['buffers'] = msg[3:]
516 516 # else:
517 517 # message['header'] = self.unpack(msg[0].bytes)
518 518 # message['msg_type'] = message['header']['msg_type']
519 519 # message['parent_header'] = self.unpack(msg[1].bytes)
520 520 # if content:
521 521 # message['content'] = self.unpack(msg[2].bytes)
522 522 # else:
523 523 # message['content'] = msg[2].bytes
524 524
525 525 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
526 526 return message
527 527
528 528
529 529
530 530 def test_msg2obj():
531 531 am = dict(x=1)
532 532 ao = Message(am)
533 533 assert ao.x == am['x']
534 534
535 535 am['y'] = dict(z=1)
536 536 ao = Message(am)
537 537 assert ao.y.z == am['y']['z']
538 538
539 539 k1, k2 = 'y', 'z'
540 540 assert ao[k1][k2] == am[k1][k2]
541 541
542 542 am2 = dict(ao)
543 543 assert am['x'] == am2['x']
544 544 assert am['y']['z'] == am2['y']['z']
@@ -1,353 +1,355
1 1 """Views of remote engines"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from IPython.external.decorator import decorator
14 14 from IPython.zmq.parallel.remotefunction import ParallelFunction, parallel
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Decorators
18 18 #-----------------------------------------------------------------------------
19 19
20 20 @decorator
21 21 def myblock(f, self, *args, **kwargs):
22 22 """override client.block with self.block during a call"""
23 23 block = self.client.block
24 24 self.client.block = self.block
25 25 try:
26 26 ret = f(self, *args, **kwargs)
27 27 finally:
28 28 self.client.block = block
29 29 return ret
30 30
31 31 @decorator
32 32 def save_ids(f, self, *args, **kwargs):
33 33 """Keep our history and outstanding attributes up to date after a method call."""
34 34 n_previous = len(self.client.history)
35 35 ret = f(self, *args, **kwargs)
36 36 nmsgs = len(self.client.history) - n_previous
37 37 msg_ids = self.client.history[-nmsgs:]
38 38 self.history.extend(msg_ids)
39 39 map(self.outstanding.add, msg_ids)
40 40 return ret
41 41
42 42 @decorator
43 43 def sync_results(f, self, *args, **kwargs):
44 44 """sync relevant results from self.client to our results attribute."""
45 45 ret = f(self, *args, **kwargs)
46 46 delta = self.outstanding.difference(self.client.outstanding)
47 47 completed = self.outstanding.intersection(delta)
48 48 self.outstanding = self.outstanding.difference(completed)
49 49 for msg_id in completed:
50 50 self.results[msg_id] = self.client.results[msg_id]
51 51 return ret
52 52
53 53 @decorator
54 54 def spin_after(f, self, *args, **kwargs):
55 55 """call spin after the method."""
56 56 ret = f(self, *args, **kwargs)
57 57 self.spin()
58 58 return ret
59 59
60 60 #-----------------------------------------------------------------------------
61 61 # Classes
62 62 #-----------------------------------------------------------------------------
63 63
64 64 class View(object):
65 65 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
66 66
67 67 Don't use this class, use subclasses.
68 68 """
69 69 _targets = None
70 70 block=None
71 71 bound=None
72 72 history=None
73 73
74 74 def __init__(self, client, targets=None):
75 75 self.client = client
76 76 self._targets = targets
77 77 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
78 78 self.block = client.block
79 79 self.bound=False
80 80 self.history = []
81 81 self.outstanding = set()
82 82 self.results = {}
83 83
84 84 def __repr__(self):
85 85 strtargets = str(self._targets)
86 86 if len(strtargets) > 16:
87 87 strtargets = strtargets[:12]+'...]'
88 88 return "<%s %s>"%(self.__class__.__name__, strtargets)
89 89
90 90 @property
91 91 def targets(self):
92 92 return self._targets
93 93
94 94 @targets.setter
95 95 def targets(self, value):
96 96 self._targets = value
97 97 # raise AttributeError("Cannot set my targets argument after construction!")
98 98
99 99 @sync_results
100 100 def spin(self):
101 101 """spin the client, and sync"""
102 102 self.client.spin()
103 103
104 104 @sync_results
105 105 @save_ids
106 106 def apply(self, f, *args, **kwargs):
107 107 """calls f(*args, **kwargs) on remote engines, returning the result.
108 108
109 109 This method does not involve the engine's namespace.
110 110
111 111 if self.block is False:
112 112 returns msg_id
113 113 else:
114 114 returns actual result of f(*args, **kwargs)
115 115 """
116 116 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=self.bound)
117 117
118 118 @save_ids
119 119 def apply_async(self, f, *args, **kwargs):
120 120 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
121 121
122 122 This method does not involve the engine's namespace.
123 123
124 124 returns msg_id
125 125 """
126 126 return self.client.apply(f,args,kwargs, block=False, targets=self.targets, bound=False)
127 127
128 128 @spin_after
129 129 @save_ids
130 130 def apply_sync(self, f, *args, **kwargs):
131 131 """calls f(*args, **kwargs) on remote engines in a blocking manner,
132 132 returning the result.
133 133
134 134 This method does not involve the engine's namespace.
135 135
136 136 returns: actual result of f(*args, **kwargs)
137 137 """
138 138 return self.client.apply(f,args,kwargs, block=True, targets=self.targets, bound=False)
139 139
140 140 @sync_results
141 141 @save_ids
142 142 def apply_bound(self, f, *args, **kwargs):
143 143 """calls f(*args, **kwargs) bound to engine namespace(s).
144 144
145 145 if self.block is False:
146 146 returns msg_id
147 147 else:
148 148 returns actual result of f(*args, **kwargs)
149 149
150 150 This method has access to the targets' globals
151 151
152 152 """
153 153 return self.client.apply(f, args, kwargs, block=self.block, targets=self.targets, bound=True)
154 154
155 155 @sync_results
156 156 @save_ids
157 157 def apply_async_bound(self, f, *args, **kwargs):
158 158 """calls f(*args, **kwargs) bound to engine namespace(s)
159 159 in a nonblocking manner.
160 160
161 161 returns: msg_id
162 162
163 163 This method has access to the targets' globals
164 164
165 165 """
166 166 return self.client.apply(f, args, kwargs, block=False, targets=self.targets, bound=True)
167 167
168 168 @spin_after
169 169 @save_ids
170 170 def apply_sync_bound(self, f, *args, **kwargs):
171 171 """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result.
172 172
173 173 returns: actual result of f(*args, **kwargs)
174 174
175 175 This method has access to the targets' globals
176 176
177 177 """
178 178 return self.client.apply(f, args, kwargs, block=True, targets=self.targets, bound=True)
179 179
180 180 @spin_after
181 181 @save_ids
182 182 def map(self, f, *sequences):
183 183 """Parallel version of builtin `map`, using this view's engines."""
184 184 if isinstance(self.targets, int):
185 185 targets = [self.targets]
186 else:
187 targets = self.targets
186 188 pf = ParallelFunction(self.client, f, block=self.block,
187 189 bound=True, targets=targets)
188 190 return pf.map(*sequences)
189 191
190 192 def parallel(self, bound=True, block=True):
191 193 """Decorator for making a ParallelFunction"""
192 194 return parallel(self.client, bound=bound, targets=self.targets, block=block)
193 195
194 196 def abort(self, msg_ids=None, block=None):
195 197 """Abort jobs on my engines.
196 198
197 199 Parameters
198 200 ----------
199 201
200 202 msg_ids : None, str, list of strs, optional
201 203 if None: abort all jobs.
202 204 else: abort specific msg_id(s).
203 205 """
204 206 block = block if block is not None else self.block
205 207 return self.client.abort(msg_ids=msg_ids, targets=self.targets, block=block)
206 208
207 209 def queue_status(self, verbose=False):
208 210 """Fetch the Queue status of my engines"""
209 211 return self.client.queue_status(targets=self.targets, verbose=verbose)
210 212
211 213 def purge_results(self, msg_ids=[], targets=[]):
212 214 """Instruct the controller to forget specific results."""
213 215 if targets is None or targets == 'all':
214 216 targets = self.targets
215 217 return self.client.purge_results(msg_ids=msg_ids, targets=targets)
216 218
217 219
218 220
219 221 class DirectView(View):
220 222 """Direct Multiplexer View of one or more engines.
221 223
222 224 These are created via indexed access to a client:
223 225
224 226 >>> dv_1 = client[1]
225 227 >>> dv_all = client[:]
226 228 >>> dv_even = client[::2]
227 229 >>> dv_some = client[1:3]
228 230
229 231 This object provides dictionary access
230 232
231 233 """
232 234
233 235 @sync_results
234 236 @save_ids
235 237 def execute(self, code, block=True):
236 238 """execute some code on my targets."""
237 239 return self.client.execute(code, block=self.block, targets=self.targets)
238 240
239 241 def update(self, ns):
240 242 """update remote namespace with dict `ns`"""
241 243 return self.client.push(ns, targets=self.targets, block=self.block)
242 244
243 245 push = update
244 246
245 247 def get(self, key_s):
246 248 """get object(s) by `key_s` from remote namespace
247 249 will return one object if it is a key.
248 250 It also takes a list of keys, and will return a list of objects."""
249 251 # block = block if block is not None else self.block
250 252 return self.client.pull(key_s, block=True, targets=self.targets)
251 253
252 254 @sync_results
253 255 @save_ids
254 256 def pull(self, key_s, block=True):
255 257 """get object(s) by `key_s` from remote namespace
256 258 will return one object if it is a key.
257 259 It also takes a list of keys, and will return a list of objects."""
258 260 block = block if block is not None else self.block
259 261 return self.client.pull(key_s, block=block, targets=self.targets)
260 262
261 263 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None):
262 264 """
263 265 Partition a Python sequence and send the partitions to a set of engines.
264 266 """
265 267 block = block if block is not None else self.block
266 268 targets = targets if targets is not None else self.targets
267 269
268 270 return self.client.scatter(key, seq, dist=dist, flatten=flatten,
269 271 targets=targets, block=block)
270 272
271 273 @sync_results
272 274 @save_ids
273 275 def gather(self, key, dist='b', targets=None, block=None):
274 276 """
275 277 Gather a partitioned sequence on a set of engines as a single local seq.
276 278 """
277 279 block = block if block is not None else self.block
278 280 targets = targets if targets is not None else self.targets
279 281
280 282 return self.client.gather(key, dist=dist, targets=targets, block=block)
281 283
282 284 def __getitem__(self, key):
283 285 return self.get(key)
284 286
285 287 def __setitem__(self,key, value):
286 288 self.update({key:value})
287 289
288 290 def clear(self, block=False):
289 291 """Clear the remote namespaces on my engines."""
290 292 block = block if block is not None else self.block
291 293 return self.client.clear(targets=self.targets, block=block)
292 294
293 295 def kill(self, block=True):
294 296 """Kill my engines."""
295 297 block = block if block is not None else self.block
296 298 return self.client.kill(targets=self.targets, block=block)
297 299
298 300 #----------------------------------------
299 301 # activate for %px,%autopx magics
300 302 #----------------------------------------
301 303 def activate(self):
302 304 """Make this `View` active for parallel magic commands.
303 305
304 306 IPython has a magic command syntax to work with `MultiEngineClient` objects.
305 307 In a given IPython session there is a single active one. While
306 308 there can be many `Views` created and used by the user,
307 309 there is only one active one. The active `View` is used whenever
308 310 the magic commands %px and %autopx are used.
309 311
310 312 The activate() method is called on a given `View` to make it
311 313 active. Once this has been done, the magic commands can be used.
312 314 """
313 315
314 316 try:
315 317 # This is injected into __builtins__.
316 318 ip = get_ipython()
317 319 except NameError:
318 320 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
319 321 else:
320 322 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
321 323 if pmagic is not None:
322 324 pmagic.active_multiengine_client = self
323 325 else:
324 326 print "You must first load the parallelmagic extension " \
325 327 "by doing '%load_ext parallelmagic'"
326 328
327 329
328 330 class LoadBalancedView(View):
329 331 """An engine-agnostic View that only executes via the Task queue.
330 332
331 333 Typically created via:
332 334
333 335 >>> lbv = client[None]
334 336 <LoadBalancedView tcp://127.0.0.1:12345>
335 337
336 338 but can also be created with:
337 339
338 340 >>> lbc = LoadBalancedView(client)
339 341
340 342 TODO: allow subset of engines across which to balance.
341 343 """
342 344 def __repr__(self):
343 345 return "<%s %s>"%(self.__class__.__name__, self.client._addr)
344 346
345 347 @property
346 348 def targets(self):
347 349 return None
348 350
349 351 @targets.setter
350 352 def targets(self, value):
351 353 raise AttributeError("Cannot set targets for LoadbalancedView!")
352 354
353 355 No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now