##// END OF EJS Templates
Fix references to xrange
Thomas Kluyver -
Show More
@@ -1,649 +1,649 b''
1 """Base classes to manage a Client's interaction with a running kernel
1 """Base classes to manage a Client's interaction with a running kernel
2 """
2 """
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2013 The IPython Development Team
5 # Copyright (C) 2013 The IPython Development Team
6 #
6 #
7 # Distributed under the terms of the BSD License. The full license is in
7 # Distributed under the terms of the BSD License. The full license is in
8 # the file COPYING, distributed as part of this software.
8 # the file COPYING, distributed as part of this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 from __future__ import absolute_import
15 from __future__ import absolute_import
16
16
17 # Standard library imports
17 # Standard library imports
18 import atexit
18 import atexit
19 import errno
19 import errno
20 from threading import Thread
20 from threading import Thread
21 import time
21 import time
22
22
23 import zmq
23 import zmq
24 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
24 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
25 # during garbage collection of threads at exit:
25 # during garbage collection of threads at exit:
26 from zmq import ZMQError
26 from zmq import ZMQError
27 from zmq.eventloop import ioloop, zmqstream
27 from zmq.eventloop import ioloop, zmqstream
28
28
29 # Local imports
29 # Local imports
30 from .channelsabc import (
30 from .channelsabc import (
31 ShellChannelABC, IOPubChannelABC,
31 ShellChannelABC, IOPubChannelABC,
32 HBChannelABC, StdInChannelABC,
32 HBChannelABC, StdInChannelABC,
33 )
33 )
34 from IPython.utils.py3compat import string_types
34 from IPython.utils.py3compat import string_types
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Constants and exceptions
37 # Constants and exceptions
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40 class InvalidPortNumber(Exception):
40 class InvalidPortNumber(Exception):
41 pass
41 pass
42
42
43 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
44 # Utility functions
44 # Utility functions
45 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
46
46
47 # some utilities to validate message structure, these might get moved elsewhere
47 # some utilities to validate message structure, these might get moved elsewhere
48 # if they prove to have more generic utility
48 # if they prove to have more generic utility
49
49
50 def validate_string_list(lst):
50 def validate_string_list(lst):
51 """Validate that the input is a list of strings.
51 """Validate that the input is a list of strings.
52
52
53 Raises ValueError if not."""
53 Raises ValueError if not."""
54 if not isinstance(lst, list):
54 if not isinstance(lst, list):
55 raise ValueError('input %r must be a list' % lst)
55 raise ValueError('input %r must be a list' % lst)
56 for x in lst:
56 for x in lst:
57 if not isinstance(x, string_types):
57 if not isinstance(x, string_types):
58 raise ValueError('element %r in list must be a string' % x)
58 raise ValueError('element %r in list must be a string' % x)
59
59
60
60
61 def validate_string_dict(dct):
61 def validate_string_dict(dct):
62 """Validate that the input is a dict with string keys and values.
62 """Validate that the input is a dict with string keys and values.
63
63
64 Raises ValueError if not."""
64 Raises ValueError if not."""
65 for k,v in dct.iteritems():
65 for k,v in dct.iteritems():
66 if not isinstance(k, string_types):
66 if not isinstance(k, string_types):
67 raise ValueError('key %r in dict must be a string' % k)
67 raise ValueError('key %r in dict must be a string' % k)
68 if not isinstance(v, string_types):
68 if not isinstance(v, string_types):
69 raise ValueError('value %r in dict must be a string' % v)
69 raise ValueError('value %r in dict must be a string' % v)
70
70
71
71
72 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
73 # ZMQ Socket Channel classes
73 # ZMQ Socket Channel classes
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75
75
76 class ZMQSocketChannel(Thread):
76 class ZMQSocketChannel(Thread):
77 """The base class for the channels that use ZMQ sockets."""
77 """The base class for the channels that use ZMQ sockets."""
78 context = None
78 context = None
79 session = None
79 session = None
80 socket = None
80 socket = None
81 ioloop = None
81 ioloop = None
82 stream = None
82 stream = None
83 _address = None
83 _address = None
84 _exiting = False
84 _exiting = False
85 proxy_methods = []
85 proxy_methods = []
86
86
87 def __init__(self, context, session, address):
87 def __init__(self, context, session, address):
88 """Create a channel.
88 """Create a channel.
89
89
90 Parameters
90 Parameters
91 ----------
91 ----------
92 context : :class:`zmq.Context`
92 context : :class:`zmq.Context`
93 The ZMQ context to use.
93 The ZMQ context to use.
94 session : :class:`session.Session`
94 session : :class:`session.Session`
95 The session to use.
95 The session to use.
96 address : zmq url
96 address : zmq url
97 Standard (ip, port) tuple that the kernel is listening on.
97 Standard (ip, port) tuple that the kernel is listening on.
98 """
98 """
99 super(ZMQSocketChannel, self).__init__()
99 super(ZMQSocketChannel, self).__init__()
100 self.daemon = True
100 self.daemon = True
101
101
102 self.context = context
102 self.context = context
103 self.session = session
103 self.session = session
104 if isinstance(address, tuple):
104 if isinstance(address, tuple):
105 if address[1] == 0:
105 if address[1] == 0:
106 message = 'The port number for a channel cannot be 0.'
106 message = 'The port number for a channel cannot be 0.'
107 raise InvalidPortNumber(message)
107 raise InvalidPortNumber(message)
108 address = "tcp://%s:%i" % address
108 address = "tcp://%s:%i" % address
109 self._address = address
109 self._address = address
110 atexit.register(self._notice_exit)
110 atexit.register(self._notice_exit)
111
111
112 def _notice_exit(self):
112 def _notice_exit(self):
113 self._exiting = True
113 self._exiting = True
114
114
115 def _run_loop(self):
115 def _run_loop(self):
116 """Run my loop, ignoring EINTR events in the poller"""
116 """Run my loop, ignoring EINTR events in the poller"""
117 while True:
117 while True:
118 try:
118 try:
119 self.ioloop.start()
119 self.ioloop.start()
120 except ZMQError as e:
120 except ZMQError as e:
121 if e.errno == errno.EINTR:
121 if e.errno == errno.EINTR:
122 continue
122 continue
123 else:
123 else:
124 raise
124 raise
125 except Exception:
125 except Exception:
126 if self._exiting:
126 if self._exiting:
127 break
127 break
128 else:
128 else:
129 raise
129 raise
130 else:
130 else:
131 break
131 break
132
132
133 def stop(self):
133 def stop(self):
134 """Stop the channel's event loop and join its thread.
134 """Stop the channel's event loop and join its thread.
135
135
136 This calls :method:`Thread.join` and returns when the thread
136 This calls :method:`Thread.join` and returns when the thread
137 terminates. :class:`RuntimeError` will be raised if
137 terminates. :class:`RuntimeError` will be raised if
138 :method:`self.start` is called again.
138 :method:`self.start` is called again.
139 """
139 """
140 self.join()
140 self.join()
141
141
142 @property
142 @property
143 def address(self):
143 def address(self):
144 """Get the channel's address as a zmq url string.
144 """Get the channel's address as a zmq url string.
145
145
146 These URLS have the form: 'tcp://127.0.0.1:5555'.
146 These URLS have the form: 'tcp://127.0.0.1:5555'.
147 """
147 """
148 return self._address
148 return self._address
149
149
150 def _queue_send(self, msg):
150 def _queue_send(self, msg):
151 """Queue a message to be sent from the IOLoop's thread.
151 """Queue a message to be sent from the IOLoop's thread.
152
152
153 Parameters
153 Parameters
154 ----------
154 ----------
155 msg : message to send
155 msg : message to send
156
156
157 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
157 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
158 thread control of the action.
158 thread control of the action.
159 """
159 """
160 def thread_send():
160 def thread_send():
161 self.session.send(self.stream, msg)
161 self.session.send(self.stream, msg)
162 self.ioloop.add_callback(thread_send)
162 self.ioloop.add_callback(thread_send)
163
163
164 def _handle_recv(self, msg):
164 def _handle_recv(self, msg):
165 """Callback for stream.on_recv.
165 """Callback for stream.on_recv.
166
166
167 Unpacks message, and calls handlers with it.
167 Unpacks message, and calls handlers with it.
168 """
168 """
169 ident,smsg = self.session.feed_identities(msg)
169 ident,smsg = self.session.feed_identities(msg)
170 self.call_handlers(self.session.unserialize(smsg))
170 self.call_handlers(self.session.unserialize(smsg))
171
171
172
172
173
173
174 class ShellChannel(ZMQSocketChannel):
174 class ShellChannel(ZMQSocketChannel):
175 """The shell channel for issuing request/replies to the kernel."""
175 """The shell channel for issuing request/replies to the kernel."""
176
176
177 command_queue = None
177 command_queue = None
178 # flag for whether execute requests should be allowed to call raw_input:
178 # flag for whether execute requests should be allowed to call raw_input:
179 allow_stdin = True
179 allow_stdin = True
180 proxy_methods = [
180 proxy_methods = [
181 'execute',
181 'execute',
182 'complete',
182 'complete',
183 'object_info',
183 'object_info',
184 'history',
184 'history',
185 'kernel_info',
185 'kernel_info',
186 'shutdown',
186 'shutdown',
187 ]
187 ]
188
188
189 def __init__(self, context, session, address):
189 def __init__(self, context, session, address):
190 super(ShellChannel, self).__init__(context, session, address)
190 super(ShellChannel, self).__init__(context, session, address)
191 self.ioloop = ioloop.IOLoop()
191 self.ioloop = ioloop.IOLoop()
192
192
193 def run(self):
193 def run(self):
194 """The thread's main activity. Call start() instead."""
194 """The thread's main activity. Call start() instead."""
195 self.socket = self.context.socket(zmq.DEALER)
195 self.socket = self.context.socket(zmq.DEALER)
196 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
196 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
197 self.socket.connect(self.address)
197 self.socket.connect(self.address)
198 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
198 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
199 self.stream.on_recv(self._handle_recv)
199 self.stream.on_recv(self._handle_recv)
200 self._run_loop()
200 self._run_loop()
201 try:
201 try:
202 self.socket.close()
202 self.socket.close()
203 except:
203 except:
204 pass
204 pass
205
205
206 def stop(self):
206 def stop(self):
207 """Stop the channel's event loop and join its thread."""
207 """Stop the channel's event loop and join its thread."""
208 self.ioloop.stop()
208 self.ioloop.stop()
209 super(ShellChannel, self).stop()
209 super(ShellChannel, self).stop()
210
210
211 def call_handlers(self, msg):
211 def call_handlers(self, msg):
212 """This method is called in the ioloop thread when a message arrives.
212 """This method is called in the ioloop thread when a message arrives.
213
213
214 Subclasses should override this method to handle incoming messages.
214 Subclasses should override this method to handle incoming messages.
215 It is important to remember that this method is called in the thread
215 It is important to remember that this method is called in the thread
216 so that some logic must be done to ensure that the application level
216 so that some logic must be done to ensure that the application level
217 handlers are called in the application thread.
217 handlers are called in the application thread.
218 """
218 """
219 raise NotImplementedError('call_handlers must be defined in a subclass.')
219 raise NotImplementedError('call_handlers must be defined in a subclass.')
220
220
221 def execute(self, code, silent=False, store_history=True,
221 def execute(self, code, silent=False, store_history=True,
222 user_variables=None, user_expressions=None, allow_stdin=None):
222 user_variables=None, user_expressions=None, allow_stdin=None):
223 """Execute code in the kernel.
223 """Execute code in the kernel.
224
224
225 Parameters
225 Parameters
226 ----------
226 ----------
227 code : str
227 code : str
228 A string of Python code.
228 A string of Python code.
229
229
230 silent : bool, optional (default False)
230 silent : bool, optional (default False)
231 If set, the kernel will execute the code as quietly possible, and
231 If set, the kernel will execute the code as quietly possible, and
232 will force store_history to be False.
232 will force store_history to be False.
233
233
234 store_history : bool, optional (default True)
234 store_history : bool, optional (default True)
235 If set, the kernel will store command history. This is forced
235 If set, the kernel will store command history. This is forced
236 to be False if silent is True.
236 to be False if silent is True.
237
237
238 user_variables : list, optional
238 user_variables : list, optional
239 A list of variable names to pull from the user's namespace. They
239 A list of variable names to pull from the user's namespace. They
240 will come back as a dict with these names as keys and their
240 will come back as a dict with these names as keys and their
241 :func:`repr` as values.
241 :func:`repr` as values.
242
242
243 user_expressions : dict, optional
243 user_expressions : dict, optional
244 A dict mapping names to expressions to be evaluated in the user's
244 A dict mapping names to expressions to be evaluated in the user's
245 dict. The expression values are returned as strings formatted using
245 dict. The expression values are returned as strings formatted using
246 :func:`repr`.
246 :func:`repr`.
247
247
248 allow_stdin : bool, optional (default self.allow_stdin)
248 allow_stdin : bool, optional (default self.allow_stdin)
249 Flag for whether the kernel can send stdin requests to frontends.
249 Flag for whether the kernel can send stdin requests to frontends.
250
250
251 Some frontends (e.g. the Notebook) do not support stdin requests.
251 Some frontends (e.g. the Notebook) do not support stdin requests.
252 If raw_input is called from code executed from such a frontend, a
252 If raw_input is called from code executed from such a frontend, a
253 StdinNotImplementedError will be raised.
253 StdinNotImplementedError will be raised.
254
254
255 Returns
255 Returns
256 -------
256 -------
257 The msg_id of the message sent.
257 The msg_id of the message sent.
258 """
258 """
259 if user_variables is None:
259 if user_variables is None:
260 user_variables = []
260 user_variables = []
261 if user_expressions is None:
261 if user_expressions is None:
262 user_expressions = {}
262 user_expressions = {}
263 if allow_stdin is None:
263 if allow_stdin is None:
264 allow_stdin = self.allow_stdin
264 allow_stdin = self.allow_stdin
265
265
266
266
267 # Don't waste network traffic if inputs are invalid
267 # Don't waste network traffic if inputs are invalid
268 if not isinstance(code, string_types):
268 if not isinstance(code, string_types):
269 raise ValueError('code %r must be a string' % code)
269 raise ValueError('code %r must be a string' % code)
270 validate_string_list(user_variables)
270 validate_string_list(user_variables)
271 validate_string_dict(user_expressions)
271 validate_string_dict(user_expressions)
272
272
273 # Create class for content/msg creation. Related to, but possibly
273 # Create class for content/msg creation. Related to, but possibly
274 # not in Session.
274 # not in Session.
275 content = dict(code=code, silent=silent, store_history=store_history,
275 content = dict(code=code, silent=silent, store_history=store_history,
276 user_variables=user_variables,
276 user_variables=user_variables,
277 user_expressions=user_expressions,
277 user_expressions=user_expressions,
278 allow_stdin=allow_stdin,
278 allow_stdin=allow_stdin,
279 )
279 )
280 msg = self.session.msg('execute_request', content)
280 msg = self.session.msg('execute_request', content)
281 self._queue_send(msg)
281 self._queue_send(msg)
282 return msg['header']['msg_id']
282 return msg['header']['msg_id']
283
283
284 def complete(self, text, line, cursor_pos, block=None):
284 def complete(self, text, line, cursor_pos, block=None):
285 """Tab complete text in the kernel's namespace.
285 """Tab complete text in the kernel's namespace.
286
286
287 Parameters
287 Parameters
288 ----------
288 ----------
289 text : str
289 text : str
290 The text to complete.
290 The text to complete.
291 line : str
291 line : str
292 The full line of text that is the surrounding context for the
292 The full line of text that is the surrounding context for the
293 text to complete.
293 text to complete.
294 cursor_pos : int
294 cursor_pos : int
295 The position of the cursor in the line where the completion was
295 The position of the cursor in the line where the completion was
296 requested.
296 requested.
297 block : str, optional
297 block : str, optional
298 The full block of code in which the completion is being requested.
298 The full block of code in which the completion is being requested.
299
299
300 Returns
300 Returns
301 -------
301 -------
302 The msg_id of the message sent.
302 The msg_id of the message sent.
303 """
303 """
304 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
304 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
305 msg = self.session.msg('complete_request', content)
305 msg = self.session.msg('complete_request', content)
306 self._queue_send(msg)
306 self._queue_send(msg)
307 return msg['header']['msg_id']
307 return msg['header']['msg_id']
308
308
309 def object_info(self, oname, detail_level=0):
309 def object_info(self, oname, detail_level=0):
310 """Get metadata information about an object in the kernel's namespace.
310 """Get metadata information about an object in the kernel's namespace.
311
311
312 Parameters
312 Parameters
313 ----------
313 ----------
314 oname : str
314 oname : str
315 A string specifying the object name.
315 A string specifying the object name.
316 detail_level : int, optional
316 detail_level : int, optional
317 The level of detail for the introspection (0-2)
317 The level of detail for the introspection (0-2)
318
318
319 Returns
319 Returns
320 -------
320 -------
321 The msg_id of the message sent.
321 The msg_id of the message sent.
322 """
322 """
323 content = dict(oname=oname, detail_level=detail_level)
323 content = dict(oname=oname, detail_level=detail_level)
324 msg = self.session.msg('object_info_request', content)
324 msg = self.session.msg('object_info_request', content)
325 self._queue_send(msg)
325 self._queue_send(msg)
326 return msg['header']['msg_id']
326 return msg['header']['msg_id']
327
327
328 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
328 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
329 """Get entries from the kernel's history list.
329 """Get entries from the kernel's history list.
330
330
331 Parameters
331 Parameters
332 ----------
332 ----------
333 raw : bool
333 raw : bool
334 If True, return the raw input.
334 If True, return the raw input.
335 output : bool
335 output : bool
336 If True, then return the output as well.
336 If True, then return the output as well.
337 hist_access_type : str
337 hist_access_type : str
338 'range' (fill in session, start and stop params), 'tail' (fill in n)
338 'range' (fill in session, start and stop params), 'tail' (fill in n)
339 or 'search' (fill in pattern param).
339 or 'search' (fill in pattern param).
340
340
341 session : int
341 session : int
342 For a range request, the session from which to get lines. Session
342 For a range request, the session from which to get lines. Session
343 numbers are positive integers; negative ones count back from the
343 numbers are positive integers; negative ones count back from the
344 current session.
344 current session.
345 start : int
345 start : int
346 The first line number of a history range.
346 The first line number of a history range.
347 stop : int
347 stop : int
348 The final (excluded) line number of a history range.
348 The final (excluded) line number of a history range.
349
349
350 n : int
350 n : int
351 The number of lines of history to get for a tail request.
351 The number of lines of history to get for a tail request.
352
352
353 pattern : str
353 pattern : str
354 The glob-syntax pattern for a search request.
354 The glob-syntax pattern for a search request.
355
355
356 Returns
356 Returns
357 -------
357 -------
358 The msg_id of the message sent.
358 The msg_id of the message sent.
359 """
359 """
360 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
360 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
361 **kwargs)
361 **kwargs)
362 msg = self.session.msg('history_request', content)
362 msg = self.session.msg('history_request', content)
363 self._queue_send(msg)
363 self._queue_send(msg)
364 return msg['header']['msg_id']
364 return msg['header']['msg_id']
365
365
366 def kernel_info(self):
366 def kernel_info(self):
367 """Request kernel info."""
367 """Request kernel info."""
368 msg = self.session.msg('kernel_info_request')
368 msg = self.session.msg('kernel_info_request')
369 self._queue_send(msg)
369 self._queue_send(msg)
370 return msg['header']['msg_id']
370 return msg['header']['msg_id']
371
371
372 def shutdown(self, restart=False):
372 def shutdown(self, restart=False):
373 """Request an immediate kernel shutdown.
373 """Request an immediate kernel shutdown.
374
374
375 Upon receipt of the (empty) reply, client code can safely assume that
375 Upon receipt of the (empty) reply, client code can safely assume that
376 the kernel has shut down and it's safe to forcefully terminate it if
376 the kernel has shut down and it's safe to forcefully terminate it if
377 it's still alive.
377 it's still alive.
378
378
379 The kernel will send the reply via a function registered with Python's
379 The kernel will send the reply via a function registered with Python's
380 atexit module, ensuring it's truly done as the kernel is done with all
380 atexit module, ensuring it's truly done as the kernel is done with all
381 normal operation.
381 normal operation.
382 """
382 """
383 # Send quit message to kernel. Once we implement kernel-side setattr,
383 # Send quit message to kernel. Once we implement kernel-side setattr,
384 # this should probably be done that way, but for now this will do.
384 # this should probably be done that way, but for now this will do.
385 msg = self.session.msg('shutdown_request', {'restart':restart})
385 msg = self.session.msg('shutdown_request', {'restart':restart})
386 self._queue_send(msg)
386 self._queue_send(msg)
387 return msg['header']['msg_id']
387 return msg['header']['msg_id']
388
388
389
389
390
390
391 class IOPubChannel(ZMQSocketChannel):
391 class IOPubChannel(ZMQSocketChannel):
392 """The iopub channel which listens for messages that the kernel publishes.
392 """The iopub channel which listens for messages that the kernel publishes.
393
393
394 This channel is where all output is published to frontends.
394 This channel is where all output is published to frontends.
395 """
395 """
396
396
397 def __init__(self, context, session, address):
397 def __init__(self, context, session, address):
398 super(IOPubChannel, self).__init__(context, session, address)
398 super(IOPubChannel, self).__init__(context, session, address)
399 self.ioloop = ioloop.IOLoop()
399 self.ioloop = ioloop.IOLoop()
400
400
401 def run(self):
401 def run(self):
402 """The thread's main activity. Call start() instead."""
402 """The thread's main activity. Call start() instead."""
403 self.socket = self.context.socket(zmq.SUB)
403 self.socket = self.context.socket(zmq.SUB)
404 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
404 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
405 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
405 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
406 self.socket.connect(self.address)
406 self.socket.connect(self.address)
407 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
407 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
408 self.stream.on_recv(self._handle_recv)
408 self.stream.on_recv(self._handle_recv)
409 self._run_loop()
409 self._run_loop()
410 try:
410 try:
411 self.socket.close()
411 self.socket.close()
412 except:
412 except:
413 pass
413 pass
414
414
415 def stop(self):
415 def stop(self):
416 """Stop the channel's event loop and join its thread."""
416 """Stop the channel's event loop and join its thread."""
417 self.ioloop.stop()
417 self.ioloop.stop()
418 super(IOPubChannel, self).stop()
418 super(IOPubChannel, self).stop()
419
419
420 def call_handlers(self, msg):
420 def call_handlers(self, msg):
421 """This method is called in the ioloop thread when a message arrives.
421 """This method is called in the ioloop thread when a message arrives.
422
422
423 Subclasses should override this method to handle incoming messages.
423 Subclasses should override this method to handle incoming messages.
424 It is important to remember that this method is called in the thread
424 It is important to remember that this method is called in the thread
425 so that some logic must be done to ensure that the application leve
425 so that some logic must be done to ensure that the application leve
426 handlers are called in the application thread.
426 handlers are called in the application thread.
427 """
427 """
428 raise NotImplementedError('call_handlers must be defined in a subclass.')
428 raise NotImplementedError('call_handlers must be defined in a subclass.')
429
429
430 def flush(self, timeout=1.0):
430 def flush(self, timeout=1.0):
431 """Immediately processes all pending messages on the iopub channel.
431 """Immediately processes all pending messages on the iopub channel.
432
432
433 Callers should use this method to ensure that :method:`call_handlers`
433 Callers should use this method to ensure that :method:`call_handlers`
434 has been called for all messages that have been received on the
434 has been called for all messages that have been received on the
435 0MQ SUB socket of this channel.
435 0MQ SUB socket of this channel.
436
436
437 This method is thread safe.
437 This method is thread safe.
438
438
439 Parameters
439 Parameters
440 ----------
440 ----------
441 timeout : float, optional
441 timeout : float, optional
442 The maximum amount of time to spend flushing, in seconds. The
442 The maximum amount of time to spend flushing, in seconds. The
443 default is one second.
443 default is one second.
444 """
444 """
445 # We do the IOLoop callback process twice to ensure that the IOLoop
445 # We do the IOLoop callback process twice to ensure that the IOLoop
446 # gets to perform at least one full poll.
446 # gets to perform at least one full poll.
447 stop_time = time.time() + timeout
447 stop_time = time.time() + timeout
448 for i in xrange(2):
448 for i in range(2):
449 self._flushed = False
449 self._flushed = False
450 self.ioloop.add_callback(self._flush)
450 self.ioloop.add_callback(self._flush)
451 while not self._flushed and time.time() < stop_time:
451 while not self._flushed and time.time() < stop_time:
452 time.sleep(0.01)
452 time.sleep(0.01)
453
453
454 def _flush(self):
454 def _flush(self):
455 """Callback for :method:`self.flush`."""
455 """Callback for :method:`self.flush`."""
456 self.stream.flush()
456 self.stream.flush()
457 self._flushed = True
457 self._flushed = True
458
458
459
459
460 class StdInChannel(ZMQSocketChannel):
460 class StdInChannel(ZMQSocketChannel):
461 """The stdin channel to handle raw_input requests that the kernel makes."""
461 """The stdin channel to handle raw_input requests that the kernel makes."""
462
462
463 msg_queue = None
463 msg_queue = None
464 proxy_methods = ['input']
464 proxy_methods = ['input']
465
465
466 def __init__(self, context, session, address):
466 def __init__(self, context, session, address):
467 super(StdInChannel, self).__init__(context, session, address)
467 super(StdInChannel, self).__init__(context, session, address)
468 self.ioloop = ioloop.IOLoop()
468 self.ioloop = ioloop.IOLoop()
469
469
470 def run(self):
470 def run(self):
471 """The thread's main activity. Call start() instead."""
471 """The thread's main activity. Call start() instead."""
472 self.socket = self.context.socket(zmq.DEALER)
472 self.socket = self.context.socket(zmq.DEALER)
473 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
473 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
474 self.socket.connect(self.address)
474 self.socket.connect(self.address)
475 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
475 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
476 self.stream.on_recv(self._handle_recv)
476 self.stream.on_recv(self._handle_recv)
477 self._run_loop()
477 self._run_loop()
478 try:
478 try:
479 self.socket.close()
479 self.socket.close()
480 except:
480 except:
481 pass
481 pass
482
482
483 def stop(self):
483 def stop(self):
484 """Stop the channel's event loop and join its thread."""
484 """Stop the channel's event loop and join its thread."""
485 self.ioloop.stop()
485 self.ioloop.stop()
486 super(StdInChannel, self).stop()
486 super(StdInChannel, self).stop()
487
487
488 def call_handlers(self, msg):
488 def call_handlers(self, msg):
489 """This method is called in the ioloop thread when a message arrives.
489 """This method is called in the ioloop thread when a message arrives.
490
490
491 Subclasses should override this method to handle incoming messages.
491 Subclasses should override this method to handle incoming messages.
492 It is important to remember that this method is called in the thread
492 It is important to remember that this method is called in the thread
493 so that some logic must be done to ensure that the application leve
493 so that some logic must be done to ensure that the application leve
494 handlers are called in the application thread.
494 handlers are called in the application thread.
495 """
495 """
496 raise NotImplementedError('call_handlers must be defined in a subclass.')
496 raise NotImplementedError('call_handlers must be defined in a subclass.')
497
497
498 def input(self, string):
498 def input(self, string):
499 """Send a string of raw input to the kernel."""
499 """Send a string of raw input to the kernel."""
500 content = dict(value=string)
500 content = dict(value=string)
501 msg = self.session.msg('input_reply', content)
501 msg = self.session.msg('input_reply', content)
502 self._queue_send(msg)
502 self._queue_send(msg)
503
503
504
504
505 class HBChannel(ZMQSocketChannel):
505 class HBChannel(ZMQSocketChannel):
506 """The heartbeat channel which monitors the kernel heartbeat.
506 """The heartbeat channel which monitors the kernel heartbeat.
507
507
508 Note that the heartbeat channel is paused by default. As long as you start
508 Note that the heartbeat channel is paused by default. As long as you start
509 this channel, the kernel manager will ensure that it is paused and un-paused
509 this channel, the kernel manager will ensure that it is paused and un-paused
510 as appropriate.
510 as appropriate.
511 """
511 """
512
512
513 time_to_dead = 3.0
513 time_to_dead = 3.0
514 socket = None
514 socket = None
515 poller = None
515 poller = None
516 _running = None
516 _running = None
517 _pause = None
517 _pause = None
518 _beating = None
518 _beating = None
519
519
520 def __init__(self, context, session, address):
520 def __init__(self, context, session, address):
521 super(HBChannel, self).__init__(context, session, address)
521 super(HBChannel, self).__init__(context, session, address)
522 self._running = False
522 self._running = False
523 self._pause =True
523 self._pause =True
524 self.poller = zmq.Poller()
524 self.poller = zmq.Poller()
525
525
526 def _create_socket(self):
526 def _create_socket(self):
527 if self.socket is not None:
527 if self.socket is not None:
528 # close previous socket, before opening a new one
528 # close previous socket, before opening a new one
529 self.poller.unregister(self.socket)
529 self.poller.unregister(self.socket)
530 self.socket.close()
530 self.socket.close()
531 self.socket = self.context.socket(zmq.REQ)
531 self.socket = self.context.socket(zmq.REQ)
532 self.socket.setsockopt(zmq.LINGER, 0)
532 self.socket.setsockopt(zmq.LINGER, 0)
533 self.socket.connect(self.address)
533 self.socket.connect(self.address)
534
534
535 self.poller.register(self.socket, zmq.POLLIN)
535 self.poller.register(self.socket, zmq.POLLIN)
536
536
537 def _poll(self, start_time):
537 def _poll(self, start_time):
538 """poll for heartbeat replies until we reach self.time_to_dead.
538 """poll for heartbeat replies until we reach self.time_to_dead.
539
539
540 Ignores interrupts, and returns the result of poll(), which
540 Ignores interrupts, and returns the result of poll(), which
541 will be an empty list if no messages arrived before the timeout,
541 will be an empty list if no messages arrived before the timeout,
542 or the event tuple if there is a message to receive.
542 or the event tuple if there is a message to receive.
543 """
543 """
544
544
545 until_dead = self.time_to_dead - (time.time() - start_time)
545 until_dead = self.time_to_dead - (time.time() - start_time)
546 # ensure poll at least once
546 # ensure poll at least once
547 until_dead = max(until_dead, 1e-3)
547 until_dead = max(until_dead, 1e-3)
548 events = []
548 events = []
549 while True:
549 while True:
550 try:
550 try:
551 events = self.poller.poll(1000 * until_dead)
551 events = self.poller.poll(1000 * until_dead)
552 except ZMQError as e:
552 except ZMQError as e:
553 if e.errno == errno.EINTR:
553 if e.errno == errno.EINTR:
554 # ignore interrupts during heartbeat
554 # ignore interrupts during heartbeat
555 # this may never actually happen
555 # this may never actually happen
556 until_dead = self.time_to_dead - (time.time() - start_time)
556 until_dead = self.time_to_dead - (time.time() - start_time)
557 until_dead = max(until_dead, 1e-3)
557 until_dead = max(until_dead, 1e-3)
558 pass
558 pass
559 else:
559 else:
560 raise
560 raise
561 except Exception:
561 except Exception:
562 if self._exiting:
562 if self._exiting:
563 break
563 break
564 else:
564 else:
565 raise
565 raise
566 else:
566 else:
567 break
567 break
568 return events
568 return events
569
569
570 def run(self):
570 def run(self):
571 """The thread's main activity. Call start() instead."""
571 """The thread's main activity. Call start() instead."""
572 self._create_socket()
572 self._create_socket()
573 self._running = True
573 self._running = True
574 self._beating = True
574 self._beating = True
575
575
576 while self._running:
576 while self._running:
577 if self._pause:
577 if self._pause:
578 # just sleep, and skip the rest of the loop
578 # just sleep, and skip the rest of the loop
579 time.sleep(self.time_to_dead)
579 time.sleep(self.time_to_dead)
580 continue
580 continue
581
581
582 since_last_heartbeat = 0.0
582 since_last_heartbeat = 0.0
583 # io.rprint('Ping from HB channel') # dbg
583 # io.rprint('Ping from HB channel') # dbg
584 # no need to catch EFSM here, because the previous event was
584 # no need to catch EFSM here, because the previous event was
585 # either a recv or connect, which cannot be followed by EFSM
585 # either a recv or connect, which cannot be followed by EFSM
586 self.socket.send(b'ping')
586 self.socket.send(b'ping')
587 request_time = time.time()
587 request_time = time.time()
588 ready = self._poll(request_time)
588 ready = self._poll(request_time)
589 if ready:
589 if ready:
590 self._beating = True
590 self._beating = True
591 # the poll above guarantees we have something to recv
591 # the poll above guarantees we have something to recv
592 self.socket.recv()
592 self.socket.recv()
593 # sleep the remainder of the cycle
593 # sleep the remainder of the cycle
594 remainder = self.time_to_dead - (time.time() - request_time)
594 remainder = self.time_to_dead - (time.time() - request_time)
595 if remainder > 0:
595 if remainder > 0:
596 time.sleep(remainder)
596 time.sleep(remainder)
597 continue
597 continue
598 else:
598 else:
599 # nothing was received within the time limit, signal heart failure
599 # nothing was received within the time limit, signal heart failure
600 self._beating = False
600 self._beating = False
601 since_last_heartbeat = time.time() - request_time
601 since_last_heartbeat = time.time() - request_time
602 self.call_handlers(since_last_heartbeat)
602 self.call_handlers(since_last_heartbeat)
603 # and close/reopen the socket, because the REQ/REP cycle has been broken
603 # and close/reopen the socket, because the REQ/REP cycle has been broken
604 self._create_socket()
604 self._create_socket()
605 continue
605 continue
606 try:
606 try:
607 self.socket.close()
607 self.socket.close()
608 except:
608 except:
609 pass
609 pass
610
610
611 def pause(self):
611 def pause(self):
612 """Pause the heartbeat."""
612 """Pause the heartbeat."""
613 self._pause = True
613 self._pause = True
614
614
615 def unpause(self):
615 def unpause(self):
616 """Unpause the heartbeat."""
616 """Unpause the heartbeat."""
617 self._pause = False
617 self._pause = False
618
618
619 def is_beating(self):
619 def is_beating(self):
620 """Is the heartbeat running and responsive (and not paused)."""
620 """Is the heartbeat running and responsive (and not paused)."""
621 if self.is_alive() and not self._pause and self._beating:
621 if self.is_alive() and not self._pause and self._beating:
622 return True
622 return True
623 else:
623 else:
624 return False
624 return False
625
625
626 def stop(self):
626 def stop(self):
627 """Stop the channel's event loop and join its thread."""
627 """Stop the channel's event loop and join its thread."""
628 self._running = False
628 self._running = False
629 super(HBChannel, self).stop()
629 super(HBChannel, self).stop()
630
630
631 def call_handlers(self, since_last_heartbeat):
631 def call_handlers(self, since_last_heartbeat):
632 """This method is called in the ioloop thread when a message arrives.
632 """This method is called in the ioloop thread when a message arrives.
633
633
634 Subclasses should override this method to handle incoming messages.
634 Subclasses should override this method to handle incoming messages.
635 It is important to remember that this method is called in the thread
635 It is important to remember that this method is called in the thread
636 so that some logic must be done to ensure that the application level
636 so that some logic must be done to ensure that the application level
637 handlers are called in the application thread.
637 handlers are called in the application thread.
638 """
638 """
639 raise NotImplementedError('call_handlers must be defined in a subclass.')
639 raise NotImplementedError('call_handlers must be defined in a subclass.')
640
640
641
641
642 #---------------------------------------------------------------------#-----------------------------------------------------------------------------
642 #---------------------------------------------------------------------#-----------------------------------------------------------------------------
643 # ABC Registration
643 # ABC Registration
644 #-----------------------------------------------------------------------------
644 #-----------------------------------------------------------------------------
645
645
646 ShellChannelABC.register(ShellChannel)
646 ShellChannelABC.register(ShellChannel)
647 IOPubChannelABC.register(IOPubChannel)
647 IOPubChannelABC.register(IOPubChannel)
648 HBChannelABC.register(HBChannel)
648 HBChannelABC.register(HBChannel)
649 StdInChannelABC.register(StdInChannel)
649 StdInChannelABC.register(StdInChannel)
@@ -1,339 +1,339 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 A module to change reload() so that it acts recursively.
3 A module to change reload() so that it acts recursively.
4 To enable it type::
4 To enable it type::
5
5
6 import __builtin__, deepreload
6 import __builtin__, deepreload
7 __builtin__.reload = deepreload.reload
7 __builtin__.reload = deepreload.reload
8
8
9 You can then disable it with::
9 You can then disable it with::
10
10
11 __builtin__.reload = deepreload.original_reload
11 __builtin__.reload = deepreload.original_reload
12
12
13 Alternatively, you can add a dreload builtin alongside normal reload with::
13 Alternatively, you can add a dreload builtin alongside normal reload with::
14
14
15 __builtin__.dreload = deepreload.reload
15 __builtin__.dreload = deepreload.reload
16
16
17 This code is almost entirely based on knee.py, which is a Python
17 This code is almost entirely based on knee.py, which is a Python
18 re-implementation of hierarchical module import.
18 re-implementation of hierarchical module import.
19 """
19 """
20 from __future__ import print_function
20 from __future__ import print_function
21 #*****************************************************************************
21 #*****************************************************************************
22 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
22 # Copyright (C) 2001 Nathaniel Gray <n8gray@caltech.edu>
23 #
23 #
24 # Distributed under the terms of the BSD License. The full license is in
24 # Distributed under the terms of the BSD License. The full license is in
25 # the file COPYING, distributed as part of this software.
25 # the file COPYING, distributed as part of this software.
26 #*****************************************************************************
26 #*****************************************************************************
27
27
28 from contextlib import contextmanager
28 from contextlib import contextmanager
29 import imp
29 import imp
30 import sys
30 import sys
31
31
32 from types import ModuleType
32 from types import ModuleType
33 from warnings import warn
33 from warnings import warn
34
34
35 from IPython.utils.py3compat import builtin_mod, builtin_mod_name
35 from IPython.utils.py3compat import builtin_mod, builtin_mod_name
36
36
37 original_import = builtin_mod.__import__
37 original_import = builtin_mod.__import__
38
38
39 @contextmanager
39 @contextmanager
40 def replace_import_hook(new_import):
40 def replace_import_hook(new_import):
41 saved_import = builtin_mod.__import__
41 saved_import = builtin_mod.__import__
42 builtin_mod.__import__ = new_import
42 builtin_mod.__import__ = new_import
43 try:
43 try:
44 yield
44 yield
45 finally:
45 finally:
46 builtin_mod.__import__ = saved_import
46 builtin_mod.__import__ = saved_import
47
47
48 def get_parent(globals, level):
48 def get_parent(globals, level):
49 """
49 """
50 parent, name = get_parent(globals, level)
50 parent, name = get_parent(globals, level)
51
51
52 Return the package that an import is being performed in. If globals comes
52 Return the package that an import is being performed in. If globals comes
53 from the module foo.bar.bat (not itself a package), this returns the
53 from the module foo.bar.bat (not itself a package), this returns the
54 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
54 sys.modules entry for foo.bar. If globals is from a package's __init__.py,
55 the package's entry in sys.modules is returned.
55 the package's entry in sys.modules is returned.
56
56
57 If globals doesn't come from a package or a module in a package, or a
57 If globals doesn't come from a package or a module in a package, or a
58 corresponding entry is not found in sys.modules, None is returned.
58 corresponding entry is not found in sys.modules, None is returned.
59 """
59 """
60 orig_level = level
60 orig_level = level
61
61
62 if not level or not isinstance(globals, dict):
62 if not level or not isinstance(globals, dict):
63 return None, ''
63 return None, ''
64
64
65 pkgname = globals.get('__package__', None)
65 pkgname = globals.get('__package__', None)
66
66
67 if pkgname is not None:
67 if pkgname is not None:
68 # __package__ is set, so use it
68 # __package__ is set, so use it
69 if not hasattr(pkgname, 'rindex'):
69 if not hasattr(pkgname, 'rindex'):
70 raise ValueError('__package__ set to non-string')
70 raise ValueError('__package__ set to non-string')
71 if len(pkgname) == 0:
71 if len(pkgname) == 0:
72 if level > 0:
72 if level > 0:
73 raise ValueError('Attempted relative import in non-package')
73 raise ValueError('Attempted relative import in non-package')
74 return None, ''
74 return None, ''
75 name = pkgname
75 name = pkgname
76 else:
76 else:
77 # __package__ not set, so figure it out and set it
77 # __package__ not set, so figure it out and set it
78 if '__name__' not in globals:
78 if '__name__' not in globals:
79 return None, ''
79 return None, ''
80 modname = globals['__name__']
80 modname = globals['__name__']
81
81
82 if '__path__' in globals:
82 if '__path__' in globals:
83 # __path__ is set, so modname is already the package name
83 # __path__ is set, so modname is already the package name
84 globals['__package__'] = name = modname
84 globals['__package__'] = name = modname
85 else:
85 else:
86 # Normal module, so work out the package name if any
86 # Normal module, so work out the package name if any
87 lastdot = modname.rfind('.')
87 lastdot = modname.rfind('.')
88 if lastdot < 0 and level > 0:
88 if lastdot < 0 and level > 0:
89 raise ValueError("Attempted relative import in non-package")
89 raise ValueError("Attempted relative import in non-package")
90 if lastdot < 0:
90 if lastdot < 0:
91 globals['__package__'] = None
91 globals['__package__'] = None
92 return None, ''
92 return None, ''
93 globals['__package__'] = name = modname[:lastdot]
93 globals['__package__'] = name = modname[:lastdot]
94
94
95 dot = len(name)
95 dot = len(name)
96 for x in xrange(level, 1, -1):
96 for x in range(level, 1, -1):
97 try:
97 try:
98 dot = name.rindex('.', 0, dot)
98 dot = name.rindex('.', 0, dot)
99 except ValueError:
99 except ValueError:
100 raise ValueError("attempted relative import beyond top-level "
100 raise ValueError("attempted relative import beyond top-level "
101 "package")
101 "package")
102 name = name[:dot]
102 name = name[:dot]
103
103
104 try:
104 try:
105 parent = sys.modules[name]
105 parent = sys.modules[name]
106 except:
106 except:
107 if orig_level < 1:
107 if orig_level < 1:
108 warn("Parent module '%.200s' not found while handling absolute "
108 warn("Parent module '%.200s' not found while handling absolute "
109 "import" % name)
109 "import" % name)
110 parent = None
110 parent = None
111 else:
111 else:
112 raise SystemError("Parent module '%.200s' not loaded, cannot "
112 raise SystemError("Parent module '%.200s' not loaded, cannot "
113 "perform relative import" % name)
113 "perform relative import" % name)
114
114
115 # We expect, but can't guarantee, if parent != None, that:
115 # We expect, but can't guarantee, if parent != None, that:
116 # - parent.__name__ == name
116 # - parent.__name__ == name
117 # - parent.__dict__ is globals
117 # - parent.__dict__ is globals
118 # If this is violated... Who cares?
118 # If this is violated... Who cares?
119 return parent, name
119 return parent, name
120
120
121 def load_next(mod, altmod, name, buf):
121 def load_next(mod, altmod, name, buf):
122 """
122 """
123 mod, name, buf = load_next(mod, altmod, name, buf)
123 mod, name, buf = load_next(mod, altmod, name, buf)
124
124
125 altmod is either None or same as mod
125 altmod is either None or same as mod
126 """
126 """
127
127
128 if len(name) == 0:
128 if len(name) == 0:
129 # completely empty module name should only happen in
129 # completely empty module name should only happen in
130 # 'from . import' (or '__import__("")')
130 # 'from . import' (or '__import__("")')
131 return mod, None, buf
131 return mod, None, buf
132
132
133 dot = name.find('.')
133 dot = name.find('.')
134 if dot == 0:
134 if dot == 0:
135 raise ValueError('Empty module name')
135 raise ValueError('Empty module name')
136
136
137 if dot < 0:
137 if dot < 0:
138 subname = name
138 subname = name
139 next = None
139 next = None
140 else:
140 else:
141 subname = name[:dot]
141 subname = name[:dot]
142 next = name[dot+1:]
142 next = name[dot+1:]
143
143
144 if buf != '':
144 if buf != '':
145 buf += '.'
145 buf += '.'
146 buf += subname
146 buf += subname
147
147
148 result = import_submodule(mod, subname, buf)
148 result = import_submodule(mod, subname, buf)
149 if result is None and mod != altmod:
149 if result is None and mod != altmod:
150 result = import_submodule(altmod, subname, subname)
150 result = import_submodule(altmod, subname, subname)
151 if result is not None:
151 if result is not None:
152 buf = subname
152 buf = subname
153
153
154 if result is None:
154 if result is None:
155 raise ImportError("No module named %.200s" % name)
155 raise ImportError("No module named %.200s" % name)
156
156
157 return result, next, buf
157 return result, next, buf
158
158
159 # Need to keep track of what we've already reloaded to prevent cyclic evil
159 # Need to keep track of what we've already reloaded to prevent cyclic evil
160 found_now = {}
160 found_now = {}
161
161
162 def import_submodule(mod, subname, fullname):
162 def import_submodule(mod, subname, fullname):
163 """m = import_submodule(mod, subname, fullname)"""
163 """m = import_submodule(mod, subname, fullname)"""
164 # Require:
164 # Require:
165 # if mod == None: subname == fullname
165 # if mod == None: subname == fullname
166 # else: mod.__name__ + "." + subname == fullname
166 # else: mod.__name__ + "." + subname == fullname
167
167
168 global found_now
168 global found_now
169 if fullname in found_now and fullname in sys.modules:
169 if fullname in found_now and fullname in sys.modules:
170 m = sys.modules[fullname]
170 m = sys.modules[fullname]
171 else:
171 else:
172 print('Reloading', fullname)
172 print('Reloading', fullname)
173 found_now[fullname] = 1
173 found_now[fullname] = 1
174 oldm = sys.modules.get(fullname, None)
174 oldm = sys.modules.get(fullname, None)
175
175
176 if mod is None:
176 if mod is None:
177 path = None
177 path = None
178 elif hasattr(mod, '__path__'):
178 elif hasattr(mod, '__path__'):
179 path = mod.__path__
179 path = mod.__path__
180 else:
180 else:
181 return None
181 return None
182
182
183 try:
183 try:
184 # This appears to be necessary on Python 3, because imp.find_module()
184 # This appears to be necessary on Python 3, because imp.find_module()
185 # tries to import standard libraries (like io) itself, and we don't
185 # tries to import standard libraries (like io) itself, and we don't
186 # want them to be processed by our deep_import_hook.
186 # want them to be processed by our deep_import_hook.
187 with replace_import_hook(original_import):
187 with replace_import_hook(original_import):
188 fp, filename, stuff = imp.find_module(subname, path)
188 fp, filename, stuff = imp.find_module(subname, path)
189 except ImportError:
189 except ImportError:
190 return None
190 return None
191
191
192 try:
192 try:
193 m = imp.load_module(fullname, fp, filename, stuff)
193 m = imp.load_module(fullname, fp, filename, stuff)
194 except:
194 except:
195 # load_module probably removed name from modules because of
195 # load_module probably removed name from modules because of
196 # the error. Put back the original module object.
196 # the error. Put back the original module object.
197 if oldm:
197 if oldm:
198 sys.modules[fullname] = oldm
198 sys.modules[fullname] = oldm
199 raise
199 raise
200 finally:
200 finally:
201 if fp: fp.close()
201 if fp: fp.close()
202
202
203 add_submodule(mod, m, fullname, subname)
203 add_submodule(mod, m, fullname, subname)
204
204
205 return m
205 return m
206
206
207 def add_submodule(mod, submod, fullname, subname):
207 def add_submodule(mod, submod, fullname, subname):
208 """mod.{subname} = submod"""
208 """mod.{subname} = submod"""
209 if mod is None:
209 if mod is None:
210 return #Nothing to do here.
210 return #Nothing to do here.
211
211
212 if submod is None:
212 if submod is None:
213 submod = sys.modules[fullname]
213 submod = sys.modules[fullname]
214
214
215 setattr(mod, subname, submod)
215 setattr(mod, subname, submod)
216
216
217 return
217 return
218
218
219 def ensure_fromlist(mod, fromlist, buf, recursive):
219 def ensure_fromlist(mod, fromlist, buf, recursive):
220 """Handle 'from module import a, b, c' imports."""
220 """Handle 'from module import a, b, c' imports."""
221 if not hasattr(mod, '__path__'):
221 if not hasattr(mod, '__path__'):
222 return
222 return
223 for item in fromlist:
223 for item in fromlist:
224 if not hasattr(item, 'rindex'):
224 if not hasattr(item, 'rindex'):
225 raise TypeError("Item in ``from list'' not a string")
225 raise TypeError("Item in ``from list'' not a string")
226 if item == '*':
226 if item == '*':
227 if recursive:
227 if recursive:
228 continue # avoid endless recursion
228 continue # avoid endless recursion
229 try:
229 try:
230 all = mod.__all__
230 all = mod.__all__
231 except AttributeError:
231 except AttributeError:
232 pass
232 pass
233 else:
233 else:
234 ret = ensure_fromlist(mod, all, buf, 1)
234 ret = ensure_fromlist(mod, all, buf, 1)
235 if not ret:
235 if not ret:
236 return 0
236 return 0
237 elif not hasattr(mod, item):
237 elif not hasattr(mod, item):
238 import_submodule(mod, item, buf + '.' + item)
238 import_submodule(mod, item, buf + '.' + item)
239
239
240 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
240 def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1):
241 """Replacement for __import__()"""
241 """Replacement for __import__()"""
242 parent, buf = get_parent(globals, level)
242 parent, buf = get_parent(globals, level)
243
243
244 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
244 head, name, buf = load_next(parent, None if level < 0 else parent, name, buf)
245
245
246 tail = head
246 tail = head
247 while name:
247 while name:
248 tail, name, buf = load_next(tail, tail, name, buf)
248 tail, name, buf = load_next(tail, tail, name, buf)
249
249
250 # If tail is None, both get_parent and load_next found
250 # If tail is None, both get_parent and load_next found
251 # an empty module name: someone called __import__("") or
251 # an empty module name: someone called __import__("") or
252 # doctored faulty bytecode
252 # doctored faulty bytecode
253 if tail is None:
253 if tail is None:
254 raise ValueError('Empty module name')
254 raise ValueError('Empty module name')
255
255
256 if not fromlist:
256 if not fromlist:
257 return head
257 return head
258
258
259 ensure_fromlist(tail, fromlist, buf, 0)
259 ensure_fromlist(tail, fromlist, buf, 0)
260 return tail
260 return tail
261
261
262 modules_reloading = {}
262 modules_reloading = {}
263
263
264 def deep_reload_hook(m):
264 def deep_reload_hook(m):
265 """Replacement for reload()."""
265 """Replacement for reload()."""
266 if not isinstance(m, ModuleType):
266 if not isinstance(m, ModuleType):
267 raise TypeError("reload() argument must be module")
267 raise TypeError("reload() argument must be module")
268
268
269 name = m.__name__
269 name = m.__name__
270
270
271 if name not in sys.modules:
271 if name not in sys.modules:
272 raise ImportError("reload(): module %.200s not in sys.modules" % name)
272 raise ImportError("reload(): module %.200s not in sys.modules" % name)
273
273
274 global modules_reloading
274 global modules_reloading
275 try:
275 try:
276 return modules_reloading[name]
276 return modules_reloading[name]
277 except:
277 except:
278 modules_reloading[name] = m
278 modules_reloading[name] = m
279
279
280 dot = name.rfind('.')
280 dot = name.rfind('.')
281 if dot < 0:
281 if dot < 0:
282 subname = name
282 subname = name
283 path = None
283 path = None
284 else:
284 else:
285 try:
285 try:
286 parent = sys.modules[name[:dot]]
286 parent = sys.modules[name[:dot]]
287 except KeyError:
287 except KeyError:
288 modules_reloading.clear()
288 modules_reloading.clear()
289 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
289 raise ImportError("reload(): parent %.200s not in sys.modules" % name[:dot])
290 subname = name[dot+1:]
290 subname = name[dot+1:]
291 path = getattr(parent, "__path__", None)
291 path = getattr(parent, "__path__", None)
292
292
293 try:
293 try:
294 # This appears to be necessary on Python 3, because imp.find_module()
294 # This appears to be necessary on Python 3, because imp.find_module()
295 # tries to import standard libraries (like io) itself, and we don't
295 # tries to import standard libraries (like io) itself, and we don't
296 # want them to be processed by our deep_import_hook.
296 # want them to be processed by our deep_import_hook.
297 with replace_import_hook(original_import):
297 with replace_import_hook(original_import):
298 fp, filename, stuff = imp.find_module(subname, path)
298 fp, filename, stuff = imp.find_module(subname, path)
299 finally:
299 finally:
300 modules_reloading.clear()
300 modules_reloading.clear()
301
301
302 try:
302 try:
303 newm = imp.load_module(name, fp, filename, stuff)
303 newm = imp.load_module(name, fp, filename, stuff)
304 except:
304 except:
305 # load_module probably removed name from modules because of
305 # load_module probably removed name from modules because of
306 # the error. Put back the original module object.
306 # the error. Put back the original module object.
307 sys.modules[name] = m
307 sys.modules[name] = m
308 raise
308 raise
309 finally:
309 finally:
310 if fp: fp.close()
310 if fp: fp.close()
311
311
312 modules_reloading.clear()
312 modules_reloading.clear()
313 return newm
313 return newm
314
314
315 # Save the original hooks
315 # Save the original hooks
316 try:
316 try:
317 original_reload = builtin_mod.reload
317 original_reload = builtin_mod.reload
318 except AttributeError:
318 except AttributeError:
319 original_reload = imp.reload # Python 3
319 original_reload = imp.reload # Python 3
320
320
321 # Replacement for reload()
321 # Replacement for reload()
322 def reload(module, exclude=['sys', 'os.path', builtin_mod_name, '__main__']):
322 def reload(module, exclude=['sys', 'os.path', builtin_mod_name, '__main__']):
323 """Recursively reload all modules used in the given module. Optionally
323 """Recursively reload all modules used in the given module. Optionally
324 takes a list of modules to exclude from reloading. The default exclude
324 takes a list of modules to exclude from reloading. The default exclude
325 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
325 list contains sys, __main__, and __builtin__, to prevent, e.g., resetting
326 display, exception, and io hooks.
326 display, exception, and io hooks.
327 """
327 """
328 global found_now
328 global found_now
329 for i in exclude:
329 for i in exclude:
330 found_now[i] = 1
330 found_now[i] = 1
331 try:
331 try:
332 with replace_import_hook(deep_import_hook):
332 with replace_import_hook(deep_import_hook):
333 return deep_reload_hook(module)
333 return deep_reload_hook(module)
334 finally:
334 finally:
335 found_now = {}
335 found_now = {}
336
336
337 # Uncomment the following to automatically activate deep reloading whenever
337 # Uncomment the following to automatically activate deep reloading whenever
338 # this module is imported
338 # this module is imported
339 #builtin_mod.reload = reload
339 #builtin_mod.reload = reload
@@ -1,790 +1,790 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Python advanced pretty printer. This pretty printer is intended to
3 Python advanced pretty printer. This pretty printer is intended to
4 replace the old `pprint` python module which does not allow developers
4 replace the old `pprint` python module which does not allow developers
5 to provide their own pretty print callbacks.
5 to provide their own pretty print callbacks.
6
6
7 This module is based on ruby's `prettyprint.rb` library by `Tanaka Akira`.
7 This module is based on ruby's `prettyprint.rb` library by `Tanaka Akira`.
8
8
9
9
10 Example Usage
10 Example Usage
11 -------------
11 -------------
12
12
13 To directly print the representation of an object use `pprint`::
13 To directly print the representation of an object use `pprint`::
14
14
15 from pretty import pprint
15 from pretty import pprint
16 pprint(complex_object)
16 pprint(complex_object)
17
17
18 To get a string of the output use `pretty`::
18 To get a string of the output use `pretty`::
19
19
20 from pretty import pretty
20 from pretty import pretty
21 string = pretty(complex_object)
21 string = pretty(complex_object)
22
22
23
23
24 Extending
24 Extending
25 ---------
25 ---------
26
26
27 The pretty library allows developers to add pretty printing rules for their
27 The pretty library allows developers to add pretty printing rules for their
28 own objects. This process is straightforward. All you have to do is to
28 own objects. This process is straightforward. All you have to do is to
29 add a `_repr_pretty_` method to your object and call the methods on the
29 add a `_repr_pretty_` method to your object and call the methods on the
30 pretty printer passed::
30 pretty printer passed::
31
31
32 class MyObject(object):
32 class MyObject(object):
33
33
34 def _repr_pretty_(self, p, cycle):
34 def _repr_pretty_(self, p, cycle):
35 ...
35 ...
36
36
37 Depending on the python version you want to support you have two
37 Depending on the python version you want to support you have two
38 possibilities. The following list shows the python 2.5 version and the
38 possibilities. The following list shows the python 2.5 version and the
39 compatibility one.
39 compatibility one.
40
40
41
41
42 Here the example implementation of a `_repr_pretty_` method for a list
42 Here the example implementation of a `_repr_pretty_` method for a list
43 subclass for python 2.5 and higher (python 2.5 requires the with statement
43 subclass for python 2.5 and higher (python 2.5 requires the with statement
44 __future__ import)::
44 __future__ import)::
45
45
46 class MyList(list):
46 class MyList(list):
47
47
48 def _repr_pretty_(self, p, cycle):
48 def _repr_pretty_(self, p, cycle):
49 if cycle:
49 if cycle:
50 p.text('MyList(...)')
50 p.text('MyList(...)')
51 else:
51 else:
52 with p.group(8, 'MyList([', '])'):
52 with p.group(8, 'MyList([', '])'):
53 for idx, item in enumerate(self):
53 for idx, item in enumerate(self):
54 if idx:
54 if idx:
55 p.text(',')
55 p.text(',')
56 p.breakable()
56 p.breakable()
57 p.pretty(item)
57 p.pretty(item)
58
58
59 The `cycle` parameter is `True` if pretty detected a cycle. You *have* to
59 The `cycle` parameter is `True` if pretty detected a cycle. You *have* to
60 react to that or the result is an infinite loop. `p.text()` just adds
60 react to that or the result is an infinite loop. `p.text()` just adds
61 non breaking text to the output, `p.breakable()` either adds a whitespace
61 non breaking text to the output, `p.breakable()` either adds a whitespace
62 or breaks here. If you pass it an argument it's used instead of the
62 or breaks here. If you pass it an argument it's used instead of the
63 default space. `p.pretty` prettyprints another object using the pretty print
63 default space. `p.pretty` prettyprints another object using the pretty print
64 method.
64 method.
65
65
66 The first parameter to the `group` function specifies the extra indentation
66 The first parameter to the `group` function specifies the extra indentation
67 of the next line. In this example the next item will either be not
67 of the next line. In this example the next item will either be not
68 breaked (if the items are short enough) or aligned with the right edge of
68 breaked (if the items are short enough) or aligned with the right edge of
69 the opening bracked of `MyList`.
69 the opening bracked of `MyList`.
70
70
71 If you want to support python 2.4 and lower you can use this code::
71 If you want to support python 2.4 and lower you can use this code::
72
72
73 class MyList(list):
73 class MyList(list):
74
74
75 def _repr_pretty_(self, p, cycle):
75 def _repr_pretty_(self, p, cycle):
76 if cycle:
76 if cycle:
77 p.text('MyList(...)')
77 p.text('MyList(...)')
78 else:
78 else:
79 p.begin_group(8, 'MyList([')
79 p.begin_group(8, 'MyList([')
80 for idx, item in enumerate(self):
80 for idx, item in enumerate(self):
81 if idx:
81 if idx:
82 p.text(',')
82 p.text(',')
83 p.breakable()
83 p.breakable()
84 p.pretty(item)
84 p.pretty(item)
85 p.end_group(8, '])')
85 p.end_group(8, '])')
86
86
87 If you just want to indent something you can use the group function
87 If you just want to indent something you can use the group function
88 without open / close parameters. Under python 2.5 you can also use this
88 without open / close parameters. Under python 2.5 you can also use this
89 code::
89 code::
90
90
91 with p.indent(2):
91 with p.indent(2):
92 ...
92 ...
93
93
94 Or under python2.4 you might want to modify ``p.indentation`` by hand but
94 Or under python2.4 you might want to modify ``p.indentation`` by hand but
95 this is rather ugly.
95 this is rather ugly.
96
96
97 Inheritance diagram:
97 Inheritance diagram:
98
98
99 .. inheritance-diagram:: IPython.lib.pretty
99 .. inheritance-diagram:: IPython.lib.pretty
100 :parts: 3
100 :parts: 3
101
101
102 :copyright: 2007 by Armin Ronacher.
102 :copyright: 2007 by Armin Ronacher.
103 Portions (c) 2009 by Robert Kern.
103 Portions (c) 2009 by Robert Kern.
104 :license: BSD License.
104 :license: BSD License.
105 """
105 """
106 from __future__ import print_function
106 from __future__ import print_function
107 from contextlib import contextmanager
107 from contextlib import contextmanager
108 import sys
108 import sys
109 import types
109 import types
110 import re
110 import re
111 import datetime
111 import datetime
112 from io import StringIO
112 from io import StringIO
113 from collections import deque
113 from collections import deque
114
114
115
115
116 __all__ = ['pretty', 'pprint', 'PrettyPrinter', 'RepresentationPrinter',
116 __all__ = ['pretty', 'pprint', 'PrettyPrinter', 'RepresentationPrinter',
117 'for_type', 'for_type_by_name']
117 'for_type', 'for_type_by_name']
118
118
119
119
120 _re_pattern_type = type(re.compile(''))
120 _re_pattern_type = type(re.compile(''))
121
121
122
122
123 def pretty(obj, verbose=False, max_width=79, newline='\n'):
123 def pretty(obj, verbose=False, max_width=79, newline='\n'):
124 """
124 """
125 Pretty print the object's representation.
125 Pretty print the object's representation.
126 """
126 """
127 stream = StringIO()
127 stream = StringIO()
128 printer = RepresentationPrinter(stream, verbose, max_width, newline)
128 printer = RepresentationPrinter(stream, verbose, max_width, newline)
129 printer.pretty(obj)
129 printer.pretty(obj)
130 printer.flush()
130 printer.flush()
131 return stream.getvalue()
131 return stream.getvalue()
132
132
133
133
134 def pprint(obj, verbose=False, max_width=79, newline='\n'):
134 def pprint(obj, verbose=False, max_width=79, newline='\n'):
135 """
135 """
136 Like `pretty` but print to stdout.
136 Like `pretty` but print to stdout.
137 """
137 """
138 printer = RepresentationPrinter(sys.stdout, verbose, max_width, newline)
138 printer = RepresentationPrinter(sys.stdout, verbose, max_width, newline)
139 printer.pretty(obj)
139 printer.pretty(obj)
140 printer.flush()
140 printer.flush()
141 sys.stdout.write(newline)
141 sys.stdout.write(newline)
142 sys.stdout.flush()
142 sys.stdout.flush()
143
143
144 class _PrettyPrinterBase(object):
144 class _PrettyPrinterBase(object):
145
145
146 @contextmanager
146 @contextmanager
147 def indent(self, indent):
147 def indent(self, indent):
148 """with statement support for indenting/dedenting."""
148 """with statement support for indenting/dedenting."""
149 self.indentation += indent
149 self.indentation += indent
150 try:
150 try:
151 yield
151 yield
152 finally:
152 finally:
153 self.indentation -= indent
153 self.indentation -= indent
154
154
155 @contextmanager
155 @contextmanager
156 def group(self, indent=0, open='', close=''):
156 def group(self, indent=0, open='', close=''):
157 """like begin_group / end_group but for the with statement."""
157 """like begin_group / end_group but for the with statement."""
158 self.begin_group(indent, open)
158 self.begin_group(indent, open)
159 try:
159 try:
160 yield
160 yield
161 finally:
161 finally:
162 self.end_group(indent, close)
162 self.end_group(indent, close)
163
163
164 class PrettyPrinter(_PrettyPrinterBase):
164 class PrettyPrinter(_PrettyPrinterBase):
165 """
165 """
166 Baseclass for the `RepresentationPrinter` prettyprinter that is used to
166 Baseclass for the `RepresentationPrinter` prettyprinter that is used to
167 generate pretty reprs of objects. Contrary to the `RepresentationPrinter`
167 generate pretty reprs of objects. Contrary to the `RepresentationPrinter`
168 this printer knows nothing about the default pprinters or the `_repr_pretty_`
168 this printer knows nothing about the default pprinters or the `_repr_pretty_`
169 callback method.
169 callback method.
170 """
170 """
171
171
172 def __init__(self, output, max_width=79, newline='\n'):
172 def __init__(self, output, max_width=79, newline='\n'):
173 self.output = output
173 self.output = output
174 self.max_width = max_width
174 self.max_width = max_width
175 self.newline = newline
175 self.newline = newline
176 self.output_width = 0
176 self.output_width = 0
177 self.buffer_width = 0
177 self.buffer_width = 0
178 self.buffer = deque()
178 self.buffer = deque()
179
179
180 root_group = Group(0)
180 root_group = Group(0)
181 self.group_stack = [root_group]
181 self.group_stack = [root_group]
182 self.group_queue = GroupQueue(root_group)
182 self.group_queue = GroupQueue(root_group)
183 self.indentation = 0
183 self.indentation = 0
184
184
185 def _break_outer_groups(self):
185 def _break_outer_groups(self):
186 while self.max_width < self.output_width + self.buffer_width:
186 while self.max_width < self.output_width + self.buffer_width:
187 group = self.group_queue.deq()
187 group = self.group_queue.deq()
188 if not group:
188 if not group:
189 return
189 return
190 while group.breakables:
190 while group.breakables:
191 x = self.buffer.popleft()
191 x = self.buffer.popleft()
192 self.output_width = x.output(self.output, self.output_width)
192 self.output_width = x.output(self.output, self.output_width)
193 self.buffer_width -= x.width
193 self.buffer_width -= x.width
194 while self.buffer and isinstance(self.buffer[0], Text):
194 while self.buffer and isinstance(self.buffer[0], Text):
195 x = self.buffer.popleft()
195 x = self.buffer.popleft()
196 self.output_width = x.output(self.output, self.output_width)
196 self.output_width = x.output(self.output, self.output_width)
197 self.buffer_width -= x.width
197 self.buffer_width -= x.width
198
198
199 def text(self, obj):
199 def text(self, obj):
200 """Add literal text to the output."""
200 """Add literal text to the output."""
201 width = len(obj)
201 width = len(obj)
202 if self.buffer:
202 if self.buffer:
203 text = self.buffer[-1]
203 text = self.buffer[-1]
204 if not isinstance(text, Text):
204 if not isinstance(text, Text):
205 text = Text()
205 text = Text()
206 self.buffer.append(text)
206 self.buffer.append(text)
207 text.add(obj, width)
207 text.add(obj, width)
208 self.buffer_width += width
208 self.buffer_width += width
209 self._break_outer_groups()
209 self._break_outer_groups()
210 else:
210 else:
211 self.output.write(obj)
211 self.output.write(obj)
212 self.output_width += width
212 self.output_width += width
213
213
214 def breakable(self, sep=' '):
214 def breakable(self, sep=' '):
215 """
215 """
216 Add a breakable separator to the output. This does not mean that it
216 Add a breakable separator to the output. This does not mean that it
217 will automatically break here. If no breaking on this position takes
217 will automatically break here. If no breaking on this position takes
218 place the `sep` is inserted which default to one space.
218 place the `sep` is inserted which default to one space.
219 """
219 """
220 width = len(sep)
220 width = len(sep)
221 group = self.group_stack[-1]
221 group = self.group_stack[-1]
222 if group.want_break:
222 if group.want_break:
223 self.flush()
223 self.flush()
224 self.output.write(self.newline)
224 self.output.write(self.newline)
225 self.output.write(' ' * self.indentation)
225 self.output.write(' ' * self.indentation)
226 self.output_width = self.indentation
226 self.output_width = self.indentation
227 self.buffer_width = 0
227 self.buffer_width = 0
228 else:
228 else:
229 self.buffer.append(Breakable(sep, width, self))
229 self.buffer.append(Breakable(sep, width, self))
230 self.buffer_width += width
230 self.buffer_width += width
231 self._break_outer_groups()
231 self._break_outer_groups()
232
232
233 def break_(self):
233 def break_(self):
234 """
234 """
235 Explicitly insert a newline into the output, maintaining correct indentation.
235 Explicitly insert a newline into the output, maintaining correct indentation.
236 """
236 """
237 self.flush()
237 self.flush()
238 self.output.write(self.newline)
238 self.output.write(self.newline)
239 self.output.write(' ' * self.indentation)
239 self.output.write(' ' * self.indentation)
240 self.output_width = self.indentation
240 self.output_width = self.indentation
241 self.buffer_width = 0
241 self.buffer_width = 0
242
242
243
243
244 def begin_group(self, indent=0, open=''):
244 def begin_group(self, indent=0, open=''):
245 """
245 """
246 Begin a group. If you want support for python < 2.5 which doesn't has
246 Begin a group. If you want support for python < 2.5 which doesn't has
247 the with statement this is the preferred way:
247 the with statement this is the preferred way:
248
248
249 p.begin_group(1, '{')
249 p.begin_group(1, '{')
250 ...
250 ...
251 p.end_group(1, '}')
251 p.end_group(1, '}')
252
252
253 The python 2.5 expression would be this:
253 The python 2.5 expression would be this:
254
254
255 with p.group(1, '{', '}'):
255 with p.group(1, '{', '}'):
256 ...
256 ...
257
257
258 The first parameter specifies the indentation for the next line (usually
258 The first parameter specifies the indentation for the next line (usually
259 the width of the opening text), the second the opening text. All
259 the width of the opening text), the second the opening text. All
260 parameters are optional.
260 parameters are optional.
261 """
261 """
262 if open:
262 if open:
263 self.text(open)
263 self.text(open)
264 group = Group(self.group_stack[-1].depth + 1)
264 group = Group(self.group_stack[-1].depth + 1)
265 self.group_stack.append(group)
265 self.group_stack.append(group)
266 self.group_queue.enq(group)
266 self.group_queue.enq(group)
267 self.indentation += indent
267 self.indentation += indent
268
268
269 def end_group(self, dedent=0, close=''):
269 def end_group(self, dedent=0, close=''):
270 """End a group. See `begin_group` for more details."""
270 """End a group. See `begin_group` for more details."""
271 self.indentation -= dedent
271 self.indentation -= dedent
272 group = self.group_stack.pop()
272 group = self.group_stack.pop()
273 if not group.breakables:
273 if not group.breakables:
274 self.group_queue.remove(group)
274 self.group_queue.remove(group)
275 if close:
275 if close:
276 self.text(close)
276 self.text(close)
277
277
278 def flush(self):
278 def flush(self):
279 """Flush data that is left in the buffer."""
279 """Flush data that is left in the buffer."""
280 for data in self.buffer:
280 for data in self.buffer:
281 self.output_width += data.output(self.output, self.output_width)
281 self.output_width += data.output(self.output, self.output_width)
282 self.buffer.clear()
282 self.buffer.clear()
283 self.buffer_width = 0
283 self.buffer_width = 0
284
284
285
285
286 def _get_mro(obj_class):
286 def _get_mro(obj_class):
287 """ Get a reasonable method resolution order of a class and its superclasses
287 """ Get a reasonable method resolution order of a class and its superclasses
288 for both old-style and new-style classes.
288 for both old-style and new-style classes.
289 """
289 """
290 if not hasattr(obj_class, '__mro__'):
290 if not hasattr(obj_class, '__mro__'):
291 # Old-style class. Mix in object to make a fake new-style class.
291 # Old-style class. Mix in object to make a fake new-style class.
292 try:
292 try:
293 obj_class = type(obj_class.__name__, (obj_class, object), {})
293 obj_class = type(obj_class.__name__, (obj_class, object), {})
294 except TypeError:
294 except TypeError:
295 # Old-style extension type that does not descend from object.
295 # Old-style extension type that does not descend from object.
296 # FIXME: try to construct a more thorough MRO.
296 # FIXME: try to construct a more thorough MRO.
297 mro = [obj_class]
297 mro = [obj_class]
298 else:
298 else:
299 mro = obj_class.__mro__[1:-1]
299 mro = obj_class.__mro__[1:-1]
300 else:
300 else:
301 mro = obj_class.__mro__
301 mro = obj_class.__mro__
302 return mro
302 return mro
303
303
304
304
305 class RepresentationPrinter(PrettyPrinter):
305 class RepresentationPrinter(PrettyPrinter):
306 """
306 """
307 Special pretty printer that has a `pretty` method that calls the pretty
307 Special pretty printer that has a `pretty` method that calls the pretty
308 printer for a python object.
308 printer for a python object.
309
309
310 This class stores processing data on `self` so you must *never* use
310 This class stores processing data on `self` so you must *never* use
311 this class in a threaded environment. Always lock it or reinstanciate
311 this class in a threaded environment. Always lock it or reinstanciate
312 it.
312 it.
313
313
314 Instances also have a verbose flag callbacks can access to control their
314 Instances also have a verbose flag callbacks can access to control their
315 output. For example the default instance repr prints all attributes and
315 output. For example the default instance repr prints all attributes and
316 methods that are not prefixed by an underscore if the printer is in
316 methods that are not prefixed by an underscore if the printer is in
317 verbose mode.
317 verbose mode.
318 """
318 """
319
319
320 def __init__(self, output, verbose=False, max_width=79, newline='\n',
320 def __init__(self, output, verbose=False, max_width=79, newline='\n',
321 singleton_pprinters=None, type_pprinters=None, deferred_pprinters=None):
321 singleton_pprinters=None, type_pprinters=None, deferred_pprinters=None):
322
322
323 PrettyPrinter.__init__(self, output, max_width, newline)
323 PrettyPrinter.__init__(self, output, max_width, newline)
324 self.verbose = verbose
324 self.verbose = verbose
325 self.stack = []
325 self.stack = []
326 if singleton_pprinters is None:
326 if singleton_pprinters is None:
327 singleton_pprinters = _singleton_pprinters.copy()
327 singleton_pprinters = _singleton_pprinters.copy()
328 self.singleton_pprinters = singleton_pprinters
328 self.singleton_pprinters = singleton_pprinters
329 if type_pprinters is None:
329 if type_pprinters is None:
330 type_pprinters = _type_pprinters.copy()
330 type_pprinters = _type_pprinters.copy()
331 self.type_pprinters = type_pprinters
331 self.type_pprinters = type_pprinters
332 if deferred_pprinters is None:
332 if deferred_pprinters is None:
333 deferred_pprinters = _deferred_type_pprinters.copy()
333 deferred_pprinters = _deferred_type_pprinters.copy()
334 self.deferred_pprinters = deferred_pprinters
334 self.deferred_pprinters = deferred_pprinters
335
335
336 def pretty(self, obj):
336 def pretty(self, obj):
337 """Pretty print the given object."""
337 """Pretty print the given object."""
338 obj_id = id(obj)
338 obj_id = id(obj)
339 cycle = obj_id in self.stack
339 cycle = obj_id in self.stack
340 self.stack.append(obj_id)
340 self.stack.append(obj_id)
341 self.begin_group()
341 self.begin_group()
342 try:
342 try:
343 obj_class = getattr(obj, '__class__', None) or type(obj)
343 obj_class = getattr(obj, '__class__', None) or type(obj)
344 # First try to find registered singleton printers for the type.
344 # First try to find registered singleton printers for the type.
345 try:
345 try:
346 printer = self.singleton_pprinters[obj_id]
346 printer = self.singleton_pprinters[obj_id]
347 except (TypeError, KeyError):
347 except (TypeError, KeyError):
348 pass
348 pass
349 else:
349 else:
350 return printer(obj, self, cycle)
350 return printer(obj, self, cycle)
351 # Next walk the mro and check for either:
351 # Next walk the mro and check for either:
352 # 1) a registered printer
352 # 1) a registered printer
353 # 2) a _repr_pretty_ method
353 # 2) a _repr_pretty_ method
354 for cls in _get_mro(obj_class):
354 for cls in _get_mro(obj_class):
355 if cls in self.type_pprinters:
355 if cls in self.type_pprinters:
356 # printer registered in self.type_pprinters
356 # printer registered in self.type_pprinters
357 return self.type_pprinters[cls](obj, self, cycle)
357 return self.type_pprinters[cls](obj, self, cycle)
358 else:
358 else:
359 # deferred printer
359 # deferred printer
360 printer = self._in_deferred_types(cls)
360 printer = self._in_deferred_types(cls)
361 if printer is not None:
361 if printer is not None:
362 return printer(obj, self, cycle)
362 return printer(obj, self, cycle)
363 else:
363 else:
364 # Finally look for special method names.
364 # Finally look for special method names.
365 # Some objects automatically create any requested
365 # Some objects automatically create any requested
366 # attribute. Try to ignore most of them by checking for
366 # attribute. Try to ignore most of them by checking for
367 # callability.
367 # callability.
368 if '_repr_pretty_' in cls.__dict__:
368 if '_repr_pretty_' in cls.__dict__:
369 meth = cls._repr_pretty_
369 meth = cls._repr_pretty_
370 if callable(meth):
370 if callable(meth):
371 return meth(obj, self, cycle)
371 return meth(obj, self, cycle)
372 return _default_pprint(obj, self, cycle)
372 return _default_pprint(obj, self, cycle)
373 finally:
373 finally:
374 self.end_group()
374 self.end_group()
375 self.stack.pop()
375 self.stack.pop()
376
376
377 def _in_deferred_types(self, cls):
377 def _in_deferred_types(self, cls):
378 """
378 """
379 Check if the given class is specified in the deferred type registry.
379 Check if the given class is specified in the deferred type registry.
380
380
381 Returns the printer from the registry if it exists, and None if the
381 Returns the printer from the registry if it exists, and None if the
382 class is not in the registry. Successful matches will be moved to the
382 class is not in the registry. Successful matches will be moved to the
383 regular type registry for future use.
383 regular type registry for future use.
384 """
384 """
385 mod = getattr(cls, '__module__', None)
385 mod = getattr(cls, '__module__', None)
386 name = getattr(cls, '__name__', None)
386 name = getattr(cls, '__name__', None)
387 key = (mod, name)
387 key = (mod, name)
388 printer = None
388 printer = None
389 if key in self.deferred_pprinters:
389 if key in self.deferred_pprinters:
390 # Move the printer over to the regular registry.
390 # Move the printer over to the regular registry.
391 printer = self.deferred_pprinters.pop(key)
391 printer = self.deferred_pprinters.pop(key)
392 self.type_pprinters[cls] = printer
392 self.type_pprinters[cls] = printer
393 return printer
393 return printer
394
394
395
395
396 class Printable(object):
396 class Printable(object):
397
397
398 def output(self, stream, output_width):
398 def output(self, stream, output_width):
399 return output_width
399 return output_width
400
400
401
401
402 class Text(Printable):
402 class Text(Printable):
403
403
404 def __init__(self):
404 def __init__(self):
405 self.objs = []
405 self.objs = []
406 self.width = 0
406 self.width = 0
407
407
408 def output(self, stream, output_width):
408 def output(self, stream, output_width):
409 for obj in self.objs:
409 for obj in self.objs:
410 stream.write(obj)
410 stream.write(obj)
411 return output_width + self.width
411 return output_width + self.width
412
412
413 def add(self, obj, width):
413 def add(self, obj, width):
414 self.objs.append(obj)
414 self.objs.append(obj)
415 self.width += width
415 self.width += width
416
416
417
417
418 class Breakable(Printable):
418 class Breakable(Printable):
419
419
420 def __init__(self, seq, width, pretty):
420 def __init__(self, seq, width, pretty):
421 self.obj = seq
421 self.obj = seq
422 self.width = width
422 self.width = width
423 self.pretty = pretty
423 self.pretty = pretty
424 self.indentation = pretty.indentation
424 self.indentation = pretty.indentation
425 self.group = pretty.group_stack[-1]
425 self.group = pretty.group_stack[-1]
426 self.group.breakables.append(self)
426 self.group.breakables.append(self)
427
427
428 def output(self, stream, output_width):
428 def output(self, stream, output_width):
429 self.group.breakables.popleft()
429 self.group.breakables.popleft()
430 if self.group.want_break:
430 if self.group.want_break:
431 stream.write(self.pretty.newline)
431 stream.write(self.pretty.newline)
432 stream.write(' ' * self.indentation)
432 stream.write(' ' * self.indentation)
433 return self.indentation
433 return self.indentation
434 if not self.group.breakables:
434 if not self.group.breakables:
435 self.pretty.group_queue.remove(self.group)
435 self.pretty.group_queue.remove(self.group)
436 stream.write(self.obj)
436 stream.write(self.obj)
437 return output_width + self.width
437 return output_width + self.width
438
438
439
439
440 class Group(Printable):
440 class Group(Printable):
441
441
442 def __init__(self, depth):
442 def __init__(self, depth):
443 self.depth = depth
443 self.depth = depth
444 self.breakables = deque()
444 self.breakables = deque()
445 self.want_break = False
445 self.want_break = False
446
446
447
447
448 class GroupQueue(object):
448 class GroupQueue(object):
449
449
450 def __init__(self, *groups):
450 def __init__(self, *groups):
451 self.queue = []
451 self.queue = []
452 for group in groups:
452 for group in groups:
453 self.enq(group)
453 self.enq(group)
454
454
455 def enq(self, group):
455 def enq(self, group):
456 depth = group.depth
456 depth = group.depth
457 while depth > len(self.queue) - 1:
457 while depth > len(self.queue) - 1:
458 self.queue.append([])
458 self.queue.append([])
459 self.queue[depth].append(group)
459 self.queue[depth].append(group)
460
460
461 def deq(self):
461 def deq(self):
462 for stack in self.queue:
462 for stack in self.queue:
463 for idx, group in enumerate(reversed(stack)):
463 for idx, group in enumerate(reversed(stack)):
464 if group.breakables:
464 if group.breakables:
465 del stack[idx]
465 del stack[idx]
466 group.want_break = True
466 group.want_break = True
467 return group
467 return group
468 for group in stack:
468 for group in stack:
469 group.want_break = True
469 group.want_break = True
470 del stack[:]
470 del stack[:]
471
471
472 def remove(self, group):
472 def remove(self, group):
473 try:
473 try:
474 self.queue[group.depth].remove(group)
474 self.queue[group.depth].remove(group)
475 except ValueError:
475 except ValueError:
476 pass
476 pass
477
477
478 try:
478 try:
479 _baseclass_reprs = (object.__repr__, types.InstanceType.__repr__)
479 _baseclass_reprs = (object.__repr__, types.InstanceType.__repr__)
480 except AttributeError: # Python 3
480 except AttributeError: # Python 3
481 _baseclass_reprs = (object.__repr__,)
481 _baseclass_reprs = (object.__repr__,)
482
482
483
483
484 def _default_pprint(obj, p, cycle):
484 def _default_pprint(obj, p, cycle):
485 """
485 """
486 The default print function. Used if an object does not provide one and
486 The default print function. Used if an object does not provide one and
487 it's none of the builtin objects.
487 it's none of the builtin objects.
488 """
488 """
489 klass = getattr(obj, '__class__', None) or type(obj)
489 klass = getattr(obj, '__class__', None) or type(obj)
490 if getattr(klass, '__repr__', None) not in _baseclass_reprs:
490 if getattr(klass, '__repr__', None) not in _baseclass_reprs:
491 # A user-provided repr. Find newlines and replace them with p.break_()
491 # A user-provided repr. Find newlines and replace them with p.break_()
492 output = repr(obj)
492 output = repr(obj)
493 for idx,output_line in enumerate(output.splitlines()):
493 for idx,output_line in enumerate(output.splitlines()):
494 if idx:
494 if idx:
495 p.break_()
495 p.break_()
496 p.text(output_line)
496 p.text(output_line)
497 return
497 return
498 p.begin_group(1, '<')
498 p.begin_group(1, '<')
499 p.pretty(klass)
499 p.pretty(klass)
500 p.text(' at 0x%x' % id(obj))
500 p.text(' at 0x%x' % id(obj))
501 if cycle:
501 if cycle:
502 p.text(' ...')
502 p.text(' ...')
503 elif p.verbose:
503 elif p.verbose:
504 first = True
504 first = True
505 for key in dir(obj):
505 for key in dir(obj):
506 if not key.startswith('_'):
506 if not key.startswith('_'):
507 try:
507 try:
508 value = getattr(obj, key)
508 value = getattr(obj, key)
509 except AttributeError:
509 except AttributeError:
510 continue
510 continue
511 if isinstance(value, types.MethodType):
511 if isinstance(value, types.MethodType):
512 continue
512 continue
513 if not first:
513 if not first:
514 p.text(',')
514 p.text(',')
515 p.breakable()
515 p.breakable()
516 p.text(key)
516 p.text(key)
517 p.text('=')
517 p.text('=')
518 step = len(key) + 1
518 step = len(key) + 1
519 p.indentation += step
519 p.indentation += step
520 p.pretty(value)
520 p.pretty(value)
521 p.indentation -= step
521 p.indentation -= step
522 first = False
522 first = False
523 p.end_group(1, '>')
523 p.end_group(1, '>')
524
524
525
525
526 def _seq_pprinter_factory(start, end, basetype):
526 def _seq_pprinter_factory(start, end, basetype):
527 """
527 """
528 Factory that returns a pprint function useful for sequences. Used by
528 Factory that returns a pprint function useful for sequences. Used by
529 the default pprint for tuples, dicts, and lists.
529 the default pprint for tuples, dicts, and lists.
530 """
530 """
531 def inner(obj, p, cycle):
531 def inner(obj, p, cycle):
532 typ = type(obj)
532 typ = type(obj)
533 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
533 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
534 # If the subclass provides its own repr, use it instead.
534 # If the subclass provides its own repr, use it instead.
535 return p.text(typ.__repr__(obj))
535 return p.text(typ.__repr__(obj))
536
536
537 if cycle:
537 if cycle:
538 return p.text(start + '...' + end)
538 return p.text(start + '...' + end)
539 step = len(start)
539 step = len(start)
540 p.begin_group(step, start)
540 p.begin_group(step, start)
541 for idx, x in enumerate(obj):
541 for idx, x in enumerate(obj):
542 if idx:
542 if idx:
543 p.text(',')
543 p.text(',')
544 p.breakable()
544 p.breakable()
545 p.pretty(x)
545 p.pretty(x)
546 if len(obj) == 1 and type(obj) is tuple:
546 if len(obj) == 1 and type(obj) is tuple:
547 # Special case for 1-item tuples.
547 # Special case for 1-item tuples.
548 p.text(',')
548 p.text(',')
549 p.end_group(step, end)
549 p.end_group(step, end)
550 return inner
550 return inner
551
551
552
552
553 def _set_pprinter_factory(start, end, basetype):
553 def _set_pprinter_factory(start, end, basetype):
554 """
554 """
555 Factory that returns a pprint function useful for sets and frozensets.
555 Factory that returns a pprint function useful for sets and frozensets.
556 """
556 """
557 def inner(obj, p, cycle):
557 def inner(obj, p, cycle):
558 typ = type(obj)
558 typ = type(obj)
559 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
559 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
560 # If the subclass provides its own repr, use it instead.
560 # If the subclass provides its own repr, use it instead.
561 return p.text(typ.__repr__(obj))
561 return p.text(typ.__repr__(obj))
562
562
563 if cycle:
563 if cycle:
564 return p.text(start + '...' + end)
564 return p.text(start + '...' + end)
565 if len(obj) == 0:
565 if len(obj) == 0:
566 # Special case.
566 # Special case.
567 p.text(basetype.__name__ + '()')
567 p.text(basetype.__name__ + '()')
568 else:
568 else:
569 step = len(start)
569 step = len(start)
570 p.begin_group(step, start)
570 p.begin_group(step, start)
571 # Like dictionary keys, we will try to sort the items.
571 # Like dictionary keys, we will try to sort the items.
572 items = list(obj)
572 items = list(obj)
573 try:
573 try:
574 items.sort()
574 items.sort()
575 except Exception:
575 except Exception:
576 # Sometimes the items don't sort.
576 # Sometimes the items don't sort.
577 pass
577 pass
578 for idx, x in enumerate(items):
578 for idx, x in enumerate(items):
579 if idx:
579 if idx:
580 p.text(',')
580 p.text(',')
581 p.breakable()
581 p.breakable()
582 p.pretty(x)
582 p.pretty(x)
583 p.end_group(step, end)
583 p.end_group(step, end)
584 return inner
584 return inner
585
585
586
586
587 def _dict_pprinter_factory(start, end, basetype=None):
587 def _dict_pprinter_factory(start, end, basetype=None):
588 """
588 """
589 Factory that returns a pprint function used by the default pprint of
589 Factory that returns a pprint function used by the default pprint of
590 dicts and dict proxies.
590 dicts and dict proxies.
591 """
591 """
592 def inner(obj, p, cycle):
592 def inner(obj, p, cycle):
593 typ = type(obj)
593 typ = type(obj)
594 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
594 if basetype is not None and typ is not basetype and typ.__repr__ != basetype.__repr__:
595 # If the subclass provides its own repr, use it instead.
595 # If the subclass provides its own repr, use it instead.
596 return p.text(typ.__repr__(obj))
596 return p.text(typ.__repr__(obj))
597
597
598 if cycle:
598 if cycle:
599 return p.text('{...}')
599 return p.text('{...}')
600 p.begin_group(1, start)
600 p.begin_group(1, start)
601 keys = obj.keys()
601 keys = obj.keys()
602 try:
602 try:
603 keys.sort()
603 keys.sort()
604 except Exception as e:
604 except Exception as e:
605 # Sometimes the keys don't sort.
605 # Sometimes the keys don't sort.
606 pass
606 pass
607 for idx, key in enumerate(keys):
607 for idx, key in enumerate(keys):
608 if idx:
608 if idx:
609 p.text(',')
609 p.text(',')
610 p.breakable()
610 p.breakable()
611 p.pretty(key)
611 p.pretty(key)
612 p.text(': ')
612 p.text(': ')
613 p.pretty(obj[key])
613 p.pretty(obj[key])
614 p.end_group(1, end)
614 p.end_group(1, end)
615 return inner
615 return inner
616
616
617
617
618 def _super_pprint(obj, p, cycle):
618 def _super_pprint(obj, p, cycle):
619 """The pprint for the super type."""
619 """The pprint for the super type."""
620 p.begin_group(8, '<super: ')
620 p.begin_group(8, '<super: ')
621 p.pretty(obj.__self_class__)
621 p.pretty(obj.__self_class__)
622 p.text(',')
622 p.text(',')
623 p.breakable()
623 p.breakable()
624 p.pretty(obj.__self__)
624 p.pretty(obj.__self__)
625 p.end_group(8, '>')
625 p.end_group(8, '>')
626
626
627
627
628 def _re_pattern_pprint(obj, p, cycle):
628 def _re_pattern_pprint(obj, p, cycle):
629 """The pprint function for regular expression patterns."""
629 """The pprint function for regular expression patterns."""
630 p.text('re.compile(')
630 p.text('re.compile(')
631 pattern = repr(obj.pattern)
631 pattern = repr(obj.pattern)
632 if pattern[:1] in 'uU':
632 if pattern[:1] in 'uU':
633 pattern = pattern[1:]
633 pattern = pattern[1:]
634 prefix = 'ur'
634 prefix = 'ur'
635 else:
635 else:
636 prefix = 'r'
636 prefix = 'r'
637 pattern = prefix + pattern.replace('\\\\', '\\')
637 pattern = prefix + pattern.replace('\\\\', '\\')
638 p.text(pattern)
638 p.text(pattern)
639 if obj.flags:
639 if obj.flags:
640 p.text(',')
640 p.text(',')
641 p.breakable()
641 p.breakable()
642 done_one = False
642 done_one = False
643 for flag in ('TEMPLATE', 'IGNORECASE', 'LOCALE', 'MULTILINE', 'DOTALL',
643 for flag in ('TEMPLATE', 'IGNORECASE', 'LOCALE', 'MULTILINE', 'DOTALL',
644 'UNICODE', 'VERBOSE', 'DEBUG'):
644 'UNICODE', 'VERBOSE', 'DEBUG'):
645 if obj.flags & getattr(re, flag):
645 if obj.flags & getattr(re, flag):
646 if done_one:
646 if done_one:
647 p.text('|')
647 p.text('|')
648 p.text('re.' + flag)
648 p.text('re.' + flag)
649 done_one = True
649 done_one = True
650 p.text(')')
650 p.text(')')
651
651
652
652
653 def _type_pprint(obj, p, cycle):
653 def _type_pprint(obj, p, cycle):
654 """The pprint for classes and types."""
654 """The pprint for classes and types."""
655 mod = getattr(obj, '__module__', None)
655 mod = getattr(obj, '__module__', None)
656 if mod is None:
656 if mod is None:
657 # Heap allocated types might not have the module attribute,
657 # Heap allocated types might not have the module attribute,
658 # and others may set it to None.
658 # and others may set it to None.
659 return p.text(obj.__name__)
659 return p.text(obj.__name__)
660
660
661 if mod in ('__builtin__', 'builtins', 'exceptions'):
661 if mod in ('__builtin__', 'builtins', 'exceptions'):
662 name = obj.__name__
662 name = obj.__name__
663 else:
663 else:
664 name = mod + '.' + obj.__name__
664 name = mod + '.' + obj.__name__
665 p.text(name)
665 p.text(name)
666
666
667
667
668 def _repr_pprint(obj, p, cycle):
668 def _repr_pprint(obj, p, cycle):
669 """A pprint that just redirects to the normal repr function."""
669 """A pprint that just redirects to the normal repr function."""
670 p.text(repr(obj))
670 p.text(repr(obj))
671
671
672
672
673 def _function_pprint(obj, p, cycle):
673 def _function_pprint(obj, p, cycle):
674 """Base pprint for all functions and builtin functions."""
674 """Base pprint for all functions and builtin functions."""
675 if obj.__module__ in ('__builtin__', 'builtins', 'exceptions') or not obj.__module__:
675 if obj.__module__ in ('__builtin__', 'builtins', 'exceptions') or not obj.__module__:
676 name = obj.__name__
676 name = obj.__name__
677 else:
677 else:
678 name = obj.__module__ + '.' + obj.__name__
678 name = obj.__module__ + '.' + obj.__name__
679 p.text('<function %s>' % name)
679 p.text('<function %s>' % name)
680
680
681
681
682 def _exception_pprint(obj, p, cycle):
682 def _exception_pprint(obj, p, cycle):
683 """Base pprint for all exceptions."""
683 """Base pprint for all exceptions."""
684 if obj.__class__.__module__ in ('exceptions', 'builtins'):
684 if obj.__class__.__module__ in ('exceptions', 'builtins'):
685 name = obj.__class__.__name__
685 name = obj.__class__.__name__
686 else:
686 else:
687 name = '%s.%s' % (
687 name = '%s.%s' % (
688 obj.__class__.__module__,
688 obj.__class__.__module__,
689 obj.__class__.__name__
689 obj.__class__.__name__
690 )
690 )
691 step = len(name) + 1
691 step = len(name) + 1
692 p.begin_group(step, name + '(')
692 p.begin_group(step, name + '(')
693 for idx, arg in enumerate(getattr(obj, 'args', ())):
693 for idx, arg in enumerate(getattr(obj, 'args', ())):
694 if idx:
694 if idx:
695 p.text(',')
695 p.text(',')
696 p.breakable()
696 p.breakable()
697 p.pretty(arg)
697 p.pretty(arg)
698 p.end_group(step, ')')
698 p.end_group(step, ')')
699
699
700
700
701 #: the exception base
701 #: the exception base
702 try:
702 try:
703 _exception_base = BaseException
703 _exception_base = BaseException
704 except NameError:
704 except NameError:
705 _exception_base = Exception
705 _exception_base = Exception
706
706
707
707
708 #: printers for builtin types
708 #: printers for builtin types
709 _type_pprinters = {
709 _type_pprinters = {
710 int: _repr_pprint,
710 int: _repr_pprint,
711 long: _repr_pprint,
712 float: _repr_pprint,
711 float: _repr_pprint,
713 str: _repr_pprint,
712 str: _repr_pprint,
714 unicode: _repr_pprint,
715 tuple: _seq_pprinter_factory('(', ')', tuple),
713 tuple: _seq_pprinter_factory('(', ')', tuple),
716 list: _seq_pprinter_factory('[', ']', list),
714 list: _seq_pprinter_factory('[', ']', list),
717 dict: _dict_pprinter_factory('{', '}', dict),
715 dict: _dict_pprinter_factory('{', '}', dict),
718
716
719 set: _set_pprinter_factory('{', '}', set),
717 set: _set_pprinter_factory('{', '}', set),
720 frozenset: _set_pprinter_factory('frozenset({', '})', frozenset),
718 frozenset: _set_pprinter_factory('frozenset({', '})', frozenset),
721 super: _super_pprint,
719 super: _super_pprint,
722 _re_pattern_type: _re_pattern_pprint,
720 _re_pattern_type: _re_pattern_pprint,
723 type: _type_pprint,
721 type: _type_pprint,
724 types.FunctionType: _function_pprint,
722 types.FunctionType: _function_pprint,
725 types.BuiltinFunctionType: _function_pprint,
723 types.BuiltinFunctionType: _function_pprint,
726 types.SliceType: _repr_pprint,
724 types.SliceType: _repr_pprint,
727 types.MethodType: _repr_pprint,
725 types.MethodType: _repr_pprint,
728
726
729 datetime.datetime: _repr_pprint,
727 datetime.datetime: _repr_pprint,
730 datetime.timedelta: _repr_pprint,
728 datetime.timedelta: _repr_pprint,
731 _exception_base: _exception_pprint
729 _exception_base: _exception_pprint
732 }
730 }
733
731
734 try:
732 try:
735 _type_pprinters[types.DictProxyType] = _dict_pprinter_factory('<dictproxy {', '}>')
733 _type_pprinters[types.DictProxyType] = _dict_pprinter_factory('<dictproxy {', '}>')
736 _type_pprinters[types.ClassType] = _type_pprint
734 _type_pprinters[types.ClassType] = _type_pprint
737 except AttributeError: # Python 3
735 _type_pprinters[long] = _repr_pprint
738 pass
736 _type_pprinters[unicode] = _repr_pprint
737 except (AttributeError, NameError): # Python 3
738 _type_pprinters[bytes] = _repr_pprint
739
739
740 try:
740 try:
741 _type_pprinters[xrange] = _repr_pprint
741 _type_pprinters[xrange] = _repr_pprint
742 except NameError:
742 except NameError:
743 _type_pprinters[range] = _repr_pprint
743 _type_pprinters[range] = _repr_pprint
744
744
745 #: printers for types specified by name
745 #: printers for types specified by name
746 _deferred_type_pprinters = {
746 _deferred_type_pprinters = {
747 }
747 }
748
748
749 def for_type(typ, func):
749 def for_type(typ, func):
750 """
750 """
751 Add a pretty printer for a given type.
751 Add a pretty printer for a given type.
752 """
752 """
753 oldfunc = _type_pprinters.get(typ, None)
753 oldfunc = _type_pprinters.get(typ, None)
754 if func is not None:
754 if func is not None:
755 # To support easy restoration of old pprinters, we need to ignore Nones.
755 # To support easy restoration of old pprinters, we need to ignore Nones.
756 _type_pprinters[typ] = func
756 _type_pprinters[typ] = func
757 return oldfunc
757 return oldfunc
758
758
759 def for_type_by_name(type_module, type_name, func):
759 def for_type_by_name(type_module, type_name, func):
760 """
760 """
761 Add a pretty printer for a type specified by the module and name of a type
761 Add a pretty printer for a type specified by the module and name of a type
762 rather than the type object itself.
762 rather than the type object itself.
763 """
763 """
764 key = (type_module, type_name)
764 key = (type_module, type_name)
765 oldfunc = _deferred_type_pprinters.get(key, None)
765 oldfunc = _deferred_type_pprinters.get(key, None)
766 if func is not None:
766 if func is not None:
767 # To support easy restoration of old pprinters, we need to ignore Nones.
767 # To support easy restoration of old pprinters, we need to ignore Nones.
768 _deferred_type_pprinters[key] = func
768 _deferred_type_pprinters[key] = func
769 return oldfunc
769 return oldfunc
770
770
771
771
772 #: printers for the default singletons
772 #: printers for the default singletons
773 _singleton_pprinters = dict.fromkeys(map(id, [None, True, False, Ellipsis,
773 _singleton_pprinters = dict.fromkeys(map(id, [None, True, False, Ellipsis,
774 NotImplemented]), _repr_pprint)
774 NotImplemented]), _repr_pprint)
775
775
776
776
777 if __name__ == '__main__':
777 if __name__ == '__main__':
778 from random import randrange
778 from random import randrange
779 class Foo(object):
779 class Foo(object):
780 def __init__(self):
780 def __init__(self):
781 self.foo = 1
781 self.foo = 1
782 self.bar = re.compile(r'\s+')
782 self.bar = re.compile(r'\s+')
783 self.blub = dict.fromkeys(range(30), randrange(1, 40))
783 self.blub = dict.fromkeys(range(30), randrange(1, 40))
784 self.hehe = 23424.234234
784 self.hehe = 23424.234234
785 self.list = ["blub", "blah", self]
785 self.list = ["blub", "blah", self]
786
786
787 def get_foo(self):
787 def get_foo(self):
788 print("foo")
788 print("foo")
789
789
790 pprint(Foo(), verbose=True)
790 pprint(Foo(), verbose=True)
@@ -1,1859 +1,1854 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 from __future__ import print_function
7 from __future__ import print_function
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2010-2011 The IPython Development Team
9 # Copyright (C) 2010-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import os
19 import os
20 import json
20 import json
21 import sys
21 import sys
22 from threading import Thread, Event
22 from threading import Thread, Event
23 import time
23 import time
24 import warnings
24 import warnings
25 from datetime import datetime
25 from datetime import datetime
26 from getpass import getpass
26 from getpass import getpass
27 from pprint import pprint
27 from pprint import pprint
28
28
29 pjoin = os.path.join
29 pjoin = os.path.join
30
30
31 import zmq
31 import zmq
32 # from zmq.eventloop import ioloop, zmqstream
32 # from zmq.eventloop import ioloop, zmqstream
33
33
34 from IPython.config.configurable import MultipleInstanceError
34 from IPython.config.configurable import MultipleInstanceError
35 from IPython.core.application import BaseIPythonApplication
35 from IPython.core.application import BaseIPythonApplication
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37
37
38 from IPython.utils.capture import RichOutput
38 from IPython.utils.capture import RichOutput
39 from IPython.utils.coloransi import TermColors
39 from IPython.utils.coloransi import TermColors
40 from IPython.utils.jsonutil import rekey
40 from IPython.utils.jsonutil import rekey
41 from IPython.utils.localinterfaces import localhost, is_local_ip
41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 from IPython.utils.path import get_ipython_dir
42 from IPython.utils.path import get_ipython_dir
43 from IPython.utils.py3compat import cast_bytes, string_types
43 from IPython.utils.py3compat import cast_bytes, string_types, xrange
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 Dict, List, Bool, Set, Any)
45 Dict, List, Bool, Set, Any)
46 from IPython.external.decorator import decorator
46 from IPython.external.decorator import decorator
47 from IPython.external.ssh import tunnel
47 from IPython.external.ssh import tunnel
48
48
49 from IPython.parallel import Reference
49 from IPython.parallel import Reference
50 from IPython.parallel import error
50 from IPython.parallel import error
51 from IPython.parallel import util
51 from IPython.parallel import util
52
52
53 from IPython.kernel.zmq.session import Session, Message
53 from IPython.kernel.zmq.session import Session, Message
54 from IPython.kernel.zmq import serialize
54 from IPython.kernel.zmq import serialize
55
55
56 from .asyncresult import AsyncResult, AsyncHubResult
56 from .asyncresult import AsyncResult, AsyncHubResult
57 from .view import DirectView, LoadBalancedView
57 from .view import DirectView, LoadBalancedView
58
58
59 if sys.version_info[0] >= 3:
60 # xrange is used in a couple 'isinstance' tests in py2
61 # should be just 'range' in 3k
62 xrange = range
63
64 #--------------------------------------------------------------------------
59 #--------------------------------------------------------------------------
65 # Decorators for Client methods
60 # Decorators for Client methods
66 #--------------------------------------------------------------------------
61 #--------------------------------------------------------------------------
67
62
68 @decorator
63 @decorator
69 def spin_first(f, self, *args, **kwargs):
64 def spin_first(f, self, *args, **kwargs):
70 """Call spin() to sync state prior to calling the method."""
65 """Call spin() to sync state prior to calling the method."""
71 self.spin()
66 self.spin()
72 return f(self, *args, **kwargs)
67 return f(self, *args, **kwargs)
73
68
74
69
75 #--------------------------------------------------------------------------
70 #--------------------------------------------------------------------------
76 # Classes
71 # Classes
77 #--------------------------------------------------------------------------
72 #--------------------------------------------------------------------------
78
73
79
74
80 class ExecuteReply(RichOutput):
75 class ExecuteReply(RichOutput):
81 """wrapper for finished Execute results"""
76 """wrapper for finished Execute results"""
82 def __init__(self, msg_id, content, metadata):
77 def __init__(self, msg_id, content, metadata):
83 self.msg_id = msg_id
78 self.msg_id = msg_id
84 self._content = content
79 self._content = content
85 self.execution_count = content['execution_count']
80 self.execution_count = content['execution_count']
86 self.metadata = metadata
81 self.metadata = metadata
87
82
88 # RichOutput overrides
83 # RichOutput overrides
89
84
90 @property
85 @property
91 def source(self):
86 def source(self):
92 pyout = self.metadata['pyout']
87 pyout = self.metadata['pyout']
93 if pyout:
88 if pyout:
94 return pyout.get('source', '')
89 return pyout.get('source', '')
95
90
96 @property
91 @property
97 def data(self):
92 def data(self):
98 pyout = self.metadata['pyout']
93 pyout = self.metadata['pyout']
99 if pyout:
94 if pyout:
100 return pyout.get('data', {})
95 return pyout.get('data', {})
101
96
102 @property
97 @property
103 def _metadata(self):
98 def _metadata(self):
104 pyout = self.metadata['pyout']
99 pyout = self.metadata['pyout']
105 if pyout:
100 if pyout:
106 return pyout.get('metadata', {})
101 return pyout.get('metadata', {})
107
102
108 def display(self):
103 def display(self):
109 from IPython.display import publish_display_data
104 from IPython.display import publish_display_data
110 publish_display_data(self.source, self.data, self.metadata)
105 publish_display_data(self.source, self.data, self.metadata)
111
106
112 def _repr_mime_(self, mime):
107 def _repr_mime_(self, mime):
113 if mime not in self.data:
108 if mime not in self.data:
114 return
109 return
115 data = self.data[mime]
110 data = self.data[mime]
116 if mime in self._metadata:
111 if mime in self._metadata:
117 return data, self._metadata[mime]
112 return data, self._metadata[mime]
118 else:
113 else:
119 return data
114 return data
120
115
121 def __getitem__(self, key):
116 def __getitem__(self, key):
122 return self.metadata[key]
117 return self.metadata[key]
123
118
124 def __getattr__(self, key):
119 def __getattr__(self, key):
125 if key not in self.metadata:
120 if key not in self.metadata:
126 raise AttributeError(key)
121 raise AttributeError(key)
127 return self.metadata[key]
122 return self.metadata[key]
128
123
129 def __repr__(self):
124 def __repr__(self):
130 pyout = self.metadata['pyout'] or {'data':{}}
125 pyout = self.metadata['pyout'] or {'data':{}}
131 text_out = pyout['data'].get('text/plain', '')
126 text_out = pyout['data'].get('text/plain', '')
132 if len(text_out) > 32:
127 if len(text_out) > 32:
133 text_out = text_out[:29] + '...'
128 text_out = text_out[:29] + '...'
134
129
135 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
136
131
137 def _repr_pretty_(self, p, cycle):
132 def _repr_pretty_(self, p, cycle):
138 pyout = self.metadata['pyout'] or {'data':{}}
133 pyout = self.metadata['pyout'] or {'data':{}}
139 text_out = pyout['data'].get('text/plain', '')
134 text_out = pyout['data'].get('text/plain', '')
140
135
141 if not text_out:
136 if not text_out:
142 return
137 return
143
138
144 try:
139 try:
145 ip = get_ipython()
140 ip = get_ipython()
146 except NameError:
141 except NameError:
147 colors = "NoColor"
142 colors = "NoColor"
148 else:
143 else:
149 colors = ip.colors
144 colors = ip.colors
150
145
151 if colors == "NoColor":
146 if colors == "NoColor":
152 out = normal = ""
147 out = normal = ""
153 else:
148 else:
154 out = TermColors.Red
149 out = TermColors.Red
155 normal = TermColors.Normal
150 normal = TermColors.Normal
156
151
157 if '\n' in text_out and not text_out.startswith('\n'):
152 if '\n' in text_out and not text_out.startswith('\n'):
158 # add newline for multiline reprs
153 # add newline for multiline reprs
159 text_out = '\n' + text_out
154 text_out = '\n' + text_out
160
155
161 p.text(
156 p.text(
162 out + u'Out[%i:%i]: ' % (
157 out + u'Out[%i:%i]: ' % (
163 self.metadata['engine_id'], self.execution_count
158 self.metadata['engine_id'], self.execution_count
164 ) + normal + text_out
159 ) + normal + text_out
165 )
160 )
166
161
167
162
168 class Metadata(dict):
163 class Metadata(dict):
169 """Subclass of dict for initializing metadata values.
164 """Subclass of dict for initializing metadata values.
170
165
171 Attribute access works on keys.
166 Attribute access works on keys.
172
167
173 These objects have a strict set of keys - errors will raise if you try
168 These objects have a strict set of keys - errors will raise if you try
174 to add new keys.
169 to add new keys.
175 """
170 """
176 def __init__(self, *args, **kwargs):
171 def __init__(self, *args, **kwargs):
177 dict.__init__(self)
172 dict.__init__(self)
178 md = {'msg_id' : None,
173 md = {'msg_id' : None,
179 'submitted' : None,
174 'submitted' : None,
180 'started' : None,
175 'started' : None,
181 'completed' : None,
176 'completed' : None,
182 'received' : None,
177 'received' : None,
183 'engine_uuid' : None,
178 'engine_uuid' : None,
184 'engine_id' : None,
179 'engine_id' : None,
185 'follow' : None,
180 'follow' : None,
186 'after' : None,
181 'after' : None,
187 'status' : None,
182 'status' : None,
188
183
189 'pyin' : None,
184 'pyin' : None,
190 'pyout' : None,
185 'pyout' : None,
191 'pyerr' : None,
186 'pyerr' : None,
192 'stdout' : '',
187 'stdout' : '',
193 'stderr' : '',
188 'stderr' : '',
194 'outputs' : [],
189 'outputs' : [],
195 'data': {},
190 'data': {},
196 'outputs_ready' : False,
191 'outputs_ready' : False,
197 }
192 }
198 self.update(md)
193 self.update(md)
199 self.update(dict(*args, **kwargs))
194 self.update(dict(*args, **kwargs))
200
195
201 def __getattr__(self, key):
196 def __getattr__(self, key):
202 """getattr aliased to getitem"""
197 """getattr aliased to getitem"""
203 if key in self.iterkeys():
198 if key in self.iterkeys():
204 return self[key]
199 return self[key]
205 else:
200 else:
206 raise AttributeError(key)
201 raise AttributeError(key)
207
202
208 def __setattr__(self, key, value):
203 def __setattr__(self, key, value):
209 """setattr aliased to setitem, with strict"""
204 """setattr aliased to setitem, with strict"""
210 if key in self.iterkeys():
205 if key in self.iterkeys():
211 self[key] = value
206 self[key] = value
212 else:
207 else:
213 raise AttributeError(key)
208 raise AttributeError(key)
214
209
215 def __setitem__(self, key, value):
210 def __setitem__(self, key, value):
216 """strict static key enforcement"""
211 """strict static key enforcement"""
217 if key in self.iterkeys():
212 if key in self.iterkeys():
218 dict.__setitem__(self, key, value)
213 dict.__setitem__(self, key, value)
219 else:
214 else:
220 raise KeyError(key)
215 raise KeyError(key)
221
216
222
217
223 class Client(HasTraits):
218 class Client(HasTraits):
224 """A semi-synchronous client to the IPython ZMQ cluster
219 """A semi-synchronous client to the IPython ZMQ cluster
225
220
226 Parameters
221 Parameters
227 ----------
222 ----------
228
223
229 url_file : str/unicode; path to ipcontroller-client.json
224 url_file : str/unicode; path to ipcontroller-client.json
230 This JSON file should contain all the information needed to connect to a cluster,
225 This JSON file should contain all the information needed to connect to a cluster,
231 and is likely the only argument needed.
226 and is likely the only argument needed.
232 Connection information for the Hub's registration. If a json connector
227 Connection information for the Hub's registration. If a json connector
233 file is given, then likely no further configuration is necessary.
228 file is given, then likely no further configuration is necessary.
234 [Default: use profile]
229 [Default: use profile]
235 profile : bytes
230 profile : bytes
236 The name of the Cluster profile to be used to find connector information.
231 The name of the Cluster profile to be used to find connector information.
237 If run from an IPython application, the default profile will be the same
232 If run from an IPython application, the default profile will be the same
238 as the running application, otherwise it will be 'default'.
233 as the running application, otherwise it will be 'default'.
239 cluster_id : str
234 cluster_id : str
240 String id to added to runtime files, to prevent name collisions when using
235 String id to added to runtime files, to prevent name collisions when using
241 multiple clusters with a single profile simultaneously.
236 multiple clusters with a single profile simultaneously.
242 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
243 Since this is text inserted into filenames, typical recommendations apply:
238 Since this is text inserted into filenames, typical recommendations apply:
244 Simple character strings are ideal, and spaces are not recommended (but
239 Simple character strings are ideal, and spaces are not recommended (but
245 should generally work)
240 should generally work)
246 context : zmq.Context
241 context : zmq.Context
247 Pass an existing zmq.Context instance, otherwise the client will create its own.
242 Pass an existing zmq.Context instance, otherwise the client will create its own.
248 debug : bool
243 debug : bool
249 flag for lots of message printing for debug purposes
244 flag for lots of message printing for debug purposes
250 timeout : int/float
245 timeout : int/float
251 time (in seconds) to wait for connection replies from the Hub
246 time (in seconds) to wait for connection replies from the Hub
252 [Default: 10]
247 [Default: 10]
253
248
254 #-------------- session related args ----------------
249 #-------------- session related args ----------------
255
250
256 config : Config object
251 config : Config object
257 If specified, this will be relayed to the Session for configuration
252 If specified, this will be relayed to the Session for configuration
258 username : str
253 username : str
259 set username for the session object
254 set username for the session object
260
255
261 #-------------- ssh related args ----------------
256 #-------------- ssh related args ----------------
262 # These are args for configuring the ssh tunnel to be used
257 # These are args for configuring the ssh tunnel to be used
263 # credentials are used to forward connections over ssh to the Controller
258 # credentials are used to forward connections over ssh to the Controller
264 # Note that the ip given in `addr` needs to be relative to sshserver
259 # Note that the ip given in `addr` needs to be relative to sshserver
265 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
266 # and set sshserver as the same machine the Controller is on. However,
261 # and set sshserver as the same machine the Controller is on. However,
267 # the only requirement is that sshserver is able to see the Controller
262 # the only requirement is that sshserver is able to see the Controller
268 # (i.e. is within the same trusted network).
263 # (i.e. is within the same trusted network).
269
264
270 sshserver : str
265 sshserver : str
271 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
272 If keyfile or password is specified, and this is not, it will default to
267 If keyfile or password is specified, and this is not, it will default to
273 the ip given in addr.
268 the ip given in addr.
274 sshkey : str; path to ssh private key file
269 sshkey : str; path to ssh private key file
275 This specifies a key to be used in ssh login, default None.
270 This specifies a key to be used in ssh login, default None.
276 Regular default ssh keys will be used without specifying this argument.
271 Regular default ssh keys will be used without specifying this argument.
277 password : str
272 password : str
278 Your ssh password to sshserver. Note that if this is left None,
273 Your ssh password to sshserver. Note that if this is left None,
279 you will be prompted for it if passwordless key based login is unavailable.
274 you will be prompted for it if passwordless key based login is unavailable.
280 paramiko : bool
275 paramiko : bool
281 flag for whether to use paramiko instead of shell ssh for tunneling.
276 flag for whether to use paramiko instead of shell ssh for tunneling.
282 [default: True on win32, False else]
277 [default: True on win32, False else]
283
278
284
279
285 Attributes
280 Attributes
286 ----------
281 ----------
287
282
288 ids : list of int engine IDs
283 ids : list of int engine IDs
289 requesting the ids attribute always synchronizes
284 requesting the ids attribute always synchronizes
290 the registration state. To request ids without synchronization,
285 the registration state. To request ids without synchronization,
291 use semi-private _ids attributes.
286 use semi-private _ids attributes.
292
287
293 history : list of msg_ids
288 history : list of msg_ids
294 a list of msg_ids, keeping track of all the execution
289 a list of msg_ids, keeping track of all the execution
295 messages you have submitted in order.
290 messages you have submitted in order.
296
291
297 outstanding : set of msg_ids
292 outstanding : set of msg_ids
298 a set of msg_ids that have been submitted, but whose
293 a set of msg_ids that have been submitted, but whose
299 results have not yet been received.
294 results have not yet been received.
300
295
301 results : dict
296 results : dict
302 a dict of all our results, keyed by msg_id
297 a dict of all our results, keyed by msg_id
303
298
304 block : bool
299 block : bool
305 determines default behavior when block not specified
300 determines default behavior when block not specified
306 in execution methods
301 in execution methods
307
302
308 Methods
303 Methods
309 -------
304 -------
310
305
311 spin
306 spin
312 flushes incoming results and registration state changes
307 flushes incoming results and registration state changes
313 control methods spin, and requesting `ids` also ensures up to date
308 control methods spin, and requesting `ids` also ensures up to date
314
309
315 wait
310 wait
316 wait on one or more msg_ids
311 wait on one or more msg_ids
317
312
318 execution methods
313 execution methods
319 apply
314 apply
320 legacy: execute, run
315 legacy: execute, run
321
316
322 data movement
317 data movement
323 push, pull, scatter, gather
318 push, pull, scatter, gather
324
319
325 query methods
320 query methods
326 queue_status, get_result, purge, result_status
321 queue_status, get_result, purge, result_status
327
322
328 control methods
323 control methods
329 abort, shutdown
324 abort, shutdown
330
325
331 """
326 """
332
327
333
328
334 block = Bool(False)
329 block = Bool(False)
335 outstanding = Set()
330 outstanding = Set()
336 results = Instance('collections.defaultdict', (dict,))
331 results = Instance('collections.defaultdict', (dict,))
337 metadata = Instance('collections.defaultdict', (Metadata,))
332 metadata = Instance('collections.defaultdict', (Metadata,))
338 history = List()
333 history = List()
339 debug = Bool(False)
334 debug = Bool(False)
340 _spin_thread = Any()
335 _spin_thread = Any()
341 _stop_spinning = Any()
336 _stop_spinning = Any()
342
337
343 profile=Unicode()
338 profile=Unicode()
344 def _profile_default(self):
339 def _profile_default(self):
345 if BaseIPythonApplication.initialized():
340 if BaseIPythonApplication.initialized():
346 # an IPython app *might* be running, try to get its profile
341 # an IPython app *might* be running, try to get its profile
347 try:
342 try:
348 return BaseIPythonApplication.instance().profile
343 return BaseIPythonApplication.instance().profile
349 except (AttributeError, MultipleInstanceError):
344 except (AttributeError, MultipleInstanceError):
350 # could be a *different* subclass of config.Application,
345 # could be a *different* subclass of config.Application,
351 # which would raise one of these two errors.
346 # which would raise one of these two errors.
352 return u'default'
347 return u'default'
353 else:
348 else:
354 return u'default'
349 return u'default'
355
350
356
351
357 _outstanding_dict = Instance('collections.defaultdict', (set,))
352 _outstanding_dict = Instance('collections.defaultdict', (set,))
358 _ids = List()
353 _ids = List()
359 _connected=Bool(False)
354 _connected=Bool(False)
360 _ssh=Bool(False)
355 _ssh=Bool(False)
361 _context = Instance('zmq.Context')
356 _context = Instance('zmq.Context')
362 _config = Dict()
357 _config = Dict()
363 _engines=Instance(util.ReverseDict, (), {})
358 _engines=Instance(util.ReverseDict, (), {})
364 # _hub_socket=Instance('zmq.Socket')
359 # _hub_socket=Instance('zmq.Socket')
365 _query_socket=Instance('zmq.Socket')
360 _query_socket=Instance('zmq.Socket')
366 _control_socket=Instance('zmq.Socket')
361 _control_socket=Instance('zmq.Socket')
367 _iopub_socket=Instance('zmq.Socket')
362 _iopub_socket=Instance('zmq.Socket')
368 _notification_socket=Instance('zmq.Socket')
363 _notification_socket=Instance('zmq.Socket')
369 _mux_socket=Instance('zmq.Socket')
364 _mux_socket=Instance('zmq.Socket')
370 _task_socket=Instance('zmq.Socket')
365 _task_socket=Instance('zmq.Socket')
371 _task_scheme=Unicode()
366 _task_scheme=Unicode()
372 _closed = False
367 _closed = False
373 _ignored_control_replies=Integer(0)
368 _ignored_control_replies=Integer(0)
374 _ignored_hub_replies=Integer(0)
369 _ignored_hub_replies=Integer(0)
375
370
376 def __new__(self, *args, **kw):
371 def __new__(self, *args, **kw):
377 # don't raise on positional args
372 # don't raise on positional args
378 return HasTraits.__new__(self, **kw)
373 return HasTraits.__new__(self, **kw)
379
374
380 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
381 context=None, debug=False,
376 context=None, debug=False,
382 sshserver=None, sshkey=None, password=None, paramiko=None,
377 sshserver=None, sshkey=None, password=None, paramiko=None,
383 timeout=10, cluster_id=None, **extra_args
378 timeout=10, cluster_id=None, **extra_args
384 ):
379 ):
385 if profile:
380 if profile:
386 super(Client, self).__init__(debug=debug, profile=profile)
381 super(Client, self).__init__(debug=debug, profile=profile)
387 else:
382 else:
388 super(Client, self).__init__(debug=debug)
383 super(Client, self).__init__(debug=debug)
389 if context is None:
384 if context is None:
390 context = zmq.Context.instance()
385 context = zmq.Context.instance()
391 self._context = context
386 self._context = context
392 self._stop_spinning = Event()
387 self._stop_spinning = Event()
393
388
394 if 'url_or_file' in extra_args:
389 if 'url_or_file' in extra_args:
395 url_file = extra_args['url_or_file']
390 url_file = extra_args['url_or_file']
396 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
397
392
398 if url_file and util.is_url(url_file):
393 if url_file and util.is_url(url_file):
399 raise ValueError("single urls cannot be specified, url-files must be used.")
394 raise ValueError("single urls cannot be specified, url-files must be used.")
400
395
401 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
402
397
403 if self._cd is not None:
398 if self._cd is not None:
404 if url_file is None:
399 if url_file is None:
405 if not cluster_id:
400 if not cluster_id:
406 client_json = 'ipcontroller-client.json'
401 client_json = 'ipcontroller-client.json'
407 else:
402 else:
408 client_json = 'ipcontroller-%s-client.json' % cluster_id
403 client_json = 'ipcontroller-%s-client.json' % cluster_id
409 url_file = pjoin(self._cd.security_dir, client_json)
404 url_file = pjoin(self._cd.security_dir, client_json)
410 if url_file is None:
405 if url_file is None:
411 raise ValueError(
406 raise ValueError(
412 "I can't find enough information to connect to a hub!"
407 "I can't find enough information to connect to a hub!"
413 " Please specify at least one of url_file or profile."
408 " Please specify at least one of url_file or profile."
414 )
409 )
415
410
416 with open(url_file) as f:
411 with open(url_file) as f:
417 cfg = json.load(f)
412 cfg = json.load(f)
418
413
419 self._task_scheme = cfg['task_scheme']
414 self._task_scheme = cfg['task_scheme']
420
415
421 # sync defaults from args, json:
416 # sync defaults from args, json:
422 if sshserver:
417 if sshserver:
423 cfg['ssh'] = sshserver
418 cfg['ssh'] = sshserver
424
419
425 location = cfg.setdefault('location', None)
420 location = cfg.setdefault('location', None)
426
421
427 proto,addr = cfg['interface'].split('://')
422 proto,addr = cfg['interface'].split('://')
428 addr = util.disambiguate_ip_address(addr, location)
423 addr = util.disambiguate_ip_address(addr, location)
429 cfg['interface'] = "%s://%s" % (proto, addr)
424 cfg['interface'] = "%s://%s" % (proto, addr)
430
425
431 # turn interface,port into full urls:
426 # turn interface,port into full urls:
432 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
433 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
434
429
435 url = cfg['registration']
430 url = cfg['registration']
436
431
437 if location is not None and addr == localhost():
432 if location is not None and addr == localhost():
438 # location specified, and connection is expected to be local
433 # location specified, and connection is expected to be local
439 if not is_local_ip(location) and not sshserver:
434 if not is_local_ip(location) and not sshserver:
440 # load ssh from JSON *only* if the controller is not on
435 # load ssh from JSON *only* if the controller is not on
441 # this machine
436 # this machine
442 sshserver=cfg['ssh']
437 sshserver=cfg['ssh']
443 if not is_local_ip(location) and not sshserver:
438 if not is_local_ip(location) and not sshserver:
444 # warn if no ssh specified, but SSH is probably needed
439 # warn if no ssh specified, but SSH is probably needed
445 # This is only a warning, because the most likely cause
440 # This is only a warning, because the most likely cause
446 # is a local Controller on a laptop whose IP is dynamic
441 # is a local Controller on a laptop whose IP is dynamic
447 warnings.warn("""
442 warnings.warn("""
448 Controller appears to be listening on localhost, but not on this machine.
443 Controller appears to be listening on localhost, but not on this machine.
449 If this is true, you should specify Client(...,sshserver='you@%s')
444 If this is true, you should specify Client(...,sshserver='you@%s')
450 or instruct your controller to listen on an external IP."""%location,
445 or instruct your controller to listen on an external IP."""%location,
451 RuntimeWarning)
446 RuntimeWarning)
452 elif not sshserver:
447 elif not sshserver:
453 # otherwise sync with cfg
448 # otherwise sync with cfg
454 sshserver = cfg['ssh']
449 sshserver = cfg['ssh']
455
450
456 self._config = cfg
451 self._config = cfg
457
452
458 self._ssh = bool(sshserver or sshkey or password)
453 self._ssh = bool(sshserver or sshkey or password)
459 if self._ssh and sshserver is None:
454 if self._ssh and sshserver is None:
460 # default to ssh via localhost
455 # default to ssh via localhost
461 sshserver = addr
456 sshserver = addr
462 if self._ssh and password is None:
457 if self._ssh and password is None:
463 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
464 password=False
459 password=False
465 else:
460 else:
466 password = getpass("SSH Password for %s: "%sshserver)
461 password = getpass("SSH Password for %s: "%sshserver)
467 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
468
463
469 # configure and construct the session
464 # configure and construct the session
470 try:
465 try:
471 extra_args['packer'] = cfg['pack']
466 extra_args['packer'] = cfg['pack']
472 extra_args['unpacker'] = cfg['unpack']
467 extra_args['unpacker'] = cfg['unpack']
473 extra_args['key'] = cast_bytes(cfg['key'])
468 extra_args['key'] = cast_bytes(cfg['key'])
474 extra_args['signature_scheme'] = cfg['signature_scheme']
469 extra_args['signature_scheme'] = cfg['signature_scheme']
475 except KeyError as exc:
470 except KeyError as exc:
476 msg = '\n'.join([
471 msg = '\n'.join([
477 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
478 "If you are reusing connection files, remove them and start ipcontroller again."
473 "If you are reusing connection files, remove them and start ipcontroller again."
479 ])
474 ])
480 raise ValueError(msg.format(exc.message))
475 raise ValueError(msg.format(exc.message))
481
476
482 self.session = Session(**extra_args)
477 self.session = Session(**extra_args)
483
478
484 self._query_socket = self._context.socket(zmq.DEALER)
479 self._query_socket = self._context.socket(zmq.DEALER)
485
480
486 if self._ssh:
481 if self._ssh:
487 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
488 else:
483 else:
489 self._query_socket.connect(cfg['registration'])
484 self._query_socket.connect(cfg['registration'])
490
485
491 self.session.debug = self.debug
486 self.session.debug = self.debug
492
487
493 self._notification_handlers = {'registration_notification' : self._register_engine,
488 self._notification_handlers = {'registration_notification' : self._register_engine,
494 'unregistration_notification' : self._unregister_engine,
489 'unregistration_notification' : self._unregister_engine,
495 'shutdown_notification' : lambda msg: self.close(),
490 'shutdown_notification' : lambda msg: self.close(),
496 }
491 }
497 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
498 'apply_reply' : self._handle_apply_reply}
493 'apply_reply' : self._handle_apply_reply}
499
494
500 try:
495 try:
501 self._connect(sshserver, ssh_kwargs, timeout)
496 self._connect(sshserver, ssh_kwargs, timeout)
502 except:
497 except:
503 self.close(linger=0)
498 self.close(linger=0)
504 raise
499 raise
505
500
506 # last step: setup magics, if we are in IPython:
501 # last step: setup magics, if we are in IPython:
507
502
508 try:
503 try:
509 ip = get_ipython()
504 ip = get_ipython()
510 except NameError:
505 except NameError:
511 return
506 return
512 else:
507 else:
513 if 'px' not in ip.magics_manager.magics:
508 if 'px' not in ip.magics_manager.magics:
514 # in IPython but we are the first Client.
509 # in IPython but we are the first Client.
515 # activate a default view for parallel magics.
510 # activate a default view for parallel magics.
516 self.activate()
511 self.activate()
517
512
518 def __del__(self):
513 def __del__(self):
519 """cleanup sockets, but _not_ context."""
514 """cleanup sockets, but _not_ context."""
520 self.close()
515 self.close()
521
516
522 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
523 if ipython_dir is None:
518 if ipython_dir is None:
524 ipython_dir = get_ipython_dir()
519 ipython_dir = get_ipython_dir()
525 if profile_dir is not None:
520 if profile_dir is not None:
526 try:
521 try:
527 self._cd = ProfileDir.find_profile_dir(profile_dir)
522 self._cd = ProfileDir.find_profile_dir(profile_dir)
528 return
523 return
529 except ProfileDirError:
524 except ProfileDirError:
530 pass
525 pass
531 elif profile is not None:
526 elif profile is not None:
532 try:
527 try:
533 self._cd = ProfileDir.find_profile_dir_by_name(
528 self._cd = ProfileDir.find_profile_dir_by_name(
534 ipython_dir, profile)
529 ipython_dir, profile)
535 return
530 return
536 except ProfileDirError:
531 except ProfileDirError:
537 pass
532 pass
538 self._cd = None
533 self._cd = None
539
534
540 def _update_engines(self, engines):
535 def _update_engines(self, engines):
541 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
542 for k,v in engines.iteritems():
537 for k,v in engines.iteritems():
543 eid = int(k)
538 eid = int(k)
544 if eid not in self._engines:
539 if eid not in self._engines:
545 self._ids.append(eid)
540 self._ids.append(eid)
546 self._engines[eid] = v
541 self._engines[eid] = v
547 self._ids = sorted(self._ids)
542 self._ids = sorted(self._ids)
548 if sorted(self._engines.keys()) != range(len(self._engines)) and \
543 if sorted(self._engines.keys()) != range(len(self._engines)) and \
549 self._task_scheme == 'pure' and self._task_socket:
544 self._task_scheme == 'pure' and self._task_socket:
550 self._stop_scheduling_tasks()
545 self._stop_scheduling_tasks()
551
546
552 def _stop_scheduling_tasks(self):
547 def _stop_scheduling_tasks(self):
553 """Stop scheduling tasks because an engine has been unregistered
548 """Stop scheduling tasks because an engine has been unregistered
554 from a pure ZMQ scheduler.
549 from a pure ZMQ scheduler.
555 """
550 """
556 self._task_socket.close()
551 self._task_socket.close()
557 self._task_socket = None
552 self._task_socket = None
558 msg = "An engine has been unregistered, and we are using pure " +\
553 msg = "An engine has been unregistered, and we are using pure " +\
559 "ZMQ task scheduling. Task farming will be disabled."
554 "ZMQ task scheduling. Task farming will be disabled."
560 if self.outstanding:
555 if self.outstanding:
561 msg += " If you were running tasks when this happened, " +\
556 msg += " If you were running tasks when this happened, " +\
562 "some `outstanding` msg_ids may never resolve."
557 "some `outstanding` msg_ids may never resolve."
563 warnings.warn(msg, RuntimeWarning)
558 warnings.warn(msg, RuntimeWarning)
564
559
565 def _build_targets(self, targets):
560 def _build_targets(self, targets):
566 """Turn valid target IDs or 'all' into two lists:
561 """Turn valid target IDs or 'all' into two lists:
567 (int_ids, uuids).
562 (int_ids, uuids).
568 """
563 """
569 if not self._ids:
564 if not self._ids:
570 # flush notification socket if no engines yet, just in case
565 # flush notification socket if no engines yet, just in case
571 if not self.ids:
566 if not self.ids:
572 raise error.NoEnginesRegistered("Can't build targets without any engines")
567 raise error.NoEnginesRegistered("Can't build targets without any engines")
573
568
574 if targets is None:
569 if targets is None:
575 targets = self._ids
570 targets = self._ids
576 elif isinstance(targets, string_types):
571 elif isinstance(targets, string_types):
577 if targets.lower() == 'all':
572 if targets.lower() == 'all':
578 targets = self._ids
573 targets = self._ids
579 else:
574 else:
580 raise TypeError("%r not valid str target, must be 'all'"%(targets))
575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
581 elif isinstance(targets, int):
576 elif isinstance(targets, int):
582 if targets < 0:
577 if targets < 0:
583 targets = self.ids[targets]
578 targets = self.ids[targets]
584 if targets not in self._ids:
579 if targets not in self._ids:
585 raise IndexError("No such engine: %i"%targets)
580 raise IndexError("No such engine: %i"%targets)
586 targets = [targets]
581 targets = [targets]
587
582
588 if isinstance(targets, slice):
583 if isinstance(targets, slice):
589 indices = range(len(self._ids))[targets]
584 indices = range(len(self._ids))[targets]
590 ids = self.ids
585 ids = self.ids
591 targets = [ ids[i] for i in indices ]
586 targets = [ ids[i] for i in indices ]
592
587
593 if not isinstance(targets, (tuple, list, xrange)):
588 if not isinstance(targets, (tuple, list, xrange)):
594 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
595
590
596 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
597
592
598 def _connect(self, sshserver, ssh_kwargs, timeout):
593 def _connect(self, sshserver, ssh_kwargs, timeout):
599 """setup all our socket connections to the cluster. This is called from
594 """setup all our socket connections to the cluster. This is called from
600 __init__."""
595 __init__."""
601
596
602 # Maybe allow reconnecting?
597 # Maybe allow reconnecting?
603 if self._connected:
598 if self._connected:
604 return
599 return
605 self._connected=True
600 self._connected=True
606
601
607 def connect_socket(s, url):
602 def connect_socket(s, url):
608 if self._ssh:
603 if self._ssh:
609 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
610 else:
605 else:
611 return s.connect(url)
606 return s.connect(url)
612
607
613 self.session.send(self._query_socket, 'connection_request')
608 self.session.send(self._query_socket, 'connection_request')
614 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
615 poller = zmq.Poller()
610 poller = zmq.Poller()
616 poller.register(self._query_socket, zmq.POLLIN)
611 poller.register(self._query_socket, zmq.POLLIN)
617 # poll expects milliseconds, timeout is seconds
612 # poll expects milliseconds, timeout is seconds
618 evts = poller.poll(timeout*1000)
613 evts = poller.poll(timeout*1000)
619 if not evts:
614 if not evts:
620 raise error.TimeoutError("Hub connection request timed out")
615 raise error.TimeoutError("Hub connection request timed out")
621 idents,msg = self.session.recv(self._query_socket,mode=0)
616 idents,msg = self.session.recv(self._query_socket,mode=0)
622 if self.debug:
617 if self.debug:
623 pprint(msg)
618 pprint(msg)
624 content = msg['content']
619 content = msg['content']
625 # self._config['registration'] = dict(content)
620 # self._config['registration'] = dict(content)
626 cfg = self._config
621 cfg = self._config
627 if content['status'] == 'ok':
622 if content['status'] == 'ok':
628 self._mux_socket = self._context.socket(zmq.DEALER)
623 self._mux_socket = self._context.socket(zmq.DEALER)
629 connect_socket(self._mux_socket, cfg['mux'])
624 connect_socket(self._mux_socket, cfg['mux'])
630
625
631 self._task_socket = self._context.socket(zmq.DEALER)
626 self._task_socket = self._context.socket(zmq.DEALER)
632 connect_socket(self._task_socket, cfg['task'])
627 connect_socket(self._task_socket, cfg['task'])
633
628
634 self._notification_socket = self._context.socket(zmq.SUB)
629 self._notification_socket = self._context.socket(zmq.SUB)
635 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
636 connect_socket(self._notification_socket, cfg['notification'])
631 connect_socket(self._notification_socket, cfg['notification'])
637
632
638 self._control_socket = self._context.socket(zmq.DEALER)
633 self._control_socket = self._context.socket(zmq.DEALER)
639 connect_socket(self._control_socket, cfg['control'])
634 connect_socket(self._control_socket, cfg['control'])
640
635
641 self._iopub_socket = self._context.socket(zmq.SUB)
636 self._iopub_socket = self._context.socket(zmq.SUB)
642 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
643 connect_socket(self._iopub_socket, cfg['iopub'])
638 connect_socket(self._iopub_socket, cfg['iopub'])
644
639
645 self._update_engines(dict(content['engines']))
640 self._update_engines(dict(content['engines']))
646 else:
641 else:
647 self._connected = False
642 self._connected = False
648 raise Exception("Failed to connect!")
643 raise Exception("Failed to connect!")
649
644
650 #--------------------------------------------------------------------------
645 #--------------------------------------------------------------------------
651 # handlers and callbacks for incoming messages
646 # handlers and callbacks for incoming messages
652 #--------------------------------------------------------------------------
647 #--------------------------------------------------------------------------
653
648
654 def _unwrap_exception(self, content):
649 def _unwrap_exception(self, content):
655 """unwrap exception, and remap engine_id to int."""
650 """unwrap exception, and remap engine_id to int."""
656 e = error.unwrap_exception(content)
651 e = error.unwrap_exception(content)
657 # print e.traceback
652 # print e.traceback
658 if e.engine_info:
653 if e.engine_info:
659 e_uuid = e.engine_info['engine_uuid']
654 e_uuid = e.engine_info['engine_uuid']
660 eid = self._engines[e_uuid]
655 eid = self._engines[e_uuid]
661 e.engine_info['engine_id'] = eid
656 e.engine_info['engine_id'] = eid
662 return e
657 return e
663
658
664 def _extract_metadata(self, msg):
659 def _extract_metadata(self, msg):
665 header = msg['header']
660 header = msg['header']
666 parent = msg['parent_header']
661 parent = msg['parent_header']
667 msg_meta = msg['metadata']
662 msg_meta = msg['metadata']
668 content = msg['content']
663 content = msg['content']
669 md = {'msg_id' : parent['msg_id'],
664 md = {'msg_id' : parent['msg_id'],
670 'received' : datetime.now(),
665 'received' : datetime.now(),
671 'engine_uuid' : msg_meta.get('engine', None),
666 'engine_uuid' : msg_meta.get('engine', None),
672 'follow' : msg_meta.get('follow', []),
667 'follow' : msg_meta.get('follow', []),
673 'after' : msg_meta.get('after', []),
668 'after' : msg_meta.get('after', []),
674 'status' : content['status'],
669 'status' : content['status'],
675 }
670 }
676
671
677 if md['engine_uuid'] is not None:
672 if md['engine_uuid'] is not None:
678 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
679
674
680 if 'date' in parent:
675 if 'date' in parent:
681 md['submitted'] = parent['date']
676 md['submitted'] = parent['date']
682 if 'started' in msg_meta:
677 if 'started' in msg_meta:
683 md['started'] = msg_meta['started']
678 md['started'] = msg_meta['started']
684 if 'date' in header:
679 if 'date' in header:
685 md['completed'] = header['date']
680 md['completed'] = header['date']
686 return md
681 return md
687
682
688 def _register_engine(self, msg):
683 def _register_engine(self, msg):
689 """Register a new engine, and update our connection info."""
684 """Register a new engine, and update our connection info."""
690 content = msg['content']
685 content = msg['content']
691 eid = content['id']
686 eid = content['id']
692 d = {eid : content['uuid']}
687 d = {eid : content['uuid']}
693 self._update_engines(d)
688 self._update_engines(d)
694
689
695 def _unregister_engine(self, msg):
690 def _unregister_engine(self, msg):
696 """Unregister an engine that has died."""
691 """Unregister an engine that has died."""
697 content = msg['content']
692 content = msg['content']
698 eid = int(content['id'])
693 eid = int(content['id'])
699 if eid in self._ids:
694 if eid in self._ids:
700 self._ids.remove(eid)
695 self._ids.remove(eid)
701 uuid = self._engines.pop(eid)
696 uuid = self._engines.pop(eid)
702
697
703 self._handle_stranded_msgs(eid, uuid)
698 self._handle_stranded_msgs(eid, uuid)
704
699
705 if self._task_socket and self._task_scheme == 'pure':
700 if self._task_socket and self._task_scheme == 'pure':
706 self._stop_scheduling_tasks()
701 self._stop_scheduling_tasks()
707
702
708 def _handle_stranded_msgs(self, eid, uuid):
703 def _handle_stranded_msgs(self, eid, uuid):
709 """Handle messages known to be on an engine when the engine unregisters.
704 """Handle messages known to be on an engine when the engine unregisters.
710
705
711 It is possible that this will fire prematurely - that is, an engine will
706 It is possible that this will fire prematurely - that is, an engine will
712 go down after completing a result, and the client will be notified
707 go down after completing a result, and the client will be notified
713 of the unregistration and later receive the successful result.
708 of the unregistration and later receive the successful result.
714 """
709 """
715
710
716 outstanding = self._outstanding_dict[uuid]
711 outstanding = self._outstanding_dict[uuid]
717
712
718 for msg_id in list(outstanding):
713 for msg_id in list(outstanding):
719 if msg_id in self.results:
714 if msg_id in self.results:
720 # we already
715 # we already
721 continue
716 continue
722 try:
717 try:
723 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
724 except:
719 except:
725 content = error.wrap_exception()
720 content = error.wrap_exception()
726 # build a fake message:
721 # build a fake message:
727 msg = self.session.msg('apply_reply', content=content)
722 msg = self.session.msg('apply_reply', content=content)
728 msg['parent_header']['msg_id'] = msg_id
723 msg['parent_header']['msg_id'] = msg_id
729 msg['metadata']['engine'] = uuid
724 msg['metadata']['engine'] = uuid
730 self._handle_apply_reply(msg)
725 self._handle_apply_reply(msg)
731
726
732 def _handle_execute_reply(self, msg):
727 def _handle_execute_reply(self, msg):
733 """Save the reply to an execute_request into our results.
728 """Save the reply to an execute_request into our results.
734
729
735 execute messages are never actually used. apply is used instead.
730 execute messages are never actually used. apply is used instead.
736 """
731 """
737
732
738 parent = msg['parent_header']
733 parent = msg['parent_header']
739 msg_id = parent['msg_id']
734 msg_id = parent['msg_id']
740 if msg_id not in self.outstanding:
735 if msg_id not in self.outstanding:
741 if msg_id in self.history:
736 if msg_id in self.history:
742 print(("got stale result: %s"%msg_id))
737 print(("got stale result: %s"%msg_id))
743 else:
738 else:
744 print(("got unknown result: %s"%msg_id))
739 print(("got unknown result: %s"%msg_id))
745 else:
740 else:
746 self.outstanding.remove(msg_id)
741 self.outstanding.remove(msg_id)
747
742
748 content = msg['content']
743 content = msg['content']
749 header = msg['header']
744 header = msg['header']
750
745
751 # construct metadata:
746 # construct metadata:
752 md = self.metadata[msg_id]
747 md = self.metadata[msg_id]
753 md.update(self._extract_metadata(msg))
748 md.update(self._extract_metadata(msg))
754 # is this redundant?
749 # is this redundant?
755 self.metadata[msg_id] = md
750 self.metadata[msg_id] = md
756
751
757 e_outstanding = self._outstanding_dict[md['engine_uuid']]
752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
758 if msg_id in e_outstanding:
753 if msg_id in e_outstanding:
759 e_outstanding.remove(msg_id)
754 e_outstanding.remove(msg_id)
760
755
761 # construct result:
756 # construct result:
762 if content['status'] == 'ok':
757 if content['status'] == 'ok':
763 self.results[msg_id] = ExecuteReply(msg_id, content, md)
758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
764 elif content['status'] == 'aborted':
759 elif content['status'] == 'aborted':
765 self.results[msg_id] = error.TaskAborted(msg_id)
760 self.results[msg_id] = error.TaskAborted(msg_id)
766 elif content['status'] == 'resubmitted':
761 elif content['status'] == 'resubmitted':
767 # TODO: handle resubmission
762 # TODO: handle resubmission
768 pass
763 pass
769 else:
764 else:
770 self.results[msg_id] = self._unwrap_exception(content)
765 self.results[msg_id] = self._unwrap_exception(content)
771
766
772 def _handle_apply_reply(self, msg):
767 def _handle_apply_reply(self, msg):
773 """Save the reply to an apply_request into our results."""
768 """Save the reply to an apply_request into our results."""
774 parent = msg['parent_header']
769 parent = msg['parent_header']
775 msg_id = parent['msg_id']
770 msg_id = parent['msg_id']
776 if msg_id not in self.outstanding:
771 if msg_id not in self.outstanding:
777 if msg_id in self.history:
772 if msg_id in self.history:
778 print(("got stale result: %s"%msg_id))
773 print(("got stale result: %s"%msg_id))
779 print(self.results[msg_id])
774 print(self.results[msg_id])
780 print(msg)
775 print(msg)
781 else:
776 else:
782 print(("got unknown result: %s"%msg_id))
777 print(("got unknown result: %s"%msg_id))
783 else:
778 else:
784 self.outstanding.remove(msg_id)
779 self.outstanding.remove(msg_id)
785 content = msg['content']
780 content = msg['content']
786 header = msg['header']
781 header = msg['header']
787
782
788 # construct metadata:
783 # construct metadata:
789 md = self.metadata[msg_id]
784 md = self.metadata[msg_id]
790 md.update(self._extract_metadata(msg))
785 md.update(self._extract_metadata(msg))
791 # is this redundant?
786 # is this redundant?
792 self.metadata[msg_id] = md
787 self.metadata[msg_id] = md
793
788
794 e_outstanding = self._outstanding_dict[md['engine_uuid']]
789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
795 if msg_id in e_outstanding:
790 if msg_id in e_outstanding:
796 e_outstanding.remove(msg_id)
791 e_outstanding.remove(msg_id)
797
792
798 # construct result:
793 # construct result:
799 if content['status'] == 'ok':
794 if content['status'] == 'ok':
800 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
801 elif content['status'] == 'aborted':
796 elif content['status'] == 'aborted':
802 self.results[msg_id] = error.TaskAborted(msg_id)
797 self.results[msg_id] = error.TaskAborted(msg_id)
803 elif content['status'] == 'resubmitted':
798 elif content['status'] == 'resubmitted':
804 # TODO: handle resubmission
799 # TODO: handle resubmission
805 pass
800 pass
806 else:
801 else:
807 self.results[msg_id] = self._unwrap_exception(content)
802 self.results[msg_id] = self._unwrap_exception(content)
808
803
809 def _flush_notifications(self):
804 def _flush_notifications(self):
810 """Flush notifications of engine registrations waiting
805 """Flush notifications of engine registrations waiting
811 in ZMQ queue."""
806 in ZMQ queue."""
812 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
813 while msg is not None:
808 while msg is not None:
814 if self.debug:
809 if self.debug:
815 pprint(msg)
810 pprint(msg)
816 msg_type = msg['header']['msg_type']
811 msg_type = msg['header']['msg_type']
817 handler = self._notification_handlers.get(msg_type, None)
812 handler = self._notification_handlers.get(msg_type, None)
818 if handler is None:
813 if handler is None:
819 raise Exception("Unhandled message type: %s" % msg_type)
814 raise Exception("Unhandled message type: %s" % msg_type)
820 else:
815 else:
821 handler(msg)
816 handler(msg)
822 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
823
818
824 def _flush_results(self, sock):
819 def _flush_results(self, sock):
825 """Flush task or queue results waiting in ZMQ queue."""
820 """Flush task or queue results waiting in ZMQ queue."""
826 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
827 while msg is not None:
822 while msg is not None:
828 if self.debug:
823 if self.debug:
829 pprint(msg)
824 pprint(msg)
830 msg_type = msg['header']['msg_type']
825 msg_type = msg['header']['msg_type']
831 handler = self._queue_handlers.get(msg_type, None)
826 handler = self._queue_handlers.get(msg_type, None)
832 if handler is None:
827 if handler is None:
833 raise Exception("Unhandled message type: %s" % msg_type)
828 raise Exception("Unhandled message type: %s" % msg_type)
834 else:
829 else:
835 handler(msg)
830 handler(msg)
836 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
837
832
838 def _flush_control(self, sock):
833 def _flush_control(self, sock):
839 """Flush replies from the control channel waiting
834 """Flush replies from the control channel waiting
840 in the ZMQ queue.
835 in the ZMQ queue.
841
836
842 Currently: ignore them."""
837 Currently: ignore them."""
843 if self._ignored_control_replies <= 0:
838 if self._ignored_control_replies <= 0:
844 return
839 return
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846 while msg is not None:
841 while msg is not None:
847 self._ignored_control_replies -= 1
842 self._ignored_control_replies -= 1
848 if self.debug:
843 if self.debug:
849 pprint(msg)
844 pprint(msg)
850 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
851
846
852 def _flush_ignored_control(self):
847 def _flush_ignored_control(self):
853 """flush ignored control replies"""
848 """flush ignored control replies"""
854 while self._ignored_control_replies > 0:
849 while self._ignored_control_replies > 0:
855 self.session.recv(self._control_socket)
850 self.session.recv(self._control_socket)
856 self._ignored_control_replies -= 1
851 self._ignored_control_replies -= 1
857
852
858 def _flush_ignored_hub_replies(self):
853 def _flush_ignored_hub_replies(self):
859 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
860 while msg is not None:
855 while msg is not None:
861 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
862
857
863 def _flush_iopub(self, sock):
858 def _flush_iopub(self, sock):
864 """Flush replies from the iopub channel waiting
859 """Flush replies from the iopub channel waiting
865 in the ZMQ queue.
860 in the ZMQ queue.
866 """
861 """
867 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
868 while msg is not None:
863 while msg is not None:
869 if self.debug:
864 if self.debug:
870 pprint(msg)
865 pprint(msg)
871 parent = msg['parent_header']
866 parent = msg['parent_header']
872 # ignore IOPub messages with no parent.
867 # ignore IOPub messages with no parent.
873 # Caused by print statements or warnings from before the first execution.
868 # Caused by print statements or warnings from before the first execution.
874 if not parent:
869 if not parent:
875 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
876 continue
871 continue
877 msg_id = parent['msg_id']
872 msg_id = parent['msg_id']
878 content = msg['content']
873 content = msg['content']
879 header = msg['header']
874 header = msg['header']
880 msg_type = msg['header']['msg_type']
875 msg_type = msg['header']['msg_type']
881
876
882 # init metadata:
877 # init metadata:
883 md = self.metadata[msg_id]
878 md = self.metadata[msg_id]
884
879
885 if msg_type == 'stream':
880 if msg_type == 'stream':
886 name = content['name']
881 name = content['name']
887 s = md[name] or ''
882 s = md[name] or ''
888 md[name] = s + content['data']
883 md[name] = s + content['data']
889 elif msg_type == 'pyerr':
884 elif msg_type == 'pyerr':
890 md.update({'pyerr' : self._unwrap_exception(content)})
885 md.update({'pyerr' : self._unwrap_exception(content)})
891 elif msg_type == 'pyin':
886 elif msg_type == 'pyin':
892 md.update({'pyin' : content['code']})
887 md.update({'pyin' : content['code']})
893 elif msg_type == 'display_data':
888 elif msg_type == 'display_data':
894 md['outputs'].append(content)
889 md['outputs'].append(content)
895 elif msg_type == 'pyout':
890 elif msg_type == 'pyout':
896 md['pyout'] = content
891 md['pyout'] = content
897 elif msg_type == 'data_message':
892 elif msg_type == 'data_message':
898 data, remainder = serialize.unserialize_object(msg['buffers'])
893 data, remainder = serialize.unserialize_object(msg['buffers'])
899 md['data'].update(data)
894 md['data'].update(data)
900 elif msg_type == 'status':
895 elif msg_type == 'status':
901 # idle message comes after all outputs
896 # idle message comes after all outputs
902 if content['execution_state'] == 'idle':
897 if content['execution_state'] == 'idle':
903 md['outputs_ready'] = True
898 md['outputs_ready'] = True
904 else:
899 else:
905 # unhandled msg_type (status, etc.)
900 # unhandled msg_type (status, etc.)
906 pass
901 pass
907
902
908 # reduntant?
903 # reduntant?
909 self.metadata[msg_id] = md
904 self.metadata[msg_id] = md
910
905
911 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
912
907
913 #--------------------------------------------------------------------------
908 #--------------------------------------------------------------------------
914 # len, getitem
909 # len, getitem
915 #--------------------------------------------------------------------------
910 #--------------------------------------------------------------------------
916
911
917 def __len__(self):
912 def __len__(self):
918 """len(client) returns # of engines."""
913 """len(client) returns # of engines."""
919 return len(self.ids)
914 return len(self.ids)
920
915
921 def __getitem__(self, key):
916 def __getitem__(self, key):
922 """index access returns DirectView multiplexer objects
917 """index access returns DirectView multiplexer objects
923
918
924 Must be int, slice, or list/tuple/xrange of ints"""
919 Must be int, slice, or list/tuple/xrange of ints"""
925 if not isinstance(key, (int, slice, tuple, list, xrange)):
920 if not isinstance(key, (int, slice, tuple, list, xrange)):
926 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
927 else:
922 else:
928 return self.direct_view(key)
923 return self.direct_view(key)
929
924
930 #--------------------------------------------------------------------------
925 #--------------------------------------------------------------------------
931 # Begin public methods
926 # Begin public methods
932 #--------------------------------------------------------------------------
927 #--------------------------------------------------------------------------
933
928
934 @property
929 @property
935 def ids(self):
930 def ids(self):
936 """Always up-to-date ids property."""
931 """Always up-to-date ids property."""
937 self._flush_notifications()
932 self._flush_notifications()
938 # always copy:
933 # always copy:
939 return list(self._ids)
934 return list(self._ids)
940
935
941 def activate(self, targets='all', suffix=''):
936 def activate(self, targets='all', suffix=''):
942 """Create a DirectView and register it with IPython magics
937 """Create a DirectView and register it with IPython magics
943
938
944 Defines the magics `%px, %autopx, %pxresult, %%px`
939 Defines the magics `%px, %autopx, %pxresult, %%px`
945
940
946 Parameters
941 Parameters
947 ----------
942 ----------
948
943
949 targets: int, list of ints, or 'all'
944 targets: int, list of ints, or 'all'
950 The engines on which the view's magics will run
945 The engines on which the view's magics will run
951 suffix: str [default: '']
946 suffix: str [default: '']
952 The suffix, if any, for the magics. This allows you to have
947 The suffix, if any, for the magics. This allows you to have
953 multiple views associated with parallel magics at the same time.
948 multiple views associated with parallel magics at the same time.
954
949
955 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
956 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
957 on engine 0.
952 on engine 0.
958 """
953 """
959 view = self.direct_view(targets)
954 view = self.direct_view(targets)
960 view.block = True
955 view.block = True
961 view.activate(suffix)
956 view.activate(suffix)
962 return view
957 return view
963
958
964 def close(self, linger=None):
959 def close(self, linger=None):
965 """Close my zmq Sockets
960 """Close my zmq Sockets
966
961
967 If `linger`, set the zmq LINGER socket option,
962 If `linger`, set the zmq LINGER socket option,
968 which allows discarding of messages.
963 which allows discarding of messages.
969 """
964 """
970 if self._closed:
965 if self._closed:
971 return
966 return
972 self.stop_spin_thread()
967 self.stop_spin_thread()
973 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
974 for name in snames:
969 for name in snames:
975 socket = getattr(self, name)
970 socket = getattr(self, name)
976 if socket is not None and not socket.closed:
971 if socket is not None and not socket.closed:
977 if linger is not None:
972 if linger is not None:
978 socket.close(linger=linger)
973 socket.close(linger=linger)
979 else:
974 else:
980 socket.close()
975 socket.close()
981 self._closed = True
976 self._closed = True
982
977
983 def _spin_every(self, interval=1):
978 def _spin_every(self, interval=1):
984 """target func for use in spin_thread"""
979 """target func for use in spin_thread"""
985 while True:
980 while True:
986 if self._stop_spinning.is_set():
981 if self._stop_spinning.is_set():
987 return
982 return
988 time.sleep(interval)
983 time.sleep(interval)
989 self.spin()
984 self.spin()
990
985
991 def spin_thread(self, interval=1):
986 def spin_thread(self, interval=1):
992 """call Client.spin() in a background thread on some regular interval
987 """call Client.spin() in a background thread on some regular interval
993
988
994 This helps ensure that messages don't pile up too much in the zmq queue
989 This helps ensure that messages don't pile up too much in the zmq queue
995 while you are working on other things, or just leaving an idle terminal.
990 while you are working on other things, or just leaving an idle terminal.
996
991
997 It also helps limit potential padding of the `received` timestamp
992 It also helps limit potential padding of the `received` timestamp
998 on AsyncResult objects, used for timings.
993 on AsyncResult objects, used for timings.
999
994
1000 Parameters
995 Parameters
1001 ----------
996 ----------
1002
997
1003 interval : float, optional
998 interval : float, optional
1004 The interval on which to spin the client in the background thread
999 The interval on which to spin the client in the background thread
1005 (simply passed to time.sleep).
1000 (simply passed to time.sleep).
1006
1001
1007 Notes
1002 Notes
1008 -----
1003 -----
1009
1004
1010 For precision timing, you may want to use this method to put a bound
1005 For precision timing, you may want to use this method to put a bound
1011 on the jitter (in seconds) in `received` timestamps used
1006 on the jitter (in seconds) in `received` timestamps used
1012 in AsyncResult.wall_time.
1007 in AsyncResult.wall_time.
1013
1008
1014 """
1009 """
1015 if self._spin_thread is not None:
1010 if self._spin_thread is not None:
1016 self.stop_spin_thread()
1011 self.stop_spin_thread()
1017 self._stop_spinning.clear()
1012 self._stop_spinning.clear()
1018 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1019 self._spin_thread.daemon = True
1014 self._spin_thread.daemon = True
1020 self._spin_thread.start()
1015 self._spin_thread.start()
1021
1016
1022 def stop_spin_thread(self):
1017 def stop_spin_thread(self):
1023 """stop background spin_thread, if any"""
1018 """stop background spin_thread, if any"""
1024 if self._spin_thread is not None:
1019 if self._spin_thread is not None:
1025 self._stop_spinning.set()
1020 self._stop_spinning.set()
1026 self._spin_thread.join()
1021 self._spin_thread.join()
1027 self._spin_thread = None
1022 self._spin_thread = None
1028
1023
1029 def spin(self):
1024 def spin(self):
1030 """Flush any registration notifications and execution results
1025 """Flush any registration notifications and execution results
1031 waiting in the ZMQ queue.
1026 waiting in the ZMQ queue.
1032 """
1027 """
1033 if self._notification_socket:
1028 if self._notification_socket:
1034 self._flush_notifications()
1029 self._flush_notifications()
1035 if self._iopub_socket:
1030 if self._iopub_socket:
1036 self._flush_iopub(self._iopub_socket)
1031 self._flush_iopub(self._iopub_socket)
1037 if self._mux_socket:
1032 if self._mux_socket:
1038 self._flush_results(self._mux_socket)
1033 self._flush_results(self._mux_socket)
1039 if self._task_socket:
1034 if self._task_socket:
1040 self._flush_results(self._task_socket)
1035 self._flush_results(self._task_socket)
1041 if self._control_socket:
1036 if self._control_socket:
1042 self._flush_control(self._control_socket)
1037 self._flush_control(self._control_socket)
1043 if self._query_socket:
1038 if self._query_socket:
1044 self._flush_ignored_hub_replies()
1039 self._flush_ignored_hub_replies()
1045
1040
1046 def wait(self, jobs=None, timeout=-1):
1041 def wait(self, jobs=None, timeout=-1):
1047 """waits on one or more `jobs`, for up to `timeout` seconds.
1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1048
1043
1049 Parameters
1044 Parameters
1050 ----------
1045 ----------
1051
1046
1052 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1053 ints are indices to self.history
1048 ints are indices to self.history
1054 strs are msg_ids
1049 strs are msg_ids
1055 default: wait on all outstanding messages
1050 default: wait on all outstanding messages
1056 timeout : float
1051 timeout : float
1057 a time in seconds, after which to give up.
1052 a time in seconds, after which to give up.
1058 default is -1, which means no timeout
1053 default is -1, which means no timeout
1059
1054
1060 Returns
1055 Returns
1061 -------
1056 -------
1062
1057
1063 True : when all msg_ids are done
1058 True : when all msg_ids are done
1064 False : timeout reached, some msg_ids still outstanding
1059 False : timeout reached, some msg_ids still outstanding
1065 """
1060 """
1066 tic = time.time()
1061 tic = time.time()
1067 if jobs is None:
1062 if jobs is None:
1068 theids = self.outstanding
1063 theids = self.outstanding
1069 else:
1064 else:
1070 if isinstance(jobs, string_types + (int, AsyncResult)):
1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1071 jobs = [jobs]
1066 jobs = [jobs]
1072 theids = set()
1067 theids = set()
1073 for job in jobs:
1068 for job in jobs:
1074 if isinstance(job, int):
1069 if isinstance(job, int):
1075 # index access
1070 # index access
1076 job = self.history[job]
1071 job = self.history[job]
1077 elif isinstance(job, AsyncResult):
1072 elif isinstance(job, AsyncResult):
1078 map(theids.add, job.msg_ids)
1073 map(theids.add, job.msg_ids)
1079 continue
1074 continue
1080 theids.add(job)
1075 theids.add(job)
1081 if not theids.intersection(self.outstanding):
1076 if not theids.intersection(self.outstanding):
1082 return True
1077 return True
1083 self.spin()
1078 self.spin()
1084 while theids.intersection(self.outstanding):
1079 while theids.intersection(self.outstanding):
1085 if timeout >= 0 and ( time.time()-tic ) > timeout:
1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1086 break
1081 break
1087 time.sleep(1e-3)
1082 time.sleep(1e-3)
1088 self.spin()
1083 self.spin()
1089 return len(theids.intersection(self.outstanding)) == 0
1084 return len(theids.intersection(self.outstanding)) == 0
1090
1085
1091 #--------------------------------------------------------------------------
1086 #--------------------------------------------------------------------------
1092 # Control methods
1087 # Control methods
1093 #--------------------------------------------------------------------------
1088 #--------------------------------------------------------------------------
1094
1089
1095 @spin_first
1090 @spin_first
1096 def clear(self, targets=None, block=None):
1091 def clear(self, targets=None, block=None):
1097 """Clear the namespace in target(s)."""
1092 """Clear the namespace in target(s)."""
1098 block = self.block if block is None else block
1093 block = self.block if block is None else block
1099 targets = self._build_targets(targets)[0]
1094 targets = self._build_targets(targets)[0]
1100 for t in targets:
1095 for t in targets:
1101 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1102 error = False
1097 error = False
1103 if block:
1098 if block:
1104 self._flush_ignored_control()
1099 self._flush_ignored_control()
1105 for i in range(len(targets)):
1100 for i in range(len(targets)):
1106 idents,msg = self.session.recv(self._control_socket,0)
1101 idents,msg = self.session.recv(self._control_socket,0)
1107 if self.debug:
1102 if self.debug:
1108 pprint(msg)
1103 pprint(msg)
1109 if msg['content']['status'] != 'ok':
1104 if msg['content']['status'] != 'ok':
1110 error = self._unwrap_exception(msg['content'])
1105 error = self._unwrap_exception(msg['content'])
1111 else:
1106 else:
1112 self._ignored_control_replies += len(targets)
1107 self._ignored_control_replies += len(targets)
1113 if error:
1108 if error:
1114 raise error
1109 raise error
1115
1110
1116
1111
1117 @spin_first
1112 @spin_first
1118 def abort(self, jobs=None, targets=None, block=None):
1113 def abort(self, jobs=None, targets=None, block=None):
1119 """Abort specific jobs from the execution queues of target(s).
1114 """Abort specific jobs from the execution queues of target(s).
1120
1115
1121 This is a mechanism to prevent jobs that have already been submitted
1116 This is a mechanism to prevent jobs that have already been submitted
1122 from executing.
1117 from executing.
1123
1118
1124 Parameters
1119 Parameters
1125 ----------
1120 ----------
1126
1121
1127 jobs : msg_id, list of msg_ids, or AsyncResult
1122 jobs : msg_id, list of msg_ids, or AsyncResult
1128 The jobs to be aborted
1123 The jobs to be aborted
1129
1124
1130 If unspecified/None: abort all outstanding jobs.
1125 If unspecified/None: abort all outstanding jobs.
1131
1126
1132 """
1127 """
1133 block = self.block if block is None else block
1128 block = self.block if block is None else block
1134 jobs = jobs if jobs is not None else list(self.outstanding)
1129 jobs = jobs if jobs is not None else list(self.outstanding)
1135 targets = self._build_targets(targets)[0]
1130 targets = self._build_targets(targets)[0]
1136
1131
1137 msg_ids = []
1132 msg_ids = []
1138 if isinstance(jobs, string_types + (AsyncResult,)):
1133 if isinstance(jobs, string_types + (AsyncResult,)):
1139 jobs = [jobs]
1134 jobs = [jobs]
1140 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1135 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1141 if bad_ids:
1136 if bad_ids:
1142 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1143 for j in jobs:
1138 for j in jobs:
1144 if isinstance(j, AsyncResult):
1139 if isinstance(j, AsyncResult):
1145 msg_ids.extend(j.msg_ids)
1140 msg_ids.extend(j.msg_ids)
1146 else:
1141 else:
1147 msg_ids.append(j)
1142 msg_ids.append(j)
1148 content = dict(msg_ids=msg_ids)
1143 content = dict(msg_ids=msg_ids)
1149 for t in targets:
1144 for t in targets:
1150 self.session.send(self._control_socket, 'abort_request',
1145 self.session.send(self._control_socket, 'abort_request',
1151 content=content, ident=t)
1146 content=content, ident=t)
1152 error = False
1147 error = False
1153 if block:
1148 if block:
1154 self._flush_ignored_control()
1149 self._flush_ignored_control()
1155 for i in range(len(targets)):
1150 for i in range(len(targets)):
1156 idents,msg = self.session.recv(self._control_socket,0)
1151 idents,msg = self.session.recv(self._control_socket,0)
1157 if self.debug:
1152 if self.debug:
1158 pprint(msg)
1153 pprint(msg)
1159 if msg['content']['status'] != 'ok':
1154 if msg['content']['status'] != 'ok':
1160 error = self._unwrap_exception(msg['content'])
1155 error = self._unwrap_exception(msg['content'])
1161 else:
1156 else:
1162 self._ignored_control_replies += len(targets)
1157 self._ignored_control_replies += len(targets)
1163 if error:
1158 if error:
1164 raise error
1159 raise error
1165
1160
1166 @spin_first
1161 @spin_first
1167 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1168 """Terminates one or more engine processes, optionally including the hub.
1163 """Terminates one or more engine processes, optionally including the hub.
1169
1164
1170 Parameters
1165 Parameters
1171 ----------
1166 ----------
1172
1167
1173 targets: list of ints or 'all' [default: all]
1168 targets: list of ints or 'all' [default: all]
1174 Which engines to shutdown.
1169 Which engines to shutdown.
1175 hub: bool [default: False]
1170 hub: bool [default: False]
1176 Whether to include the Hub. hub=True implies targets='all'.
1171 Whether to include the Hub. hub=True implies targets='all'.
1177 block: bool [default: self.block]
1172 block: bool [default: self.block]
1178 Whether to wait for clean shutdown replies or not.
1173 Whether to wait for clean shutdown replies or not.
1179 restart: bool [default: False]
1174 restart: bool [default: False]
1180 NOT IMPLEMENTED
1175 NOT IMPLEMENTED
1181 whether to restart engines after shutting them down.
1176 whether to restart engines after shutting them down.
1182 """
1177 """
1183 from IPython.parallel.error import NoEnginesRegistered
1178 from IPython.parallel.error import NoEnginesRegistered
1184 if restart:
1179 if restart:
1185 raise NotImplementedError("Engine restart is not yet implemented")
1180 raise NotImplementedError("Engine restart is not yet implemented")
1186
1181
1187 block = self.block if block is None else block
1182 block = self.block if block is None else block
1188 if hub:
1183 if hub:
1189 targets = 'all'
1184 targets = 'all'
1190 try:
1185 try:
1191 targets = self._build_targets(targets)[0]
1186 targets = self._build_targets(targets)[0]
1192 except NoEnginesRegistered:
1187 except NoEnginesRegistered:
1193 targets = []
1188 targets = []
1194 for t in targets:
1189 for t in targets:
1195 self.session.send(self._control_socket, 'shutdown_request',
1190 self.session.send(self._control_socket, 'shutdown_request',
1196 content={'restart':restart},ident=t)
1191 content={'restart':restart},ident=t)
1197 error = False
1192 error = False
1198 if block or hub:
1193 if block or hub:
1199 self._flush_ignored_control()
1194 self._flush_ignored_control()
1200 for i in range(len(targets)):
1195 for i in range(len(targets)):
1201 idents,msg = self.session.recv(self._control_socket, 0)
1196 idents,msg = self.session.recv(self._control_socket, 0)
1202 if self.debug:
1197 if self.debug:
1203 pprint(msg)
1198 pprint(msg)
1204 if msg['content']['status'] != 'ok':
1199 if msg['content']['status'] != 'ok':
1205 error = self._unwrap_exception(msg['content'])
1200 error = self._unwrap_exception(msg['content'])
1206 else:
1201 else:
1207 self._ignored_control_replies += len(targets)
1202 self._ignored_control_replies += len(targets)
1208
1203
1209 if hub:
1204 if hub:
1210 time.sleep(0.25)
1205 time.sleep(0.25)
1211 self.session.send(self._query_socket, 'shutdown_request')
1206 self.session.send(self._query_socket, 'shutdown_request')
1212 idents,msg = self.session.recv(self._query_socket, 0)
1207 idents,msg = self.session.recv(self._query_socket, 0)
1213 if self.debug:
1208 if self.debug:
1214 pprint(msg)
1209 pprint(msg)
1215 if msg['content']['status'] != 'ok':
1210 if msg['content']['status'] != 'ok':
1216 error = self._unwrap_exception(msg['content'])
1211 error = self._unwrap_exception(msg['content'])
1217
1212
1218 if error:
1213 if error:
1219 raise error
1214 raise error
1220
1215
1221 #--------------------------------------------------------------------------
1216 #--------------------------------------------------------------------------
1222 # Execution related methods
1217 # Execution related methods
1223 #--------------------------------------------------------------------------
1218 #--------------------------------------------------------------------------
1224
1219
1225 def _maybe_raise(self, result):
1220 def _maybe_raise(self, result):
1226 """wrapper for maybe raising an exception if apply failed."""
1221 """wrapper for maybe raising an exception if apply failed."""
1227 if isinstance(result, error.RemoteError):
1222 if isinstance(result, error.RemoteError):
1228 raise result
1223 raise result
1229
1224
1230 return result
1225 return result
1231
1226
1232 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1233 ident=None):
1228 ident=None):
1234 """construct and send an apply message via a socket.
1229 """construct and send an apply message via a socket.
1235
1230
1236 This is the principal method with which all engine execution is performed by views.
1231 This is the principal method with which all engine execution is performed by views.
1237 """
1232 """
1238
1233
1239 if self._closed:
1234 if self._closed:
1240 raise RuntimeError("Client cannot be used after its sockets have been closed")
1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1241
1236
1242 # defaults:
1237 # defaults:
1243 args = args if args is not None else []
1238 args = args if args is not None else []
1244 kwargs = kwargs if kwargs is not None else {}
1239 kwargs = kwargs if kwargs is not None else {}
1245 metadata = metadata if metadata is not None else {}
1240 metadata = metadata if metadata is not None else {}
1246
1241
1247 # validate arguments
1242 # validate arguments
1248 if not callable(f) and not isinstance(f, Reference):
1243 if not callable(f) and not isinstance(f, Reference):
1249 raise TypeError("f must be callable, not %s"%type(f))
1244 raise TypeError("f must be callable, not %s"%type(f))
1250 if not isinstance(args, (tuple, list)):
1245 if not isinstance(args, (tuple, list)):
1251 raise TypeError("args must be tuple or list, not %s"%type(args))
1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1252 if not isinstance(kwargs, dict):
1247 if not isinstance(kwargs, dict):
1253 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1254 if not isinstance(metadata, dict):
1249 if not isinstance(metadata, dict):
1255 raise TypeError("metadata must be dict, not %s"%type(metadata))
1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1256
1251
1257 bufs = serialize.pack_apply_message(f, args, kwargs,
1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1258 buffer_threshold=self.session.buffer_threshold,
1253 buffer_threshold=self.session.buffer_threshold,
1259 item_threshold=self.session.item_threshold,
1254 item_threshold=self.session.item_threshold,
1260 )
1255 )
1261
1256
1262 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1263 metadata=metadata, track=track)
1258 metadata=metadata, track=track)
1264
1259
1265 msg_id = msg['header']['msg_id']
1260 msg_id = msg['header']['msg_id']
1266 self.outstanding.add(msg_id)
1261 self.outstanding.add(msg_id)
1267 if ident:
1262 if ident:
1268 # possibly routed to a specific engine
1263 # possibly routed to a specific engine
1269 if isinstance(ident, list):
1264 if isinstance(ident, list):
1270 ident = ident[-1]
1265 ident = ident[-1]
1271 if ident in self._engines.values():
1266 if ident in self._engines.values():
1272 # save for later, in case of engine death
1267 # save for later, in case of engine death
1273 self._outstanding_dict[ident].add(msg_id)
1268 self._outstanding_dict[ident].add(msg_id)
1274 self.history.append(msg_id)
1269 self.history.append(msg_id)
1275 self.metadata[msg_id]['submitted'] = datetime.now()
1270 self.metadata[msg_id]['submitted'] = datetime.now()
1276
1271
1277 return msg
1272 return msg
1278
1273
1279 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1280 """construct and send an execute request via a socket.
1275 """construct and send an execute request via a socket.
1281
1276
1282 """
1277 """
1283
1278
1284 if self._closed:
1279 if self._closed:
1285 raise RuntimeError("Client cannot be used after its sockets have been closed")
1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1286
1281
1287 # defaults:
1282 # defaults:
1288 metadata = metadata if metadata is not None else {}
1283 metadata = metadata if metadata is not None else {}
1289
1284
1290 # validate arguments
1285 # validate arguments
1291 if not isinstance(code, string_types):
1286 if not isinstance(code, string_types):
1292 raise TypeError("code must be text, not %s" % type(code))
1287 raise TypeError("code must be text, not %s" % type(code))
1293 if not isinstance(metadata, dict):
1288 if not isinstance(metadata, dict):
1294 raise TypeError("metadata must be dict, not %s" % type(metadata))
1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1295
1290
1296 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1297
1292
1298
1293
1299 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1300 metadata=metadata)
1295 metadata=metadata)
1301
1296
1302 msg_id = msg['header']['msg_id']
1297 msg_id = msg['header']['msg_id']
1303 self.outstanding.add(msg_id)
1298 self.outstanding.add(msg_id)
1304 if ident:
1299 if ident:
1305 # possibly routed to a specific engine
1300 # possibly routed to a specific engine
1306 if isinstance(ident, list):
1301 if isinstance(ident, list):
1307 ident = ident[-1]
1302 ident = ident[-1]
1308 if ident in self._engines.values():
1303 if ident in self._engines.values():
1309 # save for later, in case of engine death
1304 # save for later, in case of engine death
1310 self._outstanding_dict[ident].add(msg_id)
1305 self._outstanding_dict[ident].add(msg_id)
1311 self.history.append(msg_id)
1306 self.history.append(msg_id)
1312 self.metadata[msg_id]['submitted'] = datetime.now()
1307 self.metadata[msg_id]['submitted'] = datetime.now()
1313
1308
1314 return msg
1309 return msg
1315
1310
1316 #--------------------------------------------------------------------------
1311 #--------------------------------------------------------------------------
1317 # construct a View object
1312 # construct a View object
1318 #--------------------------------------------------------------------------
1313 #--------------------------------------------------------------------------
1319
1314
1320 def load_balanced_view(self, targets=None):
1315 def load_balanced_view(self, targets=None):
1321 """construct a DirectView object.
1316 """construct a DirectView object.
1322
1317
1323 If no arguments are specified, create a LoadBalancedView
1318 If no arguments are specified, create a LoadBalancedView
1324 using all engines.
1319 using all engines.
1325
1320
1326 Parameters
1321 Parameters
1327 ----------
1322 ----------
1328
1323
1329 targets: list,slice,int,etc. [default: use all engines]
1324 targets: list,slice,int,etc. [default: use all engines]
1330 The subset of engines across which to load-balance
1325 The subset of engines across which to load-balance
1331 """
1326 """
1332 if targets == 'all':
1327 if targets == 'all':
1333 targets = None
1328 targets = None
1334 if targets is not None:
1329 if targets is not None:
1335 targets = self._build_targets(targets)[1]
1330 targets = self._build_targets(targets)[1]
1336 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1337
1332
1338 def direct_view(self, targets='all'):
1333 def direct_view(self, targets='all'):
1339 """construct a DirectView object.
1334 """construct a DirectView object.
1340
1335
1341 If no targets are specified, create a DirectView using all engines.
1336 If no targets are specified, create a DirectView using all engines.
1342
1337
1343 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1344 evaluate the target engines at each execution, whereas rc[:] will connect to
1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1345 all *current* engines, and that list will not change.
1340 all *current* engines, and that list will not change.
1346
1341
1347 That is, 'all' will always use all engines, whereas rc[:] will not use
1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1348 engines added after the DirectView is constructed.
1343 engines added after the DirectView is constructed.
1349
1344
1350 Parameters
1345 Parameters
1351 ----------
1346 ----------
1352
1347
1353 targets: list,slice,int,etc. [default: use all engines]
1348 targets: list,slice,int,etc. [default: use all engines]
1354 The engines to use for the View
1349 The engines to use for the View
1355 """
1350 """
1356 single = isinstance(targets, int)
1351 single = isinstance(targets, int)
1357 # allow 'all' to be lazily evaluated at each execution
1352 # allow 'all' to be lazily evaluated at each execution
1358 if targets != 'all':
1353 if targets != 'all':
1359 targets = self._build_targets(targets)[1]
1354 targets = self._build_targets(targets)[1]
1360 if single:
1355 if single:
1361 targets = targets[0]
1356 targets = targets[0]
1362 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1363
1358
1364 #--------------------------------------------------------------------------
1359 #--------------------------------------------------------------------------
1365 # Query methods
1360 # Query methods
1366 #--------------------------------------------------------------------------
1361 #--------------------------------------------------------------------------
1367
1362
1368 @spin_first
1363 @spin_first
1369 def get_result(self, indices_or_msg_ids=None, block=None):
1364 def get_result(self, indices_or_msg_ids=None, block=None):
1370 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1371
1366
1372 If the client already has the results, no request to the Hub will be made.
1367 If the client already has the results, no request to the Hub will be made.
1373
1368
1374 This is a convenient way to construct AsyncResult objects, which are wrappers
1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1375 that include metadata about execution, and allow for awaiting results that
1370 that include metadata about execution, and allow for awaiting results that
1376 were not submitted by this Client.
1371 were not submitted by this Client.
1377
1372
1378 It can also be a convenient way to retrieve the metadata associated with
1373 It can also be a convenient way to retrieve the metadata associated with
1379 blocking execution, since it always retrieves
1374 blocking execution, since it always retrieves
1380
1375
1381 Examples
1376 Examples
1382 --------
1377 --------
1383 ::
1378 ::
1384
1379
1385 In [10]: r = client.apply()
1380 In [10]: r = client.apply()
1386
1381
1387 Parameters
1382 Parameters
1388 ----------
1383 ----------
1389
1384
1390 indices_or_msg_ids : integer history index, str msg_id, or list of either
1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1391 The indices or msg_ids of indices to be retrieved
1386 The indices or msg_ids of indices to be retrieved
1392
1387
1393 block : bool
1388 block : bool
1394 Whether to wait for the result to be done
1389 Whether to wait for the result to be done
1395
1390
1396 Returns
1391 Returns
1397 -------
1392 -------
1398
1393
1399 AsyncResult
1394 AsyncResult
1400 A single AsyncResult object will always be returned.
1395 A single AsyncResult object will always be returned.
1401
1396
1402 AsyncHubResult
1397 AsyncHubResult
1403 A subclass of AsyncResult that retrieves results from the Hub
1398 A subclass of AsyncResult that retrieves results from the Hub
1404
1399
1405 """
1400 """
1406 block = self.block if block is None else block
1401 block = self.block if block is None else block
1407 if indices_or_msg_ids is None:
1402 if indices_or_msg_ids is None:
1408 indices_or_msg_ids = -1
1403 indices_or_msg_ids = -1
1409
1404
1410 single_result = False
1405 single_result = False
1411 if not isinstance(indices_or_msg_ids, (list,tuple)):
1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1412 indices_or_msg_ids = [indices_or_msg_ids]
1407 indices_or_msg_ids = [indices_or_msg_ids]
1413 single_result = True
1408 single_result = True
1414
1409
1415 theids = []
1410 theids = []
1416 for id in indices_or_msg_ids:
1411 for id in indices_or_msg_ids:
1417 if isinstance(id, int):
1412 if isinstance(id, int):
1418 id = self.history[id]
1413 id = self.history[id]
1419 if not isinstance(id, string_types):
1414 if not isinstance(id, string_types):
1420 raise TypeError("indices must be str or int, not %r"%id)
1415 raise TypeError("indices must be str or int, not %r"%id)
1421 theids.append(id)
1416 theids.append(id)
1422
1417
1423 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1418 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1424 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1419 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1425
1420
1426 # given single msg_id initially, get_result shot get the result itself,
1421 # given single msg_id initially, get_result shot get the result itself,
1427 # not a length-one list
1422 # not a length-one list
1428 if single_result:
1423 if single_result:
1429 theids = theids[0]
1424 theids = theids[0]
1430
1425
1431 if remote_ids:
1426 if remote_ids:
1432 ar = AsyncHubResult(self, msg_ids=theids)
1427 ar = AsyncHubResult(self, msg_ids=theids)
1433 else:
1428 else:
1434 ar = AsyncResult(self, msg_ids=theids)
1429 ar = AsyncResult(self, msg_ids=theids)
1435
1430
1436 if block:
1431 if block:
1437 ar.wait()
1432 ar.wait()
1438
1433
1439 return ar
1434 return ar
1440
1435
1441 @spin_first
1436 @spin_first
1442 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1443 """Resubmit one or more tasks.
1438 """Resubmit one or more tasks.
1444
1439
1445 in-flight tasks may not be resubmitted.
1440 in-flight tasks may not be resubmitted.
1446
1441
1447 Parameters
1442 Parameters
1448 ----------
1443 ----------
1449
1444
1450 indices_or_msg_ids : integer history index, str msg_id, or list of either
1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1451 The indices or msg_ids of indices to be retrieved
1446 The indices or msg_ids of indices to be retrieved
1452
1447
1453 block : bool
1448 block : bool
1454 Whether to wait for the result to be done
1449 Whether to wait for the result to be done
1455
1450
1456 Returns
1451 Returns
1457 -------
1452 -------
1458
1453
1459 AsyncHubResult
1454 AsyncHubResult
1460 A subclass of AsyncResult that retrieves results from the Hub
1455 A subclass of AsyncResult that retrieves results from the Hub
1461
1456
1462 """
1457 """
1463 block = self.block if block is None else block
1458 block = self.block if block is None else block
1464 if indices_or_msg_ids is None:
1459 if indices_or_msg_ids is None:
1465 indices_or_msg_ids = -1
1460 indices_or_msg_ids = -1
1466
1461
1467 if not isinstance(indices_or_msg_ids, (list,tuple)):
1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1468 indices_or_msg_ids = [indices_or_msg_ids]
1463 indices_or_msg_ids = [indices_or_msg_ids]
1469
1464
1470 theids = []
1465 theids = []
1471 for id in indices_or_msg_ids:
1466 for id in indices_or_msg_ids:
1472 if isinstance(id, int):
1467 if isinstance(id, int):
1473 id = self.history[id]
1468 id = self.history[id]
1474 if not isinstance(id, string_types):
1469 if not isinstance(id, string_types):
1475 raise TypeError("indices must be str or int, not %r"%id)
1470 raise TypeError("indices must be str or int, not %r"%id)
1476 theids.append(id)
1471 theids.append(id)
1477
1472
1478 content = dict(msg_ids = theids)
1473 content = dict(msg_ids = theids)
1479
1474
1480 self.session.send(self._query_socket, 'resubmit_request', content)
1475 self.session.send(self._query_socket, 'resubmit_request', content)
1481
1476
1482 zmq.select([self._query_socket], [], [])
1477 zmq.select([self._query_socket], [], [])
1483 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1484 if self.debug:
1479 if self.debug:
1485 pprint(msg)
1480 pprint(msg)
1486 content = msg['content']
1481 content = msg['content']
1487 if content['status'] != 'ok':
1482 if content['status'] != 'ok':
1488 raise self._unwrap_exception(content)
1483 raise self._unwrap_exception(content)
1489 mapping = content['resubmitted']
1484 mapping = content['resubmitted']
1490 new_ids = [ mapping[msg_id] for msg_id in theids ]
1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1491
1486
1492 ar = AsyncHubResult(self, msg_ids=new_ids)
1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1493
1488
1494 if block:
1489 if block:
1495 ar.wait()
1490 ar.wait()
1496
1491
1497 return ar
1492 return ar
1498
1493
1499 @spin_first
1494 @spin_first
1500 def result_status(self, msg_ids, status_only=True):
1495 def result_status(self, msg_ids, status_only=True):
1501 """Check on the status of the result(s) of the apply request with `msg_ids`.
1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1502
1497
1503 If status_only is False, then the actual results will be retrieved, else
1498 If status_only is False, then the actual results will be retrieved, else
1504 only the status of the results will be checked.
1499 only the status of the results will be checked.
1505
1500
1506 Parameters
1501 Parameters
1507 ----------
1502 ----------
1508
1503
1509 msg_ids : list of msg_ids
1504 msg_ids : list of msg_ids
1510 if int:
1505 if int:
1511 Passed as index to self.history for convenience.
1506 Passed as index to self.history for convenience.
1512 status_only : bool (default: True)
1507 status_only : bool (default: True)
1513 if False:
1508 if False:
1514 Retrieve the actual results of completed tasks.
1509 Retrieve the actual results of completed tasks.
1515
1510
1516 Returns
1511 Returns
1517 -------
1512 -------
1518
1513
1519 results : dict
1514 results : dict
1520 There will always be the keys 'pending' and 'completed', which will
1515 There will always be the keys 'pending' and 'completed', which will
1521 be lists of msg_ids that are incomplete or complete. If `status_only`
1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1522 is False, then completed results will be keyed by their `msg_id`.
1517 is False, then completed results will be keyed by their `msg_id`.
1523 """
1518 """
1524 if not isinstance(msg_ids, (list,tuple)):
1519 if not isinstance(msg_ids, (list,tuple)):
1525 msg_ids = [msg_ids]
1520 msg_ids = [msg_ids]
1526
1521
1527 theids = []
1522 theids = []
1528 for msg_id in msg_ids:
1523 for msg_id in msg_ids:
1529 if isinstance(msg_id, int):
1524 if isinstance(msg_id, int):
1530 msg_id = self.history[msg_id]
1525 msg_id = self.history[msg_id]
1531 if not isinstance(msg_id, string_types):
1526 if not isinstance(msg_id, string_types):
1532 raise TypeError("msg_ids must be str, not %r"%msg_id)
1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1533 theids.append(msg_id)
1528 theids.append(msg_id)
1534
1529
1535 completed = []
1530 completed = []
1536 local_results = {}
1531 local_results = {}
1537
1532
1538 # comment this block out to temporarily disable local shortcut:
1533 # comment this block out to temporarily disable local shortcut:
1539 for msg_id in theids:
1534 for msg_id in theids:
1540 if msg_id in self.results:
1535 if msg_id in self.results:
1541 completed.append(msg_id)
1536 completed.append(msg_id)
1542 local_results[msg_id] = self.results[msg_id]
1537 local_results[msg_id] = self.results[msg_id]
1543 theids.remove(msg_id)
1538 theids.remove(msg_id)
1544
1539
1545 if theids: # some not locally cached
1540 if theids: # some not locally cached
1546 content = dict(msg_ids=theids, status_only=status_only)
1541 content = dict(msg_ids=theids, status_only=status_only)
1547 msg = self.session.send(self._query_socket, "result_request", content=content)
1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1548 zmq.select([self._query_socket], [], [])
1543 zmq.select([self._query_socket], [], [])
1549 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1550 if self.debug:
1545 if self.debug:
1551 pprint(msg)
1546 pprint(msg)
1552 content = msg['content']
1547 content = msg['content']
1553 if content['status'] != 'ok':
1548 if content['status'] != 'ok':
1554 raise self._unwrap_exception(content)
1549 raise self._unwrap_exception(content)
1555 buffers = msg['buffers']
1550 buffers = msg['buffers']
1556 else:
1551 else:
1557 content = dict(completed=[],pending=[])
1552 content = dict(completed=[],pending=[])
1558
1553
1559 content['completed'].extend(completed)
1554 content['completed'].extend(completed)
1560
1555
1561 if status_only:
1556 if status_only:
1562 return content
1557 return content
1563
1558
1564 failures = []
1559 failures = []
1565 # load cached results into result:
1560 # load cached results into result:
1566 content.update(local_results)
1561 content.update(local_results)
1567
1562
1568 # update cache with results:
1563 # update cache with results:
1569 for msg_id in sorted(theids):
1564 for msg_id in sorted(theids):
1570 if msg_id in content['completed']:
1565 if msg_id in content['completed']:
1571 rec = content[msg_id]
1566 rec = content[msg_id]
1572 parent = rec['header']
1567 parent = rec['header']
1573 header = rec['result_header']
1568 header = rec['result_header']
1574 rcontent = rec['result_content']
1569 rcontent = rec['result_content']
1575 iodict = rec['io']
1570 iodict = rec['io']
1576 if isinstance(rcontent, str):
1571 if isinstance(rcontent, str):
1577 rcontent = self.session.unpack(rcontent)
1572 rcontent = self.session.unpack(rcontent)
1578
1573
1579 md = self.metadata[msg_id]
1574 md = self.metadata[msg_id]
1580 md_msg = dict(
1575 md_msg = dict(
1581 content=rcontent,
1576 content=rcontent,
1582 parent_header=parent,
1577 parent_header=parent,
1583 header=header,
1578 header=header,
1584 metadata=rec['result_metadata'],
1579 metadata=rec['result_metadata'],
1585 )
1580 )
1586 md.update(self._extract_metadata(md_msg))
1581 md.update(self._extract_metadata(md_msg))
1587 if rec.get('received'):
1582 if rec.get('received'):
1588 md['received'] = rec['received']
1583 md['received'] = rec['received']
1589 md.update(iodict)
1584 md.update(iodict)
1590
1585
1591 if rcontent['status'] == 'ok':
1586 if rcontent['status'] == 'ok':
1592 if header['msg_type'] == 'apply_reply':
1587 if header['msg_type'] == 'apply_reply':
1593 res,buffers = serialize.unserialize_object(buffers)
1588 res,buffers = serialize.unserialize_object(buffers)
1594 elif header['msg_type'] == 'execute_reply':
1589 elif header['msg_type'] == 'execute_reply':
1595 res = ExecuteReply(msg_id, rcontent, md)
1590 res = ExecuteReply(msg_id, rcontent, md)
1596 else:
1591 else:
1597 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1598 else:
1593 else:
1599 res = self._unwrap_exception(rcontent)
1594 res = self._unwrap_exception(rcontent)
1600 failures.append(res)
1595 failures.append(res)
1601
1596
1602 self.results[msg_id] = res
1597 self.results[msg_id] = res
1603 content[msg_id] = res
1598 content[msg_id] = res
1604
1599
1605 if len(theids) == 1 and failures:
1600 if len(theids) == 1 and failures:
1606 raise failures[0]
1601 raise failures[0]
1607
1602
1608 error.collect_exceptions(failures, "result_status")
1603 error.collect_exceptions(failures, "result_status")
1609 return content
1604 return content
1610
1605
1611 @spin_first
1606 @spin_first
1612 def queue_status(self, targets='all', verbose=False):
1607 def queue_status(self, targets='all', verbose=False):
1613 """Fetch the status of engine queues.
1608 """Fetch the status of engine queues.
1614
1609
1615 Parameters
1610 Parameters
1616 ----------
1611 ----------
1617
1612
1618 targets : int/str/list of ints/strs
1613 targets : int/str/list of ints/strs
1619 the engines whose states are to be queried.
1614 the engines whose states are to be queried.
1620 default : all
1615 default : all
1621 verbose : bool
1616 verbose : bool
1622 Whether to return lengths only, or lists of ids for each element
1617 Whether to return lengths only, or lists of ids for each element
1623 """
1618 """
1624 if targets == 'all':
1619 if targets == 'all':
1625 # allow 'all' to be evaluated on the engine
1620 # allow 'all' to be evaluated on the engine
1626 engine_ids = None
1621 engine_ids = None
1627 else:
1622 else:
1628 engine_ids = self._build_targets(targets)[1]
1623 engine_ids = self._build_targets(targets)[1]
1629 content = dict(targets=engine_ids, verbose=verbose)
1624 content = dict(targets=engine_ids, verbose=verbose)
1630 self.session.send(self._query_socket, "queue_request", content=content)
1625 self.session.send(self._query_socket, "queue_request", content=content)
1631 idents,msg = self.session.recv(self._query_socket, 0)
1626 idents,msg = self.session.recv(self._query_socket, 0)
1632 if self.debug:
1627 if self.debug:
1633 pprint(msg)
1628 pprint(msg)
1634 content = msg['content']
1629 content = msg['content']
1635 status = content.pop('status')
1630 status = content.pop('status')
1636 if status != 'ok':
1631 if status != 'ok':
1637 raise self._unwrap_exception(content)
1632 raise self._unwrap_exception(content)
1638 content = rekey(content)
1633 content = rekey(content)
1639 if isinstance(targets, int):
1634 if isinstance(targets, int):
1640 return content[targets]
1635 return content[targets]
1641 else:
1636 else:
1642 return content
1637 return content
1643
1638
1644 def _build_msgids_from_target(self, targets=None):
1639 def _build_msgids_from_target(self, targets=None):
1645 """Build a list of msg_ids from the list of engine targets"""
1640 """Build a list of msg_ids from the list of engine targets"""
1646 if not targets: # needed as _build_targets otherwise uses all engines
1641 if not targets: # needed as _build_targets otherwise uses all engines
1647 return []
1642 return []
1648 target_ids = self._build_targets(targets)[0]
1643 target_ids = self._build_targets(targets)[0]
1649 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1644 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1650
1645
1651 def _build_msgids_from_jobs(self, jobs=None):
1646 def _build_msgids_from_jobs(self, jobs=None):
1652 """Build a list of msg_ids from "jobs" """
1647 """Build a list of msg_ids from "jobs" """
1653 if not jobs:
1648 if not jobs:
1654 return []
1649 return []
1655 msg_ids = []
1650 msg_ids = []
1656 if isinstance(jobs, string_types + (AsyncResult,)):
1651 if isinstance(jobs, string_types + (AsyncResult,)):
1657 jobs = [jobs]
1652 jobs = [jobs]
1658 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1653 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1659 if bad_ids:
1654 if bad_ids:
1660 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1661 for j in jobs:
1656 for j in jobs:
1662 if isinstance(j, AsyncResult):
1657 if isinstance(j, AsyncResult):
1663 msg_ids.extend(j.msg_ids)
1658 msg_ids.extend(j.msg_ids)
1664 else:
1659 else:
1665 msg_ids.append(j)
1660 msg_ids.append(j)
1666 return msg_ids
1661 return msg_ids
1667
1662
1668 def purge_local_results(self, jobs=[], targets=[]):
1663 def purge_local_results(self, jobs=[], targets=[]):
1669 """Clears the client caches of results and frees such memory.
1664 """Clears the client caches of results and frees such memory.
1670
1665
1671 Individual results can be purged by msg_id, or the entire
1666 Individual results can be purged by msg_id, or the entire
1672 history of specific targets can be purged.
1667 history of specific targets can be purged.
1673
1668
1674 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1675
1670
1676 The client must have no outstanding tasks before purging the caches.
1671 The client must have no outstanding tasks before purging the caches.
1677 Raises `AssertionError` if there are still outstanding tasks.
1672 Raises `AssertionError` if there are still outstanding tasks.
1678
1673
1679 After this call all `AsyncResults` are invalid and should be discarded.
1674 After this call all `AsyncResults` are invalid and should be discarded.
1680
1675
1681 If you must "reget" the results, you can still do so by using
1676 If you must "reget" the results, you can still do so by using
1682 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1683 redownload the results from the hub if they are still available
1678 redownload the results from the hub if they are still available
1684 (i.e `client.purge_hub_results(...)` has not been called.
1679 (i.e `client.purge_hub_results(...)` has not been called.
1685
1680
1686 Parameters
1681 Parameters
1687 ----------
1682 ----------
1688
1683
1689 jobs : str or list of str or AsyncResult objects
1684 jobs : str or list of str or AsyncResult objects
1690 the msg_ids whose results should be purged.
1685 the msg_ids whose results should be purged.
1691 targets : int/str/list of ints/strs
1686 targets : int/str/list of ints/strs
1692 The targets, by int_id, whose entire results are to be purged.
1687 The targets, by int_id, whose entire results are to be purged.
1693
1688
1694 default : None
1689 default : None
1695 """
1690 """
1696 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1697
1692
1698 if not targets and not jobs:
1693 if not targets and not jobs:
1699 raise ValueError("Must specify at least one of `targets` and `jobs`")
1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1700
1695
1701 if jobs == 'all':
1696 if jobs == 'all':
1702 self.results.clear()
1697 self.results.clear()
1703 self.metadata.clear()
1698 self.metadata.clear()
1704 return
1699 return
1705 else:
1700 else:
1706 msg_ids = []
1701 msg_ids = []
1707 msg_ids.extend(self._build_msgids_from_target(targets))
1702 msg_ids.extend(self._build_msgids_from_target(targets))
1708 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1709 map(self.results.pop, msg_ids)
1704 map(self.results.pop, msg_ids)
1710 map(self.metadata.pop, msg_ids)
1705 map(self.metadata.pop, msg_ids)
1711
1706
1712
1707
1713 @spin_first
1708 @spin_first
1714 def purge_hub_results(self, jobs=[], targets=[]):
1709 def purge_hub_results(self, jobs=[], targets=[]):
1715 """Tell the Hub to forget results.
1710 """Tell the Hub to forget results.
1716
1711
1717 Individual results can be purged by msg_id, or the entire
1712 Individual results can be purged by msg_id, or the entire
1718 history of specific targets can be purged.
1713 history of specific targets can be purged.
1719
1714
1720 Use `purge_results('all')` to scrub everything from the Hub's db.
1715 Use `purge_results('all')` to scrub everything from the Hub's db.
1721
1716
1722 Parameters
1717 Parameters
1723 ----------
1718 ----------
1724
1719
1725 jobs : str or list of str or AsyncResult objects
1720 jobs : str or list of str or AsyncResult objects
1726 the msg_ids whose results should be forgotten.
1721 the msg_ids whose results should be forgotten.
1727 targets : int/str/list of ints/strs
1722 targets : int/str/list of ints/strs
1728 The targets, by int_id, whose entire history is to be purged.
1723 The targets, by int_id, whose entire history is to be purged.
1729
1724
1730 default : None
1725 default : None
1731 """
1726 """
1732 if not targets and not jobs:
1727 if not targets and not jobs:
1733 raise ValueError("Must specify at least one of `targets` and `jobs`")
1728 raise ValueError("Must specify at least one of `targets` and `jobs`")
1734 if targets:
1729 if targets:
1735 targets = self._build_targets(targets)[1]
1730 targets = self._build_targets(targets)[1]
1736
1731
1737 # construct msg_ids from jobs
1732 # construct msg_ids from jobs
1738 if jobs == 'all':
1733 if jobs == 'all':
1739 msg_ids = jobs
1734 msg_ids = jobs
1740 else:
1735 else:
1741 msg_ids = self._build_msgids_from_jobs(jobs)
1736 msg_ids = self._build_msgids_from_jobs(jobs)
1742
1737
1743 content = dict(engine_ids=targets, msg_ids=msg_ids)
1738 content = dict(engine_ids=targets, msg_ids=msg_ids)
1744 self.session.send(self._query_socket, "purge_request", content=content)
1739 self.session.send(self._query_socket, "purge_request", content=content)
1745 idents, msg = self.session.recv(self._query_socket, 0)
1740 idents, msg = self.session.recv(self._query_socket, 0)
1746 if self.debug:
1741 if self.debug:
1747 pprint(msg)
1742 pprint(msg)
1748 content = msg['content']
1743 content = msg['content']
1749 if content['status'] != 'ok':
1744 if content['status'] != 'ok':
1750 raise self._unwrap_exception(content)
1745 raise self._unwrap_exception(content)
1751
1746
1752 def purge_results(self, jobs=[], targets=[]):
1747 def purge_results(self, jobs=[], targets=[]):
1753 """Clears the cached results from both the hub and the local client
1748 """Clears the cached results from both the hub and the local client
1754
1749
1755 Individual results can be purged by msg_id, or the entire
1750 Individual results can be purged by msg_id, or the entire
1756 history of specific targets can be purged.
1751 history of specific targets can be purged.
1757
1752
1758 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1753 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1759 the Client's db.
1754 the Client's db.
1760
1755
1761 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1756 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1762 the same arguments.
1757 the same arguments.
1763
1758
1764 Parameters
1759 Parameters
1765 ----------
1760 ----------
1766
1761
1767 jobs : str or list of str or AsyncResult objects
1762 jobs : str or list of str or AsyncResult objects
1768 the msg_ids whose results should be forgotten.
1763 the msg_ids whose results should be forgotten.
1769 targets : int/str/list of ints/strs
1764 targets : int/str/list of ints/strs
1770 The targets, by int_id, whose entire history is to be purged.
1765 The targets, by int_id, whose entire history is to be purged.
1771
1766
1772 default : None
1767 default : None
1773 """
1768 """
1774 self.purge_local_results(jobs=jobs, targets=targets)
1769 self.purge_local_results(jobs=jobs, targets=targets)
1775 self.purge_hub_results(jobs=jobs, targets=targets)
1770 self.purge_hub_results(jobs=jobs, targets=targets)
1776
1771
1777 def purge_everything(self):
1772 def purge_everything(self):
1778 """Clears all content from previous Tasks from both the hub and the local client
1773 """Clears all content from previous Tasks from both the hub and the local client
1779
1774
1780 In addition to calling `purge_results("all")` it also deletes the history and
1775 In addition to calling `purge_results("all")` it also deletes the history and
1781 other bookkeeping lists.
1776 other bookkeeping lists.
1782 """
1777 """
1783 self.purge_results("all")
1778 self.purge_results("all")
1784 self.history = []
1779 self.history = []
1785 self.session.digest_history.clear()
1780 self.session.digest_history.clear()
1786
1781
1787 @spin_first
1782 @spin_first
1788 def hub_history(self):
1783 def hub_history(self):
1789 """Get the Hub's history
1784 """Get the Hub's history
1790
1785
1791 Just like the Client, the Hub has a history, which is a list of msg_ids.
1786 Just like the Client, the Hub has a history, which is a list of msg_ids.
1792 This will contain the history of all clients, and, depending on configuration,
1787 This will contain the history of all clients, and, depending on configuration,
1793 may contain history across multiple cluster sessions.
1788 may contain history across multiple cluster sessions.
1794
1789
1795 Any msg_id returned here is a valid argument to `get_result`.
1790 Any msg_id returned here is a valid argument to `get_result`.
1796
1791
1797 Returns
1792 Returns
1798 -------
1793 -------
1799
1794
1800 msg_ids : list of strs
1795 msg_ids : list of strs
1801 list of all msg_ids, ordered by task submission time.
1796 list of all msg_ids, ordered by task submission time.
1802 """
1797 """
1803
1798
1804 self.session.send(self._query_socket, "history_request", content={})
1799 self.session.send(self._query_socket, "history_request", content={})
1805 idents, msg = self.session.recv(self._query_socket, 0)
1800 idents, msg = self.session.recv(self._query_socket, 0)
1806
1801
1807 if self.debug:
1802 if self.debug:
1808 pprint(msg)
1803 pprint(msg)
1809 content = msg['content']
1804 content = msg['content']
1810 if content['status'] != 'ok':
1805 if content['status'] != 'ok':
1811 raise self._unwrap_exception(content)
1806 raise self._unwrap_exception(content)
1812 else:
1807 else:
1813 return content['history']
1808 return content['history']
1814
1809
1815 @spin_first
1810 @spin_first
1816 def db_query(self, query, keys=None):
1811 def db_query(self, query, keys=None):
1817 """Query the Hub's TaskRecord database
1812 """Query the Hub's TaskRecord database
1818
1813
1819 This will return a list of task record dicts that match `query`
1814 This will return a list of task record dicts that match `query`
1820
1815
1821 Parameters
1816 Parameters
1822 ----------
1817 ----------
1823
1818
1824 query : mongodb query dict
1819 query : mongodb query dict
1825 The search dict. See mongodb query docs for details.
1820 The search dict. See mongodb query docs for details.
1826 keys : list of strs [optional]
1821 keys : list of strs [optional]
1827 The subset of keys to be returned. The default is to fetch everything but buffers.
1822 The subset of keys to be returned. The default is to fetch everything but buffers.
1828 'msg_id' will *always* be included.
1823 'msg_id' will *always* be included.
1829 """
1824 """
1830 if isinstance(keys, string_types):
1825 if isinstance(keys, string_types):
1831 keys = [keys]
1826 keys = [keys]
1832 content = dict(query=query, keys=keys)
1827 content = dict(query=query, keys=keys)
1833 self.session.send(self._query_socket, "db_request", content=content)
1828 self.session.send(self._query_socket, "db_request", content=content)
1834 idents, msg = self.session.recv(self._query_socket, 0)
1829 idents, msg = self.session.recv(self._query_socket, 0)
1835 if self.debug:
1830 if self.debug:
1836 pprint(msg)
1831 pprint(msg)
1837 content = msg['content']
1832 content = msg['content']
1838 if content['status'] != 'ok':
1833 if content['status'] != 'ok':
1839 raise self._unwrap_exception(content)
1834 raise self._unwrap_exception(content)
1840
1835
1841 records = content['records']
1836 records = content['records']
1842
1837
1843 buffer_lens = content['buffer_lens']
1838 buffer_lens = content['buffer_lens']
1844 result_buffer_lens = content['result_buffer_lens']
1839 result_buffer_lens = content['result_buffer_lens']
1845 buffers = msg['buffers']
1840 buffers = msg['buffers']
1846 has_bufs = buffer_lens is not None
1841 has_bufs = buffer_lens is not None
1847 has_rbufs = result_buffer_lens is not None
1842 has_rbufs = result_buffer_lens is not None
1848 for i,rec in enumerate(records):
1843 for i,rec in enumerate(records):
1849 # relink buffers
1844 # relink buffers
1850 if has_bufs:
1845 if has_bufs:
1851 blen = buffer_lens[i]
1846 blen = buffer_lens[i]
1852 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1847 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1853 if has_rbufs:
1848 if has_rbufs:
1854 blen = result_buffer_lens[i]
1849 blen = result_buffer_lens[i]
1855 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1850 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1856
1851
1857 return records
1852 return records
1858
1853
1859 __all__ = [ 'Client' ]
1854 __all__ = [ 'Client' ]
@@ -1,369 +1,369 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 from IPython.utils.py3compat import string_types
47 from IPython.utils.py3compat import string_types
48 from IPython.kernel.zmq.log import EnginePUBHandler
48 from IPython.kernel.zmq.log import EnginePUBHandler
49 from IPython.kernel.zmq.serialize import (
49 from IPython.kernel.zmq.serialize import (
50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 )
51 )
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Classes
54 # Classes
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 class Namespace(dict):
57 class Namespace(dict):
58 """Subclass of dict for attribute access to keys."""
58 """Subclass of dict for attribute access to keys."""
59
59
60 def __getattr__(self, key):
60 def __getattr__(self, key):
61 """getattr aliased to getitem"""
61 """getattr aliased to getitem"""
62 if key in self.iterkeys():
62 if key in self.iterkeys():
63 return self[key]
63 return self[key]
64 else:
64 else:
65 raise NameError(key)
65 raise NameError(key)
66
66
67 def __setattr__(self, key, value):
67 def __setattr__(self, key, value):
68 """setattr aliased to setitem, with strict"""
68 """setattr aliased to setitem, with strict"""
69 if hasattr(dict, key):
69 if hasattr(dict, key):
70 raise KeyError("Cannot override dict keys %r"%key)
70 raise KeyError("Cannot override dict keys %r"%key)
71 self[key] = value
71 self[key] = value
72
72
73
73
74 class ReverseDict(dict):
74 class ReverseDict(dict):
75 """simple double-keyed subset of dict methods."""
75 """simple double-keyed subset of dict methods."""
76
76
77 def __init__(self, *args, **kwargs):
77 def __init__(self, *args, **kwargs):
78 dict.__init__(self, *args, **kwargs)
78 dict.__init__(self, *args, **kwargs)
79 self._reverse = dict()
79 self._reverse = dict()
80 for key, value in self.iteritems():
80 for key, value in self.iteritems():
81 self._reverse[value] = key
81 self._reverse[value] = key
82
82
83 def __getitem__(self, key):
83 def __getitem__(self, key):
84 try:
84 try:
85 return dict.__getitem__(self, key)
85 return dict.__getitem__(self, key)
86 except KeyError:
86 except KeyError:
87 return self._reverse[key]
87 return self._reverse[key]
88
88
89 def __setitem__(self, key, value):
89 def __setitem__(self, key, value):
90 if key in self._reverse:
90 if key in self._reverse:
91 raise KeyError("Can't have key %r on both sides!"%key)
91 raise KeyError("Can't have key %r on both sides!"%key)
92 dict.__setitem__(self, key, value)
92 dict.__setitem__(self, key, value)
93 self._reverse[value] = key
93 self._reverse[value] = key
94
94
95 def pop(self, key):
95 def pop(self, key):
96 value = dict.pop(self, key)
96 value = dict.pop(self, key)
97 self._reverse.pop(value)
97 self._reverse.pop(value)
98 return value
98 return value
99
99
100 def get(self, key, default=None):
100 def get(self, key, default=None):
101 try:
101 try:
102 return self[key]
102 return self[key]
103 except KeyError:
103 except KeyError:
104 return default
104 return default
105
105
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 # Functions
107 # Functions
108 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
109
109
110 @decorator
110 @decorator
111 def log_errors(f, self, *args, **kwargs):
111 def log_errors(f, self, *args, **kwargs):
112 """decorator to log unhandled exceptions raised in a method.
112 """decorator to log unhandled exceptions raised in a method.
113
113
114 For use wrapping on_recv callbacks, so that exceptions
114 For use wrapping on_recv callbacks, so that exceptions
115 do not cause the stream to be closed.
115 do not cause the stream to be closed.
116 """
116 """
117 try:
117 try:
118 return f(self, *args, **kwargs)
118 return f(self, *args, **kwargs)
119 except Exception:
119 except Exception:
120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121
121
122
122
123 def is_url(url):
123 def is_url(url):
124 """boolean check for whether a string is a zmq url"""
124 """boolean check for whether a string is a zmq url"""
125 if '://' not in url:
125 if '://' not in url:
126 return False
126 return False
127 proto, addr = url.split('://', 1)
127 proto, addr = url.split('://', 1)
128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 return False
129 return False
130 return True
130 return True
131
131
132 def validate_url(url):
132 def validate_url(url):
133 """validate a url for zeromq"""
133 """validate a url for zeromq"""
134 if not isinstance(url, string_types):
134 if not isinstance(url, string_types):
135 raise TypeError("url must be a string, not %r"%type(url))
135 raise TypeError("url must be a string, not %r"%type(url))
136 url = url.lower()
136 url = url.lower()
137
137
138 proto_addr = url.split('://')
138 proto_addr = url.split('://')
139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 proto, addr = proto_addr
140 proto, addr = proto_addr
141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142
142
143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 # author: Remi Sabourin
144 # author: Remi Sabourin
145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146
146
147 if proto == 'tcp':
147 if proto == 'tcp':
148 lis = addr.split(':')
148 lis = addr.split(':')
149 assert len(lis) == 2, 'Invalid url: %r'%url
149 assert len(lis) == 2, 'Invalid url: %r'%url
150 addr,s_port = lis
150 addr,s_port = lis
151 try:
151 try:
152 port = int(s_port)
152 port = int(s_port)
153 except ValueError:
153 except ValueError:
154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155
155
156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157
157
158 else:
158 else:
159 # only validate tcp urls currently
159 # only validate tcp urls currently
160 pass
160 pass
161
161
162 return True
162 return True
163
163
164
164
165 def validate_url_container(container):
165 def validate_url_container(container):
166 """validate a potentially nested collection of urls."""
166 """validate a potentially nested collection of urls."""
167 if isinstance(container, string_types):
167 if isinstance(container, string_types):
168 url = container
168 url = container
169 return validate_url(url)
169 return validate_url(url)
170 elif isinstance(container, dict):
170 elif isinstance(container, dict):
171 container = container.itervalues()
171 container = container.itervalues()
172
172
173 for element in container:
173 for element in container:
174 validate_url_container(element)
174 validate_url_container(element)
175
175
176
176
177 def split_url(url):
177 def split_url(url):
178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 proto_addr = url.split('://')
179 proto_addr = url.split('://')
180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 proto, addr = proto_addr
181 proto, addr = proto_addr
182 lis = addr.split(':')
182 lis = addr.split(':')
183 assert len(lis) == 2, 'Invalid url: %r'%url
183 assert len(lis) == 2, 'Invalid url: %r'%url
184 addr,s_port = lis
184 addr,s_port = lis
185 return proto,addr,s_port
185 return proto,addr,s_port
186
186
187 def disambiguate_ip_address(ip, location=None):
187 def disambiguate_ip_address(ip, location=None):
188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 ones, based on the location (default interpretation of location is localhost)."""
189 ones, based on the location (default interpretation of location is localhost)."""
190 if ip in ('0.0.0.0', '*'):
190 if ip in ('0.0.0.0', '*'):
191 if location is None or is_public_ip(location) or not public_ips():
191 if location is None or is_public_ip(location) or not public_ips():
192 # If location is unspecified or cannot be determined, assume local
192 # If location is unspecified or cannot be determined, assume local
193 ip = localhost()
193 ip = localhost()
194 elif location:
194 elif location:
195 return location
195 return location
196 return ip
196 return ip
197
197
198 def disambiguate_url(url, location=None):
198 def disambiguate_url(url, location=None):
199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 ones, based on the location (default interpretation is localhost).
200 ones, based on the location (default interpretation is localhost).
201
201
202 This is for zeromq urls, such as tcp://*:10101."""
202 This is for zeromq urls, such as tcp://*:10101."""
203 try:
203 try:
204 proto,ip,port = split_url(url)
204 proto,ip,port = split_url(url)
205 except AssertionError:
205 except AssertionError:
206 # probably not tcp url; could be ipc, etc.
206 # probably not tcp url; could be ipc, etc.
207 return url
207 return url
208
208
209 ip = disambiguate_ip_address(ip,location)
209 ip = disambiguate_ip_address(ip,location)
210
210
211 return "%s://%s:%s"%(proto,ip,port)
211 return "%s://%s:%s"%(proto,ip,port)
212
212
213
213
214 #--------------------------------------------------------------------------
214 #--------------------------------------------------------------------------
215 # helpers for implementing old MEC API via view.apply
215 # helpers for implementing old MEC API via view.apply
216 #--------------------------------------------------------------------------
216 #--------------------------------------------------------------------------
217
217
218 def interactive(f):
218 def interactive(f):
219 """decorator for making functions appear as interactively defined.
219 """decorator for making functions appear as interactively defined.
220 This results in the function being linked to the user_ns as globals()
220 This results in the function being linked to the user_ns as globals()
221 instead of the module globals().
221 instead of the module globals().
222 """
222 """
223 f.__module__ = '__main__'
223 f.__module__ = '__main__'
224 return f
224 return f
225
225
226 @interactive
226 @interactive
227 def _push(**ns):
227 def _push(**ns):
228 """helper method for implementing `client.push` via `client.apply`"""
228 """helper method for implementing `client.push` via `client.apply`"""
229 user_ns = globals()
229 user_ns = globals()
230 tmp = '_IP_PUSH_TMP_'
230 tmp = '_IP_PUSH_TMP_'
231 while tmp in user_ns:
231 while tmp in user_ns:
232 tmp = tmp + '_'
232 tmp = tmp + '_'
233 try:
233 try:
234 for name, value in ns.iteritems():
234 for name, value in ns.iteritems():
235 user_ns[tmp] = value
235 user_ns[tmp] = value
236 exec("%s = %s" % (name, tmp), user_ns)
236 exec("%s = %s" % (name, tmp), user_ns)
237 finally:
237 finally:
238 user_ns.pop(tmp, None)
238 user_ns.pop(tmp, None)
239
239
240 @interactive
240 @interactive
241 def _pull(keys):
241 def _pull(keys):
242 """helper method for implementing `client.pull` via `client.apply`"""
242 """helper method for implementing `client.pull` via `client.apply`"""
243 if isinstance(keys, (list,tuple, set)):
243 if isinstance(keys, (list,tuple, set)):
244 return map(lambda key: eval(key, globals()), keys)
244 return map(lambda key: eval(key, globals()), keys)
245 else:
245 else:
246 return eval(keys, globals())
246 return eval(keys, globals())
247
247
248 @interactive
248 @interactive
249 def _execute(code):
249 def _execute(code):
250 """helper method for implementing `client.execute` via `client.apply`"""
250 """helper method for implementing `client.execute` via `client.apply`"""
251 exec(code, globals())
251 exec(code, globals())
252
252
253 #--------------------------------------------------------------------------
253 #--------------------------------------------------------------------------
254 # extra process management utilities
254 # extra process management utilities
255 #--------------------------------------------------------------------------
255 #--------------------------------------------------------------------------
256
256
257 _random_ports = set()
257 _random_ports = set()
258
258
259 def select_random_ports(n):
259 def select_random_ports(n):
260 """Selects and return n random ports that are available."""
260 """Selects and return n random ports that are available."""
261 ports = []
261 ports = []
262 for i in xrange(n):
262 for i in range(n):
263 sock = socket.socket()
263 sock = socket.socket()
264 sock.bind(('', 0))
264 sock.bind(('', 0))
265 while sock.getsockname()[1] in _random_ports:
265 while sock.getsockname()[1] in _random_ports:
266 sock.close()
266 sock.close()
267 sock = socket.socket()
267 sock = socket.socket()
268 sock.bind(('', 0))
268 sock.bind(('', 0))
269 ports.append(sock)
269 ports.append(sock)
270 for i, sock in enumerate(ports):
270 for i, sock in enumerate(ports):
271 port = sock.getsockname()[1]
271 port = sock.getsockname()[1]
272 sock.close()
272 sock.close()
273 ports[i] = port
273 ports[i] = port
274 _random_ports.add(port)
274 _random_ports.add(port)
275 return ports
275 return ports
276
276
277 def signal_children(children):
277 def signal_children(children):
278 """Relay interupt/term signals to children, for more solid process cleanup."""
278 """Relay interupt/term signals to children, for more solid process cleanup."""
279 def terminate_children(sig, frame):
279 def terminate_children(sig, frame):
280 log = Application.instance().log
280 log = Application.instance().log
281 log.critical("Got signal %i, terminating children..."%sig)
281 log.critical("Got signal %i, terminating children..."%sig)
282 for child in children:
282 for child in children:
283 child.terminate()
283 child.terminate()
284
284
285 sys.exit(sig != SIGINT)
285 sys.exit(sig != SIGINT)
286 # sys.exit(sig)
286 # sys.exit(sig)
287 for sig in (SIGINT, SIGABRT, SIGTERM):
287 for sig in (SIGINT, SIGABRT, SIGTERM):
288 signal(sig, terminate_children)
288 signal(sig, terminate_children)
289
289
290 def generate_exec_key(keyfile):
290 def generate_exec_key(keyfile):
291 import uuid
291 import uuid
292 newkey = str(uuid.uuid4())
292 newkey = str(uuid.uuid4())
293 with open(keyfile, 'w') as f:
293 with open(keyfile, 'w') as f:
294 # f.write('ipython-key ')
294 # f.write('ipython-key ')
295 f.write(newkey+'\n')
295 f.write(newkey+'\n')
296 # set user-only RW permissions (0600)
296 # set user-only RW permissions (0600)
297 # this will have no effect on Windows
297 # this will have no effect on Windows
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299
299
300
300
301 def integer_loglevel(loglevel):
301 def integer_loglevel(loglevel):
302 try:
302 try:
303 loglevel = int(loglevel)
303 loglevel = int(loglevel)
304 except ValueError:
304 except ValueError:
305 if isinstance(loglevel, str):
305 if isinstance(loglevel, str):
306 loglevel = getattr(logging, loglevel)
306 loglevel = getattr(logging, loglevel)
307 return loglevel
307 return loglevel
308
308
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 logger = logging.getLogger(logname)
310 logger = logging.getLogger(logname)
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 # don't add a second PUBHandler
312 # don't add a second PUBHandler
313 return
313 return
314 loglevel = integer_loglevel(loglevel)
314 loglevel = integer_loglevel(loglevel)
315 lsock = context.socket(zmq.PUB)
315 lsock = context.socket(zmq.PUB)
316 lsock.connect(iface)
316 lsock.connect(iface)
317 handler = handlers.PUBHandler(lsock)
317 handler = handlers.PUBHandler(lsock)
318 handler.setLevel(loglevel)
318 handler.setLevel(loglevel)
319 handler.root_topic = root
319 handler.root_topic = root
320 logger.addHandler(handler)
320 logger.addHandler(handler)
321 logger.setLevel(loglevel)
321 logger.setLevel(loglevel)
322
322
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 logger = logging.getLogger()
324 logger = logging.getLogger()
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 # don't add a second PUBHandler
326 # don't add a second PUBHandler
327 return
327 return
328 loglevel = integer_loglevel(loglevel)
328 loglevel = integer_loglevel(loglevel)
329 lsock = context.socket(zmq.PUB)
329 lsock = context.socket(zmq.PUB)
330 lsock.connect(iface)
330 lsock.connect(iface)
331 handler = EnginePUBHandler(engine, lsock)
331 handler = EnginePUBHandler(engine, lsock)
332 handler.setLevel(loglevel)
332 handler.setLevel(loglevel)
333 logger.addHandler(handler)
333 logger.addHandler(handler)
334 logger.setLevel(loglevel)
334 logger.setLevel(loglevel)
335 return logger
335 return logger
336
336
337 def local_logger(logname, loglevel=logging.DEBUG):
337 def local_logger(logname, loglevel=logging.DEBUG):
338 loglevel = integer_loglevel(loglevel)
338 loglevel = integer_loglevel(loglevel)
339 logger = logging.getLogger(logname)
339 logger = logging.getLogger(logname)
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 # don't add a second StreamHandler
341 # don't add a second StreamHandler
342 return
342 return
343 handler = logging.StreamHandler()
343 handler = logging.StreamHandler()
344 handler.setLevel(loglevel)
344 handler.setLevel(loglevel)
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 datefmt="%Y-%m-%d %H:%M:%S")
346 datefmt="%Y-%m-%d %H:%M:%S")
347 handler.setFormatter(formatter)
347 handler.setFormatter(formatter)
348
348
349 logger.addHandler(handler)
349 logger.addHandler(handler)
350 logger.setLevel(loglevel)
350 logger.setLevel(loglevel)
351 return logger
351 return logger
352
352
353 def set_hwm(sock, hwm=0):
353 def set_hwm(sock, hwm=0):
354 """set zmq High Water Mark on a socket
354 """set zmq High Water Mark on a socket
355
355
356 in a way that always works for various pyzmq / libzmq versions.
356 in a way that always works for various pyzmq / libzmq versions.
357 """
357 """
358 import zmq
358 import zmq
359
359
360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
361 opt = getattr(zmq, key, None)
361 opt = getattr(zmq, key, None)
362 if opt is None:
362 if opt is None:
363 continue
363 continue
364 try:
364 try:
365 sock.setsockopt(opt, hwm)
365 sock.setsockopt(opt, hwm)
366 except zmq.ZMQError:
366 except zmq.ZMQError:
367 pass
367 pass
368
368
369 No newline at end of file
369
@@ -1,378 +1,378 b''
1 """ Utilities for processing ANSI escape codes and special ASCII characters.
1 """ Utilities for processing ANSI escape codes and special ASCII characters.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Imports
4 # Imports
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6
6
7 # Standard library imports
7 # Standard library imports
8 from collections import namedtuple
8 from collections import namedtuple
9 import re
9 import re
10
10
11 # System library imports
11 # System library imports
12 from IPython.external.qt import QtGui
12 from IPython.external.qt import QtGui
13
13
14 # Local imports
14 # Local imports
15 from IPython.utils.py3compat import string_types
15 from IPython.utils.py3compat import string_types
16
16
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18 # Constants and datatypes
18 # Constants and datatypes
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20
20
21 # An action for erase requests (ED and EL commands).
21 # An action for erase requests (ED and EL commands).
22 EraseAction = namedtuple('EraseAction', ['action', 'area', 'erase_to'])
22 EraseAction = namedtuple('EraseAction', ['action', 'area', 'erase_to'])
23
23
24 # An action for cursor move requests (CUU, CUD, CUF, CUB, CNL, CPL, CHA, CUP,
24 # An action for cursor move requests (CUU, CUD, CUF, CUB, CNL, CPL, CHA, CUP,
25 # and HVP commands).
25 # and HVP commands).
26 # FIXME: Not implemented in AnsiCodeProcessor.
26 # FIXME: Not implemented in AnsiCodeProcessor.
27 MoveAction = namedtuple('MoveAction', ['action', 'dir', 'unit', 'count'])
27 MoveAction = namedtuple('MoveAction', ['action', 'dir', 'unit', 'count'])
28
28
29 # An action for scroll requests (SU and ST) and form feeds.
29 # An action for scroll requests (SU and ST) and form feeds.
30 ScrollAction = namedtuple('ScrollAction', ['action', 'dir', 'unit', 'count'])
30 ScrollAction = namedtuple('ScrollAction', ['action', 'dir', 'unit', 'count'])
31
31
32 # An action for the carriage return character
32 # An action for the carriage return character
33 CarriageReturnAction = namedtuple('CarriageReturnAction', ['action'])
33 CarriageReturnAction = namedtuple('CarriageReturnAction', ['action'])
34
34
35 # An action for the \n character
35 # An action for the \n character
36 NewLineAction = namedtuple('NewLineAction', ['action'])
36 NewLineAction = namedtuple('NewLineAction', ['action'])
37
37
38 # An action for the beep character
38 # An action for the beep character
39 BeepAction = namedtuple('BeepAction', ['action'])
39 BeepAction = namedtuple('BeepAction', ['action'])
40
40
41 # An action for backspace
41 # An action for backspace
42 BackSpaceAction = namedtuple('BackSpaceAction', ['action'])
42 BackSpaceAction = namedtuple('BackSpaceAction', ['action'])
43
43
44 # Regular expressions.
44 # Regular expressions.
45 CSI_COMMANDS = 'ABCDEFGHJKSTfmnsu'
45 CSI_COMMANDS = 'ABCDEFGHJKSTfmnsu'
46 CSI_SUBPATTERN = '\[(.*?)([%s])' % CSI_COMMANDS
46 CSI_SUBPATTERN = '\[(.*?)([%s])' % CSI_COMMANDS
47 OSC_SUBPATTERN = '\](.*?)[\x07\x1b]'
47 OSC_SUBPATTERN = '\](.*?)[\x07\x1b]'
48 ANSI_PATTERN = ('\x01?\x1b(%s|%s)\x02?' % \
48 ANSI_PATTERN = ('\x01?\x1b(%s|%s)\x02?' % \
49 (CSI_SUBPATTERN, OSC_SUBPATTERN))
49 (CSI_SUBPATTERN, OSC_SUBPATTERN))
50 ANSI_OR_SPECIAL_PATTERN = re.compile('(\a|\b|\r(?!\n)|\r?\n)|(?:%s)' % ANSI_PATTERN)
50 ANSI_OR_SPECIAL_PATTERN = re.compile('(\a|\b|\r(?!\n)|\r?\n)|(?:%s)' % ANSI_PATTERN)
51 SPECIAL_PATTERN = re.compile('([\f])')
51 SPECIAL_PATTERN = re.compile('([\f])')
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Classes
54 # Classes
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 class AnsiCodeProcessor(object):
57 class AnsiCodeProcessor(object):
58 """ Translates special ASCII characters and ANSI escape codes into readable
58 """ Translates special ASCII characters and ANSI escape codes into readable
59 attributes. It also supports a few non-standard, xterm-specific codes.
59 attributes. It also supports a few non-standard, xterm-specific codes.
60 """
60 """
61
61
62 # Whether to increase intensity or set boldness for SGR code 1.
62 # Whether to increase intensity or set boldness for SGR code 1.
63 # (Different terminals handle this in different ways.)
63 # (Different terminals handle this in different ways.)
64 bold_text_enabled = False
64 bold_text_enabled = False
65
65
66 # We provide an empty default color map because subclasses will likely want
66 # We provide an empty default color map because subclasses will likely want
67 # to use a custom color format.
67 # to use a custom color format.
68 default_color_map = {}
68 default_color_map = {}
69
69
70 #---------------------------------------------------------------------------
70 #---------------------------------------------------------------------------
71 # AnsiCodeProcessor interface
71 # AnsiCodeProcessor interface
72 #---------------------------------------------------------------------------
72 #---------------------------------------------------------------------------
73
73
74 def __init__(self):
74 def __init__(self):
75 self.actions = []
75 self.actions = []
76 self.color_map = self.default_color_map.copy()
76 self.color_map = self.default_color_map.copy()
77 self.reset_sgr()
77 self.reset_sgr()
78
78
79 def reset_sgr(self):
79 def reset_sgr(self):
80 """ Reset graphics attributs to their default values.
80 """ Reset graphics attributs to their default values.
81 """
81 """
82 self.intensity = 0
82 self.intensity = 0
83 self.italic = False
83 self.italic = False
84 self.bold = False
84 self.bold = False
85 self.underline = False
85 self.underline = False
86 self.foreground_color = None
86 self.foreground_color = None
87 self.background_color = None
87 self.background_color = None
88
88
89 def split_string(self, string):
89 def split_string(self, string):
90 """ Yields substrings for which the same escape code applies.
90 """ Yields substrings for which the same escape code applies.
91 """
91 """
92 self.actions = []
92 self.actions = []
93 start = 0
93 start = 0
94
94
95 # strings ending with \r are assumed to be ending in \r\n since
95 # strings ending with \r are assumed to be ending in \r\n since
96 # \n is appended to output strings automatically. Accounting
96 # \n is appended to output strings automatically. Accounting
97 # for that, here.
97 # for that, here.
98 last_char = '\n' if len(string) > 0 and string[-1] == '\n' else None
98 last_char = '\n' if len(string) > 0 and string[-1] == '\n' else None
99 string = string[:-1] if last_char is not None else string
99 string = string[:-1] if last_char is not None else string
100
100
101 for match in ANSI_OR_SPECIAL_PATTERN.finditer(string):
101 for match in ANSI_OR_SPECIAL_PATTERN.finditer(string):
102 raw = string[start:match.start()]
102 raw = string[start:match.start()]
103 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
103 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
104 if substring or self.actions:
104 if substring or self.actions:
105 yield substring
105 yield substring
106 self.actions = []
106 self.actions = []
107 start = match.end()
107 start = match.end()
108
108
109 groups = filter(lambda x: x is not None, match.groups())
109 groups = filter(lambda x: x is not None, match.groups())
110 g0 = groups[0]
110 g0 = groups[0]
111 if g0 == '\a':
111 if g0 == '\a':
112 self.actions.append(BeepAction('beep'))
112 self.actions.append(BeepAction('beep'))
113 yield None
113 yield None
114 self.actions = []
114 self.actions = []
115 elif g0 == '\r':
115 elif g0 == '\r':
116 self.actions.append(CarriageReturnAction('carriage-return'))
116 self.actions.append(CarriageReturnAction('carriage-return'))
117 yield None
117 yield None
118 self.actions = []
118 self.actions = []
119 elif g0 == '\b':
119 elif g0 == '\b':
120 self.actions.append(BackSpaceAction('backspace'))
120 self.actions.append(BackSpaceAction('backspace'))
121 yield None
121 yield None
122 self.actions = []
122 self.actions = []
123 elif g0 == '\n' or g0 == '\r\n':
123 elif g0 == '\n' or g0 == '\r\n':
124 self.actions.append(NewLineAction('newline'))
124 self.actions.append(NewLineAction('newline'))
125 yield g0
125 yield g0
126 self.actions = []
126 self.actions = []
127 else:
127 else:
128 params = [ param for param in groups[1].split(';') if param ]
128 params = [ param for param in groups[1].split(';') if param ]
129 if g0.startswith('['):
129 if g0.startswith('['):
130 # Case 1: CSI code.
130 # Case 1: CSI code.
131 try:
131 try:
132 params = map(int, params)
132 params = map(int, params)
133 except ValueError:
133 except ValueError:
134 # Silently discard badly formed codes.
134 # Silently discard badly formed codes.
135 pass
135 pass
136 else:
136 else:
137 self.set_csi_code(groups[2], params)
137 self.set_csi_code(groups[2], params)
138
138
139 elif g0.startswith(']'):
139 elif g0.startswith(']'):
140 # Case 2: OSC code.
140 # Case 2: OSC code.
141 self.set_osc_code(params)
141 self.set_osc_code(params)
142
142
143 raw = string[start:]
143 raw = string[start:]
144 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
144 substring = SPECIAL_PATTERN.sub(self._replace_special, raw)
145 if substring or self.actions:
145 if substring or self.actions:
146 yield substring
146 yield substring
147
147
148 if last_char is not None:
148 if last_char is not None:
149 self.actions.append(NewLineAction('newline'))
149 self.actions.append(NewLineAction('newline'))
150 yield last_char
150 yield last_char
151
151
152 def set_csi_code(self, command, params=[]):
152 def set_csi_code(self, command, params=[]):
153 """ Set attributes based on CSI (Control Sequence Introducer) code.
153 """ Set attributes based on CSI (Control Sequence Introducer) code.
154
154
155 Parameters
155 Parameters
156 ----------
156 ----------
157 command : str
157 command : str
158 The code identifier, i.e. the final character in the sequence.
158 The code identifier, i.e. the final character in the sequence.
159
159
160 params : sequence of integers, optional
160 params : sequence of integers, optional
161 The parameter codes for the command.
161 The parameter codes for the command.
162 """
162 """
163 if command == 'm': # SGR - Select Graphic Rendition
163 if command == 'm': # SGR - Select Graphic Rendition
164 if params:
164 if params:
165 self.set_sgr_code(params)
165 self.set_sgr_code(params)
166 else:
166 else:
167 self.set_sgr_code([0])
167 self.set_sgr_code([0])
168
168
169 elif (command == 'J' or # ED - Erase Data
169 elif (command == 'J' or # ED - Erase Data
170 command == 'K'): # EL - Erase in Line
170 command == 'K'): # EL - Erase in Line
171 code = params[0] if params else 0
171 code = params[0] if params else 0
172 if 0 <= code <= 2:
172 if 0 <= code <= 2:
173 area = 'screen' if command == 'J' else 'line'
173 area = 'screen' if command == 'J' else 'line'
174 if code == 0:
174 if code == 0:
175 erase_to = 'end'
175 erase_to = 'end'
176 elif code == 1:
176 elif code == 1:
177 erase_to = 'start'
177 erase_to = 'start'
178 elif code == 2:
178 elif code == 2:
179 erase_to = 'all'
179 erase_to = 'all'
180 self.actions.append(EraseAction('erase', area, erase_to))
180 self.actions.append(EraseAction('erase', area, erase_to))
181
181
182 elif (command == 'S' or # SU - Scroll Up
182 elif (command == 'S' or # SU - Scroll Up
183 command == 'T'): # SD - Scroll Down
183 command == 'T'): # SD - Scroll Down
184 dir = 'up' if command == 'S' else 'down'
184 dir = 'up' if command == 'S' else 'down'
185 count = params[0] if params else 1
185 count = params[0] if params else 1
186 self.actions.append(ScrollAction('scroll', dir, 'line', count))
186 self.actions.append(ScrollAction('scroll', dir, 'line', count))
187
187
188 def set_osc_code(self, params):
188 def set_osc_code(self, params):
189 """ Set attributes based on OSC (Operating System Command) parameters.
189 """ Set attributes based on OSC (Operating System Command) parameters.
190
190
191 Parameters
191 Parameters
192 ----------
192 ----------
193 params : sequence of str
193 params : sequence of str
194 The parameters for the command.
194 The parameters for the command.
195 """
195 """
196 try:
196 try:
197 command = int(params.pop(0))
197 command = int(params.pop(0))
198 except (IndexError, ValueError):
198 except (IndexError, ValueError):
199 return
199 return
200
200
201 if command == 4:
201 if command == 4:
202 # xterm-specific: set color number to color spec.
202 # xterm-specific: set color number to color spec.
203 try:
203 try:
204 color = int(params.pop(0))
204 color = int(params.pop(0))
205 spec = params.pop(0)
205 spec = params.pop(0)
206 self.color_map[color] = self._parse_xterm_color_spec(spec)
206 self.color_map[color] = self._parse_xterm_color_spec(spec)
207 except (IndexError, ValueError):
207 except (IndexError, ValueError):
208 pass
208 pass
209
209
210 def set_sgr_code(self, params):
210 def set_sgr_code(self, params):
211 """ Set attributes based on SGR (Select Graphic Rendition) codes.
211 """ Set attributes based on SGR (Select Graphic Rendition) codes.
212
212
213 Parameters
213 Parameters
214 ----------
214 ----------
215 params : sequence of ints
215 params : sequence of ints
216 A list of SGR codes for one or more SGR commands. Usually this
216 A list of SGR codes for one or more SGR commands. Usually this
217 sequence will have one element per command, although certain
217 sequence will have one element per command, although certain
218 xterm-specific commands requires multiple elements.
218 xterm-specific commands requires multiple elements.
219 """
219 """
220 # Always consume the first parameter.
220 # Always consume the first parameter.
221 if not params:
221 if not params:
222 return
222 return
223 code = params.pop(0)
223 code = params.pop(0)
224
224
225 if code == 0:
225 if code == 0:
226 self.reset_sgr()
226 self.reset_sgr()
227 elif code == 1:
227 elif code == 1:
228 if self.bold_text_enabled:
228 if self.bold_text_enabled:
229 self.bold = True
229 self.bold = True
230 else:
230 else:
231 self.intensity = 1
231 self.intensity = 1
232 elif code == 2:
232 elif code == 2:
233 self.intensity = 0
233 self.intensity = 0
234 elif code == 3:
234 elif code == 3:
235 self.italic = True
235 self.italic = True
236 elif code == 4:
236 elif code == 4:
237 self.underline = True
237 self.underline = True
238 elif code == 22:
238 elif code == 22:
239 self.intensity = 0
239 self.intensity = 0
240 self.bold = False
240 self.bold = False
241 elif code == 23:
241 elif code == 23:
242 self.italic = False
242 self.italic = False
243 elif code == 24:
243 elif code == 24:
244 self.underline = False
244 self.underline = False
245 elif code >= 30 and code <= 37:
245 elif code >= 30 and code <= 37:
246 self.foreground_color = code - 30
246 self.foreground_color = code - 30
247 elif code == 38 and params and params.pop(0) == 5:
247 elif code == 38 and params and params.pop(0) == 5:
248 # xterm-specific: 256 color support.
248 # xterm-specific: 256 color support.
249 if params:
249 if params:
250 self.foreground_color = params.pop(0)
250 self.foreground_color = params.pop(0)
251 elif code == 39:
251 elif code == 39:
252 self.foreground_color = None
252 self.foreground_color = None
253 elif code >= 40 and code <= 47:
253 elif code >= 40 and code <= 47:
254 self.background_color = code - 40
254 self.background_color = code - 40
255 elif code == 48 and params and params.pop(0) == 5:
255 elif code == 48 and params and params.pop(0) == 5:
256 # xterm-specific: 256 color support.
256 # xterm-specific: 256 color support.
257 if params:
257 if params:
258 self.background_color = params.pop(0)
258 self.background_color = params.pop(0)
259 elif code == 49:
259 elif code == 49:
260 self.background_color = None
260 self.background_color = None
261
261
262 # Recurse with unconsumed parameters.
262 # Recurse with unconsumed parameters.
263 self.set_sgr_code(params)
263 self.set_sgr_code(params)
264
264
265 #---------------------------------------------------------------------------
265 #---------------------------------------------------------------------------
266 # Protected interface
266 # Protected interface
267 #---------------------------------------------------------------------------
267 #---------------------------------------------------------------------------
268
268
269 def _parse_xterm_color_spec(self, spec):
269 def _parse_xterm_color_spec(self, spec):
270 if spec.startswith('rgb:'):
270 if spec.startswith('rgb:'):
271 return tuple(map(lambda x: int(x, 16), spec[4:].split('/')))
271 return tuple(map(lambda x: int(x, 16), spec[4:].split('/')))
272 elif spec.startswith('rgbi:'):
272 elif spec.startswith('rgbi:'):
273 return tuple(map(lambda x: int(float(x) * 255),
273 return tuple(map(lambda x: int(float(x) * 255),
274 spec[5:].split('/')))
274 spec[5:].split('/')))
275 elif spec == '?':
275 elif spec == '?':
276 raise ValueError('Unsupported xterm color spec')
276 raise ValueError('Unsupported xterm color spec')
277 return spec
277 return spec
278
278
279 def _replace_special(self, match):
279 def _replace_special(self, match):
280 special = match.group(1)
280 special = match.group(1)
281 if special == '\f':
281 if special == '\f':
282 self.actions.append(ScrollAction('scroll', 'down', 'page', 1))
282 self.actions.append(ScrollAction('scroll', 'down', 'page', 1))
283 return ''
283 return ''
284
284
285
285
286 class QtAnsiCodeProcessor(AnsiCodeProcessor):
286 class QtAnsiCodeProcessor(AnsiCodeProcessor):
287 """ Translates ANSI escape codes into QTextCharFormats.
287 """ Translates ANSI escape codes into QTextCharFormats.
288 """
288 """
289
289
290 # A map from ANSI color codes to SVG color names or RGB(A) tuples.
290 # A map from ANSI color codes to SVG color names or RGB(A) tuples.
291 darkbg_color_map = {
291 darkbg_color_map = {
292 0 : 'black', # black
292 0 : 'black', # black
293 1 : 'darkred', # red
293 1 : 'darkred', # red
294 2 : 'darkgreen', # green
294 2 : 'darkgreen', # green
295 3 : 'brown', # yellow
295 3 : 'brown', # yellow
296 4 : 'darkblue', # blue
296 4 : 'darkblue', # blue
297 5 : 'darkviolet', # magenta
297 5 : 'darkviolet', # magenta
298 6 : 'steelblue', # cyan
298 6 : 'steelblue', # cyan
299 7 : 'grey', # white
299 7 : 'grey', # white
300 8 : 'grey', # black (bright)
300 8 : 'grey', # black (bright)
301 9 : 'red', # red (bright)
301 9 : 'red', # red (bright)
302 10 : 'lime', # green (bright)
302 10 : 'lime', # green (bright)
303 11 : 'yellow', # yellow (bright)
303 11 : 'yellow', # yellow (bright)
304 12 : 'deepskyblue', # blue (bright)
304 12 : 'deepskyblue', # blue (bright)
305 13 : 'magenta', # magenta (bright)
305 13 : 'magenta', # magenta (bright)
306 14 : 'cyan', # cyan (bright)
306 14 : 'cyan', # cyan (bright)
307 15 : 'white' } # white (bright)
307 15 : 'white' } # white (bright)
308
308
309 # Set the default color map for super class.
309 # Set the default color map for super class.
310 default_color_map = darkbg_color_map.copy()
310 default_color_map = darkbg_color_map.copy()
311
311
312 def get_color(self, color, intensity=0):
312 def get_color(self, color, intensity=0):
313 """ Returns a QColor for a given color code, or None if one cannot be
313 """ Returns a QColor for a given color code, or None if one cannot be
314 constructed.
314 constructed.
315 """
315 """
316 if color is None:
316 if color is None:
317 return None
317 return None
318
318
319 # Adjust for intensity, if possible.
319 # Adjust for intensity, if possible.
320 if color < 8 and intensity > 0:
320 if color < 8 and intensity > 0:
321 color += 8
321 color += 8
322
322
323 constructor = self.color_map.get(color, None)
323 constructor = self.color_map.get(color, None)
324 if isinstance(constructor, string_types):
324 if isinstance(constructor, string_types):
325 # If this is an X11 color name, we just hope there is a close SVG
325 # If this is an X11 color name, we just hope there is a close SVG
326 # color name. We could use QColor's static method
326 # color name. We could use QColor's static method
327 # 'setAllowX11ColorNames()', but this is global and only available
327 # 'setAllowX11ColorNames()', but this is global and only available
328 # on X11. It seems cleaner to aim for uniformity of behavior.
328 # on X11. It seems cleaner to aim for uniformity of behavior.
329 return QtGui.QColor(constructor)
329 return QtGui.QColor(constructor)
330
330
331 elif isinstance(constructor, (tuple, list)):
331 elif isinstance(constructor, (tuple, list)):
332 return QtGui.QColor(*constructor)
332 return QtGui.QColor(*constructor)
333
333
334 return None
334 return None
335
335
336 def get_format(self):
336 def get_format(self):
337 """ Returns a QTextCharFormat that encodes the current style attributes.
337 """ Returns a QTextCharFormat that encodes the current style attributes.
338 """
338 """
339 format = QtGui.QTextCharFormat()
339 format = QtGui.QTextCharFormat()
340
340
341 # Set foreground color
341 # Set foreground color
342 qcolor = self.get_color(self.foreground_color, self.intensity)
342 qcolor = self.get_color(self.foreground_color, self.intensity)
343 if qcolor is not None:
343 if qcolor is not None:
344 format.setForeground(qcolor)
344 format.setForeground(qcolor)
345
345
346 # Set background color
346 # Set background color
347 qcolor = self.get_color(self.background_color, self.intensity)
347 qcolor = self.get_color(self.background_color, self.intensity)
348 if qcolor is not None:
348 if qcolor is not None:
349 format.setBackground(qcolor)
349 format.setBackground(qcolor)
350
350
351 # Set font weight/style options
351 # Set font weight/style options
352 if self.bold:
352 if self.bold:
353 format.setFontWeight(QtGui.QFont.Bold)
353 format.setFontWeight(QtGui.QFont.Bold)
354 else:
354 else:
355 format.setFontWeight(QtGui.QFont.Normal)
355 format.setFontWeight(QtGui.QFont.Normal)
356 format.setFontItalic(self.italic)
356 format.setFontItalic(self.italic)
357 format.setFontUnderline(self.underline)
357 format.setFontUnderline(self.underline)
358
358
359 return format
359 return format
360
360
361 def set_background_color(self, color):
361 def set_background_color(self, color):
362 """ Given a background color (a QColor), attempt to set a color map
362 """ Given a background color (a QColor), attempt to set a color map
363 that will be aesthetically pleasing.
363 that will be aesthetically pleasing.
364 """
364 """
365 # Set a new default color map.
365 # Set a new default color map.
366 self.default_color_map = self.darkbg_color_map.copy()
366 self.default_color_map = self.darkbg_color_map.copy()
367
367
368 if color.value() >= 127:
368 if color.value() >= 127:
369 # Colors appropriate for a terminal with a light background. For
369 # Colors appropriate for a terminal with a light background. For
370 # now, only use non-bright colors...
370 # now, only use non-bright colors...
371 for i in xrange(8):
371 for i in range(8):
372 self.default_color_map[i + 8] = self.default_color_map[i]
372 self.default_color_map[i + 8] = self.default_color_map[i]
373
373
374 # ...and replace white with black.
374 # ...and replace white with black.
375 self.default_color_map[7] = self.default_color_map[15] = 'black'
375 self.default_color_map[7] = self.default_color_map[15] = 'black'
376
376
377 # Update the current color map with the new defaults.
377 # Update the current color map with the new defaults.
378 self.color_map.update(self.default_color_map)
378 self.color_map.update(self.default_color_map)
@@ -1,35 +1,37 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Utilities for working with data structures like lists, dicts and tuples.
2 """Utilities for working with data structures like lists, dicts and tuples.
3 """
3 """
4
4
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (C) 2008-2011 The IPython Development Team
6 # Copyright (C) 2008-2011 The IPython Development Team
7 #
7 #
8 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12 from .py3compat import xrange
13
12 def uniq_stable(elems):
14 def uniq_stable(elems):
13 """uniq_stable(elems) -> list
15 """uniq_stable(elems) -> list
14
16
15 Return from an iterable, a list of all the unique elements in the input,
17 Return from an iterable, a list of all the unique elements in the input,
16 but maintaining the order in which they first appear.
18 but maintaining the order in which they first appear.
17
19
18 Note: All elements in the input must be hashable for this routine
20 Note: All elements in the input must be hashable for this routine
19 to work, as it internally uses a set for efficiency reasons.
21 to work, as it internally uses a set for efficiency reasons.
20 """
22 """
21 seen = set()
23 seen = set()
22 return [x for x in elems if x not in seen and not seen.add(x)]
24 return [x for x in elems if x not in seen and not seen.add(x)]
23
25
24
26
25 def flatten(seq):
27 def flatten(seq):
26 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
28 """Flatten a list of lists (NOT recursive, only works for 2d lists)."""
27
29
28 return [x for subseq in seq for x in subseq]
30 return [x for subseq in seq for x in subseq]
29
31
30
32
31 def chop(seq, size):
33 def chop(seq, size):
32 """Chop a sequence into chunks of the given size."""
34 """Chop a sequence into chunks of the given size."""
33 return [seq[i:i+size] for i in xrange(0,len(seq),size)]
35 return [seq[i:i+size] for i in xrange(0,len(seq),size)]
34
36
35
37
@@ -1,207 +1,210 b''
1 # coding: utf-8
1 # coding: utf-8
2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
2 """Compatibility tricks for Python 3. Mainly to do with unicode."""
3 import functools
3 import functools
4 import sys
4 import sys
5 import re
5 import re
6 import types
6 import types
7
7
8 from .encoding import DEFAULT_ENCODING
8 from .encoding import DEFAULT_ENCODING
9
9
10 orig_open = open
10 orig_open = open
11
11
12 def no_code(x, encoding=None):
12 def no_code(x, encoding=None):
13 return x
13 return x
14
14
15 def decode(s, encoding=None):
15 def decode(s, encoding=None):
16 encoding = encoding or DEFAULT_ENCODING
16 encoding = encoding or DEFAULT_ENCODING
17 return s.decode(encoding, "replace")
17 return s.decode(encoding, "replace")
18
18
19 def encode(u, encoding=None):
19 def encode(u, encoding=None):
20 encoding = encoding or DEFAULT_ENCODING
20 encoding = encoding or DEFAULT_ENCODING
21 return u.encode(encoding, "replace")
21 return u.encode(encoding, "replace")
22
22
23
23
24 def cast_unicode(s, encoding=None):
24 def cast_unicode(s, encoding=None):
25 if isinstance(s, bytes):
25 if isinstance(s, bytes):
26 return decode(s, encoding)
26 return decode(s, encoding)
27 return s
27 return s
28
28
29 def cast_bytes(s, encoding=None):
29 def cast_bytes(s, encoding=None):
30 if not isinstance(s, bytes):
30 if not isinstance(s, bytes):
31 return encode(s, encoding)
31 return encode(s, encoding)
32 return s
32 return s
33
33
34 def _modify_str_or_docstring(str_change_func):
34 def _modify_str_or_docstring(str_change_func):
35 @functools.wraps(str_change_func)
35 @functools.wraps(str_change_func)
36 def wrapper(func_or_str):
36 def wrapper(func_or_str):
37 if isinstance(func_or_str, string_types):
37 if isinstance(func_or_str, string_types):
38 func = None
38 func = None
39 doc = func_or_str
39 doc = func_or_str
40 else:
40 else:
41 func = func_or_str
41 func = func_or_str
42 doc = func.__doc__
42 doc = func.__doc__
43
43
44 doc = str_change_func(doc)
44 doc = str_change_func(doc)
45
45
46 if func:
46 if func:
47 func.__doc__ = doc
47 func.__doc__ = doc
48 return func
48 return func
49 return doc
49 return doc
50 return wrapper
50 return wrapper
51
51
52 def safe_unicode(e):
52 def safe_unicode(e):
53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
53 """unicode(e) with various fallbacks. Used for exceptions, which may not be
54 safe to call unicode() on.
54 safe to call unicode() on.
55 """
55 """
56 try:
56 try:
57 return unicode_type(e)
57 return unicode_type(e)
58 except UnicodeError:
58 except UnicodeError:
59 pass
59 pass
60
60
61 try:
61 try:
62 return str_to_unicode(str(e))
62 return str_to_unicode(str(e))
63 except UnicodeError:
63 except UnicodeError:
64 pass
64 pass
65
65
66 try:
66 try:
67 return str_to_unicode(repr(e))
67 return str_to_unicode(repr(e))
68 except UnicodeError:
68 except UnicodeError:
69 pass
69 pass
70
70
71 return u'Unrecoverably corrupt evalue'
71 return u'Unrecoverably corrupt evalue'
72
72
73 if sys.version_info[0] >= 3:
73 if sys.version_info[0] >= 3:
74 PY3 = True
74 PY3 = True
75
75
76 input = input
76 input = input
77 builtin_mod_name = "builtins"
77 builtin_mod_name = "builtins"
78 import builtins as builtin_mod
78 import builtins as builtin_mod
79
79
80 str_to_unicode = no_code
80 str_to_unicode = no_code
81 unicode_to_str = no_code
81 unicode_to_str = no_code
82 str_to_bytes = encode
82 str_to_bytes = encode
83 bytes_to_str = decode
83 bytes_to_str = decode
84 cast_bytes_py2 = no_code
84 cast_bytes_py2 = no_code
85
85
86 string_types = (str,)
86 string_types = (str,)
87 unicode_type = str
87 unicode_type = str
88
88
89 def isidentifier(s, dotted=False):
89 def isidentifier(s, dotted=False):
90 if dotted:
90 if dotted:
91 return all(isidentifier(a) for a in s.split("."))
91 return all(isidentifier(a) for a in s.split("."))
92 return s.isidentifier()
92 return s.isidentifier()
93
93
94 open = orig_open
94 open = orig_open
95 xrange = range
95
96
96 MethodType = types.MethodType
97 MethodType = types.MethodType
97
98
98 def execfile(fname, glob, loc=None):
99 def execfile(fname, glob, loc=None):
99 loc = loc if (loc is not None) else glob
100 loc = loc if (loc is not None) else glob
100 with open(fname, 'rb') as f:
101 with open(fname, 'rb') as f:
101 exec(compile(f.read(), fname, 'exec'), glob, loc)
102 exec(compile(f.read(), fname, 'exec'), glob, loc)
102
103
103 # Refactor print statements in doctests.
104 # Refactor print statements in doctests.
104 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
105 _print_statement_re = re.compile(r"\bprint (?P<expr>.*)$", re.MULTILINE)
105 def _print_statement_sub(match):
106 def _print_statement_sub(match):
106 expr = match.groups('expr')
107 expr = match.groups('expr')
107 return "print(%s)" % expr
108 return "print(%s)" % expr
108
109
109 @_modify_str_or_docstring
110 @_modify_str_or_docstring
110 def doctest_refactor_print(doc):
111 def doctest_refactor_print(doc):
111 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
112 """Refactor 'print x' statements in a doctest to print(x) style. 2to3
112 unfortunately doesn't pick up on our doctests.
113 unfortunately doesn't pick up on our doctests.
113
114
114 Can accept a string or a function, so it can be used as a decorator."""
115 Can accept a string or a function, so it can be used as a decorator."""
115 return _print_statement_re.sub(_print_statement_sub, doc)
116 return _print_statement_re.sub(_print_statement_sub, doc)
116
117
117 # Abstract u'abc' syntax:
118 # Abstract u'abc' syntax:
118 @_modify_str_or_docstring
119 @_modify_str_or_docstring
119 def u_format(s):
120 def u_format(s):
120 """"{u}'abc'" --> "'abc'" (Python 3)
121 """"{u}'abc'" --> "'abc'" (Python 3)
121
122
122 Accepts a string or a function, so it can be used as a decorator."""
123 Accepts a string or a function, so it can be used as a decorator."""
123 return s.format(u='')
124 return s.format(u='')
124
125
125 else:
126 else:
126 PY3 = False
127 PY3 = False
127
128
128 input = raw_input
129 input = raw_input
129 builtin_mod_name = "__builtin__"
130 builtin_mod_name = "__builtin__"
130 import __builtin__ as builtin_mod
131 import __builtin__ as builtin_mod
131
132
132 str_to_unicode = decode
133 str_to_unicode = decode
133 unicode_to_str = encode
134 unicode_to_str = encode
134 str_to_bytes = no_code
135 str_to_bytes = no_code
135 bytes_to_str = no_code
136 bytes_to_str = no_code
136 cast_bytes_py2 = cast_bytes
137 cast_bytes_py2 = cast_bytes
137
138
138 string_types = (str, unicode)
139 string_types = (str, unicode)
139 unicode_type = unicode
140 unicode_type = unicode
140
141
141 import re
142 import re
142 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
143 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
143 def isidentifier(s, dotted=False):
144 def isidentifier(s, dotted=False):
144 if dotted:
145 if dotted:
145 return all(isidentifier(a) for a in s.split("."))
146 return all(isidentifier(a) for a in s.split("."))
146 return bool(_name_re.match(s))
147 return bool(_name_re.match(s))
147
148
148 class open(object):
149 class open(object):
149 """Wrapper providing key part of Python 3 open() interface."""
150 """Wrapper providing key part of Python 3 open() interface."""
150 def __init__(self, fname, mode="r", encoding="utf-8"):
151 def __init__(self, fname, mode="r", encoding="utf-8"):
151 self.f = orig_open(fname, mode)
152 self.f = orig_open(fname, mode)
152 self.enc = encoding
153 self.enc = encoding
153
154
154 def write(self, s):
155 def write(self, s):
155 return self.f.write(s.encode(self.enc))
156 return self.f.write(s.encode(self.enc))
156
157
157 def read(self, size=-1):
158 def read(self, size=-1):
158 return self.f.read(size).decode(self.enc)
159 return self.f.read(size).decode(self.enc)
159
160
160 def close(self):
161 def close(self):
161 return self.f.close()
162 return self.f.close()
162
163
163 def __enter__(self):
164 def __enter__(self):
164 return self
165 return self
165
166
166 def __exit__(self, etype, value, traceback):
167 def __exit__(self, etype, value, traceback):
167 self.f.close()
168 self.f.close()
168
169
170 xrange = xrange
171
169 def MethodType(func, instance):
172 def MethodType(func, instance):
170 return types.MethodType(func, instance, type(instance))
173 return types.MethodType(func, instance, type(instance))
171
174
172 # don't override system execfile on 2.x:
175 # don't override system execfile on 2.x:
173 execfile = execfile
176 execfile = execfile
174
177
175 def doctest_refactor_print(func_or_str):
178 def doctest_refactor_print(func_or_str):
176 return func_or_str
179 return func_or_str
177
180
178
181
179 # Abstract u'abc' syntax:
182 # Abstract u'abc' syntax:
180 @_modify_str_or_docstring
183 @_modify_str_or_docstring
181 def u_format(s):
184 def u_format(s):
182 """"{u}'abc'" --> "u'abc'" (Python 2)
185 """"{u}'abc'" --> "u'abc'" (Python 2)
183
186
184 Accepts a string or a function, so it can be used as a decorator."""
187 Accepts a string or a function, so it can be used as a decorator."""
185 return s.format(u='u')
188 return s.format(u='u')
186
189
187 if sys.platform == 'win32':
190 if sys.platform == 'win32':
188 def execfile(fname, glob=None, loc=None):
191 def execfile(fname, glob=None, loc=None):
189 loc = loc if (loc is not None) else glob
192 loc = loc if (loc is not None) else glob
190 # The rstrip() is necessary b/c trailing whitespace in files will
193 # The rstrip() is necessary b/c trailing whitespace in files will
191 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
194 # cause an IndentationError in Python 2.6 (this was fixed in 2.7,
192 # but we still support 2.6). See issue 1027.
195 # but we still support 2.6). See issue 1027.
193 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
196 scripttext = builtin_mod.open(fname).read().rstrip() + '\n'
194 # compile converts unicode filename to str assuming
197 # compile converts unicode filename to str assuming
195 # ascii. Let's do the conversion before calling compile
198 # ascii. Let's do the conversion before calling compile
196 if isinstance(fname, unicode):
199 if isinstance(fname, unicode):
197 filename = unicode_to_str(fname)
200 filename = unicode_to_str(fname)
198 else:
201 else:
199 filename = fname
202 filename = fname
200 exec(compile(scripttext, filename, 'exec'), glob, loc)
203 exec(compile(scripttext, filename, 'exec'), glob, loc)
201 else:
204 else:
202 def execfile(fname, *where):
205 def execfile(fname, *where):
203 if isinstance(fname, unicode):
206 if isinstance(fname, unicode):
204 filename = fname.encode(sys.getfilesystemencoding())
207 filename = fname.encode(sys.getfilesystemencoding())
205 else:
208 else:
206 filename = fname
209 filename = fname
207 builtin_mod.execfile(filename, *where)
210 builtin_mod.execfile(filename, *where)
@@ -1,759 +1,758 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for working with strings and text.
3 Utilities for working with strings and text.
4
4
5 Inheritance diagram:
5 Inheritance diagram:
6
6
7 .. inheritance-diagram:: IPython.utils.text
7 .. inheritance-diagram:: IPython.utils.text
8 :parts: 3
8 :parts: 3
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import os
22 import os
23 import re
23 import re
24 import sys
24 import sys
25 import textwrap
25 import textwrap
26 from string import Formatter
26 from string import Formatter
27
27
28 from IPython.external.path import path
28 from IPython.external.path import path
29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
29 from IPython.testing.skipdoctest import skip_doctest_py3, skip_doctest
30 from IPython.utils import py3compat
30 from IPython.utils import py3compat
31
31
32
33 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
34 # Declarations
33 # Declarations
35 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
36
35
37 # datetime.strftime date format for ipython
36 # datetime.strftime date format for ipython
38 if sys.platform == 'win32':
37 if sys.platform == 'win32':
39 date_format = "%B %d, %Y"
38 date_format = "%B %d, %Y"
40 else:
39 else:
41 date_format = "%B %-d, %Y"
40 date_format = "%B %-d, %Y"
42
41
43
42
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45 # Code
44 # Code
46 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
47
46
48 class LSString(str):
47 class LSString(str):
49 """String derivative with a special access attributes.
48 """String derivative with a special access attributes.
50
49
51 These are normal strings, but with the special attributes:
50 These are normal strings, but with the special attributes:
52
51
53 .l (or .list) : value as list (split on newlines).
52 .l (or .list) : value as list (split on newlines).
54 .n (or .nlstr): original value (the string itself).
53 .n (or .nlstr): original value (the string itself).
55 .s (or .spstr): value as whitespace-separated string.
54 .s (or .spstr): value as whitespace-separated string.
56 .p (or .paths): list of path objects
55 .p (or .paths): list of path objects
57
56
58 Any values which require transformations are computed only once and
57 Any values which require transformations are computed only once and
59 cached.
58 cached.
60
59
61 Such strings are very useful to efficiently interact with the shell, which
60 Such strings are very useful to efficiently interact with the shell, which
62 typically only understands whitespace-separated options for commands."""
61 typically only understands whitespace-separated options for commands."""
63
62
64 def get_list(self):
63 def get_list(self):
65 try:
64 try:
66 return self.__list
65 return self.__list
67 except AttributeError:
66 except AttributeError:
68 self.__list = self.split('\n')
67 self.__list = self.split('\n')
69 return self.__list
68 return self.__list
70
69
71 l = list = property(get_list)
70 l = list = property(get_list)
72
71
73 def get_spstr(self):
72 def get_spstr(self):
74 try:
73 try:
75 return self.__spstr
74 return self.__spstr
76 except AttributeError:
75 except AttributeError:
77 self.__spstr = self.replace('\n',' ')
76 self.__spstr = self.replace('\n',' ')
78 return self.__spstr
77 return self.__spstr
79
78
80 s = spstr = property(get_spstr)
79 s = spstr = property(get_spstr)
81
80
82 def get_nlstr(self):
81 def get_nlstr(self):
83 return self
82 return self
84
83
85 n = nlstr = property(get_nlstr)
84 n = nlstr = property(get_nlstr)
86
85
87 def get_paths(self):
86 def get_paths(self):
88 try:
87 try:
89 return self.__paths
88 return self.__paths
90 except AttributeError:
89 except AttributeError:
91 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
90 self.__paths = [path(p) for p in self.split('\n') if os.path.exists(p)]
92 return self.__paths
91 return self.__paths
93
92
94 p = paths = property(get_paths)
93 p = paths = property(get_paths)
95
94
96 # FIXME: We need to reimplement type specific displayhook and then add this
95 # FIXME: We need to reimplement type specific displayhook and then add this
97 # back as a custom printer. This should also be moved outside utils into the
96 # back as a custom printer. This should also be moved outside utils into the
98 # core.
97 # core.
99
98
100 # def print_lsstring(arg):
99 # def print_lsstring(arg):
101 # """ Prettier (non-repr-like) and more informative printer for LSString """
100 # """ Prettier (non-repr-like) and more informative printer for LSString """
102 # print "LSString (.p, .n, .l, .s available). Value:"
101 # print "LSString (.p, .n, .l, .s available). Value:"
103 # print arg
102 # print arg
104 #
103 #
105 #
104 #
106 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
105 # print_lsstring = result_display.when_type(LSString)(print_lsstring)
107
106
108
107
109 class SList(list):
108 class SList(list):
110 """List derivative with a special access attributes.
109 """List derivative with a special access attributes.
111
110
112 These are normal lists, but with the special attributes:
111 These are normal lists, but with the special attributes:
113
112
114 .l (or .list) : value as list (the list itself).
113 .l (or .list) : value as list (the list itself).
115 .n (or .nlstr): value as a string, joined on newlines.
114 .n (or .nlstr): value as a string, joined on newlines.
116 .s (or .spstr): value as a string, joined on spaces.
115 .s (or .spstr): value as a string, joined on spaces.
117 .p (or .paths): list of path objects
116 .p (or .paths): list of path objects
118
117
119 Any values which require transformations are computed only once and
118 Any values which require transformations are computed only once and
120 cached."""
119 cached."""
121
120
122 def get_list(self):
121 def get_list(self):
123 return self
122 return self
124
123
125 l = list = property(get_list)
124 l = list = property(get_list)
126
125
127 def get_spstr(self):
126 def get_spstr(self):
128 try:
127 try:
129 return self.__spstr
128 return self.__spstr
130 except AttributeError:
129 except AttributeError:
131 self.__spstr = ' '.join(self)
130 self.__spstr = ' '.join(self)
132 return self.__spstr
131 return self.__spstr
133
132
134 s = spstr = property(get_spstr)
133 s = spstr = property(get_spstr)
135
134
136 def get_nlstr(self):
135 def get_nlstr(self):
137 try:
136 try:
138 return self.__nlstr
137 return self.__nlstr
139 except AttributeError:
138 except AttributeError:
140 self.__nlstr = '\n'.join(self)
139 self.__nlstr = '\n'.join(self)
141 return self.__nlstr
140 return self.__nlstr
142
141
143 n = nlstr = property(get_nlstr)
142 n = nlstr = property(get_nlstr)
144
143
145 def get_paths(self):
144 def get_paths(self):
146 try:
145 try:
147 return self.__paths
146 return self.__paths
148 except AttributeError:
147 except AttributeError:
149 self.__paths = [path(p) for p in self if os.path.exists(p)]
148 self.__paths = [path(p) for p in self if os.path.exists(p)]
150 return self.__paths
149 return self.__paths
151
150
152 p = paths = property(get_paths)
151 p = paths = property(get_paths)
153
152
154 def grep(self, pattern, prune = False, field = None):
153 def grep(self, pattern, prune = False, field = None):
155 """ Return all strings matching 'pattern' (a regex or callable)
154 """ Return all strings matching 'pattern' (a regex or callable)
156
155
157 This is case-insensitive. If prune is true, return all items
156 This is case-insensitive. If prune is true, return all items
158 NOT matching the pattern.
157 NOT matching the pattern.
159
158
160 If field is specified, the match must occur in the specified
159 If field is specified, the match must occur in the specified
161 whitespace-separated field.
160 whitespace-separated field.
162
161
163 Examples::
162 Examples::
164
163
165 a.grep( lambda x: x.startswith('C') )
164 a.grep( lambda x: x.startswith('C') )
166 a.grep('Cha.*log', prune=1)
165 a.grep('Cha.*log', prune=1)
167 a.grep('chm', field=-1)
166 a.grep('chm', field=-1)
168 """
167 """
169
168
170 def match_target(s):
169 def match_target(s):
171 if field is None:
170 if field is None:
172 return s
171 return s
173 parts = s.split()
172 parts = s.split()
174 try:
173 try:
175 tgt = parts[field]
174 tgt = parts[field]
176 return tgt
175 return tgt
177 except IndexError:
176 except IndexError:
178 return ""
177 return ""
179
178
180 if isinstance(pattern, py3compat.string_types):
179 if isinstance(pattern, py3compat.string_types):
181 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
180 pred = lambda x : re.search(pattern, x, re.IGNORECASE)
182 else:
181 else:
183 pred = pattern
182 pred = pattern
184 if not prune:
183 if not prune:
185 return SList([el for el in self if pred(match_target(el))])
184 return SList([el for el in self if pred(match_target(el))])
186 else:
185 else:
187 return SList([el for el in self if not pred(match_target(el))])
186 return SList([el for el in self if not pred(match_target(el))])
188
187
189 def fields(self, *fields):
188 def fields(self, *fields):
190 """ Collect whitespace-separated fields from string list
189 """ Collect whitespace-separated fields from string list
191
190
192 Allows quick awk-like usage of string lists.
191 Allows quick awk-like usage of string lists.
193
192
194 Example data (in var a, created by 'a = !ls -l')::
193 Example data (in var a, created by 'a = !ls -l')::
195 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
194 -rwxrwxrwx 1 ville None 18 Dec 14 2006 ChangeLog
196 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
195 drwxrwxrwx+ 6 ville None 0 Oct 24 18:05 IPython
197
196
198 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
197 a.fields(0) is ['-rwxrwxrwx', 'drwxrwxrwx+']
199 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
198 a.fields(1,0) is ['1 -rwxrwxrwx', '6 drwxrwxrwx+']
200 (note the joining by space).
199 (note the joining by space).
201 a.fields(-1) is ['ChangeLog', 'IPython']
200 a.fields(-1) is ['ChangeLog', 'IPython']
202
201
203 IndexErrors are ignored.
202 IndexErrors are ignored.
204
203
205 Without args, fields() just split()'s the strings.
204 Without args, fields() just split()'s the strings.
206 """
205 """
207 if len(fields) == 0:
206 if len(fields) == 0:
208 return [el.split() for el in self]
207 return [el.split() for el in self]
209
208
210 res = SList()
209 res = SList()
211 for el in [f.split() for f in self]:
210 for el in [f.split() for f in self]:
212 lineparts = []
211 lineparts = []
213
212
214 for fd in fields:
213 for fd in fields:
215 try:
214 try:
216 lineparts.append(el[fd])
215 lineparts.append(el[fd])
217 except IndexError:
216 except IndexError:
218 pass
217 pass
219 if lineparts:
218 if lineparts:
220 res.append(" ".join(lineparts))
219 res.append(" ".join(lineparts))
221
220
222 return res
221 return res
223
222
224 def sort(self,field= None, nums = False):
223 def sort(self,field= None, nums = False):
225 """ sort by specified fields (see fields())
224 """ sort by specified fields (see fields())
226
225
227 Example::
226 Example::
228 a.sort(1, nums = True)
227 a.sort(1, nums = True)
229
228
230 Sorts a by second field, in numerical order (so that 21 > 3)
229 Sorts a by second field, in numerical order (so that 21 > 3)
231
230
232 """
231 """
233
232
234 #decorate, sort, undecorate
233 #decorate, sort, undecorate
235 if field is not None:
234 if field is not None:
236 dsu = [[SList([line]).fields(field), line] for line in self]
235 dsu = [[SList([line]).fields(field), line] for line in self]
237 else:
236 else:
238 dsu = [[line, line] for line in self]
237 dsu = [[line, line] for line in self]
239 if nums:
238 if nums:
240 for i in range(len(dsu)):
239 for i in range(len(dsu)):
241 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
240 numstr = "".join([ch for ch in dsu[i][0] if ch.isdigit()])
242 try:
241 try:
243 n = int(numstr)
242 n = int(numstr)
244 except ValueError:
243 except ValueError:
245 n = 0;
244 n = 0;
246 dsu[i][0] = n
245 dsu[i][0] = n
247
246
248
247
249 dsu.sort()
248 dsu.sort()
250 return SList([t[1] for t in dsu])
249 return SList([t[1] for t in dsu])
251
250
252
251
253 # FIXME: We need to reimplement type specific displayhook and then add this
252 # FIXME: We need to reimplement type specific displayhook and then add this
254 # back as a custom printer. This should also be moved outside utils into the
253 # back as a custom printer. This should also be moved outside utils into the
255 # core.
254 # core.
256
255
257 # def print_slist(arg):
256 # def print_slist(arg):
258 # """ Prettier (non-repr-like) and more informative printer for SList """
257 # """ Prettier (non-repr-like) and more informative printer for SList """
259 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
258 # print "SList (.p, .n, .l, .s, .grep(), .fields(), sort() available):"
260 # if hasattr(arg, 'hideonce') and arg.hideonce:
259 # if hasattr(arg, 'hideonce') and arg.hideonce:
261 # arg.hideonce = False
260 # arg.hideonce = False
262 # return
261 # return
263 #
262 #
264 # nlprint(arg) # This was a nested list printer, now removed.
263 # nlprint(arg) # This was a nested list printer, now removed.
265 #
264 #
266 # print_slist = result_display.when_type(SList)(print_slist)
265 # print_slist = result_display.when_type(SList)(print_slist)
267
266
268
267
269 def indent(instr,nspaces=4, ntabs=0, flatten=False):
268 def indent(instr,nspaces=4, ntabs=0, flatten=False):
270 """Indent a string a given number of spaces or tabstops.
269 """Indent a string a given number of spaces or tabstops.
271
270
272 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
271 indent(str,nspaces=4,ntabs=0) -> indent str by ntabs+nspaces.
273
272
274 Parameters
273 Parameters
275 ----------
274 ----------
276
275
277 instr : basestring
276 instr : basestring
278 The string to be indented.
277 The string to be indented.
279 nspaces : int (default: 4)
278 nspaces : int (default: 4)
280 The number of spaces to be indented.
279 The number of spaces to be indented.
281 ntabs : int (default: 0)
280 ntabs : int (default: 0)
282 The number of tabs to be indented.
281 The number of tabs to be indented.
283 flatten : bool (default: False)
282 flatten : bool (default: False)
284 Whether to scrub existing indentation. If True, all lines will be
283 Whether to scrub existing indentation. If True, all lines will be
285 aligned to the same indentation. If False, existing indentation will
284 aligned to the same indentation. If False, existing indentation will
286 be strictly increased.
285 be strictly increased.
287
286
288 Returns
287 Returns
289 -------
288 -------
290
289
291 str|unicode : string indented by ntabs and nspaces.
290 str|unicode : string indented by ntabs and nspaces.
292
291
293 """
292 """
294 if instr is None:
293 if instr is None:
295 return
294 return
296 ind = '\t'*ntabs+' '*nspaces
295 ind = '\t'*ntabs+' '*nspaces
297 if flatten:
296 if flatten:
298 pat = re.compile(r'^\s*', re.MULTILINE)
297 pat = re.compile(r'^\s*', re.MULTILINE)
299 else:
298 else:
300 pat = re.compile(r'^', re.MULTILINE)
299 pat = re.compile(r'^', re.MULTILINE)
301 outstr = re.sub(pat, ind, instr)
300 outstr = re.sub(pat, ind, instr)
302 if outstr.endswith(os.linesep+ind):
301 if outstr.endswith(os.linesep+ind):
303 return outstr[:-len(ind)]
302 return outstr[:-len(ind)]
304 else:
303 else:
305 return outstr
304 return outstr
306
305
307
306
308 def list_strings(arg):
307 def list_strings(arg):
309 """Always return a list of strings, given a string or list of strings
308 """Always return a list of strings, given a string or list of strings
310 as input.
309 as input.
311
310
312 :Examples:
311 :Examples:
313
312
314 In [7]: list_strings('A single string')
313 In [7]: list_strings('A single string')
315 Out[7]: ['A single string']
314 Out[7]: ['A single string']
316
315
317 In [8]: list_strings(['A single string in a list'])
316 In [8]: list_strings(['A single string in a list'])
318 Out[8]: ['A single string in a list']
317 Out[8]: ['A single string in a list']
319
318
320 In [9]: list_strings(['A','list','of','strings'])
319 In [9]: list_strings(['A','list','of','strings'])
321 Out[9]: ['A', 'list', 'of', 'strings']
320 Out[9]: ['A', 'list', 'of', 'strings']
322 """
321 """
323
322
324 if isinstance(arg, py3compat.string_types): return [arg]
323 if isinstance(arg, py3compat.string_types): return [arg]
325 else: return arg
324 else: return arg
326
325
327
326
328 def marquee(txt='',width=78,mark='*'):
327 def marquee(txt='',width=78,mark='*'):
329 """Return the input string centered in a 'marquee'.
328 """Return the input string centered in a 'marquee'.
330
329
331 :Examples:
330 :Examples:
332
331
333 In [16]: marquee('A test',40)
332 In [16]: marquee('A test',40)
334 Out[16]: '**************** A test ****************'
333 Out[16]: '**************** A test ****************'
335
334
336 In [17]: marquee('A test',40,'-')
335 In [17]: marquee('A test',40,'-')
337 Out[17]: '---------------- A test ----------------'
336 Out[17]: '---------------- A test ----------------'
338
337
339 In [18]: marquee('A test',40,' ')
338 In [18]: marquee('A test',40,' ')
340 Out[18]: ' A test '
339 Out[18]: ' A test '
341
340
342 """
341 """
343 if not txt:
342 if not txt:
344 return (mark*width)[:width]
343 return (mark*width)[:width]
345 nmark = (width-len(txt)-2)//len(mark)//2
344 nmark = (width-len(txt)-2)//len(mark)//2
346 if nmark < 0: nmark =0
345 if nmark < 0: nmark =0
347 marks = mark*nmark
346 marks = mark*nmark
348 return '%s %s %s' % (marks,txt,marks)
347 return '%s %s %s' % (marks,txt,marks)
349
348
350
349
351 ini_spaces_re = re.compile(r'^(\s+)')
350 ini_spaces_re = re.compile(r'^(\s+)')
352
351
353 def num_ini_spaces(strng):
352 def num_ini_spaces(strng):
354 """Return the number of initial spaces in a string"""
353 """Return the number of initial spaces in a string"""
355
354
356 ini_spaces = ini_spaces_re.match(strng)
355 ini_spaces = ini_spaces_re.match(strng)
357 if ini_spaces:
356 if ini_spaces:
358 return ini_spaces.end()
357 return ini_spaces.end()
359 else:
358 else:
360 return 0
359 return 0
361
360
362
361
363 def format_screen(strng):
362 def format_screen(strng):
364 """Format a string for screen printing.
363 """Format a string for screen printing.
365
364
366 This removes some latex-type format codes."""
365 This removes some latex-type format codes."""
367 # Paragraph continue
366 # Paragraph continue
368 par_re = re.compile(r'\\$',re.MULTILINE)
367 par_re = re.compile(r'\\$',re.MULTILINE)
369 strng = par_re.sub('',strng)
368 strng = par_re.sub('',strng)
370 return strng
369 return strng
371
370
372
371
373 def dedent(text):
372 def dedent(text):
374 """Equivalent of textwrap.dedent that ignores unindented first line.
373 """Equivalent of textwrap.dedent that ignores unindented first line.
375
374
376 This means it will still dedent strings like:
375 This means it will still dedent strings like:
377 '''foo
376 '''foo
378 is a bar
377 is a bar
379 '''
378 '''
380
379
381 For use in wrap_paragraphs.
380 For use in wrap_paragraphs.
382 """
381 """
383
382
384 if text.startswith('\n'):
383 if text.startswith('\n'):
385 # text starts with blank line, don't ignore the first line
384 # text starts with blank line, don't ignore the first line
386 return textwrap.dedent(text)
385 return textwrap.dedent(text)
387
386
388 # split first line
387 # split first line
389 splits = text.split('\n',1)
388 splits = text.split('\n',1)
390 if len(splits) == 1:
389 if len(splits) == 1:
391 # only one line
390 # only one line
392 return textwrap.dedent(text)
391 return textwrap.dedent(text)
393
392
394 first, rest = splits
393 first, rest = splits
395 # dedent everything but the first line
394 # dedent everything but the first line
396 rest = textwrap.dedent(rest)
395 rest = textwrap.dedent(rest)
397 return '\n'.join([first, rest])
396 return '\n'.join([first, rest])
398
397
399
398
400 def wrap_paragraphs(text, ncols=80):
399 def wrap_paragraphs(text, ncols=80):
401 """Wrap multiple paragraphs to fit a specified width.
400 """Wrap multiple paragraphs to fit a specified width.
402
401
403 This is equivalent to textwrap.wrap, but with support for multiple
402 This is equivalent to textwrap.wrap, but with support for multiple
404 paragraphs, as separated by empty lines.
403 paragraphs, as separated by empty lines.
405
404
406 Returns
405 Returns
407 -------
406 -------
408
407
409 list of complete paragraphs, wrapped to fill `ncols` columns.
408 list of complete paragraphs, wrapped to fill `ncols` columns.
410 """
409 """
411 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
410 paragraph_re = re.compile(r'\n(\s*\n)+', re.MULTILINE)
412 text = dedent(text).strip()
411 text = dedent(text).strip()
413 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
412 paragraphs = paragraph_re.split(text)[::2] # every other entry is space
414 out_ps = []
413 out_ps = []
415 indent_re = re.compile(r'\n\s+', re.MULTILINE)
414 indent_re = re.compile(r'\n\s+', re.MULTILINE)
416 for p in paragraphs:
415 for p in paragraphs:
417 # presume indentation that survives dedent is meaningful formatting,
416 # presume indentation that survives dedent is meaningful formatting,
418 # so don't fill unless text is flush.
417 # so don't fill unless text is flush.
419 if indent_re.search(p) is None:
418 if indent_re.search(p) is None:
420 # wrap paragraph
419 # wrap paragraph
421 p = textwrap.fill(p, ncols)
420 p = textwrap.fill(p, ncols)
422 out_ps.append(p)
421 out_ps.append(p)
423 return out_ps
422 return out_ps
424
423
425
424
426 def long_substr(data):
425 def long_substr(data):
427 """Return the longest common substring in a list of strings.
426 """Return the longest common substring in a list of strings.
428
427
429 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
428 Credit: http://stackoverflow.com/questions/2892931/longest-common-substring-from-more-than-two-strings-python
430 """
429 """
431 substr = ''
430 substr = ''
432 if len(data) > 1 and len(data[0]) > 0:
431 if len(data) > 1 and len(data[0]) > 0:
433 for i in range(len(data[0])):
432 for i in range(len(data[0])):
434 for j in range(len(data[0])-i+1):
433 for j in range(len(data[0])-i+1):
435 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
434 if j > len(substr) and all(data[0][i:i+j] in x for x in data):
436 substr = data[0][i:i+j]
435 substr = data[0][i:i+j]
437 elif len(data) == 1:
436 elif len(data) == 1:
438 substr = data[0]
437 substr = data[0]
439 return substr
438 return substr
440
439
441
440
442 def strip_email_quotes(text):
441 def strip_email_quotes(text):
443 """Strip leading email quotation characters ('>').
442 """Strip leading email quotation characters ('>').
444
443
445 Removes any combination of leading '>' interspersed with whitespace that
444 Removes any combination of leading '>' interspersed with whitespace that
446 appears *identically* in all lines of the input text.
445 appears *identically* in all lines of the input text.
447
446
448 Parameters
447 Parameters
449 ----------
448 ----------
450 text : str
449 text : str
451
450
452 Examples
451 Examples
453 --------
452 --------
454
453
455 Simple uses::
454 Simple uses::
456
455
457 In [2]: strip_email_quotes('> > text')
456 In [2]: strip_email_quotes('> > text')
458 Out[2]: 'text'
457 Out[2]: 'text'
459
458
460 In [3]: strip_email_quotes('> > text\\n> > more')
459 In [3]: strip_email_quotes('> > text\\n> > more')
461 Out[3]: 'text\\nmore'
460 Out[3]: 'text\\nmore'
462
461
463 Note how only the common prefix that appears in all lines is stripped::
462 Note how only the common prefix that appears in all lines is stripped::
464
463
465 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
464 In [4]: strip_email_quotes('> > text\\n> > more\\n> more...')
466 Out[4]: '> text\\n> more\\nmore...'
465 Out[4]: '> text\\n> more\\nmore...'
467
466
468 So if any line has no quote marks ('>') , then none are stripped from any
467 So if any line has no quote marks ('>') , then none are stripped from any
469 of them ::
468 of them ::
470
469
471 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
470 In [5]: strip_email_quotes('> > text\\n> > more\\nlast different')
472 Out[5]: '> > text\\n> > more\\nlast different'
471 Out[5]: '> > text\\n> > more\\nlast different'
473 """
472 """
474 lines = text.splitlines()
473 lines = text.splitlines()
475 matches = set()
474 matches = set()
476 for line in lines:
475 for line in lines:
477 prefix = re.match(r'^(\s*>[ >]*)', line)
476 prefix = re.match(r'^(\s*>[ >]*)', line)
478 if prefix:
477 if prefix:
479 matches.add(prefix.group(1))
478 matches.add(prefix.group(1))
480 else:
479 else:
481 break
480 break
482 else:
481 else:
483 prefix = long_substr(list(matches))
482 prefix = long_substr(list(matches))
484 if prefix:
483 if prefix:
485 strip = len(prefix)
484 strip = len(prefix)
486 text = '\n'.join([ ln[strip:] for ln in lines])
485 text = '\n'.join([ ln[strip:] for ln in lines])
487 return text
486 return text
488
487
489
488
490 class EvalFormatter(Formatter):
489 class EvalFormatter(Formatter):
491 """A String Formatter that allows evaluation of simple expressions.
490 """A String Formatter that allows evaluation of simple expressions.
492
491
493 Note that this version interprets a : as specifying a format string (as per
492 Note that this version interprets a : as specifying a format string (as per
494 standard string formatting), so if slicing is required, you must explicitly
493 standard string formatting), so if slicing is required, you must explicitly
495 create a slice.
494 create a slice.
496
495
497 This is to be used in templating cases, such as the parallel batch
496 This is to be used in templating cases, such as the parallel batch
498 script templates, where simple arithmetic on arguments is useful.
497 script templates, where simple arithmetic on arguments is useful.
499
498
500 Examples
499 Examples
501 --------
500 --------
502
501
503 In [1]: f = EvalFormatter()
502 In [1]: f = EvalFormatter()
504 In [2]: f.format('{n//4}', n=8)
503 In [2]: f.format('{n//4}', n=8)
505 Out [2]: '2'
504 Out [2]: '2'
506
505
507 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
506 In [3]: f.format("{greeting[slice(2,4)]}", greeting="Hello")
508 Out [3]: 'll'
507 Out [3]: 'll'
509 """
508 """
510 def get_field(self, name, args, kwargs):
509 def get_field(self, name, args, kwargs):
511 v = eval(name, kwargs)
510 v = eval(name, kwargs)
512 return v, name
511 return v, name
513
512
514
513
515 @skip_doctest_py3
514 @skip_doctest_py3
516 class FullEvalFormatter(Formatter):
515 class FullEvalFormatter(Formatter):
517 """A String Formatter that allows evaluation of simple expressions.
516 """A String Formatter that allows evaluation of simple expressions.
518
517
519 Any time a format key is not found in the kwargs,
518 Any time a format key is not found in the kwargs,
520 it will be tried as an expression in the kwargs namespace.
519 it will be tried as an expression in the kwargs namespace.
521
520
522 Note that this version allows slicing using [1:2], so you cannot specify
521 Note that this version allows slicing using [1:2], so you cannot specify
523 a format string. Use :class:`EvalFormatter` to permit format strings.
522 a format string. Use :class:`EvalFormatter` to permit format strings.
524
523
525 Examples
524 Examples
526 --------
525 --------
527
526
528 In [1]: f = FullEvalFormatter()
527 In [1]: f = FullEvalFormatter()
529 In [2]: f.format('{n//4}', n=8)
528 In [2]: f.format('{n//4}', n=8)
530 Out[2]: u'2'
529 Out[2]: u'2'
531
530
532 In [3]: f.format('{list(range(5))[2:4]}')
531 In [3]: f.format('{list(range(5))[2:4]}')
533 Out[3]: u'[2, 3]'
532 Out[3]: u'[2, 3]'
534
533
535 In [4]: f.format('{3*2}')
534 In [4]: f.format('{3*2}')
536 Out[4]: u'6'
535 Out[4]: u'6'
537 """
536 """
538 # copied from Formatter._vformat with minor changes to allow eval
537 # copied from Formatter._vformat with minor changes to allow eval
539 # and replace the format_spec code with slicing
538 # and replace the format_spec code with slicing
540 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
539 def _vformat(self, format_string, args, kwargs, used_args, recursion_depth):
541 if recursion_depth < 0:
540 if recursion_depth < 0:
542 raise ValueError('Max string recursion exceeded')
541 raise ValueError('Max string recursion exceeded')
543 result = []
542 result = []
544 for literal_text, field_name, format_spec, conversion in \
543 for literal_text, field_name, format_spec, conversion in \
545 self.parse(format_string):
544 self.parse(format_string):
546
545
547 # output the literal text
546 # output the literal text
548 if literal_text:
547 if literal_text:
549 result.append(literal_text)
548 result.append(literal_text)
550
549
551 # if there's a field, output it
550 # if there's a field, output it
552 if field_name is not None:
551 if field_name is not None:
553 # this is some markup, find the object and do
552 # this is some markup, find the object and do
554 # the formatting
553 # the formatting
555
554
556 if format_spec:
555 if format_spec:
557 # override format spec, to allow slicing:
556 # override format spec, to allow slicing:
558 field_name = ':'.join([field_name, format_spec])
557 field_name = ':'.join([field_name, format_spec])
559
558
560 # eval the contents of the field for the object
559 # eval the contents of the field for the object
561 # to be formatted
560 # to be formatted
562 obj = eval(field_name, kwargs)
561 obj = eval(field_name, kwargs)
563
562
564 # do any conversion on the resulting object
563 # do any conversion on the resulting object
565 obj = self.convert_field(obj, conversion)
564 obj = self.convert_field(obj, conversion)
566
565
567 # format the object and append to the result
566 # format the object and append to the result
568 result.append(self.format_field(obj, ''))
567 result.append(self.format_field(obj, ''))
569
568
570 return u''.join(py3compat.cast_unicode(s) for s in result)
569 return u''.join(py3compat.cast_unicode(s) for s in result)
571
570
572
571
573 @skip_doctest_py3
572 @skip_doctest_py3
574 class DollarFormatter(FullEvalFormatter):
573 class DollarFormatter(FullEvalFormatter):
575 """Formatter allowing Itpl style $foo replacement, for names and attribute
574 """Formatter allowing Itpl style $foo replacement, for names and attribute
576 access only. Standard {foo} replacement also works, and allows full
575 access only. Standard {foo} replacement also works, and allows full
577 evaluation of its arguments.
576 evaluation of its arguments.
578
577
579 Examples
578 Examples
580 --------
579 --------
581 In [1]: f = DollarFormatter()
580 In [1]: f = DollarFormatter()
582 In [2]: f.format('{n//4}', n=8)
581 In [2]: f.format('{n//4}', n=8)
583 Out[2]: u'2'
582 Out[2]: u'2'
584
583
585 In [3]: f.format('23 * 76 is $result', result=23*76)
584 In [3]: f.format('23 * 76 is $result', result=23*76)
586 Out[3]: u'23 * 76 is 1748'
585 Out[3]: u'23 * 76 is 1748'
587
586
588 In [4]: f.format('$a or {b}', a=1, b=2)
587 In [4]: f.format('$a or {b}', a=1, b=2)
589 Out[4]: u'1 or 2'
588 Out[4]: u'1 or 2'
590 """
589 """
591 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
590 _dollar_pattern = re.compile("(.*?)\$(\$?[\w\.]+)")
592 def parse(self, fmt_string):
591 def parse(self, fmt_string):
593 for literal_txt, field_name, format_spec, conversion \
592 for literal_txt, field_name, format_spec, conversion \
594 in Formatter.parse(self, fmt_string):
593 in Formatter.parse(self, fmt_string):
595
594
596 # Find $foo patterns in the literal text.
595 # Find $foo patterns in the literal text.
597 continue_from = 0
596 continue_from = 0
598 txt = ""
597 txt = ""
599 for m in self._dollar_pattern.finditer(literal_txt):
598 for m in self._dollar_pattern.finditer(literal_txt):
600 new_txt, new_field = m.group(1,2)
599 new_txt, new_field = m.group(1,2)
601 # $$foo --> $foo
600 # $$foo --> $foo
602 if new_field.startswith("$"):
601 if new_field.startswith("$"):
603 txt += new_txt + new_field
602 txt += new_txt + new_field
604 else:
603 else:
605 yield (txt + new_txt, new_field, "", None)
604 yield (txt + new_txt, new_field, "", None)
606 txt = ""
605 txt = ""
607 continue_from = m.end()
606 continue_from = m.end()
608
607
609 # Re-yield the {foo} style pattern
608 # Re-yield the {foo} style pattern
610 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
609 yield (txt + literal_txt[continue_from:], field_name, format_spec, conversion)
611
610
612 #-----------------------------------------------------------------------------
611 #-----------------------------------------------------------------------------
613 # Utils to columnize a list of string
612 # Utils to columnize a list of string
614 #-----------------------------------------------------------------------------
613 #-----------------------------------------------------------------------------
615
614
616 def _chunks(l, n):
615 def _chunks(l, n):
617 """Yield successive n-sized chunks from l."""
616 """Yield successive n-sized chunks from l."""
618 for i in xrange(0, len(l), n):
617 for i in py3compat.xrange(0, len(l), n):
619 yield l[i:i+n]
618 yield l[i:i+n]
620
619
621
620
622 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
621 def _find_optimal(rlist , separator_size=2 , displaywidth=80):
623 """Calculate optimal info to columnize a list of string"""
622 """Calculate optimal info to columnize a list of string"""
624 for nrow in range(1, len(rlist)+1) :
623 for nrow in range(1, len(rlist)+1) :
625 chk = map(max,_chunks(rlist, nrow))
624 chk = map(max,_chunks(rlist, nrow))
626 sumlength = sum(chk)
625 sumlength = sum(chk)
627 ncols = len(chk)
626 ncols = len(chk)
628 if sumlength+separator_size*(ncols-1) <= displaywidth :
627 if sumlength+separator_size*(ncols-1) <= displaywidth :
629 break;
628 break;
630 return {'columns_numbers' : ncols,
629 return {'columns_numbers' : ncols,
631 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
630 'optimal_separator_width':(displaywidth - sumlength)/(ncols-1) if (ncols -1) else 0,
632 'rows_numbers' : nrow,
631 'rows_numbers' : nrow,
633 'columns_width' : chk
632 'columns_width' : chk
634 }
633 }
635
634
636
635
637 def _get_or_default(mylist, i, default=None):
636 def _get_or_default(mylist, i, default=None):
638 """return list item number, or default if don't exist"""
637 """return list item number, or default if don't exist"""
639 if i >= len(mylist):
638 if i >= len(mylist):
640 return default
639 return default
641 else :
640 else :
642 return mylist[i]
641 return mylist[i]
643
642
644
643
645 @skip_doctest
644 @skip_doctest
646 def compute_item_matrix(items, empty=None, *args, **kwargs) :
645 def compute_item_matrix(items, empty=None, *args, **kwargs) :
647 """Returns a nested list, and info to columnize items
646 """Returns a nested list, and info to columnize items
648
647
649 Parameters
648 Parameters
650 ----------
649 ----------
651
650
652 items :
651 items :
653 list of strings to columize
652 list of strings to columize
654 empty : (default None)
653 empty : (default None)
655 default value to fill list if needed
654 default value to fill list if needed
656 separator_size : int (default=2)
655 separator_size : int (default=2)
657 How much caracters will be used as a separation between each columns.
656 How much caracters will be used as a separation between each columns.
658 displaywidth : int (default=80)
657 displaywidth : int (default=80)
659 The width of the area onto wich the columns should enter
658 The width of the area onto wich the columns should enter
660
659
661 Returns
660 Returns
662 -------
661 -------
663
662
664 Returns a tuple of (strings_matrix, dict_info)
663 Returns a tuple of (strings_matrix, dict_info)
665
664
666 strings_matrix :
665 strings_matrix :
667
666
668 nested list of string, the outer most list contains as many list as
667 nested list of string, the outer most list contains as many list as
669 rows, the innermost lists have each as many element as colums. If the
668 rows, the innermost lists have each as many element as colums. If the
670 total number of elements in `items` does not equal the product of
669 total number of elements in `items` does not equal the product of
671 rows*columns, the last element of some lists are filled with `None`.
670 rows*columns, the last element of some lists are filled with `None`.
672
671
673 dict_info :
672 dict_info :
674 some info to make columnize easier:
673 some info to make columnize easier:
675
674
676 columns_numbers : number of columns
675 columns_numbers : number of columns
677 rows_numbers : number of rows
676 rows_numbers : number of rows
678 columns_width : list of with of each columns
677 columns_width : list of with of each columns
679 optimal_separator_width : best separator width between columns
678 optimal_separator_width : best separator width between columns
680
679
681 Examples
680 Examples
682 --------
681 --------
683
682
684 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
683 In [1]: l = ['aaa','b','cc','d','eeeee','f','g','h','i','j','k','l']
685 ...: compute_item_matrix(l,displaywidth=12)
684 ...: compute_item_matrix(l,displaywidth=12)
686 Out[1]:
685 Out[1]:
687 ([['aaa', 'f', 'k'],
686 ([['aaa', 'f', 'k'],
688 ['b', 'g', 'l'],
687 ['b', 'g', 'l'],
689 ['cc', 'h', None],
688 ['cc', 'h', None],
690 ['d', 'i', None],
689 ['d', 'i', None],
691 ['eeeee', 'j', None]],
690 ['eeeee', 'j', None]],
692 {'columns_numbers': 3,
691 {'columns_numbers': 3,
693 'columns_width': [5, 1, 1],
692 'columns_width': [5, 1, 1],
694 'optimal_separator_width': 2,
693 'optimal_separator_width': 2,
695 'rows_numbers': 5})
694 'rows_numbers': 5})
696
695
697 """
696 """
698 info = _find_optimal(map(len, items), *args, **kwargs)
697 info = _find_optimal(map(len, items), *args, **kwargs)
699 nrow, ncol = info['rows_numbers'], info['columns_numbers']
698 nrow, ncol = info['rows_numbers'], info['columns_numbers']
700 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
699 return ([[ _get_or_default(items, c*nrow+i, default=empty) for c in range(ncol) ] for i in range(nrow) ], info)
701
700
702
701
703 def columnize(items, separator=' ', displaywidth=80):
702 def columnize(items, separator=' ', displaywidth=80):
704 """ Transform a list of strings into a single string with columns.
703 """ Transform a list of strings into a single string with columns.
705
704
706 Parameters
705 Parameters
707 ----------
706 ----------
708 items : sequence of strings
707 items : sequence of strings
709 The strings to process.
708 The strings to process.
710
709
711 separator : str, optional [default is two spaces]
710 separator : str, optional [default is two spaces]
712 The string that separates columns.
711 The string that separates columns.
713
712
714 displaywidth : int, optional [default is 80]
713 displaywidth : int, optional [default is 80]
715 Width of the display in number of characters.
714 Width of the display in number of characters.
716
715
717 Returns
716 Returns
718 -------
717 -------
719 The formatted string.
718 The formatted string.
720 """
719 """
721 if not items :
720 if not items :
722 return '\n'
721 return '\n'
723 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
722 matrix, info = compute_item_matrix(items, separator_size=len(separator), displaywidth=displaywidth)
724 fmatrix = [filter(None, x) for x in matrix]
723 fmatrix = [filter(None, x) for x in matrix]
725 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
724 sjoin = lambda x : separator.join([ y.ljust(w, ' ') for y, w in zip(x, info['columns_width'])])
726 return '\n'.join(map(sjoin, fmatrix))+'\n'
725 return '\n'.join(map(sjoin, fmatrix))+'\n'
727
726
728
727
729 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
728 def get_text_list(list_, last_sep=' and ', sep=", ", wrap_item_with=""):
730 """
729 """
731 Return a string with a natural enumeration of items
730 Return a string with a natural enumeration of items
732
731
733 >>> get_text_list(['a', 'b', 'c', 'd'])
732 >>> get_text_list(['a', 'b', 'c', 'd'])
734 'a, b, c and d'
733 'a, b, c and d'
735 >>> get_text_list(['a', 'b', 'c'], ' or ')
734 >>> get_text_list(['a', 'b', 'c'], ' or ')
736 'a, b or c'
735 'a, b or c'
737 >>> get_text_list(['a', 'b', 'c'], ', ')
736 >>> get_text_list(['a', 'b', 'c'], ', ')
738 'a, b, c'
737 'a, b, c'
739 >>> get_text_list(['a', 'b'], ' or ')
738 >>> get_text_list(['a', 'b'], ' or ')
740 'a or b'
739 'a or b'
741 >>> get_text_list(['a'])
740 >>> get_text_list(['a'])
742 'a'
741 'a'
743 >>> get_text_list([])
742 >>> get_text_list([])
744 ''
743 ''
745 >>> get_text_list(['a', 'b'], wrap_item_with="`")
744 >>> get_text_list(['a', 'b'], wrap_item_with="`")
746 '`a` and `b`'
745 '`a` and `b`'
747 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
746 >>> get_text_list(['a', 'b', 'c', 'd'], " = ", sep=" + ")
748 'a + b + c = d'
747 'a + b + c = d'
749 """
748 """
750 if len(list_) == 0:
749 if len(list_) == 0:
751 return ''
750 return ''
752 if wrap_item_with:
751 if wrap_item_with:
753 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
752 list_ = ['%s%s%s' % (wrap_item_with, item, wrap_item_with) for
754 item in list_]
753 item in list_]
755 if len(list_) == 1:
754 if len(list_) == 1:
756 return list_[0]
755 return list_[0]
757 return '%s%s%s' % (
756 return '%s%s%s' % (
758 sep.join(i for i in list_[:-1]),
757 sep.join(i for i in list_[:-1]),
759 last_sep, list_[-1]) No newline at end of file
758 last_sep, list_[-1])
@@ -1,116 +1,118 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Utilities for timing code execution.
3 Utilities for timing code execution.
4 """
4 """
5
5
6 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import time
17 import time
18
18
19 from .py3compat import xrange
20
19 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
20 # Code
22 # Code
21 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
22
24
23 # If possible (Unix), use the resource module instead of time.clock()
25 # If possible (Unix), use the resource module instead of time.clock()
24 try:
26 try:
25 import resource
27 import resource
26 def clocku():
28 def clocku():
27 """clocku() -> floating point number
29 """clocku() -> floating point number
28
30
29 Return the *USER* CPU time in seconds since the start of the process.
31 Return the *USER* CPU time in seconds since the start of the process.
30 This is done via a call to resource.getrusage, so it avoids the
32 This is done via a call to resource.getrusage, so it avoids the
31 wraparound problems in time.clock()."""
33 wraparound problems in time.clock()."""
32
34
33 return resource.getrusage(resource.RUSAGE_SELF)[0]
35 return resource.getrusage(resource.RUSAGE_SELF)[0]
34
36
35 def clocks():
37 def clocks():
36 """clocks() -> floating point number
38 """clocks() -> floating point number
37
39
38 Return the *SYSTEM* CPU time in seconds since the start of the process.
40 Return the *SYSTEM* CPU time in seconds since the start of the process.
39 This is done via a call to resource.getrusage, so it avoids the
41 This is done via a call to resource.getrusage, so it avoids the
40 wraparound problems in time.clock()."""
42 wraparound problems in time.clock()."""
41
43
42 return resource.getrusage(resource.RUSAGE_SELF)[1]
44 return resource.getrusage(resource.RUSAGE_SELF)[1]
43
45
44 def clock():
46 def clock():
45 """clock() -> floating point number
47 """clock() -> floating point number
46
48
47 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
49 Return the *TOTAL USER+SYSTEM* CPU time in seconds since the start of
48 the process. This is done via a call to resource.getrusage, so it
50 the process. This is done via a call to resource.getrusage, so it
49 avoids the wraparound problems in time.clock()."""
51 avoids the wraparound problems in time.clock()."""
50
52
51 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
53 u,s = resource.getrusage(resource.RUSAGE_SELF)[:2]
52 return u+s
54 return u+s
53
55
54 def clock2():
56 def clock2():
55 """clock2() -> (t_user,t_system)
57 """clock2() -> (t_user,t_system)
56
58
57 Similar to clock(), but return a tuple of user/system times."""
59 Similar to clock(), but return a tuple of user/system times."""
58 return resource.getrusage(resource.RUSAGE_SELF)[:2]
60 return resource.getrusage(resource.RUSAGE_SELF)[:2]
59 except ImportError:
61 except ImportError:
60 # There is no distinction of user/system time under windows, so we just use
62 # There is no distinction of user/system time under windows, so we just use
61 # time.clock() for everything...
63 # time.clock() for everything...
62 clocku = clocks = clock = time.clock
64 clocku = clocks = clock = time.clock
63 def clock2():
65 def clock2():
64 """Under windows, system CPU time can't be measured.
66 """Under windows, system CPU time can't be measured.
65
67
66 This just returns clock() and zero."""
68 This just returns clock() and zero."""
67 return time.clock(),0.0
69 return time.clock(),0.0
68
70
69
71
70 def timings_out(reps,func,*args,**kw):
72 def timings_out(reps,func,*args,**kw):
71 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
73 """timings_out(reps,func,*args,**kw) -> (t_total,t_per_call,output)
72
74
73 Execute a function reps times, return a tuple with the elapsed total
75 Execute a function reps times, return a tuple with the elapsed total
74 CPU time in seconds, the time per call and the function's output.
76 CPU time in seconds, the time per call and the function's output.
75
77
76 Under Unix, the return value is the sum of user+system time consumed by
78 Under Unix, the return value is the sum of user+system time consumed by
77 the process, computed via the resource module. This prevents problems
79 the process, computed via the resource module. This prevents problems
78 related to the wraparound effect which the time.clock() function has.
80 related to the wraparound effect which the time.clock() function has.
79
81
80 Under Windows the return value is in wall clock seconds. See the
82 Under Windows the return value is in wall clock seconds. See the
81 documentation for the time module for more details."""
83 documentation for the time module for more details."""
82
84
83 reps = int(reps)
85 reps = int(reps)
84 assert reps >=1, 'reps must be >= 1'
86 assert reps >=1, 'reps must be >= 1'
85 if reps==1:
87 if reps==1:
86 start = clock()
88 start = clock()
87 out = func(*args,**kw)
89 out = func(*args,**kw)
88 tot_time = clock()-start
90 tot_time = clock()-start
89 else:
91 else:
90 rng = xrange(reps-1) # the last time is executed separately to store output
92 rng = xrange(reps-1) # the last time is executed separately to store output
91 start = clock()
93 start = clock()
92 for dummy in rng: func(*args,**kw)
94 for dummy in rng: func(*args,**kw)
93 out = func(*args,**kw) # one last time
95 out = func(*args,**kw) # one last time
94 tot_time = clock()-start
96 tot_time = clock()-start
95 av_time = tot_time / reps
97 av_time = tot_time / reps
96 return tot_time,av_time,out
98 return tot_time,av_time,out
97
99
98
100
99 def timings(reps,func,*args,**kw):
101 def timings(reps,func,*args,**kw):
100 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
102 """timings(reps,func,*args,**kw) -> (t_total,t_per_call)
101
103
102 Execute a function reps times, return a tuple with the elapsed total CPU
104 Execute a function reps times, return a tuple with the elapsed total CPU
103 time in seconds and the time per call. These are just the first two values
105 time in seconds and the time per call. These are just the first two values
104 in timings_out()."""
106 in timings_out()."""
105
107
106 return timings_out(reps,func,*args,**kw)[0:2]
108 return timings_out(reps,func,*args,**kw)[0:2]
107
109
108
110
109 def timing(func,*args,**kw):
111 def timing(func,*args,**kw):
110 """timing(func,*args,**kw) -> t_total
112 """timing(func,*args,**kw) -> t_total
111
113
112 Execute a function once, return the elapsed total CPU time in
114 Execute a function once, return the elapsed total CPU time in
113 seconds. This is just the first value in timings_out()."""
115 seconds. This is just the first value in timings_out()."""
114
116
115 return timings_out(1,func,*args,**kw)[0]
117 return timings_out(1,func,*args,**kw)[0]
116
118
General Comments 0
You need to be logged in to leave comments. Login now