##// END OF EJS Templates
minor py3 fixes in IPython.parallel...
MinRK -
Show More
@@ -1,450 +1,452 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import json
27 27 import os
28 28 import socket
29 29 import stat
30 30 import sys
31 31
32 32 from multiprocessing import Process
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37
38 38 from IPython.core.profiledir import ProfileDir
39 39
40 40 from IPython.parallel.apps.baseapp import (
41 41 BaseParallelApplication,
42 42 base_aliases,
43 43 base_flags,
44 44 catch_config_error,
45 45 )
46 46 from IPython.utils.importstring import import_item
47 47 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
48 48
49 49 from IPython.zmq.session import (
50 50 Session, session_aliases, session_flags, default_secure
51 51 )
52 52
53 53 from IPython.parallel.controller.heartmonitor import HeartMonitor
54 54 from IPython.parallel.controller.hub import HubFactory
55 55 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
56 56 from IPython.parallel.controller.sqlitedb import SQLiteDB
57 57
58 from IPython.parallel.util import signal_children, split_url, asbytes, disambiguate_url
58 from IPython.parallel.util import signal_children, split_url, disambiguate_url
59 59
60 60 # conditional import of MongoDB backend class
61 61
62 62 try:
63 63 from IPython.parallel.controller.mongodb import MongoDB
64 64 except ImportError:
65 65 maybe_mongo = []
66 66 else:
67 67 maybe_mongo = [MongoDB]
68 68
69 69
70 70 #-----------------------------------------------------------------------------
71 71 # Module level variables
72 72 #-----------------------------------------------------------------------------
73 73
74 74
75 75 #: The default config file name for this application
76 76 default_config_file_name = u'ipcontroller_config.py'
77 77
78 78
79 79 _description = """Start the IPython controller for parallel computing.
80 80
81 81 The IPython controller provides a gateway between the IPython engines and
82 82 clients. The controller needs to be started before the engines and can be
83 83 configured using command line options or using a cluster directory. Cluster
84 84 directories contain config, log and security files and are usually located in
85 85 your ipython directory and named as "profile_name". See the `profile`
86 86 and `profile-dir` options for details.
87 87 """
88 88
89 89 _examples = """
90 90 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
91 91 ipcontroller --scheme=pure # use the pure zeromq scheduler
92 92 """
93 93
94 94
95 95 #-----------------------------------------------------------------------------
96 96 # The main application
97 97 #-----------------------------------------------------------------------------
98 98 flags = {}
99 99 flags.update(base_flags)
100 100 flags.update({
101 101 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
102 102 'Use threads instead of processes for the schedulers'),
103 103 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
104 104 'use the SQLiteDB backend'),
105 105 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
106 106 'use the MongoDB backend'),
107 107 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
108 108 'use the in-memory DictDB backend'),
109 109 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
110 110 'reuse existing json connection files')
111 111 })
112 112
113 113 flags.update(session_flags)
114 114
115 115 aliases = dict(
116 116 ssh = 'IPControllerApp.ssh_server',
117 117 enginessh = 'IPControllerApp.engine_ssh_server',
118 118 location = 'IPControllerApp.location',
119 119
120 120 url = 'HubFactory.url',
121 121 ip = 'HubFactory.ip',
122 122 transport = 'HubFactory.transport',
123 123 port = 'HubFactory.regport',
124 124
125 125 ping = 'HeartMonitor.period',
126 126
127 127 scheme = 'TaskScheduler.scheme_name',
128 128 hwm = 'TaskScheduler.hwm',
129 129 )
130 130 aliases.update(base_aliases)
131 131 aliases.update(session_aliases)
132 132
133 133
134 134 class IPControllerApp(BaseParallelApplication):
135 135
136 136 name = u'ipcontroller'
137 137 description = _description
138 138 examples = _examples
139 139 config_file_name = Unicode(default_config_file_name)
140 140 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
141 141
142 142 # change default to True
143 143 auto_create = Bool(True, config=True,
144 144 help="""Whether to create profile dir if it doesn't exist.""")
145 145
146 146 reuse_files = Bool(False, config=True,
147 147 help='Whether to reuse existing json connection files.'
148 148 )
149 149 ssh_server = Unicode(u'', config=True,
150 150 help="""ssh url for clients to use when connecting to the Controller
151 151 processes. It should be of the form: [user@]server[:port]. The
152 152 Controller's listening addresses must be accessible from the ssh server""",
153 153 )
154 154 engine_ssh_server = Unicode(u'', config=True,
155 155 help="""ssh url for engines to use when connecting to the Controller
156 156 processes. It should be of the form: [user@]server[:port]. The
157 157 Controller's listening addresses must be accessible from the ssh server""",
158 158 )
159 159 location = Unicode(u'', config=True,
160 160 help="""The external IP or domain name of the Controller, used for disambiguating
161 161 engine and client connections.""",
162 162 )
163 163 import_statements = List([], config=True,
164 164 help="import statements to be run at startup. Necessary in some environments"
165 165 )
166 166
167 167 use_threads = Bool(False, config=True,
168 168 help='Use threads instead of processes for the schedulers',
169 169 )
170 170
171 171 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
172 172 help="JSON filename where engine connection info will be stored.")
173 173 client_json_file = Unicode('ipcontroller-client.json', config=True,
174 174 help="JSON filename where client connection info will be stored.")
175 175
176 176 def _cluster_id_changed(self, name, old, new):
177 177 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
178 178 self.engine_json_file = "%s-engine.json" % self.name
179 179 self.client_json_file = "%s-client.json" % self.name
180 180
181 181
182 182 # internal
183 183 children = List()
184 184 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
185 185
186 186 def _use_threads_changed(self, name, old, new):
187 187 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
188 188
189 189 aliases = Dict(aliases)
190 190 flags = Dict(flags)
191 191
192 192
193 193 def save_connection_dict(self, fname, cdict):
194 194 """save a connection dict to json file."""
195 195 c = self.config
196 196 url = cdict['url']
197 197 location = cdict['location']
198 198 if not location:
199 199 try:
200 200 proto,ip,port = split_url(url)
201 201 except AssertionError:
202 202 pass
203 203 else:
204 204 try:
205 205 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
206 206 except (socket.gaierror, IndexError):
207 207 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
208 208 " You may need to specify '--location=<external_ip_address>' to help"
209 209 " IPython decide when to connect via loopback.")
210 210 location = '127.0.0.1'
211 211 cdict['location'] = location
212 212 fname = os.path.join(self.profile_dir.security_dir, fname)
213 213 self.log.info("writing connection info to %s", fname)
214 214 with open(fname, 'w') as f:
215 215 f.write(json.dumps(cdict, indent=2))
216 216 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
217 217
218 218 def load_config_from_json(self):
219 219 """load config from existing json connector files."""
220 220 c = self.config
221 221 self.log.debug("loading config from JSON")
222 222 # load from engine config
223 223 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
224 224 self.log.info("loading connection info from %s", fname)
225 225 with open(fname) as f:
226 226 cfg = json.loads(f.read())
227 key = c.Session.key = asbytes(cfg['exec_key'])
227 key = cfg['exec_key']
228 # json gives unicode, Session.key wants bytes
229 c.Session.key = key.encode('ascii')
228 230 xport,addr = cfg['url'].split('://')
229 231 c.HubFactory.engine_transport = xport
230 232 ip,ports = addr.split(':')
231 233 c.HubFactory.engine_ip = ip
232 234 c.HubFactory.regport = int(ports)
233 235 self.location = cfg['location']
234 236 if not self.engine_ssh_server:
235 237 self.engine_ssh_server = cfg['ssh']
236 238 # load client config
237 239 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
238 240 self.log.info("loading connection info from %s", fname)
239 241 with open(fname) as f:
240 242 cfg = json.loads(f.read())
241 243 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
242 244 xport,addr = cfg['url'].split('://')
243 245 c.HubFactory.client_transport = xport
244 246 ip,ports = addr.split(':')
245 247 c.HubFactory.client_ip = ip
246 248 if not self.ssh_server:
247 249 self.ssh_server = cfg['ssh']
248 250 assert int(ports) == c.HubFactory.regport, "regport mismatch"
249 251
250 252 def load_secondary_config(self):
251 253 """secondary config, loading from JSON and setting defaults"""
252 254 if self.reuse_files:
253 255 try:
254 256 self.load_config_from_json()
255 257 except (AssertionError,IOError) as e:
256 258 self.log.error("Could not load config from JSON: %s" % e)
257 259 self.reuse_files=False
258 260 # switch Session.key default to secure
259 261 default_secure(self.config)
260 262 self.log.debug("Config changed")
261 263 self.log.debug(repr(self.config))
262 264
263 265 def init_hub(self):
264 266 c = self.config
265 267
266 268 self.do_import_statements()
267 269
268 270 try:
269 271 self.factory = HubFactory(config=c, log=self.log)
270 272 # self.start_logging()
271 273 self.factory.init_hub()
272 274 except TraitError:
273 275 raise
274 276 except Exception:
275 277 self.log.error("Couldn't construct the Controller", exc_info=True)
276 278 self.exit(1)
277 279
278 280 if not self.reuse_files:
279 281 # save to new json config files
280 282 f = self.factory
281 283 cdict = {'exec_key' : f.session.key.decode('ascii'),
282 284 'ssh' : self.ssh_server,
283 285 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
284 286 'location' : self.location
285 287 }
286 288 self.save_connection_dict(self.client_json_file, cdict)
287 289 edict = cdict
288 290 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
289 291 edict['ssh'] = self.engine_ssh_server
290 292 self.save_connection_dict(self.engine_json_file, edict)
291 293
292 294 #
293 295 def init_schedulers(self):
294 296 children = self.children
295 297 mq = import_item(str(self.mq_class))
296 298
297 299 hub = self.factory
298 300 # disambiguate url, in case of *
299 301 monitor_url = disambiguate_url(hub.monitor_url)
300 302 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
301 303 # IOPub relay (in a Process)
302 304 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
303 305 q.bind_in(hub.client_info['iopub'])
304 306 q.bind_out(hub.engine_info['iopub'])
305 307 q.setsockopt_out(zmq.SUBSCRIBE, b'')
306 308 q.connect_mon(monitor_url)
307 309 q.daemon=True
308 310 children.append(q)
309 311
310 312 # Multiplexer Queue (in a Process)
311 313 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
312 314 q.bind_in(hub.client_info['mux'])
313 315 q.setsockopt_in(zmq.IDENTITY, b'mux')
314 316 q.bind_out(hub.engine_info['mux'])
315 317 q.connect_mon(monitor_url)
316 318 q.daemon=True
317 319 children.append(q)
318 320
319 321 # Control Queue (in a Process)
320 322 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
321 323 q.bind_in(hub.client_info['control'])
322 324 q.setsockopt_in(zmq.IDENTITY, b'control')
323 325 q.bind_out(hub.engine_info['control'])
324 326 q.connect_mon(monitor_url)
325 327 q.daemon=True
326 328 children.append(q)
327 329 try:
328 330 scheme = self.config.TaskScheduler.scheme_name
329 331 except AttributeError:
330 332 scheme = TaskScheduler.scheme_name.get_default_value()
331 333 # Task Queue (in a Process)
332 334 if scheme == 'pure':
333 335 self.log.warn("task::using pure XREQ Task scheduler")
334 336 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
335 337 # q.setsockopt_out(zmq.HWM, hub.hwm)
336 338 q.bind_in(hub.client_info['task'][1])
337 339 q.setsockopt_in(zmq.IDENTITY, b'task')
338 340 q.bind_out(hub.engine_info['task'])
339 341 q.connect_mon(monitor_url)
340 342 q.daemon=True
341 343 children.append(q)
342 344 elif scheme == 'none':
343 345 self.log.warn("task::using no Task scheduler")
344 346
345 347 else:
346 348 self.log.info("task::using Python %s Task scheduler"%scheme)
347 349 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
348 350 monitor_url, disambiguate_url(hub.client_info['notification']))
349 351 kwargs = dict(logname='scheduler', loglevel=self.log_level,
350 352 log_url = self.log_url, config=dict(self.config))
351 353 if 'Process' in self.mq_class:
352 354 # run the Python scheduler in a Process
353 355 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
354 356 q.daemon=True
355 357 children.append(q)
356 358 else:
357 359 # single-threaded Controller
358 360 kwargs['in_thread'] = True
359 361 launch_scheduler(*sargs, **kwargs)
360 362
361 363
362 364 def save_urls(self):
363 365 """save the registration urls to files."""
364 366 c = self.config
365 367
366 368 sec_dir = self.profile_dir.security_dir
367 369 cf = self.factory
368 370
369 371 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
370 372 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
371 373
372 374 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
373 375 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
374 376
375 377
376 378 def do_import_statements(self):
377 379 statements = self.import_statements
378 380 for s in statements:
379 381 try:
380 382 self.log.msg("Executing statement: '%s'" % s)
381 383 exec s in globals(), locals()
382 384 except:
383 385 self.log.msg("Error running statement: %s" % s)
384 386
385 387 def forward_logging(self):
386 388 if self.log_url:
387 389 self.log.info("Forwarding logging to %s"%self.log_url)
388 390 context = zmq.Context.instance()
389 391 lsock = context.socket(zmq.PUB)
390 392 lsock.connect(self.log_url)
391 393 handler = PUBHandler(lsock)
392 394 self.log.removeHandler(self._log_handler)
393 395 handler.root_topic = 'controller'
394 396 handler.setLevel(self.log_level)
395 397 self.log.addHandler(handler)
396 398 self._log_handler = handler
397 399
398 400 @catch_config_error
399 401 def initialize(self, argv=None):
400 402 super(IPControllerApp, self).initialize(argv)
401 403 self.forward_logging()
402 404 self.load_secondary_config()
403 405 self.init_hub()
404 406 self.init_schedulers()
405 407
406 408 def start(self):
407 409 # Start the subprocesses:
408 410 self.factory.start()
409 411 child_procs = []
410 412 for child in self.children:
411 413 child.start()
412 414 if isinstance(child, ProcessMonitoredQueue):
413 415 child_procs.append(child.launcher)
414 416 elif isinstance(child, Process):
415 417 child_procs.append(child)
416 418 if child_procs:
417 419 signal_children(child_procs)
418 420
419 421 self.write_pid_file(overwrite=True)
420 422
421 423 try:
422 424 self.factory.loop.start()
423 425 except KeyboardInterrupt:
424 426 self.log.critical("Interrupted, Exiting...\n")
425 427
426 428
427 429
428 430 def launch_new_instance():
429 431 """Create and run the IPython controller"""
430 432 if sys.platform == 'win32':
431 433 # make sure we don't get called from a multiprocessing subprocess
432 434 # this can result in infinite Controllers being started on Windows
433 435 # which doesn't have a proper fork, so multiprocessing is wonky
434 436
435 437 # this only comes up when IPython has been installed using vanilla
436 438 # setuptools, and *not* distribute.
437 439 import multiprocessing
438 440 p = multiprocessing.current_process()
439 441 # the main process has name 'MainProcess'
440 442 # subprocesses will have names like 'Process-1'
441 443 if p.name != 'MainProcess':
442 444 # we are a subprocess, don't start another Controller!
443 445 return
444 446 app = IPControllerApp.instance()
445 447 app.initialize()
446 448 app.start()
447 449
448 450
449 451 if __name__ == '__main__':
450 452 launch_new_instance()
@@ -1,1290 +1,1290 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import sys
22 22 import time
23 23 from datetime import datetime
24 24
25 25 import zmq
26 26 from zmq.eventloop import ioloop
27 27 from zmq.eventloop.zmqstream import ZMQStream
28 28
29 29 # internal:
30 30 from IPython.utils.importstring import import_item
31 31 from IPython.utils.traitlets import (
32 32 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 33 )
34 34
35 35 from IPython.parallel import error, util
36 36 from IPython.parallel.factory import RegistrationFactory
37 37
38 38 from IPython.zmq.session import SessionFactory
39 39
40 40 from .heartmonitor import HeartMonitor
41 41
42 42 #-----------------------------------------------------------------------------
43 43 # Code
44 44 #-----------------------------------------------------------------------------
45 45
46 46 def _passer(*args, **kwargs):
47 47 return
48 48
49 49 def _printer(*args, **kwargs):
50 50 print (args)
51 51 print (kwargs)
52 52
53 53 def empty_record():
54 54 """Return an empty dict with all record keys."""
55 55 return {
56 56 'msg_id' : None,
57 57 'header' : None,
58 58 'content': None,
59 59 'buffers': None,
60 60 'submitted': None,
61 61 'client_uuid' : None,
62 62 'engine_uuid' : None,
63 63 'started': None,
64 64 'completed': None,
65 65 'resubmitted': None,
66 66 'result_header' : None,
67 67 'result_content' : None,
68 68 'result_buffers' : None,
69 69 'queue' : None,
70 70 'pyin' : None,
71 71 'pyout': None,
72 72 'pyerr': None,
73 73 'stdout': '',
74 74 'stderr': '',
75 75 }
76 76
77 77 def init_record(msg):
78 78 """Initialize a TaskRecord based on a request."""
79 79 header = msg['header']
80 80 return {
81 81 'msg_id' : header['msg_id'],
82 82 'header' : header,
83 83 'content': msg['content'],
84 84 'buffers': msg['buffers'],
85 85 'submitted': header['date'],
86 86 'client_uuid' : None,
87 87 'engine_uuid' : None,
88 88 'started': None,
89 89 'completed': None,
90 90 'resubmitted': None,
91 91 'result_header' : None,
92 92 'result_content' : None,
93 93 'result_buffers' : None,
94 94 'queue' : None,
95 95 'pyin' : None,
96 96 'pyout': None,
97 97 'pyerr': None,
98 98 'stdout': '',
99 99 'stderr': '',
100 100 }
101 101
102 102
103 103 class EngineConnector(HasTraits):
104 104 """A simple object for accessing the various zmq connections of an object.
105 105 Attributes are:
106 106 id (int): engine ID
107 107 uuid (str): uuid (unused?)
108 108 queue (str): identity of queue's XREQ socket
109 109 registration (str): identity of registration XREQ socket
110 110 heartbeat (str): identity of heartbeat XREQ socket
111 111 """
112 112 id=Integer(0)
113 113 queue=CBytes()
114 114 control=CBytes()
115 115 registration=CBytes()
116 116 heartbeat=CBytes()
117 117 pending=Set()
118 118
119 119 class HubFactory(RegistrationFactory):
120 120 """The Configurable for setting up a Hub."""
121 121
122 122 # port-pairs for monitoredqueues:
123 123 hb = Tuple(Integer,Integer,config=True,
124 124 help="""XREQ/SUB Port pair for Engine heartbeats""")
125 125 def _hb_default(self):
126 126 return tuple(util.select_random_ports(2))
127 127
128 128 mux = Tuple(Integer,Integer,config=True,
129 129 help="""Engine/Client Port pair for MUX queue""")
130 130
131 131 def _mux_default(self):
132 132 return tuple(util.select_random_ports(2))
133 133
134 134 task = Tuple(Integer,Integer,config=True,
135 135 help="""Engine/Client Port pair for Task queue""")
136 136 def _task_default(self):
137 137 return tuple(util.select_random_ports(2))
138 138
139 139 control = Tuple(Integer,Integer,config=True,
140 140 help="""Engine/Client Port pair for Control queue""")
141 141
142 142 def _control_default(self):
143 143 return tuple(util.select_random_ports(2))
144 144
145 145 iopub = Tuple(Integer,Integer,config=True,
146 146 help="""Engine/Client Port pair for IOPub relay""")
147 147
148 148 def _iopub_default(self):
149 149 return tuple(util.select_random_ports(2))
150 150
151 151 # single ports:
152 152 mon_port = Integer(config=True,
153 153 help="""Monitor (SUB) port for queue traffic""")
154 154
155 155 def _mon_port_default(self):
156 156 return util.select_random_ports(1)[0]
157 157
158 158 notifier_port = Integer(config=True,
159 159 help="""PUB port for sending engine status notifications""")
160 160
161 161 def _notifier_port_default(self):
162 162 return util.select_random_ports(1)[0]
163 163
164 164 engine_ip = Unicode('127.0.0.1', config=True,
165 165 help="IP on which to listen for engine connections. [default: loopback]")
166 166 engine_transport = Unicode('tcp', config=True,
167 167 help="0MQ transport for engine connections. [default: tcp]")
168 168
169 169 client_ip = Unicode('127.0.0.1', config=True,
170 170 help="IP on which to listen for client connections. [default: loopback]")
171 171 client_transport = Unicode('tcp', config=True,
172 172 help="0MQ transport for client connections. [default : tcp]")
173 173
174 174 monitor_ip = Unicode('127.0.0.1', config=True,
175 175 help="IP on which to listen for monitor messages. [default: loopback]")
176 176 monitor_transport = Unicode('tcp', config=True,
177 177 help="0MQ transport for monitor messages. [default : tcp]")
178 178
179 179 monitor_url = Unicode('')
180 180
181 181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
182 182 config=True, help="""The class to use for the DB backend""")
183 183
184 184 # not configurable
185 185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
186 186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
187 187
188 188 def _ip_changed(self, name, old, new):
189 189 self.engine_ip = new
190 190 self.client_ip = new
191 191 self.monitor_ip = new
192 192 self._update_monitor_url()
193 193
194 194 def _update_monitor_url(self):
195 195 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
196 196
197 197 def _transport_changed(self, name, old, new):
198 198 self.engine_transport = new
199 199 self.client_transport = new
200 200 self.monitor_transport = new
201 201 self._update_monitor_url()
202 202
203 203 def __init__(self, **kwargs):
204 204 super(HubFactory, self).__init__(**kwargs)
205 205 self._update_monitor_url()
206 206
207 207
208 208 def construct(self):
209 209 self.init_hub()
210 210
211 211 def start(self):
212 212 self.heartmonitor.start()
213 213 self.log.info("Heartmonitor started")
214 214
215 215 def init_hub(self):
216 216 """construct"""
217 217 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
218 218 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
219 219
220 220 ctx = self.context
221 221 loop = self.loop
222 222
223 223 # Registrar socket
224 224 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
225 225 q.bind(client_iface % self.regport)
226 226 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
227 227 if self.client_ip != self.engine_ip:
228 228 q.bind(engine_iface % self.regport)
229 229 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
230 230
231 231 ### Engine connections ###
232 232
233 233 # heartbeat
234 234 hpub = ctx.socket(zmq.PUB)
235 235 hpub.bind(engine_iface % self.hb[0])
236 236 hrep = ctx.socket(zmq.ROUTER)
237 237 hrep.bind(engine_iface % self.hb[1])
238 238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
239 239 pingstream=ZMQStream(hpub,loop),
240 240 pongstream=ZMQStream(hrep,loop)
241 241 )
242 242
243 243 ### Client connections ###
244 244 # Notifier socket
245 245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
246 246 n.bind(client_iface%self.notifier_port)
247 247
248 248 ### build and launch the queues ###
249 249
250 250 # monitor socket
251 251 sub = ctx.socket(zmq.SUB)
252 252 sub.setsockopt(zmq.SUBSCRIBE, b"")
253 253 sub.bind(self.monitor_url)
254 254 sub.bind('inproc://monitor')
255 255 sub = ZMQStream(sub, loop)
256 256
257 257 # connect the db
258 258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
259 259 # cdir = self.config.Global.cluster_dir
260 260 self.db = import_item(str(self.db_class))(session=self.session.session,
261 261 config=self.config, log=self.log)
262 262 time.sleep(.25)
263 263 try:
264 264 scheme = self.config.TaskScheduler.scheme_name
265 265 except AttributeError:
266 266 from .scheduler import TaskScheduler
267 267 scheme = TaskScheduler.scheme_name.get_default_value()
268 268 # build connection dicts
269 269 self.engine_info = {
270 270 'control' : engine_iface%self.control[1],
271 271 'mux': engine_iface%self.mux[1],
272 272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
273 273 'task' : engine_iface%self.task[1],
274 274 'iopub' : engine_iface%self.iopub[1],
275 275 # 'monitor' : engine_iface%self.mon_port,
276 276 }
277 277
278 278 self.client_info = {
279 279 'control' : client_iface%self.control[0],
280 280 'mux': client_iface%self.mux[0],
281 281 'task' : (scheme, client_iface%self.task[0]),
282 282 'iopub' : client_iface%self.iopub[0],
283 283 'notification': client_iface%self.notifier_port
284 284 }
285 285 self.log.debug("Hub engine addrs: %s"%self.engine_info)
286 286 self.log.debug("Hub client addrs: %s"%self.client_info)
287 287
288 288 # resubmit stream
289 289 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
290 290 url = util.disambiguate_url(self.client_info['task'][-1])
291 291 r.setsockopt(zmq.IDENTITY, self.session.bsession)
292 292 r.connect(url)
293 293
294 294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
295 295 query=q, notifier=n, resubmit=r, db=self.db,
296 296 engine_info=self.engine_info, client_info=self.client_info,
297 297 log=self.log)
298 298
299 299
300 300 class Hub(SessionFactory):
301 301 """The IPython Controller Hub with 0MQ connections
302 302
303 303 Parameters
304 304 ==========
305 305 loop: zmq IOLoop instance
306 306 session: Session object
307 307 <removed> context: zmq context for creating new connections (?)
308 308 queue: ZMQStream for monitoring the command queue (SUB)
309 309 query: ZMQStream for engine registration and client queries requests (XREP)
310 310 heartbeat: HeartMonitor object checking the pulse of the engines
311 311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
312 312 db: connection to db for out of memory logging of commands
313 313 NotImplemented
314 314 engine_info: dict of zmq connection information for engines to connect
315 315 to the queues.
316 316 client_info: dict of zmq connection information for engines to connect
317 317 to the queues.
318 318 """
319 319 # internal data structures:
320 320 ids=Set() # engine IDs
321 321 keytable=Dict()
322 322 by_ident=Dict()
323 323 engines=Dict()
324 324 clients=Dict()
325 325 hearts=Dict()
326 326 pending=Set()
327 327 queues=Dict() # pending msg_ids keyed by engine_id
328 328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
329 329 completed=Dict() # completed msg_ids keyed by engine_id
330 330 all_completed=Set() # completed msg_ids keyed by engine_id
331 331 dead_engines=Set() # completed msg_ids keyed by engine_id
332 332 unassigned=Set() # set of task msg_ds not yet assigned a destination
333 333 incoming_registrations=Dict()
334 334 registration_timeout=Integer()
335 335 _idcounter=Integer(0)
336 336
337 337 # objects from constructor:
338 338 query=Instance(ZMQStream)
339 339 monitor=Instance(ZMQStream)
340 340 notifier=Instance(ZMQStream)
341 341 resubmit=Instance(ZMQStream)
342 342 heartmonitor=Instance(HeartMonitor)
343 343 db=Instance(object)
344 344 client_info=Dict()
345 345 engine_info=Dict()
346 346
347 347
348 348 def __init__(self, **kwargs):
349 349 """
350 350 # universal:
351 351 loop: IOLoop for creating future connections
352 352 session: streamsession for sending serialized data
353 353 # engine:
354 354 queue: ZMQStream for monitoring queue messages
355 355 query: ZMQStream for engine+client registration and client requests
356 356 heartbeat: HeartMonitor object for tracking engines
357 357 # extra:
358 358 db: ZMQStream for db connection (NotImplemented)
359 359 engine_info: zmq address/protocol dict for engine connections
360 360 client_info: zmq address/protocol dict for client connections
361 361 """
362 362
363 363 super(Hub, self).__init__(**kwargs)
364 364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
365 365
366 366 # validate connection dicts:
367 367 for k,v in self.client_info.iteritems():
368 368 if k == 'task':
369 369 util.validate_url_container(v[1])
370 370 else:
371 371 util.validate_url_container(v)
372 372 # util.validate_url_container(self.client_info)
373 373 util.validate_url_container(self.engine_info)
374 374
375 375 # register our callbacks
376 376 self.query.on_recv(self.dispatch_query)
377 377 self.monitor.on_recv(self.dispatch_monitor_traffic)
378 378
379 379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
380 380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
381 381
382 382 self.monitor_handlers = {b'in' : self.save_queue_request,
383 383 b'out': self.save_queue_result,
384 384 b'intask': self.save_task_request,
385 385 b'outtask': self.save_task_result,
386 386 b'tracktask': self.save_task_destination,
387 387 b'incontrol': _passer,
388 388 b'outcontrol': _passer,
389 389 b'iopub': self.save_iopub_message,
390 390 }
391 391
392 392 self.query_handlers = {'queue_request': self.queue_status,
393 393 'result_request': self.get_results,
394 394 'history_request': self.get_history,
395 395 'db_request': self.db_query,
396 396 'purge_request': self.purge_results,
397 397 'load_request': self.check_load,
398 398 'resubmit_request': self.resubmit_task,
399 399 'shutdown_request': self.shutdown_request,
400 400 'registration_request' : self.register_engine,
401 401 'unregistration_request' : self.unregister_engine,
402 402 'connection_request': self.connection_request,
403 403 }
404 404
405 405 # ignore resubmit replies
406 406 self.resubmit.on_recv(lambda msg: None, copy=False)
407 407
408 408 self.log.info("hub::created hub")
409 409
410 410 @property
411 411 def _next_id(self):
412 412 """gemerate a new ID.
413 413
414 414 No longer reuse old ids, just count from 0."""
415 415 newid = self._idcounter
416 416 self._idcounter += 1
417 417 return newid
418 418 # newid = 0
419 419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
420 420 # # print newid, self.ids, self.incoming_registrations
421 421 # while newid in self.ids or newid in incoming:
422 422 # newid += 1
423 423 # return newid
424 424
425 425 #-----------------------------------------------------------------------------
426 426 # message validation
427 427 #-----------------------------------------------------------------------------
428 428
429 429 def _validate_targets(self, targets):
430 430 """turn any valid targets argument into a list of integer ids"""
431 431 if targets is None:
432 432 # default to all
433 433 targets = self.ids
434 434
435 435 if isinstance(targets, (int,str,unicode)):
436 436 # only one target specified
437 437 targets = [targets]
438 438 _targets = []
439 439 for t in targets:
440 440 # map raw identities to ids
441 441 if isinstance(t, (str,unicode)):
442 442 t = self.by_ident.get(t, t)
443 443 _targets.append(t)
444 444 targets = _targets
445 445 bad_targets = [ t for t in targets if t not in self.ids ]
446 446 if bad_targets:
447 447 raise IndexError("No Such Engine: %r"%bad_targets)
448 448 if not targets:
449 449 raise IndexError("No Engines Registered")
450 450 return targets
451 451
452 452 #-----------------------------------------------------------------------------
453 453 # dispatch methods (1 per stream)
454 454 #-----------------------------------------------------------------------------
455 455
456 456
457 457 def dispatch_monitor_traffic(self, msg):
458 458 """all ME and Task queue messages come through here, as well as
459 459 IOPub traffic."""
460 460 self.log.debug("monitor traffic: %r"%msg[:2])
461 461 switch = msg[0]
462 462 try:
463 463 idents, msg = self.session.feed_identities(msg[1:])
464 464 except ValueError:
465 465 idents=[]
466 466 if not idents:
467 467 self.log.error("Bad Monitor Message: %r"%msg)
468 468 return
469 469 handler = self.monitor_handlers.get(switch, None)
470 470 if handler is not None:
471 471 handler(idents, msg)
472 472 else:
473 473 self.log.error("Invalid monitor topic: %r"%switch)
474 474
475 475
476 476 def dispatch_query(self, msg):
477 477 """Route registration requests and queries from clients."""
478 478 try:
479 479 idents, msg = self.session.feed_identities(msg)
480 480 except ValueError:
481 481 idents = []
482 482 if not idents:
483 483 self.log.error("Bad Query Message: %r"%msg)
484 484 return
485 485 client_id = idents[0]
486 486 try:
487 487 msg = self.session.unserialize(msg, content=True)
488 488 except Exception:
489 489 content = error.wrap_exception()
490 490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
491 491 self.session.send(self.query, "hub_error", ident=client_id,
492 492 content=content)
493 493 return
494 494 # print client_id, header, parent, content
495 495 #switch on message type:
496 496 msg_type = msg['header']['msg_type']
497 497 self.log.info("client::client %r requested %r"%(client_id, msg_type))
498 498 handler = self.query_handlers.get(msg_type, None)
499 499 try:
500 500 assert handler is not None, "Bad Message Type: %r"%msg_type
501 501 except:
502 502 content = error.wrap_exception()
503 503 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
504 504 self.session.send(self.query, "hub_error", ident=client_id,
505 505 content=content)
506 506 return
507 507
508 508 else:
509 509 handler(idents, msg)
510 510
511 511 def dispatch_db(self, msg):
512 512 """"""
513 513 raise NotImplementedError
514 514
515 515 #---------------------------------------------------------------------------
516 516 # handler methods (1 per event)
517 517 #---------------------------------------------------------------------------
518 518
519 519 #----------------------- Heartbeat --------------------------------------
520 520
521 521 def handle_new_heart(self, heart):
522 522 """handler to attach to heartbeater.
523 523 Called when a new heart starts to beat.
524 524 Triggers completion of registration."""
525 525 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
526 526 if heart not in self.incoming_registrations:
527 527 self.log.info("heartbeat::ignoring new heart: %r"%heart)
528 528 else:
529 529 self.finish_registration(heart)
530 530
531 531
532 532 def handle_heart_failure(self, heart):
533 533 """handler to attach to heartbeater.
534 534 called when a previously registered heart fails to respond to beat request.
535 535 triggers unregistration"""
536 536 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
537 537 eid = self.hearts.get(heart, None)
538 538 queue = self.engines[eid].queue
539 539 if eid is None:
540 540 self.log.info("heartbeat::ignoring heart failure %r"%heart)
541 541 else:
542 542 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
543 543
544 544 #----------------------- MUX Queue Traffic ------------------------------
545 545
546 546 def save_queue_request(self, idents, msg):
547 547 if len(idents) < 2:
548 548 self.log.error("invalid identity prefix: %r"%idents)
549 549 return
550 550 queue_id, client_id = idents[:2]
551 551 try:
552 552 msg = self.session.unserialize(msg)
553 553 except Exception:
554 554 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
555 555 return
556 556
557 557 eid = self.by_ident.get(queue_id, None)
558 558 if eid is None:
559 559 self.log.error("queue::target %r not registered"%queue_id)
560 560 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
561 561 return
562 562 record = init_record(msg)
563 563 msg_id = record['msg_id']
564 564 # Unicode in records
565 565 record['engine_uuid'] = queue_id.decode('ascii')
566 566 record['client_uuid'] = client_id.decode('ascii')
567 567 record['queue'] = 'mux'
568 568
569 569 try:
570 570 # it's posible iopub arrived first:
571 571 existing = self.db.get_record(msg_id)
572 572 for key,evalue in existing.iteritems():
573 573 rvalue = record.get(key, None)
574 574 if evalue and rvalue and evalue != rvalue:
575 575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
576 576 elif evalue and not rvalue:
577 577 record[key] = evalue
578 578 try:
579 579 self.db.update_record(msg_id, record)
580 580 except Exception:
581 581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
582 582 except KeyError:
583 583 try:
584 584 self.db.add_record(msg_id, record)
585 585 except Exception:
586 586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
587 587
588 588
589 589 self.pending.add(msg_id)
590 590 self.queues[eid].append(msg_id)
591 591
592 592 def save_queue_result(self, idents, msg):
593 593 if len(idents) < 2:
594 594 self.log.error("invalid identity prefix: %r"%idents)
595 595 return
596 596
597 597 client_id, queue_id = idents[:2]
598 598 try:
599 599 msg = self.session.unserialize(msg)
600 600 except Exception:
601 601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
602 602 queue_id,client_id, msg), exc_info=True)
603 603 return
604 604
605 605 eid = self.by_ident.get(queue_id, None)
606 606 if eid is None:
607 607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
608 608 return
609 609
610 610 parent = msg['parent_header']
611 611 if not parent:
612 612 return
613 613 msg_id = parent['msg_id']
614 614 if msg_id in self.pending:
615 615 self.pending.remove(msg_id)
616 616 self.all_completed.add(msg_id)
617 617 self.queues[eid].remove(msg_id)
618 618 self.completed[eid].append(msg_id)
619 619 elif msg_id not in self.all_completed:
620 620 # it could be a result from a dead engine that died before delivering the
621 621 # result
622 622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
623 623 return
624 624 # update record anyway, because the unregistration could have been premature
625 625 rheader = msg['header']
626 626 completed = rheader['date']
627 627 started = rheader.get('started', None)
628 628 result = {
629 629 'result_header' : rheader,
630 630 'result_content': msg['content'],
631 631 'started' : started,
632 632 'completed' : completed
633 633 }
634 634
635 635 result['result_buffers'] = msg['buffers']
636 636 try:
637 637 self.db.update_record(msg_id, result)
638 638 except Exception:
639 639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
640 640
641 641
642 642 #--------------------- Task Queue Traffic ------------------------------
643 643
644 644 def save_task_request(self, idents, msg):
645 645 """Save the submission of a task."""
646 646 client_id = idents[0]
647 647
648 648 try:
649 649 msg = self.session.unserialize(msg)
650 650 except Exception:
651 651 self.log.error("task::client %r sent invalid task message: %r"%(
652 652 client_id, msg), exc_info=True)
653 653 return
654 654 record = init_record(msg)
655 655
656 record['client_uuid'] = client_id
656 record['client_uuid'] = client_id.decode('ascii')
657 657 record['queue'] = 'task'
658 658 header = msg['header']
659 659 msg_id = header['msg_id']
660 660 self.pending.add(msg_id)
661 661 self.unassigned.add(msg_id)
662 662 try:
663 663 # it's posible iopub arrived first:
664 664 existing = self.db.get_record(msg_id)
665 665 if existing['resubmitted']:
666 666 for key in ('submitted', 'client_uuid', 'buffers'):
667 667 # don't clobber these keys on resubmit
668 668 # submitted and client_uuid should be different
669 669 # and buffers might be big, and shouldn't have changed
670 670 record.pop(key)
671 671 # still check content,header which should not change
672 672 # but are not expensive to compare as buffers
673 673
674 674 for key,evalue in existing.iteritems():
675 675 if key.endswith('buffers'):
676 676 # don't compare buffers
677 677 continue
678 678 rvalue = record.get(key, None)
679 679 if evalue and rvalue and evalue != rvalue:
680 680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
681 681 elif evalue and not rvalue:
682 682 record[key] = evalue
683 683 try:
684 684 self.db.update_record(msg_id, record)
685 685 except Exception:
686 686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
687 687 except KeyError:
688 688 try:
689 689 self.db.add_record(msg_id, record)
690 690 except Exception:
691 691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
692 692 except Exception:
693 693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
694 694
695 695 def save_task_result(self, idents, msg):
696 696 """save the result of a completed task."""
697 697 client_id = idents[0]
698 698 try:
699 699 msg = self.session.unserialize(msg)
700 700 except Exception:
701 701 self.log.error("task::invalid task result message send to %r: %r"%(
702 702 client_id, msg), exc_info=True)
703 703 return
704 704
705 705 parent = msg['parent_header']
706 706 if not parent:
707 707 # print msg
708 708 self.log.warn("Task %r had no parent!"%msg)
709 709 return
710 710 msg_id = parent['msg_id']
711 711 if msg_id in self.unassigned:
712 712 self.unassigned.remove(msg_id)
713 713
714 714 header = msg['header']
715 715 engine_uuid = header.get('engine', None)
716 716 eid = self.by_ident.get(engine_uuid, None)
717 717
718 718 if msg_id in self.pending:
719 719 self.pending.remove(msg_id)
720 720 self.all_completed.add(msg_id)
721 721 if eid is not None:
722 722 self.completed[eid].append(msg_id)
723 723 if msg_id in self.tasks[eid]:
724 724 self.tasks[eid].remove(msg_id)
725 725 completed = header['date']
726 726 started = header.get('started', None)
727 727 result = {
728 728 'result_header' : header,
729 729 'result_content': msg['content'],
730 730 'started' : started,
731 731 'completed' : completed,
732 732 'engine_uuid': engine_uuid
733 733 }
734 734
735 735 result['result_buffers'] = msg['buffers']
736 736 try:
737 737 self.db.update_record(msg_id, result)
738 738 except Exception:
739 739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
740 740
741 741 else:
742 742 self.log.debug("task::unknown task %r finished"%msg_id)
743 743
744 744 def save_task_destination(self, idents, msg):
745 745 try:
746 746 msg = self.session.unserialize(msg, content=True)
747 747 except Exception:
748 748 self.log.error("task::invalid task tracking message", exc_info=True)
749 749 return
750 750 content = msg['content']
751 751 # print (content)
752 752 msg_id = content['msg_id']
753 753 engine_uuid = content['engine_id']
754 754 eid = self.by_ident[util.asbytes(engine_uuid)]
755 755
756 756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
757 757 if msg_id in self.unassigned:
758 758 self.unassigned.remove(msg_id)
759 759 # else:
760 760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
761 761
762 762 self.tasks[eid].append(msg_id)
763 763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
764 764 try:
765 765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
766 766 except Exception:
767 767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
768 768
769 769
770 770 def mia_task_request(self, idents, msg):
771 771 raise NotImplementedError
772 772 client_id = idents[0]
773 773 # content = dict(mia=self.mia,status='ok')
774 774 # self.session.send('mia_reply', content=content, idents=client_id)
775 775
776 776
777 777 #--------------------- IOPub Traffic ------------------------------
778 778
779 779 def save_iopub_message(self, topics, msg):
780 780 """save an iopub message into the db"""
781 781 # print (topics)
782 782 try:
783 783 msg = self.session.unserialize(msg, content=True)
784 784 except Exception:
785 785 self.log.error("iopub::invalid IOPub message", exc_info=True)
786 786 return
787 787
788 788 parent = msg['parent_header']
789 789 if not parent:
790 790 self.log.error("iopub::invalid IOPub message: %r"%msg)
791 791 return
792 792 msg_id = parent['msg_id']
793 793 msg_type = msg['header']['msg_type']
794 794 content = msg['content']
795 795
796 796 # ensure msg_id is in db
797 797 try:
798 798 rec = self.db.get_record(msg_id)
799 799 except KeyError:
800 800 rec = empty_record()
801 801 rec['msg_id'] = msg_id
802 802 self.db.add_record(msg_id, rec)
803 803 # stream
804 804 d = {}
805 805 if msg_type == 'stream':
806 806 name = content['name']
807 807 s = rec[name] or ''
808 808 d[name] = s + content['data']
809 809
810 810 elif msg_type == 'pyerr':
811 811 d['pyerr'] = content
812 812 elif msg_type == 'pyin':
813 813 d['pyin'] = content['code']
814 814 else:
815 815 d[msg_type] = content.get('data', '')
816 816
817 817 try:
818 818 self.db.update_record(msg_id, d)
819 819 except Exception:
820 820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
821 821
822 822
823 823
824 824 #-------------------------------------------------------------------------
825 825 # Registration requests
826 826 #-------------------------------------------------------------------------
827 827
828 828 def connection_request(self, client_id, msg):
829 829 """Reply with connection addresses for clients."""
830 830 self.log.info("client::client %r connected"%client_id)
831 831 content = dict(status='ok')
832 832 content.update(self.client_info)
833 833 jsonable = {}
834 834 for k,v in self.keytable.iteritems():
835 835 if v not in self.dead_engines:
836 836 jsonable[str(k)] = v.decode('ascii')
837 837 content['engines'] = jsonable
838 838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
839 839
840 840 def register_engine(self, reg, msg):
841 841 """Register a new engine."""
842 842 content = msg['content']
843 843 try:
844 844 queue = util.asbytes(content['queue'])
845 845 except KeyError:
846 846 self.log.error("registration::queue not specified", exc_info=True)
847 847 return
848 848 heart = content.get('heartbeat', None)
849 849 if heart:
850 850 heart = util.asbytes(heart)
851 851 """register a new engine, and create the socket(s) necessary"""
852 852 eid = self._next_id
853 853 # print (eid, queue, reg, heart)
854 854
855 855 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
856 856
857 857 content = dict(id=eid,status='ok')
858 858 content.update(self.engine_info)
859 859 # check if requesting available IDs:
860 860 if queue in self.by_ident:
861 861 try:
862 862 raise KeyError("queue_id %r in use"%queue)
863 863 except:
864 864 content = error.wrap_exception()
865 865 self.log.error("queue_id %r in use"%queue, exc_info=True)
866 866 elif heart in self.hearts: # need to check unique hearts?
867 867 try:
868 868 raise KeyError("heart_id %r in use"%heart)
869 869 except:
870 870 self.log.error("heart_id %r in use"%heart, exc_info=True)
871 871 content = error.wrap_exception()
872 872 else:
873 873 for h, pack in self.incoming_registrations.iteritems():
874 874 if heart == h:
875 875 try:
876 876 raise KeyError("heart_id %r in use"%heart)
877 877 except:
878 878 self.log.error("heart_id %r in use"%heart, exc_info=True)
879 879 content = error.wrap_exception()
880 880 break
881 881 elif queue == pack[1]:
882 882 try:
883 883 raise KeyError("queue_id %r in use"%queue)
884 884 except:
885 885 self.log.error("queue_id %r in use"%queue, exc_info=True)
886 886 content = error.wrap_exception()
887 887 break
888 888
889 889 msg = self.session.send(self.query, "registration_reply",
890 890 content=content,
891 891 ident=reg)
892 892
893 893 if content['status'] == 'ok':
894 894 if heart in self.heartmonitor.hearts:
895 895 # already beating
896 896 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
897 897 self.finish_registration(heart)
898 898 else:
899 899 purge = lambda : self._purge_stalled_registration(heart)
900 900 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
901 901 dc.start()
902 902 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
903 903 else:
904 904 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
905 905 return eid
906 906
907 907 def unregister_engine(self, ident, msg):
908 908 """Unregister an engine that explicitly requested to leave."""
909 909 try:
910 910 eid = msg['content']['id']
911 911 except:
912 912 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
913 913 return
914 914 self.log.info("registration::unregister_engine(%r)"%eid)
915 915 # print (eid)
916 916 uuid = self.keytable[eid]
917 917 content=dict(id=eid, queue=uuid.decode('ascii'))
918 918 self.dead_engines.add(uuid)
919 919 # self.ids.remove(eid)
920 920 # uuid = self.keytable.pop(eid)
921 921 #
922 922 # ec = self.engines.pop(eid)
923 923 # self.hearts.pop(ec.heartbeat)
924 924 # self.by_ident.pop(ec.queue)
925 925 # self.completed.pop(eid)
926 926 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
927 927 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
928 928 dc.start()
929 929 ############## TODO: HANDLE IT ################
930 930
931 931 if self.notifier:
932 932 self.session.send(self.notifier, "unregistration_notification", content=content)
933 933
934 934 def _handle_stranded_msgs(self, eid, uuid):
935 935 """Handle messages known to be on an engine when the engine unregisters.
936 936
937 937 It is possible that this will fire prematurely - that is, an engine will
938 938 go down after completing a result, and the client will be notified
939 939 that the result failed and later receive the actual result.
940 940 """
941 941
942 942 outstanding = self.queues[eid]
943 943
944 944 for msg_id in outstanding:
945 945 self.pending.remove(msg_id)
946 946 self.all_completed.add(msg_id)
947 947 try:
948 948 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
949 949 except:
950 950 content = error.wrap_exception()
951 951 # build a fake header:
952 952 header = {}
953 953 header['engine'] = uuid
954 954 header['date'] = datetime.now()
955 955 rec = dict(result_content=content, result_header=header, result_buffers=[])
956 956 rec['completed'] = header['date']
957 957 rec['engine_uuid'] = uuid
958 958 try:
959 959 self.db.update_record(msg_id, rec)
960 960 except Exception:
961 961 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
962 962
963 963
964 964 def finish_registration(self, heart):
965 965 """Second half of engine registration, called after our HeartMonitor
966 966 has received a beat from the Engine's Heart."""
967 967 try:
968 968 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
969 969 except KeyError:
970 970 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
971 971 return
972 972 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
973 973 if purge is not None:
974 974 purge.stop()
975 975 control = queue
976 976 self.ids.add(eid)
977 977 self.keytable[eid] = queue
978 978 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
979 979 control=control, heartbeat=heart)
980 980 self.by_ident[queue] = eid
981 981 self.queues[eid] = list()
982 982 self.tasks[eid] = list()
983 983 self.completed[eid] = list()
984 984 self.hearts[heart] = eid
985 985 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
986 986 if self.notifier:
987 987 self.session.send(self.notifier, "registration_notification", content=content)
988 988 self.log.info("engine::Engine Connected: %i"%eid)
989 989
990 990 def _purge_stalled_registration(self, heart):
991 991 if heart in self.incoming_registrations:
992 992 eid = self.incoming_registrations.pop(heart)[0]
993 993 self.log.info("registration::purging stalled registration: %i"%eid)
994 994 else:
995 995 pass
996 996
997 997 #-------------------------------------------------------------------------
998 998 # Client Requests
999 999 #-------------------------------------------------------------------------
1000 1000
1001 1001 def shutdown_request(self, client_id, msg):
1002 1002 """handle shutdown request."""
1003 1003 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1004 1004 # also notify other clients of shutdown
1005 1005 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1006 1006 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1007 1007 dc.start()
1008 1008
1009 1009 def _shutdown(self):
1010 1010 self.log.info("hub::hub shutting down.")
1011 1011 time.sleep(0.1)
1012 1012 sys.exit(0)
1013 1013
1014 1014
1015 1015 def check_load(self, client_id, msg):
1016 1016 content = msg['content']
1017 1017 try:
1018 1018 targets = content['targets']
1019 1019 targets = self._validate_targets(targets)
1020 1020 except:
1021 1021 content = error.wrap_exception()
1022 1022 self.session.send(self.query, "hub_error",
1023 1023 content=content, ident=client_id)
1024 1024 return
1025 1025
1026 1026 content = dict(status='ok')
1027 1027 # loads = {}
1028 1028 for t in targets:
1029 1029 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1030 1030 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1031 1031
1032 1032
1033 1033 def queue_status(self, client_id, msg):
1034 1034 """Return the Queue status of one or more targets.
1035 1035 if verbose: return the msg_ids
1036 1036 else: return len of each type.
1037 1037 keys: queue (pending MUX jobs)
1038 1038 tasks (pending Task jobs)
1039 1039 completed (finished jobs from both queues)"""
1040 1040 content = msg['content']
1041 1041 targets = content['targets']
1042 1042 try:
1043 1043 targets = self._validate_targets(targets)
1044 1044 except:
1045 1045 content = error.wrap_exception()
1046 1046 self.session.send(self.query, "hub_error",
1047 1047 content=content, ident=client_id)
1048 1048 return
1049 1049 verbose = content.get('verbose', False)
1050 1050 content = dict(status='ok')
1051 1051 for t in targets:
1052 1052 queue = self.queues[t]
1053 1053 completed = self.completed[t]
1054 1054 tasks = self.tasks[t]
1055 1055 if not verbose:
1056 1056 queue = len(queue)
1057 1057 completed = len(completed)
1058 1058 tasks = len(tasks)
1059 1059 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1060 1060 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1061 1061 # print (content)
1062 1062 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1063 1063
1064 1064 def purge_results(self, client_id, msg):
1065 1065 """Purge results from memory. This method is more valuable before we move
1066 1066 to a DB based message storage mechanism."""
1067 1067 content = msg['content']
1068 1068 self.log.info("Dropping records with %s", content)
1069 1069 msg_ids = content.get('msg_ids', [])
1070 1070 reply = dict(status='ok')
1071 1071 if msg_ids == 'all':
1072 1072 try:
1073 1073 self.db.drop_matching_records(dict(completed={'$ne':None}))
1074 1074 except Exception:
1075 1075 reply = error.wrap_exception()
1076 1076 else:
1077 1077 pending = filter(lambda m: m in self.pending, msg_ids)
1078 1078 if pending:
1079 1079 try:
1080 1080 raise IndexError("msg pending: %r"%pending[0])
1081 1081 except:
1082 1082 reply = error.wrap_exception()
1083 1083 else:
1084 1084 try:
1085 1085 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1086 1086 except Exception:
1087 1087 reply = error.wrap_exception()
1088 1088
1089 1089 if reply['status'] == 'ok':
1090 1090 eids = content.get('engine_ids', [])
1091 1091 for eid in eids:
1092 1092 if eid not in self.engines:
1093 1093 try:
1094 1094 raise IndexError("No such engine: %i"%eid)
1095 1095 except:
1096 1096 reply = error.wrap_exception()
1097 1097 break
1098 1098 uid = self.engines[eid].queue
1099 1099 try:
1100 1100 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1101 1101 except Exception:
1102 1102 reply = error.wrap_exception()
1103 1103 break
1104 1104
1105 1105 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1106 1106
1107 1107 def resubmit_task(self, client_id, msg):
1108 1108 """Resubmit one or more tasks."""
1109 1109 def finish(reply):
1110 1110 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1111 1111
1112 1112 content = msg['content']
1113 1113 msg_ids = content['msg_ids']
1114 1114 reply = dict(status='ok')
1115 1115 try:
1116 1116 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1117 1117 'header', 'content', 'buffers'])
1118 1118 except Exception:
1119 1119 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1120 1120 return finish(error.wrap_exception())
1121 1121
1122 1122 # validate msg_ids
1123 1123 found_ids = [ rec['msg_id'] for rec in records ]
1124 1124 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1125 1125 if len(records) > len(msg_ids):
1126 1126 try:
1127 1127 raise RuntimeError("DB appears to be in an inconsistent state."
1128 1128 "More matching records were found than should exist")
1129 1129 except Exception:
1130 1130 return finish(error.wrap_exception())
1131 1131 elif len(records) < len(msg_ids):
1132 1132 missing = [ m for m in msg_ids if m not in found_ids ]
1133 1133 try:
1134 1134 raise KeyError("No such msg(s): %r"%missing)
1135 1135 except KeyError:
1136 1136 return finish(error.wrap_exception())
1137 1137 elif invalid_ids:
1138 1138 msg_id = invalid_ids[0]
1139 1139 try:
1140 1140 raise ValueError("Task %r appears to be inflight"%(msg_id))
1141 1141 except Exception:
1142 1142 return finish(error.wrap_exception())
1143 1143
1144 1144 # clear the existing records
1145 1145 now = datetime.now()
1146 1146 rec = empty_record()
1147 1147 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1148 1148 rec['resubmitted'] = now
1149 1149 rec['queue'] = 'task'
1150 1150 rec['client_uuid'] = client_id[0]
1151 1151 try:
1152 1152 for msg_id in msg_ids:
1153 1153 self.all_completed.discard(msg_id)
1154 1154 self.db.update_record(msg_id, rec)
1155 1155 except Exception:
1156 1156 self.log.error('db::db error upating record', exc_info=True)
1157 1157 reply = error.wrap_exception()
1158 1158 else:
1159 1159 # send the messages
1160 1160 for rec in records:
1161 1161 header = rec['header']
1162 1162 # include resubmitted in header to prevent digest collision
1163 1163 header['resubmitted'] = now
1164 1164 msg = self.session.msg(header['msg_type'])
1165 1165 msg['content'] = rec['content']
1166 1166 msg['header'] = header
1167 1167 msg['header']['msg_id'] = rec['msg_id']
1168 1168 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1169 1169
1170 1170 finish(dict(status='ok'))
1171 1171
1172 1172
1173 1173 def _extract_record(self, rec):
1174 1174 """decompose a TaskRecord dict into subsection of reply for get_result"""
1175 1175 io_dict = {}
1176 1176 for key in 'pyin pyout pyerr stdout stderr'.split():
1177 1177 io_dict[key] = rec[key]
1178 1178 content = { 'result_content': rec['result_content'],
1179 1179 'header': rec['header'],
1180 1180 'result_header' : rec['result_header'],
1181 1181 'io' : io_dict,
1182 1182 }
1183 1183 if rec['result_buffers']:
1184 1184 buffers = map(bytes, rec['result_buffers'])
1185 1185 else:
1186 1186 buffers = []
1187 1187
1188 1188 return content, buffers
1189 1189
1190 1190 def get_results(self, client_id, msg):
1191 1191 """Get the result of 1 or more messages."""
1192 1192 content = msg['content']
1193 1193 msg_ids = sorted(set(content['msg_ids']))
1194 1194 statusonly = content.get('status_only', False)
1195 1195 pending = []
1196 1196 completed = []
1197 1197 content = dict(status='ok')
1198 1198 content['pending'] = pending
1199 1199 content['completed'] = completed
1200 1200 buffers = []
1201 1201 if not statusonly:
1202 1202 try:
1203 1203 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1204 1204 # turn match list into dict, for faster lookup
1205 1205 records = {}
1206 1206 for rec in matches:
1207 1207 records[rec['msg_id']] = rec
1208 1208 except Exception:
1209 1209 content = error.wrap_exception()
1210 1210 self.session.send(self.query, "result_reply", content=content,
1211 1211 parent=msg, ident=client_id)
1212 1212 return
1213 1213 else:
1214 1214 records = {}
1215 1215 for msg_id in msg_ids:
1216 1216 if msg_id in self.pending:
1217 1217 pending.append(msg_id)
1218 1218 elif msg_id in self.all_completed:
1219 1219 completed.append(msg_id)
1220 1220 if not statusonly:
1221 1221 c,bufs = self._extract_record(records[msg_id])
1222 1222 content[msg_id] = c
1223 1223 buffers.extend(bufs)
1224 1224 elif msg_id in records:
1225 1225 if rec['completed']:
1226 1226 completed.append(msg_id)
1227 1227 c,bufs = self._extract_record(records[msg_id])
1228 1228 content[msg_id] = c
1229 1229 buffers.extend(bufs)
1230 1230 else:
1231 1231 pending.append(msg_id)
1232 1232 else:
1233 1233 try:
1234 1234 raise KeyError('No such message: '+msg_id)
1235 1235 except:
1236 1236 content = error.wrap_exception()
1237 1237 break
1238 1238 self.session.send(self.query, "result_reply", content=content,
1239 1239 parent=msg, ident=client_id,
1240 1240 buffers=buffers)
1241 1241
1242 1242 def get_history(self, client_id, msg):
1243 1243 """Get a list of all msg_ids in our DB records"""
1244 1244 try:
1245 1245 msg_ids = self.db.get_history()
1246 1246 except Exception as e:
1247 1247 content = error.wrap_exception()
1248 1248 else:
1249 1249 content = dict(status='ok', history=msg_ids)
1250 1250
1251 1251 self.session.send(self.query, "history_reply", content=content,
1252 1252 parent=msg, ident=client_id)
1253 1253
1254 1254 def db_query(self, client_id, msg):
1255 1255 """Perform a raw query on the task record database."""
1256 1256 content = msg['content']
1257 1257 query = content.get('query', {})
1258 1258 keys = content.get('keys', None)
1259 1259 buffers = []
1260 1260 empty = list()
1261 1261 try:
1262 1262 records = self.db.find_records(query, keys)
1263 1263 except Exception as e:
1264 1264 content = error.wrap_exception()
1265 1265 else:
1266 1266 # extract buffers from reply content:
1267 1267 if keys is not None:
1268 1268 buffer_lens = [] if 'buffers' in keys else None
1269 1269 result_buffer_lens = [] if 'result_buffers' in keys else None
1270 1270 else:
1271 1271 buffer_lens = []
1272 1272 result_buffer_lens = []
1273 1273
1274 1274 for rec in records:
1275 1275 # buffers may be None, so double check
1276 1276 if buffer_lens is not None:
1277 1277 b = rec.pop('buffers', empty) or empty
1278 1278 buffer_lens.append(len(b))
1279 1279 buffers.extend(b)
1280 1280 if result_buffer_lens is not None:
1281 1281 rb = rec.pop('result_buffers', empty) or empty
1282 1282 result_buffer_lens.append(len(rb))
1283 1283 buffers.extend(rb)
1284 1284 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1285 1285 result_buffer_lens=result_buffer_lens)
1286 1286 # self.log.debug (content)
1287 1287 self.session.send(self.query, "db_reply", content=content,
1288 1288 parent=msg, ident=client_id,
1289 1289 buffers=buffers)
1290 1290
@@ -1,234 +1,234 b''
1 1 """A simple engine that talks to a controller over 0MQ.
2 2 it handles registration, etc. and launches a kernel
3 3 connected to the Controller's Schedulers.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 from __future__ import print_function
17 17
18 18 import sys
19 19 import time
20 20 from getpass import getpass
21 21
22 22 import zmq
23 23 from zmq.eventloop import ioloop, zmqstream
24 24
25 25 from IPython.external.ssh import tunnel
26 26 # internal
27 27 from IPython.utils.traitlets import (
28 28 Instance, Dict, Integer, Type, CFloat, Unicode, CBytes, Bool
29 29 )
30 # from IPython.utils.localinterfaces import LOCALHOST
30 from IPython.utils import py3compat
31 31
32 32 from IPython.parallel.controller.heartmonitor import Heart
33 33 from IPython.parallel.factory import RegistrationFactory
34 34 from IPython.parallel.util import disambiguate_url, asbytes
35 35
36 36 from IPython.zmq.session import Message
37 37
38 38 from .streamkernel import Kernel
39 39
40 40 class EngineFactory(RegistrationFactory):
41 41 """IPython engine"""
42 42
43 43 # configurables:
44 44 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
45 45 help="""The OutStream for handling stdout/err.
46 46 Typically 'IPython.zmq.iostream.OutStream'""")
47 47 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
48 48 help="""The class for handling displayhook.
49 49 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
50 50 location=Unicode(config=True,
51 51 help="""The location (an IP address) of the controller. This is
52 52 used for disambiguating URLs, to determine whether
53 53 loopback should be used to connect or the public address.""")
54 54 timeout=CFloat(2,config=True,
55 55 help="""The time (in seconds) to wait for the Controller to respond
56 56 to registration requests before giving up.""")
57 57 sshserver=Unicode(config=True,
58 58 help="""The SSH server to use for tunneling connections to the Controller.""")
59 59 sshkey=Unicode(config=True,
60 60 help="""The SSH private key file to use when tunneling connections to the Controller.""")
61 61 paramiko=Bool(sys.platform == 'win32', config=True,
62 62 help="""Whether to use paramiko instead of openssh for tunnels.""")
63 63
64 64 # not configurable:
65 65 user_ns=Dict()
66 66 id=Integer(allow_none=True)
67 67 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 68 kernel=Instance(Kernel)
69 69
70 70 bident = CBytes()
71 71 ident = Unicode()
72 72 def _ident_changed(self, name, old, new):
73 73 self.bident = asbytes(new)
74 74 using_ssh=Bool(False)
75 75
76 76
77 77 def __init__(self, **kwargs):
78 78 super(EngineFactory, self).__init__(**kwargs)
79 79 self.ident = self.session.session
80 80
81 81 def init_connector(self):
82 82 """construct connection function, which handles tunnels."""
83 83 self.using_ssh = bool(self.sshkey or self.sshserver)
84 84
85 85 if self.sshkey and not self.sshserver:
86 86 # We are using ssh directly to the controller, tunneling localhost to localhost
87 87 self.sshserver = self.url.split('://')[1].split(':')[0]
88 88
89 89 if self.using_ssh:
90 90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
91 91 password=False
92 92 else:
93 93 password = getpass("SSH Password for %s: "%self.sshserver)
94 94 else:
95 95 password = False
96 96
97 97 def connect(s, url):
98 98 url = disambiguate_url(url, self.location)
99 99 if self.using_ssh:
100 100 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
101 101 return tunnel.tunnel_connection(s, url, self.sshserver,
102 102 keyfile=self.sshkey, paramiko=self.paramiko,
103 103 password=password,
104 104 )
105 105 else:
106 106 return s.connect(url)
107 107
108 108 def maybe_tunnel(url):
109 109 """like connect, but don't complete the connection (for use by heartbeat)"""
110 110 url = disambiguate_url(url, self.location)
111 111 if self.using_ssh:
112 112 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
113 113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
114 114 keyfile=self.sshkey, paramiko=self.paramiko,
115 115 password=password,
116 116 )
117 117 return url
118 118 return connect, maybe_tunnel
119 119
120 120 def register(self):
121 121 """send the registration_request"""
122 122
123 123 self.log.info("Registering with controller at %s"%self.url)
124 124 ctx = self.context
125 125 connect,maybe_tunnel = self.init_connector()
126 126 reg = ctx.socket(zmq.DEALER)
127 127 reg.setsockopt(zmq.IDENTITY, self.bident)
128 128 connect(reg, self.url)
129 129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
130 130
131 131
132 132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
133 133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
134 134 # print (self.session.key)
135 135 self.session.send(self.registrar, "registration_request",content=content)
136 136
137 137 def complete_registration(self, msg, connect, maybe_tunnel):
138 138 # print msg
139 139 self._abort_dc.stop()
140 140 ctx = self.context
141 141 loop = self.loop
142 142 identity = self.bident
143 143 idents,msg = self.session.feed_identities(msg)
144 144 msg = Message(self.session.unserialize(msg))
145 145
146 146 if msg.content.status == 'ok':
147 147 self.id = int(msg.content.id)
148 148
149 149 # launch heartbeat
150 150 hb_addrs = msg.content.heartbeat
151 151
152 152 # possibly forward hb ports with tunnels
153 153 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
154 154 heart = Heart(*map(str, hb_addrs), heart_id=identity)
155 155 heart.start()
156 156
157 157 # create Shell Streams (MUX, Task, etc.):
158 158 queue_addr = msg.content.mux
159 159 shell_addrs = [ str(queue_addr) ]
160 160 task_addr = msg.content.task
161 161 if task_addr:
162 162 shell_addrs.append(str(task_addr))
163 163
164 164 # Uncomment this to go back to two-socket model
165 165 # shell_streams = []
166 166 # for addr in shell_addrs:
167 167 # stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
168 168 # stream.setsockopt(zmq.IDENTITY, identity)
169 169 # stream.connect(disambiguate_url(addr, self.location))
170 170 # shell_streams.append(stream)
171 171
172 172 # Now use only one shell stream for mux and tasks
173 173 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
174 174 stream.setsockopt(zmq.IDENTITY, identity)
175 175 shell_streams = [stream]
176 176 for addr in shell_addrs:
177 177 connect(stream, addr)
178 178 # end single stream-socket
179 179
180 180 # control stream:
181 181 control_addr = str(msg.content.control)
182 182 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
183 183 control_stream.setsockopt(zmq.IDENTITY, identity)
184 184 connect(control_stream, control_addr)
185 185
186 186 # create iopub stream:
187 187 iopub_addr = msg.content.iopub
188 188 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
189 189 iopub_stream.setsockopt(zmq.IDENTITY, identity)
190 190 connect(iopub_stream, iopub_addr)
191 191
192 192 # # Redirect input streams and set a display hook.
193 193 if self.out_stream_factory:
194 194 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
195 sys.stdout.topic = 'engine.%i.stdout'%self.id
195 sys.stdout.topic = py3compat.cast_bytes('engine.%i.stdout' % self.id)
196 196 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
197 sys.stderr.topic = 'engine.%i.stderr'%self.id
197 sys.stderr.topic = py3compat.cast_bytes('engine.%i.stderr' % self.id)
198 198 if self.display_hook_factory:
199 199 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
200 sys.displayhook.topic = 'engine.%i.pyout'%self.id
200 sys.displayhook.topic = py3compat.cast_bytes('engine.%i.pyout' % self.id)
201 201
202 202 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
203 203 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
204 204 loop=loop, user_ns = self.user_ns, log=self.log)
205 205 self.kernel.start()
206 206
207 207
208 208 else:
209 209 self.log.fatal("Registration Failed: %s"%msg)
210 210 raise Exception("Registration Failed: %s"%msg)
211 211
212 212 self.log.info("Completed registration with id %i"%self.id)
213 213
214 214
215 215 def abort(self):
216 216 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
217 217 if self.url.startswith('127.'):
218 218 self.log.fatal("""
219 219 If the controller and engines are not on the same machine,
220 220 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
221 221 c.HubFactory.ip='*' # for all interfaces, internal and external
222 222 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
223 223 or tunnel connections via ssh.
224 224 """)
225 225 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
226 226 time.sleep(1)
227 227 sys.exit(255)
228 228
229 229 def start(self):
230 230 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
231 231 dc.start()
232 232 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
233 233 self._abort_dc.start()
234 234
@@ -1,460 +1,463 b''
1 1 # -*- coding: utf-8 -*-
2 2 """test View objects
3 3
4 4 Authors:
5 5
6 6 * Min RK
7 7 """
8 8 #-------------------------------------------------------------------------------
9 9 # Copyright (C) 2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-------------------------------------------------------------------------------
14 14
15 15 #-------------------------------------------------------------------------------
16 16 # Imports
17 17 #-------------------------------------------------------------------------------
18 18
19 19 import sys
20 20 import time
21 21 from tempfile import mktemp
22 22 from StringIO import StringIO
23 23
24 24 import zmq
25 25 from nose import SkipTest
26 26
27 from IPython.testing import decorators as dec
28
27 29 from IPython import parallel as pmod
28 30 from IPython.parallel import error
29 31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
30 32 from IPython.parallel import DirectView
31 33 from IPython.parallel.util import interactive
32 34
33 35 from IPython.parallel.tests import add_engines
34 36
35 37 from .clienttest import ClusterTestCase, crash, wait, skip_without
36 38
37 39 def setup():
38 40 add_engines(3)
39 41
40 42 class TestView(ClusterTestCase):
41 43
42 44 def test_z_crash_mux(self):
43 45 """test graceful handling of engine death (direct)"""
44 46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
45 47 # self.add_engines(1)
46 48 eid = self.client.ids[-1]
47 49 ar = self.client[eid].apply_async(crash)
48 50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
49 51 eid = ar.engine_id
50 52 tic = time.time()
51 53 while eid in self.client.ids and time.time()-tic < 5:
52 54 time.sleep(.01)
53 55 self.client.spin()
54 56 self.assertFalse(eid in self.client.ids, "Engine should have died")
55 57
56 58 def test_push_pull(self):
57 59 """test pushing and pulling"""
58 60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
59 61 t = self.client.ids[-1]
60 62 v = self.client[t]
61 63 push = v.push
62 64 pull = v.pull
63 65 v.block=True
64 66 nengines = len(self.client)
65 67 push({'data':data})
66 68 d = pull('data')
67 69 self.assertEquals(d, data)
68 70 self.client[:].push({'data':data})
69 71 d = self.client[:].pull('data', block=True)
70 72 self.assertEquals(d, nengines*[data])
71 73 ar = push({'data':data}, block=False)
72 74 self.assertTrue(isinstance(ar, AsyncResult))
73 75 r = ar.get()
74 76 ar = self.client[:].pull('data', block=False)
75 77 self.assertTrue(isinstance(ar, AsyncResult))
76 78 r = ar.get()
77 79 self.assertEquals(r, nengines*[data])
78 80 self.client[:].push(dict(a=10,b=20))
79 81 r = self.client[:].pull(('a','b'), block=True)
80 82 self.assertEquals(r, nengines*[[10,20]])
81 83
82 84 def test_push_pull_function(self):
83 85 "test pushing and pulling functions"
84 86 def testf(x):
85 87 return 2.0*x
86 88
87 89 t = self.client.ids[-1]
88 90 v = self.client[t]
89 91 v.block=True
90 92 push = v.push
91 93 pull = v.pull
92 94 execute = v.execute
93 95 push({'testf':testf})
94 96 r = pull('testf')
95 97 self.assertEqual(r(1.0), testf(1.0))
96 98 execute('r = testf(10)')
97 99 r = pull('r')
98 100 self.assertEquals(r, testf(10))
99 101 ar = self.client[:].push({'testf':testf}, block=False)
100 102 ar.get()
101 103 ar = self.client[:].pull('testf', block=False)
102 104 rlist = ar.get()
103 105 for r in rlist:
104 106 self.assertEqual(r(1.0), testf(1.0))
105 107 execute("def g(x): return x*x")
106 108 r = pull(('testf','g'))
107 109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
108 110
109 111 def test_push_function_globals(self):
110 112 """test that pushed functions have access to globals"""
111 113 @interactive
112 114 def geta():
113 115 return a
114 116 # self.add_engines(1)
115 117 v = self.client[-1]
116 118 v.block=True
117 119 v['f'] = geta
118 120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
119 121 v.execute('a=5')
120 122 v.execute('b=f()')
121 123 self.assertEquals(v['b'], 5)
122 124
123 125 def test_push_function_defaults(self):
124 126 """test that pushed functions preserve default args"""
125 127 def echo(a=10):
126 128 return a
127 129 v = self.client[-1]
128 130 v.block=True
129 131 v['f'] = echo
130 132 v.execute('b=f()')
131 133 self.assertEquals(v['b'], 10)
132 134
133 135 def test_get_result(self):
134 136 """test getting results from the Hub."""
135 137 c = pmod.Client(profile='iptest')
136 138 # self.add_engines(1)
137 139 t = c.ids[-1]
138 140 v = c[t]
139 141 v2 = self.client[t]
140 142 ar = v.apply_async(wait, 1)
141 143 # give the monitor time to notice the message
142 144 time.sleep(.25)
143 145 ahr = v2.get_result(ar.msg_ids)
144 146 self.assertTrue(isinstance(ahr, AsyncHubResult))
145 147 self.assertEquals(ahr.get(), ar.get())
146 148 ar2 = v2.get_result(ar.msg_ids)
147 149 self.assertFalse(isinstance(ar2, AsyncHubResult))
148 150 c.spin()
149 151 c.close()
150 152
151 153 def test_run_newline(self):
152 154 """test that run appends newline to files"""
153 155 tmpfile = mktemp()
154 156 with open(tmpfile, 'w') as f:
155 157 f.write("""def g():
156 158 return 5
157 159 """)
158 160 v = self.client[-1]
159 161 v.run(tmpfile, block=True)
160 162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
161 163
162 164 def test_apply_tracked(self):
163 165 """test tracking for apply"""
164 166 # self.add_engines(1)
165 167 t = self.client.ids[-1]
166 168 v = self.client[t]
167 169 v.block=False
168 170 def echo(n=1024*1024, **kwargs):
169 171 with v.temp_flags(**kwargs):
170 172 return v.apply(lambda x: x, 'x'*n)
171 173 ar = echo(1, track=False)
172 174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
173 175 self.assertTrue(ar.sent)
174 176 ar = echo(track=True)
175 177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
176 178 self.assertEquals(ar.sent, ar._tracker.done)
177 179 ar._tracker.wait()
178 180 self.assertTrue(ar.sent)
179 181
180 182 def test_push_tracked(self):
181 183 t = self.client.ids[-1]
182 184 ns = dict(x='x'*1024*1024)
183 185 v = self.client[t]
184 186 ar = v.push(ns, block=False, track=False)
185 187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
186 188 self.assertTrue(ar.sent)
187 189
188 190 ar = v.push(ns, block=False, track=True)
189 191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
190 192 ar._tracker.wait()
191 193 self.assertEquals(ar.sent, ar._tracker.done)
192 194 self.assertTrue(ar.sent)
193 195 ar.get()
194 196
195 197 def test_scatter_tracked(self):
196 198 t = self.client.ids
197 199 x='x'*1024*1024
198 200 ar = self.client[t].scatter('x', x, block=False, track=False)
199 201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
200 202 self.assertTrue(ar.sent)
201 203
202 204 ar = self.client[t].scatter('x', x, block=False, track=True)
203 205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
204 206 self.assertEquals(ar.sent, ar._tracker.done)
205 207 ar._tracker.wait()
206 208 self.assertTrue(ar.sent)
207 209 ar.get()
208 210
209 211 def test_remote_reference(self):
210 212 v = self.client[-1]
211 213 v['a'] = 123
212 214 ra = pmod.Reference('a')
213 215 b = v.apply_sync(lambda x: x, ra)
214 216 self.assertEquals(b, 123)
215 217
216 218
217 219 def test_scatter_gather(self):
218 220 view = self.client[:]
219 221 seq1 = range(16)
220 222 view.scatter('a', seq1)
221 223 seq2 = view.gather('a', block=True)
222 224 self.assertEquals(seq2, seq1)
223 225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
224 226
225 227 @skip_without('numpy')
226 228 def test_scatter_gather_numpy(self):
227 229 import numpy
228 230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
229 231 view = self.client[:]
230 232 a = numpy.arange(64)
231 233 view.scatter('a', a)
232 234 b = view.gather('a', block=True)
233 235 assert_array_equal(b, a)
234 236
235 237 def test_map(self):
236 238 view = self.client[:]
237 239 def f(x):
238 240 return x**2
239 241 data = range(16)
240 242 r = view.map_sync(f, data)
241 243 self.assertEquals(r, map(f, data))
242 244
243 245 def test_map_iterable(self):
244 246 """test map on iterables (direct)"""
245 247 view = self.client[:]
246 248 # 101 is prime, so it won't be evenly distributed
247 249 arr = range(101)
248 250 # ensure it will be an iterator, even in Python 3
249 251 it = iter(arr)
250 252 r = view.map_sync(lambda x:x, arr)
251 253 self.assertEquals(r, list(arr))
252 254
253 255 def test_scatterGatherNonblocking(self):
254 256 data = range(16)
255 257 view = self.client[:]
256 258 view.scatter('a', data, block=False)
257 259 ar = view.gather('a', block=False)
258 260 self.assertEquals(ar.get(), data)
259 261
260 262 @skip_without('numpy')
261 263 def test_scatter_gather_numpy_nonblocking(self):
262 264 import numpy
263 265 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
264 266 a = numpy.arange(64)
265 267 view = self.client[:]
266 268 ar = view.scatter('a', a, block=False)
267 269 self.assertTrue(isinstance(ar, AsyncResult))
268 270 amr = view.gather('a', block=False)
269 271 self.assertTrue(isinstance(amr, AsyncMapResult))
270 272 assert_array_equal(amr.get(), a)
271 273
272 274 def test_execute(self):
273 275 view = self.client[:]
274 276 # self.client.debug=True
275 277 execute = view.execute
276 278 ar = execute('c=30', block=False)
277 279 self.assertTrue(isinstance(ar, AsyncResult))
278 280 ar = execute('d=[0,1,2]', block=False)
279 281 self.client.wait(ar, 1)
280 282 self.assertEquals(len(ar.get()), len(self.client))
281 283 for c in view['c']:
282 284 self.assertEquals(c, 30)
283 285
284 286 def test_abort(self):
285 287 view = self.client[-1]
286 288 ar = view.execute('import time; time.sleep(1)', block=False)
287 289 ar2 = view.apply_async(lambda : 2)
288 290 ar3 = view.apply_async(lambda : 3)
289 291 view.abort(ar2)
290 292 view.abort(ar3.msg_ids)
291 293 self.assertRaises(error.TaskAborted, ar2.get)
292 294 self.assertRaises(error.TaskAborted, ar3.get)
293 295
294 296 def test_temp_flags(self):
295 297 view = self.client[-1]
296 298 view.block=True
297 299 with view.temp_flags(block=False):
298 300 self.assertFalse(view.block)
299 301 self.assertTrue(view.block)
300 302
303 @dec.known_failure_py3
301 304 def test_importer(self):
302 305 view = self.client[-1]
303 306 view.clear(block=True)
304 307 with view.importer:
305 308 import re
306 309
307 310 @interactive
308 311 def findall(pat, s):
309 312 # this globals() step isn't necessary in real code
310 313 # only to prevent a closure in the test
311 314 re = globals()['re']
312 315 return re.findall(pat, s)
313 316
314 317 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
315 318
316 319 # parallel magic tests
317 320
318 321 def test_magic_px_blocking(self):
319 322 ip = get_ipython()
320 323 v = self.client[-1]
321 324 v.activate()
322 325 v.block=True
323 326
324 327 ip.magic_px('a=5')
325 328 self.assertEquals(v['a'], 5)
326 329 ip.magic_px('a=10')
327 330 self.assertEquals(v['a'], 10)
328 331 sio = StringIO()
329 332 savestdout = sys.stdout
330 333 sys.stdout = sio
331 334 # just 'print a' worst ~99% of the time, but this ensures that
332 335 # the stdout message has arrived when the result is finished:
333 ip.magic_px('import sys,time;print a; sys.stdout.flush();time.sleep(0.2)')
336 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
334 337 sys.stdout = savestdout
335 338 buf = sio.getvalue()
336 339 self.assertTrue('[stdout:' in buf, buf)
337 340 self.assertTrue(buf.rstrip().endswith('10'))
338 341 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
339 342
340 343 def test_magic_px_nonblocking(self):
341 344 ip = get_ipython()
342 345 v = self.client[-1]
343 346 v.activate()
344 347 v.block=False
345 348
346 349 ip.magic_px('a=5')
347 350 self.assertEquals(v['a'], 5)
348 351 ip.magic_px('a=10')
349 352 self.assertEquals(v['a'], 10)
350 353 sio = StringIO()
351 354 savestdout = sys.stdout
352 355 sys.stdout = sio
353 356 ip.magic_px('print a')
354 357 sys.stdout = savestdout
355 358 buf = sio.getvalue()
356 359 self.assertFalse('[stdout:%i]'%v.targets in buf)
357 360 ip.magic_px('1/0')
358 361 ar = v.get_result(-1)
359 362 self.assertRaisesRemote(ZeroDivisionError, ar.get)
360 363
361 364 def test_magic_autopx_blocking(self):
362 365 ip = get_ipython()
363 366 v = self.client[-1]
364 367 v.activate()
365 368 v.block=True
366 369
367 370 sio = StringIO()
368 371 savestdout = sys.stdout
369 372 sys.stdout = sio
370 373 ip.magic_autopx()
371 374 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
372 375 ip.run_cell('print b')
373 376 ip.run_cell("b/c")
374 377 ip.run_code(compile('b*=2', '', 'single'))
375 378 ip.magic_autopx()
376 379 sys.stdout = savestdout
377 380 output = sio.getvalue().strip()
378 381 self.assertTrue(output.startswith('%autopx enabled'))
379 382 self.assertTrue(output.endswith('%autopx disabled'))
380 383 self.assertTrue('RemoteError: ZeroDivisionError' in output)
381 384 ar = v.get_result(-2)
382 385 self.assertEquals(v['a'], 5)
383 386 self.assertEquals(v['b'], 20)
384 387 self.assertRaisesRemote(ZeroDivisionError, ar.get)
385 388
386 389 def test_magic_autopx_nonblocking(self):
387 390 ip = get_ipython()
388 391 v = self.client[-1]
389 392 v.activate()
390 393 v.block=False
391 394
392 395 sio = StringIO()
393 396 savestdout = sys.stdout
394 397 sys.stdout = sio
395 398 ip.magic_autopx()
396 399 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
397 400 ip.run_cell('print b')
398 401 ip.run_cell("b/c")
399 402 ip.run_code(compile('b*=2', '', 'single'))
400 403 ip.magic_autopx()
401 404 sys.stdout = savestdout
402 405 output = sio.getvalue().strip()
403 406 self.assertTrue(output.startswith('%autopx enabled'))
404 407 self.assertTrue(output.endswith('%autopx disabled'))
405 408 self.assertFalse('ZeroDivisionError' in output)
406 409 ar = v.get_result(-2)
407 410 self.assertEquals(v['a'], 5)
408 411 self.assertEquals(v['b'], 20)
409 412 self.assertRaisesRemote(ZeroDivisionError, ar.get)
410 413
411 414 def test_magic_result(self):
412 415 ip = get_ipython()
413 416 v = self.client[-1]
414 417 v.activate()
415 418 v['a'] = 111
416 419 ra = v['a']
417 420
418 421 ar = ip.magic_result()
419 422 self.assertEquals(ar.msg_ids, [v.history[-1]])
420 423 self.assertEquals(ar.get(), 111)
421 424 ar = ip.magic_result('-2')
422 425 self.assertEquals(ar.msg_ids, [v.history[-2]])
423 426
424 427 def test_unicode_execute(self):
425 428 """test executing unicode strings"""
426 429 v = self.client[-1]
427 430 v.block=True
428 431 if sys.version_info[0] >= 3:
429 432 code="a='é'"
430 433 else:
431 434 code=u"a=u'é'"
432 435 v.execute(code)
433 436 self.assertEquals(v['a'], u'é')
434 437
435 438 def test_unicode_apply_result(self):
436 439 """test unicode apply results"""
437 440 v = self.client[-1]
438 441 r = v.apply_sync(lambda : u'é')
439 442 self.assertEquals(r, u'é')
440 443
441 444 def test_unicode_apply_arg(self):
442 445 """test passing unicode arguments to apply"""
443 446 v = self.client[-1]
444 447
445 448 @interactive
446 449 def check_unicode(a, check):
447 450 assert isinstance(a, unicode), "%r is not unicode"%a
448 451 assert isinstance(check, bytes), "%r is not bytes"%check
449 452 assert a.encode('utf8') == check, "%s != %s"%(a,check)
450 453
451 454 for s in [ u'é', u'ßø®∫',u'asdf' ]:
452 455 try:
453 456 v.apply_sync(check_unicode, s, s.encode('utf8'))
454 457 except error.RemoteError as e:
455 458 if e.ename == 'AssertionError':
456 459 self.fail(e.evalue)
457 460 else:
458 461 raise e
459 462
460 463
General Comments 0
You need to be logged in to leave comments. Login now