##// END OF EJS Templates
add '-s' for startup script in ipengine...
MinRK -
Show More
@@ -1,295 +1,303
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5 """
6 6
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2008-2009 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import json
19 19 import os
20 20 import sys
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24
25 25 from .clusterdir import (
26 26 ApplicationWithClusterDir,
27 27 ClusterDirConfigLoader
28 28 )
29 29 from IPython.zmq.log import EnginePUBHandler
30 30
31 31 from IPython.parallel import factory
32 32 from IPython.parallel.engine.engine import EngineFactory
33 33 from IPython.parallel.engine.streamkernel import Kernel
34 34 from IPython.parallel.util import disambiguate_url
35 35
36 36 from IPython.utils.importstring import import_item
37 37
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Module level variables
41 41 #-----------------------------------------------------------------------------
42 42
43 43 #: The default config file name for this application
44 44 default_config_file_name = u'ipengine_config.py'
45 45
46 46
47 47 mpi4py_init = """from mpi4py import MPI as mpi
48 48 mpi.size = mpi.COMM_WORLD.Get_size()
49 49 mpi.rank = mpi.COMM_WORLD.Get_rank()
50 50 """
51 51
52 52
53 53 pytrilinos_init = """from PyTrilinos import Epetra
54 54 class SimpleStruct:
55 55 pass
56 56 mpi = SimpleStruct()
57 57 mpi.rank = 0
58 58 mpi.size = 0
59 59 """
60 60
61 61
62 62 _description = """Start an IPython engine for parallel computing.\n\n
63 63
64 64 IPython engines run in parallel and perform computations on behalf of a client
65 65 and controller. A controller needs to be started before the engines. The
66 66 engine can be configured using command line options or using a cluster
67 67 directory. Cluster directories contain config, log and security files and are
68 68 usually located in your ipython directory and named as "cluster_<profile>".
69 69 See the --profile and --cluster-dir options for details.
70 70 """
71 71
72 72 #-----------------------------------------------------------------------------
73 73 # Command line options
74 74 #-----------------------------------------------------------------------------
75 75
76 76
77 77 class IPEngineAppConfigLoader(ClusterDirConfigLoader):
78 78
79 79 def _add_arguments(self):
80 80 super(IPEngineAppConfigLoader, self)._add_arguments()
81 81 paa = self.parser.add_argument
82 82 # Controller config
83 83 paa('--file', '-f',
84 84 type=unicode, dest='Global.url_file',
85 85 help='The full location of the file containing the connection information fo '
86 86 'controller. If this is not given, the file must be in the '
87 87 'security directory of the cluster directory. This location is '
88 88 'resolved using the --profile and --app-dir options.',
89 89 metavar='Global.url_file')
90 90 # MPI
91 91 paa('--mpi',
92 92 type=str, dest='MPI.use',
93 93 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).',
94 94 metavar='MPI.use')
95 95 # Global config
96 96 paa('--log-to-file',
97 97 action='store_true', dest='Global.log_to_file',
98 98 help='Log to a file in the log directory (default is stdout)')
99 99 paa('--log-url',
100 100 dest='Global.log_url',
101 101 help="url of ZMQ logger, as started with iploggerz")
102 102 # paa('--execkey',
103 103 # type=str, dest='Global.exec_key',
104 104 # help='path to a file containing an execution key.',
105 105 # metavar='keyfile')
106 106 # paa('--no-secure',
107 107 # action='store_false', dest='Global.secure',
108 108 # help='Turn off execution keys.')
109 109 # paa('--secure',
110 110 # action='store_true', dest='Global.secure',
111 111 # help='Turn on execution keys (default).')
112 112 # init command
113 113 paa('-c',
114 114 type=str, dest='Global.extra_exec_lines',
115 115 help='specify a command to be run at startup')
116 paa('-s',
117 type=unicode, dest='Global.extra_exec_file',
118 help='specify a script to be run at startup')
116 119
117 120 factory.add_session_arguments(self.parser)
118 121 factory.add_registration_arguments(self.parser)
119 122
120 123
121 124 #-----------------------------------------------------------------------------
122 125 # Main application
123 126 #-----------------------------------------------------------------------------
124 127
125 128
126 129 class IPEngineApp(ApplicationWithClusterDir):
127 130
128 131 name = u'ipengine'
129 132 description = _description
130 133 command_line_loader = IPEngineAppConfigLoader
131 134 default_config_file_name = default_config_file_name
132 135 auto_create_cluster_dir = True
133 136
134 137 def create_default_config(self):
135 138 super(IPEngineApp, self).create_default_config()
136 139
137 140 # The engine should not clean logs as we don't want to remove the
138 141 # active log files of other running engines.
139 142 self.default_config.Global.clean_logs = False
140 143 self.default_config.Global.secure = True
141 144
142 145 # Global config attributes
143 146 self.default_config.Global.exec_lines = []
144 147 self.default_config.Global.extra_exec_lines = ''
148 self.default_config.Global.extra_exec_file = u''
145 149
146 150 # Configuration related to the controller
147 151 # This must match the filename (path not included) that the controller
148 152 # used for the FURL file.
149 153 self.default_config.Global.url_file = u''
150 154 self.default_config.Global.url_file_name = u'ipcontroller-engine.json'
151 155 # If given, this is the actual location of the controller's FURL file.
152 156 # If not, this is computed using the profile, app_dir and furl_file_name
153 157 # self.default_config.Global.key_file_name = u'exec_key.key'
154 158 # self.default_config.Global.key_file = u''
155 159
156 160 # MPI related config attributes
157 161 self.default_config.MPI.use = ''
158 162 self.default_config.MPI.mpi4py = mpi4py_init
159 163 self.default_config.MPI.pytrilinos = pytrilinos_init
160 164
161 165 def post_load_command_line_config(self):
162 166 pass
163 167
164 168 def pre_construct(self):
165 169 super(IPEngineApp, self).pre_construct()
166 170 # self.find_cont_url_file()
167 171 self.find_url_file()
168 172 if self.master_config.Global.extra_exec_lines:
169 173 self.master_config.Global.exec_lines.append(self.master_config.Global.extra_exec_lines)
174 if self.master_config.Global.extra_exec_file:
175 enc = sys.getfilesystemencoding() or 'utf8'
176 cmd="execfile(%r)"%self.master_config.Global.extra_exec_file.encode(enc)
177 self.master_config.Global.exec_lines.append(cmd)
170 178
171 179 # def find_key_file(self):
172 180 # """Set the key file.
173 181 #
174 182 # Here we don't try to actually see if it exists for is valid as that
175 183 # is hadled by the connection logic.
176 184 # """
177 185 # config = self.master_config
178 186 # # Find the actual controller key file
179 187 # if not config.Global.key_file:
180 188 # try_this = os.path.join(
181 189 # config.Global.cluster_dir,
182 190 # config.Global.security_dir,
183 191 # config.Global.key_file_name
184 192 # )
185 193 # config.Global.key_file = try_this
186 194
187 195 def find_url_file(self):
188 196 """Set the key file.
189 197
190 198 Here we don't try to actually see if it exists for is valid as that
191 199 is hadled by the connection logic.
192 200 """
193 201 config = self.master_config
194 202 # Find the actual controller key file
195 203 if not config.Global.url_file:
196 204 try_this = os.path.join(
197 205 config.Global.cluster_dir,
198 206 config.Global.security_dir,
199 207 config.Global.url_file_name
200 208 )
201 209 config.Global.url_file = try_this
202 210
203 211 def construct(self):
204 212 # This is the working dir by now.
205 213 sys.path.insert(0, '')
206 214 config = self.master_config
207 215 # if os.path.exists(config.Global.key_file) and config.Global.secure:
208 216 # config.SessionFactory.exec_key = config.Global.key_file
209 217 if os.path.exists(config.Global.url_file):
210 218 with open(config.Global.url_file) as f:
211 219 d = json.loads(f.read())
212 220 for k,v in d.iteritems():
213 221 if isinstance(v, unicode):
214 222 d[k] = v.encode()
215 223 if d['exec_key']:
216 224 config.SessionFactory.exec_key = d['exec_key']
217 225 d['url'] = disambiguate_url(d['url'], d['location'])
218 226 config.RegistrationFactory.url=d['url']
219 227 config.EngineFactory.location = d['location']
220 228
221 229
222 230
223 231 config.Kernel.exec_lines = config.Global.exec_lines
224 232
225 233 self.start_mpi()
226 234
227 235 # Create the underlying shell class and EngineService
228 236 # shell_class = import_item(self.master_config.Global.shell_class)
229 237 try:
230 238 self.engine = EngineFactory(config=config, logname=self.log.name)
231 239 except:
232 240 self.log.error("Couldn't start the Engine", exc_info=True)
233 241 self.exit(1)
234 242
235 243 self.start_logging()
236 244
237 245 # Create the service hierarchy
238 246 # self.main_service = service.MultiService()
239 247 # self.engine_service.setServiceParent(self.main_service)
240 248 # self.tub_service = Tub()
241 249 # self.tub_service.setServiceParent(self.main_service)
242 250 # # This needs to be called before the connection is initiated
243 251 # self.main_service.startService()
244 252
245 253 # This initiates the connection to the controller and calls
246 254 # register_engine to tell the controller we are ready to do work
247 255 # self.engine_connector = EngineConnector(self.tub_service)
248 256
249 257 # self.log.info("Using furl file: %s" % self.master_config.Global.furl_file)
250 258
251 259 # reactor.callWhenRunning(self.call_connect)
252 260
253 261
254 262 def start_logging(self):
255 263 super(IPEngineApp, self).start_logging()
256 264 if self.master_config.Global.log_url:
257 265 context = self.engine.context
258 266 lsock = context.socket(zmq.PUB)
259 267 lsock.connect(self.master_config.Global.log_url)
260 268 handler = EnginePUBHandler(self.engine, lsock)
261 269 handler.setLevel(self.log_level)
262 270 self.log.addHandler(handler)
263 271
264 272 def start_mpi(self):
265 273 global mpi
266 274 mpikey = self.master_config.MPI.use
267 275 mpi_import_statement = self.master_config.MPI.get(mpikey, None)
268 276 if mpi_import_statement is not None:
269 277 try:
270 278 self.log.info("Initializing MPI:")
271 279 self.log.info(mpi_import_statement)
272 280 exec mpi_import_statement in globals()
273 281 except:
274 282 mpi = None
275 283 else:
276 284 mpi = None
277 285
278 286
279 287 def start_app(self):
280 288 self.engine.start()
281 289 try:
282 290 self.engine.loop.start()
283 291 except KeyboardInterrupt:
284 292 self.log.critical("Engine Interrupted, shutting down...\n")
285 293
286 294
287 295 def launch_new_instance():
288 296 """Create and run the IPython controller"""
289 297 app = IPEngineApp()
290 298 app.start()
291 299
292 300
293 301 if __name__ == '__main__':
294 302 launch_new_instance()
295 303
@@ -1,1279 +1,1281
1 1 """A semi-synchronous Client for the ZMQ cluster"""
2 2 #-----------------------------------------------------------------------------
3 3 # Copyright (C) 2010 The IPython Development Team
4 4 #
5 5 # Distributed under the terms of the BSD License. The full license is in
6 6 # the file COPYING, distributed as part of this software.
7 7 #-----------------------------------------------------------------------------
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Imports
11 11 #-----------------------------------------------------------------------------
12 12
13 13 import os
14 14 import json
15 15 import time
16 16 import warnings
17 17 from datetime import datetime
18 18 from getpass import getpass
19 19 from pprint import pprint
20 20
21 21 pjoin = os.path.join
22 22
23 23 import zmq
24 24 # from zmq.eventloop import ioloop, zmqstream
25 25
26 26 from IPython.utils.path import get_ipython_dir
27 27 from IPython.utils.traitlets import (HasTraits, Int, Instance, CUnicode,
28 28 Dict, List, Bool, Str, Set)
29 29 from IPython.external.decorator import decorator
30 30 from IPython.external.ssh import tunnel
31 31
32 32 from IPython.parallel import error
33 33 from IPython.parallel import streamsession as ss
34 34 from IPython.parallel import util
35 35
36 36 from .asyncresult import AsyncResult, AsyncHubResult
37 37 from IPython.parallel.apps.clusterdir import ClusterDir, ClusterDirError
38 38 from .view import DirectView, LoadBalancedView
39 39
40 40 #--------------------------------------------------------------------------
41 41 # Decorators for Client methods
42 42 #--------------------------------------------------------------------------
43 43
44 44 @decorator
45 45 def spin_first(f, self, *args, **kwargs):
46 46 """Call spin() to sync state prior to calling the method."""
47 47 self.spin()
48 48 return f(self, *args, **kwargs)
49 49
50 50 @decorator
51 51 def default_block(f, self, *args, **kwargs):
52 52 """Default to self.block; preserve self.block."""
53 53 block = kwargs.get('block',None)
54 54 block = self.block if block is None else block
55 55 saveblock = self.block
56 56 self.block = block
57 57 try:
58 58 ret = f(self, *args, **kwargs)
59 59 finally:
60 60 self.block = saveblock
61 61 return ret
62 62
63 63
64 64 #--------------------------------------------------------------------------
65 65 # Classes
66 66 #--------------------------------------------------------------------------
67 67
68 68 class Metadata(dict):
69 69 """Subclass of dict for initializing metadata values.
70 70
71 71 Attribute access works on keys.
72 72
73 73 These objects have a strict set of keys - errors will raise if you try
74 74 to add new keys.
75 75 """
76 76 def __init__(self, *args, **kwargs):
77 77 dict.__init__(self)
78 78 md = {'msg_id' : None,
79 79 'submitted' : None,
80 80 'started' : None,
81 81 'completed' : None,
82 82 'received' : None,
83 83 'engine_uuid' : None,
84 84 'engine_id' : None,
85 85 'follow' : None,
86 86 'after' : None,
87 87 'status' : None,
88 88
89 89 'pyin' : None,
90 90 'pyout' : None,
91 91 'pyerr' : None,
92 92 'stdout' : '',
93 93 'stderr' : '',
94 94 }
95 95 self.update(md)
96 96 self.update(dict(*args, **kwargs))
97 97
98 98 def __getattr__(self, key):
99 99 """getattr aliased to getitem"""
100 100 if key in self.iterkeys():
101 101 return self[key]
102 102 else:
103 103 raise AttributeError(key)
104 104
105 105 def __setattr__(self, key, value):
106 106 """setattr aliased to setitem, with strict"""
107 107 if key in self.iterkeys():
108 108 self[key] = value
109 109 else:
110 110 raise AttributeError(key)
111 111
112 112 def __setitem__(self, key, value):
113 113 """strict static key enforcement"""
114 114 if key in self.iterkeys():
115 115 dict.__setitem__(self, key, value)
116 116 else:
117 117 raise KeyError(key)
118 118
119 119
120 120 class Client(HasTraits):
121 121 """A semi-synchronous client to the IPython ZMQ cluster
122 122
123 123 Parameters
124 124 ----------
125 125
126 126 url_or_file : bytes; zmq url or path to ipcontroller-client.json
127 127 Connection information for the Hub's registration. If a json connector
128 128 file is given, then likely no further configuration is necessary.
129 129 [Default: use profile]
130 130 profile : bytes
131 131 The name of the Cluster profile to be used to find connector information.
132 132 [Default: 'default']
133 133 context : zmq.Context
134 134 Pass an existing zmq.Context instance, otherwise the client will create its own.
135 135 username : bytes
136 136 set username to be passed to the Session object
137 137 debug : bool
138 138 flag for lots of message printing for debug purposes
139 139
140 140 #-------------- ssh related args ----------------
141 141 # These are args for configuring the ssh tunnel to be used
142 142 # credentials are used to forward connections over ssh to the Controller
143 143 # Note that the ip given in `addr` needs to be relative to sshserver
144 144 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
145 145 # and set sshserver as the same machine the Controller is on. However,
146 146 # the only requirement is that sshserver is able to see the Controller
147 147 # (i.e. is within the same trusted network).
148 148
149 149 sshserver : str
150 150 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
151 151 If keyfile or password is specified, and this is not, it will default to
152 152 the ip given in addr.
153 153 sshkey : str; path to public ssh key file
154 154 This specifies a key to be used in ssh login, default None.
155 155 Regular default ssh keys will be used without specifying this argument.
156 156 password : str
157 157 Your ssh password to sshserver. Note that if this is left None,
158 158 you will be prompted for it if passwordless key based login is unavailable.
159 159 paramiko : bool
160 160 flag for whether to use paramiko instead of shell ssh for tunneling.
161 161 [default: True on win32, False else]
162 162
163 163 ------- exec authentication args -------
164 164 If even localhost is untrusted, you can have some protection against
165 165 unauthorized execution by using a key. Messages are still sent
166 166 as cleartext, so if someone can snoop your loopback traffic this will
167 167 not help against malicious attacks.
168 168
169 169 exec_key : str
170 170 an authentication key or file containing a key
171 171 default: None
172 172
173 173
174 174 Attributes
175 175 ----------
176 176
177 177 ids : list of int engine IDs
178 178 requesting the ids attribute always synchronizes
179 179 the registration state. To request ids without synchronization,
180 180 use semi-private _ids attributes.
181 181
182 182 history : list of msg_ids
183 183 a list of msg_ids, keeping track of all the execution
184 184 messages you have submitted in order.
185 185
186 186 outstanding : set of msg_ids
187 187 a set of msg_ids that have been submitted, but whose
188 188 results have not yet been received.
189 189
190 190 results : dict
191 191 a dict of all our results, keyed by msg_id
192 192
193 193 block : bool
194 194 determines default behavior when block not specified
195 195 in execution methods
196 196
197 197 Methods
198 198 -------
199 199
200 200 spin
201 201 flushes incoming results and registration state changes
202 202 control methods spin, and requesting `ids` also ensures up to date
203 203
204 204 wait
205 205 wait on one or more msg_ids
206 206
207 207 execution methods
208 208 apply
209 209 legacy: execute, run
210 210
211 211 data movement
212 212 push, pull, scatter, gather
213 213
214 214 query methods
215 215 queue_status, get_result, purge, result_status
216 216
217 217 control methods
218 218 abort, shutdown
219 219
220 220 """
221 221
222 222
223 223 block = Bool(False)
224 224 outstanding = Set()
225 225 results = Instance('collections.defaultdict', (dict,))
226 226 metadata = Instance('collections.defaultdict', (Metadata,))
227 227 history = List()
228 228 debug = Bool(False)
229 229 profile=CUnicode('default')
230 230
231 231 _outstanding_dict = Instance('collections.defaultdict', (set,))
232 232 _ids = List()
233 233 _connected=Bool(False)
234 234 _ssh=Bool(False)
235 235 _context = Instance('zmq.Context')
236 236 _config = Dict()
237 237 _engines=Instance(util.ReverseDict, (), {})
238 238 # _hub_socket=Instance('zmq.Socket')
239 239 _query_socket=Instance('zmq.Socket')
240 240 _control_socket=Instance('zmq.Socket')
241 241 _iopub_socket=Instance('zmq.Socket')
242 242 _notification_socket=Instance('zmq.Socket')
243 243 _mux_socket=Instance('zmq.Socket')
244 244 _task_socket=Instance('zmq.Socket')
245 245 _task_scheme=Str()
246 246 _closed = False
247 247 _ignored_control_replies=Int(0)
248 248 _ignored_hub_replies=Int(0)
249 249
250 250 def __init__(self, url_or_file=None, profile='default', cluster_dir=None, ipython_dir=None,
251 251 context=None, username=None, debug=False, exec_key=None,
252 252 sshserver=None, sshkey=None, password=None, paramiko=None,
253 253 timeout=10
254 254 ):
255 255 super(Client, self).__init__(debug=debug, profile=profile)
256 256 if context is None:
257 257 context = zmq.Context.instance()
258 258 self._context = context
259 259
260 260
261 261 self._setup_cluster_dir(profile, cluster_dir, ipython_dir)
262 262 if self._cd is not None:
263 263 if url_or_file is None:
264 264 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
265 265 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
266 266 " Please specify at least one of url_or_file or profile."
267 267
268 268 try:
269 269 util.validate_url(url_or_file)
270 270 except AssertionError:
271 271 if not os.path.exists(url_or_file):
272 272 if self._cd:
273 273 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
274 274 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
275 275 with open(url_or_file) as f:
276 276 cfg = json.loads(f.read())
277 277 else:
278 278 cfg = {'url':url_or_file}
279 279
280 280 # sync defaults from args, json:
281 281 if sshserver:
282 282 cfg['ssh'] = sshserver
283 283 if exec_key:
284 284 cfg['exec_key'] = exec_key
285 285 exec_key = cfg['exec_key']
286 286 sshserver=cfg['ssh']
287 287 url = cfg['url']
288 288 location = cfg.setdefault('location', None)
289 289 cfg['url'] = util.disambiguate_url(cfg['url'], location)
290 290 url = cfg['url']
291 291
292 292 self._config = cfg
293 293
294 294 self._ssh = bool(sshserver or sshkey or password)
295 295 if self._ssh and sshserver is None:
296 296 # default to ssh via localhost
297 297 sshserver = url.split('://')[1].split(':')[0]
298 298 if self._ssh and password is None:
299 299 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
300 300 password=False
301 301 else:
302 302 password = getpass("SSH Password for %s: "%sshserver)
303 303 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
304 304 if exec_key is not None and os.path.isfile(exec_key):
305 305 arg = 'keyfile'
306 306 else:
307 307 arg = 'key'
308 308 key_arg = {arg:exec_key}
309 309 if username is None:
310 310 self.session = ss.StreamSession(**key_arg)
311 311 else:
312 312 self.session = ss.StreamSession(username, **key_arg)
313 313 self._query_socket = self._context.socket(zmq.XREQ)
314 314 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
315 315 if self._ssh:
316 316 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
317 317 else:
318 318 self._query_socket.connect(url)
319 319
320 320 self.session.debug = self.debug
321 321
322 322 self._notification_handlers = {'registration_notification' : self._register_engine,
323 323 'unregistration_notification' : self._unregister_engine,
324 324 'shutdown_notification' : lambda msg: self.close(),
325 325 }
326 326 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
327 327 'apply_reply' : self._handle_apply_reply}
328 328 self._connect(sshserver, ssh_kwargs, timeout)
329 329
330 330 def __del__(self):
331 331 """cleanup sockets, but _not_ context."""
332 332 self.close()
333 333
334 334 def _setup_cluster_dir(self, profile, cluster_dir, ipython_dir):
335 335 if ipython_dir is None:
336 336 ipython_dir = get_ipython_dir()
337 337 if cluster_dir is not None:
338 338 try:
339 339 self._cd = ClusterDir.find_cluster_dir(cluster_dir)
340 340 return
341 341 except ClusterDirError:
342 342 pass
343 343 elif profile is not None:
344 344 try:
345 345 self._cd = ClusterDir.find_cluster_dir_by_profile(
346 346 ipython_dir, profile)
347 347 return
348 348 except ClusterDirError:
349 349 pass
350 350 self._cd = None
351 351
352 352 def _update_engines(self, engines):
353 353 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
354 354 for k,v in engines.iteritems():
355 355 eid = int(k)
356 356 self._engines[eid] = bytes(v) # force not unicode
357 357 self._ids.append(eid)
358 358 self._ids = sorted(self._ids)
359 359 if sorted(self._engines.keys()) != range(len(self._engines)) and \
360 360 self._task_scheme == 'pure' and self._task_socket:
361 361 self._stop_scheduling_tasks()
362 362
363 363 def _stop_scheduling_tasks(self):
364 364 """Stop scheduling tasks because an engine has been unregistered
365 365 from a pure ZMQ scheduler.
366 366 """
367 367 self._task_socket.close()
368 368 self._task_socket = None
369 369 msg = "An engine has been unregistered, and we are using pure " +\
370 370 "ZMQ task scheduling. Task farming will be disabled."
371 371 if self.outstanding:
372 372 msg += " If you were running tasks when this happened, " +\
373 373 "some `outstanding` msg_ids may never resolve."
374 374 warnings.warn(msg, RuntimeWarning)
375 375
376 376 def _build_targets(self, targets):
377 377 """Turn valid target IDs or 'all' into two lists:
378 378 (int_ids, uuids).
379 379 """
380 380 if targets is None:
381 381 targets = self._ids
382 382 elif isinstance(targets, str):
383 383 if targets.lower() == 'all':
384 384 targets = self._ids
385 385 else:
386 386 raise TypeError("%r not valid str target, must be 'all'"%(targets))
387 387 elif isinstance(targets, int):
388 388 if targets < 0:
389 389 targets = self.ids[targets]
390 390 if targets not in self.ids:
391 391 raise IndexError("No such engine: %i"%targets)
392 392 targets = [targets]
393 393
394 394 if isinstance(targets, slice):
395 395 indices = range(len(self._ids))[targets]
396 396 ids = self.ids
397 397 targets = [ ids[i] for i in indices ]
398 398
399 399 if not isinstance(targets, (tuple, list, xrange)):
400 400 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
401 401
402 402 return [self._engines[t] for t in targets], list(targets)
403 403
404 404 def _connect(self, sshserver, ssh_kwargs, timeout):
405 405 """setup all our socket connections to the cluster. This is called from
406 406 __init__."""
407 407
408 408 # Maybe allow reconnecting?
409 409 if self._connected:
410 410 return
411 411 self._connected=True
412 412
413 413 def connect_socket(s, url):
414 414 url = util.disambiguate_url(url, self._config['location'])
415 415 if self._ssh:
416 416 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
417 417 else:
418 418 return s.connect(url)
419 419
420 420 self.session.send(self._query_socket, 'connection_request')
421 421 r,w,x = zmq.select([self._query_socket],[],[], timeout)
422 422 if not r:
423 423 raise error.TimeoutError("Hub connection request timed out")
424 424 idents,msg = self.session.recv(self._query_socket,mode=0)
425 425 if self.debug:
426 426 pprint(msg)
427 427 msg = ss.Message(msg)
428 428 content = msg.content
429 429 self._config['registration'] = dict(content)
430 430 if content.status == 'ok':
431 431 if content.mux:
432 432 self._mux_socket = self._context.socket(zmq.XREQ)
433 433 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
434 434 connect_socket(self._mux_socket, content.mux)
435 435 if content.task:
436 436 self._task_scheme, task_addr = content.task
437 437 self._task_socket = self._context.socket(zmq.XREQ)
438 438 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
439 439 connect_socket(self._task_socket, task_addr)
440 440 if content.notification:
441 441 self._notification_socket = self._context.socket(zmq.SUB)
442 442 connect_socket(self._notification_socket, content.notification)
443 443 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
444 444 # if content.query:
445 445 # self._query_socket = self._context.socket(zmq.XREQ)
446 446 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
447 447 # connect_socket(self._query_socket, content.query)
448 448 if content.control:
449 449 self._control_socket = self._context.socket(zmq.XREQ)
450 450 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
451 451 connect_socket(self._control_socket, content.control)
452 452 if content.iopub:
453 453 self._iopub_socket = self._context.socket(zmq.SUB)
454 454 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
455 455 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
456 456 connect_socket(self._iopub_socket, content.iopub)
457 457 self._update_engines(dict(content.engines))
458 458 else:
459 459 self._connected = False
460 460 raise Exception("Failed to connect!")
461 461
462 462 #--------------------------------------------------------------------------
463 463 # handlers and callbacks for incoming messages
464 464 #--------------------------------------------------------------------------
465 465
466 466 def _unwrap_exception(self, content):
467 467 """unwrap exception, and remap engine_id to int."""
468 468 e = error.unwrap_exception(content)
469 469 # print e.traceback
470 470 if e.engine_info:
471 471 e_uuid = e.engine_info['engine_uuid']
472 472 eid = self._engines[e_uuid]
473 473 e.engine_info['engine_id'] = eid
474 474 return e
475 475
476 476 def _extract_metadata(self, header, parent, content):
477 477 md = {'msg_id' : parent['msg_id'],
478 478 'received' : datetime.now(),
479 479 'engine_uuid' : header.get('engine', None),
480 480 'follow' : parent.get('follow', []),
481 481 'after' : parent.get('after', []),
482 482 'status' : content['status'],
483 483 }
484 484
485 485 if md['engine_uuid'] is not None:
486 486 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
487 487
488 488 if 'date' in parent:
489 489 md['submitted'] = datetime.strptime(parent['date'], util.ISO8601)
490 490 if 'started' in header:
491 491 md['started'] = datetime.strptime(header['started'], util.ISO8601)
492 492 if 'date' in header:
493 493 md['completed'] = datetime.strptime(header['date'], util.ISO8601)
494 494 return md
495 495
496 496 def _register_engine(self, msg):
497 497 """Register a new engine, and update our connection info."""
498 498 content = msg['content']
499 499 eid = content['id']
500 500 d = {eid : content['queue']}
501 501 self._update_engines(d)
502 502
503 503 def _unregister_engine(self, msg):
504 504 """Unregister an engine that has died."""
505 505 content = msg['content']
506 506 eid = int(content['id'])
507 507 if eid in self._ids:
508 508 self._ids.remove(eid)
509 509 uuid = self._engines.pop(eid)
510 510
511 511 self._handle_stranded_msgs(eid, uuid)
512 512
513 513 if self._task_socket and self._task_scheme == 'pure':
514 514 self._stop_scheduling_tasks()
515 515
516 516 def _handle_stranded_msgs(self, eid, uuid):
517 517 """Handle messages known to be on an engine when the engine unregisters.
518 518
519 519 It is possible that this will fire prematurely - that is, an engine will
520 520 go down after completing a result, and the client will be notified
521 521 of the unregistration and later receive the successful result.
522 522 """
523 523
524 524 outstanding = self._outstanding_dict[uuid]
525 525
526 526 for msg_id in list(outstanding):
527 527 if msg_id in self.results:
528 528 # we already
529 529 continue
530 530 try:
531 531 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
532 532 except:
533 533 content = error.wrap_exception()
534 534 # build a fake message:
535 535 parent = {}
536 536 header = {}
537 537 parent['msg_id'] = msg_id
538 538 header['engine'] = uuid
539 539 header['date'] = datetime.now().strftime(util.ISO8601)
540 540 msg = dict(parent_header=parent, header=header, content=content)
541 541 self._handle_apply_reply(msg)
542 542
543 543 def _handle_execute_reply(self, msg):
544 544 """Save the reply to an execute_request into our results.
545 545
546 546 execute messages are never actually used. apply is used instead.
547 547 """
548 548
549 549 parent = msg['parent_header']
550 550 msg_id = parent['msg_id']
551 551 if msg_id not in self.outstanding:
552 552 if msg_id in self.history:
553 553 print ("got stale result: %s"%msg_id)
554 554 else:
555 555 print ("got unknown result: %s"%msg_id)
556 556 else:
557 557 self.outstanding.remove(msg_id)
558 558 self.results[msg_id] = self._unwrap_exception(msg['content'])
559 559
560 560 def _handle_apply_reply(self, msg):
561 561 """Save the reply to an apply_request into our results."""
562 562 parent = msg['parent_header']
563 563 msg_id = parent['msg_id']
564 564 if msg_id not in self.outstanding:
565 565 if msg_id in self.history:
566 566 print ("got stale result: %s"%msg_id)
567 567 print self.results[msg_id]
568 568 print msg
569 569 else:
570 570 print ("got unknown result: %s"%msg_id)
571 571 else:
572 572 self.outstanding.remove(msg_id)
573 573 content = msg['content']
574 574 header = msg['header']
575 575
576 576 # construct metadata:
577 577 md = self.metadata[msg_id]
578 578 md.update(self._extract_metadata(header, parent, content))
579 579 # is this redundant?
580 580 self.metadata[msg_id] = md
581 581
582 582 e_outstanding = self._outstanding_dict[md['engine_uuid']]
583 583 if msg_id in e_outstanding:
584 584 e_outstanding.remove(msg_id)
585 585
586 586 # construct result:
587 587 if content['status'] == 'ok':
588 588 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
589 589 elif content['status'] == 'aborted':
590 590 self.results[msg_id] = error.TaskAborted(msg_id)
591 591 elif content['status'] == 'resubmitted':
592 592 # TODO: handle resubmission
593 593 pass
594 594 else:
595 595 self.results[msg_id] = self._unwrap_exception(content)
596 596
597 597 def _flush_notifications(self):
598 598 """Flush notifications of engine registrations waiting
599 599 in ZMQ queue."""
600 600 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
601 601 while msg is not None:
602 602 if self.debug:
603 603 pprint(msg)
604 604 msg = msg[-1]
605 605 msg_type = msg['msg_type']
606 606 handler = self._notification_handlers.get(msg_type, None)
607 607 if handler is None:
608 608 raise Exception("Unhandled message type: %s"%msg.msg_type)
609 609 else:
610 610 handler(msg)
611 611 msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
612 612
613 613 def _flush_results(self, sock):
614 614 """Flush task or queue results waiting in ZMQ queue."""
615 615 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
616 616 while msg is not None:
617 617 if self.debug:
618 618 pprint(msg)
619 619 msg = msg[-1]
620 620 msg_type = msg['msg_type']
621 621 handler = self._queue_handlers.get(msg_type, None)
622 622 if handler is None:
623 623 raise Exception("Unhandled message type: %s"%msg.msg_type)
624 624 else:
625 625 handler(msg)
626 626 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
627 627
628 628 def _flush_control(self, sock):
629 629 """Flush replies from the control channel waiting
630 630 in the ZMQ queue.
631 631
632 632 Currently: ignore them."""
633 633 if self._ignored_control_replies <= 0:
634 634 return
635 635 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
636 636 while msg is not None:
637 637 self._ignored_control_replies -= 1
638 638 if self.debug:
639 639 pprint(msg)
640 640 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
641 641
642 642 def _flush_ignored_control(self):
643 643 """flush ignored control replies"""
644 644 while self._ignored_control_replies > 0:
645 645 self.session.recv(self._control_socket)
646 646 self._ignored_control_replies -= 1
647 647
648 648 def _flush_ignored_hub_replies(self):
649 649 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
650 650 while msg is not None:
651 651 msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
652 652
653 653 def _flush_iopub(self, sock):
654 654 """Flush replies from the iopub channel waiting
655 655 in the ZMQ queue.
656 656 """
657 657 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
658 658 while msg is not None:
659 659 if self.debug:
660 660 pprint(msg)
661 661 msg = msg[-1]
662 662 parent = msg['parent_header']
663 663 msg_id = parent['msg_id']
664 664 content = msg['content']
665 665 header = msg['header']
666 666 msg_type = msg['msg_type']
667 667
668 668 # init metadata:
669 669 md = self.metadata[msg_id]
670 670
671 671 if msg_type == 'stream':
672 672 name = content['name']
673 673 s = md[name] or ''
674 674 md[name] = s + content['data']
675 675 elif msg_type == 'pyerr':
676 676 md.update({'pyerr' : self._unwrap_exception(content)})
677 elif msg_type == 'pyin':
678 md.update({'pyin' : content['code']})
677 679 else:
678 md.update({msg_type : content['data']})
680 md.update({msg_type : content.get('data', '')})
679 681
680 682 # reduntant?
681 683 self.metadata[msg_id] = md
682 684
683 685 msg = self.session.recv(sock, mode=zmq.NOBLOCK)
684 686
685 687 #--------------------------------------------------------------------------
686 688 # len, getitem
687 689 #--------------------------------------------------------------------------
688 690
689 691 def __len__(self):
690 692 """len(client) returns # of engines."""
691 693 return len(self.ids)
692 694
693 695 def __getitem__(self, key):
694 696 """index access returns DirectView multiplexer objects
695 697
696 698 Must be int, slice, or list/tuple/xrange of ints"""
697 699 if not isinstance(key, (int, slice, tuple, list, xrange)):
698 700 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
699 701 else:
700 702 return self.direct_view(key)
701 703
702 704 #--------------------------------------------------------------------------
703 705 # Begin public methods
704 706 #--------------------------------------------------------------------------
705 707
706 708 @property
707 709 def ids(self):
708 710 """Always up-to-date ids property."""
709 711 self._flush_notifications()
710 712 # always copy:
711 713 return list(self._ids)
712 714
713 715 def close(self):
714 716 if self._closed:
715 717 return
716 718 snames = filter(lambda n: n.endswith('socket'), dir(self))
717 719 for socket in map(lambda name: getattr(self, name), snames):
718 720 if isinstance(socket, zmq.Socket) and not socket.closed:
719 721 socket.close()
720 722 self._closed = True
721 723
722 724 def spin(self):
723 725 """Flush any registration notifications and execution results
724 726 waiting in the ZMQ queue.
725 727 """
726 728 if self._notification_socket:
727 729 self._flush_notifications()
728 730 if self._mux_socket:
729 731 self._flush_results(self._mux_socket)
730 732 if self._task_socket:
731 733 self._flush_results(self._task_socket)
732 734 if self._control_socket:
733 735 self._flush_control(self._control_socket)
734 736 if self._iopub_socket:
735 737 self._flush_iopub(self._iopub_socket)
736 738 if self._query_socket:
737 739 self._flush_ignored_hub_replies()
738 740
739 741 def wait(self, jobs=None, timeout=-1):
740 742 """waits on one or more `jobs`, for up to `timeout` seconds.
741 743
742 744 Parameters
743 745 ----------
744 746
745 747 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
746 748 ints are indices to self.history
747 749 strs are msg_ids
748 750 default: wait on all outstanding messages
749 751 timeout : float
750 752 a time in seconds, after which to give up.
751 753 default is -1, which means no timeout
752 754
753 755 Returns
754 756 -------
755 757
756 758 True : when all msg_ids are done
757 759 False : timeout reached, some msg_ids still outstanding
758 760 """
759 761 tic = time.time()
760 762 if jobs is None:
761 763 theids = self.outstanding
762 764 else:
763 765 if isinstance(jobs, (int, str, AsyncResult)):
764 766 jobs = [jobs]
765 767 theids = set()
766 768 for job in jobs:
767 769 if isinstance(job, int):
768 770 # index access
769 771 job = self.history[job]
770 772 elif isinstance(job, AsyncResult):
771 773 map(theids.add, job.msg_ids)
772 774 continue
773 775 theids.add(job)
774 776 if not theids.intersection(self.outstanding):
775 777 return True
776 778 self.spin()
777 779 while theids.intersection(self.outstanding):
778 780 if timeout >= 0 and ( time.time()-tic ) > timeout:
779 781 break
780 782 time.sleep(1e-3)
781 783 self.spin()
782 784 return len(theids.intersection(self.outstanding)) == 0
783 785
784 786 #--------------------------------------------------------------------------
785 787 # Control methods
786 788 #--------------------------------------------------------------------------
787 789
788 790 @spin_first
789 791 @default_block
790 792 def clear(self, targets=None, block=None):
791 793 """Clear the namespace in target(s)."""
792 794 targets = self._build_targets(targets)[0]
793 795 for t in targets:
794 796 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
795 797 error = False
796 798 if self.block:
797 799 self._flush_ignored_control()
798 800 for i in range(len(targets)):
799 801 idents,msg = self.session.recv(self._control_socket,0)
800 802 if self.debug:
801 803 pprint(msg)
802 804 if msg['content']['status'] != 'ok':
803 805 error = self._unwrap_exception(msg['content'])
804 806 else:
805 807 self._ignored_control_replies += len(targets)
806 808 if error:
807 809 raise error
808 810
809 811
810 812 @spin_first
811 813 @default_block
812 814 def abort(self, jobs=None, targets=None, block=None):
813 815 """Abort specific jobs from the execution queues of target(s).
814 816
815 817 This is a mechanism to prevent jobs that have already been submitted
816 818 from executing.
817 819
818 820 Parameters
819 821 ----------
820 822
821 823 jobs : msg_id, list of msg_ids, or AsyncResult
822 824 The jobs to be aborted
823 825
824 826
825 827 """
826 828 targets = self._build_targets(targets)[0]
827 829 msg_ids = []
828 830 if isinstance(jobs, (basestring,AsyncResult)):
829 831 jobs = [jobs]
830 832 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
831 833 if bad_ids:
832 834 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
833 835 for j in jobs:
834 836 if isinstance(j, AsyncResult):
835 837 msg_ids.extend(j.msg_ids)
836 838 else:
837 839 msg_ids.append(j)
838 840 content = dict(msg_ids=msg_ids)
839 841 for t in targets:
840 842 self.session.send(self._control_socket, 'abort_request',
841 843 content=content, ident=t)
842 844 error = False
843 845 if self.block:
844 846 self._flush_ignored_control()
845 847 for i in range(len(targets)):
846 848 idents,msg = self.session.recv(self._control_socket,0)
847 849 if self.debug:
848 850 pprint(msg)
849 851 if msg['content']['status'] != 'ok':
850 852 error = self._unwrap_exception(msg['content'])
851 853 else:
852 854 self._ignored_control_replies += len(targets)
853 855 if error:
854 856 raise error
855 857
856 858 @spin_first
857 859 @default_block
858 860 def shutdown(self, targets=None, restart=False, hub=False, block=None):
859 861 """Terminates one or more engine processes, optionally including the hub."""
860 862 if hub:
861 863 targets = 'all'
862 864 targets = self._build_targets(targets)[0]
863 865 for t in targets:
864 866 self.session.send(self._control_socket, 'shutdown_request',
865 867 content={'restart':restart},ident=t)
866 868 error = False
867 869 if block or hub:
868 870 self._flush_ignored_control()
869 871 for i in range(len(targets)):
870 872 idents,msg = self.session.recv(self._control_socket, 0)
871 873 if self.debug:
872 874 pprint(msg)
873 875 if msg['content']['status'] != 'ok':
874 876 error = self._unwrap_exception(msg['content'])
875 877 else:
876 878 self._ignored_control_replies += len(targets)
877 879
878 880 if hub:
879 881 time.sleep(0.25)
880 882 self.session.send(self._query_socket, 'shutdown_request')
881 883 idents,msg = self.session.recv(self._query_socket, 0)
882 884 if self.debug:
883 885 pprint(msg)
884 886 if msg['content']['status'] != 'ok':
885 887 error = self._unwrap_exception(msg['content'])
886 888
887 889 if error:
888 890 raise error
889 891
890 892 #--------------------------------------------------------------------------
891 893 # Execution methods
892 894 #--------------------------------------------------------------------------
893 895
894 896 @default_block
895 897 def _execute(self, code, targets='all', block=None):
896 898 """Executes `code` on `targets` in blocking or nonblocking manner.
897 899
898 900 ``execute`` is always `bound` (affects engine namespace)
899 901
900 902 Parameters
901 903 ----------
902 904
903 905 code : str
904 906 the code string to be executed
905 907 targets : int/str/list of ints/strs
906 908 the engines on which to execute
907 909 default : all
908 910 block : bool
909 911 whether or not to wait until done to return
910 912 default: self.block
911 913 """
912 914 return self[targets].execute(code, block=block)
913 915
914 916 def _maybe_raise(self, result):
915 917 """wrapper for maybe raising an exception if apply failed."""
916 918 if isinstance(result, error.RemoteError):
917 919 raise result
918 920
919 921 return result
920 922
921 923 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
922 924 ident=None):
923 925 """construct and send an apply message via a socket.
924 926
925 927 This is the principal method with which all engine execution is performed by views.
926 928 """
927 929
928 930 assert not self._closed, "cannot use me anymore, I'm closed!"
929 931 # defaults:
930 932 args = args if args is not None else []
931 933 kwargs = kwargs if kwargs is not None else {}
932 934 subheader = subheader if subheader is not None else {}
933 935
934 936 # validate arguments
935 937 if not callable(f):
936 938 raise TypeError("f must be callable, not %s"%type(f))
937 939 if not isinstance(args, (tuple, list)):
938 940 raise TypeError("args must be tuple or list, not %s"%type(args))
939 941 if not isinstance(kwargs, dict):
940 942 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
941 943 if not isinstance(subheader, dict):
942 944 raise TypeError("subheader must be dict, not %s"%type(subheader))
943 945
944 946 if not self._ids:
945 947 # flush notification socket if no engines yet
946 948 any_ids = self.ids
947 949 if not any_ids:
948 950 raise error.NoEnginesRegistered("Can't execute without any connected engines.")
949 951 # enforce types of f,args,kwargs
950 952
951 953 bufs = util.pack_apply_message(f,args,kwargs)
952 954
953 955 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
954 956 subheader=subheader, track=track)
955 957
956 958 msg_id = msg['msg_id']
957 959 self.outstanding.add(msg_id)
958 960 if ident:
959 961 # possibly routed to a specific engine
960 962 if isinstance(ident, list):
961 963 ident = ident[-1]
962 964 if ident in self._engines.values():
963 965 # save for later, in case of engine death
964 966 self._outstanding_dict[ident].add(msg_id)
965 967 self.history.append(msg_id)
966 968 self.metadata[msg_id]['submitted'] = datetime.now()
967 969
968 970 return msg
969 971
970 972 #--------------------------------------------------------------------------
971 973 # construct a View object
972 974 #--------------------------------------------------------------------------
973 975
974 976 def load_balanced_view(self, targets=None):
975 977 """construct a DirectView object.
976 978
977 979 If no arguments are specified, create a LoadBalancedView
978 980 using all engines.
979 981
980 982 Parameters
981 983 ----------
982 984
983 985 targets: list,slice,int,etc. [default: use all engines]
984 986 The subset of engines across which to load-balance
985 987 """
986 988 if targets is not None:
987 989 targets = self._build_targets(targets)[1]
988 990 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
989 991
990 992 def direct_view(self, targets='all'):
991 993 """construct a DirectView object.
992 994
993 995 If no targets are specified, create a DirectView
994 996 using all engines.
995 997
996 998 Parameters
997 999 ----------
998 1000
999 1001 targets: list,slice,int,etc. [default: use all engines]
1000 1002 The engines to use for the View
1001 1003 """
1002 1004 single = isinstance(targets, int)
1003 1005 targets = self._build_targets(targets)[1]
1004 1006 if single:
1005 1007 targets = targets[0]
1006 1008 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1007 1009
1008 1010 #--------------------------------------------------------------------------
1009 1011 # Data movement (TO BE REMOVED)
1010 1012 #--------------------------------------------------------------------------
1011 1013
1012 1014 @default_block
1013 1015 def _push(self, ns, targets='all', block=None, track=False):
1014 1016 """Push the contents of `ns` into the namespace on `target`"""
1015 1017 if not isinstance(ns, dict):
1016 1018 raise TypeError("Must be a dict, not %s"%type(ns))
1017 1019 result = self.apply(util._push, kwargs=ns, targets=targets, block=block, bound=True, balanced=False, track=track)
1018 1020 if not block:
1019 1021 return result
1020 1022
1021 1023 @default_block
1022 1024 def _pull(self, keys, targets='all', block=None):
1023 1025 """Pull objects from `target`'s namespace by `keys`"""
1024 1026 if isinstance(keys, basestring):
1025 1027 pass
1026 1028 elif isinstance(keys, (list,tuple,set)):
1027 1029 for key in keys:
1028 1030 if not isinstance(key, basestring):
1029 1031 raise TypeError("keys must be str, not type %r"%type(key))
1030 1032 else:
1031 1033 raise TypeError("keys must be strs, not %r"%keys)
1032 1034 result = self.apply(util._pull, (keys,), targets=targets, block=block, bound=True, balanced=False)
1033 1035 return result
1034 1036
1035 1037 #--------------------------------------------------------------------------
1036 1038 # Query methods
1037 1039 #--------------------------------------------------------------------------
1038 1040
1039 1041 @spin_first
1040 1042 @default_block
1041 1043 def get_result(self, indices_or_msg_ids=None, block=None):
1042 1044 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1043 1045
1044 1046 If the client already has the results, no request to the Hub will be made.
1045 1047
1046 1048 This is a convenient way to construct AsyncResult objects, which are wrappers
1047 1049 that include metadata about execution, and allow for awaiting results that
1048 1050 were not submitted by this Client.
1049 1051
1050 1052 It can also be a convenient way to retrieve the metadata associated with
1051 1053 blocking execution, since it always retrieves
1052 1054
1053 1055 Examples
1054 1056 --------
1055 1057 ::
1056 1058
1057 1059 In [10]: r = client.apply()
1058 1060
1059 1061 Parameters
1060 1062 ----------
1061 1063
1062 1064 indices_or_msg_ids : integer history index, str msg_id, or list of either
1063 1065 The indices or msg_ids of indices to be retrieved
1064 1066
1065 1067 block : bool
1066 1068 Whether to wait for the result to be done
1067 1069
1068 1070 Returns
1069 1071 -------
1070 1072
1071 1073 AsyncResult
1072 1074 A single AsyncResult object will always be returned.
1073 1075
1074 1076 AsyncHubResult
1075 1077 A subclass of AsyncResult that retrieves results from the Hub
1076 1078
1077 1079 """
1078 1080 if indices_or_msg_ids is None:
1079 1081 indices_or_msg_ids = -1
1080 1082
1081 1083 if not isinstance(indices_or_msg_ids, (list,tuple)):
1082 1084 indices_or_msg_ids = [indices_or_msg_ids]
1083 1085
1084 1086 theids = []
1085 1087 for id in indices_or_msg_ids:
1086 1088 if isinstance(id, int):
1087 1089 id = self.history[id]
1088 1090 if not isinstance(id, str):
1089 1091 raise TypeError("indices must be str or int, not %r"%id)
1090 1092 theids.append(id)
1091 1093
1092 1094 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1093 1095 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1094 1096
1095 1097 if remote_ids:
1096 1098 ar = AsyncHubResult(self, msg_ids=theids)
1097 1099 else:
1098 1100 ar = AsyncResult(self, msg_ids=theids)
1099 1101
1100 1102 if block:
1101 1103 ar.wait()
1102 1104
1103 1105 return ar
1104 1106
1105 1107 @spin_first
1106 1108 def result_status(self, msg_ids, status_only=True):
1107 1109 """Check on the status of the result(s) of the apply request with `msg_ids`.
1108 1110
1109 1111 If status_only is False, then the actual results will be retrieved, else
1110 1112 only the status of the results will be checked.
1111 1113
1112 1114 Parameters
1113 1115 ----------
1114 1116
1115 1117 msg_ids : list of msg_ids
1116 1118 if int:
1117 1119 Passed as index to self.history for convenience.
1118 1120 status_only : bool (default: True)
1119 1121 if False:
1120 1122 Retrieve the actual results of completed tasks.
1121 1123
1122 1124 Returns
1123 1125 -------
1124 1126
1125 1127 results : dict
1126 1128 There will always be the keys 'pending' and 'completed', which will
1127 1129 be lists of msg_ids that are incomplete or complete. If `status_only`
1128 1130 is False, then completed results will be keyed by their `msg_id`.
1129 1131 """
1130 1132 if not isinstance(msg_ids, (list,tuple)):
1131 1133 msg_ids = [msg_ids]
1132 1134
1133 1135 theids = []
1134 1136 for msg_id in msg_ids:
1135 1137 if isinstance(msg_id, int):
1136 1138 msg_id = self.history[msg_id]
1137 1139 if not isinstance(msg_id, basestring):
1138 1140 raise TypeError("msg_ids must be str, not %r"%msg_id)
1139 1141 theids.append(msg_id)
1140 1142
1141 1143 completed = []
1142 1144 local_results = {}
1143 1145
1144 1146 # comment this block out to temporarily disable local shortcut:
1145 1147 for msg_id in theids:
1146 1148 if msg_id in self.results:
1147 1149 completed.append(msg_id)
1148 1150 local_results[msg_id] = self.results[msg_id]
1149 1151 theids.remove(msg_id)
1150 1152
1151 1153 if theids: # some not locally cached
1152 1154 content = dict(msg_ids=theids, status_only=status_only)
1153 1155 msg = self.session.send(self._query_socket, "result_request", content=content)
1154 1156 zmq.select([self._query_socket], [], [])
1155 1157 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1156 1158 if self.debug:
1157 1159 pprint(msg)
1158 1160 content = msg['content']
1159 1161 if content['status'] != 'ok':
1160 1162 raise self._unwrap_exception(content)
1161 1163 buffers = msg['buffers']
1162 1164 else:
1163 1165 content = dict(completed=[],pending=[])
1164 1166
1165 1167 content['completed'].extend(completed)
1166 1168
1167 1169 if status_only:
1168 1170 return content
1169 1171
1170 1172 failures = []
1171 1173 # load cached results into result:
1172 1174 content.update(local_results)
1173 1175 # update cache with results:
1174 1176 for msg_id in sorted(theids):
1175 1177 if msg_id in content['completed']:
1176 1178 rec = content[msg_id]
1177 1179 parent = rec['header']
1178 1180 header = rec['result_header']
1179 1181 rcontent = rec['result_content']
1180 1182 iodict = rec['io']
1181 1183 if isinstance(rcontent, str):
1182 1184 rcontent = self.session.unpack(rcontent)
1183 1185
1184 1186 md = self.metadata[msg_id]
1185 1187 md.update(self._extract_metadata(header, parent, rcontent))
1186 1188 md.update(iodict)
1187 1189
1188 1190 if rcontent['status'] == 'ok':
1189 1191 res,buffers = util.unserialize_object(buffers)
1190 1192 else:
1191 1193 print rcontent
1192 1194 res = self._unwrap_exception(rcontent)
1193 1195 failures.append(res)
1194 1196
1195 1197 self.results[msg_id] = res
1196 1198 content[msg_id] = res
1197 1199
1198 1200 if len(theids) == 1 and failures:
1199 1201 raise failures[0]
1200 1202
1201 1203 error.collect_exceptions(failures, "result_status")
1202 1204 return content
1203 1205
1204 1206 @spin_first
1205 1207 def queue_status(self, targets='all', verbose=False):
1206 1208 """Fetch the status of engine queues.
1207 1209
1208 1210 Parameters
1209 1211 ----------
1210 1212
1211 1213 targets : int/str/list of ints/strs
1212 1214 the engines whose states are to be queried.
1213 1215 default : all
1214 1216 verbose : bool
1215 1217 Whether to return lengths only, or lists of ids for each element
1216 1218 """
1217 1219 engine_ids = self._build_targets(targets)[1]
1218 1220 content = dict(targets=engine_ids, verbose=verbose)
1219 1221 self.session.send(self._query_socket, "queue_request", content=content)
1220 1222 idents,msg = self.session.recv(self._query_socket, 0)
1221 1223 if self.debug:
1222 1224 pprint(msg)
1223 1225 content = msg['content']
1224 1226 status = content.pop('status')
1225 1227 if status != 'ok':
1226 1228 raise self._unwrap_exception(content)
1227 1229 content = util.rekey(content)
1228 1230 if isinstance(targets, int):
1229 1231 return content[targets]
1230 1232 else:
1231 1233 return content
1232 1234
1233 1235 @spin_first
1234 1236 def purge_results(self, jobs=[], targets=[]):
1235 1237 """Tell the Hub to forget results.
1236 1238
1237 1239 Individual results can be purged by msg_id, or the entire
1238 1240 history of specific targets can be purged.
1239 1241
1240 1242 Parameters
1241 1243 ----------
1242 1244
1243 1245 jobs : str or list of str or AsyncResult objects
1244 1246 the msg_ids whose results should be forgotten.
1245 1247 targets : int/str/list of ints/strs
1246 1248 The targets, by uuid or int_id, whose entire history is to be purged.
1247 1249 Use `targets='all'` to scrub everything from the Hub's memory.
1248 1250
1249 1251 default : None
1250 1252 """
1251 1253 if not targets and not jobs:
1252 1254 raise ValueError("Must specify at least one of `targets` and `jobs`")
1253 1255 if targets:
1254 1256 targets = self._build_targets(targets)[1]
1255 1257
1256 1258 # construct msg_ids from jobs
1257 1259 msg_ids = []
1258 1260 if isinstance(jobs, (basestring,AsyncResult)):
1259 1261 jobs = [jobs]
1260 1262 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1261 1263 if bad_ids:
1262 1264 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1263 1265 for j in jobs:
1264 1266 if isinstance(j, AsyncResult):
1265 1267 msg_ids.extend(j.msg_ids)
1266 1268 else:
1267 1269 msg_ids.append(j)
1268 1270
1269 1271 content = dict(targets=targets, msg_ids=msg_ids)
1270 1272 self.session.send(self._query_socket, "purge_request", content=content)
1271 1273 idents, msg = self.session.recv(self._query_socket, 0)
1272 1274 if self.debug:
1273 1275 pprint(msg)
1274 1276 content = msg['content']
1275 1277 if content['status'] != 'ok':
1276 1278 raise self._unwrap_exception(content)
1277 1279
1278 1280
1279 1281 __all__ = [ 'Client' ]
@@ -1,1089 +1,1091
1 1 #!/usr/bin/env python
2 2 """The IPython Controller Hub with 0MQ
3 3 This is the master object that handles connections from engines and clients,
4 4 and monitors traffic through the various queues.
5 5 """
6 6 #-----------------------------------------------------------------------------
7 7 # Copyright (C) 2010 The IPython Development Team
8 8 #
9 9 # Distributed under the terms of the BSD License. The full license is in
10 10 # the file COPYING, distributed as part of this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from datetime import datetime
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # internal:
27 27 from IPython.utils.importstring import import_item
28 28 from IPython.utils.traitlets import HasTraits, Instance, Int, CStr, Str, Dict, Set, List, Bool
29 29
30 30 from IPython.parallel import error
31 31 from IPython.parallel.factory import RegistrationFactory, LoggingFactory
32 32 from IPython.parallel.util import select_random_ports, validate_url_container, ISO8601
33 33
34 34 from .heartmonitor import HeartMonitor
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Code
38 38 #-----------------------------------------------------------------------------
39 39
40 40 def _passer(*args, **kwargs):
41 41 return
42 42
43 43 def _printer(*args, **kwargs):
44 44 print (args)
45 45 print (kwargs)
46 46
47 47 def empty_record():
48 48 """Return an empty dict with all record keys."""
49 49 return {
50 50 'msg_id' : None,
51 51 'header' : None,
52 52 'content': None,
53 53 'buffers': None,
54 54 'submitted': None,
55 55 'client_uuid' : None,
56 56 'engine_uuid' : None,
57 57 'started': None,
58 58 'completed': None,
59 59 'resubmitted': None,
60 60 'result_header' : None,
61 61 'result_content' : None,
62 62 'result_buffers' : None,
63 63 'queue' : None,
64 64 'pyin' : None,
65 65 'pyout': None,
66 66 'pyerr': None,
67 67 'stdout': '',
68 68 'stderr': '',
69 69 }
70 70
71 71 def init_record(msg):
72 72 """Initialize a TaskRecord based on a request."""
73 73 header = msg['header']
74 74 return {
75 75 'msg_id' : header['msg_id'],
76 76 'header' : header,
77 77 'content': msg['content'],
78 78 'buffers': msg['buffers'],
79 79 'submitted': datetime.strptime(header['date'], ISO8601),
80 80 'client_uuid' : None,
81 81 'engine_uuid' : None,
82 82 'started': None,
83 83 'completed': None,
84 84 'resubmitted': None,
85 85 'result_header' : None,
86 86 'result_content' : None,
87 87 'result_buffers' : None,
88 88 'queue' : None,
89 89 'pyin' : None,
90 90 'pyout': None,
91 91 'pyerr': None,
92 92 'stdout': '',
93 93 'stderr': '',
94 94 }
95 95
96 96
97 97 class EngineConnector(HasTraits):
98 98 """A simple object for accessing the various zmq connections of an object.
99 99 Attributes are:
100 100 id (int): engine ID
101 101 uuid (str): uuid (unused?)
102 102 queue (str): identity of queue's XREQ socket
103 103 registration (str): identity of registration XREQ socket
104 104 heartbeat (str): identity of heartbeat XREQ socket
105 105 """
106 106 id=Int(0)
107 107 queue=Str()
108 108 control=Str()
109 109 registration=Str()
110 110 heartbeat=Str()
111 111 pending=Set()
112 112
113 113 class HubFactory(RegistrationFactory):
114 114 """The Configurable for setting up a Hub."""
115 115
116 116 # name of a scheduler scheme
117 117 scheme = Str('leastload', config=True)
118 118
119 119 # port-pairs for monitoredqueues:
120 120 hb = Instance(list, config=True)
121 121 def _hb_default(self):
122 122 return select_random_ports(2)
123 123
124 124 mux = Instance(list, config=True)
125 125 def _mux_default(self):
126 126 return select_random_ports(2)
127 127
128 128 task = Instance(list, config=True)
129 129 def _task_default(self):
130 130 return select_random_ports(2)
131 131
132 132 control = Instance(list, config=True)
133 133 def _control_default(self):
134 134 return select_random_ports(2)
135 135
136 136 iopub = Instance(list, config=True)
137 137 def _iopub_default(self):
138 138 return select_random_ports(2)
139 139
140 140 # single ports:
141 141 mon_port = Instance(int, config=True)
142 142 def _mon_port_default(self):
143 143 return select_random_ports(1)[0]
144 144
145 145 notifier_port = Instance(int, config=True)
146 146 def _notifier_port_default(self):
147 147 return select_random_ports(1)[0]
148 148
149 149 ping = Int(1000, config=True) # ping frequency
150 150
151 151 engine_ip = CStr('127.0.0.1', config=True)
152 152 engine_transport = CStr('tcp', config=True)
153 153
154 154 client_ip = CStr('127.0.0.1', config=True)
155 155 client_transport = CStr('tcp', config=True)
156 156
157 157 monitor_ip = CStr('127.0.0.1', config=True)
158 158 monitor_transport = CStr('tcp', config=True)
159 159
160 160 monitor_url = CStr('')
161 161
162 162 db_class = CStr('IPython.parallel.controller.dictdb.DictDB', config=True)
163 163
164 164 # not configurable
165 165 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
166 166 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
167 167 subconstructors = List()
168 168 _constructed = Bool(False)
169 169
170 170 def _ip_changed(self, name, old, new):
171 171 self.engine_ip = new
172 172 self.client_ip = new
173 173 self.monitor_ip = new
174 174 self._update_monitor_url()
175 175
176 176 def _update_monitor_url(self):
177 177 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
178 178
179 179 def _transport_changed(self, name, old, new):
180 180 self.engine_transport = new
181 181 self.client_transport = new
182 182 self.monitor_transport = new
183 183 self._update_monitor_url()
184 184
185 185 def __init__(self, **kwargs):
186 186 super(HubFactory, self).__init__(**kwargs)
187 187 self._update_monitor_url()
188 188 # self.on_trait_change(self._sync_ips, 'ip')
189 189 # self.on_trait_change(self._sync_transports, 'transport')
190 190 self.subconstructors.append(self.construct_hub)
191 191
192 192
193 193 def construct(self):
194 194 assert not self._constructed, "already constructed!"
195 195
196 196 for subc in self.subconstructors:
197 197 subc()
198 198
199 199 self._constructed = True
200 200
201 201
202 202 def start(self):
203 203 assert self._constructed, "must be constructed by self.construct() first!"
204 204 self.heartmonitor.start()
205 205 self.log.info("Heartmonitor started")
206 206
207 207 def construct_hub(self):
208 208 """construct"""
209 209 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
210 210 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
211 211
212 212 ctx = self.context
213 213 loop = self.loop
214 214
215 215 # Registrar socket
216 216 q = ZMQStream(ctx.socket(zmq.XREP), loop)
217 217 q.bind(client_iface % self.regport)
218 218 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
219 219 if self.client_ip != self.engine_ip:
220 220 q.bind(engine_iface % self.regport)
221 221 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
222 222
223 223 ### Engine connections ###
224 224
225 225 # heartbeat
226 226 hpub = ctx.socket(zmq.PUB)
227 227 hpub.bind(engine_iface % self.hb[0])
228 228 hrep = ctx.socket(zmq.XREP)
229 229 hrep.bind(engine_iface % self.hb[1])
230 230 self.heartmonitor = HeartMonitor(loop=loop, pingstream=ZMQStream(hpub,loop), pongstream=ZMQStream(hrep,loop),
231 231 period=self.ping, logname=self.log.name)
232 232
233 233 ### Client connections ###
234 234 # Notifier socket
235 235 n = ZMQStream(ctx.socket(zmq.PUB), loop)
236 236 n.bind(client_iface%self.notifier_port)
237 237
238 238 ### build and launch the queues ###
239 239
240 240 # monitor socket
241 241 sub = ctx.socket(zmq.SUB)
242 242 sub.setsockopt(zmq.SUBSCRIBE, "")
243 243 sub.bind(self.monitor_url)
244 244 sub.bind('inproc://monitor')
245 245 sub = ZMQStream(sub, loop)
246 246
247 247 # connect the db
248 248 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
249 249 # cdir = self.config.Global.cluster_dir
250 250 self.db = import_item(self.db_class)(session=self.session.session, config=self.config)
251 251 time.sleep(.25)
252 252
253 253 # build connection dicts
254 254 self.engine_info = {
255 255 'control' : engine_iface%self.control[1],
256 256 'mux': engine_iface%self.mux[1],
257 257 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
258 258 'task' : engine_iface%self.task[1],
259 259 'iopub' : engine_iface%self.iopub[1],
260 260 # 'monitor' : engine_iface%self.mon_port,
261 261 }
262 262
263 263 self.client_info = {
264 264 'control' : client_iface%self.control[0],
265 265 'mux': client_iface%self.mux[0],
266 266 'task' : (self.scheme, client_iface%self.task[0]),
267 267 'iopub' : client_iface%self.iopub[0],
268 268 'notification': client_iface%self.notifier_port
269 269 }
270 270 self.log.debug("Hub engine addrs: %s"%self.engine_info)
271 271 self.log.debug("Hub client addrs: %s"%self.client_info)
272 272 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
273 273 query=q, notifier=n, db=self.db,
274 274 engine_info=self.engine_info, client_info=self.client_info,
275 275 logname=self.log.name)
276 276
277 277
278 278 class Hub(LoggingFactory):
279 279 """The IPython Controller Hub with 0MQ connections
280 280
281 281 Parameters
282 282 ==========
283 283 loop: zmq IOLoop instance
284 284 session: StreamSession object
285 285 <removed> context: zmq context for creating new connections (?)
286 286 queue: ZMQStream for monitoring the command queue (SUB)
287 287 query: ZMQStream for engine registration and client queries requests (XREP)
288 288 heartbeat: HeartMonitor object checking the pulse of the engines
289 289 notifier: ZMQStream for broadcasting engine registration changes (PUB)
290 290 db: connection to db for out of memory logging of commands
291 291 NotImplemented
292 292 engine_info: dict of zmq connection information for engines to connect
293 293 to the queues.
294 294 client_info: dict of zmq connection information for engines to connect
295 295 to the queues.
296 296 """
297 297 # internal data structures:
298 298 ids=Set() # engine IDs
299 299 keytable=Dict()
300 300 by_ident=Dict()
301 301 engines=Dict()
302 302 clients=Dict()
303 303 hearts=Dict()
304 304 pending=Set()
305 305 queues=Dict() # pending msg_ids keyed by engine_id
306 306 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
307 307 completed=Dict() # completed msg_ids keyed by engine_id
308 308 all_completed=Set() # completed msg_ids keyed by engine_id
309 309 dead_engines=Set() # completed msg_ids keyed by engine_id
310 310 # mia=None
311 311 incoming_registrations=Dict()
312 312 registration_timeout=Int()
313 313 _idcounter=Int(0)
314 314
315 315 # objects from constructor:
316 316 loop=Instance(ioloop.IOLoop)
317 317 query=Instance(ZMQStream)
318 318 monitor=Instance(ZMQStream)
319 319 heartmonitor=Instance(HeartMonitor)
320 320 notifier=Instance(ZMQStream)
321 321 db=Instance(object)
322 322 client_info=Dict()
323 323 engine_info=Dict()
324 324
325 325
326 326 def __init__(self, **kwargs):
327 327 """
328 328 # universal:
329 329 loop: IOLoop for creating future connections
330 330 session: streamsession for sending serialized data
331 331 # engine:
332 332 queue: ZMQStream for monitoring queue messages
333 333 query: ZMQStream for engine+client registration and client requests
334 334 heartbeat: HeartMonitor object for tracking engines
335 335 # extra:
336 336 db: ZMQStream for db connection (NotImplemented)
337 337 engine_info: zmq address/protocol dict for engine connections
338 338 client_info: zmq address/protocol dict for client connections
339 339 """
340 340
341 341 super(Hub, self).__init__(**kwargs)
342 342 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
343 343
344 344 # validate connection dicts:
345 345 for k,v in self.client_info.iteritems():
346 346 if k == 'task':
347 347 validate_url_container(v[1])
348 348 else:
349 349 validate_url_container(v)
350 350 # validate_url_container(self.client_info)
351 351 validate_url_container(self.engine_info)
352 352
353 353 # register our callbacks
354 354 self.query.on_recv(self.dispatch_query)
355 355 self.monitor.on_recv(self.dispatch_monitor_traffic)
356 356
357 357 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
358 358 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
359 359
360 360 self.monitor_handlers = { 'in' : self.save_queue_request,
361 361 'out': self.save_queue_result,
362 362 'intask': self.save_task_request,
363 363 'outtask': self.save_task_result,
364 364 'tracktask': self.save_task_destination,
365 365 'incontrol': _passer,
366 366 'outcontrol': _passer,
367 367 'iopub': self.save_iopub_message,
368 368 }
369 369
370 370 self.query_handlers = {'queue_request': self.queue_status,
371 371 'result_request': self.get_results,
372 372 'purge_request': self.purge_results,
373 373 'load_request': self.check_load,
374 374 'resubmit_request': self.resubmit_task,
375 375 'shutdown_request': self.shutdown_request,
376 376 'registration_request' : self.register_engine,
377 377 'unregistration_request' : self.unregister_engine,
378 378 'connection_request': self.connection_request,
379 379 }
380 380
381 381 self.log.info("hub::created hub")
382 382
383 383 @property
384 384 def _next_id(self):
385 385 """gemerate a new ID.
386 386
387 387 No longer reuse old ids, just count from 0."""
388 388 newid = self._idcounter
389 389 self._idcounter += 1
390 390 return newid
391 391 # newid = 0
392 392 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
393 393 # # print newid, self.ids, self.incoming_registrations
394 394 # while newid in self.ids or newid in incoming:
395 395 # newid += 1
396 396 # return newid
397 397
398 398 #-----------------------------------------------------------------------------
399 399 # message validation
400 400 #-----------------------------------------------------------------------------
401 401
402 402 def _validate_targets(self, targets):
403 403 """turn any valid targets argument into a list of integer ids"""
404 404 if targets is None:
405 405 # default to all
406 406 targets = self.ids
407 407
408 408 if isinstance(targets, (int,str,unicode)):
409 409 # only one target specified
410 410 targets = [targets]
411 411 _targets = []
412 412 for t in targets:
413 413 # map raw identities to ids
414 414 if isinstance(t, (str,unicode)):
415 415 t = self.by_ident.get(t, t)
416 416 _targets.append(t)
417 417 targets = _targets
418 418 bad_targets = [ t for t in targets if t not in self.ids ]
419 419 if bad_targets:
420 420 raise IndexError("No Such Engine: %r"%bad_targets)
421 421 if not targets:
422 422 raise IndexError("No Engines Registered")
423 423 return targets
424 424
425 425 #-----------------------------------------------------------------------------
426 426 # dispatch methods (1 per stream)
427 427 #-----------------------------------------------------------------------------
428 428
429 429 # def dispatch_registration_request(self, msg):
430 430 # """"""
431 431 # self.log.debug("registration::dispatch_register_request(%s)"%msg)
432 432 # idents,msg = self.session.feed_identities(msg)
433 433 # if not idents:
434 434 # self.log.error("Bad Query Message: %s"%msg, exc_info=True)
435 435 # return
436 436 # try:
437 437 # msg = self.session.unpack_message(msg,content=True)
438 438 # except:
439 439 # self.log.error("registration::got bad registration message: %s"%msg, exc_info=True)
440 440 # return
441 441 #
442 442 # msg_type = msg['msg_type']
443 443 # content = msg['content']
444 444 #
445 445 # handler = self.query_handlers.get(msg_type, None)
446 446 # if handler is None:
447 447 # self.log.error("registration::got bad registration message: %s"%msg)
448 448 # else:
449 449 # handler(idents, msg)
450 450
451 451 def dispatch_monitor_traffic(self, msg):
452 452 """all ME and Task queue messages come through here, as well as
453 453 IOPub traffic."""
454 454 self.log.debug("monitor traffic: %s"%msg[:2])
455 455 switch = msg[0]
456 456 idents, msg = self.session.feed_identities(msg[1:])
457 457 if not idents:
458 458 self.log.error("Bad Monitor Message: %s"%msg)
459 459 return
460 460 handler = self.monitor_handlers.get(switch, None)
461 461 if handler is not None:
462 462 handler(idents, msg)
463 463 else:
464 464 self.log.error("Invalid monitor topic: %s"%switch)
465 465
466 466
467 467 def dispatch_query(self, msg):
468 468 """Route registration requests and queries from clients."""
469 469 idents, msg = self.session.feed_identities(msg)
470 470 if not idents:
471 471 self.log.error("Bad Query Message: %s"%msg)
472 472 return
473 473 client_id = idents[0]
474 474 try:
475 475 msg = self.session.unpack_message(msg, content=True)
476 476 except:
477 477 content = error.wrap_exception()
478 478 self.log.error("Bad Query Message: %s"%msg, exc_info=True)
479 479 self.session.send(self.query, "hub_error", ident=client_id,
480 480 content=content)
481 481 return
482 482
483 483 # print client_id, header, parent, content
484 484 #switch on message type:
485 485 msg_type = msg['msg_type']
486 486 self.log.info("client::client %s requested %s"%(client_id, msg_type))
487 487 handler = self.query_handlers.get(msg_type, None)
488 488 try:
489 489 assert handler is not None, "Bad Message Type: %s"%msg_type
490 490 except:
491 491 content = error.wrap_exception()
492 492 self.log.error("Bad Message Type: %s"%msg_type, exc_info=True)
493 493 self.session.send(self.query, "hub_error", ident=client_id,
494 494 content=content)
495 495 return
496 496 else:
497 497 handler(idents, msg)
498 498
499 499 def dispatch_db(self, msg):
500 500 """"""
501 501 raise NotImplementedError
502 502
503 503 #---------------------------------------------------------------------------
504 504 # handler methods (1 per event)
505 505 #---------------------------------------------------------------------------
506 506
507 507 #----------------------- Heartbeat --------------------------------------
508 508
509 509 def handle_new_heart(self, heart):
510 510 """handler to attach to heartbeater.
511 511 Called when a new heart starts to beat.
512 512 Triggers completion of registration."""
513 513 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
514 514 if heart not in self.incoming_registrations:
515 515 self.log.info("heartbeat::ignoring new heart: %r"%heart)
516 516 else:
517 517 self.finish_registration(heart)
518 518
519 519
520 520 def handle_heart_failure(self, heart):
521 521 """handler to attach to heartbeater.
522 522 called when a previously registered heart fails to respond to beat request.
523 523 triggers unregistration"""
524 524 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
525 525 eid = self.hearts.get(heart, None)
526 526 queue = self.engines[eid].queue
527 527 if eid is None:
528 528 self.log.info("heartbeat::ignoring heart failure %r"%heart)
529 529 else:
530 530 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
531 531
532 532 #----------------------- MUX Queue Traffic ------------------------------
533 533
534 534 def save_queue_request(self, idents, msg):
535 535 if len(idents) < 2:
536 536 self.log.error("invalid identity prefix: %s"%idents)
537 537 return
538 538 queue_id, client_id = idents[:2]
539 539 try:
540 540 msg = self.session.unpack_message(msg, content=False)
541 541 except:
542 542 self.log.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True)
543 543 return
544 544
545 545 eid = self.by_ident.get(queue_id, None)
546 546 if eid is None:
547 547 self.log.error("queue::target %r not registered"%queue_id)
548 548 self.log.debug("queue:: valid are: %s"%(self.by_ident.keys()))
549 549 return
550 550
551 551 header = msg['header']
552 552 msg_id = header['msg_id']
553 553 record = init_record(msg)
554 554 record['engine_uuid'] = queue_id
555 555 record['client_uuid'] = client_id
556 556 record['queue'] = 'mux'
557 557
558 558 try:
559 559 # it's posible iopub arrived first:
560 560 existing = self.db.get_record(msg_id)
561 561 for key,evalue in existing.iteritems():
562 562 rvalue = record[key]
563 563 if evalue and rvalue and evalue != rvalue:
564 564 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
565 565 elif evalue and not rvalue:
566 566 record[key] = evalue
567 567 self.db.update_record(msg_id, record)
568 568 except KeyError:
569 569 self.db.add_record(msg_id, record)
570 570
571 571 self.pending.add(msg_id)
572 572 self.queues[eid].append(msg_id)
573 573
574 574 def save_queue_result(self, idents, msg):
575 575 if len(idents) < 2:
576 576 self.log.error("invalid identity prefix: %s"%idents)
577 577 return
578 578
579 579 client_id, queue_id = idents[:2]
580 580 try:
581 581 msg = self.session.unpack_message(msg, content=False)
582 582 except:
583 583 self.log.error("queue::engine %r sent invalid message to %r: %s"%(
584 584 queue_id,client_id, msg), exc_info=True)
585 585 return
586 586
587 587 eid = self.by_ident.get(queue_id, None)
588 588 if eid is None:
589 589 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
590 590 # self.log.debug("queue:: %s"%msg[2:])
591 591 return
592 592
593 593 parent = msg['parent_header']
594 594 if not parent:
595 595 return
596 596 msg_id = parent['msg_id']
597 597 if msg_id in self.pending:
598 598 self.pending.remove(msg_id)
599 599 self.all_completed.add(msg_id)
600 600 self.queues[eid].remove(msg_id)
601 601 self.completed[eid].append(msg_id)
602 602 elif msg_id not in self.all_completed:
603 603 # it could be a result from a dead engine that died before delivering the
604 604 # result
605 605 self.log.warn("queue:: unknown msg finished %s"%msg_id)
606 606 return
607 607 # update record anyway, because the unregistration could have been premature
608 608 rheader = msg['header']
609 609 completed = datetime.strptime(rheader['date'], ISO8601)
610 610 started = rheader.get('started', None)
611 611 if started is not None:
612 612 started = datetime.strptime(started, ISO8601)
613 613 result = {
614 614 'result_header' : rheader,
615 615 'result_content': msg['content'],
616 616 'started' : started,
617 617 'completed' : completed
618 618 }
619 619
620 620 result['result_buffers'] = msg['buffers']
621 621 self.db.update_record(msg_id, result)
622 622
623 623
624 624 #--------------------- Task Queue Traffic ------------------------------
625 625
626 626 def save_task_request(self, idents, msg):
627 627 """Save the submission of a task."""
628 628 client_id = idents[0]
629 629
630 630 try:
631 631 msg = self.session.unpack_message(msg, content=False)
632 632 except:
633 633 self.log.error("task::client %r sent invalid task message: %s"%(
634 634 client_id, msg), exc_info=True)
635 635 return
636 636 record = init_record(msg)
637 637
638 638 record['client_uuid'] = client_id
639 639 record['queue'] = 'task'
640 640 header = msg['header']
641 641 msg_id = header['msg_id']
642 642 self.pending.add(msg_id)
643 643 try:
644 644 # it's posible iopub arrived first:
645 645 existing = self.db.get_record(msg_id)
646 646 for key,evalue in existing.iteritems():
647 647 rvalue = record[key]
648 648 if evalue and rvalue and evalue != rvalue:
649 649 self.log.error("conflicting initial state for record: %s:%s <> %s"%(msg_id, rvalue, evalue))
650 650 elif evalue and not rvalue:
651 651 record[key] = evalue
652 652 self.db.update_record(msg_id, record)
653 653 except KeyError:
654 654 self.db.add_record(msg_id, record)
655 655
656 656 def save_task_result(self, idents, msg):
657 657 """save the result of a completed task."""
658 658 client_id = idents[0]
659 659 try:
660 660 msg = self.session.unpack_message(msg, content=False)
661 661 except:
662 662 self.log.error("task::invalid task result message send to %r: %s"%(
663 663 client_id, msg), exc_info=True)
664 664 raise
665 665 return
666 666
667 667 parent = msg['parent_header']
668 668 if not parent:
669 669 # print msg
670 670 self.log.warn("Task %r had no parent!"%msg)
671 671 return
672 672 msg_id = parent['msg_id']
673 673
674 674 header = msg['header']
675 675 engine_uuid = header.get('engine', None)
676 676 eid = self.by_ident.get(engine_uuid, None)
677 677
678 678 if msg_id in self.pending:
679 679 self.pending.remove(msg_id)
680 680 self.all_completed.add(msg_id)
681 681 if eid is not None:
682 682 self.completed[eid].append(msg_id)
683 683 if msg_id in self.tasks[eid]:
684 684 self.tasks[eid].remove(msg_id)
685 685 completed = datetime.strptime(header['date'], ISO8601)
686 686 started = header.get('started', None)
687 687 if started is not None:
688 688 started = datetime.strptime(started, ISO8601)
689 689 result = {
690 690 'result_header' : header,
691 691 'result_content': msg['content'],
692 692 'started' : started,
693 693 'completed' : completed,
694 694 'engine_uuid': engine_uuid
695 695 }
696 696
697 697 result['result_buffers'] = msg['buffers']
698 698 self.db.update_record(msg_id, result)
699 699
700 700 else:
701 701 self.log.debug("task::unknown task %s finished"%msg_id)
702 702
703 703 def save_task_destination(self, idents, msg):
704 704 try:
705 705 msg = self.session.unpack_message(msg, content=True)
706 706 except:
707 707 self.log.error("task::invalid task tracking message", exc_info=True)
708 708 return
709 709 content = msg['content']
710 710 # print (content)
711 711 msg_id = content['msg_id']
712 712 engine_uuid = content['engine_id']
713 713 eid = self.by_ident[engine_uuid]
714 714
715 715 self.log.info("task::task %s arrived on %s"%(msg_id, eid))
716 716 # if msg_id in self.mia:
717 717 # self.mia.remove(msg_id)
718 718 # else:
719 719 # self.log.debug("task::task %s not listed as MIA?!"%(msg_id))
720 720
721 721 self.tasks[eid].append(msg_id)
722 722 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
723 723 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
724 724
725 725 def mia_task_request(self, idents, msg):
726 726 raise NotImplementedError
727 727 client_id = idents[0]
728 728 # content = dict(mia=self.mia,status='ok')
729 729 # self.session.send('mia_reply', content=content, idents=client_id)
730 730
731 731
732 732 #--------------------- IOPub Traffic ------------------------------
733 733
734 734 def save_iopub_message(self, topics, msg):
735 735 """save an iopub message into the db"""
736 736 # print (topics)
737 737 try:
738 738 msg = self.session.unpack_message(msg, content=True)
739 739 except:
740 740 self.log.error("iopub::invalid IOPub message", exc_info=True)
741 741 return
742 742
743 743 parent = msg['parent_header']
744 744 if not parent:
745 745 self.log.error("iopub::invalid IOPub message: %s"%msg)
746 746 return
747 747 msg_id = parent['msg_id']
748 748 msg_type = msg['msg_type']
749 749 content = msg['content']
750 750
751 751 # ensure msg_id is in db
752 752 try:
753 753 rec = self.db.get_record(msg_id)
754 754 except KeyError:
755 755 rec = empty_record()
756 756 rec['msg_id'] = msg_id
757 757 self.db.add_record(msg_id, rec)
758 758 # stream
759 759 d = {}
760 760 if msg_type == 'stream':
761 761 name = content['name']
762 762 s = rec[name] or ''
763 763 d[name] = s + content['data']
764 764
765 765 elif msg_type == 'pyerr':
766 766 d['pyerr'] = content
767 elif msg_type == 'pyin':
768 d['pyin'] = content['code']
767 769 else:
768 d[msg_type] = content['data']
770 d[msg_type] = content.get('data', '')
769 771
770 772 self.db.update_record(msg_id, d)
771 773
772 774
773 775
774 776 #-------------------------------------------------------------------------
775 777 # Registration requests
776 778 #-------------------------------------------------------------------------
777 779
778 780 def connection_request(self, client_id, msg):
779 781 """Reply with connection addresses for clients."""
780 782 self.log.info("client::client %s connected"%client_id)
781 783 content = dict(status='ok')
782 784 content.update(self.client_info)
783 785 jsonable = {}
784 786 for k,v in self.keytable.iteritems():
785 787 if v not in self.dead_engines:
786 788 jsonable[str(k)] = v
787 789 content['engines'] = jsonable
788 790 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
789 791
790 792 def register_engine(self, reg, msg):
791 793 """Register a new engine."""
792 794 content = msg['content']
793 795 try:
794 796 queue = content['queue']
795 797 except KeyError:
796 798 self.log.error("registration::queue not specified", exc_info=True)
797 799 return
798 800 heart = content.get('heartbeat', None)
799 801 """register a new engine, and create the socket(s) necessary"""
800 802 eid = self._next_id
801 803 # print (eid, queue, reg, heart)
802 804
803 805 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
804 806
805 807 content = dict(id=eid,status='ok')
806 808 content.update(self.engine_info)
807 809 # check if requesting available IDs:
808 810 if queue in self.by_ident:
809 811 try:
810 812 raise KeyError("queue_id %r in use"%queue)
811 813 except:
812 814 content = error.wrap_exception()
813 815 self.log.error("queue_id %r in use"%queue, exc_info=True)
814 816 elif heart in self.hearts: # need to check unique hearts?
815 817 try:
816 818 raise KeyError("heart_id %r in use"%heart)
817 819 except:
818 820 self.log.error("heart_id %r in use"%heart, exc_info=True)
819 821 content = error.wrap_exception()
820 822 else:
821 823 for h, pack in self.incoming_registrations.iteritems():
822 824 if heart == h:
823 825 try:
824 826 raise KeyError("heart_id %r in use"%heart)
825 827 except:
826 828 self.log.error("heart_id %r in use"%heart, exc_info=True)
827 829 content = error.wrap_exception()
828 830 break
829 831 elif queue == pack[1]:
830 832 try:
831 833 raise KeyError("queue_id %r in use"%queue)
832 834 except:
833 835 self.log.error("queue_id %r in use"%queue, exc_info=True)
834 836 content = error.wrap_exception()
835 837 break
836 838
837 839 msg = self.session.send(self.query, "registration_reply",
838 840 content=content,
839 841 ident=reg)
840 842
841 843 if content['status'] == 'ok':
842 844 if heart in self.heartmonitor.hearts:
843 845 # already beating
844 846 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
845 847 self.finish_registration(heart)
846 848 else:
847 849 purge = lambda : self._purge_stalled_registration(heart)
848 850 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
849 851 dc.start()
850 852 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
851 853 else:
852 854 self.log.error("registration::registration %i failed: %s"%(eid, content['evalue']))
853 855 return eid
854 856
855 857 def unregister_engine(self, ident, msg):
856 858 """Unregister an engine that explicitly requested to leave."""
857 859 try:
858 860 eid = msg['content']['id']
859 861 except:
860 862 self.log.error("registration::bad engine id for unregistration: %s"%ident, exc_info=True)
861 863 return
862 864 self.log.info("registration::unregister_engine(%s)"%eid)
863 865 # print (eid)
864 866 uuid = self.keytable[eid]
865 867 content=dict(id=eid, queue=uuid)
866 868 self.dead_engines.add(uuid)
867 869 # self.ids.remove(eid)
868 870 # uuid = self.keytable.pop(eid)
869 871 #
870 872 # ec = self.engines.pop(eid)
871 873 # self.hearts.pop(ec.heartbeat)
872 874 # self.by_ident.pop(ec.queue)
873 875 # self.completed.pop(eid)
874 876 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
875 877 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
876 878 dc.start()
877 879 ############## TODO: HANDLE IT ################
878 880
879 881 if self.notifier:
880 882 self.session.send(self.notifier, "unregistration_notification", content=content)
881 883
882 884 def _handle_stranded_msgs(self, eid, uuid):
883 885 """Handle messages known to be on an engine when the engine unregisters.
884 886
885 887 It is possible that this will fire prematurely - that is, an engine will
886 888 go down after completing a result, and the client will be notified
887 889 that the result failed and later receive the actual result.
888 890 """
889 891
890 892 outstanding = self.queues[eid]
891 893
892 894 for msg_id in outstanding:
893 895 self.pending.remove(msg_id)
894 896 self.all_completed.add(msg_id)
895 897 try:
896 898 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
897 899 except:
898 900 content = error.wrap_exception()
899 901 # build a fake header:
900 902 header = {}
901 903 header['engine'] = uuid
902 904 header['date'] = datetime.now().strftime(ISO8601)
903 905 rec = dict(result_content=content, result_header=header, result_buffers=[])
904 906 rec['completed'] = header['date']
905 907 rec['engine_uuid'] = uuid
906 908 self.db.update_record(msg_id, rec)
907 909
908 910 def finish_registration(self, heart):
909 911 """Second half of engine registration, called after our HeartMonitor
910 912 has received a beat from the Engine's Heart."""
911 913 try:
912 914 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
913 915 except KeyError:
914 916 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
915 917 return
916 918 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
917 919 if purge is not None:
918 920 purge.stop()
919 921 control = queue
920 922 self.ids.add(eid)
921 923 self.keytable[eid] = queue
922 924 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
923 925 control=control, heartbeat=heart)
924 926 self.by_ident[queue] = eid
925 927 self.queues[eid] = list()
926 928 self.tasks[eid] = list()
927 929 self.completed[eid] = list()
928 930 self.hearts[heart] = eid
929 931 content = dict(id=eid, queue=self.engines[eid].queue)
930 932 if self.notifier:
931 933 self.session.send(self.notifier, "registration_notification", content=content)
932 934 self.log.info("engine::Engine Connected: %i"%eid)
933 935
934 936 def _purge_stalled_registration(self, heart):
935 937 if heart in self.incoming_registrations:
936 938 eid = self.incoming_registrations.pop(heart)[0]
937 939 self.log.info("registration::purging stalled registration: %i"%eid)
938 940 else:
939 941 pass
940 942
941 943 #-------------------------------------------------------------------------
942 944 # Client Requests
943 945 #-------------------------------------------------------------------------
944 946
945 947 def shutdown_request(self, client_id, msg):
946 948 """handle shutdown request."""
947 949 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
948 950 # also notify other clients of shutdown
949 951 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
950 952 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
951 953 dc.start()
952 954
953 955 def _shutdown(self):
954 956 self.log.info("hub::hub shutting down.")
955 957 time.sleep(0.1)
956 958 sys.exit(0)
957 959
958 960
959 961 def check_load(self, client_id, msg):
960 962 content = msg['content']
961 963 try:
962 964 targets = content['targets']
963 965 targets = self._validate_targets(targets)
964 966 except:
965 967 content = error.wrap_exception()
966 968 self.session.send(self.query, "hub_error",
967 969 content=content, ident=client_id)
968 970 return
969 971
970 972 content = dict(status='ok')
971 973 # loads = {}
972 974 for t in targets:
973 975 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
974 976 self.session.send(self.query, "load_reply", content=content, ident=client_id)
975 977
976 978
977 979 def queue_status(self, client_id, msg):
978 980 """Return the Queue status of one or more targets.
979 981 if verbose: return the msg_ids
980 982 else: return len of each type.
981 983 keys: queue (pending MUX jobs)
982 984 tasks (pending Task jobs)
983 985 completed (finished jobs from both queues)"""
984 986 content = msg['content']
985 987 targets = content['targets']
986 988 try:
987 989 targets = self._validate_targets(targets)
988 990 except:
989 991 content = error.wrap_exception()
990 992 self.session.send(self.query, "hub_error",
991 993 content=content, ident=client_id)
992 994 return
993 995 verbose = content.get('verbose', False)
994 996 content = dict(status='ok')
995 997 for t in targets:
996 998 queue = self.queues[t]
997 999 completed = self.completed[t]
998 1000 tasks = self.tasks[t]
999 1001 if not verbose:
1000 1002 queue = len(queue)
1001 1003 completed = len(completed)
1002 1004 tasks = len(tasks)
1003 1005 content[bytes(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1004 1006 # pending
1005 1007 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1006 1008
1007 1009 def purge_results(self, client_id, msg):
1008 1010 """Purge results from memory. This method is more valuable before we move
1009 1011 to a DB based message storage mechanism."""
1010 1012 content = msg['content']
1011 1013 msg_ids = content.get('msg_ids', [])
1012 1014 reply = dict(status='ok')
1013 1015 if msg_ids == 'all':
1014 1016 self.db.drop_matching_records(dict(completed={'$ne':None}))
1015 1017 else:
1016 1018 for msg_id in msg_ids:
1017 1019 if msg_id in self.all_completed:
1018 1020 self.db.drop_record(msg_id)
1019 1021 else:
1020 1022 if msg_id in self.pending:
1021 1023 try:
1022 1024 raise IndexError("msg pending: %r"%msg_id)
1023 1025 except:
1024 1026 reply = error.wrap_exception()
1025 1027 else:
1026 1028 try:
1027 1029 raise IndexError("No such msg: %r"%msg_id)
1028 1030 except:
1029 1031 reply = error.wrap_exception()
1030 1032 break
1031 1033 eids = content.get('engine_ids', [])
1032 1034 for eid in eids:
1033 1035 if eid not in self.engines:
1034 1036 try:
1035 1037 raise IndexError("No such engine: %i"%eid)
1036 1038 except:
1037 1039 reply = error.wrap_exception()
1038 1040 break
1039 1041 msg_ids = self.completed.pop(eid)
1040 1042 uid = self.engines[eid].queue
1041 1043 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1042 1044
1043 1045 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1044 1046
1045 1047 def resubmit_task(self, client_id, msg, buffers):
1046 1048 """Resubmit a task."""
1047 1049 raise NotImplementedError
1048 1050
1049 1051 def get_results(self, client_id, msg):
1050 1052 """Get the result of 1 or more messages."""
1051 1053 content = msg['content']
1052 1054 msg_ids = sorted(set(content['msg_ids']))
1053 1055 statusonly = content.get('status_only', False)
1054 1056 pending = []
1055 1057 completed = []
1056 1058 content = dict(status='ok')
1057 1059 content['pending'] = pending
1058 1060 content['completed'] = completed
1059 1061 buffers = []
1060 1062 if not statusonly:
1061 1063 content['results'] = {}
1062 1064 records = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1063 1065 for msg_id in msg_ids:
1064 1066 if msg_id in self.pending:
1065 1067 pending.append(msg_id)
1066 1068 elif msg_id in self.all_completed:
1067 1069 completed.append(msg_id)
1068 1070 if not statusonly:
1069 1071 rec = records[msg_id]
1070 1072 io_dict = {}
1071 1073 for key in 'pyin pyout pyerr stdout stderr'.split():
1072 1074 io_dict[key] = rec[key]
1073 1075 content[msg_id] = { 'result_content': rec['result_content'],
1074 1076 'header': rec['header'],
1075 1077 'result_header' : rec['result_header'],
1076 1078 'io' : io_dict,
1077 1079 }
1078 1080 if rec['result_buffers']:
1079 1081 buffers.extend(map(str, rec['result_buffers']))
1080 1082 else:
1081 1083 try:
1082 1084 raise KeyError('No such message: '+msg_id)
1083 1085 except:
1084 1086 content = error.wrap_exception()
1085 1087 break
1086 1088 self.session.send(self.query, "result_reply", content=content,
1087 1089 parent=msg, ident=client_id,
1088 1090 buffers=buffers)
1089 1091
@@ -1,423 +1,430
1 1 #!/usr/bin/env python
2 2 """
3 3 Kernel adapted from kernel.py to use ZMQ Streams
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010-2011 The IPython Development Team
7 7 #
8 8 # Distributed under the terms of the BSD License. The full license is in
9 9 # the file COPYING, distributed as part of this software.
10 10 #-----------------------------------------------------------------------------
11 11
12 12 #-----------------------------------------------------------------------------
13 13 # Imports
14 14 #-----------------------------------------------------------------------------
15 15
16 16 # Standard library imports.
17 17 from __future__ import print_function
18 18
19 19 import sys
20 20 import time
21 21
22 22 from code import CommandCompiler
23 23 from datetime import datetime
24 24 from pprint import pprint
25 25
26 26 # System library imports.
27 27 import zmq
28 28 from zmq.eventloop import ioloop, zmqstream
29 29
30 30 # Local imports.
31 31 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Str
32 32 from IPython.zmq.completer import KernelCompleter
33 33
34 34 from IPython.parallel.error import wrap_exception
35 35 from IPython.parallel.factory import SessionFactory
36 36 from IPython.parallel.util import serialize_object, unpack_apply_message, ISO8601
37 37
38 38 def printer(*args):
39 39 pprint(args, stream=sys.__stdout__)
40 40
41 41
42 class _Passer:
43 """Empty class that implements `send()` that does nothing."""
42 class _Passer(zmqstream.ZMQStream):
43 """Empty class that implements `send()` that does nothing.
44
45 Subclass ZMQStream for StreamSession typechecking
46
47 """
48 def __init__(self, *args, **kwargs):
49 pass
50
44 51 def send(self, *args, **kwargs):
45 52 pass
46 53 send_multipart = send
47 54
48 55
49 56 #-----------------------------------------------------------------------------
50 57 # Main kernel class
51 58 #-----------------------------------------------------------------------------
52 59
53 60 class Kernel(SessionFactory):
54 61
55 62 #---------------------------------------------------------------------------
56 63 # Kernel interface
57 64 #---------------------------------------------------------------------------
58 65
59 66 # kwargs:
60 67 int_id = Int(-1, config=True)
61 68 user_ns = Dict(config=True)
62 69 exec_lines = List(config=True)
63 70
64 71 control_stream = Instance(zmqstream.ZMQStream)
65 72 task_stream = Instance(zmqstream.ZMQStream)
66 73 iopub_stream = Instance(zmqstream.ZMQStream)
67 74 client = Instance('IPython.parallel.Client')
68 75
69 76 # internals
70 77 shell_streams = List()
71 78 compiler = Instance(CommandCompiler, (), {})
72 79 completer = Instance(KernelCompleter)
73 80
74 81 aborted = Set()
75 82 shell_handlers = Dict()
76 83 control_handlers = Dict()
77 84
78 85 def _set_prefix(self):
79 86 self.prefix = "engine.%s"%self.int_id
80 87
81 88 def _connect_completer(self):
82 89 self.completer = KernelCompleter(self.user_ns)
83 90
84 91 def __init__(self, **kwargs):
85 92 super(Kernel, self).__init__(**kwargs)
86 93 self._set_prefix()
87 94 self._connect_completer()
88 95
89 96 self.on_trait_change(self._set_prefix, 'id')
90 97 self.on_trait_change(self._connect_completer, 'user_ns')
91 98
92 99 # Build dict of handlers for message types
93 100 for msg_type in ['execute_request', 'complete_request', 'apply_request',
94 101 'clear_request']:
95 102 self.shell_handlers[msg_type] = getattr(self, msg_type)
96 103
97 104 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
98 105 self.control_handlers[msg_type] = getattr(self, msg_type)
99 106
100 107 self._initial_exec_lines()
101 108
102 109 def _wrap_exception(self, method=None):
103 110 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
104 111 content=wrap_exception(e_info)
105 112 return content
106 113
107 114 def _initial_exec_lines(self):
108 115 s = _Passer()
109 116 content = dict(silent=True, user_variable=[],user_expressions=[])
110 117 for line in self.exec_lines:
111 118 self.log.debug("executing initialization: %s"%line)
112 119 content.update({'code':line})
113 120 msg = self.session.msg('execute_request', content)
114 121 self.execute_request(s, [], msg)
115 122
116 123
117 124 #-------------------- control handlers -----------------------------
118 125 def abort_queues(self):
119 126 for stream in self.shell_streams:
120 127 if stream:
121 128 self.abort_queue(stream)
122 129
123 130 def abort_queue(self, stream):
124 131 while True:
125 132 try:
126 133 msg = self.session.recv(stream, zmq.NOBLOCK,content=True)
127 134 except zmq.ZMQError as e:
128 135 if e.errno == zmq.EAGAIN:
129 136 break
130 137 else:
131 138 return
132 139 else:
133 140 if msg is None:
134 141 return
135 142 else:
136 143 idents,msg = msg
137 144
138 145 # assert self.reply_socketly_socket.rcvmore(), "Unexpected missing message part."
139 146 # msg = self.reply_socket.recv_json()
140 147 self.log.info("Aborting:")
141 148 self.log.info(str(msg))
142 149 msg_type = msg['msg_type']
143 150 reply_type = msg_type.split('_')[0] + '_reply'
144 151 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
145 152 # self.reply_socket.send(ident,zmq.SNDMORE)
146 153 # self.reply_socket.send_json(reply_msg)
147 154 reply_msg = self.session.send(stream, reply_type,
148 155 content={'status' : 'aborted'}, parent=msg, ident=idents)[0]
149 156 self.log.debug(str(reply_msg))
150 157 # We need to wait a bit for requests to come in. This can probably
151 158 # be set shorter for true asynchronous clients.
152 159 time.sleep(0.05)
153 160
154 161 def abort_request(self, stream, ident, parent):
155 162 """abort a specifig msg by id"""
156 163 msg_ids = parent['content'].get('msg_ids', None)
157 164 if isinstance(msg_ids, basestring):
158 165 msg_ids = [msg_ids]
159 166 if not msg_ids:
160 167 self.abort_queues()
161 168 for mid in msg_ids:
162 169 self.aborted.add(str(mid))
163 170
164 171 content = dict(status='ok')
165 172 reply_msg = self.session.send(stream, 'abort_reply', content=content,
166 173 parent=parent, ident=ident)
167 174 self.log.debug(str(reply_msg))
168 175
169 176 def shutdown_request(self, stream, ident, parent):
170 177 """kill ourself. This should really be handled in an external process"""
171 178 try:
172 179 self.abort_queues()
173 180 except:
174 181 content = self._wrap_exception('shutdown')
175 182 else:
176 183 content = dict(parent['content'])
177 184 content['status'] = 'ok'
178 185 msg = self.session.send(stream, 'shutdown_reply',
179 186 content=content, parent=parent, ident=ident)
180 187 self.log.debug(str(msg))
181 188 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
182 189 dc.start()
183 190
184 191 def dispatch_control(self, msg):
185 192 idents,msg = self.session.feed_identities(msg, copy=False)
186 193 try:
187 194 msg = self.session.unpack_message(msg, content=True, copy=False)
188 195 except:
189 196 self.log.error("Invalid Message", exc_info=True)
190 197 return
191 198
192 199 header = msg['header']
193 200 msg_id = header['msg_id']
194 201
195 202 handler = self.control_handlers.get(msg['msg_type'], None)
196 203 if handler is None:
197 204 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
198 205 else:
199 206 handler(self.control_stream, idents, msg)
200 207
201 208
202 209 #-------------------- queue helpers ------------------------------
203 210
204 211 def check_dependencies(self, dependencies):
205 212 if not dependencies:
206 213 return True
207 214 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
208 215 anyorall = dependencies[0]
209 216 dependencies = dependencies[1]
210 217 else:
211 218 anyorall = 'all'
212 219 results = self.client.get_results(dependencies,status_only=True)
213 220 if results['status'] != 'ok':
214 221 return False
215 222
216 223 if anyorall == 'any':
217 224 if not results['completed']:
218 225 return False
219 226 else:
220 227 if results['pending']:
221 228 return False
222 229
223 230 return True
224 231
225 232 def check_aborted(self, msg_id):
226 233 return msg_id in self.aborted
227 234
228 235 #-------------------- queue handlers -----------------------------
229 236
230 237 def clear_request(self, stream, idents, parent):
231 238 """Clear our namespace."""
232 239 self.user_ns = {}
233 240 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
234 241 content = dict(status='ok'))
235 242 self._initial_exec_lines()
236 243
237 244 def execute_request(self, stream, ident, parent):
238 245 self.log.debug('execute request %s'%parent)
239 246 try:
240 247 code = parent[u'content'][u'code']
241 248 except:
242 249 self.log.error("Got bad msg: %s"%parent, exc_info=True)
243 250 return
244 251 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
245 252 ident='%s.pyin'%self.prefix)
246 253 started = datetime.now().strftime(ISO8601)
247 254 try:
248 255 comp_code = self.compiler(code, '<zmq-kernel>')
249 256 # allow for not overriding displayhook
250 257 if hasattr(sys.displayhook, 'set_parent'):
251 258 sys.displayhook.set_parent(parent)
252 259 sys.stdout.set_parent(parent)
253 260 sys.stderr.set_parent(parent)
254 261 exec comp_code in self.user_ns, self.user_ns
255 262 except:
256 263 exc_content = self._wrap_exception('execute')
257 264 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
258 265 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
259 266 ident='%s.pyerr'%self.prefix)
260 267 reply_content = exc_content
261 268 else:
262 269 reply_content = {'status' : 'ok'}
263 270
264 271 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
265 272 ident=ident, subheader = dict(started=started))
266 273 self.log.debug(str(reply_msg))
267 274 if reply_msg['content']['status'] == u'error':
268 275 self.abort_queues()
269 276
270 277 def complete_request(self, stream, ident, parent):
271 278 matches = {'matches' : self.complete(parent),
272 279 'status' : 'ok'}
273 280 completion_msg = self.session.send(stream, 'complete_reply',
274 281 matches, parent, ident)
275 282 # print >> sys.__stdout__, completion_msg
276 283
277 284 def complete(self, msg):
278 285 return self.completer.complete(msg.content.line, msg.content.text)
279 286
280 287 def apply_request(self, stream, ident, parent):
281 288 # flush previous reply, so this request won't block it
282 289 stream.flush(zmq.POLLOUT)
283 290
284 291 try:
285 292 content = parent[u'content']
286 293 bufs = parent[u'buffers']
287 294 msg_id = parent['header']['msg_id']
288 295 # bound = parent['header'].get('bound', False)
289 296 except:
290 297 self.log.error("Got bad msg: %s"%parent, exc_info=True)
291 298 return
292 299 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
293 300 # self.iopub_stream.send(pyin_msg)
294 301 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
295 302 sub = {'dependencies_met' : True, 'engine' : self.ident,
296 303 'started': datetime.now().strftime(ISO8601)}
297 304 try:
298 305 # allow for not overriding displayhook
299 306 if hasattr(sys.displayhook, 'set_parent'):
300 307 sys.displayhook.set_parent(parent)
301 308 sys.stdout.set_parent(parent)
302 309 sys.stderr.set_parent(parent)
303 310 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
304 311 working = self.user_ns
305 312 # suffix =
306 313 prefix = "_"+str(msg_id).replace("-","")+"_"
307 314
308 315 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
309 316 # if bound:
310 317 # bound_ns = Namespace(working)
311 318 # args = [bound_ns]+list(args)
312 319
313 320 fname = getattr(f, '__name__', 'f')
314 321
315 322 fname = prefix+"f"
316 323 argname = prefix+"args"
317 324 kwargname = prefix+"kwargs"
318 325 resultname = prefix+"result"
319 326
320 327 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
321 328 # print ns
322 329 working.update(ns)
323 330 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
324 331 try:
325 332 exec code in working,working
326 333 result = working.get(resultname)
327 334 finally:
328 335 for key in ns.iterkeys():
329 336 working.pop(key)
330 337 # if bound:
331 338 # working.update(bound_ns)
332 339
333 340 packed_result,buf = serialize_object(result)
334 341 result_buf = [packed_result]+buf
335 342 except:
336 343 exc_content = self._wrap_exception('apply')
337 344 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
338 345 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
339 346 ident='%s.pyerr'%self.prefix)
340 347 reply_content = exc_content
341 348 result_buf = []
342 349
343 350 if exc_content['ename'] == 'UnmetDependency':
344 351 sub['dependencies_met'] = False
345 352 else:
346 353 reply_content = {'status' : 'ok'}
347 354
348 355 # put 'ok'/'error' status in header, for scheduler introspection:
349 356 sub['status'] = reply_content['status']
350 357
351 358 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
352 359 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
353 360
354 361 # flush i/o
355 362 # should this be before reply_msg is sent, like in the single-kernel code,
356 363 # or should nothing get in the way of real results?
357 364 sys.stdout.flush()
358 365 sys.stderr.flush()
359 366
360 367 def dispatch_queue(self, stream, msg):
361 368 self.control_stream.flush()
362 369 idents,msg = self.session.feed_identities(msg, copy=False)
363 370 try:
364 371 msg = self.session.unpack_message(msg, content=True, copy=False)
365 372 except:
366 373 self.log.error("Invalid Message", exc_info=True)
367 374 return
368 375
369 376
370 377 header = msg['header']
371 378 msg_id = header['msg_id']
372 379 if self.check_aborted(msg_id):
373 380 self.aborted.remove(msg_id)
374 381 # is it safe to assume a msg_id will not be resubmitted?
375 382 reply_type = msg['msg_type'].split('_')[0] + '_reply'
376 383 reply_msg = self.session.send(stream, reply_type,
377 384 content={'status' : 'aborted'}, parent=msg, ident=idents)
378 385 return
379 386 handler = self.shell_handlers.get(msg['msg_type'], None)
380 387 if handler is None:
381 388 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
382 389 else:
383 390 handler(stream, idents, msg)
384 391
385 392 def start(self):
386 393 #### stream mode:
387 394 if self.control_stream:
388 395 self.control_stream.on_recv(self.dispatch_control, copy=False)
389 396 self.control_stream.on_err(printer)
390 397
391 398 def make_dispatcher(stream):
392 399 def dispatcher(msg):
393 400 return self.dispatch_queue(stream, msg)
394 401 return dispatcher
395 402
396 403 for s in self.shell_streams:
397 404 s.on_recv(make_dispatcher(s), copy=False)
398 405 s.on_err(printer)
399 406
400 407 if self.iopub_stream:
401 408 self.iopub_stream.on_err(printer)
402 409
403 410 #### while True mode:
404 411 # while True:
405 412 # idle = True
406 413 # try:
407 414 # msg = self.shell_stream.socket.recv_multipart(
408 415 # zmq.NOBLOCK, copy=False)
409 416 # except zmq.ZMQError, e:
410 417 # if e.errno != zmq.EAGAIN:
411 418 # raise e
412 419 # else:
413 420 # idle=False
414 421 # self.dispatch_queue(self.shell_stream, msg)
415 422 #
416 423 # if not self.task_stream.empty():
417 424 # idle=False
418 425 # msg = self.task_stream.recv_multipart()
419 426 # self.dispatch_queue(self.task_stream, msg)
420 427 # if idle:
421 428 # # don't busywait
422 429 # time.sleep(1e-3)
423 430
General Comments 0
You need to be logged in to leave comments. Login now