##// END OF EJS Templates
cleanup pass
MinRK -
Show More
@@ -0,0 +1,18 b''
1 """The IPython ZMQ-based parallel computing interface."""
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2011 The IPython Development Team
4 #
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
8
9 #-----------------------------------------------------------------------------
10 # Imports
11 #-----------------------------------------------------------------------------
12
13 from .asyncresult import *
14 from .client import Client
15 from .dependency import *
16 from .remotefunction import *
17 from .view import *
18
@@ -1,294 +1,305 b''
1 1 """AsyncResult objects for the client"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import time
14 14
15 15 from IPython.external.decorator import decorator
16 16 from . import error
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Classes
20 20 #-----------------------------------------------------------------------------
21 21
22 22 @decorator
23 23 def check_ready(f, self, *args, **kwargs):
24 24 """Call spin() to sync state prior to calling the method."""
25 25 self.wait(0)
26 26 if not self._ready:
27 27 raise error.TimeoutError("result not ready")
28 28 return f(self, *args, **kwargs)
29 29
30 30 class AsyncResult(object):
31 31 """Class for representing results of non-blocking calls.
32 32
33 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
33 Provides the same interface as :py:class:`multiprocessing.pool.AsyncResult`.
34 34 """
35 35
36 36 msg_ids = None
37 37
38 38 def __init__(self, client, msg_ids, fname='unknown'):
39 39 self._client = client
40 40 if isinstance(msg_ids, basestring):
41 41 msg_ids = [msg_ids]
42 42 self.msg_ids = msg_ids
43 43 self._fname=fname
44 44 self._ready = False
45 45 self._success = None
46 46 self._single_result = len(msg_ids) == 1
47 47
48 48 def __repr__(self):
49 49 if self._ready:
50 50 return "<%s: finished>"%(self.__class__.__name__)
51 51 else:
52 52 return "<%s: %s>"%(self.__class__.__name__,self._fname)
53 53
54 54
55 55 def _reconstruct_result(self, res):
56 """
56 """Reconstruct our result from actual result list (always a list)
57
57 58 Override me in subclasses for turning a list of results
58 59 into the expected form.
59 60 """
60 61 if self._single_result:
61 62 return res[0]
62 63 else:
63 64 return res
64 65
65 66 def get(self, timeout=-1):
66 67 """Return the result when it arrives.
67 68
68 69 If `timeout` is not ``None`` and the result does not arrive within
69 70 `timeout` seconds then ``TimeoutError`` is raised. If the
70 71 remote call raised an exception then that exception will be reraised
71 by get().
72 by get() inside a `RemoteError`.
72 73 """
73 74 if not self.ready():
74 75 self.wait(timeout)
75 76
76 77 if self._ready:
77 78 if self._success:
78 79 return self._result
79 80 else:
80 81 raise self._exception
81 82 else:
82 83 raise error.TimeoutError("Result not ready.")
83 84
84 85 def ready(self):
85 86 """Return whether the call has completed."""
86 87 if not self._ready:
87 88 self.wait(0)
88 89 return self._ready
89 90
90 91 def wait(self, timeout=-1):
91 92 """Wait until the result is available or until `timeout` seconds pass.
93
94 This method always returns None.
92 95 """
93 96 if self._ready:
94 97 return
95 98 self._ready = self._client.barrier(self.msg_ids, timeout)
96 99 if self._ready:
97 100 try:
98 101 results = map(self._client.results.get, self.msg_ids)
99 102 self._result = results
100 103 if self._single_result:
101 104 r = results[0]
102 105 if isinstance(r, Exception):
103 106 raise r
104 107 else:
105 108 results = error.collect_exceptions(results, self._fname)
106 109 self._result = self._reconstruct_result(results)
107 110 except Exception, e:
108 111 self._exception = e
109 112 self._success = False
110 113 else:
111 114 self._success = True
112 115 finally:
113 116 self._metadata = map(self._client.metadata.get, self.msg_ids)
114 117
115 118
116 119 def successful(self):
117 120 """Return whether the call completed without raising an exception.
118 121
119 122 Will raise ``AssertionError`` if the result is not ready.
120 123 """
121 assert self._ready
124 assert self.ready()
122 125 return self._success
123 126
124 127 #----------------------------------------------------------------
125 128 # Extra methods not in mp.pool.AsyncResult
126 129 #----------------------------------------------------------------
127 130
128 131 def get_dict(self, timeout=-1):
129 """Get the results as a dict, keyed by engine_id."""
132 """Get the results as a dict, keyed by engine_id.
133
134 timeout behavior is described in `get()`.
135 """
136
130 137 results = self.get(timeout)
131 138 engine_ids = [ md['engine_id'] for md in self._metadata ]
132 139 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
133 140 maxcount = bycount.count(bycount[-1])
134 141 if maxcount > 1:
135 142 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
136 143 maxcount, bycount[-1]))
137 144
138 145 return dict(zip(engine_ids,results))
139 146
140 147 @property
141 148 @check_ready
142 149 def result(self):
143 """result property."""
150 """result property wrapper for `get(timeout=0)`."""
144 151 return self._result
145 152
146 153 # abbreviated alias:
147 154 r = result
148 155
149 156 @property
150 157 @check_ready
151 158 def metadata(self):
152 """metadata property."""
159 """property for accessing execution metadata."""
153 160 if self._single_result:
154 161 return self._metadata[0]
155 162 else:
156 163 return self._metadata
157 164
158 165 @property
159 166 def result_dict(self):
160 167 """result property as a dict."""
161 168 return self.get_dict(0)
162 169
163 170 def __dict__(self):
164 171 return self.get_dict(0)
165 172
166 173 #-------------------------------------
167 174 # dict-access
168 175 #-------------------------------------
169 176
170 177 @check_ready
171 178 def __getitem__(self, key):
172 179 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
173 180 """
174 181 if isinstance(key, int):
175 182 return error.collect_exceptions([self._result[key]], self._fname)[0]
176 183 elif isinstance(key, slice):
177 184 return error.collect_exceptions(self._result[key], self._fname)
178 185 elif isinstance(key, basestring):
179 186 values = [ md[key] for md in self._metadata ]
180 187 if self._single_result:
181 188 return values[0]
182 189 else:
183 190 return values
184 191 else:
185 192 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
186 193
187 194 @check_ready
188 195 def __getattr__(self, key):
189 """getattr maps to getitem for convenient access to metadata."""
196 """getattr maps to getitem for convenient attr access to metadata."""
190 197 if key not in self._metadata[0].keys():
191 198 raise AttributeError("%r object has no attribute %r"%(
192 199 self.__class__.__name__, key))
193 200 return self.__getitem__(key)
194 201
195 202 # asynchronous iterator:
196 203 def __iter__(self):
197 204 if self._single_result:
198 205 raise TypeError("AsyncResults with a single result are not iterable.")
199 206 try:
200 207 rlist = self.get(0)
201 208 except error.TimeoutError:
202 209 # wait for each result individually
203 210 for msg_id in self.msg_ids:
204 211 ar = AsyncResult(self._client, msg_id, self._fname)
205 212 yield ar.get()
206 213 else:
207 214 # already done
208 215 for r in rlist:
209 216 yield r
210 217
211 218
212 219
213 220 class AsyncMapResult(AsyncResult):
214 221 """Class for representing results of non-blocking gathers.
215 222
216 223 This will properly reconstruct the gather.
217 224 """
218 225
219 226 def __init__(self, client, msg_ids, mapObject, fname=''):
220 227 AsyncResult.__init__(self, client, msg_ids, fname=fname)
221 228 self._mapObject = mapObject
222 229 self._single_result = False
223 230
224 231 def _reconstruct_result(self, res):
225 232 """Perform the gather on the actual results."""
226 233 return self._mapObject.joinPartitions(res)
227 234
228 235 # asynchronous iterator:
229 236 def __iter__(self):
230 237 try:
231 238 rlist = self.get(0)
232 239 except error.TimeoutError:
233 240 # wait for each result individually
234 241 for msg_id in self.msg_ids:
235 242 ar = AsyncResult(self._client, msg_id, self._fname)
236 243 rlist = ar.get()
237 244 try:
238 245 for r in rlist:
239 246 yield r
240 247 except TypeError:
241 248 # flattened, not a list
242 249 # this could get broken by flattened data that returns iterables
243 250 # but most calls to map do not expose the `flatten` argument
244 251 yield rlist
245 252 else:
246 253 # already done
247 254 for r in rlist:
248 255 yield r
249 256
250 257
251 258 class AsyncHubResult(AsyncResult):
252 """Class to wrap pending results that must be requested from the Hub"""
259 """Class to wrap pending results that must be requested from the Hub.
260
261 Note that waiting/polling on these objects requires polling the Hubover the network,
262 so use `AsyncHubResult.wait()` sparingly.
263 """
253 264
254 265 def wait(self, timeout=-1):
255 266 """wait for result to complete."""
256 267 start = time.time()
257 268 if self._ready:
258 269 return
259 270 local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
260 271 local_ready = self._client.barrier(local_ids, timeout)
261 272 if local_ready:
262 273 remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
263 274 if not remote_ids:
264 275 self._ready = True
265 276 else:
266 277 rdict = self._client.result_status(remote_ids, status_only=False)
267 278 pending = rdict['pending']
268 279 while pending and (timeout < 0 or time.time() < start+timeout):
269 280 rdict = self._client.result_status(remote_ids, status_only=False)
270 281 pending = rdict['pending']
271 282 if pending:
272 283 time.sleep(0.1)
273 284 if not pending:
274 285 self._ready = True
275 286 if self._ready:
276 287 try:
277 288 results = map(self._client.results.get, self.msg_ids)
278 289 self._result = results
279 290 if self._single_result:
280 291 r = results[0]
281 292 if isinstance(r, Exception):
282 293 raise r
283 294 else:
284 295 results = error.collect_exceptions(results, self._fname)
285 296 self._result = self._reconstruct_result(results)
286 297 except Exception, e:
287 298 self._exception = e
288 299 self._success = False
289 300 else:
290 301 self._success = True
291 302 finally:
292 303 self._metadata = map(self._client.metadata.get, self.msg_ids)
293 304
294 305 __all__ = ['AsyncResult', 'AsyncMapResult', 'AsyncHubResult'] No newline at end of file
@@ -1,1499 +1,1501 b''
1 1 """A semi-synchronous Client for the ZMQ controller"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.pickleutil import Reference
28 28 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
29 29 Dict, List, Bool, Str, Set)
30 30 from IPython.external.decorator import decorator
31 31 from IPython.external.ssh import tunnel
32 32
33 33 from . import error
34 34 from . import map as Map
35 from . import util
35 36 from . import streamsession as ss
36 37 from .asyncresult import AsyncResult, AsyncMapResult, AsyncHubResult
37 38 from .clusterdir import ClusterDir, ClusterDirError
38 39 from .dependency import Dependency, depend, require, dependent
39 from .remotefunction import remote,parallel,ParallelFunction,RemoteFunction
40 from .util import ReverseDict, disambiguate_url, validate_url
40 from .remotefunction import remote, parallel, ParallelFunction, RemoteFunction
41 from .util import ReverseDict, validate_url, disambiguate_url
41 42 from .view import DirectView, LoadBalancedView
42 43
43 44 #--------------------------------------------------------------------------
44 45 # helpers for implementing old MEC API via client.apply
45 46 #--------------------------------------------------------------------------
46 47
47 48 def _push(ns):
48 49 """helper method for implementing `client.push` via `client.apply`"""
49 50 globals().update(ns)
50 51
51 52 def _pull(keys):
52 53 """helper method for implementing `client.pull` via `client.apply`"""
53 54 g = globals()
54 55 if isinstance(keys, (list,tuple, set)):
55 56 for key in keys:
56 57 if not g.has_key(key):
57 58 raise NameError("name '%s' is not defined"%key)
58 59 return map(g.get, keys)
59 60 else:
60 61 if not g.has_key(keys):
61 62 raise NameError("name '%s' is not defined"%keys)
62 63 return g.get(keys)
63 64
64 65 def _clear():
65 66 """helper method for implementing `client.clear` via `client.apply`"""
66 67 globals().clear()
67 68
68 69 def _execute(code):
69 70 """helper method for implementing `client.execute` via `client.apply`"""
70 71 exec code in globals()
71 72
72 73
73 74 #--------------------------------------------------------------------------
74 75 # Decorators for Client methods
75 76 #--------------------------------------------------------------------------
76 77
77 78 @decorator
78 79 def spinfirst(f, self, *args, **kwargs):
79 80 """Call spin() to sync state prior to calling the method."""
80 81 self.spin()
81 82 return f(self, *args, **kwargs)
82 83
83 84 @decorator
84 85 def defaultblock(f, self, *args, **kwargs):
85 86 """Default to self.block; preserve self.block."""
86 87 block = kwargs.get('block',None)
87 88 block = self.block if block is None else block
88 89 saveblock = self.block
89 90 self.block = block
90 91 try:
91 92 ret = f(self, *args, **kwargs)
92 93 finally:
93 94 self.block = saveblock
94 95 return ret
95 96
96 97
97 98 #--------------------------------------------------------------------------
98 99 # Classes
99 100 #--------------------------------------------------------------------------
100 101
101 102 class Metadata(dict):
102 103 """Subclass of dict for initializing metadata values.
103 104
104 105 Attribute access works on keys.
105 106
106 107 These objects have a strict set of keys - errors will raise if you try
107 108 to add new keys.
108 109 """
109 110 def __init__(self, *args, **kwargs):
110 111 dict.__init__(self)
111 112 md = {'msg_id' : None,
112 113 'submitted' : None,
113 114 'started' : None,
114 115 'completed' : None,
115 116 'received' : None,
116 117 'engine_uuid' : None,
117 118 'engine_id' : None,
118 119 'follow' : None,
119 120 'after' : None,
120 121 'status' : None,
121 122
122 123 'pyin' : None,
123 124 'pyout' : None,
124 125 'pyerr' : None,
125 126 'stdout' : '',
126 127 'stderr' : '',
127 128 }
128 129 self.update(md)
129 130 self.update(dict(*args, **kwargs))
130 131
131 132 def __getattr__(self, key):
132 133 """getattr aliased to getitem"""
133 134 if key in self.iterkeys():
134 135 return self[key]
135 136 else:
136 137 raise AttributeError(key)
137 138
138 139 def __setattr__(self, key, value):
139 140 """setattr aliased to setitem, with strict"""
140 141 if key in self.iterkeys():
141 142 self[key] = value
142 143 else:
143 144 raise AttributeError(key)
144 145
145 146 def __setitem__(self, key, value):
146 147 """strict static key enforcement"""
147 148 if key in self.iterkeys():
148 149 dict.__setitem__(self, key, value)
149 150 else:
150 151 raise KeyError(key)
151 152
152 153
153 154 class Client(HasTraits):
154 155 """A semi-synchronous client to the IPython ZMQ controller
155 156
156 157 Parameters
157 158 ----------
158 159
159 160 url_or_file : bytes; zmq url or path to ipcontroller-client.json
160 161 Connection information for the Hub's registration. If a json connector
161 162 file is given, then likely no further configuration is necessary.
162 163 [Default: use profile]
163 164 profile : bytes
164 165 The name of the Cluster profile to be used to find connector information.
165 166 [Default: 'default']
166 167 context : zmq.Context
167 168 Pass an existing zmq.Context instance, otherwise the client will create its own.
168 169 username : bytes
169 170 set username to be passed to the Session object
170 171 debug : bool
171 172 flag for lots of message printing for debug purposes
172 173
173 174 #-------------- ssh related args ----------------
174 175 # These are args for configuring the ssh tunnel to be used
175 176 # credentials are used to forward connections over ssh to the Controller
176 177 # Note that the ip given in `addr` needs to be relative to sshserver
177 178 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
178 179 # and set sshserver as the same machine the Controller is on. However,
179 180 # the only requirement is that sshserver is able to see the Controller
180 181 # (i.e. is within the same trusted network).
181 182
182 183 sshserver : str
183 184 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
184 185 If keyfile or password is specified, and this is not, it will default to
185 186 the ip given in addr.
186 187 sshkey : str; path to public ssh key file
187 188 This specifies a key to be used in ssh login, default None.
188 189 Regular default ssh keys will be used without specifying this argument.
189 190 password : str
190 191 Your ssh password to sshserver. Note that if this is left None,
191 192 you will be prompted for it if passwordless key based login is unavailable.
192 193 paramiko : bool
193 194 flag for whether to use paramiko instead of shell ssh for tunneling.
194 195 [default: True on win32, False else]
195 196
196 197 #------- exec authentication args -------
197 198 # If even localhost is untrusted, you can have some protection against
198 199 # unauthorized execution by using a key. Messages are still sent
199 200 # as cleartext, so if someone can snoop your loopback traffic this will
200 201 # not help against malicious attacks.
201 202
202 203 exec_key : str
203 204 an authentication key or file containing a key
204 205 default: None
205 206
206 207
207 208 Attributes
208 209 ----------
209 210
210 211 ids : set of int engine IDs
211 212 requesting the ids attribute always synchronizes
212 213 the registration state. To request ids without synchronization,
213 214 use semi-private _ids attributes.
214 215
215 216 history : list of msg_ids
216 217 a list of msg_ids, keeping track of all the execution
217 218 messages you have submitted in order.
218 219
219 220 outstanding : set of msg_ids
220 221 a set of msg_ids that have been submitted, but whose
221 222 results have not yet been received.
222 223
223 224 results : dict
224 225 a dict of all our results, keyed by msg_id
225 226
226 227 block : bool
227 228 determines default behavior when block not specified
228 229 in execution methods
229 230
230 231 Methods
231 232 -------
232 233
233 234 spin
234 235 flushes incoming results and registration state changes
235 236 control methods spin, and requesting `ids` also ensures up to date
236 237
237 238 barrier
238 239 wait on one or more msg_ids
239 240
240 241 execution methods
241 242 apply
242 243 legacy: execute, run
243 244
244 245 query methods
245 246 queue_status, get_result, purge
246 247
247 248 control methods
248 249 abort, shutdown
249 250
250 251 """
251 252
252 253
253 254 block = Bool(False)
254 255 outstanding=Set()
255 256 results = Dict()
256 257 metadata = Dict()
257 258 history = List()
258 259 debug = Bool(False)
259 260 profile=CUnicode('default')
260 261
261 262 _ids = List()
262 263 _connected=Bool(False)
263 264 _ssh=Bool(False)
264 265 _context = Instance('zmq.Context')
265 266 _config = Dict()
266 267 _engines=Instance(ReverseDict, (), {})
267 268 _registration_socket=Instance('zmq.Socket')
268 269 _query_socket=Instance('zmq.Socket')
269 270 _control_socket=Instance('zmq.Socket')
270 271 _iopub_socket=Instance('zmq.Socket')
271 272 _notification_socket=Instance('zmq.Socket')
272 273 _mux_socket=Instance('zmq.Socket')
273 274 _task_socket=Instance('zmq.Socket')
274 275 _task_scheme=Str()
275 276 _balanced_views=Dict()
276 277 _direct_views=Dict()
277 278 _closed = False
278 279
279 280 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
280 281 context=None, username=None, debug=False, exec_key=None,
281 282 sshserver=None, sshkey=None, password=None, paramiko=None,
282 283 ):
283 284 super(Client, self).__init__(debug=debug, profile=profile)
284 285 if context is None:
285 286 context = zmq.Context()
286 287 self._context = context
287 288
288 289
289 290 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
290 291 if self._cd is not None:
291 292 if url_or_file is None:
292 293 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
293 294 assert url_or_file is not None, "I can't find enough information to connect to a controller!"\
294 295 " Please specify at least one of url_or_file or profile."
295 296
296 297 try:
297 298 validate_url(url_or_file)
298 299 except AssertionError:
299 300 if not os.path.exists(url_or_file):
300 301 if self._cd:
301 302 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
302 303 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
303 304 with open(url_or_file) as f:
304 305 cfg = json.loads(f.read())
305 306 else:
306 307 cfg = {'url':url_or_file}
307 308
308 309 # sync defaults from args, json:
309 310 if sshserver:
310 311 cfg['ssh'] = sshserver
311 312 if exec_key:
312 313 cfg['exec_key'] = exec_key
313 314 exec_key = cfg['exec_key']
314 315 sshserver=cfg['ssh']
315 316 url = cfg['url']
316 317 location = cfg.setdefault('location', None)
317 318 cfg['url'] = disambiguate_url(cfg['url'], location)
318 319 url = cfg['url']
319 320
320 321 self._config = cfg
321 322
322 323 self._ssh = bool(sshserver or sshkey or password)
323 324 if self._ssh and sshserver is None:
324 325 # default to ssh via localhost
325 326 sshserver = url.split('://')[1].split(':')[0]
326 327 if self._ssh and password is None:
327 328 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
328 329 password=False
329 330 else:
330 331 password = getpass("SSH Password for %s: "%sshserver)
331 332 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
332 333 if exec_key is not None and os.path.isfile(exec_key):
333 334 arg = 'keyfile'
334 335 else:
335 336 arg = 'key'
336 337 key_arg = {arg:exec_key}
337 338 if username is None:
338 339 self.session = ss.StreamSession(**key_arg)
339 340 else:
340 341 self.session = ss.StreamSession(username, **key_arg)
341 342 self._registration_socket = self._context.socket(zmq.XREQ)
342 343 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
343 344 if self._ssh:
344 345 tunnel.tunnel_connection(self._registration_socket, url, sshserver, **ssh_kwargs)
345 346 else:
346 347 self._registration_socket.connect(url)
347 348
348 349 self.session.debug = self.debug
349 350
350 351 self._notification_handlers = {'registration_notification' : self._register_engine,
351 352 'unregistration_notification' : self._unregister_engine,
352 353 }
353 354 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
354 355 'apply_reply' : self._handle_apply_reply}
355 356 self._connect(sshserver, ssh_kwargs)
356 357
357 358
358 359 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
359 360 if ipython_dir is None:
360 361 ipython_dir = get_ipython_dir()
361 362 if cluster_dir is not None:
362 363 try:
363 364 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
364 365 return
365 366 except ClusterDirError:
366 367 pass
367 368 elif profile is not None:
368 369 try:
369 370 self._cd = ClusterDir.find_cluster_dir_by_profile(
370 371 ipython_dir, profile)
371 372 return
372 373 except ClusterDirError:
373 374 pass
374 375 self._cd = None
375 376
376 377 @property
377 378 def ids(self):
378 379 """Always up-to-date ids property."""
379 380 self._flush_notifications()
380 381 return self._ids
381 382
382 383 def close(self):
383 384 if self._closed:
384 385 return
385 386 snames = filter(lambda n: n.endswith('socket'), dir(self))
386 387 for socket in map(lambda name: getattr(self, name), snames):
387 388 socket.close()
388 389 self._closed = True
389 390
390 391 def _update_engines(self, engines):
391 392 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
392 393 for k,v in engines.iteritems():
393 394 eid = int(k)
394 395 self._engines[eid] = bytes(v) # force not unicode
395 396 self._ids.append(eid)
396 397 self._ids = sorted(self._ids)
397 398 if sorted(self._engines.keys()) != range(len(self._engines)) and \
398 399 self._task_scheme == 'pure' and self._task_socket:
399 400 self._stop_scheduling_tasks()
400 401
401 402 def _stop_scheduling_tasks(self):
402 403 """Stop scheduling tasks because an engine has been unregistered
403 404 from a pure ZMQ scheduler.
404 405 """
405 406
406 407 self._task_socket.close()
407 408 self._task_socket = None
408 409 msg = "An engine has been unregistered, and we are using pure " +\
409 410 "ZMQ task scheduling. Task farming will be disabled."
410 411 if self.outstanding:
411 412 msg += " If you were running tasks when this happened, " +\
412 413 "some `outstanding` msg_ids may never resolve."
413 414 warnings.warn(msg, RuntimeWarning)
414 415
415 416 def _build_targets(self, targets):
416 417 """Turn valid target IDs or 'all' into two lists:
417 418 (int_ids, uuids).
418 419 """
419 420 if targets is None:
420 421 targets = self._ids
421 422 elif isinstance(targets, str):
422 423 if targets.lower() == 'all':
423 424 targets = self._ids
424 425 else:
425 426 raise TypeError("%r not valid str target, must be 'all'"%(targets))
426 427 elif isinstance(targets, int):
427 428 targets = [targets]
428 429 return [self._engines[t] for t in targets], list(targets)
429 430
430 431 def _connect(self, sshserver, ssh_kwargs):
431 432 """setup all our socket connections to the controller. This is called from
432 433 __init__."""
433 434
434 435 # Maybe allow reconnecting?
435 436 if self._connected:
436 437 return
437 438 self._connected=True
438 439
439 440 def connect_socket(s, url):
440 441 url = disambiguate_url(url, self._config['location'])
441 442 if self._ssh:
442 443 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
443 444 else:
444 445 return s.connect(url)
445 446
446 447 self.session.send(self._registration_socket, 'connection_request')
447 448 idents,msg = self.session.recv(self._registration_socket,mode=0)
448 449 if self.debug:
449 450 pprint(msg)
450 451 msg = ss.Message(msg)
451 452 content = msg.content
452 453 self._config['registration'] = dict(content)
453 454 if content.status == 'ok':
454 455 if content.mux:
455 456 self._mux_socket = self._context.socket(zmq.PAIR)
456 457 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
457 458 connect_socket(self._mux_socket, content.mux)
458 459 if content.task:
459 460 self._task_scheme, task_addr = content.task
460 461 self._task_socket = self._context.socket(zmq.PAIR)
461 462 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
462 463 connect_socket(self._task_socket, task_addr)
463 464 if content.notification:
464 465 self._notification_socket = self._context.socket(zmq.SUB)
465 466 connect_socket(self._notification_socket, content.notification)
466 467 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
467 468 if content.query:
468 469 self._query_socket = self._context.socket(zmq.PAIR)
469 470 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
470 471 connect_socket(self._query_socket, content.query)
471 472 if content.control:
472 473 self._control_socket = self._context.socket(zmq.PAIR)
473 474 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
474 475 connect_socket(self._control_socket, content.control)
475 476 if content.iopub:
476 477 self._iopub_socket = self._context.socket(zmq.SUB)
477 478 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
478 479 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
479 480 connect_socket(self._iopub_socket, content.iopub)
480 481 self._update_engines(dict(content.engines))
481 482
482 483 else:
483 484 self._connected = False
484 485 raise Exception("Failed to connect!")
485 486
486 487 #--------------------------------------------------------------------------
487 488 # handlers and callbacks for incoming messages
488 489 #--------------------------------------------------------------------------
489 490
490 491 def _unwrap_exception(self, content):
491 492 """unwrap exception, and remap engineid to int."""
492 e = ss.unwrap_exception(content)
493 e = error.unwrap_exception(content)
493 494 if e.engine_info:
494 495 e_uuid = e.engine_info['engine_uuid']
495 496 eid = self._engines[e_uuid]
496 497 e.engine_info['engine_id'] = eid
497 498 return e
498 499
499 500 def _register_engine(self, msg):
500 501 """Register a new engine, and update our connection info."""
501 502 content = msg['content']
502 503 eid = content['id']
503 504 d = {eid : content['queue']}
504 505 self._update_engines(d)
505 506
506 507 def _unregister_engine(self, msg):
507 508 """Unregister an engine that has died."""
508 509 content = msg['content']
509 510 eid = int(content['id'])
510 511 if eid in self._ids:
511 512 self._ids.remove(eid)
512 513 self._engines.pop(eid)
513 514 if self._task_socket and self._task_scheme == 'pure':
514 515 self._stop_scheduling_tasks()
515 516
516 517 def _extract_metadata(self, header, parent, content):
517 518 md = {'msg_id' : parent['msg_id'],
518 519 'received' : datetime.now(),
519 520 'engine_uuid' : header.get('engine', None),
520 521 'follow' : parent.get('follow', []),
521 522 'after' : parent.get('after', []),
522 523 'status' : content['status'],
523 524 }
524 525
525 526 if md['engine_uuid'] is not None:
526 527 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
527 528
528 529 if 'date' in parent:
529 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
530 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
530 531 if 'started' in header:
531 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
532 md['started'] = datetime.strptime(header['started'], util.ISO8601)
532 533 if 'date' in header:
533 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
534 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
534 535 return md
535 536
536 537 def _handle_execute_reply(self, msg):
537 538 """Save the reply to an execute_request into our results.
538 539
539 540 execute messages are never actually used. apply is used instead.
540 541 """
541 542
542 543 parent = msg['parent_header']
543 544 msg_id = parent['msg_id']
544 545 if msg_id not in self.outstanding:
545 546 if msg_id in self.history:
546 547 print ("got stale result: %s"%msg_id)
547 548 else:
548 549 print ("got unknown result: %s"%msg_id)
549 550 else:
550 551 self.outstanding.remove(msg_id)
551 552 self.results[msg_id] = self._unwrap_exception(msg['content'])
552 553
553 554 def _handle_apply_reply(self, msg):
554 555 """Save the reply to an apply_request into our results."""
555 556 parent = msg['parent_header']
556 557 msg_id = parent['msg_id']
557 558 if msg_id not in self.outstanding:
558 559 if msg_id in self.history:
559 560 print ("got stale result: %s"%msg_id)
560 561 print self.results[msg_id]
561 562 print msg
562 563 else:
563 564 print ("got unknown result: %s"%msg_id)
564 565 else:
565 566 self.outstanding.remove(msg_id)
566 567 content = msg['content']
567 568 header = msg['header']
568 569
569 570 # construct metadata:
570 571 md = self.metadata.setdefault(msg_id, Metadata())
571 572 md.update(self._extract_metadata(header, parent, content))
572 573 self.metadata[msg_id] = md
573 574
574 575 # construct result:
575 576 if content['status'] == 'ok':
576 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
577 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
577 578 elif content['status'] == 'aborted':
578 579 self.results[msg_id] = error.AbortedTask(msg_id)
579 580 elif content['status'] == 'resubmitted':
580 581 # TODO: handle resubmission
581 582 pass
582 583 else:
583 584 self.results[msg_id] = self._unwrap_exception(content)
584 585
585 586 def _flush_notifications(self):
586 587 """Flush notifications of engine registrations waiting
587 588 in ZMQ queue."""
588 589 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
589 590 while msg is not None:
590 591 if self.debug:
591 592 pprint(msg)
592 593 msg = msg[-1]
593 594 msg_type = msg['msg_type']
594 595 handler = self._notification_handlers.get(msg_type, None)
595 596 if handler is None:
596 597 raise Exception("Unhandled message type: %s"%msg.msg_type)
597 598 else:
598 599 handler(msg)
599 600 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
600 601
601 602 def _flush_results(self, sock):
602 603 """Flush task or queue results waiting in ZMQ queue."""
603 604 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
604 605 while msg is not None:
605 606 if self.debug:
606 607 pprint(msg)
607 608 msg = msg[-1]
608 609 msg_type = msg['msg_type']
609 610 handler = self._queue_handlers.get(msg_type, None)
610 611 if handler is None:
611 612 raise Exception("Unhandled message type: %s"%msg.msg_type)
612 613 else:
613 614 handler(msg)
614 615 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
615 616
616 617 def _flush_control(self, sock):
617 618 """Flush replies from the control channel waiting
618 619 in the ZMQ queue.
619 620
620 621 Currently: ignore them."""
621 622 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
622 623 while msg is not None:
623 624 if self.debug:
624 625 pprint(msg)
625 626 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
626 627
627 628 def _flush_iopub(self, sock):
628 629 """Flush replies from the iopub channel waiting
629 630 in the ZMQ queue.
630 631 """
631 632 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
632 633 while msg is not None:
633 634 if self.debug:
634 635 pprint(msg)
635 636 msg = msg[-1]
636 637 parent = msg['parent_header']
637 638 msg_id = parent['msg_id']
638 639 content = msg['content']
639 640 header = msg['header']
640 641 msg_type = msg['msg_type']
641 642
642 643 # init metadata:
643 644 md = self.metadata.setdefault(msg_id, Metadata())
644 645
645 646 if msg_type == 'stream':
646 647 name = content['name']
647 648 s = md[name] or ''
648 649 md[name] = s + content['data']
649 650 elif msg_type == 'pyerr':
650 651 md.update({'pyerr' : self._unwrap_exception(content)})
651 652 else:
652 653 md.update({msg_type : content['data']})
653 654
654 655 self.metadata[msg_id] = md
655 656
656 657 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
657 658
658 659 #--------------------------------------------------------------------------
659 660 # len, getitem
660 661 #--------------------------------------------------------------------------
661 662
662 663 def __len__(self):
663 664 """len(client) returns # of engines."""
664 665 return len(self.ids)
665 666
666 667 def __getitem__(self, key):
667 668 """index access returns DirectView multiplexer objects
668 669
669 670 Must be int, slice, or list/tuple/xrange of ints"""
670 671 if not isinstance(key, (int, slice, tuple, list, xrange)):
671 672 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
672 673 else:
673 674 return self.view(key, balanced=False)
674 675
675 676 #--------------------------------------------------------------------------
676 677 # Begin public methods
677 678 #--------------------------------------------------------------------------
678 679
679 680 def spin(self):
680 681 """Flush any registration notifications and execution results
681 682 waiting in the ZMQ queue.
682 683 """
683 684 if self._notification_socket:
684 685 self._flush_notifications()
685 686 if self._mux_socket:
686 687 self._flush_results(self._mux_socket)
687 688 if self._task_socket:
688 689 self._flush_results(self._task_socket)
689 690 if self._control_socket:
690 691 self._flush_control(self._control_socket)
691 692 if self._iopub_socket:
692 693 self._flush_iopub(self._iopub_socket)
693 694
694 695 def barrier(self, jobs=None, timeout=-1):
695 696 """waits on one or more `jobs`, for up to `timeout` seconds.
696 697
697 698 Parameters
698 699 ----------
699 700
700 701 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
701 702 ints are indices to self.history
702 703 strs are msg_ids
703 704 default: wait on all outstanding messages
704 705 timeout : float
705 706 a time in seconds, after which to give up.
706 707 default is -1, which means no timeout
707 708
708 709 Returns
709 710 -------
710 711
711 712 True : when all msg_ids are done
712 713 False : timeout reached, some msg_ids still outstanding
713 714 """
714 715 tic = time.time()
715 716 if jobs is None:
716 717 theids = self.outstanding
717 718 else:
718 719 if isinstance(jobs, (int, str, AsyncResult)):
719 720 jobs = [jobs]
720 721 theids = set()
721 722 for job in jobs:
722 723 if isinstance(job, int):
723 724 # index access
724 725 job = self.history[job]
725 726 elif isinstance(job, AsyncResult):
726 727 map(theids.add, job.msg_ids)
727 728 continue
728 729 theids.add(job)
729 730 if not theids.intersection(self.outstanding):
730 731 return True
731 732 self.spin()
732 733 while theids.intersection(self.outstanding):
733 734 if timeout >= 0 and ( time.time()-tic ) > timeout:
734 735 break
735 736 time.sleep(1e-3)
736 737 self.spin()
737 738 return len(theids.intersection(self.outstanding)) == 0
738 739
739 740 #--------------------------------------------------------------------------
740 741 # Control methods
741 742 #--------------------------------------------------------------------------
742 743
743 744 @spinfirst
744 745 @defaultblock
745 746 def clear(self, targets=None, block=None):
746 747 """Clear the namespace in target(s)."""
747 748 targets = self._build_targets(targets)[0]
748 749 for t in targets:
749 750 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
750 751 error = False
751 752 if self.block:
752 753 for i in range(len(targets)):
753 754 idents,msg = self.session.recv(self._control_socket,0)
754 755 if self.debug:
755 756 pprint(msg)
756 757 if msg['content']['status'] != 'ok':
757 758 error = self._unwrap_exception(msg['content'])
758 759 if error:
759 760 return error
760 761
761 762
762 763 @spinfirst
763 764 @defaultblock
764 765 def abort(self, jobs=None, targets=None, block=None):
765 766 """Abort specific jobs from the execution queues of target(s).
766 767
767 768 This is a mechanism to prevent jobs that have already been submitted
768 769 from executing.
769 770
770 771 Parameters
771 772 ----------
772 773
773 774 jobs : msg_id, list of msg_ids, or AsyncResult
774 775 The jobs to be aborted
775 776
776 777
777 778 """
778 779 targets = self._build_targets(targets)[0]
779 780 msg_ids = []
780 781 if isinstance(jobs, (basestring,AsyncResult)):
781 782 jobs = [jobs]
782 783 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
783 784 if bad_ids:
784 785 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
785 786 for j in jobs:
786 787 if isinstance(j, AsyncResult):
787 788 msg_ids.extend(j.msg_ids)
788 789 else:
789 790 msg_ids.append(j)
790 791 content = dict(msg_ids=msg_ids)
791 792 for t in targets:
792 793 self.session.send(self._control_socket, 'abort_request',
793 794 content=content, ident=t)
794 795 error = False
795 796 if self.block:
796 797 for i in range(len(targets)):
797 798 idents,msg = self.session.recv(self._control_socket,0)
798 799 if self.debug:
799 800 pprint(msg)
800 801 if msg['content']['status'] != 'ok':
801 802 error = self._unwrap_exception(msg['content'])
802 803 if error:
803 804 return error
804 805
805 806 @spinfirst
806 807 @defaultblock
807 808 def shutdown(self, targets=None, restart=False, controller=False, block=None):
808 809 """Terminates one or more engine processes, optionally including the controller."""
809 810 if controller:
810 811 targets = 'all'
811 812 targets = self._build_targets(targets)[0]
812 813 for t in targets:
813 814 self.session.send(self._control_socket, 'shutdown_request',
814 815 content={'restart':restart},ident=t)
815 816 error = False
816 817 if block or controller:
817 818 for i in range(len(targets)):
818 819 idents,msg = self.session.recv(self._control_socket,0)
819 820 if self.debug:
820 821 pprint(msg)
821 822 if msg['content']['status'] != 'ok':
822 823 error = self._unwrap_exception(msg['content'])
823 824
824 825 if controller:
825 826 time.sleep(0.25)
826 827 self.session.send(self._query_socket, 'shutdown_request')
827 828 idents,msg = self.session.recv(self._query_socket, 0)
828 829 if self.debug:
829 830 pprint(msg)
830 831 if msg['content']['status'] != 'ok':
831 832 error = self._unwrap_exception(msg['content'])
832 833
833 834 if error:
834 835 raise error
835 836
836 837 #--------------------------------------------------------------------------
837 838 # Execution methods
838 839 #--------------------------------------------------------------------------
839 840
840 841 @defaultblock
841 842 def execute(self, code, targets='all', block=None):
842 843 """Executes `code` on `targets` in blocking or nonblocking manner.
843 844
844 845 ``execute`` is always `bound` (affects engine namespace)
845 846
846 847 Parameters
847 848 ----------
848 849
849 850 code : str
850 851 the code string to be executed
851 852 targets : int/str/list of ints/strs
852 853 the engines on which to execute
853 854 default : all
854 855 block : bool
855 856 whether or not to wait until done to return
856 857 default: self.block
857 858 """
858 859 result = self.apply(_execute, (code,), targets=targets, block=block, bound=True, balanced=False)
859 860 if not block:
860 861 return result
861 862
862 863 def run(self, filename, targets='all', block=None):
863 864 """Execute contents of `filename` on engine(s).
864 865
865 866 This simply reads the contents of the file and calls `execute`.
866 867
867 868 Parameters
868 869 ----------
869 870
870 871 filename : str
871 872 The path to the file
872 873 targets : int/str/list of ints/strs
873 874 the engines on which to execute
874 875 default : all
875 876 block : bool
876 877 whether or not to wait until done
877 878 default: self.block
878 879
879 880 """
880 881 with open(filename, 'rb') as f:
881 882 code = f.read()
882 883 return self.execute(code, targets=targets, block=block)
883 884
884 885 def _maybe_raise(self, result):
885 886 """wrapper for maybe raising an exception if apply failed."""
886 887 if isinstance(result, error.RemoteError):
887 888 raise result
888 889
889 890 return result
890 891
891 892 def _build_dependency(self, dep):
892 893 """helper for building jsonable dependencies from various input forms"""
893 894 if isinstance(dep, Dependency):
894 895 return dep.as_dict()
895 896 elif isinstance(dep, AsyncResult):
896 897 return dep.msg_ids
897 898 elif dep is None:
898 899 return []
899 900 else:
900 901 # pass to Dependency constructor
901 902 return list(Dependency(dep))
902 903
903 904 @defaultblock
904 905 def apply(self, f, args=None, kwargs=None, bound=True, block=None,
905 906 targets=None, balanced=None,
906 907 after=None, follow=None, timeout=None):
907 908 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
908 909
909 910 This is the central execution command for the client.
910 911
911 912 Parameters
912 913 ----------
913 914
914 915 f : function
915 916 The fuction to be called remotely
916 917 args : tuple/list
917 918 The positional arguments passed to `f`
918 919 kwargs : dict
919 920 The keyword arguments passed to `f`
920 921 bound : bool (default: True)
921 922 Whether to execute in the Engine(s) namespace, or in a clean
922 923 namespace not affecting the engine.
923 924 block : bool (default: self.block)
924 925 Whether to wait for the result, or return immediately.
925 926 False:
926 927 returns AsyncResult
927 928 True:
928 929 returns actual result(s) of f(*args, **kwargs)
929 930 if multiple targets:
930 931 list of results, matching `targets`
931 932 targets : int,list of ints, 'all', None
932 933 Specify the destination of the job.
933 934 if None:
934 935 Submit via Task queue for load-balancing.
935 936 if 'all':
936 937 Run on all active engines
937 938 if list:
938 939 Run on each specified engine
939 940 if int:
940 941 Run on single engine
941 942
942 943 balanced : bool, default None
943 944 whether to load-balance. This will default to True
944 945 if targets is unspecified, or False if targets is specified.
945 946
946 947 The following arguments are only used when balanced is True:
947 948 after : Dependency or collection of msg_ids
948 949 Only for load-balanced execution (targets=None)
949 950 Specify a list of msg_ids as a time-based dependency.
950 951 This job will only be run *after* the dependencies
951 952 have been met.
952 953
953 954 follow : Dependency or collection of msg_ids
954 955 Only for load-balanced execution (targets=None)
955 956 Specify a list of msg_ids as a location-based dependency.
956 957 This job will only be run on an engine where this dependency
957 958 is met.
958 959
959 960 timeout : float/int or None
960 961 Only for load-balanced execution (targets=None)
961 962 Specify an amount of time (in seconds) for the scheduler to
962 963 wait for dependencies to be met before failing with a
963 964 DependencyTimeout.
964 965
965 966 after,follow,timeout only used if `balanced=True`.
966 967
967 968 Returns
968 969 -------
969 970
970 971 if block is False:
971 972 return AsyncResult wrapping msg_ids
972 973 output of AsyncResult.get() is identical to that of `apply(...block=True)`
973 974 else:
974 975 if single target:
975 976 return result of `f(*args, **kwargs)`
976 977 else:
977 978 return list of results, matching `targets`
978 979 """
979 980 assert not self._closed, "cannot use me anymore, I'm closed!"
980 981 # defaults:
981 982 block = block if block is not None else self.block
982 983 args = args if args is not None else []
983 984 kwargs = kwargs if kwargs is not None else {}
984 985
985 986 if balanced is None:
986 987 if targets is None:
987 988 # default to balanced if targets unspecified
988 989 balanced = True
989 990 else:
990 991 # otherwise default to multiplexing
991 992 balanced = False
992 993
993 994 if targets is None and balanced is False:
994 995 # default to all if *not* balanced, and targets is unspecified
995 996 targets = 'all'
996 997
997 998 # enforce types of f,args,kwrags
998 999 if not callable(f):
999 1000 raise TypeError("f must be callable, not %s"%type(f))
1000 1001 if not isinstance(args, (tuple, list)):
1001 1002 raise TypeError("args must be tuple or list, not %s"%type(args))
1002 1003 if not isinstance(kwargs, dict):
1003 1004 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1004 1005
1005 1006 options = dict(bound=bound, block=block, targets=targets)
1006 1007
1007 1008 if balanced:
1008 1009 return self._apply_balanced(f, args, kwargs, timeout=timeout,
1009 1010 after=after, follow=follow, **options)
1010 1011 elif follow or after or timeout:
1011 1012 msg = "follow, after, and timeout args are only used for"
1012 1013 msg += " load-balanced execution."
1013 1014 raise ValueError(msg)
1014 1015 else:
1015 1016 return self._apply_direct(f, args, kwargs, **options)
1016 1017
1017 1018 def _apply_balanced(self, f, args, kwargs, bound=None, block=None, targets=None,
1018 1019 after=None, follow=None, timeout=None):
1019 1020 """call f(*args, **kwargs) remotely in a load-balanced manner.
1020 1021
1021 1022 This is a private method, see `apply` for details.
1022 1023 Not to be called directly!
1023 1024 """
1024 1025
1025 1026 loc = locals()
1026 1027 for name in ('bound', 'block'):
1027 1028 assert loc[name] is not None, "kwarg %r must be specified!"%name
1028 1029
1029 1030 if self._task_socket is None:
1030 1031 msg = "Task farming is disabled"
1031 1032 if self._task_scheme == 'pure':
1032 1033 msg += " because the pure ZMQ scheduler cannot handle"
1033 1034 msg += " disappearing engines."
1034 1035 raise RuntimeError(msg)
1035 1036
1036 1037 if self._task_scheme == 'pure':
1037 1038 # pure zmq scheme doesn't support dependencies
1038 1039 msg = "Pure ZMQ scheduler doesn't support dependencies"
1039 1040 if (follow or after):
1040 1041 # hard fail on DAG dependencies
1041 1042 raise RuntimeError(msg)
1042 1043 if isinstance(f, dependent):
1043 1044 # soft warn on functional dependencies
1044 1045 warnings.warn(msg, RuntimeWarning)
1045 1046
1046 1047 # defaults:
1047 1048 args = args if args is not None else []
1048 1049 kwargs = kwargs if kwargs is not None else {}
1049 1050
1050 1051 if targets:
1051 1052 idents,_ = self._build_targets(targets)
1052 1053 else:
1053 1054 idents = []
1054 1055
1055 1056 after = self._build_dependency(after)
1056 1057 follow = self._build_dependency(follow)
1057 1058 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents)
1058 bufs = ss.pack_apply_message(f,args,kwargs)
1059 bufs = util.pack_apply_message(f,args,kwargs)
1059 1060 content = dict(bound=bound)
1060 1061
1061 1062 msg = self.session.send(self._task_socket, "apply_request",
1062 1063 content=content, buffers=bufs, subheader=subheader)
1063 1064 msg_id = msg['msg_id']
1064 1065 self.outstanding.add(msg_id)
1065 1066 self.history.append(msg_id)
1066 1067 ar = AsyncResult(self, [msg_id], fname=f.__name__)
1067 1068 if block:
1068 1069 try:
1069 1070 return ar.get()
1070 1071 except KeyboardInterrupt:
1071 1072 return ar
1072 1073 else:
1073 1074 return ar
1074 1075
1075 1076 def _apply_direct(self, f, args, kwargs, bound=None, block=None, targets=None):
1076 1077 """Then underlying method for applying functions to specific engines
1077 1078 via the MUX queue.
1078 1079
1079 1080 This is a private method, see `apply` for details.
1080 1081 Not to be called directly!
1081 1082 """
1082 1083 loc = locals()
1083 1084 for name in ('bound', 'block', 'targets'):
1084 1085 assert loc[name] is not None, "kwarg %r must be specified!"%name
1085 1086
1086 1087 idents,targets = self._build_targets(targets)
1087 1088
1088 1089 subheader = {}
1089 1090 content = dict(bound=bound)
1090 bufs = ss.pack_apply_message(f,args,kwargs)
1091 bufs = util.pack_apply_message(f,args,kwargs)
1091 1092
1092 1093 msg_ids = []
1093 1094 for ident in idents:
1094 1095 msg = self.session.send(self._mux_socket, "apply_request",
1095 1096 content=content, buffers=bufs, ident=ident, subheader=subheader)
1096 1097 msg_id = msg['msg_id']
1097 1098 self.outstanding.add(msg_id)
1098 1099 self.history.append(msg_id)
1099 1100 msg_ids.append(msg_id)
1100 1101 ar = AsyncResult(self, msg_ids, fname=f.__name__)
1101 1102 if block:
1102 1103 try:
1103 1104 return ar.get()
1104 1105 except KeyboardInterrupt:
1105 1106 return ar
1106 1107 else:
1107 1108 return ar
1108 1109
1109 1110 #--------------------------------------------------------------------------
1110 1111 # construct a View object
1111 1112 #--------------------------------------------------------------------------
1112 1113
1113 1114 @defaultblock
1114 1115 def remote(self, bound=True, block=None, targets=None, balanced=None):
1115 1116 """Decorator for making a RemoteFunction"""
1116 1117 return remote(self, bound=bound, targets=targets, block=block, balanced=balanced)
1117 1118
1118 1119 @defaultblock
1119 1120 def parallel(self, dist='b', bound=True, block=None, targets=None, balanced=None):
1120 1121 """Decorator for making a ParallelFunction"""
1121 1122 return parallel(self, bound=bound, targets=targets, block=block, balanced=balanced)
1122 1123
1123 1124 def _cache_view(self, targets, balanced):
1124 1125 """save views, so subsequent requests don't create new objects."""
1125 1126 if balanced:
1126 1127 view_class = LoadBalancedView
1127 1128 view_cache = self._balanced_views
1128 1129 else:
1129 1130 view_class = DirectView
1130 1131 view_cache = self._direct_views
1131 1132
1132 1133 # use str, since often targets will be a list
1133 1134 key = str(targets)
1134 1135 if key not in view_cache:
1135 1136 view_cache[key] = view_class(client=self, targets=targets)
1136 1137
1137 1138 return view_cache[key]
1138 1139
1139 1140 def view(self, targets=None, balanced=None):
1140 1141 """Method for constructing View objects.
1141 1142
1142 1143 If no arguments are specified, create a LoadBalancedView
1143 1144 using all engines. If only `targets` specified, it will
1144 1145 be a DirectView. This method is the underlying implementation
1145 1146 of ``client.__getitem__``.
1146 1147
1147 1148 Parameters
1148 1149 ----------
1149 1150
1150 1151 targets: list,slice,int,etc. [default: use all engines]
1151 1152 The engines to use for the View
1152 1153 balanced : bool [default: False if targets specified, True else]
1153 1154 whether to build a LoadBalancedView or a DirectView
1154 1155
1155 1156 """
1156 1157
1157 1158 balanced = (targets is None) if balanced is None else balanced
1158 1159
1159 1160 if targets is None:
1160 1161 if balanced:
1161 1162 return self._cache_view(None,True)
1162 1163 else:
1163 1164 targets = slice(None)
1164 1165
1165 1166 if isinstance(targets, int):
1166 1167 if targets < 0:
1167 1168 targets = self.ids[targets]
1168 1169 if targets not in self.ids:
1169 1170 raise IndexError("No such engine: %i"%targets)
1170 1171 return self._cache_view(targets, balanced)
1171 1172
1172 1173 if isinstance(targets, slice):
1173 1174 indices = range(len(self.ids))[targets]
1174 1175 ids = sorted(self._ids)
1175 1176 targets = [ ids[i] for i in indices ]
1176 1177
1177 1178 if isinstance(targets, (tuple, list, xrange)):
1178 1179 _,targets = self._build_targets(list(targets))
1179 1180 return self._cache_view(targets, balanced)
1180 1181 else:
1181 1182 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
1182 1183
1183 1184 #--------------------------------------------------------------------------
1184 1185 # Data movement
1185 1186 #--------------------------------------------------------------------------
1186 1187
1187 1188 @defaultblock
1188 1189 def push(self, ns, targets='all', block=None):
1189 1190 """Push the contents of `ns` into the namespace on `target`"""
1190 1191 if not isinstance(ns, dict):
1191 1192 raise TypeError("Must be a dict, not %s"%type(ns))
1192 1193 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True, balanced=False)
1193 1194 if not block:
1194 1195 return result
1195 1196
1196 1197 @defaultblock
1197 1198 def pull(self, keys, targets='all', block=None):
1198 1199 """Pull objects from `target`'s namespace by `keys`"""
1199 1200 if isinstance(keys, str):
1200 1201 pass
1201 1202 elif isinstance(keys, (list,tuple,set)):
1202 1203 for key in keys:
1203 1204 if not isinstance(key, str):
1204 1205 raise TypeError
1205 1206 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1206 1207 return result
1207 1208
1208 1209 @defaultblock
1209 1210 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
1210 1211 """
1211 1212 Partition a Python sequence and send the partitions to a set of engines.
1212 1213 """
1213 1214 targets = self._build_targets(targets)[-1]
1214 1215 mapObject = Map.dists[dist]()
1215 1216 nparts = len(targets)
1216 1217 msg_ids = []
1217 1218 for index, engineid in enumerate(targets):
1218 1219 partition = mapObject.getPartition(seq, index, nparts)
1219 1220 if flatten and len(partition) == 1:
1220 1221 r = self.push({key: partition[0]}, targets=engineid, block=False)
1221 1222 else:
1222 1223 r = self.push({key: partition}, targets=engineid, block=False)
1223 1224 msg_ids.extend(r.msg_ids)
1224 1225 r = AsyncResult(self, msg_ids, fname='scatter')
1225 1226 if block:
1226 1227 r.get()
1227 1228 else:
1228 1229 return r
1229 1230
1230 1231 @defaultblock
1231 1232 def gather(self, key, dist='b', targets='all', block=None):
1232 1233 """
1233 1234 Gather a partitioned sequence on a set of engines as a single local seq.
1234 1235 """
1235 1236
1236 1237 targets = self._build_targets(targets)[-1]
1237 1238 mapObject = Map.dists[dist]()
1238 1239 msg_ids = []
1239 1240 for index, engineid in enumerate(targets):
1240 1241 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
1241 1242
1242 1243 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
1243 1244 if block:
1244 1245 return r.get()
1245 1246 else:
1246 1247 return r
1247 1248
1248 1249 #--------------------------------------------------------------------------
1249 1250 # Query methods
1250 1251 #--------------------------------------------------------------------------
1251 1252
1252 1253 @spinfirst
1253 1254 @defaultblock
1254 1255 def get_result(self, indices_or_msg_ids=None, block=None):
1255 1256 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1256 1257
1257 1258 If the client already has the results, no request to the Hub will be made.
1258 1259
1259 1260 This is a convenient way to construct AsyncResult objects, which are wrappers
1260 1261 that include metadata about execution, and allow for awaiting results that
1261 1262 were not submitted by this Client.
1262 1263
1263 1264 It can also be a convenient way to retrieve the metadata associated with
1264 1265 blocking execution, since it always retrieves
1265 1266
1266 1267 Examples
1267 1268 --------
1268 1269 ::
1269 1270
1270 1271 In [10]: r = client.apply()
1271 1272
1272 1273 Parameters
1273 1274 ----------
1274 1275
1275 1276 indices_or_msg_ids : integer history index, str msg_id, or list of either
1276 1277 The indices or msg_ids of indices to be retrieved
1277 1278
1278 1279 block : bool
1279 1280 Whether to wait for the result to be done
1280 1281
1281 1282 Returns
1282 1283 -------
1283 1284
1284 1285 AsyncResult
1285 1286 A single AsyncResult object will always be returned.
1286 1287
1287 1288 AsyncHubResult
1288 1289 A subclass of AsyncResult that retrieves results from the Hub
1289 1290
1290 1291 """
1291 1292 if indices_or_msg_ids is None:
1292 1293 indices_or_msg_ids = -1
1293 1294
1294 1295 if not isinstance(indices_or_msg_ids, (list,tuple)):
1295 1296 indices_or_msg_ids = [indices_or_msg_ids]
1296 1297
1297 1298 theids = []
1298 1299 for id in indices_or_msg_ids:
1299 1300 if isinstance(id, int):
1300 1301 id = self.history[id]
1301 1302 if not isinstance(id, str):
1302 1303 raise TypeError("indices must be str or int, not %r"%id)
1303 1304 theids.append(id)
1304 1305
1305 1306 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1306 1307 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1307 1308
1308 1309 if remote_ids:
1309 1310 ar = AsyncHubResult(self, msg_ids=theids)
1310 1311 else:
1311 1312 ar = AsyncResult(self, msg_ids=theids)
1312 1313
1313 1314 if block:
1314 1315 ar.wait()
1315 1316
1316 1317 return ar
1317 1318
1318 1319 @spinfirst
1319 1320 def result_status(self, msg_ids, status_only=True):
1320 1321 """Check on the status of the result(s) of the apply request with `msg_ids`.
1321 1322
1322 1323 If status_only is False, then the actual results will be retrieved, else
1323 1324 only the status of the results will be checked.
1324 1325
1325 1326 Parameters
1326 1327 ----------
1327 1328
1328 1329 msg_ids : list of msg_ids
1329 1330 if int:
1330 1331 Passed as index to self.history for convenience.
1331 1332 status_only : bool (default: True)
1332 1333 if False:
1333 1334 Retrieve the actual results of completed tasks.
1334 1335
1335 1336 Returns
1336 1337 -------
1337 1338
1338 1339 results : dict
1339 1340 There will always be the keys 'pending' and 'completed', which will
1340 1341 be lists of msg_ids that are incomplete or complete. If `status_only`
1341 1342 is False, then completed results will be keyed by their `msg_id`.
1342 1343 """
1343 1344 if not isinstance(msg_ids, (list,tuple)):
1344 1345 indices_or_msg_ids = [msg_ids]
1345 1346
1346 1347 theids = []
1347 1348 for msg_id in msg_ids:
1348 1349 if isinstance(msg_id, int):
1349 1350 msg_id = self.history[msg_id]
1350 1351 if not isinstance(msg_id, basestring):
1351 1352 raise TypeError("msg_ids must be str, not %r"%msg_id)
1352 1353 theids.append(msg_id)
1353 1354
1354 1355 completed = []
1355 1356 local_results = {}
1356 1357
1357 1358 # comment this block out to temporarily disable local shortcut:
1358 1359 for msg_id in theids:
1359 1360 if msg_id in self.results:
1360 1361 completed.append(msg_id)
1361 1362 local_results[msg_id] = self.results[msg_id]
1362 1363 theids.remove(msg_id)
1363 1364
1364 1365 if theids: # some not locally cached
1365 1366 content = dict(msg_ids=theids, status_only=status_only)
1366 1367 msg = self.session.send(self._query_socket, "result_request", content=content)
1367 1368 zmq.select([self._query_socket], [], [])
1368 1369 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1369 1370 if self.debug:
1370 1371 pprint(msg)
1371 1372 content = msg['content']
1372 1373 if content['status'] != 'ok':
1373 1374 raise self._unwrap_exception(content)
1374 1375 buffers = msg['buffers']
1375 1376 else:
1376 1377 content = dict(completed=[],pending=[])
1377 1378
1378 1379 content['completed'].extend(completed)
1379 1380
1380 1381 if status_only:
1381 1382 return content
1382 1383
1383 1384 failures = []
1384 1385 # load cached results into result:
1385 1386 content.update(local_results)
1386 1387 # update cache with results:
1387 1388 for msg_id in sorted(theids):
1388 1389 if msg_id in content['completed']:
1389 1390 rec = content[msg_id]
1390 1391 parent = rec['header']
1391 1392 header = rec['result_header']
1392 1393 rcontent = rec['result_content']
1393 1394 iodict = rec['io']
1394 1395 if isinstance(rcontent, str):
1395 1396 rcontent = self.session.unpack(rcontent)
1396 1397
1397 1398 md = self.metadata.setdefault(msg_id, Metadata())
1398 1399 md.update(self._extract_metadata(header, parent, rcontent))
1399 1400 md.update(iodict)
1400 1401
1401 1402 if rcontent['status'] == 'ok':
1402 res,buffers = ss.unserialize_object(buffers)
1403 res,buffers = util.unserialize_object(buffers)
1403 1404 else:
1404 1405 print rcontent
1405 1406 res = self._unwrap_exception(rcontent)
1406 1407 failures.append(res)
1407 1408
1408 1409 self.results[msg_id] = res
1409 1410 content[msg_id] = res
1410 1411
1411 1412 if len(theids) == 1 and failures:
1412 1413 raise failures[0]
1413 1414
1414 1415 error.collect_exceptions(failures, "result_status")
1415 1416 return content
1416 1417
1417 1418 @spinfirst
1418 1419 def queue_status(self, targets='all', verbose=False):
1419 1420 """Fetch the status of engine queues.
1420 1421
1421 1422 Parameters
1422 1423 ----------
1423 1424
1424 1425 targets : int/str/list of ints/strs
1425 1426 the engines whose states are to be queried.
1426 1427 default : all
1427 1428 verbose : bool
1428 1429 Whether to return lengths only, or lists of ids for each element
1429 1430 """
1430 1431 targets = self._build_targets(targets)[1]
1431 1432 content = dict(targets=targets, verbose=verbose)
1432 1433 self.session.send(self._query_socket, "queue_request", content=content)
1433 1434 idents,msg = self.session.recv(self._query_socket, 0)
1434 1435 if self.debug:
1435 1436 pprint(msg)
1436 1437 content = msg['content']
1437 1438 status = content.pop('status')
1438 1439 if status != 'ok':
1439 1440 raise self._unwrap_exception(content)
1440 return ss.rekey(content)
1441 return util.rekey(content)
1441 1442
1442 1443 @spinfirst
1443 1444 def purge_results(self, jobs=[], targets=[]):
1444 1445 """Tell the controller to forget results.
1445 1446
1446 1447 Individual results can be purged by msg_id, or the entire
1447 1448 history of specific targets can be purged.
1448 1449
1449 1450 Parameters
1450 1451 ----------
1451 1452
1452 1453 jobs : str or list of strs or AsyncResult objects
1453 1454 the msg_ids whose results should be forgotten.
1454 1455 targets : int/str/list of ints/strs
1455 1456 The targets, by uuid or int_id, whose entire history is to be purged.
1456 1457 Use `targets='all'` to scrub everything from the controller's memory.
1457 1458
1458 1459 default : None
1459 1460 """
1460 1461 if not targets and not jobs:
1461 1462 raise ValueError("Must specify at least one of `targets` and `jobs`")
1462 1463 if targets:
1463 1464 targets = self._build_targets(targets)[1]
1464 1465
1465 1466 # construct msg_ids from jobs
1466 1467 msg_ids = []
1467 1468 if isinstance(jobs, (basestring,AsyncResult)):
1468 1469 jobs = [jobs]
1469 1470 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1470 1471 if bad_ids:
1471 1472 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1472 1473 for j in jobs:
1473 1474 if isinstance(j, AsyncResult):
1474 1475 msg_ids.extend(j.msg_ids)
1475 1476 else:
1476 1477 msg_ids.append(j)
1477 1478
1478 1479 content = dict(targets=targets, msg_ids=msg_ids)
1479 1480 self.session.send(self._query_socket, "purge_request", content=content)
1480 1481 idents, msg = self.session.recv(self._query_socket, 0)
1481 1482 if self.debug:
1482 1483 pprint(msg)
1483 1484 content = msg['content']
1484 1485 if content['status'] != 'ok':
1485 1486 raise self._unwrap_exception(content)
1486 1487
1487 1488
1488 1489 __all__ = [ 'Client',
1489 1490 'depend',
1490 1491 'require',
1491 1492 'remote',
1492 1493 'parallel',
1493 1494 'RemoteFunction',
1494 1495 'ParallelFunction',
1495 1496 'DirectView',
1496 1497 'LoadBalancedView',
1497 1498 'AsyncResult',
1498 'AsyncMapResult'
1499 'AsyncMapResult',
1500 'Reference'
1499 1501 ]
@@ -1,538 +1,537 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython cluster directory
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 from __future__ import with_statement
19 19
20 20 import os
21 21 import logging
22 22 import re
23 23 import shutil
24 24 import sys
25 import warnings
26 25
27 26 from IPython.config.loader import PyFileConfigLoader
28 27 from IPython.config.configurable import Configurable
29 28 from IPython.core.application import Application, BaseAppConfigLoader
30 29 from IPython.core.crashhandler import CrashHandler
31 30 from IPython.core import release
32 31 from IPython.utils.path import (
33 32 get_ipython_package_dir,
34 33 expand_path
35 34 )
36 35 from IPython.utils.traitlets import Unicode
37 36
38 37 #-----------------------------------------------------------------------------
39 38 # Module errors
40 39 #-----------------------------------------------------------------------------
41 40
42 41 class ClusterDirError(Exception):
43 42 pass
44 43
45 44
46 45 class PIDFileError(Exception):
47 46 pass
48 47
49 48
50 49 #-----------------------------------------------------------------------------
51 50 # Class for managing cluster directories
52 51 #-----------------------------------------------------------------------------
53 52
54 53 class ClusterDir(Configurable):
55 54 """An object to manage the cluster directory and its resources.
56 55
57 56 The cluster directory is used by :command:`ipengine`,
58 57 :command:`ipcontroller` and :command:`ipclsuter` to manage the
59 58 configuration, logging and security of these applications.
60 59
61 60 This object knows how to find, create and manage these directories. This
62 61 should be used by any code that want's to handle cluster directories.
63 62 """
64 63
65 64 security_dir_name = Unicode('security')
66 65 log_dir_name = Unicode('log')
67 66 pid_dir_name = Unicode('pid')
68 67 security_dir = Unicode(u'')
69 68 log_dir = Unicode(u'')
70 69 pid_dir = Unicode(u'')
71 70 location = Unicode(u'')
72 71
73 72 def __init__(self, location=u''):
74 73 super(ClusterDir, self).__init__(location=location)
75 74
76 75 def _location_changed(self, name, old, new):
77 76 if not os.path.isdir(new):
78 77 os.makedirs(new)
79 78 self.security_dir = os.path.join(new, self.security_dir_name)
80 79 self.log_dir = os.path.join(new, self.log_dir_name)
81 80 self.pid_dir = os.path.join(new, self.pid_dir_name)
82 81 self.check_dirs()
83 82
84 83 def _log_dir_changed(self, name, old, new):
85 84 self.check_log_dir()
86 85
87 86 def check_log_dir(self):
88 87 if not os.path.isdir(self.log_dir):
89 88 os.mkdir(self.log_dir)
90 89
91 90 def _security_dir_changed(self, name, old, new):
92 91 self.check_security_dir()
93 92
94 93 def check_security_dir(self):
95 94 if not os.path.isdir(self.security_dir):
96 95 os.mkdir(self.security_dir, 0700)
97 96 os.chmod(self.security_dir, 0700)
98 97
99 98 def _pid_dir_changed(self, name, old, new):
100 99 self.check_pid_dir()
101 100
102 101 def check_pid_dir(self):
103 102 if not os.path.isdir(self.pid_dir):
104 103 os.mkdir(self.pid_dir, 0700)
105 104 os.chmod(self.pid_dir, 0700)
106 105
107 106 def check_dirs(self):
108 107 self.check_security_dir()
109 108 self.check_log_dir()
110 109 self.check_pid_dir()
111 110
112 111 def load_config_file(self, filename):
113 112 """Load a config file from the top level of the cluster dir.
114 113
115 114 Parameters
116 115 ----------
117 116 filename : unicode or str
118 117 The filename only of the config file that must be located in
119 118 the top-level of the cluster directory.
120 119 """
121 120 loader = PyFileConfigLoader(filename, self.location)
122 121 return loader.load_config()
123 122
124 123 def copy_config_file(self, config_file, path=None, overwrite=False):
125 124 """Copy a default config file into the active cluster directory.
126 125
127 126 Default configuration files are kept in :mod:`IPython.config.default`.
128 127 This function moves these from that location to the working cluster
129 128 directory.
130 129 """
131 130 if path is None:
132 131 import IPython.config.default
133 132 path = IPython.config.default.__file__.split(os.path.sep)[:-1]
134 133 path = os.path.sep.join(path)
135 134 src = os.path.join(path, config_file)
136 135 dst = os.path.join(self.location, config_file)
137 136 if not os.path.isfile(dst) or overwrite:
138 137 shutil.copy(src, dst)
139 138
140 139 def copy_all_config_files(self, path=None, overwrite=False):
141 140 """Copy all config files into the active cluster directory."""
142 141 for f in [u'ipcontrollerz_config.py', u'ipenginez_config.py',
143 142 u'ipclusterz_config.py']:
144 143 self.copy_config_file(f, path=path, overwrite=overwrite)
145 144
146 145 @classmethod
147 146 def create_cluster_dir(csl, cluster_dir):
148 147 """Create a new cluster directory given a full path.
149 148
150 149 Parameters
151 150 ----------
152 151 cluster_dir : str
153 152 The full path to the cluster directory. If it does exist, it will
154 153 be used. If not, it will be created.
155 154 """
156 155 return ClusterDir(location=cluster_dir)
157 156
158 157 @classmethod
159 158 def create_cluster_dir_by_profile(cls, path, profile=u'default'):
160 159 """Create a cluster dir by profile name and path.
161 160
162 161 Parameters
163 162 ----------
164 163 path : str
165 164 The path (directory) to put the cluster directory in.
166 165 profile : str
167 166 The name of the profile. The name of the cluster directory will
168 167 be "clusterz_<profile>".
169 168 """
170 169 if not os.path.isdir(path):
171 170 raise ClusterDirError('Directory not found: %s' % path)
172 171 cluster_dir = os.path.join(path, u'clusterz_' + profile)
173 172 return ClusterDir(location=cluster_dir)
174 173
175 174 @classmethod
176 175 def find_cluster_dir_by_profile(cls, ipython_dir, profile=u'default'):
177 176 """Find an existing cluster dir by profile name, return its ClusterDir.
178 177
179 178 This searches through a sequence of paths for a cluster dir. If it
180 179 is not found, a :class:`ClusterDirError` exception will be raised.
181 180
182 181 The search path algorithm is:
183 182 1. ``os.getcwd()``
184 183 2. ``ipython_dir``
185 184 3. The directories found in the ":" separated
186 185 :env:`IPCLUSTER_DIR_PATH` environment variable.
187 186
188 187 Parameters
189 188 ----------
190 189 ipython_dir : unicode or str
191 190 The IPython directory to use.
192 191 profile : unicode or str
193 192 The name of the profile. The name of the cluster directory
194 193 will be "clusterz_<profile>".
195 194 """
196 195 dirname = u'clusterz_' + profile
197 196 cluster_dir_paths = os.environ.get('IPCLUSTER_DIR_PATH','')
198 197 if cluster_dir_paths:
199 198 cluster_dir_paths = cluster_dir_paths.split(':')
200 199 else:
201 200 cluster_dir_paths = []
202 201 paths = [os.getcwd(), ipython_dir] + cluster_dir_paths
203 202 for p in paths:
204 203 cluster_dir = os.path.join(p, dirname)
205 204 if os.path.isdir(cluster_dir):
206 205 return ClusterDir(location=cluster_dir)
207 206 else:
208 207 raise ClusterDirError('Cluster directory not found in paths: %s' % dirname)
209 208
210 209 @classmethod
211 210 def find_cluster_dir(cls, cluster_dir):
212 211 """Find/create a cluster dir and return its ClusterDir.
213 212
214 213 This will create the cluster directory if it doesn't exist.
215 214
216 215 Parameters
217 216 ----------
218 217 cluster_dir : unicode or str
219 218 The path of the cluster directory. This is expanded using
220 219 :func:`IPython.utils.genutils.expand_path`.
221 220 """
222 221 cluster_dir = expand_path(cluster_dir)
223 222 if not os.path.isdir(cluster_dir):
224 223 raise ClusterDirError('Cluster directory not found: %s' % cluster_dir)
225 224 return ClusterDir(location=cluster_dir)
226 225
227 226
228 227 #-----------------------------------------------------------------------------
229 228 # Command line options
230 229 #-----------------------------------------------------------------------------
231 230
232 231 class ClusterDirConfigLoader(BaseAppConfigLoader):
233 232
234 233 def _add_cluster_profile(self, parser):
235 234 paa = parser.add_argument
236 235 paa('-p', '--profile',
237 236 dest='Global.profile',type=unicode,
238 237 help=
239 238 """The string name of the profile to be used. This determines the name
240 239 of the cluster dir as: cluster_<profile>. The default profile is named
241 240 'default'. The cluster directory is resolve this way if the
242 241 --cluster-dir option is not used.""",
243 242 metavar='Global.profile')
244 243
245 244 def _add_cluster_dir(self, parser):
246 245 paa = parser.add_argument
247 246 paa('--cluster-dir',
248 247 dest='Global.cluster_dir',type=unicode,
249 248 help="""Set the cluster dir. This overrides the logic used by the
250 249 --profile option.""",
251 250 metavar='Global.cluster_dir')
252 251
253 252 def _add_work_dir(self, parser):
254 253 paa = parser.add_argument
255 254 paa('--work-dir',
256 255 dest='Global.work_dir',type=unicode,
257 256 help='Set the working dir for the process.',
258 257 metavar='Global.work_dir')
259 258
260 259 def _add_clean_logs(self, parser):
261 260 paa = parser.add_argument
262 261 paa('--clean-logs',
263 262 dest='Global.clean_logs', action='store_true',
264 263 help='Delete old log flies before starting.')
265 264
266 265 def _add_no_clean_logs(self, parser):
267 266 paa = parser.add_argument
268 267 paa('--no-clean-logs',
269 268 dest='Global.clean_logs', action='store_false',
270 269 help="Don't Delete old log flies before starting.")
271 270
272 271 def _add_arguments(self):
273 272 super(ClusterDirConfigLoader, self)._add_arguments()
274 273 self._add_cluster_profile(self.parser)
275 274 self._add_cluster_dir(self.parser)
276 275 self._add_work_dir(self.parser)
277 276 self._add_clean_logs(self.parser)
278 277 self._add_no_clean_logs(self.parser)
279 278
280 279
281 280 #-----------------------------------------------------------------------------
282 281 # Crash handler for this application
283 282 #-----------------------------------------------------------------------------
284 283
285 284
286 285 _message_template = """\
287 286 Oops, $self.app_name crashed. We do our best to make it stable, but...
288 287
289 288 A crash report was automatically generated with the following information:
290 289 - A verbatim copy of the crash traceback.
291 290 - Data on your current $self.app_name configuration.
292 291
293 292 It was left in the file named:
294 293 \t'$self.crash_report_fname'
295 294 If you can email this file to the developers, the information in it will help
296 295 them in understanding and correcting the problem.
297 296
298 297 You can mail it to: $self.contact_name at $self.contact_email
299 298 with the subject '$self.app_name Crash Report'.
300 299
301 300 If you want to do it now, the following command will work (under Unix):
302 301 mail -s '$self.app_name Crash Report' $self.contact_email < $self.crash_report_fname
303 302
304 303 To ensure accurate tracking of this issue, please file a report about it at:
305 304 $self.bug_tracker
306 305 """
307 306
308 307 class ClusterDirCrashHandler(CrashHandler):
309 308 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
310 309
311 310 message_template = _message_template
312 311
313 312 def __init__(self, app):
314 313 contact_name = release.authors['Brian'][0]
315 314 contact_email = release.authors['Brian'][1]
316 315 bug_tracker = 'http://github.com/ipython/ipython/issues'
317 316 super(ClusterDirCrashHandler,self).__init__(
318 317 app, contact_name, contact_email, bug_tracker
319 318 )
320 319
321 320
322 321 #-----------------------------------------------------------------------------
323 322 # Main application
324 323 #-----------------------------------------------------------------------------
325 324
326 325 class ApplicationWithClusterDir(Application):
327 326 """An application that puts everything into a cluster directory.
328 327
329 328 Instead of looking for things in the ipython_dir, this type of application
330 329 will use its own private directory called the "cluster directory"
331 330 for things like config files, log files, etc.
332 331
333 332 The cluster directory is resolved as follows:
334 333
335 334 * If the ``--cluster-dir`` option is given, it is used.
336 335 * If ``--cluster-dir`` is not given, the application directory is
337 336 resolve using the profile name as ``cluster_<profile>``. The search
338 337 path for this directory is then i) cwd if it is found there
339 338 and ii) in ipython_dir otherwise.
340 339
341 340 The config file for the application is to be put in the cluster
342 341 dir and named the value of the ``config_file_name`` class attribute.
343 342 """
344 343
345 344 command_line_loader = ClusterDirConfigLoader
346 345 crash_handler_class = ClusterDirCrashHandler
347 346 auto_create_cluster_dir = True
348 347 # temporarily override default_log_level to INFO
349 348 default_log_level = logging.INFO
350 349
351 350 def create_default_config(self):
352 351 super(ApplicationWithClusterDir, self).create_default_config()
353 352 self.default_config.Global.profile = u'default'
354 353 self.default_config.Global.cluster_dir = u''
355 354 self.default_config.Global.work_dir = os.getcwd()
356 355 self.default_config.Global.log_to_file = False
357 356 self.default_config.Global.log_url = None
358 357 self.default_config.Global.clean_logs = False
359 358
360 359 def find_resources(self):
361 360 """This resolves the cluster directory.
362 361
363 362 This tries to find the cluster directory and if successful, it will
364 363 have done:
365 364 * Sets ``self.cluster_dir_obj`` to the :class:`ClusterDir` object for
366 365 the application.
367 366 * Sets ``self.cluster_dir`` attribute of the application and config
368 367 objects.
369 368
370 369 The algorithm used for this is as follows:
371 370 1. Try ``Global.cluster_dir``.
372 371 2. Try using ``Global.profile``.
373 372 3. If both of these fail and ``self.auto_create_cluster_dir`` is
374 373 ``True``, then create the new cluster dir in the IPython directory.
375 374 4. If all fails, then raise :class:`ClusterDirError`.
376 375 """
377 376
378 377 try:
379 378 cluster_dir = self.command_line_config.Global.cluster_dir
380 379 except AttributeError:
381 380 cluster_dir = self.default_config.Global.cluster_dir
382 381 cluster_dir = expand_path(cluster_dir)
383 382 try:
384 383 self.cluster_dir_obj = ClusterDir.find_cluster_dir(cluster_dir)
385 384 except ClusterDirError:
386 385 pass
387 386 else:
388 387 self.log.info('Using existing cluster dir: %s' % \
389 388 self.cluster_dir_obj.location
390 389 )
391 390 self.finish_cluster_dir()
392 391 return
393 392
394 393 try:
395 394 self.profile = self.command_line_config.Global.profile
396 395 except AttributeError:
397 396 self.profile = self.default_config.Global.profile
398 397 try:
399 398 self.cluster_dir_obj = ClusterDir.find_cluster_dir_by_profile(
400 399 self.ipython_dir, self.profile)
401 400 except ClusterDirError:
402 401 pass
403 402 else:
404 403 self.log.info('Using existing cluster dir: %s' % \
405 404 self.cluster_dir_obj.location
406 405 )
407 406 self.finish_cluster_dir()
408 407 return
409 408
410 409 if self.auto_create_cluster_dir:
411 410 self.cluster_dir_obj = ClusterDir.create_cluster_dir_by_profile(
412 411 self.ipython_dir, self.profile
413 412 )
414 413 self.log.info('Creating new cluster dir: %s' % \
415 414 self.cluster_dir_obj.location
416 415 )
417 416 self.finish_cluster_dir()
418 417 else:
419 418 raise ClusterDirError('Could not find a valid cluster directory.')
420 419
421 420 def finish_cluster_dir(self):
422 421 # Set the cluster directory
423 422 self.cluster_dir = self.cluster_dir_obj.location
424 423
425 424 # These have to be set because they could be different from the one
426 425 # that we just computed. Because command line has the highest
427 426 # priority, this will always end up in the master_config.
428 427 self.default_config.Global.cluster_dir = self.cluster_dir
429 428 self.command_line_config.Global.cluster_dir = self.cluster_dir
430 429
431 430 def find_config_file_name(self):
432 431 """Find the config file name for this application."""
433 432 # For this type of Application it should be set as a class attribute.
434 433 if not hasattr(self, 'default_config_file_name'):
435 434 self.log.critical("No config filename found")
436 435 else:
437 436 self.config_file_name = self.default_config_file_name
438 437
439 438 def find_config_file_paths(self):
440 439 # Set the search path to to the cluster directory. We should NOT
441 440 # include IPython.config.default here as the default config files
442 441 # are ALWAYS automatically moved to the cluster directory.
443 442 conf_dir = os.path.join(get_ipython_package_dir(), 'config', 'default')
444 443 self.config_file_paths = (self.cluster_dir,)
445 444
446 445 def pre_construct(self):
447 446 # The log and security dirs were set earlier, but here we put them
448 447 # into the config and log them.
449 448 config = self.master_config
450 449 sdir = self.cluster_dir_obj.security_dir
451 450 self.security_dir = config.Global.security_dir = sdir
452 451 ldir = self.cluster_dir_obj.log_dir
453 452 self.log_dir = config.Global.log_dir = ldir
454 453 pdir = self.cluster_dir_obj.pid_dir
455 454 self.pid_dir = config.Global.pid_dir = pdir
456 455 self.log.info("Cluster directory set to: %s" % self.cluster_dir)
457 456 config.Global.work_dir = unicode(expand_path(config.Global.work_dir))
458 457 # Change to the working directory. We do this just before construct
459 458 # is called so all the components there have the right working dir.
460 459 self.to_work_dir()
461 460
462 461 def to_work_dir(self):
463 462 wd = self.master_config.Global.work_dir
464 463 if unicode(wd) != unicode(os.getcwd()):
465 464 os.chdir(wd)
466 465 self.log.info("Changing to working dir: %s" % wd)
467 466
468 467 def start_logging(self):
469 468 # Remove old log files
470 469 if self.master_config.Global.clean_logs:
471 470 log_dir = self.master_config.Global.log_dir
472 471 for f in os.listdir(log_dir):
473 472 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
474 473 # if f.startswith(self.name + u'-') and f.endswith('.log'):
475 474 os.remove(os.path.join(log_dir, f))
476 475 # Start logging to the new log file
477 476 if self.master_config.Global.log_to_file:
478 477 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
479 478 logfile = os.path.join(self.log_dir, log_filename)
480 479 open_log_file = open(logfile, 'w')
481 480 elif self.master_config.Global.log_url:
482 481 open_log_file = None
483 482 else:
484 483 open_log_file = sys.stdout
485 484 if open_log_file is not None:
486 485 self.log.removeHandler(self._log_handler)
487 486 self._log_handler = logging.StreamHandler(open_log_file)
488 487 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
489 488 self._log_handler.setFormatter(self._log_formatter)
490 489 self.log.addHandler(self._log_handler)
491 490 # log.startLogging(open_log_file)
492 491
493 492 def write_pid_file(self, overwrite=False):
494 493 """Create a .pid file in the pid_dir with my pid.
495 494
496 495 This must be called after pre_construct, which sets `self.pid_dir`.
497 496 This raises :exc:`PIDFileError` if the pid file exists already.
498 497 """
499 498 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
500 499 if os.path.isfile(pid_file):
501 500 pid = self.get_pid_from_file()
502 501 if not overwrite:
503 502 raise PIDFileError(
504 503 'The pid file [%s] already exists. \nThis could mean that this '
505 504 'server is already running with [pid=%s].' % (pid_file, pid)
506 505 )
507 506 with open(pid_file, 'w') as f:
508 507 self.log.info("Creating pid file: %s" % pid_file)
509 508 f.write(repr(os.getpid())+'\n')
510 509
511 510 def remove_pid_file(self):
512 511 """Remove the pid file.
513 512
514 513 This should be called at shutdown by registering a callback with
515 514 :func:`reactor.addSystemEventTrigger`. This needs to return
516 515 ``None``.
517 516 """
518 517 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
519 518 if os.path.isfile(pid_file):
520 519 try:
521 520 self.log.info("Removing pid file: %s" % pid_file)
522 521 os.remove(pid_file)
523 522 except:
524 523 self.log.warn("Error removing the pid file: %s" % pid_file)
525 524
526 525 def get_pid_from_file(self):
527 526 """Get the pid from the pid file.
528 527
529 528 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
530 529 """
531 530 pid_file = os.path.join(self.pid_dir, self.name + u'.pid')
532 531 if os.path.isfile(pid_file):
533 532 with open(pid_file, 'r') as f:
534 533 pid = int(f.read().strip())
535 534 return pid
536 535 else:
537 536 raise PIDFileError('pid file not found: %s' % pid_file)
538 537
@@ -1,115 +1,115 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller with 0MQ
3 3 This is a collection of one Hub and several Schedulers.
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15 from __future__ import print_function
16 16
17 17 import logging
18 18 from multiprocessing import Process
19 19
20 20 import zmq
21 21 from zmq.devices import ProcessMonitoredQueue
22 22 # internal:
23 23 from IPython.utils.importstring import import_item
24 from IPython.utils.traitlets import Int, Str, Instance, List, Bool
24 from IPython.utils.traitlets import Int, CStr, Instance, List, Bool
25 25
26 26 from .entry_point import signal_children
27 27 from .hub import Hub, HubFactory
28 28 from .scheduler import launch_scheduler
29 29
30 30 #-----------------------------------------------------------------------------
31 31 # Configurable
32 32 #-----------------------------------------------------------------------------
33 33
34 34
35 35 class ControllerFactory(HubFactory):
36 36 """Configurable for setting up a Hub and Schedulers."""
37 37
38 38 usethreads = Bool(False, config=True)
39 39 # pure-zmq downstream HWM
40 40 hwm = Int(0, config=True)
41 41
42 42 # internal
43 43 children = List()
44 mq_class = Str('zmq.devices.ProcessMonitoredQueue')
44 mq_class = CStr('zmq.devices.ProcessMonitoredQueue')
45 45
46 46 def _usethreads_changed(self, name, old, new):
47 47 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
48 48
49 49 def __init__(self, **kwargs):
50 50 super(ControllerFactory, self).__init__(**kwargs)
51 51 self.subconstructors.append(self.construct_schedulers)
52 52
53 53 def start(self):
54 54 super(ControllerFactory, self).start()
55 55 child_procs = []
56 56 for child in self.children:
57 57 child.start()
58 58 if isinstance(child, ProcessMonitoredQueue):
59 59 child_procs.append(child.launcher)
60 60 elif isinstance(child, Process):
61 61 child_procs.append(child)
62 62 if child_procs:
63 63 signal_children(child_procs)
64 64
65 65
66 66 def construct_schedulers(self):
67 67 children = self.children
68 68 mq = import_item(self.mq_class)
69 69
70 70 maybe_inproc = 'inproc://monitor' if self.usethreads else self.monitor_url
71 71 # IOPub relay (in a Process)
72 72 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, 'N/A','iopub')
73 73 q.bind_in(self.client_info['iopub'])
74 74 q.bind_out(self.engine_info['iopub'])
75 75 q.setsockopt_out(zmq.SUBSCRIBE, '')
76 76 q.connect_mon(maybe_inproc)
77 77 q.daemon=True
78 78 children.append(q)
79 79
80 80 # Multiplexer Queue (in a Process)
81 81 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'in', 'out')
82 82 q.bind_in(self.client_info['mux'])
83 83 q.bind_out(self.engine_info['mux'])
84 84 q.connect_mon(maybe_inproc)
85 85 q.daemon=True
86 86 children.append(q)
87 87
88 88 # Control Queue (in a Process)
89 89 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, 'incontrol', 'outcontrol')
90 90 q.bind_in(self.client_info['control'])
91 91 q.bind_out(self.engine_info['control'])
92 92 q.connect_mon(maybe_inproc)
93 93 q.daemon=True
94 94 children.append(q)
95 95 # Task Queue (in a Process)
96 96 if self.scheme == 'pure':
97 97 self.log.warn("task::using pure XREQ Task scheduler")
98 98 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, 'intask', 'outtask')
99 99 q.setsockopt_out(zmq.HWM, self.hwm)
100 100 q.bind_in(self.client_info['task'][1])
101 101 q.bind_out(self.engine_info['task'])
102 102 q.connect_mon(maybe_inproc)
103 103 q.daemon=True
104 104 children.append(q)
105 105 elif self.scheme == 'none':
106 106 self.log.warn("task::using no Task scheduler")
107 107
108 108 else:
109 109 self.log.info("task::using Python %s Task scheduler"%self.scheme)
110 110 sargs = (self.client_info['task'][1], self.engine_info['task'], self.monitor_url, self.client_info['notification'])
111 111 kwargs = dict(scheme=self.scheme,logname=self.log.name, loglevel=self.log.level, config=self.config)
112 112 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
113 113 q.daemon=True
114 114 children.append(q)
115 115
@@ -1,111 +1,153 b''
1 1 """Dependency utilities"""
2 2
3 3 from IPython.external.decorator import decorator
4 4
5 5 from .asyncresult import AsyncResult
6 6 from .error import UnmetDependency
7 7
8 8
9 9 class depend(object):
10 """Dependency decorator, for use with tasks."""
10 """Dependency decorator, for use with tasks.
11
12 `@depend` lets you define a function for engine dependencies
13 just like you use `apply` for tasks.
14
15
16 Examples
17 --------
18 ::
19
20 @depend(df, a,b, c=5)
21 def f(m,n,p)
22
23 view.apply(f, 1,2,3)
24
25 will call df(a,b,c=5) on the engine, and if it returns False or
26 raises an UnmetDependency error, then the task will not be run
27 and another engine will be tried.
28 """
11 29 def __init__(self, f, *args, **kwargs):
12 30 self.f = f
13 31 self.args = args
14 32 self.kwargs = kwargs
15 33
16 34 def __call__(self, f):
17 35 return dependent(f, self.f, *self.args, **self.kwargs)
18 36
19 37 class dependent(object):
20 38 """A function that depends on another function.
21 39 This is an object to prevent the closure used
22 40 in traditional decorators, which are not picklable.
23 41 """
24 42
25 43 def __init__(self, f, df, *dargs, **dkwargs):
26 44 self.f = f
27 45 self.func_name = getattr(f, '__name__', 'f')
28 46 self.df = df
29 47 self.dargs = dargs
30 48 self.dkwargs = dkwargs
31 49
32 50 def __call__(self, *args, **kwargs):
33 51 if self.df(*self.dargs, **self.dkwargs) is False:
34 52 raise UnmetDependency()
35 53 return self.f(*args, **kwargs)
36 54
37 55 @property
38 56 def __name__(self):
39 57 return self.func_name
40 58
41 59 def _require(*names):
60 """Helper for @require decorator."""
42 61 for name in names:
43 62 try:
44 63 __import__(name)
45 64 except ImportError:
46 65 return False
47 66 return True
48 67
49 68 def require(*names):
69 """Simple decorator for requiring names to be importable.
70
71 Examples
72 --------
73
74 In [1]: @require('numpy')
75 ...: def norm(a):
76 ...: import numpy
77 ...: return numpy.linalg.norm(a,2)
78 """
50 79 return depend(_require, *names)
51 80
52 81 class Dependency(set):
53 82 """An object for representing a set of msg_id dependencies.
54 83
55 Subclassed from set()."""
84 Subclassed from set().
85
86 Parameters
87 ----------
88 dependencies: list/set of msg_ids or AsyncResult objects or output of Dependency.as_dict()
89 The msg_ids to depend on
90 all : bool [default True]
91 Whether the dependency should be considered met when *all* depending tasks have completed
92 or only when *any* have been completed.
93 success_only : bool [default True]
94 Whether to consider only successes for Dependencies, or consider failures as well.
95 If `all=success_only=True`, then this task will fail with an ImpossibleDependency
96 as soon as the first depended-upon task fails.
97 """
56 98
57 99 all=True
58 100 success_only=True
59 101
60 102 def __init__(self, dependencies=[], all=True, success_only=True):
61 103 if isinstance(dependencies, dict):
62 104 # load from dict
63 105 all = dependencies.get('all', True)
64 106 success_only = dependencies.get('success_only', success_only)
65 107 dependencies = dependencies.get('dependencies', [])
66 108 ids = []
67 109 if isinstance(dependencies, AsyncResult):
68 110 ids.extend(AsyncResult.msg_ids)
69 111 else:
70 112 for d in dependencies:
71 113 if isinstance(d, basestring):
72 114 ids.append(d)
73 115 elif isinstance(d, AsyncResult):
74 116 ids.extend(d.msg_ids)
75 117 else:
76 118 raise TypeError("invalid dependency type: %r"%type(d))
77 119 set.__init__(self, ids)
78 120 self.all = all
79 121 self.success_only=success_only
80 122
81 123 def check(self, completed, failed=None):
82 124 if failed is not None and not self.success_only:
83 125 completed = completed.union(failed)
84 126 if len(self) == 0:
85 127 return True
86 128 if self.all:
87 129 return self.issubset(completed)
88 130 else:
89 131 return not self.isdisjoint(completed)
90 132
91 133 def unreachable(self, failed):
92 134 if len(self) == 0 or len(failed) == 0 or not self.success_only:
93 135 return False
94 136 # print self, self.success_only, self.all, failed
95 137 if self.all:
96 138 return not self.isdisjoint(failed)
97 139 else:
98 140 return self.issubset(failed)
99 141
100 142
101 143 def as_dict(self):
102 144 """Represent this dependency as a dict. For json compatibility."""
103 145 return dict(
104 146 dependencies=list(self),
105 147 all=self.all,
106 148 success_only=self.success_only,
107 149 )
108 150
109 151
110 152 __all__ = ['depend', 'require', 'dependent', 'Dependency']
111 153
@@ -1,152 +1,152 b''
1 1 """A Task logger that presents our DB interface,
2 2 but exists entirely in memory and implemented with dicts.
3 3
4 4 TaskRecords are dicts of the form:
5 5 {
6 6 'msg_id' : str(uuid),
7 7 'client_uuid' : str(uuid),
8 8 'engine_uuid' : str(uuid) or None,
9 9 'header' : dict(header),
10 10 'content': dict(content),
11 11 'buffers': list(buffers),
12 12 'submitted': datetime,
13 13 'started': datetime or None,
14 14 'completed': datetime or None,
15 15 'resubmitted': datetime or None,
16 16 'result_header' : dict(header) or None,
17 17 'result_content' : dict(content) or None,
18 18 'result_buffers' : list(buffers) or None,
19 19 }
20 20 With this info, many of the special categories of tasks can be defined by query:
21 21
22 22 pending: completed is None
23 23 client's outstanding: client_uuid = uuid && completed is None
24 24 MIA: arrived is None (and completed is None)
25 25 etc.
26 26
27 27 EngineRecords are dicts of the form:
28 28 {
29 29 'eid' : int(id),
30 30 'uuid': str(uuid)
31 31 }
32 32 This may be extended, but is currently.
33 33
34 34 We support a subset of mongodb operators:
35 35 $lt,$gt,$lte,$gte,$ne,$in,$nin,$all,$mod,$exists
36 36 """
37 37 #-----------------------------------------------------------------------------
38 38 # Copyright (C) 2010 The IPython Development Team
39 39 #
40 40 # Distributed under the terms of the BSD License. The full license is in
41 41 # the file COPYING, distributed as part of this software.
42 42 #-----------------------------------------------------------------------------
43 43
44 44
45 45 from datetime import datetime
46 46
47 47 filters = {
48 '$eq' : lambda a,b: a==b,
49 48 '$lt' : lambda a,b: a < b,
50 49 '$gt' : lambda a,b: b > a,
50 '$eq' : lambda a,b: a == b,
51 '$ne' : lambda a,b: a != b,
51 52 '$lte': lambda a,b: a <= b,
52 53 '$gte': lambda a,b: a >= b,
53 '$ne' : lambda a,b: not a==b,
54 54 '$in' : lambda a,b: a in b,
55 55 '$nin': lambda a,b: a not in b,
56 '$all' : lambda a,b: all([ a in bb for bb in b ]),
56 '$all': lambda a,b: all([ a in bb for bb in b ]),
57 57 '$mod': lambda a,b: a%b[0] == b[1],
58 58 '$exists' : lambda a,b: (b and a is not None) or (a is None and not b)
59 59 }
60 60
61 61
62 62 class CompositeFilter(object):
63 63 """Composite filter for matching multiple properties."""
64 64
65 65 def __init__(self, dikt):
66 66 self.tests = []
67 67 self.values = []
68 68 for key, value in dikt.iteritems():
69 69 self.tests.append(filters[key])
70 70 self.values.append(value)
71 71
72 72 def __call__(self, value):
73 73 for test,check in zip(self.tests, self.values):
74 74 if not test(value, check):
75 75 return False
76 76 return True
77 77
78 78 class BaseDB(object):
79 79 """Empty Parent class so traitlets work on DB."""
80 80 pass
81 81
82 82 class DictDB(BaseDB):
83 83 """Basic in-memory dict-based object for saving Task Records.
84 84
85 85 This is the first object to present the DB interface
86 86 for logging tasks out of memory.
87 87
88 88 The interface is based on MongoDB, so adding a MongoDB
89 89 backend should be straightforward.
90 90 """
91 91 _records = None
92 92
93 93 def __init__(self, *args, **kwargs):
94 94 self._records = dict()
95 95
96 96 def _match_one(self, rec, tests):
97 97 """Check if a specific record matches tests."""
98 98 for key,test in tests.iteritems():
99 99 if not test(rec.get(key, None)):
100 100 return False
101 101 return True
102 102
103 103 def _match(self, check, id_only=True):
104 104 """Find all the matches for a check dict."""
105 105 matches = {}
106 106 tests = {}
107 107 for k,v in check.iteritems():
108 108 if isinstance(v, dict):
109 109 tests[k] = CompositeFilter(v)
110 110 else:
111 111 tests[k] = lambda o: o==v
112 112
113 113 for msg_id, rec in self._records.iteritems():
114 114 if self._match_one(rec, tests):
115 115 matches[msg_id] = rec
116 116 if id_only:
117 117 return matches.keys()
118 118 else:
119 119 return matches
120 120
121 121
122 122 def add_record(self, msg_id, rec):
123 123 """Add a new Task Record, by msg_id."""
124 124 if self._records.has_key(msg_id):
125 125 raise KeyError("Already have msg_id %r"%(msg_id))
126 126 self._records[msg_id] = rec
127 127
128 128 def get_record(self, msg_id):
129 129 """Get a specific Task Record, by msg_id."""
130 130 if not self._records.has_key(msg_id):
131 131 raise KeyError("No such msg_id %r"%(msg_id))
132 132 return self._records[msg_id]
133 133
134 134 def update_record(self, msg_id, rec):
135 135 """Update the data in an existing record."""
136 136 self._records[msg_id].update(rec)
137 137
138 138 def drop_matching_records(self, check):
139 139 """Remove a record from the DB."""
140 140 matches = self._match(check, id_only=True)
141 141 for m in matches:
142 142 del self._records[m]
143 143
144 144 def drop_record(self, msg_id):
145 145 """Remove a record from the DB."""
146 146 del self._records[msg_id]
147 147
148 148
149 149 def find_records(self, check, id_only=False):
150 150 """Find records matching a query dict."""
151 151 matches = self._match(check, id_only)
152 152 return matches No newline at end of file
@@ -1,147 +1,139 b''
1 1 #!/usr/bin/env python
2 2 """A simple engine that talks to a controller over 0MQ.
3 3 it handles registration, etc. and launches a kernel
4 connected to the Controller's queue(s).
4 connected to the Controller's Schedulers.
5 5 """
6 6 from __future__ import print_function
7 7
8 import logging
9 8 import sys
10 9 import time
11 import uuid
12 from pprint import pprint
13 10
14 11 import zmq
15 12 from zmq.eventloop import ioloop, zmqstream
16 13
17 14 # internal
18 from IPython.config.configurable import Configurable
19 15 from IPython.utils.traitlets import Instance, Str, Dict, Int, Type, CFloat
20 16 # from IPython.utils.localinterfaces import LOCALHOST
21 17
22 18 from . import heartmonitor
23 19 from .factory import RegistrationFactory
24 20 from .streamkernel import Kernel
25 21 from .streamsession import Message
26 22 from .util import disambiguate_url
27 23
28 def printer(*msg):
29 # print (self.log.handlers, file=sys.__stdout__)
30 self.log.info(str(msg))
31
32 24 class EngineFactory(RegistrationFactory):
33 25 """IPython engine"""
34 26
35 27 # configurables:
36 28 user_ns=Dict(config=True)
37 29 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True)
38 30 display_hook_factory=Type('IPython.zmq.displayhook.DisplayHook', config=True)
39 31 location=Str(config=True)
40 32 timeout=CFloat(2,config=True)
41 33
42 34 # not configurable:
43 35 id=Int(allow_none=True)
44 36 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
45 37 kernel=Instance(Kernel)
46 38
47 39
48 40 def __init__(self, **kwargs):
49 41 super(EngineFactory, self).__init__(**kwargs)
50 42 ctx = self.context
51 43
52 44 reg = ctx.socket(zmq.PAIR)
53 45 reg.setsockopt(zmq.IDENTITY, self.ident)
54 46 reg.connect(self.url)
55 47 self.registrar = zmqstream.ZMQStream(reg, self.loop)
56 48
57 49 def register(self):
58 50 """send the registration_request"""
59 51
60 52 self.log.info("registering")
61 53 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
62 54 self.registrar.on_recv(self.complete_registration)
63 55 # print (self.session.key)
64 56 self.session.send(self.registrar, "registration_request",content=content)
65 57
66 58 def complete_registration(self, msg):
67 59 # print msg
68 60 self._abort_dc.stop()
69 61 ctx = self.context
70 62 loop = self.loop
71 63 identity = self.ident
72 64 print (identity)
73 65
74 66 idents,msg = self.session.feed_identities(msg)
75 67 msg = Message(self.session.unpack_message(msg))
76 68
77 69 if msg.content.status == 'ok':
78 70 self.id = int(msg.content.id)
79 71
80 72 # create Shell Streams (MUX, Task, etc.):
81 73 queue_addr = msg.content.mux
82 74 shell_addrs = [ str(queue_addr) ]
83 75 task_addr = msg.content.task
84 76 if task_addr:
85 77 shell_addrs.append(str(task_addr))
86 78 shell_streams = []
87 79 for addr in shell_addrs:
88 80 stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
89 81 stream.setsockopt(zmq.IDENTITY, identity)
90 82 stream.connect(disambiguate_url(addr, self.location))
91 83 shell_streams.append(stream)
92 84
93 85 # control stream:
94 86 control_addr = str(msg.content.control)
95 87 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.PAIR), loop)
96 88 control_stream.setsockopt(zmq.IDENTITY, identity)
97 89 control_stream.connect(disambiguate_url(control_addr, self.location))
98 90
99 91 # create iopub stream:
100 92 iopub_addr = msg.content.iopub
101 93 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
102 94 iopub_stream.setsockopt(zmq.IDENTITY, identity)
103 95 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
104 96
105 97 # launch heartbeat
106 98 hb_addrs = msg.content.heartbeat
107 99 # print (hb_addrs)
108 100
109 101 # # Redirect input streams and set a display hook.
110 102 if self.out_stream_factory:
111 103 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
112 104 sys.stdout.topic = 'engine.%i.stdout'%self.id
113 105 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
114 106 sys.stderr.topic = 'engine.%i.stderr'%self.id
115 107 if self.display_hook_factory:
116 108 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
117 109 sys.displayhook.topic = 'engine.%i.pyout'%self.id
118 110
119 111 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
120 112 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
121 113 loop=loop, user_ns = self.user_ns, logname=self.log.name)
122 114 self.kernel.start()
123 115 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
124 116 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
125 117 # ioloop.DelayedCallback(heart.start, 1000, self.loop).start()
126 118 heart.start()
127 119
128 120
129 121 else:
130 122 self.log.fatal("Registration Failed: %s"%msg)
131 123 raise Exception("Registration Failed: %s"%msg)
132 124
133 125 self.log.info("Completed registration with id %i"%self.id)
134 126
135 127
136 128 def abort(self):
137 129 self.log.fatal("Registration timed out")
138 130 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
139 131 time.sleep(1)
140 132 sys.exit(255)
141 133
142 134 def start(self):
143 135 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
144 136 dc.start()
145 137 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
146 138 self._abort_dc.start()
147 139
@@ -1,292 +1,313 b''
1 1 # encoding: utf-8
2 2
3 3 """Classes and functions for kernel related errors and exceptions."""
4 4 from __future__ import print_function
5 5
6 import sys
7 import traceback
8
6 9 __docformat__ = "restructuredtext en"
7 10
8 11 # Tell nose to skip this module
9 12 __test__ = {}
10 13
11 14 #-------------------------------------------------------------------------------
12 15 # Copyright (C) 2008 The IPython Development Team
13 16 #
14 17 # Distributed under the terms of the BSD License. The full license is in
15 18 # the file COPYING, distributed as part of this software.
16 19 #-------------------------------------------------------------------------------
17 20
18 21 #-------------------------------------------------------------------------------
19 22 # Error classes
20 23 #-------------------------------------------------------------------------------
21 24 class IPythonError(Exception):
22 25 """Base exception that all of our exceptions inherit from.
23 26
24 27 This can be raised by code that doesn't have any more specific
25 28 information."""
26 29
27 30 pass
28 31
29 32 # Exceptions associated with the controller objects
30 33 class ControllerError(IPythonError): pass
31 34
32 35 class ControllerCreationError(ControllerError): pass
33 36
34 37
35 38 # Exceptions associated with the Engines
36 39 class EngineError(IPythonError): pass
37 40
38 41 class EngineCreationError(EngineError): pass
39 42
40 43 class KernelError(IPythonError):
41 44 pass
42 45
43 46 class NotDefined(KernelError):
44 47 def __init__(self, name):
45 48 self.name = name
46 49 self.args = (name,)
47 50
48 51 def __repr__(self):
49 52 return '<NotDefined: %s>' % self.name
50 53
51 54 __str__ = __repr__
52 55
53 56
54 57 class QueueCleared(KernelError):
55 58 pass
56 59
57 60
58 61 class IdInUse(KernelError):
59 62 pass
60 63
61 64
62 65 class ProtocolError(KernelError):
63 66 pass
64 67
65 68
66 69 class ConnectionError(KernelError):
67 70 pass
68 71
69 72
70 73 class InvalidEngineID(KernelError):
71 74 pass
72 75
73 76
74 77 class NoEnginesRegistered(KernelError):
75 78 pass
76 79
77 80
78 81 class InvalidClientID(KernelError):
79 82 pass
80 83
81 84
82 85 class InvalidDeferredID(KernelError):
83 86 pass
84 87
85 88
86 89 class SerializationError(KernelError):
87 90 pass
88 91
89 92
90 93 class MessageSizeError(KernelError):
91 94 pass
92 95
93 96
94 97 class PBMessageSizeError(MessageSizeError):
95 98 pass
96 99
97 100
98 101 class ResultNotCompleted(KernelError):
99 102 pass
100 103
101 104
102 105 class ResultAlreadyRetrieved(KernelError):
103 106 pass
104 107
105 108 class ClientError(KernelError):
106 109 pass
107 110
108 111
109 112 class TaskAborted(KernelError):
110 113 pass
111 114
112 115
113 116 class TaskTimeout(KernelError):
114 117 pass
115 118
116 119
117 120 class NotAPendingResult(KernelError):
118 121 pass
119 122
120 123
121 124 class UnpickleableException(KernelError):
122 125 pass
123 126
124 127
125 128 class AbortedPendingDeferredError(KernelError):
126 129 pass
127 130
128 131
129 132 class InvalidProperty(KernelError):
130 133 pass
131 134
132 135
133 136 class MissingBlockArgument(KernelError):
134 137 pass
135 138
136 139
137 140 class StopLocalExecution(KernelError):
138 141 pass
139 142
140 143
141 144 class SecurityError(KernelError):
142 145 pass
143 146
144 147
145 148 class FileTimeoutError(KernelError):
146 149 pass
147 150
148 151 class TimeoutError(KernelError):
149 152 pass
150 153
151 154 class UnmetDependency(KernelError):
152 155 pass
153 156
154 157 class ImpossibleDependency(UnmetDependency):
155 158 pass
156 159
157 160 class DependencyTimeout(ImpossibleDependency):
158 161 pass
159 162
160 163 class InvalidDependency(ImpossibleDependency):
161 164 pass
162 165
163 166 class RemoteError(KernelError):
164 167 """Error raised elsewhere"""
165 168 ename=None
166 169 evalue=None
167 170 traceback=None
168 171 engine_info=None
169 172
170 173 def __init__(self, ename, evalue, traceback, engine_info=None):
171 174 self.ename=ename
172 175 self.evalue=evalue
173 176 self.traceback=traceback
174 177 self.engine_info=engine_info or {}
175 178 self.args=(ename, evalue)
176 179
177 180 def __repr__(self):
178 181 engineid = self.engine_info.get('engine_id', ' ')
179 182 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
180 183
181 184 def __str__(self):
182 185 sig = "%s(%s)"%(self.ename, self.evalue)
183 186 if self.traceback:
184 187 return sig + '\n' + self.traceback
185 188 else:
186 189 return sig
187 190
188 191
189 192 class TaskRejectError(KernelError):
190 193 """Exception to raise when a task should be rejected by an engine.
191 194
192 195 This exception can be used to allow a task running on an engine to test
193 196 if the engine (or the user's namespace on the engine) has the needed
194 197 task dependencies. If not, the task should raise this exception. For
195 198 the task to be retried on another engine, the task should be created
196 199 with the `retries` argument > 1.
197 200
198 201 The advantage of this approach over our older properties system is that
199 202 tasks have full access to the user's namespace on the engines and the
200 203 properties don't have to be managed or tested by the controller.
201 204 """
202 205
203 206
204 207 class CompositeError(RemoteError):
205 208 """Error for representing possibly multiple errors on engines"""
206 209 def __init__(self, message, elist):
207 210 Exception.__init__(self, *(message, elist))
208 211 # Don't use pack_exception because it will conflict with the .message
209 212 # attribute that is being deprecated in 2.6 and beyond.
210 213 self.msg = message
211 214 self.elist = elist
212 215 self.args = [ e[0] for e in elist ]
213 216
214 217 def _get_engine_str(self, ei):
215 218 if not ei:
216 219 return '[Engine Exception]'
217 220 else:
218 221 return '[%s:%s]: ' % (ei['engine_id'], ei['method'])
219 222
220 223 def _get_traceback(self, ev):
221 224 try:
222 225 tb = ev._ipython_traceback_text
223 226 except AttributeError:
224 227 return 'No traceback available'
225 228 else:
226 229 return tb
227 230
228 231 def __str__(self):
229 232 s = str(self.msg)
230 233 for en, ev, etb, ei in self.elist:
231 234 engine_str = self._get_engine_str(ei)
232 235 s = s + '\n' + engine_str + en + ': ' + str(ev)
233 236 return s
234 237
235 238 def __repr__(self):
236 239 return "CompositeError(%i)"%len(self.elist)
237 240
238 241 def print_tracebacks(self, excid=None):
239 242 if excid is None:
240 243 for (en,ev,etb,ei) in self.elist:
241 244 print (self._get_engine_str(ei))
242 245 print (etb or 'No traceback available')
243 246 print ()
244 247 else:
245 248 try:
246 249 en,ev,etb,ei = self.elist[excid]
247 250 except:
248 251 raise IndexError("an exception with index %i does not exist"%excid)
249 252 else:
250 253 print (self._get_engine_str(ei))
251 254 print (etb or 'No traceback available')
252 255
253 256 def raise_exception(self, excid=0):
254 257 try:
255 258 en,ev,etb,ei = self.elist[excid]
256 259 except:
257 260 raise IndexError("an exception with index %i does not exist"%excid)
258 261 else:
259 262 raise RemoteError(en, ev, etb, ei)
260 263
261 264
262 265 def collect_exceptions(rdict_or_list, method='unspecified'):
263 266 """check a result dict for errors, and raise CompositeError if any exist.
264 267 Passthrough otherwise."""
265 268 elist = []
266 269 if isinstance(rdict_or_list, dict):
267 270 rlist = rdict_or_list.values()
268 271 else:
269 272 rlist = rdict_or_list
270 273 for r in rlist:
271 274 if isinstance(r, RemoteError):
272 275 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
273 276 # Sometimes we could have CompositeError in our list. Just take
274 277 # the errors out of them and put them in our new list. This
275 278 # has the effect of flattening lists of CompositeErrors into one
276 279 # CompositeError
277 280 if en=='CompositeError':
278 281 for e in ev.elist:
279 282 elist.append(e)
280 283 else:
281 284 elist.append((en, ev, etb, ei))
282 285 if len(elist)==0:
283 286 return rdict_or_list
284 287 else:
285 288 msg = "one or more exceptions from call to method: %s" % (method)
286 289 # This silliness is needed so the debugger has access to the exception
287 290 # instance (e in this case)
288 291 try:
289 292 raise CompositeError(msg, elist)
290 293 except CompositeError as e:
291 294 raise e
292 295
296 def wrap_exception(engine_info={}):
297 etype, evalue, tb = sys.exc_info()
298 stb = traceback.format_exception(etype, evalue, tb)
299 exc_content = {
300 'status' : 'error',
301 'traceback' : stb,
302 'ename' : unicode(etype.__name__),
303 'evalue' : unicode(evalue),
304 'engine_info' : engine_info
305 }
306 return exc_content
307
308 def unwrap_exception(content):
309 err = RemoteError(content['ename'], content['evalue'],
310 ''.join(content['traceback']),
311 content.get('engine_info', {}))
312 return err
313
@@ -1,152 +1,152 b''
1 1 """Base config factories."""
2 2
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2008-2009 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13
14 14
15 15 import logging
16 16 import os
17 17 import uuid
18 18
19 19 from zmq.eventloop.ioloop import IOLoop
20 20
21 21 from IPython.config.configurable import Configurable
22 22 from IPython.utils.importstring import import_item
23 23 from IPython.utils.traitlets import Str,Int,Instance, CUnicode, CStr
24 24
25 25 import IPython.zmq.parallel.streamsession as ss
26 26 from IPython.zmq.parallel.entry_point import select_random_ports
27 27
28 28 #-----------------------------------------------------------------------------
29 29 # Classes
30 30 #-----------------------------------------------------------------------------
31 31 class LoggingFactory(Configurable):
32 32 """A most basic class, that has a `log` (type:`Logger`) attribute, set via a `logname` Trait."""
33 33 log = Instance('logging.Logger', ('ZMQ', logging.WARN))
34 logname = CStr('ZMQ')
34 logname = CUnicode('ZMQ')
35 35 def _logname_changed(self, name, old, new):
36 36 self.log = logging.getLogger(new)
37 37
38 38
39 39 class SessionFactory(LoggingFactory):
40 40 """The Base factory from which every factory in IPython.zmq.parallel inherits"""
41 41
42 42 packer = Str('',config=True)
43 43 unpacker = Str('',config=True)
44 44 ident = CStr('',config=True)
45 45 def _ident_default(self):
46 46 return str(uuid.uuid4())
47 username = Str(os.environ.get('USER','username'),config=True)
48 exec_key = CStr('',config=True)
47 username = CUnicode(os.environ.get('USER','username'),config=True)
48 exec_key = CUnicode('',config=True)
49 49 # not configurable:
50 50 context = Instance('zmq.Context', (), {})
51 51 session = Instance('IPython.zmq.parallel.streamsession.StreamSession')
52 52 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
53 53 def _loop_default(self):
54 54 return IOLoop.instance()
55 55
56 56
57 57 def __init__(self, **kwargs):
58 58 super(SessionFactory, self).__init__(**kwargs)
59 59 exec_key = self.exec_key or None
60 60 # set the packers:
61 61 if not self.packer:
62 62 packer_f = unpacker_f = None
63 63 elif self.packer.lower() == 'json':
64 64 packer_f = ss.json_packer
65 65 unpacker_f = ss.json_unpacker
66 66 elif self.packer.lower() == 'pickle':
67 67 packer_f = ss.pickle_packer
68 68 unpacker_f = ss.pickle_unpacker
69 69 else:
70 70 packer_f = import_item(self.packer)
71 71 unpacker_f = import_item(self.unpacker)
72 72
73 73 # construct the session
74 74 self.session = ss.StreamSession(self.username, self.ident, packer=packer_f, unpacker=unpacker_f, key=exec_key)
75 75
76 76
77 77 class RegistrationFactory(SessionFactory):
78 78 """The Base Configurable for objects that involve registration."""
79 79
80 80 url = Str('', config=True) # url takes precedence over ip,regport,transport
81 81 transport = Str('tcp', config=True)
82 82 ip = Str('127.0.0.1', config=True)
83 83 regport = Instance(int, config=True)
84 84 def _regport_default(self):
85 85 # return 10101
86 86 return select_random_ports(1)[0]
87 87
88 88 def __init__(self, **kwargs):
89 89 super(RegistrationFactory, self).__init__(**kwargs)
90 90 self._propagate_url()
91 91 self._rebuild_url()
92 92 self.on_trait_change(self._propagate_url, 'url')
93 93 self.on_trait_change(self._rebuild_url, 'ip')
94 94 self.on_trait_change(self._rebuild_url, 'transport')
95 95 self.on_trait_change(self._rebuild_url, 'regport')
96 96
97 97 def _rebuild_url(self):
98 98 self.url = "%s://%s:%i"%(self.transport, self.ip, self.regport)
99 99
100 100 def _propagate_url(self):
101 101 """Ensure self.url contains full transport://interface:port"""
102 102 if self.url:
103 103 iface = self.url.split('://',1)
104 104 if len(iface) == 2:
105 105 self.transport,iface = iface
106 106 iface = iface.split(':')
107 107 self.ip = iface[0]
108 108 if iface[1]:
109 109 self.regport = int(iface[1])
110 110
111 111 #-----------------------------------------------------------------------------
112 112 # argparse argument extenders
113 113 #-----------------------------------------------------------------------------
114 114
115 115
116 116 def add_session_arguments(parser):
117 117 paa = parser.add_argument
118 118 paa('--ident',
119 119 type=str, dest='SessionFactory.ident',
120 120 help='set the ZMQ and session identity [default: random uuid]',
121 121 metavar='identity')
122 122 # paa('--execkey',
123 123 # type=str, dest='SessionFactory.exec_key',
124 124 # help='path to a file containing an execution key.',
125 125 # metavar='execkey')
126 126 paa('--packer',
127 127 type=str, dest='SessionFactory.packer',
128 128 help='method to serialize messages: {json,pickle} [default: json]',
129 129 metavar='packer')
130 130 paa('--unpacker',
131 131 type=str, dest='SessionFactory.unpacker',
132 132 help='inverse function of `packer`. Only necessary when using something other than json|pickle',
133 133 metavar='packer')
134 134
135 135 def add_registration_arguments(parser):
136 136 paa = parser.add_argument
137 137 paa('--ip',
138 138 type=str, dest='RegistrationFactory.ip',
139 139 help="The IP used for registration [default: localhost]",
140 140 metavar='ip')
141 141 paa('--transport',
142 142 type=str, dest='RegistrationFactory.transport',
143 143 help="The ZeroMQ transport used for registration [default: tcp]",
144 144 metavar='transport')
145 145 paa('--url',
146 146 type=str, dest='RegistrationFactory.url',
147 147 help='set transport,ip,regport in one go, e.g. tcp://127.0.0.1:10101',
148 148 metavar='url')
149 149 paa('--regport',
150 150 type=int, dest='RegistrationFactory.regport',
151 151 help="The port used for registration [default: 10101]",
152 152 metavar='ip')
@@ -1,1054 +1,1052 b''
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 import logging
19 18 import sys
20 19 import time
21 20 from datetime import datetime
22 21
23 22 import zmq
24 23 from zmq.eventloop import ioloop
25 24 from zmq.eventloop.zmqstream import ZMQStream
26 25
27 26 # internal:
28 from IPython.config.configurable import Configurable
29 27 from IPython.utils.importstring import import_item
30 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
31 29
32 30 from .entry_point import select_random_ports
33 31 from .factory import RegistrationFactory, LoggingFactory
34 32
33 from . import error
35 34 from .heartmonitor import HeartMonitor
36 from .streamsession import Message, wrap_exception, ISO8601
37 from .util import validate_url_container
35 from .util import validate_url_container, ISO8601
38 36
39 37 try:
40 38 from pymongo.binary import Binary
41 39 except ImportError:
42 40 MongoDB=None
43 41 else:
44 42 from mongodb import MongoDB
45 43
46 44 #-----------------------------------------------------------------------------
47 45 # Code
48 46 #-----------------------------------------------------------------------------
49 47
50 48 def _passer(*args, **kwargs):
51 49 return
52 50
53 51 def _printer(*args, **kwargs):
54 52 print (args)
55 53 print (kwargs)
56 54
57 55 def init_record(msg):
58 56 """Initialize a TaskRecord based on a request."""
59 57 header = msg['header']
60 58 return {
61 59 'msg_id' : header['msg_id'],
62 60 'header' : header,
63 61 'content': msg['content'],
64 62 'buffers': msg['buffers'],
65 63 'submitted': datetime.strptime(header['date'], ISO8601),
66 64 'client_uuid' : None,
67 65 'engine_uuid' : None,
68 66 'started': None,
69 67 'completed': None,
70 68 'resubmitted': None,
71 69 'result_header' : None,
72 70 'result_content' : None,
73 71 'result_buffers' : None,
74 72 'queue' : None,
75 73 'pyin' : None,
76 74 'pyout': None,
77 75 'pyerr': None,
78 76 'stdout': '',
79 77 'stderr': '',
80 78 }
81 79
82 80
83 81 class EngineConnector(HasTraits):
84 82 """A simple object for accessing the various zmq connections of an object.
85 83 Attributes are:
86 84 id (int): engine ID
87 85 uuid (str): uuid (unused?)
88 86 queue (str): identity of queue's XREQ socket
89 87 registration (str): identity of registration XREQ socket
90 88 heartbeat (str): identity of heartbeat XREQ socket
91 89 """
92 90 id=Int(0)
93 91 queue=Str()
94 92 control=Str()
95 93 registration=Str()
96 94 heartbeat=Str()
97 95 pending=Set()
98 96
99 97 class HubFactory(RegistrationFactory):
100 98 """The Configurable for setting up a Hub."""
101 99
102 100 # name of a scheduler scheme
103 101 scheme = Str('leastload', config=True)
104 102
105 103 # port-pairs for monitoredqueues:
106 104 hb = Instance(list, config=True)
107 105 def _hb_default(self):
108 106 return select_random_ports(2)
109 107
110 108 mux = Instance(list, config=True)
111 109 def _mux_default(self):
112 110 return select_random_ports(2)
113 111
114 112 task = Instance(list, config=True)
115 113 def _task_default(self):
116 114 return select_random_ports(2)
117 115
118 116 control = Instance(list, config=True)
119 117 def _control_default(self):
120 118 return select_random_ports(2)
121 119
122 120 iopub = Instance(list, config=True)
123 121 def _iopub_default(self):
124 122 return select_random_ports(2)
125 123
126 124 # single ports:
127 125 mon_port = Instance(int, config=True)
128 126 def _mon_port_default(self):
129 127 return select_random_ports(1)[0]
130 128
131 129 query_port = Instance(int, config=True)
132 130 def _query_port_default(self):
133 131 return select_random_ports(1)[0]
134 132
135 133 notifier_port = Instance(int, config=True)
136 134 def _notifier_port_default(self):
137 135 return select_random_ports(1)[0]
138 136
139 137 ping = Int(1000, config=True) # ping frequency
140 138
141 139 engine_ip = CStr('127.0.0.1', config=True)
142 140 engine_transport = CStr('tcp', config=True)
143 141
144 142 client_ip = CStr('127.0.0.1', config=True)
145 143 client_transport = CStr('tcp', config=True)
146 144
147 145 monitor_ip = CStr('127.0.0.1', config=True)
148 146 monitor_transport = CStr('tcp', config=True)
149 147
150 148 monitor_url = CStr('')
151 149
152 150 db_class = CStr('IPython.zmq.parallel.dictdb.DictDB', config=True)
153 151
154 152 # not configurable
155 153 db = Instance('IPython.zmq.parallel.dictdb.BaseDB')
156 154 heartmonitor = Instance('IPython.zmq.parallel.heartmonitor.HeartMonitor')
157 155 subconstructors = List()
158 156 _constructed = Bool(False)
159 157
160 158 def _ip_changed(self, name, old, new):
161 159 self.engine_ip = new
162 160 self.client_ip = new
163 161 self.monitor_ip = new
164 162 self._update_monitor_url()
165 163
166 164 def _update_monitor_url(self):
167 165 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
168 166
169 167 def _transport_changed(self, name, old, new):
170 168 self.engine_transport = new
171 169 self.client_transport = new
172 170 self.monitor_transport = new
173 171 self._update_monitor_url()
174 172
175 173 def __init__(self, **kwargs):
176 174 super(HubFactory, self).__init__(**kwargs)
177 175 self._update_monitor_url()
178 176 # self.on_trait_change(self._sync_ips, 'ip')
179 177 # self.on_trait_change(self._sync_transports, 'transport')
180 178 self.subconstructors.append(self.construct_hub)
181 179
182 180
183 181 def construct(self):
184 182 assert not self._constructed, "already constructed!"
185 183
186 184 for subc in self.subconstructors:
187 185 subc()
188 186
189 187 self._constructed = True
190 188
191 189
192 190 def start(self):
193 191 assert self._constructed, "must be constructed by self.construct() first!"
194 192 self.heartmonitor.start()
195 193 self.log.info("Heartmonitor started")
196 194
197 195 def construct_hub(self):
198 196 """construct"""
199 197 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
200 198 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
201 199
202 200 ctx = self.context
203 201 loop = self.loop
204 202
205 203 # Registrar socket
206 204 reg = ZMQStream(ctx.socket(zmq.XREP), loop)
207 205 reg.bind(client_iface % self.regport)
208 206 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
209 207 if self.client_ip != self.engine_ip:
210 208 reg.bind(engine_iface % self.regport)
211 209 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
212 210
213 211 ### Engine connections ###
214 212
215 213 # heartbeat
216 214 hpub = ctx.socket(zmq.PUB)
217 215 hpub.bind(engine_iface % self.hb[0])
218 216 hrep = ctx.socket(zmq.XREP)
219 217 hrep.bind(engine_iface % self.hb[1])
220 218 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
221 219 period=self.ping, logname=self.log.name)
222 220
223 221 ### Client connections ###
224 222 # Clientele socket
225 223 c = ZMQStream(ctx.socket(zmq.XREP), loop)
226 224 c.bind(client_iface%self.query_port)
227 225 # Notifier socket
228 226 n = ZMQStream(ctx.socket(zmq.PUB), loop)
229 227 n.bind(client_iface%self.notifier_port)
230 228
231 229 ### build and launch the queues ###
232 230
233 231 # monitor socket
234 232 sub = ctx.socket(zmq.SUB)
235 233 sub.setsockopt(zmq.SUBSCRIBE, "")
236 234 sub.bind(self.monitor_url)
237 235 sub.bind('inproc://monitor')
238 236 sub = ZMQStream(sub, loop)
239 237
240 238 # connect the db
241 239 self.db = import_item(self.db_class)(self.session.session)
242 240 time.sleep(.25)
243 241
244 242 # build connection dicts
245 243 self.engine_info = {
246 244 'control' : engine_iface%self.control[1],
247 245 'mux': engine_iface%self.mux[1],
248 246 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
249 247 'task' : engine_iface%self.task[1],
250 248 'iopub' : engine_iface%self.iopub[1],
251 249 # 'monitor' : engine_iface%self.mon_port,
252 250 }
253 251
254 252 self.client_info = {
255 253 'control' : client_iface%self.control[0],
256 254 'query': client_iface%self.query_port,
257 255 'mux': client_iface%self.mux[0],
258 256 'task' : (self.scheme, client_iface%self.task[0]),
259 257 'iopub' : client_iface%self.iopub[0],
260 258 'notification': client_iface%self.notifier_port
261 259 }
262 260 self.log.debug("hub::Hub engine addrs: %s"%self.engine_info)
263 261 self.log.debug("hub::Hub client addrs: %s"%self.client_info)
264 262 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
265 263 registrar=reg, clientele=c, notifier=n, db=self.db,
266 264 engine_info=self.engine_info, client_info=self.client_info,
267 265 logname=self.log.name)
268 266
269 267
270 268 class Hub(LoggingFactory):
271 269 """The IPython Controller Hub with 0MQ connections
272 270
273 271 Parameters
274 272 ==========
275 273 loop: zmq IOLoop instance
276 274 session: StreamSession object
277 275 <removed> context: zmq context for creating new connections (?)
278 276 queue: ZMQStream for monitoring the command queue (SUB)
279 277 registrar: ZMQStream for engine registration requests (XREP)
280 278 heartbeat: HeartMonitor object checking the pulse of the engines
281 279 clientele: ZMQStream for client connections (XREP)
282 280 not used for jobs, only query/control commands
283 281 notifier: ZMQStream for broadcasting engine registration changes (PUB)
284 282 db: connection to db for out of memory logging of commands
285 283 NotImplemented
286 284 engine_info: dict of zmq connection information for engines to connect
287 285 to the queues.
288 286 client_info: dict of zmq connection information for engines to connect
289 287 to the queues.
290 288 """
291 289 # internal data structures:
292 290 ids=Set() # engine IDs
293 291 keytable=Dict()
294 292 by_ident=Dict()
295 293 engines=Dict()
296 294 clients=Dict()
297 295 hearts=Dict()
298 296 pending=Set()
299 297 queues=Dict() # pending msg_ids keyed by engine_id
300 298 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
301 299 completed=Dict() # completed msg_ids keyed by engine_id
302 300 all_completed=Set() # completed msg_ids keyed by engine_id
303 301 # mia=None
304 302 incoming_registrations=Dict()
305 303 registration_timeout=Int()
306 304 _idcounter=Int(0)
307 305
308 306 # objects from constructor:
309 307 loop=Instance(ioloop.IOLoop)
310 308 registrar=Instance(ZMQStream)
311 309 clientele=Instance(ZMQStream)
312 310 monitor=Instance(ZMQStream)
313 311 heartmonitor=Instance(HeartMonitor)
314 312 notifier=Instance(ZMQStream)
315 313 db=Instance(object)
316 314 client_info=Dict()
317 315 engine_info=Dict()
318 316
319 317
320 318 def __init__(self, **kwargs):
321 319 """
322 320 # universal:
323 321 loop: IOLoop for creating future connections
324 322 session: streamsession for sending serialized data
325 323 # engine:
326 324 queue: ZMQStream for monitoring queue messages
327 325 registrar: ZMQStream for engine registration
328 326 heartbeat: HeartMonitor object for tracking engines
329 327 # client:
330 328 clientele: ZMQStream for client connections
331 329 # extra:
332 330 db: ZMQStream for db connection (NotImplemented)
333 331 engine_info: zmq address/protocol dict for engine connections
334 332 client_info: zmq address/protocol dict for client connections
335 333 """
336 334
337 335 super(Hub, self).__init__(**kwargs)
338 336 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
339 337
340 338 # validate connection dicts:
341 339 for k,v in self.client_info.iteritems():
342 340 if k == 'task':
343 341 validate_url_container(v[1])
344 342 else:
345 343 validate_url_container(v)
346 344 # validate_url_container(self.client_info)
347 345 validate_url_container(self.engine_info)
348 346
349 347 # register our callbacks
350 348 self.registrar.on_recv(self.dispatch_register_request)
351 349 self.clientele.on_recv(self.dispatch_client_msg)
352 350 self.monitor.on_recv(self.dispatch_monitor_traffic)
353 351
354 352 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
355 353 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
356 354
357 355 self.monitor_handlers = { 'in' : self.save_queue_request,
358 356 'out': self.save_queue_result,
359 357 'intask': self.save_task_request,
360 358 'outtask': self.save_task_result,
361 359 'tracktask': self.save_task_destination,
362 360 'incontrol': _passer,
363 361 'outcontrol': _passer,
364 362 'iopub': self.save_iopub_message,
365 363 }
366 364
367 365 self.client_handlers = {'queue_request': self.queue_status,
368 366 'result_request': self.get_results,
369 367 'purge_request': self.purge_results,
370 368 'load_request': self.check_load,
371 369 'resubmit_request': self.resubmit_task,
372 370 'shutdown_request': self.shutdown_request,
373 371 }
374 372
375 373 self.registrar_handlers = {'registration_request' : self.register_engine,
376 374 'unregistration_request' : self.unregister_engine,
377 375 'connection_request': self.connection_request,
378 376 }
379 377
380 378 self.log.info("hub::created hub")
381 379
382 380 @property
383 381 def _next_id(self):
384 382 """gemerate a new ID.
385 383
386 384 No longer reuse old ids, just count from 0."""
387 385 newid = self._idcounter
388 386 self._idcounter += 1
389 387 return newid
390 388 # newid = 0
391 389 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
392 390 # # print newid, self.ids, self.incoming_registrations
393 391 # while newid in self.ids or newid in incoming:
394 392 # newid += 1
395 393 # return newid
396 394
397 395 #-----------------------------------------------------------------------------
398 396 # message validation
399 397 #-----------------------------------------------------------------------------
400 398
401 399 def _validate_targets(self, targets):
402 400 """turn any valid targets argument into a list of integer ids"""
403 401 if targets is None:
404 402 # default to all
405 403 targets = self.ids
406 404
407 405 if isinstance(targets, (int,str,unicode)):
408 406 # only one target specified
409 407 targets = [targets]
410 408 _targets = []
411 409 for t in targets:
412 410 # map raw identities to ids
413 411 if isinstance(t, (str,unicode)):
414 412 t = self.by_ident.get(t, t)
415 413 _targets.append(t)
416 414 targets = _targets
417 415 bad_targets = [ t for t in targets if t not in self.ids ]
418 416 if bad_targets:
419 417 raise IndexError("No Such Engine: %r"%bad_targets)
420 418 if not targets:
421 419 raise IndexError("No Engines Registered")
422 420 return targets
423 421
424 422 def _validate_client_msg(self, msg):
425 423 """validates and unpacks headers of a message. Returns False if invalid,
426 424 (ident, header, parent, content)"""
427 425 client_id = msg[0]
428 426 try:
429 427 msg = self.session.unpack_message(msg[1:], content=True)
430 428 except:
431 429 self.log.error("client::Invalid Message %s"%msg, exc_info=True)
432 430 return False
433 431
434 432 msg_type = msg.get('msg_type', None)
435 433 if msg_type is None:
436 434 return False
437 435 header = msg.get('header')
438 436 # session doesn't handle split content for now:
439 437 return client_id, msg
440 438
441 439
442 440 #-----------------------------------------------------------------------------
443 441 # dispatch methods (1 per stream)
444 442 #-----------------------------------------------------------------------------
445 443
446 444 def dispatch_register_request(self, msg):
447 445 """"""
448 446 self.log.debug("registration::dispatch_register_request(%s)"%msg)
449 447 idents,msg = self.session.feed_identities(msg)
450 448 if not idents:
451 449 self.log.error("Bad Queue Message: %s"%msg, exc_info=True)
452 450 return
453 451 try:
454 452 msg = self.session.unpack_message(msg,content=True)
455 453 except:
456 454 self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
457 455 return
458 456
459 457 msg_type = msg['msg_type']
460 458 content = msg['content']
461 459
462 460 handler = self.registrar_handlers.get(msg_type, None)
463 461 if handler is None:
464 462 self.log.error("registration::got bad registration message: %s"%msg)
465 463 else:
466 464 handler(idents, msg)
467 465
468 466 def dispatch_monitor_traffic(self, msg):
469 467 """all ME and Task queue messages come through here, as well as
470 468 IOPub traffic."""
471 469 self.log.debug("monitor traffic: %s"%msg[:2])
472 470 switch = msg[0]
473 471 idents, msg = self.session.feed_identities(msg[1:])
474 472 if not idents:
475 473 self.log.error("Bad Monitor Message: %s"%msg)
476 474 return
477 475 handler = self.monitor_handlers.get(switch, None)
478 476 if handler is not None:
479 477 handler(idents, msg)
480 478 else:
481 479 self.log.error("Invalid monitor topic: %s"%switch)
482 480
483 481
484 482 def dispatch_client_msg(self, msg):
485 483 """Route messages from clients"""
486 484 idents, msg = self.session.feed_identities(msg)
487 485 if not idents:
488 486 self.log.error("Bad Client Message: %s"%msg)
489 487 return
490 488 client_id = idents[0]
491 489 try:
492 490 msg = self.session.unpack_message(msg, content=True)
493 491 except:
494 content = wrap_exception()
492 content = error.wrap_exception()
495 493 self.log.error("Bad Client Message: %s"%msg, exc_info=True)
496 494 self.session.send(self.clientele, "hub_error", ident=client_id,
497 495 content=content)
498 496 return
499 497
500 498 # print client_id, header, parent, content
501 499 #switch on message type:
502 500 msg_type = msg['msg_type']
503 501 self.log.info("client:: client %s requested %s"%(client_id, msg_type))
504 502 handler = self.client_handlers.get(msg_type, None)
505 503 try:
506 504 assert handler is not None, "Bad Message Type: %s"%msg_type
507 505 except:
508 content = wrap_exception()
506 content = error.wrap_exception()
509 507 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
510 508 self.session.send(self.clientele, "hub_error", ident=client_id,
511 509 content=content)
512 510 return
513 511 else:
514 512 handler(client_id, msg)
515 513
516 514 def dispatch_db(self, msg):
517 515 """"""
518 516 raise NotImplementedError
519 517
520 518 #---------------------------------------------------------------------------
521 519 # handler methods (1 per event)
522 520 #---------------------------------------------------------------------------
523 521
524 522 #----------------------- Heartbeat --------------------------------------
525 523
526 524 def handle_new_heart(self, heart):
527 525 """handler to attach to heartbeater.
528 526 Called when a new heart starts to beat.
529 527 Triggers completion of registration."""
530 528 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
531 529 if heart not in self.incoming_registrations:
532 530 self.log.info("heartbeat::ignoring new heart: %r"%heart)
533 531 else:
534 532 self.finish_registration(heart)
535 533
536 534
537 535 def handle_heart_failure(self, heart):
538 536 """handler to attach to heartbeater.
539 537 called when a previously registered heart fails to respond to beat request.
540 538 triggers unregistration"""
541 539 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
542 540 eid = self.hearts.get(heart, None)
543 541 queue = self.engines[eid].queue
544 542 if eid is None:
545 543 self.log.info("heartbeat::ignoring heart failure %r"%heart)
546 544 else:
547 545 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
548 546
549 547 #----------------------- MUX Queue Traffic ------------------------------
550 548
551 549 def save_queue_request(self, idents, msg):
552 550 if len(idents) < 2:
553 551 self.log.error("invalid identity prefix: %s"%idents)
554 552 return
555 553 queue_id, client_id = idents[:2]
556 554 try:
557 555 msg = self.session.unpack_message(msg, content=False)
558 556 except:
559 557 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
560 558 return
561 559
562 560 eid = self.by_ident.get(queue_id, None)
563 561 if eid is None:
564 562 self.log.error("queue::target %r not registered"%queue_id)
565 563 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
566 564 return
567 565
568 566 header = msg['header']
569 567 msg_id = header['msg_id']
570 568 record = init_record(msg)
571 569 record['engine_uuid'] = queue_id
572 570 record['client_uuid'] = client_id
573 571 record['queue'] = 'mux'
574 572 if MongoDB is not None and isinstance(self.db, MongoDB):
575 573 record['buffers'] = map(Binary, record['buffers'])
576 574 self.pending.add(msg_id)
577 575 self.queues[eid].append(msg_id)
578 576 self.db.add_record(msg_id, record)
579 577
580 578 def save_queue_result(self, idents, msg):
581 579 if len(idents) < 2:
582 580 self.log.error("invalid identity prefix: %s"%idents)
583 581 return
584 582
585 583 client_id, queue_id = idents[:2]
586 584 try:
587 585 msg = self.session.unpack_message(msg, content=False)
588 586 except:
589 587 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
590 588 queue_id,client_id, msg), exc_info=True)
591 589 return
592 590
593 591 eid = self.by_ident.get(queue_id, None)
594 592 if eid is None:
595 593 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
596 594 self.log.debug("queue:: %s"%msg[2:])
597 595 return
598 596
599 597 parent = msg['parent_header']
600 598 if not parent:
601 599 return
602 600 msg_id = parent['msg_id']
603 601 if msg_id in self.pending:
604 602 self.pending.remove(msg_id)
605 603 self.all_completed.add(msg_id)
606 604 self.queues[eid].remove(msg_id)
607 605 self.completed[eid].append(msg_id)
608 606 rheader = msg['header']
609 607 completed = datetime.strptime(rheader['date'], ISO8601)
610 608 started = rheader.get('started', None)
611 609 if started is not None:
612 610 started = datetime.strptime(started, ISO8601)
613 611 result = {
614 612 'result_header' : rheader,
615 613 'result_content': msg['content'],
616 614 'started' : started,
617 615 'completed' : completed
618 616 }
619 617 if MongoDB is not None and isinstance(self.db, MongoDB):
620 618 result['result_buffers'] = map(Binary, msg['buffers'])
621 619 else:
622 620 result['result_buffers'] = msg['buffers']
623 621 self.db.update_record(msg_id, result)
624 622 else:
625 623 self.log.debug("queue:: unknown msg finished %s"%msg_id)
626 624
627 625 #--------------------- Task Queue Traffic ------------------------------
628 626
629 627 def save_task_request(self, idents, msg):
630 628 """Save the submission of a task."""
631 629 client_id = idents[0]
632 630
633 631 try:
634 632 msg = self.session.unpack_message(msg, content=False)
635 633 except:
636 634 self.log.error("task::client %r sent invalid task message: %s"%(
637 635 client_id, msg), exc_info=True)
638 636 return
639 637 record = init_record(msg)
640 638 if MongoDB is not None and isinstance(self.db, MongoDB):
641 639 record['buffers'] = map(Binary, record['buffers'])
642 640 record['client_uuid'] = client_id
643 641 record['queue'] = 'task'
644 642 header = msg['header']
645 643 msg_id = header['msg_id']
646 644 self.pending.add(msg_id)
647 645 self.db.add_record(msg_id, record)
648 646
649 647 def save_task_result(self, idents, msg):
650 648 """save the result of a completed task."""
651 649 client_id = idents[0]
652 650 try:
653 651 msg = self.session.unpack_message(msg, content=False)
654 652 except:
655 653 self.log.error("task::invalid task result message send to %r: %s"%(
656 654 client_id, msg), exc_info=True)
657 655 raise
658 656 return
659 657
660 658 parent = msg['parent_header']
661 659 if not parent:
662 660 # print msg
663 661 self.log.warn("Task %r had no parent!"%msg)
664 662 return
665 663 msg_id = parent['msg_id']
666 664
667 665 header = msg['header']
668 666 engine_uuid = header.get('engine', None)
669 667 eid = self.by_ident.get(engine_uuid, None)
670 668
671 669 if msg_id in self.pending:
672 670 self.pending.remove(msg_id)
673 671 self.all_completed.add(msg_id)
674 672 if eid is not None:
675 673 self.completed[eid].append(msg_id)
676 674 if msg_id in self.tasks[eid]:
677 675 self.tasks[eid].remove(msg_id)
678 676 completed = datetime.strptime(header['date'], ISO8601)
679 677 started = header.get('started', None)
680 678 if started is not None:
681 679 started = datetime.strptime(started, ISO8601)
682 680 result = {
683 681 'result_header' : header,
684 682 'result_content': msg['content'],
685 683 'started' : started,
686 684 'completed' : completed,
687 685 'engine_uuid': engine_uuid
688 686 }
689 687 if MongoDB is not None and isinstance(self.db, MongoDB):
690 688 result['result_buffers'] = map(Binary, msg['buffers'])
691 689 else:
692 690 result['result_buffers'] = msg['buffers']
693 691 self.db.update_record(msg_id, result)
694 692
695 693 else:
696 694 self.log.debug("task::unknown task %s finished"%msg_id)
697 695
698 696 def save_task_destination(self, idents, msg):
699 697 try:
700 698 msg = self.session.unpack_message(msg, content=True)
701 699 except:
702 700 self.log.error("task::invalid task tracking message", exc_info=True)
703 701 return
704 702 content = msg['content']
705 703 # print (content)
706 704 msg_id = content['msg_id']
707 705 engine_uuid = content['engine_id']
708 706 eid = self.by_ident[engine_uuid]
709 707
710 708 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
711 709 # if msg_id in self.mia:
712 710 # self.mia.remove(msg_id)
713 711 # else:
714 712 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
715 713
716 714 self.tasks[eid].append(msg_id)
717 715 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
718 716 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
719 717
720 718 def mia_task_request(self, idents, msg):
721 719 raise NotImplementedError
722 720 client_id = idents[0]
723 721 # content = dict(mia=self.mia,status='ok')
724 722 # self.session.send('mia_reply', content=content, idents=client_id)
725 723
726 724
727 725 #--------------------- IOPub Traffic ------------------------------
728 726
729 727 def save_iopub_message(self, topics, msg):
730 728 """save an iopub message into the db"""
731 729 # print (topics)
732 730 try:
733 731 msg = self.session.unpack_message(msg, content=True)
734 732 except:
735 733 self.log.error("iopub::invalid IOPub message", exc_info=True)
736 734 return
737 735
738 736 parent = msg['parent_header']
739 737 if not parent:
740 738 self.log.error("iopub::invalid IOPub message: %s"%msg)
741 739 return
742 740 msg_id = parent['msg_id']
743 741 msg_type = msg['msg_type']
744 742 content = msg['content']
745 743
746 744 # ensure msg_id is in db
747 745 try:
748 746 rec = self.db.get_record(msg_id)
749 747 except:
750 748 self.log.error("iopub::IOPub message has invalid parent", exc_info=True)
751 749 return
752 750 # stream
753 751 d = {}
754 752 if msg_type == 'stream':
755 753 name = content['name']
756 754 s = rec[name] or ''
757 755 d[name] = s + content['data']
758 756
759 757 elif msg_type == 'pyerr':
760 758 d['pyerr'] = content
761 759 else:
762 760 d[msg_type] = content['data']
763 761
764 762 self.db.update_record(msg_id, d)
765 763
766 764
767 765
768 766 #-------------------------------------------------------------------------
769 767 # Registration requests
770 768 #-------------------------------------------------------------------------
771 769
772 770 def connection_request(self, client_id, msg):
773 771 """Reply with connection addresses for clients."""
774 772 self.log.info("client::client %s connected"%client_id)
775 773 content = dict(status='ok')
776 774 content.update(self.client_info)
777 775 jsonable = {}
778 776 for k,v in self.keytable.iteritems():
779 777 jsonable[str(k)] = v
780 778 content['engines'] = jsonable
781 779 self.session.send(self.registrar, 'connection_reply', content, parent=msg, ident=client_id)
782 780
783 781 def register_engine(self, reg, msg):
784 782 """Register a new engine."""
785 783 content = msg['content']
786 784 try:
787 785 queue = content['queue']
788 786 except KeyError:
789 787 self.log.error("registration::queue not specified", exc_info=True)
790 788 return
791 789 heart = content.get('heartbeat', None)
792 790 """register a new engine, and create the socket(s) necessary"""
793 791 eid = self._next_id
794 792 # print (eid, queue, reg, heart)
795 793
796 794 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
797 795
798 796 content = dict(id=eid,status='ok')
799 797 content.update(self.engine_info)
800 798 # check if requesting available IDs:
801 799 if queue in self.by_ident:
802 800 try:
803 801 raise KeyError("queue_id %r in use"%queue)
804 802 except:
805 content = wrap_exception()
803 content = error.wrap_exception()
806 804 self.log.error("queue_id %r in use"%queue, exc_info=True)
807 805 elif heart in self.hearts: # need to check unique hearts?
808 806 try:
809 807 raise KeyError("heart_id %r in use"%heart)
810 808 except:
811 809 self.log.error("heart_id %r in use"%heart, exc_info=True)
812 content = wrap_exception()
810 content = error.wrap_exception()
813 811 else:
814 812 for h, pack in self.incoming_registrations.iteritems():
815 813 if heart == h:
816 814 try:
817 815 raise KeyError("heart_id %r in use"%heart)
818 816 except:
819 817 self.log.error("heart_id %r in use"%heart, exc_info=True)
820 content = wrap_exception()
818 content = error.wrap_exception()
821 819 break
822 820 elif queue == pack[1]:
823 821 try:
824 822 raise KeyError("queue_id %r in use"%queue)
825 823 except:
826 824 self.log.error("queue_id %r in use"%queue, exc_info=True)
827 content = wrap_exception()
825 content = error.wrap_exception()
828 826 break
829 827
830 828 msg = self.session.send(self.registrar, "registration_reply",
831 829 content=content,
832 830 ident=reg)
833 831
834 832 if content['status'] == 'ok':
835 833 if heart in self.heartmonitor.hearts:
836 834 # already beating
837 835 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
838 836 self.finish_registration(heart)
839 837 else:
840 838 purge = lambda : self._purge_stalled_registration(heart)
841 839 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
842 840 dc.start()
843 841 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
844 842 else:
845 843 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
846 844 return eid
847 845
848 846 def unregister_engine(self, ident, msg):
849 847 """Unregister an engine that explicitly requested to leave."""
850 848 try:
851 849 eid = msg['content']['id']
852 850 except:
853 851 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
854 852 return
855 853 self.log.info("registration::unregister_engine(%s)"%eid)
856 854 content=dict(id=eid, queue=self.engines[eid].queue)
857 855 self.ids.remove(eid)
858 856 self.keytable.pop(eid)
859 857 ec = self.engines.pop(eid)
860 858 self.hearts.pop(ec.heartbeat)
861 859 self.by_ident.pop(ec.queue)
862 860 self.completed.pop(eid)
863 861 for msg_id in self.queues.pop(eid):
864 862 msg = self.pending.remove(msg_id)
865 863 ############## TODO: HANDLE IT ################
866 864
867 865 if self.notifier:
868 866 self.session.send(self.notifier, "unregistration_notification", content=content)
869 867
870 868 def finish_registration(self, heart):
871 869 """Second half of engine registration, called after our HeartMonitor
872 870 has received a beat from the Engine's Heart."""
873 871 try:
874 872 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
875 873 except KeyError:
876 874 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
877 875 return
878 876 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
879 877 if purge is not None:
880 878 purge.stop()
881 879 control = queue
882 880 self.ids.add(eid)
883 881 self.keytable[eid] = queue
884 882 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
885 883 control=control, heartbeat=heart)
886 884 self.by_ident[queue] = eid
887 885 self.queues[eid] = list()
888 886 self.tasks[eid] = list()
889 887 self.completed[eid] = list()
890 888 self.hearts[heart] = eid
891 889 content = dict(id=eid, queue=self.engines[eid].queue)
892 890 if self.notifier:
893 891 self.session.send(self.notifier, "registration_notification", content=content)
894 892 self.log.info("engine::Engine Connected: %i"%eid)
895 893
896 894 def _purge_stalled_registration(self, heart):
897 895 if heart in self.incoming_registrations:
898 896 eid = self.incoming_registrations.pop(heart)[0]
899 897 self.log.info("registration::purging stalled registration: %i"%eid)
900 898 else:
901 899 pass
902 900
903 901 #-------------------------------------------------------------------------
904 902 # Client Requests
905 903 #-------------------------------------------------------------------------
906 904
907 905 def shutdown_request(self, client_id, msg):
908 906 """handle shutdown request."""
909 907 # s = self.context.socket(zmq.XREQ)
910 908 # s.connect(self.client_connections['mux'])
911 909 # time.sleep(0.1)
912 910 # for eid,ec in self.engines.iteritems():
913 911 # self.session.send(s, 'shutdown_request', content=dict(restart=False), ident=ec.queue)
914 912 # time.sleep(1)
915 913 self.session.send(self.clientele, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
916 914 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
917 915 dc.start()
918 916
919 917 def _shutdown(self):
920 918 self.log.info("hub::hub shutting down.")
921 919 time.sleep(0.1)
922 920 sys.exit(0)
923 921
924 922
925 923 def check_load(self, client_id, msg):
926 924 content = msg['content']
927 925 try:
928 926 targets = content['targets']
929 927 targets = self._validate_targets(targets)
930 928 except:
931 content = wrap_exception()
929 content = error.wrap_exception()
932 930 self.session.send(self.clientele, "hub_error",
933 931 content=content, ident=client_id)
934 932 return
935 933
936 934 content = dict(status='ok')
937 935 # loads = {}
938 936 for t in targets:
939 937 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
940 938 self.session.send(self.clientele, "load_reply", content=content, ident=client_id)
941 939
942 940
943 941 def queue_status(self, client_id, msg):
944 942 """Return the Queue status of one or more targets.
945 943 if verbose: return the msg_ids
946 944 else: return len of each type.
947 945 keys: queue (pending MUX jobs)
948 946 tasks (pending Task jobs)
949 947 completed (finished jobs from both queues)"""
950 948 content = msg['content']
951 949 targets = content['targets']
952 950 try:
953 951 targets = self._validate_targets(targets)
954 952 except:
955 content = wrap_exception()
953 content = error.wrap_exception()
956 954 self.session.send(self.clientele, "hub_error",
957 955 content=content, ident=client_id)
958 956 return
959 957 verbose = content.get('verbose', False)
960 958 content = dict(status='ok')
961 959 for t in targets:
962 960 queue = self.queues[t]
963 961 completed = self.completed[t]
964 962 tasks = self.tasks[t]
965 963 if not verbose:
966 964 queue = len(queue)
967 965 completed = len(completed)
968 966 tasks = len(tasks)
969 967 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
970 968 # pending
971 969 self.session.send(self.clientele, "queue_reply", content=content, ident=client_id)
972 970
973 971 def purge_results(self, client_id, msg):
974 972 """Purge results from memory. This method is more valuable before we move
975 973 to a DB based message storage mechanism."""
976 974 content = msg['content']
977 975 msg_ids = content.get('msg_ids', [])
978 976 reply = dict(status='ok')
979 977 if msg_ids == 'all':
980 978 self.db.drop_matching_records(dict(completed={'$ne':None}))
981 979 else:
982 980 for msg_id in msg_ids:
983 981 if msg_id in self.all_completed:
984 982 self.db.drop_record(msg_id)
985 983 else:
986 984 if msg_id in self.pending:
987 985 try:
988 986 raise IndexError("msg pending: %r"%msg_id)
989 987 except:
990 reply = wrap_exception()
988 reply = error.wrap_exception()
991 989 else:
992 990 try:
993 991 raise IndexError("No such msg: %r"%msg_id)
994 992 except:
995 reply = wrap_exception()
993 reply = error.wrap_exception()
996 994 break
997 995 eids = content.get('engine_ids', [])
998 996 for eid in eids:
999 997 if eid not in self.engines:
1000 998 try:
1001 999 raise IndexError("No such engine: %i"%eid)
1002 1000 except:
1003 reply = wrap_exception()
1001 reply = error.wrap_exception()
1004 1002 break
1005 1003 msg_ids = self.completed.pop(eid)
1006 1004 uid = self.engines[eid].queue
1007 1005 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1008 1006
1009 1007 self.session.send(self.clientele, 'purge_reply', content=reply, ident=client_id)
1010 1008
1011 1009 def resubmit_task(self, client_id, msg, buffers):
1012 1010 """Resubmit a task."""
1013 1011 raise NotImplementedError
1014 1012
1015 1013 def get_results(self, client_id, msg):
1016 1014 """Get the result of 1 or more messages."""
1017 1015 content = msg['content']
1018 1016 msg_ids = sorted(set(content['msg_ids']))
1019 1017 statusonly = content.get('status_only', False)
1020 1018 pending = []
1021 1019 completed = []
1022 1020 content = dict(status='ok')
1023 1021 content['pending'] = pending
1024 1022 content['completed'] = completed
1025 1023 buffers = []
1026 1024 if not statusonly:
1027 1025 content['results'] = {}
1028 1026 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1029 1027 for msg_id in msg_ids:
1030 1028 if msg_id in self.pending:
1031 1029 pending.append(msg_id)
1032 1030 elif msg_id in self.all_completed:
1033 1031 completed.append(msg_id)
1034 1032 if not statusonly:
1035 1033 rec = records[msg_id]
1036 1034 io_dict = {}
1037 1035 for key in 'pyin pyout pyerr stdout stderr'.split():
1038 1036 io_dict[key] = rec[key]
1039 1037 content[msg_id] = { 'result_content': rec['result_content'],
1040 1038 'header': rec['header'],
1041 1039 'result_header' : rec['result_header'],
1042 1040 'io' : io_dict,
1043 1041 }
1044 1042 buffers.extend(map(str, rec['result_buffers']))
1045 1043 else:
1046 1044 try:
1047 1045 raise KeyError('No such message: '+msg_id)
1048 1046 except:
1049 content = wrap_exception()
1047 content = error.wrap_exception()
1050 1048 break
1051 1049 self.session.send(self.clientele, "result_reply", content=content,
1052 1050 parent=msg, ident=client_id,
1053 1051 buffers=buffers)
1054 1052
@@ -1,174 +1,203 b''
1 1 """Remote Functions and decorators for the client."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import warnings
14 14
15 15 from IPython.testing import decorators as testdec
16 16
17 17 from . import map as Map
18 18 from .asyncresult import AsyncMapResult
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Decorators
22 22 #-----------------------------------------------------------------------------
23 23
24 24 @testdec.skip_doctest
25 25 def remote(client, bound=True, block=None, targets=None, balanced=None):
26 26 """Turn a function into a remote function.
27 27
28 28 This method can be used for map:
29 29
30 30 In [1]: @remote(client,block=True)
31 31 ...: def func(a):
32 32 ...: pass
33 33 """
34 34
35 35 def remote_function(f):
36 36 return RemoteFunction(client, f, bound, block, targets, balanced)
37 37 return remote_function
38 38
39 39 @testdec.skip_doctest
40 40 def parallel(client, dist='b', bound=True, block=None, targets='all', balanced=None):
41 41 """Turn a function into a parallel remote function.
42 42
43 43 This method can be used for map:
44 44
45 45 In [1]: @parallel(client,block=True)
46 46 ...: def func(a):
47 47 ...: pass
48 48 """
49 49
50 50 def parallel_function(f):
51 51 return ParallelFunction(client, f, dist, bound, block, targets, balanced)
52 52 return parallel_function
53 53
54 54 #--------------------------------------------------------------------------
55 55 # Classes
56 56 #--------------------------------------------------------------------------
57 57
58 58 class RemoteFunction(object):
59 59 """Turn an existing function into a remote function.
60 60
61 61 Parameters
62 62 ----------
63 63
64 64 client : Client instance
65 65 The client to be used to connect to engines
66 66 f : callable
67 67 The function to be wrapped into a remote function
68 68 bound : bool [default: False]
69 69 Whether the affect the remote namespace when called
70 70 block : bool [default: None]
71 71 Whether to wait for results or not. The default behavior is
72 72 to use the current `block` attribute of `client`
73 73 targets : valid target list [default: all]
74 74 The targets on which to execute.
75 75 balanced : bool
76 76 Whether to load-balance with the Task scheduler or not
77 77 """
78 78
79 79 client = None # the remote connection
80 80 func = None # the wrapped function
81 81 block = None # whether to block
82 82 bound = None # whether to affect the namespace
83 83 targets = None # where to execute
84 84 balanced = None # whether to load-balance
85 85
86 86 def __init__(self, client, f, bound=False, block=None, targets=None, balanced=None):
87 87 self.client = client
88 88 self.func = f
89 89 self.block=block
90 90 self.bound=bound
91 91 self.targets=targets
92 92 if balanced is None:
93 93 if targets is None:
94 94 balanced = True
95 95 else:
96 96 balanced = False
97 97 self.balanced = balanced
98 98
99 99 def __call__(self, *args, **kwargs):
100 100 return self.client.apply(self.func, args=args, kwargs=kwargs,
101 101 block=self.block, targets=self.targets, bound=self.bound, balanced=self.balanced)
102 102
103 103
104 104 class ParallelFunction(RemoteFunction):
105 """Class for mapping a function to sequences."""
105 """Class for mapping a function to sequences.
106
107 This will distribute the sequences according the a mapper, and call
108 the function on each sub-sequence. If called via map, then the function
109 will be called once on each element, rather that each sub-sequence.
110
111 Parameters
112 ----------
113
114 client : Client instance
115 The client to be used to connect to engines
116 f : callable
117 The function to be wrapped into a remote function
118 bound : bool [default: False]
119 Whether the affect the remote namespace when called
120 block : bool [default: None]
121 Whether to wait for results or not. The default behavior is
122 to use the current `block` attribute of `client`
123 targets : valid target list [default: all]
124 The targets on which to execute.
125 balanced : bool
126 Whether to load-balance with the Task scheduler or not
127 chunk_size : int or None
128 The size of chunk to use when breaking up sequences in a load-balanced manner
129 """
106 130 def __init__(self, client, f, dist='b', bound=False, block=None, targets='all', balanced=None, chunk_size=None):
107 131 super(ParallelFunction, self).__init__(client,f,bound,block,targets,balanced)
108 132 self.chunk_size = chunk_size
109 133
110 134 mapClass = Map.dists[dist]
111 135 self.mapObject = mapClass()
112 136
113 137 def __call__(self, *sequences):
114 138 len_0 = len(sequences[0])
115 139 for s in sequences:
116 140 if len(s)!=len_0:
117 141 msg = 'all sequences must have equal length, but %i!=%i'%(len_0,len(s))
118 142 raise ValueError(msg)
119 143
120 144 if self.balanced:
121 145 if self.chunk_size:
122 146 nparts = len_0/self.chunk_size + int(len_0%self.chunk_size > 0)
123 147 else:
124 148 nparts = len_0
125 149 targets = [self.targets]*nparts
126 150 else:
127 151 if self.chunk_size:
128 152 warnings.warn("`chunk_size` is ignored when `balanced=False", UserWarning)
129 153 # multiplexed:
130 154 targets = self.client._build_targets(self.targets)[-1]
131 155 nparts = len(targets)
132 156
133 157 msg_ids = []
134 158 # my_f = lambda *a: map(self.func, *a)
135 159 for index, t in enumerate(targets):
136 160 args = []
137 161 for seq in sequences:
138 162 part = self.mapObject.getPartition(seq, index, nparts)
139 163 if len(part) == 0:
140 164 continue
141 165 else:
142 166 args.append(part)
143 167 if not args:
144 168 continue
145 169
146 170 # print (args)
147 171 if hasattr(self, '_map'):
148 172 f = map
149 173 args = [self.func]+args
150 174 else:
151 175 f=self.func
152 176 ar = self.client.apply(f, args=args, block=False, bound=self.bound,
153 177 targets=t, balanced=self.balanced)
154 178
155 179 msg_ids.append(ar.msg_ids[0])
156 180
157 181 r = AsyncMapResult(self.client, msg_ids, self.mapObject, fname=self.func.__name__)
158 182 if self.block:
159 183 try:
160 184 return r.get()
161 185 except KeyboardInterrupt:
162 186 return r
163 187 else:
164 188 return r
165 189
166 190 def map(self, *sequences):
167 """call a function on each element of a sequence remotely."""
191 """call a function on each element of a sequence remotely.
192 This should behave very much like the builtin map, but return an AsyncMapResult
193 if self.block is False.
194 """
195 # set _map as a flag for use inside self.__call__
168 196 self._map = True
169 197 try:
170 198 ret = self.__call__(*sequences)
171 199 finally:
172 200 del self._map
173 201 return ret
174 202
203 __all__ = ['remote', 'parallel', 'RemoteFunction', 'ParallelFunction'] No newline at end of file
@@ -1,581 +1,580 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6 6 """
7 7
8 8 #----------------------------------------------------------------------
9 9 # Imports
10 10 #----------------------------------------------------------------------
11 11
12 12 from __future__ import print_function
13 13
14 14 import logging
15 15 import sys
16 16
17 17 from datetime import datetime, timedelta
18 18 from random import randint, random
19 19 from types import FunctionType
20 20
21 21 try:
22 22 import numpy
23 23 except ImportError:
24 24 numpy = None
25 25
26 26 import zmq
27 27 from zmq.eventloop import ioloop, zmqstream
28 28
29 29 # local imports
30 30 from IPython.external.decorator import decorator
31 31 from IPython.utils.traitlets import Instance, Dict, List, Set
32 32
33 33 from . import error
34 from . import streamsession as ss
35 34 from .dependency import Dependency
36 35 from .entry_point import connect_logger, local_logger
37 36 from .factory import SessionFactory
38 37
39 38
40 39 @decorator
41 40 def logged(f,self,*args,**kwargs):
42 41 # print ("#--------------------")
43 42 self.log.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
44 43 # print ("#--")
45 44 return f(self,*args, **kwargs)
46 45
47 46 #----------------------------------------------------------------------
48 47 # Chooser functions
49 48 #----------------------------------------------------------------------
50 49
51 50 def plainrandom(loads):
52 51 """Plain random pick."""
53 52 n = len(loads)
54 53 return randint(0,n-1)
55 54
56 55 def lru(loads):
57 56 """Always pick the front of the line.
58 57
59 58 The content of `loads` is ignored.
60 59
61 60 Assumes LRU ordering of loads, with oldest first.
62 61 """
63 62 return 0
64 63
65 64 def twobin(loads):
66 65 """Pick two at random, use the LRU of the two.
67 66
68 67 The content of loads is ignored.
69 68
70 69 Assumes LRU ordering of loads, with oldest first.
71 70 """
72 71 n = len(loads)
73 72 a = randint(0,n-1)
74 73 b = randint(0,n-1)
75 74 return min(a,b)
76 75
77 76 def weighted(loads):
78 77 """Pick two at random using inverse load as weight.
79 78
80 79 Return the less loaded of the two.
81 80 """
82 81 # weight 0 a million times more than 1:
83 82 weights = 1./(1e-6+numpy.array(loads))
84 83 sums = weights.cumsum()
85 84 t = sums[-1]
86 85 x = random()*t
87 86 y = random()*t
88 87 idx = 0
89 88 idy = 0
90 89 while sums[idx] < x:
91 90 idx += 1
92 91 while sums[idy] < y:
93 92 idy += 1
94 93 if weights[idy] > weights[idx]:
95 94 return idy
96 95 else:
97 96 return idx
98 97
99 98 def leastload(loads):
100 99 """Always choose the lowest load.
101 100
102 101 If the lowest load occurs more than once, the first
103 102 occurance will be used. If loads has LRU ordering, this means
104 103 the LRU of those with the lowest load is chosen.
105 104 """
106 105 return loads.index(min(loads))
107 106
108 107 #---------------------------------------------------------------------
109 108 # Classes
110 109 #---------------------------------------------------------------------
111 110 # store empty default dependency:
112 111 MET = Dependency([])
113 112
114 113 class TaskScheduler(SessionFactory):
115 114 """Python TaskScheduler object.
116 115
117 116 This is the simplest object that supports msg_id based
118 117 DAG dependencies. *Only* task msg_ids are checked, not
119 118 msg_ids of jobs submitted via the MUX queue.
120 119
121 120 """
122 121
123 122 # input arguments:
124 123 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
125 124 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
126 125 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
127 126 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
128 127 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
129 128
130 129 # internals:
131 130 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
132 131 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
133 132 pending = Dict() # dict by engine_uuid of submitted tasks
134 133 completed = Dict() # dict by engine_uuid of completed tasks
135 134 failed = Dict() # dict by engine_uuid of failed tasks
136 135 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
137 136 clients = Dict() # dict by msg_id for who submitted the task
138 137 targets = List() # list of target IDENTs
139 138 loads = List() # list of engine loads
140 139 all_completed = Set() # set of all completed tasks
141 140 all_failed = Set() # set of all failed tasks
142 141 all_done = Set() # set of all finished tasks=union(completed,failed)
143 142 all_ids = Set() # set of all submitted task IDs
144 143 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
145 144 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
146 145
147 146
148 147 def start(self):
149 148 self.engine_stream.on_recv(self.dispatch_result, copy=False)
150 149 self._notification_handlers = dict(
151 150 registration_notification = self._register_engine,
152 151 unregistration_notification = self._unregister_engine
153 152 )
154 153 self.notifier_stream.on_recv(self.dispatch_notification)
155 154 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
156 155 self.auditor.start()
157 156 self.log.info("Scheduler started...%r"%self)
158 157
159 158 def resume_receiving(self):
160 159 """Resume accepting jobs."""
161 160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
162 161
163 162 def stop_receiving(self):
164 163 """Stop accepting jobs while there are no engines.
165 164 Leave them in the ZMQ queue."""
166 165 self.client_stream.on_recv(None)
167 166
168 167 #-----------------------------------------------------------------------
169 168 # [Un]Registration Handling
170 169 #-----------------------------------------------------------------------
171 170
172 171 def dispatch_notification(self, msg):
173 172 """dispatch register/unregister events."""
174 173 idents,msg = self.session.feed_identities(msg)
175 174 msg = self.session.unpack_message(msg)
176 175 msg_type = msg['msg_type']
177 176 handler = self._notification_handlers.get(msg_type, None)
178 177 if handler is None:
179 178 raise Exception("Unhandled message type: %s"%msg_type)
180 179 else:
181 180 try:
182 181 handler(str(msg['content']['queue']))
183 182 except KeyError:
184 183 self.log.error("task::Invalid notification msg: %s"%msg)
185 184
186 185 @logged
187 186 def _register_engine(self, uid):
188 187 """New engine with ident `uid` became available."""
189 188 # head of the line:
190 189 self.targets.insert(0,uid)
191 190 self.loads.insert(0,0)
192 191 # initialize sets
193 192 self.completed[uid] = set()
194 193 self.failed[uid] = set()
195 194 self.pending[uid] = {}
196 195 if len(self.targets) == 1:
197 196 self.resume_receiving()
198 197
199 198 def _unregister_engine(self, uid):
200 199 """Existing engine with ident `uid` became unavailable."""
201 200 if len(self.targets) == 1:
202 201 # this was our only engine
203 202 self.stop_receiving()
204 203
205 204 # handle any potentially finished tasks:
206 205 self.engine_stream.flush()
207 206
208 207 self.completed.pop(uid)
209 208 self.failed.pop(uid)
210 209 # don't pop destinations, because it might be used later
211 210 # map(self.destinations.pop, self.completed.pop(uid))
212 211 # map(self.destinations.pop, self.failed.pop(uid))
213 212
214 213 idx = self.targets.index(uid)
215 214 self.targets.pop(idx)
216 215 self.loads.pop(idx)
217 216
218 217 # wait 5 seconds before cleaning up pending jobs, since the results might
219 218 # still be incoming
220 219 if self.pending[uid]:
221 220 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
222 221 dc.start()
223 222
224 223 @logged
225 224 def handle_stranded_tasks(self, engine):
226 225 """Deal with jobs resident in an engine that died."""
227 226 lost = self.pending.pop(engine)
228 227
229 228 for msg_id, (raw_msg, targets, MET, follow, timeout) in lost.iteritems():
230 229 self.all_failed.add(msg_id)
231 230 self.all_done.add(msg_id)
232 231 idents,msg = self.session.feed_identities(raw_msg, copy=False)
233 232 msg = self.session.unpack_message(msg, copy=False, content=False)
234 233 parent = msg['header']
235 234 idents = [idents[0],engine]+idents[1:]
236 235 print (idents)
237 236 try:
238 237 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
239 238 except:
240 content = ss.wrap_exception()
239 content = error.wrap_exception()
241 240 msg = self.session.send(self.client_stream, 'apply_reply', content,
242 241 parent=parent, ident=idents)
243 242 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
244 243 self.update_graph(msg_id)
245 244
246 245
247 246 #-----------------------------------------------------------------------
248 247 # Job Submission
249 248 #-----------------------------------------------------------------------
250 249 @logged
251 250 def dispatch_submission(self, raw_msg):
252 251 """Dispatch job submission to appropriate handlers."""
253 252 # ensure targets up to date:
254 253 self.notifier_stream.flush()
255 254 try:
256 255 idents, msg = self.session.feed_identities(raw_msg, copy=False)
257 256 msg = self.session.unpack_message(msg, content=False, copy=False)
258 257 except:
259 258 self.log.error("task::Invaid task: %s"%raw_msg, exc_info=True)
260 259 return
261 260
262 261 # send to monitor
263 262 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
264 263
265 264 header = msg['header']
266 265 msg_id = header['msg_id']
267 266 self.all_ids.add(msg_id)
268 267
269 268 # targets
270 269 targets = set(header.get('targets', []))
271 270
272 271 # time dependencies
273 272 after = Dependency(header.get('after', []))
274 273 if after.all:
275 274 after.difference_update(self.all_completed)
276 275 if not after.success_only:
277 276 after.difference_update(self.all_failed)
278 277 if after.check(self.all_completed, self.all_failed):
279 278 # recast as empty set, if `after` already met,
280 279 # to prevent unnecessary set comparisons
281 280 after = MET
282 281
283 282 # location dependencies
284 283 follow = Dependency(header.get('follow', []))
285 284
286 285 # turn timeouts into datetime objects:
287 286 timeout = header.get('timeout', None)
288 287 if timeout:
289 288 timeout = datetime.now() + timedelta(0,timeout,0)
290 289
291 290 args = [raw_msg, targets, after, follow, timeout]
292 291
293 292 # validate and reduce dependencies:
294 293 for dep in after,follow:
295 294 # check valid:
296 295 if msg_id in dep or dep.difference(self.all_ids):
297 296 self.depending[msg_id] = args
298 297 return self.fail_unreachable(msg_id, error.InvalidDependency)
299 298 # check if unreachable:
300 299 if dep.unreachable(self.all_failed):
301 300 self.depending[msg_id] = args
302 301 return self.fail_unreachable(msg_id)
303 302
304 303 if after.check(self.all_completed, self.all_failed):
305 304 # time deps already met, try to run
306 305 if not self.maybe_run(msg_id, *args):
307 306 # can't run yet
308 307 self.save_unmet(msg_id, *args)
309 308 else:
310 309 self.save_unmet(msg_id, *args)
311 310
312 311 # @logged
313 312 def audit_timeouts(self):
314 313 """Audit all waiting tasks for expired timeouts."""
315 314 now = datetime.now()
316 315 for msg_id in self.depending.keys():
317 316 # must recheck, in case one failure cascaded to another:
318 317 if msg_id in self.depending:
319 318 raw,after,targets,follow,timeout = self.depending[msg_id]
320 319 if timeout and timeout < now:
321 320 self.fail_unreachable(msg_id, timeout=True)
322 321
323 322 @logged
324 323 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
325 324 """a task has become unreachable, send a reply with an ImpossibleDependency
326 325 error."""
327 326 if msg_id not in self.depending:
328 327 self.log.error("msg %r already failed!"%msg_id)
329 328 return
330 329 raw_msg,targets,after,follow,timeout = self.depending.pop(msg_id)
331 330 for mid in follow.union(after):
332 331 if mid in self.graph:
333 332 self.graph[mid].remove(msg_id)
334 333
335 334 # FIXME: unpacking a message I've already unpacked, but didn't save:
336 335 idents,msg = self.session.feed_identities(raw_msg, copy=False)
337 336 msg = self.session.unpack_message(msg, copy=False, content=False)
338 337 header = msg['header']
339 338
340 339 try:
341 340 raise why()
342 341 except:
343 content = ss.wrap_exception()
342 content = error.wrap_exception()
344 343
345 344 self.all_done.add(msg_id)
346 345 self.all_failed.add(msg_id)
347 346
348 347 msg = self.session.send(self.client_stream, 'apply_reply', content,
349 348 parent=header, ident=idents)
350 349 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
351 350
352 351 self.update_graph(msg_id, success=False)
353 352
354 353 @logged
355 354 def maybe_run(self, msg_id, raw_msg, targets, after, follow, timeout):
356 355 """check location dependencies, and run if they are met."""
357 356 blacklist = self.blacklist.setdefault(msg_id, set())
358 357 if follow or targets or blacklist:
359 358 # we need a can_run filter
360 359 def can_run(idx):
361 360 target = self.targets[idx]
362 361 # check targets
363 362 if targets and target not in targets:
364 363 return False
365 364 # check blacklist
366 365 if target in blacklist:
367 366 return False
368 367 # check follow
369 368 return follow.check(self.completed[target], self.failed[target])
370 369
371 370 indices = filter(can_run, range(len(self.targets)))
372 371 if not indices:
373 372 # couldn't run
374 373 if follow.all:
375 374 # check follow for impossibility
376 375 dests = set()
377 376 relevant = self.all_completed if follow.success_only else self.all_done
378 377 for m in follow.intersection(relevant):
379 378 dests.add(self.destinations[m])
380 379 if len(dests) > 1:
381 380 self.fail_unreachable(msg_id)
382 381 return False
383 382 if targets:
384 383 # check blacklist+targets for impossibility
385 384 targets.difference_update(blacklist)
386 385 if not targets or not targets.intersection(self.targets):
387 386 self.fail_unreachable(msg_id)
388 387 return False
389 388 return False
390 389 else:
391 390 indices = None
392 391
393 392 self.submit_task(msg_id, raw_msg, targets, follow, timeout, indices)
394 393 return True
395 394
396 395 @logged
397 396 def save_unmet(self, msg_id, raw_msg, targets, after, follow, timeout):
398 397 """Save a message for later submission when its dependencies are met."""
399 398 self.depending[msg_id] = [raw_msg,targets,after,follow,timeout]
400 399 # track the ids in follow or after, but not those already finished
401 400 for dep_id in after.union(follow).difference(self.all_done):
402 401 if dep_id not in self.graph:
403 402 self.graph[dep_id] = set()
404 403 self.graph[dep_id].add(msg_id)
405 404
406 405 @logged
407 406 def submit_task(self, msg_id, raw_msg, targets, follow, timeout, indices=None):
408 407 """Submit a task to any of a subset of our targets."""
409 408 if indices:
410 409 loads = [self.loads[i] for i in indices]
411 410 else:
412 411 loads = self.loads
413 412 idx = self.scheme(loads)
414 413 if indices:
415 414 idx = indices[idx]
416 415 target = self.targets[idx]
417 416 # print (target, map(str, msg[:3]))
418 417 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
419 418 self.engine_stream.send_multipart(raw_msg, copy=False)
420 419 self.add_job(idx)
421 420 self.pending[target][msg_id] = (raw_msg, targets, MET, follow, timeout)
422 421 content = dict(msg_id=msg_id, engine_id=target)
423 422 self.session.send(self.mon_stream, 'task_destination', content=content,
424 423 ident=['tracktask',self.session.session])
425 424
426 425 #-----------------------------------------------------------------------
427 426 # Result Handling
428 427 #-----------------------------------------------------------------------
429 428 @logged
430 429 def dispatch_result(self, raw_msg):
431 430 """dispatch method for result replies"""
432 431 try:
433 432 idents,msg = self.session.feed_identities(raw_msg, copy=False)
434 433 msg = self.session.unpack_message(msg, content=False, copy=False)
435 434 except:
436 435 self.log.error("task::Invaid result: %s"%raw_msg, exc_info=True)
437 436 return
438 437
439 438 header = msg['header']
440 439 if header.get('dependencies_met', True):
441 440 success = (header['status'] == 'ok')
442 441 self.handle_result(idents, msg['parent_header'], raw_msg, success)
443 442 # send to Hub monitor
444 443 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
445 444 else:
446 445 self.handle_unmet_dependency(idents, msg['parent_header'])
447 446
448 447 @logged
449 448 def handle_result(self, idents, parent, raw_msg, success=True):
450 449 """handle a real task result, either success or failure"""
451 450 # first, relay result to client
452 451 engine = idents[0]
453 452 client = idents[1]
454 453 # swap_ids for XREP-XREP mirror
455 454 raw_msg[:2] = [client,engine]
456 455 # print (map(str, raw_msg[:4]))
457 456 self.client_stream.send_multipart(raw_msg, copy=False)
458 457 # now, update our data structures
459 458 msg_id = parent['msg_id']
460 459 self.blacklist.pop(msg_id, None)
461 460 self.pending[engine].pop(msg_id)
462 461 if success:
463 462 self.completed[engine].add(msg_id)
464 463 self.all_completed.add(msg_id)
465 464 else:
466 465 self.failed[engine].add(msg_id)
467 466 self.all_failed.add(msg_id)
468 467 self.all_done.add(msg_id)
469 468 self.destinations[msg_id] = engine
470 469
471 470 self.update_graph(msg_id, success)
472 471
473 472 @logged
474 473 def handle_unmet_dependency(self, idents, parent):
475 474 """handle an unmet dependency"""
476 475 engine = idents[0]
477 476 msg_id = parent['msg_id']
478 477
479 478 if msg_id not in self.blacklist:
480 479 self.blacklist[msg_id] = set()
481 480 self.blacklist[msg_id].add(engine)
482 481
483 482 args = self.pending[engine].pop(msg_id)
484 483 raw,targets,after,follow,timeout = args
485 484
486 485 if self.blacklist[msg_id] == targets:
487 486 self.depending[msg_id] = args
488 487 return self.fail_unreachable(msg_id)
489 488
490 489 elif not self.maybe_run(msg_id, *args):
491 490 # resubmit failed, put it back in our dependency tree
492 491 self.save_unmet(msg_id, *args)
493 492
494 493
495 494 @logged
496 495 def update_graph(self, dep_id, success=True):
497 496 """dep_id just finished. Update our dependency
498 497 graph and submit any jobs that just became runable."""
499 498 # print ("\n\n***********")
500 499 # pprint (dep_id)
501 500 # pprint (self.graph)
502 501 # pprint (self.depending)
503 502 # pprint (self.all_completed)
504 503 # pprint (self.all_failed)
505 504 # print ("\n\n***********\n\n")
506 505 if dep_id not in self.graph:
507 506 return
508 507 jobs = self.graph.pop(dep_id)
509 508
510 509 for msg_id in jobs:
511 510 raw_msg, targets, after, follow, timeout = self.depending[msg_id]
512 511 # if dep_id in after:
513 512 # if after.all and (success or not after.success_only):
514 513 # after.remove(dep_id)
515 514
516 515 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
517 516 self.fail_unreachable(msg_id)
518 517
519 518 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
520 519 if self.maybe_run(msg_id, raw_msg, targets, MET, follow, timeout):
521 520
522 521 self.depending.pop(msg_id)
523 522 for mid in follow.union(after):
524 523 if mid in self.graph:
525 524 self.graph[mid].remove(msg_id)
526 525
527 526 #----------------------------------------------------------------------
528 527 # methods to be overridden by subclasses
529 528 #----------------------------------------------------------------------
530 529
531 530 def add_job(self, idx):
532 531 """Called after self.targets[idx] just got the job with header.
533 532 Override with subclasses. The default ordering is simple LRU.
534 533 The default loads are the number of outstanding jobs."""
535 534 self.loads[idx] += 1
536 535 for lis in (self.targets, self.loads):
537 536 lis.append(lis.pop(idx))
538 537
539 538
540 539 def finish_job(self, idx):
541 540 """Called after self.targets[idx] just finished a job.
542 541 Override with subclasses."""
543 542 self.loads[idx] -= 1
544 543
545 544
546 545
547 546 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,logname='ZMQ',
548 547 log_addr=None, loglevel=logging.DEBUG, scheme='lru'):
549 548 from zmq.eventloop import ioloop
550 549 from zmq.eventloop.zmqstream import ZMQStream
551 550
552 551 ctx = zmq.Context()
553 552 loop = ioloop.IOLoop()
554 553 print (in_addr, out_addr, mon_addr, not_addr)
555 554 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
556 555 ins.bind(in_addr)
557 556 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
558 557 outs.bind(out_addr)
559 558 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
560 559 mons.connect(mon_addr)
561 560 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
562 561 nots.setsockopt(zmq.SUBSCRIBE, '')
563 562 nots.connect(not_addr)
564 563
565 564 scheme = globals().get(scheme, None)
566 565 # setup logging
567 566 if log_addr:
568 567 connect_logger(logname, ctx, log_addr, root="scheduler", loglevel=loglevel)
569 568 else:
570 569 local_logger(logname, loglevel)
571 570
572 571 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
573 572 mon_stream=mons, notifier_stream=nots,
574 573 scheme=scheme, loop=loop, logname=logname,
575 574 config=config)
576 575 scheduler.start()
577 576 try:
578 577 loop.start()
579 578 except KeyboardInterrupt:
580 579 print ("interrupted, exiting...", file=sys.__stderr__)
581 580
@@ -1,487 +1,484 b''
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5
6 6 #-----------------------------------------------------------------------------
7 7 # Imports
8 8 #-----------------------------------------------------------------------------
9 9
10 10 # Standard library imports.
11 11 from __future__ import print_function
12 import __builtin__
13 12
14 import logging
15 import os
16 13 import sys
17 14 import time
18 import traceback
19 15
20 16 from code import CommandCompiler
21 17 from datetime import datetime
22 18 from pprint import pprint
23 19 from signal import SIGTERM, SIGKILL
24 20
25 21 # System library imports.
26 22 import zmq
27 23 from zmq.eventloop import ioloop, zmqstream
28 24
29 25 # Local imports.
30 26 from IPython.core import ultratb
31 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
27 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Str
32 28 from IPython.zmq.completer import KernelCompleter
33 29 from IPython.zmq.iostream import OutStream
34 30 from IPython.zmq.displayhook import DisplayHook
35 31
36 32 from . import heartmonitor
37 33 from .client import Client
34 from .error import wrap_exception
38 35 from .factory import SessionFactory
39 from .streamsession import StreamSession, Message, extract_header, serialize_object,\
40 unpack_apply_message, ISO8601, wrap_exception
36 from .streamsession import StreamSession
37 from .util import serialize_object, unpack_apply_message, ISO8601
41 38
42 39 def printer(*args):
43 40 pprint(args, stream=sys.__stdout__)
44 41
45 42
46 43 class _Passer:
47 44 """Empty class that implements `send()` that does nothing."""
48 45 def send(self, *args, **kwargs):
49 46 pass
50 47 send_multipart = send
51 48
52 49
53 50 #-----------------------------------------------------------------------------
54 51 # Main kernel class
55 52 #-----------------------------------------------------------------------------
56 53
57 54 class Kernel(SessionFactory):
58 55
59 56 #---------------------------------------------------------------------------
60 57 # Kernel interface
61 58 #---------------------------------------------------------------------------
62 59
63 60 # kwargs:
64 61 int_id = Int(-1, config=True)
65 62 user_ns = Dict(config=True)
66 63 exec_lines = List(config=True)
67 64
68 65 control_stream = Instance(zmqstream.ZMQStream)
69 66 task_stream = Instance(zmqstream.ZMQStream)
70 67 iopub_stream = Instance(zmqstream.ZMQStream)
71 68 client = Instance('IPython.zmq.parallel.client.Client')
72 69
73 70 # internals
74 71 shell_streams = List()
75 72 compiler = Instance(CommandCompiler, (), {})
76 73 completer = Instance(KernelCompleter)
77 74
78 75 aborted = Set()
79 76 shell_handlers = Dict()
80 77 control_handlers = Dict()
81 78
82 79 def _set_prefix(self):
83 80 self.prefix = "engine.%s"%self.int_id
84 81
85 82 def _connect_completer(self):
86 83 self.completer = KernelCompleter(self.user_ns)
87 84
88 85 def __init__(self, **kwargs):
89 86 super(Kernel, self).__init__(**kwargs)
90 87 self._set_prefix()
91 88 self._connect_completer()
92 89
93 90 self.on_trait_change(self._set_prefix, 'id')
94 91 self.on_trait_change(self._connect_completer, 'user_ns')
95 92
96 93 # Build dict of handlers for message types
97 94 for msg_type in ['execute_request', 'complete_request', 'apply_request',
98 95 'clear_request']:
99 96 self.shell_handlers[msg_type] = getattr(self, msg_type)
100 97
101 98 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
102 99 self.control_handlers[msg_type] = getattr(self, msg_type)
103 100
104 101 self._initial_exec_lines()
105 102
106 103 def _wrap_exception(self, method=None):
107 104 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
108 105 content=wrap_exception(e_info)
109 106 return content
110 107
111 108 def _initial_exec_lines(self):
112 109 s = _Passer()
113 110 content = dict(silent=True, user_variable=[],user_expressions=[])
114 111 for line in self.exec_lines:
115 112 self.log.debug("executing initialization: %s"%line)
116 113 content.update({'code':line})
117 114 msg = self.session.msg('execute_request', content)
118 115 self.execute_request(s, [], msg)
119 116
120 117
121 118 #-------------------- control handlers -----------------------------
122 119 def abort_queues(self):
123 120 for stream in self.shell_streams:
124 121 if stream:
125 122 self.abort_queue(stream)
126 123
127 124 def abort_queue(self, stream):
128 125 while True:
129 126 try:
130 127 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
131 128 except zmq.ZMQError as e:
132 129 if e.errno == zmq.EAGAIN:
133 130 break
134 131 else:
135 132 return
136 133 else:
137 134 if msg is None:
138 135 return
139 136 else:
140 137 idents,msg = msg
141 138
142 139 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
143 140 # msg = self.reply_socket.recv_json()
144 141 self.log.info("Aborting:")
145 142 self.log.info(str(msg))
146 143 msg_type = msg['msg_type']
147 144 reply_type = msg_type.split('_')[0] + '_reply'
148 145 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
149 146 # self.reply_socket.send(ident,zmq.SNDMORE)
150 147 # self.reply_socket.send_json(reply_msg)
151 148 reply_msg = self.session.send(stream, reply_type,
152 149 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
153 150 self.log.debug(str(reply_msg))
154 151 # We need to wait a bit for requests to come in. This can probably
155 152 # be set shorter for true asynchronous clients.
156 153 time.sleep(0.05)
157 154
158 155 def abort_request(self, stream, ident, parent):
159 156 """abort a specifig msg by id"""
160 157 msg_ids = parent['content'].get('msg_ids', None)
161 158 if isinstance(msg_ids, basestring):
162 159 msg_ids = [msg_ids]
163 160 if not msg_ids:
164 161 self.abort_queues()
165 162 for mid in msg_ids:
166 163 self.aborted.add(str(mid))
167 164
168 165 content = dict(status='ok')
169 166 reply_msg = self.session.send(stream, 'abort_reply', content=content,
170 167 parent=parent, ident=ident)[0]
171 168 self.log.debug(str(reply_msg))
172 169
173 170 def shutdown_request(self, stream, ident, parent):
174 171 """kill ourself. This should really be handled in an external process"""
175 172 try:
176 173 self.abort_queues()
177 174 except:
178 175 content = self._wrap_exception('shutdown')
179 176 else:
180 177 content = dict(parent['content'])
181 178 content['status'] = 'ok'
182 179 msg = self.session.send(stream, 'shutdown_reply',
183 180 content=content, parent=parent, ident=ident)
184 181 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
185 182 # content, parent, ident)
186 183 # print >> sys.__stdout__, msg
187 184 # time.sleep(0.2)
188 185 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
189 186 dc.start()
190 187
191 188 def dispatch_control(self, msg):
192 189 idents,msg = self.session.feed_identities(msg, copy=False)
193 190 try:
194 191 msg = self.session.unpack_message(msg, content=True, copy=False)
195 192 except:
196 193 self.log.error("Invalid Message", exc_info=True)
197 194 return
198 195
199 196 header = msg['header']
200 197 msg_id = header['msg_id']
201 198
202 199 handler = self.control_handlers.get(msg['msg_type'], None)
203 200 if handler is None:
204 201 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
205 202 else:
206 203 handler(self.control_stream, idents, msg)
207 204
208 205
209 206 #-------------------- queue helpers ------------------------------
210 207
211 208 def check_dependencies(self, dependencies):
212 209 if not dependencies:
213 210 return True
214 211 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
215 212 anyorall = dependencies[0]
216 213 dependencies = dependencies[1]
217 214 else:
218 215 anyorall = 'all'
219 216 results = self.client.get_results(dependencies,status_only=True)
220 217 if results['status'] != 'ok':
221 218 return False
222 219
223 220 if anyorall == 'any':
224 221 if not results['completed']:
225 222 return False
226 223 else:
227 224 if results['pending']:
228 225 return False
229 226
230 227 return True
231 228
232 229 def check_aborted(self, msg_id):
233 230 return msg_id in self.aborted
234 231
235 232 #-------------------- queue handlers -----------------------------
236 233
237 234 def clear_request(self, stream, idents, parent):
238 235 """Clear our namespace."""
239 236 self.user_ns = {}
240 237 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
241 238 content = dict(status='ok'))
242 239 self._initial_exec_lines()
243 240
244 241 def execute_request(self, stream, ident, parent):
245 242 self.log.debug('execute request %s'%parent)
246 243 try:
247 244 code = parent[u'content'][u'code']
248 245 except:
249 246 self.log.error("Got bad msg: %s"%parent, exc_info=True)
250 247 return
251 248 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
252 249 ident='%s.pyin'%self.prefix)
253 250 started = datetime.now().strftime(ISO8601)
254 251 try:
255 252 comp_code = self.compiler(code, '<zmq-kernel>')
256 253 # allow for not overriding displayhook
257 254 if hasattr(sys.displayhook, 'set_parent'):
258 255 sys.displayhook.set_parent(parent)
259 256 sys.stdout.set_parent(parent)
260 257 sys.stderr.set_parent(parent)
261 258 exec comp_code in self.user_ns, self.user_ns
262 259 except:
263 260 exc_content = self._wrap_exception('execute')
264 261 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
265 262 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
266 263 ident='%s.pyerr'%self.prefix)
267 264 reply_content = exc_content
268 265 else:
269 266 reply_content = {'status' : 'ok'}
270 267
271 268 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
272 269 ident=ident, subheader = dict(started=started))
273 270 self.log.debug(str(reply_msg))
274 271 if reply_msg['content']['status'] == u'error':
275 272 self.abort_queues()
276 273
277 274 def complete_request(self, stream, ident, parent):
278 275 matches = {'matches' : self.complete(parent),
279 276 'status' : 'ok'}
280 277 completion_msg = self.session.send(stream, 'complete_reply',
281 278 matches, parent, ident)
282 279 # print >> sys.__stdout__, completion_msg
283 280
284 281 def complete(self, msg):
285 282 return self.completer.complete(msg.content.line, msg.content.text)
286 283
287 284 def apply_request(self, stream, ident, parent):
288 285 # flush previous reply, so this request won't block it
289 286 stream.flush(zmq.POLLOUT)
290 287
291 288 try:
292 289 content = parent[u'content']
293 290 bufs = parent[u'buffers']
294 291 msg_id = parent['header']['msg_id']
295 292 bound = content.get('bound', False)
296 293 except:
297 294 self.log.error("Got bad msg: %s"%parent, exc_info=True)
298 295 return
299 296 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
300 297 # self.iopub_stream.send(pyin_msg)
301 298 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
302 299 sub = {'dependencies_met' : True, 'engine' : self.ident,
303 300 'started': datetime.now().strftime(ISO8601)}
304 301 try:
305 302 # allow for not overriding displayhook
306 303 if hasattr(sys.displayhook, 'set_parent'):
307 304 sys.displayhook.set_parent(parent)
308 305 sys.stdout.set_parent(parent)
309 306 sys.stderr.set_parent(parent)
310 307 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
311 308 if bound:
312 309 working = self.user_ns
313 310 suffix = str(msg_id).replace("-","")
314 311 prefix = "_"
315 312
316 313 else:
317 314 working = dict()
318 315 suffix = prefix = "_" # prevent keyword collisions with lambda
319 316 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
320 317 # if f.fun
321 318 fname = getattr(f, '__name__', 'f')
322 319
323 320 fname = prefix+fname.strip('<>')+suffix
324 321 argname = prefix+"args"+suffix
325 322 kwargname = prefix+"kwargs"+suffix
326 323 resultname = prefix+"result"+suffix
327 324
328 325 ns = { fname : f, argname : args, kwargname : kwargs }
329 326 # print ns
330 327 working.update(ns)
331 328 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
332 329 exec code in working, working
333 330 result = working.get(resultname)
334 331 # clear the namespace
335 332 if bound:
336 333 for key in ns.iterkeys():
337 334 self.user_ns.pop(key)
338 335 else:
339 336 del working
340 337
341 338 packed_result,buf = serialize_object(result)
342 339 result_buf = [packed_result]+buf
343 340 except:
344 341 exc_content = self._wrap_exception('apply')
345 342 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
346 343 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
347 344 ident='%s.pyerr'%self.prefix)
348 345 reply_content = exc_content
349 346 result_buf = []
350 347
351 348 if exc_content['ename'] == 'UnmetDependency':
352 349 sub['dependencies_met'] = False
353 350 else:
354 351 reply_content = {'status' : 'ok'}
355 352
356 353 # put 'ok'/'error' status in header, for scheduler introspection:
357 354 sub['status'] = reply_content['status']
358 355
359 356 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
360 357 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
361 358
362 359 # if reply_msg['content']['status'] == u'error':
363 360 # self.abort_queues()
364 361
365 362 def dispatch_queue(self, stream, msg):
366 363 self.control_stream.flush()
367 364 idents,msg = self.session.feed_identities(msg, copy=False)
368 365 try:
369 366 msg = self.session.unpack_message(msg, content=True, copy=False)
370 367 except:
371 368 self.log.error("Invalid Message", exc_info=True)
372 369 return
373 370
374 371
375 372 header = msg['header']
376 373 msg_id = header['msg_id']
377 374 if self.check_aborted(msg_id):
378 375 self.aborted.remove(msg_id)
379 376 # is it safe to assume a msg_id will not be resubmitted?
380 377 reply_type = msg['msg_type'].split('_')[0] + '_reply'
381 378 reply_msg = self.session.send(stream, reply_type,
382 379 content={'status' : 'aborted'}, parent=msg, ident=idents)
383 380 return
384 381 handler = self.shell_handlers.get(msg['msg_type'], None)
385 382 if handler is None:
386 383 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
387 384 else:
388 385 handler(stream, idents, msg)
389 386
390 387 def start(self):
391 388 #### stream mode:
392 389 if self.control_stream:
393 390 self.control_stream.on_recv(self.dispatch_control, copy=False)
394 391 self.control_stream.on_err(printer)
395 392
396 393 def make_dispatcher(stream):
397 394 def dispatcher(msg):
398 395 return self.dispatch_queue(stream, msg)
399 396 return dispatcher
400 397
401 398 for s in self.shell_streams:
402 399 s.on_recv(make_dispatcher(s), copy=False)
403 400 s.on_err(printer)
404 401
405 402 if self.iopub_stream:
406 403 self.iopub_stream.on_err(printer)
407 404
408 405 #### while True mode:
409 406 # while True:
410 407 # idle = True
411 408 # try:
412 409 # msg = self.shell_stream.socket.recv_multipart(
413 410 # zmq.NOBLOCK, copy=False)
414 411 # except zmq.ZMQError, e:
415 412 # if e.errno != zmq.EAGAIN:
416 413 # raise e
417 414 # else:
418 415 # idle=False
419 416 # self.dispatch_queue(self.shell_stream, msg)
420 417 #
421 418 # if not self.task_stream.empty():
422 419 # idle=False
423 420 # msg = self.task_stream.recv_multipart()
424 421 # self.dispatch_queue(self.task_stream, msg)
425 422 # if idle:
426 423 # # don't busywait
427 424 # time.sleep(1e-3)
428 425
429 426 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
430 427 client_addr=None, loop=None, context=None, key=None,
431 428 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
432 429 """NO LONGER IN USE"""
433 430 # create loop, context, and session:
434 431 if loop is None:
435 432 loop = ioloop.IOLoop.instance()
436 433 if context is None:
437 434 context = zmq.Context()
438 435 c = context
439 436 session = StreamSession(key=key)
440 437 # print (session.key)
441 438 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
442 439
443 440 # create Control Stream
444 441 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
445 442 control_stream.setsockopt(zmq.IDENTITY, identity)
446 443 control_stream.connect(control_addr)
447 444
448 445 # create Shell Streams (MUX, Task, etc.):
449 446 shell_streams = []
450 447 for addr in shell_addrs:
451 448 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
452 449 stream.setsockopt(zmq.IDENTITY, identity)
453 450 stream.connect(addr)
454 451 shell_streams.append(stream)
455 452
456 453 # create iopub stream:
457 454 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
458 455 iopub_stream.setsockopt(zmq.IDENTITY, identity)
459 456 iopub_stream.connect(iopub_addr)
460 457
461 458 # Redirect input streams and set a display hook.
462 459 if out_stream_factory:
463 460 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
464 461 sys.stdout.topic = 'engine.%i.stdout'%int_id
465 462 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
466 463 sys.stderr.topic = 'engine.%i.stderr'%int_id
467 464 if display_hook_factory:
468 465 sys.displayhook = display_hook_factory(session, iopub_stream)
469 466 sys.displayhook.topic = 'engine.%i.pyout'%int_id
470 467
471 468
472 469 # launch heartbeat
473 470 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
474 471 heart.start()
475 472
476 473 # create (optional) Client
477 474 if client_addr:
478 475 client = Client(client_addr, username=identity)
479 476 else:
480 477 client = None
481 478
482 479 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
483 480 shell_streams=shell_streams, iopub_stream=iopub_stream,
484 481 client=client, loop=loop)
485 482 kernel.start()
486 483 return loop, c, kernel
487 484
@@ -1,542 +1,377 b''
1 1 #!/usr/bin/env python
2 2 """edited session.py to work with streams, and move msg_type to the header
3 3 """
4 4
5 5
6 6 import os
7 7 import pprint
8 import sys
9 import traceback
10 8 import uuid
11 9 from datetime import datetime
12 10
13 11 try:
14 12 import cPickle
15 13 pickle = cPickle
16 14 except:
17 15 cPickle = None
18 16 import pickle
19 17
20 18 import zmq
21 19 from zmq.utils import jsonapi
22 20 from zmq.eventloop.zmqstream import ZMQStream
23 21
24 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
25 from IPython.utils.newserialized import serialize, unserialize
26
27 from .error import RemoteError
22 from .util import ISO8601
28 23
29 24 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 25 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
31 26 if json_name in ('jsonlib', 'jsonlib2'):
32 27 use_json = True
33 28 elif json_name:
34 29 if cPickle is None:
35 30 use_json = True
36 31 else:
37 32 use_json = False
38 33 else:
39 34 use_json = False
40 35
41 36 def squash_unicode(obj):
42 37 if isinstance(obj,dict):
43 38 for key in obj.keys():
44 39 obj[key] = squash_unicode(obj[key])
45 40 if isinstance(key, unicode):
46 41 obj[squash_unicode(key)] = obj.pop(key)
47 42 elif isinstance(obj, list):
48 43 for i,v in enumerate(obj):
49 44 obj[i] = squash_unicode(v)
50 45 elif isinstance(obj, unicode):
51 46 obj = obj.encode('utf8')
52 47 return obj
53 48
54 49 json_packer = jsonapi.dumps
55 50 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
56 51
57 52 pickle_packer = lambda o: pickle.dumps(o,-1)
58 53 pickle_unpacker = pickle.loads
59 54
60 55 if use_json:
61 56 default_packer = json_packer
62 57 default_unpacker = json_unpacker
63 58 else:
64 59 default_packer = pickle_packer
65 60 default_unpacker = pickle_unpacker
66 61
67 62
68 63 DELIM="<IDS|MSG>"
69 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
70
71 def wrap_exception(engine_info={}):
72 etype, evalue, tb = sys.exc_info()
73 stb = traceback.format_exception(etype, evalue, tb)
74 exc_content = {
75 'status' : 'error',
76 'traceback' : stb,
77 'ename' : unicode(etype.__name__),
78 'evalue' : unicode(evalue),
79 'engine_info' : engine_info
80 }
81 return exc_content
82
83 def unwrap_exception(content):
84 err = RemoteError(content['ename'], content['evalue'],
85 ''.join(content['traceback']),
86 content.get('engine_info', {}))
87 return err
88
89 64
90 65 class Message(object):
91 66 """A simple message object that maps dict keys to attributes.
92 67
93 68 A Message can be created from a dict and a dict from a Message instance
94 69 simply by calling dict(msg_obj)."""
95 70
96 71 def __init__(self, msg_dict):
97 72 dct = self.__dict__
98 73 for k, v in dict(msg_dict).iteritems():
99 74 if isinstance(v, dict):
100 75 v = Message(v)
101 76 dct[k] = v
102 77
103 78 # Having this iterator lets dict(msg_obj) work out of the box.
104 79 def __iter__(self):
105 80 return iter(self.__dict__.iteritems())
106 81
107 82 def __repr__(self):
108 83 return repr(self.__dict__)
109 84
110 85 def __str__(self):
111 86 return pprint.pformat(self.__dict__)
112 87
113 88 def __contains__(self, k):
114 89 return k in self.__dict__
115 90
116 91 def __getitem__(self, k):
117 92 return self.__dict__[k]
118 93
119 94
120 95 def msg_header(msg_id, msg_type, username, session):
121 96 date=datetime.now().strftime(ISO8601)
122 97 return locals()
123 98
124 99 def extract_header(msg_or_header):
125 100 """Given a message or header, return the header."""
126 101 if not msg_or_header:
127 102 return {}
128 103 try:
129 104 # See if msg_or_header is the entire message.
130 105 h = msg_or_header['header']
131 106 except KeyError:
132 107 try:
133 108 # See if msg_or_header is just the header
134 109 h = msg_or_header['msg_id']
135 110 except KeyError:
136 111 raise
137 112 else:
138 113 h = msg_or_header
139 114 if not isinstance(h, dict):
140 115 h = dict(h)
141 116 return h
142 117
143 def rekey(dikt):
144 """Rekey a dict that has been forced to use str keys where there should be
145 ints by json. This belongs in the jsonutil added by fperez."""
146 for k in dikt.iterkeys():
147 if isinstance(k, str):
148 ik=fk=None
149 try:
150 ik = int(k)
151 except ValueError:
152 try:
153 fk = float(k)
154 except ValueError:
155 continue
156 if ik is not None:
157 nk = ik
158 else:
159 nk = fk
160 if nk in dikt:
161 raise KeyError("already have key %r"%nk)
162 dikt[nk] = dikt.pop(k)
163 return dikt
164
165 def serialize_object(obj, threshold=64e-6):
166 """Serialize an object into a list of sendable buffers.
167
168 Parameters
169 ----------
170
171 obj : object
172 The object to be serialized
173 threshold : float
174 The threshold for not double-pickling the content.
175
176
177 Returns
178 -------
179 ('pmd', [bufs]) :
180 where pmd is the pickled metadata wrapper,
181 bufs is a list of data buffers
182 """
183 databuffers = []
184 if isinstance(obj, (list, tuple)):
185 clist = canSequence(obj)
186 slist = map(serialize, clist)
187 for s in slist:
188 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
189 databuffers.append(s.getData())
190 s.data = None
191 return pickle.dumps(slist,-1), databuffers
192 elif isinstance(obj, dict):
193 sobj = {}
194 for k in sorted(obj.iterkeys()):
195 s = serialize(can(obj[k]))
196 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
197 databuffers.append(s.getData())
198 s.data = None
199 sobj[k] = s
200 return pickle.dumps(sobj,-1),databuffers
201 else:
202 s = serialize(can(obj))
203 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
204 databuffers.append(s.getData())
205 s.data = None
206 return pickle.dumps(s,-1),databuffers
207
208
209 def unserialize_object(bufs):
210 """reconstruct an object serialized by serialize_object from data buffers."""
211 bufs = list(bufs)
212 sobj = pickle.loads(bufs.pop(0))
213 if isinstance(sobj, (list, tuple)):
214 for s in sobj:
215 if s.data is None:
216 s.data = bufs.pop(0)
217 return uncanSequence(map(unserialize, sobj)), bufs
218 elif isinstance(sobj, dict):
219 newobj = {}
220 for k in sorted(sobj.iterkeys()):
221 s = sobj[k]
222 if s.data is None:
223 s.data = bufs.pop(0)
224 newobj[k] = uncan(unserialize(s))
225 return newobj, bufs
226 else:
227 if sobj.data is None:
228 sobj.data = bufs.pop(0)
229 return uncan(unserialize(sobj)), bufs
230
231 def pack_apply_message(f, args, kwargs, threshold=64e-6):
232 """pack up a function, args, and kwargs to be sent over the wire
233 as a series of buffers. Any object whose data is larger than `threshold`
234 will not have their data copied (currently only numpy arrays support zero-copy)"""
235 msg = [pickle.dumps(can(f),-1)]
236 databuffers = [] # for large objects
237 sargs, bufs = serialize_object(args,threshold)
238 msg.append(sargs)
239 databuffers.extend(bufs)
240 skwargs, bufs = serialize_object(kwargs,threshold)
241 msg.append(skwargs)
242 databuffers.extend(bufs)
243 msg.extend(databuffers)
244 return msg
245
246 def unpack_apply_message(bufs, g=None, copy=True):
247 """unpack f,args,kwargs from buffers packed by pack_apply_message()
248 Returns: original f,args,kwargs"""
249 bufs = list(bufs) # allow us to pop
250 assert len(bufs) >= 3, "not enough buffers!"
251 if not copy:
252 for i in range(3):
253 bufs[i] = bufs[i].bytes
254 cf = pickle.loads(bufs.pop(0))
255 sargs = list(pickle.loads(bufs.pop(0)))
256 skwargs = dict(pickle.loads(bufs.pop(0)))
257 # print sargs, skwargs
258 f = uncan(cf, g)
259 for sa in sargs:
260 if sa.data is None:
261 m = bufs.pop(0)
262 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
263 if copy:
264 sa.data = buffer(m)
265 else:
266 sa.data = m.buffer
267 else:
268 if copy:
269 sa.data = m
270 else:
271 sa.data = m.bytes
272
273 args = uncanSequence(map(unserialize, sargs), g)
274 kwargs = {}
275 for k in sorted(skwargs.iterkeys()):
276 sa = skwargs[k]
277 if sa.data is None:
278 sa.data = bufs.pop(0)
279 kwargs[k] = uncan(unserialize(sa), g)
280
281 return f,args,kwargs
282
283 118 class StreamSession(object):
284 119 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
285 120 debug=False
286 121 key=None
287 122
288 123 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
289 124 if username is None:
290 125 username = os.environ.get('USER','username')
291 126 self.username = username
292 127 if session is None:
293 128 self.session = str(uuid.uuid4())
294 129 else:
295 130 self.session = session
296 131 self.msg_id = str(uuid.uuid4())
297 132 if packer is None:
298 133 self.pack = default_packer
299 134 else:
300 135 if not callable(packer):
301 136 raise TypeError("packer must be callable, not %s"%type(packer))
302 137 self.pack = packer
303 138
304 139 if unpacker is None:
305 140 self.unpack = default_unpacker
306 141 else:
307 142 if not callable(unpacker):
308 143 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
309 144 self.unpack = unpacker
310 145
311 146 if key is not None and keyfile is not None:
312 147 raise TypeError("Must specify key OR keyfile, not both")
313 148 if keyfile is not None:
314 149 with open(keyfile) as f:
315 150 self.key = f.read().strip()
316 151 else:
317 152 self.key = key
318 153 if isinstance(self.key, unicode):
319 154 self.key = self.key.encode('utf8')
320 155 # print key, keyfile, self.key
321 156 self.none = self.pack({})
322 157
323 158 def msg_header(self, msg_type):
324 159 h = msg_header(self.msg_id, msg_type, self.username, self.session)
325 160 self.msg_id = str(uuid.uuid4())
326 161 return h
327 162
328 163 def msg(self, msg_type, content=None, parent=None, subheader=None):
329 164 msg = {}
330 165 msg['header'] = self.msg_header(msg_type)
331 166 msg['msg_id'] = msg['header']['msg_id']
332 167 msg['parent_header'] = {} if parent is None else extract_header(parent)
333 168 msg['msg_type'] = msg_type
334 169 msg['content'] = {} if content is None else content
335 170 sub = {} if subheader is None else subheader
336 171 msg['header'].update(sub)
337 172 return msg
338 173
339 174 def check_key(self, msg_or_header):
340 175 """Check that a message's header has the right key"""
341 176 if self.key is None:
342 177 return True
343 178 header = extract_header(msg_or_header)
344 179 return header.get('key', None) == self.key
345 180
346 181
347 182 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
348 183 """Build and send a message via stream or socket.
349 184
350 185 Parameters
351 186 ----------
352 187
353 188 stream : zmq.Socket or ZMQStream
354 189 the socket-like object used to send the data
355 190 msg_or_type : str or Message/dict
356 191 Normally, msg_or_type will be a msg_type unless a message is being sent more
357 192 than once.
358 193
359 194 Returns
360 195 -------
361 196 (msg,sent) : tuple
362 197 msg : Message
363 198 the nice wrapped dict-like object containing the headers
364 199
365 200 """
366 201 if isinstance(msg_or_type, (Message, dict)):
367 202 # we got a Message, not a msg_type
368 203 # don't build a new Message
369 204 msg = msg_or_type
370 205 content = msg['content']
371 206 else:
372 207 msg = self.msg(msg_or_type, content, parent, subheader)
373 208 buffers = [] if buffers is None else buffers
374 209 to_send = []
375 210 if isinstance(ident, list):
376 211 # accept list of idents
377 212 to_send.extend(ident)
378 213 elif ident is not None:
379 214 to_send.append(ident)
380 215 to_send.append(DELIM)
381 216 if self.key is not None:
382 217 to_send.append(self.key)
383 218 to_send.append(self.pack(msg['header']))
384 219 to_send.append(self.pack(msg['parent_header']))
385 220
386 221 if content is None:
387 222 content = self.none
388 223 elif isinstance(content, dict):
389 224 content = self.pack(content)
390 225 elif isinstance(content, str):
391 226 # content is already packed, as in a relayed message
392 227 pass
393 228 else:
394 229 raise TypeError("Content incorrect type: %s"%type(content))
395 230 to_send.append(content)
396 231 flag = 0
397 232 if buffers:
398 233 flag = zmq.SNDMORE
399 234 stream.send_multipart(to_send, flag, copy=False)
400 235 for b in buffers[:-1]:
401 236 stream.send(b, flag, copy=False)
402 237 if buffers:
403 238 stream.send(buffers[-1], copy=False)
404 239 # omsg = Message(msg)
405 240 if self.debug:
406 241 pprint.pprint(msg)
407 242 pprint.pprint(to_send)
408 243 pprint.pprint(buffers)
409 244 return msg
410 245
411 246 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
412 247 """Send a raw message via ident path.
413 248
414 249 Parameters
415 250 ----------
416 251 msg : list of sendable buffers"""
417 252 to_send = []
418 253 if isinstance(ident, str):
419 254 ident = [ident]
420 255 if ident is not None:
421 256 to_send.extend(ident)
422 257 to_send.append(DELIM)
423 258 if self.key is not None:
424 259 to_send.append(self.key)
425 260 to_send.extend(msg)
426 261 stream.send_multipart(msg, flags, copy=copy)
427 262
428 263 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
429 264 """receives and unpacks a message
430 265 returns [idents], msg"""
431 266 if isinstance(socket, ZMQStream):
432 267 socket = socket.socket
433 268 try:
434 269 msg = socket.recv_multipart(mode)
435 270 except zmq.ZMQError as e:
436 271 if e.errno == zmq.EAGAIN:
437 272 # We can convert EAGAIN to None as we know in this case
438 273 # recv_multipart won't return None.
439 274 return None
440 275 else:
441 276 raise
442 277 # return an actual Message object
443 278 # determine the number of idents by trying to unpack them.
444 279 # this is terrible:
445 280 idents, msg = self.feed_identities(msg, copy)
446 281 try:
447 282 return idents, self.unpack_message(msg, content=content, copy=copy)
448 283 except Exception as e:
449 284 print (idents, msg)
450 285 # TODO: handle it
451 286 raise e
452 287
453 288 def feed_identities(self, msg, copy=True):
454 289 """feed until DELIM is reached, then return the prefix as idents and remainder as
455 290 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
456 291
457 292 Parameters
458 293 ----------
459 294 msg : a list of Message or bytes objects
460 295 the message to be split
461 296 copy : bool
462 297 flag determining whether the arguments are bytes or Messages
463 298
464 299 Returns
465 300 -------
466 301 (idents,msg) : two lists
467 302 idents will always be a list of bytes - the indentity prefix
468 303 msg will be a list of bytes or Messages, unchanged from input
469 304 msg should be unpackable via self.unpack_message at this point.
470 305 """
471 306 ikey = int(self.key is not None)
472 307 minlen = 3 + ikey
473 308 msg = list(msg)
474 309 idents = []
475 310 while len(msg) > minlen:
476 311 if copy:
477 312 s = msg[0]
478 313 else:
479 314 s = msg[0].bytes
480 315 if s == DELIM:
481 316 msg.pop(0)
482 317 break
483 318 else:
484 319 idents.append(s)
485 320 msg.pop(0)
486 321
487 322 return idents, msg
488 323
489 324 def unpack_message(self, msg, content=True, copy=True):
490 325 """Return a message object from the format
491 326 sent by self.send.
492 327
493 328 Parameters:
494 329 -----------
495 330
496 331 content : bool (True)
497 332 whether to unpack the content dict (True),
498 333 or leave it serialized (False)
499 334
500 335 copy : bool (True)
501 336 whether to return the bytes (True),
502 337 or the non-copying Message object in each place (False)
503 338
504 339 """
505 340 ikey = int(self.key is not None)
506 341 minlen = 3 + ikey
507 342 message = {}
508 343 if not copy:
509 344 for i in range(minlen):
510 345 msg[i] = msg[i].bytes
511 346 if ikey:
512 347 if not self.key == msg[0]:
513 348 raise KeyError("Invalid Session Key: %s"%msg[0])
514 349 if not len(msg) >= minlen:
515 350 raise TypeError("malformed message, must have at least %i elements"%minlen)
516 351 message['header'] = self.unpack(msg[ikey+0])
517 352 message['msg_type'] = message['header']['msg_type']
518 353 message['parent_header'] = self.unpack(msg[ikey+1])
519 354 if content:
520 355 message['content'] = self.unpack(msg[ikey+2])
521 356 else:
522 357 message['content'] = msg[ikey+2]
523 358
524 359 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
525 360 return message
526 361
527 362
528 363 def test_msg2obj():
529 364 am = dict(x=1)
530 365 ao = Message(am)
531 366 assert ao.x == am['x']
532 367
533 368 am['y'] = dict(z=1)
534 369 ao = Message(am)
535 370 assert ao.y.z == am['y']['z']
536 371
537 372 k1, k2 = 'y', 'z'
538 373 assert ao[k1][k2] == am[k1][k2]
539 374
540 375 am2 = dict(ao)
541 376 assert am['x'] == am2['x']
542 377 assert am['y']['z'] == am2['y']['z']
@@ -1,82 +1,82 b''
1 1
2 2 import os
3 3 import uuid
4 4 import zmq
5 5
6 6 from zmq.tests import BaseZMQTestCase
7 7
8 8 # from IPython.zmq.tests import SessionTestCase
9 9 from IPython.zmq.parallel import streamsession as ss
10 10
11 11 class SessionTestCase(BaseZMQTestCase):
12 12
13 13 def setUp(self):
14 14 BaseZMQTestCase.setUp(self)
15 15 self.session = ss.StreamSession()
16 16
17 17 class TestSession(SessionTestCase):
18 18
19 19 def test_msg(self):
20 20 """message format"""
21 21 msg = self.session.msg('execute')
22 22 thekeys = set('header msg_id parent_header msg_type content'.split())
23 23 s = set(msg.keys())
24 24 self.assertEquals(s, thekeys)
25 25 self.assertTrue(isinstance(msg['content'],dict))
26 26 self.assertTrue(isinstance(msg['header'],dict))
27 27 self.assertTrue(isinstance(msg['parent_header'],dict))
28 28 self.assertEquals(msg['msg_type'], 'execute')
29 29
30 30
31 31
32 32 def test_args(self):
33 33 """initialization arguments for StreamSession"""
34 34 s = ss.StreamSession()
35 35 self.assertTrue(s.pack is ss.default_packer)
36 36 self.assertTrue(s.unpack is ss.default_unpacker)
37 37 self.assertEquals(s.username, os.environ.get('USER', 'username'))
38 38
39 39 s = ss.StreamSession(username=None)
40 40 self.assertEquals(s.username, os.environ.get('USER', 'username'))
41 41
42 42 self.assertRaises(TypeError, ss.StreamSession, packer='hi')
43 43 self.assertRaises(TypeError, ss.StreamSession, unpacker='hi')
44 44 u = str(uuid.uuid4())
45 45 s = ss.StreamSession(username='carrot', session=u)
46 46 self.assertEquals(s.session, u)
47 47 self.assertEquals(s.username, 'carrot')
48 48
49 49
50 def test_rekey(self):
51 """rekeying dict around json str keys"""
52 d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 self.assertRaises(KeyError, ss.rekey, d)
54
55 d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 rd = ss.rekey(d)
58 self.assertEquals(d2,rd)
59
60 d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 d2 = {1.5:d['1.5'],1:d['1']}
62 rd = ss.rekey(d)
63 self.assertEquals(d2,rd)
64
65 d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 self.assertRaises(KeyError, ss.rekey, d)
67
50 # def test_rekey(self):
51 # """rekeying dict around json str keys"""
52 # d = {'0': uuid.uuid4(), 0:uuid.uuid4()}
53 # self.assertRaises(KeyError, ss.rekey, d)
54 #
55 # d = {'0': uuid.uuid4(), 1:uuid.uuid4(), 'asdf':uuid.uuid4()}
56 # d2 = {0:d['0'],1:d[1],'asdf':d['asdf']}
57 # rd = ss.rekey(d)
58 # self.assertEquals(d2,rd)
59 #
60 # d = {'1.5':uuid.uuid4(),'1':uuid.uuid4()}
61 # d2 = {1.5:d['1.5'],1:d['1']}
62 # rd = ss.rekey(d)
63 # self.assertEquals(d2,rd)
64 #
65 # d = {'1.0':uuid.uuid4(),'1':uuid.uuid4()}
66 # self.assertRaises(KeyError, ss.rekey, d)
67 #
68 68 def test_unique_msg_ids(self):
69 69 """test that messages receive unique ids"""
70 70 ids = set()
71 71 for i in range(2**12):
72 72 h = self.session.msg_header('test')
73 73 msg_id = h['msg_id']
74 74 self.assertTrue(msg_id not in ids)
75 75 ids.add(msg_id)
76 76
77 77 def test_feed_identities(self):
78 78 """scrub the front for zmq IDENTITIES"""
79 79 theids = "engine client other".split()
80 80 content = dict(code='whoda',stuff=object())
81 81 themsg = self.session.msg('execute',content=content)
82 82 pmsg = theids
@@ -1,119 +1,271 b''
1 """some generic utilities"""
1 """some generic utilities for dealing with classes, urls, and serialization"""
2 2 import re
3 3 import socket
4 4
5 try:
6 import cPickle
7 pickle = cPickle
8 except:
9 cPickle = None
10 import pickle
11
12
13 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
14 from IPython.utils.newserialized import serialize, unserialize
15
16 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
17
5 18 class ReverseDict(dict):
6 19 """simple double-keyed subset of dict methods."""
7 20
8 21 def __init__(self, *args, **kwargs):
9 22 dict.__init__(self, *args, **kwargs)
10 23 self._reverse = dict()
11 24 for key, value in self.iteritems():
12 25 self._reverse[value] = key
13 26
14 27 def __getitem__(self, key):
15 28 try:
16 29 return dict.__getitem__(self, key)
17 30 except KeyError:
18 31 return self._reverse[key]
19 32
20 33 def __setitem__(self, key, value):
21 34 if key in self._reverse:
22 35 raise KeyError("Can't have key %r on both sides!"%key)
23 36 dict.__setitem__(self, key, value)
24 37 self._reverse[value] = key
25 38
26 39 def pop(self, key):
27 40 value = dict.pop(self, key)
28 41 self._reverse.pop(value)
29 42 return value
30 43
31 44 def get(self, key, default=None):
32 45 try:
33 46 return self[key]
34 47 except KeyError:
35 48 return default
36
37 49
38 50 def validate_url(url):
39 51 """validate a url for zeromq"""
40 52 if not isinstance(url, basestring):
41 53 raise TypeError("url must be a string, not %r"%type(url))
42 54 url = url.lower()
43 55
44 56 proto_addr = url.split('://')
45 57 assert len(proto_addr) == 2, 'Invalid url: %r'%url
46 58 proto, addr = proto_addr
47 59 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
48 60
49 61 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
50 62 # author: Remi Sabourin
51 63 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
52 64
53 65 if proto == 'tcp':
54 66 lis = addr.split(':')
55 67 assert len(lis) == 2, 'Invalid url: %r'%url
56 68 addr,s_port = lis
57 69 try:
58 70 port = int(s_port)
59 71 except ValueError:
60 72 raise AssertionError("Invalid port %r in url: %r"%(port, url))
61 73
62 74 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
63 75
64 76 else:
65 77 # only validate tcp urls currently
66 78 pass
67 79
68 80 return True
69 81
70 82
71 83 def validate_url_container(container):
72 84 """validate a potentially nested collection of urls."""
73 85 if isinstance(container, basestring):
74 86 url = container
75 87 return validate_url(url)
76 88 elif isinstance(container, dict):
77 89 container = container.itervalues()
78 90
79 91 for element in container:
80 92 validate_url_container(element)
81 93
82 94
83 95 def split_url(url):
84 96 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
85 97 proto_addr = url.split('://')
86 98 assert len(proto_addr) == 2, 'Invalid url: %r'%url
87 99 proto, addr = proto_addr
88 100 lis = addr.split(':')
89 101 assert len(lis) == 2, 'Invalid url: %r'%url
90 102 addr,s_port = lis
91 103 return proto,addr,s_port
92 104
93 105 def disambiguate_ip_address(ip, location=None):
94 106 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
95 107 ones, based on the location (default interpretation of location is localhost)."""
96 108 if ip in ('0.0.0.0', '*'):
97 109 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
98 110 if location is None or location in external_ips:
99 111 ip='127.0.0.1'
100 112 elif location:
101 113 return location
102 114 return ip
103 115
104 116 def disambiguate_url(url, location=None):
105 117 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
106 118 ones, based on the location (default interpretation is localhost).
107 119
108 120 This is for zeromq urls, such as tcp://*:10101."""
109 121 try:
110 122 proto,ip,port = split_url(url)
111 123 except AssertionError:
112 124 # probably not tcp url; could be ipc, etc.
113 125 return url
114 126
115 127 ip = disambiguate_ip_address(ip,location)
116 128
117 129 return "%s://%s:%s"%(proto,ip,port)
118 130
119 131
132 def rekey(dikt):
133 """Rekey a dict that has been forced to use str keys where there should be
134 ints by json. This belongs in the jsonutil added by fperez."""
135 for k in dikt.iterkeys():
136 if isinstance(k, str):
137 ik=fk=None
138 try:
139 ik = int(k)
140 except ValueError:
141 try:
142 fk = float(k)
143 except ValueError:
144 continue
145 if ik is not None:
146 nk = ik
147 else:
148 nk = fk
149 if nk in dikt:
150 raise KeyError("already have key %r"%nk)
151 dikt[nk] = dikt.pop(k)
152 return dikt
153
154 def serialize_object(obj, threshold=64e-6):
155 """Serialize an object into a list of sendable buffers.
156
157 Parameters
158 ----------
159
160 obj : object
161 The object to be serialized
162 threshold : float
163 The threshold for not double-pickling the content.
164
165
166 Returns
167 -------
168 ('pmd', [bufs]) :
169 where pmd is the pickled metadata wrapper,
170 bufs is a list of data buffers
171 """
172 databuffers = []
173 if isinstance(obj, (list, tuple)):
174 clist = canSequence(obj)
175 slist = map(serialize, clist)
176 for s in slist:
177 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
178 databuffers.append(s.getData())
179 s.data = None
180 return pickle.dumps(slist,-1), databuffers
181 elif isinstance(obj, dict):
182 sobj = {}
183 for k in sorted(obj.iterkeys()):
184 s = serialize(can(obj[k]))
185 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
186 databuffers.append(s.getData())
187 s.data = None
188 sobj[k] = s
189 return pickle.dumps(sobj,-1),databuffers
190 else:
191 s = serialize(can(obj))
192 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
193 databuffers.append(s.getData())
194 s.data = None
195 return pickle.dumps(s,-1),databuffers
196
197
198 def unserialize_object(bufs):
199 """reconstruct an object serialized by serialize_object from data buffers."""
200 bufs = list(bufs)
201 sobj = pickle.loads(bufs.pop(0))
202 if isinstance(sobj, (list, tuple)):
203 for s in sobj:
204 if s.data is None:
205 s.data = bufs.pop(0)
206 return uncanSequence(map(unserialize, sobj)), bufs
207 elif isinstance(sobj, dict):
208 newobj = {}
209 for k in sorted(sobj.iterkeys()):
210 s = sobj[k]
211 if s.data is None:
212 s.data = bufs.pop(0)
213 newobj[k] = uncan(unserialize(s))
214 return newobj, bufs
215 else:
216 if sobj.data is None:
217 sobj.data = bufs.pop(0)
218 return uncan(unserialize(sobj)), bufs
219
220 def pack_apply_message(f, args, kwargs, threshold=64e-6):
221 """pack up a function, args, and kwargs to be sent over the wire
222 as a series of buffers. Any object whose data is larger than `threshold`
223 will not have their data copied (currently only numpy arrays support zero-copy)"""
224 msg = [pickle.dumps(can(f),-1)]
225 databuffers = [] # for large objects
226 sargs, bufs = serialize_object(args,threshold)
227 msg.append(sargs)
228 databuffers.extend(bufs)
229 skwargs, bufs = serialize_object(kwargs,threshold)
230 msg.append(skwargs)
231 databuffers.extend(bufs)
232 msg.extend(databuffers)
233 return msg
234
235 def unpack_apply_message(bufs, g=None, copy=True):
236 """unpack f,args,kwargs from buffers packed by pack_apply_message()
237 Returns: original f,args,kwargs"""
238 bufs = list(bufs) # allow us to pop
239 assert len(bufs) >= 3, "not enough buffers!"
240 if not copy:
241 for i in range(3):
242 bufs[i] = bufs[i].bytes
243 cf = pickle.loads(bufs.pop(0))
244 sargs = list(pickle.loads(bufs.pop(0)))
245 skwargs = dict(pickle.loads(bufs.pop(0)))
246 # print sargs, skwargs
247 f = uncan(cf, g)
248 for sa in sargs:
249 if sa.data is None:
250 m = bufs.pop(0)
251 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
252 if copy:
253 sa.data = buffer(m)
254 else:
255 sa.data = m.buffer
256 else:
257 if copy:
258 sa.data = m
259 else:
260 sa.data = m.bytes
261
262 args = uncanSequence(map(unserialize, sargs), g)
263 kwargs = {}
264 for k in sorted(skwargs.iterkeys()):
265 sa = skwargs[k]
266 if sa.data is None:
267 sa.data = bufs.pop(0)
268 kwargs[k] = uncan(unserialize(sa), g)
269
270 return f,args,kwargs
271
@@ -1,657 +1,658 b''
1 """Views of remote engines"""
1 """Views of remote engines."""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 from IPython.testing import decorators as testdec
14 from IPython.utils.traitlets import HasTraits, Bool, List, Dict, Set, Int, Instance
14 from IPython.utils.traitlets import HasTraits, Any, Bool, List, Dict, Set, Int, Instance
15 15
16 16 from IPython.external.decorator import decorator
17 17
18 18 from .asyncresult import AsyncResult
19 19 from .dependency import Dependency
20 20 from .remotefunction import ParallelFunction, parallel, remote
21 21
22 22 #-----------------------------------------------------------------------------
23 23 # Decorators
24 24 #-----------------------------------------------------------------------------
25 25
26 26 @decorator
27 27 def myblock(f, self, *args, **kwargs):
28 28 """override client.block with self.block during a call"""
29 29 block = self.client.block
30 30 self.client.block = self.block
31 31 try:
32 32 ret = f(self, *args, **kwargs)
33 33 finally:
34 34 self.client.block = block
35 35 return ret
36 36
37 37 @decorator
38 38 def save_ids(f, self, *args, **kwargs):
39 39 """Keep our history and outstanding attributes up to date after a method call."""
40 40 n_previous = len(self.client.history)
41 41 ret = f(self, *args, **kwargs)
42 42 nmsgs = len(self.client.history) - n_previous
43 43 msg_ids = self.client.history[-nmsgs:]
44 44 self.history.extend(msg_ids)
45 45 map(self.outstanding.add, msg_ids)
46 46 return ret
47 47
48 48 @decorator
49 49 def sync_results(f, self, *args, **kwargs):
50 50 """sync relevant results from self.client to our results attribute."""
51 51 ret = f(self, *args, **kwargs)
52 52 delta = self.outstanding.difference(self.client.outstanding)
53 53 completed = self.outstanding.intersection(delta)
54 54 self.outstanding = self.outstanding.difference(completed)
55 55 for msg_id in completed:
56 56 self.results[msg_id] = self.client.results[msg_id]
57 57 return ret
58 58
59 59 @decorator
60 60 def spin_after(f, self, *args, **kwargs):
61 61 """call spin after the method."""
62 62 ret = f(self, *args, **kwargs)
63 63 self.spin()
64 64 return ret
65 65
66 66 #-----------------------------------------------------------------------------
67 67 # Classes
68 68 #-----------------------------------------------------------------------------
69 69
70 70 class View(HasTraits):
71 71 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
72 72
73 73 Don't use this class, use subclasses.
74 74 """
75 75 block=Bool(False)
76 76 bound=Bool(False)
77 77 history=List()
78 78 outstanding = Set()
79 79 results = Dict()
80 80 client = Instance('IPython.zmq.parallel.client.Client')
81 81
82 82 _ntargets = Int(1)
83 83 _balanced = Bool(False)
84 84 _default_names = List(['block', 'bound'])
85 _targets = None
85 _targets = Any()
86 86
87 87 def __init__(self, client=None, targets=None):
88 88 super(View, self).__init__(client=client)
89 89 self._targets = targets
90 90 self._ntargets = 1 if isinstance(targets, (int,type(None))) else len(targets)
91 91 self.block = client.block
92 92
93 93 for name in self._default_names:
94 94 setattr(self, name, getattr(self, name, None))
95 95
96 96 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
97 97
98 98
99 99 def __repr__(self):
100 100 strtargets = str(self._targets)
101 101 if len(strtargets) > 16:
102 102 strtargets = strtargets[:12]+'...]'
103 103 return "<%s %s>"%(self.__class__.__name__, strtargets)
104 104
105 105 @property
106 106 def targets(self):
107 107 return self._targets
108 108
109 109 @targets.setter
110 110 def targets(self, value):
111 111 raise AttributeError("Cannot set View `targets` after construction!")
112 112
113 113 @property
114 114 def balanced(self):
115 115 return self._balanced
116 116
117 117 @balanced.setter
118 118 def balanced(self, value):
119 119 raise AttributeError("Cannot set View `balanced` after construction!")
120 120
121 121 def _defaults(self, *excludes):
122 122 """return dict of our default attributes, excluding names given."""
123 123 d = dict(balanced=self._balanced, targets=self._targets)
124 124 for name in self._default_names:
125 125 if name not in excludes:
126 126 d[name] = getattr(self, name)
127 127 return d
128 128
129 129 def set_flags(self, **kwargs):
130 130 """set my attribute flags by keyword.
131 131
132 132 A View is a wrapper for the Client's apply method, but
133 133 with attributes that specify keyword arguments, those attributes
134 134 can be set by keyword argument with this method.
135 135
136 136 Parameters
137 137 ----------
138 138
139 139 block : bool
140 140 whether to wait for results
141 141 bound : bool
142 142 whether to use the client's namespace
143 143 """
144 144 for key in kwargs:
145 145 if key not in self._default_names:
146 146 raise KeyError("Invalid name: %r"%key)
147 147 for name in ('block', 'bound'):
148 148 if name in kwargs:
149 149 setattr(self, name, kwargs[name])
150 150
151 151 #----------------------------------------------------------------
152 152 # wrappers for client methods:
153 153 #----------------------------------------------------------------
154 154 @sync_results
155 155 def spin(self):
156 156 """spin the client, and sync"""
157 157 self.client.spin()
158 158
159 159 @sync_results
160 160 @save_ids
161 161 def apply(self, f, *args, **kwargs):
162 162 """calls f(*args, **kwargs) on remote engines, returning the result.
163 163
164 164 This method does not involve the engine's namespace.
165 165
166 166 if self.block is False:
167 167 returns msg_id
168 168 else:
169 169 returns actual result of f(*args, **kwargs)
170 170 """
171 171 return self.client.apply(f, args, kwargs, **self._defaults())
172 172
173 173 @save_ids
174 174 def apply_async(self, f, *args, **kwargs):
175 175 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
176 176
177 177 This method does not involve the engine's namespace.
178 178
179 179 returns msg_id
180 180 """
181 181 d = self._defaults('block', 'bound')
182 182 return self.client.apply(f,args,kwargs, block=False, bound=False, **d)
183 183
184 184 @spin_after
185 185 @save_ids
186 186 def apply_sync(self, f, *args, **kwargs):
187 187 """calls f(*args, **kwargs) on remote engines in a blocking manner,
188 188 returning the result.
189 189
190 190 This method does not involve the engine's namespace.
191 191
192 192 returns: actual result of f(*args, **kwargs)
193 193 """
194 194 d = self._defaults('block', 'bound')
195 195 return self.client.apply(f,args,kwargs, block=True, bound=False, **d)
196 196
197 197 # @sync_results
198 198 # @save_ids
199 199 # def apply_bound(self, f, *args, **kwargs):
200 200 # """calls f(*args, **kwargs) bound to engine namespace(s).
201 201 #
202 202 # if self.block is False:
203 203 # returns msg_id
204 204 # else:
205 205 # returns actual result of f(*args, **kwargs)
206 206 #
207 207 # This method has access to the targets' namespace via globals()
208 208 #
209 209 # """
210 210 # d = self._defaults('bound')
211 211 # return self.client.apply(f, args, kwargs, bound=True, **d)
212 212 #
213 213 @sync_results
214 214 @save_ids
215 215 def apply_async_bound(self, f, *args, **kwargs):
216 216 """calls f(*args, **kwargs) bound to engine namespace(s)
217 217 in a nonblocking manner.
218 218
219 219 returns: msg_id
220 220
221 221 This method has access to the targets' namespace via globals()
222 222
223 223 """
224 224 d = self._defaults('block', 'bound')
225 225 return self.client.apply(f, args, kwargs, block=False, bound=True, **d)
226 226
227 227 @spin_after
228 228 @save_ids
229 229 def apply_sync_bound(self, f, *args, **kwargs):
230 230 """calls f(*args, **kwargs) bound to engine namespace(s), waiting for the result.
231 231
232 232 returns: actual result of f(*args, **kwargs)
233 233
234 234 This method has access to the targets' namespace via globals()
235 235
236 236 """
237 237 d = self._defaults('block', 'bound')
238 238 return self.client.apply(f, args, kwargs, block=True, bound=True, **d)
239 239
240 240 def abort(self, jobs=None, block=None):
241 241 """Abort jobs on my engines.
242 242
243 243 Parameters
244 244 ----------
245 245
246 246 jobs : None, str, list of strs, optional
247 247 if None: abort all jobs.
248 248 else: abort specific msg_id(s).
249 249 """
250 250 block = block if block is not None else self.block
251 251 return self.client.abort(jobs=jobs, targets=self._targets, block=block)
252 252
253 253 def queue_status(self, verbose=False):
254 254 """Fetch the Queue status of my engines"""
255 255 return self.client.queue_status(targets=self._targets, verbose=verbose)
256 256
257 257 def purge_results(self, jobs=[], targets=[]):
258 258 """Instruct the controller to forget specific results."""
259 259 if targets is None or targets == 'all':
260 260 targets = self._targets
261 261 return self.client.purge_results(jobs=jobs, targets=targets)
262 262
263 263 @spin_after
264 264 def get_result(self, indices_or_msg_ids=None):
265 265 """return one or more results, specified by history index or msg_id.
266 266
267 267 See client.get_result for details.
268 268
269 269 """
270 270
271 271 if indices_or_msg_ids is None:
272 272 indices_or_msg_ids = -1
273 273 if isinstance(indices_or_msg_ids, int):
274 274 indices_or_msg_ids = self.history[indices_or_msg_ids]
275 275 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
276 276 indices_or_msg_ids = list(indices_or_msg_ids)
277 277 for i,index in enumerate(indices_or_msg_ids):
278 278 if isinstance(index, int):
279 279 indices_or_msg_ids[i] = self.history[index]
280 280 return self.client.get_result(indices_or_msg_ids)
281 281
282 282 #-------------------------------------------------------------------
283 283 # Map
284 284 #-------------------------------------------------------------------
285 285
286 286 def map(self, f, *sequences, **kwargs):
287 287 """override in subclasses"""
288 288 raise NotImplementedError
289 289
290 290 def map_async(self, f, *sequences, **kwargs):
291 291 """Parallel version of builtin `map`, using this view's engines.
292 292
293 293 This is equivalent to map(...block=False)
294 294
295 295 See `self.map` for details.
296 296 """
297 297 if 'block' in kwargs:
298 298 raise TypeError("map_async doesn't take a `block` keyword argument.")
299 299 kwargs['block'] = False
300 300 return self.map(f,*sequences,**kwargs)
301 301
302 302 def map_sync(self, f, *sequences, **kwargs):
303 303 """Parallel version of builtin `map`, using this view's engines.
304 304
305 305 This is equivalent to map(...block=True)
306 306
307 307 See `self.map` for details.
308 308 """
309 309 if 'block' in kwargs:
310 310 raise TypeError("map_sync doesn't take a `block` keyword argument.")
311 311 kwargs['block'] = True
312 312 return self.map(f,*sequences,**kwargs)
313 313
314 314 def imap(self, f, *sequences, **kwargs):
315 315 """Parallel version of `itertools.imap`.
316 316
317 317 See `self.map` for details.
318 318 """
319 319
320 320 return iter(self.map_async(f,*sequences, **kwargs))
321 321
322 322 #-------------------------------------------------------------------
323 323 # Decorators
324 324 #-------------------------------------------------------------------
325 325
326 326 def remote(self, bound=True, block=True):
327 327 """Decorator for making a RemoteFunction"""
328 328 return remote(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced)
329 329
330 330 def parallel(self, dist='b', bound=True, block=None):
331 331 """Decorator for making a ParallelFunction"""
332 332 block = self.block if block is None else block
333 333 return parallel(self.client, bound=bound, targets=self._targets, block=block, balanced=self._balanced)
334 334
335 335 @testdec.skip_doctest
336 336 class DirectView(View):
337 337 """Direct Multiplexer View of one or more engines.
338 338
339 339 These are created via indexed access to a client:
340 340
341 341 >>> dv_1 = client[1]
342 342 >>> dv_all = client[:]
343 343 >>> dv_even = client[::2]
344 344 >>> dv_some = client[1:3]
345 345
346 346 This object provides dictionary access to engine namespaces:
347 347
348 348 # push a=5:
349 349 >>> dv['a'] = 5
350 350 # pull 'foo':
351 351 >>> db['foo']
352 352
353 353 """
354 354
355 355 def __init__(self, client=None, targets=None):
356 356 super(DirectView, self).__init__(client=client, targets=targets)
357 357 self._balanced = False
358 358
359 359 @spin_after
360 360 @save_ids
361 361 def map(self, f, *sequences, **kwargs):
362 362 """view.map(f, *sequences, block=self.block, bound=self.bound) => list|AsyncMapResult
363 363
364 364 Parallel version of builtin `map`, using this View's `targets`.
365 365
366 366 There will be one task per target, so work will be chunked
367 367 if the sequences are longer than `targets`.
368 368
369 369 Results can be iterated as they are ready, but will become available in chunks.
370 370
371 371 Parameters
372 372 ----------
373 373
374 374 f : callable
375 375 function to be mapped
376 376 *sequences: one or more sequences of matching length
377 377 the sequences to be distributed and passed to `f`
378 378 block : bool
379 379 whether to wait for the result or not [default self.block]
380 380 bound : bool
381 381 whether to have access to the engines' namespaces [default self.bound]
382 382
383 383 Returns
384 384 -------
385 385
386 386 if block=False:
387 387 AsyncMapResult
388 388 An object like AsyncResult, but which reassembles the sequence of results
389 389 into a single list. AsyncMapResults can be iterated through before all
390 390 results are complete.
391 391 else:
392 392 list
393 393 the result of map(f,*sequences)
394 394 """
395 395
396 396 block = kwargs.get('block', self.block)
397 397 bound = kwargs.get('bound', self.bound)
398 398 for k in kwargs.keys():
399 399 if k not in ['block', 'bound']:
400 400 raise TypeError("invalid keyword arg, %r"%k)
401 401
402 402 assert len(sequences) > 0, "must have some sequences to map onto!"
403 403 pf = ParallelFunction(self.client, f, block=block, bound=bound,
404 404 targets=self._targets, balanced=False)
405 405 return pf.map(*sequences)
406 406
407 407 @sync_results
408 408 @save_ids
409 409 def execute(self, code, block=None):
410 410 """execute some code on my targets."""
411 411
412 412 block = block if block is not None else self.block
413 413
414 414 return self.client.execute(code, block=block, targets=self._targets)
415 415
416 416 @sync_results
417 417 @save_ids
418 418 def run(self, fname, block=None):
419 419 """execute the code in a file on my targets."""
420 420
421 421 block = block if block is not None else self.block
422 422
423 423 return self.client.run(fname, block=block, targets=self._targets)
424 424
425 425 def update(self, ns):
426 426 """update remote namespace with dict `ns`"""
427 427 return self.client.push(ns, targets=self._targets, block=self.block)
428 428
429 429 def push(self, ns, block=None):
430 430 """update remote namespace with dict `ns`"""
431 431
432 432 block = block if block is not None else self.block
433 433
434 434 return self.client.push(ns, targets=self._targets, block=block)
435 435
436 436 def get(self, key_s):
437 437 """get object(s) by `key_s` from remote namespace
438 438 will return one object if it is a key.
439 439 It also takes a list of keys, and will return a list of objects."""
440 440 # block = block if block is not None else self.block
441 441 return self.client.pull(key_s, block=True, targets=self._targets)
442 442
443 443 @sync_results
444 444 @save_ids
445 445 def pull(self, key_s, block=True):
446 446 """get object(s) by `key_s` from remote namespace
447 447 will return one object if it is a key.
448 448 It also takes a list of keys, and will return a list of objects."""
449 449 block = block if block is not None else self.block
450 450 return self.client.pull(key_s, block=block, targets=self._targets)
451 451
452 452 def scatter(self, key, seq, dist='b', flatten=False, block=None):
453 453 """
454 454 Partition a Python sequence and send the partitions to a set of engines.
455 455 """
456 456 block = block if block is not None else self.block
457 457
458 458 return self.client.scatter(key, seq, dist=dist, flatten=flatten,
459 459 targets=self._targets, block=block)
460 460
461 461 @sync_results
462 462 @save_ids
463 463 def gather(self, key, dist='b', block=None):
464 464 """
465 465 Gather a partitioned sequence on a set of engines as a single local seq.
466 466 """
467 467 block = block if block is not None else self.block
468 468
469 469 return self.client.gather(key, dist=dist, targets=self._targets, block=block)
470 470
471 471 def __getitem__(self, key):
472 472 return self.get(key)
473 473
474 474 def __setitem__(self,key, value):
475 475 self.update({key:value})
476 476
477 477 def clear(self, block=False):
478 478 """Clear the remote namespaces on my engines."""
479 479 block = block if block is not None else self.block
480 480 return self.client.clear(targets=self._targets, block=block)
481 481
482 482 def kill(self, block=True):
483 483 """Kill my engines."""
484 484 block = block if block is not None else self.block
485 485 return self.client.kill(targets=self._targets, block=block)
486 486
487 487 #----------------------------------------
488 488 # activate for %px,%autopx magics
489 489 #----------------------------------------
490 490 def activate(self):
491 491 """Make this `View` active for parallel magic commands.
492 492
493 493 IPython has a magic command syntax to work with `MultiEngineClient` objects.
494 494 In a given IPython session there is a single active one. While
495 495 there can be many `Views` created and used by the user,
496 496 there is only one active one. The active `View` is used whenever
497 497 the magic commands %px and %autopx are used.
498 498
499 499 The activate() method is called on a given `View` to make it
500 500 active. Once this has been done, the magic commands can be used.
501 501 """
502 502
503 503 try:
504 504 # This is injected into __builtins__.
505 505 ip = get_ipython()
506 506 except NameError:
507 507 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
508 508 else:
509 509 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
510 510 if pmagic is not None:
511 511 pmagic.active_multiengine_client = self
512 512 else:
513 513 print "You must first load the parallelmagic extension " \
514 514 "by doing '%load_ext parallelmagic'"
515 515
516 516
517 517 @testdec.skip_doctest
518 518 class LoadBalancedView(View):
519 519 """An load-balancing View that only executes via the Task scheduler.
520 520
521 521 Load-balanced views can be created with the client's `view` method:
522 522
523 523 >>> v = client.view(balanced=True)
524 524
525 525 or targets can be specified, to restrict the potential destinations:
526 526
527 527 >>> v = client.view([1,3],balanced=True)
528 528
529 529 which would restrict loadbalancing to between engines 1 and 3.
530 530
531 531 """
532 532
533 533 _default_names = ['block', 'bound', 'follow', 'after', 'timeout']
534 534
535 535 def __init__(self, client=None, targets=None):
536 536 super(LoadBalancedView, self).__init__(client=client, targets=targets)
537 537 self._ntargets = 1
538 538 self._balanced = True
539 539
540 540 def _validate_dependency(self, dep):
541 541 """validate a dependency.
542 542
543 543 For use in `set_flags`.
544 544 """
545 545 if dep is None or isinstance(dep, (str, AsyncResult, Dependency)):
546 546 return True
547 547 elif isinstance(dep, (list,set, tuple)):
548 548 for d in dep:
549 549 if not isinstance(d, str, AsyncResult):
550 550 return False
551 551 elif isinstance(dep, dict):
552 552 if set(dep.keys()) != set(Dependency().as_dict().keys()):
553 553 return False
554 554 if not isinstance(dep['msg_ids'], list):
555 555 return False
556 556 for d in dep['msg_ids']:
557 557 if not isinstance(d, str):
558 558 return False
559 559 else:
560 560 return False
561 561
562 562 def set_flags(self, **kwargs):
563 563 """set my attribute flags by keyword.
564 564
565 565 A View is a wrapper for the Client's apply method, but with attributes
566 566 that specify keyword arguments, those attributes can be set by keyword
567 567 argument with this method.
568 568
569 569 Parameters
570 570 ----------
571 571
572 572 block : bool
573 573 whether to wait for results
574 574 bound : bool
575 575 whether to use the engine's namespace
576 576 follow : Dependency, list, msg_id, AsyncResult
577 577 the location dependencies of tasks
578 578 after : Dependency, list, msg_id, AsyncResult
579 579 the time dependencies of tasks
580 580 timeout : int,None
581 581 the timeout to be used for tasks
582 582 """
583 583
584 584 super(LoadBalancedView, self).set_flags(**kwargs)
585 585 for name in ('follow', 'after'):
586 586 if name in kwargs:
587 587 value = kwargs[name]
588 588 if self._validate_dependency(value):
589 589 setattr(self, name, value)
590 590 else:
591 591 raise ValueError("Invalid dependency: %r"%value)
592 592 if 'timeout' in kwargs:
593 593 t = kwargs['timeout']
594 594 if not isinstance(t, (int, long, float, None)):
595 595 raise TypeError("Invalid type for timeout: %r"%type(t))
596 596 if t is not None:
597 597 if t < 0:
598 598 raise ValueError("Invalid timeout: %s"%t)
599 599 self.timeout = t
600 600
601 601 @spin_after
602 602 @save_ids
603 603 def map(self, f, *sequences, **kwargs):
604 604 """view.map(f, *sequences, block=self.block, bound=self.bound, chunk_size=1) => list|AsyncMapResult
605 605
606 606 Parallel version of builtin `map`, load-balanced by this View.
607 607
608 608 `block`, `bound`, and `chunk_size` can be specified by keyword only.
609 609
610 610 Each `chunk_size` elements will be a separate task, and will be
611 611 load-balanced. This lets individual elements be available for iteration
612 612 as soon as they arrive.
613 613
614 614 Parameters
615 615 ----------
616 616
617 617 f : callable
618 618 function to be mapped
619 619 *sequences: one or more sequences of matching length
620 620 the sequences to be distributed and passed to `f`
621 621 block : bool
622 622 whether to wait for the result or not [default self.block]
623 623 bound : bool
624 624 whether to use the engine's namespace [default self.bound]
625 625 chunk_size : int
626 626 how many elements should be in each task [default 1]
627 627
628 628 Returns
629 629 -------
630 630
631 631 if block=False:
632 632 AsyncMapResult
633 633 An object like AsyncResult, but which reassembles the sequence of results
634 634 into a single list. AsyncMapResults can be iterated through before all
635 635 results are complete.
636 636 else:
637 637 the result of map(f,*sequences)
638 638
639 639 """
640 640
641 641 # default
642 642 block = kwargs.get('block', self.block)
643 643 bound = kwargs.get('bound', self.bound)
644 644 chunk_size = kwargs.get('chunk_size', 1)
645 645
646 646 keyset = set(kwargs.keys())
647 647 extra_keys = keyset.difference_update(set(['block', 'bound', 'chunk_size']))
648 648 if extra_keys:
649 649 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
650 650
651 651 assert len(sequences) > 0, "must have some sequences to map onto!"
652 652
653 653 pf = ParallelFunction(self.client, f, block=block, bound=bound,
654 654 targets=self._targets, balanced=True,
655 655 chunk_size=chunk_size)
656 656 return pf.map(*sequences)
657 657
658 __all__ = ['LoadBalancedView', 'DirectView'] No newline at end of file
General Comments 0
You need to be logged in to leave comments. Login now