##// END OF EJS Templates
Improvements to dependency handling...
MinRK -
Show More
@@ -1,115 +1,127 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Pickle related utilities. Perhaps this should be called 'can'."""
3 """Pickle related utilities. Perhaps this should be called 'can'."""
4
4
5 __docformat__ = "restructuredtext en"
5 __docformat__ = "restructuredtext en"
6
6
7 #-------------------------------------------------------------------------------
7 #-------------------------------------------------------------------------------
8 # Copyright (C) 2008 The IPython Development Team
8 # Copyright (C) 2008 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 #-------------------------------------------------------------------------------
14 #-------------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 from types import FunctionType
18 from types import FunctionType
19 import copy
19
20
20 # contents of codeutil should either be in here, or codeutil belongs in IPython/util
21 from IPython.zmq.parallel.dependency import dependent
21 from IPython.zmq.parallel.dependency import dependent
22
22 import codeutil
23 import codeutil
23
24
25 #-------------------------------------------------------------------------------
26 # Classes
27 #-------------------------------------------------------------------------------
28
29
24 class CannedObject(object):
30 class CannedObject(object):
25 def __init__(self, obj, keys=[]):
31 def __init__(self, obj, keys=[]):
26 self.keys = keys
32 self.keys = keys
27 self.obj = obj
33 self.obj = copy.copy(obj)
28 for key in keys:
34 for key in keys:
29 setattr(obj, key, can(getattr(obj, key)))
35 setattr(self.obj, key, can(getattr(obj, key)))
30
36
31
37
32 def getObject(self, g=None):
38 def getObject(self, g=None):
33 if g is None:
39 if g is None:
34 g = globals()
40 g = globals()
35 for key in self.keys:
41 for key in self.keys:
36 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
42 setattr(self.obj, key, uncan(getattr(self.obj, key), g))
37 return self.obj
43 return self.obj
38
44
39
45
40
46
41 class CannedFunction(CannedObject):
47 class CannedFunction(CannedObject):
42
48
43 def __init__(self, f):
49 def __init__(self, f):
44 self._checkType(f)
50 self._checkType(f)
45 self.code = f.func_code
51 self.code = f.func_code
52 self.__name__ = f.__name__
46
53
47 def _checkType(self, obj):
54 def _checkType(self, obj):
48 assert isinstance(obj, FunctionType), "Not a function type"
55 assert isinstance(obj, FunctionType), "Not a function type"
49
56
50 def getFunction(self, g=None):
57 def getFunction(self, g=None):
51 if g is None:
58 if g is None:
52 g = globals()
59 g = globals()
53 newFunc = FunctionType(self.code, g)
60 newFunc = FunctionType(self.code, g)
54 return newFunc
61 return newFunc
55
62
63 #-------------------------------------------------------------------------------
64 # Functions
65 #-------------------------------------------------------------------------------
66
67
56 def can(obj):
68 def can(obj):
57 if isinstance(obj, FunctionType):
69 if isinstance(obj, FunctionType):
58 return CannedFunction(obj)
70 return CannedFunction(obj)
59 elif isinstance(obj, dependent):
71 elif isinstance(obj, dependent):
60 keys = ('f','df')
72 keys = ('f','df')
61 return CannedObject(obj, keys=keys)
73 return CannedObject(obj, keys=keys)
62 elif isinstance(obj,dict):
74 elif isinstance(obj,dict):
63 return canDict(obj)
75 return canDict(obj)
64 elif isinstance(obj, (list,tuple)):
76 elif isinstance(obj, (list,tuple)):
65 return canSequence(obj)
77 return canSequence(obj)
66 else:
78 else:
67 return obj
79 return obj
68
80
69 def canDict(obj):
81 def canDict(obj):
70 if isinstance(obj, dict):
82 if isinstance(obj, dict):
71 newobj = {}
83 newobj = {}
72 for k, v in obj.iteritems():
84 for k, v in obj.iteritems():
73 newobj[k] = can(v)
85 newobj[k] = can(v)
74 return newobj
86 return newobj
75 else:
87 else:
76 return obj
88 return obj
77
89
78 def canSequence(obj):
90 def canSequence(obj):
79 if isinstance(obj, (list, tuple)):
91 if isinstance(obj, (list, tuple)):
80 t = type(obj)
92 t = type(obj)
81 return t([can(i) for i in obj])
93 return t([can(i) for i in obj])
82 else:
94 else:
83 return obj
95 return obj
84
96
85 def uncan(obj, g=None):
97 def uncan(obj, g=None):
86 if isinstance(obj, CannedFunction):
98 if isinstance(obj, CannedFunction):
87 return obj.getFunction(g)
99 return obj.getFunction(g)
88 elif isinstance(obj, CannedObject):
100 elif isinstance(obj, CannedObject):
89 return obj.getObject(g)
101 return obj.getObject(g)
90 elif isinstance(obj,dict):
102 elif isinstance(obj,dict):
91 return uncanDict(obj)
103 return uncanDict(obj)
92 elif isinstance(obj, (list,tuple)):
104 elif isinstance(obj, (list,tuple)):
93 return uncanSequence(obj)
105 return uncanSequence(obj)
94 else:
106 else:
95 return obj
107 return obj
96
108
97 def uncanDict(obj, g=None):
109 def uncanDict(obj, g=None):
98 if isinstance(obj, dict):
110 if isinstance(obj, dict):
99 newobj = {}
111 newobj = {}
100 for k, v in obj.iteritems():
112 for k, v in obj.iteritems():
101 newobj[k] = uncan(v,g)
113 newobj[k] = uncan(v,g)
102 return newobj
114 return newobj
103 else:
115 else:
104 return obj
116 return obj
105
117
106 def uncanSequence(obj, g=None):
118 def uncanSequence(obj, g=None):
107 if isinstance(obj, (list, tuple)):
119 if isinstance(obj, (list, tuple)):
108 t = type(obj)
120 t = type(obj)
109 return t([uncan(i,g) for i in obj])
121 return t([uncan(i,g) for i in obj])
110 else:
122 else:
111 return obj
123 return obj
112
124
113
125
114 def rebindFunctionGlobals(f, glbls):
126 def rebindFunctionGlobals(f, glbls):
115 return FunctionType(f.func_code, glbls)
127 return FunctionType(f.func_code, glbls)
@@ -1,188 +1,200 b''
1 """AsyncResult objects for the client"""
1 """AsyncResult objects for the client"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 from IPython.external.decorator import decorator
13 from IPython.external.decorator import decorator
14 import error
14 import error
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Classes
17 # Classes
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 @decorator
20 @decorator
21 def check_ready(f, self, *args, **kwargs):
21 def check_ready(f, self, *args, **kwargs):
22 """Call spin() to sync state prior to calling the method."""
22 """Call spin() to sync state prior to calling the method."""
23 self.wait(0)
23 self.wait(0)
24 if not self._ready:
24 if not self._ready:
25 raise error.TimeoutError("result not ready")
25 raise error.TimeoutError("result not ready")
26 return f(self, *args, **kwargs)
26 return f(self, *args, **kwargs)
27
27
28 class AsyncResult(object):
28 class AsyncResult(object):
29 """Class for representing results of non-blocking calls.
29 """Class for representing results of non-blocking calls.
30
30
31 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
31 Provides the same interface as :py:class:`multiprocessing.AsyncResult`.
32 """
32 """
33 def __init__(self, client, msg_ids, fname=''):
33 def __init__(self, client, msg_ids, fname=''):
34 self._client = client
34 self._client = client
35 self.msg_ids = msg_ids
35 self.msg_ids = msg_ids
36 self._fname=fname
36 self._fname=fname
37 self._ready = False
37 self._ready = False
38 self._success = None
38 self._success = None
39 self._flatten_result = len(msg_ids) == 1
39
40
40 def __repr__(self):
41 def __repr__(self):
41 if self._ready:
42 if self._ready:
42 return "<%s: finished>"%(self.__class__.__name__)
43 return "<%s: finished>"%(self.__class__.__name__)
43 else:
44 else:
44 return "<%s: %s>"%(self.__class__.__name__,self._fname)
45 return "<%s: %s>"%(self.__class__.__name__,self._fname)
45
46
46
47
47 def _reconstruct_result(self, res):
48 def _reconstruct_result(self, res):
48 """
49 """
49 Override me in subclasses for turning a list of results
50 Override me in subclasses for turning a list of results
50 into the expected form.
51 into the expected form.
51 """
52 """
52 if len(self.msg_ids) == 1:
53 if self._flatten_result:
53 return res[0]
54 return res[0]
54 else:
55 else:
55 return res
56 return res
56
57
57 def get(self, timeout=-1):
58 def get(self, timeout=-1):
58 """Return the result when it arrives.
59 """Return the result when it arrives.
59
60
60 If `timeout` is not ``None`` and the result does not arrive within
61 If `timeout` is not ``None`` and the result does not arrive within
61 `timeout` seconds then ``TimeoutError`` is raised. If the
62 `timeout` seconds then ``TimeoutError`` is raised. If the
62 remote call raised an exception then that exception will be reraised
63 remote call raised an exception then that exception will be reraised
63 by get().
64 by get().
64 """
65 """
65 if not self.ready():
66 if not self.ready():
66 self.wait(timeout)
67 self.wait(timeout)
67
68
68 if self._ready:
69 if self._ready:
69 if self._success:
70 if self._success:
70 return self._result
71 return self._result
71 else:
72 else:
72 raise self._exception
73 raise self._exception
73 else:
74 else:
74 raise error.TimeoutError("Result not ready.")
75 raise error.TimeoutError("Result not ready.")
75
76
76 def ready(self):
77 def ready(self):
77 """Return whether the call has completed."""
78 """Return whether the call has completed."""
78 if not self._ready:
79 if not self._ready:
79 self.wait(0)
80 self.wait(0)
80 return self._ready
81 return self._ready
81
82
82 def wait(self, timeout=-1):
83 def wait(self, timeout=-1):
83 """Wait until the result is available or until `timeout` seconds pass.
84 """Wait until the result is available or until `timeout` seconds pass.
84 """
85 """
85 if self._ready:
86 if self._ready:
86 return
87 return
87 self._ready = self._client.barrier(self.msg_ids, timeout)
88 self._ready = self._client.barrier(self.msg_ids, timeout)
88 if self._ready:
89 if self._ready:
89 try:
90 try:
90 results = map(self._client.results.get, self.msg_ids)
91 results = map(self._client.results.get, self.msg_ids)
91 self._result = results
92 self._result = results
92 results = error.collect_exceptions(results, self._fname)
93 results = error.collect_exceptions(results, self._fname)
93 self._result = self._reconstruct_result(results)
94 self._result = self._reconstruct_result(results)
94 except Exception, e:
95 except Exception, e:
95 self._exception = e
96 self._exception = e
96 self._success = False
97 self._success = False
97 else:
98 else:
98 self._success = True
99 self._success = True
99 finally:
100 finally:
100 self._metadata = map(self._client.metadata.get, self.msg_ids)
101 self._metadata = map(self._client.metadata.get, self.msg_ids)
101
102
102
103
103 def successful(self):
104 def successful(self):
104 """Return whether the call completed without raising an exception.
105 """Return whether the call completed without raising an exception.
105
106
106 Will raise ``AssertionError`` if the result is not ready.
107 Will raise ``AssertionError`` if the result is not ready.
107 """
108 """
108 assert self._ready
109 assert self._ready
109 return self._success
110 return self._success
110
111
111 #----------------------------------------------------------------
112 #----------------------------------------------------------------
112 # Extra methods not in mp.pool.AsyncResult
113 # Extra methods not in mp.pool.AsyncResult
113 #----------------------------------------------------------------
114 #----------------------------------------------------------------
114
115
115 def get_dict(self, timeout=-1):
116 def get_dict(self, timeout=-1):
116 """Get the results as a dict, keyed by engine_id."""
117 """Get the results as a dict, keyed by engine_id."""
117 results = self.get(timeout)
118 results = self.get(timeout)
118 engine_ids = [md['engine_id'] for md in self._metadata ]
119 engine_ids = [ md['engine_id'] for md in self._metadata ]
119 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
120 bycount = sorted(engine_ids, key=lambda k: engine_ids.count(k))
120 maxcount = bycount.count(bycount[-1])
121 maxcount = bycount.count(bycount[-1])
121 if maxcount > 1:
122 if maxcount > 1:
122 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
123 raise ValueError("Cannot build dict, %i jobs ran on engine #%i"%(
123 maxcount, bycount[-1]))
124 maxcount, bycount[-1]))
124
125
125 return dict(zip(engine_ids,results))
126 return dict(zip(engine_ids,results))
126
127
127 @property
128 @property
128 @check_ready
129 @check_ready
129 def result(self):
130 def result(self):
130 """result property."""
131 """result property."""
131 return self._result
132 return self._result
132
133
134 # abbreviated alias:
135 r = result
136
133 @property
137 @property
134 @check_ready
138 @check_ready
135 def metadata(self):
139 def metadata(self):
136 """metadata property."""
140 """metadata property."""
137 return self._metadata
141 if self._flatten_result:
142 return self._metadata[0]
143 else:
144 return self._metadata
138
145
139 @property
146 @property
140 def result_dict(self):
147 def result_dict(self):
141 """result property as a dict."""
148 """result property as a dict."""
142 return self.get_dict(0)
149 return self.get_dict(0)
143
150
144 def __dict__(self):
151 def __dict__(self):
145 return self.get_dict(0)
152 return self.get_dict(0)
146
153
147 #-------------------------------------
154 #-------------------------------------
148 # dict-access
155 # dict-access
149 #-------------------------------------
156 #-------------------------------------
150
157
151 @check_ready
158 @check_ready
152 def __getitem__(self, key):
159 def __getitem__(self, key):
153 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
160 """getitem returns result value(s) if keyed by int/slice, or metadata if key is str.
154 """
161 """
155 if isinstance(key, int):
162 if isinstance(key, int):
156 return error.collect_exceptions([self._result[key]], self._fname)[0]
163 return error.collect_exceptions([self._result[key]], self._fname)[0]
157 elif isinstance(key, slice):
164 elif isinstance(key, slice):
158 return error.collect_exceptions(self._result[key], self._fname)
165 return error.collect_exceptions(self._result[key], self._fname)
159 elif isinstance(key, basestring):
166 elif isinstance(key, basestring):
160 return [ md[key] for md in self._metadata ]
167 values = [ md[key] for md in self._metadata ]
168 if self._flatten_result:
169 return values[0]
170 else:
171 return values
161 else:
172 else:
162 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
173 raise TypeError("Invalid key type %r, must be 'int','slice', or 'str'"%type(key))
163
174
164 @check_ready
175 @check_ready
165 def __getattr__(self, key):
176 def __getattr__(self, key):
166 """getattr maps to getitem for convenient access to metadata."""
177 """getattr maps to getitem for convenient access to metadata."""
167 if key not in self._metadata[0].keys():
178 if key not in self._metadata[0].keys():
168 raise AttributeError("%r object has no attribute %r"%(
179 raise AttributeError("%r object has no attribute %r"%(
169 self.__class__.__name__, key))
180 self.__class__.__name__, key))
170 return self.__getitem__(key)
181 return self.__getitem__(key)
171
182
172
183
173 class AsyncMapResult(AsyncResult):
184 class AsyncMapResult(AsyncResult):
174 """Class for representing results of non-blocking gathers.
185 """Class for representing results of non-blocking gathers.
175
186
176 This will properly reconstruct the gather.
187 This will properly reconstruct the gather.
177 """
188 """
178
189
179 def __init__(self, client, msg_ids, mapObject, fname=''):
190 def __init__(self, client, msg_ids, mapObject, fname=''):
180 self._mapObject = mapObject
181 AsyncResult.__init__(self, client, msg_ids, fname=fname)
191 AsyncResult.__init__(self, client, msg_ids, fname=fname)
192 self._mapObject = mapObject
193 self._flatten_result = False
182
194
183 def _reconstruct_result(self, res):
195 def _reconstruct_result(self, res):
184 """Perform the gather on the actual results."""
196 """Perform the gather on the actual results."""
185 return self._mapObject.joinPartitions(res)
197 return self._mapObject.joinPartitions(res)
186
198
187
199
188 __all__ = ['AsyncResult', 'AsyncMapResult'] No newline at end of file
200 __all__ = ['AsyncResult', 'AsyncMapResult']
@@ -1,1164 +1,1206 b''
1 """A semi-synchronous Client for the ZMQ controller"""
1 """A semi-synchronous Client for the ZMQ controller"""
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010 The IPython Development Team
3 # Copyright (C) 2010 The IPython Development Team
4 #
4 #
5 # Distributed under the terms of the BSD License. The full license is in
5 # Distributed under the terms of the BSD License. The full license is in
6 # the file COPYING, distributed as part of this software.
6 # the file COPYING, distributed as part of this software.
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Imports
10 # Imports
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 import os
13 import os
14 import time
14 import time
15 from getpass import getpass
15 from getpass import getpass
16 from pprint import pprint
16 from pprint import pprint
17 from datetime import datetime
17 from datetime import datetime
18
18
19 import zmq
19 import zmq
20 from zmq.eventloop import ioloop, zmqstream
20 from zmq.eventloop import ioloop, zmqstream
21
21
22 from IPython.external.decorator import decorator
22 from IPython.external.decorator import decorator
23 from IPython.zmq import tunnel
23 from IPython.zmq import tunnel
24
24
25 import streamsession as ss
25 import streamsession as ss
26 # from remotenamespace import RemoteNamespace
26 # from remotenamespace import RemoteNamespace
27 from view import DirectView, LoadBalancedView
27 from view import DirectView, LoadBalancedView
28 from dependency import Dependency, depend, require
28 from dependency import Dependency, depend, require
29 import error
29 import error
30 import map as Map
30 import map as Map
31 from asyncresult import AsyncResult, AsyncMapResult
31 from asyncresult import AsyncResult, AsyncMapResult
32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
32 from remotefunction import remote,parallel,ParallelFunction,RemoteFunction
33 from util import ReverseDict
33 from util import ReverseDict
34
34
35 #--------------------------------------------------------------------------
35 #--------------------------------------------------------------------------
36 # helpers for implementing old MEC API via client.apply
36 # helpers for implementing old MEC API via client.apply
37 #--------------------------------------------------------------------------
37 #--------------------------------------------------------------------------
38
38
39 def _push(ns):
39 def _push(ns):
40 """helper method for implementing `client.push` via `client.apply`"""
40 """helper method for implementing `client.push` via `client.apply`"""
41 globals().update(ns)
41 globals().update(ns)
42
42
43 def _pull(keys):
43 def _pull(keys):
44 """helper method for implementing `client.pull` via `client.apply`"""
44 """helper method for implementing `client.pull` via `client.apply`"""
45 g = globals()
45 g = globals()
46 if isinstance(keys, (list,tuple, set)):
46 if isinstance(keys, (list,tuple, set)):
47 for key in keys:
47 for key in keys:
48 if not g.has_key(key):
48 if not g.has_key(key):
49 raise NameError("name '%s' is not defined"%key)
49 raise NameError("name '%s' is not defined"%key)
50 return map(g.get, keys)
50 return map(g.get, keys)
51 else:
51 else:
52 if not g.has_key(keys):
52 if not g.has_key(keys):
53 raise NameError("name '%s' is not defined"%keys)
53 raise NameError("name '%s' is not defined"%keys)
54 return g.get(keys)
54 return g.get(keys)
55
55
56 def _clear():
56 def _clear():
57 """helper method for implementing `client.clear` via `client.apply`"""
57 """helper method for implementing `client.clear` via `client.apply`"""
58 globals().clear()
58 globals().clear()
59
59
60 def _execute(code):
60 def _execute(code):
61 """helper method for implementing `client.execute` via `client.apply`"""
61 """helper method for implementing `client.execute` via `client.apply`"""
62 exec code in globals()
62 exec code in globals()
63
63
64
64
65 #--------------------------------------------------------------------------
65 #--------------------------------------------------------------------------
66 # Decorators for Client methods
66 # Decorators for Client methods
67 #--------------------------------------------------------------------------
67 #--------------------------------------------------------------------------
68
68
69 @decorator
69 @decorator
70 def spinfirst(f, self, *args, **kwargs):
70 def spinfirst(f, self, *args, **kwargs):
71 """Call spin() to sync state prior to calling the method."""
71 """Call spin() to sync state prior to calling the method."""
72 self.spin()
72 self.spin()
73 return f(self, *args, **kwargs)
73 return f(self, *args, **kwargs)
74
74
75 @decorator
75 @decorator
76 def defaultblock(f, self, *args, **kwargs):
76 def defaultblock(f, self, *args, **kwargs):
77 """Default to self.block; preserve self.block."""
77 """Default to self.block; preserve self.block."""
78 block = kwargs.get('block',None)
78 block = kwargs.get('block',None)
79 block = self.block if block is None else block
79 block = self.block if block is None else block
80 saveblock = self.block
80 saveblock = self.block
81 self.block = block
81 self.block = block
82 try:
82 try:
83 ret = f(self, *args, **kwargs)
83 ret = f(self, *args, **kwargs)
84 finally:
84 finally:
85 self.block = saveblock
85 self.block = saveblock
86 return ret
86 return ret
87
87
88
88
89 #--------------------------------------------------------------------------
89 #--------------------------------------------------------------------------
90 # Classes
90 # Classes
91 #--------------------------------------------------------------------------
91 #--------------------------------------------------------------------------
92
92
93 class Metadata(dict):
93 class Metadata(dict):
94 """Subclass of dict for initializing metadata values."""
94 """Subclass of dict for initializing metadata values.
95
96 Attribute access works on keys.
97
98 These objects have a strict set of keys - errors will raise if you try
99 to add new keys.
100 """
95 def __init__(self, *args, **kwargs):
101 def __init__(self, *args, **kwargs):
96 dict.__init__(self)
102 dict.__init__(self)
97 md = {'msg_id' : None,
103 md = {'msg_id' : None,
98 'submitted' : None,
104 'submitted' : None,
99 'started' : None,
105 'started' : None,
100 'completed' : None,
106 'completed' : None,
101 'received' : None,
107 'received' : None,
102 'engine_uuid' : None,
108 'engine_uuid' : None,
103 'engine_id' : None,
109 'engine_id' : None,
104 'follow' : None,
110 'follow' : None,
105 'after' : None,
111 'after' : None,
106 'status' : None,
112 'status' : None,
107
113
108 'pyin' : None,
114 'pyin' : None,
109 'pyout' : None,
115 'pyout' : None,
110 'pyerr' : None,
116 'pyerr' : None,
111 'stdout' : '',
117 'stdout' : '',
112 'stderr' : '',
118 'stderr' : '',
113 }
119 }
114 self.update(md)
120 self.update(md)
115 self.update(dict(*args, **kwargs))
121 self.update(dict(*args, **kwargs))
122
123 def __getattr__(self, key):
124 """getattr aliased to getitem"""
125 if key in self.iterkeys():
126 return self[key]
127 else:
128 raise AttributeError(key)
116
129
130 def __setattr__(self, key, value):
131 """setattr aliased to setitem, with strict"""
132 if key in self.iterkeys():
133 self[key] = value
134 else:
135 raise AttributeError(key)
136
137 def __setitem__(self, key, value):
138 """strict static key enforcement"""
139 if key in self.iterkeys():
140 dict.__setitem__(self, key, value)
141 else:
142 raise KeyError(key)
117
143
118
144
119 class Client(object):
145 class Client(object):
120 """A semi-synchronous client to the IPython ZMQ controller
146 """A semi-synchronous client to the IPython ZMQ controller
121
147
122 Parameters
148 Parameters
123 ----------
149 ----------
124
150
125 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
151 addr : bytes; zmq url, e.g. 'tcp://127.0.0.1:10101'
126 The address of the controller's registration socket.
152 The address of the controller's registration socket.
127 [Default: 'tcp://127.0.0.1:10101']
153 [Default: 'tcp://127.0.0.1:10101']
128 context : zmq.Context
154 context : zmq.Context
129 Pass an existing zmq.Context instance, otherwise the client will create its own
155 Pass an existing zmq.Context instance, otherwise the client will create its own
130 username : bytes
156 username : bytes
131 set username to be passed to the Session object
157 set username to be passed to the Session object
132 debug : bool
158 debug : bool
133 flag for lots of message printing for debug purposes
159 flag for lots of message printing for debug purposes
134
160
135 #-------------- ssh related args ----------------
161 #-------------- ssh related args ----------------
136 # These are args for configuring the ssh tunnel to be used
162 # These are args for configuring the ssh tunnel to be used
137 # credentials are used to forward connections over ssh to the Controller
163 # credentials are used to forward connections over ssh to the Controller
138 # Note that the ip given in `addr` needs to be relative to sshserver
164 # Note that the ip given in `addr` needs to be relative to sshserver
139 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
165 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
140 # and set sshserver as the same machine the Controller is on. However,
166 # and set sshserver as the same machine the Controller is on. However,
141 # the only requirement is that sshserver is able to see the Controller
167 # the only requirement is that sshserver is able to see the Controller
142 # (i.e. is within the same trusted network).
168 # (i.e. is within the same trusted network).
143
169
144 sshserver : str
170 sshserver : str
145 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
171 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
146 If keyfile or password is specified, and this is not, it will default to
172 If keyfile or password is specified, and this is not, it will default to
147 the ip given in addr.
173 the ip given in addr.
148 sshkey : str; path to public ssh key file
174 sshkey : str; path to public ssh key file
149 This specifies a key to be used in ssh login, default None.
175 This specifies a key to be used in ssh login, default None.
150 Regular default ssh keys will be used without specifying this argument.
176 Regular default ssh keys will be used without specifying this argument.
151 password : str;
177 password : str;
152 Your ssh password to sshserver. Note that if this is left None,
178 Your ssh password to sshserver. Note that if this is left None,
153 you will be prompted for it if passwordless key based login is unavailable.
179 you will be prompted for it if passwordless key based login is unavailable.
154
180
155 #------- exec authentication args -------
181 #------- exec authentication args -------
156 # If even localhost is untrusted, you can have some protection against
182 # If even localhost is untrusted, you can have some protection against
157 # unauthorized execution by using a key. Messages are still sent
183 # unauthorized execution by using a key. Messages are still sent
158 # as cleartext, so if someone can snoop your loopback traffic this will
184 # as cleartext, so if someone can snoop your loopback traffic this will
159 # not help anything.
185 # not help anything.
160
186
161 exec_key : str
187 exec_key : str
162 an authentication key or file containing a key
188 an authentication key or file containing a key
163 default: None
189 default: None
164
190
165
191
166 Attributes
192 Attributes
167 ----------
193 ----------
168 ids : set of int engine IDs
194 ids : set of int engine IDs
169 requesting the ids attribute always synchronizes
195 requesting the ids attribute always synchronizes
170 the registration state. To request ids without synchronization,
196 the registration state. To request ids without synchronization,
171 use semi-private _ids attributes.
197 use semi-private _ids attributes.
172
198
173 history : list of msg_ids
199 history : list of msg_ids
174 a list of msg_ids, keeping track of all the execution
200 a list of msg_ids, keeping track of all the execution
175 messages you have submitted in order.
201 messages you have submitted in order.
176
202
177 outstanding : set of msg_ids
203 outstanding : set of msg_ids
178 a set of msg_ids that have been submitted, but whose
204 a set of msg_ids that have been submitted, but whose
179 results have not yet been received.
205 results have not yet been received.
180
206
181 results : dict
207 results : dict
182 a dict of all our results, keyed by msg_id
208 a dict of all our results, keyed by msg_id
183
209
184 block : bool
210 block : bool
185 determines default behavior when block not specified
211 determines default behavior when block not specified
186 in execution methods
212 in execution methods
187
213
188 Methods
214 Methods
189 -------
215 -------
190 spin : flushes incoming results and registration state changes
216 spin : flushes incoming results and registration state changes
191 control methods spin, and requesting `ids` also ensures up to date
217 control methods spin, and requesting `ids` also ensures up to date
192
218
193 barrier : wait on one or more msg_ids
219 barrier : wait on one or more msg_ids
194
220
195 execution methods: apply/apply_bound/apply_to/apply_bound
221 execution methods: apply/apply_bound/apply_to/apply_bound
196 legacy: execute, run
222 legacy: execute, run
197
223
198 query methods: queue_status, get_result, purge
224 query methods: queue_status, get_result, purge
199
225
200 control methods: abort, kill
226 control methods: abort, kill
201
227
202 """
228 """
203
229
204
230
205 _connected=False
231 _connected=False
206 _ssh=False
232 _ssh=False
207 _engines=None
233 _engines=None
208 _addr='tcp://127.0.0.1:10101'
234 _addr='tcp://127.0.0.1:10101'
209 _registration_socket=None
235 _registration_socket=None
210 _query_socket=None
236 _query_socket=None
211 _control_socket=None
237 _control_socket=None
212 _iopub_socket=None
238 _iopub_socket=None
213 _notification_socket=None
239 _notification_socket=None
214 _mux_socket=None
240 _mux_socket=None
215 _task_socket=None
241 _task_socket=None
216 block = False
242 block = False
217 outstanding=None
243 outstanding=None
218 results = None
244 results = None
219 history = None
245 history = None
220 debug = False
246 debug = False
221 targets = None
247 targets = None
222
248
223 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
249 def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False,
224 sshserver=None, sshkey=None, password=None, paramiko=None,
250 sshserver=None, sshkey=None, password=None, paramiko=None,
225 exec_key=None,):
251 exec_key=None,):
226 if context is None:
252 if context is None:
227 context = zmq.Context()
253 context = zmq.Context()
228 self.context = context
254 self.context = context
229 self.targets = 'all'
255 self.targets = 'all'
230 self._addr = addr
256 self._addr = addr
231 self._ssh = bool(sshserver or sshkey or password)
257 self._ssh = bool(sshserver or sshkey or password)
232 if self._ssh and sshserver is None:
258 if self._ssh and sshserver is None:
233 # default to the same
259 # default to the same
234 sshserver = addr.split('://')[1].split(':')[0]
260 sshserver = addr.split('://')[1].split(':')[0]
235 if self._ssh and password is None:
261 if self._ssh and password is None:
236 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
262 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
237 password=False
263 password=False
238 else:
264 else:
239 password = getpass("SSH Password for %s: "%sshserver)
265 password = getpass("SSH Password for %s: "%sshserver)
240 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
266 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
241
267
242 if exec_key is not None and os.path.isfile(exec_key):
268 if exec_key is not None and os.path.isfile(exec_key):
243 arg = 'keyfile'
269 arg = 'keyfile'
244 else:
270 else:
245 arg = 'key'
271 arg = 'key'
246 key_arg = {arg:exec_key}
272 key_arg = {arg:exec_key}
247 if username is None:
273 if username is None:
248 self.session = ss.StreamSession(**key_arg)
274 self.session = ss.StreamSession(**key_arg)
249 else:
275 else:
250 self.session = ss.StreamSession(username, **key_arg)
276 self.session = ss.StreamSession(username, **key_arg)
251 self._registration_socket = self.context.socket(zmq.XREQ)
277 self._registration_socket = self.context.socket(zmq.XREQ)
252 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
278 self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session)
253 if self._ssh:
279 if self._ssh:
254 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
280 tunnel.tunnel_connection(self._registration_socket, addr, sshserver, **ssh_kwargs)
255 else:
281 else:
256 self._registration_socket.connect(addr)
282 self._registration_socket.connect(addr)
257 self._engines = ReverseDict()
283 self._engines = ReverseDict()
258 self._ids = set()
284 self._ids = set()
259 self.outstanding=set()
285 self.outstanding=set()
260 self.results = {}
286 self.results = {}
261 self.metadata = {}
287 self.metadata = {}
262 self.history = []
288 self.history = []
263 self.debug = debug
289 self.debug = debug
264 self.session.debug = debug
290 self.session.debug = debug
265
291
266 self._notification_handlers = {'registration_notification' : self._register_engine,
292 self._notification_handlers = {'registration_notification' : self._register_engine,
267 'unregistration_notification' : self._unregister_engine,
293 'unregistration_notification' : self._unregister_engine,
268 }
294 }
269 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
295 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
270 'apply_reply' : self._handle_apply_reply}
296 'apply_reply' : self._handle_apply_reply}
271 self._connect(sshserver, ssh_kwargs)
297 self._connect(sshserver, ssh_kwargs)
272
298
273
299
274 @property
300 @property
275 def ids(self):
301 def ids(self):
276 """Always up to date ids property."""
302 """Always up to date ids property."""
277 self._flush_notifications()
303 self._flush_notifications()
278 return self._ids
304 return self._ids
279
305
280 def _update_engines(self, engines):
306 def _update_engines(self, engines):
281 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
307 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
282 for k,v in engines.iteritems():
308 for k,v in engines.iteritems():
283 eid = int(k)
309 eid = int(k)
284 self._engines[eid] = bytes(v) # force not unicode
310 self._engines[eid] = bytes(v) # force not unicode
285 self._ids.add(eid)
311 self._ids.add(eid)
286
312
287 def _build_targets(self, targets):
313 def _build_targets(self, targets):
288 """Turn valid target IDs or 'all' into two lists:
314 """Turn valid target IDs or 'all' into two lists:
289 (int_ids, uuids).
315 (int_ids, uuids).
290 """
316 """
291 if targets is None:
317 if targets is None:
292 targets = self._ids
318 targets = self._ids
293 elif isinstance(targets, str):
319 elif isinstance(targets, str):
294 if targets.lower() == 'all':
320 if targets.lower() == 'all':
295 targets = self._ids
321 targets = self._ids
296 else:
322 else:
297 raise TypeError("%r not valid str target, must be 'all'"%(targets))
323 raise TypeError("%r not valid str target, must be 'all'"%(targets))
298 elif isinstance(targets, int):
324 elif isinstance(targets, int):
299 targets = [targets]
325 targets = [targets]
300 return [self._engines[t] for t in targets], list(targets)
326 return [self._engines[t] for t in targets], list(targets)
301
327
302 def _connect(self, sshserver, ssh_kwargs):
328 def _connect(self, sshserver, ssh_kwargs):
303 """setup all our socket connections to the controller. This is called from
329 """setup all our socket connections to the controller. This is called from
304 __init__."""
330 __init__."""
305 if self._connected:
331 if self._connected:
306 return
332 return
307 self._connected=True
333 self._connected=True
308
334
309 def connect_socket(s, addr):
335 def connect_socket(s, addr):
310 if self._ssh:
336 if self._ssh:
311 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
337 return tunnel.tunnel_connection(s, addr, sshserver, **ssh_kwargs)
312 else:
338 else:
313 return s.connect(addr)
339 return s.connect(addr)
314
340
315 self.session.send(self._registration_socket, 'connection_request')
341 self.session.send(self._registration_socket, 'connection_request')
316 idents,msg = self.session.recv(self._registration_socket,mode=0)
342 idents,msg = self.session.recv(self._registration_socket,mode=0)
317 if self.debug:
343 if self.debug:
318 pprint(msg)
344 pprint(msg)
319 msg = ss.Message(msg)
345 msg = ss.Message(msg)
320 content = msg.content
346 content = msg.content
321 if content.status == 'ok':
347 if content.status == 'ok':
322 if content.mux:
348 if content.mux:
323 self._mux_socket = self.context.socket(zmq.PAIR)
349 self._mux_socket = self.context.socket(zmq.PAIR)
324 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
350 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
325 connect_socket(self._mux_socket, content.mux)
351 connect_socket(self._mux_socket, content.mux)
326 if content.task:
352 if content.task:
327 self._task_socket = self.context.socket(zmq.PAIR)
353 self._task_socket = self.context.socket(zmq.PAIR)
328 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
354 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
329 connect_socket(self._task_socket, content.task)
355 connect_socket(self._task_socket, content.task)
330 if content.notification:
356 if content.notification:
331 self._notification_socket = self.context.socket(zmq.SUB)
357 self._notification_socket = self.context.socket(zmq.SUB)
332 connect_socket(self._notification_socket, content.notification)
358 connect_socket(self._notification_socket, content.notification)
333 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
359 self._notification_socket.setsockopt(zmq.SUBSCRIBE, "")
334 if content.query:
360 if content.query:
335 self._query_socket = self.context.socket(zmq.PAIR)
361 self._query_socket = self.context.socket(zmq.PAIR)
336 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
362 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
337 connect_socket(self._query_socket, content.query)
363 connect_socket(self._query_socket, content.query)
338 if content.control:
364 if content.control:
339 self._control_socket = self.context.socket(zmq.PAIR)
365 self._control_socket = self.context.socket(zmq.PAIR)
340 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
366 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
341 connect_socket(self._control_socket, content.control)
367 connect_socket(self._control_socket, content.control)
342 if content.iopub:
368 if content.iopub:
343 self._iopub_socket = self.context.socket(zmq.SUB)
369 self._iopub_socket = self.context.socket(zmq.SUB)
344 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
370 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, '')
345 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
371 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
346 connect_socket(self._iopub_socket, content.iopub)
372 connect_socket(self._iopub_socket, content.iopub)
347 self._update_engines(dict(content.engines))
373 self._update_engines(dict(content.engines))
348
374
349 else:
375 else:
350 self._connected = False
376 self._connected = False
351 raise Exception("Failed to connect!")
377 raise Exception("Failed to connect!")
352
378
353 #--------------------------------------------------------------------------
379 #--------------------------------------------------------------------------
354 # handlers and callbacks for incoming messages
380 # handlers and callbacks for incoming messages
355 #--------------------------------------------------------------------------
381 #--------------------------------------------------------------------------
356
382
357 def _register_engine(self, msg):
383 def _register_engine(self, msg):
358 """Register a new engine, and update our connection info."""
384 """Register a new engine, and update our connection info."""
359 content = msg['content']
385 content = msg['content']
360 eid = content['id']
386 eid = content['id']
361 d = {eid : content['queue']}
387 d = {eid : content['queue']}
362 self._update_engines(d)
388 self._update_engines(d)
363 self._ids.add(int(eid))
389 self._ids.add(int(eid))
364
390
365 def _unregister_engine(self, msg):
391 def _unregister_engine(self, msg):
366 """Unregister an engine that has died."""
392 """Unregister an engine that has died."""
367 content = msg['content']
393 content = msg['content']
368 eid = int(content['id'])
394 eid = int(content['id'])
369 if eid in self._ids:
395 if eid in self._ids:
370 self._ids.remove(eid)
396 self._ids.remove(eid)
371 self._engines.pop(eid)
397 self._engines.pop(eid)
372
398
373 def _extract_metadata(self, header, parent, content):
399 def _extract_metadata(self, header, parent, content):
374 md = {'msg_id' : parent['msg_id'],
400 md = {'msg_id' : parent['msg_id'],
375 'submitted' : datetime.strptime(parent['date'], ss.ISO8601),
376 'started' : datetime.strptime(header['started'], ss.ISO8601),
377 'completed' : datetime.strptime(header['date'], ss.ISO8601),
378 'received' : datetime.now(),
401 'received' : datetime.now(),
379 'engine_uuid' : header['engine'],
402 'engine_uuid' : header.get('engine', None),
380 'engine_id' : self._engines.get(header['engine'], None),
381 'follow' : parent['follow'],
403 'follow' : parent['follow'],
382 'after' : parent['after'],
404 'after' : parent['after'],
383 'status' : content['status'],
405 'status' : content['status'],
384 }
406 }
407
408 if md['engine_uuid'] is not None:
409 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
410
411 if 'date' in parent:
412 md['submitted'] = datetime.strptime(parent['date'], ss.ISO8601)
413 if 'started' in header:
414 md['started'] = datetime.strptime(header['started'], ss.ISO8601)
415 if 'date' in header:
416 md['completed'] = datetime.strptime(header['date'], ss.ISO8601)
385 return md
417 return md
386
418
387 def _handle_execute_reply(self, msg):
419 def _handle_execute_reply(self, msg):
388 """Save the reply to an execute_request into our results.
420 """Save the reply to an execute_request into our results.
389
421
390 execute messages are never actually used. apply is used instead.
422 execute messages are never actually used. apply is used instead.
391 """
423 """
392
424
393 parent = msg['parent_header']
425 parent = msg['parent_header']
394 msg_id = parent['msg_id']
426 msg_id = parent['msg_id']
395 if msg_id not in self.outstanding:
427 if msg_id not in self.outstanding:
396 print("got unknown result: %s"%msg_id)
428 if msg_id in self.history:
429 print ("got stale result: %s"%msg_id)
430 else:
431 print ("got unknown result: %s"%msg_id)
397 else:
432 else:
398 self.outstanding.remove(msg_id)
433 self.outstanding.remove(msg_id)
399 self.results[msg_id] = ss.unwrap_exception(msg['content'])
434 self.results[msg_id] = ss.unwrap_exception(msg['content'])
400
435
401 def _handle_apply_reply(self, msg):
436 def _handle_apply_reply(self, msg):
402 """Save the reply to an apply_request into our results."""
437 """Save the reply to an apply_request into our results."""
403 parent = msg['parent_header']
438 parent = msg['parent_header']
404 msg_id = parent['msg_id']
439 msg_id = parent['msg_id']
405 if msg_id not in self.outstanding:
440 if msg_id not in self.outstanding:
406 print ("got unknown result: %s"%msg_id)
441 if msg_id in self.history:
442 print ("got stale result: %s"%msg_id)
443 print self.results[msg_id]
444 print msg
445 else:
446 print ("got unknown result: %s"%msg_id)
407 else:
447 else:
408 self.outstanding.remove(msg_id)
448 self.outstanding.remove(msg_id)
409 content = msg['content']
449 content = msg['content']
410 header = msg['header']
450 header = msg['header']
411
451
412 # construct metadata:
452 # construct metadata:
413 md = self.metadata.setdefault(msg_id, Metadata())
453 md = self.metadata.setdefault(msg_id, Metadata())
414 md.update(self._extract_metadata(header, parent, content))
454 md.update(self._extract_metadata(header, parent, content))
415 self.metadata[msg_id] = md
455 self.metadata[msg_id] = md
416
456
417 # construct result:
457 # construct result:
418 if content['status'] == 'ok':
458 if content['status'] == 'ok':
419 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
459 self.results[msg_id] = ss.unserialize_object(msg['buffers'])[0]
420 elif content['status'] == 'aborted':
460 elif content['status'] == 'aborted':
421 self.results[msg_id] = error.AbortedTask(msg_id)
461 self.results[msg_id] = error.AbortedTask(msg_id)
422 elif content['status'] == 'resubmitted':
462 elif content['status'] == 'resubmitted':
423 # TODO: handle resubmission
463 # TODO: handle resubmission
424 pass
464 pass
425 else:
465 else:
426 e = ss.unwrap_exception(content)
466 e = ss.unwrap_exception(content)
427 e_uuid = e.engine_info['engineid']
467 if e.engine_info:
428 eid = self._engines[e_uuid]
468 e_uuid = e.engine_info['engineid']
429 e.engine_info['engineid'] = eid
469 eid = self._engines[e_uuid]
470 e.engine_info['engineid'] = eid
430 self.results[msg_id] = e
471 self.results[msg_id] = e
431
472
432 def _flush_notifications(self):
473 def _flush_notifications(self):
433 """Flush notifications of engine registrations waiting
474 """Flush notifications of engine registrations waiting
434 in ZMQ queue."""
475 in ZMQ queue."""
435 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
476 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
436 while msg is not None:
477 while msg is not None:
437 if self.debug:
478 if self.debug:
438 pprint(msg)
479 pprint(msg)
439 msg = msg[-1]
480 msg = msg[-1]
440 msg_type = msg['msg_type']
481 msg_type = msg['msg_type']
441 handler = self._notification_handlers.get(msg_type, None)
482 handler = self._notification_handlers.get(msg_type, None)
442 if handler is None:
483 if handler is None:
443 raise Exception("Unhandled message type: %s"%msg.msg_type)
484 raise Exception("Unhandled message type: %s"%msg.msg_type)
444 else:
485 else:
445 handler(msg)
486 handler(msg)
446 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
487 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
447
488
448 def _flush_results(self, sock):
489 def _flush_results(self, sock):
449 """Flush task or queue results waiting in ZMQ queue."""
490 """Flush task or queue results waiting in ZMQ queue."""
450 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
491 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
451 while msg is not None:
492 while msg is not None:
452 if self.debug:
493 if self.debug:
453 pprint(msg)
494 pprint(msg)
454 msg = msg[-1]
495 msg = msg[-1]
455 msg_type = msg['msg_type']
496 msg_type = msg['msg_type']
456 handler = self._queue_handlers.get(msg_type, None)
497 handler = self._queue_handlers.get(msg_type, None)
457 if handler is None:
498 if handler is None:
458 raise Exception("Unhandled message type: %s"%msg.msg_type)
499 raise Exception("Unhandled message type: %s"%msg.msg_type)
459 else:
500 else:
460 handler(msg)
501 handler(msg)
461 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
502 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
462
503
463 def _flush_control(self, sock):
504 def _flush_control(self, sock):
464 """Flush replies from the control channel waiting
505 """Flush replies from the control channel waiting
465 in the ZMQ queue.
506 in the ZMQ queue.
466
507
467 Currently: ignore them."""
508 Currently: ignore them."""
468 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
509 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
469 while msg is not None:
510 while msg is not None:
470 if self.debug:
511 if self.debug:
471 pprint(msg)
512 pprint(msg)
472 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
513 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
473
514
474 def _flush_iopub(self, sock):
515 def _flush_iopub(self, sock):
475 """Flush replies from the iopub channel waiting
516 """Flush replies from the iopub channel waiting
476 in the ZMQ queue.
517 in the ZMQ queue.
477 """
518 """
478 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
519 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
479 while msg is not None:
520 while msg is not None:
480 if self.debug:
521 if self.debug:
481 pprint(msg)
522 pprint(msg)
482 msg = msg[-1]
523 msg = msg[-1]
483 parent = msg['parent_header']
524 parent = msg['parent_header']
484 msg_id = parent['msg_id']
525 msg_id = parent['msg_id']
485 content = msg['content']
526 content = msg['content']
486 header = msg['header']
527 header = msg['header']
487 msg_type = msg['msg_type']
528 msg_type = msg['msg_type']
488
529
489 # init metadata:
530 # init metadata:
490 md = self.metadata.setdefault(msg_id, Metadata())
531 md = self.metadata.setdefault(msg_id, Metadata())
491
532
492 if msg_type == 'stream':
533 if msg_type == 'stream':
493 name = content['name']
534 name = content['name']
494 s = md[name] or ''
535 s = md[name] or ''
495 md[name] = s + content['data']
536 md[name] = s + content['data']
496 elif msg_type == 'pyerr':
537 elif msg_type == 'pyerr':
497 md.update({'pyerr' : ss.unwrap_exception(content)})
538 md.update({'pyerr' : ss.unwrap_exception(content)})
498 else:
539 else:
499 md.update({msg_type : content['data']})
540 md.update({msg_type : content['data']})
500
541
501 self.metadata[msg_id] = md
542 self.metadata[msg_id] = md
502
543
503 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
544 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
504
545
505 #--------------------------------------------------------------------------
546 #--------------------------------------------------------------------------
506 # getitem
547 # getitem
507 #--------------------------------------------------------------------------
548 #--------------------------------------------------------------------------
508
549
509 def __getitem__(self, key):
550 def __getitem__(self, key):
510 """Dict access returns DirectView multiplexer objects or,
551 """Dict access returns DirectView multiplexer objects or,
511 if key is None, a LoadBalancedView."""
552 if key is None, a LoadBalancedView."""
512 if key is None:
553 if key is None:
513 return LoadBalancedView(self)
554 return LoadBalancedView(self)
514 if isinstance(key, int):
555 if isinstance(key, int):
515 if key not in self.ids:
556 if key not in self.ids:
516 raise IndexError("No such engine: %i"%key)
557 raise IndexError("No such engine: %i"%key)
517 return DirectView(self, key)
558 return DirectView(self, key)
518
559
519 if isinstance(key, slice):
560 if isinstance(key, slice):
520 indices = range(len(self.ids))[key]
561 indices = range(len(self.ids))[key]
521 ids = sorted(self._ids)
562 ids = sorted(self._ids)
522 key = [ ids[i] for i in indices ]
563 key = [ ids[i] for i in indices ]
523 # newkeys = sorted(self._ids)[thekeys[k]]
564 # newkeys = sorted(self._ids)[thekeys[k]]
524
565
525 if isinstance(key, (tuple, list, xrange)):
566 if isinstance(key, (tuple, list, xrange)):
526 _,targets = self._build_targets(list(key))
567 _,targets = self._build_targets(list(key))
527 return DirectView(self, targets)
568 return DirectView(self, targets)
528 else:
569 else:
529 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
570 raise TypeError("key by int/iterable of ints only, not %s"%(type(key)))
530
571
531 #--------------------------------------------------------------------------
572 #--------------------------------------------------------------------------
532 # Begin public methods
573 # Begin public methods
533 #--------------------------------------------------------------------------
574 #--------------------------------------------------------------------------
534
575
535 @property
576 @property
536 def remote(self):
577 def remote(self):
537 """property for convenient RemoteFunction generation.
578 """property for convenient RemoteFunction generation.
538
579
539 >>> @client.remote
580 >>> @client.remote
540 ... def f():
581 ... def f():
541 import os
582 import os
542 print (os.getpid())
583 print (os.getpid())
543 """
584 """
544 return remote(self, block=self.block)
585 return remote(self, block=self.block)
545
586
546 def spin(self):
587 def spin(self):
547 """Flush any registration notifications and execution results
588 """Flush any registration notifications and execution results
548 waiting in the ZMQ queue.
589 waiting in the ZMQ queue.
549 """
590 """
550 if self._notification_socket:
591 if self._notification_socket:
551 self._flush_notifications()
592 self._flush_notifications()
552 if self._mux_socket:
593 if self._mux_socket:
553 self._flush_results(self._mux_socket)
594 self._flush_results(self._mux_socket)
554 if self._task_socket:
595 if self._task_socket:
555 self._flush_results(self._task_socket)
596 self._flush_results(self._task_socket)
556 if self._control_socket:
597 if self._control_socket:
557 self._flush_control(self._control_socket)
598 self._flush_control(self._control_socket)
558 if self._iopub_socket:
599 if self._iopub_socket:
559 self._flush_iopub(self._iopub_socket)
600 self._flush_iopub(self._iopub_socket)
560
601
561 def barrier(self, msg_ids=None, timeout=-1):
602 def barrier(self, msg_ids=None, timeout=-1):
562 """waits on one or more `msg_ids`, for up to `timeout` seconds.
603 """waits on one or more `msg_ids`, for up to `timeout` seconds.
563
604
564 Parameters
605 Parameters
565 ----------
606 ----------
566 msg_ids : int, str, or list of ints and/or strs, or one or more AsyncResult objects
607 msg_ids : int, str, or list of ints and/or strs, or one or more AsyncResult objects
567 ints are indices to self.history
608 ints are indices to self.history
568 strs are msg_ids
609 strs are msg_ids
569 default: wait on all outstanding messages
610 default: wait on all outstanding messages
570 timeout : float
611 timeout : float
571 a time in seconds, after which to give up.
612 a time in seconds, after which to give up.
572 default is -1, which means no timeout
613 default is -1, which means no timeout
573
614
574 Returns
615 Returns
575 -------
616 -------
576 True : when all msg_ids are done
617 True : when all msg_ids are done
577 False : timeout reached, some msg_ids still outstanding
618 False : timeout reached, some msg_ids still outstanding
578 """
619 """
579 tic = time.time()
620 tic = time.time()
580 if msg_ids is None:
621 if msg_ids is None:
581 theids = self.outstanding
622 theids = self.outstanding
582 else:
623 else:
583 if isinstance(msg_ids, (int, str, AsyncResult)):
624 if isinstance(msg_ids, (int, str, AsyncResult)):
584 msg_ids = [msg_ids]
625 msg_ids = [msg_ids]
585 theids = set()
626 theids = set()
586 for msg_id in msg_ids:
627 for msg_id in msg_ids:
587 if isinstance(msg_id, int):
628 if isinstance(msg_id, int):
588 msg_id = self.history[msg_id]
629 msg_id = self.history[msg_id]
589 elif isinstance(msg_id, AsyncResult):
630 elif isinstance(msg_id, AsyncResult):
590 map(theids.add, msg_id.msg_ids)
631 map(theids.add, msg_id.msg_ids)
591 continue
632 continue
592 theids.add(msg_id)
633 theids.add(msg_id)
593 if not theids.intersection(self.outstanding):
634 if not theids.intersection(self.outstanding):
594 return True
635 return True
595 self.spin()
636 self.spin()
596 while theids.intersection(self.outstanding):
637 while theids.intersection(self.outstanding):
597 if timeout >= 0 and ( time.time()-tic ) > timeout:
638 if timeout >= 0 and ( time.time()-tic ) > timeout:
598 break
639 break
599 time.sleep(1e-3)
640 time.sleep(1e-3)
600 self.spin()
641 self.spin()
601 return len(theids.intersection(self.outstanding)) == 0
642 return len(theids.intersection(self.outstanding)) == 0
602
643
603 #--------------------------------------------------------------------------
644 #--------------------------------------------------------------------------
604 # Control methods
645 # Control methods
605 #--------------------------------------------------------------------------
646 #--------------------------------------------------------------------------
606
647
607 @spinfirst
648 @spinfirst
608 @defaultblock
649 @defaultblock
609 def clear(self, targets=None, block=None):
650 def clear(self, targets=None, block=None):
610 """Clear the namespace in target(s)."""
651 """Clear the namespace in target(s)."""
611 targets = self._build_targets(targets)[0]
652 targets = self._build_targets(targets)[0]
612 for t in targets:
653 for t in targets:
613 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
654 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
614 error = False
655 error = False
615 if self.block:
656 if self.block:
616 for i in range(len(targets)):
657 for i in range(len(targets)):
617 idents,msg = self.session.recv(self._control_socket,0)
658 idents,msg = self.session.recv(self._control_socket,0)
618 if self.debug:
659 if self.debug:
619 pprint(msg)
660 pprint(msg)
620 if msg['content']['status'] != 'ok':
661 if msg['content']['status'] != 'ok':
621 error = ss.unwrap_exception(msg['content'])
662 error = ss.unwrap_exception(msg['content'])
622 if error:
663 if error:
623 return error
664 return error
624
665
625
666
626 @spinfirst
667 @spinfirst
627 @defaultblock
668 @defaultblock
628 def abort(self, msg_ids = None, targets=None, block=None):
669 def abort(self, msg_ids = None, targets=None, block=None):
629 """Abort the execution queues of target(s)."""
670 """Abort the execution queues of target(s)."""
630 targets = self._build_targets(targets)[0]
671 targets = self._build_targets(targets)[0]
631 if isinstance(msg_ids, basestring):
672 if isinstance(msg_ids, basestring):
632 msg_ids = [msg_ids]
673 msg_ids = [msg_ids]
633 content = dict(msg_ids=msg_ids)
674 content = dict(msg_ids=msg_ids)
634 for t in targets:
675 for t in targets:
635 self.session.send(self._control_socket, 'abort_request',
676 self.session.send(self._control_socket, 'abort_request',
636 content=content, ident=t)
677 content=content, ident=t)
637 error = False
678 error = False
638 if self.block:
679 if self.block:
639 for i in range(len(targets)):
680 for i in range(len(targets)):
640 idents,msg = self.session.recv(self._control_socket,0)
681 idents,msg = self.session.recv(self._control_socket,0)
641 if self.debug:
682 if self.debug:
642 pprint(msg)
683 pprint(msg)
643 if msg['content']['status'] != 'ok':
684 if msg['content']['status'] != 'ok':
644 error = ss.unwrap_exception(msg['content'])
685 error = ss.unwrap_exception(msg['content'])
645 if error:
686 if error:
646 return error
687 return error
647
688
648 @spinfirst
689 @spinfirst
649 @defaultblock
690 @defaultblock
650 def shutdown(self, targets=None, restart=False, controller=False, block=None):
691 def shutdown(self, targets=None, restart=False, controller=False, block=None):
651 """Terminates one or more engine processes, optionally including the controller."""
692 """Terminates one or more engine processes, optionally including the controller."""
652 if controller:
693 if controller:
653 targets = 'all'
694 targets = 'all'
654 targets = self._build_targets(targets)[0]
695 targets = self._build_targets(targets)[0]
655 for t in targets:
696 for t in targets:
656 self.session.send(self._control_socket, 'shutdown_request',
697 self.session.send(self._control_socket, 'shutdown_request',
657 content={'restart':restart},ident=t)
698 content={'restart':restart},ident=t)
658 error = False
699 error = False
659 if block or controller:
700 if block or controller:
660 for i in range(len(targets)):
701 for i in range(len(targets)):
661 idents,msg = self.session.recv(self._control_socket,0)
702 idents,msg = self.session.recv(self._control_socket,0)
662 if self.debug:
703 if self.debug:
663 pprint(msg)
704 pprint(msg)
664 if msg['content']['status'] != 'ok':
705 if msg['content']['status'] != 'ok':
665 error = ss.unwrap_exception(msg['content'])
706 error = ss.unwrap_exception(msg['content'])
666
707
667 if controller:
708 if controller:
668 time.sleep(0.25)
709 time.sleep(0.25)
669 self.session.send(self._query_socket, 'shutdown_request')
710 self.session.send(self._query_socket, 'shutdown_request')
670 idents,msg = self.session.recv(self._query_socket, 0)
711 idents,msg = self.session.recv(self._query_socket, 0)
671 if self.debug:
712 if self.debug:
672 pprint(msg)
713 pprint(msg)
673 if msg['content']['status'] != 'ok':
714 if msg['content']['status'] != 'ok':
674 error = ss.unwrap_exception(msg['content'])
715 error = ss.unwrap_exception(msg['content'])
675
716
676 if error:
717 if error:
677 raise error
718 raise error
678
719
679 #--------------------------------------------------------------------------
720 #--------------------------------------------------------------------------
680 # Execution methods
721 # Execution methods
681 #--------------------------------------------------------------------------
722 #--------------------------------------------------------------------------
682
723
683 @defaultblock
724 @defaultblock
684 def execute(self, code, targets='all', block=None):
725 def execute(self, code, targets='all', block=None):
685 """Executes `code` on `targets` in blocking or nonblocking manner.
726 """Executes `code` on `targets` in blocking or nonblocking manner.
686
727
687 ``execute`` is always `bound` (affects engine namespace)
728 ``execute`` is always `bound` (affects engine namespace)
688
729
689 Parameters
730 Parameters
690 ----------
731 ----------
691 code : str
732 code : str
692 the code string to be executed
733 the code string to be executed
693 targets : int/str/list of ints/strs
734 targets : int/str/list of ints/strs
694 the engines on which to execute
735 the engines on which to execute
695 default : all
736 default : all
696 block : bool
737 block : bool
697 whether or not to wait until done to return
738 whether or not to wait until done to return
698 default: self.block
739 default: self.block
699 """
740 """
700 result = self.apply(_execute, (code,), targets=targets, block=self.block, bound=True)
741 result = self.apply(_execute, (code,), targets=targets, block=self.block, bound=True)
701 return result
742 return result
702
743
703 def run(self, code, block=None):
744 def run(self, code, block=None):
704 """Runs `code` on an engine.
745 """Runs `code` on an engine.
705
746
706 Calls to this are load-balanced.
747 Calls to this are load-balanced.
707
748
708 ``run`` is never `bound` (no effect on engine namespace)
749 ``run`` is never `bound` (no effect on engine namespace)
709
750
710 Parameters
751 Parameters
711 ----------
752 ----------
712 code : str
753 code : str
713 the code string to be executed
754 the code string to be executed
714 block : bool
755 block : bool
715 whether or not to wait until done
756 whether or not to wait until done
716
757
717 """
758 """
718 result = self.apply(_execute, (code,), targets=None, block=block, bound=False)
759 result = self.apply(_execute, (code,), targets=None, block=block, bound=False)
719 return result
760 return result
720
761
721 def _maybe_raise(self, result):
762 def _maybe_raise(self, result):
722 """wrapper for maybe raising an exception if apply failed."""
763 """wrapper for maybe raising an exception if apply failed."""
723 if isinstance(result, error.RemoteError):
764 if isinstance(result, error.RemoteError):
724 raise result
765 raise result
725
766
726 return result
767 return result
727
768
728 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
769 def apply(self, f, args=None, kwargs=None, bound=True, block=None, targets=None,
729 after=None, follow=None):
770 after=None, follow=None):
730 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
771 """Call `f(*args, **kwargs)` on a remote engine(s), returning the result.
731
772
732 This is the central execution command for the client.
773 This is the central execution command for the client.
733
774
734 Parameters
775 Parameters
735 ----------
776 ----------
736
777
737 f : function
778 f : function
738 The fuction to be called remotely
779 The fuction to be called remotely
739 args : tuple/list
780 args : tuple/list
740 The positional arguments passed to `f`
781 The positional arguments passed to `f`
741 kwargs : dict
782 kwargs : dict
742 The keyword arguments passed to `f`
783 The keyword arguments passed to `f`
743 bound : bool (default: True)
784 bound : bool (default: True)
744 Whether to execute in the Engine(s) namespace, or in a clean
785 Whether to execute in the Engine(s) namespace, or in a clean
745 namespace not affecting the engine.
786 namespace not affecting the engine.
746 block : bool (default: self.block)
787 block : bool (default: self.block)
747 Whether to wait for the result, or return immediately.
788 Whether to wait for the result, or return immediately.
748 False:
789 False:
749 returns msg_id(s)
790 returns msg_id(s)
750 if multiple targets:
791 if multiple targets:
751 list of ids
792 list of ids
752 True:
793 True:
753 returns actual result(s) of f(*args, **kwargs)
794 returns actual result(s) of f(*args, **kwargs)
754 if multiple targets:
795 if multiple targets:
755 dict of results, by engine ID
796 dict of results, by engine ID
756 targets : int,list of ints, 'all', None
797 targets : int,list of ints, 'all', None
757 Specify the destination of the job.
798 Specify the destination of the job.
758 if None:
799 if None:
759 Submit via Task queue for load-balancing.
800 Submit via Task queue for load-balancing.
760 if 'all':
801 if 'all':
761 Run on all active engines
802 Run on all active engines
762 if list:
803 if list:
763 Run on each specified engine
804 Run on each specified engine
764 if int:
805 if int:
765 Run on single engine
806 Run on single engine
766
807
767 after : Dependency or collection of msg_ids
808 after : Dependency or collection of msg_ids
768 Only for load-balanced execution (targets=None)
809 Only for load-balanced execution (targets=None)
769 Specify a list of msg_ids as a time-based dependency.
810 Specify a list of msg_ids as a time-based dependency.
770 This job will only be run *after* the dependencies
811 This job will only be run *after* the dependencies
771 have been met.
812 have been met.
772
813
773 follow : Dependency or collection of msg_ids
814 follow : Dependency or collection of msg_ids
774 Only for load-balanced execution (targets=None)
815 Only for load-balanced execution (targets=None)
775 Specify a list of msg_ids as a location-based dependency.
816 Specify a list of msg_ids as a location-based dependency.
776 This job will only be run on an engine where this dependency
817 This job will only be run on an engine where this dependency
777 is met.
818 is met.
778
819
779 Returns
820 Returns
780 -------
821 -------
781 if block is False:
822 if block is False:
782 if single target:
823 if single target:
783 return msg_id
824 return msg_id
784 else:
825 else:
785 return list of msg_ids
826 return list of msg_ids
786 ? (should this be dict like block=True) ?
827 ? (should this be dict like block=True) ?
787 else:
828 else:
788 if single target:
829 if single target:
789 return result of f(*args, **kwargs)
830 return result of f(*args, **kwargs)
790 else:
831 else:
791 return dict of results, keyed by engine
832 return dict of results, keyed by engine
792 """
833 """
793
834
794 # defaults:
835 # defaults:
795 block = block if block is not None else self.block
836 block = block if block is not None else self.block
796 args = args if args is not None else []
837 args = args if args is not None else []
797 kwargs = kwargs if kwargs is not None else {}
838 kwargs = kwargs if kwargs is not None else {}
798
839
799 # enforce types of f,args,kwrags
840 # enforce types of f,args,kwrags
800 if not callable(f):
841 if not callable(f):
801 raise TypeError("f must be callable, not %s"%type(f))
842 raise TypeError("f must be callable, not %s"%type(f))
802 if not isinstance(args, (tuple, list)):
843 if not isinstance(args, (tuple, list)):
803 raise TypeError("args must be tuple or list, not %s"%type(args))
844 raise TypeError("args must be tuple or list, not %s"%type(args))
804 if not isinstance(kwargs, dict):
845 if not isinstance(kwargs, dict):
805 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
846 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
806
847
807 if isinstance(after, Dependency):
848 if isinstance(after, Dependency):
808 after = after.as_dict()
849 after = after.as_dict()
809 elif isinstance(after, AsyncResult):
850 elif isinstance(after, AsyncResult):
810 after=after.msg_ids
851 after=after.msg_ids
811 elif after is None:
852 elif after is None:
812 after = []
853 after = []
813 if isinstance(follow, Dependency):
854 if isinstance(follow, Dependency):
855 # if len(follow) > 1 and follow.mode == 'all':
856 # warn("complex follow-dependencies are not rigorously tested for reachability", UserWarning)
814 follow = follow.as_dict()
857 follow = follow.as_dict()
815 elif isinstance(follow, AsyncResult):
858 elif isinstance(follow, AsyncResult):
816 follow=follow.msg_ids
859 follow=follow.msg_ids
817 elif follow is None:
860 elif follow is None:
818 follow = []
861 follow = []
819 options = dict(bound=bound, block=block, after=after, follow=follow)
862 options = dict(bound=bound, block=block, after=after, follow=follow)
820
863
821 if targets is None:
864 if targets is None:
822 return self._apply_balanced(f, args, kwargs, **options)
865 return self._apply_balanced(f, args, kwargs, **options)
823 else:
866 else:
824 return self._apply_direct(f, args, kwargs, targets=targets, **options)
867 return self._apply_direct(f, args, kwargs, targets=targets, **options)
825
868
826 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
869 def _apply_balanced(self, f, args, kwargs, bound=True, block=None,
827 after=None, follow=None):
870 after=None, follow=None):
828 """The underlying method for applying functions in a load balanced
871 """The underlying method for applying functions in a load balanced
829 manner, via the task queue."""
872 manner, via the task queue."""
830
831 subheader = dict(after=after, follow=follow)
873 subheader = dict(after=after, follow=follow)
832 bufs = ss.pack_apply_message(f,args,kwargs)
874 bufs = ss.pack_apply_message(f,args,kwargs)
833 content = dict(bound=bound)
875 content = dict(bound=bound)
834
876
835 msg = self.session.send(self._task_socket, "apply_request",
877 msg = self.session.send(self._task_socket, "apply_request",
836 content=content, buffers=bufs, subheader=subheader)
878 content=content, buffers=bufs, subheader=subheader)
837 msg_id = msg['msg_id']
879 msg_id = msg['msg_id']
838 self.outstanding.add(msg_id)
880 self.outstanding.add(msg_id)
839 self.history.append(msg_id)
881 self.history.append(msg_id)
840 ar = AsyncResult(self, [msg_id], fname=f.__name__)
882 ar = AsyncResult(self, [msg_id], fname=f.__name__)
841 if block:
883 if block:
842 return ar.get()
884 return ar.get()
843 else:
885 else:
844 return ar
886 return ar
845
887
846 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
888 def _apply_direct(self, f, args, kwargs, bound=True, block=None, targets=None,
847 after=None, follow=None):
889 after=None, follow=None):
848 """Then underlying method for applying functions to specific engines
890 """Then underlying method for applying functions to specific engines
849 via the MUX queue."""
891 via the MUX queue."""
850
892
851 queues,targets = self._build_targets(targets)
893 queues,targets = self._build_targets(targets)
852
894
853 subheader = dict(after=after, follow=follow)
895 subheader = dict(after=after, follow=follow)
854 content = dict(bound=bound)
896 content = dict(bound=bound)
855 bufs = ss.pack_apply_message(f,args,kwargs)
897 bufs = ss.pack_apply_message(f,args,kwargs)
856
898
857 msg_ids = []
899 msg_ids = []
858 for queue in queues:
900 for queue in queues:
859 msg = self.session.send(self._mux_socket, "apply_request",
901 msg = self.session.send(self._mux_socket, "apply_request",
860 content=content, buffers=bufs,ident=queue, subheader=subheader)
902 content=content, buffers=bufs,ident=queue, subheader=subheader)
861 msg_id = msg['msg_id']
903 msg_id = msg['msg_id']
862 self.outstanding.add(msg_id)
904 self.outstanding.add(msg_id)
863 self.history.append(msg_id)
905 self.history.append(msg_id)
864 msg_ids.append(msg_id)
906 msg_ids.append(msg_id)
865 ar = AsyncResult(self, msg_ids, fname=f.__name__)
907 ar = AsyncResult(self, msg_ids, fname=f.__name__)
866 if block:
908 if block:
867 return ar.get()
909 return ar.get()
868 else:
910 else:
869 return ar
911 return ar
870
912
871 #--------------------------------------------------------------------------
913 #--------------------------------------------------------------------------
872 # Map and decorators
914 # Map and decorators
873 #--------------------------------------------------------------------------
915 #--------------------------------------------------------------------------
874
916
875 def map(self, f, *sequences):
917 def map(self, f, *sequences):
876 """Parallel version of builtin `map`, using all our engines."""
918 """Parallel version of builtin `map`, using all our engines."""
877 pf = ParallelFunction(self, f, block=self.block,
919 pf = ParallelFunction(self, f, block=self.block,
878 bound=True, targets='all')
920 bound=True, targets='all')
879 return pf.map(*sequences)
921 return pf.map(*sequences)
880
922
881 def parallel(self, bound=True, targets='all', block=True):
923 def parallel(self, bound=True, targets='all', block=True):
882 """Decorator for making a ParallelFunction."""
924 """Decorator for making a ParallelFunction."""
883 return parallel(self, bound=bound, targets=targets, block=block)
925 return parallel(self, bound=bound, targets=targets, block=block)
884
926
885 def remote(self, bound=True, targets='all', block=True):
927 def remote(self, bound=True, targets='all', block=True):
886 """Decorator for making a RemoteFunction."""
928 """Decorator for making a RemoteFunction."""
887 return remote(self, bound=bound, targets=targets, block=block)
929 return remote(self, bound=bound, targets=targets, block=block)
888
930
889 #--------------------------------------------------------------------------
931 #--------------------------------------------------------------------------
890 # Data movement
932 # Data movement
891 #--------------------------------------------------------------------------
933 #--------------------------------------------------------------------------
892
934
893 @defaultblock
935 @defaultblock
894 def push(self, ns, targets='all', block=None):
936 def push(self, ns, targets='all', block=None):
895 """Push the contents of `ns` into the namespace on `target`"""
937 """Push the contents of `ns` into the namespace on `target`"""
896 if not isinstance(ns, dict):
938 if not isinstance(ns, dict):
897 raise TypeError("Must be a dict, not %s"%type(ns))
939 raise TypeError("Must be a dict, not %s"%type(ns))
898 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
940 result = self.apply(_push, (ns,), targets=targets, block=block, bound=True)
899 return result
941 return result
900
942
901 @defaultblock
943 @defaultblock
902 def pull(self, keys, targets='all', block=None):
944 def pull(self, keys, targets='all', block=None):
903 """Pull objects from `target`'s namespace by `keys`"""
945 """Pull objects from `target`'s namespace by `keys`"""
904 if isinstance(keys, str):
946 if isinstance(keys, str):
905 pass
947 pass
906 elif isinstance(keys, (list,tuple,set)):
948 elif isinstance(keys, (list,tuple,set)):
907 for key in keys:
949 for key in keys:
908 if not isinstance(key, str):
950 if not isinstance(key, str):
909 raise TypeError
951 raise TypeError
910 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
952 result = self.apply(_pull, (keys,), targets=targets, block=block, bound=True)
911 return result
953 return result
912
954
913 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
955 def scatter(self, key, seq, dist='b', flatten=False, targets='all', block=None):
914 """
956 """
915 Partition a Python sequence and send the partitions to a set of engines.
957 Partition a Python sequence and send the partitions to a set of engines.
916 """
958 """
917 block = block if block is not None else self.block
959 block = block if block is not None else self.block
918 targets = self._build_targets(targets)[-1]
960 targets = self._build_targets(targets)[-1]
919 mapObject = Map.dists[dist]()
961 mapObject = Map.dists[dist]()
920 nparts = len(targets)
962 nparts = len(targets)
921 msg_ids = []
963 msg_ids = []
922 for index, engineid in enumerate(targets):
964 for index, engineid in enumerate(targets):
923 partition = mapObject.getPartition(seq, index, nparts)
965 partition = mapObject.getPartition(seq, index, nparts)
924 if flatten and len(partition) == 1:
966 if flatten and len(partition) == 1:
925 r = self.push({key: partition[0]}, targets=engineid, block=False)
967 r = self.push({key: partition[0]}, targets=engineid, block=False)
926 else:
968 else:
927 r = self.push({key: partition}, targets=engineid, block=False)
969 r = self.push({key: partition}, targets=engineid, block=False)
928 msg_ids.extend(r.msg_ids)
970 msg_ids.extend(r.msg_ids)
929 r = AsyncResult(self, msg_ids, fname='scatter')
971 r = AsyncResult(self, msg_ids, fname='scatter')
930 if block:
972 if block:
931 return r.get()
973 return r.get()
932 else:
974 else:
933 return r
975 return r
934
976
935 def gather(self, key, dist='b', targets='all', block=None):
977 def gather(self, key, dist='b', targets='all', block=None):
936 """
978 """
937 Gather a partitioned sequence on a set of engines as a single local seq.
979 Gather a partitioned sequence on a set of engines as a single local seq.
938 """
980 """
939 block = block if block is not None else self.block
981 block = block if block is not None else self.block
940
982
941 targets = self._build_targets(targets)[-1]
983 targets = self._build_targets(targets)[-1]
942 mapObject = Map.dists[dist]()
984 mapObject = Map.dists[dist]()
943 msg_ids = []
985 msg_ids = []
944 for index, engineid in enumerate(targets):
986 for index, engineid in enumerate(targets):
945 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
987 msg_ids.extend(self.pull(key, targets=engineid,block=False).msg_ids)
946
988
947 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
989 r = AsyncMapResult(self, msg_ids, mapObject, fname='gather')
948 if block:
990 if block:
949 return r.get()
991 return r.get()
950 else:
992 else:
951 return r
993 return r
952
994
953 #--------------------------------------------------------------------------
995 #--------------------------------------------------------------------------
954 # Query methods
996 # Query methods
955 #--------------------------------------------------------------------------
997 #--------------------------------------------------------------------------
956
998
957 @spinfirst
999 @spinfirst
958 def get_results(self, msg_ids, status_only=False):
1000 def get_results(self, msg_ids, status_only=False):
959 """Returns the result of the execute or task request with `msg_ids`.
1001 """Returns the result of the execute or task request with `msg_ids`.
960
1002
961 Parameters
1003 Parameters
962 ----------
1004 ----------
963 msg_ids : list of ints or msg_ids
1005 msg_ids : list of ints or msg_ids
964 if int:
1006 if int:
965 Passed as index to self.history for convenience.
1007 Passed as index to self.history for convenience.
966 status_only : bool (default: False)
1008 status_only : bool (default: False)
967 if False:
1009 if False:
968 return the actual results
1010 return the actual results
969
1011
970 Returns
1012 Returns
971 -------
1013 -------
972
1014
973 results : dict
1015 results : dict
974 There will always be the keys 'pending' and 'completed', which will
1016 There will always be the keys 'pending' and 'completed', which will
975 be lists of msg_ids.
1017 be lists of msg_ids.
976 """
1018 """
977 if not isinstance(msg_ids, (list,tuple)):
1019 if not isinstance(msg_ids, (list,tuple)):
978 msg_ids = [msg_ids]
1020 msg_ids = [msg_ids]
979 theids = []
1021 theids = []
980 for msg_id in msg_ids:
1022 for msg_id in msg_ids:
981 if isinstance(msg_id, int):
1023 if isinstance(msg_id, int):
982 msg_id = self.history[msg_id]
1024 msg_id = self.history[msg_id]
983 if not isinstance(msg_id, str):
1025 if not isinstance(msg_id, str):
984 raise TypeError("msg_ids must be str, not %r"%msg_id)
1026 raise TypeError("msg_ids must be str, not %r"%msg_id)
985 theids.append(msg_id)
1027 theids.append(msg_id)
986
1028
987 completed = []
1029 completed = []
988 local_results = {}
1030 local_results = {}
989 # temporarily disable local shortcut
1031 # temporarily disable local shortcut
990 # for msg_id in list(theids):
1032 # for msg_id in list(theids):
991 # if msg_id in self.results:
1033 # if msg_id in self.results:
992 # completed.append(msg_id)
1034 # completed.append(msg_id)
993 # local_results[msg_id] = self.results[msg_id]
1035 # local_results[msg_id] = self.results[msg_id]
994 # theids.remove(msg_id)
1036 # theids.remove(msg_id)
995
1037
996 if theids: # some not locally cached
1038 if theids: # some not locally cached
997 content = dict(msg_ids=theids, status_only=status_only)
1039 content = dict(msg_ids=theids, status_only=status_only)
998 msg = self.session.send(self._query_socket, "result_request", content=content)
1040 msg = self.session.send(self._query_socket, "result_request", content=content)
999 zmq.select([self._query_socket], [], [])
1041 zmq.select([self._query_socket], [], [])
1000 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1042 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1001 if self.debug:
1043 if self.debug:
1002 pprint(msg)
1044 pprint(msg)
1003 content = msg['content']
1045 content = msg['content']
1004 if content['status'] != 'ok':
1046 if content['status'] != 'ok':
1005 raise ss.unwrap_exception(content)
1047 raise ss.unwrap_exception(content)
1006 buffers = msg['buffers']
1048 buffers = msg['buffers']
1007 else:
1049 else:
1008 content = dict(completed=[],pending=[])
1050 content = dict(completed=[],pending=[])
1009
1051
1010 content['completed'].extend(completed)
1052 content['completed'].extend(completed)
1011
1053
1012 if status_only:
1054 if status_only:
1013 return content
1055 return content
1014
1056
1015 failures = []
1057 failures = []
1016 # load cached results into result:
1058 # load cached results into result:
1017 content.update(local_results)
1059 content.update(local_results)
1018 # update cache with results:
1060 # update cache with results:
1019 for msg_id in sorted(theids):
1061 for msg_id in sorted(theids):
1020 if msg_id in content['completed']:
1062 if msg_id in content['completed']:
1021 rec = content[msg_id]
1063 rec = content[msg_id]
1022 parent = rec['header']
1064 parent = rec['header']
1023 header = rec['result_header']
1065 header = rec['result_header']
1024 rcontent = rec['result_content']
1066 rcontent = rec['result_content']
1025 iodict = rec['io']
1067 iodict = rec['io']
1026 if isinstance(rcontent, str):
1068 if isinstance(rcontent, str):
1027 rcontent = self.session.unpack(rcontent)
1069 rcontent = self.session.unpack(rcontent)
1028
1070
1029 md = self.metadata.setdefault(msg_id, Metadata())
1071 md = self.metadata.setdefault(msg_id, Metadata())
1030 md.update(self._extract_metadata(header, parent, rcontent))
1072 md.update(self._extract_metadata(header, parent, rcontent))
1031 md.update(iodict)
1073 md.update(iodict)
1032
1074
1033 if rcontent['status'] == 'ok':
1075 if rcontent['status'] == 'ok':
1034 res,buffers = ss.unserialize_object(buffers)
1076 res,buffers = ss.unserialize_object(buffers)
1035 else:
1077 else:
1036 res = ss.unwrap_exception(rcontent)
1078 res = ss.unwrap_exception(rcontent)
1037 failures.append(res)
1079 failures.append(res)
1038
1080
1039 self.results[msg_id] = res
1081 self.results[msg_id] = res
1040 content[msg_id] = res
1082 content[msg_id] = res
1041
1083
1042 error.collect_exceptions(failures, "get_results")
1084 error.collect_exceptions(failures, "get_results")
1043 return content
1085 return content
1044
1086
1045 @spinfirst
1087 @spinfirst
1046 def queue_status(self, targets=None, verbose=False):
1088 def queue_status(self, targets=None, verbose=False):
1047 """Fetch the status of engine queues.
1089 """Fetch the status of engine queues.
1048
1090
1049 Parameters
1091 Parameters
1050 ----------
1092 ----------
1051 targets : int/str/list of ints/strs
1093 targets : int/str/list of ints/strs
1052 the engines on which to execute
1094 the engines on which to execute
1053 default : all
1095 default : all
1054 verbose : bool
1096 verbose : bool
1055 Whether to return lengths only, or lists of ids for each element
1097 Whether to return lengths only, or lists of ids for each element
1056 """
1098 """
1057 targets = self._build_targets(targets)[1]
1099 targets = self._build_targets(targets)[1]
1058 content = dict(targets=targets, verbose=verbose)
1100 content = dict(targets=targets, verbose=verbose)
1059 self.session.send(self._query_socket, "queue_request", content=content)
1101 self.session.send(self._query_socket, "queue_request", content=content)
1060 idents,msg = self.session.recv(self._query_socket, 0)
1102 idents,msg = self.session.recv(self._query_socket, 0)
1061 if self.debug:
1103 if self.debug:
1062 pprint(msg)
1104 pprint(msg)
1063 content = msg['content']
1105 content = msg['content']
1064 status = content.pop('status')
1106 status = content.pop('status')
1065 if status != 'ok':
1107 if status != 'ok':
1066 raise ss.unwrap_exception(content)
1108 raise ss.unwrap_exception(content)
1067 return ss.rekey(content)
1109 return ss.rekey(content)
1068
1110
1069 @spinfirst
1111 @spinfirst
1070 def purge_results(self, msg_ids=[], targets=[]):
1112 def purge_results(self, msg_ids=[], targets=[]):
1071 """Tell the controller to forget results.
1113 """Tell the controller to forget results.
1072
1114
1073 Individual results can be purged by msg_id, or the entire
1115 Individual results can be purged by msg_id, or the entire
1074 history of specific targets can be purged.
1116 history of specific targets can be purged.
1075
1117
1076 Parameters
1118 Parameters
1077 ----------
1119 ----------
1078 msg_ids : str or list of strs
1120 msg_ids : str or list of strs
1079 the msg_ids whose results should be forgotten.
1121 the msg_ids whose results should be forgotten.
1080 targets : int/str/list of ints/strs
1122 targets : int/str/list of ints/strs
1081 The targets, by uuid or int_id, whose entire history is to be purged.
1123 The targets, by uuid or int_id, whose entire history is to be purged.
1082 Use `targets='all'` to scrub everything from the controller's memory.
1124 Use `targets='all'` to scrub everything from the controller's memory.
1083
1125
1084 default : None
1126 default : None
1085 """
1127 """
1086 if not targets and not msg_ids:
1128 if not targets and not msg_ids:
1087 raise ValueError
1129 raise ValueError
1088 if targets:
1130 if targets:
1089 targets = self._build_targets(targets)[1]
1131 targets = self._build_targets(targets)[1]
1090 content = dict(targets=targets, msg_ids=msg_ids)
1132 content = dict(targets=targets, msg_ids=msg_ids)
1091 self.session.send(self._query_socket, "purge_request", content=content)
1133 self.session.send(self._query_socket, "purge_request", content=content)
1092 idents, msg = self.session.recv(self._query_socket, 0)
1134 idents, msg = self.session.recv(self._query_socket, 0)
1093 if self.debug:
1135 if self.debug:
1094 pprint(msg)
1136 pprint(msg)
1095 content = msg['content']
1137 content = msg['content']
1096 if content['status'] != 'ok':
1138 if content['status'] != 'ok':
1097 raise ss.unwrap_exception(content)
1139 raise ss.unwrap_exception(content)
1098
1140
1099 #----------------------------------------
1141 #----------------------------------------
1100 # activate for %px,%autopx magics
1142 # activate for %px,%autopx magics
1101 #----------------------------------------
1143 #----------------------------------------
1102 def activate(self):
1144 def activate(self):
1103 """Make this `View` active for parallel magic commands.
1145 """Make this `View` active for parallel magic commands.
1104
1146
1105 IPython has a magic command syntax to work with `MultiEngineClient` objects.
1147 IPython has a magic command syntax to work with `MultiEngineClient` objects.
1106 In a given IPython session there is a single active one. While
1148 In a given IPython session there is a single active one. While
1107 there can be many `Views` created and used by the user,
1149 there can be many `Views` created and used by the user,
1108 there is only one active one. The active `View` is used whenever
1150 there is only one active one. The active `View` is used whenever
1109 the magic commands %px and %autopx are used.
1151 the magic commands %px and %autopx are used.
1110
1152
1111 The activate() method is called on a given `View` to make it
1153 The activate() method is called on a given `View` to make it
1112 active. Once this has been done, the magic commands can be used.
1154 active. Once this has been done, the magic commands can be used.
1113 """
1155 """
1114
1156
1115 try:
1157 try:
1116 # This is injected into __builtins__.
1158 # This is injected into __builtins__.
1117 ip = get_ipython()
1159 ip = get_ipython()
1118 except NameError:
1160 except NameError:
1119 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
1161 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
1120 else:
1162 else:
1121 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
1163 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
1122 if pmagic is not None:
1164 if pmagic is not None:
1123 pmagic.active_multiengine_client = self
1165 pmagic.active_multiengine_client = self
1124 else:
1166 else:
1125 print "You must first load the parallelmagic extension " \
1167 print "You must first load the parallelmagic extension " \
1126 "by doing '%load_ext parallelmagic'"
1168 "by doing '%load_ext parallelmagic'"
1127
1169
1128 class AsynClient(Client):
1170 class AsynClient(Client):
1129 """An Asynchronous client, using the Tornado Event Loop.
1171 """An Asynchronous client, using the Tornado Event Loop.
1130 !!!unfinished!!!"""
1172 !!!unfinished!!!"""
1131 io_loop = None
1173 io_loop = None
1132 _queue_stream = None
1174 _queue_stream = None
1133 _notifier_stream = None
1175 _notifier_stream = None
1134 _task_stream = None
1176 _task_stream = None
1135 _control_stream = None
1177 _control_stream = None
1136
1178
1137 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
1179 def __init__(self, addr, context=None, username=None, debug=False, io_loop=None):
1138 Client.__init__(self, addr, context, username, debug)
1180 Client.__init__(self, addr, context, username, debug)
1139 if io_loop is None:
1181 if io_loop is None:
1140 io_loop = ioloop.IOLoop.instance()
1182 io_loop = ioloop.IOLoop.instance()
1141 self.io_loop = io_loop
1183 self.io_loop = io_loop
1142
1184
1143 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
1185 self._queue_stream = zmqstream.ZMQStream(self._mux_socket, io_loop)
1144 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
1186 self._control_stream = zmqstream.ZMQStream(self._control_socket, io_loop)
1145 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
1187 self._task_stream = zmqstream.ZMQStream(self._task_socket, io_loop)
1146 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
1188 self._notification_stream = zmqstream.ZMQStream(self._notification_socket, io_loop)
1147
1189
1148 def spin(self):
1190 def spin(self):
1149 for stream in (self.queue_stream, self.notifier_stream,
1191 for stream in (self.queue_stream, self.notifier_stream,
1150 self.task_stream, self.control_stream):
1192 self.task_stream, self.control_stream):
1151 stream.flush()
1193 stream.flush()
1152
1194
1153 __all__ = [ 'Client',
1195 __all__ = [ 'Client',
1154 'depend',
1196 'depend',
1155 'require',
1197 'require',
1156 'remote',
1198 'remote',
1157 'parallel',
1199 'parallel',
1158 'RemoteFunction',
1200 'RemoteFunction',
1159 'ParallelFunction',
1201 'ParallelFunction',
1160 'DirectView',
1202 'DirectView',
1161 'LoadBalancedView',
1203 'LoadBalancedView',
1162 'AsyncResult',
1204 'AsyncResult',
1163 'AsyncMapResult'
1205 'AsyncMapResult'
1164 ]
1206 ]
@@ -1,90 +1,111 b''
1 """Dependency utilities"""
1 """Dependency utilities"""
2
2
3 from IPython.external.decorator import decorator
3 from IPython.external.decorator import decorator
4 from error import UnmetDependency
5
4
6
5 # flags
7 # flags
6 ALL = 1 << 0
8 ALL = 1 << 0
7 ANY = 1 << 1
9 ANY = 1 << 1
8 HERE = 1 << 2
10 HERE = 1 << 2
9 ANYWHERE = 1 << 3
11 ANYWHERE = 1 << 3
10
12
11 class UnmetDependency(Exception):
12 pass
13
14
13
15 class depend(object):
14 class depend(object):
16 """Dependency decorator, for use with tasks."""
15 """Dependency decorator, for use with tasks."""
17 def __init__(self, f, *args, **kwargs):
16 def __init__(self, f, *args, **kwargs):
18 self.f = f
17 self.f = f
19 self.args = args
18 self.args = args
20 self.kwargs = kwargs
19 self.kwargs = kwargs
21
20
22 def __call__(self, f):
21 def __call__(self, f):
23 return dependent(f, self.f, *self.args, **self.kwargs)
22 return dependent(f, self.f, *self.args, **self.kwargs)
24
23
25 class dependent(object):
24 class dependent(object):
26 """A function that depends on another function.
25 """A function that depends on another function.
27 This is an object to prevent the closure used
26 This is an object to prevent the closure used
28 in traditional decorators, which are not picklable.
27 in traditional decorators, which are not picklable.
29 """
28 """
30
29
31 def __init__(self, f, df, *dargs, **dkwargs):
30 def __init__(self, f, df, *dargs, **dkwargs):
32 self.f = f
31 self.f = f
33 self.func_name = self.f.func_name
32 self.func_name = getattr(f, '__name__', 'f')
34 self.df = df
33 self.df = df
35 self.dargs = dargs
34 self.dargs = dargs
36 self.dkwargs = dkwargs
35 self.dkwargs = dkwargs
37
36
38 def __call__(self, *args, **kwargs):
37 def __call__(self, *args, **kwargs):
39 if self.df(*self.dargs, **self.dkwargs) is False:
38 if self.df(*self.dargs, **self.dkwargs) is False:
40 raise UnmetDependency()
39 raise UnmetDependency()
41 return self.f(*args, **kwargs)
40 return self.f(*args, **kwargs)
41
42 @property
43 def __name__(self):
44 return self.func_name
42
45
43 def _require(*names):
46 def _require(*names):
44 for name in names:
47 for name in names:
45 try:
48 try:
46 __import__(name)
49 __import__(name)
47 except ImportError:
50 except ImportError:
48 return False
51 return False
49 return True
52 return True
50
53
51 def require(*names):
54 def require(*names):
52 return depend(_require, *names)
55 return depend(_require, *names)
53
56
54 class Dependency(set):
57 class Dependency(set):
55 """An object for representing a set of dependencies.
58 """An object for representing a set of dependencies.
56
59
57 Subclassed from set()."""
60 Subclassed from set()."""
58
61
59 mode='all'
62 mode='all'
63 success_only=True
60
64
61 def __init__(self, dependencies=[], mode='all'):
65 def __init__(self, dependencies=[], mode='all', success_only=True):
62 if isinstance(dependencies, dict):
66 if isinstance(dependencies, dict):
63 # load from dict
67 # load from dict
64 dependencies = dependencies.get('dependencies', [])
65 mode = dependencies.get('mode', mode)
68 mode = dependencies.get('mode', mode)
69 success_only = dependencies.get('success_only', success_only)
70 dependencies = dependencies.get('dependencies', [])
66 set.__init__(self, dependencies)
71 set.__init__(self, dependencies)
67 self.mode = mode.lower()
72 self.mode = mode.lower()
73 self.success_only=success_only
68 if self.mode not in ('any', 'all'):
74 if self.mode not in ('any', 'all'):
69 raise NotImplementedError("Only any|all supported, not %r"%mode)
75 raise NotImplementedError("Only any|all supported, not %r"%mode)
70
76
71 def check(self, completed):
77 def check(self, completed, failed=None):
78 if failed is not None and not self.success_only:
79 completed = completed.union(failed)
72 if len(self) == 0:
80 if len(self) == 0:
73 return True
81 return True
74 if self.mode == 'all':
82 if self.mode == 'all':
75 return self.issubset(completed)
83 return self.issubset(completed)
76 elif self.mode == 'any':
84 elif self.mode == 'any':
77 return not self.isdisjoint(completed)
85 return not self.isdisjoint(completed)
78 else:
86 else:
79 raise NotImplementedError("Only any|all supported, not %r"%mode)
87 raise NotImplementedError("Only any|all supported, not %r"%mode)
80
88
89 def unreachable(self, failed):
90 if len(self) == 0 or len(failed) == 0 or not self.success_only:
91 return False
92 print self, self.success_only, self.mode, failed
93 if self.mode == 'all':
94 return not self.isdisjoint(failed)
95 elif self.mode == 'any':
96 return self.issubset(failed)
97 else:
98 raise NotImplementedError("Only any|all supported, not %r"%mode)
99
100
81 def as_dict(self):
101 def as_dict(self):
82 """Represent this dependency as a dict. For json compatibility."""
102 """Represent this dependency as a dict. For json compatibility."""
83 return dict(
103 return dict(
84 dependencies=list(self),
104 dependencies=list(self),
85 mode=self.mode
105 mode=self.mode,
106 success_only=self.success_only,
86 )
107 )
87
108
88
109
89 __all__ = ['UnmetDependency', 'depend', 'require', 'Dependency']
110 __all__ = ['depend', 'require', 'Dependency']
90
111
@@ -1,283 +1,289 b''
1 # encoding: utf-8
1 # encoding: utf-8
2
2
3 """Classes and functions for kernel related errors and exceptions."""
3 """Classes and functions for kernel related errors and exceptions."""
4 from __future__ import print_function
4 from __future__ import print_function
5
5
6 __docformat__ = "restructuredtext en"
6 __docformat__ = "restructuredtext en"
7
7
8 # Tell nose to skip this module
8 # Tell nose to skip this module
9 __test__ = {}
9 __test__ = {}
10
10
11 #-------------------------------------------------------------------------------
11 #-------------------------------------------------------------------------------
12 # Copyright (C) 2008 The IPython Development Team
12 # Copyright (C) 2008 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-------------------------------------------------------------------------------
16 #-------------------------------------------------------------------------------
17
17
18 #-------------------------------------------------------------------------------
18 #-------------------------------------------------------------------------------
19 # Error classes
19 # Error classes
20 #-------------------------------------------------------------------------------
20 #-------------------------------------------------------------------------------
21 class IPythonError(Exception):
21 class IPythonError(Exception):
22 """Base exception that all of our exceptions inherit from.
22 """Base exception that all of our exceptions inherit from.
23
23
24 This can be raised by code that doesn't have any more specific
24 This can be raised by code that doesn't have any more specific
25 information."""
25 information."""
26
26
27 pass
27 pass
28
28
29 # Exceptions associated with the controller objects
29 # Exceptions associated with the controller objects
30 class ControllerError(IPythonError): pass
30 class ControllerError(IPythonError): pass
31
31
32 class ControllerCreationError(ControllerError): pass
32 class ControllerCreationError(ControllerError): pass
33
33
34
34
35 # Exceptions associated with the Engines
35 # Exceptions associated with the Engines
36 class EngineError(IPythonError): pass
36 class EngineError(IPythonError): pass
37
37
38 class EngineCreationError(EngineError): pass
38 class EngineCreationError(EngineError): pass
39
39
40 class KernelError(IPythonError):
40 class KernelError(IPythonError):
41 pass
41 pass
42
42
43 class NotDefined(KernelError):
43 class NotDefined(KernelError):
44 def __init__(self, name):
44 def __init__(self, name):
45 self.name = name
45 self.name = name
46 self.args = (name,)
46 self.args = (name,)
47
47
48 def __repr__(self):
48 def __repr__(self):
49 return '<NotDefined: %s>' % self.name
49 return '<NotDefined: %s>' % self.name
50
50
51 __str__ = __repr__
51 __str__ = __repr__
52
52
53
53
54 class QueueCleared(KernelError):
54 class QueueCleared(KernelError):
55 pass
55 pass
56
56
57
57
58 class IdInUse(KernelError):
58 class IdInUse(KernelError):
59 pass
59 pass
60
60
61
61
62 class ProtocolError(KernelError):
62 class ProtocolError(KernelError):
63 pass
63 pass
64
64
65
65
66 class ConnectionError(KernelError):
66 class ConnectionError(KernelError):
67 pass
67 pass
68
68
69
69
70 class InvalidEngineID(KernelError):
70 class InvalidEngineID(KernelError):
71 pass
71 pass
72
72
73
73
74 class NoEnginesRegistered(KernelError):
74 class NoEnginesRegistered(KernelError):
75 pass
75 pass
76
76
77
77
78 class InvalidClientID(KernelError):
78 class InvalidClientID(KernelError):
79 pass
79 pass
80
80
81
81
82 class InvalidDeferredID(KernelError):
82 class InvalidDeferredID(KernelError):
83 pass
83 pass
84
84
85
85
86 class SerializationError(KernelError):
86 class SerializationError(KernelError):
87 pass
87 pass
88
88
89
89
90 class MessageSizeError(KernelError):
90 class MessageSizeError(KernelError):
91 pass
91 pass
92
92
93
93
94 class PBMessageSizeError(MessageSizeError):
94 class PBMessageSizeError(MessageSizeError):
95 pass
95 pass
96
96
97
97
98 class ResultNotCompleted(KernelError):
98 class ResultNotCompleted(KernelError):
99 pass
99 pass
100
100
101
101
102 class ResultAlreadyRetrieved(KernelError):
102 class ResultAlreadyRetrieved(KernelError):
103 pass
103 pass
104
104
105 class ClientError(KernelError):
105 class ClientError(KernelError):
106 pass
106 pass
107
107
108
108
109 class TaskAborted(KernelError):
109 class TaskAborted(KernelError):
110 pass
110 pass
111
111
112
112
113 class TaskTimeout(KernelError):
113 class TaskTimeout(KernelError):
114 pass
114 pass
115
115
116
116
117 class NotAPendingResult(KernelError):
117 class NotAPendingResult(KernelError):
118 pass
118 pass
119
119
120
120
121 class UnpickleableException(KernelError):
121 class UnpickleableException(KernelError):
122 pass
122 pass
123
123
124
124
125 class AbortedPendingDeferredError(KernelError):
125 class AbortedPendingDeferredError(KernelError):
126 pass
126 pass
127
127
128
128
129 class InvalidProperty(KernelError):
129 class InvalidProperty(KernelError):
130 pass
130 pass
131
131
132
132
133 class MissingBlockArgument(KernelError):
133 class MissingBlockArgument(KernelError):
134 pass
134 pass
135
135
136
136
137 class StopLocalExecution(KernelError):
137 class StopLocalExecution(KernelError):
138 pass
138 pass
139
139
140
140
141 class SecurityError(KernelError):
141 class SecurityError(KernelError):
142 pass
142 pass
143
143
144
144
145 class FileTimeoutError(KernelError):
145 class FileTimeoutError(KernelError):
146 pass
146 pass
147
147
148 class TimeoutError(KernelError):
148 class TimeoutError(KernelError):
149 pass
149 pass
150
150
151 class UnmetDependency(KernelError):
152 pass
153
154 class ImpossibleDependency(UnmetDependency):
155 pass
156
151 class RemoteError(KernelError):
157 class RemoteError(KernelError):
152 """Error raised elsewhere"""
158 """Error raised elsewhere"""
153 ename=None
159 ename=None
154 evalue=None
160 evalue=None
155 traceback=None
161 traceback=None
156 engine_info=None
162 engine_info=None
157
163
158 def __init__(self, ename, evalue, traceback, engine_info=None):
164 def __init__(self, ename, evalue, traceback, engine_info=None):
159 self.ename=ename
165 self.ename=ename
160 self.evalue=evalue
166 self.evalue=evalue
161 self.traceback=traceback
167 self.traceback=traceback
162 self.engine_info=engine_info or {}
168 self.engine_info=engine_info or {}
163 self.args=(ename, evalue)
169 self.args=(ename, evalue)
164
170
165 def __repr__(self):
171 def __repr__(self):
166 engineid = self.engine_info.get('engineid', ' ')
172 engineid = self.engine_info.get('engineid', ' ')
167 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
173 return "<Remote[%s]:%s(%s)>"%(engineid, self.ename, self.evalue)
168
174
169 def __str__(self):
175 def __str__(self):
170 sig = "%s(%s)"%(self.ename, self.evalue)
176 sig = "%s(%s)"%(self.ename, self.evalue)
171 if self.traceback:
177 if self.traceback:
172 return sig + '\n' + self.traceback
178 return sig + '\n' + self.traceback
173 else:
179 else:
174 return sig
180 return sig
175
181
176
182
177 class TaskRejectError(KernelError):
183 class TaskRejectError(KernelError):
178 """Exception to raise when a task should be rejected by an engine.
184 """Exception to raise when a task should be rejected by an engine.
179
185
180 This exception can be used to allow a task running on an engine to test
186 This exception can be used to allow a task running on an engine to test
181 if the engine (or the user's namespace on the engine) has the needed
187 if the engine (or the user's namespace on the engine) has the needed
182 task dependencies. If not, the task should raise this exception. For
188 task dependencies. If not, the task should raise this exception. For
183 the task to be retried on another engine, the task should be created
189 the task to be retried on another engine, the task should be created
184 with the `retries` argument > 1.
190 with the `retries` argument > 1.
185
191
186 The advantage of this approach over our older properties system is that
192 The advantage of this approach over our older properties system is that
187 tasks have full access to the user's namespace on the engines and the
193 tasks have full access to the user's namespace on the engines and the
188 properties don't have to be managed or tested by the controller.
194 properties don't have to be managed or tested by the controller.
189 """
195 """
190
196
191
197
192 class CompositeError(KernelError):
198 class CompositeError(KernelError):
193 """Error for representing possibly multiple errors on engines"""
199 """Error for representing possibly multiple errors on engines"""
194 def __init__(self, message, elist):
200 def __init__(self, message, elist):
195 Exception.__init__(self, *(message, elist))
201 Exception.__init__(self, *(message, elist))
196 # Don't use pack_exception because it will conflict with the .message
202 # Don't use pack_exception because it will conflict with the .message
197 # attribute that is being deprecated in 2.6 and beyond.
203 # attribute that is being deprecated in 2.6 and beyond.
198 self.msg = message
204 self.msg = message
199 self.elist = elist
205 self.elist = elist
200 self.args = [ e[0] for e in elist ]
206 self.args = [ e[0] for e in elist ]
201
207
202 def _get_engine_str(self, ei):
208 def _get_engine_str(self, ei):
203 if not ei:
209 if not ei:
204 return '[Engine Exception]'
210 return '[Engine Exception]'
205 else:
211 else:
206 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
212 return '[%i:%s]: ' % (ei['engineid'], ei['method'])
207
213
208 def _get_traceback(self, ev):
214 def _get_traceback(self, ev):
209 try:
215 try:
210 tb = ev._ipython_traceback_text
216 tb = ev._ipython_traceback_text
211 except AttributeError:
217 except AttributeError:
212 return 'No traceback available'
218 return 'No traceback available'
213 else:
219 else:
214 return tb
220 return tb
215
221
216 def __str__(self):
222 def __str__(self):
217 s = str(self.msg)
223 s = str(self.msg)
218 for en, ev, etb, ei in self.elist:
224 for en, ev, etb, ei in self.elist:
219 engine_str = self._get_engine_str(ei)
225 engine_str = self._get_engine_str(ei)
220 s = s + '\n' + engine_str + en + ': ' + str(ev)
226 s = s + '\n' + engine_str + en + ': ' + str(ev)
221 return s
227 return s
222
228
223 def __repr__(self):
229 def __repr__(self):
224 return "CompositeError(%i)"%len(self.elist)
230 return "CompositeError(%i)"%len(self.elist)
225
231
226 def print_tracebacks(self, excid=None):
232 def print_tracebacks(self, excid=None):
227 if excid is None:
233 if excid is None:
228 for (en,ev,etb,ei) in self.elist:
234 for (en,ev,etb,ei) in self.elist:
229 print (self._get_engine_str(ei))
235 print (self._get_engine_str(ei))
230 print (etb or 'No traceback available')
236 print (etb or 'No traceback available')
231 print ()
237 print ()
232 else:
238 else:
233 try:
239 try:
234 en,ev,etb,ei = self.elist[excid]
240 en,ev,etb,ei = self.elist[excid]
235 except:
241 except:
236 raise IndexError("an exception with index %i does not exist"%excid)
242 raise IndexError("an exception with index %i does not exist"%excid)
237 else:
243 else:
238 print (self._get_engine_str(ei))
244 print (self._get_engine_str(ei))
239 print (etb or 'No traceback available')
245 print (etb or 'No traceback available')
240
246
241 def raise_exception(self, excid=0):
247 def raise_exception(self, excid=0):
242 try:
248 try:
243 en,ev,etb,ei = self.elist[excid]
249 en,ev,etb,ei = self.elist[excid]
244 except:
250 except:
245 raise IndexError("an exception with index %i does not exist"%excid)
251 raise IndexError("an exception with index %i does not exist"%excid)
246 else:
252 else:
247 try:
253 try:
248 raise RemoteError(en, ev, etb, ei)
254 raise RemoteError(en, ev, etb, ei)
249 except:
255 except:
250 et,ev,tb = sys.exc_info()
256 et,ev,tb = sys.exc_info()
251
257
252
258
253 def collect_exceptions(rdict_or_list, method='unspecified'):
259 def collect_exceptions(rdict_or_list, method='unspecified'):
254 """check a result dict for errors, and raise CompositeError if any exist.
260 """check a result dict for errors, and raise CompositeError if any exist.
255 Passthrough otherwise."""
261 Passthrough otherwise."""
256 elist = []
262 elist = []
257 if isinstance(rdict_or_list, dict):
263 if isinstance(rdict_or_list, dict):
258 rlist = rdict_or_list.values()
264 rlist = rdict_or_list.values()
259 else:
265 else:
260 rlist = rdict_or_list
266 rlist = rdict_or_list
261 for r in rlist:
267 for r in rlist:
262 if isinstance(r, RemoteError):
268 if isinstance(r, RemoteError):
263 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
269 en, ev, etb, ei = r.ename, r.evalue, r.traceback, r.engine_info
264 # Sometimes we could have CompositeError in our list. Just take
270 # Sometimes we could have CompositeError in our list. Just take
265 # the errors out of them and put them in our new list. This
271 # the errors out of them and put them in our new list. This
266 # has the effect of flattening lists of CompositeErrors into one
272 # has the effect of flattening lists of CompositeErrors into one
267 # CompositeError
273 # CompositeError
268 if en=='CompositeError':
274 if en=='CompositeError':
269 for e in ev.elist:
275 for e in ev.elist:
270 elist.append(e)
276 elist.append(e)
271 else:
277 else:
272 elist.append((en, ev, etb, ei))
278 elist.append((en, ev, etb, ei))
273 if len(elist)==0:
279 if len(elist)==0:
274 return rdict_or_list
280 return rdict_or_list
275 else:
281 else:
276 msg = "one or more exceptions from call to method: %s" % (method)
282 msg = "one or more exceptions from call to method: %s" % (method)
277 # This silliness is needed so the debugger has access to the exception
283 # This silliness is needed so the debugger has access to the exception
278 # instance (e in this case)
284 # instance (e in this case)
279 try:
285 try:
280 raise CompositeError(msg, elist)
286 raise CompositeError(msg, elist)
281 except CompositeError, e:
287 except CompositeError, e:
282 raise e
288 raise e
283
289
@@ -1,426 +1,509 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6 """
6 """
7
7
8 #----------------------------------------------------------------------
8 #----------------------------------------------------------------------
9 # Imports
9 # Imports
10 #----------------------------------------------------------------------
10 #----------------------------------------------------------------------
11
11
12 from __future__ import print_function
12 from __future__ import print_function
13 from random import randint,random
13 from random import randint,random
14 import logging
14 import logging
15 from types import FunctionType
15 from types import FunctionType
16
16
17 try:
17 try:
18 import numpy
18 import numpy
19 except ImportError:
19 except ImportError:
20 numpy = None
20 numpy = None
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 # local imports
25 # local imports
26 from IPython.external.decorator import decorator
26 from IPython.external.decorator import decorator
27 from IPython.config.configurable import Configurable
27 from IPython.config.configurable import Configurable
28 from IPython.utils.traitlets import Instance, Dict, List, Set
28 from IPython.utils.traitlets import Instance, Dict, List, Set
29
29
30 import error
30 from client import Client
31 from client import Client
31 from dependency import Dependency
32 from dependency import Dependency
32 import streamsession as ss
33 import streamsession as ss
33 from entry_point import connect_logger, local_logger
34 from entry_point import connect_logger, local_logger
34
35
35
36
36 @decorator
37 @decorator
37 def logged(f,self,*args,**kwargs):
38 def logged(f,self,*args,**kwargs):
38 # print ("#--------------------")
39 # print ("#--------------------")
39 logging.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
40 logging.debug("scheduler::%s(*%s,**%s)"%(f.func_name, args, kwargs))
40 # print ("#--")
41 # print ("#--")
41 return f(self,*args, **kwargs)
42 return f(self,*args, **kwargs)
42
43
43 #----------------------------------------------------------------------
44 #----------------------------------------------------------------------
44 # Chooser functions
45 # Chooser functions
45 #----------------------------------------------------------------------
46 #----------------------------------------------------------------------
46
47
47 def plainrandom(loads):
48 def plainrandom(loads):
48 """Plain random pick."""
49 """Plain random pick."""
49 n = len(loads)
50 n = len(loads)
50 return randint(0,n-1)
51 return randint(0,n-1)
51
52
52 def lru(loads):
53 def lru(loads):
53 """Always pick the front of the line.
54 """Always pick the front of the line.
54
55
55 The content of `loads` is ignored.
56 The content of `loads` is ignored.
56
57
57 Assumes LRU ordering of loads, with oldest first.
58 Assumes LRU ordering of loads, with oldest first.
58 """
59 """
59 return 0
60 return 0
60
61
61 def twobin(loads):
62 def twobin(loads):
62 """Pick two at random, use the LRU of the two.
63 """Pick two at random, use the LRU of the two.
63
64
64 The content of loads is ignored.
65 The content of loads is ignored.
65
66
66 Assumes LRU ordering of loads, with oldest first.
67 Assumes LRU ordering of loads, with oldest first.
67 """
68 """
68 n = len(loads)
69 n = len(loads)
69 a = randint(0,n-1)
70 a = randint(0,n-1)
70 b = randint(0,n-1)
71 b = randint(0,n-1)
71 return min(a,b)
72 return min(a,b)
72
73
73 def weighted(loads):
74 def weighted(loads):
74 """Pick two at random using inverse load as weight.
75 """Pick two at random using inverse load as weight.
75
76
76 Return the less loaded of the two.
77 Return the less loaded of the two.
77 """
78 """
78 # weight 0 a million times more than 1:
79 # weight 0 a million times more than 1:
79 weights = 1./(1e-6+numpy.array(loads))
80 weights = 1./(1e-6+numpy.array(loads))
80 sums = weights.cumsum()
81 sums = weights.cumsum()
81 t = sums[-1]
82 t = sums[-1]
82 x = random()*t
83 x = random()*t
83 y = random()*t
84 y = random()*t
84 idx = 0
85 idx = 0
85 idy = 0
86 idy = 0
86 while sums[idx] < x:
87 while sums[idx] < x:
87 idx += 1
88 idx += 1
88 while sums[idy] < y:
89 while sums[idy] < y:
89 idy += 1
90 idy += 1
90 if weights[idy] > weights[idx]:
91 if weights[idy] > weights[idx]:
91 return idy
92 return idy
92 else:
93 else:
93 return idx
94 return idx
94
95
95 def leastload(loads):
96 def leastload(loads):
96 """Always choose the lowest load.
97 """Always choose the lowest load.
97
98
98 If the lowest load occurs more than once, the first
99 If the lowest load occurs more than once, the first
99 occurance will be used. If loads has LRU ordering, this means
100 occurance will be used. If loads has LRU ordering, this means
100 the LRU of those with the lowest load is chosen.
101 the LRU of those with the lowest load is chosen.
101 """
102 """
102 return loads.index(min(loads))
103 return loads.index(min(loads))
103
104
104 #---------------------------------------------------------------------
105 #---------------------------------------------------------------------
105 # Classes
106 # Classes
106 #---------------------------------------------------------------------
107 #---------------------------------------------------------------------
108 # store empty default dependency:
109 MET = Dependency([])
110
107 class TaskScheduler(Configurable):
111 class TaskScheduler(Configurable):
108 """Python TaskScheduler object.
112 """Python TaskScheduler object.
109
113
110 This is the simplest object that supports msg_id based
114 This is the simplest object that supports msg_id based
111 DAG dependencies. *Only* task msg_ids are checked, not
115 DAG dependencies. *Only* task msg_ids are checked, not
112 msg_ids of jobs submitted via the MUX queue.
116 msg_ids of jobs submitted via the MUX queue.
113
117
114 """
118 """
115
119
116 # input arguments:
120 # input arguments:
117 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
121 scheme = Instance(FunctionType, default=leastload) # function for determining the destination
118 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
122 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
119 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
123 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
120 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
124 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
121 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
125 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
122 io_loop = Instance(ioloop.IOLoop)
126 io_loop = Instance(ioloop.IOLoop)
123
127
124 # internals:
128 # internals:
125 dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
129 dependencies = Dict() # dict by msg_id of [ msg_ids that depend on key ]
126 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
130 depending = Dict() # dict by msg_id of (msg_id, raw_msg, after, follow)
127 pending = Dict() # dict by engine_uuid of submitted tasks
131 pending = Dict() # dict by engine_uuid of submitted tasks
128 completed = Dict() # dict by engine_uuid of completed tasks
132 completed = Dict() # dict by engine_uuid of completed tasks
133 failed = Dict() # dict by engine_uuid of failed tasks
134 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
129 clients = Dict() # dict by msg_id for who submitted the task
135 clients = Dict() # dict by msg_id for who submitted the task
130 targets = List() # list of target IDENTs
136 targets = List() # list of target IDENTs
131 loads = List() # list of engine loads
137 loads = List() # list of engine loads
132 all_done = Set() # set of all completed tasks
138 all_completed = Set() # set of all completed tasks
139 all_failed = Set() # set of all failed tasks
140 all_done = Set() # set of all finished tasks=union(completed,failed)
133 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
141 blacklist = Dict() # dict by msg_id of locations where a job has encountered UnmetDependency
134 session = Instance(ss.StreamSession)
142 session = Instance(ss.StreamSession)
135
143
136
144
137 def __init__(self, **kwargs):
145 def __init__(self, **kwargs):
138 super(TaskScheduler, self).__init__(**kwargs)
146 super(TaskScheduler, self).__init__(**kwargs)
139
147
140 self.session = ss.StreamSession(username="TaskScheduler")
148 self.session = ss.StreamSession(username="TaskScheduler")
141
149
142 self.engine_stream.on_recv(self.dispatch_result, copy=False)
150 self.engine_stream.on_recv(self.dispatch_result, copy=False)
143 self._notification_handlers = dict(
151 self._notification_handlers = dict(
144 registration_notification = self._register_engine,
152 registration_notification = self._register_engine,
145 unregistration_notification = self._unregister_engine
153 unregistration_notification = self._unregister_engine
146 )
154 )
147 self.notifier_stream.on_recv(self.dispatch_notification)
155 self.notifier_stream.on_recv(self.dispatch_notification)
148 logging.info("Scheduler started...%r"%self)
156 logging.info("Scheduler started...%r"%self)
149
157
150 def resume_receiving(self):
158 def resume_receiving(self):
151 """Resume accepting jobs."""
159 """Resume accepting jobs."""
152 self.client_stream.on_recv(self.dispatch_submission, copy=False)
160 self.client_stream.on_recv(self.dispatch_submission, copy=False)
153
161
154 def stop_receiving(self):
162 def stop_receiving(self):
155 """Stop accepting jobs while there are no engines.
163 """Stop accepting jobs while there are no engines.
156 Leave them in the ZMQ queue."""
164 Leave them in the ZMQ queue."""
157 self.client_stream.on_recv(None)
165 self.client_stream.on_recv(None)
158
166
159 #-----------------------------------------------------------------------
167 #-----------------------------------------------------------------------
160 # [Un]Registration Handling
168 # [Un]Registration Handling
161 #-----------------------------------------------------------------------
169 #-----------------------------------------------------------------------
162
170
163 def dispatch_notification(self, msg):
171 def dispatch_notification(self, msg):
164 """dispatch register/unregister events."""
172 """dispatch register/unregister events."""
165 idents,msg = self.session.feed_identities(msg)
173 idents,msg = self.session.feed_identities(msg)
166 msg = self.session.unpack_message(msg)
174 msg = self.session.unpack_message(msg)
167 msg_type = msg['msg_type']
175 msg_type = msg['msg_type']
168 handler = self._notification_handlers.get(msg_type, None)
176 handler = self._notification_handlers.get(msg_type, None)
169 if handler is None:
177 if handler is None:
170 raise Exception("Unhandled message type: %s"%msg_type)
178 raise Exception("Unhandled message type: %s"%msg_type)
171 else:
179 else:
172 try:
180 try:
173 handler(str(msg['content']['queue']))
181 handler(str(msg['content']['queue']))
174 except KeyError:
182 except KeyError:
175 logging.error("task::Invalid notification msg: %s"%msg)
183 logging.error("task::Invalid notification msg: %s"%msg)
176
184
177 @logged
185 @logged
178 def _register_engine(self, uid):
186 def _register_engine(self, uid):
179 """New engine with ident `uid` became available."""
187 """New engine with ident `uid` became available."""
180 # head of the line:
188 # head of the line:
181 self.targets.insert(0,uid)
189 self.targets.insert(0,uid)
182 self.loads.insert(0,0)
190 self.loads.insert(0,0)
183 # initialize sets
191 # initialize sets
184 self.completed[uid] = set()
192 self.completed[uid] = set()
193 self.failed[uid] = set()
185 self.pending[uid] = {}
194 self.pending[uid] = {}
186 if len(self.targets) == 1:
195 if len(self.targets) == 1:
187 self.resume_receiving()
196 self.resume_receiving()
188
197
189 def _unregister_engine(self, uid):
198 def _unregister_engine(self, uid):
190 """Existing engine with ident `uid` became unavailable."""
199 """Existing engine with ident `uid` became unavailable."""
191 if len(self.targets) == 1:
200 if len(self.targets) == 1:
192 # this was our only engine
201 # this was our only engine
193 self.stop_receiving()
202 self.stop_receiving()
194
203
195 # handle any potentially finished tasks:
204 # handle any potentially finished tasks:
196 self.engine_stream.flush()
205 self.engine_stream.flush()
197
206
198 self.completed.pop(uid)
207 self.completed.pop(uid)
208 self.failed.pop(uid)
209 # don't pop destinations, because it might be used later
210 # map(self.destinations.pop, self.completed.pop(uid))
211 # map(self.destinations.pop, self.failed.pop(uid))
212
199 lost = self.pending.pop(uid)
213 lost = self.pending.pop(uid)
200
214
201 idx = self.targets.index(uid)
215 idx = self.targets.index(uid)
202 self.targets.pop(idx)
216 self.targets.pop(idx)
203 self.loads.pop(idx)
217 self.loads.pop(idx)
204
218
205 self.handle_stranded_tasks(lost)
219 self.handle_stranded_tasks(lost)
206
220
207 def handle_stranded_tasks(self, lost):
221 def handle_stranded_tasks(self, lost):
208 """Deal with jobs resident in an engine that died."""
222 """Deal with jobs resident in an engine that died."""
209 # TODO: resubmit the tasks?
223 # TODO: resubmit the tasks?
210 for msg_id in lost:
224 for msg_id in lost:
211 pass
225 pass
212
226
213
227
214 #-----------------------------------------------------------------------
228 #-----------------------------------------------------------------------
215 # Job Submission
229 # Job Submission
216 #-----------------------------------------------------------------------
230 #-----------------------------------------------------------------------
217 @logged
231 @logged
218 def dispatch_submission(self, raw_msg):
232 def dispatch_submission(self, raw_msg):
219 """Dispatch job submission to appropriate handlers."""
233 """Dispatch job submission to appropriate handlers."""
220 # ensure targets up to date:
234 # ensure targets up to date:
221 self.notifier_stream.flush()
235 self.notifier_stream.flush()
222 try:
236 try:
223 idents, msg = self.session.feed_identities(raw_msg, copy=False)
237 idents, msg = self.session.feed_identities(raw_msg, copy=False)
224 except Exception as e:
238 except Exception as e:
225 logging.error("task::Invaid msg: %s"%msg)
239 logging.error("task::Invaid msg: %s"%msg)
226 return
240 return
227
241
228 # send to monitor
242 # send to monitor
229 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
243 self.mon_stream.send_multipart(['intask']+raw_msg, copy=False)
230
244
231 msg = self.session.unpack_message(msg, content=False, copy=False)
245 msg = self.session.unpack_message(msg, content=False, copy=False)
232 header = msg['header']
246 header = msg['header']
233 msg_id = header['msg_id']
247 msg_id = header['msg_id']
234
248
235 # time dependencies
249 # time dependencies
236 after = Dependency(header.get('after', []))
250 after = Dependency(header.get('after', []))
237 if after.mode == 'all':
251 if after.mode == 'all':
238 after.difference_update(self.all_done)
252 after.difference_update(self.all_completed)
239 if after.check(self.all_done):
253 if not after.success_only:
254 after.difference_update(self.all_failed)
255 if after.check(self.all_completed, self.all_failed):
240 # recast as empty set, if `after` already met,
256 # recast as empty set, if `after` already met,
241 # to prevent unnecessary set comparisons
257 # to prevent unnecessary set comparisons
242 after = Dependency([])
258 after = MET
243
259
244 # location dependencies
260 # location dependencies
245 follow = Dependency(header.get('follow', []))
261 follow = Dependency(header.get('follow', []))
246 if len(after) == 0:
262
263 # check if unreachable:
264 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
265 self.depending[msg_id] = [raw_msg,MET,MET]
266 return self.fail_unreachable(msg_id)
267
268 if after.check(self.all_completed, self.all_failed):
247 # time deps already met, try to run
269 # time deps already met, try to run
248 if not self.maybe_run(msg_id, raw_msg, follow):
270 if not self.maybe_run(msg_id, raw_msg, follow):
249 # can't run yet
271 # can't run yet
250 self.save_unmet(msg_id, raw_msg, after, follow)
272 self.save_unmet(msg_id, raw_msg, after, follow)
251 else:
273 else:
252 self.save_unmet(msg_id, raw_msg, after, follow)
274 self.save_unmet(msg_id, raw_msg, after, follow)
253
275
254 @logged
276 @logged
277 def fail_unreachable(self, msg_id):
278 """a message has become unreachable"""
279 if msg_id not in self.depending:
280 logging.error("msg %r already failed!"%msg_id)
281 return
282 raw_msg, after, follow = self.depending.pop(msg_id)
283 for mid in follow.union(after):
284 if mid in self.dependencies:
285 self.dependencies[mid].remove(msg_id)
286
287 idents,msg = self.session.feed_identities(raw_msg, copy=False)
288 msg = self.session.unpack_message(msg, copy=False, content=False)
289 header = msg['header']
290
291 try:
292 raise error.ImpossibleDependency()
293 except:
294 content = ss.wrap_exception()
295
296 self.all_done.add(msg_id)
297 self.all_failed.add(msg_id)
298
299 msg = self.session.send(self.client_stream, 'apply_reply', content,
300 parent=header, ident=idents)
301 self.session.send(self.mon_stream, msg, ident=['outtask']+idents)
302
303 self.update_dependencies(msg_id, success=False)
304
305 @logged
255 def maybe_run(self, msg_id, raw_msg, follow=None):
306 def maybe_run(self, msg_id, raw_msg, follow=None):
256 """check location dependencies, and run if they are met."""
307 """check location dependencies, and run if they are met."""
257
308
258 if follow:
309 if follow:
259 def can_run(idx):
310 def can_run(idx):
260 target = self.targets[idx]
311 target = self.targets[idx]
261 return target not in self.blacklist.get(msg_id, []) and\
312 return target not in self.blacklist.get(msg_id, []) and\
262 follow.check(self.completed[target])
313 follow.check(self.completed[target], self.failed[target])
263
314
264 indices = filter(can_run, range(len(self.targets)))
315 indices = filter(can_run, range(len(self.targets)))
265 if not indices:
316 if not indices:
317 # TODO evaluate unmeetable follow dependencies
318 if follow.mode == 'all':
319 dests = set()
320 relevant = self.all_completed if follow.success_only else self.all_done
321 for m in follow.intersection(relevant):
322 dests.add(self.destinations[m])
323 if len(dests) > 1:
324 self.fail_unreachable(msg_id)
325
326
266 return False
327 return False
267 else:
328 else:
268 indices = None
329 indices = None
269
330
270 self.submit_task(msg_id, raw_msg, indices)
331 self.submit_task(msg_id, raw_msg, indices)
271 return True
332 return True
272
333
273 @logged
334 @logged
274 def save_unmet(self, msg_id, msg, after, follow):
335 def save_unmet(self, msg_id, raw_msg, after, follow):
275 """Save a message for later submission when its dependencies are met."""
336 """Save a message for later submission when its dependencies are met."""
276 self.depending[msg_id] = (msg_id,msg,after,follow)
337 self.depending[msg_id] = [raw_msg,after,follow]
277 # track the ids in both follow/after, but not those already completed
338 # track the ids in follow or after, but not those already finished
278 for dep_id in after.union(follow).difference(self.all_done):
339 for dep_id in after.union(follow).difference(self.all_done):
279 if dep_id not in self.dependencies:
340 if dep_id not in self.dependencies:
280 self.dependencies[dep_id] = set()
341 self.dependencies[dep_id] = set()
281 self.dependencies[dep_id].add(msg_id)
342 self.dependencies[dep_id].add(msg_id)
282
343
283 @logged
344 @logged
284 def submit_task(self, msg_id, msg, follow=None, indices=None):
345 def submit_task(self, msg_id, msg, follow=None, indices=None):
285 """Submit a task to any of a subset of our targets."""
346 """Submit a task to any of a subset of our targets."""
286 if indices:
347 if indices:
287 loads = [self.loads[i] for i in indices]
348 loads = [self.loads[i] for i in indices]
288 else:
349 else:
289 loads = self.loads
350 loads = self.loads
290 idx = self.scheme(loads)
351 idx = self.scheme(loads)
291 if indices:
352 if indices:
292 idx = indices[idx]
353 idx = indices[idx]
293 target = self.targets[idx]
354 target = self.targets[idx]
294 # print (target, map(str, msg[:3]))
355 # print (target, map(str, msg[:3]))
295 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
356 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
296 self.engine_stream.send_multipart(msg, copy=False)
357 self.engine_stream.send_multipart(msg, copy=False)
297 self.add_job(idx)
358 self.add_job(idx)
298 self.pending[target][msg_id] = (msg, follow)
359 self.pending[target][msg_id] = (msg, follow)
299 content = dict(msg_id=msg_id, engine_id=target)
360 content = dict(msg_id=msg_id, engine_id=target)
300 self.session.send(self.mon_stream, 'task_destination', content=content,
361 self.session.send(self.mon_stream, 'task_destination', content=content,
301 ident=['tracktask',self.session.session])
362 ident=['tracktask',self.session.session])
302
363
303 #-----------------------------------------------------------------------
364 #-----------------------------------------------------------------------
304 # Result Handling
365 # Result Handling
305 #-----------------------------------------------------------------------
366 #-----------------------------------------------------------------------
306 @logged
367 @logged
307 def dispatch_result(self, raw_msg):
368 def dispatch_result(self, raw_msg):
308 try:
369 try:
309 idents,msg = self.session.feed_identities(raw_msg, copy=False)
370 idents,msg = self.session.feed_identities(raw_msg, copy=False)
310 except Exception as e:
371 except Exception as e:
311 logging.error("task::Invaid result: %s"%msg)
372 logging.error("task::Invaid result: %s"%msg)
312 return
373 return
313 msg = self.session.unpack_message(msg, content=False, copy=False)
374 msg = self.session.unpack_message(msg, content=False, copy=False)
314 header = msg['header']
375 header = msg['header']
315 if header.get('dependencies_met', True):
376 if header.get('dependencies_met', True):
316 self.handle_result_success(idents, msg['parent_header'], raw_msg)
377 success = (header['status'] == 'ok')
317 # send to monitor
378 self.handle_result(idents, msg['parent_header'], raw_msg, success)
379 # send to Hub monitor
318 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
380 self.mon_stream.send_multipart(['outtask']+raw_msg, copy=False)
319 else:
381 else:
320 self.handle_unmet_dependency(idents, msg['parent_header'])
382 self.handle_unmet_dependency(idents, msg['parent_header'])
321
383
322 @logged
384 @logged
323 def handle_result_success(self, idents, parent, raw_msg):
385 def handle_result(self, idents, parent, raw_msg, success=True):
324 # first, relay result to client
386 # first, relay result to client
325 engine = idents[0]
387 engine = idents[0]
326 client = idents[1]
388 client = idents[1]
327 # swap_ids for XREP-XREP mirror
389 # swap_ids for XREP-XREP mirror
328 raw_msg[:2] = [client,engine]
390 raw_msg[:2] = [client,engine]
329 # print (map(str, raw_msg[:4]))
391 # print (map(str, raw_msg[:4]))
330 self.client_stream.send_multipart(raw_msg, copy=False)
392 self.client_stream.send_multipart(raw_msg, copy=False)
331 # now, update our data structures
393 # now, update our data structures
332 msg_id = parent['msg_id']
394 msg_id = parent['msg_id']
333 self.pending[engine].pop(msg_id)
395 self.pending[engine].pop(msg_id)
334 self.completed[engine].add(msg_id)
396 if success:
397 self.completed[engine].add(msg_id)
398 self.all_completed.add(msg_id)
399 else:
400 self.failed[engine].add(msg_id)
401 self.all_failed.add(msg_id)
335 self.all_done.add(msg_id)
402 self.all_done.add(msg_id)
403 self.destinations[msg_id] = engine
336
404
337 self.update_dependencies(msg_id)
405 self.update_dependencies(msg_id, success)
338
406
339 @logged
407 @logged
340 def handle_unmet_dependency(self, idents, parent):
408 def handle_unmet_dependency(self, idents, parent):
341 engine = idents[0]
409 engine = idents[0]
342 msg_id = parent['msg_id']
410 msg_id = parent['msg_id']
343 if msg_id not in self.blacklist:
411 if msg_id not in self.blacklist:
344 self.blacklist[msg_id] = set()
412 self.blacklist[msg_id] = set()
345 self.blacklist[msg_id].add(engine)
413 self.blacklist[msg_id].add(engine)
346 raw_msg,follow = self.pending[engine].pop(msg_id)
414 raw_msg,follow = self.pending[engine].pop(msg_id)
347 if not self.maybe_run(msg_id, raw_msg, follow):
415 if not self.maybe_run(msg_id, raw_msg, follow):
348 # resubmit failed, put it back in our dependency tree
416 # resubmit failed, put it back in our dependency tree
349 self.save_unmet(msg_id, raw_msg, Dependency(), follow)
417 self.save_unmet(msg_id, raw_msg, MET, follow)
350 pass
418 pass
419
351 @logged
420 @logged
352 def update_dependencies(self, dep_id):
421 def update_dependencies(self, dep_id, success=True):
353 """dep_id just finished. Update our dependency
422 """dep_id just finished. Update our dependency
354 table and submit any jobs that just became runable."""
423 table and submit any jobs that just became runable."""
355
424 # print ("\n\n***********")
425 # pprint (dep_id)
426 # pprint (self.dependencies)
427 # pprint (self.depending)
428 # pprint (self.all_completed)
429 # pprint (self.all_failed)
430 # print ("\n\n***********\n\n")
356 if dep_id not in self.dependencies:
431 if dep_id not in self.dependencies:
357 return
432 return
358 jobs = self.dependencies.pop(dep_id)
433 jobs = self.dependencies.pop(dep_id)
359 for job in jobs:
434
360 msg_id, raw_msg, after, follow = self.depending[job]
435 for msg_id in jobs:
361 if dep_id in after:
436 raw_msg, after, follow = self.depending[msg_id]
362 after.remove(dep_id)
437 # if dep_id in after:
363 if not after: # time deps met, maybe run
438 # if after.mode == 'all' and (success or not after.success_only):
439 # after.remove(dep_id)
440
441 if after.unreachable(self.all_failed) or follow.unreachable(self.all_failed):
442 self.fail_unreachable(msg_id)
443
444 elif after.check(self.all_completed, self.all_failed): # time deps met, maybe run
445 self.depending[msg_id][1] = MET
364 if self.maybe_run(msg_id, raw_msg, follow):
446 if self.maybe_run(msg_id, raw_msg, follow):
365 self.depending.pop(job)
447
366 for mid in follow:
448 self.depending.pop(msg_id)
449 for mid in follow.union(after):
367 if mid in self.dependencies:
450 if mid in self.dependencies:
368 self.dependencies[mid].remove(msg_id)
451 self.dependencies[mid].remove(msg_id)
369
452
370 #----------------------------------------------------------------------
453 #----------------------------------------------------------------------
371 # methods to be overridden by subclasses
454 # methods to be overridden by subclasses
372 #----------------------------------------------------------------------
455 #----------------------------------------------------------------------
373
456
374 def add_job(self, idx):
457 def add_job(self, idx):
375 """Called after self.targets[idx] just got the job with header.
458 """Called after self.targets[idx] just got the job with header.
376 Override with subclasses. The default ordering is simple LRU.
459 Override with subclasses. The default ordering is simple LRU.
377 The default loads are the number of outstanding jobs."""
460 The default loads are the number of outstanding jobs."""
378 self.loads[idx] += 1
461 self.loads[idx] += 1
379 for lis in (self.targets, self.loads):
462 for lis in (self.targets, self.loads):
380 lis.append(lis.pop(idx))
463 lis.append(lis.pop(idx))
381
464
382
465
383 def finish_job(self, idx):
466 def finish_job(self, idx):
384 """Called after self.targets[idx] just finished a job.
467 """Called after self.targets[idx] just finished a job.
385 Override with subclasses."""
468 Override with subclasses."""
386 self.loads[idx] -= 1
469 self.loads[idx] -= 1
387
470
388
471
389
472
390 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, log_addr=None, loglevel=logging.DEBUG, scheme='weighted'):
473 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, log_addr=None, loglevel=logging.DEBUG, scheme='weighted'):
391 from zmq.eventloop import ioloop
474 from zmq.eventloop import ioloop
392 from zmq.eventloop.zmqstream import ZMQStream
475 from zmq.eventloop.zmqstream import ZMQStream
393
476
394 ctx = zmq.Context()
477 ctx = zmq.Context()
395 loop = ioloop.IOLoop()
478 loop = ioloop.IOLoop()
396
479
397 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
480 ins = ZMQStream(ctx.socket(zmq.XREP),loop)
398 ins.bind(in_addr)
481 ins.bind(in_addr)
399 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
482 outs = ZMQStream(ctx.socket(zmq.XREP),loop)
400 outs.bind(out_addr)
483 outs.bind(out_addr)
401 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
484 mons = ZMQStream(ctx.socket(zmq.PUB),loop)
402 mons.connect(mon_addr)
485 mons.connect(mon_addr)
403 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
486 nots = ZMQStream(ctx.socket(zmq.SUB),loop)
404 nots.setsockopt(zmq.SUBSCRIBE, '')
487 nots.setsockopt(zmq.SUBSCRIBE, '')
405 nots.connect(not_addr)
488 nots.connect(not_addr)
406
489
407 scheme = globals().get(scheme, None)
490 scheme = globals().get(scheme, None)
408 # setup logging
491 # setup logging
409 if log_addr:
492 if log_addr:
410 connect_logger(ctx, log_addr, root="scheduler", loglevel=loglevel)
493 connect_logger(ctx, log_addr, root="scheduler", loglevel=loglevel)
411 else:
494 else:
412 local_logger(loglevel)
495 local_logger(loglevel)
413
496
414 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
497 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
415 mon_stream=mons,notifier_stream=nots,
498 mon_stream=mons,notifier_stream=nots,
416 scheme=scheme,io_loop=loop)
499 scheme=scheme,io_loop=loop)
417
500
418 try:
501 try:
419 loop.start()
502 loop.start()
420 except KeyboardInterrupt:
503 except KeyboardInterrupt:
421 print ("interrupted, exiting...", file=sys.__stderr__)
504 print ("interrupted, exiting...", file=sys.__stderr__)
422
505
423
506
424 if __name__ == '__main__':
507 if __name__ == '__main__':
425 iface = 'tcp://127.0.0.1:%i'
508 iface = 'tcp://127.0.0.1:%i'
426 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
509 launch_scheduler(iface%12345,iface%1236,iface%12347,iface%12348)
@@ -1,490 +1,483 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 Kernel adapted from kernel.py to use ZMQ Streams
3 Kernel adapted from kernel.py to use ZMQ Streams
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Imports
7 # Imports
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 # Standard library imports.
10 # Standard library imports.
11 from __future__ import print_function
11 from __future__ import print_function
12 import __builtin__
12 import __builtin__
13 from code import CommandCompiler
13 from code import CommandCompiler
14 import os
14 import os
15 import sys
15 import sys
16 import time
16 import time
17 import traceback
17 import traceback
18 import logging
18 import logging
19 from datetime import datetime
19 from datetime import datetime
20 from signal import SIGTERM, SIGKILL
20 from signal import SIGTERM, SIGKILL
21 from pprint import pprint
21 from pprint import pprint
22
22
23 # System library imports.
23 # System library imports.
24 import zmq
24 import zmq
25 from zmq.eventloop import ioloop, zmqstream
25 from zmq.eventloop import ioloop, zmqstream
26
26
27 # Local imports.
27 # Local imports.
28 from IPython.core import ultratb
28 from IPython.core import ultratb
29 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
29 from IPython.utils.traitlets import HasTraits, Instance, List, Int, Dict, Set, Str
30 from IPython.zmq.completer import KernelCompleter
30 from IPython.zmq.completer import KernelCompleter
31 from IPython.zmq.iostream import OutStream
31 from IPython.zmq.iostream import OutStream
32 from IPython.zmq.displayhook import DisplayHook
32 from IPython.zmq.displayhook import DisplayHook
33
33
34 from factory import SessionFactory
34 from factory import SessionFactory
35 from streamsession import StreamSession, Message, extract_header, serialize_object,\
35 from streamsession import StreamSession, Message, extract_header, serialize_object,\
36 unpack_apply_message, ISO8601, wrap_exception
36 unpack_apply_message, ISO8601, wrap_exception
37 from dependency import UnmetDependency
38 import heartmonitor
37 import heartmonitor
39 from client import Client
38 from client import Client
40
39
41 def printer(*args):
40 def printer(*args):
42 pprint(args, stream=sys.__stdout__)
41 pprint(args, stream=sys.__stdout__)
43
42
44
43
45 class _Passer:
44 class _Passer:
46 """Empty class that implements `send()` that does nothing."""
45 """Empty class that implements `send()` that does nothing."""
47 def send(self, *args, **kwargs):
46 def send(self, *args, **kwargs):
48 pass
47 pass
49 send_multipart = send
48 send_multipart = send
50
49
51
50
52 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
53 # Main kernel class
52 # Main kernel class
54 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
55
54
56 class Kernel(SessionFactory):
55 class Kernel(SessionFactory):
57
56
58 #---------------------------------------------------------------------------
57 #---------------------------------------------------------------------------
59 # Kernel interface
58 # Kernel interface
60 #---------------------------------------------------------------------------
59 #---------------------------------------------------------------------------
61
60
62 # kwargs:
61 # kwargs:
63 int_id = Int(-1, config=True)
62 int_id = Int(-1, config=True)
64 user_ns = Dict(config=True)
63 user_ns = Dict(config=True)
65 exec_lines = List(config=True)
64 exec_lines = List(config=True)
66
65
67 control_stream = Instance(zmqstream.ZMQStream)
66 control_stream = Instance(zmqstream.ZMQStream)
68 task_stream = Instance(zmqstream.ZMQStream)
67 task_stream = Instance(zmqstream.ZMQStream)
69 iopub_stream = Instance(zmqstream.ZMQStream)
68 iopub_stream = Instance(zmqstream.ZMQStream)
70 client = Instance('IPython.zmq.parallel.client.Client')
69 client = Instance('IPython.zmq.parallel.client.Client')
71
70
72 # internals
71 # internals
73 shell_streams = List()
72 shell_streams = List()
74 compiler = Instance(CommandCompiler, (), {})
73 compiler = Instance(CommandCompiler, (), {})
75 completer = Instance(KernelCompleter)
74 completer = Instance(KernelCompleter)
76
75
77 aborted = Set()
76 aborted = Set()
78 shell_handlers = Dict()
77 shell_handlers = Dict()
79 control_handlers = Dict()
78 control_handlers = Dict()
80
79
81 def _set_prefix(self):
80 def _set_prefix(self):
82 self.prefix = "engine.%s"%self.int_id
81 self.prefix = "engine.%s"%self.int_id
83
82
84 def _connect_completer(self):
83 def _connect_completer(self):
85 self.completer = KernelCompleter(self.user_ns)
84 self.completer = KernelCompleter(self.user_ns)
86
85
87 def __init__(self, **kwargs):
86 def __init__(self, **kwargs):
88 super(Kernel, self).__init__(**kwargs)
87 super(Kernel, self).__init__(**kwargs)
89 self._set_prefix()
88 self._set_prefix()
90 self._connect_completer()
89 self._connect_completer()
91
90
92 self.on_trait_change(self._set_prefix, 'id')
91 self.on_trait_change(self._set_prefix, 'id')
93 self.on_trait_change(self._connect_completer, 'user_ns')
92 self.on_trait_change(self._connect_completer, 'user_ns')
94
93
95 # Build dict of handlers for message types
94 # Build dict of handlers for message types
96 for msg_type in ['execute_request', 'complete_request', 'apply_request',
95 for msg_type in ['execute_request', 'complete_request', 'apply_request',
97 'clear_request']:
96 'clear_request']:
98 self.shell_handlers[msg_type] = getattr(self, msg_type)
97 self.shell_handlers[msg_type] = getattr(self, msg_type)
99
98
100 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
99 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
101 self.control_handlers[msg_type] = getattr(self, msg_type)
100 self.control_handlers[msg_type] = getattr(self, msg_type)
102
101
103 self._initial_exec_lines()
102 self._initial_exec_lines()
104
103
105 def _wrap_exception(self, method=None):
104 def _wrap_exception(self, method=None):
106 e_info = dict(engineid=self.ident, method=method)
105 e_info = dict(engineid=self.ident, method=method)
107 content=wrap_exception(e_info)
106 content=wrap_exception(e_info)
108 return content
107 return content
109
108
110 def _initial_exec_lines(self):
109 def _initial_exec_lines(self):
111 s = _Passer()
110 s = _Passer()
112 content = dict(silent=True, user_variable=[],user_expressions=[])
111 content = dict(silent=True, user_variable=[],user_expressions=[])
113 for line in self.exec_lines:
112 for line in self.exec_lines:
114 logging.debug("executing initialization: %s"%line)
113 logging.debug("executing initialization: %s"%line)
115 content.update({'code':line})
114 content.update({'code':line})
116 msg = self.session.msg('execute_request', content)
115 msg = self.session.msg('execute_request', content)
117 self.execute_request(s, [], msg)
116 self.execute_request(s, [], msg)
118
117
119
118
120 #-------------------- control handlers -----------------------------
119 #-------------------- control handlers -----------------------------
121 def abort_queues(self):
120 def abort_queues(self):
122 for stream in self.shell_streams:
121 for stream in self.shell_streams:
123 if stream:
122 if stream:
124 self.abort_queue(stream)
123 self.abort_queue(stream)
125
124
126 def abort_queue(self, stream):
125 def abort_queue(self, stream):
127 while True:
126 while True:
128 try:
127 try:
129 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
128 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
130 except zmq.ZMQError as e:
129 except zmq.ZMQError as e:
131 if e.errno == zmq.EAGAIN:
130 if e.errno == zmq.EAGAIN:
132 break
131 break
133 else:
132 else:
134 return
133 return
135 else:
134 else:
136 if msg is None:
135 if msg is None:
137 return
136 return
138 else:
137 else:
139 idents,msg = msg
138 idents,msg = msg
140
139
141 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
140 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
142 # msg = self.reply_socket.recv_json()
141 # msg = self.reply_socket.recv_json()
143 logging.info("Aborting:")
142 logging.info("Aborting:")
144 logging.info(str(msg))
143 logging.info(str(msg))
145 msg_type = msg['msg_type']
144 msg_type = msg['msg_type']
146 reply_type = msg_type.split('_')[0] + '_reply'
145 reply_type = msg_type.split('_')[0] + '_reply'
147 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
146 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
148 # self.reply_socket.send(ident,zmq.SNDMORE)
147 # self.reply_socket.send(ident,zmq.SNDMORE)
149 # self.reply_socket.send_json(reply_msg)
148 # self.reply_socket.send_json(reply_msg)
150 reply_msg = self.session.send(stream, reply_type,
149 reply_msg = self.session.send(stream, reply_type,
151 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
150 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
152 logging.debug(str(reply_msg))
151 logging.debug(str(reply_msg))
153 # We need to wait a bit for requests to come in. This can probably
152 # We need to wait a bit for requests to come in. This can probably
154 # be set shorter for true asynchronous clients.
153 # be set shorter for true asynchronous clients.
155 time.sleep(0.05)
154 time.sleep(0.05)
156
155
157 def abort_request(self, stream, ident, parent):
156 def abort_request(self, stream, ident, parent):
158 """abort a specifig msg by id"""
157 """abort a specifig msg by id"""
159 msg_ids = parent['content'].get('msg_ids', None)
158 msg_ids = parent['content'].get('msg_ids', None)
160 if isinstance(msg_ids, basestring):
159 if isinstance(msg_ids, basestring):
161 msg_ids = [msg_ids]
160 msg_ids = [msg_ids]
162 if not msg_ids:
161 if not msg_ids:
163 self.abort_queues()
162 self.abort_queues()
164 for mid in msg_ids:
163 for mid in msg_ids:
165 self.aborted.add(str(mid))
164 self.aborted.add(str(mid))
166
165
167 content = dict(status='ok')
166 content = dict(status='ok')
168 reply_msg = self.session.send(stream, 'abort_reply', content=content,
167 reply_msg = self.session.send(stream, 'abort_reply', content=content,
169 parent=parent, ident=ident)[0]
168 parent=parent, ident=ident)[0]
170 logging.debug(str(reply_msg))
169 logging.debug(str(reply_msg))
171
170
172 def shutdown_request(self, stream, ident, parent):
171 def shutdown_request(self, stream, ident, parent):
173 """kill ourself. This should really be handled in an external process"""
172 """kill ourself. This should really be handled in an external process"""
174 try:
173 try:
175 self.abort_queues()
174 self.abort_queues()
176 except:
175 except:
177 content = self._wrap_exception('shutdown')
176 content = self._wrap_exception('shutdown')
178 else:
177 else:
179 content = dict(parent['content'])
178 content = dict(parent['content'])
180 content['status'] = 'ok'
179 content['status'] = 'ok'
181 msg = self.session.send(stream, 'shutdown_reply',
180 msg = self.session.send(stream, 'shutdown_reply',
182 content=content, parent=parent, ident=ident)
181 content=content, parent=parent, ident=ident)
183 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
182 # msg = self.session.send(self.pub_socket, 'shutdown_reply',
184 # content, parent, ident)
183 # content, parent, ident)
185 # print >> sys.__stdout__, msg
184 # print >> sys.__stdout__, msg
186 # time.sleep(0.2)
185 # time.sleep(0.2)
187 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
186 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
188 dc.start()
187 dc.start()
189
188
190 def dispatch_control(self, msg):
189 def dispatch_control(self, msg):
191 idents,msg = self.session.feed_identities(msg, copy=False)
190 idents,msg = self.session.feed_identities(msg, copy=False)
192 try:
191 try:
193 msg = self.session.unpack_message(msg, content=True, copy=False)
192 msg = self.session.unpack_message(msg, content=True, copy=False)
194 except:
193 except:
195 logging.error("Invalid Message", exc_info=True)
194 logging.error("Invalid Message", exc_info=True)
196 return
195 return
197
196
198 header = msg['header']
197 header = msg['header']
199 msg_id = header['msg_id']
198 msg_id = header['msg_id']
200
199
201 handler = self.control_handlers.get(msg['msg_type'], None)
200 handler = self.control_handlers.get(msg['msg_type'], None)
202 if handler is None:
201 if handler is None:
203 logging.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
202 logging.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
204 else:
203 else:
205 handler(self.control_stream, idents, msg)
204 handler(self.control_stream, idents, msg)
206
205
207
206
208 #-------------------- queue helpers ------------------------------
207 #-------------------- queue helpers ------------------------------
209
208
210 def check_dependencies(self, dependencies):
209 def check_dependencies(self, dependencies):
211 if not dependencies:
210 if not dependencies:
212 return True
211 return True
213 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
212 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
214 anyorall = dependencies[0]
213 anyorall = dependencies[0]
215 dependencies = dependencies[1]
214 dependencies = dependencies[1]
216 else:
215 else:
217 anyorall = 'all'
216 anyorall = 'all'
218 results = self.client.get_results(dependencies,status_only=True)
217 results = self.client.get_results(dependencies,status_only=True)
219 if results['status'] != 'ok':
218 if results['status'] != 'ok':
220 return False
219 return False
221
220
222 if anyorall == 'any':
221 if anyorall == 'any':
223 if not results['completed']:
222 if not results['completed']:
224 return False
223 return False
225 else:
224 else:
226 if results['pending']:
225 if results['pending']:
227 return False
226 return False
228
227
229 return True
228 return True
230
229
231 def check_aborted(self, msg_id):
230 def check_aborted(self, msg_id):
232 return msg_id in self.aborted
231 return msg_id in self.aborted
233
232
234 #-------------------- queue handlers -----------------------------
233 #-------------------- queue handlers -----------------------------
235
234
236 def clear_request(self, stream, idents, parent):
235 def clear_request(self, stream, idents, parent):
237 """Clear our namespace."""
236 """Clear our namespace."""
238 self.user_ns = {}
237 self.user_ns = {}
239 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
238 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
240 content = dict(status='ok'))
239 content = dict(status='ok'))
241 self._initial_exec_lines()
240 self._initial_exec_lines()
242
241
243 def execute_request(self, stream, ident, parent):
242 def execute_request(self, stream, ident, parent):
244 logging.debug('execute request %s'%parent)
243 logging.debug('execute request %s'%parent)
245 try:
244 try:
246 code = parent[u'content'][u'code']
245 code = parent[u'content'][u'code']
247 except:
246 except:
248 logging.error("Got bad msg: %s"%parent, exc_info=True)
247 logging.error("Got bad msg: %s"%parent, exc_info=True)
249 return
248 return
250 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
249 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
251 ident='%s.pyin'%self.prefix)
250 ident='%s.pyin'%self.prefix)
252 started = datetime.now().strftime(ISO8601)
251 started = datetime.now().strftime(ISO8601)
253 try:
252 try:
254 comp_code = self.compiler(code, '<zmq-kernel>')
253 comp_code = self.compiler(code, '<zmq-kernel>')
255 # allow for not overriding displayhook
254 # allow for not overriding displayhook
256 if hasattr(sys.displayhook, 'set_parent'):
255 if hasattr(sys.displayhook, 'set_parent'):
257 sys.displayhook.set_parent(parent)
256 sys.displayhook.set_parent(parent)
258 sys.stdout.set_parent(parent)
257 sys.stdout.set_parent(parent)
259 sys.stderr.set_parent(parent)
258 sys.stderr.set_parent(parent)
260 exec comp_code in self.user_ns, self.user_ns
259 exec comp_code in self.user_ns, self.user_ns
261 except:
260 except:
262 exc_content = self._wrap_exception('execute')
261 exc_content = self._wrap_exception('execute')
263 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
262 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
264 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
263 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
265 ident='%s.pyerr'%self.prefix)
264 ident='%s.pyerr'%self.prefix)
266 reply_content = exc_content
265 reply_content = exc_content
267 else:
266 else:
268 reply_content = {'status' : 'ok'}
267 reply_content = {'status' : 'ok'}
269 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
268
270 # self.reply_socket.send(ident, zmq.SNDMORE)
271 # self.reply_socket.send_json(reply_msg)
272 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
269 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
273 ident=ident, subheader = dict(started=started))
270 ident=ident, subheader = dict(started=started))
274 logging.debug(str(reply_msg))
271 logging.debug(str(reply_msg))
275 if reply_msg['content']['status'] == u'error':
272 if reply_msg['content']['status'] == u'error':
276 self.abort_queues()
273 self.abort_queues()
277
274
278 def complete_request(self, stream, ident, parent):
275 def complete_request(self, stream, ident, parent):
279 matches = {'matches' : self.complete(parent),
276 matches = {'matches' : self.complete(parent),
280 'status' : 'ok'}
277 'status' : 'ok'}
281 completion_msg = self.session.send(stream, 'complete_reply',
278 completion_msg = self.session.send(stream, 'complete_reply',
282 matches, parent, ident)
279 matches, parent, ident)
283 # print >> sys.__stdout__, completion_msg
280 # print >> sys.__stdout__, completion_msg
284
281
285 def complete(self, msg):
282 def complete(self, msg):
286 return self.completer.complete(msg.content.line, msg.content.text)
283 return self.completer.complete(msg.content.line, msg.content.text)
287
284
288 def apply_request(self, stream, ident, parent):
285 def apply_request(self, stream, ident, parent):
289 # print (parent)
286 # print (parent)
290 try:
287 try:
291 content = parent[u'content']
288 content = parent[u'content']
292 bufs = parent[u'buffers']
289 bufs = parent[u'buffers']
293 msg_id = parent['header']['msg_id']
290 msg_id = parent['header']['msg_id']
294 bound = content.get('bound', False)
291 bound = content.get('bound', False)
295 except:
292 except:
296 logging.error("Got bad msg: %s"%parent, exc_info=True)
293 logging.error("Got bad msg: %s"%parent, exc_info=True)
297 return
294 return
298 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
295 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
299 # self.iopub_stream.send(pyin_msg)
296 # self.iopub_stream.send(pyin_msg)
300 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
297 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
301 sub = {'dependencies_met' : True, 'engine' : self.ident,
298 sub = {'dependencies_met' : True, 'engine' : self.ident,
302 'started': datetime.now().strftime(ISO8601)}
299 'started': datetime.now().strftime(ISO8601)}
303 try:
300 try:
304 # allow for not overriding displayhook
301 # allow for not overriding displayhook
305 if hasattr(sys.displayhook, 'set_parent'):
302 if hasattr(sys.displayhook, 'set_parent'):
306 sys.displayhook.set_parent(parent)
303 sys.displayhook.set_parent(parent)
307 sys.stdout.set_parent(parent)
304 sys.stdout.set_parent(parent)
308 sys.stderr.set_parent(parent)
305 sys.stderr.set_parent(parent)
309 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
306 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
310 if bound:
307 if bound:
311 working = self.user_ns
308 working = self.user_ns
312 suffix = str(msg_id).replace("-","")
309 suffix = str(msg_id).replace("-","")
313 prefix = "_"
310 prefix = "_"
314
311
315 else:
312 else:
316 working = dict()
313 working = dict()
317 suffix = prefix = "_" # prevent keyword collisions with lambda
314 suffix = prefix = "_" # prevent keyword collisions with lambda
318 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
315 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
319 # if f.fun
316 # if f.fun
320 if hasattr(f, 'func_name'):
317 fname = getattr(f, '__name__', 'f')
321 fname = f.func_name
322 else:
323 fname = f.__name__
324
318
325 fname = prefix+fname.strip('<>')+suffix
319 fname = prefix+fname.strip('<>')+suffix
326 argname = prefix+"args"+suffix
320 argname = prefix+"args"+suffix
327 kwargname = prefix+"kwargs"+suffix
321 kwargname = prefix+"kwargs"+suffix
328 resultname = prefix+"result"+suffix
322 resultname = prefix+"result"+suffix
329
323
330 ns = { fname : f, argname : args, kwargname : kwargs }
324 ns = { fname : f, argname : args, kwargname : kwargs }
331 # print ns
325 # print ns
332 working.update(ns)
326 working.update(ns)
333 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
327 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
334 exec code in working, working
328 exec code in working, working
335 result = working.get(resultname)
329 result = working.get(resultname)
336 # clear the namespace
330 # clear the namespace
337 if bound:
331 if bound:
338 for key in ns.iterkeys():
332 for key in ns.iterkeys():
339 self.user_ns.pop(key)
333 self.user_ns.pop(key)
340 else:
334 else:
341 del working
335 del working
342
336
343 packed_result,buf = serialize_object(result)
337 packed_result,buf = serialize_object(result)
344 result_buf = [packed_result]+buf
338 result_buf = [packed_result]+buf
345 except:
339 except:
346 exc_content = self._wrap_exception('apply')
340 exc_content = self._wrap_exception('apply')
347 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
341 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
348 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
342 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
349 ident='%s.pyerr'%self.prefix)
343 ident='%s.pyerr'%self.prefix)
350 reply_content = exc_content
344 reply_content = exc_content
351 result_buf = []
345 result_buf = []
352
346
353 if exc_content['ename'] == UnmetDependency.__name__:
347 if exc_content['ename'] == 'UnmetDependency':
354 sub['dependencies_met'] = False
348 sub['dependencies_met'] = False
355 else:
349 else:
356 reply_content = {'status' : 'ok'}
350 reply_content = {'status' : 'ok'}
357 # reply_msg = self.session.msg(u'execute_reply', reply_content, parent)
351
358 # self.reply_socket.send(ident, zmq.SNDMORE)
352 # put 'ok'/'error' status in header, for scheduler introspection:
359 # self.reply_socket.send_json(reply_msg)
353 sub['status'] = reply_content['status']
354
360 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
355 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
361 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
356 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
362 # print(Message(reply_msg), file=sys.__stdout__)
357
363 # if reply_msg['content']['status'] == u'error':
358 # if reply_msg['content']['status'] == u'error':
364 # self.abort_queues()
359 # self.abort_queues()
365
360
366 def dispatch_queue(self, stream, msg):
361 def dispatch_queue(self, stream, msg):
367 self.control_stream.flush()
362 self.control_stream.flush()
368 idents,msg = self.session.feed_identities(msg, copy=False)
363 idents,msg = self.session.feed_identities(msg, copy=False)
369 try:
364 try:
370 msg = self.session.unpack_message(msg, content=True, copy=False)
365 msg = self.session.unpack_message(msg, content=True, copy=False)
371 except:
366 except:
372 logging.error("Invalid Message", exc_info=True)
367 logging.error("Invalid Message", exc_info=True)
373 return
368 return
374
369
375
370
376 header = msg['header']
371 header = msg['header']
377 msg_id = header['msg_id']
372 msg_id = header['msg_id']
378 if self.check_aborted(msg_id):
373 if self.check_aborted(msg_id):
379 self.aborted.remove(msg_id)
374 self.aborted.remove(msg_id)
380 # is it safe to assume a msg_id will not be resubmitted?
375 # is it safe to assume a msg_id will not be resubmitted?
381 reply_type = msg['msg_type'].split('_')[0] + '_reply'
376 reply_type = msg['msg_type'].split('_')[0] + '_reply'
382 reply_msg = self.session.send(stream, reply_type,
377 reply_msg = self.session.send(stream, reply_type,
383 content={'status' : 'aborted'}, parent=msg, ident=idents)
378 content={'status' : 'aborted'}, parent=msg, ident=idents)
384 return
379 return
385 handler = self.shell_handlers.get(msg['msg_type'], None)
380 handler = self.shell_handlers.get(msg['msg_type'], None)
386 if handler is None:
381 if handler is None:
387 logging.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
382 logging.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
388 else:
383 else:
389 handler(stream, idents, msg)
384 handler(stream, idents, msg)
390
385
391 def start(self):
386 def start(self):
392 #### stream mode:
387 #### stream mode:
393 if self.control_stream:
388 if self.control_stream:
394 self.control_stream.on_recv(self.dispatch_control, copy=False)
389 self.control_stream.on_recv(self.dispatch_control, copy=False)
395 self.control_stream.on_err(printer)
390 self.control_stream.on_err(printer)
396
391
397 def make_dispatcher(stream):
392 def make_dispatcher(stream):
398 def dispatcher(msg):
393 def dispatcher(msg):
399 return self.dispatch_queue(stream, msg)
394 return self.dispatch_queue(stream, msg)
400 return dispatcher
395 return dispatcher
401
396
402 for s in self.shell_streams:
397 for s in self.shell_streams:
403 # s.on_recv(printer)
404 s.on_recv(make_dispatcher(s), copy=False)
398 s.on_recv(make_dispatcher(s), copy=False)
405 # s.on_err(printer)
399 s.on_err(printer)
406
400
407 if self.iopub_stream:
401 if self.iopub_stream:
408 self.iopub_stream.on_err(printer)
402 self.iopub_stream.on_err(printer)
409 # self.iopub_stream.on_send(printer)
410
403
411 #### while True mode:
404 #### while True mode:
412 # while True:
405 # while True:
413 # idle = True
406 # idle = True
414 # try:
407 # try:
415 # msg = self.shell_stream.socket.recv_multipart(
408 # msg = self.shell_stream.socket.recv_multipart(
416 # zmq.NOBLOCK, copy=False)
409 # zmq.NOBLOCK, copy=False)
417 # except zmq.ZMQError, e:
410 # except zmq.ZMQError, e:
418 # if e.errno != zmq.EAGAIN:
411 # if e.errno != zmq.EAGAIN:
419 # raise e
412 # raise e
420 # else:
413 # else:
421 # idle=False
414 # idle=False
422 # self.dispatch_queue(self.shell_stream, msg)
415 # self.dispatch_queue(self.shell_stream, msg)
423 #
416 #
424 # if not self.task_stream.empty():
417 # if not self.task_stream.empty():
425 # idle=False
418 # idle=False
426 # msg = self.task_stream.recv_multipart()
419 # msg = self.task_stream.recv_multipart()
427 # self.dispatch_queue(self.task_stream, msg)
420 # self.dispatch_queue(self.task_stream, msg)
428 # if idle:
421 # if idle:
429 # # don't busywait
422 # # don't busywait
430 # time.sleep(1e-3)
423 # time.sleep(1e-3)
431
424
432 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
425 def make_kernel(int_id, identity, control_addr, shell_addrs, iopub_addr, hb_addrs,
433 client_addr=None, loop=None, context=None, key=None,
426 client_addr=None, loop=None, context=None, key=None,
434 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
427 out_stream_factory=OutStream, display_hook_factory=DisplayHook):
435 """NO LONGER IN USE"""
428 """NO LONGER IN USE"""
436 # create loop, context, and session:
429 # create loop, context, and session:
437 if loop is None:
430 if loop is None:
438 loop = ioloop.IOLoop.instance()
431 loop = ioloop.IOLoop.instance()
439 if context is None:
432 if context is None:
440 context = zmq.Context()
433 context = zmq.Context()
441 c = context
434 c = context
442 session = StreamSession(key=key)
435 session = StreamSession(key=key)
443 # print (session.key)
436 # print (session.key)
444 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
437 # print (control_addr, shell_addrs, iopub_addr, hb_addrs)
445
438
446 # create Control Stream
439 # create Control Stream
447 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
440 control_stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
448 control_stream.setsockopt(zmq.IDENTITY, identity)
441 control_stream.setsockopt(zmq.IDENTITY, identity)
449 control_stream.connect(control_addr)
442 control_stream.connect(control_addr)
450
443
451 # create Shell Streams (MUX, Task, etc.):
444 # create Shell Streams (MUX, Task, etc.):
452 shell_streams = []
445 shell_streams = []
453 for addr in shell_addrs:
446 for addr in shell_addrs:
454 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
447 stream = zmqstream.ZMQStream(c.socket(zmq.PAIR), loop)
455 stream.setsockopt(zmq.IDENTITY, identity)
448 stream.setsockopt(zmq.IDENTITY, identity)
456 stream.connect(addr)
449 stream.connect(addr)
457 shell_streams.append(stream)
450 shell_streams.append(stream)
458
451
459 # create iopub stream:
452 # create iopub stream:
460 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
453 iopub_stream = zmqstream.ZMQStream(c.socket(zmq.PUB), loop)
461 iopub_stream.setsockopt(zmq.IDENTITY, identity)
454 iopub_stream.setsockopt(zmq.IDENTITY, identity)
462 iopub_stream.connect(iopub_addr)
455 iopub_stream.connect(iopub_addr)
463
456
464 # Redirect input streams and set a display hook.
457 # Redirect input streams and set a display hook.
465 if out_stream_factory:
458 if out_stream_factory:
466 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
459 sys.stdout = out_stream_factory(session, iopub_stream, u'stdout')
467 sys.stdout.topic = 'engine.%i.stdout'%int_id
460 sys.stdout.topic = 'engine.%i.stdout'%int_id
468 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
461 sys.stderr = out_stream_factory(session, iopub_stream, u'stderr')
469 sys.stderr.topic = 'engine.%i.stderr'%int_id
462 sys.stderr.topic = 'engine.%i.stderr'%int_id
470 if display_hook_factory:
463 if display_hook_factory:
471 sys.displayhook = display_hook_factory(session, iopub_stream)
464 sys.displayhook = display_hook_factory(session, iopub_stream)
472 sys.displayhook.topic = 'engine.%i.pyout'%int_id
465 sys.displayhook.topic = 'engine.%i.pyout'%int_id
473
466
474
467
475 # launch heartbeat
468 # launch heartbeat
476 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
469 heart = heartmonitor.Heart(*map(str, hb_addrs), heart_id=identity)
477 heart.start()
470 heart.start()
478
471
479 # create (optional) Client
472 # create (optional) Client
480 if client_addr:
473 if client_addr:
481 client = Client(client_addr, username=identity)
474 client = Client(client_addr, username=identity)
482 else:
475 else:
483 client = None
476 client = None
484
477
485 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
478 kernel = Kernel(id=int_id, session=session, control_stream=control_stream,
486 shell_streams=shell_streams, iopub_stream=iopub_stream,
479 shell_streams=shell_streams, iopub_stream=iopub_stream,
487 client=client, loop=loop)
480 client=client, loop=loop)
488 kernel.start()
481 kernel.start()
489 return loop, c, kernel
482 return loop, c, kernel
490
483
@@ -1,549 +1,549 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """edited session.py to work with streams, and move msg_type to the header
2 """edited session.py to work with streams, and move msg_type to the header
3 """
3 """
4
4
5
5
6 import os
6 import os
7 import sys
7 import sys
8 import traceback
8 import traceback
9 import pprint
9 import pprint
10 import uuid
10 import uuid
11 from datetime import datetime
11 from datetime import datetime
12
12
13 import zmq
13 import zmq
14 from zmq.utils import jsonapi
14 from zmq.utils import jsonapi
15 from zmq.eventloop.zmqstream import ZMQStream
15 from zmq.eventloop.zmqstream import ZMQStream
16
16
17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
17 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
18 from IPython.utils.newserialized import serialize, unserialize
18 from IPython.utils.newserialized import serialize, unserialize
19
19
20 from IPython.zmq.parallel.error import RemoteError
20 from IPython.zmq.parallel.error import RemoteError
21
21
22 try:
22 try:
23 import cPickle
23 import cPickle
24 pickle = cPickle
24 pickle = cPickle
25 except:
25 except:
26 cPickle = None
26 cPickle = None
27 import pickle
27 import pickle
28
28
29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
29 # packer priority: jsonlib[2], cPickle, simplejson/json, pickle
30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
30 json_name = '' if not jsonapi.jsonmod else jsonapi.jsonmod.__name__
31 if json_name in ('jsonlib', 'jsonlib2'):
31 if json_name in ('jsonlib', 'jsonlib2'):
32 use_json = True
32 use_json = True
33 elif json_name:
33 elif json_name:
34 if cPickle is None:
34 if cPickle is None:
35 use_json = True
35 use_json = True
36 else:
36 else:
37 use_json = False
37 use_json = False
38 else:
38 else:
39 use_json = False
39 use_json = False
40
40
41 def squash_unicode(obj):
41 def squash_unicode(obj):
42 if isinstance(obj,dict):
42 if isinstance(obj,dict):
43 for key in obj.keys():
43 for key in obj.keys():
44 obj[key] = squash_unicode(obj[key])
44 obj[key] = squash_unicode(obj[key])
45 if isinstance(key, unicode):
45 if isinstance(key, unicode):
46 obj[squash_unicode(key)] = obj.pop(key)
46 obj[squash_unicode(key)] = obj.pop(key)
47 elif isinstance(obj, list):
47 elif isinstance(obj, list):
48 for i,v in enumerate(obj):
48 for i,v in enumerate(obj):
49 obj[i] = squash_unicode(v)
49 obj[i] = squash_unicode(v)
50 elif isinstance(obj, unicode):
50 elif isinstance(obj, unicode):
51 obj = obj.encode('utf8')
51 obj = obj.encode('utf8')
52 return obj
52 return obj
53
53
54 json_packer = jsonapi.dumps
54 json_packer = jsonapi.dumps
55 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
55 json_unpacker = lambda s: squash_unicode(jsonapi.loads(s))
56
56
57 pickle_packer = lambda o: pickle.dumps(o,-1)
57 pickle_packer = lambda o: pickle.dumps(o,-1)
58 pickle_unpacker = pickle.loads
58 pickle_unpacker = pickle.loads
59
59
60 if use_json:
60 if use_json:
61 default_packer = json_packer
61 default_packer = json_packer
62 default_unpacker = json_unpacker
62 default_unpacker = json_unpacker
63 else:
63 else:
64 default_packer = pickle_packer
64 default_packer = pickle_packer
65 default_unpacker = pickle_unpacker
65 default_unpacker = pickle_unpacker
66
66
67
67
68 DELIM="<IDS|MSG>"
68 DELIM="<IDS|MSG>"
69 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
69 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
70
70
71 def wrap_exception(engine_info={}):
71 def wrap_exception(engine_info={}):
72 etype, evalue, tb = sys.exc_info()
72 etype, evalue, tb = sys.exc_info()
73 stb = traceback.format_exception(etype, evalue, tb)
73 stb = traceback.format_exception(etype, evalue, tb)
74 exc_content = {
74 exc_content = {
75 'status' : 'error',
75 'status' : 'error',
76 'traceback' : stb,
76 'traceback' : stb,
77 'ename' : unicode(etype.__name__),
77 'ename' : unicode(etype.__name__),
78 'evalue' : unicode(evalue),
78 'evalue' : unicode(evalue),
79 'engine_info' : engine_info
79 'engine_info' : engine_info
80 }
80 }
81 return exc_content
81 return exc_content
82
82
83 def unwrap_exception(content):
83 def unwrap_exception(content):
84 err = RemoteError(content['ename'], content['evalue'],
84 err = RemoteError(content['ename'], content['evalue'],
85 ''.join(content['traceback']),
85 ''.join(content['traceback']),
86 content.get('engine_info', {}))
86 content.get('engine_info', {}))
87 return err
87 return err
88
88
89
89
90 class Message(object):
90 class Message(object):
91 """A simple message object that maps dict keys to attributes.
91 """A simple message object that maps dict keys to attributes.
92
92
93 A Message can be created from a dict and a dict from a Message instance
93 A Message can be created from a dict and a dict from a Message instance
94 simply by calling dict(msg_obj)."""
94 simply by calling dict(msg_obj)."""
95
95
96 def __init__(self, msg_dict):
96 def __init__(self, msg_dict):
97 dct = self.__dict__
97 dct = self.__dict__
98 for k, v in dict(msg_dict).iteritems():
98 for k, v in dict(msg_dict).iteritems():
99 if isinstance(v, dict):
99 if isinstance(v, dict):
100 v = Message(v)
100 v = Message(v)
101 dct[k] = v
101 dct[k] = v
102
102
103 # Having this iterator lets dict(msg_obj) work out of the box.
103 # Having this iterator lets dict(msg_obj) work out of the box.
104 def __iter__(self):
104 def __iter__(self):
105 return iter(self.__dict__.iteritems())
105 return iter(self.__dict__.iteritems())
106
106
107 def __repr__(self):
107 def __repr__(self):
108 return repr(self.__dict__)
108 return repr(self.__dict__)
109
109
110 def __str__(self):
110 def __str__(self):
111 return pprint.pformat(self.__dict__)
111 return pprint.pformat(self.__dict__)
112
112
113 def __contains__(self, k):
113 def __contains__(self, k):
114 return k in self.__dict__
114 return k in self.__dict__
115
115
116 def __getitem__(self, k):
116 def __getitem__(self, k):
117 return self.__dict__[k]
117 return self.__dict__[k]
118
118
119
119
120 def msg_header(msg_id, msg_type, username, session):
120 def msg_header(msg_id, msg_type, username, session):
121 date=datetime.now().strftime(ISO8601)
121 date=datetime.now().strftime(ISO8601)
122 return locals()
122 return locals()
123
123
124 def extract_header(msg_or_header):
124 def extract_header(msg_or_header):
125 """Given a message or header, return the header."""
125 """Given a message or header, return the header."""
126 if not msg_or_header:
126 if not msg_or_header:
127 return {}
127 return {}
128 try:
128 try:
129 # See if msg_or_header is the entire message.
129 # See if msg_or_header is the entire message.
130 h = msg_or_header['header']
130 h = msg_or_header['header']
131 except KeyError:
131 except KeyError:
132 try:
132 try:
133 # See if msg_or_header is just the header
133 # See if msg_or_header is just the header
134 h = msg_or_header['msg_id']
134 h = msg_or_header['msg_id']
135 except KeyError:
135 except KeyError:
136 raise
136 raise
137 else:
137 else:
138 h = msg_or_header
138 h = msg_or_header
139 if not isinstance(h, dict):
139 if not isinstance(h, dict):
140 h = dict(h)
140 h = dict(h)
141 return h
141 return h
142
142
143 def rekey(dikt):
143 def rekey(dikt):
144 """Rekey a dict that has been forced to use str keys where there should be
144 """Rekey a dict that has been forced to use str keys where there should be
145 ints by json. This belongs in the jsonutil added by fperez."""
145 ints by json. This belongs in the jsonutil added by fperez."""
146 for k in dikt.iterkeys():
146 for k in dikt.iterkeys():
147 if isinstance(k, str):
147 if isinstance(k, str):
148 ik=fk=None
148 ik=fk=None
149 try:
149 try:
150 ik = int(k)
150 ik = int(k)
151 except ValueError:
151 except ValueError:
152 try:
152 try:
153 fk = float(k)
153 fk = float(k)
154 except ValueError:
154 except ValueError:
155 continue
155 continue
156 if ik is not None:
156 if ik is not None:
157 nk = ik
157 nk = ik
158 else:
158 else:
159 nk = fk
159 nk = fk
160 if nk in dikt:
160 if nk in dikt:
161 raise KeyError("already have key %r"%nk)
161 raise KeyError("already have key %r"%nk)
162 dikt[nk] = dikt.pop(k)
162 dikt[nk] = dikt.pop(k)
163 return dikt
163 return dikt
164
164
165 def serialize_object(obj, threshold=64e-6):
165 def serialize_object(obj, threshold=64e-6):
166 """Serialize an object into a list of sendable buffers.
166 """Serialize an object into a list of sendable buffers.
167
167
168 Parameters
168 Parameters
169 ----------
169 ----------
170
170
171 obj : object
171 obj : object
172 The object to be serialized
172 The object to be serialized
173 threshold : float
173 threshold : float
174 The threshold for not double-pickling the content.
174 The threshold for not double-pickling the content.
175
175
176
176
177 Returns
177 Returns
178 -------
178 -------
179 ('pmd', [bufs]) :
179 ('pmd', [bufs]) :
180 where pmd is the pickled metadata wrapper,
180 where pmd is the pickled metadata wrapper,
181 bufs is a list of data buffers
181 bufs is a list of data buffers
182 """
182 """
183 databuffers = []
183 databuffers = []
184 if isinstance(obj, (list, tuple)):
184 if isinstance(obj, (list, tuple)):
185 clist = canSequence(obj)
185 clist = canSequence(obj)
186 slist = map(serialize, clist)
186 slist = map(serialize, clist)
187 for s in slist:
187 for s in slist:
188 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
188 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
189 databuffers.append(s.getData())
189 databuffers.append(s.getData())
190 s.data = None
190 s.data = None
191 return pickle.dumps(slist,-1), databuffers
191 return pickle.dumps(slist,-1), databuffers
192 elif isinstance(obj, dict):
192 elif isinstance(obj, dict):
193 sobj = {}
193 sobj = {}
194 for k in sorted(obj.iterkeys()):
194 for k in sorted(obj.iterkeys()):
195 s = serialize(can(obj[k]))
195 s = serialize(can(obj[k]))
196 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
196 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
197 databuffers.append(s.getData())
197 databuffers.append(s.getData())
198 s.data = None
198 s.data = None
199 sobj[k] = s
199 sobj[k] = s
200 return pickle.dumps(sobj,-1),databuffers
200 return pickle.dumps(sobj,-1),databuffers
201 else:
201 else:
202 s = serialize(can(obj))
202 s = serialize(can(obj))
203 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
203 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
204 databuffers.append(s.getData())
204 databuffers.append(s.getData())
205 s.data = None
205 s.data = None
206 return pickle.dumps(s,-1),databuffers
206 return pickle.dumps(s,-1),databuffers
207
207
208
208
209 def unserialize_object(bufs):
209 def unserialize_object(bufs):
210 """reconstruct an object serialized by serialize_object from data buffers."""
210 """reconstruct an object serialized by serialize_object from data buffers."""
211 bufs = list(bufs)
211 bufs = list(bufs)
212 sobj = pickle.loads(bufs.pop(0))
212 sobj = pickle.loads(bufs.pop(0))
213 if isinstance(sobj, (list, tuple)):
213 if isinstance(sobj, (list, tuple)):
214 for s in sobj:
214 for s in sobj:
215 if s.data is None:
215 if s.data is None:
216 s.data = bufs.pop(0)
216 s.data = bufs.pop(0)
217 return uncanSequence(map(unserialize, sobj)), bufs
217 return uncanSequence(map(unserialize, sobj)), bufs
218 elif isinstance(sobj, dict):
218 elif isinstance(sobj, dict):
219 newobj = {}
219 newobj = {}
220 for k in sorted(sobj.iterkeys()):
220 for k in sorted(sobj.iterkeys()):
221 s = sobj[k]
221 s = sobj[k]
222 if s.data is None:
222 if s.data is None:
223 s.data = bufs.pop(0)
223 s.data = bufs.pop(0)
224 newobj[k] = uncan(unserialize(s))
224 newobj[k] = uncan(unserialize(s))
225 return newobj, bufs
225 return newobj, bufs
226 else:
226 else:
227 if sobj.data is None:
227 if sobj.data is None:
228 sobj.data = bufs.pop(0)
228 sobj.data = bufs.pop(0)
229 return uncan(unserialize(sobj)), bufs
229 return uncan(unserialize(sobj)), bufs
230
230
231 def pack_apply_message(f, args, kwargs, threshold=64e-6):
231 def pack_apply_message(f, args, kwargs, threshold=64e-6):
232 """pack up a function, args, and kwargs to be sent over the wire
232 """pack up a function, args, and kwargs to be sent over the wire
233 as a series of buffers. Any object whose data is larger than `threshold`
233 as a series of buffers. Any object whose data is larger than `threshold`
234 will not have their data copied (currently only numpy arrays support zero-copy)"""
234 will not have their data copied (currently only numpy arrays support zero-copy)"""
235 msg = [pickle.dumps(can(f),-1)]
235 msg = [pickle.dumps(can(f),-1)]
236 databuffers = [] # for large objects
236 databuffers = [] # for large objects
237 sargs, bufs = serialize_object(args,threshold)
237 sargs, bufs = serialize_object(args,threshold)
238 msg.append(sargs)
238 msg.append(sargs)
239 databuffers.extend(bufs)
239 databuffers.extend(bufs)
240 skwargs, bufs = serialize_object(kwargs,threshold)
240 skwargs, bufs = serialize_object(kwargs,threshold)
241 msg.append(skwargs)
241 msg.append(skwargs)
242 databuffers.extend(bufs)
242 databuffers.extend(bufs)
243 msg.extend(databuffers)
243 msg.extend(databuffers)
244 return msg
244 return msg
245
245
246 def unpack_apply_message(bufs, g=None, copy=True):
246 def unpack_apply_message(bufs, g=None, copy=True):
247 """unpack f,args,kwargs from buffers packed by pack_apply_message()
247 """unpack f,args,kwargs from buffers packed by pack_apply_message()
248 Returns: original f,args,kwargs"""
248 Returns: original f,args,kwargs"""
249 bufs = list(bufs) # allow us to pop
249 bufs = list(bufs) # allow us to pop
250 assert len(bufs) >= 3, "not enough buffers!"
250 assert len(bufs) >= 3, "not enough buffers!"
251 if not copy:
251 if not copy:
252 for i in range(3):
252 for i in range(3):
253 bufs[i] = bufs[i].bytes
253 bufs[i] = bufs[i].bytes
254 cf = pickle.loads(bufs.pop(0))
254 cf = pickle.loads(bufs.pop(0))
255 sargs = list(pickle.loads(bufs.pop(0)))
255 sargs = list(pickle.loads(bufs.pop(0)))
256 skwargs = dict(pickle.loads(bufs.pop(0)))
256 skwargs = dict(pickle.loads(bufs.pop(0)))
257 # print sargs, skwargs
257 # print sargs, skwargs
258 f = uncan(cf, g)
258 f = uncan(cf, g)
259 for sa in sargs:
259 for sa in sargs:
260 if sa.data is None:
260 if sa.data is None:
261 m = bufs.pop(0)
261 m = bufs.pop(0)
262 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
262 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
263 if copy:
263 if copy:
264 sa.data = buffer(m)
264 sa.data = buffer(m)
265 else:
265 else:
266 sa.data = m.buffer
266 sa.data = m.buffer
267 else:
267 else:
268 if copy:
268 if copy:
269 sa.data = m
269 sa.data = m
270 else:
270 else:
271 sa.data = m.bytes
271 sa.data = m.bytes
272
272
273 args = uncanSequence(map(unserialize, sargs), g)
273 args = uncanSequence(map(unserialize, sargs), g)
274 kwargs = {}
274 kwargs = {}
275 for k in sorted(skwargs.iterkeys()):
275 for k in sorted(skwargs.iterkeys()):
276 sa = skwargs[k]
276 sa = skwargs[k]
277 if sa.data is None:
277 if sa.data is None:
278 sa.data = bufs.pop(0)
278 sa.data = bufs.pop(0)
279 kwargs[k] = uncan(unserialize(sa), g)
279 kwargs[k] = uncan(unserialize(sa), g)
280
280
281 return f,args,kwargs
281 return f,args,kwargs
282
282
283 class StreamSession(object):
283 class StreamSession(object):
284 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
284 """tweaked version of IPython.zmq.session.Session, for development in Parallel"""
285 debug=False
285 debug=False
286 key=None
286 key=None
287
287
288 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
288 def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None):
289 if username is None:
289 if username is None:
290 username = os.environ.get('USER','username')
290 username = os.environ.get('USER','username')
291 self.username = username
291 self.username = username
292 if session is None:
292 if session is None:
293 self.session = str(uuid.uuid4())
293 self.session = str(uuid.uuid4())
294 else:
294 else:
295 self.session = session
295 self.session = session
296 self.msg_id = str(uuid.uuid4())
296 self.msg_id = str(uuid.uuid4())
297 if packer is None:
297 if packer is None:
298 self.pack = default_packer
298 self.pack = default_packer
299 else:
299 else:
300 if not callable(packer):
300 if not callable(packer):
301 raise TypeError("packer must be callable, not %s"%type(packer))
301 raise TypeError("packer must be callable, not %s"%type(packer))
302 self.pack = packer
302 self.pack = packer
303
303
304 if unpacker is None:
304 if unpacker is None:
305 self.unpack = default_unpacker
305 self.unpack = default_unpacker
306 else:
306 else:
307 if not callable(unpacker):
307 if not callable(unpacker):
308 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
308 raise TypeError("unpacker must be callable, not %s"%type(unpacker))
309 self.unpack = unpacker
309 self.unpack = unpacker
310
310
311 if key is not None and keyfile is not None:
311 if key is not None and keyfile is not None:
312 raise TypeError("Must specify key OR keyfile, not both")
312 raise TypeError("Must specify key OR keyfile, not both")
313 if keyfile is not None:
313 if keyfile is not None:
314 with open(keyfile) as f:
314 with open(keyfile) as f:
315 self.key = f.read().strip()
315 self.key = f.read().strip()
316 else:
316 else:
317 self.key = key
317 self.key = key
318 # print key, keyfile, self.key
318 # print key, keyfile, self.key
319 self.none = self.pack({})
319 self.none = self.pack({})
320
320
321 def msg_header(self, msg_type):
321 def msg_header(self, msg_type):
322 h = msg_header(self.msg_id, msg_type, self.username, self.session)
322 h = msg_header(self.msg_id, msg_type, self.username, self.session)
323 self.msg_id = str(uuid.uuid4())
323 self.msg_id = str(uuid.uuid4())
324 return h
324 return h
325
325
326 def msg(self, msg_type, content=None, parent=None, subheader=None):
326 def msg(self, msg_type, content=None, parent=None, subheader=None):
327 msg = {}
327 msg = {}
328 msg['header'] = self.msg_header(msg_type)
328 msg['header'] = self.msg_header(msg_type)
329 msg['msg_id'] = msg['header']['msg_id']
329 msg['msg_id'] = msg['header']['msg_id']
330 msg['parent_header'] = {} if parent is None else extract_header(parent)
330 msg['parent_header'] = {} if parent is None else extract_header(parent)
331 msg['msg_type'] = msg_type
331 msg['msg_type'] = msg_type
332 msg['content'] = {} if content is None else content
332 msg['content'] = {} if content is None else content
333 sub = {} if subheader is None else subheader
333 sub = {} if subheader is None else subheader
334 msg['header'].update(sub)
334 msg['header'].update(sub)
335 return msg
335 return msg
336
336
337 def check_key(self, msg_or_header):
337 def check_key(self, msg_or_header):
338 """Check that a message's header has the right key"""
338 """Check that a message's header has the right key"""
339 if self.key is None:
339 if self.key is None:
340 return True
340 return True
341 header = extract_header(msg_or_header)
341 header = extract_header(msg_or_header)
342 return header.get('key', None) == self.key
342 return header.get('key', None) == self.key
343
343
344
344
345 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
345 def send(self, stream, msg_or_type, content=None, buffers=None, parent=None, subheader=None, ident=None):
346 """Build and send a message via stream or socket.
346 """Build and send a message via stream or socket.
347
347
348 Parameters
348 Parameters
349 ----------
349 ----------
350
350
351 stream : zmq.Socket or ZMQStream
351 stream : zmq.Socket or ZMQStream
352 the socket-like object used to send the data
352 the socket-like object used to send the data
353 msg_or_type : str or Message/dict
353 msg_or_type : str or Message/dict
354 Normally, msg_or_type will be a msg_type unless a message is being sent more
354 Normally, msg_or_type will be a msg_type unless a message is being sent more
355 than once.
355 than once.
356
356
357 Returns
357 Returns
358 -------
358 -------
359 (msg,sent) : tuple
359 (msg,sent) : tuple
360 msg : Message
360 msg : Message
361 the nice wrapped dict-like object containing the headers
361 the nice wrapped dict-like object containing the headers
362
362
363 """
363 """
364 if isinstance(msg_or_type, (Message, dict)):
364 if isinstance(msg_or_type, (Message, dict)):
365 # we got a Message, not a msg_type
365 # we got a Message, not a msg_type
366 # don't build a new Message
366 # don't build a new Message
367 msg = msg_or_type
367 msg = msg_or_type
368 content = msg['content']
368 content = msg['content']
369 else:
369 else:
370 msg = self.msg(msg_or_type, content, parent, subheader)
370 msg = self.msg(msg_or_type, content, parent, subheader)
371 buffers = [] if buffers is None else buffers
371 buffers = [] if buffers is None else buffers
372 to_send = []
372 to_send = []
373 if isinstance(ident, list):
373 if isinstance(ident, list):
374 # accept list of idents
374 # accept list of idents
375 to_send.extend(ident)
375 to_send.extend(ident)
376 elif ident is not None:
376 elif ident is not None:
377 to_send.append(ident)
377 to_send.append(ident)
378 to_send.append(DELIM)
378 to_send.append(DELIM)
379 if self.key is not None:
379 if self.key is not None:
380 to_send.append(self.key)
380 to_send.append(self.key)
381 to_send.append(self.pack(msg['header']))
381 to_send.append(self.pack(msg['header']))
382 to_send.append(self.pack(msg['parent_header']))
382 to_send.append(self.pack(msg['parent_header']))
383
383
384 if content is None:
384 if content is None:
385 content = self.none
385 content = self.none
386 elif isinstance(content, dict):
386 elif isinstance(content, dict):
387 content = self.pack(content)
387 content = self.pack(content)
388 elif isinstance(content, str):
388 elif isinstance(content, str):
389 # content is already packed, as in a relayed message
389 # content is already packed, as in a relayed message
390 pass
390 pass
391 else:
391 else:
392 raise TypeError("Content incorrect type: %s"%type(content))
392 raise TypeError("Content incorrect type: %s"%type(content))
393 to_send.append(content)
393 to_send.append(content)
394 flag = 0
394 flag = 0
395 if buffers:
395 if buffers:
396 flag = zmq.SNDMORE
396 flag = zmq.SNDMORE
397 stream.send_multipart(to_send, flag, copy=False)
397 stream.send_multipart(to_send, flag, copy=False)
398 for b in buffers[:-1]:
398 for b in buffers[:-1]:
399 stream.send(b, flag, copy=False)
399 stream.send(b, flag, copy=False)
400 if buffers:
400 if buffers:
401 stream.send(buffers[-1], copy=False)
401 stream.send(buffers[-1], copy=False)
402 omsg = Message(msg)
402 # omsg = Message(msg)
403 if self.debug:
403 if self.debug:
404 pprint.pprint(omsg)
404 pprint.pprint(msg)
405 pprint.pprint(to_send)
405 pprint.pprint(to_send)
406 pprint.pprint(buffers)
406 pprint.pprint(buffers)
407 return omsg
407 return msg
408
408
409 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
409 def send_raw(self, stream, msg, flags=0, copy=True, ident=None):
410 """Send a raw message via ident path.
410 """Send a raw message via ident path.
411
411
412 Parameters
412 Parameters
413 ----------
413 ----------
414 msg : list of sendable buffers"""
414 msg : list of sendable buffers"""
415 to_send = []
415 to_send = []
416 if isinstance(ident, str):
416 if isinstance(ident, str):
417 ident = [ident]
417 ident = [ident]
418 if ident is not None:
418 if ident is not None:
419 to_send.extend(ident)
419 to_send.extend(ident)
420 to_send.append(DELIM)
420 to_send.append(DELIM)
421 if self.key is not None:
421 if self.key is not None:
422 to_send.append(self.key)
422 to_send.append(self.key)
423 to_send.extend(msg)
423 to_send.extend(msg)
424 stream.send_multipart(msg, flags, copy=copy)
424 stream.send_multipart(msg, flags, copy=copy)
425
425
426 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
426 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
427 """receives and unpacks a message
427 """receives and unpacks a message
428 returns [idents], msg"""
428 returns [idents], msg"""
429 if isinstance(socket, ZMQStream):
429 if isinstance(socket, ZMQStream):
430 socket = socket.socket
430 socket = socket.socket
431 try:
431 try:
432 msg = socket.recv_multipart(mode)
432 msg = socket.recv_multipart(mode)
433 except zmq.ZMQError as e:
433 except zmq.ZMQError as e:
434 if e.errno == zmq.EAGAIN:
434 if e.errno == zmq.EAGAIN:
435 # We can convert EAGAIN to None as we know in this case
435 # We can convert EAGAIN to None as we know in this case
436 # recv_json won't return None.
436 # recv_json won't return None.
437 return None
437 return None
438 else:
438 else:
439 raise
439 raise
440 # return an actual Message object
440 # return an actual Message object
441 # determine the number of idents by trying to unpack them.
441 # determine the number of idents by trying to unpack them.
442 # this is terrible:
442 # this is terrible:
443 idents, msg = self.feed_identities(msg, copy)
443 idents, msg = self.feed_identities(msg, copy)
444 try:
444 try:
445 return idents, self.unpack_message(msg, content=content, copy=copy)
445 return idents, self.unpack_message(msg, content=content, copy=copy)
446 except Exception as e:
446 except Exception as e:
447 print (idents, msg)
447 print (idents, msg)
448 # TODO: handle it
448 # TODO: handle it
449 raise e
449 raise e
450
450
451 def feed_identities(self, msg, copy=True):
451 def feed_identities(self, msg, copy=True):
452 """feed until DELIM is reached, then return the prefix as idents and remainder as
452 """feed until DELIM is reached, then return the prefix as idents and remainder as
453 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
453 msg. This is easily broken by setting an IDENT to DELIM, but that would be silly.
454
454
455 Parameters
455 Parameters
456 ----------
456 ----------
457 msg : a list of Message or bytes objects
457 msg : a list of Message or bytes objects
458 the message to be split
458 the message to be split
459 copy : bool
459 copy : bool
460 flag determining whether the arguments are bytes or Messages
460 flag determining whether the arguments are bytes or Messages
461
461
462 Returns
462 Returns
463 -------
463 -------
464 (idents,msg) : two lists
464 (idents,msg) : two lists
465 idents will always be a list of bytes - the indentity prefix
465 idents will always be a list of bytes - the indentity prefix
466 msg will be a list of bytes or Messages, unchanged from input
466 msg will be a list of bytes or Messages, unchanged from input
467 msg should be unpackable via self.unpack_message at this point.
467 msg should be unpackable via self.unpack_message at this point.
468 """
468 """
469 msg = list(msg)
469 msg = list(msg)
470 idents = []
470 idents = []
471 while len(msg) > 3:
471 while len(msg) > 3:
472 if copy:
472 if copy:
473 s = msg[0]
473 s = msg[0]
474 else:
474 else:
475 s = msg[0].bytes
475 s = msg[0].bytes
476 if s == DELIM:
476 if s == DELIM:
477 msg.pop(0)
477 msg.pop(0)
478 break
478 break
479 else:
479 else:
480 idents.append(s)
480 idents.append(s)
481 msg.pop(0)
481 msg.pop(0)
482
482
483 return idents, msg
483 return idents, msg
484
484
485 def unpack_message(self, msg, content=True, copy=True):
485 def unpack_message(self, msg, content=True, copy=True):
486 """Return a message object from the format
486 """Return a message object from the format
487 sent by self.send.
487 sent by self.send.
488
488
489 Parameters:
489 Parameters:
490 -----------
490 -----------
491
491
492 content : bool (True)
492 content : bool (True)
493 whether to unpack the content dict (True),
493 whether to unpack the content dict (True),
494 or leave it serialized (False)
494 or leave it serialized (False)
495
495
496 copy : bool (True)
496 copy : bool (True)
497 whether to return the bytes (True),
497 whether to return the bytes (True),
498 or the non-copying Message object in each place (False)
498 or the non-copying Message object in each place (False)
499
499
500 """
500 """
501 ikey = int(self.key is not None)
501 ikey = int(self.key is not None)
502 minlen = 3 + ikey
502 minlen = 3 + ikey
503 if not len(msg) >= minlen:
503 if not len(msg) >= minlen:
504 raise TypeError("malformed message, must have at least %i elements"%minlen)
504 raise TypeError("malformed message, must have at least %i elements"%minlen)
505 message = {}
505 message = {}
506 if not copy:
506 if not copy:
507 for i in range(minlen):
507 for i in range(minlen):
508 msg[i] = msg[i].bytes
508 msg[i] = msg[i].bytes
509 if ikey:
509 if ikey:
510 if not self.key == msg[0]:
510 if not self.key == msg[0]:
511 raise KeyError("Invalid Session Key: %s"%msg[0])
511 raise KeyError("Invalid Session Key: %s"%msg[0])
512 message['header'] = self.unpack(msg[ikey+0])
512 message['header'] = self.unpack(msg[ikey+0])
513 message['msg_type'] = message['header']['msg_type']
513 message['msg_type'] = message['header']['msg_type']
514 message['parent_header'] = self.unpack(msg[ikey+1])
514 message['parent_header'] = self.unpack(msg[ikey+1])
515 if content:
515 if content:
516 message['content'] = self.unpack(msg[ikey+2])
516 message['content'] = self.unpack(msg[ikey+2])
517 else:
517 else:
518 message['content'] = msg[ikey+2]
518 message['content'] = msg[ikey+2]
519
519
520 # message['buffers'] = msg[3:]
520 # message['buffers'] = msg[3:]
521 # else:
521 # else:
522 # message['header'] = self.unpack(msg[0].bytes)
522 # message['header'] = self.unpack(msg[0].bytes)
523 # message['msg_type'] = message['header']['msg_type']
523 # message['msg_type'] = message['header']['msg_type']
524 # message['parent_header'] = self.unpack(msg[1].bytes)
524 # message['parent_header'] = self.unpack(msg[1].bytes)
525 # if content:
525 # if content:
526 # message['content'] = self.unpack(msg[2].bytes)
526 # message['content'] = self.unpack(msg[2].bytes)
527 # else:
527 # else:
528 # message['content'] = msg[2].bytes
528 # message['content'] = msg[2].bytes
529
529
530 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
530 message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ]
531 return message
531 return message
532
532
533
533
534
534
535 def test_msg2obj():
535 def test_msg2obj():
536 am = dict(x=1)
536 am = dict(x=1)
537 ao = Message(am)
537 ao = Message(am)
538 assert ao.x == am['x']
538 assert ao.x == am['x']
539
539
540 am['y'] = dict(z=1)
540 am['y'] = dict(z=1)
541 ao = Message(am)
541 ao = Message(am)
542 assert ao.y.z == am['y']['z']
542 assert ao.y.z == am['y']['z']
543
543
544 k1, k2 = 'y', 'z'
544 k1, k2 = 'y', 'z'
545 assert ao[k1][k2] == am[k1][k2]
545 assert ao[k1][k2] == am[k1][k2]
546
546
547 am2 = dict(ao)
547 am2 = dict(ao)
548 assert am['x'] == am2['x']
548 assert am['x'] == am2['x']
549 assert am['y']['z'] == am2['y']['z']
549 assert am['y']['z'] == am2['y']['z']
General Comments 0
You need to be logged in to leave comments. Login now