##// END OF EJS Templates
don't allow gethostbyname(gethostname()) failure to crash ipcontroller...
MinRK -
Show More
@@ -1,425 +1,431 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import os
27 27 import socket
28 28 import stat
29 29 import sys
30 30 import uuid
31 31
32 32 from multiprocessing import Process
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37 from zmq.utils import jsonapi as json
38 38
39 39 from IPython.config.application import boolean_flag
40 40 from IPython.core.profiledir import ProfileDir
41 41
42 42 from IPython.parallel.apps.baseapp import (
43 43 BaseParallelApplication,
44 44 base_aliases,
45 45 base_flags,
46 46 )
47 47 from IPython.utils.importstring import import_item
48 48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict
49 49
50 50 # from IPython.parallel.controller.controller import ControllerFactory
51 51 from IPython.zmq.session import Session
52 52 from IPython.parallel.controller.heartmonitor import HeartMonitor
53 53 from IPython.parallel.controller.hub import HubFactory
54 54 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
55 55 from IPython.parallel.controller.sqlitedb import SQLiteDB
56 56
57 57 from IPython.parallel.util import signal_children, split_url, asbytes
58 58
59 59 # conditional import of MongoDB backend class
60 60
61 61 try:
62 62 from IPython.parallel.controller.mongodb import MongoDB
63 63 except ImportError:
64 64 maybe_mongo = []
65 65 else:
66 66 maybe_mongo = [MongoDB]
67 67
68 68
69 69 #-----------------------------------------------------------------------------
70 70 # Module level variables
71 71 #-----------------------------------------------------------------------------
72 72
73 73
74 74 #: The default config file name for this application
75 75 default_config_file_name = u'ipcontroller_config.py'
76 76
77 77
78 78 _description = """Start the IPython controller for parallel computing.
79 79
80 80 The IPython controller provides a gateway between the IPython engines and
81 81 clients. The controller needs to be started before the engines and can be
82 82 configured using command line options or using a cluster directory. Cluster
83 83 directories contain config, log and security files and are usually located in
84 84 your ipython directory and named as "profile_name". See the `profile`
85 85 and `profile-dir` options for details.
86 86 """
87 87
88 88 _examples = """
89 89 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
90 90 ipcontroller --scheme=pure # use the pure zeromq scheduler
91 91 """
92 92
93 93
94 94 #-----------------------------------------------------------------------------
95 95 # The main application
96 96 #-----------------------------------------------------------------------------
97 97 flags = {}
98 98 flags.update(base_flags)
99 99 flags.update({
100 100 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
101 101 'Use threads instead of processes for the schedulers'),
102 102 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
103 103 'use the SQLiteDB backend'),
104 104 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
105 105 'use the MongoDB backend'),
106 106 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
107 107 'use the in-memory DictDB backend'),
108 108 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
109 109 'reuse existing json connection files')
110 110 })
111 111
112 112 flags.update(boolean_flag('secure', 'IPControllerApp.secure',
113 113 "Use HMAC digests for authentication of messages.",
114 114 "Don't authenticate messages."
115 115 ))
116 116 aliases = dict(
117 117 secure = 'IPControllerApp.secure',
118 118 ssh = 'IPControllerApp.ssh_server',
119 119 location = 'IPControllerApp.location',
120 120
121 121 ident = 'Session.session',
122 122 user = 'Session.username',
123 123 keyfile = 'Session.keyfile',
124 124
125 125 url = 'HubFactory.url',
126 126 ip = 'HubFactory.ip',
127 127 transport = 'HubFactory.transport',
128 128 port = 'HubFactory.regport',
129 129
130 130 ping = 'HeartMonitor.period',
131 131
132 132 scheme = 'TaskScheduler.scheme_name',
133 133 hwm = 'TaskScheduler.hwm',
134 134 )
135 135 aliases.update(base_aliases)
136 136
137 137
138 138 class IPControllerApp(BaseParallelApplication):
139 139
140 140 name = u'ipcontroller'
141 141 description = _description
142 142 examples = _examples
143 143 config_file_name = Unicode(default_config_file_name)
144 144 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
145 145
146 146 # change default to True
147 147 auto_create = Bool(True, config=True,
148 148 help="""Whether to create profile dir if it doesn't exist.""")
149 149
150 150 reuse_files = Bool(False, config=True,
151 151 help='Whether to reuse existing json connection files.'
152 152 )
153 153 secure = Bool(True, config=True,
154 154 help='Whether to use HMAC digests for extra message authentication.'
155 155 )
156 156 ssh_server = Unicode(u'', config=True,
157 157 help="""ssh url for clients to use when connecting to the Controller
158 158 processes. It should be of the form: [user@]server[:port]. The
159 159 Controller's listening addresses must be accessible from the ssh server""",
160 160 )
161 161 location = Unicode(u'', config=True,
162 162 help="""The external IP or domain name of the Controller, used for disambiguating
163 163 engine and client connections.""",
164 164 )
165 165 import_statements = List([], config=True,
166 166 help="import statements to be run at startup. Necessary in some environments"
167 167 )
168 168
169 169 use_threads = Bool(False, config=True,
170 170 help='Use threads instead of processes for the schedulers',
171 171 )
172 172
173 173 # internal
174 174 children = List()
175 175 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
176 176
177 177 def _use_threads_changed(self, name, old, new):
178 178 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
179 179
180 180 aliases = Dict(aliases)
181 181 flags = Dict(flags)
182 182
183 183
184 184 def save_connection_dict(self, fname, cdict):
185 185 """save a connection dict to json file."""
186 186 c = self.config
187 187 url = cdict['url']
188 188 location = cdict['location']
189 189 if not location:
190 190 try:
191 191 proto,ip,port = split_url(url)
192 192 except AssertionError:
193 193 pass
194 194 else:
195 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
195 try:
196 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
197 except (socket.gaierror, IndexError):
198 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
199 " You may need to specify '--location=<external_ip_address>' to help"
200 " IPython decide when to connect via loopback.")
201 location = '127.0.0.1'
196 202 cdict['location'] = location
197 203 fname = os.path.join(self.profile_dir.security_dir, fname)
198 204 with open(fname, 'wb') as f:
199 205 f.write(json.dumps(cdict, indent=2))
200 206 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
201 207
202 208 def load_config_from_json(self):
203 209 """load config from existing json connector files."""
204 210 c = self.config
205 211 # load from engine config
206 212 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-engine.json')) as f:
207 213 cfg = json.loads(f.read())
208 214 key = c.Session.key = asbytes(cfg['exec_key'])
209 215 xport,addr = cfg['url'].split('://')
210 216 c.HubFactory.engine_transport = xport
211 217 ip,ports = addr.split(':')
212 218 c.HubFactory.engine_ip = ip
213 219 c.HubFactory.regport = int(ports)
214 220 self.location = cfg['location']
215 221 # load client config
216 222 with open(os.path.join(self.profile_dir.security_dir, 'ipcontroller-client.json')) as f:
217 223 cfg = json.loads(f.read())
218 224 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
219 225 xport,addr = cfg['url'].split('://')
220 226 c.HubFactory.client_transport = xport
221 227 ip,ports = addr.split(':')
222 228 c.HubFactory.client_ip = ip
223 229 self.ssh_server = cfg['ssh']
224 230 assert int(ports) == c.HubFactory.regport, "regport mismatch"
225 231
226 232 def init_hub(self):
227 233 c = self.config
228 234
229 235 self.do_import_statements()
230 236 reusing = self.reuse_files
231 237 if reusing:
232 238 try:
233 239 self.load_config_from_json()
234 240 except (AssertionError,IOError):
235 241 reusing=False
236 242 # check again, because reusing may have failed:
237 243 if reusing:
238 244 pass
239 245 elif self.secure:
240 246 key = str(uuid.uuid4())
241 247 # keyfile = os.path.join(self.profile_dir.security_dir, self.exec_key)
242 248 # with open(keyfile, 'w') as f:
243 249 # f.write(key)
244 250 # os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
245 251 c.Session.key = asbytes(key)
246 252 else:
247 253 key = c.Session.key = b''
248 254
249 255 try:
250 256 self.factory = HubFactory(config=c, log=self.log)
251 257 # self.start_logging()
252 258 self.factory.init_hub()
253 259 except:
254 260 self.log.error("Couldn't construct the Controller", exc_info=True)
255 261 self.exit(1)
256 262
257 263 if not reusing:
258 264 # save to new json config files
259 265 f = self.factory
260 266 cdict = {'exec_key' : key,
261 267 'ssh' : self.ssh_server,
262 268 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
263 269 'location' : self.location
264 270 }
265 271 self.save_connection_dict('ipcontroller-client.json', cdict)
266 272 edict = cdict
267 273 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
268 274 self.save_connection_dict('ipcontroller-engine.json', edict)
269 275
270 276 #
271 277 def init_schedulers(self):
272 278 children = self.children
273 279 mq = import_item(str(self.mq_class))
274 280
275 281 hub = self.factory
276 282 # maybe_inproc = 'inproc://monitor' if self.use_threads else self.monitor_url
277 283 # IOPub relay (in a Process)
278 284 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
279 285 q.bind_in(hub.client_info['iopub'])
280 286 q.bind_out(hub.engine_info['iopub'])
281 287 q.setsockopt_out(zmq.SUBSCRIBE, b'')
282 288 q.connect_mon(hub.monitor_url)
283 289 q.daemon=True
284 290 children.append(q)
285 291
286 292 # Multiplexer Queue (in a Process)
287 293 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'in', b'out')
288 294 q.bind_in(hub.client_info['mux'])
289 295 q.setsockopt_in(zmq.IDENTITY, b'mux')
290 296 q.bind_out(hub.engine_info['mux'])
291 297 q.connect_mon(hub.monitor_url)
292 298 q.daemon=True
293 299 children.append(q)
294 300
295 301 # Control Queue (in a Process)
296 302 q = mq(zmq.XREP, zmq.XREP, zmq.PUB, b'incontrol', b'outcontrol')
297 303 q.bind_in(hub.client_info['control'])
298 304 q.setsockopt_in(zmq.IDENTITY, b'control')
299 305 q.bind_out(hub.engine_info['control'])
300 306 q.connect_mon(hub.monitor_url)
301 307 q.daemon=True
302 308 children.append(q)
303 309 try:
304 310 scheme = self.config.TaskScheduler.scheme_name
305 311 except AttributeError:
306 312 scheme = TaskScheduler.scheme_name.get_default_value()
307 313 # Task Queue (in a Process)
308 314 if scheme == 'pure':
309 315 self.log.warn("task::using pure XREQ Task scheduler")
310 316 q = mq(zmq.XREP, zmq.XREQ, zmq.PUB, b'intask', b'outtask')
311 317 # q.setsockopt_out(zmq.HWM, hub.hwm)
312 318 q.bind_in(hub.client_info['task'][1])
313 319 q.setsockopt_in(zmq.IDENTITY, b'task')
314 320 q.bind_out(hub.engine_info['task'])
315 321 q.connect_mon(hub.monitor_url)
316 322 q.daemon=True
317 323 children.append(q)
318 324 elif scheme == 'none':
319 325 self.log.warn("task::using no Task scheduler")
320 326
321 327 else:
322 328 self.log.info("task::using Python %s Task scheduler"%scheme)
323 329 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
324 330 hub.monitor_url, hub.client_info['notification'])
325 331 kwargs = dict(logname='scheduler', loglevel=self.log_level,
326 332 log_url = self.log_url, config=dict(self.config))
327 333 if 'Process' in self.mq_class:
328 334 # run the Python scheduler in a Process
329 335 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
330 336 q.daemon=True
331 337 children.append(q)
332 338 else:
333 339 # single-threaded Controller
334 340 kwargs['in_thread'] = True
335 341 launch_scheduler(*sargs, **kwargs)
336 342
337 343
338 344 def save_urls(self):
339 345 """save the registration urls to files."""
340 346 c = self.config
341 347
342 348 sec_dir = self.profile_dir.security_dir
343 349 cf = self.factory
344 350
345 351 with open(os.path.join(sec_dir, 'ipcontroller-engine.url'), 'w') as f:
346 352 f.write("%s://%s:%s"%(cf.engine_transport, cf.engine_ip, cf.regport))
347 353
348 354 with open(os.path.join(sec_dir, 'ipcontroller-client.url'), 'w') as f:
349 355 f.write("%s://%s:%s"%(cf.client_transport, cf.client_ip, cf.regport))
350 356
351 357
352 358 def do_import_statements(self):
353 359 statements = self.import_statements
354 360 for s in statements:
355 361 try:
356 362 self.log.msg("Executing statement: '%s'" % s)
357 363 exec s in globals(), locals()
358 364 except:
359 365 self.log.msg("Error running statement: %s" % s)
360 366
361 367 def forward_logging(self):
362 368 if self.log_url:
363 369 self.log.info("Forwarding logging to %s"%self.log_url)
364 370 context = zmq.Context.instance()
365 371 lsock = context.socket(zmq.PUB)
366 372 lsock.connect(self.log_url)
367 373 handler = PUBHandler(lsock)
368 374 self.log.removeHandler(self._log_handler)
369 375 handler.root_topic = 'controller'
370 376 handler.setLevel(self.log_level)
371 377 self.log.addHandler(handler)
372 378 self._log_handler = handler
373 379 # #
374 380
375 381 def initialize(self, argv=None):
376 382 super(IPControllerApp, self).initialize(argv)
377 383 self.forward_logging()
378 384 self.init_hub()
379 385 self.init_schedulers()
380 386
381 387 def start(self):
382 388 # Start the subprocesses:
383 389 self.factory.start()
384 390 child_procs = []
385 391 for child in self.children:
386 392 child.start()
387 393 if isinstance(child, ProcessMonitoredQueue):
388 394 child_procs.append(child.launcher)
389 395 elif isinstance(child, Process):
390 396 child_procs.append(child)
391 397 if child_procs:
392 398 signal_children(child_procs)
393 399
394 400 self.write_pid_file(overwrite=True)
395 401
396 402 try:
397 403 self.factory.loop.start()
398 404 except KeyboardInterrupt:
399 405 self.log.critical("Interrupted, Exiting...\n")
400 406
401 407
402 408
403 409 def launch_new_instance():
404 410 """Create and run the IPython controller"""
405 411 if sys.platform == 'win32':
406 412 # make sure we don't get called from a multiprocessing subprocess
407 413 # this can result in infinite Controllers being started on Windows
408 414 # which doesn't have a proper fork, so multiprocessing is wonky
409 415
410 416 # this only comes up when IPython has been installed using vanilla
411 417 # setuptools, and *not* distribute.
412 418 import multiprocessing
413 419 p = multiprocessing.current_process()
414 420 # the main process has name 'MainProcess'
415 421 # subprocesses will have names like 'Process-1'
416 422 if p.name != 'MainProcess':
417 423 # we are a subprocess, don't start another Controller!
418 424 return
419 425 app = IPControllerApp.instance()
420 426 app.initialize()
421 427 app.start()
422 428
423 429
424 430 if __name__ == '__main__':
425 431 launch_new_instance()
@@ -1,456 +1,461 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 44 from IPython.utils.newserialized import serialize, unserialize
45 45 from IPython.zmq.log import EnginePUBHandler
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Classes
49 49 #-----------------------------------------------------------------------------
50 50
51 51 class Namespace(dict):
52 52 """Subclass of dict for attribute access to keys."""
53 53
54 54 def __getattr__(self, key):
55 55 """getattr aliased to getitem"""
56 56 if key in self.iterkeys():
57 57 return self[key]
58 58 else:
59 59 raise NameError(key)
60 60
61 61 def __setattr__(self, key, value):
62 62 """setattr aliased to setitem, with strict"""
63 63 if hasattr(dict, key):
64 64 raise KeyError("Cannot override dict keys %r"%key)
65 65 self[key] = value
66 66
67 67
68 68 class ReverseDict(dict):
69 69 """simple double-keyed subset of dict methods."""
70 70
71 71 def __init__(self, *args, **kwargs):
72 72 dict.__init__(self, *args, **kwargs)
73 73 self._reverse = dict()
74 74 for key, value in self.iteritems():
75 75 self._reverse[value] = key
76 76
77 77 def __getitem__(self, key):
78 78 try:
79 79 return dict.__getitem__(self, key)
80 80 except KeyError:
81 81 return self._reverse[key]
82 82
83 83 def __setitem__(self, key, value):
84 84 if key in self._reverse:
85 85 raise KeyError("Can't have key %r on both sides!"%key)
86 86 dict.__setitem__(self, key, value)
87 87 self._reverse[value] = key
88 88
89 89 def pop(self, key):
90 90 value = dict.pop(self, key)
91 91 self._reverse.pop(value)
92 92 return value
93 93
94 94 def get(self, key, default=None):
95 95 try:
96 96 return self[key]
97 97 except KeyError:
98 98 return default
99 99
100 100 #-----------------------------------------------------------------------------
101 101 # Functions
102 102 #-----------------------------------------------------------------------------
103 103
104 104 def asbytes(s):
105 105 """ensure that an object is ascii bytes"""
106 106 if isinstance(s, unicode):
107 107 s = s.encode('ascii')
108 108 return s
109 109
110 110 def validate_url(url):
111 111 """validate a url for zeromq"""
112 112 if not isinstance(url, basestring):
113 113 raise TypeError("url must be a string, not %r"%type(url))
114 114 url = url.lower()
115 115
116 116 proto_addr = url.split('://')
117 117 assert len(proto_addr) == 2, 'Invalid url: %r'%url
118 118 proto, addr = proto_addr
119 119 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
120 120
121 121 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
122 122 # author: Remi Sabourin
123 123 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
124 124
125 125 if proto == 'tcp':
126 126 lis = addr.split(':')
127 127 assert len(lis) == 2, 'Invalid url: %r'%url
128 128 addr,s_port = lis
129 129 try:
130 130 port = int(s_port)
131 131 except ValueError:
132 132 raise AssertionError("Invalid port %r in url: %r"%(port, url))
133 133
134 134 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
135 135
136 136 else:
137 137 # only validate tcp urls currently
138 138 pass
139 139
140 140 return True
141 141
142 142
143 143 def validate_url_container(container):
144 144 """validate a potentially nested collection of urls."""
145 145 if isinstance(container, basestring):
146 146 url = container
147 147 return validate_url(url)
148 148 elif isinstance(container, dict):
149 149 container = container.itervalues()
150 150
151 151 for element in container:
152 152 validate_url_container(element)
153 153
154 154
155 155 def split_url(url):
156 156 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
157 157 proto_addr = url.split('://')
158 158 assert len(proto_addr) == 2, 'Invalid url: %r'%url
159 159 proto, addr = proto_addr
160 160 lis = addr.split(':')
161 161 assert len(lis) == 2, 'Invalid url: %r'%url
162 162 addr,s_port = lis
163 163 return proto,addr,s_port
164 164
165 165 def disambiguate_ip_address(ip, location=None):
166 166 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
167 167 ones, based on the location (default interpretation of location is localhost)."""
168 168 if ip in ('0.0.0.0', '*'):
169 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
170 if location is None or location in external_ips:
169 try:
170 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
171 except (socket.gaierror, IndexError):
172 # couldn't identify this machine, assume localhost
173 external_ips = []
174 if location is None or location in external_ips or not external_ips:
175 # If location is unspecified or cannot be determined, assume local
171 176 ip='127.0.0.1'
172 177 elif location:
173 178 return location
174 179 return ip
175 180
176 181 def disambiguate_url(url, location=None):
177 182 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
178 183 ones, based on the location (default interpretation is localhost).
179 184
180 185 This is for zeromq urls, such as tcp://*:10101."""
181 186 try:
182 187 proto,ip,port = split_url(url)
183 188 except AssertionError:
184 189 # probably not tcp url; could be ipc, etc.
185 190 return url
186 191
187 192 ip = disambiguate_ip_address(ip,location)
188 193
189 194 return "%s://%s:%s"%(proto,ip,port)
190 195
191 196 def serialize_object(obj, threshold=64e-6):
192 197 """Serialize an object into a list of sendable buffers.
193 198
194 199 Parameters
195 200 ----------
196 201
197 202 obj : object
198 203 The object to be serialized
199 204 threshold : float
200 205 The threshold for not double-pickling the content.
201 206
202 207
203 208 Returns
204 209 -------
205 210 ('pmd', [bufs]) :
206 211 where pmd is the pickled metadata wrapper,
207 212 bufs is a list of data buffers
208 213 """
209 214 databuffers = []
210 215 if isinstance(obj, (list, tuple)):
211 216 clist = canSequence(obj)
212 217 slist = map(serialize, clist)
213 218 for s in slist:
214 219 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
215 220 databuffers.append(s.getData())
216 221 s.data = None
217 222 return pickle.dumps(slist,-1), databuffers
218 223 elif isinstance(obj, dict):
219 224 sobj = {}
220 225 for k in sorted(obj.iterkeys()):
221 226 s = serialize(can(obj[k]))
222 227 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
223 228 databuffers.append(s.getData())
224 229 s.data = None
225 230 sobj[k] = s
226 231 return pickle.dumps(sobj,-1),databuffers
227 232 else:
228 233 s = serialize(can(obj))
229 234 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
230 235 databuffers.append(s.getData())
231 236 s.data = None
232 237 return pickle.dumps(s,-1),databuffers
233 238
234 239
235 240 def unserialize_object(bufs):
236 241 """reconstruct an object serialized by serialize_object from data buffers."""
237 242 bufs = list(bufs)
238 243 sobj = pickle.loads(bufs.pop(0))
239 244 if isinstance(sobj, (list, tuple)):
240 245 for s in sobj:
241 246 if s.data is None:
242 247 s.data = bufs.pop(0)
243 248 return uncanSequence(map(unserialize, sobj)), bufs
244 249 elif isinstance(sobj, dict):
245 250 newobj = {}
246 251 for k in sorted(sobj.iterkeys()):
247 252 s = sobj[k]
248 253 if s.data is None:
249 254 s.data = bufs.pop(0)
250 255 newobj[k] = uncan(unserialize(s))
251 256 return newobj, bufs
252 257 else:
253 258 if sobj.data is None:
254 259 sobj.data = bufs.pop(0)
255 260 return uncan(unserialize(sobj)), bufs
256 261
257 262 def pack_apply_message(f, args, kwargs, threshold=64e-6):
258 263 """pack up a function, args, and kwargs to be sent over the wire
259 264 as a series of buffers. Any object whose data is larger than `threshold`
260 265 will not have their data copied (currently only numpy arrays support zero-copy)"""
261 266 msg = [pickle.dumps(can(f),-1)]
262 267 databuffers = [] # for large objects
263 268 sargs, bufs = serialize_object(args,threshold)
264 269 msg.append(sargs)
265 270 databuffers.extend(bufs)
266 271 skwargs, bufs = serialize_object(kwargs,threshold)
267 272 msg.append(skwargs)
268 273 databuffers.extend(bufs)
269 274 msg.extend(databuffers)
270 275 return msg
271 276
272 277 def unpack_apply_message(bufs, g=None, copy=True):
273 278 """unpack f,args,kwargs from buffers packed by pack_apply_message()
274 279 Returns: original f,args,kwargs"""
275 280 bufs = list(bufs) # allow us to pop
276 281 assert len(bufs) >= 3, "not enough buffers!"
277 282 if not copy:
278 283 for i in range(3):
279 284 bufs[i] = bufs[i].bytes
280 285 cf = pickle.loads(bufs.pop(0))
281 286 sargs = list(pickle.loads(bufs.pop(0)))
282 287 skwargs = dict(pickle.loads(bufs.pop(0)))
283 288 # print sargs, skwargs
284 289 f = uncan(cf, g)
285 290 for sa in sargs:
286 291 if sa.data is None:
287 292 m = bufs.pop(0)
288 293 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
289 294 # always use a buffer, until memoryviews get sorted out
290 295 sa.data = buffer(m)
291 296 # disable memoryview support
292 297 # if copy:
293 298 # sa.data = buffer(m)
294 299 # else:
295 300 # sa.data = m.buffer
296 301 else:
297 302 if copy:
298 303 sa.data = m
299 304 else:
300 305 sa.data = m.bytes
301 306
302 307 args = uncanSequence(map(unserialize, sargs), g)
303 308 kwargs = {}
304 309 for k in sorted(skwargs.iterkeys()):
305 310 sa = skwargs[k]
306 311 if sa.data is None:
307 312 m = bufs.pop(0)
308 313 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
309 314 # always use a buffer, until memoryviews get sorted out
310 315 sa.data = buffer(m)
311 316 # disable memoryview support
312 317 # if copy:
313 318 # sa.data = buffer(m)
314 319 # else:
315 320 # sa.data = m.buffer
316 321 else:
317 322 if copy:
318 323 sa.data = m
319 324 else:
320 325 sa.data = m.bytes
321 326
322 327 kwargs[k] = uncan(unserialize(sa), g)
323 328
324 329 return f,args,kwargs
325 330
326 331 #--------------------------------------------------------------------------
327 332 # helpers for implementing old MEC API via view.apply
328 333 #--------------------------------------------------------------------------
329 334
330 335 def interactive(f):
331 336 """decorator for making functions appear as interactively defined.
332 337 This results in the function being linked to the user_ns as globals()
333 338 instead of the module globals().
334 339 """
335 340 f.__module__ = '__main__'
336 341 return f
337 342
338 343 @interactive
339 344 def _push(ns):
340 345 """helper method for implementing `client.push` via `client.apply`"""
341 346 globals().update(ns)
342 347
343 348 @interactive
344 349 def _pull(keys):
345 350 """helper method for implementing `client.pull` via `client.apply`"""
346 351 user_ns = globals()
347 352 if isinstance(keys, (list,tuple, set)):
348 353 for key in keys:
349 354 if not user_ns.has_key(key):
350 355 raise NameError("name '%s' is not defined"%key)
351 356 return map(user_ns.get, keys)
352 357 else:
353 358 if not user_ns.has_key(keys):
354 359 raise NameError("name '%s' is not defined"%keys)
355 360 return user_ns.get(keys)
356 361
357 362 @interactive
358 363 def _execute(code):
359 364 """helper method for implementing `client.execute` via `client.apply`"""
360 365 exec code in globals()
361 366
362 367 #--------------------------------------------------------------------------
363 368 # extra process management utilities
364 369 #--------------------------------------------------------------------------
365 370
366 371 _random_ports = set()
367 372
368 373 def select_random_ports(n):
369 374 """Selects and return n random ports that are available."""
370 375 ports = []
371 376 for i in xrange(n):
372 377 sock = socket.socket()
373 378 sock.bind(('', 0))
374 379 while sock.getsockname()[1] in _random_ports:
375 380 sock.close()
376 381 sock = socket.socket()
377 382 sock.bind(('', 0))
378 383 ports.append(sock)
379 384 for i, sock in enumerate(ports):
380 385 port = sock.getsockname()[1]
381 386 sock.close()
382 387 ports[i] = port
383 388 _random_ports.add(port)
384 389 return ports
385 390
386 391 def signal_children(children):
387 392 """Relay interupt/term signals to children, for more solid process cleanup."""
388 393 def terminate_children(sig, frame):
389 394 logging.critical("Got signal %i, terminating children..."%sig)
390 395 for child in children:
391 396 child.terminate()
392 397
393 398 sys.exit(sig != SIGINT)
394 399 # sys.exit(sig)
395 400 for sig in (SIGINT, SIGABRT, SIGTERM):
396 401 signal(sig, terminate_children)
397 402
398 403 def generate_exec_key(keyfile):
399 404 import uuid
400 405 newkey = str(uuid.uuid4())
401 406 with open(keyfile, 'w') as f:
402 407 # f.write('ipython-key ')
403 408 f.write(newkey+'\n')
404 409 # set user-only RW permissions (0600)
405 410 # this will have no effect on Windows
406 411 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
407 412
408 413
409 414 def integer_loglevel(loglevel):
410 415 try:
411 416 loglevel = int(loglevel)
412 417 except ValueError:
413 418 if isinstance(loglevel, str):
414 419 loglevel = getattr(logging, loglevel)
415 420 return loglevel
416 421
417 422 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
418 423 logger = logging.getLogger(logname)
419 424 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
420 425 # don't add a second PUBHandler
421 426 return
422 427 loglevel = integer_loglevel(loglevel)
423 428 lsock = context.socket(zmq.PUB)
424 429 lsock.connect(iface)
425 430 handler = handlers.PUBHandler(lsock)
426 431 handler.setLevel(loglevel)
427 432 handler.root_topic = root
428 433 logger.addHandler(handler)
429 434 logger.setLevel(loglevel)
430 435
431 436 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
432 437 logger = logging.getLogger()
433 438 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
434 439 # don't add a second PUBHandler
435 440 return
436 441 loglevel = integer_loglevel(loglevel)
437 442 lsock = context.socket(zmq.PUB)
438 443 lsock.connect(iface)
439 444 handler = EnginePUBHandler(engine, lsock)
440 445 handler.setLevel(loglevel)
441 446 logger.addHandler(handler)
442 447 logger.setLevel(loglevel)
443 448 return logger
444 449
445 450 def local_logger(logname, loglevel=logging.DEBUG):
446 451 loglevel = integer_loglevel(loglevel)
447 452 logger = logging.getLogger(logname)
448 453 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
449 454 # don't add a second StreamHandler
450 455 return
451 456 handler = logging.StreamHandler()
452 457 handler.setLevel(loglevel)
453 458 logger.addHandler(handler)
454 459 logger.setLevel(loglevel)
455 460 return logger
456 461
General Comments 0
You need to be logged in to leave comments. Login now