##// END OF EJS Templates
move rekey to jsonutil from parallel.util...
MinRK -
Show More
@@ -1,1373 +1,1374 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 import os
19 19 import json
20 20 import time
21 21 import warnings
22 22 from datetime import datetime
23 23 from getpass import getpass
24 24 from pprint import pprint
25 25
26 26 pjoin = os.path.join
27 27
28 28 import zmq
29 29 # from zmq.eventloop import ioloop, zmqstream
30 30
31 from IPython.utils.jsonutil import rekey
31 32 from IPython.utils.path import get_ipython_dir
32 33 from IPython.utils.traitlets import (HasTraits, Int, Instance, Unicode,
33 34 Dict, List, Bool, Set)
34 35 from IPython.external.decorator import decorator
35 36 from IPython.external.ssh import tunnel
36 37
37 38 from IPython.parallel import error
38 39 from IPython.parallel import util
39 40
40 41 from IPython.zmq.session import Session, Message
41 42
42 43 from .asyncresult import AsyncResult, AsyncHubResult
43 44 from IPython.core.profiledir import ProfileDir, ProfileDirError
44 45 from .view import DirectView, LoadBalancedView
45 46
46 47 #--------------------------------------------------------------------------
47 48 # Decorators for Client methods
48 49 #--------------------------------------------------------------------------
49 50
50 51 @decorator
51 52 def spin_first(f, self, *args, **kwargs):
52 53 """Call spin() to sync state prior to calling the method."""
53 54 self.spin()
54 55 return f(self, *args, **kwargs)
55 56
56 57
57 58 #--------------------------------------------------------------------------
58 59 # Classes
59 60 #--------------------------------------------------------------------------
60 61
61 62 class Metadata(dict):
62 63 """Subclass of dict for initializing metadata values.
63 64
64 65 Attribute access works on keys.
65 66
66 67 These objects have a strict set of keys - errors will raise if you try
67 68 to add new keys.
68 69 """
69 70 def __init__(self, *args, **kwargs):
70 71 dict.__init__(self)
71 72 md = {'msg_id' : None,
72 73 'submitted' : None,
73 74 'started' : None,
74 75 'completed' : None,
75 76 'received' : None,
76 77 'engine_uuid' : None,
77 78 'engine_id' : None,
78 79 'follow' : None,
79 80 'after' : None,
80 81 'status' : None,
81 82
82 83 'pyin' : None,
83 84 'pyout' : None,
84 85 'pyerr' : None,
85 86 'stdout' : '',
86 87 'stderr' : '',
87 88 }
88 89 self.update(md)
89 90 self.update(dict(*args, **kwargs))
90 91
91 92 def __getattr__(self, key):
92 93 """getattr aliased to getitem"""
93 94 if key in self.iterkeys():
94 95 return self[key]
95 96 else:
96 97 raise AttributeError(key)
97 98
98 99 def __setattr__(self, key, value):
99 100 """setattr aliased to setitem, with strict"""
100 101 if key in self.iterkeys():
101 102 self[key] = value
102 103 else:
103 104 raise AttributeError(key)
104 105
105 106 def __setitem__(self, key, value):
106 107 """strict static key enforcement"""
107 108 if key in self.iterkeys():
108 109 dict.__setitem__(self, key, value)
109 110 else:
110 111 raise KeyError(key)
111 112
112 113
113 114 class Client(HasTraits):
114 115 """A semi-synchronous client to the IPython ZMQ cluster
115 116
116 117 Parameters
117 118 ----------
118 119
119 120 url_or_file : bytes; zmq url or path to ipcontroller-client.json
120 121 Connection information for the Hub's registration. If a json connector
121 122 file is given, then likely no further configuration is necessary.
122 123 [Default: use profile]
123 124 profile : bytes
124 125 The name of the Cluster profile to be used to find connector information.
125 126 [Default: 'default']
126 127 context : zmq.Context
127 128 Pass an existing zmq.Context instance, otherwise the client will create its own.
128 129 debug : bool
129 130 flag for lots of message printing for debug purposes
130 131 timeout : int/float
131 132 time (in seconds) to wait for connection replies from the Hub
132 133 [Default: 10]
133 134
134 135 #-------------- session related args ----------------
135 136
136 137 config : Config object
137 138 If specified, this will be relayed to the Session for configuration
138 139 username : str
139 140 set username for the session object
140 141 packer : str (import_string) or callable
141 142 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
142 143 function to serialize messages. Must support same input as
143 144 JSON, and output must be bytes.
144 145 You can pass a callable directly as `pack`
145 146 unpacker : str (import_string) or callable
146 147 The inverse of packer. Only necessary if packer is specified as *not* one
147 148 of 'json' or 'pickle'.
148 149
149 150 #-------------- ssh related args ----------------
150 151 # These are args for configuring the ssh tunnel to be used
151 152 # credentials are used to forward connections over ssh to the Controller
152 153 # Note that the ip given in `addr` needs to be relative to sshserver
153 154 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
154 155 # and set sshserver as the same machine the Controller is on. However,
155 156 # the only requirement is that sshserver is able to see the Controller
156 157 # (i.e. is within the same trusted network).
157 158
158 159 sshserver : str
159 160 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
160 161 If keyfile or password is specified, and this is not, it will default to
161 162 the ip given in addr.
162 163 sshkey : str; path to public ssh key file
163 164 This specifies a key to be used in ssh login, default None.
164 165 Regular default ssh keys will be used without specifying this argument.
165 166 password : str
166 167 Your ssh password to sshserver. Note that if this is left None,
167 168 you will be prompted for it if passwordless key based login is unavailable.
168 169 paramiko : bool
169 170 flag for whether to use paramiko instead of shell ssh for tunneling.
170 171 [default: True on win32, False else]
171 172
172 173 ------- exec authentication args -------
173 174 If even localhost is untrusted, you can have some protection against
174 175 unauthorized execution by signing messages with HMAC digests.
175 176 Messages are still sent as cleartext, so if someone can snoop your
176 177 loopback traffic this will not protect your privacy, but will prevent
177 178 unauthorized execution.
178 179
179 180 exec_key : str
180 181 an authentication key or file containing a key
181 182 default: None
182 183
183 184
184 185 Attributes
185 186 ----------
186 187
187 188 ids : list of int engine IDs
188 189 requesting the ids attribute always synchronizes
189 190 the registration state. To request ids without synchronization,
190 191 use semi-private _ids attributes.
191 192
192 193 history : list of msg_ids
193 194 a list of msg_ids, keeping track of all the execution
194 195 messages you have submitted in order.
195 196
196 197 outstanding : set of msg_ids
197 198 a set of msg_ids that have been submitted, but whose
198 199 results have not yet been received.
199 200
200 201 results : dict
201 202 a dict of all our results, keyed by msg_id
202 203
203 204 block : bool
204 205 determines default behavior when block not specified
205 206 in execution methods
206 207
207 208 Methods
208 209 -------
209 210
210 211 spin
211 212 flushes incoming results and registration state changes
212 213 control methods spin, and requesting `ids` also ensures up to date
213 214
214 215 wait
215 216 wait on one or more msg_ids
216 217
217 218 execution methods
218 219 apply
219 220 legacy: execute, run
220 221
221 222 data movement
222 223 push, pull, scatter, gather
223 224
224 225 query methods
225 226 queue_status, get_result, purge, result_status
226 227
227 228 control methods
228 229 abort, shutdown
229 230
230 231 """
231 232
232 233
233 234 block = Bool(False)
234 235 outstanding = Set()
235 236 results = Instance('collections.defaultdict', (dict,))
236 237 metadata = Instance('collections.defaultdict', (Metadata,))
237 238 history = List()
238 239 debug = Bool(False)
239 240 profile=Unicode('default')
240 241
241 242 _outstanding_dict = Instance('collections.defaultdict', (set,))
242 243 _ids = List()
243 244 _connected=Bool(False)
244 245 _ssh=Bool(False)
245 246 _context = Instance('zmq.Context')
246 247 _config = Dict()
247 248 _engines=Instance(util.ReverseDict, (), {})
248 249 # _hub_socket=Instance('zmq.Socket')
249 250 _query_socket=Instance('zmq.Socket')
250 251 _control_socket=Instance('zmq.Socket')
251 252 _iopub_socket=Instance('zmq.Socket')
252 253 _notification_socket=Instance('zmq.Socket')
253 254 _mux_socket=Instance('zmq.Socket')
254 255 _task_socket=Instance('zmq.Socket')
255 256 _task_scheme=Unicode()
256 257 _closed = False
257 258 _ignored_control_replies=Int(0)
258 259 _ignored_hub_replies=Int(0)
259 260
260 261 def __init__(self, url_or_file=None, profile='default', profile_dir=None, ipython_dir=None,
261 262 context=None, debug=False, exec_key=None,
262 263 sshserver=None, sshkey=None, password=None, paramiko=None,
263 264 timeout=10, **extra_args
264 265 ):
265 266 super(Client, self).__init__(debug=debug, profile=profile)
266 267 if context is None:
267 268 context = zmq.Context.instance()
268 269 self._context = context
269 270
270 271
271 272 self._setup_profile_dir(profile, profile_dir, ipython_dir)
272 273 if self._cd is not None:
273 274 if url_or_file is None:
274 275 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
275 276 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
276 277 " Please specify at least one of url_or_file or profile."
277 278
278 279 try:
279 280 util.validate_url(url_or_file)
280 281 except AssertionError:
281 282 if not os.path.exists(url_or_file):
282 283 if self._cd:
283 284 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
284 285 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
285 286 with open(url_or_file) as f:
286 287 cfg = json.loads(f.read())
287 288 else:
288 289 cfg = {'url':url_or_file}
289 290
290 291 # sync defaults from args, json:
291 292 if sshserver:
292 293 cfg['ssh'] = sshserver
293 294 if exec_key:
294 295 cfg['exec_key'] = exec_key
295 296 exec_key = cfg['exec_key']
296 297 sshserver=cfg['ssh']
297 298 url = cfg['url']
298 299 location = cfg.setdefault('location', None)
299 300 cfg['url'] = util.disambiguate_url(cfg['url'], location)
300 301 url = cfg['url']
301 302
302 303 self._config = cfg
303 304
304 305 self._ssh = bool(sshserver or sshkey or password)
305 306 if self._ssh and sshserver is None:
306 307 # default to ssh via localhost
307 308 sshserver = url.split('://')[1].split(':')[0]
308 309 if self._ssh and password is None:
309 310 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
310 311 password=False
311 312 else:
312 313 password = getpass("SSH Password for %s: "%sshserver)
313 314 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
314 315
315 316 # configure and construct the session
316 317 if exec_key is not None:
317 318 if os.path.isfile(exec_key):
318 319 extra_args['keyfile'] = exec_key
319 320 else:
320 321 extra_args['key'] = exec_key
321 322 self.session = Session(**extra_args)
322 323
323 324 self._query_socket = self._context.socket(zmq.XREQ)
324 325 self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
325 326 if self._ssh:
326 327 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
327 328 else:
328 329 self._query_socket.connect(url)
329 330
330 331 self.session.debug = self.debug
331 332
332 333 self._notification_handlers = {'registration_notification' : self._register_engine,
333 334 'unregistration_notification' : self._unregister_engine,
334 335 'shutdown_notification' : lambda msg: self.close(),
335 336 }
336 337 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
337 338 'apply_reply' : self._handle_apply_reply}
338 339 self._connect(sshserver, ssh_kwargs, timeout)
339 340
340 341 def __del__(self):
341 342 """cleanup sockets, but _not_ context."""
342 343 self.close()
343 344
344 345 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
345 346 if ipython_dir is None:
346 347 ipython_dir = get_ipython_dir()
347 348 if profile_dir is not None:
348 349 try:
349 350 self._cd = ProfileDir.find_profile_dir(profile_dir)
350 351 return
351 352 except ProfileDirError:
352 353 pass
353 354 elif profile is not None:
354 355 try:
355 356 self._cd = ProfileDir.find_profile_dir_by_name(
356 357 ipython_dir, profile)
357 358 return
358 359 except ProfileDirError:
359 360 pass
360 361 self._cd = None
361 362
362 363 def _update_engines(self, engines):
363 364 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
364 365 for k,v in engines.iteritems():
365 366 eid = int(k)
366 367 self._engines[eid] = bytes(v) # force not unicode
367 368 self._ids.append(eid)
368 369 self._ids = sorted(self._ids)
369 370 if sorted(self._engines.keys()) != range(len(self._engines)) and \
370 371 self._task_scheme == 'pure' and self._task_socket:
371 372 self._stop_scheduling_tasks()
372 373
373 374 def _stop_scheduling_tasks(self):
374 375 """Stop scheduling tasks because an engine has been unregistered
375 376 from a pure ZMQ scheduler.
376 377 """
377 378 self._task_socket.close()
378 379 self._task_socket = None
379 380 msg = "An engine has been unregistered, and we are using pure " +\
380 381 "ZMQ task scheduling. Task farming will be disabled."
381 382 if self.outstanding:
382 383 msg += " If you were running tasks when this happened, " +\
383 384 "some `outstanding` msg_ids may never resolve."
384 385 warnings.warn(msg, RuntimeWarning)
385 386
386 387 def _build_targets(self, targets):
387 388 """Turn valid target IDs or 'all' into two lists:
388 389 (int_ids, uuids).
389 390 """
390 391 if not self._ids:
391 392 # flush notification socket if no engines yet, just in case
392 393 if not self.ids:
393 394 raise error.NoEnginesRegistered("Can't build targets without any engines")
394 395
395 396 if targets is None:
396 397 targets = self._ids
397 398 elif isinstance(targets, str):
398 399 if targets.lower() == 'all':
399 400 targets = self._ids
400 401 else:
401 402 raise TypeError("%r not valid str target, must be 'all'"%(targets))
402 403 elif isinstance(targets, int):
403 404 if targets < 0:
404 405 targets = self.ids[targets]
405 406 if targets not in self._ids:
406 407 raise IndexError("No such engine: %i"%targets)
407 408 targets = [targets]
408 409
409 410 if isinstance(targets, slice):
410 411 indices = range(len(self._ids))[targets]
411 412 ids = self.ids
412 413 targets = [ ids[i] for i in indices ]
413 414
414 415 if not isinstance(targets, (tuple, list, xrange)):
415 416 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
416 417
417 418 return [self._engines[t] for t in targets], list(targets)
418 419
419 420 def _connect(self, sshserver, ssh_kwargs, timeout):
420 421 """setup all our socket connections to the cluster. This is called from
421 422 __init__."""
422 423
423 424 # Maybe allow reconnecting?
424 425 if self._connected:
425 426 return
426 427 self._connected=True
427 428
428 429 def connect_socket(s, url):
429 430 url = util.disambiguate_url(url, self._config['location'])
430 431 if self._ssh:
431 432 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
432 433 else:
433 434 return s.connect(url)
434 435
435 436 self.session.send(self._query_socket, 'connection_request')
436 437 r,w,x = zmq.select([self._query_socket],[],[], timeout)
437 438 if not r:
438 439 raise error.TimeoutError("Hub connection request timed out")
439 440 idents,msg = self.session.recv(self._query_socket,mode=0)
440 441 if self.debug:
441 442 pprint(msg)
442 443 msg = Message(msg)
443 444 content = msg.content
444 445 self._config['registration'] = dict(content)
445 446 if content.status == 'ok':
446 447 if content.mux:
447 448 self._mux_socket = self._context.socket(zmq.XREQ)
448 449 self._mux_socket.setsockopt(zmq.IDENTITY, self.session.session)
449 450 connect_socket(self._mux_socket, content.mux)
450 451 if content.task:
451 452 self._task_scheme, task_addr = content.task
452 453 self._task_socket = self._context.socket(zmq.XREQ)
453 454 self._task_socket.setsockopt(zmq.IDENTITY, self.session.session)
454 455 connect_socket(self._task_socket, task_addr)
455 456 if content.notification:
456 457 self._notification_socket = self._context.socket(zmq.SUB)
457 458 connect_socket(self._notification_socket, content.notification)
458 459 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
459 460 # if content.query:
460 461 # self._query_socket = self._context.socket(zmq.XREQ)
461 462 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.session)
462 463 # connect_socket(self._query_socket, content.query)
463 464 if content.control:
464 465 self._control_socket = self._context.socket(zmq.XREQ)
465 466 self._control_socket.setsockopt(zmq.IDENTITY, self.session.session)
466 467 connect_socket(self._control_socket, content.control)
467 468 if content.iopub:
468 469 self._iopub_socket = self._context.socket(zmq.SUB)
469 470 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
470 471 self._iopub_socket.setsockopt(zmq.IDENTITY, self.session.session)
471 472 connect_socket(self._iopub_socket, content.iopub)
472 473 self._update_engines(dict(content.engines))
473 474 else:
474 475 self._connected = False
475 476 raise Exception("Failed to connect!")
476 477
477 478 #--------------------------------------------------------------------------
478 479 # handlers and callbacks for incoming messages
479 480 #--------------------------------------------------------------------------
480 481
481 482 def _unwrap_exception(self, content):
482 483 """unwrap exception, and remap engine_id to int."""
483 484 e = error.unwrap_exception(content)
484 485 # print e.traceback
485 486 if e.engine_info:
486 487 e_uuid = e.engine_info['engine_uuid']
487 488 eid = self._engines[e_uuid]
488 489 e.engine_info['engine_id'] = eid
489 490 return e
490 491
491 492 def _extract_metadata(self, header, parent, content):
492 493 md = {'msg_id' : parent['msg_id'],
493 494 'received' : datetime.now(),
494 495 'engine_uuid' : header.get('engine', None),
495 496 'follow' : parent.get('follow', []),
496 497 'after' : parent.get('after', []),
497 498 'status' : content['status'],
498 499 }
499 500
500 501 if md['engine_uuid'] is not None:
501 502 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
502 503
503 504 if 'date' in parent:
504 505 md['submitted'] = parent['date']
505 506 if 'started' in header:
506 507 md['started'] = header['started']
507 508 if 'date' in header:
508 509 md['completed'] = header['date']
509 510 return md
510 511
511 512 def _register_engine(self, msg):
512 513 """Register a new engine, and update our connection info."""
513 514 content = msg['content']
514 515 eid = content['id']
515 516 d = {eid : content['queue']}
516 517 self._update_engines(d)
517 518
518 519 def _unregister_engine(self, msg):
519 520 """Unregister an engine that has died."""
520 521 content = msg['content']
521 522 eid = int(content['id'])
522 523 if eid in self._ids:
523 524 self._ids.remove(eid)
524 525 uuid = self._engines.pop(eid)
525 526
526 527 self._handle_stranded_msgs(eid, uuid)
527 528
528 529 if self._task_socket and self._task_scheme == 'pure':
529 530 self._stop_scheduling_tasks()
530 531
531 532 def _handle_stranded_msgs(self, eid, uuid):
532 533 """Handle messages known to be on an engine when the engine unregisters.
533 534
534 535 It is possible that this will fire prematurely - that is, an engine will
535 536 go down after completing a result, and the client will be notified
536 537 of the unregistration and later receive the successful result.
537 538 """
538 539
539 540 outstanding = self._outstanding_dict[uuid]
540 541
541 542 for msg_id in list(outstanding):
542 543 if msg_id in self.results:
543 544 # we already
544 545 continue
545 546 try:
546 547 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
547 548 except:
548 549 content = error.wrap_exception()
549 550 # build a fake message:
550 551 parent = {}
551 552 header = {}
552 553 parent['msg_id'] = msg_id
553 554 header['engine'] = uuid
554 555 header['date'] = datetime.now()
555 556 msg = dict(parent_header=parent, header=header, content=content)
556 557 self._handle_apply_reply(msg)
557 558
558 559 def _handle_execute_reply(self, msg):
559 560 """Save the reply to an execute_request into our results.
560 561
561 562 execute messages are never actually used. apply is used instead.
562 563 """
563 564
564 565 parent = msg['parent_header']
565 566 msg_id = parent['msg_id']
566 567 if msg_id not in self.outstanding:
567 568 if msg_id in self.history:
568 569 print ("got stale result: %s"%msg_id)
569 570 else:
570 571 print ("got unknown result: %s"%msg_id)
571 572 else:
572 573 self.outstanding.remove(msg_id)
573 574 self.results[msg_id] = self._unwrap_exception(msg['content'])
574 575
575 576 def _handle_apply_reply(self, msg):
576 577 """Save the reply to an apply_request into our results."""
577 578 parent = msg['parent_header']
578 579 msg_id = parent['msg_id']
579 580 if msg_id not in self.outstanding:
580 581 if msg_id in self.history:
581 582 print ("got stale result: %s"%msg_id)
582 583 print self.results[msg_id]
583 584 print msg
584 585 else:
585 586 print ("got unknown result: %s"%msg_id)
586 587 else:
587 588 self.outstanding.remove(msg_id)
588 589 content = msg['content']
589 590 header = msg['header']
590 591
591 592 # construct metadata:
592 593 md = self.metadata[msg_id]
593 594 md.update(self._extract_metadata(header, parent, content))
594 595 # is this redundant?
595 596 self.metadata[msg_id] = md
596 597
597 598 e_outstanding = self._outstanding_dict[md['engine_uuid']]
598 599 if msg_id in e_outstanding:
599 600 e_outstanding.remove(msg_id)
600 601
601 602 # construct result:
602 603 if content['status'] == 'ok':
603 604 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
604 605 elif content['status'] == 'aborted':
605 606 self.results[msg_id] = error.TaskAborted(msg_id)
606 607 elif content['status'] == 'resubmitted':
607 608 # TODO: handle resubmission
608 609 pass
609 610 else:
610 611 self.results[msg_id] = self._unwrap_exception(content)
611 612
612 613 def _flush_notifications(self):
613 614 """Flush notifications of engine registrations waiting
614 615 in ZMQ queue."""
615 616 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
616 617 while msg is not None:
617 618 if self.debug:
618 619 pprint(msg)
619 620 msg_type = msg['msg_type']
620 621 handler = self._notification_handlers.get(msg_type, None)
621 622 if handler is None:
622 623 raise Exception("Unhandled message type: %s"%msg.msg_type)
623 624 else:
624 625 handler(msg)
625 626 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
626 627
627 628 def _flush_results(self, sock):
628 629 """Flush task or queue results waiting in ZMQ queue."""
629 630 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
630 631 while msg is not None:
631 632 if self.debug:
632 633 pprint(msg)
633 634 msg_type = msg['msg_type']
634 635 handler = self._queue_handlers.get(msg_type, None)
635 636 if handler is None:
636 637 raise Exception("Unhandled message type: %s"%msg.msg_type)
637 638 else:
638 639 handler(msg)
639 640 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
640 641
641 642 def _flush_control(self, sock):
642 643 """Flush replies from the control channel waiting
643 644 in the ZMQ queue.
644 645
645 646 Currently: ignore them."""
646 647 if self._ignored_control_replies <= 0:
647 648 return
648 649 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
649 650 while msg is not None:
650 651 self._ignored_control_replies -= 1
651 652 if self.debug:
652 653 pprint(msg)
653 654 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
654 655
655 656 def _flush_ignored_control(self):
656 657 """flush ignored control replies"""
657 658 while self._ignored_control_replies > 0:
658 659 self.session.recv(self._control_socket)
659 660 self._ignored_control_replies -= 1
660 661
661 662 def _flush_ignored_hub_replies(self):
662 663 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
663 664 while msg is not None:
664 665 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
665 666
666 667 def _flush_iopub(self, sock):
667 668 """Flush replies from the iopub channel waiting
668 669 in the ZMQ queue.
669 670 """
670 671 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
671 672 while msg is not None:
672 673 if self.debug:
673 674 pprint(msg)
674 675 parent = msg['parent_header']
675 676 msg_id = parent['msg_id']
676 677 content = msg['content']
677 678 header = msg['header']
678 679 msg_type = msg['msg_type']
679 680
680 681 # init metadata:
681 682 md = self.metadata[msg_id]
682 683
683 684 if msg_type == 'stream':
684 685 name = content['name']
685 686 s = md[name] or ''
686 687 md[name] = s + content['data']
687 688 elif msg_type == 'pyerr':
688 689 md.update({'pyerr' : self._unwrap_exception(content)})
689 690 elif msg_type == 'pyin':
690 691 md.update({'pyin' : content['code']})
691 692 else:
692 693 md.update({msg_type : content.get('data', '')})
693 694
694 695 # reduntant?
695 696 self.metadata[msg_id] = md
696 697
697 698 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
698 699
699 700 #--------------------------------------------------------------------------
700 701 # len, getitem
701 702 #--------------------------------------------------------------------------
702 703
703 704 def __len__(self):
704 705 """len(client) returns # of engines."""
705 706 return len(self.ids)
706 707
707 708 def __getitem__(self, key):
708 709 """index access returns DirectView multiplexer objects
709 710
710 711 Must be int, slice, or list/tuple/xrange of ints"""
711 712 if not isinstance(key, (int, slice, tuple, list, xrange)):
712 713 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
713 714 else:
714 715 return self.direct_view(key)
715 716
716 717 #--------------------------------------------------------------------------
717 718 # Begin public methods
718 719 #--------------------------------------------------------------------------
719 720
720 721 @property
721 722 def ids(self):
722 723 """Always up-to-date ids property."""
723 724 self._flush_notifications()
724 725 # always copy:
725 726 return list(self._ids)
726 727
727 728 def close(self):
728 729 if self._closed:
729 730 return
730 731 snames = filter(lambda n: n.endswith('socket'), dir(self))
731 732 for socket in map(lambda name: getattr(self, name), snames):
732 733 if isinstance(socket, zmq.Socket) and not socket.closed:
733 734 socket.close()
734 735 self._closed = True
735 736
736 737 def spin(self):
737 738 """Flush any registration notifications and execution results
738 739 waiting in the ZMQ queue.
739 740 """
740 741 if self._notification_socket:
741 742 self._flush_notifications()
742 743 if self._mux_socket:
743 744 self._flush_results(self._mux_socket)
744 745 if self._task_socket:
745 746 self._flush_results(self._task_socket)
746 747 if self._control_socket:
747 748 self._flush_control(self._control_socket)
748 749 if self._iopub_socket:
749 750 self._flush_iopub(self._iopub_socket)
750 751 if self._query_socket:
751 752 self._flush_ignored_hub_replies()
752 753
753 754 def wait(self, jobs=None, timeout=-1):
754 755 """waits on one or more `jobs`, for up to `timeout` seconds.
755 756
756 757 Parameters
757 758 ----------
758 759
759 760 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
760 761 ints are indices to self.history
761 762 strs are msg_ids
762 763 default: wait on all outstanding messages
763 764 timeout : float
764 765 a time in seconds, after which to give up.
765 766 default is -1, which means no timeout
766 767
767 768 Returns
768 769 -------
769 770
770 771 True : when all msg_ids are done
771 772 False : timeout reached, some msg_ids still outstanding
772 773 """
773 774 tic = time.time()
774 775 if jobs is None:
775 776 theids = self.outstanding
776 777 else:
777 778 if isinstance(jobs, (int, str, AsyncResult)):
778 779 jobs = [jobs]
779 780 theids = set()
780 781 for job in jobs:
781 782 if isinstance(job, int):
782 783 # index access
783 784 job = self.history[job]
784 785 elif isinstance(job, AsyncResult):
785 786 map(theids.add, job.msg_ids)
786 787 continue
787 788 theids.add(job)
788 789 if not theids.intersection(self.outstanding):
789 790 return True
790 791 self.spin()
791 792 while theids.intersection(self.outstanding):
792 793 if timeout >= 0 and ( time.time()-tic ) > timeout:
793 794 break
794 795 time.sleep(1e-3)
795 796 self.spin()
796 797 return len(theids.intersection(self.outstanding)) == 0
797 798
798 799 #--------------------------------------------------------------------------
799 800 # Control methods
800 801 #--------------------------------------------------------------------------
801 802
802 803 @spin_first
803 804 def clear(self, targets=None, block=None):
804 805 """Clear the namespace in target(s)."""
805 806 block = self.block if block is None else block
806 807 targets = self._build_targets(targets)[0]
807 808 for t in targets:
808 809 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
809 810 error = False
810 811 if block:
811 812 self._flush_ignored_control()
812 813 for i in range(len(targets)):
813 814 idents,msg = self.session.recv(self._control_socket,0)
814 815 if self.debug:
815 816 pprint(msg)
816 817 if msg['content']['status'] != 'ok':
817 818 error = self._unwrap_exception(msg['content'])
818 819 else:
819 820 self._ignored_control_replies += len(targets)
820 821 if error:
821 822 raise error
822 823
823 824
824 825 @spin_first
825 826 def abort(self, jobs=None, targets=None, block=None):
826 827 """Abort specific jobs from the execution queues of target(s).
827 828
828 829 This is a mechanism to prevent jobs that have already been submitted
829 830 from executing.
830 831
831 832 Parameters
832 833 ----------
833 834
834 835 jobs : msg_id, list of msg_ids, or AsyncResult
835 836 The jobs to be aborted
836 837
837 838
838 839 """
839 840 block = self.block if block is None else block
840 841 targets = self._build_targets(targets)[0]
841 842 msg_ids = []
842 843 if isinstance(jobs, (basestring,AsyncResult)):
843 844 jobs = [jobs]
844 845 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
845 846 if bad_ids:
846 847 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
847 848 for j in jobs:
848 849 if isinstance(j, AsyncResult):
849 850 msg_ids.extend(j.msg_ids)
850 851 else:
851 852 msg_ids.append(j)
852 853 content = dict(msg_ids=msg_ids)
853 854 for t in targets:
854 855 self.session.send(self._control_socket, 'abort_request',
855 856 content=content, ident=t)
856 857 error = False
857 858 if block:
858 859 self._flush_ignored_control()
859 860 for i in range(len(targets)):
860 861 idents,msg = self.session.recv(self._control_socket,0)
861 862 if self.debug:
862 863 pprint(msg)
863 864 if msg['content']['status'] != 'ok':
864 865 error = self._unwrap_exception(msg['content'])
865 866 else:
866 867 self._ignored_control_replies += len(targets)
867 868 if error:
868 869 raise error
869 870
870 871 @spin_first
871 872 def shutdown(self, targets=None, restart=False, hub=False, block=None):
872 873 """Terminates one or more engine processes, optionally including the hub."""
873 874 block = self.block if block is None else block
874 875 if hub:
875 876 targets = 'all'
876 877 targets = self._build_targets(targets)[0]
877 878 for t in targets:
878 879 self.session.send(self._control_socket, 'shutdown_request',
879 880 content={'restart':restart},ident=t)
880 881 error = False
881 882 if block or hub:
882 883 self._flush_ignored_control()
883 884 for i in range(len(targets)):
884 885 idents,msg = self.session.recv(self._control_socket, 0)
885 886 if self.debug:
886 887 pprint(msg)
887 888 if msg['content']['status'] != 'ok':
888 889 error = self._unwrap_exception(msg['content'])
889 890 else:
890 891 self._ignored_control_replies += len(targets)
891 892
892 893 if hub:
893 894 time.sleep(0.25)
894 895 self.session.send(self._query_socket, 'shutdown_request')
895 896 idents,msg = self.session.recv(self._query_socket, 0)
896 897 if self.debug:
897 898 pprint(msg)
898 899 if msg['content']['status'] != 'ok':
899 900 error = self._unwrap_exception(msg['content'])
900 901
901 902 if error:
902 903 raise error
903 904
904 905 #--------------------------------------------------------------------------
905 906 # Execution related methods
906 907 #--------------------------------------------------------------------------
907 908
908 909 def _maybe_raise(self, result):
909 910 """wrapper for maybe raising an exception if apply failed."""
910 911 if isinstance(result, error.RemoteError):
911 912 raise result
912 913
913 914 return result
914 915
915 916 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
916 917 ident=None):
917 918 """construct and send an apply message via a socket.
918 919
919 920 This is the principal method with which all engine execution is performed by views.
920 921 """
921 922
922 923 assert not self._closed, "cannot use me anymore, I'm closed!"
923 924 # defaults:
924 925 args = args if args is not None else []
925 926 kwargs = kwargs if kwargs is not None else {}
926 927 subheader = subheader if subheader is not None else {}
927 928
928 929 # validate arguments
929 930 if not callable(f):
930 931 raise TypeError("f must be callable, not %s"%type(f))
931 932 if not isinstance(args, (tuple, list)):
932 933 raise TypeError("args must be tuple or list, not %s"%type(args))
933 934 if not isinstance(kwargs, dict):
934 935 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
935 936 if not isinstance(subheader, dict):
936 937 raise TypeError("subheader must be dict, not %s"%type(subheader))
937 938
938 939 bufs = util.pack_apply_message(f,args,kwargs)
939 940
940 941 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
941 942 subheader=subheader, track=track)
942 943
943 944 msg_id = msg['msg_id']
944 945 self.outstanding.add(msg_id)
945 946 if ident:
946 947 # possibly routed to a specific engine
947 948 if isinstance(ident, list):
948 949 ident = ident[-1]
949 950 if ident in self._engines.values():
950 951 # save for later, in case of engine death
951 952 self._outstanding_dict[ident].add(msg_id)
952 953 self.history.append(msg_id)
953 954 self.metadata[msg_id]['submitted'] = datetime.now()
954 955
955 956 return msg
956 957
957 958 #--------------------------------------------------------------------------
958 959 # construct a View object
959 960 #--------------------------------------------------------------------------
960 961
961 962 def load_balanced_view(self, targets=None):
962 963 """construct a DirectView object.
963 964
964 965 If no arguments are specified, create a LoadBalancedView
965 966 using all engines.
966 967
967 968 Parameters
968 969 ----------
969 970
970 971 targets: list,slice,int,etc. [default: use all engines]
971 972 The subset of engines across which to load-balance
972 973 """
973 974 if targets is not None:
974 975 targets = self._build_targets(targets)[1]
975 976 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
976 977
977 978 def direct_view(self, targets='all'):
978 979 """construct a DirectView object.
979 980
980 981 If no targets are specified, create a DirectView
981 982 using all engines.
982 983
983 984 Parameters
984 985 ----------
985 986
986 987 targets: list,slice,int,etc. [default: use all engines]
987 988 The engines to use for the View
988 989 """
989 990 single = isinstance(targets, int)
990 991 targets = self._build_targets(targets)[1]
991 992 if single:
992 993 targets = targets[0]
993 994 return DirectView(client=self, socket=self._mux_socket, targets=targets)
994 995
995 996 #--------------------------------------------------------------------------
996 997 # Query methods
997 998 #--------------------------------------------------------------------------
998 999
999 1000 @spin_first
1000 1001 def get_result(self, indices_or_msg_ids=None, block=None):
1001 1002 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1002 1003
1003 1004 If the client already has the results, no request to the Hub will be made.
1004 1005
1005 1006 This is a convenient way to construct AsyncResult objects, which are wrappers
1006 1007 that include metadata about execution, and allow for awaiting results that
1007 1008 were not submitted by this Client.
1008 1009
1009 1010 It can also be a convenient way to retrieve the metadata associated with
1010 1011 blocking execution, since it always retrieves
1011 1012
1012 1013 Examples
1013 1014 --------
1014 1015 ::
1015 1016
1016 1017 In [10]: r = client.apply()
1017 1018
1018 1019 Parameters
1019 1020 ----------
1020 1021
1021 1022 indices_or_msg_ids : integer history index, str msg_id, or list of either
1022 1023 The indices or msg_ids of indices to be retrieved
1023 1024
1024 1025 block : bool
1025 1026 Whether to wait for the result to be done
1026 1027
1027 1028 Returns
1028 1029 -------
1029 1030
1030 1031 AsyncResult
1031 1032 A single AsyncResult object will always be returned.
1032 1033
1033 1034 AsyncHubResult
1034 1035 A subclass of AsyncResult that retrieves results from the Hub
1035 1036
1036 1037 """
1037 1038 block = self.block if block is None else block
1038 1039 if indices_or_msg_ids is None:
1039 1040 indices_or_msg_ids = -1
1040 1041
1041 1042 if not isinstance(indices_or_msg_ids, (list,tuple)):
1042 1043 indices_or_msg_ids = [indices_or_msg_ids]
1043 1044
1044 1045 theids = []
1045 1046 for id in indices_or_msg_ids:
1046 1047 if isinstance(id, int):
1047 1048 id = self.history[id]
1048 1049 if not isinstance(id, str):
1049 1050 raise TypeError("indices must be str or int, not %r"%id)
1050 1051 theids.append(id)
1051 1052
1052 1053 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1053 1054 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1054 1055
1055 1056 if remote_ids:
1056 1057 ar = AsyncHubResult(self, msg_ids=theids)
1057 1058 else:
1058 1059 ar = AsyncResult(self, msg_ids=theids)
1059 1060
1060 1061 if block:
1061 1062 ar.wait()
1062 1063
1063 1064 return ar
1064 1065
1065 1066 @spin_first
1066 1067 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1067 1068 """Resubmit one or more tasks.
1068 1069
1069 1070 in-flight tasks may not be resubmitted.
1070 1071
1071 1072 Parameters
1072 1073 ----------
1073 1074
1074 1075 indices_or_msg_ids : integer history index, str msg_id, or list of either
1075 1076 The indices or msg_ids of indices to be retrieved
1076 1077
1077 1078 block : bool
1078 1079 Whether to wait for the result to be done
1079 1080
1080 1081 Returns
1081 1082 -------
1082 1083
1083 1084 AsyncHubResult
1084 1085 A subclass of AsyncResult that retrieves results from the Hub
1085 1086
1086 1087 """
1087 1088 block = self.block if block is None else block
1088 1089 if indices_or_msg_ids is None:
1089 1090 indices_or_msg_ids = -1
1090 1091
1091 1092 if not isinstance(indices_or_msg_ids, (list,tuple)):
1092 1093 indices_or_msg_ids = [indices_or_msg_ids]
1093 1094
1094 1095 theids = []
1095 1096 for id in indices_or_msg_ids:
1096 1097 if isinstance(id, int):
1097 1098 id = self.history[id]
1098 1099 if not isinstance(id, str):
1099 1100 raise TypeError("indices must be str or int, not %r"%id)
1100 1101 theids.append(id)
1101 1102
1102 1103 for msg_id in theids:
1103 1104 self.outstanding.discard(msg_id)
1104 1105 if msg_id in self.history:
1105 1106 self.history.remove(msg_id)
1106 1107 self.results.pop(msg_id, None)
1107 1108 self.metadata.pop(msg_id, None)
1108 1109 content = dict(msg_ids = theids)
1109 1110
1110 1111 self.session.send(self._query_socket, 'resubmit_request', content)
1111 1112
1112 1113 zmq.select([self._query_socket], [], [])
1113 1114 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1114 1115 if self.debug:
1115 1116 pprint(msg)
1116 1117 content = msg['content']
1117 1118 if content['status'] != 'ok':
1118 1119 raise self._unwrap_exception(content)
1119 1120
1120 1121 ar = AsyncHubResult(self, msg_ids=theids)
1121 1122
1122 1123 if block:
1123 1124 ar.wait()
1124 1125
1125 1126 return ar
1126 1127
1127 1128 @spin_first
1128 1129 def result_status(self, msg_ids, status_only=True):
1129 1130 """Check on the status of the result(s) of the apply request with `msg_ids`.
1130 1131
1131 1132 If status_only is False, then the actual results will be retrieved, else
1132 1133 only the status of the results will be checked.
1133 1134
1134 1135 Parameters
1135 1136 ----------
1136 1137
1137 1138 msg_ids : list of msg_ids
1138 1139 if int:
1139 1140 Passed as index to self.history for convenience.
1140 1141 status_only : bool (default: True)
1141 1142 if False:
1142 1143 Retrieve the actual results of completed tasks.
1143 1144
1144 1145 Returns
1145 1146 -------
1146 1147
1147 1148 results : dict
1148 1149 There will always be the keys 'pending' and 'completed', which will
1149 1150 be lists of msg_ids that are incomplete or complete. If `status_only`
1150 1151 is False, then completed results will be keyed by their `msg_id`.
1151 1152 """
1152 1153 if not isinstance(msg_ids, (list,tuple)):
1153 1154 msg_ids = [msg_ids]
1154 1155
1155 1156 theids = []
1156 1157 for msg_id in msg_ids:
1157 1158 if isinstance(msg_id, int):
1158 1159 msg_id = self.history[msg_id]
1159 1160 if not isinstance(msg_id, basestring):
1160 1161 raise TypeError("msg_ids must be str, not %r"%msg_id)
1161 1162 theids.append(msg_id)
1162 1163
1163 1164 completed = []
1164 1165 local_results = {}
1165 1166
1166 1167 # comment this block out to temporarily disable local shortcut:
1167 1168 for msg_id in theids:
1168 1169 if msg_id in self.results:
1169 1170 completed.append(msg_id)
1170 1171 local_results[msg_id] = self.results[msg_id]
1171 1172 theids.remove(msg_id)
1172 1173
1173 1174 if theids: # some not locally cached
1174 1175 content = dict(msg_ids=theids, status_only=status_only)
1175 1176 msg = self.session.send(self._query_socket, "result_request", content=content)
1176 1177 zmq.select([self._query_socket], [], [])
1177 1178 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1178 1179 if self.debug:
1179 1180 pprint(msg)
1180 1181 content = msg['content']
1181 1182 if content['status'] != 'ok':
1182 1183 raise self._unwrap_exception(content)
1183 1184 buffers = msg['buffers']
1184 1185 else:
1185 1186 content = dict(completed=[],pending=[])
1186 1187
1187 1188 content['completed'].extend(completed)
1188 1189
1189 1190 if status_only:
1190 1191 return content
1191 1192
1192 1193 failures = []
1193 1194 # load cached results into result:
1194 1195 content.update(local_results)
1195 1196
1196 1197 # update cache with results:
1197 1198 for msg_id in sorted(theids):
1198 1199 if msg_id in content['completed']:
1199 1200 rec = content[msg_id]
1200 1201 parent = rec['header']
1201 1202 header = rec['result_header']
1202 1203 rcontent = rec['result_content']
1203 1204 iodict = rec['io']
1204 1205 if isinstance(rcontent, str):
1205 1206 rcontent = self.session.unpack(rcontent)
1206 1207
1207 1208 md = self.metadata[msg_id]
1208 1209 md.update(self._extract_metadata(header, parent, rcontent))
1209 1210 md.update(iodict)
1210 1211
1211 1212 if rcontent['status'] == 'ok':
1212 1213 res,buffers = util.unserialize_object(buffers)
1213 1214 else:
1214 1215 print rcontent
1215 1216 res = self._unwrap_exception(rcontent)
1216 1217 failures.append(res)
1217 1218
1218 1219 self.results[msg_id] = res
1219 1220 content[msg_id] = res
1220 1221
1221 1222 if len(theids) == 1 and failures:
1222 1223 raise failures[0]
1223 1224
1224 1225 error.collect_exceptions(failures, "result_status")
1225 1226 return content
1226 1227
1227 1228 @spin_first
1228 1229 def queue_status(self, targets='all', verbose=False):
1229 1230 """Fetch the status of engine queues.
1230 1231
1231 1232 Parameters
1232 1233 ----------
1233 1234
1234 1235 targets : int/str/list of ints/strs
1235 1236 the engines whose states are to be queried.
1236 1237 default : all
1237 1238 verbose : bool
1238 1239 Whether to return lengths only, or lists of ids for each element
1239 1240 """
1240 1241 engine_ids = self._build_targets(targets)[1]
1241 1242 content = dict(targets=engine_ids, verbose=verbose)
1242 1243 self.session.send(self._query_socket, "queue_request", content=content)
1243 1244 idents,msg = self.session.recv(self._query_socket, 0)
1244 1245 if self.debug:
1245 1246 pprint(msg)
1246 1247 content = msg['content']
1247 1248 status = content.pop('status')
1248 1249 if status != 'ok':
1249 1250 raise self._unwrap_exception(content)
1250 content = util.rekey(content)
1251 content = rekey(content)
1251 1252 if isinstance(targets, int):
1252 1253 return content[targets]
1253 1254 else:
1254 1255 return content
1255 1256
1256 1257 @spin_first
1257 1258 def purge_results(self, jobs=[], targets=[]):
1258 1259 """Tell the Hub to forget results.
1259 1260
1260 1261 Individual results can be purged by msg_id, or the entire
1261 1262 history of specific targets can be purged.
1262 1263
1263 1264 Parameters
1264 1265 ----------
1265 1266
1266 1267 jobs : str or list of str or AsyncResult objects
1267 1268 the msg_ids whose results should be forgotten.
1268 1269 targets : int/str/list of ints/strs
1269 1270 The targets, by uuid or int_id, whose entire history is to be purged.
1270 1271 Use `targets='all'` to scrub everything from the Hub's memory.
1271 1272
1272 1273 default : None
1273 1274 """
1274 1275 if not targets and not jobs:
1275 1276 raise ValueError("Must specify at least one of `targets` and `jobs`")
1276 1277 if targets:
1277 1278 targets = self._build_targets(targets)[1]
1278 1279
1279 1280 # construct msg_ids from jobs
1280 1281 msg_ids = []
1281 1282 if isinstance(jobs, (basestring,AsyncResult)):
1282 1283 jobs = [jobs]
1283 1284 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1284 1285 if bad_ids:
1285 1286 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1286 1287 for j in jobs:
1287 1288 if isinstance(j, AsyncResult):
1288 1289 msg_ids.extend(j.msg_ids)
1289 1290 else:
1290 1291 msg_ids.append(j)
1291 1292
1292 1293 content = dict(targets=targets, msg_ids=msg_ids)
1293 1294 self.session.send(self._query_socket, "purge_request", content=content)
1294 1295 idents, msg = self.session.recv(self._query_socket, 0)
1295 1296 if self.debug:
1296 1297 pprint(msg)
1297 1298 content = msg['content']
1298 1299 if content['status'] != 'ok':
1299 1300 raise self._unwrap_exception(content)
1300 1301
1301 1302 @spin_first
1302 1303 def hub_history(self):
1303 1304 """Get the Hub's history
1304 1305
1305 1306 Just like the Client, the Hub has a history, which is a list of msg_ids.
1306 1307 This will contain the history of all clients, and, depending on configuration,
1307 1308 may contain history across multiple cluster sessions.
1308 1309
1309 1310 Any msg_id returned here is a valid argument to `get_result`.
1310 1311
1311 1312 Returns
1312 1313 -------
1313 1314
1314 1315 msg_ids : list of strs
1315 1316 list of all msg_ids, ordered by task submission time.
1316 1317 """
1317 1318
1318 1319 self.session.send(self._query_socket, "history_request", content={})
1319 1320 idents, msg = self.session.recv(self._query_socket, 0)
1320 1321
1321 1322 if self.debug:
1322 1323 pprint(msg)
1323 1324 content = msg['content']
1324 1325 if content['status'] != 'ok':
1325 1326 raise self._unwrap_exception(content)
1326 1327 else:
1327 1328 return content['history']
1328 1329
1329 1330 @spin_first
1330 1331 def db_query(self, query, keys=None):
1331 1332 """Query the Hub's TaskRecord database
1332 1333
1333 1334 This will return a list of task record dicts that match `query`
1334 1335
1335 1336 Parameters
1336 1337 ----------
1337 1338
1338 1339 query : mongodb query dict
1339 1340 The search dict. See mongodb query docs for details.
1340 1341 keys : list of strs [optional]
1341 1342 The subset of keys to be returned. The default is to fetch everything but buffers.
1342 1343 'msg_id' will *always* be included.
1343 1344 """
1344 1345 if isinstance(keys, basestring):
1345 1346 keys = [keys]
1346 1347 content = dict(query=query, keys=keys)
1347 1348 self.session.send(self._query_socket, "db_request", content=content)
1348 1349 idents, msg = self.session.recv(self._query_socket, 0)
1349 1350 if self.debug:
1350 1351 pprint(msg)
1351 1352 content = msg['content']
1352 1353 if content['status'] != 'ok':
1353 1354 raise self._unwrap_exception(content)
1354 1355
1355 1356 records = content['records']
1356 1357
1357 1358 buffer_lens = content['buffer_lens']
1358 1359 result_buffer_lens = content['result_buffer_lens']
1359 1360 buffers = msg['buffers']
1360 1361 has_bufs = buffer_lens is not None
1361 1362 has_rbufs = result_buffer_lens is not None
1362 1363 for i,rec in enumerate(records):
1363 1364 # relink buffers
1364 1365 if has_bufs:
1365 1366 blen = buffer_lens[i]
1366 1367 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1367 1368 if has_rbufs:
1368 1369 blen = result_buffer_lens[i]
1369 1370 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1370 1371
1371 1372 return records
1372 1373
1373 1374 __all__ = [ 'Client' ]
@@ -1,473 +1,450 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 # IPython imports
43 43 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
44 44 from IPython.utils.newserialized import serialize, unserialize
45 45 from IPython.zmq.log import EnginePUBHandler
46 46
47 47 #-----------------------------------------------------------------------------
48 48 # Classes
49 49 #-----------------------------------------------------------------------------
50 50
51 51 class Namespace(dict):
52 52 """Subclass of dict for attribute access to keys."""
53 53
54 54 def __getattr__(self, key):
55 55 """getattr aliased to getitem"""
56 56 if key in self.iterkeys():
57 57 return self[key]
58 58 else:
59 59 raise NameError(key)
60 60
61 61 def __setattr__(self, key, value):
62 62 """setattr aliased to setitem, with strict"""
63 63 if hasattr(dict, key):
64 64 raise KeyError("Cannot override dict keys %r"%key)
65 65 self[key] = value
66 66
67 67
68 68 class ReverseDict(dict):
69 69 """simple double-keyed subset of dict methods."""
70 70
71 71 def __init__(self, *args, **kwargs):
72 72 dict.__init__(self, *args, **kwargs)
73 73 self._reverse = dict()
74 74 for key, value in self.iteritems():
75 75 self._reverse[value] = key
76 76
77 77 def __getitem__(self, key):
78 78 try:
79 79 return dict.__getitem__(self, key)
80 80 except KeyError:
81 81 return self._reverse[key]
82 82
83 83 def __setitem__(self, key, value):
84 84 if key in self._reverse:
85 85 raise KeyError("Can't have key %r on both sides!"%key)
86 86 dict.__setitem__(self, key, value)
87 87 self._reverse[value] = key
88 88
89 89 def pop(self, key):
90 90 value = dict.pop(self, key)
91 91 self._reverse.pop(value)
92 92 return value
93 93
94 94 def get(self, key, default=None):
95 95 try:
96 96 return self[key]
97 97 except KeyError:
98 98 return default
99 99
100 100 #-----------------------------------------------------------------------------
101 101 # Functions
102 102 #-----------------------------------------------------------------------------
103 103
104 104 def validate_url(url):
105 105 """validate a url for zeromq"""
106 106 if not isinstance(url, basestring):
107 107 raise TypeError("url must be a string, not %r"%type(url))
108 108 url = url.lower()
109 109
110 110 proto_addr = url.split('://')
111 111 assert len(proto_addr) == 2, 'Invalid url: %r'%url
112 112 proto, addr = proto_addr
113 113 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
114 114
115 115 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
116 116 # author: Remi Sabourin
117 117 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
118 118
119 119 if proto == 'tcp':
120 120 lis = addr.split(':')
121 121 assert len(lis) == 2, 'Invalid url: %r'%url
122 122 addr,s_port = lis
123 123 try:
124 124 port = int(s_port)
125 125 except ValueError:
126 126 raise AssertionError("Invalid port %r in url: %r"%(port, url))
127 127
128 128 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
129 129
130 130 else:
131 131 # only validate tcp urls currently
132 132 pass
133 133
134 134 return True
135 135
136 136
137 137 def validate_url_container(container):
138 138 """validate a potentially nested collection of urls."""
139 139 if isinstance(container, basestring):
140 140 url = container
141 141 return validate_url(url)
142 142 elif isinstance(container, dict):
143 143 container = container.itervalues()
144 144
145 145 for element in container:
146 146 validate_url_container(element)
147 147
148 148
149 149 def split_url(url):
150 150 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
151 151 proto_addr = url.split('://')
152 152 assert len(proto_addr) == 2, 'Invalid url: %r'%url
153 153 proto, addr = proto_addr
154 154 lis = addr.split(':')
155 155 assert len(lis) == 2, 'Invalid url: %r'%url
156 156 addr,s_port = lis
157 157 return proto,addr,s_port
158 158
159 159 def disambiguate_ip_address(ip, location=None):
160 160 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
161 161 ones, based on the location (default interpretation of location is localhost)."""
162 162 if ip in ('0.0.0.0', '*'):
163 163 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
164 164 if location is None or location in external_ips:
165 165 ip='127.0.0.1'
166 166 elif location:
167 167 return location
168 168 return ip
169 169
170 170 def disambiguate_url(url, location=None):
171 171 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
172 172 ones, based on the location (default interpretation is localhost).
173 173
174 174 This is for zeromq urls, such as tcp://*:10101."""
175 175 try:
176 176 proto,ip,port = split_url(url)
177 177 except AssertionError:
178 178 # probably not tcp url; could be ipc, etc.
179 179 return url
180 180
181 181 ip = disambiguate_ip_address(ip,location)
182 182
183 183 return "%s://%s:%s"%(proto,ip,port)
184 184
185
186 def rekey(dikt):
187 """Rekey a dict that has been forced to use str keys where there should be
188 ints by json. This belongs in the jsonutil added by fperez."""
189 for k in dikt.iterkeys():
190 if isinstance(k, str):
191 ik=fk=None
192 try:
193 ik = int(k)
194 except ValueError:
195 try:
196 fk = float(k)
197 except ValueError:
198 continue
199 if ik is not None:
200 nk = ik
201 else:
202 nk = fk
203 if nk in dikt:
204 raise KeyError("already have key %r"%nk)
205 dikt[nk] = dikt.pop(k)
206 return dikt
207
208 185 def serialize_object(obj, threshold=64e-6):
209 186 """Serialize an object into a list of sendable buffers.
210 187
211 188 Parameters
212 189 ----------
213 190
214 191 obj : object
215 192 The object to be serialized
216 193 threshold : float
217 194 The threshold for not double-pickling the content.
218 195
219 196
220 197 Returns
221 198 -------
222 199 ('pmd', [bufs]) :
223 200 where pmd is the pickled metadata wrapper,
224 201 bufs is a list of data buffers
225 202 """
226 203 databuffers = []
227 204 if isinstance(obj, (list, tuple)):
228 205 clist = canSequence(obj)
229 206 slist = map(serialize, clist)
230 207 for s in slist:
231 208 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
232 209 databuffers.append(s.getData())
233 210 s.data = None
234 211 return pickle.dumps(slist,-1), databuffers
235 212 elif isinstance(obj, dict):
236 213 sobj = {}
237 214 for k in sorted(obj.iterkeys()):
238 215 s = serialize(can(obj[k]))
239 216 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
240 217 databuffers.append(s.getData())
241 218 s.data = None
242 219 sobj[k] = s
243 220 return pickle.dumps(sobj,-1),databuffers
244 221 else:
245 222 s = serialize(can(obj))
246 223 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
247 224 databuffers.append(s.getData())
248 225 s.data = None
249 226 return pickle.dumps(s,-1),databuffers
250 227
251 228
252 229 def unserialize_object(bufs):
253 230 """reconstruct an object serialized by serialize_object from data buffers."""
254 231 bufs = list(bufs)
255 232 sobj = pickle.loads(bufs.pop(0))
256 233 if isinstance(sobj, (list, tuple)):
257 234 for s in sobj:
258 235 if s.data is None:
259 236 s.data = bufs.pop(0)
260 237 return uncanSequence(map(unserialize, sobj)), bufs
261 238 elif isinstance(sobj, dict):
262 239 newobj = {}
263 240 for k in sorted(sobj.iterkeys()):
264 241 s = sobj[k]
265 242 if s.data is None:
266 243 s.data = bufs.pop(0)
267 244 newobj[k] = uncan(unserialize(s))
268 245 return newobj, bufs
269 246 else:
270 247 if sobj.data is None:
271 248 sobj.data = bufs.pop(0)
272 249 return uncan(unserialize(sobj)), bufs
273 250
274 251 def pack_apply_message(f, args, kwargs, threshold=64e-6):
275 252 """pack up a function, args, and kwargs to be sent over the wire
276 253 as a series of buffers. Any object whose data is larger than `threshold`
277 254 will not have their data copied (currently only numpy arrays support zero-copy)"""
278 255 msg = [pickle.dumps(can(f),-1)]
279 256 databuffers = [] # for large objects
280 257 sargs, bufs = serialize_object(args,threshold)
281 258 msg.append(sargs)
282 259 databuffers.extend(bufs)
283 260 skwargs, bufs = serialize_object(kwargs,threshold)
284 261 msg.append(skwargs)
285 262 databuffers.extend(bufs)
286 263 msg.extend(databuffers)
287 264 return msg
288 265
289 266 def unpack_apply_message(bufs, g=None, copy=True):
290 267 """unpack f,args,kwargs from buffers packed by pack_apply_message()
291 268 Returns: original f,args,kwargs"""
292 269 bufs = list(bufs) # allow us to pop
293 270 assert len(bufs) >= 3, "not enough buffers!"
294 271 if not copy:
295 272 for i in range(3):
296 273 bufs[i] = bufs[i].bytes
297 274 cf = pickle.loads(bufs.pop(0))
298 275 sargs = list(pickle.loads(bufs.pop(0)))
299 276 skwargs = dict(pickle.loads(bufs.pop(0)))
300 277 # print sargs, skwargs
301 278 f = uncan(cf, g)
302 279 for sa in sargs:
303 280 if sa.data is None:
304 281 m = bufs.pop(0)
305 282 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
306 283 # always use a buffer, until memoryviews get sorted out
307 284 sa.data = buffer(m)
308 285 # disable memoryview support
309 286 # if copy:
310 287 # sa.data = buffer(m)
311 288 # else:
312 289 # sa.data = m.buffer
313 290 else:
314 291 if copy:
315 292 sa.data = m
316 293 else:
317 294 sa.data = m.bytes
318 295
319 296 args = uncanSequence(map(unserialize, sargs), g)
320 297 kwargs = {}
321 298 for k in sorted(skwargs.iterkeys()):
322 299 sa = skwargs[k]
323 300 if sa.data is None:
324 301 m = bufs.pop(0)
325 302 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
326 303 # always use a buffer, until memoryviews get sorted out
327 304 sa.data = buffer(m)
328 305 # disable memoryview support
329 306 # if copy:
330 307 # sa.data = buffer(m)
331 308 # else:
332 309 # sa.data = m.buffer
333 310 else:
334 311 if copy:
335 312 sa.data = m
336 313 else:
337 314 sa.data = m.bytes
338 315
339 316 kwargs[k] = uncan(unserialize(sa), g)
340 317
341 318 return f,args,kwargs
342 319
343 320 #--------------------------------------------------------------------------
344 321 # helpers for implementing old MEC API via view.apply
345 322 #--------------------------------------------------------------------------
346 323
347 324 def interactive(f):
348 325 """decorator for making functions appear as interactively defined.
349 326 This results in the function being linked to the user_ns as globals()
350 327 instead of the module globals().
351 328 """
352 329 f.__module__ = '__main__'
353 330 return f
354 331
355 332 @interactive
356 333 def _push(ns):
357 334 """helper method for implementing `client.push` via `client.apply`"""
358 335 globals().update(ns)
359 336
360 337 @interactive
361 338 def _pull(keys):
362 339 """helper method for implementing `client.pull` via `client.apply`"""
363 340 user_ns = globals()
364 341 if isinstance(keys, (list,tuple, set)):
365 342 for key in keys:
366 343 if not user_ns.has_key(key):
367 344 raise NameError("name '%s' is not defined"%key)
368 345 return map(user_ns.get, keys)
369 346 else:
370 347 if not user_ns.has_key(keys):
371 348 raise NameError("name '%s' is not defined"%keys)
372 349 return user_ns.get(keys)
373 350
374 351 @interactive
375 352 def _execute(code):
376 353 """helper method for implementing `client.execute` via `client.apply`"""
377 354 exec code in globals()
378 355
379 356 #--------------------------------------------------------------------------
380 357 # extra process management utilities
381 358 #--------------------------------------------------------------------------
382 359
383 360 _random_ports = set()
384 361
385 362 def select_random_ports(n):
386 363 """Selects and return n random ports that are available."""
387 364 ports = []
388 365 for i in xrange(n):
389 366 sock = socket.socket()
390 367 sock.bind(('', 0))
391 368 while sock.getsockname()[1] in _random_ports:
392 369 sock.close()
393 370 sock = socket.socket()
394 371 sock.bind(('', 0))
395 372 ports.append(sock)
396 373 for i, sock in enumerate(ports):
397 374 port = sock.getsockname()[1]
398 375 sock.close()
399 376 ports[i] = port
400 377 _random_ports.add(port)
401 378 return ports
402 379
403 380 def signal_children(children):
404 381 """Relay interupt/term signals to children, for more solid process cleanup."""
405 382 def terminate_children(sig, frame):
406 383 logging.critical("Got signal %i, terminating children..."%sig)
407 384 for child in children:
408 385 child.terminate()
409 386
410 387 sys.exit(sig != SIGINT)
411 388 # sys.exit(sig)
412 389 for sig in (SIGINT, SIGABRT, SIGTERM):
413 390 signal(sig, terminate_children)
414 391
415 392 def generate_exec_key(keyfile):
416 393 import uuid
417 394 newkey = str(uuid.uuid4())
418 395 with open(keyfile, 'w') as f:
419 396 # f.write('ipython-key ')
420 397 f.write(newkey+'\n')
421 398 # set user-only RW permissions (0600)
422 399 # this will have no effect on Windows
423 400 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
424 401
425 402
426 403 def integer_loglevel(loglevel):
427 404 try:
428 405 loglevel = int(loglevel)
429 406 except ValueError:
430 407 if isinstance(loglevel, str):
431 408 loglevel = getattr(logging, loglevel)
432 409 return loglevel
433 410
434 411 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
435 412 logger = logging.getLogger(logname)
436 413 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
437 414 # don't add a second PUBHandler
438 415 return
439 416 loglevel = integer_loglevel(loglevel)
440 417 lsock = context.socket(zmq.PUB)
441 418 lsock.connect(iface)
442 419 handler = handlers.PUBHandler(lsock)
443 420 handler.setLevel(loglevel)
444 421 handler.root_topic = root
445 422 logger.addHandler(handler)
446 423 logger.setLevel(loglevel)
447 424
448 425 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
449 426 logger = logging.getLogger()
450 427 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
451 428 # don't add a second PUBHandler
452 429 return
453 430 loglevel = integer_loglevel(loglevel)
454 431 lsock = context.socket(zmq.PUB)
455 432 lsock.connect(iface)
456 433 handler = EnginePUBHandler(engine, lsock)
457 434 handler.setLevel(loglevel)
458 435 logger.addHandler(handler)
459 436 logger.setLevel(loglevel)
460 437 return logger
461 438
462 439 def local_logger(logname, loglevel=logging.DEBUG):
463 440 loglevel = integer_loglevel(loglevel)
464 441 logger = logging.getLogger(logname)
465 442 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
466 443 # don't add a second StreamHandler
467 444 return
468 445 handler = logging.StreamHandler()
469 446 handler.setLevel(loglevel)
470 447 logger.addHandler(handler)
471 448 logger.setLevel(loglevel)
472 449 return logger
473 450
@@ -1,134 +1,157 b''
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import re
15 15 import types
16 16 from datetime import datetime
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Globals and constants
20 20 #-----------------------------------------------------------------------------
21 21
22 22 # timestamp formats
23 23 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
24 24 ISO8601_PAT=re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+$")
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # Classes and functions
28 28 #-----------------------------------------------------------------------------
29 29
30 def rekey(dikt):
31 """Rekey a dict that has been forced to use str keys where there should be
32 ints by json."""
33 for k in dikt.iterkeys():
34 if isinstance(k, basestring):
35 ik=fk=None
36 try:
37 ik = int(k)
38 except ValueError:
39 try:
40 fk = float(k)
41 except ValueError:
42 continue
43 if ik is not None:
44 nk = ik
45 else:
46 nk = fk
47 if nk in dikt:
48 raise KeyError("already have key %r"%nk)
49 dikt[nk] = dikt.pop(k)
50 return dikt
51
52
30 53 def extract_dates(obj):
31 54 """extract ISO8601 dates from unpacked JSON"""
32 55 if isinstance(obj, dict):
33 56 obj = dict(obj) # don't clobber
34 57 for k,v in obj.iteritems():
35 58 obj[k] = extract_dates(v)
36 59 elif isinstance(obj, (list, tuple)):
37 60 obj = [ extract_dates(o) for o in obj ]
38 61 elif isinstance(obj, basestring):
39 62 if ISO8601_PAT.match(obj):
40 63 obj = datetime.strptime(obj, ISO8601)
41 64 return obj
42 65
43 66 def squash_dates(obj):
44 67 """squash datetime objects into ISO8601 strings"""
45 68 if isinstance(obj, dict):
46 69 obj = dict(obj) # don't clobber
47 70 for k,v in obj.iteritems():
48 71 obj[k] = squash_dates(v)
49 72 elif isinstance(obj, (list, tuple)):
50 73 obj = [ squash_dates(o) for o in obj ]
51 74 elif isinstance(obj, datetime):
52 75 obj = obj.strftime(ISO8601)
53 76 return obj
54 77
55 78 def date_default(obj):
56 79 """default function for packing datetime objects in JSON."""
57 80 if isinstance(obj, datetime):
58 81 return obj.strftime(ISO8601)
59 82 else:
60 83 raise TypeError("%r is not JSON serializable"%obj)
61 84
62 85
63 86
64 87 def json_clean(obj):
65 88 """Clean an object to ensure it's safe to encode in JSON.
66 89
67 90 Atomic, immutable objects are returned unmodified. Sets and tuples are
68 91 converted to lists, lists are copied and dicts are also copied.
69 92
70 93 Note: dicts whose keys could cause collisions upon encoding (such as a dict
71 94 with both the number 1 and the string '1' as keys) will cause a ValueError
72 95 to be raised.
73 96
74 97 Parameters
75 98 ----------
76 99 obj : any python object
77 100
78 101 Returns
79 102 -------
80 103 out : object
81 104
82 105 A version of the input which will not cause an encoding error when
83 106 encoded as JSON. Note that this function does not *encode* its inputs,
84 107 it simply sanitizes it so that there will be no encoding errors later.
85 108
86 109 Examples
87 110 --------
88 111 >>> json_clean(4)
89 112 4
90 113 >>> json_clean(range(10))
91 114 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
92 115 >>> json_clean(dict(x=1, y=2))
93 116 {'y': 2, 'x': 1}
94 117 >>> json_clean(dict(x=1, y=2, z=[1,2,3]))
95 118 {'y': 2, 'x': 1, 'z': [1, 2, 3]}
96 119 >>> json_clean(True)
97 120 True
98 121 """
99 122 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
100 123 # listed explicitly because bools pass as int instances
101 124 atomic_ok = (basestring, int, float, types.NoneType)
102 125
103 126 # containers that we need to convert into lists
104 127 container_to_list = (tuple, set, types.GeneratorType)
105 128
106 129 if isinstance(obj, atomic_ok):
107 130 return obj
108 131
109 132 if isinstance(obj, container_to_list) or (
110 133 hasattr(obj, '__iter__') and hasattr(obj, 'next')):
111 134 obj = list(obj)
112 135
113 136 if isinstance(obj, list):
114 137 return [json_clean(x) for x in obj]
115 138
116 139 if isinstance(obj, dict):
117 140 # First, validate that the dict won't lose data in conversion due to
118 141 # key collisions after stringification. This can happen with keys like
119 142 # True and 'true' or 1 and '1', which collide in JSON.
120 143 nkeys = len(obj)
121 144 nkeys_collapsed = len(set(map(str, obj)))
122 145 if nkeys != nkeys_collapsed:
123 146 raise ValueError('dict can not be safely converted to JSON: '
124 147 'key collision would lead to dropped values')
125 148 # If all OK, proceed by making the new dict that will be json-safe
126 149 out = {}
127 150 for k,v in obj.iteritems():
128 151 out[str(k)] = json_clean(v)
129 152 return out
130 153
131 154 # If we get here, we don't know how to handle the object, so we just get
132 155 # its repr and return that. This will catch lambdas, open sockets, class
133 156 # objects, and any other complicated contraption that json can't encode
134 157 return repr(obj)
General Comments 0
You need to be logged in to leave comments. Login now